How to create and use partial functions in Scala

This is an excerpt from the 1st Edition of the Scala Cookbook (partially modified for the internet). This is Recipe 9.8, “How to create and use partial functions in Scala.”

Problem

You want to define a Scala function that will (a) only work for a subset of possible input values, or (b) you want to define a series of functions that only work for a subset of input values, and combine those functions to completely solve a problem.

Solution: Scala partial functions

A Scala partial function is a function that does not provide an answer for every possible input value it can be given. It provides an answer only for a subset of possible data, and defines the data it can handle. In Scala, a partial function can also be queried to determine if it can handle a particular value.

To demonstrate this, imagine a normal function that divides one number by another:

val divide = (x: Int) => 42 / x

As defined, this function blows up when the input parameter is zero:

scala> divide(0)
java.lang.ArithmeticException: / by zero

Although you can handle this particular situation by catching and throwing an exception, Scala lets you define the divide function as a PartialFunction. When doing so, you also explicitly state that the function is defined when the input parameter is not zero:

val divide = new PartialFunction[Int, Int] {
    def apply(x: Int) = 42 / x
    def isDefinedAt(x: Int) = x != 0
}

With this approach, you can do several nice things. One thing you can do is test the function before you attempt to use it:

scala> divide.isDefinedAt(1)
res0: Boolean = true

scala> if (divide.isDefinedAt(1)) divide(1)
res1: AnyVal = 42

scala> divide.isDefinedAt(0)
res2: Boolean = false

This isn’t all you can do with partial functions. You’ll see shortly that other code can take advantage of partial functions to provide elegant and concise solutions.

Whereas that divide function is explicit about what data it handles, partial functions are often written using case statements:

val divide2: PartialFunction[Int, Int] = {
    case d: Int if d != 0 => 42 / d
}

Although this code doesn’t explicitly implement the isDefinedAt method, it works exactly the same as the previous divide function definition:

scala> divide2.isDefinedAt(0)
res0: Boolean = false

scala> divide2.isDefinedAt(1)
res1: Boolean = true

The PartialFunction explained

The PartialFunction Scaladoc describes a partial function in this way:

“A partial function of type PartialFunction[A, B] is a unary function where the domain does not necessarily include all values of type A. The function isDefinedAt allows [you] to test dynamically if a value is in the domain of the function.”

This helps to explain why the last example with the match expression (case statement) works: the isDefinedAt method dynamically tests to see if the given value is in the domain of the function (i.e., it is handled, or accounted for).

The signature of the PartialFunction trait looks like this:

trait PartialFunction[-A, +B] extends (A) => B

As discussed in other recipes, the => symbol can be thought of as a transformer, and in this case, the (A) => B can be interpreted as a function that transforms a type A into a resulting type B.

The example method transformed an input Int into an output Int, but if it returned a String instead, it would be declared like this:

PartialFunction[Int, String]

For example, the following method uses this signature:

// converts 1 to "one", etc., up to 5
val convertLowNumToString = new PartialFunction[Int, String] {
    val nums = List("one", "two", "three", "four", "five")
    def apply(i: Int) = nums(i-1)
    def isDefinedAt(i: Int) = i > 0 && i < 6
}

orElse and andThen

A terrific feature of partial functions is that you can chain them together. For instance, one method may only work with even numbers, and another method may only work with odd numbers. Together they can solve all integer problems.

In the following example, two functions are defined that can each handle a small number of Int inputs, and convert them to String results:

// converts 1 to "one", etc., up to 5
val convert1to5 = new PartialFunction[Int, String] {
    val nums = List("one", "two", "three", "four", "five")
    def apply(i: Int) = nums(i-1)
    def isDefinedAt(i: Int) = i > 0 && i < 6
}

// converts 6 to "six", etc., up to 10
val convert6to10 = new PartialFunction[Int, String] {
    val nums = List("six", "seven", "eight", "nine", "ten")
    def apply(i: Int) = nums(i-6)
    def isDefinedAt(i: Int) = i > 5 && i < 11
}

Taken separately, they can each handle only five numbers. But combined with orElse, they can handle ten:

scala> val handle1to10 = convert1to5 orElse convert6to10
handle1to10: PartialFunction[Int,String] = <function1>

scala> handle1to10(3)
res0: String = three

scala> handle1to10(8)
res1: String = eight

The orElse method comes from the Scala PartialFunction trait, which also includes the andThen method to further help chain partial functions together.

Discussion

It’s important to know about partial functions, not just to have another tool in your toolbox, but because they are used in the APIs of some libraries, including the Scala collections library.

One example of where you’ll run into partial functions is with the collect method on collections’ classes. The collect method takes a partial function as input, and as its Scaladoc describes, collect “Builds a new collection by applying a partial function to all elements of this list on which the function is defined.”

For instance, the divide function shown earlier is a partial function that is not defined at the Int value zero. Here’s that function again:

val divide: PartialFunction[Int, Int] = {
    case d: Int if d != 0 => 42 / d
}

If you attempt to use this function with the map method, it will explode with a MatchError:

scala> List(0,1,2) map { divide }
scala.MatchError: 0 (of class java.lang.Integer)
stack trace continues ...

However, if you use the same function with the collect method, it works fine:

scala> List(0,1,2) collect { divide }
res0: List[Int] = List(42, 21)

This is because the collect method is written to test the isDefinedAt method for each element it’s given. As a result, it doesn’t run the divide algorithm when the input value is 0 (but does run it for every other element).

You can see the collect method work in other situations, such as passing it a List that contains a mix of data types, with a function that works only with Int values:

scala> List(42, "cat") collect { case i: Int => i + 1 }
res0: List[Int] = List(43)

Because it checks the isDefinedAt method under the covers, collect can handle the fact that your anonymous function can’t work with a String as input.

The PartialFunction Scaladoc demonstrates this same technique in a slightly different way. In the first example, it shows how to create a list of even numbers by defining a PartialFunction named isEven, and using that function with the collect method:

scala> val sample = 1 to 5
sample: scala.collection.immutable.Range.Inclusive = Range(1, 2, 3, 4, 5)

scala> val isEven: PartialFunction[Int, String] = {
     |  case x if x % 2 == 0 => x + " is even"
     | }
isEven: PartialFunction[Int,String] = <function1>

scala> val evenNumbers = sample collect isEven
evenNumbers: scala.collection.immutable.IndexedSeq[String] =
  Vector(2 is even, 4 is even)

Similarly, an isOdd function can be defined, and the two functions can be joined by orElse to work with the map method:

scala> val isOdd: PartialFunction[Int, String] = {
     |    case x if x % 2 == 1 => x + " is odd"
     | }
isOdd: PartialFunction[Int,String] = <function1>

scala> val numbers = sample map (isEven orElse isOdd)
numbers: scala.collection.immutable.IndexedSeq[String] =
    Vector(1 is odd, 2 is even, 3 is odd, 4 is even, 5 is odd)

Portions of this recipe were inspired by Erik Bruchez’s blog post, titled, “Scala partial functions (without a PhD).”

See Also