Chaining Custom PySpark DataFrame Transformations

PySpark code should generally be organized as single purpose DataFrame transformations that can be chained together for production analyses (e.g. generating a datamart).

This blog post demonstrates how to monkey patch the DataFrame object with a transform method, how to define custom DataFrame transformations, and how to chain the function calls.

We’ll also demonstrate how to run multiple custom transformations with function composition using the cytoolz library.

If you’re using the Scala API, read this blog post on chaining DataFrame transformations with Scala.

Accessing DataFrame transform method

Spark 3 includes a native DataFrame transform method, so Spark 3 users can skip the rest of this section.

Spark 2 users can monkey patch the DataFrame object with a transform method so we can chain DataFrame transformations.

from pyspark.sql.dataframe import DataFrame


def transform(self, f):
    return f(self)


DataFrame.transform = transform

This code snippet is from the quinn project.

Chaining DataFrame Transformations with lambda

Let’s define a couple of simple DataFrame transformations to test the transform method.

def with_greeting(df):
    return df.withColumn("greeting", lit("hi"))

def with_something(df, something):
    return df.withColumn("something", lit(something))

Let’s create a DataFrame and then run the with_greeting and with_something DataFrame transformations.

data = [("jose", 1), ("li", 2), ("liz", 3)]
source_df = spark.createDataFrame(data, ["name", "age"])

actual_df = (source_df
    .transform(lambda df: with_greeting(df))
    .transform(lambda df: with_something(df, "crazy")))
print(actual_df.show())

+----+---+--------+---------+
|name|age|greeting|something|
+----+---+--------+---------+
|jose|  1|      hi|    crazy|
|  li|  2|      hi|    crazy|
| liz|  3|      hi|    crazy|
+----+---+--------+---------+

The lambda is optional for custom DataFrame transformations that only take a single DataFrame argument so we can refactor with_greeting line as follows:

actual_df = (source_df
    .transform(with_greeting)
    .transform(lambda df: with_something(df, "crazy")))

Without the DataFrame#transform method, we would have needed to write code like this:

df1 = with_greeting(source_df)
actual_df = with_something(df1, "moo")

The transform method improves our code by helping us avoid multiple order dependent variable assignments. Creating multiple variables gets especially ugly when 5+ transformations need to be run. You don’t want df1, df2, df3, df4, and df5 😑

Let’s define a DataFrame transformation with an alternative method signature to allow for easier chaining πŸ˜…

Chaining DataFrame Transformations with functools.partial

Let’s define a with_jacket DataFrame transformation that appends a jacket column to a DataFrame.

def with_jacket(word, df):
    return df.withColumn("jacket", lit(word))

We’ll use the same source_df DataFrame and with_greeting method from before and chain the transformations with functools.partial.

from functools import partial

actual_df = (source_df
    .transform(with_greeting)
    .transform(partial(with_jacket, "warm")))
print(actual_df.show())

+----+---+--------+------+
|name|age|greeting|jacket|
+----+---+--------+------+
|jose|  1|      hi|  warm|
|  li|  2|      hi|  warm|
| liz|  3|      hi|  warm|
+----+---+--------+------+

functools.partial helps us get rid of the lambda functions, but we can do even better…

Defining DataFrame transformations as nested functions

DataFrame transformations that are defined with nested functions have the most elegant interface for chaining. Let’s define a with_funny function that appends a funny column to a DataFrame.

def with_funny(something_funny):
    return lambda df: (
        df.withColumn("funny1", F.lit(something_funny))
    )

We’ll use the same source_df DataFrame and with_greeting method from before.

actual_df = (source_df
     .transform(with_greeting)
     .transform(with_funny("haha")))
print(actual_df.show())

+----+---+--------+-----+
|name|age|greeting|funny|
+----+---+--------+-----+
|jose|  1|      hi| haha|
|  li|  2|      hi| haha|
| liz|  3|      hi| haha|
+----+---+--------+-----+

This is much better! 🎊. Thanks for suggesting this implementation hoffrocket!

We can also define a custom transformation with an inner function (the inner function underscore in this example).

def with_funny(word):
    def _(df):
        return df.withColumn("funny", lit(word))
    return _

The inner function is named _. If you’re going to explicitly name the inner function, using an underscore is a good choice because it’s easy to apply consistently throughout the codebase. This design pattern was suggested by the developer that added the transform method to the DataFrame API, see here.

def with_funny(word):
    def _(df):
        return df.withColumn("funny", lit(word))
    return _

The inner function is named _. Naming the inner function as underscore makes it easier to build a consistent codebase, as suggested by the developer that added the transform method to the DataFrame API, see here.

Function composition with cytoolz

We can define custom DataFrame transformations with the @curry decorator and run them with function composition provided by cytoolz.

from cytoolz import curry
from cytoolz.functoolz import compose

@curry
def with_stuff1(arg1, arg2, df):
    return df.withColumn("stuff1", lit(f"{arg1} {arg2}"))

@curry
def with_stuff2(arg, df):
    return df.withColumn("stuff2", lit(arg))
data = [("jose", 1), ("li", 2), ("liz", 3)]
source_df = spark.createDataFrame(data, ["name", "age"])

pipeline = compose(
    with_stuff1("nice", "person"),
    with_stuff2("yoyo")
)
actual_df = pipeline(source_df)
print(actual_df.show())

+----+---+------+-----------+
|name|age|stuff2|     stuff1|
+----+---+------+-----------+
|jose|  1|  yoyo|nice person|
|  li|  2|  yoyo|nice person|
| liz|  3|  yoyo|nice person|
+----+---+------+-----------+

The compose function applies transformations from right to left (bottom to top). We can modify the function to apply the transformations from left to right (top to bottom):

pipeline = compose(*reversed([
    with_stuff1("nice", "person"),
    with_stuff2("yoyo")
]))
actual_df = pipeline(source_df)
print(actual_df.show())

+----+---+-----------+------+
|name|age|     stuff1|stuff2|
+----+---+-----------+------+
|jose|  1|nice person|  yoyo|
|  li|  2|nice person|  yoyo|
| liz|  3|nice person|  yoyo|
+----+---+-----------+------+

Custom transformations are often order dependent and running them from left to right may be required.

Follow the best practices outlined in this post to make it easier to write code with dependencies with cytoolz.

Custom transformations make testing easier

Custom transformations encourage developers to write code that’s easy to test. The code logic is broken up into a bunch of single purpose functions that are easy to understand.

Read this blog post on testing PySpark code for examples of how to test custom transformations.

Chaining custom transformations with the Scala API

The Scala API defines a Dataset#transform method that makes it easy to chain custom transformations. The Scala programming lanaguage allows for multiple parameter lists, so you don’t need to define nested functions.

Chaining custom DataFrame transformations is easier with the Scala API, but still necessary when writing PySpark code!

This blog post explains how to chain DataFrame transformations with the Scala API.

Next steps

You should organize your code as single purpose DataFrame transformations that are tested individually.

Following dependency management and project organization best practices will make your life a lot easier as a PySpark developer. Your development time should be mixed between experimentation in notebooks and coding with software engineering best practices in GitHub repos – both are important.

Use the transform method to chain your DataFrame transformations and run production analyses. Any DataFrame transformations that make assumptions about the underlying schema of a DataFrame should be validated with the quinn DataFrame validation helper methods.

If you’re writing PySpark code properly, you should be using the transform method quite frequently πŸ˜‰

Registration

5 Comments


  1. Hi Matthew, really like your articles and instructions on medium and here. I wonder if you could give me some advice on my pyspark work with aws glue. I don’t want to explain the whole thing here, so will leave my email and hopefully we can get in touch. Thanks for the article though, I hope to use the .transform method soon. And i’m also learning scala, so hope to be doing that soon.

    Thanks,
    Ron

    Reply

  2. Hi,
    Thanks for your post. Is there a potential efficiency improvement if transformations are chained using the this monkey patch trick in comparison with defining a function and sending the dataframe to the function and returning the function as output?

    Reply

  3. I’m struggling to understand the why use transform instead of just chaining two . withColumn() calls.

    I’m sure there’s a benefit, but the examples are very basic.
    Genuinely interested in this if anyone is kind enough to reply !

    Reply

    1. Hi Dee, I think its bcz of this thing mentioned at the end about being able to unit test the code:
      You should organize your code as single purpose DataFrame transformations that tested individually. Read this post on designing easily testable Spark code.

      And also I think for dry purposes, as we can resuse those functions, if we’re doing similar .withColumn()’s over and over again.

      Reply

      1. The transform function also makes it easier keep hardcoded configuration variables out of code. For example:

        CONFIGURATION (read from outside file perhaps?):

        thresholdD = {‘math’: [(‘A’, 92),
        (‘B’, 75),
        (‘C’, 60),
        (‘D’, 30)],
        ‘french’: [(‘A’, 88),
        (‘B’, 72),
        (‘C’, 60),
        (‘D’, 50)],
        ‘chemistry’: [(‘A’, 95),
        (‘B’, 85),
        (‘C’, 70),
        (‘D’, 60)]}

        DATA:
        data = [(‘a’, 77, 64, 57),
        (‘b’, 85, 51, 73),
        (‘c’, 45, 78, 55),
        (‘d’, 94, 34, 90),
        (‘e’, 46, 93, 54),
        (‘f’, 72, 61, 63),
        (‘g’, 73, 41, 86),
        (‘h’, 62, 76, 93),
        (‘i’, 65, 43, 61),
        (‘j’, 93, 74, 99)
        ]

        FUNCTION:

        def chainthresh(df, thresholdD):
        def _(df,att, L):
        df = df.withColumn(‘%s_grade’ % att, F.when(F.col(att) >= L[0][1], L[0][0]) \
        .when(F.col(att) >= L[1][1], L[1][0]) \
        .when(F.col(att) >= L[2][1], L[2][0]) \
        .when(F.col(att) >= L[3][1], L[3][0]) \
        .otherwise(F.lit(‘F’)))
        return df

        for k,L in thresholdD.items():
        df = df.transform(lambda df: _(df,k,L))
        return df

        CODE:
        from pyspark.sql import SparkSession
        sc = SparkSession.builder.getOrCreate()
        import pyspark.sql.functions as F

        dfin = spark.createDataFrame(data, [“student”, “math”, “french”, “chemistry”])
        dfout = dfin.transform(lambda df: chainthresh(df, thresholdD))
        dfout.show()

        OUTPUT:
        +——-+—-+——+———+———-+————+—————+
        |student|math|french|chemistry|math_grade|french_grade|chemistry_grade|
        +——-+—-+——+———+———-+————+—————+
        | a| 77| 64| 57| B| C| F|
        | b| 85| 51| 73| B| D| C|
        | c| 45| 78| 55| D| B| F|
        | d| 94| 34| 90| A| F| B|
        | e| 46| 93| 54| D| A| F|
        | f| 72| 61| 63| C| C| D|
        | g| 73| 41| 86| C| F| B|
        | h| 62| 76| 93| C| B| B|
        | i| 65| 43| 61| C| F| D|
        | j| 93| 74| 99| A| B| A|
        +——-+—-+——+———+———-+————+—————+

        Reply

Leave a Reply

Your email address will not be published. Required fields are marked *