All Articles

Tail-recursion

Intro

The canonical example of recursion is the Fibonacci sequence in which each element in the sequence is equal to the sum of the former two elements (with a1=a2=1a_1=a_2=1 usually by convention). List-based recursion is slightly different. One typically begins with a list [x1,,xn][x_1,\ldots,x_n] and a property of the list such as length or max which we would like to compute.

Length

Suppose we would like to compute the length of a list =[x1,,xn]\ell=[x_1,\ldots,x_n]. If \ell is empty (=[ ]\ell=[ \ ]) then its length is 0. If \ell contains a single element (=[x]\ell = [x]) then its length is 1. With these two base cases we can recursively define the length of a list as:

len()=1+len(n1)\text{len}(\ell) = 1 + \text{len}(\ell_{n-1})

where n1:=[x1,,xn1]\ell_{n-1} := [x_1,\ldots,x_{n-1}].

Expanding this out we get:

len()=1+len(n1)=1+(1+len(n2))=1+(1+(1+len(n3)))=1+(1+(1+1))=n+0=0=n\begin{aligned} \text{len}(\ell) &= 1 + \text{len}(\ell_{n-1})\\ &= 1 + (1 + \text{len}(\ell_{n-2}))\\ &= 1 + (1 + (1 + \text{len}(\ell_{n-3})))\\ &\cdots\\ &= \underbrace{1 + \cdots (1 + (1 + 1))}_{\text{=n}} + \underbrace{\ell_0}_{=0}\\ &= n \end{aligned}

Haskell

len :: [Int] -> Int
len [] = 0
len (x:xs) = 1 +  len xs

main :: IO ()
main = print(len [1,2,3]) -- 3

Note that in the above example x represents the current element and xs represents the remaining elements.

Space complexity

Unfortunately, the above recursion is not very forgiving on stack space. In an ordinary while loop the only variable being stored in memory is the current sum. The space complexity is O(1)O(1). In the above recursion, every function call occupies memory on the stack. In addition, the lifetime of each stack frame is the entire process because the recursive calls have to complete before the interpreter can begin to calculate the sum. The space complexity is O(n)O(n). A stack overflow exception may very well occur for this reason by simply trying to compute the length of an especially long list. Not good!

Tail-recursion

The workaround to the above problem is to use an accumulator to pass an intermediate result to the recursive call.The len function, for example, would then become a function of two variables, taking both the list [x1,,xn][x_1,\ldots,x_n] and an accumulation paramater as arguments.

Haskell

len :: [Int] -> Int -> Int
len [] acc = acc
len (x:xs) acc = len xs (1 + acc)

main :: IO ()
main = print(len [1,2,3] 0) -- 3

If the list is empty, the accumulator variable remains unchanged. Otherwise, the accumulator is incremented by 1 and passed along with the sublist xs (the current list less the current element) back to the len function.

What makes this tail recursion is the fact that the tail call (the recursive call where the sublist xs is passed back to len) is the final instruction at each iteration. In contrast, what makes the original recursion non-tail-recursive is the fact that the recursive call is not the final instruction. Why? Look closely at the 1 + component in len (x:xs) = 1 + len xs. This sum is called for every recursive call and is the final statement at each iteration, instead of the recursive call itself. From this point of view it becomes even clearer why the stack space issue occurs, each sum cannot resolve until the next sum resolves, etc. Until the final element is reached in which case 1 + 1 can be resolved. Then 1 + (1 + 1) can also be resolved, etc.

Scala

 def sum(list: List[Int], acc: Int): Int = list match {
    case Nil     => acc
    case x :: xs => sum(xs, 1 + acc)
  }
  println(sum(List(1, 2, 3), 0)) // 3

Thunks

The above function is indeed tail-recursive. But we are not out of the woods yet memory-wise. Due to lazy evaluation in Haskell, there are no guarantees that 1 + acc will be evaluated in-place. Such values which are yet to be evaluated are called thunks. As a result, we may still overflow the stack. To force this evaluation we can use the $! operator.

len :: [Int] -> Int -> Int
len [] acc = acc
len (x:xs) acc = len xs $! (1 + acc)

main :: IO ()
main = print(len [1,2,3] 0) -- 3

This way the recursive function call terminates at each step, the accumulator sum is incremented in-place, and therefore nothing accumulates on the stack. This solves the stack space problem. The space complexity is now O(1)O(1).

Fold

There is a concise way to black-box tail-recursion. This black-box is usually referred to as fold. The accumulation operation is passed (anonymously or otherwise) as an argument as well as its direction and initial value. Left folds traverse the list from left to right and right folds traverse the list from right to left. In haskell the foldl' syntax is used to force strict analysis and override lazy evaluation.

Haskell

len :: [Int] -> Int
len = foldl' (\acc x -> 1 + acc) 0

main :: IO ()
main = print(len [1,2,3]) -- 3

(Note that acc x is an anonymous function, making foldl' a family of higher-order functions as it takes a function as one of its arguments)

Scala

  def sum(list: List[Int]): Int = list.foldLeft(0)((sum, _) => sum + 1)
  println(sum(List(1, 2, 3))) // 3

Note that while Haskell is lazy by default, Scala is not. Scala is strict by default and does not delay computation unless explicitly requested using the lazy modifier.

Warning

In Scala foldRight is not tail-recursive (to be discussed in a later post).

Published 20 Jun 2018