User Defined FunctionsΒΆ

DataFusion provides powerful expressions and functions, reducing the need for custom Python functions. However you can still incorporate your own functions, i.e. User-Defined Functions (UDFs), with the udf() function.

In [1]: import pyarrow

In [2]: import datafusion

In [3]: from datafusion import udf, col

In [4]: def is_null(array: pyarrow.Array) -> pyarrow.Array:
   ...:     return array.is_null()
   ...: 

In [5]: is_null_arr = udf(is_null, [pyarrow.int64()], pyarrow.bool_(), 'stable')

In [6]: ctx = datafusion.SessionContext()

In [7]: batch = pyarrow.RecordBatch.from_arrays(
   ...:     [pyarrow.array([1, 2, 3]), pyarrow.array([4, 5, 6])],
   ...:     names=["a", "b"],
   ...: )
   ...: 

In [8]: df = ctx.create_dataframe([[batch]], name="batch_array")

In [9]: df.select(is_null_arr(col("a"))).to_pandas()
Out[9]: 
   is_null(batch_array.a)
0                   False
1                   False
2                   False

Additionally the udaf() function allows you to define User-Defined Aggregate Functions (UDAFs)

import pyarrow
import pyarrow.compute
import datafusion
from datafusion import col, udaf, Accumulator

class MyAccumulator(Accumulator):
    """
    Interface of a user-defined accumulation.
    """
    def __init__(self):
        self._sum = pyarrow.scalar(0.0)

    def update(self, values: pyarrow.Array) -> None:
        # not nice since pyarrow scalars can't be summed yet. This breaks on `None`
        self._sum = pyarrow.scalar(self._sum.as_py() + pyarrow.compute.sum(values).as_py())

    def merge(self, states: pyarrow.Array) -> None:
        # not nice since pyarrow scalars can't be summed yet. This breaks on `None`
        self._sum = pyarrow.scalar(self._sum.as_py() + pyarrow.compute.sum(states).as_py())

    def state(self) -> pyarrow.Array:
        return pyarrow.array([self._sum.as_py()])

    def evaluate(self) -> pyarrow.Scalar:
        return self._sum

ctx = datafusion.SessionContext()
df = ctx.from_pydict(
    {
        "a": [1, 2, 3],
        "b": [4, 5, 6],
    }
)

my_udaf = udaf(MyAccumulator, pyarrow.float64(), pyarrow.float64(), [pyarrow.float64()], 'stable')

df.aggregate([],[my_udaf(col("a"))])