Scala: Passing a function literal as a function argument

Note: The following Scala examples of passing a function as an argument to another function have all been taken from the PDFs on the Scala website. The only thing I've done here is to add comments to the source code, and add detailed discussions of them in this article.

1) A simple Scala function argument example

Having gone directly from C programming to Java (I skipped most of C++), I missed working with “function pointers” for the most part. That being said, I am used to a variety of different ways to simulate callbacks in Java, Drupal, and other languages/tools, so when I saw the following Scala example that just happened to include the word callback, the light when off in my head regarding how passing a function as a function argument in Scala works:

// a class/object to demonstrate how to pass a function
// to another function in scala
object Timer {

  // this function takes another function as an argument.
  // that function takes no args, and doesn't return anything.
  def oncePerSecond(callback: () => Unit) {
    while (true) { callback(); Thread.sleep(1000) }
  }

  // the function we'll pass in to oncePerSecond.
  // this can be any function that takes no args and doesn't
  // return anything.
  def timeFlies() {
    println("time flies like an arrow ...")
  }

  // the main() method, where we pass timeFlies into oncePerSecond.
  def main(args: Array[String]) {
    oncePerSecond(timeFlies)
  }
}

Because of the simplicity of that demo and the use of the word callback as a variable name, this example was very easy for me to digest.

As you can see from the example:

  • The function named oncePerSecond takes a function as its only argument.
  • That function takes no parameters, and does not return anything (implied by the use of the Unit class).
  • The oncePerSecond function invokes the callback function within its while loop.
  • The oncePerSecond function has no idea what the name of the function is that its given; it simply refers to it by the variable name callback.
  • The timeFlies function is passed into the oncePerSecond function in the main method of the Timer object.

2) Passing an anonymous function as a function argument

As a second example of a function being passed as a variable to another function in Scala, that first example is modified so instead of passing a named function into the oncePerSecond function, we pass in an anonymous function directly from the main method. Here's that source code:

// a class/object to demonstrate how to pass an anonymous function
// to another function in scala
object Timer {

  // takes a function that receives no args and doesn't return anything
  def oncePerSecond(callback: () => Unit) {
    while (true) { callback(); Thread.sleep(1000) }
  }

  def main(args: Array[String]) {
    // pass in an anonymous function to the oncePerSecond function
    oncePerSecond(() =>
      println("time flies like an arrow ...")
    )
  }
}

As you can see, the timeFlies function has been removed, and replaced by the anonymous function code in the main method.

3) A more complicated Scala function to function example

As a final example of passing one function as an argument to another Scala function, the next example shows how several different functions are passed into a function named sum(). As you can see from the definition of the sum function, its first argument is a function which it names f, and that function takes one Int as a parameter, and returns an Int as a function result:

// three args are passed in:
// 1) 'f' - a function that takes an Int and returns an Int
// 2) 'a' - an Int
// 3) 'b' - an Int
def sum(f: Int => Int, a: Int, b: Int): Int = if (a > b) 0 else f(a) + sum(f, a + 1, b)

// these three functions use the sum() function
def sumInts(a: Int, b: Int): Int = sum(id, a, b)
def sumSquares(a: Int, b: Int): Int = sum(square, a, b)
def sumPowersOfTwo(a: Int, b: Int): Int = sum(powerOfTwo, a, b)

// three functions that are passed into the sum() function
def id(x: Int): Int = x
def square(x: Int): Int = x * x
def powerOfTwo(x: Int): Int = if (x == 0) 1 else 2 * powerOfTwo(x - 1)

// this simply prints the number 10
println("sum ints 1 to 4 = " + sumInts(1,4))

The functions sumInts, sumSquares, and sumPowersOfTwo all call the sum function, and pass in different functions as the first argument in their call to sum.

4) Passing a function literal that has arguments and returns a value

In those last examples I started to demonstrate how to pass a function literal as an argument to another function. If it helps to see another example of this, my next source code example will demonstrate (a) how to create a function literal that accepts an argument and returns a value, (b) how to pass that function literal to another function, (c) how to define the function that accepts that function literal, and (d) how to invoke the function literal from inside the other function.

Here's a quick discussion of how this will work:

  • The main method below defines a function literal named fx.
  • That function literal accepts an argument (a Double), and returns a value (a Double).
  • The main method passes fx to another function named halveTheInterval.
  • You can see how halveTheInterval accepts the function literal argument, then invokes the function literal from inside its function.

Here's the Scala source code to demonstrate all of this. I've made the most interesting points bold.

object IntervalHalving2 {

  def main(args: Array[String]) {
      val x1 = 1.0
      val x2 = 2.0
      val tolerance = 0.00005

      // define the f(x) function here
      val fx = (x: Double) => x*x*x + x*x - 3*x -3
    
      // pass the f(x) function as a parameter to the new halveTheInterval function
      val answer = halveTheInterval(fx, x1, x2, tolerance)

      // print the answer
      println(answer)
  }

  /**
   * the first argument to this function is a function literal.
   */
  def halveTheInterval(fx: Double => Double, x1:Double, x2:Double, tolerance:Double): Double = {
      var x1wkg = x1
      var x2wkg = x2
      while (Math.abs(x1wkg-x2wkg) > tolerance) {
          var x3 = (x1wkg + x2wkg)/2.0
          if (signsAreOpposite(fx(x3), fx(x1wkg))) x2wkg = x3 else x1wkg = x3
      }
      return x1wkg
  }

  def signsAreOpposite(x: Double, y: Double):Boolean = {
      if (x < 0 && y > 0) return true
      else if (x > 0 && y < 0) return true
      else return false
  }

}

My intent in this tutorial is to demonstrate Scala functions and function literals, but if you're interested in how this algorithm works, see my Interval halving (bisection) method in Scala tutorial.

Summary

Here’s a quick summary of what I showed in these Scala function literal examples:

  1. One function can be passed to another function as a function argument (i.e., a function input parameter).
  2. The definition of the function that can be passed in as defined with syntax that looks like this:

    "f: Int > Int", or this: "callback: () > Unit".

I’ll try to add some more original content here over time, but until then, I hope the additional documentation on these Scala function argument examples has been helpful. (See this link for more information on Scala and functional programming.)