5 Best Ways to Plot a Confusion Matrix with String Axis in Python

πŸ’‘ Problem Formulation: Confusion matrices are a vital part of evaluating classification algorithms. The standard confusion matrix uses integer axis to indicate the classes. However, for better readability and interpretation, it often helps to label these axes with string representations of classes. This article demonstrates five methods to plot a confusion matrix in Python with string axis rather than integers, from using matplotlib to leveraging advanced libraries like seaborn and plotly.

Method 1: Matplotlib with xticklabels and yticklabels

This method uses Matplotlib’s imshow() function to plot the confusion matrix and manually sets the tick labels with the xticks() and yticks() functions. It is straightforward and customizable for simple use cases.

Here’s an example:

import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix

y_true = ['apple', 'banana', 'cherry', 'apple']
y_pred = ['banana', 'banana', 'cherry', 'apple']
labels = ['apple', 'banana', 'cherry']

cm = confusion_matrix(y_true, y_pred, labels=labels)
fig, ax = plt.subplots()
im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
ax.set_xticks(range(len(labels)))
ax.set_yticks(range(len(labels)))
ax.set_xticklabels(labels)
ax.set_yticklabels(labels)

plt.show()

The output is a graphical confusion matrix with string labels on both axes.

This code snippet generates a confusion matrix from the true and predicted labels, then uses Matplotlib to visualize it. The axes are labeled with the actual names of the classes to provide clarity on which class corresponds to which axis.

Method 2: Seaborn’s heatmap

Seaborn is a data visualization library that builds on Matplotlib and offers a more attractive interface. Its heatmap() function automatically annotates each cell with the numerical value and allows for an easy way to label the axes with string classes.

Here’s an example:

import seaborn as sns
from sklearn.metrics import confusion_matrix

y_true = ['dog', 'cat', 'bird', 'dog']
y_pred = ['cat', 'cat', 'bird', 'dog']
labels = ['dog', 'cat', 'bird']

cm = confusion_matrix(y_true, y_pred, labels=labels)
sns.heatmap(cm, annot=True, fmt='d', xticklabels=labels, yticklabels=labels)
plt.show()

The output is a heatmap representing the confusion matrix with strings as axis ticks.

In this code, a confusion matrix is computed and then plotted as a heatmap using Seaborn. The classes for the axes are specified as strings directly within the heatmap() function call, making this method very concise.

Method 3: Pandas DataFrame with Matplotlib

By converting the confusion matrix to a Pandas DataFrame, one can take advantage of the data manipulation capabilities of Pandas. This method allows for more complex manipulation and can be combined with Matplotlib for plotting.

Here’s an example:

import pandas as pd
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix

y_true = ['spring', 'summer', 'fall', 'spring']
y_pred = ['summer', 'summer', 'fall', 'spring']
classes = ['spring', 'summer', 'fall']

cm = confusion_matrix(y_true, y_pred, labels=classes)
df_cm = pd.DataFrame(cm, index=classes, columns=classes)

plt.figure(figsize = (10,7))
sns.heatmap(df_cm, annot=True)
plt.show()

The output is a Matplotlib figure with a heatmap displaying the confusion matrix, with string labels derived from a Pandas DataFrame.

This approach involves creating a DataFrame from the confusion matrix data, which then provides the axes labels when plotted. It makes it simple to work with more complicated layouts and manipulate data labels before plotting.

Method 4: Plotly’s Interactive Heatmap

Plotly offers an interactive plotting experience. Its heatmap() in Python can be used to achieve an interactive confusion matrix plot with string axis labels, which is great for web-based presentations and reports.

Here’s an example:

import plotly.figure_factory as ff
from sklearn.metrics import confusion_matrix

y_true = ['red', 'blue', 'green', 'red']
y_pred = ['blue', 'blue', 'green', 'red']
colors = ['red', 'blue', 'green']

cm = confusion_matrix(y_true, y_pred, labels=colors)
fig = ff.create_annotated_heatmap(cm, x=colors, y=colors)
fig.show()

The output is an interactive heatmap that displays the confusion matrix with strings on both axes.

The code uses Plotly to plot the confusion matrix as an interactive annotated heatmap. The string labels are set directly in the x and y parameters of the create_annotated_heatmap() function.

Bonus One-Liner Method 5: Using ConfusionMatrixDisplay from scikit-learn

Scikit-learn’s ConfusionMatrixDisplay class can be utilized to plot the confusion matrix with string labels in one line of code. It can also be customized further if needed.

Here’s an example:

from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

y_true = ['cow', 'pig', 'hen', 'cow']
y_pred = ['pig', 'pig', 'hen', 'cow']
labels = ['cow', 'pig', 'hen']

cm = confusion_matrix(y_true, y_pred, labels=labels)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels)
disp.plot()
plt.show()

The output is a confusion matrix with the specified string labels on the axes.

This method is highly efficient, leveraging the ConfusionMatrixDisplay class to handle the plotting of the confusion matrix with provided display labels without additional boilerplate code.

Summary/Discussion

  • Method 1: Matplotlib manual tick labeling. Strengths: high customization, direct control of labels. Weaknesses: more boilerplate code.
  • Method 2: Seaborn heatmap. Strengths: simplicity, integrated labeling, attractive defaults. Weaknesses: less customization for complex needs.
  • Method 3: DataFrame with Seaborn/Matplotlib. Strengths: use of DataFrame functionality, good for complex data manipulation. Weaknesses: heavier code compared to direct plotting methods.
  • Method 4: Plotly’s interactive heatmap. Strengths: interactive features, useful for web reports. Weaknesses: may require more resources, not as straightforward for quick static plots.
  • Bonus Method 5: Scikit-learn’s ConfusionMatrixDisplay. Strengths: quickest and simplest method for standard use cases. Weaknesses: less customizable than manual methods.