Right now there are a few ways we can create UDF:
-
With standalone function:
def _add_one(x): """Adds one""" if x is not None: return x + 1 add_one = udf(_add_one, IntegerType())
This allows for full control flow, including exception handling, but duplicates variables.
-
With
lambda
expression:add_one = udf(lambda x: x + 1 if x is not None else None, IntegerType())
No variable duplication but only pure expressions.
-
Using nested functions with immediate call:
def add_one(c): def add_one_(x): if x is not None: return x + 1 return udf(add_one_, IntegerType())(c)
Quite verbose but enables full control flow and clearly indicates expected number of arguments.
-
Using
udf
functions as a decorator:@udf def add_one(x): """Adds one""" if x is not None: return x + 1
Possible but only with default
returnType
(or curried@partial(udf, returnType=IntegerType())
).
Add udf
decorator which can be used as follows:
from pyspark.sql.decorators import udf
@udf(IntegerType())
def add_one(x):
"""Adds one"""
if x is not None:
return x + 1
Example implementation: https://github.com/zero323/spark/commit/74ffdcd60e36fc915d434a5f7a8b567a5b8ab570
Right now udf
returns an UserDefinedFunction
object which doesn't provide meaningful docstring:
In [1]: from pyspark.sql.types import IntegerType
In [2]: from pyspark.sql.functions import udf
In [3]: def _add_one(x):
"""Adds one"""
if x is not None:
return x + 1
...:
In [4]: add_one = udf(_add_one, IntegerType())
In [5]: ?add_one
Type: UserDefinedFunction
String form: <pyspark.sql.functions.UserDefinedFunction object at 0x7f281ed2d198>
File: ~/Spark/spark-2.0/python/pyspark/sql/functions.py
Signature: add_one(*cols)
Docstring:
User defined function in Python
.. versionadded:: 1.3
In [6]: help(add_one)
Help on UserDefinedFunction in module pyspark.sql.functions object:
class UserDefinedFunction(builtins.object)
| User defined function in Python
|
| .. versionadded:: 1.3
|
| Methods defined here:
|
| __call__(self, *cols)
| Call self as a function.
|
| __del__(self)
|
| __init__(self, func, returnType, name=None)
| Initialize self. See help(type(self)) for accurate signature.
|
| ----------------------------------------------------------------------
| Data descriptors defined here:
|
| __dict__
| dictionary for instance variables (if defined)
|
| __weakref__
| list of weak references to the object (if defined)
(END)
It is possible to extract the function:
In [7]: ?add_one.func
Signature: add_one.func(x)
Docstring: Adds one
File: ~/Spark/spark-2.0/<ipython-input-3-d2d8e4c530ac>
Type: function
In [8]: help(add_one.func)
Help on function _add_one in module __main__:
_add_one(x)
Adds one
but it assumes that the final user is aware of the distinction between UDF and built-in functions.
Copy input functions docstring to the UDF object or function wrapper.
In [1]: from pyspark.sql.types import IntegerType
In [2]: from pyspark.sql.functions import udf
In [3]: def _add_one(x):
"""Adds one"""
if x is not None:
return x + 1
...:
In [4]: add_one = udf(_add_one, IntegerType())
In [5]: ?add_one
Signature: add_one(x)
Docstring:
Adds one
SQL Type: IntegerType
File: ~/Workspace/spark/<ipython-input-3-d2d8e4c530ac>
Type: function
In [6]: help(add_one)
Help on function _add_one in module __main__:
_add_one(x)
Adds one
SQL Type: IntegerType
(END)
Proposed implementation: https://github.com/zero323/spark/commit/aebe3a69609448fdfb438f5f27a58f7e134cb201
Right now UserDefinedFunctions
don't perform any input type validation. It will accept non-callable objects just to fail with hard to understand traceback:
In [1]: from pyspark.sql.functions import udf
In [2]: df = spark.range(0, 1)
In [3]: f = udf(None)
In [4]: df.select(f()).first()
17/01/07 19:30:50 ERROR Executor: Exception in task 2.0 in stage 2.0 (TID 7)
...
Py4JJavaError: An error occurred while calling o51.collectToPython.
...
TypeError: 'NoneType' object is not callable
...
Invalid arguments to UDF call fail fast but with a bit cryptic Py4J errors:
In [5]: g = udf(lambda x: x)
In [6]: df.select(f([]))
---------------------------------------------------------------------------
Py4JError Traceback (most recent call last)
<ipython-input-10-5fb48a5d66d2> in <module>()
----> 1 df.select(f([]))
....
Py4JError: An error occurred while calling z:org.apache.spark.sql.functions.col. Trace:
py4j.Py4JException: Method col([class java.util.ArrayList]) does not exist
at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:318)
at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:339)
at py4j.Gateway.invoke(Gateway.java:274)
at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
at py4j.commands.CallCommand.execute(CallCommand.java:79)
at py4j.GatewayConnection.run(GatewayConnection.java:214)
at java.lang.Thread.run(Thread.java:745)
Apply basic type validation for both constructor arguments:
In [7]: udf(None)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-7-0765fbe657a9> in <module>()
----> 1 udf(None)
...
TypeError: func should be a callable object (a function or an instance of a class with __call__). Got <class 'NoneType'>
and call arguments:
In [8]: f = udf(lambda x: x)
In [9]: f(1)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
...
TypeError: All arguments should be Columns or strings representing column names. Got 1 of type <class 'int'>
Proposed implementation https://github.com/zero323/spark/commit/f9e481467efe58653d660cb6615291f362079bf3 and https://github.com/zero323/spark/commit/c62168781c7f0259ce7be58f708d32ac8d5c916f
Right UserDefinedFunction
eagerly creates _judf
and initializes SparkSession
as a side effect. This behavior may have undesired results when udf
is imported from a module:
-
myudfs.py
from pyspark.sql.functions import udf from pyspark.sql.types import IntegerType def _add_one(x): """Adds one""" if x is not None: return x + 1 add_one = udf(_add_one, IntegerType())
Example session:
In [1]: from pyspark.sql import SparkSession
In [2]: from myudfs import add_one
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
17/01/07 19:55:44 WARN Utils: Your hostname, xxx resolves to a loopback address: 127.0.1.1; using xxx instead (on interface eth0)
17/01/07 19:55:44 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
In [3]: spark = SparkSession.builder.appName("foo").getOrCreate()
In [4]: spark.sparkContext.appName
Out[4]: 'pyspark-shell'
Delay _judf
initialization until the first call.
In [1]: from pyspark.sql import SparkSession
In [2]: from myudfs import add_one
In [3]: spark = SparkSession.builder.appName("foo").getOrCreate()
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
17/01/07 19:58:38 WARN Utils: Your hostname, xxx resolves to a loopback address: 127.0.1.1; using xxx instead (on interface eth0)
17/01/07 19:58:38 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
In [4]: spark.sparkContext.appName
Out[4]: 'foo'
Proposed implementation: https://github.com/zero323/spark/commit/ce040b98e52bc38cca6e52bab97370db687a5614
As far as I can tell UserDefinedFunction._broadcast is not used at all.
Remove _broadcast
variable with coressponding __del__
method.
Proposed implementation: https://github.com/zero323/spark/commit/4f917b2fa15515d7c2636d40d71a94eb78fc05ea
So the one thing that we might be able to do is for py34+ we might be able to do some pieces with py3 type annotations?
That should give us some additional simplicity under modern python