To create a 3D surface plot of a bivariate normal distribution define two normally distributed random variables x
and y
, each with its own mean (mu_x
, mu_y
) and variance (variance_x
, variance_y
). The random variables are independent,the covariance between x
and y
is 0. Use the grid of (x, y)
pairs to calculate the probability density function (pdf) of this bivariate normal distribution at each point.
Here’s the code for copy and paste. I also added comments to explain what each part is doing: π
import numpy as np import matplotlib.pyplot as plt from scipy.stats import multivariate_normal from mpl_toolkits.mplot3d import Axes3D # define parameters for x and y distributions mu_x = 0 # mean of x variance_x = 3 # variance of x mu_y = 0 # mean of y variance_y = 15 # variance of y # define a grid for x and y values x = np.linspace(-10, 10, 500) # generate 500 points between -10 and 10 for x y = np.linspace(-10, 10, 500) # generate 500 points between -10 and 10 for y X, Y = np.meshgrid(x, y) # create a grid for (x,y) pairs # create an empty array of the same shape as X to hold the (x, y) coordinates pos = np.empty(X.shape + (2,)) # fill the pos array with the x and y coordinates pos[:, :, 0] = X pos[:, :, 1] = Y # create a multivariate normal distribution using the defined parameters rv = multivariate_normal([mu_x, mu_y], [[variance_x, 0], [0, variance_y]]) # create a new figure for 3D plot fig = plt.figure() # add a 3D subplot to the figure ax = fig.add_subplot(projection='3d') # create a 3D surface plot of the multivariate normal distribution ax.plot_surface(X, Y, rv.pdf(pos), cmap='viridis', linewidth=0) # set labels for the axes ax.set_xlabel('X axis') ax.set_ylabel('Y axis') ax.set_zlabel('Z axis') # display the 3D plot plt.show()
π Interactive: I also created a Google Colab Jupyter Notebook where you can plot it yourself. Click here to open it in a new tab.
I went ahead and tried to anticipate some follow-up questions you may have on the code. Here they are: π€β
FAQ
- What is a multivariate normal distribution? This concept from probability theory and statistics extends the 1D normal distribution to multiple dimensions. The code uses a 2D or bivariate normal distribution.
- What does
np.linspace(-10, 10, 500)
do? This function generates 500 evenly spaced points from -10 to 10 over the interval. It’s used here to create a range of values for x and y. I have written a detailed blog tutorial here (with video). - What is
np.meshgrid(x, y)
used for? This function generates a two-dimensional grid of coordinates based on two one-dimensional arrays. In this case, it generates a grid of(x, y)
pairs. - What is the purpose of the
pos
array? Thepos
array is used to hold the coordinates of each point in the grid in a format suitable for use with themultivariate_normal
probability density function. It’s a 3D array where the first two dimensions match the dimensions of the grid, and the third dimension has size 2 to hold thex
andy
coordinates. - What does
rv.pdf(pos)
do? This function calculates the value of the probability density function (pdf) of the multivariate normal distribution at each point in the grid. - What is
plot_surface
used for? This function is used to create a three-dimensional plot of the distribution. It takes as input thex
andy
coordinates and the pdf values at each point, generating a 3D surface plot. - What is the purpose of
'cmap'
parameter in theplot_surface
function? The ‘cmap
‘ parameter is used to specify the color map for the plot.'viridis'
is one of the predefined color maps in Matplotlib. - Why is the covariance matrix diagonal? The covariance matrix is diagonal because the variables
x
andy
are assumed to be independent in this bivariate distribution. Off-diagonal elements of the covariance matrix represent the covariance between different variables – these are zero if the variables are independent. - What does
linewidth=0
do? Thelinewidth
parameter inplot_surface
specifies the line width for the edges of the surface polygons. Setting it to 0 removes these edges.
Feel free to check out our full course on Matplotlib on the Finxter academy.
π Academy: Matplotlib – The Complete Guide to Becoming a Data Visualization Wizard