Tail Call Elimination in Scala Monads

Update: Friday, August 23, 2013

This post is from 2011, but has seen a lot of traffic lately and drawn some comments that the solution given here is incomplete or “doesn’t work”. This post is definitely incomplete, but the solution does work. For the most up-to-date code, please see my paper Stackless Scala with Free Monads as well as the source code for scalaz.Free.

 

Consider a simple reader monad:

case class IntReader[A](run: Int => A) {
  def map[B](f: A => B): IntReader[B] =
    IntReader(i => f(run(i)))
  def flatMap[B](f: A => IntReader[B]): IntReader[B] =
    IntReader(i => f(run(i)).run(i))
}

Now say we have a chain of flatMaps of arbitrary length. Let’s say 100,000. Let’s mock that up using a list:

List.range(0, 100000).foldLeft(
  IntReader(List(_)))(
    (a, e) => a.flatMap(xs => IntReader(_ => e :: xs)))

This is going to result in a single function that crashes with a StackOverflowError. The reason why is that flatMap makes nested calls to apply on its argument without being in tail position. So the call stack repeatedly shows a call to apply in an anonymous function.

CPS Transformation

The classical way of avoiding the call stack in this situation is to transform the program to continuation-passing style (CPS). The CPS-transformed version of our reader monad looks like this:

trait IntReader[A] {
  def apply[R](k: A => R, i: Int): R
  def map[B](f: A => B): IntReader[B] = new IntReader[B] {
    def apply[R](k: B => R, i: Int): R =
      IntReader.this(a => k(f(a)), i)
  }
  def flatMap[B](f: A => IntReader[B]): IntReader[B] = new IntReader[B] {
    def apply[R](k: B => R, i: Int): R =
      IntReader.this(a => k(f(a)(b => b, i)), i)
  }
}

Instead of returning A directly from apply, we take a continuation k that receives the A. You can see how this would be an improvement. The calls to apply are now all in tail position. And the calls to the continuation k at every point is also in tail position. Unfortunately, Scala has very limited tail call elimination, which is able to eliminate a tail call only if it’s a recursive call to the current method. But note that apply above actually calls a different apply method: that of the containing module. And since k is different from the apply method of the function it’s called from, the call to the continuation cannot be eliminated either. So if we traverse our list with this CPS-transformed reader monad, we will still get a StackOverflowError.

Trampolining

What we must do is exchange stack for heap. The idea is simple. Instead of making a tail call, we return a data structure representing what to do next.

sealed trait Trampoline[+A] {
  def run: A = this match {
    case Done(a) => a
    case More(t) => t().run
  }
}
case class Done[A](a: A) extends Trampoline[A]
case class More[A](a: () => Trampoline[A]) extends Trampoline[A]

Note that the run method is tail recursive. We can now use our trampoline to turn mutual recursion into tail recursion (thanks, Rich):

def even(n: Int): Trampoline[Boolean] = {
  if (n == 0) Done(true)
  else More(() => odd(n - 1))
}
def odd(n: Int): Trampoline[Boolean] = {
  if (n == 0) Done(false)
  else More(() => even(n - 1))
}

No matter how deep the mutual recursion, calling either of these methods simply returns a Trampoline that we can unwind tail-recursively with run:

scala> val b = odd(100000001).run
b: Boolean = true

Trampolines of Trampolines

Now let’s say we wanted to transform a binary recursion in the same way. For example, the (terribly inefficient) recursive function to find the nth Fibonacci number:

def fib(n: Int): Int =
  if (n < 2) n else fib(n - 1) + fib(n - 2)

There’s a bit of a problem here. If we change this to use our Trampoline, the result of fib will be Trampoline[Int]. So then how do we add two trampolines together? One way is to simply call run:

def fib(n: Int): Trampoline[Int] =
  if (n < 2) Done(n) else More(() => fib(n - 1).run + fib(n - 2).run)

But this defeats the purpose! The call to run is not in a tail position here, and so we’re back to getting stack overflows.

Another idea is to make Trampoline a monad, by just adding a flatMap method to it. Then we can just say:

for { x <- fib(n - 1)
      y <- fib(n - 2) } yield x + y

But there is no way of implementing flatMap without calling run.

def flatMap[B](f: A => Trampoline[B]): Trampoline[B] =
  More(() => f(this.run).run)

Delimited Continuations

The solution, as ever so often with continuations, is found in delimited control. We bake monadicity into the Trampoline data type, with an additional case. Again, instead of making a call to the continuation, we return a data structure representing what we’re doing currently together with what to do with the result:

case class Cont[A, B](a: Trampoline[A],
                      f: A => Trampoline[B]) extends Trampoline[B]

Note that the arguments to this constructor are exactly the arguments to flatMap. The idea is that we can now implement map and flatMap like this:

def map[B](f: A => B): Trampoline[B] =
  Cont(this, a => More(() => Done(f(a))))
def flatMap[B](f: A => Trampoline[B]): Trampoline[B] =
  Cont(this, f)

The implementation of run becomes a tad more complicated now:

def run: A = {
  var cur: Trampoline[_] = this
  var stack: List[Any => Trampoline[A]] = List()
  var result: Option[A] = None
  while (result.isEmpty) {
    cur match {
      case Done(a) => stack match {
        case Nil => result = Some(a.asInstanceOf[A])
        case c :: cs => {
          cur = c(a)
          stack = cs
        }
      }
      case More(t) => cur = t()
      case Cont(a, f) => {
        cur = a
        stack = f.asInstanceOf[Any => Trampoline[A]] :: stack
      }
    }
  }
  result.get
}

We’re essentially breaking out of Scala here and dropping into a Java-level loop. Firstly, a Cont has two parts: an intermediate computation whose type is not known, and a continuation for which we only know the return type. Secondly, we have to keep our own stack of continuations as we descend into the monadic binds. So we must cast, just as if we were working in a language without generics. Don’t worry, the continuation type matches the intermediate computation by construction. Lastly, I’m using my own while loop instead of relying on Scala to translate tail recursion into a loop for me.

It’s possible to write this code with better types, using existentials and heterogeneous lists (left as an exercise for the hardened type-level programmer). But this is pretty self-contained, and we can be confident that it’s well typed without Scala’s help. It’s also possible to use the Delimited Continuations compiler plugin (also an exercise for the reader) to hide the casts, but that plugin makes these exact same casts.

Here is the whole trampoline code again:

  sealed trait Trampoline[A] {
    def map[B](f: A => B): Trampoline[B] =
      flatMap(a => More(() => Done(f(a))))
    def flatMap[B](f: A => Trampoline[B]): Trampoline[B] =
      Cont(this, f)
    def run: A = {
      var cur: Trampoline[_] = this
      var stack: List[Any => Trampoline[A]] = List()
      var result: Option[A] = None
      while (result.isEmpty) {
        cur match {
          case Done(a) => stack match {
            case Nil => result = Some(a.asInstanceOf[A])
            case c :: cs => {
              cur = c(a)
              stack = cs
            }
          }
          case More(t) => cur = t()
          case Cont(a, f) => {
            cur = a
            stack = f.asInstanceOf[Any => Trampoline[A]] :: stack
          }
        }
      }
      result.get
    }
  }
  case class Done[A](a: A) extends Trampoline[A]
  case class More[A](a: () => Trampoline[A]) extends Trampoline[A]
  case class Cont[A, B](a: Trampoline[A],
                        f: A => Trampoline[B]) extends Trampoline[B]

Now we can write binary-recursive Fibonacci function that uses constant stack:

def fib(n: Int): Trampoline[Int] =
  if (n < 2) Done(n) else for {
    x <- fib(n - 1)
    y <- fib(n - 2)
  } yield (x + y)

Even with millions of recursive calls, we don’t overflow the stack:

scala> fib(40).run
res23: Int = 102334155

Trampolining other monads

We’re now ready to come back to our original problem, which was tail call elimination in arbitrary monads. Remember that original reader monad?

As long as there exists a monad transformer version of the monad in question, we can transform our Trampoline monad, resulting in a new tail-recursive monad. For example, type IntReader[A] = Int => A is a monad, but IntReaderT[M[_], A] = Int => M[A] is also a monad, for any monad M, including Trampoline.

To illustrate this, I will use the Kleisli monad transformer from scalaz. Here, Kleisli[M, A, B] is isomorphic to the type A => M[B], and so our trampolined IntReader[A] will be written Kleisli[Trampoline, Int, A].

This lets us traverse, in our reader monad, a list with millions of elements:

val x = List.range(0, 1000000).
  foldLeft[Kleisli[Trampoline, Int, List[Int]]](
    kleisli(i => Done(List(i))))(
      (a, e) => kleisli(r => for {
        x <- More(() => a.apply(r))
        y <- Done(e + r :: x)
    } yield y))

And it’s all in constant stack:

scala> x(0).run
res28: List[Int] = List(999999, 999998, 999997, 999996, 999995, ...

This can be employed with any monad transformer, e.g. StateT, WriterT, Iteratees, etc.

27 thoughts on “Tail Call Elimination in Scala Monads

  1. Very interesting, although I don’t grasp the details, yet. Is this available in scalaz? Can this handle:

    List.fill(10000)(1).traverse[({type S[x]=State[Int,x]})#S, Int](i => state[Int, Int](s => (s+1, i))) ~> 0

  2. Looking forward to this being in scalaz. For the tail recursive case, heap usage (for the explicit stack you are managing) will actually be constant. It’s only when you use this to write a non-tail recursive function like fib will you actually grow that stack.

    Also, I’d use ArrayStack [1]. It’s actually mutable and faster than using a var List [2].

    [1]: http://www.scala-lang.org/api/current/scala/collection/mutable/ArrayStack.html.
    [2]: http://www.drmaciver.com/2008/08/new-collections-in-scala-272/

  3. Sweet! This solves a problem I’ve been having with Iteratees not being directly tail recursive. Now I can run

    (length[Unit, Int, Trampoline] >>== enumStream(Stream.range(1, 100000, 1))) apply (_ => Done(-1)) run

    without getting a SOE! Trying to do it with just Id before always gave a SOE after about 6.5k elements.

    Now I just need to figure out how to combine it with IO.

    • Josh,

      There are several improvements in scalaz.Free.Trampoline over scala.util.control.TailCalls.TailRec:

      1. This trampoline is a monad. That means you can combine two trampolines into one. It also means that you can turn any call into a tail-call.
      2. The monad is codense. This means that left-associated binds are reassociated to the right.
      3. A method resume is provided to step through the trampoline if needed.
      4. A method zipWith is provided to run two trampolines cooperatively.
      5. A method mapSuspension is provided to change the type of the suspension from Function0 to something else.

      I’d be happy to discuss this further with you.

  4. This obviously uses constant stack, but does it use constant heap? If so, why so, and if not, is this not just as bad as a solution that uses the stack (albeit with a few more seconds of runtime before you run out of memory?)

    • Yes, it does use constant heap. Because each step sits behind a lambda, it is not constructed on the heap until the previous step runs. And each step can be garbage collected as soon as the trampoline loop goes on to the next iteration. So it does a lot of heap allocation, and it makes JIT impossible, but it does really use constant memory (both stack and heap).

      You can try this out by running the Ackermann function from the scalaz.examples project. It will run for a very long time without growing the stack or the heap.

  5. It’s not our fault:

    Compare the above binary-recursive Fibonacci function in Scala:

    fib(40);

    with the ocaml – implementation

    let fib n =
    let rec fib_aux n a b =
    match n with
    | 0 -> a
    | _ -> fib_aux (n-1) b (a+b)
    in
    fib_aux n 0 1;;

    fib 10000;;

    see any difference – except for the numbers?

Leave a comment