Skip to content

Commit

Permalink
feat(udf): support batch mode
Browse files Browse the repository at this point in the history
  • Loading branch information
sundy-li committed Sep 1, 2024
1 parent ae390cd commit e0ccad9
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 6 deletions.
24 changes: 18 additions & 6 deletions python/databend_udf/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,14 @@ class ScalarFunction(UserDefinedFunction):
_batch_mode: bool

def __init__(
self, func, input_types, result_type, name=None, io_threads=None, skip_null=None, batch_mode=False
self,
func,
input_types,
result_type,
name=None,
io_threads=None,
skip_null=None,
batch_mode=False,
):
self._func = func
self._input_schema = pa.schema(
Expand Down Expand Up @@ -100,8 +107,8 @@ def eval_batch(self, batch: pa.RecordBatch) -> Iterator[pa.RecordBatch]:
_input_process_func(_list_field(field))(array)
for array, field in zip(inputs, self._input_schema)
]
# evaluate the function for each row

# evaluate the function for each row
if self._batch_mode:
column = self._func(*inputs)
elif self._executor is not None:
Expand Down Expand Up @@ -177,7 +184,7 @@ def external_api(x):
response = requests.get(my_endpoint + '?param=' + x)
return response["data"]
```
Batch mode example:
```
@udf(input_types=['INT', 'INT'], result_type='INT', batch_mode=True)
Expand All @@ -194,11 +201,16 @@ def gcd(x, y):
name,
io_threads=io_threads,
skip_null=skip_null,
batch_mode=batch_mode
batch_mode=batch_mode,
)
else:
return lambda f: ScalarFunction(
f, input_types, result_type, name, skip_null=skip_null, batch_mode=batch_mode
f,
input_types,
result_type,
name,
skip_null=skip_null,
batch_mode=batch_mode,
)


Expand Down
3 changes: 3 additions & 0 deletions python/example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def gcd(x: int, y: int) -> int:
(x, y) = (y, x % y)
return x


@udf(
name="gcd_batch",
input_types=["INT", "INT"],
Expand All @@ -67,8 +68,10 @@ def gcd_single(x_i, y_i):
while y_i != 0:
(x_i, y_i) = (y_i, x_i % y_i)
return x_i

return [gcd_single(x_i, y_i) for x_i, y_i in zip(x, y)]


@udf(input_types=["VARCHAR", "VARCHAR", "VARCHAR"], result_type="VARCHAR")
def split_and_join(s: str, split_s: str, join_s: str) -> str:
return join_s.join(s.split(split_s))
Expand Down

0 comments on commit e0ccad9

Please sign in to comment.