Dan Vatterott

Data Scientist

Python Aggregate UDFs in PySpark

PySpark has a great set of aggregate functions (e.g., count, countDistinct, min, max, avg, sum), but these are not enough for all cases (particularly if you’re trying to avoid costly Shuffle operations).

PySpark currently has pandas_udfs, which can create custom aggregators, but you can only “apply” one pandas_udf at a time. If you want to use more than one, you’ll have to preform multiple groupBys…and there goes avoiding those shuffles.

In this post I describe a little hack which enables you to create simple python UDFs which act on aggregated data (this functionality is only supposed to exist in Scala!).

1
2
3
4
5
6
7
8
from pyspark.sql import functions as F
from pyspark.sql import types as T

a = sc.parallelize([[1, 'a'],
                    [1, 'b'],
                    [1, 'b'],
                    [2, 'c']]).toDF(['id', 'value'])
a.show()
id value
1 'a'
1 'b'
1 'b'
2 'c'

I use collect_list to bring all data from a given group into a single row. I print the output of this operation below.

1
a.groupBy('id').agg(F.collect_list('value').alias('value_list')).show()
id value_list
1 ['a', 'b', 'b']
2 ['c']

I then create a UDF which will count all the occurences of the letter ‘a’ in these lists (this can be easily done without a UDF but you get the point). This UDF wraps around collect_list, so it acts on the output of collect_list.

1
2
3
4
5
6
7
8
9
10
11
def find_a(x):
  """Count 'a's in list."""
  output_count = 0
  for i in x:
    if i == 'a':
      output_count += 1
  return output_count

find_a_udf = F.udf(find_a, T.IntegerType())

a.groupBy('id').agg(find_a_udf(F.collect_list('value')).alias('a_count')).show()
id a_count
1 1
2 0

There we go! A UDF that acts on aggregated data! Next, I show the power of this approach when combined with when which let’s us control which data enters F.collect_list.

First, let’s create a dataframe with an extra column.

1
2
3
4
5
6
7
8
9
from pyspark.sql import functions as F
from pyspark.sql import types as T

a = sc.parallelize([[1, 1, 'a'],
                    [1, 2, 'a'],
                    [1, 1, 'b'],
                    [1, 2, 'b'],
                    [2, 1, 'c']]).toDF(['id', 'value1', 'value2'])
a.show()
id value1 value2
1 1 'a'
1 2 'a'
1 1 'b'
1 2 'b'
2 1 'c'

Notice, how I included a when in the collect_list. Note that the UDF still wraps around collect_list.

1
a.groupBy('id').agg(find_a_udf( F.collect_list(F.when(F.col('value1') == 1, F.col('value2')))).alias('a_count')).show()
id a_count
1 1
2 0

There we go! Hope you find this info helpful!

Comments