JIT: Recursion

The other programming technique you’ll have to have some awareness of before the next lesson is recursion. Recursion can be really fun, and it’s an important FP technique, so I spend over 50 pages explaining it in Functional Programming, Simplified.

Update: You can now learn about recursion for free in my free PDF, Learning Recursion.

Writing a recursive function in Scala

In short, a recursive function is a function that calls itself. To demonstrate this, let’s create a recursive function that “counts down” to 0. The idea is that you give it some integer, and it prints down to 0, as shown in the REPL:

scala> countdown(3)

Right away you can see that this function takes an Int parameter:

def countdown(i: Int)

You can also see that it prints its result, and doesn’t return anything, so its return type is Unit (which is like void in Java and other languages):

def countdown(i: Int): Unit =

Normally I implement most recursive functions using match expressions, but to keep this familiar for OOP developers, I’ll use an if/then expression here.

The first thing I usually do when I write a recursive function is to specify its end or stop condition. In this case, I know that I want the recursion to stop when the Int parameter is 0, so I write that condition first:

def countdown(i: Int): Unit =
    if i == 0 then
        // STOP HERE

At this point I know two things: (a) I want the recursive calls to stop here, and (b) the function needs to return Unit. Therefore, I return the Unit value() — here:

def countdown(i: Int): Unit =
    if i == 0 then

The symbol () is the one and only instance of the Unit type, and as you’ll see, this is how the recursive calls stop: they hit this if condition, () is yielded, and the recursive calls unroll.

Now that I have the stop condition I want, I write the rest of the algorithm, i.e., the portion of the code that calls itself. Here I know from the function’s output that I want to do two things: (a) print the current value of i, and (b) make the recursive call.

An important part of the second step is that I reduce the value of i by 1 when making the call. If I fail to do this, the recursion will continue forever, so this is another key:

def countdown(i: Int): Unit =
    if i == 0 then
        countdown(i - 1)   // the recursive call

I don’t want to make light of it, but basically that’s all there is to recursion. Please verify that code in the REPL, and make sure you’re comfortable with it.

A recursion example with a Scala match expression

A fun recursive algorithm to write is a “sum” algorithm for a List[Int]. A key here is to know that a Scala List is implemented just like a Lisp list, as a series of cons cells that end with a Nil value:

1 :: 2 :: 3 :: Nil

I believe this is the type of linked-list you’re taught in college, so I’m going to assume that you’ve at least seen or heard of this before, but the key is that ending Nil value.

Getting into the algorithm, we know that we want to write a sum function:

def sum

It takes a List[Int] parameter:

def sum(list: List[Int])

and it returns a sum of the ints:

def sum(list: List[Int]): Int

I also stated that I want to implement this with a match expression:

def sum(list: List[Int]): Int = list match

Again, I like to write the stop condition first, and the stop condition when you work on a List is always the Nil value:

def sum(list: List[Int]): Int = list match
    case Nil =>

Identity value

One reason I’m sharing this example is that it also helps to know a mathematical concept called an identity value (or element). In short, when you’re working with (a) a set of integers and (b) a sum algorithm, the identity element is the value 0. This means that when you sum a list of integers, the value 0 adds nothing to that sum.

Similarly, the value 1 is the identity value for (a) a list of integers along with (b) a product algorithm: when you multiply an integer by 1, it doesn’t modify the value. This is a key concept to know when working with recursion.

Back to sum

Getting back to the sum function, the identity element for a sum algorithm tells me that the stop condition for the recursion is to yield the value 0:

def sum(list: List[Int]): Int = list match
    case Nil => 0

Next, I need to add the recursive call. The way you do this with match expressions looks like this:

def sum(list: List[Int]): Int = list match
    case Nil => 
    case head :: tail =>
        head + sum(tail)

Here’s how that works:

  • On the left side of the case, I break list into two elements, head and tail
  • Notice that those are separated by the :: symbol
  • When working with a List, that means head is a variable that contains a single element (an Int), and tail is a variable that contains the rest of the list (an List[Int])
  • After that, this last line of code can be read as, “Add the head element to the sum of all remaining elements”:
head + sum(tail)

Now if you paste that function into the REPL with a sample List[Int], you’ll see a result like this:

scala> sum(List(1,2,3))
val res0: Int = 6

Tail recursion

As you get into recursion, you’ll also want to learn about tail recursion, and if you’re programming on the JVM, you’ll want to make sure you know about stacks and frames. I cover all of this in more than 50 pages of recursion content in my book, Functional Programming, Simplified.

Other recursion links

If you’re interested in more details on recursion in Scala, I’ve created a few resources about it: