Creating Beautiful Heatmaps with Seaborn

Rate this post

Heatmaps are a specific type of plot which exploits the combination of color schemes and numerical values for representing complex and articulated datasets. They are largely used in data science application that involves large numbers, like biology, economics and medicine.

In this video we will see how to create a heatmap for representing the total number of COVID-19 cases in the different USA countries, in different days. For achieving this result, we will exploit Seaborn, a Python package that provides lots of fancy and powerful functions for plotting data.

Here’s the code to be discussed:

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

#url of the .csv file
url = r"path of the .csv file"

# import the .csv file into a pandas DataFrame
df = pd.read_csv(url, sep = ';', thousands = ',')

# defining the array containing the states present in the study
states = np.array(df['state'].drop_duplicates())[:40]

#extracting the total cases for each day and each country
overall_cases = []
for state in states:
    tot_cases = []
    for i in range(len(df['state'])):
        if df['state'][i] == state:
            tot_cases.append(df['tot_cases'][i])
    overall_cases.append(tot_cases[:30])

data = pd.DataFrame(overall_cases).T
data.columns = states

#Plotting
fig = plt.figure()
ax = fig.subplots()
ax = sns.heatmap(data, annot = True, fmt="d", linewidths=0, cmap = 'viridis', xticklabels = True)
ax.invert_yaxis()
ax.set_xlabel('States')
ax.set_ylabel('Day n°')
plt.show()

Let’s dive into the code to learn Seaborn’s heatmap functionality in a step-by-step manner.

Importing the required libraries for this example

We start our script by importing the libraries requested for running this example; namely Numpy, Pandas, Matplotlib and Seaborn.

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

What’s in the data?

As mentioned in the introduction part, we will use the COVID-19 data that were also used in the article about Scipy.curve_fit() function. Data have been downloaded from the official website of the “Centers for Disease Control and Prevention” as a .csv file.

The file reports multiple information regarding the COVID-19 pandemic in the different US countries, such as the total number of cases, the number of new cases, the number of deaths etc…; all of them have been recorded every day, for multiple US countries.

We will generate a heatmap that displays in each slot the number of total cases recorded for a particular day in a particular US country. To do that, the first thing that should be done is to import the .csv file and to store it in a Pandas DataFrame.

Importing the data with Pandas

The data are stored in a .csv file; the different values are separated by a semi-colon while the thousands symbol is denoted with a comma. In order to import the .csv file within our python script, we exploit the Pandas function .read_csv() which accepts as input the path of the file and converts it into a Pandas DataFrame.

It is important to note that, when calling .read_csv(), we specify the separator, which in our case is “;” by saying “sep = ‘;’” and the symbol used for denoting the thousands, by writing “thousands = ‘,’”. All these things are contained in the following code lines:

#url of the .csv file
url = r"path of the file" 
# import the .csv file into a pandas DataFrame
df = pd.read_csv(url, sep = ';', thousands = ',')

Creating the arrays that will be used in the heatmap

At this point, we have to edit the created DataFrame in order to extract just the information that will be used for the creation of the heatmap.

The first values that we extract are the ones that describe the name of the countries in which the data have been recorded. To better identify all the categories that make up the DataFrame, we can type “df.columns” to print out the header of the file. Among the different categories present in the header, the one that we are interested in is “state”, in which we can find the name of all the states involved in this chart.

Since the data are recorded on daily basis, each line corresponds to the data collected for a single day in a specific state; as a result, the names of the states are repeated along this column. Since we do not want any repetition in our heatmap, we also have to remove the duplicates from the array.

We proceed further by defining a Numpy array called “states” in which we store all the values present under the column “state” of the DataFrame; in the same code line, we also apply the method .drop_duplicates() to remove any duplicate of that array. Since there are 60 states in the DataFrame, we limit our analysis to the first 40, in order not to create graphical problems in the labels of the heatmap x-axis, due to the limited window space.

#defining the array containing the states present in the study
states = np.array(df['state'].drop_duplicates())[:40]  

The next step is to extract the number of total cases, recorded for each day in each country. To do that, we exploit two nested for loops which allow us creating a list containing the n° of total cases (an integer number for each day) for every country present in the “states” array and appending them into another list called “overall_cases” which needs to be defined before calling the for loop.

#extracting the total cases for each day and each country
overall_cases = []

As you can see in the following code, in the first for loop we iterate over the different states that were previously stored into the “states” array; for each state, we define an empty list called “tot_cases” in which we will append the values referred to the total cases recorded at each day.

for state in states:
    tot_cases = []

Once we are within the first for loop (meaning that we are dealing with a single state), we initialize another for loop which iterates through all the total cases values stored for that particular state. This second for loop will start from the element 0 and iterate through all the values of the “state” column of our DataFrame. We achieve this by exploiting the functions range and len.

   for i in range(len(df['state'])):

Once we are within this second for loop, we want to append to the list “tot_cases” only the values that are referred to the state we are currently interested in (i.e the one defined in the first for loop, identified by the value of the variable “state”); we do this by using the following if statement:

       if df['state'][i] == state:
              tot_cases.append(df['tot_cases'][i])

When we are finished with appending the values of total cases for each day of a particular country to the “tot_cases” list, we exit from the inner for loop and store this list into the “overall_cases” one, which will then become a list of lists. Also in this case, we limit our analysis to the first 30 days, otherwise we would not have enough space in our heatmap for all the 286 values present in the DataFrame.

     overall_cases.append(tot_cases[:30])

In the next iteration, the code will start to analyze the second element of the “states” array, i.e. another country, will initialize an empty list called “tot_cases” and enter in the second for loop for appending all the values referred to that country in the different days and eventually, once finished, append the entire list to the list “overall_cases”; this procedure will be iterated for all the countries stored in the “states” array. At the end, we will have extracted all the values needed for generating our heatmap. 

Creating the DataFrame for the heatmap

As already introduced in the first part, we exploit the Seaborn function .heatmap() to generate our heatmap.

This function can take as input a pandas DataFrame that contains the rows, the columns and all the values for each cell that we want to display in our plot. We hence generate a new pandas DataFrame (we call it “data”) that contains the values stored in the list “overall_cases”; in this way, each row of this new DataFrame is referred to a specific state and each column to a specific day.

We then transpose this DataFrame by adding “.T” at the end of the code line, since in this way we can then insert the name of the states as the header of our Dataframe.

data = pd.DataFrame(overall_cases).T

The names of the states were previously stored in the array “states”, we can modify the header of the DataFrame using the following code:

data.columns = states

The DataFrame that will be used for generating the heatmap will have the following shape:

   CO  FL  AZ  SC  CT  NE  KY  WY  IA  ...  LA  ID  NV  GA  IN  AR  MD  NY  OR
 0   0   0   0   0   0   0   0   0   0  ...   0   0   0   0   0   0   0   0   0
 1   0   0   0   0   0   0   0   0   0  ...   0   0   0   0   0   0   0   0   0
 2   0   0   0   0   0   0   0   0   0  ...   0   0   0   0   0   0   0   0   0
 3   0   0   0   0   0   0   0   0   0  ...   0   0   0   0   0   0   0   0   0
 4   0   0   1   0   0   0   0   0   0  ...   0   0   0   0   0   0   0   0   0 

The row indexes represent the n° of the day in which the data are recorded while the columns of the header are the name of the states.

Generating the heatmap

After generating the usual plot window with the typical matplotlib functions, we call the Seaborn function .heatmap() to generate the heatmap.

The mandatory input of this function is the pandas DataFrame that we created in the previous section. There are then multiple optional input parameters that can improve our heatmap:

  • linewidths allows adding a white contour to each cell to better separate them, we just have to specify the width;
  • xticklabels modify the notation along the x-axis, if it’s equal to True, all the values of the array plotted as the x-axis will be displayed.
  • We can also chose the colormap of the heatmap by using cmap and specifying the name of an available heatmap (“viridis” or “magma” are very fancy but also the Seaborn default one is really cool);
  • finally, it is possible to display the numerical value of each cell by using the option annot = True; the numerical value will be displayed at the center of each cell.

The following lines contain the code for plotting the heatmap. One final observation regards the command .invert_yaxis(); since we plot the heatmap directly from a pandas DataFrame, the row index will be the “day n°”; hence it will start from 0 and increase as we go down along the rows. By adding .invert_yaxis() we reverse the y-axis, having day 0 at the bottom part of the heatmap.

#Plotting
fig = plt.figure()
ax = fig.subplots()
ax = sns.heatmap(data, annot = True, fmt="d", linewidths=0, cmap = 'viridis', xticklabels = True)
ax.invert_yaxis()
ax.set_xlabel('States')
ax.set_ylabel('Day n°')
plt.show() 

Figure 1 displays the heatmap obtained by this code snippet.

Figure 1: Heatmap representing the number of COVID-19 total cases for the first 30 days of measurement (y-axis) in the different USA countries (x-axis).

As you can see in Figure 1, there are a lot of zeroes, this is because we decided to plot the data related to the first 30 days of measurement, in which the n° of recorded cases were very low. If we decided to plot the results from all the days of measurement (from day 0 to 286), we would obtain the result displayed in Figure 2 (in this latter case, we placed annot equal to False since the numbers would have been too large for the cell size):

Figure 2: Heatmap representing the number of COVID-19 total cases for the first 286 days of measurement (y-axis) in the different USA countries (x-axis); this time annot = False, since the cells are too small for accommodating the n° of total cases (which becomes very large towards the upper part of the heatmap).