How to Plot a 3D Normal Distribution in Python?

4/5 - (2 votes)

To create a 3D surface plot of a bivariate normal distribution define two normally distributed random variables x and y, each with its own mean (mu_x, mu_y) and variance (variance_x, variance_y). The random variables are independent,the covariance between x and y is 0. Use the grid of (x, y) pairs to calculate the probability density function (pdf) of this bivariate normal distribution at each point.

Here’s the code for copy and paste. I also added comments to explain what each part is doing: 👇

import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import multivariate_normal
from mpl_toolkits.mplot3d import Axes3D

# define parameters for x and y distributions
mu_x = 0  # mean of x
variance_x = 3  # variance of x

mu_y = 0  # mean of y
variance_y = 15  # variance of y

# define a grid for x and y values
x = np.linspace(-10, 10, 500)  # generate 500 points between -10 and 10 for x
y = np.linspace(-10, 10, 500)  # generate 500 points between -10 and 10 for y
X, Y = np.meshgrid(x, y)  # create a grid for (x,y) pairs

# create an empty array of the same shape as X to hold the (x, y) coordinates
pos = np.empty(X.shape + (2,))

# fill the pos array with the x and y coordinates
pos[:, :, 0] = X  
pos[:, :, 1] = Y  

# create a multivariate normal distribution using the defined parameters
rv = multivariate_normal([mu_x, mu_y], [[variance_x, 0], [0, variance_y]])

# create a new figure for 3D plot
fig = plt.figure()

# add a 3D subplot to the figure
ax = fig.add_subplot(projection='3d')

# create a 3D surface plot of the multivariate normal distribution
ax.plot_surface(X, Y, rv.pdf(pos), cmap='viridis', linewidth=0)

# set labels for the axes
ax.set_xlabel('X axis')
ax.set_ylabel('Y axis')
ax.set_zlabel('Z axis')

# display the 3D plot

👉 Interactive: I also created a Google Colab Jupyter Notebook where you can plot it yourself. Click here to open it in a new tab.

I went ahead and tried to anticipate some follow-up questions you may have on the code. Here they are: 🤔❓


  1. What is a multivariate normal distribution? This concept from probability theory and statistics extends the 1D normal distribution to multiple dimensions. The code uses a 2D or bivariate normal distribution.
  2. What does np.linspace(-10, 10, 500) do? This function generates 500 evenly spaced points from -10 to 10 over the interval. It’s used here to create a range of values for x and y. I have written a detailed blog tutorial here (with video).
  3. What is np.meshgrid(x, y) used for? This function generates a two-dimensional grid of coordinates based on two one-dimensional arrays. In this case, it generates a grid of (x, y) pairs.
  4. What is the purpose of the pos array? The pos array is used to hold the coordinates of each point in the grid in a format suitable for use with the multivariate_normal probability density function. It’s a 3D array where the first two dimensions match the dimensions of the grid, and the third dimension has size 2 to hold the x and y coordinates.
  5. What does rv.pdf(pos) do? This function calculates the value of the probability density function (pdf) of the multivariate normal distribution at each point in the grid.
  6. What is plot_surface used for? This function is used to create a three-dimensional plot of the distribution. It takes as input the x and y coordinates and the pdf values at each point, generating a 3D surface plot.
  7. What is the purpose of 'cmap' parameter in the plot_surface function? The ‘cmap‘ parameter is used to specify the color map for the plot. 'viridis' is one of the predefined color maps in Matplotlib.
  8. Why is the covariance matrix diagonal? The covariance matrix is diagonal because the variables x and y are assumed to be independent in this bivariate distribution. Off-diagonal elements of the covariance matrix represent the covariance between different variables – these are zero if the variables are independent.
  9. What does linewidth=0 do? The linewidth parameter in plot_surface specifies the line width for the edges of the surface polygons. Setting it to 0 removes these edges.

Feel free to check out our full course on Matplotlib on the Finxter academy.

🚀 Academy: Matplotlib – The Complete Guide to Becoming a Data Visualization Wizard