Counting Unique Binary Search Trees for 0 to n Values in Python

Rate this post

πŸ’‘ Problem Formulation: Given an integer n, the task is to determine the number of distinct binary search trees (BSTs) that can be created using numbers ranging from 0 to n. For example, if n = 3, there are five unique BSTs that can be constructed. This article will explore several methods to calculate this number efficiently in Python.

Method 1: Dynamic Programming

The Dynamic Programming approach calculates the number of unique BSTs by storing the results of subproblems. The key insight is that the number of trees depends on the number of left and right subtrees which can be formed recursively. Function specification: given an integer n, return the count of unique BSTs.

Here’s an example:

def num_trees(n):
    dp = [0] * (n+1)
    dp[0], dp[1] = 1, 1
    for i in range(2, n+1):
        for j in range(1, i+1):
            dp[i] += dp[j-1] * dp[i-j]
    return dp[n]


Output: 5

This code uses a bottom-up dynamic programming approach. It initializes a list dp where each index i represents the count of unique BSTs that can be formed with i nodes. Through nested loops, the code accumulates the possible tree counts by considering each number as a root and multiplying the count of left and right subtrees, which are just previous results from the dp array.

Method 2: Mathematical Deduction (Catalan Number)

The count of unique BSTs for a sequence of numbers can be directly calculated using the Catalan number formula. This method provides a direct calculation without recursion or dynamic programming, making it efficient for larger n. The function takes an integer n and returns the count of unique BSTs representable as the nth Catalan number.

Here’s an example:

import math

def num_trees_catalan(n):
    return math.comb(2*n, n) // (n + 1)

Output: 5

This snippet utilizes the combinatorial nature of the problem. The Catalan number for a given n can be given by the binomial coefficient C(2n, n)/(n+1), which is computed using Python’s math.comb() function to calculate the combination, followed by integer division to improve computational efficiency.

Method 3: Memoization

Memoization involves storing previously computed values to avoid redundant calculations. This technique optimizes the recursive approach by caching results, therefore significantly reducing the number of computations. Function specification: given an integer n, return the count of unique BSTs using memoization.

Here’s an example:

def num_trees_memo(n, memo={}):
    if n in memo:
        return memo[n]
    if n <= 1:
        return 1
    total = 0
    for i in range(1, n+1):
        total += num_trees_memo(i-1, memo) * num_trees_memo(n-i, memo)
    memo[n] = total
    return total


Output: 5

This code uses a recursive function with a memoization dictionary. It checks if the result for a given n is already in the dictionary to avoid redundant calculations. If not, it computes it recursively and stores the result in the memo dictionary.

Method 4: Recursion

Recursion is a straightforward but not as efficient method. It involves breaking down the problem into smaller subproblems, similar to dynamic programming, but without storing intermediate results, which leads to repeated calculations. It gives a clear but slower recursive function implementation. Function specification: given an integer n, return the count of unique BSTs by recursion.

Here’s an example:

def num_trees_recursive(n):
    if n <= 1:
        return 1
    total = 0
    for i in range(1, n + 1):
        total += num_trees_recursive(i - 1) * num_trees_recursive(n - i)
    return total


Output: 5

In this recursive approach, the number of trees is counted by assuming each number in turn to be the root of the tree. The product of the number of trees that can be formed to the left and right of each root is summed up, which gives the total for a given n.

Bonus One-Liner Method 5: Use of LRU Cache

The functools.lru_cache decorator can be used to add memoization to the recursive solution in a single line. This approach is simple to implement and maintains the clarity of the recursive solution with the efficiency of memoization.

Here’s an example:

from functools import lru_cache

def num_trees_lru(n):
    if n <= 1:
        return 1
    return sum(num_trees_lru(i - 1) * num_trees_lru(n - i) for i in range(1, n+1))


Output: 5

By adding @lru_cache(maxsize=None) as a decorator to the recursive function, Python caches the results of the function calls, which prevents redundant calculations while keeping the code simple and easy to understand.


  • Method 1: Dynamic Programming. Highly efficient for large n. It avoids redundant computations by using stored values. However, it may consume more memory for very large n.
  • Method 2: Mathematical Deduction. Provides direct computation. It’s fast but might be prone to arithmetic overflow for larger n if not implemented with care.
  • Method 3: Memoization. Optimizes the recursive approach. Delivers substantial performance benefits but still has more function call overhead than dynamic programming.
  • Method 4: Recursion. The most intuitive method. It’s the simplest to implement but is the least efficient due to repetitive calculations.
  • Bonus Method 5: Use of LRU Cache. Offers simplicity like the recursive method but with the performance close to dynamic programming. Best for keeping the code concise without sacrificing speed.