How curried functions and partially-applied functions compile in Scala

This morning I was curious about how Scala curried functions and partially-applied functions are really compiled at a bytecode level.

Prior to that, I wrote this post that Higher order functions are the Haskell experience — which is also implicitly about curried functions — and it got me thinking about Scala, in particular why we might use one function syntax versus another, i.e., why would I use this syntax:

 

(this space left blank for the ToC over there --> )
 

      \\ one param list //
def max1(a: Int, b: Int) = if (a > b) a else b

and not this syntax for everything:

      \\ two param lists //
def max2(a: Int)(b: Int) = if (a > b) a else b

A curried function test class

This got me curious about how these functions work at a bytecode level. To check it out I wrote this little Scala test class:

class ScalaCurriedFunctions {

    def max1(a: Int, b: Int) = if (a > b) a else b
    def max2(a: Int)(b: Int) = if (a > b) a else b

    val m1a = max1(1, _: Int)
    val m2a = max2(1)(_)

    val m1b = max1(1,2)
    val m2b = max2(1)(2)

}

I then compiled the source code with scalac and then decompiled the resulting JVM class files with Jad. After decompiling that code and looking at the results, I commented my source code like this:

class ScalaCurriedFunctions {

    // jad shows these two functions generate the same jvm bytecode
    def max1(a: Int, b: Int) = if (a > b) a else b
    def max2(a: Int)(b: Int) = if (a > b) a else b

    val m1a = max1(1, _: Int)   //jad: `public Function1 m1a()`
    val m2a = max2(1)(_)        //jad: `public Function1 m2a()`

    val m1b = max1(1,2)    //jad: `public int m1b()`
    val m2b = max2(1)(2)   //jad: `public int m2b()`

}

Result: The bytecode is the same

I was surprised to see that when the .class file is decompiled, the resulting Java source code for max1 and max2 is identical. I thought there might be some difference about how these functions are compiled because of their different syntax, but again, at a bytecode level there is no difference.

(FWIW, I wrote about some things like this in the Scala Cookbook, but I don’t think I covered this specific case.)

Be careful looking at -Xprint output

As a “note to self,” I have to be careful when looking at the output from the scalac -Xprint output. Of course it’s correct for the stage the compiler is at, but if you look at early or intermediate stages, you can see output that will mislead you.

As an example of this, here’s the output from scalac -Xprint:parse:

$ scalac -Xprint:parse ScalaCurriedFunctions.scala 

[[syntax trees at end of                    parser]] // ScalaCurriedFunctions.scala
package <empty> {
  class ScalaCurriedFunctions extends scala.AnyRef {
    def <init>() = {
      super.<init>();
      ()
    };
    def max1(a: Int, b: Int) = if (a.$greater(b))
      a
    else
      b;
    def max2(a: Int)(b: Int) = if (a.$greater(b))
      a
    else
      b;
    val m1a = ((x$1: Int) => max1(1, (x$1: Int)));
    val m2a = ((x$2) => max2(1)(x$2));
    val m1b = max1(1, 2);
    val m2b = max2(1)(2)
  }
}

That shows a difference between max1 and max2, and m1a and m2a.

This output from scalac -Xprint:typer also shows a difference:

$ scalac -Xprint:typer ScalaCurriedFunctions.scala 

[[syntax trees at end of                     typer]] // ScalaCurriedFunctions.scala
package <empty> {
  class ScalaCurriedFunctions extends scala.AnyRef {
    def <init>(): ScalaCurriedFunctions = {
      ScalaCurriedFunctions.super.<init>();
      ()
    };
    def max1(a: Int, b: Int): Int = if (a.>(b))
      a
    else
      b;
    def max2(a: Int)(b: Int): Int = if (a.>(b))
      a
    else
      b;
    private[this] val m1a: Int => Int = ((x$1: Int) => ScalaCurriedFunctions.this.max1(1, (x$1: Int)));
    <stable> <accessor> def m1a: Int => Int = ScalaCurriedFunctions.this.m1a;
    private[this] val m2a: Int => Int = ((x$2: Int) => ScalaCurriedFunctions.this.max2(1)(x$2));
    <stable> <accessor> def m2a: Int => Int = ScalaCurriedFunctions.this.m2a;
    private[this] val m1b: Int = ScalaCurriedFunctions.this.max1(1, 2);
    <stable> <accessor> def m1b: Int = ScalaCurriedFunctions.this.m1b;
    private[this] val m2b: Int = ScalaCurriedFunctions.this.max2(1)(2);
    <stable> <accessor> def m2b: Int = ScalaCurriedFunctions.this.m2b
  }
}

But if you look at the output from Jad, or the end of the output from -Xprint:all, you’ll see that the end-result code is the same. As proof of this, here’s a little bit of output from -Xprint:all:

[[syntax trees at end of                   flatten]] // ScalaCurriedFunctions.scala
package <empty> {
  class ScalaCurriedFunctions extends Object {
    def max1(a: Int, b: Int): Int = if (a.>(b))
      a
    else
      b;
    def max2(a: Int, b: Int): Int = if (a.>(b))
      a
    else
      b;
    private[this] val m1a: Function1 = _;
    <stable> <accessor> def m1a(): Function1 = ScalaCurriedFunctions.this.m1a;
    private[this] val m2a: Function1 = _;
    <stable> <accessor> def m2a(): Function1 = ScalaCurriedFunctions.this.m2a;
    private[this] val m1b: Int = _;
    <stable> <accessor> def m1b(): Int = ScalaCurriedFunctions.this.m1b;
    private[this] val m2b: Int = _;
    <stable> <accessor> def m2b(): Int = ScalaCurriedFunctions.this.m2b;
    def <init>(): ScalaCurriedFunctions = {
      ScalaCurriedFunctions.super.<init>();
      ScalaCurriedFunctions.this.m1a = {
        (new <$anon: Function1>(ScalaCurriedFunctions.this): Function1)
      };
      ScalaCurriedFunctions.this.m2a = {
        (new <$anon: Function1>(ScalaCurriedFunctions.this): Function1)
      };
      ScalaCurriedFunctions.this.m1b = ScalaCurriedFunctions.this.max1(1, 2);
      ScalaCurriedFunctions.this.m2b = ScalaCurriedFunctions.this.max2(1, 2);
      ()
    }
  };

In that compiler-generated source code you can see that all of the differences between the curried and partially-applied functions are gone.

Final decompiled source code

Finally, if you’re really interested in the gory details, this is the source code that’s generated by Jad when I use it to decompile the ScalaCurriedFunctions.class JVM bytecode file:

import scala.Function1;
import scala.Serializable;
import scala.runtime.BoxesRunTime;

public class ScalaCurriedFunctions
{

    public int max1(int a, int b)
    {
        return a <= b ? b : a;
    }

    public int max2(int a, int b)
    {
        return a <= b ? b : a;
    }

    public Function1 m1a()
    {
        return m1a;
    }

    public Function1 m2a()
    {
        return m2a;
    }

    public int m1b()
    {
        return m1b;
    }

    public int m2b()
    {
        return m2b;
    }

    public ScalaCurriedFunctions()
    {
    }

    private final Function1 m1a = new Serializable() {

        public final int apply(int x$1)
        {
            return apply$mcII$sp(x$1);
        }

        public int apply$mcII$sp(int x$1)
        {
            return $outer.max1(1, x$1);
        }

        public final volatile Object apply(Object v1)
        {
            return BoxesRunTime.boxToInteger(apply(BoxesRunTime.unboxToInt(v1)));
        }

        public static final long serialVersionUID = 0L;
        private final ScalaCurriedFunctions $outer;

            public 
            {
                if(ScalaCurriedFunctions.this == null)
                {
                    throw null;
                } else
                {
                    this.$outer = ScalaCurriedFunctions.this;
                    super();
                    return;
                }
            }
    };

    private final Function1 m2a = new Serializable() {

        public final int apply(int x$2)
        {
            return apply$mcII$sp(x$2);
        }

        public int apply$mcII$sp(int x$2)
        {
            return $outer.max2(1, x$2);
        }

        public final volatile Object apply(Object v1)
        {
            return BoxesRunTime.boxToInteger(apply(BoxesRunTime.unboxToInt(v1)));
        }

        public static final long serialVersionUID = 0L;
        private final ScalaCurriedFunctions $outer;

            public 
            {
                if(ScalaCurriedFunctions.this == null)
                {
                    throw null;
                } else
                {
                    this.$outer = ScalaCurriedFunctions.this;
                    super();
                    return;
                }
            }
    };

    private final int m1b = max1(1, 2);
    private final int m2b = max2(1, 2);

}

scalac Xprint parser phases

In a related note, if you haven’t seen them before, you can see the scalac parser phases with the following command:

$ scalac -Xshow-phases

    phase name  id  description
    ----------  --  -----------
        parser   1  parse source into ASTs, perform simple desugaring
         namer   2  resolve names, attach symbols to named trees
packageobjects   3  load package objects
         typer   4  the meat and potatoes: type the trees
        patmat   5  translate match expressions
superaccessors   6  add super accessors in traits and nested classes
    extmethods   7  add extension methods for inline classes
       pickler   8  serialize symbol tables
     refchecks   9  reference/override checking, translate nested objects
       uncurry  10  uncurry, translate function values to anonymous classes
     tailcalls  11  replace tail calls by jumps
    specialize  12  @specialized-driven class and method specialization
 explicitouter  13  this refs to outer pointers
       erasure  14  erase types, add interfaces for traits
   posterasure  15  clean up erased inline classes
      lazyvals  16  allocate bitmaps, translate lazy vals into lazified defs
    lambdalift  17  move nested functions to top level
  constructors  18  move field definitions into constructors
       flatten  19  eliminate inner classes
         mixin  20  mixin composition
       cleanup  21  platform-specific cleanups, generate reflective calls
    delambdafy  22  remove lambdas
         icode  23  generate portable intermediate code
           jvm  24  generate JVM bytecode
      terminal  25  the last phase during a compilation run

You can also see them by running the man command on scalac, i.e., man scalac.

Summary

I don’t have any major conclusions here. I just wanted to look at how the Scala compiler treats “regular” functions, curried functions, and partially-applied functions, and as you can see, when you get to a bytecode level all of these examples look identical.