5 Best Ways to Count the Number of BSTs with N Nodes in Python

Rate this post

πŸ’‘ Problem Formulation: Given a positive integer n, the task is to find the number of distinct Binary Search Trees (BSTs) that can be created using n distinct nodes. Each node contains a unique key. An example input could be n = 3, and the desired output would be 5, as there are five structurally unique BSTs that can be formed with three nodes.

Method 1: Dynamic Programming

Dynamic Programming is a method that solves problems by breaking them down into simpler subproblems. For counting BSTs, we can use the fact that the number of BSTs with n nodes is the sum of all BSTs possible by taking each node as the root. This method is often referred to as computing the nth Catalan number.

Here’s an example:

def num_trees(n):
    G = [0] * (n+1)
    G[0], G[1] = 1, 1
    
    for i in range(2, n+1):
        for j in range(1, i+1):
            G[i] += G[j-1] * G[i-j]
    
    return G[n]

print(num_trees(3))

Output: 5

This function num_trees initializes a list to store intermediate results and fill it in a bottom-up manner. We iterate over each count from 2 to n, and for each count, we calculate the number of unique BSTs by summing the Cartesian product of left and right subtrees counts.

Method 2: Recursive Solution

The recursive solution leverages the recursive nature of BSTs. A BST with n nodes can be recursively split into a left subtree, a root, and a right subtree. We can recursively calculate the number of unique BSTs by fixing each node as the root and combining the possibilities.

Here’s an example:

def num_trees_recursive(n):
    if n == 0 or n == 1:
        return 1
    total_trees = 0
    for i in range(1, n + 1):
        left_trees = num_trees_recursive(i - 1)
        right_trees = num_trees_recursive(n - i)
        total_trees += left_trees * right_trees
    return total_trees

print(num_trees_recursive(3))

Output: 5

The num_trees_recursive function is a straightforward recursive implementation. It calculates the number of unique BSTs for each possibility of the root node and sums them up. This method has a higher time complexity due to repeated calculations.

Method 3: Memoization

Memoization enhances the recursive method by storing the results of subproblems to avoid redundant computations. This optimization can significantly reduce the time complexity of the recursive solution for counting BSTs.

Here’s an example:

def num_trees_memoized(n, memo=None):
    if memo is None:
        memo = {}
    if n in memo:
        return memo[n]
    if n == 0 or n == 1:
        return 1
    total_trees = 0
    for i in range(1, n + 1):
        left_trees = num_trees_memoized(i - 1, memo)
        right_trees = num_trees_memoized(n - i, memo)
        total_trees += left_trees * right_trees
    memo[n] = total_trees
    return total_trees

print(num_trees_memoized(3))

Output: 5

The function num_trees_memoized accepts an integer n and a memoization dictionary. It computes the number of unique BSTs and stores the result in a memoization dictionary if it’s not already present, greatly reducing redundant calculations.

Method 4: Mathematical Formula

The number of BSTs with n nodes can also be computed directly using a mathematical formula based on Catalan numbers. This is the most efficient method as it has a linear time complexity.

Here’s an example:

from math import factorial

def num_trees_math(n):
    return factorial(2 * n) // (factorial(n + 1) * factorial(n))

print(num_trees_math(3))

Output: 5

In the num_trees_math function, we calculate the nth Catalan number directly using a factorial-based formula that represents the exact number of unique BSTs that can be made with n nodes.

Bonus One-Liner Method 5: functools

The Python functools module can be used to apply memoization in a one-liner approach, which simplifies the implementation and still benefits from optimized performance.

Here’s an example:

from functools import lru_cache

@lru_cache(maxsize=None)
def num_trees_oneliner(n):
    if n < 2:
        return 1
    return sum(num_trees_oneliner(i) * num_trees_oneliner(n - i - 1) for i in range(n))

print(num_trees_oneliner(3))

Output: 5

By decorating the function num_trees_oneliner with @lru_cache, Python will automatically memoize the results which allows this recursive function to be written more concisely while still being efficient.

Summary/Discussion

  • Method 1: Dynamic Programming. Provides an optimal, bottom-up approach. Efficient and widely used. Rain drawback is that it requires additional space for the DP array.
  • Method 2: Recursive Solution. Simple and intuitive, but inefficient due to repeated calculations.
  • Method 3: Memoization. An optimized version of the recursive method. It is much more efficient but requires additional code for caching.
  • Method 4: Mathematical Formula. Offers a direct calculation leading to the most efficient solution. However, requires understanding of combinatorial mathematics for Catalan numbers.
  • Method 5: functools. Simplifies the recursive memoization approach through a Python module. Neat and effective but may abstract away the learning process of memoization for beginners.