Towards an Effect System in Scala, Part 2: IO Monad

In the previous post, we looked at the ST monad and how we can use it to encapsulate in-place mutation as first-class referentially transparent expressions. ST gives us mutable references and arrays, with the following guarantees:

  1. Mutations of an object are guaranteed to run in sequence.
  2. The holder of a mutable object is guaranteed to hold the only reference to it.
  3. Once a sequence of mutations of an object terminates, that object is never mutated again.

These invariants are enforced by the type system. I think that’s pretty cool, and sort of stands as a testament to the power of Scala’s type system.

This is much more than “delay side-effects until the last possible moment”. Remember, it is always safe to call runST on an ST action, anywhere in your program. As long as the call typechecks, it will have no observable side-effects with regard to STRefs and STArrays.

OK, so far we only have these guarantees for references and arrays, but that’s really not bad at all. We can eliminate an enormous class of bugs that have to do with shared mutable state, by guaranteeing that mutable state is never shared. Of course, you can still mutate other objects to your heart’s content (ones that are neither STRefs nor STArrays). But imagine for a second if Scala were modified in such a way that the var keyword actually constructed an STRef, and the only arrays provided by the library were of type STArray. Wouldn’t that be something? Wouldn’t that basically make Scala a purely functional language?

Well, no. There’s I/O. While ST gives us guarantees that mutable memory is never shared, it says nothing about reading/writing files, throwing exceptions, opening network sockets, database connections, etc.

The IO Data Type

We’re going to represent I/O actions as state transition functions, just like ST actions. Remember that ST is essentially a type like this:

type ST[S, A] = World[S] => (World[S], A)

The IO data type is very similar, except that we fix the world-state to be of a specific type:

type IO[A] = ST[RealWorld, A]

RealWorld is totally abstract. It’s an uninhabited type (there are no values of type RealWorld). And we will understand a value of type World[RealWorld] to represent the current state of the entire universe. Sequencing is guaranteed, just like with ST, since the IO monad has to pass the world state in order to execute the next action.

In Scalaz, it’s not possible (without cheating) to create a value of type World[RealWorld]. There are no values of this type. So how do you run an IO action? Well, the IO[A] data type has the following method defined:

def unsafePerformIO: A = this(null)

This is actually cheating a little bit, because we’re faking the value of type RealWorld by passing nothing at all. We’re about to potentially destroy the universe anyway, so this is OK just once. Besides, there’s a reason why this method has “unsafe” in the name. You only want to ever call this method once. The idea is that you construct your entire program with as much a purely functional core as possible, and an outer shell written in the IO monad. Then at the “end of the universe”, you call unsafePerformIO:

import scalaz._; import Scalaz._; import scalaz.effects._

def main(args: Array[String]): Unit =
myProgram(ImmutableArray.fromArray(args)).unsafePerformIO

def myProgram(args: ImmutableArray[String]): IO[Unit] =
error("Your IO program goes here.")

Again, imagine if Scala were modified in such a way that instead of looking for def main(args: Array[String]): Unit, it would look for def main(args: ImmutableArray[String]): IO[Unit].

Benefits and Weaknesses of the IO Monad

Because the World type argument is fixed in the definition of IO, we don’t have the same guarantees that we did with ST. Basically we can’t guarantee that the world state will never escape from unsafePerformIO. But we do have some other nice benefits.

For example, sequencing is still guaranteed, so no part of an action that depends on another will ever run before its dependency. This can be a problem if IO[A] is modeled as simply () => A. Also, IO actions are first-class objects, so they are freely composable and re-usable.

IO With the Scalaz Library

Scalaz includes a bunch of IO combinators for manipulating standard input and output, throwing/catching errors, mutating variables, etc. For example, here are some combinators for standard I/O:`

def getChar: IO[Char]
def putChar(c: Char): IO[Unit]
def putStr(s: String): IO[Unit]
def putStrLn(s: String): IO[Unit]
def readLn: IO[String]
def putOut[A](a: A): IO[Unit]

Composing these into programs is done monadically. So we can use for-comprehensions. Here’s a program that reads a line of input and prints it out again:

def program: IO[Unit] = for {
line <- readLn
_    <- putStrLn(line)
} yield ()

Or equivalently:

def program: IO[Unit] = readLn flatMap putStrLn

And if we wanted to write another program that re-uses our existing program, we can. Here’s a program that runs out previous program forever:

def program2: IO[Unit] = program |+| program2

IO[Unit] is an instance of Monoid, so we can re-use the monoid addition function |+|. Because everything is pure, we can concatenate programs just as easily as we concatenate Strings.

It’s also important to note that we’ve gained type safety. If you try to do this, you will get a type error:

scala> (readLn flatMap putStrLn) |+| System.exit(0)
<console>:17: error: type mismatch;
found   : Unit
required: scalaz.effects.IO[Unit]

Conclusion

We can gain a lot of static safety by separating values that produce I/O effects from values that have no effects, differentiating them via the type system. We also gain modularity by treating I/O actions as pure, compositional, first-class values that we can freely reuse in a completely deterministic way. Instead of running I/O effects everywhere in our code, we build programs through the IO DSL, compose them like ordinary values, and then run them with unsafePerformIO as part of our main.

 

Pascal’s Wager and the Problem of Computability

I was on a cable car yesterday morning and I was intellecually tickled by an advertisement posted on one of the walls in the car. It was an ad for a unitarian church, and in it they quoted Blaise Pascal:

It is incomprehensible that God should exist, and it is incomprehensible that He should not exist.

Pascal was a philosopher and mathematician, and he was trying to understand the question of whether or not one should believe in the existence of a God (specifically the Christian god, but that’s not important). He concluded that since the benefits of belief are supposedly enormous (or infinite), and one cannot actually know, one should wager in favor of there being a God, since then the benefit of being right is maximized and the cost of being wrong is minimized.

But Pascal made an error in his premises, which touches on computability theory. He sets out assuming that the statement “God exists” is either true or false. This is an unwarranted premise. Pascal was a rationalist, so he assumes the dichotomy of truth, that every statement is either true or false. But any programmer can tell you that this isn’t the case.

You see, Aristotle understood that not every statement is either true or false. The law of excluded middle says that a thing either is or is not. And truth or falsehood is an attribute only of statements that refer to things which are. For this reason, some statements are simply absurd, or arbitrary. Such statements cannot be examined for truth or falsehood. They can only be dismissed out of hand, as if nothing had been said. Because, in a very strict sense, nothing really has.

In computation, this is equivalent to the fact that not every program has an answer. Some programs simply bottom (crash or hang). There are types that are nonsensical and have no implementation except for programs whose answer is bottom. And there are questions that have no answer because they are nonsensical, and statements that are neither true nor false because they do not refer to any attributes or configurations of things which exist.

But it is absurd and impossible to suppose that the unknowable and indeterminate should contain and determine.

- Aristotle

Not only is it not right, it’s not even wrong!

- Wolfgang Pauli

On two occasions I have been asked,—”Pray, Mr. Babbage, if you put into the machine wrong figures, will the right answers come out?” … I am not able rightly to apprehend the kind of confusion of ideas that could provoke such a question.

- Charles Babbage

Tail Call Elimination in Scala Monads

Consider a simple reader monad:

case class IntReader[A](run: Int => A) {
  def map[B](f: A => B): IntReader[B] =
    IntReader(i => f(run(i)))
  def flatMap[B](f: A => IntReader[B]): IntReader[B] =
    IntReader(i => f(run(i)).run(i))
}

Now say we have a chain of flatMaps of arbitrary length. Let’s say 100,000. Let’s mock that up using a list:

List.range(0, 100000).foldLeft(
  IntReader(List(_)))(
    (a, e) => a.flatMap(xs => IntReader(_ => e :: xs)))

This is going to result in a single function that crashes with a StackOverflowError. The reason why is that flatMap makes nested calls to apply on its argument without being in tail position. So the call stack repeatedly shows a call to apply in an anonymous function.

CPS Transformation

The classical way of avoiding the call stack in this situation is to transform the program to continuation-passing style (CPS). The CPS-transformed version of our reader monad looks like this:

trait IntReader[A] {
  def apply[R](k: A => R, i: Int): R
  def map[B](f: A => B): IntReader[B] = new IntReader[B] {
    def apply[R](k: B => R, i: Int): R = 
      IntReader.this(a => k(f(a)), i)
  }
  def flatMap[B](f: A => IntReader[B]): IntReader[B] = new IntReader[B] {
    def apply[R](k: B => R, i: Int): R =
      IntReader.this(a => k(f(a)(b => b, i)), i)
  }
}

Instead of returning A directly from apply, we take a continuation k that receives the A. You can see how this would be an improvement. The calls to apply are now all in tail position. And the calls to the continuation k at every point is also in tail position. Unfortunately, Scala has very limited tail call elimination, which is able to eliminate a tail call only if it’s a recursive call to the current method. But note that apply above actually calls a different apply method: that of the containing module. And since k is different from the apply method of the function it’s called from, the call to the continuation cannot be eliminated either. So if we traverse our list with this CPS-transformed reader monad, we will still get a StackOverflowError.

Trampolining

What we must do is exchange stack for heap. The idea is simple. Instead of making a tail call, we return a data structure representing what to do next.

sealed trait Trampoline[+A] {
  def run: A = this match {
    case Done(a) => a
    case More(t) => t().run
  }
}
case class Done[A](a: A) extends Trampoline[A]
case class More[A](a: () => Trampoline[A]) extends Trampoline[A]

Note that the run method is tail recursive. We can now use our trampoline to turn mutual recursion into tail recursion (thanks, Rich):

def even(n: Int): Trampoline[Boolean] = {
  if (n == 0) Done(true)
  else More(() => odd(n - 1))
}
def odd(n: Int): Trampoline[Boolean] = {
  if (n == 0) Done(false)
  else More(() => even(n - 1))
}

No matter how deep the mutual recursion, calling either of these methods simply returns a Trampoline that we can unwind tail-recursively with run:

scala> val b = odd(100000001).run
b: Boolean = true

Trampolines of Trampolines

Now let’s say we wanted to transform a binary recursion in the same way. For example, the (terribly inefficient) recursive function to find the nth Fibonacci number:

def fib(n: Int): Int =
  if (n < 2) n else fib(n - 1) + fib(n - 2)

There’s a bit of a problem here. If we change this to use our Trampoline, the result of fib will be Trampoline[Int]. So then how do we add two trampolines together? One way is to simply call run:

def fib(n: Int): Trampoline[Int] =
  if (n < 2) Done(n) else More(() => fib(n - 1).run + fib(n - 2).run)

But this defeats the purpose! The call to run is not in a tail position here, and so we’re back to getting stack overflows.

Another idea is to make Trampoline a monad, by just adding a flatMap method to it. Then we can just say:

for { x <- fib(n - 1)
      y <- fib(n - 2) } yield x + y

But there is no way of implementing flatMap without calling run.

def flatMap[B](f: A => Trampoline[B]): Trampoline[B] =
  More(() => f(this.run).run)

Delimited Continuations

The solution, as ever so often with continuations, is found in delimited control. We bake monadicity into the Trampoline data type, with an additional case. Again, instead of making a call to the continuation, we return a data structure representing what we’re doing currently together with what to do with the result:

case class Cont[A, B](a: Trampoline[A],
                      f: A => Trampoline[B]) extends Trampoline[B]

Note that the arguments to this constructor are exactly the arguments to flatMap. The idea is that we can now implement map and flatMap like this:

def map[B](f: A => B): Trampoline[B] =
  Cont(this, a => More(() => Done(f(a))))
def flatMap[B](f: A => Trampoline[B]): Trampoline[B] =
  Cont(this, f)

The implementation of run becomes a tad more complicated now:

def run: A = {
  var cur: Trampoline[_] = this
  var stack: List[Any => Trampoline[A]] = List()
  var result: Option[A] = None
  while (result.isEmpty) {
    cur match {
      case Done(a) => stack match {
        case Nil => result = Some(a.asInstanceOf[A])
        case c :: cs => { 
          cur = c(a)
          stack = cs
        }
      }
      case More(t) => cur = t()
      case Cont(a, f) => {
        cur = a
        stack = f.asInstanceOf[Any => Trampoline[A]] :: stack 
      }
    }
  }
  result.get
}

We’re essentially breaking out of Scala here and dropping into a Java-level loop. Firstly, a Cont has two parts: an intermediate computation whose type is not known, and a continuation for which we only know the return type. Secondly, we have to keep our own stack of continuations as we descend into the monadic binds. So we must cast, just as if we were working in a language without generics. Don’t worry, the continuation type matches the intermediate computation by construction. Lastly, I’m using my own while loop instead of relying on Scala to translate tail recursion into a loop for me.

It’s possible to write this code with better types, using existentials and heterogeneous lists (left as an exercise for the hardened type-level programmer). But this is pretty self-contained, and we can be confident that it’s well typed without Scala’s help. It’s also possible to use the Delimited Continuations compiler plugin (also an exercise for the reader) to hide the casts, but that plugin makes these exact same casts.

Here is the whole trampoline code again:

  sealed trait Trampoline[A] {
    def map[B](f: A => B): Trampoline[B] =
      flatMap(a => More(() => Done(f(a))))
    def flatMap[B](f: A => Trampoline[B]): Trampoline[B] =
      Cont(this, f)
    def run: A = {
      var cur: Trampoline[_] = this
      var stack: List[Any => Trampoline[A]] = List()
      var result: Option[A] = None
      while (result.isEmpty) {
        cur match {
          case Done(a) => stack match {
            case Nil => result = Some(a.asInstanceOf[A])
            case c :: cs => { 
              cur = c(a)
              stack = cs
            }
          }
          case More(t) => cur = t()
          case Cont(a, f) => {
            cur = a
            stack = f.asInstanceOf[Any => Trampoline[A]] :: stack 
          }
        }
      }
      result.get
    }
  }
  case class Done[A](a: A) extends Trampoline[A]
  case class More[A](a: () => Trampoline[A]) extends Trampoline[A]
  case class Cont[A, B](a: Trampoline[A],
                        f: A => Trampoline[B]) extends Trampoline[B]

Now we can write binary-recursive Fibonacci function that uses constant stack:

def fib(n: Int): Trampoline[Int] =
  if (n < 2) Done(n) else for {
    x <- fib(n - 1)
    y <- fib(n - 2)
  } yield (x + y)

Even with millions of recursive calls, we don’t overflow the stack:

scala> fib(40).run
res23: Int = 102334155

Trampolining other monads

We’re now ready to come back to our original problem, which was tail call elimination in arbitrary monads. Remember that original reader monad?

As long as there exists a monad transformer version of the monad in question, we can transform our Trampoline monad, resulting in a new tail-recursive monad. For example, type IntReader[A] = Int => A is a monad, but IntReaderT[M[_], A] = Int => M[A] is also a monad, for any monad M, including Trampoline.

To illustrate this, I will use the Kleisli monad transformer from scalaz. Here, Kleisli[M, A, B] is isomorphic to the type A => M[B], and so our trampolined IntReader[A] will be written Kleisli[Trampoline, Int, A].

This lets us traverse, in our reader monad, a list with millions of elements:

val x = List.range(0, 1000000).
  foldLeft[Kleisli[Trampoline, Int, List[Int]]](
    kleisli(i => Done(List(i))))(
      (a, e) => kleisli(r => for {
        x <- More(() => a.apply(r))
        y <- Done(e + r :: x)
    } yield y))

And it’s all in constant stack:

scala> x(0).run
res28: List[Int] = List(999999, 999998, 999997, 999996, 999995, ...

This can be employed with any monad transformer, e.g. StateT, WriterT, Iteratees, etc.