Scala: How to combine map and flatten with flatMap

This is an excerpt from the Scala Cookbook (partially modified for the internet). This is Recipe 10.16, “How to Combine map and flatten with flatMap


When you first come to Scala from an object-oriented programming background, the flatMap method can seem very foreign, so you’d like to understand how to use it and see where it can be applied.


Use flatMap in situations where you run map followed by flatten. The specific situation is this:

  • You’re using map (or a for/yield expression) to create a new collection from an existing collection.
  • The resulting collection is a list of lists.
  • You call flatten immediately after map (or a for/yield expression).

When you’re in this situation, you can use flatMap instead.

The next example shows how to use flatMap with an Option. In this example, you’re told that you should calculate the sum of the numbers in a list, with one catch: the numbers are all strings, and some of them won’t convert properly to integers. Here’s the list:

val bag = List("1", "2", "three", "4", "one hundred seventy five")

To solve the problem, you begin by creating a “string to integer” conversion method that returns either Some[Int] or None, based on the String it’s given:

def toInt(in: String): Option[Int] = {
    try {
    } catch {
        case e: Exception => None

With this method in hand, the resulting solution is surprisingly simple:

scala> bag.flatMap(toInt).sum
res0: Int = 7


To see how this works, break the problem down into smaller steps. First, here’s what happens when you use map on the initial collection of strings:

res0: List[Option[Int]] = List(Some(1), Some(2), None, Some(4), None)

The map method applies the toInt function to each element in the collection, and returns a list of Some[Int] and None values. But the sum method needs a List[Int]; how do you get there from here?

As shown in the previous recipe, flatten works very well with a list of Some and None elements. It extracts the values from the Some elements while discarding the None elements:

res1: List[Int] = List(1, 2, 4)

This makes finding the sum easy:

res2: Int = 7

Now, whenever I see map followed by flatten, I think “flat map,” so I get back to the earlier solution:

scala> bag.flatMap(toInt).sum
res3: Int = 7

Actually, I think, “map flat,” but the method is named flatMap.

As you can imagine, once you get the original list down to a List[Int], you can call any of the powerful collections methods to get what you want:

scala> bag.flatMap(toInt).filter(_ > 1)
res4: List[Int] = List(2, 4)

scala> bag.flatMap(toInt).takeWhile(_ < 4)
res5: List[Int] = List(1, 2)

scala> bag.flatMap(toInt).partition(_ > 3)
res6: (List[Int], List[Int]) = (List(4),List(1, 2))

As a second example of using flatMap, imagine you have a method that finds all the sub-words from a word you give it. Skipping the implementation for a moment, if you call the method with the string then, it should work as follows:

scala> subWords("then")
res0: List[String] = List(then, hen, the)

subWords should also return the string he, but it’s in beta.

With that method working (mostly), it can be called on a list of words with map:

scala> val words = List("band", "start", "then")
words: List[java.lang.String] = List(band, start, then)

res0: List[List[String]] = List(List(band, and, ban), List(start, tart, star), List(then, hen, the))

Very cool, you have a list of sub-words for all the given words. One problem, though: map gave you a list of lists. What to do? Call flatten:

res1: List[String] = List(band, and, ban, start, tart, star, then, hen, the)

Success! You have a list of all the sub-words from the original list of words. But notice what you did: You called map, then flatten. Enter “map flat,” er, flatMap:

scala> words.flatMap(subWords)
res2: List[String] = List(band, and, ban, start, tart, star, then, hen, the)

General rule: Whenever you think map followed by flatten, use flatMap. Eventually your brain will skip over the intermediate steps.

As for the implementation of subWords ... well, it’s a work in progress:

def subWords(word: String) = List(word, word.tail, word.take(word.length-1))

See Also