Scala, Folds, and Universal Quantification

Church encoding of data types is a fun exercise. Something that blew my mind while learning functional programming was that a data structure is completely defined by its catamorphism. So for example, foldr is not just another function over lists. The list data type is the foldr function.

For example, in Haskell, we can encode pairs as this function:

pair x y f = f x y

A pair is then a higher-order function. It takes a function with two arguments and yields the result of applying that function to the two elements with which we constructed the pair.

Doing this in Scala is not quite as pretty. Here’s an attempt:

def pair[A,B,C](a: A, b: B) = (f: (A,B) => C) => f(a,b)

Seems to work:

scala> pair(10,20)(_ + _)
res1: Int = 30

However, that C type parameter is going to give us problems if we try this:

scala> val x = pair(10,20)
x: ((Int, Int) => Nothing) => Nothing = <function>

scala> x(_ + _)
<console>:12: error: type mismatch;
 found   : Int
 required: Nothing
       x(_ + _)
           ^

Oops. It looks like Scala needs all of the type parameters to a function up front. In other words, it seems like its type system is unable to unify the types (∀ a. a -> (∀ b. b -> b)) and (∀ a b. a -> b -> b). Is this accurate? If you know the answer, please comment below.

However, we can be saved with anonymous type syntax (which, I’m told, unfortunately makes use of reflection):

def pair[A,B](a: A, b: B) = new { def apply[C](f: (A,B) => C) = f(a,b) }

Now this works as expected:

scala> val x = pair(1,2)
x: java.lang.Object{def apply[C]((Int, Int) => C): C} = $anon$1@ad8dbc

scala> x(_ + _)
res13: Int = 3

Sweet! Let’s do this for lists. Here’s the Haskell.

nil p z = z
cons x xs p z = p x (xs p z)

Testing this in GHCi:

Prelude> let xs = cons 1 (cons 2 (cons 3 nil))
Prelude> xs (+) 0
6

No sweat. Now, let’s do this in Scala. This is the best I’ve come up with:

def nil =
  new { def apply[A](a: A) =
    new { def apply[B](b: B) = b }}
def cons[A](x: A) =
  new { def apply[B](xs: ((A => B => B) => B => B)) =
    (f: A => B => B) => (b:B) => f(x)(xs(f)(b)) }

This doesn’t completely work:

scala> val xs = cons(1)(cons(2)(cons(3)(nil)))
<console>:20: error: type mismatch;
 found   : java.lang.Object{def apply[A](A): java.lang.Object{def apply[B](B): B}}
 required: ((Int) => (?) => ?) => (?) => ?
       val xs = cons(1)(cons(2)(cons(3)(nil)))
                                        ^

Scala is not able to unify those two types. Let’s help it along:

def objToFun[A,B](o: {def apply[A](a: A): {def apply[B](b: B): B}}) = (a: A) => (b: B) => o(a)(b)
def nilp[A,B] = objToFun[A => B => B, B](nil)

With a little annotation, we can now create a list that has the right type:

scala> cons(3)(cons(2)(cons(1)(nilp[Int,Int])))
xs: ((Int) => (Int) => Int) => (Int) => Int = <function>

scala> xs(x => y => x + y)(0)
res44: Int = 6

Not terrible, but can we do better? Can this be done without the anonymous type syntax? Is using that syntax necessarily bad? Why? Could we modify Scala to get a nicer notation for anonymous function types? Is rank-2 polymorphism required for that, or can we convince the compiler to move universal quantifiers to the left of a single arrow?

Advertisement

11 thoughts on “Scala, Folds, and Universal Quantification

  1. Try this:

    def pair[A, B](x: A, y: B)(f: (A, B) => _) = f(x, y)

    There is no need to actually name the C type in your example since it is exclusively used to determine the return type.

  2. How about this?

    def pair[A,B](a:A,b:B)(f: (A,B) => C forSome{type C}) = f(a,b)

    val x = pair(10,20)

    x(_ + _) //30

  3. I’ve managed to avoid the structural type for pairs and cons, using a new trait “Closure”. See here:

    http://paste.pocoo.org/show/199307

    However, I’ve not figured out how to define the empty list so that the types are unified. If you can get that working, I’ll buy you a beer.

  4. http://gist.github.com/360425

    Usage looks like:


    val xs = cons(3)(cons(2)(cons(1)( nil )))
    val str = xs( (x: Int) => pure( (y: String) => x + y) )("0")
    val num = xs( (x: Int) => pure( (y: Int) => x + y) )(0)

    Printing from the interpreter:

    scala> str
    res0: java.lang.String = 3210

    scala> num
    res1: Int = 6

  5. Pingback: Non-empty lists « Higher Kinded Tripe

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s