How to avoid recursion?

The recursion problem

I remember when I was learning programming languages at college, we started with Camel coding a lot of recursive functions.

Recursion is sometimes a very elegant and easy way to write a function.

However, it can become a big problem when the number of recursion becomes important.

Imagine a “scholar” recursive function: Fibonacci.

You can basically code it with this way:

public static BigInteger Fibonacci(int n)
{
    switch (n)
    {
        case 0: return 0;
        case 1: return 1;
        default: return Fibonacci(n – 1) + Fibonacci(n – 2);
    }
}

In my first college programming languages courses, I learnt to write it like this.

However, I now think this is bad for 2 reasons:

  • To calculate Fibonacci(n – 1) we need to recalculate Fibonacci(n -2) which makes the perf in O(exp(n)).
    Indeed, to calculate Fibonacci(2) this algorithm needs to do 1 addition.
    To calculate Fibonacci(3), it needs to do 2 additions.
    To calculate Fibonacci(4), it needs to do 4 additions.
    To calculate Fibonacci(5), it needs to do 7 additions.
    To calculate Fibonacci(6), it needs to do 12 additions.
    To calculate Fibonacci(7), it needs to do 20 additions.
    To calculate Fibonacci(8), it needs to do 33 additions.
    To calculate Fibonacci(25), it needs to do 121392 additions.
    To calculate Fibonacci(50), it needs to do 2.04E+10 additions.
    etc.

    With this algorithm, Fibonacci(40) needs more than 12 seconds to run in my machine and almost 27 seconds for Fibonacci(41).

  • If n equals 100,000 we will have a StackOverflowException because the call stack becomes too big.

How to avoid recursion?

To avoid recursion, we can use two data structures that will help us a lot: Stack and Queue.

For the Fibonacci function, we can use the following code for example:

public static BigInteger Fibonacci(int n)
{
    if (n < 2)
    {
        return n;
    }
 
    var q = new Queue<BigInteger>();
    q.Enqueue(0);
    q.Enqueue(1);
    for (int i = 2; i <= n; i++)
    {
        BigInteger n2 = q.Dequeue();
        BigInteger n1 = q.Peek();
        q.Enqueue(n2 + n1);
    }
    q.Dequeue();
    return q.Dequeue();
}

So basically the idea here is to always have the two last values. So here I used a queue for the sample however, because the number of elements we need is always 2, we can remove the queue:

public static BigInteger Fibonacci(int n)
{
    if (n < 2)
    {
        return n;
    }
 
    BigInteger n2 = 0;
    BigInteger n1 = 1;
    for (int i = 2; i <= n; i++)
    {
        BigInteger n0 = n2 + n1;
        n2 = n1;
        n1 = n0;
    }
    return n1;
}
 

With this algorithm, the perf is O(n) and there is no more StackOverflowException and Fibonacci(41) that takes almost 27 seconds with the recursive way now takes less than 1 millisecond! (and of course, the bigger n is, the better the improvement ratio is).

With the non-recursive way, it takes me only 15 milliseconds for Fibonacci(100000).

Real world scenario

The relation between primes and Fibonacci is very interesting and I’m sure that there is plenty of usage but in my case, I will probably never use it outside students, recruiting or blogging work.

As some of you know, I’m a developer in the U-SQL compiler team and U-SQL uses Roslyn (the C# compiler).

As the code is generally represented as a tree, one of the important task you have to do when you work on a compiler is tree manipulation.

For this, we can visit the tree or update the tree. This generally needs to go through the full tree.

The way to scan the tree can be the recursive one. In this case, it won’t have the performance problem we saw with the Fibonacci sample but we may have some StackOverflowException because of a potential excessive growth of the call stack.
Actually, the CSharpSyntaxRewriter implemented in Roslyn uses the recursive way and we saw some customer scripts failing because of it.

So concretely, the way it works is the following:

When we visit a node, it visits its children and then update itself.

For example:

public override CSharpSyntaxNode VisitBinaryExpression(BinaryExpressionSyntax node)
{
    var left = (ExpressionSyntax)this.Visit(node.Left);
    var operatorToken = (SyntaxToken)this.Visit(node.OperatorToken);
    var right = (ExpressionSyntax)this.Visit(node.Right);
    return node.Update(left, operatorToken, right);
}
 

Basically for “System.Console.WriteLine(“Hello” + ” “ + “{0}”, “world”);”, we will visit this tree with the mentioned index order:

Note that if arguments didn’t change, the update method will return the original node.

In order to avoid recursion, I created two new classes: CSharpNonRecursiveSyntaxWalker and CSharpNonRecursiveSyntaxRewriter to visit or update the tree without recursion (https://github.com/dotnet/roslyn/pull/12494).

How to make a syntax walker non recursive?

The non-recursive tree walker is very easy: instead of using recursion, we can use a Stack.

Roslyn provide a visitor: CSharpSyntaxVisitor. This visitor does nothing but providing virtual methods for each type of node that calls DefaultVisit that does nothing.

I created a new class CSharpNonRecursiveSyntaxWalker that does not visit children in each specific visit but push current node children in the stack and pop it while we have new elements in it:

public override void Visit(SyntaxNode node)
{
    int stackStart = _stack.Count;
    _stack.Push(node);
    while (_stack.Count > stackStart)
    {
        SyntaxNodeOrToken n = _stack.Pop();
        if (n.IsToken)
        {
            this.VisitToken(n.AsToken());
        }
        else if (!this.Skip(n.AsNode()))
        {
            this.VisitNode(n.AsNode());
            var children = n.ChildNodesAndTokens();
            for (int i = children.Count – 1; i >= 0; i–)
            {
                _stack.Push(children[i]);
            }
        }
    }
}
 

You can easily use CSharpNonRecursiveSyntaxWalker instead of CSharpSyntaxWalker. The only difference is that you have to overwrite VisitNode instead of Visit to do something for every node.

How to make a syntax rewriter non recursive?

This class needs to visit the full tree and update nodes.

This one is more complex because we need the transformed children to transform a node.

If you look at my previous sample graph the BinaryExpressionSyntax for “”Hello” + ” ” + “{0}”” needs the 18th, 24th and 25th visited nodes so they are not consecutive.

I first used a solution using a Stack of delegates (to summarize) and I use another stack to push the result that I pop to transform parent node.

However, Matt Warren proposed me a better solution because it is really closer to the current CSharpSyntaxRewriter than mine and also allows to reduce the number of instantiations.

This is how the CSharpNonRecursiveSyntaxRewriter class works.

We have three stacks: _undeconstructedStack, _untransformedStack and _transformedStack and two private CSharpSyntaxRewriter classes: _deconstructor and _reassembler.

_undeconstructedStack is the same stack that we used in the CSharpNonRecursiveSyntaxWalker.

When we pop a node from _undeconstructedStack, we push its children into this stack and we push the node, its number of children and the current count of _transformedStack (both together using a struct) into _understandformed.

However, we cannot use ChildNodesAndTokens as we did in the CSharpNonRecursiveSyntaxWalker because we need null nodes for reconstruction.

So for this, we use the _deconstructor variable. The NodeDeconstructor class overrides Visit and VisitToken to know all the children of the node needed to reconstruct it without visiting their own children.

private class NodeDeconstructor : CSharpSyntaxRewriter
{
    private List<SyntaxNodeOrToken> _elements;
 
    public void Deconstruct(SyntaxNode node, List<SyntaxNodeOrToken> elements)
    {
        _elements = elements;
        ((CSharpSyntaxNode)node).Accept(this);
    }
 
    public override SyntaxNode Visit(SyntaxNode node)
    {
        _elements.Add(node);
        return node;
    }
 
    public override SyntaxToken VisitToken(SyntaxToken token)
    {
        _elements.Add(token);
        return token;
    }
}
 

Then when all children are transformed (_transformedStack.Count == childCount + originalTransformedStackCount) we transform the node using reassembler.

As for NodeDeconstructor, this class hacks the CSharpSyntaxRewriter in order to reconstruct the current node using the already transformed children.

private class NodeReassembler : CSharpSyntaxRewriter
{
    private List<SyntaxNodeOrToken> _elements;
    private int _index;
 
    public SyntaxNodeOrToken Reassemble(SyntaxNodeOrToken original, List<SyntaxNodeOrToken> rewrittenElements)
    {
        _elements = rewrittenElements;
        _index = 0;
        return ((CSharpSyntaxNode)original).Accept(this);
    }
 
    public override SyntaxNode Visit(SyntaxNode node)
    {
        return _elements[_index++].AsNode();
    }
 
    public override SyntaxToken VisitToken(SyntaxToken token)
    {
        return _elements[_index++].AsToken();
    }
}
 

So to summarize this class works like this: Accept will call the specific Visit method (ex: the VisitBinaryExpression we saw earlier for a BinaryExpressionSyntax) which calls Visit on each children which is overridden to return the already transformed node.

Using this way, the latest transformed node is the expected result of the rewrite.

public SyntaxNode Rewrite(SyntaxNode node)
{
    int undeconstructedStart = _undeconstructedStack.Count;
    int untransformedStart = _untransformedStack.Count;
    int transformedStart = _transformedStack.Count;
 
    // add initial node so we have something to work on
    _undeconstructedStack.Push(node);
 
    // as long as there is more to deconstruct, there is more work to do
    while (_undeconstructedStack.Count > undeconstructedStart)
    {
        var nodeOrToken = _undeconstructedStack.Pop();
        if (nodeOrToken.IsNode)
        {
            node = nodeOrToken.AsNode();
 
            if (node == null)
            {
                // nulls just stay nulls, they don’t get transformed
                _transformedStack.Push(nodeOrToken);
            }
            else
            {
                SyntaxNodeOrToken rewriten;
                if (this.Skip(node, out rewriten))
                {
                    _transformedStack.Push(rewriten);
                }
                else
                {
                    // deconstruct node into child elements
                    _children.Clear();
                    _deconstructor.Deconstruct(node, _children);
 
                    // add child elements to undeconstructed stack in reverse order so
                    // the first child gets operated on next
                    for (int i = _children.Count – 1; i >= 0; i–)
                    {
                        _undeconstructedStack.Push(_children[i]);
                    }
 
                    // remember the node that will be tranformed later after the children are transformed
                    _untransformedStack.Push(new UntransformedNode(node, _children.Count, _transformedStack.Count));
                }
            }
        }
        else if (nodeOrToken.IsToken)
        {
            // we can transform tokens immediately
            var original = nodeOrToken.AsToken();
            SyntaxNodeOrToken rewriten;
            if (this.Skip(original, out rewriten))
            {
                _transformedStack.Push(rewriten);
            }
            else
            {
                var rewrittenToken = _trivializer.VisitToken(original); // rewrite trivia
                var transformed = this.VisitToken(original, rewrittenToken);
                _transformedStack.Push(transformed);
            }
        }
 
        // transform any nodes that can be transformed now
        while (_untransformedStack.Count > untransformedStart
            && _untransformedStack.Peek().HasAllChildrenOnStack(_transformedStack))
        {
            var untransformed = _untransformedStack.Pop();
 
            // gather transformed children for this node
            _children.Clear();
            for (int i = 0; i < untransformed.ChildCount; i++)
            {
                _children.Add(_transformedStack.Pop());
            }
 
            _children.Reverse();
 
            // reassemble original node with tranformed children
            var rewritten = _reassembler.Reassemble(untransformed.Node, _children);
 
            // now tranform the node
            var save = this.Original;
            this.Original = untransformed.Node;
            var transformed = this.VisitNode(untransformed.Node, rewritten.AsNode());
            this.Original = save;
 
            // add newly transformed node to the transformed stack
            _transformedStack.Push(transformed);
        }
    }
 
    Debug.Assert(_untransformedStack.Count == untransformedStart);
    Debug.Assert(_transformedStack.Count == transformedStart + 1);
 
    return _transformedStack.Pop().AsNode();
}
 

Now, you can inherit of this new class instead of CSharpSyntaxRewriter with the same change that for CSharpNonRecursiveSyntaxWalker: using VisitNode instead of Visit.

public virtual SyntaxNode VisitNode(SyntaxNode original, SyntaxNode rewritten)
{
    return ((CSharpSyntaxNode)rewritten).Accept(this);
}

Using the non-recursive implementation will improve reliability of your walker / rewriter trees with potential very big depth.

However, note that, in order to avoid recursion on rewriter, you cannot visit one part of the children and the other part of it later. For example, in U-SQL, I coded a constant folder that transforms
“anExpressionThatReturnsTrue || whateverExpression” into “true” skipping the visit of whateverExpression.

In order to allows short-circuiting scenarios, you can override ShouldRewriteChildren method:

protected virtual bool ShouldRewriteChildren(SyntaxNodeOrToken nodeOrToken, out SyntaxNodeOrToken rewritten)  

Note that using this rewriter may implies changes in your implementation.

Indeed, the pattern on Visit is often the following:

public override SyntaxNode VisitBinaryExpression(BinaryExpressionSyntax originalNode)
{
    var rewritten = base.VisitBinaryExpression(originalNode);
    //…
}
 

With the CSharpNonRecursiveSyntaxRewriter, this becomes:

public override SyntaxNode VisitBinaryExpression(BinaryExpressionSyntax rewritten)
{
    var originalNode = this. Original;
    //…
}
 

8 thoughts on “How to avoid recursion?”

  1. I can’t imagine any scenario where it can be useful.
    And can’t imagine such big syntax tree which can cause stack overflow.
    Could you please provide more examples?

    1. The scenarios we got were (very probably generated) queries with thousands of ||. Ex: c == 1 || c == 2 || … || c == 1000000
      In U-SQL it would be more efficient to use a constant table and to use a join on it but anyway when customers provide queries like these we don’t want to fail with an internal exception because of stack overflow.

  2. IMHO one of the biggest advantages of a Recursive approach is its ability to cut down on the amount of code one must write in order to achieve the solution, and this is a great example of that. You’ve managed to turn what should be a single line method:

    return n < 2 ? n : Fibonacci(n-1) + Fibonacci(n-2)

    into 11-12 lines of C# that in turn requires heap allocation and extra objects in order to support it. Sure, recursion isn't the answer to everything, and may not always provide the most efficient algorithm, but that's no reason to start arbitrarily excising recursion from one's toolkit.

    1. If you look at my method without the queue, I don’t have any heap allocation.
      Then, you’re right: code is shorter. But is it really a real argument for you vs performance and StackOverflowException?

      1. As I see it, it’s often the matter of how the recursion is designed. In the fibonacci example You can easily create a well performing recursion like so:

        public static BigInteger TailFibonacci(int n, BigInteger prev1, BigInteger prev2)
        {
        switch (n)
        {
        case 1: return 0;
        case 2: return 1;
        case 3: return prev1 + prev2;
        default: return TailFibonacci (n-1, prev1 + prev2, prev1);
        }
        }

        [All the code comparing three versions I placed on http://share.linqpad.net/gk2g46.linq – to be executed in LinqPad]

        On my machine it runs for n = 10000 in 13 milliseconds, while Your version using queue runs in 8 ms. For bigger n values my version obviously blows the stack but since it’s tail recursive, You could rewrite it in F# and all would be fine even for really big numbers.

        I believe that this improved recursion is still more readable than imperative version. And it actually uses the same trick of climbing up with the numbers, starting from fib(0) and fib(1).

        The difference is that while You created Your own stack explicitly, tail recursive version still leverages the one provided by execution environment.

        I’m not sure whether the same could be done for Your real life example, and if it would – whether it’d be readable.

Leave a Reply

Your email address will not be published. Required fields are marked *