New Spark 3 Array Functions (exists, forall, transform, aggregate, zip_with)

Spark 3 has new array functions that make working with ArrayType columns much easier.

Spark developers previously needed to use UDFs to perform complicated array functions. The new Spark functions make it easy to process array columns with native Spark.

Some of these higher order functions were accessible in SQL as of Spark 2.4, but they didn’t become part of the org.apache.spark.sql.functions object until Spark 3.0.

The transform and aggregate array functions are especially powerful general purpose functions. They provide functionality that’s equivalent to map and fold in Scala and make it a lot easier to work with ArrayType columns.

You no longer need to revert to ugly UDFs to perform complex array manipulation. You can use native Spark!

exists

exists returns true if the predicate function returns true for any value in an array.

Suppose you have the following data and would like identify all arrays that contain at least one even number.

+---------+------------+
|person_id|best_numbers|
+---------+------------+
|        a|   [3, 4, 5]|
|        b|     [8, 12]|
|        c|     [7, 13]|
|        d|        null|
+---------+------------+

People “a” and “b” have at least one favorite number that’s even, person “c” only has favorite odd numbers, and person “d” doesn’t have any data.

Start by creating an isEven column function that returns true is a number is even:

def isEven(col: Column): Column = {
  col % 2 === lit(0)
}

Let’s create a DataFrame and then run the org.apache.spark.sql.functions.exists function to append a even_best_number_exists column.

val df = spark.createDF(
  List(
    ("a", Array(3, 4, 5)),
    ("b", Array(8, 12)),
    ("c", Array(7, 13)),
    ("d", null),
  ), List(
    ("person_id", StringType, true),
    ("best_numbers", ArrayType(IntegerType, true), true)
  )
)

val resDF = df.withColumn(
  "even_best_number_exists",
  exists(col("best_numbers"), isEven)
)

Print the contents of the DataFrame and verify that even_best_number_exists contains the expected values.

resDF.show()

+---------+------------+-----------------------+
|person_id|best_numbers|even_best_number_exists|
+---------+------------+-----------------------+
|        a|   [3, 4, 5]|                   true|
|        b|     [8, 12]|                   true|
|        c|     [7, 13]|                  false|
|        d|        null|                   null|
+---------+------------+-----------------------+

You don’t have to defined isEven as a named function. You can also use an anonymous function and get the same result.

df.withColumn(
  "even_best_number_exists",
  exists(
    col("best_numbers"),
    (col: Column) => col % 2 === lit(0)
  )
)

Here’s the exists method signature in the Spark 3 docs.

forall

Let’s take a look at some arrays that contain words:

+--------------------------+
|words                     |
+--------------------------+
|[ants, are, animals]      |
|[italy, is, interesting]  |
|[brazilians, love, soccer]|
|null                      |
+--------------------------+

Let’s use forall to identify the arrays with words that all begin with the letter “a”:

val df = spark.createDF(
  List(
    (Array("ants", "are", "animals")),
    (Array("italy", "is", "interesting")),
    (Array("brazilians", "love", "soccer")),
    (null),
  ), List(
    ("words", ArrayType(StringType, true), true)
  )
)

val resDF = df.withColumn(
  "uses_alliteration_with_a",
  forall(
    col("words"),
    (col: Column) => col.startsWith("a")
  )
)

Let’s check out the contents of resDF and confirm it returns true for “ants are animals”:

resDF.show(false)

+--------------------------+------------------------+
|words                     |uses_alliteration_with_a|
+--------------------------+------------------------+
|[ants, are, animals]      |true                    |
|[italy, is, interesting]  |false                   |
|[brazilians, love, soccer]|false                   |
|null                      |null                    |
+--------------------------+------------------------+

A more interesting function would be one that returns true for any array that uses alliteration. We’ll solve that problem with the more advanced array functions.

Here’s the forall method signature in the docs.

filter

Suppose you have the following data:

+-----------------------+
|words                  |
+-----------------------+
|[bad, bunny, is, funny]|
|[food, is, bad, tasty] |
|null                   |
+-----------------------+

Let’s filter out all the array values equal to “bad”:

val df = spark.createDF(
  List(
    (Array("bad", "bunny", "is", "funny")),
    (Array("food", "is", "bad", "tasty")),
    (null),
  ), List(
    ("words", ArrayType(StringType, true), true)
  )
)

val resDF = df.withColumn(
  "filtered_words",
  filter(
    col("words"),
    (col: Column) => col =!= lit("bad")
  )
)

Print the contents of resDF and make sure the filtered_words column does not contain the word “bad”.

resDF.show(false)

+-----------------------+------------------+
|words                  |filtered_words    |
+-----------------------+------------------+
|[bad, bunny, is, funny]|[bunny, is, funny]|
|[food, is, bad, tasty] |[food, is, tasty] |
|null                   |null              |
+-----------------------+------------------+

The filter method is overloaded to take a function that accepts either two or one column argument.

Send me an example of a filter invocation with a column function that takes two arguments if you can figure it out.

transform

Suppose we have a dataset with arrays that contain fun cities.

+----------------------+
|places                |
+----------------------+
|[New York, Seattle]   |
|[Barcelona, Bangalore]|
|null                  |
+----------------------+

Let’s add a fun_places column that makes it clear how fun all of these cities really are!

val df = spark.createDF(
  List(
    (Array("New York", "Seattle")),
    (Array("Barcelona", "Bangalore")),
    (null),
  ), List(
    ("places", ArrayType(StringType, true), true)
  )
)


val resDF = df.withColumn(
  "fun_places",
  transform(
    col("places"),
    (col: Column) => concat(col, lit(" is fun!"))
  )
)

Print out resDF and confirm that is fun! has been appended all the elements in each array.

resDF.show(false)

+----------------------+--------------------------------------+
|places                |fun_places                            |
+----------------------+--------------------------------------+
|[New York, Seattle]   |[New York is fun!, Seattle is fun!]   |
|[Barcelona, Bangalore]|[Barcelona is fun!, Bangalore is fun!]|
|null                  |null                                  |
+----------------------+--------------------------------------+

transform works similar to the map function in Scala. I’m not sure why they chose to name this function transform… I think array_map would have been a better name, especially because the Dataset#transform function is commonly used to chain DataFrame transformations.

Using a method name that already exists confuses folks that don’t understand OOP. It also makes stuff hard to Google.

Let’s not focus on the negative. org.apache.spark.functions.transform now exists and is an absolute joy to work with. This is a great addition to the API.

aggregate

Suppose you have a DataFrame with an array of numbers:

+------------+
|     numbers|
+------------+
|[1, 2, 3, 4]|
|   [5, 6, 7]|
|        null|
+------------+

You can calculate the sum of the numbers in the array with the aggregate function.

val df = spark.createDF(
  List(
    (Array(1, 2, 3, 4)),
    (Array(5, 6, 7)),
    (null),
  ), List(
    ("numbers", ArrayType(IntegerType, true), true)
  )
)

val resDF = df.withColumn(
  "numbers_sum",
  aggregate(
    col("numbers"),
    lit(0),
    (col1: Column, col2: Column) => col1 + col2
  )
)

Let’s check out the result:

resDF.show()

+------------+-----------+
|     numbers|numbers_sum|
+------------+-----------+
|[1, 2, 3, 4]|         10|
|   [5, 6, 7]|         18|
|        null|       null|
+------------+-----------+

aggregate isn’t the best name. This concept is referred to as reduce in Python, inject in Ruby, and fold in Scala. The function name aggregate makes you think about database aggregations, not reducing an array.

The aggregate function is amazingly awesome, despite the name. It opens the door for all types of interesting array reductions.

The aggregate docs are hard to follow because there are so many column arguments:

Let me know if you have a good example of an aggregate function that uses the finish function.

zip_with

Suppose we have a DataFrame with letters1 and letters2 columns that contain arrays of letters.

+--------+--------+
|letters1|letters2|
+--------+--------+
|  [a, b]|  [c, d]|
|  [x, y]|  [p, o]|
|    null|  [e, r]|
+--------+--------+

Let’s zip the letters1 and letters2 arrays and join them with a *** delimiter. We want to convert [a, b] and [c, d] into a single array: [a***c, b***d].

val df = spark.createDF(
  List(
    (Array("a", "b"), Array("c", "d")),
    (Array("x", "y"), Array("p", "o")),
    (null, Array("e", "r"))
  ), List(
    ("letters1", ArrayType(StringType, true), true),
    ("letters2", ArrayType(StringType, true), true)
  )
)

val resDF = df.withColumn(
  "zipped_letters",
  zip_with(
    col("letters1"),
    col("letters2"),
    (left: Column, right: Column) => concat_ws("***", left, right)
  )
)

Let’s take a look at the results.

resDF.show()

+--------+--------+--------------+
|letters1|letters2|zipped_letters|
+--------+--------+--------------+
|  [a, b]|  [c, d]|[a***c, b***d]|
|  [x, y]|  [p, o]|[x***p, y***o]|
|    null|  [e, r]|          null|
+--------+--------+--------------+

Ugly UDFs from the past

spark-daria implemented exists as a UDF and the code is pretty gross:

def exists[T: TypeTag](f: (T => Boolean)) = udf[Boolean, Seq[T]] { (arr: Seq[T]) =>
  arr.exists(f(_))
}

The spark-daria forall UDF implementation was equally unappealing:

def forall[T: TypeTag](f: T => Boolean) = udf[Boolean, Seq[T]] { arr: Seq[T] =>
  arr.forall(f(_))
}

It’ll be cool to get rid of this cruft in spark-daria.

The more spark-daria functions that are supported natively in Spark, the better.

Let’s hope we can get the spark-daria createDF function merged in with the Spark codebase some day…

Conclusion

Spark 3 has added some new high level array functions that’ll make working with ArrayType columns a lot easier.

The transform and aggregate functions don’t seem quite as flexible as map and fold in Scala, but they’re a lot better than the Spark 2 alternatives.

The Spark core developers really “get it”. They’re doing a great job continuing to make Spark better.

Registration

2 Comments


  1. (Just an FYI I posted the same comment on your “Working with dates and times in Spark” but realized this might be a better place for my question). Im using Pyspark btw.

    At any rate, I want to convert ArrayType(StringType()) to ArrayType(DateType()). I have an array like the following [“22/01/2021”, “13/10/2018”] and I want to convert it to ISO-8601 like this [“2021-01-22”, “2018-10-13”].

    Some suggestions on the web included exploding the array column and then using pyspark.sql.functions.to_date(), but this is inefficient (millions of rows exploded will give hundreds of millions of rows) and quite frankly not elegant. I look forward to hearing your suggestion on this–it seems simple but Im astounded by the lack of resources Ive found thus far, cheers!


  2. Hi, Can we use the array functions on array of StructType
    My schema is:
    root
    |– parent_key: long (nullable = true)
    |– id: string (nullable = true)
    |– value: struct (nullable = true)
    | |– id: integer (nullable = true)
    | |– first_name: string (nullable = true)
    | |– is_active: integer (nullable = true)
    | |– source: string (nullable = true)
    | |– created_at: timestamp (nullable = true)
    | |– updated_at: timestamp (nullable = true)
    |– existing_values: array (nullable = true)
    | |– element: struct (containsNull = true)
    | | |– id: integer (nullable = true)
    | | |– first_name: string (nullable = true)
    | | |– is_active: integer (nullable = true)
    | | |– source: string (nullable = true)
    | | |– created_at: timestamp (nullable = true)
    | | |– updated_at: timestamp (nullable = true)
    |– op: timestamp (nullable = true)

    and depending upon op column which can have values [c,u,d], I need to add/update/delete the values from existing_values

    Right now, I am doing it by first exploding the existing_values and using group by with agg functions. But the transformation is taking a lot of time.

Comments are closed, but trackbacks and pingbacks are open.