How to Color a Scatter Plot by Category using Matplotlib in Python

Rate this post

Problem Formulation

Given three arrays:

  • The first two arrays x and y of length n contain the (x_i, y_i) data of a 2D coordinate system.
  • The third array c provides categorical label information so we essentially get n data bundles (x_i, y_i, c_i) for an arbitrary number of categories c_i.

πŸ’¬ Question: How to plot the data so that (x_i, y_i) and (x_j, y_j) with the same category c_i == c_j have the same color?

Solution: Use Pandas groupby() and Call plt.plot() Separately for Each Group

To plot data by category, you iterate over all groups separately by using the data.groupby() operation. For each group, you execute the plt.plot() operation to plot only the data in the group.

In particular, you perform the following steps:

  1. Use the data.groupby("Category") function assuming that data is a Pandas DataFrame containing the x, y, and category columns for n data points (rows).
  2. Iterate over all (name, group) tuples in the grouping operation result obtained from step one.
  3. Use plt.plot(group["X"], group["Y"], marker="o", linestyle="", label=name) to plot each group separately using the x, y data and name as a label.

Here’s what that looks like in code:

import pandas as pd
import matplotlib.pyplot as plt

# Generate the categorical data
x = [1, 2, 3, 4, 5, 6]
y = [42, 41, 40, 39, 38, 37]
c = ['a', 'b', 'a', 'b', 'b', 'a']

data = pd.DataFrame({"X": x, "Y": y, "Category": c})

# Plot data by category
groups = data.groupby("Category")
for name, group in groups:
    plt.plot(group["X"], group["Y"], marker="o", linestyle="", label=name)


Before I show you how the resulting plot looks, allow me to show you the data output from the print() function. Here’s the output of the categorical data:

   X   Y Category
0  1  42        a
1  2  41        b
2  3  40        a
3  4  39        b
4  5  38        b
5  6  37        a

Now, how does the colored category plot look like? Here’s how:

If you want to learn more about Matplotlib, feel free to check out our full blog tutorial series: