PySpark has a great set of aggregate functions (e.g., count, countDistinct, min, max, avg, sum), but these are not enough for all cases (particularly if you’re trying to avoid costly Shuffle operations).
PySpark currently has pandas_udfs, which can create custom aggregators, but you can only “apply” one pandas_udf at a time. If you want to use more than one, you’ll have to preform multiple groupBys…and there goes avoiding those shuffles.
In this post I describe a little hack which enables you to create simple python UDFs which act on aggregated data (this functionality is only supposed to exist in Scala!).
I then create a UDF which will count all the occurences of the letter ‘a’ in these lists (this can be easily done without a UDF but you get the point). This UDF wraps around collect_list, so it acts on the output of collect_list.
1234567891011
deffind_a(x):"""Count 'a's in list."""output_count=0foriinx:ifi=='a':output_count+=1returnoutput_countfind_a_udf=F.udf(find_a,T.IntegerType())a.groupBy('id').agg(find_a_udf(F.collect_list('value')).alias('a_count')).show()
id
a_count
1
1
2
0
There we go! A UDF that acts on aggregated data! Next, I show the power of this approach when combined with when which let’s us control which data enters F.collect_list.
First, let’s create a dataframe with an extra column.