Spark Column Equality
The term "column equality" refers to two different things in Spark:
- When a column is equal to a particular value (typically when filtering)
- When all the values in two columns are equal for all rows in the dataset (especially common when testing)
This blog post will explore both types of Spark column equality.
Column equality for filtering
Suppose you have a DataFrame with team_name
, num_championships
, and state
columns.
Here's how you can filter to only show the teams from TX (short for Texas).
df.filter(df("state") === "TX")
Here's a sample dataset that you can paste into a Spark console to verify this result yourself.
val df = Seq(
("Rockets", 2, "TX"),
("Warriors", 6, "CA"),
("Spurs", 5, "TX"),
("Knicks", 2, "NY")
).toDF("team_name", "num_championships", "state")
Let's filter the DataFrame and make sure it only includes the teams from TX.
df.filter(df("state") === "TX").show()
+---------+-----------------+-----+
|team_name|num_championships|state|
+---------+-----------------+-----+
| Rockets| 2| TX|
| Spurs| 5| TX|
+---------+-----------------+-----+
Writing Beautiful Spark Code is the best way to learn about filtering, including the main pitfall to avoid when filtering. Read the book to filter effectively.
More on Spark's Column class
You'll use the Spark Column class all the time and it's good to understand how it works.
Here's the method signature for the ===
method defined in the Column class.
def ===(other: Any): Column
The ===
takes Any object as an argument and returns a Column.
In df("state") === "TX"
, the ===
method is supplied a string argument. It can also be supplied a Column
argument.
import org.apache.spark.sql.functions.lit
df.filter(df("state") === lit("TX")).show()
Scala methods can be invoked with spaces instead of dot notation. This is an example of syntactic sugar.
We can use df("state").===(lit("TX"))
to avoid syntactic sugar and invoke the ===
method with standard dot notation.
df.filter(df("state").===("TX")).show()
+---------+-----------------+-----+
|team_name|num_championships|state|
+---------+-----------------+-----+
| Rockets| 2| TX|
| Spurs| 5| TX|
+---------+-----------------+-----+
We could also use the equalTo()
Column method that behaves like ===
:
df.filter(df("state").equalTo("TX")).show()
Study the Column methods to become a better Spark programmer!
Naive Column Equality
Let's create a DataFrame with num
and is_even_hardcoded
columns.
val df = Seq(
(2, true),
(6, true),
(5, false)
).toDF("num", "is_even_hardcoded")
df.show()
+---+-----------------+
|num|is_even_hardcoded|
+---+-----------------+
| 2| true|
| 6| true|
| 5| false|
+---+-----------------+
Let's create a custom DataFrame transformation called isEven
that'll return true
if the number is even and false
otherwise.
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions._
def isEven(inputColName: String, outputColName: String)(df: DataFrame) = {
df.withColumn(outputColName, col(inputColName) % 2 === lit(0))
}
Check out Writing Beautiful Spark Code if this blog post is moving too fast and you need a gradual introduction to topics like DataFrame transformation functions.
Let's run the isEven()
function and visually verify that the hardcoded and programatic results are the same.
df.transform(isEven("num", "is_even")).show()
+---+-----------------+-------+
|num|is_even_hardcoded|is_even|
+---+-----------------+-------+
| 2| true| true|
| 6| true| true|
| 5| false| false|
+---+-----------------+-------+
Visually verifying column equality isn't good for big datasets or automated processes.
Let's get this full DataFrame stored in a new variable, so we don't need to keep on running the isEven
function.
val fullDF = df.transform(isEven("num", "is_even"))
We can use the ColumnComparer
trait defined in spark-fast-tests to verify column equality.
import com.github.mrpowers.spark.fast.tests.ColumnComparer
assertColEquality(df, "is_even_hardcoded", "is_even")
When you're writing unit tests, you'll definitely want to use the spark-fast-tests library.
Let's hack together some code that'll return true
if two columns are equal.
def areColumnsEqual(df: DataFrame, colName1: String, colName2: String) = {
val elements = df
.select(colName1, colName2)
.collect()
val c1 = elements.map(_(0))
val c2 = elements.map(_(1))
c1.sameElements(c2)
}
areColumnsEqual(fullDF, "is_even_hardcoded", "is_even") // true
This function doesn't explain all the edge cases, but it's a good start!
Column equality with ===
We can also evaluate column equality by comparing both columns with the ===
operator and making sure all values evaluate to true
.
As a reminder, here are the contents of fullDF
:
fullDF.show()
+---+-----------------+-------+
|num|is_even_hardcoded|is_even|
+---+-----------------+-------+
| 2| true| true|
| 6| true| true|
| 5| false| false|
+---+-----------------+-------+
Let's append a column to fullDF
that returns true
if is_even_hardcoded
and is_even
are equal:
fullDF
.withColumn("are_cols_equal", $"is_even_hardcoded" === $"is_even")
.show()
+---+-----------------+-------+--------------+
|num|is_even_hardcoded|is_even|are_cols_equal|
+---+-----------------+-------+--------------+
| 2| true| true| true|
| 6| true| true| true|
| 5| false| false| true|
+---+-----------------+-------+--------------+
Let's write a function to verify that all the values in a given column are true.
def areAllRowsTrue(df: DataFrame, colName: String) = {
val elements = df
.select(colName)
.collect()
val c1 = elements.map(_(0))
c1.forall(_ == true)
}
Let's verify that areAllRowsTrue
returns true
for the are_cols_equal
column.
areAllRowsTrue(
fullDF.withColumn("are_cols_equal", $"is_even_hardcoded" === $"is_even"),
"are_cols_equal"
) // true
This ===
approach unfortunately doesn't work for all column types.
ArrayType Column Equality
Let's create a numbersDF
with two ArrayType columns that contain integers.
val numbersDF = Seq(
(Seq(1, 2), Seq(1, 2)),
(Seq(3, 3), Seq(3, 3)),
(Seq(3, 3), Seq(6, 6))
).toDF("nums1", "nums2")
We can use ===
to assess ArrayType column equality.
numbersDF
.withColumn("nah", $"nums1" === $"nums2")
.show()
+------+------+-----+
| nums1| nums2| nah|
+------+------+-----+
|[1, 2]|[1, 2]| true|
|[3, 3]|[3, 3]| true|
|[3, 3]|[6, 6]|false|
+------+------+-----+
Let's see if ===
works for nested arrays.
Start by making a fakeDF
DataFrame with two nested array columns.
val fakeDF = Seq(
(Seq(Seq(1, 2)), Seq(Seq(1, 2))),
(Seq(Seq(1, 3)), Seq(Seq(1, 2)))
).toDF("nums1", "nums2")
The code below confirms that the ===
operator can handle deep column comparisons gracefully.
fakeDF
.withColumn("nah", $"nums1" === $"nums2")
.show()
+--------+--------+-----+
| nums1| nums2| nah|
+--------+--------+-----+
|[[1, 2]]|[[1, 2]]| true|
|[[1, 3]]|[[1, 2]]|false|
+--------+--------+-----+
But watch out, the ===
operator doesn't work for all complex column types.
MapType Column Equality
The ===
operator does not work for MapType columns.
val mapsDF = Seq(
(Map("one" -> 1), Map("one" -> 1))
).toDF("m1", "m2")
mapsDF.withColumn("maps_equal", $"m1" === $"m2").show()
org.apache.spark.sql.AnalysisException: cannot resolve '(m1 = m2)' due to data type mismatch: EqualTo does not support ordering on type map<string,int>;;
Project [m1#373, m2#374, (m1#373 = m2#374) AS maps_equal#384]
Project [_1#370 AS m1#373, _2#371 AS m2#374]
LocalRelation [_1#370, _2#371]
at org.apache.spark.sql.catalyst.analysis.package$AnalysisErrorAt.failAnalysis(package.scala:42)
at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1$$anonfun$apply$3.applyOrElse(CheckAnalysis.scala:115)
at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1$$anonfun$apply$3.applyOrElse(CheckAnalysis.scala:107)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:278)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:278)
at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:70)
at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:277)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$3.apply(TreeNode.scala:275)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$3.apply(TreeNode.scala:275)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$4.apply(TreeNode.scala:326)
at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:187)
at org.apache.spark.sql.catalyst.trees.TreeNode.mapChildren(TreeNode.scala:324)
at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:275)
at org.apache.spark.sql.catalyst.plans.QueryPlan$$anonfun$transformExpressionsUp$1.apply(QueryPlan.scala:93)
at org.apache.spark.sql.catalyst.plans.QueryPlan$$anonfun$transformExpressionsUp$1.apply(QueryPlan.scala:93)
at org.apache.spark.sql.catalyst.plans.QueryPlan$$anonfun$1.apply(QueryPlan.scala:105)
at org.apache.spark.sql.catalyst.plans.QueryPlan$$anonfun$1.apply(QueryPlan.scala:105)
at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:70)
at org.apache.spark.sql.catalyst.plans.QueryPlan.transformExpression$1(QueryPlan.scala:104)
at org.apache.spark.sql.catalyst.plans.QueryPlan.org$apache$spark$sql$catalyst$plans$QueryPlan$$recursiveTransform$1(QueryPlan.scala:116)
at org.apache.spark.sql.catalyst.plans.QueryPlan$$anonfun$org$apache$spark$sql$catalyst$plans$QueryPlan$$recursiveTransform$1$2.apply(QueryPlan.scala:121)
at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
at scala.collection.immutable.List.foreach(List.scala:392)
at scala.collection.TraversableLike$class.map(TraversableLike.scala:234)
at scala.collection.immutable.List.map(List.scala:296)
at org.apache.spark.sql.catalyst.plans.QueryPlan.org$apache$spark$sql$catalyst$plans$QueryPlan$$recursiveTransform$1(QueryPlan.scala:121)
at org.apache.spark.sql.catalyst.plans.QueryPlan$$anonfun$2.apply(QueryPlan.scala:126)
at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:187)
at org.apache.spark.sql.catalyst.plans.QueryPlan.mapExpressions(QueryPlan.scala:126)
at org.apache.spark.sql.catalyst.plans.QueryPlan.transformExpressionsUp(QueryPlan.scala:93)
at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1.apply(CheckAnalysis.scala:107)
at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$checkAnalysis$1.apply(CheckAnalysis.scala:85)
at org.apache.spark.sql.catalyst.trees.TreeNode.foreachUp(TreeNode.scala:127)
at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$class.checkAnalysis(CheckAnalysis.scala:85)
at org.apache.spark.sql.catalyst.analysis.Analyzer.checkAnalysis(Analyzer.scala:95)
at org.apache.spark.sql.catalyst.analysis.Analyzer$$anonfun$executeAndCheck$1.apply(Analyzer.scala:108)
at org.apache.spark.sql.catalyst.analysis.Analyzer$$anonfun$executeAndCheck$1.apply(Analyzer.scala:105)
at org.apache.spark.sql.catalyst.plans.logical.AnalysisHelper$.markInAnalyzer(AnalysisHelper.scala:201)
at org.apache.spark.sql.catalyst.analysis.Analyzer.executeAndCheck(Analyzer.scala:105)
at org.apache.spark.sql.execution.QueryExecution.analyzed$lzycompute(QueryExecution.scala:57)
at org.apache.spark.sql.execution.QueryExecution.analyzed(QueryExecution.scala:55)
at org.apache.spark.sql.execution.QueryExecution.assertAnalyzed(QueryExecution.scala:47)
at org.apache.spark.sql.Dataset$.ofRows(Dataset.scala:79)
at org.apache.spark.sql.Dataset.org$apache$spark$sql$Dataset$$withPlan(Dataset.scala:3407)
at org.apache.spark.sql.Dataset.select(Dataset.scala:1335)
at org.apache.spark.sql.Dataset.withColumns(Dataset.scala:2253)
at org.apache.spark.sql.Dataset.withColumn(Dataset.scala:2220)
... 51 elided
We need to use different tactics for MapType column equality.
The Scala ==
operator can successfully compare maps:
Map("one" -> 1) == Map("one" -> 1) // true
The sameElements
Scala method also works:
Map("one" -> 1) sameElements Map("one" -> 1)
Recall that sameElements
was used in the areColumnsEqual
method we defined earlier:
def areColumnsEqual(df: DataFrame, colName1: String, colName2: String) = {
val elements = df
.select(colName1, colName2)
.collect()
val c1 = elements.map(_(0))
val c2 = elements.map(_(1))
c1.sameElements(c2)
}
So areColumnsEqual
will also work for comparing MapType columns:
areColumnsEqual(mapsDF, "m1", "m2") // true
StructType Column Equality
That's a whole new can of worms!!
Testing with Column Equality
Column equality is useful when writing unit tests.
Testing Spark Applications is the best way to learn how to test your Spark code. This section provides a great introduction, but you should really read [the book](Testing Spark Applications) if you'd like to accelerate your learning process.
You'll want to leverage a library like spark-fast-tests with column comparison methods, so you don't need to write them yourself.
Let's look at how spark-daria uses the spark-fast-tests assertColumnEquality
method to test the removeNonWordCharacters()
function that removes all the non-word characters from a string.
def removeNonWordCharacters(col: Column): Column = {
regexp_replace(col, "[^\\w\\s]+", "")
}
Let's write a test to make sure the removeNonWordCharacters
converts "Bruce &&**||ok88"
to "Bruce ok88"
, "55 oba&&&ma"
to "55 obama"
, etc.
Here's the test:
import utest._
import com.github.mrpowers.spark.fast.tests.ColumnComparer
object FunctionsTest extends TestSuite with ColumnComparer with SparkSessionTestWrapper {
val tests = Tests {
'removeNonWordCharacters - {
"removes all non-word characters from a string, excluding whitespace" - {
val df = spark
.createDF(
List(
("Bruce &&**||ok88", "Bruce ok88"),
("55 oba&&&ma", "55 obama"),
(" ni!!ce h^^air person ", " nice hair person "),
(null, null)
),
List(
("some_string", StringType, true),
("expected", StringType, true)
)
)
.withColumn(
"some_string_remove_non_word_chars",
functions.removeNonWordCharacters(col("some_string"))
)
assertColumnEquality(
df,
"some_string_remove_non_word_chars",
"expected"
)
}
}
}
}
We hardcode the expected result in the DataFrame and add the some_string_remove_non_word_chars
column by running the removeNonWordCharacters()
function.
The assertColumnEquality
method that's defined in spark-fast-tests verifies the some_string_remove_non_word_chars
and expected
are equal.
This is how you should design most of your Spark unit tests.
Conclusion
Spark column equality is a surprisingly deep topic… we haven't even covered all the edge cases!
Make sure you understand how column comparisons work at a high level.
Use spark-fast-tests to write elegant tests and abstract column comparison details from your codebase.
Studying the spark-fast-tests codebase is a great way to learn more about Spark!