Simple Scala recursion examples (recursive programming)

As I’ve been learning more about Scala and functional programming, I’ve been looking at accomplishing more tasks with recursive programming techniques. As part of my studies I put together a number of Scala recursion examples below, including:

  • Sum
  • Product
  • Max
  • Fibonacci
  • Factorial

I won’t write too much about recursion theory today, just some basic thoughts. I’ll come back here and add more when I have some good thoughts or better examples to share.

Thinking in recursion

When I’m going to write a recursive method, I usually think about it like this:

  • I know I want to do something with a collection of data elements.
  • Therefore, my function will usually take this collection as an argument.
  • Within the function I usually have two branches:
    • In one case, when I’m handling the situation of being at the last element of the collection, I do some “ending” operation. For instance, in the Sum example below, when I get to the Nil element in a List, I return 0 and let the recursive method calls unroll.
    • In the second case, as when the function is not at the end of the list, I write the code for my main algorithm; it operates on the current element in the collection (the ’head’ element); I then recursively call my function, passing it the remainder of the collection (the ’tail’).
  • When the function calls unroll, the function returns whatever it is that I’m calculating. For instance, in the sum, product, and max functions that follow, the function returns an Int. In the Fibonacci example the function prints its result as it goes along, so it doesn’t return anything.

As another note, in some cases it helps to have an “accumulator” function inside your main function. I show this in the examples that follow, and I’ll describe it more at some point in the future.

A recursive ‘sum’ function

The following code shows three ways to calculate the sum of a List[Int] recursively. I don’t think the first approach is practical; it is simple, but results in a StackOverflowError when the list is large.

The second approach shows how to fix the first approach by using a tail-recursive algorithm. This solution uses the “accumulator” I mentioned above.

The third approach shows how to use an if/else construct instead of a match expression. It’s taken from the URL shown.

With that introduction, here’s the code:

package recursion

import scala.annotation.tailrec

/**
 * Different ways to calculate the sum of a list using 
 * recursive Scala methods.
 */
object Sum extends App {

  val list = List.range(1, 100)
  println(sum(list))
  println(sum2(list))
  println(sum3(list))
  println(sumWithReduce(list))
  
  // (1) yields a "java.lang.StackOverflowError" with large lists
  def sum(ints: List[Int]): Int = ints match { 
    case Nil => 0
    case x :: tail => x + sum(tail)
  }

  // (2) tail-recursive solution
  def sum2(ints: List[Int]): Int = {
    @tailrec
    def sumAccumulator(ints: List[Int], accum: Int): Int = {
      ints match {
        case Nil => accum
        case x :: tail => sumAccumulator(tail, accum + x)
      }
    }
    sumAccumulator(ints, 0)
  }
  
  // (3) good descriptions of recursion here:
  // stackoverflow.com/questions/12496959/summing-values-in-a-list
  // this example is from that page:
  def sum3(xs: List[Int]): Int = {
    if (xs.isEmpty) 0
    else xs.head + sum3(xs.tail)
  }
  
}

I don’t want to stray too far from the point of this article, but while I’m talking about “sum” algorithms, another way you can calculate the sum of a List[Int] in Scala is to use the reduceLeft method on the List:

def sumWithReduce(ints: List[Int]) = {
  ints.reduceLeft(_ + _)
}

(That’s all I’ll say about reduceLeft today.)

Calculating the “product” of a List[Int] recursively

Calculating the product of a List[Int] is very similar to calculating the sum; you just multiply the values inside the function, and return 1 in the Nil case. Therefore I’ll just show the following code without discussing it:

package recursion

import scala.annotation.tailrec

/**
 * Different ways to calculate the product of a List[Int] recursion.
 */
object Product extends App {

    val list = List(1, 2, 3, 4)
    println(product(List(1, 2, 3, 4)))
    println(product2(List(1, 2, 3, 4)))

    // (1) basic recursion; yields a "java.lang.StackOverflowError" with large lists
    def product(ints: List[Int]): Int = ints match { 
        case Nil => 1
        case x :: tail => x * product(tail)
    }

    // (2) tail-recursive solution
    def product2(ints: List[Int]): Int = {
      @tailrec
      def productAccumulator(ints: List[Int], accum: Int): Int = {
          ints match {
              case Nil => accum
              case x :: tail => productAccumulator(tail, accum * x)
          }
      }
      productAccumulator(ints, 1)
  }

}

Calculating the “max” of a List[Int] recursively

Calculating the “max” of a List[Int] recursively is a little different than calculating the sum or product. In this algorithm you need to keep track of the highest value found as you go along, so I jump right into using an accumulator function inside the outer function.

I show two approaches in the source code below, the first using a match expression and the second using an if/else expression:

package main.scala.recursion

import scala.annotation.tailrec

object Max extends App {

  val list = List.range(1, 100000)
  println(max(list))
  println(max2(list))
  
  // 1 - using `match`
  def max(ints: List[Int]): Int = { 
    @tailrec
    def maxAccum(ints: List[Int], theMax: Int): Int = {
      ints match {
        case Nil => theMax
        case x :: tail =>
          val newMax = if (x > theMax) x else theMax
          maxAccum(tail, newMax)
      }
    }
    maxAccum(ints, 0)
  }

  // 2 - using if/else
  def max2(ints: List[Int]): Int = { 
    @tailrec
    def maxAccum2(ints: List[Int], theMax: Int): Int = {
      if (ints.isEmpty) {
        return theMax
      } else {
        val newMax = if (ints.head > theMax) ints.head else theMax
        maxAccum2(ints.tail, newMax)
      }
    }
    maxAccum2(ints, 0)
  }

}

A Scala Fibonacci recursion example

The code below shows one way to calculate a Fibonacci sequence recursively using Scala:

package recursion

/**
 * Calculating a Fibonacci sequence recursively using Scala.
 */
object Fibonacci extends App {

    println(fib(1, 2))
  
    def fib(prevPrev: Int, prev: Int) {
        val next = prevPrev + prev
        println(next)
        if (next > 1000000) System.exit(0)
        fib(prev, next)
    }

}

There are other ways to calculate a Fibonacci sequence, but since my function takes two Int values as arguments and prints as it goes along, this solution works.

A tail-recursive Fibonacci recursion example

Here’s another example of how to write a Fibonacci method, this time using a tail-recursive algorithm:

import scala.annotation.tailrec

/**
 * The `fibHelper` code comes from this url: rosettacode.org/wiki/Fibonacci_sequence#Scala  
 */
object FibonacciTailRecursive extends App {
    
    println(fib(9))

    def fib(x: Int): BigInt = {
        @tailrec def fibHelper(x: Int, prev: BigInt = 0, next: BigInt = 1): BigInt = x match {
            case 0 => prev
            case 1 => next
            case _ => fibHelper(x - 1, next, (next + prev))
        }
        fibHelper(x)
    }

}

As I’ve learned more about functional programming in Scala, I’ve come to prefer approaches like this. A cool thing about the fib method is that it has another method named fibHelper embedded inside of it. Besides just being cool that you can do even that, it’s nice because it helps limit the scope of the fibHelper method.

Recursive factorial algorithms

Finally, without much discussion, the following Scala code shows two different recursive factorial algorithms, with the second solution showing the tail-recursive solution:

package recursion

import scala.annotation.tailrec

object Factorial extends App {

    println(factorial(5))
    println(factorial2(5))

    // 1 - basic recursive factorial method
    def factorial(n: Int): Int = { 
        if (n == 0) 1
        else        n * factorial(n-1)
    }
  
    // 2 - tail-recursive factorial method
    def factorial2(n: Long): Long = {
        @tailrec
        def factorialAccumulator(acc: Long, n: Long): Long = {
            if (n == 0) acc
            else factorialAccumulator(n*acc, n-1)
        }
        factorialAccumulator(1, n)
    }
 
}

Summary

I hope these examples of recursive programming techniques in Scala have been helpful. As mentioned, I’ll try to add more when I can, but until then, I hope this collection of recursion examples has been helpful.