How to Enable the Use of Multiple Generators in a ‘for’ Expression

This is a page from my book, Learning Functional Programming in Scala

One cool thing about for expressions is that you can use multiple generators inside of them. This lets you do some nice analytics when you have some interesting data relationships.

For instance, suppose you have some data like this:

case class Person(name: String)

val myFriends = Sequence(
    Person("Adam"),
    Person("David"), 
    Person("Frank")
)

val adamsFriends = Sequence(
    Person("Nick"), 
    Person("David"), 
    Person("Frank")
)

If I want to find out which friends of mine are also friends of Adam, I can write a for expression like this:

val mutualFriends = for {
    myFriend    <- myFriends     // generator
    adamsFriend <- adamsFriends  // generator
    if (myFriend.name == adamsFriend.name)
} yield myFriend
mutualFriends.foreach(println)

Notice how I use two Sequence instances as generators in that for expression.

Sadly, the compiler tells us that this code won’t work, but happily it again tells us why:

<console>:17: error: value flatMap is not a member of Sequence[Person]
           myFriend <- myFriends
                       ^

Since you’re used to reading these error messages now, you know the compiler is telling us that we need to implement a flatMap method in the Sequence class for this code to work.

flatMap, a curious creature

The flatMap method is an interesting creature, and experienced functional programmers seem to use it a lot.

As a bit of background, when you think about map’s signature for a moment, you’ll remember that it looks like this:

def map[B](f: A => B): Sequence[B]

As shown, map takes a function that transforms a type A to a type B, and returns a Sequence[B] when it finishes, transforming all of the elements in the Sequence.

flatMap’s signature is similar to map:

def flatMap[B](f: A => Sequence[B]): Sequence[B]

As this shows, flatMap is similar to map, but it’s also different. The function flatMap takes transforms a type A to a Sequence of type BSequence[B] — and then when it’s finished it also returns a Sequence[B]. The type signatures tell us that the difference between map and flatMap is the type of functions they take as input parameters.

flatMap background

If you come to Scala from a background like Java, after a while it becomes apparent that the map function is very cool -- and you’ll use it all the time -- but it can be hard to find a use for flatMap.

As I wrote in the Scala Cookbook, I like to think of flatMap as “map flat,” because on collections classes it works similarly to a) calling map and then b) calling flatten. As an example of what I mean, these lines of code show the difference between calling map and flatMap on a Seq[String]:

scala> val fruits = Seq("apple", "banana", "orange")
fruits: Seq[java.lang.String] = List(apple, banana, orange)

scala> fruits.map(_.toUpperCase)
res0: Seq[java.lang.String] = List(APPLE, BANANA, ORANGE)

scala> fruits.flatMap(_.toUpperCase)
res1: Seq[Char] = List(A, P, P, L, E, B, A, N, A, N, A, O, R, A, N, G, E)

map applies the function to each element in the input Seq to create a transformed Seq, but flatMap takes the process a step further. In fact, you can show that calling flatMap is just like calling map and then calling flatten:

scala> val mapResult = fruits.map(_.toUpperCase)
mapResult: Seq[String] = List(APPLE, BANANA, ORANGE)

scala> val flattenResult = mapResult.flatten
flattenResult: Seq[Char] = List(A, P, P, L, E, B, A, N, A, N, A, O, R, A, N, G, E)

I won’t show any more examples of flatMap here, but if you want to see more examples of how it works, please see my article, “A collection of Scala flatMap examples.” I also demonstrate more ways to use flatMap in lessons later in this book.

Starting to write a flatMap method

Earlier I showed flatMap’s signature to be:

def flatMap[B](f: A => Sequence[B]): Sequence[B]

Given that signature, and knowing that flatMap works like a map call followed by a flatten call, I can implement flatMap’s function body by calling map and then flatten:

def flatMap[B](f: A => Sequence[B]): Sequence[B] = {
    val mapRes: Sequence[Sequence[B]] = map(f)   //map
    flatten(mapRes)                              //flatten
}

In the first line of the function body I call the map method we developed in the previous lessons, and I also explicitly show the type of mapRes:

val mapRes: Sequence[Sequence[B]] = map(f)   //map

Because this result is a little complicated, I like to make the return type obvious. That’s a great thing about Scala: You don’t have to declare variable types, but you can show them when you want to.

In the second line of this function I quickly run into a problem: I haven’t defined a flatten method yet! Let’s fix that.

Aside: Why I wrote the code as shown

It’s important to note that I wrote the function body like this:

val mapRes: Sequence[Sequence[B]] = map(f)   //map
flatten(mapRes)                              //flatten

I did this because the function input parameter that flatMap takes looks like this:

f: A => Sequence[B]

Because that function transforms a type A into a Sequence[B], I can’t just call map and flatten on elems. For example, this code won’t work:

// this won't work
def flatMap[B](f: A => Sequence[B]): Sequence[B] = {
    val mapRes = elems.map(f)
    mapRes.flatten
}

The reason I can’t cheat like this is because elems.map(f) returns an ArrayBuffer[Sequence[B]], and what I really need is a Sequence[Sequence[B]]. Because of this I need to take the approach I showed earlier:

def flatMap[B](f: A => Sequence[B]): Sequence[B] = {
    val mapRes: Sequence[Sequence[B]] = map(f)   //map
    flatten(mapRes)                              //flatten
}

How flatten works

Getting back to the problem at hand, I need to write a flatten method. If you read the Scala Cookbook you know how a flatten method should work — it converts a “list of lists” to a single list.

A little example demonstrates this. In the REPL you can create a list of lists like this:

val a = List( List(1,2), List(3,4) )

Now when you call flatten on that data structure you get a combined (or flattened) result of List(1,2,3,4). Here’s what it looks like in the REPL:

scala> val a = List( List(1,2), List(3,4) )
a: List[List[Int]] = List(List(1, 2), List(3, 4))

scala> a.flatten
res0: List[Int] = List(1, 2, 3, 4)

Our flatten function should do the same thing.

Writing a flatten function

Seeing how flatten works, I can write pseudocode for a flatten function like this:

create an empty list 'xs'
for each list 'a' in the original listOfLists
    for each element 'e' in the list 'a'
        add 'e' to 'xs'
return 'xs'

In Scala/OOP you can implement that pseudocode like this:

var xs = ArrayBuffer[B]()
for (listB: Sequence[B] <- listOfLists) {
    for (e <- listB) {
        xs += e
    }
}
xs

Because I’m working with my custom Sequence I need to modify that code slightly, but when I wrap it inside a function named flatten, it still looks similar:

def flatten[B](seqOfSeq: Sequence[Sequence[B]]): Sequence[B] = {
    var xs = ArrayBuffer[B]()
    for (listB: Sequence[B] <- seqOfSeq) {
        for (e <- listB) {
            xs += e
        }
    }
    Sequence(xs: _*)
}

The biggest difference here is that I convert the temporary ArrayBuffer to a Sequence in the last step, like this:

Sequence(xs: _*)

From flatten to flattenLike

There’s one problem with this function; the type signature for a flatten function on a Scala List looks like this:

def flatten[B]: List[B]

Because my type signature isn’t the same as that, I’m not comfortable naming it flatten. Therefore I’m going to rename it and also make it private, so the new signature looks like this:

private def flattenLike[B](seqOfSeq: Sequence[Sequence[B]]): Sequence[B]

The short explanation for why the type signatures don’t match up is that my cheating ways have finally caught up with me. Creating Sequence as a wrapper around an ArrayBuffer creates a series of problems if I try to define flatten like this:

def flatten[B](): Sequence[B] = ...

Rather than go into those problems in detail, I’ll leave that as an exercise for the reader. Focusing on the problem at hand — getting a flatMap function working — I’m going to move forward using use my flattenLike function to get flatMap working.

Making flatMap work

Now that I have flattenLike written, I can go back and update flatMap to call it:

def flatMap[B](f: A => Sequence[B]): Sequence[B] = {
    val mapRes: Sequence[Sequence[B]] = map(f)   //map
    flattenLike(mapRes)                          //flatten
}

Testing flatMap

Now that I think I have a working flatMap function, I can add it to the Sequence class. Here’s the complete source code for Sequence, including the functions I just wrote:

import scala.collection.mutable.ArrayBuffer

case class Sequence[A](private val initialElems: A*) {

    // this is a book, don't do this at home
    private val elems = ArrayBuffer[A]()

    // initialize
    elems ++= initialElems

    def flatMap[B](f: A => Sequence[B]): Sequence[B] = {
        val mapRes: Sequence[Sequence[B]] = map(f)   //map
        flattenLike(mapRes)                          //flatten
    }

    private def flattenLike[B](seqOfSeq: Sequence[Sequence[B]]): Sequence[B] = {
        var xs = ArrayBuffer[B]()
        for (listB: Sequence[B] <- seqOfSeq) {
            for (e <- listB) {
                xs += e
            }
        }
        Sequence(xs: _*)
    }

    def withFilter(p: A => Boolean): Sequence[A] = {
        val tmpArrayBuffer = elems.filter(p)
        Sequence(tmpArrayBuffer: _*)
    }

    def map[B](f: A => B): Sequence[B] = {
        val abMap = elems.map(f)
        Sequence(abMap: _*)
    }

    def foreach(block: A => Unit): Unit = {
        elems.foreach(block)
    }

}

When I paste that code into the REPL, and then paste in the code I showed at the beginning of this lesson:

case class Person(name: String)

val myFriends = Sequence(
    Person("Adam"),
    Person("David"), 
    Person("Frank")
)

val adamsFriends = Sequence(
    Person("Nick"), 
    Person("David"), 
    Person("Frank")
)

I can confirm that my for expression with multiple generators now works:

val mutualFriends = for {
    myFriend <- myFriends        // generator
    adamsFriend <- adamsFriends  // generator
    if (myFriend.name == adamsFriend.name)
} yield myFriend

mutualFriends.foreach(println)

The output of that last line looks like this:

scala> mutualFriends.foreach(println)
Person(David)
Person(Frank)

Admittedly I did some serious cheating in this lesson to get a flatMap function working, but as you see, once flatMap is implemented, you can use multiple generators in a for expression.

Summary

I can summarize what I showed in all of the for expression lessons with these lines of code:

// (1) works because `foreach` is defined
for (p <- peeps) println(p)

// (2) `yield` works because `map` is defined
val res: Sequence[Int] = for {
    i <- ints
} yield i * 2
res.foreach(println)

// (3) `if` works because `withFilter` is defined
val res = for {
    i <- ints
    if i > 2
} yield i*2

// (4) works because `flatMap` is defined
val mutualFriends = for {
    myFriend <- myFriends        // generator
    adamsFriend <- adamsFriends  // generator
    if (myFriend.name == adamsFriend.name)
} yield myFriend

books i’ve written