Difference between map() and flatMap() in Java 8 Stream

The map() and flatmap() are two important operations in new functional Java 8. Both represents functional operation and they are also methods in java.util.stream.Stream class. The key difference between map() and flatmap() function is that when you use map(), it applies a function on each element of stream and stores the value returned by the function into a new Stream. This way one stream is transformed into another e.g. a Stream of String is transformed into a Stream of Integer where each element is length of corresponding Stream. Key thing to remember is that the function used for transformation in map() returns a single value. If map() uses a function, which, instead of returning a single value returns a Stream of values than you have a Stream of Stream of values, and flatmap() is used to flat that into a Stream of values.

For example, if we have a Stream of String containing {"12", "34"} and a method getPermutations() which returns a list of permuations of given String. When you apply that function into each String of Stream using map you will get something like [["12","21"],["34","43"]], but if you use flatmap, you get a Stream of Strings e.g. ["12", "21", "34", "43"]. In this aticle, we'll see couple of working examples to understand the difference between map() and flatmap() in Java better.

I know it's not easy to understand the map() and flatMap() function, especially if you have not done any functional programming before. I was in the same situation, It took me some time to really understand purpose of map and flatMap and thanks to Java SE 8 for Really Impatient, which helped me ot understnad these key functional concepts better. The explanation given in this book is really great and even if you don't have any functional programming experience, you will understnad these new things with little bit of effort. I highly recommend this book to all Java developers who wish to learn Java 8.

Difference between map() and flatMap() in Java8

How Stream.map() works in Java 8

The Stream.map() function performs map functional operation i.e. it take a Stream and transform it to another Stream. It applies a function on each element of Stream and store return value into new Stream. This way you can transform a Stream of String into a Stream of Integer where Integer could be length of String if you supply the length() function. This is a very powerful function which is very helpful while dealing with collection in Java.

Here is an example of Stream.map() in Java 8:

List listOfIntegers = Stream.of("1", "2", "3", "4")

In this example, we have a Stream of String values which represent numbers, by using map() function we have converted this Stream to Stream of Integers. How? by appling Integer.valueOf() on each element of Stream. That's how "1" converted to intger 1 and so on. Once transformation is done, we have collected the result into a List by converting Stream to List using Collectors.

How Stream.flatMap() works in Java 8

The Stream.flatMap() function, as name suggests, is the combination of a map and a flat operation. This means you first apply map function and than flattens the result. Key difference is the function used by map operation returns a Stream of values or list of values rather than single value, that's why we need flattening. When you flat a Stream of Stream, it gets converted into Stream of values.

To understand what flattening a stream consists in, consider a structure like [ [1,2,3],[4,5,6],[7,8,9] ] which has "two levels". It's basicall a big List containing three more List.  Flattening this means transforming it in a "one level" structure e.g. [ 1,2,3,4,5,6,7,8,9 ] i.e. just one list.

In short,
Before flattening - Stream of List of Integer
After flattening - Stream of Integer

Here is a code example to understand the flatMap() function better:

List evens = Arrays.asList(2, 4, 6);
List odds = Arrays.asList(3, 5, 7);
List primes = Arrays.asList(2, 3, 5, 7, 11);
List numbers = Stream.of(evens, odds, primes)
               .flatMap(list -> list.stream())
System.out.println("flattend list: " + numbers);

flattend list: [2, 4, 6, 3, 5, 7, 2, 3, 5, 7, 11]

You can see that we have three lists which are merged into one by using flatMap() function. For mapping you can see we have used list.stream() function which returns multiple values instead of single value. Finally, we have collected the flattend stream into a list. If  you want, you can print the final list using forEach() method.

Stream.map() vs Stream.flatMap() in Java 8

In short, here are the key difference between map() vs flatMap() in Java 8:
  • The function you pass to map() operation returns a single value.
  • The function you pass to flatMap() opeartion returns a Stream of value.
  • flatMap() is combination of map and flat operation. 
  • map() is used for transformation only, but flatMap() is used for both transformation and flattening. 

Now let's see a sample Java program to understand the differnce between flatMap() and map() better.

Difference between Map vs FlatMap in Java 8

Java Program to show difference between map vs flatMap

Here is our sample Java program to demonstrate the real difference between the map() and flatMap() function of Stream class in Java 8. As I told before, map() is used to transform one Stream into another by applying a function on each element and flatMap() does both transformation as well as flattening. The flatMap() function can take a Stream of List and return Stream of values combined from all those list. In the example below we have collected the result in a List but you can also print them using forEach() method of Java 8.

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;

 * Java Program to demonstrate difference between map()
 * vs flatMap() function in Java 8. Both are defined
 * in Stream class. 
 * @author WINDOWS 8
public class Java8Demo {

    public static void main(String args[]) {

        // foods which helps in weight loss
        List<String> loseWeight = new ArrayList<>();
        System.out.println("list of String : " + loseWeight);
        // let's use map() method to convert list of weight
        // lose food, which are String to list of ints
        // which are length of each food String
        List listOfInts = loseWeight.stream()
                .map(s -> s.length())
        System.out.println("list of ints generate by map(): " + listOfInts);

        // flatMap() example, let's first creat a list of list
        List<List> listOfListOfNumber = new ArrayList<>();
        listOfListOfNumber.add(Arrays.asList(2, 4));
        listOfListOfNumber.add(Arrays.asList(3, 9));
        listOfListOfNumber.add(Arrays.asList(4, 16));
        System.out.println("list of list : " + listOfListOfNumber);
        // let's use flatMap() to flatten this list into
        // list of integers i.e. 2,4,3,9,4,16
        List listOfIntegers = listOfListOfNumber.stream()
                .flatMap( list -> list.stream())
        System.out.println("list of numbers generated by flatMap : " + listOfIntegers);



list of String : [avocados, beans, salad, oats, broccoli]
list of ints generate by map(): [8, 5, 5, 4, 8]
list of list : [[2, 4], [3, 9], [4, 16]]
list of numbers generated by flatMap : [2, 4, 3, 9, 4, 16]

You can see that in first example, the function used by map() method returns a single value, the length of Stirng passed to it, while in case of flatMap() the method returns a stream, which is basically your multiple values.

That's all about difference between map() and flatMap() in Java 8. You should use map() if you jsut want to transform one Stream into another where each element gets converted to one single value. Use flatMap() if the function used by map operation return multiple values and you want just one list containing all values. If you are still confused between map() vs flatMap() then go read Java SE 8 for Really Impatient By Cay S. Horstmann, one of the best book to learn about new features of Java 8.

No comments :

Post a Comment