Understanding Array Axis Summations with Einstein Summation Convention in Python

Rate this post

πŸ’‘ Problem Formulation: When working with multidimensional arrays in Python, it becomes necessary to perform summations over specific axes efficiently. Given a multidimensional array, we often need to reduce its dimensions by summing along one axis or more, following Einstein’s summation convention. For instance, if we have a 3x3x3 array, we might want to sum over the second axis to obtain a 3×3 result. This operation is common in scientific computing, data analysis, and machine learning tasks.

Method 1: Using NumPy’s einsum Function

The einsum function in NumPy is a versatile and efficient way to compute Einstein summations. It allows specifying the summation with subscript notation, giving control over which axes to sum and in what order. Its compact syntax can handle complex operations on high-dimensional arrays, often outperforming traditional loop-based approaches.

Here’s an example:

import numpy as np

array = np.arange(27).reshape((3, 3, 3))
result = np.einsum('ijk->ik', array)
print(result)

Output:

[[ 9 12 15]
 [36 39 42]
 [63 66 69]]

In this code snippet, we use NumPy’s einsum to sum the elements of the second axis (j) of a 3x3x3 array. The string ‘ijk->ik’ represents the Einstein summation convention, where ‘j’ is omitted in the output, indicating that we are summing over it. The result is a 2D array where each element is the sum of the corresponding column across all 3×3 matrices.

Method 2: Multi-Axis Summation Using np.sum

The np.sum function is a straightforward approach to summing elements along a specified axis. By providing the axis parameter, you can sum across any number of dimensions in a clear and intuitive manner. Though lacking the advanced notation of einsum, it is user-friendly for simpler summation tasks.

Here’s an example:

import numpy as np

array = np.arange(27).reshape((3, 3, 3))
result = np.sum(array, axis=1)
print(result)

Output:

[[ 9 12 15]
 [36 39 42]
 [63 66 69]]

By calling np.sum(array, axis=1), we achieve the same result as the previous method, summing over the second axis of our 3D array. This method is simpler and might be more approachable for those unfamiliar with Einstein notation. It’s suitable for when we only need to sum over one or multiple entire axes.

Method 3: Summation with Advanced Indexing and Slicing

Advanced indexing and slicing can be used to construct more explicit and tailored summation operations. This approach requires a more intricate understanding of the array’s structure to correctly slice and combine the desired dimensions. It provides a flexible, albeit potentially less readable, alternative to predefined functions.

Here’s an example:

import numpy as np

array = np.arange(27).reshape((3, 3, 3))
result = array.sum(axis=1)
print(result)

Output:

[[ 9 12 15]
 [36 39 42]
 [63 66 69]]

Advanced indexing and slicing are effectively leveraged by the sum method available on the NumPy array object, offering a similar functionality to np.sum. Invoking array.sum(axis=1) explicitly summing, the columns are the same as the previous examples but using the array’s own method.

Method 4: Summation using Tensordot

The tensordot function of NumPy computes the tensor dot product along specified axes, which can also be used to perform array summations. It is particularly useful for performing complex multi-dimensional array operations that go beyond simple axis summations.

Here’s an example:

import numpy as np

array = np.arange(27).reshape((3, 3, 3))
sum_over_axes = np.tensordot(array, np.ones(3), axes=([1],[0]))
print(sum_over_axes)

Output:

[[ 9 12 15]
 [36 39 42]
 [63 66 69]]

Here we use tensordot to sum over one axis by contracting our array with a dummy array of ones along the axis we wish to sum. Specifying axes=([1],[0]) indicates that we’re summing over the second axis of our original array. While powerful, this method can be less intuitive and is better suited for complex multi-array operations.

Bonus One-Liner Method 5: Summation with List Comprehensions

List comprehensions in Python offer a Pythonic way to perform summations and other operations in a concise and readable one-liner. This method may not be as efficient for large arrays but is handy for quick calculations or smaller datasets.

Here’s an example:

import numpy as np

array = np.arange(27).reshape((3, 3, 3))
result = [sum(matrix) for matrix in array]
print(result)

Output:

[array([ 9, 12, 15]), array([36, 39, 42]), array([63, 66, 69])]

This one-liner leverages list comprehensions to iteratively sum each 2D sub-array (‘matrix’) within our 3D array. The result is a list of arrays, each the sum of one of the original array’s ‘slices’. Although it is less efficient and lacks the flexibility of NumPy functions, it offers simplicity and readability for straightforward summation tasks.

Summary/Discussion

  • Method 1: NumPy’s einsum. Extremely flexible and efficient for complex operations. Can be unintuitive for those not familiar with Einstein summation notation.
  • Method 2: np.sum Function. User-friendly and clear syntax for straightforward axis summations. Less suitable for complex, advanced summation operations.
  • Method 3: Summation with Array Methods. Provides direct methods attached to array objects, simplifying syntax while offering similar functionality to np.sum.
  • Method 4: Summation using tensordot. Offers advanced functionality for tensor products but may overcomplicate simple summation tasks.
  • Bonus Method 5: List Comprehensions. Quick and Pythonic one-liners for small-scale summations but lack the performance benefits of NumPy functions.