How to Write a ‘map’ Function in Scala

“He lunged for the maps. I grabbed the chair and hit him with it. He went down. I hit him again to make sure he stayed that way, stepped over him, and picked up the maps.”

Ilona Andrews, Magic Burns

(This is a chapter from my book, Functional Programming, Simplified.)

In the previous lesson I showed how to write higher-order functions (HOFs). In this lesson you’ll use that knowledge to write a Scala map function that can work with a List.

Writing a Scala 'map' function

Imagine a world in which you know of the concept of “mapping,” but sadly a map method isn’t built into Scala’s List class. Further imagine that you’re not worried about all lists, you just want a map function for a List[Int].

Knowing that life is better with map, you sit down to write your own map method.

First steps

As I got better at FP, I came to learn that my first actions in writing most functions are:

  1. Accurately state the problem as a sentence
  2. Sketch the function signature

I’ll follow that approach to solve this problem.

Accurately state the problem

For the first step, I’ll state the problem like this:

I want to write a map function that can be used to apply other functions to each element in a List[Int] that it’s given.

Sketch the function signature

My second step is to sketch a function signature that matches that statement. A blank canvas is always hard to look at, so I start with the obvious; I want a map function:

def map

Looking back at the problem statement, what do I know? Well, first, I know that map is going to take a function as an input parameter, and it’s also going to take a List[Int]. Without thinking too much about the input parameters just yet, I can now sketch this:

def map(f: (?) => ?, xs: List[Int]): ???

Knowing how map works, I know that it should return a List that contains the same number of elements that are in the input List. For the moment, the important part about this is that this means that map will return a List of some sort:

def map(f: (?) => ?, xs: List[Int]): List...
                                     ----

Given how map works — it applies a function to every element in the input list — the type of the output List can be anything: a List[Double], List[Float], List[Foo], etc. This tells me that the List that map returns needs to be a generic type, so I add that at the end of the function declaration:

def map(f: (?) => ?, xs: List[Int]): List[A]
                                     -------

Because of Scala’s syntax, I need to add the generic type before the function signature as well:

def map[A](f: (?) => ?, xs: List[Int]): List[A]
       ---

Going through that thought process tells me everything I need to know about the signature for the function input parameter f:

  • Because f’s input parameter will come from the List[Int], the parameter type must be Int
  • Because the overall map function returns a List of the generic type A, f must also return the generic type A

The first statement lets me make this change to the definition of f:

def map[A](f: (Int) => ?, xs: List[Int]): List[A]
               ---

and the second statement lets me make this change:

def map[A](f: (Int) => A, xs: List[Int]): List[A]
                       -

When I define a FIP that has only one input parameter I can leave the parentheses off, so if you prefer that syntax, the finished function signature looks like this:

def map[A](f: Int => A, xs: List[Int]): List[A]

Cool. That seems right. Now let’s work on the function body.

The map function body

A map function works on every element in a list, and because I haven’t covered recursion yet, this means that we’re going to need a for loop to loop over every element in the input list.

Because I know that map returns a list that has one element for each element in the input list, I further know that this loop is going to be a for/yield loop without any filters:

def map[A](f: (Int) => A, xs: List[Int]): List[A] = {
    for {
        x <- xs
    } yield ???
}

The only question now is, what exactly should the loop yield?

(I’ll pause for a moment here to let you think about that.)

The answer is that the for loop should yield the result of applying the input function f to the current element in the loop. Therefore, I can finish the yield expression like this:

def map[A](f: (Int) => A, xs: List[Int]): List[A] = {
    for {
        x <- xs
    } yield f(x)   //<-- apply 'f' to each element 'x'
}

And that is the solution for the problem that was stated.

You can use the REPL to confirm that this solution works as desired. First, paste the map function into the REPL. Then create a list of integers:

scala> val nums = List(1,2,3)
nums: List[Int] = List(1, 2, 3)

Then write a function that matches the signature map expects:

scala> def double(i: Int): Int = i * 2
double: (i: Int)Int

Then you can use map to apply double to each element in nums:

scala> map(double, nums)
res0: List[Int] = List(2, 4, 6)

The map function works.

Bonus: Make it generic

I started off by making map work only for a List[Int], but at this point it’s easy to make it work for any List. This is because there’s nothing inside the map function body that depends on the given List being a List[Int]:

for {
    x <- xs
} yield f(x)

That’s as “generic” as code gets; there are no Int references in there. Therefore, you can make map work with generic types by replacing each Int reference in the function signature with a generic type. Because this type appears before the other generic type in the function signature, I’ll first convert the old A’s to B’s:

def map[B](f: (Int) => B, xs: List[Int]): List[B] = ...
        _              _                         _

Then I replace the Int references with A, and put an A in the opening brackets, resulting in this signature:

def map[A,B](f: (A) => B, xs: List[A]): List[B] = {
        _        _                   _

If you want to take this even further, there’s also nothing in this code that depends on the input “list” being a List. Because map works its way from the first element in the list to the last element, it doesn’t matter if the Seq is an IndexedSeq or a LinearSeq, so you can use the parent Seq class here instead of List:

def map[A,B](f: (A) => B, list: Seq[A]): Seq[B] = {
                                ---      ---

With this new signature, the complete, generic map function looks like this:

def map[A,B](f: (A) => B, list: Seq[A]): Seq[B] = {
    for {
        x <- xs
    } yield f(x)
}

I hope you enjoyed that process. It’s a good example of how I design functions these days, starting with the signature first, and then implementing the function body.

Exercise: Write a 'filter' function

Now that you’ve seen how to write a map function, I encourage you to take the time to write a filter function. Because filter doesn’t return a sequence that’s the same size as the input sequence, its algorithm will be a little different, but it still needs to return a sequence in the end.

What’s next

While this lesson provided a detailed example of how to write a function that takes other functions as an input parameter, the next lesson will show how to write functions that take “blocks of code” as input parameters. That technique and syntax is similar to what I just showed, but the “use case” for this other technique — known as “by-name parameters” — is a little different.

After that lesson, I’ll demonstrate how to combine these techniques with a Scala feature that lets a function have multiple input parameter groups.