Scala: How to combine map and flatten with flatMap

This is an excerpt from the Scala Cookbook (partially modified for the internet). This is Recipe 10.16, “How to Combine `map` and `flatten` with `flatMap`

Problem

When you first come to Scala from an object-oriented programming background, the `flatMap` method can seem very foreign, so you’d like to understand how to use it and see where it can be applied.

Solution

Use `flatMap` in situations where you run `map` followed by `flatten`. The specific situation is this:

• You’re using `map` (or a for/yield expression) to create a new collection from an existing collection.
• The resulting collection is a list of lists.
• You call `flatten` immediately after `map` (or a for/yield expression).

When you’re in this situation, you can use `flatMap` instead.

The next example shows how to use `flatMap` with an `Option`. In this example, you’re told that you should calculate the sum of the numbers in a list, with one catch: the numbers are all strings, and some of them won’t convert properly to integers. Here’s the list:

`val bag = List("1", "2", "three", "4", "one hundred seventy five")`

To solve the problem, you begin by creating a “string to integer” conversion method that returns either `Some[Int]` or `None`, based on the `String` it’s given:

```def toInt(in: String): Option[Int] = {
try {
Some(Integer.parseInt(in.trim))
} catch {
case e: Exception => None
}
}```

With this method in hand, the resulting solution is surprisingly simple:

```scala> bag.flatMap(toInt).sum
res0: Int = 7```

Discussion

To see how this works, break the problem down into smaller steps. First, here’s what happens when you use map on the initial collection of strings:

```scala> bag.map(toInt)
res0: List[Option[Int]] = List(Some(1), Some(2), None, Some(4), None)```

The `map` method applies the `toInt` function to each element in the collection, and returns a list of `Some[Int]` and `None` values. But the `sum` method needs a `List[Int]`; how do you get there from here?

As shown in the previous recipe, `flatten` works very well with a list of `Some` and `None` elements. It extracts the values from the `Some` elements while discarding the `None` elements:

```scala> bag.map(toInt).flatten
res1: List[Int] = List(1, 2, 4)```

This makes finding the sum easy:

```scala> bag.map(toInt).flatten.sum
res2: Int = 7```

Now, whenever I see `map` followed by `flatten`, I think “flat map,” so I get back to the earlier solution:

```scala> bag.flatMap(toInt).sum
res3: Int = 7```

Actually, I think, “map flat,” but the method is named `flatMap`.

As you can imagine, once you get the original list down to a `List[Int]`, you can call any of the powerful collections methods to get what you want:

```scala> bag.flatMap(toInt).filter(_ > 1)
res4: List[Int] = List(2, 4)

scala> bag.flatMap(toInt).takeWhile(_ < 4)
res5: List[Int] = List(1, 2)

scala> bag.flatMap(toInt).partition(_ > 3)
res6: (List[Int], List[Int]) = (List(4),List(1, 2))```

As a second example of using `flatMap`, imagine you have a method that finds all the sub-words from a word you give it. Skipping the implementation for a moment, if you call the method with the string then, it should work as follows:

```scala> subWords("then")
res0: List[String] = List(then, hen, the)```

`subWords` should also return the string `he`, but it’s in beta.

With that method working (mostly), it can be called on a list of words with `map`:

```scala> val words = List("band", "start", "then")
words: List[java.lang.String] = List(band, start, then)

scala> words.map(subWords)
res0: List[List[String]] = List(List(band, and, ban), List(start, tart, star), List(then, hen, the))```

Very cool, you have a list of sub-words for all the given words. One problem, though: `map` gave you a list of lists. What to do? Call `flatten`:

```scala> words.map(subWords).flatten
res1: List[String] = List(band, and, ban, start, tart, star, then, hen, the)```

Success! You have a list of all the sub-words from the original list of words. But notice what you did: You called `map`, then `flatten`. Enter “map flat,” er, `flatMap`:

```scala> words.flatMap(subWords)
res2: List[String] = List(band, and, ban, start, tart, star, then, hen, the)```

General rule: Whenever you think map followed by `flatten`, use `flatMap`. Eventually your brain will skip over the intermediate steps.

As for the implementation of `subWords` ... well, it’s a work in progress:

`def subWords(word: String) = List(word, word.tail, word.take(word.length-1))`