JIT: Recursion (Scala 3 Video)
The other programming technique you’ll have to have some awareness of before the next lesson is recursion. Recursion can be really fun, and it’s an important FP technique, so I spend over 50 pages explaining it in Functional Programming, Simplified.
Update: You can now learn about recursion for free in my free PDF, Learning Recursion.
Writing a recursive function in Scala
In short, a recursive function is a function that calls itself. To demonstrate this, let’s create a recursive function that “counts down” to 0
. The idea is that you give it some integer, and it prints down to 0
, as shown in the REPL:
scala> countdown(3)
3
2
1
Right away you can see that this function takes an Int
parameter:
def countdown(i: Int)
You can also see that it prints its result, and doesn’t return anything, so its return type is Unit
(which is like void
in Java and other languages):
def countdown(i: Int): Unit =
Normally I implement most recursive functions using match
expressions, but to keep this familiar for OOP developers, I’ll use an if
/then
expression here.
The first thing I usually do when I write a recursive function is to specify its end or stop condition. In this case, I know that I want the recursion to stop when the Int
parameter is 0
, so I write that condition first:
def countdown(i: Int): Unit =
if i == 0 then
// STOP HERE
At this point I know two things: (a) I want the recursive calls to stop here, and (b) the function needs to return Unit
. Therefore, I return the Unit
value — ()
— here:
def countdown(i: Int): Unit =
if i == 0 then
()
The symbol ()
is the one and only instance of the Unit
type, and as you’ll see, this is how the recursive calls stop: they hit this if
condition, ()
is yielded, and the recursive calls unroll.
Now that I have the stop condition I want, I write the rest of the algorithm, i.e., the portion of the code that calls itself. Here I know from the function’s output that I want to do two things: (a) print the current value of i
, and (b) make the recursive call.
An important part of the second step is that I reduce the value of i
by 1
when making the call. If I fail to do this, the recursion will continue forever, so this is another key:
def countdown(i: Int): Unit =
if i == 0 then
()
else
println(i)
countdown(i - 1) // the recursive call
I don’t want to make light of it, but basically that’s all there is to recursion. Please verify that code in the REPL, and make sure you’re comfortable with it.
A recursion example with a Scala match expression
A fun recursive algorithm to write is a “sum” algorithm for a List[Int]
. A key here is to know that a Scala List
is implemented just like a Lisp list, as a series of cons cells that end with a Nil
value:
1 :: 2 :: 3 :: Nil
I believe this is the type of linked-list you’re taught in college, so I’m going to assume that you’ve at least seen or heard of this before, but the key is that ending Nil
value.
Getting into the algorithm, we know that we want to write a sum function:
def sum
It takes a List[Int]
parameter:
def sum(list: List[Int])
and it returns a sum of the ints:
def sum(list: List[Int]): Int
I also stated that I want to implement this with a match
expression:
def sum(list: List[Int]): Int = list match
Again, I like to write the stop condition first, and the stop condition when you work on a List
is always the Nil
value:
def sum(list: List[Int]): Int = list match
case Nil =>
Identity value
One reason I’m sharing this example is that it also helps to know a mathematical concept called an identity value (or element). In short, when you’re working with (a) a set of integers and (b) a sum algorithm, the identity element is the value 0
. This means that when you sum a list of integers, the value 0
adds nothing to that sum.
Similarly, the value 1
is the identity value for (a) a list of integers along with (b) a product algorithm: when you multiply an integer by 1
, it doesn’t modify the value. This is a key concept to know when working with recursion.
Back to sum
Getting back to the sum
function, the identity element for a sum algorithm tells me that the stop condition for the recursion is to yield the value 0
:
def sum(list: List[Int]): Int = list match
case Nil => 0
Next, I need to add the recursive call. The way you do this with match
expressions looks like this:
def sum(list: List[Int]): Int = list match
case Nil =>
0
case head :: tail =>
head + sum(tail)
Here’s how that works:
- On the left side of the
case
, I breaklist
into two elements,head
andtail
- Notice that those are separated by the
::
symbol - When working with a
List
, that meanshead
is a variable that contains a single element (anInt
), andtail
is a variable that contains the rest of the list (anList[Int]
) - After that, this last line of code can be read as, “Add the
head
element to the sum of all remaining elements”:
head + sum(tail)
Now if you paste that function into the REPL with a sample List[Int]
, you’ll see a result like this:
scala> sum(List(1,2,3))
val res0: Int = 6
Tail recursion
As you get into recursion, you’ll also want to learn about tail recursion, and if you’re programming on the JVM, you’ll want to make sure you know about stacks and frames. I cover all of this in more than 50 pages of recursion content in my book, Functional Programming, Simplified.
Other recursion links
If you’re interested in more details on recursion in Scala, I’ve created a few resources about it:
Update: All of my new videos are now on
LearnScala.dev