# How to Plot a 3D Normal Distribution in Python?

To create a 3D surface plot of a bivariate normal distribution define two normally distributed random variables `x` and `y`, each with its own mean (`mu_x`, `mu_y`) and variance (`variance_x`, `variance_y`). The random variables are independent,the covariance between `x` and `y` is 0. Use the grid of `(x, y)` pairs to calculate the probability density function (pdf) of this bivariate normal distribution at each point.

Here’s the code for copy and paste. I also added comments to explain what each part is doing: 👇

```import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import multivariate_normal
from mpl_toolkits.mplot3d import Axes3D

# define parameters for x and y distributions
mu_x = 0  # mean of x
variance_x = 3  # variance of x

mu_y = 0  # mean of y
variance_y = 15  # variance of y

# define a grid for x and y values
x = np.linspace(-10, 10, 500)  # generate 500 points between -10 and 10 for x
y = np.linspace(-10, 10, 500)  # generate 500 points between -10 and 10 for y
X, Y = np.meshgrid(x, y)  # create a grid for (x,y) pairs

# create an empty array of the same shape as X to hold the (x, y) coordinates
pos = np.empty(X.shape + (2,))

# fill the pos array with the x and y coordinates
pos[:, :, 0] = X
pos[:, :, 1] = Y

# create a multivariate normal distribution using the defined parameters
rv = multivariate_normal([mu_x, mu_y], [[variance_x, 0], [0, variance_y]])

# create a new figure for 3D plot
fig = plt.figure()

# add a 3D subplot to the figure

# create a 3D surface plot of the multivariate normal distribution
ax.plot_surface(X, Y, rv.pdf(pos), cmap='viridis', linewidth=0)

# set labels for the axes
ax.set_xlabel('X axis')
ax.set_ylabel('Y axis')
ax.set_zlabel('Z axis')

# display the 3D plot
plt.show()
```

👉 Interactive: I also created a Google Colab Jupyter Notebook where you can plot it yourself. Click here to open it in a new tab.

I went ahead and tried to anticipate some follow-up questions you may have on the code. Here they are: 🤔❓

## FAQ

1. What is a multivariate normal distribution? This concept from probability theory and statistics extends the 1D normal distribution to multiple dimensions. The code uses a 2D or bivariate normal distribution.
2. What does `np.linspace(-10, 10, 500)` do? This function generates 500 evenly spaced points from -10 to 10 over the interval. It’s used here to create a range of values for x and y. I have written a detailed blog tutorial here (with video).
3. What is `np.meshgrid(x, y)` used for? This function generates a two-dimensional grid of coordinates based on two one-dimensional arrays. In this case, it generates a grid of `(x, y)` pairs.
4. What is the purpose of the `pos` array? The `pos` array is used to hold the coordinates of each point in the grid in a format suitable for use with the `multivariate_normal` probability density function. It’s a 3D array where the first two dimensions match the dimensions of the grid, and the third dimension has size 2 to hold the `x` and `y` coordinates.
5. What does `rv.pdf(pos)` do? This function calculates the value of the probability density function (pdf) of the multivariate normal distribution at each point in the grid.
6. What is `plot_surface` used for? This function is used to create a three-dimensional plot of the distribution. It takes as input the `x` and `y` coordinates and the pdf values at each point, generating a 3D surface plot.
7. What is the purpose of `'cmap'` parameter in the `plot_surface` function? The ‘`cmap`‘ parameter is used to specify the color map for the plot. `'viridis'` is one of the predefined color maps in Matplotlib.
8. Why is the covariance matrix diagonal? The covariance matrix is diagonal because the variables `x` and `y` are assumed to be independent in this bivariate distribution. Off-diagonal elements of the covariance matrix represent the covariance between different variables – these are zero if the variables are independent.
9. What does `linewidth=0` do? The `linewidth` parameter in `plot_surface` specifies the line width for the edges of the surface polygons. Setting it to 0 removes these edges.

Feel free to check out our full course on Matplotlib on the Finxter academy.