Skip to content

Commit

Permalink
Refactor parsing code of UDF signatures
Browse files Browse the repository at this point in the history
  • Loading branch information
jmao-denver committed Mar 19, 2024
1 parent 122e000 commit 0b49a64
Show file tree
Hide file tree
Showing 7 changed files with 145 additions and 113 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2817,12 +2817,6 @@ private void prepareVectorizationArgs(
} else {
throw new IllegalStateException("Vectorizability check failed: " + n);
}

// TODO related to core#709, but should be covered by PyCallableWrapper.verifyArguments, needs to verify
// if (!isSafelyCoerceable(argTypes[i], paramTypes.get(i))) {
// throw new PythonCallVectorizationFailure("Python vectorized function argument type mismatch: " + n + " "
// + argTypes[i].getSimpleName() + " -> " + paramTypes.get(i).getSimpleName());
// }
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,22 +201,6 @@ public void parseSignature() {
throw new IllegalStateException("Signature should always be available.");
}

// List<Class<?>> paramTypes = new ArrayList<>();
// for (char numpyTypeChar : signatureString.toCharArray()) {
// if (numpyTypeChar != '-') {
// Class<?> paramType = numpyType2JavaClass.get(numpyTypeChar);
// if (paramType == null) {
// throw new IllegalStateException(
// "Parameters of vectorized functions should always be of integral, floating point, boolean, String, or Object
// type: "
// + numpyTypeChar + " of " + signatureString);
// }
// paramTypes.add(paramType);
// } else {
// break;
// }
// }
// this.paramTypes = paramTypes;
String pyEncodedParamsStr = signatureString.split("->")[0];
if (!pyEncodedParamsStr.isEmpty()) {
String[] pyEncodedParams = pyEncodedParamsStr.split(",");
Expand Down Expand Up @@ -271,14 +255,13 @@ private boolean isSafelyCastable(Set<Class<?>> types, Class<?> type) {
public void verifyArguments(Class<?>[] argTypes) {
String callableName = pyCallable.getAttribute("__name__").toString();

// if (argTypes.length > parameters.size()) {
// throw new IllegalArgumentException(
// callableName + ": " + "Expected " + parameters.size() + " or fewer arguments, got " + argTypes.length);
// }
for (int i = 0; i < argTypes.length; i++) {
Set<Class<?>> types =
parameters.get(i > parameters.size() - 1 ? parameters.size() - 1 : i).getPossibleTypes();
// Object is a catch-all type, so we don't need to check for it

// to prevent the unpacking of an array column when calling a Python function, we prefix the column accessor
// with a cast to generic Object type, until we can find a way to convey that info, we'll just skip the
// check for Object type input
if (argTypes[i] == Object.class) {
continue;
}
Expand Down
113 changes: 102 additions & 11 deletions py/server/deephaven/_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,13 @@
import inspect
import re
import sys
import typing
from dataclasses import dataclass, field
from datetime import datetime
from functools import wraps
from typing import Callable, List, Any, Union, Tuple, _GenericAlias, Set
from typing import Callable, List, Any, Union, Tuple, _GenericAlias, Set, Optional, Sequence

import pandas as pd

from deephaven._dep import soft_dependency

Expand All @@ -17,17 +21,16 @@
import numpy as np

from deephaven import DHError, dtypes
from deephaven.dtypes import _np_ndarray_component_type, _np_dtype_char, _NUMPY_INT_TYPE_CODES, \
_NUMPY_FLOATING_TYPE_CODES, _component_np_dtype_char, _J_ARRAY_NP_TYPE_MAP, _PRIMITIVE_DTYPE_NULL_MAP, _scalar, \
_BUILDABLE_ARRAY_DTYPE_MAP
from deephaven.dtypes import _NUMPY_INT_TYPE_CODES, _NUMPY_FLOATING_TYPE_CODES, _J_ARRAY_NP_TYPE_MAP, \
_PRIMITIVE_DTYPE_NULL_MAP, _scalar, \
_BUILDABLE_ARRAY_DTYPE_MAP, DType
from deephaven.jcompat import _j_array_to_numpy_array
from deephaven.time import to_np_datetime64

# For unittest vectorization
test_vectorization = False
vectorized_count = 0


_SUPPORTED_NP_TYPE_CODES = {"b", "h", "H", "i", "l", "f", "d", "?", "U", "M", "O"}


Expand Down Expand Up @@ -80,7 +83,7 @@ def _encode_param_type(t: type) -> str:
return "N"

# find the component type if it is numpy ndarray
component_type = _np_ndarray_component_type(t)
component_type = _component_np_dtype_char(t)
if component_type:
t = component_type

Expand All @@ -92,14 +95,100 @@ def _encode_param_type(t: type) -> str:
return tc


def _np_dtype_char(t: Union[type, str]) -> str:
"""Returns the numpy dtype character code for the given type."""
try:
np_dtype = np.dtype(t if t else "object")
if np_dtype.kind == "O":
if t in (datetime, pd.Timestamp):
return "M"
except TypeError:
np_dtype = np.dtype("object")

return np_dtype.char


def _component_np_dtype_char(t: type) -> Optional[str]:
"""Returns the numpy dtype character code for the given type's component type if the type is a Sequence type or
numpy ndarray, otherwise return None. """
component_type = None

if not component_type and sys.version_info.major == 3 and sys.version_info.minor > 8:
import types
if isinstance(t, types.GenericAlias) and issubclass(t.__origin__, Sequence):
component_type = t.__args__[0]

if not component_type:
if isinstance(t, _GenericAlias) and issubclass(t.__origin__, Sequence):
component_type = t.__args__[0]
# if the component type is a DType, get its numpy type
if isinstance(component_type, DType):
component_type = component_type.np_type

if not component_type:
if t == bytes or t == bytearray:
return "b"

if not component_type:
component_type = _np_ndarray_component_type(t)

if component_type:
return _np_dtype_char(component_type)
else:
return None


def _np_ndarray_component_type(t: type) -> Optional[type]:
"""Returns the numpy ndarray component type if the type is a numpy ndarray, otherwise return None."""

# Py3.8: npt.NDArray can be used in Py 3.8 as a generic alias, but a specific alias (e.g. npt.NDArray[np.int64])
# is an instance of a private class of np, yet we don't have a choice but to use it. And when npt.NDArray is used,
# the 1st argument is typing.Any, the 2nd argument is another generic alias of which the 1st argument is the
# component type
component_type = None
if sys.version_info.major == 3 and sys.version_info.minor == 8:
if isinstance(t, np._typing._generic_alias._GenericAlias) and t.__origin__ == np.ndarray:
component_type = t.__args__[1].__args__[0]
# Py3.9+, np.ndarray as a generic alias is only supported in Python 3.9+, also npt.NDArray is still available but a
# specific alias (e.g. npt.NDArray[np.int64]) now is an instance of typing.GenericAlias.
# when npt.NDArray is used, the 1st argument is typing.Any, the 2nd argument is another generic alias of which
# the 1st argument is the component type
# when np.ndarray is used, the 1st argument is the component type
if not component_type and sys.version_info.major == 3 and sys.version_info.minor > 8:
import types
if isinstance(t, types.GenericAlias) and t.__origin__ == np.ndarray:
nargs = len(t.__args__)
if nargs == 1:
component_type = t.__args__[0]
elif nargs == 2: # for npt.NDArray[np.int64], etc.
a0 = t.__args__[0]
a1 = t.__args__[1]
if a0 == typing.Any and isinstance(a1, types.GenericAlias):
component_type = a1.__args__[0]
return component_type


def _is_union_type(t: type) -> bool:
"""Return True if the type is a Union type"""
if sys.version_info.major == 3 and sys.version_info.minor >= 10:
import types
if isinstance(t, types.UnionType):
return True

if isinstance(t, _GenericAlias) and t.__origin__ == Union:
return True

return False


def _parse_param(name: str, annotation: Any) -> _ParsedParam:
""" Parse a parameter annotation in a function's signature """
p_param = _ParsedParam(name)

if annotation is inspect._empty:
p_param.encoded_types.add("O")
p_param.none_allowed = True
elif isinstance(annotation, _GenericAlias) and annotation.__origin__ == Union:
elif _is_union_type(annotation):
for t in annotation.__args__:
_parse_type_no_nested(annotation, p_param, t)
else:
Expand Down Expand Up @@ -149,7 +238,7 @@ def _parse_return_annotation(annotation: Any) -> _ParsedReturnAnnotation:

t = annotation
pra.orig_type = t
if isinstance(annotation, _GenericAlias) and annotation.__origin__ == Union and len(annotation.__args__) == 2:
if _is_union_type(annotation) and len(annotation.__args__) == 2:
# if the annotation is a Union of two types, we'll use the non-None type
if annotation.__args__[1] == type(None): # noqa: E721
t = annotation.__args__[0]
Expand All @@ -170,7 +259,8 @@ def _parse_return_annotation(annotation: Any) -> _ParsedReturnAnnotation:


if numba:
def _parse_numba_signature(fn: Union[numba.np.ufunc.gufunc.GUFunc, numba.np.ufunc.dufunc.DUFunc]) -> _ParsedSignature:
def _parse_numba_signature(
fn: Union[numba.np.ufunc.gufunc.GUFunc, numba.np.ufunc.dufunc.DUFunc]) -> _ParsedSignature:
""" Parse a numba function's signature"""
sigs = fn.types # in the format of ll->l, ff->f,dd->d,OO->O, etc.
if sigs:
Expand Down Expand Up @@ -261,7 +351,8 @@ def _parse_signature(fn: Callable) -> _ParsedSignature:
t = eval(p.annotation, fn.__globals__) if isinstance(p.annotation, str) else p.annotation
p_sig.params.append(_parse_param(n, t))

t = eval(sig.return_annotation, fn.__globals__) if isinstance(sig.return_annotation, str) else sig.return_annotation
t = eval(sig.return_annotation, fn.__globals__) if isinstance(sig.return_annotation,
str) else sig.return_annotation
p_sig.ret_annotation = _parse_return_annotation(t)
return p_sig

Expand Down Expand Up @@ -476,4 +567,4 @@ def wrapper(*args):
global vectorized_count
vectorized_count += 1

return wrapper
return wrapper
68 changes: 2 additions & 66 deletions py/server/deephaven/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,10 @@
from __future__ import annotations

import datetime
import sys
import typing
from typing import Any, Sequence, Callable, Dict, Type, Union, _GenericAlias, Optional
from typing import Any, Sequence, Callable, Dict, Type, Union, Optional

import jpy
import numpy as np
import numpy._typing as npt
import pandas as pd

from deephaven import DHError
Expand Down Expand Up @@ -304,7 +301,7 @@ def array(dtype: DType, seq: Optional[Sequence], remap: Callable[[Any], Any] = N
raise DHError(e, f"failed to create a Java {dtype.j_name} array.") from e


def from_jtype(j_class: Any) -> DType:
def from_jtype(j_class: Any) -> Optional[DType]:
""" looks up a DType that matches the java type, if not found, creates a DType for it. """
if not j_class:
return None
Expand Down Expand Up @@ -392,64 +389,3 @@ def _scalar(x: Any, dtype: DType) -> Any:
except:
return x


def _np_dtype_char(t: Union[type, str]) -> str:
"""Returns the numpy dtype character code for the given type."""
try:
np_dtype = np.dtype(t if t else "object")
if np_dtype.kind == "O":
if t in (datetime.datetime, pd.Timestamp):
return "M"
except TypeError:
np_dtype = np.dtype("object")

return np_dtype.char


def _component_np_dtype_char(t: type) -> Optional[str]:
"""Returns the numpy dtype character code for the given type's component type if the type is a Sequence type or
numpy ndarray, otherwise return None. """
component_type = None
if isinstance(t, _GenericAlias) and issubclass(t.__origin__, Sequence):
component_type = t.__args__[0]
# if the component type is a DType, get its numpy type
if isinstance(component_type, DType):
component_type = component_type.np_type

if not component_type:
component_type = _np_ndarray_component_type(t)

if component_type:
return _np_dtype_char(component_type)
else:
return None


def _np_ndarray_component_type(t: type) -> Optional[type]:
"""Returns the numpy ndarray component type if the type is a numpy ndarray, otherwise return None."""

# Py3.8: npt.NDArray can be used in Py 3.8 as a generic alias, but a specific alias (e.g. npt.NDArray[np.int64])
# is an instance of a private class of np, yet we don't have a choice but to use it. And when npt.NDArray is used,
# the 1st argument is typing.Any, the 2nd argument is another generic alias of which the 1st argument is the
# component type
component_type = None
if sys.version_info.major == 3 and sys.version_info.minor == 8:
if isinstance(t, np._typing._generic_alias._GenericAlias) and t.__origin__ == np.ndarray:
component_type = t.__args__[1].__args__[0]
# Py3.9+, np.ndarray as a generic alias is only supported in Python 3.9+, also npt.NDArray is still available but a
# specific alias (e.g. npt.NDArray[np.int64]) now is an instance of typing.GenericAlias.
# when npt.NDArray is used, the 1st argument is typing.Any, the 2nd argument is another generic alias of which
# the 1st argument is the component type
# when np.ndarray is used, the 1st argument is the component type
if not component_type and sys.version_info.major == 3 and sys.version_info.minor > 8:
import types
if isinstance(t, types.GenericAlias) and (issubclass(t.__origin__, Sequence) or t.__origin__ == np.ndarray):
nargs = len(t.__args__)
if nargs == 1:
component_type = t.__args__[0]
elif nargs == 2: # for npt.NDArray[np.int64], etc.
a0 = t.__args__[0]
a1 = t.__args__[1]
if a0 == typing.Any and isinstance(a1, types.GenericAlias):
component_type = a1.__args__[0]
return component_type
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Copyright (c) 2016-2024 Deephaven Data Labs and Patent Pending
#
import typing
from typing import Optional, Union, Any
from typing import Optional, Union, Any, Sequence
import unittest

import numpy as np
Expand Down Expand Up @@ -452,13 +452,40 @@ def f(x: {p_type}) -> bool: # note typing
t = empty_table(1).update(["X = i", f"Y = f(X)"])
self.assertRegex(str(cm.exception), "f: Expect")

def test_np_typehints_mismatch(self):
def f(x: float) -> bool:
return True
def test_sequence_args(self):
with self.subTest("Sequence"):
def f(x: Sequence[int]) -> bool:
return True

with self.assertRaises(DHError) as cm:
t = empty_table(1).update(["X = i", "Y = f(ii)"])
self.assertRegex(str(cm.exception), "f: Expect")

t = empty_table(1).update(["X = i", "Y = ii"]).group_by("X").update(["Z = f(Y.toArray())"])
self.assertEqual(t.columns[2].data_type, dtypes.bool_)

with self.subTest("bytes"):
def f(x: bytes) -> bool:
return True

with self.assertRaises(DHError) as cm:
t = empty_table(1).update(["X = i", "Y = f(ii)"])
self.assertRegex(str(cm.exception), "f: Expect")

t = empty_table(1).update(["X = i", "Y = (byte)(ii % 128)"]).group_by("X").update(["Z = f(Y.toArray())"])
self.assertEqual(t.columns[2].data_type, dtypes.bool_)

with self.subTest("bytearray"):
def f(x: bytearray) -> bool:
return True

with self.assertRaises(DHError) as cm:
t = empty_table(1).update(["X = i", "Y = f(ii)"])
self.assertRegex(str(cm.exception), "f: Expect")

t = empty_table(1).update(["X = i", "Y = (byte)(ii % 128)"]).group_by("X").update(["Z = f(Y.toArray())"])
self.assertEqual(t.columns[2].data_type, dtypes.bool_)

with self.assertRaises(DHError) as cm:
t = empty_table(1).update(["X = i", "Y = f(ii)"])
self.assertRegex(str(cm.exception), "f: Expect")

if __name__ == "__main__":
unittest.main()
3 changes: 2 additions & 1 deletion py/server/tests/test_udf_return_java_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ def test_array_return(self):
"np.str_": dtypes.string_array,
"np.uint16": dtypes.char_array,
}
container_types = ["List", "Tuple", "list", "tuple", "Sequence", "np.ndarray"]
# container_types = ["List", "Tuple", "list", "tuple", "Sequence", "np.ndarray"]
container_types = ["list"]
for component_type, dh_dtype in component_types.items():
for container_type in container_types:
with self.subTest(component_type=component_type, container_type=container_type):
Expand Down
Loading

0 comments on commit 0b49a64

Please sign in to comment.