Working with Spark ArrayType columns

Spark DataFrame columns support arrays, which are great for data sets that have an arbitrary length. This blog post will demonstrate Spark methods that return ArrayType columns, describe how to create your own ArrayType columns, and explain when to use arrays in your analyses.

Scala collections

Scala has different types of collections: lists, sequences, and arrays. Let’s quickly review the different types of Scala collections before jumping into collections for Spark analyses.

Let’s create and sort a collection of numbers.

List(10, 2, 3).sorted // List[Int] = List(2, 3, 10)
Seq(10, 2, 3).sorted // Seq[Int] = List(2, 3, 10)
Array(10, 2, 3).sorted // Array[Int] = Array(2, 3, 10)

List, Seq, and Array differ slightly, but generally work the same. Most Spark programmers don’t need to know about how these collections differ.

Spark uses arrays for ArrayType columns, so we’ll mainly use arrays in our code snippets.

Splitting a string into an ArrayType column

Let’s create a DataFrame with a name column and a hit_songs pipe delimited string. Then let’s use the split() method to convert hit_songs into an array of strings.

val singersDF = Seq(
  ("beatles", "help|hey jude"),
  ("romeo", "eres mia")
).toDF("name", "hit_songs")

val actualDF = singersDF.withColumn(
  "hit_songs",
  split(col("hit_songs"), "\\|")
)
actualDF.show()

+-------+----------------+
|   name|       hit_songs|
+-------+----------------+
|beatles|[help, hey jude]|
|  romeo|      [eres mia]|
+-------+----------------+
actualDF.printSchema()

root
 |-- name: string (nullable = true)
 |-- hit_songs: array (nullable = true)
 |    |-- element: string (containsNull = true)

An ArrayType column is suitable in this example because a singer can have an arbitrary amount of hit songs. We don’t want to create a DataFrame with hit_song1, hit_song2, …, hit_songN columns.

Directly creating an ArrayType column

Let’s use the spark-daria createDF method to create a DataFrame with an ArrayType column directly. See this blog post for more information about the createDF method.

Let’s create another singersDF with some different artists.

val singersDF = spark.createDF(
  List(
    ("bieber", Array("baby", "sorry")),
    ("ozuna", Array("criminal"))
  ), List(
    ("name", StringType, true),
    ("hit_songs", ArrayType(StringType, true), true)
  )
)
singersDF.show()

+------+-------------+
|  name|    hit_songs|
+------+-------------+
|bieber|[baby, sorry]|
| ozuna|   [criminal]|
+------+-------------+
singersDF.printSchema()

root
 |-- name: string (nullable = true)
 |-- hit_songs: array (nullable = true)
 |    |-- element: string (containsNull = true)

The ArrayType case class is instantiated with an elementType and a containsNull flag. In ArrayType(StringType, true), StringType is the elementType and true is the containsNull flag.

See the documentation for the class here.

array_contains

The Spark functions object provides helper methods for working with ArrayType columns. The array_contains method returns true if the column contains a specified element.

Let’s create an array with people and their favorite colors. Then let’s use array_contains to append a likes_red column that returns true if the person likes red.

val peopleDF = spark.createDF(
  List(
    ("bob", Array("red", "blue")),
    ("maria", Array("green", "red")),
    ("sue", Array("black"))
  ), List(
    ("name", StringType, true),
    ("favorite_colors", ArrayType(StringType, true), true)
  )
)

val actualDF = peopleDF.withColumn(
  "likes_red",
  array_contains(col("favorite_colors"), "red")
)
actualDF.show()

+-----+---------------+---------+
| name|favorite_colors|likes_red|
+-----+---------------+---------+
|  bob|    [red, blue]|     true|
|maria|   [green, red]|     true|
|  sue|        [black]|    false|
+-----+---------------+---------+

explode

Let’s use the same DataFrame before and the explode() to create a new row for every element in each array.

val df = peopleDF.select(
  col("name"),
  explode(col("favorite_colors")).as("color")
)
df.show()

+-----+-----+
| name|color|
+-----+-----+
|  bob|  red|
|  bob| blue|
|maria|green|
|maria|  red|
|  sue|black|
+-----+-----+

peopleDF has 3 rows and df has 5 rows. The explode() method adds rows to a DataFrame.

collect_list

The collect_list method collapses a DataFrame into fewer rows and stores the collapsed data in an ArrayType column.

Let’s create a DataFrame with letter1, letter2, and number1 columns.

val df = Seq(
  ("a", "b", 1),
  ("a", "b", 2),
  ("a", "b", 3),
  ("z", "b", 4),
  ("a", "x", 5)
).toDF("letter1", "letter2", "number1")

df.show()
+-------+-------+-------+
|letter1|letter2|number1|
+-------+-------+-------+
|      a|      b|      1|
|      a|      b|      2|
|      a|      b|      3|
|      z|      b|      4|
|      a|      x|      5|
+-------+-------+-------+

Let’s use the collect_list() method to eliminate all the rows with duplicate letter1 and letter2 rows in the DataFrame and collect all the number1 entries as a list.

df
  .groupBy("letter1", "letter2")
  .agg(collect_list("number1") as "number1s")
  .show()
+-------+-------+---------+
|letter1|letter2| number1s|
+-------+-------+---------+
|      a|      x|      [5]|
|      z|      b|      [4]|
|      a|      b|[1, 2, 3]|
+-------+-------+---------+

We can see that number1s is an ArrayType column.

df.printSchema

root
 |-- letter1: string (nullable = true)
 |-- letter2: string (nullable = true)
 |-- number1s: array (nullable = true)
 |    |-- element: integer (containsNull = true)

Single column array functions

Spark added a ton of useful array functions in the 2.4 release.

We will start with the functions for a single ArrayType column and then move on to the functions for multiple ArrayType columns.

Let’s start by creating a DataFrame with an ArrayType column.

val df = spark.createDF(
  List(
    (Array(1, 2)),
    (Array(1, 2, 3, 1)),
    (null)
  ), List(
    ("nums", ArrayType(IntegerType, true), true)
  )
)
df.show()

+------------+
|        nums|
+------------+
|      [1, 2]|
|[1, 2, 3, 1]|
|        null|
+------------+

Let’s use the array_distinct() method to remove all of the duplicate array elements in the nums column.

df
  .withColumn("nums_distinct", array_distinct($"nums"))
  .show()

+------------+-------------+
|        nums|nums_distinct|
+------------+-------------+
|      [1, 2]|       [1, 2]|
|[1, 2, 3, 1]|    [1, 2, 3]|
|        null|         null|
+------------+-------------+

Let’s use array_join() to create a pipe delimited string of all elements in the arrays.

df
  .withColumn("nums_joined", array_join($"nums", "|"))
  .show()

+------------+-----------+
|        nums|nums_joined|
+------------+-----------+
|      [1, 2]|        1|2|
|[1, 2, 3, 1]|    1|2|3|1|
|        null|       null|
+------------+-----------+

Let’s use the printSchema method to verify that the nums_joined column is a StringType.

df
  .withColumn("nums_joined", array_join($"nums", "|"))
  .printSchema()

root
 |-- nums: array (nullable = true)
 |    |-- element: integer (containsNull = true)
 |-- nums_joined: string (nullable = true)

Let’s use array_max to grab the maximum value from the arrays.

df
  .withColumn("nums_max", array_max($"nums"))
  .show()

+------------+--------+
|        nums|nums_max|
+------------+--------+
|      [1, 2]|       2|
|[1, 2, 3, 1]|       3|
|        null|    null|
+------------+--------+

Let’s use array_min to grab the minimum value from the arrays.

df
  .withColumn("nums_min", array_min($"nums"))
  .show()

+------------+--------+
|        nums|nums_min|
+------------+--------+
|      [1, 2]|       1|
|[1, 2, 3, 1]|       1|
|        null|    null|
+------------+--------+

Let’s use the array_remove method to remove all the 1s from each of the arrays.

df
  .withColumn("nums_sans_1", array_remove($"nums", 1))
  .show()

+------------+-----------+
|        nums|nums_sans_1|
+------------+-----------+
|      [1, 2]|        [2]|
|[1, 2, 3, 1]|     [2, 3]|
|        null|       null|
+------------+-----------+

Let’s use array_sort to sort all of the arrays in ascending order.

df
  .withColumn("nums_sorted", array_sort($"nums"))
  .show()

+------------+------------+
|        nums| nums_sorted|
+------------+------------+
|      [1, 2]|      [1, 2]|
|[1, 2, 3, 1]|[1, 1, 2, 3]|
|        null|        null|
+------------+------------+

Generic single column array functions

Suppose you have an array of strings and would like to see if all elements in the array begin with the letter c. Here’s how you can run this check on a Scala array:

Array("cream", "cookies").forall(_.startsWith("c")) // true
Array("taco", "clam").forall(_.startsWith("c")) // false

You can use the spark-daria forall() method to run this computation on a Spark DataFrame with an ArrayType column.

import com.github.mrpowers.spark.daria.sql.functions._

val df = spark.createDF(
  List(
    (Array("cream", "cookies")),
    (Array("taco", "clam"))
  ), List(
    ("words", ArrayType(StringType, true), true)
  )
)

df.withColumn(
  "all_words_begin_with_c",
  forall[String]((x: String) => x.startsWith("c")).apply(col("words"))
).show()
+----------------+----------------------+
|           words|all_words_begin_with_c|
+----------------+----------------------+
|[cream, cookies]|                  true|
|    [taco, clam]|                 false|
+----------------+----------------------+

The native Spark API doesn’t provide access to all the helpful collection methods provided by Scala. spark-daria uses User Defined Functions to define forall and exists methods. Email me or create an issue if you would like any additional UDFs to be added to spark-daria.

Multiple column array functions

Let’s create a DataFrame with two ArrayType columns so we can try out the built-in Spark array functions that take multiple columns as input.

val numbersDF = spark.createDF(
  List(
    (Array(1, 2), Array(4, 5, 6)),
    (Array(1, 2, 3, 1), Array(2, 3, 4)),
    (null, Array(6, 7))
  ), List(
    ("nums1", ArrayType(IntegerType, true), true),
    ("nums2", ArrayType(IntegerType, true), true)
  )
)

Let’s use array_intersect to get the elements present in both the arrays without any duplication.

numbersDF
  .withColumn("nums_intersection", array_intersect($"nums1", $"nums2"))
  .show()

+------------+---------+-----------------+
|       nums1|    nums2|nums_intersection|
+------------+---------+-----------------+
|      [1, 2]|[4, 5, 6]|               []|
|[1, 2, 3, 1]|[2, 3, 4]|           [2, 3]|
|        null|   [6, 7]|             null|
+------------+---------+-----------------+

Let’s use array_union to get the elements in either array, without duplication.

numbersDF
  .withColumn("nums_union", array_union($"nums1", $"nums2"))
  .show()
+------------+---------+---------------+
|       nums1|    nums2|     nums_union|
+------------+---------+---------------+
|      [1, 2]|[4, 5, 6]|[1, 2, 4, 5, 6]|
|[1, 2, 3, 1]|[2, 3, 4]|   [1, 2, 3, 4]|
|        null|   [6, 7]|           null|
+------------+---------+---------------+

Let’s use array_except to get the elements that are in num1 and not in num2 without any duplication.

numbersDF
  .withColumn("nums1_nums2_except", array_except($"nums1", $"nums2"))
  .show()

+------------+---------+------------------+
|       nums1|    nums2|nums1_nums2_except|
+------------+---------+------------------+
|      [1, 2]|[4, 5, 6]|            [1, 2]|
|[1, 2, 3, 1]|[2, 3, 4]|               [1]|
|        null|   [6, 7]|              null|
+------------+---------+------------------+

Split array column into multiple columns

We can split an array column into multiple columns with getItem. Lets create a DataFrame with a letters column and demonstrate how this single ArrayType column can be split into a DataFrame with three StringType columns.

val df = spark.createDF(
  List(
    (Array("a", "b", "c")),
    (Array("d", "e", "f")),
    (null)
  ), List(
    ("letters", ArrayType(StringType, true), true)
  )
)
df.show()

+---------+
|  letters|
+---------+
|[a, b, c]|
|[d, e, f]|
|     null|
+---------+

This example uses the same data as this Stackoverflow question.

Let’s use getItem to break out the array into col1, col2, and col3.

df
  .select(
    $"letters".getItem(0).as("col1"),
    $"letters".getItem(1).as("col2"),
    $"letters".getItem(2).as("col3")
  )
  .show()

+----+----+----+
|col1|col2|col3|
+----+----+----+
|   a|   b|   c|
|   d|   e|   f|
|null|null|null|
+----+----+----+

Here’s how we can use getItem with a loop.

df
  .select(
    (0 until 3).map(i => $"letters".getItem(i).as(s"col$i")): _*
  )
  .show()

+----+----+----+
|col0|col1|col2|
+----+----+----+
|   a|   b|   c|
|   d|   e|   f|
|null|null|null|
+----+----+----+

Our code snippet above is a little ugly because the 3 is hardcoded. We can calculate the size of every array in the column, take the max size, and use that rather than hardcoding.

val numCols = df
  .withColumn("letters_size", size($"letters"))
  .agg(max($"letters_size"))
  .head()
  .getInt(0)

df
  .select(
    (0 until numCols).map(i => $"letters".getItem(i).as(s"col$i")): _*
  )
  .show()

+----+----+----+
|col0|col1|col2|
+----+----+----+
|   a|   b|   c|
|   d|   e|   f|
|null|null|null|
+----+----+----+

Closing thoughts

Spark ArrayType columns makes it easy to work with collections at scale.

Master the content covered in this blog to add a powerful skill to your toolset.

For more examples, see this Databricks notebook that covers even more Array / Map functions.

Leave a Reply

Your email address will not be published. Required fields are marked *