From 77587f4ff3d346da1a7c2947aeb456bd2b0b9657 Mon Sep 17 00:00:00 2001 From: Jianfeng Mao <4297243+jmao-denver@users.noreply.github.com> Date: Fri, 6 Oct 2023 15:15:18 -0600 Subject: [PATCH] auto convert Python return val to Java val (#4579) * Auto conv Py UDF returns and numba guvectorized * Add numba guvectorize tests * Harden the code for bad type annotations * More meaningful names * Clean up test code * Apply suggestions from code review Co-authored-by: Chip Kent <5250374+chipkent@users.noreply.github.com> * Add support for datetime data * Update py/server/deephaven/dtypes.py Co-authored-by: Chip Kent <5250374+chipkent@users.noreply.github.com> * Respond to review comments and array testcases * Respond to reivew comments * Respond to new review comments * new test case for complex(unsupported) typehint --------- Co-authored-by: Chip Kent <5250374+chipkent@users.noreply.github.com> --- .../table/impl/lang/QueryLanguageParser.java | 125 ++++++++++++- .../engine/util/PyCallableWrapperJpyImpl.java | 57 ++++-- .../impl/lang/PyCallableWrapperDummyImpl.java | 2 +- py/server/deephaven/column.py | 28 +-- py/server/deephaven/dtypes.py | 121 +++++++++++- py/server/deephaven/table.py | 94 ++++++++-- py/server/tests/test_numba_guvectorize.py | 94 ++++++++++ .../tests/test_numba_vectorized_column.py | 2 +- .../tests/test_numba_vectorized_filter.py | 2 +- .../tests/test_pyfunc_return_java_values.py | 174 ++++++++++++++++++ py/server/tests/test_vectorization.py | 6 +- 11 files changed, 623 insertions(+), 82 deletions(-) create mode 100644 py/server/tests/test_numba_guvectorize.py create mode 100644 py/server/tests/test_pyfunc_return_java_values.py diff --git a/engine/table/src/main/java/io/deephaven/engine/table/impl/lang/QueryLanguageParser.java b/engine/table/src/main/java/io/deephaven/engine/table/impl/lang/QueryLanguageParser.java index 6b166b2d797..f7e1e03de43 100644 --- a/engine/table/src/main/java/io/deephaven/engine/table/impl/lang/QueryLanguageParser.java +++ b/engine/table/src/main/java/io/deephaven/engine/table/impl/lang/QueryLanguageParser.java @@ -20,7 +20,6 @@ import com.github.javaparser.ast.comments.BlockComment; import com.github.javaparser.ast.comments.JavadocComment; import com.github.javaparser.ast.comments.LineComment; -import com.github.javaparser.ast.comments.Comment; import com.github.javaparser.ast.expr.ArrayAccessExpr; import com.github.javaparser.ast.expr.ArrayCreationExpr; import com.github.javaparser.ast.expr.ArrayInitializerExpr; @@ -123,6 +122,7 @@ import java.lang.reflect.ParameterizedType; import java.lang.reflect.Type; import java.lang.reflect.TypeVariable; +import java.time.Instant; import java.util.*; import java.util.function.Function; import java.util.stream.Collectors; @@ -1214,7 +1214,7 @@ private Expression[] convertParameters(final Executable executable, // as a single argument if (ObjectVector.class.isAssignableFrom(expressionTypes[ei])) { expressions[ei] = new CastExpr( - new ClassOrInterfaceType("java.lang.Object"), + StaticJavaParser.parseClassOrInterfaceType("java.lang.Object"), expressions[ei]); expressionTypes[ei] = Object.class; } else { @@ -2023,7 +2023,7 @@ public Class visit(ConditionalExpr n, VisitArgs printer) { if (classA == boolean.class && classB == Boolean.class) { // a little hacky, but this handles the null case where it unboxes. very weird stuff final Expression uncastExpr = n.getThenExpr(); - final CastExpr castExpr = new CastExpr(new ClassOrInterfaceType(null, "Boolean"), uncastExpr); + final CastExpr castExpr = new CastExpr(StaticJavaParser.parseClassOrInterfaceType("Boolean"), uncastExpr); n.setThenExpr(castExpr); // fix parent in uncastExpr (it is cleared when it is replaced with the CastExpr) uncastExpr.setParentNode(castExpr); @@ -2032,7 +2032,7 @@ public Class visit(ConditionalExpr n, VisitArgs printer) { if (classA == Boolean.class && classB == boolean.class) { // a little hacky, but this handles the null case where it unboxes. very weird stuff final Expression uncastExpr = n.getElseExpr(); - final CastExpr castExpr = new CastExpr(new ClassOrInterfaceType(null, "Boolean"), uncastExpr); + final CastExpr castExpr = new CastExpr(StaticJavaParser.parseClassOrInterfaceType("Boolean"), uncastExpr); n.setElseExpr(castExpr); // fix parent in uncastExpr (it is cleared when it is replaced with the CastExpr) uncastExpr.setParentNode(castExpr); @@ -2159,7 +2159,8 @@ public Class visit(FieldAccessExpr n, VisitArgs printer) { printer.append(", " + clsName + ".class"); final ClassExpr targetType = - new ClassExpr(new ClassOrInterfaceType(null, printer.pythonCastContext.getSimpleName())); + new ClassExpr( + StaticJavaParser.parseClassOrInterfaceType(printer.pythonCastContext.getSimpleName())); getAttributeArgs.add(targetType); // Let's advertise to the caller the cast context type @@ -2337,6 +2338,35 @@ public Class visit(MethodCallExpr n, VisitArgs printer) { callMethodCall.setData(QueryLanguageParserDataKeys.PY_CALLABLE_DETAILS, new PyCallableDetails(null, methodName)); + if (PyCallableWrapper.class.isAssignableFrom(method.getDeclaringClass())) { + final Optional> optionalRetType = pyCallableReturnType(callMethodCall); + if (optionalRetType.isPresent()) { + Class retType = optionalRetType.get(); + final Optional optionalCastExpr = + makeCastExpressionForPyCallable(retType, callMethodCall); + if (optionalCastExpr.isPresent()) { + final CastExpr castExpr = optionalCastExpr.get(); + replaceChildExpression( + n.getParentNode().orElseThrow(), + n, + castExpr); + + callMethodCall.getData(QueryLanguageParserDataKeys.PY_CALLABLE_DETAILS).setCasted(true); + try { + return castExpr.accept(this, printer); + } catch (Exception e) { + // exceptions could be thrown by {@link #tryVectorizePythonCallable} + replaceChildExpression( + castExpr.getParentNode().orElseThrow(), + castExpr, + callMethodCall); + callMethodCall.getData(QueryLanguageParserDataKeys.PY_CALLABLE_DETAILS) + .setCasted(false); + return callMethodCall.accept(this, printer); + } + } + } + } replaceChildExpression( n.getParentNode().orElseThrow(), n, @@ -2376,7 +2406,7 @@ public Class visit(MethodCallExpr n, VisitArgs printer) { final ObjectCreationExpr newPyCallableExpr = new ObjectCreationExpr( null, - new ClassOrInterfaceType(null, pyCallableWrapperImplName), + StaticJavaParser.parseClassOrInterfaceType(pyCallableWrapperImplName), NodeList.nodeList(getAttributeCall)); final MethodCallExpr callMethodCall = new MethodCallExpr( @@ -2430,6 +2460,60 @@ public Class visit(MethodCallExpr n, VisitArgs printer) { typeArguments); } + private Optional makeCastExpressionForPyCallable(Class retType, MethodCallExpr callMethodCall) { + if (retType.isPrimitive()) { + return Optional.of(new CastExpr( + new PrimitiveType(PrimitiveType.Primitive + .valueOf(retType.getSimpleName().toUpperCase())), + callMethodCall)); + } else if (retType.getComponentType() != null) { + final Class componentType = retType.getComponentType(); + if (componentType.isPrimitive()) { + ArrayType arrayType; + if (componentType == boolean.class) { + arrayType = new ArrayType(StaticJavaParser.parseClassOrInterfaceType("java.lang.Boolean")); + } else { + arrayType = new ArrayType(new PrimitiveType(PrimitiveType.Primitive + .valueOf(retType.getComponentType().getSimpleName().toUpperCase()))); + } + return Optional.of(new CastExpr(arrayType, callMethodCall)); + } else if (retType.getComponentType() == String.class || retType.getComponentType() == Boolean.class + || retType.getComponentType() == Instant.class) { + ArrayType arrayType = + new ArrayType( + StaticJavaParser.parseClassOrInterfaceType(retType.getComponentType().getSimpleName())); + return Optional.of(new CastExpr(arrayType, callMethodCall)); + } + } else if (retType == Boolean.class) { + return Optional + .of(new CastExpr(StaticJavaParser.parseClassOrInterfaceType("java.lang.Boolean"), callMethodCall)); + } else if (retType == String.class) { + return Optional + .of(new CastExpr(StaticJavaParser.parseClassOrInterfaceType("java.lang.String"), callMethodCall)); + } else if (retType == Instant.class) { + return Optional + .of(new CastExpr(StaticJavaParser.parseClassOrInterfaceType("java.time.Instant"), callMethodCall)); + } + + return Optional.empty(); + } + + private Optional> pyCallableReturnType(@NotNull MethodCallExpr n) { + final PyCallableDetails pyCallableDetails = n.getData(QueryLanguageParserDataKeys.PY_CALLABLE_DETAILS); + final String pyMethodName = pyCallableDetails.pythonMethodName; + final QueryScope queryScope = ExecutionContext.getContext().getQueryScope(); + final Object paramValueRaw = queryScope.readParamValue(pyMethodName, null); + if (paramValueRaw == null) { + return Optional.empty(); + } + if (!(paramValueRaw instanceof PyCallableWrapper)) { + return Optional.empty(); + } + final PyCallableWrapper pyCallableWrapper = (PyCallableWrapper) paramValueRaw; + pyCallableWrapper.parseSignature(); + return Optional.ofNullable(pyCallableWrapper.getReturnType()); + } + @NotNull private static Expression[] getExpressionsArray(final NodeList exprNodeList) { return exprNodeList == null ? new Expression[0] @@ -2563,13 +2647,27 @@ private void checkVectorizability(@NotNull final MethodCallExpr n, // expression evaluation code that expects singular values. This check makes sure that numba/dh vectorized // functions must be used alone as the entire expression after removing the enclosing parentheses. Node n1 = n; + boolean autoCastChecked = false; while (n1.hasParentNode()) { n1 = n1.getParentNode().orElseThrow(); Class cls = n1.getClass(); if (cls == CastExpr.class) { - throw new PythonCallVectorizationFailure( - "The return values of Python vectorized function can't be cast: " + n1); + if (!autoCastChecked && n.getData(QueryLanguageParserDataKeys.PY_CALLABLE_DETAILS).isCasted()) { + autoCastChecked = true; + } else { + throw new PythonCallVectorizationFailure( + "The return values of Python vectorized functions can't be cast: " + n1); + } + } else if (cls == MethodCallExpr.class) { + String methodName = ((MethodCallExpr) n1).getNameAsString(); + if (!autoCastChecked && n.getData(QueryLanguageParserDataKeys.PY_CALLABLE_DETAILS).isCasted() + && methodName.endsWith("Cast")) { + autoCastChecked = true; + } else { + throw new PythonCallVectorizationFailure( + "Python vectorized function can't be used in another expression: " + n1); + } } else if (cls != EnclosedExpr.class && cls != WrapperNode.class) { throw new PythonCallVectorizationFailure( "Python vectorized function can't be used in another expression: " + n1); @@ -3233,6 +3331,17 @@ private static class PyCallableDetails { @NotNull private final String pythonMethodName; + @NotNull + private boolean isCasted = false; + + public boolean isCasted() { + return isCasted; + } + + public void setCasted(boolean casted) { + isCasted = casted; + } + private PyCallableDetails(@Nullable String pythonScopeExpr, @NotNull String pythonMethodName) { this.pythonScopeExpr = pythonScopeExpr; this.pythonMethodName = pythonMethodName; diff --git a/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapperJpyImpl.java b/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapperJpyImpl.java index 5f72ec781ac..6f3ae8b9699 100644 --- a/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapperJpyImpl.java +++ b/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapperJpyImpl.java @@ -6,6 +6,7 @@ import org.jpy.PyModule; import org.jpy.PyObject; +import java.time.Instant; import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; @@ -20,6 +21,7 @@ public class PyCallableWrapperJpyImpl implements PyCallableWrapper { private static final Logger log = LoggerFactory.getLogger(PyCallableWrapperJpyImpl.class); private static final PyObject NUMBA_VECTORIZED_FUNC_TYPE = getNumbaVectorizedFuncType(); + private static final PyObject NUMBA_GUVECTORIZED_FUNC_TYPE = getNumbaGUVectorizedFuncType(); private static final PyModule dh_table_module = PyModule.importModule("deephaven.table"); @@ -34,6 +36,7 @@ public class PyCallableWrapperJpyImpl implements PyCallableWrapper { numpyType2JavaClass.put('b', byte.class); numpyType2JavaClass.put('?', boolean.class); numpyType2JavaClass.put('U', String.class); + numpyType2JavaClass.put('M', Instant.class); numpyType2JavaClass.put('O', Object.class); } @@ -47,6 +50,7 @@ public class PyCallableWrapperJpyImpl implements PyCallableWrapper { private Collection chunkArguments; private boolean numbaVectorized; private PyObject unwrapped; + private PyObject pyUdfDecoratedCallable; public PyCallableWrapperJpyImpl(PyObject pyCallable) { this.pyCallable = pyCallable; @@ -91,23 +95,37 @@ private static PyObject getNumbaVectorizedFuncType() { } } + private static PyObject getNumbaGUVectorizedFuncType() { + try { + return PyModule.importModule("numba.np.ufunc.gufunc").getAttribute("GUFunc"); + } catch (Exception e) { + if (log.isDebugEnabled()) { + log.debug("Numba isn't installed in the Python environment."); + } + return null; + } + } + private void prepareSignature() { - if (pyCallable.getType().equals(NUMBA_VECTORIZED_FUNC_TYPE)) { + boolean isNumbaVectorized = pyCallable.getType().equals(NUMBA_VECTORIZED_FUNC_TYPE); + boolean isNumbaGUVectorized = pyCallable.equals(NUMBA_GUVECTORIZED_FUNC_TYPE); + if (isNumbaGUVectorized || isNumbaVectorized) { List params = pyCallable.getAttribute("types").asList(); if (params.isEmpty()) { throw new IllegalArgumentException( - "numba vectorized function must have an explicit signature: " + pyCallable); + "numba vectorized/guvectorized function must have an explicit signature: " + pyCallable); } // numba allows a vectorized function to have multiple signatures if (params.size() > 1) { throw new UnsupportedOperationException( pyCallable - + " has multiple signatures; this is not currently supported for numba vectorized functions"); + + " has multiple signatures; this is not currently supported for numba vectorized/guvectorized functions"); } signature = params.get(0).getStringValue(); - unwrapped = null; - numbaVectorized = true; - vectorized = true; + unwrapped = pyCallable; + // since vectorization doesn't support array type parameters, don't flag numba guvectorized as vectorized + numbaVectorized = isNumbaVectorized; + vectorized = isNumbaVectorized; } else if (pyCallable.hasAttribute("dh_vectorized")) { signature = pyCallable.getAttribute("signature").toString(); unwrapped = pyCallable.getAttribute("callable"); @@ -119,6 +137,7 @@ private void prepareSignature() { numbaVectorized = false; vectorized = false; } + pyUdfDecoratedCallable = dh_table_module.call("_py_udf", unwrapped); } @Override @@ -135,14 +154,6 @@ public void parseSignature() { throw new IllegalStateException("Signature should always be available."); } - char numpyTypeCode = signature.charAt(signature.length() - 1); - Class returnType = numpyType2JavaClass.get(numpyTypeCode); - if (returnType == null) { - throw new IllegalStateException( - "Vectorized functions should always have an integral, floating point, boolean, String, or Object return type: " - + numpyTypeCode); - } - List> paramTypes = new ArrayList<>(); for (char numpyTypeChar : signature.toCharArray()) { if (numpyTypeChar != '-') { @@ -159,25 +170,31 @@ public void parseSignature() { } this.paramTypes = paramTypes; - if (returnType == Object.class) { - this.returnType = PyObject.class; - } else if (returnType == boolean.class) { + + returnType = pyUdfDecoratedCallable.getAttribute("return_type", null); + if (returnType == null) { + throw new IllegalStateException( + "Python functions should always have an integral, floating point, boolean, String, arrays, or Object return type"); + } + + if (returnType == boolean.class) { this.returnType = Boolean.class; - } else { - this.returnType = returnType; } } + // In vectorized mode, we want to call the vectorized function directly. public PyObject vectorizedCallable() { - if (numbaVectorized) { + if (numbaVectorized || vectorized) { return pyCallable; } else { return dh_table_module.call("dh_vectorize", unwrapped); } } + // In non-vectorized mode, we want to call the udf decorated function or the original function. @Override public Object call(Object... args) { + PyObject pyCallable = this.pyUdfDecoratedCallable != null ? this.pyUdfDecoratedCallable : this.pyCallable; return PythonScopeJpyImpl.convert(pyCallable.callMethod("__call__", args)); } diff --git a/engine/table/src/test/java/io/deephaven/engine/table/impl/lang/PyCallableWrapperDummyImpl.java b/engine/table/src/test/java/io/deephaven/engine/table/impl/lang/PyCallableWrapperDummyImpl.java index b1e405b67d0..672a47b5158 100644 --- a/engine/table/src/test/java/io/deephaven/engine/table/impl/lang/PyCallableWrapperDummyImpl.java +++ b/engine/table/src/test/java/io/deephaven/engine/table/impl/lang/PyCallableWrapperDummyImpl.java @@ -69,6 +69,6 @@ public void addChunkArgument(ChunkArgument ignored) {} @Override public Class getReturnType() { - throw new UnsupportedOperationException(); + return Object.class; } } diff --git a/py/server/deephaven/column.py b/py/server/deephaven/column.py index 635f60794f1..81fd5ffa202 100644 --- a/py/server/deephaven/column.py +++ b/py/server/deephaven/column.py @@ -9,13 +9,11 @@ from typing import Sequence, Any import jpy -import numpy as np -import pandas as pd import deephaven.dtypes as dtypes -from deephaven import DHError, time +from deephaven import DHError from deephaven.dtypes import DType -from deephaven.time import to_j_instant +from deephaven.dtypes import _instant_array _JColumnHeader = jpy.get_type("io.deephaven.qst.column.header.ColumnHeader") _JColumn = jpy.get_type("io.deephaven.qst.column.Column") @@ -206,27 +204,7 @@ def datetime_col(name: str, data: Sequence) -> InputColumn: Returns: a new input column """ - - # try to convert to numpy array of datetime64 if not already, so that we can call translateArrayLongToInstant on - # it to reduce the number of round trips to the JVM - if not isinstance(data, np.ndarray): - try: - data = np.array([pd.Timestamp(dt).to_numpy() for dt in data], dtype=np.datetime64) - except Exception as e: - ... - - if isinstance(data, np.ndarray) and data.dtype.kind in ('M', 'i', 'U'): - if data.dtype.kind == 'M': - longs = jpy.array('long', data.astype('datetime64[ns]').astype('int64')) - elif data.dtype.kind == 'i': - longs = jpy.array('long', data.astype('int64')) - else: # data.dtype.kind == 'U' - longs = jpy.array('long', [pd.Timestamp(str(dt)).to_numpy().astype('int64') for dt in data]) - data = _JPrimitiveArrayConversionUtility.translateArrayLongToInstant(longs) - - if not isinstance(data, dtypes.instant_array.j_type): - data = [to_j_instant(d) for d in data] - + data = _instant_array(data) return InputColumn(name=name, data_type=dtypes.Instant, input_data=data) diff --git a/py/server/deephaven/dtypes.py b/py/server/deephaven/dtypes.py index a5f940f38eb..cedf9c67138 100644 --- a/py/server/deephaven/dtypes.py +++ b/py/server/deephaven/dtypes.py @@ -8,7 +8,9 @@ """ from __future__ import annotations -from typing import Any, Sequence, Callable, Dict, Type, Union +import datetime +import sys +from typing import Any, Sequence, Callable, Dict, Type, Union, _GenericAlias, Optional import jpy import numpy as np @@ -120,6 +122,8 @@ def __call__(self, *args, **kwargs): """Python object type""" JObject = DType(j_name="java.lang.Object") """Java Object type""" +bool_array = DType(j_name='[Z') +"""boolean array type""" byte_array = DType(j_name='[B') """Byte array type""" int8_array = byte_array @@ -128,6 +132,8 @@ def __call__(self, *args, **kwargs): """Short array type""" int16_array = short_array """Short array type""" +char_array = DType(j_name='[C') +"""char array type""" int32_array = DType(j_name='[I') """32bit integer array type""" long_array = DType(j_name='[J') @@ -136,7 +142,7 @@ def __call__(self, *args, **kwargs): """64bit integer array type""" int_array = long_array """64bit integer array type""" -single_array = DType(j_name='[S') +single_array = DType(j_name='[F') """Single-precision floating-point array type""" float32_array = single_array """Single-precision floating-point array type""" @@ -148,6 +154,8 @@ def __call__(self, *args, **kwargs): """Double-precision floating-point array type""" string_array = DType(j_name='[Ljava.lang.String;') """Java String array type""" +boolean_array = DType(j_name='[Ljava.lang.Boolean;') +"""Java Boolean array type""" instant_array = DType(j_name='[Ljava.time.Instant;') """Java Instant array type""" zdt_array = DType(j_name='[Ljava.time.ZonedDateTime;') @@ -164,6 +172,19 @@ def __call__(self, *args, **kwargs): float64: NULL_DOUBLE, } +_BUILDABLE_ARRAY_DTYPE_MAP = { + bool_: bool_array, + byte: int8_array, + char: char_array, + int16: int16_array, + int32: int32_array, + int64: int64_array, + float32: float32_array, + float64: float64_array, + string: string_array, + Instant: instant_array, +} + def null_remap(dtype: DType) -> Callable[[Any], Any]: """ Creates a null value remap function for the provided DType. @@ -184,6 +205,33 @@ def null_remap(dtype: DType) -> Callable[[Any], Any]: return lambda v: null_value if v is None else v +def _instant_array(data: Sequence) -> jpy.JType: + """Converts a sequence of either datetime64[ns], datetime.datetime, pandas.Timestamp, datetime strings, + or integers in nanoseconds, to a Java array of Instant values. """ + # try to convert to numpy array of datetime64 if not already, so that we can call translateArrayLongToInstant on + # it to reduce the number of round trips to the JVM + if not isinstance(data, np.ndarray): + try: + data = np.array([pd.Timestamp(dt).to_numpy() for dt in data], dtype=np.datetime64) + except Exception as e: + ... + + if isinstance(data, np.ndarray) and data.dtype.kind in ('M', 'i', 'U'): + if data.dtype.kind == 'M': + longs = jpy.array('long', data.astype('datetime64[ns]').astype('int64')) + elif data.dtype.kind == 'i': + longs = jpy.array('long', data.astype('int64')) + else: # data.dtype.kind == 'U' + longs = jpy.array('long', [pd.Timestamp(str(dt)).to_numpy().astype('int64') for dt in data]) + data = _JPrimitiveArrayConversionUtility.translateArrayLongToInstant(longs) + + if not isinstance(data, instant_array.j_type): + from deephaven.time import to_j_instant + data = [to_j_instant(d) for d in data] + + return data + + def array(dtype: DType, seq: Sequence, remap: Callable[[Any], Any] = None) -> jpy.JType: """ Creates a Java array of the specified data type populated with values from a sequence. @@ -215,14 +263,14 @@ def array(dtype: DType, seq: Sequence, remap: Callable[[Any], Any] = None) -> jp raise ValueError("Not a callable") seq = [remap(v) for v in seq] + if dtype == Instant: + return _instant_array(seq) + if isinstance(seq, np.ndarray): if dtype == bool_: bytes_ = seq.astype(dtype=np.int8) j_bytes = array(byte, bytes_) seq = _JPrimitiveArrayConversionUtility.translateArrayByteToBoolean(j_bytes) - elif dtype == Instant: - longs = jpy.array('long', seq.astype('datetime64[ns]').astype('int64')) - seq = _JPrimitiveArrayConversionUtility.translateArrayLongToInstant(longs) return jpy.array(dtype.j_type, seq) except Exception as e: @@ -266,3 +314,66 @@ def from_np_dtype(np_dtype: Union[np.dtype, pd.api.extensions.ExtensionDtype]) - return dtype return PyObject + + +_NUMPY_INT_TYPE_CODES = ["i", "l", "h", "b"] +_NUMPY_FLOATING_TYPE_CODES = ["f", "d"] + + +def _scalar(x): + """Converts a Python value to a Java scalar value. It converts the numpy primitive types, string to + their Python equivalents so that JPY can handle them. For datetime values, it converts them to Java Instant. + Otherwise, it returns the value as is.""" + if hasattr(x, "dtype"): + if x.dtype.char in _NUMPY_INT_TYPE_CODES: + return int(x) + elif x.dtype.char in _NUMPY_FLOATING_TYPE_CODES: + return float(x) + elif x.dtype.char == '?': + return bool(x) + elif x.dtype.char == 'U': + return str(x) + elif x.dtype.char == 'O': + return x + elif x.dtype.char == 'M': + from deephaven.time import to_j_instant + return to_j_instant(x) + else: + raise TypeError(f"Unsupported dtype: {x.dtype}") + else: + if isinstance(x, (datetime.datetime, pd.Timestamp)): + from deephaven.time import to_j_instant + return to_j_instant(x) + return x + + +def _np_dtype_char(t: Union[type, str]) -> str: + """Returns the numpy dtype character code for the given type.""" + try: + np_dtype = np.dtype(t if t else "object") + if np_dtype.kind == "O": + if t in (datetime.datetime, pd.Timestamp): + return "M" + except TypeError: + np_dtype = np.dtype("object") + + return np_dtype.char + + +def _component_np_dtype_char(t: type) -> Optional[str]: + """Returns the numpy dtype character code for the given type's component type if the type is a Sequence type or + numpy ndarray, otherwise return None. """ + component_type = None + if isinstance(t, _GenericAlias) and issubclass(t.__origin__, Sequence): + component_type = t.__args__[0] + + # np.ndarray as a generic alias is only supported in Python 3.9+ + if not component_type and sys.version_info.minor > 8: + import types + if isinstance(t, types.GenericAlias) and (issubclass(t.__origin__, Sequence) or t.__origin__ == np.ndarray): + component_type = t.__args__[0] + + if component_type: + return _np_dtype_char(component_type) + else: + return None diff --git a/py/server/deephaven/table.py b/py/server/deephaven/table.py index 15181e87691..78849e4cf46 100644 --- a/py/server/deephaven/table.py +++ b/py/server/deephaven/table.py @@ -11,10 +11,12 @@ import inspect from enum import Enum from enum import auto +from functools import wraps from typing import Any, Optional, Callable, Dict from typing import Sequence, List, Union, Protocol import jpy +import numba import numpy as np from deephaven import DHError @@ -29,6 +31,8 @@ from deephaven.jcompat import to_sequence, j_array_list from deephaven.update_graph import auto_locking_ctx, UpdateGraph from deephaven.updateby import UpdateByOperation +from deephaven.dtypes import _BUILDABLE_ARRAY_DTYPE_MAP, _scalar, _np_dtype_char, \ + _component_np_dtype_char # Table _J_Table = jpy.get_type("io.deephaven.engine.table.Table") @@ -359,40 +363,93 @@ def _j_py_script_session() -> _JPythonScriptSession: return None -_numpy_type_codes = ["i", "l", "h", "f", "d", "b", "?", "U", "O"] +_SUPPORTED_NP_TYPE_CODES = ["i", "l", "h", "f", "d", "b", "?", "U", "M", "O"] def _encode_signature(fn: Callable) -> str: """Encode the signature of a Python function by mapping the annotations of the parameter types and the return - type to numpy dtype chars (i,l,h,f,d,b,?,U,O), and pack them into a string with parameter type chars first, + type to numpy dtype chars (i,l,h,f,d,b,?,U,M,O), and pack them into a string with parameter type chars first, in their original order, followed by the delimiter string '->', then the return type_char. If a parameter or the return of the function is not annotated, the default 'O' - object type, will be used. """ sig = inspect.signature(fn) - parameter_types = [] + np_type_codes = [] for n, p in sig.parameters.items(): - try: - np_dtype = np.dtype(p.annotation if p.annotation else "object") - parameter_types.append(np_dtype) - except TypeError: - parameter_types.append(np.dtype("object")) - - try: - return_type = np.dtype(sig.return_annotation if sig.return_annotation else "object") - except TypeError: - return_type = np.dtype("object") + np_type_codes.append(_np_dtype_char(p.annotation)) - np_type_codes = [np.dtype(p).char for p in parameter_types] - np_type_codes = [c if c in _numpy_type_codes else "O" for c in np_type_codes] - return_type_code = np.dtype(return_type).char - return_type_code = return_type_code if return_type_code in _numpy_type_codes else "O" + return_type_code = _np_dtype_char(sig.return_annotation) + np_type_codes = [c if c in _SUPPORTED_NP_TYPE_CODES else "O" for c in np_type_codes] + return_type_code = return_type_code if return_type_code in _SUPPORTED_NP_TYPE_CODES else "O" np_type_codes.extend(["-", ">", return_type_code]) return "".join(np_type_codes) +def _py_udf(fn: Callable): + """A decorator that acts as a transparent translator for Python UDFs used in Deephaven query formulas between + Python and Java. This decorator is intended for use by the Deephaven query engine and should not be used by + users. + + For now, this decorator is only capable of converting Python function return values to Java values. It + does not yet convert Java values in arguments to usable Python object (e.g. numpy arrays) or properly translate + Deephaven primitive null values. + + For properly annotated functions, including numba vectorized and guvectorized ones, this decorator inspects the + signature of the function and determines its return type, including supported primitive types and arrays of + the supported primitive types. It then converts the return value of the function to the corresponding Java value + of the same type. For unsupported types, the decorator returns the original Python value which appears as + org.jpy.PyObject in Java. + """ + + if hasattr(fn, "return_type"): + return fn + + if isinstance(fn, (numba.np.ufunc.dufunc.DUFunc, numba.np.ufunc.gufunc.GUFunc)) and hasattr(fn, "types"): + dh_dtype = dtypes.from_np_dtype(np.dtype(fn.types[0][-1])) + else: + dh_dtype = dtypes.from_np_dtype(np.dtype(_encode_signature(fn)[-1])) + + return_array = False + + # If the function is a numba guvectorized function, examine the signature of the function to determine if it + # returns an array. + if isinstance(fn, numba.np.ufunc.gufunc.GUFunc): + sig = fn.signature + rtype = sig.split("->")[-1].strip("()") + if rtype: + return_array = True + else: + component_type = _component_np_dtype_char(inspect.signature(fn).return_annotation) + if component_type: + dh_dtype = dtypes.from_np_dtype(np.dtype(component_type)) + if dh_dtype in _BUILDABLE_ARRAY_DTYPE_MAP: + return_array = True + + @wraps(fn) + def wrapper(*args, **kwargs): + ret = fn(*args, **kwargs) + if return_array: + return dtypes.array(dh_dtype, ret) + elif dh_dtype == dtypes.PyObject: + return ret + else: + return _scalar(ret) + + wrapper.j_name = dh_dtype.j_name + ret_dtype = _BUILDABLE_ARRAY_DTYPE_MAP.get(dh_dtype) if return_array else dh_dtype + + if hasattr(dh_dtype.j_type, 'jclass'): + j_class = ret_dtype.j_type.jclass + else: + j_class = ret_dtype.qst_type.clazz() + + wrapper.return_type = j_class + + return wrapper + + def dh_vectorize(fn): """A decorator to vectorize a Python function used in Deephaven query formulas and invoked on a row basis. @@ -410,6 +467,7 @@ def dh_vectorize(fn): """ signature = _encode_signature(fn) + @wraps(fn) def wrapper(*args): if len(args) != len(signature) - len("->?") + 2: raise ValueError( @@ -423,7 +481,7 @@ def wrapper(*args): vectorized_args = zip(*args[2:]) for i in range(chunk_size): scalar_args = next(vectorized_args) - chunk_result[i] = fn(*scalar_args) + chunk_result[i] = _scalar(fn(*scalar_args)) else: for i in range(chunk_size): chunk_result[i] = fn() diff --git a/py/server/tests/test_numba_guvectorize.py b/py/server/tests/test_numba_guvectorize.py new file mode 100644 index 00000000000..c82b92296e3 --- /dev/null +++ b/py/server/tests/test_numba_guvectorize.py @@ -0,0 +1,94 @@ +# +# Copyright (c) 2016-2023 Deephaven Data Labs and Patent Pending +# + +import unittest + +import numpy as np +from numba import guvectorize, int64 + +from deephaven import empty_table, dtypes +from tests.testbase import BaseTestCase + +a = np.arange(5, dtype=np.int64) + + +class NumbaGuvectorizeTestCase(BaseTestCase): + def test_scalar_return(self): + # vector input to scalar output function (m)->() + @guvectorize([(int64[:], int64[:])], "(m)->()", nopython=True) + def g(x, res): + res[0] = 0 + for xi in x: + res[0] += xi + + t = empty_table(10).update(["X=i%3", "Y=i"]).group_by("X").update("Z=g(Y)") + m = t.meta_table + self.assertEqual(t.columns[2].data_type, dtypes.int64) + + def test_vector_return(self): + # vector and scalar input to vector ouput function + @guvectorize([(int64[:], int64, int64[:])], "(m),()->(m)", nopython=True) + def g(x, y, res): + for i in range(len(x)): + res[i] = x[i] + y + + t = empty_table(10).update(["X=i%3", "Y=i"]).group_by("X").update("Z=g(Y,2)") + self.assertEqual(t.columns[2].data_type, dtypes.long_array) + + def test_fixed_length_vector_return(self): + # NOTE: the following does not work according to this thread from 7 years ago: + # https://numba-users.continuum.narkive.com/7OAX8Suv/numba-guvectorize-with-fixed-size-output-array + # but the latest numpy Generalized Universal Function API does seem to support frozen dimensions + # https://numpy.org/doc/stable/reference/c-api/generalized-ufuncs.html#generalized-universal-function-api + # There is an old numba ticket about it + # https://github.com/numba/numba/issues/1668 + # Possibly we could contribute a fix + + # fails with: bad token in signature "2" + + # #vector input to fixed-length vector ouput function + # @guvectorize([(int64[:],int64[:])],"(m)->(2)",nopython=True) + # def g3(x, res): + # res[0] = min(x) + # res[1] = max(x) + + # t3 = empty_table(10).update(["X=i%3", "Y=i"]).group_by("X").update("Z=g3(Y)") + # m3 = t3.meta_table + + # ** Workaround ** + + dummy = np.array([0, 0], dtype=np.int64) + + # vector input to fixed-length vector ouput function -- second arg is a dummy just to get a fixed size output + @guvectorize([(int64[:], int64[:], int64[:])], "(m),(n)->(n)", nopython=True) + def g(x, dummy, res): + res[0] = min(x) + res[1] = max(x) + + t = empty_table(10).update(["X=i%3", "Y=i"]).group_by("X").update("Z=g(Y,dummy)") + self.assertEqual(t.columns[2].data_type, dtypes.long_array) + + def test_np_on_java_array(self): + dummy = np.array([0, 0], dtype=np.int64) + + # vector input to fixed-length vector output function -- second arg is a dummy just to get a fixed size output + @guvectorize([(int64[:], int64[:], int64[:])], "(m),(n)->(n)", nopython=True) + def g(x, dummy, res): + res[0] = np.min(x) + res[1] = np.max(x) + + t = empty_table(10).update(["X=i%3", "Y=i"]).group_by("X").update("Z=g(Y,dummy)") + self.assertEqual(t.columns[2].data_type, dtypes.long_array) + + def test_np_on_java_array2(self): + @guvectorize([(int64[:], int64[:])], "(m)->(m)", nopython=True) + def g(x, res): + res[:] = x + 5 + + t = empty_table(10).update(["X=i%3", "Y=i"]).group_by("X").update("Z=g(Y)") + self.assertEqual(t.columns[2].data_type, dtypes.long_array) + + +if __name__ == '__main__': + unittest.main() diff --git a/py/server/tests/test_numba_vectorized_column.py b/py/server/tests/test_numba_vectorized_column.py index d47687c96d1..6f14eb7b659 100644 --- a/py/server/tests/test_numba_vectorized_column.py +++ b/py/server/tests/test_numba_vectorized_column.py @@ -16,7 +16,7 @@ def vectorized_func(x, y): return x % 3 + y -class TestNumbaVectorizedColumnClass(BaseTestCase): +class NumbaVectorizedColumnTestCase(BaseTestCase): def test_part_of_expr(self): with self.assertRaises(Exception): diff --git a/py/server/tests/test_numba_vectorized_filter.py b/py/server/tests/test_numba_vectorized_filter.py index 98155dadff4..468cc9e04c3 100644 --- a/py/server/tests/test_numba_vectorized_filter.py +++ b/py/server/tests/test_numba_vectorized_filter.py @@ -22,7 +22,7 @@ def vectorized_func_wrong_return_type(x, y): return x % 2 > y % 5 -class TestNumbaVectorizedFilterClass(BaseTestCase): +class NumbaVectorizedFilterTestCase(BaseTestCase): def test_wrong_return_type(self): with self.assertRaises(Exception): diff --git a/py/server/tests/test_pyfunc_return_java_values.py b/py/server/tests/test_pyfunc_return_java_values.py new file mode 100644 index 00000000000..34501242130 --- /dev/null +++ b/py/server/tests/test_pyfunc_return_java_values.py @@ -0,0 +1,174 @@ +# +# Copyright (c) 2016-2023 Deephaven Data Labs and Patent Pending +# +import time +import unittest +from typing import List, Tuple, Sequence, Union + +import numpy as np +import pandas as pd +import datetime +from deephaven import empty_table, dtypes +from deephaven import time as dhtime +from tests.testbase import BaseTestCase + +_J_TYPE_NP_DTYPE_MAP = { + dtypes.double: "np.float64", + dtypes.float32: "np.float32", + dtypes.int32: "np.int32", + dtypes.long: "np.int64", + dtypes.short: "np.int16", + dtypes.byte: "np.int8", + dtypes.bool_: "np.bool_", + dtypes.string: "np.str_", + # dtypes.char: "np.uint16", +} + + +class PyFuncReturnJavaTestCase(BaseTestCase): + def test_scalar_return(self): + for dh_dtype, np_dtype in _J_TYPE_NP_DTYPE_MAP.items(): + with self.subTest(dh_dtype=dh_dtype, np_dtype=np_dtype): + func_str = f""" +def fn(col) -> {np_dtype}: + return {np_dtype}(col) +""" + exec(func_str, globals()) + + t = empty_table(10).update("X = i").update(f"Y= fn(X + 1)") + self.assertEqual(t.columns[1].data_type, dh_dtype) + + def test_array_return(self): + component_types = { + "int": dtypes.long_array, + "float": dtypes.double_array, + "np.int8": dtypes.byte_array, + "np.int16": dtypes.short_array, + "np.int32": dtypes.int32_array, + "np.int64": dtypes.long_array, + "np.float32": dtypes.float32_array, + "np.float64": dtypes.double_array, + "bool": dtypes.boolean_array, + "np.str_": dtypes.string_array, + # "np.uint16": dtypes.char_array, + } + container_types = ["List", "Tuple", "list", "tuple", "Sequence", "np.ndarray"] + for component_type, dh_dtype in component_types.items(): + for container_type in container_types: + with self.subTest(component_type=component_type, container_type=container_type): + func_decl_str = f"""def fn(col) -> {container_type}[{component_type}]:""" + if container_type == "np.ndarray": + func_body_str = f""" return np.array([{component_type}(c) for c in col])""" + else: + func_body_str = f""" return [{component_type}(c) for c in col]""" + exec("\n".join([func_decl_str, func_body_str]), globals()) + t = empty_table(10).update(["X = i % 3", "Y = i"]).group_by("X").update(f"Z= fn(Y + 1)") + self.assertEqual(t.columns[2].data_type, dh_dtype) + + def test_scalar_return_class_method_not_supported(self): + for dh_dtype, np_dtype in _J_TYPE_NP_DTYPE_MAP.items(): + with self.subTest(dh_dtype=dh_dtype, np_dtype=np_dtype): + func_str = f""" +class Foo: + def fn(self, col) -> {np_dtype}: + return {np_dtype}(col) +foo = Foo() +""" + exec(func_str, globals()) + + t = empty_table(10).update("X = i").update(f"Y= foo.fn(X + 1)") + self.assertNotEqual(t.columns[1].data_type, dh_dtype) + + def test_datetime_scalar_return(self): + dt_dtypes = [ + "np.dtype('datetime64[ns]')", + "np.dtype('datetime64[ms]')", + "datetime.datetime", + "pd.Timestamp" + ] + + for np_dtype in dt_dtypes: + with self.subTest(np_dtype=np_dtype): + func_decl_str = f"""def fn(col) -> {np_dtype}:""" + if np_dtype == "np.dtype('datetime64[ns]')": + func_body_str = f""" return pd.Timestamp(col).to_numpy()""" + elif np_dtype == "datetime.datetime": + func_body_str = f""" return pd.Timestamp(col).to_pydatetime()""" + elif np_dtype == "pd.Timestamp": + func_body_str = f""" return pd.Timestamp(col)""" + + exec("\n".join([func_decl_str, func_body_str]), globals()) + + t = empty_table(10).update("X = i").update(f"Y= fn(X + 1)") + self.assertEqual(t.columns[1].data_type, dtypes.Instant) + # vectorized + t = empty_table(10).update("X = i").update(f"Y= fn(X)") + self.assertEqual(t.columns[1].data_type, dtypes.Instant) + + def test_datetime_array_return(self): + + dt = datetime.datetime.now() + ts = pd.Timestamp(dt) + np_dt = np.datetime64(dt) + dt_list = [ts, np_dt, dt] + + # test if we can convert to numpy datetime64 array + np_array = np.array([pd.Timestamp(dt).to_numpy() for dt in dt_list], dtype=np.datetime64) + + dt_dtypes = [ + "np.ndarray[np.dtype('datetime64[ns]')]", + "List[datetime.datetime]", + "Tuple[pd.Timestamp]" + ] + + dt_data = [ + "dt_list", + "np_array", + ] + + # we are capable of mapping all datetime array (Sequence, np.ndarray) to instant array, so even if the actual + # return value of the function doesn't match its type hint (which will be caught by a static type checker. we + # still can convert it to instant array + for np_dtype in dt_dtypes: + for data in dt_data: + with self.subTest(np_dtype=np_dtype, data=data): + func_decl_str = f"""def fn(col) -> {np_dtype}:""" + func_body_str = f""" return {data}""" + exec("\n".join([func_decl_str, func_body_str]), globals().update( + {"dt_list": dt_list, "np_array": np_array})) + + t = empty_table(10).update("X = i").update(f"Y= fn(X + 1)") + self.assertEqual(t.columns[1].data_type, dtypes.instant_array) + + def test_return_value_errors(self): + def fn(col) -> List[object]: + return [col] + + def fn1(col) -> List: + return [col] + + def fn2(col): + return col + + def fn3(col) -> List[Union[datetime.datetime, int]]: + return [col] + + with self.subTest(fn): + t = empty_table(1).update("X = i").update(f"Y= fn(X + 1)") + self.assertEqual(t.columns[1].data_type, dtypes.JObject) + + with self.subTest(fn1): + t = empty_table(1).update("X = i").update(f"Y= fn1(X + 1)") + self.assertEqual(t.columns[1].data_type, dtypes.JObject) + + with self.subTest(fn2): + t = empty_table(1).update("X = i").update(f"Y= fn2(X + 1)") + self.assertEqual(t.columns[1].data_type, dtypes.JObject) + + with self.subTest(fn3): + t = empty_table(1).update("X = i").update(f"Y= fn3(X + 1)") + self.assertEqual(t.columns[1].data_type, dtypes.JObject) + + +if __name__ == '__main__': + unittest.main() diff --git a/py/server/tests/test_vectorization.py b/py/server/tests/test_vectorization.py index b38004f8dcb..4bb941788fd 100644 --- a/py/server/tests/test_vectorization.py +++ b/py/server/tests/test_vectorization.py @@ -143,7 +143,7 @@ def pyfunc(p1, p2, p3) -> int: return p1 + p2 + p3 t = empty_table(1).update("X = i").update(["Y = pyfunc(X, i, 33)", "Z = pyfunc(X, ii, 66)"]) - self.assertEqual(deephaven.table._vectorized_count, 3) + self.assertEqual(deephaven.table._vectorized_count, 1) self.assertIn("33", t.to_string(cols=["Y"])) self.assertIn("66", t.to_string(cols=["Z"])) @@ -186,11 +186,11 @@ def pyfunc_bool(p1, p2, p3) -> bool: conditions = ["pyfunc_bool(I, 3, J)", "pyfunc_bool(i, 10, ii)"] filters = Filter.from_(conditions) t = empty_table(10).view(formulas=["I=ii", "J=(ii * 2)"]).where(filters) - self.assertEqual(3, deephaven.table._vectorized_count) + self.assertEqual(1, deephaven.table._vectorized_count) filter_and = and_(filters) t1 = empty_table(10).view(formulas=["I=ii", "J=(ii * 2)"]).where(filter_and) - self.assertEqual(5, deephaven.table._vectorized_count) + self.assertEqual(1, deephaven.table._vectorized_count) self.assertEqual(t1.size, t.size) self.assertEqual(9, t.size)