PySpark currently has pandas_udfs, which can create custom aggregators, but you can only “apply” one pandas_udf at a time. If you want to use more than one, you’ll have to preform multiple groupBys…and there goes avoiding those shuffles.
In this post I describe a little hack which enables you to create simple python UDFs which act on aggregated data (this functionality is only supposed to exist in Scala!).
I then create a UDF which will count all the occurences of the letter ‘a’ in these lists (this can be easily done without a UDF but you get the point). This UDF wraps around collect_list, so it acts on the output of collect_list.
deffind_a(x):"""Count 'a's in list."""output_count=0foriinx:ifi=='a':output_count+=1returnoutput_countfind_a_udf=F.udf(find_a,T.IntegerType())a.groupBy('id').agg(find_a_udf(F.collect_list('value')).alias('a_count')).show()
There we go! A UDF that acts on aggregated data! Next, I show the power of this approach when combined with when which let’s us control which data enters F.collect_list.
First, let’s create a dataframe with an extra column.