This morning I was curious about how Scala curried functions and partially-applied functions are really compiled at a bytecode level.
Prior to that, I wrote this post that Higher order functions are the Haskell experience — which is also implicitly about curried functions — and it got me thinking about Scala, in particular why we might use one function syntax versus another, i.e., why would I use this syntax:
(this space left blank for the ToC over there --> )
\\ one param list // def max1(a: Int, b: Int) = if (a > b) a else b
and not this syntax for everything:
\\ two param lists // def max2(a: Int)(b: Int) = if (a > b) a else b
A curried function test class
This got me curious about how these functions work at a bytecode level. To check it out I wrote this little Scala test class:
class ScalaCurriedFunctions { def max1(a: Int, b: Int) = if (a > b) a else b def max2(a: Int)(b: Int) = if (a > b) a else b val m1a = max1(1, _: Int) val m2a = max2(1)(_) val m1b = max1(1,2) val m2b = max2(1)(2) }
I then compiled the source code with scalac
and then decompiled the resulting JVM class files with Jad. After decompiling that code and looking at the results, I commented my source code like this:
class ScalaCurriedFunctions { // jad shows these two functions generate the same jvm bytecode def max1(a: Int, b: Int) = if (a > b) a else b def max2(a: Int)(b: Int) = if (a > b) a else b val m1a = max1(1, _: Int) //jad: `public Function1 m1a()` val m2a = max2(1)(_) //jad: `public Function1 m2a()` val m1b = max1(1,2) //jad: `public int m1b()` val m2b = max2(1)(2) //jad: `public int m2b()` }
Result: The bytecode is the same
I was surprised to see that when the .class file is decompiled, the resulting Java source code for max1
and max2
is identical. I thought there might be some difference about how these functions are compiled because of their different syntax, but again, at a bytecode level there is no difference.
(FWIW, I wrote about some things like this in the Scala Cookbook, but I don’t think I covered this specific case.)
Be careful looking at -Xprint output
As a “note to self,” I have to be careful when looking at the output from the scalac -Xprint
output. Of course it’s correct for the stage the compiler is at, but if you look at early or intermediate stages, you can see output that will mislead you.
As an example of this, here’s the output from scalac -Xprint:parse
:
$ scalac -Xprint:parse ScalaCurriedFunctions.scala [[syntax trees at end of parser]] // ScalaCurriedFunctions.scala package <empty> { class ScalaCurriedFunctions extends scala.AnyRef { def <init>() = { super.<init>(); () }; def max1(a: Int, b: Int) = if (a.$greater(b)) a else b; def max2(a: Int)(b: Int) = if (a.$greater(b)) a else b; val m1a = ((x$1: Int) => max1(1, (x$1: Int))); val m2a = ((x$2) => max2(1)(x$2)); val m1b = max1(1, 2); val m2b = max2(1)(2) } }
That shows a difference between max1
and max2
, and m1a
and m2a
.
This output from scalac -Xprint:typer
also shows a difference:
$ scalac -Xprint:typer ScalaCurriedFunctions.scala [[syntax trees at end of typer]] // ScalaCurriedFunctions.scala package <empty> { class ScalaCurriedFunctions extends scala.AnyRef { def <init>(): ScalaCurriedFunctions = { ScalaCurriedFunctions.super.<init>(); () }; def max1(a: Int, b: Int): Int = if (a.>(b)) a else b; def max2(a: Int)(b: Int): Int = if (a.>(b)) a else b; private[this] val m1a: Int => Int = ((x$1: Int) => ScalaCurriedFunctions.this.max1(1, (x$1: Int))); <stable> <accessor> def m1a: Int => Int = ScalaCurriedFunctions.this.m1a; private[this] val m2a: Int => Int = ((x$2: Int) => ScalaCurriedFunctions.this.max2(1)(x$2)); <stable> <accessor> def m2a: Int => Int = ScalaCurriedFunctions.this.m2a; private[this] val m1b: Int = ScalaCurriedFunctions.this.max1(1, 2); <stable> <accessor> def m1b: Int = ScalaCurriedFunctions.this.m1b; private[this] val m2b: Int = ScalaCurriedFunctions.this.max2(1)(2); <stable> <accessor> def m2b: Int = ScalaCurriedFunctions.this.m2b } }
But if you look at the output from Jad, or the end of the output from -Xprint:all
, you’ll see that the end-result code is the same. As proof of this, here’s a little bit of output from -Xprint:all
:
[[syntax trees at end of flatten]] // ScalaCurriedFunctions.scala package <empty> { class ScalaCurriedFunctions extends Object { def max1(a: Int, b: Int): Int = if (a.>(b)) a else b; def max2(a: Int, b: Int): Int = if (a.>(b)) a else b; private[this] val m1a: Function1 = _; <stable> <accessor> def m1a(): Function1 = ScalaCurriedFunctions.this.m1a; private[this] val m2a: Function1 = _; <stable> <accessor> def m2a(): Function1 = ScalaCurriedFunctions.this.m2a; private[this] val m1b: Int = _; <stable> <accessor> def m1b(): Int = ScalaCurriedFunctions.this.m1b; private[this] val m2b: Int = _; <stable> <accessor> def m2b(): Int = ScalaCurriedFunctions.this.m2b; def <init>(): ScalaCurriedFunctions = { ScalaCurriedFunctions.super.<init>(); ScalaCurriedFunctions.this.m1a = { (new <$anon: Function1>(ScalaCurriedFunctions.this): Function1) }; ScalaCurriedFunctions.this.m2a = { (new <$anon: Function1>(ScalaCurriedFunctions.this): Function1) }; ScalaCurriedFunctions.this.m1b = ScalaCurriedFunctions.this.max1(1, 2); ScalaCurriedFunctions.this.m2b = ScalaCurriedFunctions.this.max2(1, 2); () } };
In that compiler-generated source code you can see that all of the differences between the curried and partially-applied functions are gone.
Final decompiled source code
Finally, if you’re really interested in the gory details, this is the source code that’s generated by Jad when I use it to decompile the ScalaCurriedFunctions.class JVM bytecode file:
import scala.Function1; import scala.Serializable; import scala.runtime.BoxesRunTime; public class ScalaCurriedFunctions { public int max1(int a, int b) { return a <= b ? b : a; } public int max2(int a, int b) { return a <= b ? b : a; } public Function1 m1a() { return m1a; } public Function1 m2a() { return m2a; } public int m1b() { return m1b; } public int m2b() { return m2b; } public ScalaCurriedFunctions() { } private final Function1 m1a = new Serializable() { public final int apply(int x$1) { return apply$mcII$sp(x$1); } public int apply$mcII$sp(int x$1) { return $outer.max1(1, x$1); } public final volatile Object apply(Object v1) { return BoxesRunTime.boxToInteger(apply(BoxesRunTime.unboxToInt(v1))); } public static final long serialVersionUID = 0L; private final ScalaCurriedFunctions $outer; public { if(ScalaCurriedFunctions.this == null) { throw null; } else { this.$outer = ScalaCurriedFunctions.this; super(); return; } } }; private final Function1 m2a = new Serializable() { public final int apply(int x$2) { return apply$mcII$sp(x$2); } public int apply$mcII$sp(int x$2) { return $outer.max2(1, x$2); } public final volatile Object apply(Object v1) { return BoxesRunTime.boxToInteger(apply(BoxesRunTime.unboxToInt(v1))); } public static final long serialVersionUID = 0L; private final ScalaCurriedFunctions $outer; public { if(ScalaCurriedFunctions.this == null) { throw null; } else { this.$outer = ScalaCurriedFunctions.this; super(); return; } } }; private final int m1b = max1(1, 2); private final int m2b = max2(1, 2); }
scalac Xprint parser phases
In a related note, if you haven’t seen them before, you can see the scalac
parser phases with the following command:
$ scalac -Xshow-phases phase name id description ---------- -- ----------- parser 1 parse source into ASTs, perform simple desugaring namer 2 resolve names, attach symbols to named trees packageobjects 3 load package objects typer 4 the meat and potatoes: type the trees patmat 5 translate match expressions superaccessors 6 add super accessors in traits and nested classes extmethods 7 add extension methods for inline classes pickler 8 serialize symbol tables refchecks 9 reference/override checking, translate nested objects uncurry 10 uncurry, translate function values to anonymous classes tailcalls 11 replace tail calls by jumps specialize 12 @specialized-driven class and method specialization explicitouter 13 this refs to outer pointers erasure 14 erase types, add interfaces for traits posterasure 15 clean up erased inline classes lazyvals 16 allocate bitmaps, translate lazy vals into lazified defs lambdalift 17 move nested functions to top level constructors 18 move field definitions into constructors flatten 19 eliminate inner classes mixin 20 mixin composition cleanup 21 platform-specific cleanups, generate reflective calls delambdafy 22 remove lambdas icode 23 generate portable intermediate code jvm 24 generate JVM bytecode terminal 25 the last phase during a compilation run
You can also see them by running the man
command on scalac
, i.e., man scalac
.
this post is sponsored by my books: | |||
#1 New Release |
FP Best Seller |
Learn Scala 3 |
Learn FP Fast |
Summary
I don’t have any major conclusions here. I just wanted to look at how the Scala compiler treats “regular” functions, curried functions, and partially-applied functions, and as you can see, when you get to a bytecode level all of these examples look identical.