 # Update: Friday, August 23, 2013

This post is from 2011, but has seen a lot of traffic lately and drawn some comments that the solution given here is incomplete or “doesn’t work”. This post is definitely incomplete, but the solution does work. For the most up-to-date code, please see my paper Stackless Scala with Free Monads as well as the source code for scalaz.Free.

```case class IntReader[A](run: Int => A) {
def map[B](f: A => B): IntReader[B] =
}
```

Now say we have a chain of flatMaps of arbitrary length. Let’s say 100,000. Let’s mock that up using a list:

```List.range(0, 100000).foldLeft(
(a, e) => a.flatMap(xs => IntReader(_ => e :: xs)))
```

This is going to result in a single function that crashes with a StackOverflowError. The reason why is that flatMap makes nested calls to `apply` on its argument without being in tail position. So the call stack repeatedly shows a call to `apply` in an anonymous function.

CPS Transformation

The classical way of avoiding the call stack in this situation is to transform the program to continuation-passing style (CPS). The CPS-transformed version of our reader monad looks like this:

```trait IntReader[A] {
def apply[R](k: A => R, i: Int): R
def apply[R](k: B => R, i: Int): R =
}
def apply[R](k: B => R, i: Int): R =
IntReader.this(a => k(f(a)(b => b, i)), i)
}
}
```

Instead of returning `A` directly from `apply`, we take a continuation `k` that receives the `A`. You can see how this would be an improvement. The calls to `apply` are now all in tail position. And the calls to the continuation `k` at every point is also in tail position. Unfortunately, Scala has very limited tail call elimination, which is able to eliminate a tail call only if it’s a recursive call to the current method. But note that `apply` above actually calls a different `apply` method: that of the containing module. And since `k` is different from the `apply` method of the function it’s called from, the call to the continuation cannot be eliminated either. So if we traverse our list with this CPS-transformed reader monad, we will still get a StackOverflowError.

Trampolining

What we must do is exchange stack for heap. The idea is simple. Instead of making a tail call, we return a data structure representing what to do next.

```sealed trait Trampoline[+A] {
def run: A = this match {
case Done(a) => a
case More(t) => t().run
}
}
case class Done[A](a: A) extends Trampoline[A]
case class More[A](a: () => Trampoline[A]) extends Trampoline[A]
```

Note that the `run` method is tail recursive. We can now use our trampoline to turn mutual recursion into tail recursion (thanks, Rich):

```def even(n: Int): Trampoline[Boolean] = {
if (n == 0) Done(true)
else More(() => odd(n - 1))
}
def odd(n: Int): Trampoline[Boolean] = {
if (n == 0) Done(false)
else More(() => even(n - 1))
}
```

No matter how deep the mutual recursion, calling either of these methods simply returns a `Trampoline` that we can unwind tail-recursively with `run`:

```scala> val b = odd(100000001).run
b: Boolean = true
```

Trampolines of Trampolines

Now let’s say we wanted to transform a binary recursion in the same way. For example, the (terribly inefficient) recursive function to find the nth Fibonacci number:

```def fib(n: Int): Int =
if (n < 2) n else fib(n - 1) + fib(n - 2)
```

There’s a bit of a problem here. If we change this to use our `Trampoline`, the result of `fib` will be `Trampoline[Int]`. So then how do we add two trampolines together? One way is to simply call `run`:

```def fib(n: Int): Trampoline[Int] =
if (n < 2) Done(n) else More(() => fib(n - 1).run + fib(n - 2).run)
```

But this defeats the purpose! The call to `run` is not in a tail position here, and so we’re back to getting stack overflows.

Another idea is to make `Trampoline` a monad, by just adding a `flatMap` method to it. Then we can just say:

```for { x <- fib(n - 1)
y <- fib(n - 2) } yield x + y
```

But there is no way of implementing `flatMap` without calling `run`.

```def flatMap[B](f: A => Trampoline[B]): Trampoline[B] =
More(() => f(this.run).run)
```

Delimited Continuations

The solution, as ever so often with continuations, is found in delimited control. We bake monadicity into the `Trampoline` data type, with an additional case. Again, instead of making a call to the continuation, we return a data structure representing what we’re doing currently together with what to do with the result:

```case class Cont[A, B](a: Trampoline[A],
f: A => Trampoline[B]) extends Trampoline[B]
```

Note that the arguments to this constructor are exactly the arguments to `flatMap`. The idea is that we can now implement `map` and `flatMap` like this:

```def map[B](f: A => B): Trampoline[B] =
Cont(this, a => More(() => Done(f(a))))
def flatMap[B](f: A => Trampoline[B]): Trampoline[B] =
Cont(this, f)
```

The implementation of `run` becomes a tad more complicated now:

```def run: A = {
var cur: Trampoline[_] = this
var stack: List[Any => Trampoline[A]] = List()
var result: Option[A] = None
while (result.isEmpty) {
cur match {
case Done(a) => stack match {
case Nil => result = Some(a.asInstanceOf[A])
case c :: cs => {
cur = c(a)
stack = cs
}
}
case More(t) => cur = t()
case Cont(a, f) => {
cur = a
stack = f.asInstanceOf[Any => Trampoline[A]] :: stack
}
}
}
result.get
}
```

We’re essentially breaking out of Scala here and dropping into a Java-level loop. Firstly, a `Cont` has two parts: an intermediate computation whose type is not known, and a continuation for which we only know the return type. Secondly, we have to keep our own stack of continuations as we descend into the monadic binds. So we must cast, just as if we were working in a language without generics. Don’t worry, the continuation type matches the intermediate computation by construction. Lastly, I’m using my own while loop instead of relying on Scala to translate tail recursion into a loop for me.

It’s possible to write this code with better types, using existentials and heterogeneous lists (left as an exercise for the hardened type-level programmer). But this is pretty self-contained, and we can be confident that it’s well typed without Scala’s help. It’s also possible to use the Delimited Continuations compiler plugin (also an exercise for the reader) to hide the casts, but that plugin makes these exact same casts.

Here is the whole trampoline code again:

```  sealed trait Trampoline[A] {
def map[B](f: A => B): Trampoline[B] =
flatMap(a => More(() => Done(f(a))))
def flatMap[B](f: A => Trampoline[B]): Trampoline[B] =
Cont(this, f)
def run: A = {
var cur: Trampoline[_] = this
var stack: List[Any => Trampoline[A]] = List()
var result: Option[A] = None
while (result.isEmpty) {
cur match {
case Done(a) => stack match {
case Nil => result = Some(a.asInstanceOf[A])
case c :: cs => {
cur = c(a)
stack = cs
}
}
case More(t) => cur = t()
case Cont(a, f) => {
cur = a
stack = f.asInstanceOf[Any => Trampoline[A]] :: stack
}
}
}
result.get
}
}
case class Done[A](a: A) extends Trampoline[A]
case class More[A](a: () => Trampoline[A]) extends Trampoline[A]
case class Cont[A, B](a: Trampoline[A],
f: A => Trampoline[B]) extends Trampoline[B]
```

Now we can write binary-recursive Fibonacci function that uses constant stack:

```def fib(n: Int): Trampoline[Int] =
if (n < 2) Done(n) else for {
x <- fib(n - 1)
y <- fib(n - 2)
} yield (x + y)
```

Even with millions of recursive calls, we don’t overflow the stack:

```scala> fib(40).run
res23: Int = 102334155
```

We’re now ready to come back to our original problem, which was tail call elimination in arbitrary monads. Remember that original reader monad?

As long as there exists a monad transformer version of the monad in question, we can transform our `Trampoline` monad, resulting in a new tail-recursive monad. For example, `type IntReader[A] = Int => A` is a monad, but `IntReaderT[M[_], A] = Int => M[A]` is also a monad, for any monad `M`, including `Trampoline`.

To illustrate this, I will use the `Kleisli` monad transformer from scalaz. Here, `Kleisli[M, A, B]` is isomorphic to the type `A => M[B]`, and so our trampolined `IntReader[A]` will be written `Kleisli[Trampoline, Int, A]`.

This lets us traverse, in our reader monad, a list with millions of elements:

```val x = List.range(0, 1000000).
foldLeft[Kleisli[Trampoline, Int, List[Int]]](
kleisli(i => Done(List(i))))(
(a, e) => kleisli(r => for {
x <- More(() => a.apply(r))
y <- Done(e + r :: x)
} yield y))
```

And it’s all in constant stack:

```scala> x(0).run
res28: List[Int] = List(999999, 999998, 999997, 999996, 999995, ...
```

This can be employed with any monad transformer, e.g. `StateT`, `WriterT`, Iteratees, etc.

## 27 thoughts on “Tail Call Elimination in Scala Monads”

1. huynhjl says:

Very interesting, although I don’t grasp the details, yet. Is this available in scalaz? Can this handle:

List.fill(10000)(1).traverse[({type S[x]=State[Int,x]})#S, Int](i => state[Int, Int](s => (s+1, i))) ~> 0

• Rúnar says:

Yes it does. Traverse with StateT[Trampoline, Int, _]

• Rúnar says:

This is not in Scalaz yet.

2. okomok says:

Greatest article ever!
I didn’t know how to implement flatMap.
Thanks.

3. Yo Eight says:

Thanks, that’s what I have been looking for !

Very informative post

4. Kenji Yoshida (@xuwei_k) says:

Awesome ! thanks

5. Paul Chiusano says:

Looking forward to this being in scalaz. For the tail recursive case, heap usage (for the explicit stack you are managing) will actually be constant. It’s only when you use this to write a non-tail recursive function like fib will you actually grow that stack.

Also, I’d use ArrayStack . It’s actually mutable and faster than using a var List .

6. Richard Wallace says:

Sweet! This solves a problem I’ve been having with Iteratees not being directly tail recursive. Now I can run

(length[Unit, Int, Trampoline] >>== enumStream(Stream.range(1, 100000, 1))) apply (_ => Done(-1)) run

without getting a SOE! Trying to do it with just Id before always gave a SOE after about 6.5k elements.

Now I just need to figure out how to combine it with IO.

7. okomok says:

Though it’s a minor issue, the first Trampoline’s `run` method might miss the `final` modifier?

8. Lachlan O'Dea says:

Very cool. Do you think an approach like this could be used to implement

9. Lachlan O'Dea says:

Argh, miss tap. I meant to ask if this could be used to implement a practical difference list in Scala.

• Rúnar says:
10. Josh Suereth (@jsuereth) says:

Just curious, what’s the difference between this style and TailRec in trunk?

• Rúnar says:

Josh,

There are several improvements in scalaz.Free.Trampoline over scala.util.control.TailCalls.TailRec:

1. This trampoline is a monad. That means you can combine two trampolines into one. It also means that you can turn any call into a tail-call.
2. The monad is codense. This means that left-associated binds are reassociated to the right.
3. A method `resume` is provided to step through the trampoline if needed.
4. A method `zipWith` is provided to run two trampolines cooperatively.
5. A method `mapSuspension` is provided to change the type of the suspension from Function0 to something else.

I’d be happy to discuss this further with you.

11. Alex Repain says:

Will this TCE technique work on nested recursion schemes ? e.g. the Ackermann function ?

• Rúnar says:

George, you need to wrap all of your recursive calls in suspend.

• George (@folone) says:

Right, seems to work. Thanks! I’ve updated the gist.

12. Kris Nuttycombe (@nuttycom) says:

This obviously uses constant stack, but does it use constant heap? If so, why so, and if not, is this not just as bad as a solution that uses the stack (albeit with a few more seconds of runtime before you run out of memory?)

• Rúnar says:

Yes, it does use constant heap. Because each step sits behind a lambda, it is not constructed on the heap until the previous step runs. And each step can be garbage collected as soon as the trampoline loop goes on to the next iteration. So it does a lot of heap allocation, and it makes JIT impossible, but it does really use constant memory (both stack and heap).

You can try this out by running the Ackermann function from the scalaz.examples project. It will run for a very long time without growing the stack or the heap.

• mb says:

Try and test your Trampoline implementation from above.
(In Scala) it doesn’t work! It doesn’t do what you think it should do.
The use with the binary-recursive Fibonacci function above is misleading, because that doesn’t require a lot of stack.

Try something like that;
final def rec(n:Int) : Trampoline[Int] = {
if ( n == 1 ) Done(n) else for {
x <- rec(n)
} yield (x)

rec(1); rec(2);

• Rúnar says:

You need to wrap your recursive case in a suspension. It really does work.

• Evgeny Kotelnikov (@aztek) says:

Is it really constant? Executing `run` increases the `stack: List[Any => Trampoline[A]]` along the way and in case of Fibonacci example it seems that the size of it grows linearly.

• Rúnar says:

Why does it seem that way?

13. mb says:

It’s not our fault:

Compare the above binary-recursive Fibonacci function in Scala:

fib(40);

with the ocaml – implementation

let fib n =
let rec fib_aux n a b =
match n with
| 0 -> a
| _ -> fib_aux (n-1) b (a+b)
in
fib_aux n 0 1;;

fib 10000;;

see any difference – except for the numbers?