Friday, November 25, 2011

Memoization of recursive functions in C#

Today we discuss strategies for memoization of recursive functions in C#.

Consider the problem of computing the n'th Fibbonaci number (1, 1, 2, 3, 5, 8, 13, ...).  This problem can cast as a recursive function, namely:
  • F(0) = 0
  • F(1) = 1
  • F(n) = F(n-2) + F(n-1) for all n > 1
This can be computed by a simple recursive function:

static BigInteger simpleFib(int n)
{
    if( n < 1 )
        return 0;
    if( n == 1 )
        return 1;
    
    return simpleFib(n-2) + simpleFib(n-1);
}

However, this function does not yield good performance. Because of the recursive step for computing any function n > 2, it re-calculates (expensively!) many calls. This problem can be solved with memoization: the program will keep track of the values for results of the function in a static cache (using the Dictionary<,> object).  This generates a significant speedup:

static Dictionary<int, BigInteger> cache = new Dictionary<int,BigInteger>();
//Set up the cache in Main() with cache[0] = 0 and cache[1] = 1;

static BigInteger recursiveFib(int n)
{
    //Return a value from the cache if available
    BigInteger result;
    if( cache.TryGetValue(n, out result) )
        return result;
    
    //Otherwise recursively calculate, add to cache and return
    result = recursiveFib(n-2) + recursiveFib(n-1);
    cache.Add(n, result);
    return result;
}

Although achieving good runtime performance, this function is limited by the size of the callstack.  We can shift the memory burden away from the callstack and on to the heap by converting the recursive call into a loop, keeping track of the arguments to the current 'recursive call' being executed by the loop using a Stack<int>.  The strategy will be to instead call Stack.Push(n) for values of N that need to be calculated and saved in the memoized cache.  There are three cases:
  1. The value is already in the memoized cache.  If so, continue with the next item on the stack immeditially.
  2. The value is not in the cache, but the recursive terms (F(n-2) and F(n-1)) are already known in the cache.  If so, calculate F(n) immeditially, save the the cache and continue with the next item on the stack.
  3. Either F(n-2) or F(n-1) (or both) is not known in the cache.  In that case, (a) push the current n back on to the stack so that it will be re-examined.  Then push n-2 and n-1 on to the stack, if their value is not known so that their value will be computed first (before the stack reaches n again).
When the loop terminates, the value of F(n) must be in the memoized cache, so it can be returned directly.  The code for this is:

static Dictionary<int, BigInteger> cache = new Dictionary<int,BigInteger>();
//Set up the cache in Main() with cache[0] = 0 and cache[1] = 1;

static BigInteger fib(int n)
{
    Stack<int> parameterStack = new Stack<int>();
    parameterStack.Push(n);

    while( parameterStack.Count > 0 )
    {
        int current = parameterStack.Pop();
    
        //Continue if already in the cache
        BigInteger result = 0;
        if( cache.TryGetValue(current, out result) )
            continue;
        
        //Are the fib(n-1) an fib(n-2) already in the cache?
        BigInteger oneResult;
        BigInteger twoResult;
        bool twoInCache = cache.TryGetValue(current - 2,
                                            out oneResult);
        bool oneInCache = cache.TryGetValue(current - 1,
                                            out twoResult);
        
        //If so, remember this result and return
        if( twoInCache && oneInCache )
        {
            cache.Add(current, oneResult + twoResult);
            continue;
        }
        
        //Otherwise, push the value back onto the stack 
        //(plus either n-2 and n-1) to calculate recursively
        parameterStack.Push(current);
        if( !oneInCache ) parameterStack.Push(current - 1);
        if( !twoInCache ) parameterStack.Push(current - 2);
    }
    
    //Return the value from the cache
    return cache[n];
}

This achieves the best memoized performance without triggering stack overflows.

Note that for the Fibonacci sequence there is a simple, trivial O(n) way to calculate it [i.e. starting from 1, 1, ... and generating the sequence up to n).  It has simply been used here as a simple example of a recursive function that can be computed with memoization.  There are other functions (such as the Ackermann function), for which memoization may be the best strategy to compute the result.

1 comment: