Skip to content

Commit

Permalink
fix: string and binary type
Browse files Browse the repository at this point in the history
  • Loading branch information
everpcpc committed Jan 17, 2024
1 parent 4d12fed commit 0ee7509
Show file tree
Hide file tree
Showing 6 changed files with 463 additions and 83 deletions.
1 change: 1 addition & 0 deletions .github/workflows/python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ jobs:
python -m pip install build
python -m build
- uses: pypa/gh-action-pypi-publish@release/v1
if: github.event_name == 'push'
with:
packages-dir: python/dist/
skip-existing: true
1 change: 1 addition & 0 deletions python/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ venv/
*.egg-info/
dist/
__pycache__/
Pipfile.lock
14 changes: 14 additions & 0 deletions python/Pipfile
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[[source]]
url = "https://pypi.org/simple"
verify_ssl = true
name = "pypi"

[packages]
databend-udf = {file = "."}

[dev-packages]
flake8 = "*"
black = "*"

[requires]
python_version = "3.12"
206 changes: 124 additions & 82 deletions python/databend_udf/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
# limitations under the License.

import json
import logging
import inspect
import traceback
from concurrent.futures import ThreadPoolExecutor
from typing import Iterator, Callable, Optional, Union, List, Dict

Expand All @@ -24,11 +24,13 @@
# comes from Databend
MAX_DECIMAL128_PRECISION = 38
MAX_DECIMAL256_PRECISION = 76
EXTENSION_KEY = "Extension"
ARROW_EXT_TYPE_VARIANT = "Variant"
EXTENSION_KEY = b"Extension"
ARROW_EXT_TYPE_VARIANT = b"Variant"

TIMESTAMP_UINT = "us"

logger = logging.getLogger(__name__)


class UserDefinedFunction:
"""
Expand Down Expand Up @@ -92,8 +94,8 @@ def __init__(
def eval_batch(self, batch: pa.RecordBatch) -> Iterator[pa.RecordBatch]:
inputs = [[v.as_py() for v in array] for array in batch]
inputs = [
_process_func(pa.list_(type), False)(array)
for array, type in zip(inputs, self._input_schema.types)
_input_process_func(_list_field(field))(array)
for array, field in zip(inputs, self._input_schema)
]
if self._executor is not None:
# concurrently evaluate the function for each row
Expand Down Expand Up @@ -122,7 +124,7 @@ def eval_batch(self, batch: pa.RecordBatch) -> Iterator[pa.RecordBatch]:
for row in range(batch.num_rows)
]

column = _process_func(pa.list_(self._result_schema.types[0]), True)(column)
column = _output_process_func(_list_field(self._result_schema.field(0)))(column)

array = pa.array(column, type=self._result_schema.types[0])
yield pa.RecordBatch.from_arrays([array], schema=self._result_schema)
Expand Down Expand Up @@ -231,7 +233,7 @@ def do_exchange(self, context, descriptor, reader, writer):
for output_batch in udf.eval_batch(batch.data):
writer.write_batch(output_batch)
except Exception as e:
print(traceback.print_exc())
logger.exception(e)
raise e

def add_function(self, udf: UserDefinedFunction):
Expand All @@ -249,104 +251,138 @@ def add_function(self, udf: UserDefinedFunction):
f"RETURNS {output_type} LANGUAGE python "
f"HANDLER = '{name}' ADDRESS = 'http://{self._location}';"
)
print(f"added function: {name}, corresponding SQL:\n{sql}\n")
logger.info(f"added function: {name}, SQL:\n{sql}\n")

def serve(self):
"""Start the server."""
print(f"listening on {self._location}")
logger.info(f"listening on {self._location}")
super(UDFServer, self).serve()


def _null_func(*args):
return None


def _process_func(type: pa.DataType, output: bool) -> Callable:
def _input_process_func(field: pa.Field) -> Callable:
"""
Return a function to process input or output value.
For input type:
- String=pa.string(): bytes -> str
- Tuple=pa.struct(): dict -> tuple
- Json=pa.large_binary(): bytes -> Any
- Map=pa.map_(): list[tuple(k,v)] -> dict
Return a function to process input value.
For output type:
- Json=pa.large_binary(): Any -> str
- Map=pa.map_(): dict -> list[tuple(k,v)]
- Tuple=pa.struct(): dict -> tuple
- Json=pa.large_binary(): bytes -> Any
- Map=pa.map_(): list[tuple(k,v)] -> dict
"""
if pa.types.is_list(type):
func = _process_func(type.value_type, output)
if pa.types.is_list(field.type):
func = _input_process_func(field.type.value_field)
return (
lambda array: [(func(v) if v is not None else None) for v in array]
lambda array: [func(v) if v is not None else None for v in array]
if array is not None
else None
)
if pa.types.is_struct(type):
funcs = [_process_func(field.type, output) for field in type]
if output:
return (
lambda tup: tuple(
(func(v) if v is not None else None) for v, func in zip(tup, funcs)
)
if tup is not None
else None
)
else:
# the input value of struct type is a dict
# we convert it into tuple here
return (
lambda map: tuple(
(func(v) if v is not None else None)
for v, func in zip(map.values(), funcs)
)
if map is not None
else None
if pa.types.is_struct(field.type):
funcs = [_input_process_func(f) for f in field.type]
# the input value of struct type is a dict
# we convert it into tuple here
return (
lambda map: tuple(
func(v) if v is not None else None
for v, func in zip(map.values(), funcs)
)
if pa.types.is_map(type):
if map is not None
else None
)
if pa.types.is_map(field.type):
funcs = [
_process_func(type.key_type, output),
_process_func(type.item_type, output),
_input_process_func(field.type.key_field),
_input_process_func(field.type.item_field),
]
if output:
# dict -> list[tuple[k,v]]
return (
lambda map: [
tuple(func(v) for v, func in zip(item, funcs))
for item in map.items()
]
if map is not None
else None
# list[tuple[k,v]] -> dict
return (
lambda array: dict(
tuple(func(v) for v, func in zip(item, funcs)) for item in array
)
else:
# list[tuple[k,v]] -> dict
return (
lambda array: dict(
tuple(func(v) for v, func in zip(item, funcs)) for item in array
)
if array is not None
else None
if array is not None
else None
)
if pa.types.is_large_binary(field.type):
if _field_is_variant(field):
return lambda v: json.loads(v) if v is not None else None

return lambda v: v


def _output_process_func(field: pa.Field) -> Callable:
"""
Return a function to process output value.
- Json=pa.large_binary(): Any -> str
- Map=pa.map_(): dict -> list[tuple(k,v)]
"""
if pa.types.is_list(field.type):
func = _output_process_func(field.type.value_field)
return (
lambda array: [func(v) if v is not None else None for v in array]
if array is not None
else None
)
if pa.types.is_struct(field.type):
funcs = [_output_process_func(f) for f in field.type]
return (
lambda tup: tuple(
func(v) if v is not None else None for v, func in zip(tup, funcs)
)
if tup is not None
else None
)
if pa.types.is_map(field.type):
funcs = [
_output_process_func(field.type.key_field),
_output_process_func(field.type.item_field),
]
# dict -> list[tuple[k,v]]
return (
lambda map: [
tuple(func(v) for v, func in zip(item, funcs)) for item in map.items()
]
if map is not None
else None
)
if pa.types.is_large_binary(field.type):
if _field_is_variant(field):
return lambda v: json.dumps(_ensure_str(v)) if v is not None else None

if pa.types.is_string(type) and not output:
# string type is converted to LargeBinary in Databend,
# we cast it back to string here
return lambda v: v.decode("utf-8") if v is not None else None
if pa.types.is_large_binary(type):
if output:
return lambda v: json.dumps(v) if v is not None else None
else:
return lambda v: json.loads(v) if v is not None else None
return lambda v: v


def _null_func(*args):
return None


def _list_field(field: pa.Field) -> pa.Field:
return pa.field("", pa.list_(field))


def _to_list(x):
if isinstance(x, list):
return x
else:
return [x]


def _ensure_str(x):
if isinstance(x, bytes):
return x.decode("utf-8")
elif isinstance(x, list):
return [_ensure_str(v) for v in x]
elif isinstance(x, dict):
return {_ensure_str(k): _ensure_str(v) for k, v in x.items()}
else:
return x


def _field_is_variant(field: pa.Field) -> bool:
if field.metadata is None:
return False
if field.metadata.get(EXTENSION_KEY) == ARROW_EXT_TYPE_VARIANT:
return True
return False


def _to_arrow_field(t: Union[str, pa.DataType]) -> pa.Field:
"""
Convert a string or pyarrow.DataType to pyarrow.Field.
Expand Down Expand Up @@ -401,7 +437,9 @@ def _type_str_to_arrow_field_inner(type_str: str) -> pa.Field:
elif type_str in ("DATETIME", "TIMESTAMP"):
return pa.field("", pa.timestamp(TIMESTAMP_UINT), False)
elif type_str in ("STRING", "VARCHAR", "CHAR", "CHARACTER", "TEXT"):
return pa.field("", pa.string(), False)
return pa.field("", pa.large_utf8(), False)
elif type_str in ("BINARY"):
return pa.field("", pa.large_binary(), False)
elif type_str in ("VARIANT", "JSON"):
# In Databend, JSON type is identified by the "EXTENSION" key in the metadata.
return pa.field(
Expand Down Expand Up @@ -460,20 +498,21 @@ def _arrow_field_to_string(field: pa.Field) -> str:
"""
Convert a `pyarrow.Field` to a SQL data type string.
"""
type_str = _data_type_to_string(field.type)
type_str = _field_type_to_string(field)
return f"{type_str} NOT NULL" if not field.nullable else type_str


def _inner_field_to_string(field: pa.Field) -> str:
# inner field default is NOT NULL in databend
type_str = _data_type_to_string(field.type)
type_str = _field_type_to_string(field)
return f"{type_str} NULL" if field.nullable else type_str


def _data_type_to_string(t: pa.DataType) -> str:
def _field_type_to_string(field: pa.Field) -> str:
"""
Convert a `pyarrow.DataType` to a SQL data type string.
"""
t = field.type
if pa.types.is_boolean(t):
return "BOOLEAN"
elif pa.types.is_int8(t):
Expand Down Expand Up @@ -502,10 +541,13 @@ def _data_type_to_string(t: pa.DataType) -> str:
return "DATE"
elif pa.types.is_timestamp(t):
return "TIMESTAMP"
elif pa.types.is_string(t):
elif pa.types.is_large_unicode(t):
return "VARCHAR"
elif pa.types.is_large_binary(t):
return "VARIANT"
if _field_is_variant(field):
return "VARIANT"
else:
return "BINARY"
elif pa.types.is_list(t):
return f"ARRAY({_inner_field_to_string(t.value_field)})"
elif pa.types.is_map(t):
Expand Down
Loading

0 comments on commit 0ee7509

Please sign in to comment.