Sequence: Enabling Multiple Generators in 'for' Loops (Scala 3 Video)
One cool thing about Scala for
expressions is that you can use multiple generators inside of them. This lets you do some nice analytics when you have some interesting data relationships.
For instance, suppose you have some data like this:
case class Person(name: String)
val myFriends = Sequence(
Person("Adam"),
Person("David"),
Person("Frank")
)
val adamsFriends = Sequence(
Person("Nick"),
Person("David"),
Person("Frank")
)
If I want to find out which friends of mine are also friends of Adam, I can write a for
expression like this:
val mutualFriends = for {
myFriend <- myFriends // generator
adamsFriend <- adamsFriends // generator
if (myFriend.name == adamsFriend.name)
} yield myFriend
mutualFriends.foreach(println)
Notice how I use two Sequence
instances as generators in that for
expression.
Sadly, the compiler tells us that this code won’t work, but happily it again tells us why:
<console>:17: error: value flatMap is not a member of Sequence[Person]
myFriend <- myFriends
^
Since you’re used to reading these error messages now, you know the compiler is telling us that we need to implement a flatMap
method in the Sequence
class for this code to work.
flatMap
, a curious creature
The flatMap
method is an interesting creature, and experienced functional programmers seem to use it a lot.
As a bit of background, when you think about map
’s signature for a moment, you’ll remember that it looks like this:
def map[B](f: A => B): Sequence[B]
As shown, map
takes a function that transforms a type A
to a type B
, and returns a Sequence[B]
when it finishes, transforming all of the elements in the Sequence
.
flatMap
’s signature is similar to map
:
def flatMap[B](f: A => Sequence[B]): Sequence[B]
As this shows, flatMap
is similar to map
, but it’s also different. The function flatMap
takes transforms a type A
to a Sequence
of type B
— Sequence[B]
— and then when it’s finished it also returns a Sequence[B]
. The type signatures tell us that the difference between map
and flatMap
is the type of functions they take as input parameters.
flatMap
background
If you come to Scala from a background like Java, after a while it becomes apparent that the map
function is very cool -- and you’ll use it all the time -- but it can be hard to find a use for flatMap
.
As I wrote in the Scala Cookbook, I like to think of flatMap
as “map flat,” because on collections classes it works similarly to a) calling map
and then b) calling flatten
. As an example of what I mean, these lines of code show the difference between calling map
and flatMap
on a Seq[String]
:
scala> val fruits = Seq("apple", "banana", "orange")
fruits: Seq[java.lang.String] = List(apple, banana, orange)
scala> fruits.map(_.toUpperCase)
res0: Seq[java.lang.String] = List(APPLE, BANANA, ORANGE)
scala> fruits.flatMap(_.toUpperCase)
res1: Seq[Char] = List(A, P, P, L, E, B, A, N, A, N, A, O, R, A, N, G, E)
map
applies the function to each element in the input Seq
to create a transformed Seq
, but flatMap
takes the process a step further. In fact, you can show that calling flatMap
is just like calling map
and then calling flatten
:
scala> val mapResult = fruits.map(_.toUpperCase)
mapResult: Seq[String] = List(APPLE, BANANA, ORANGE)
scala> val flattenResult = mapResult.flatten
flattenResult: Seq[Char] = List(A, P, P, L, E, B, A, N, A, N, A, O, R, A, N, G, E)
I won’t show any more examples of flatMap
here, but if you want to see more examples of how it works, please see my article, “A collection of Scala flatMap examples.” I also demonstrate more ways to use flatMap
in lessons later in this book.
Starting to write a flatMap
method
Earlier I showed flatMap
’s signature to be:
def flatMap[B](f: A => Sequence[B]): Sequence[B]
Given that signature, and knowing that flatMap
works like a map
call followed by a flatten
call, I can implement flatMap
’s function body by calling map
and then flatten
:
def flatMap[B](f: A => Sequence[B]): Sequence[B] = {
val mapRes: Sequence[Sequence[B]] = map(f) //map
flatten(mapRes) //flatten
}
In the first line of the function body I call the map
method we developed in the previous lessons, and I also explicitly show the type of mapRes
:
val mapRes: Sequence[Sequence[B]] = map(f) //map
Because this result is a little complicated, I like to make the return type obvious. That’s a great thing about Scala: You don’t have to declare variable types, but you can show them when you want to.
In the second line of this function I quickly run into a problem: I haven’t defined a flatten
method yet! Let’s fix that.
Aside: Why I wrote the code as shown
It’s important to note that I wrote the function body like this:
val mapRes: Sequence[Sequence[B]] = map(f) //map
flatten(mapRes) //flatten
I did this because the function input parameter that flatMap
takes looks like this:
f: A => Sequence[B]
Because that function transforms a type A
into a Sequence[B]
, I can’t just call map
and flatten
on elems
. For example, this code won’t work:
// this won't work
def flatMap[B](f: A => Sequence[B]): Sequence[B] = {
val mapRes = elems.map(f)
mapRes.flatten
}
The reason I can’t cheat like this is because elems.map(f)
returns an ArrayBuffer[Sequence[B]]
, and what I really need is a Sequence[Sequence[B]]
. Because of this I need to take the approach I showed earlier:
def flatMap[B](f: A => Sequence[B]): Sequence[B] = {
val mapRes: Sequence[Sequence[B]] = map(f) //map
flatten(mapRes) //flatten
}
How flatten
works
Getting back to the problem at hand, I need to write a flatten
method. If you read the Scala Cookbook you know how a flatten
method should work — it converts a “list of lists” to a single list.
A little example demonstrates this. In the REPL you can create a list of lists like this:
val a = List( List(1,2), List(3,4) )
Now when you call flatten
on that data structure you get a combined (or flattened) result of List(1,2,3,4)
. Here’s what it looks like in the REPL:
scala> val a = List( List(1,2), List(3,4) )
a: List[List[Int]] = List(List(1, 2), List(3, 4))
scala> a.flatten
res0: List[Int] = List(1, 2, 3, 4)
Our flatten
function should do the same thing.
Writing a flatten
function
Seeing how flatten
works, I can write pseudocode for a flatten
function like this:
create an empty list 'xs'
for each list 'a' in the original listOfLists
for each element 'e' in the list 'a'
add 'e' to 'xs'
return 'xs'
In Scala/OOP you can implement that pseudocode like this:
var xs = ArrayBuffer[B]()
for (listB: Sequence[B] <- listOfLists) {
for (e <- listB) {
xs += e
}
}
xs
Because I’m working with my custom Sequence
I need to modify that code slightly, but when I wrap it inside a function named flatten
, it still looks similar:
def flatten[B](seqOfSeq: Sequence[Sequence[B]]): Sequence[B] = {
var xs = ArrayBuffer[B]()
for (listB: Sequence[B] <- seqOfSeq) {
for (e <- listB) {
xs += e
}
}
Sequence(xs: _*)
}
The biggest difference here is that I convert the temporary ArrayBuffer
to a Sequence
in the last step, like this:
Sequence(xs: _*)
From flatten
to flattenLike
There’s one problem with this function; the type signature for a flatten
function on a Scala List
looks like this:
def flatten[B]: List[B]
Because my type signature isn’t the same as that, I’m not comfortable naming it flatten
. Therefore I’m going to rename it and also make it private, so the new signature looks like this:
private def flattenLike[B](seqOfSeq: Sequence[Sequence[B]]): Sequence[B]
The short explanation for why the type signatures don’t match up is that my cheating ways have finally caught up with me. Creating Sequence
as a wrapper around an ArrayBuffer
creates a series of problems if I try to define flatten
like this:
def flatten[B](): Sequence[B] = ...
Rather than go into those problems in detail, I’ll leave that as an exercise for the reader. Focusing on the problem at hand — getting a flatMap
function working — I’m going to move forward using use my flattenLike
function to get flatMap
working.
Making flatMap
work
Now that I have flattenLike
written, I can go back and update flatMap
to call it:
def flatMap[B](f: A => Sequence[B]): Sequence[B] = {
val mapRes: Sequence[Sequence[B]] = map(f) //map
flattenLike(mapRes) //flatten
}
Testing flatMap
Now that I think I have a working flatMap
function, I can add it to the Sequence
class. Here’s the complete source code for Sequence
, including the functions I just wrote:
import scala.collection.mutable.ArrayBuffer
case class Sequence[A](private val initialElems: A*) {
// this is a book, don't do this at home
private val elems = ArrayBuffer[A]()
// initialize
elems ++= initialElems
def flatMap[B](f: A => Sequence[B]): Sequence[B] = {
val mapRes: Sequence[Sequence[B]] = map(f) //map
flattenLike(mapRes) //flatten
}
private def flattenLike[B](seqOfSeq: Sequence[Sequence[B]]): Sequence[B] = {
var xs = ArrayBuffer[B]()
for (listB: Sequence[B] <- seqOfSeq) {
for (e <- listB) {
xs += e
}
}
Sequence(xs: _*)
}
def withFilter(p: A => Boolean): Sequence[A] = {
val tmpArrayBuffer = elems.filter(p)
Sequence(tmpArrayBuffer: _*)
}
def map[B](f: A => B): Sequence[B] = {
val abMap = elems.map(f)
Sequence(abMap: _*)
}
def foreach(block: A => Unit): Unit = {
elems.foreach(block)
}
}
When I paste that code into the REPL, and then paste in the code I showed at the beginning of this lesson:
case class Person(name: String)
val myFriends = Sequence(
Person("Adam"),
Person("David"),
Person("Frank")
)
val adamsFriends = Sequence(
Person("Nick"),
Person("David"),
Person("Frank")
)
I can confirm that my for
expression with multiple generators now works:
val mutualFriends = for {
myFriend <- myFriends // generator
adamsFriend <- adamsFriends // generator
if (myFriend.name == adamsFriend.name)
} yield myFriend
mutualFriends.foreach(println)
The output of that last line looks like this:
scala> mutualFriends.foreach(println)
Person(David)
Person(Frank)
Admittedly I did some serious cheating in this lesson to get a flatMap
function working, but as you see, once flatMap
is implemented, you can use multiple generators in a for
expression.
Summary
I can summarize what I showed in all of the for
expression lessons with these lines of code:
// (1) works because `foreach` is defined
for (p <- peeps) println(p)
// (2) `yield` works because `map` is defined
val res: Sequence[Int] = for {
i <- ints
} yield i * 2
res.foreach(println)
// (3) `if` works because `withFilter` is defined
val res = for {
i <- ints
if i > 2
} yield i*2
// (4) works because `flatMap` is defined
val mutualFriends = for {
myFriend <- myFriends // generator
adamsFriend <- adamsFriends // generator
if (myFriend.name == adamsFriend.name)
} yield myFriend
Update: All of my new videos are now on
LearnScala.dev