From b14b349748335bcbd52d244d313a578f7d77bec1 Mon Sep 17 00:00:00 2001 From: Devin Smith Date: Fri, 13 Oct 2023 13:29:46 -0700 Subject: [PATCH] Revert "auto convert Python return val to Java val (#4579)" (#4638) This reverts commit 77587f4ff3d346da1a7c2947aeb456bd2b0b9657. --- .../table/impl/lang/QueryLanguageParser.java | 125 +------------ .../engine/util/PyCallableWrapperJpyImpl.java | 57 ++---- .../impl/lang/PyCallableWrapperDummyImpl.java | 2 +- py/server/deephaven/column.py | 28 ++- py/server/deephaven/dtypes.py | 121 +----------- py/server/deephaven/table.py | 94 ++-------- py/server/tests/test_numba_guvectorize.py | 94 ---------- .../tests/test_numba_vectorized_column.py | 2 +- .../tests/test_numba_vectorized_filter.py | 2 +- .../tests/test_pyfunc_return_java_values.py | 174 ------------------ py/server/tests/test_vectorization.py | 6 +- 11 files changed, 82 insertions(+), 623 deletions(-) delete mode 100644 py/server/tests/test_numba_guvectorize.py delete mode 100644 py/server/tests/test_pyfunc_return_java_values.py diff --git a/engine/table/src/main/java/io/deephaven/engine/table/impl/lang/QueryLanguageParser.java b/engine/table/src/main/java/io/deephaven/engine/table/impl/lang/QueryLanguageParser.java index f7e1e03de43..6b166b2d797 100644 --- a/engine/table/src/main/java/io/deephaven/engine/table/impl/lang/QueryLanguageParser.java +++ b/engine/table/src/main/java/io/deephaven/engine/table/impl/lang/QueryLanguageParser.java @@ -20,6 +20,7 @@ import com.github.javaparser.ast.comments.BlockComment; import com.github.javaparser.ast.comments.JavadocComment; import com.github.javaparser.ast.comments.LineComment; +import com.github.javaparser.ast.comments.Comment; import com.github.javaparser.ast.expr.ArrayAccessExpr; import com.github.javaparser.ast.expr.ArrayCreationExpr; import com.github.javaparser.ast.expr.ArrayInitializerExpr; @@ -122,7 +123,6 @@ import java.lang.reflect.ParameterizedType; import java.lang.reflect.Type; import java.lang.reflect.TypeVariable; -import java.time.Instant; import java.util.*; import java.util.function.Function; import java.util.stream.Collectors; @@ -1214,7 +1214,7 @@ private Expression[] convertParameters(final Executable executable, // as a single argument if (ObjectVector.class.isAssignableFrom(expressionTypes[ei])) { expressions[ei] = new CastExpr( - StaticJavaParser.parseClassOrInterfaceType("java.lang.Object"), + new ClassOrInterfaceType("java.lang.Object"), expressions[ei]); expressionTypes[ei] = Object.class; } else { @@ -2023,7 +2023,7 @@ public Class visit(ConditionalExpr n, VisitArgs printer) { if (classA == boolean.class && classB == Boolean.class) { // a little hacky, but this handles the null case where it unboxes. very weird stuff final Expression uncastExpr = n.getThenExpr(); - final CastExpr castExpr = new CastExpr(StaticJavaParser.parseClassOrInterfaceType("Boolean"), uncastExpr); + final CastExpr castExpr = new CastExpr(new ClassOrInterfaceType(null, "Boolean"), uncastExpr); n.setThenExpr(castExpr); // fix parent in uncastExpr (it is cleared when it is replaced with the CastExpr) uncastExpr.setParentNode(castExpr); @@ -2032,7 +2032,7 @@ public Class visit(ConditionalExpr n, VisitArgs printer) { if (classA == Boolean.class && classB == boolean.class) { // a little hacky, but this handles the null case where it unboxes. very weird stuff final Expression uncastExpr = n.getElseExpr(); - final CastExpr castExpr = new CastExpr(StaticJavaParser.parseClassOrInterfaceType("Boolean"), uncastExpr); + final CastExpr castExpr = new CastExpr(new ClassOrInterfaceType(null, "Boolean"), uncastExpr); n.setElseExpr(castExpr); // fix parent in uncastExpr (it is cleared when it is replaced with the CastExpr) uncastExpr.setParentNode(castExpr); @@ -2159,8 +2159,7 @@ public Class visit(FieldAccessExpr n, VisitArgs printer) { printer.append(", " + clsName + ".class"); final ClassExpr targetType = - new ClassExpr( - StaticJavaParser.parseClassOrInterfaceType(printer.pythonCastContext.getSimpleName())); + new ClassExpr(new ClassOrInterfaceType(null, printer.pythonCastContext.getSimpleName())); getAttributeArgs.add(targetType); // Let's advertise to the caller the cast context type @@ -2338,35 +2337,6 @@ public Class visit(MethodCallExpr n, VisitArgs printer) { callMethodCall.setData(QueryLanguageParserDataKeys.PY_CALLABLE_DETAILS, new PyCallableDetails(null, methodName)); - if (PyCallableWrapper.class.isAssignableFrom(method.getDeclaringClass())) { - final Optional> optionalRetType = pyCallableReturnType(callMethodCall); - if (optionalRetType.isPresent()) { - Class retType = optionalRetType.get(); - final Optional optionalCastExpr = - makeCastExpressionForPyCallable(retType, callMethodCall); - if (optionalCastExpr.isPresent()) { - final CastExpr castExpr = optionalCastExpr.get(); - replaceChildExpression( - n.getParentNode().orElseThrow(), - n, - castExpr); - - callMethodCall.getData(QueryLanguageParserDataKeys.PY_CALLABLE_DETAILS).setCasted(true); - try { - return castExpr.accept(this, printer); - } catch (Exception e) { - // exceptions could be thrown by {@link #tryVectorizePythonCallable} - replaceChildExpression( - castExpr.getParentNode().orElseThrow(), - castExpr, - callMethodCall); - callMethodCall.getData(QueryLanguageParserDataKeys.PY_CALLABLE_DETAILS) - .setCasted(false); - return callMethodCall.accept(this, printer); - } - } - } - } replaceChildExpression( n.getParentNode().orElseThrow(), n, @@ -2406,7 +2376,7 @@ public Class visit(MethodCallExpr n, VisitArgs printer) { final ObjectCreationExpr newPyCallableExpr = new ObjectCreationExpr( null, - StaticJavaParser.parseClassOrInterfaceType(pyCallableWrapperImplName), + new ClassOrInterfaceType(null, pyCallableWrapperImplName), NodeList.nodeList(getAttributeCall)); final MethodCallExpr callMethodCall = new MethodCallExpr( @@ -2460,60 +2430,6 @@ public Class visit(MethodCallExpr n, VisitArgs printer) { typeArguments); } - private Optional makeCastExpressionForPyCallable(Class retType, MethodCallExpr callMethodCall) { - if (retType.isPrimitive()) { - return Optional.of(new CastExpr( - new PrimitiveType(PrimitiveType.Primitive - .valueOf(retType.getSimpleName().toUpperCase())), - callMethodCall)); - } else if (retType.getComponentType() != null) { - final Class componentType = retType.getComponentType(); - if (componentType.isPrimitive()) { - ArrayType arrayType; - if (componentType == boolean.class) { - arrayType = new ArrayType(StaticJavaParser.parseClassOrInterfaceType("java.lang.Boolean")); - } else { - arrayType = new ArrayType(new PrimitiveType(PrimitiveType.Primitive - .valueOf(retType.getComponentType().getSimpleName().toUpperCase()))); - } - return Optional.of(new CastExpr(arrayType, callMethodCall)); - } else if (retType.getComponentType() == String.class || retType.getComponentType() == Boolean.class - || retType.getComponentType() == Instant.class) { - ArrayType arrayType = - new ArrayType( - StaticJavaParser.parseClassOrInterfaceType(retType.getComponentType().getSimpleName())); - return Optional.of(new CastExpr(arrayType, callMethodCall)); - } - } else if (retType == Boolean.class) { - return Optional - .of(new CastExpr(StaticJavaParser.parseClassOrInterfaceType("java.lang.Boolean"), callMethodCall)); - } else if (retType == String.class) { - return Optional - .of(new CastExpr(StaticJavaParser.parseClassOrInterfaceType("java.lang.String"), callMethodCall)); - } else if (retType == Instant.class) { - return Optional - .of(new CastExpr(StaticJavaParser.parseClassOrInterfaceType("java.time.Instant"), callMethodCall)); - } - - return Optional.empty(); - } - - private Optional> pyCallableReturnType(@NotNull MethodCallExpr n) { - final PyCallableDetails pyCallableDetails = n.getData(QueryLanguageParserDataKeys.PY_CALLABLE_DETAILS); - final String pyMethodName = pyCallableDetails.pythonMethodName; - final QueryScope queryScope = ExecutionContext.getContext().getQueryScope(); - final Object paramValueRaw = queryScope.readParamValue(pyMethodName, null); - if (paramValueRaw == null) { - return Optional.empty(); - } - if (!(paramValueRaw instanceof PyCallableWrapper)) { - return Optional.empty(); - } - final PyCallableWrapper pyCallableWrapper = (PyCallableWrapper) paramValueRaw; - pyCallableWrapper.parseSignature(); - return Optional.ofNullable(pyCallableWrapper.getReturnType()); - } - @NotNull private static Expression[] getExpressionsArray(final NodeList exprNodeList) { return exprNodeList == null ? new Expression[0] @@ -2647,27 +2563,13 @@ private void checkVectorizability(@NotNull final MethodCallExpr n, // expression evaluation code that expects singular values. This check makes sure that numba/dh vectorized // functions must be used alone as the entire expression after removing the enclosing parentheses. Node n1 = n; - boolean autoCastChecked = false; while (n1.hasParentNode()) { n1 = n1.getParentNode().orElseThrow(); Class cls = n1.getClass(); if (cls == CastExpr.class) { - if (!autoCastChecked && n.getData(QueryLanguageParserDataKeys.PY_CALLABLE_DETAILS).isCasted()) { - autoCastChecked = true; - } else { - throw new PythonCallVectorizationFailure( - "The return values of Python vectorized functions can't be cast: " + n1); - } - } else if (cls == MethodCallExpr.class) { - String methodName = ((MethodCallExpr) n1).getNameAsString(); - if (!autoCastChecked && n.getData(QueryLanguageParserDataKeys.PY_CALLABLE_DETAILS).isCasted() - && methodName.endsWith("Cast")) { - autoCastChecked = true; - } else { - throw new PythonCallVectorizationFailure( - "Python vectorized function can't be used in another expression: " + n1); - } + throw new PythonCallVectorizationFailure( + "The return values of Python vectorized function can't be cast: " + n1); } else if (cls != EnclosedExpr.class && cls != WrapperNode.class) { throw new PythonCallVectorizationFailure( "Python vectorized function can't be used in another expression: " + n1); @@ -3331,17 +3233,6 @@ private static class PyCallableDetails { @NotNull private final String pythonMethodName; - @NotNull - private boolean isCasted = false; - - public boolean isCasted() { - return isCasted; - } - - public void setCasted(boolean casted) { - isCasted = casted; - } - private PyCallableDetails(@Nullable String pythonScopeExpr, @NotNull String pythonMethodName) { this.pythonScopeExpr = pythonScopeExpr; this.pythonMethodName = pythonMethodName; diff --git a/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapperJpyImpl.java b/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapperJpyImpl.java index 6f3ae8b9699..5f72ec781ac 100644 --- a/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapperJpyImpl.java +++ b/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapperJpyImpl.java @@ -6,7 +6,6 @@ import org.jpy.PyModule; import org.jpy.PyObject; -import java.time.Instant; import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; @@ -21,7 +20,6 @@ public class PyCallableWrapperJpyImpl implements PyCallableWrapper { private static final Logger log = LoggerFactory.getLogger(PyCallableWrapperJpyImpl.class); private static final PyObject NUMBA_VECTORIZED_FUNC_TYPE = getNumbaVectorizedFuncType(); - private static final PyObject NUMBA_GUVECTORIZED_FUNC_TYPE = getNumbaGUVectorizedFuncType(); private static final PyModule dh_table_module = PyModule.importModule("deephaven.table"); @@ -36,7 +34,6 @@ public class PyCallableWrapperJpyImpl implements PyCallableWrapper { numpyType2JavaClass.put('b', byte.class); numpyType2JavaClass.put('?', boolean.class); numpyType2JavaClass.put('U', String.class); - numpyType2JavaClass.put('M', Instant.class); numpyType2JavaClass.put('O', Object.class); } @@ -50,7 +47,6 @@ public class PyCallableWrapperJpyImpl implements PyCallableWrapper { private Collection chunkArguments; private boolean numbaVectorized; private PyObject unwrapped; - private PyObject pyUdfDecoratedCallable; public PyCallableWrapperJpyImpl(PyObject pyCallable) { this.pyCallable = pyCallable; @@ -95,37 +91,23 @@ private static PyObject getNumbaVectorizedFuncType() { } } - private static PyObject getNumbaGUVectorizedFuncType() { - try { - return PyModule.importModule("numba.np.ufunc.gufunc").getAttribute("GUFunc"); - } catch (Exception e) { - if (log.isDebugEnabled()) { - log.debug("Numba isn't installed in the Python environment."); - } - return null; - } - } - private void prepareSignature() { - boolean isNumbaVectorized = pyCallable.getType().equals(NUMBA_VECTORIZED_FUNC_TYPE); - boolean isNumbaGUVectorized = pyCallable.equals(NUMBA_GUVECTORIZED_FUNC_TYPE); - if (isNumbaGUVectorized || isNumbaVectorized) { + if (pyCallable.getType().equals(NUMBA_VECTORIZED_FUNC_TYPE)) { List params = pyCallable.getAttribute("types").asList(); if (params.isEmpty()) { throw new IllegalArgumentException( - "numba vectorized/guvectorized function must have an explicit signature: " + pyCallable); + "numba vectorized function must have an explicit signature: " + pyCallable); } // numba allows a vectorized function to have multiple signatures if (params.size() > 1) { throw new UnsupportedOperationException( pyCallable - + " has multiple signatures; this is not currently supported for numba vectorized/guvectorized functions"); + + " has multiple signatures; this is not currently supported for numba vectorized functions"); } signature = params.get(0).getStringValue(); - unwrapped = pyCallable; - // since vectorization doesn't support array type parameters, don't flag numba guvectorized as vectorized - numbaVectorized = isNumbaVectorized; - vectorized = isNumbaVectorized; + unwrapped = null; + numbaVectorized = true; + vectorized = true; } else if (pyCallable.hasAttribute("dh_vectorized")) { signature = pyCallable.getAttribute("signature").toString(); unwrapped = pyCallable.getAttribute("callable"); @@ -137,7 +119,6 @@ private void prepareSignature() { numbaVectorized = false; vectorized = false; } - pyUdfDecoratedCallable = dh_table_module.call("_py_udf", unwrapped); } @Override @@ -154,6 +135,14 @@ public void parseSignature() { throw new IllegalStateException("Signature should always be available."); } + char numpyTypeCode = signature.charAt(signature.length() - 1); + Class returnType = numpyType2JavaClass.get(numpyTypeCode); + if (returnType == null) { + throw new IllegalStateException( + "Vectorized functions should always have an integral, floating point, boolean, String, or Object return type: " + + numpyTypeCode); + } + List> paramTypes = new ArrayList<>(); for (char numpyTypeChar : signature.toCharArray()) { if (numpyTypeChar != '-') { @@ -170,31 +159,25 @@ public void parseSignature() { } this.paramTypes = paramTypes; - - returnType = pyUdfDecoratedCallable.getAttribute("return_type", null); - if (returnType == null) { - throw new IllegalStateException( - "Python functions should always have an integral, floating point, boolean, String, arrays, or Object return type"); - } - - if (returnType == boolean.class) { + if (returnType == Object.class) { + this.returnType = PyObject.class; + } else if (returnType == boolean.class) { this.returnType = Boolean.class; + } else { + this.returnType = returnType; } } - // In vectorized mode, we want to call the vectorized function directly. public PyObject vectorizedCallable() { - if (numbaVectorized || vectorized) { + if (numbaVectorized) { return pyCallable; } else { return dh_table_module.call("dh_vectorize", unwrapped); } } - // In non-vectorized mode, we want to call the udf decorated function or the original function. @Override public Object call(Object... args) { - PyObject pyCallable = this.pyUdfDecoratedCallable != null ? this.pyUdfDecoratedCallable : this.pyCallable; return PythonScopeJpyImpl.convert(pyCallable.callMethod("__call__", args)); } diff --git a/engine/table/src/test/java/io/deephaven/engine/table/impl/lang/PyCallableWrapperDummyImpl.java b/engine/table/src/test/java/io/deephaven/engine/table/impl/lang/PyCallableWrapperDummyImpl.java index 672a47b5158..b1e405b67d0 100644 --- a/engine/table/src/test/java/io/deephaven/engine/table/impl/lang/PyCallableWrapperDummyImpl.java +++ b/engine/table/src/test/java/io/deephaven/engine/table/impl/lang/PyCallableWrapperDummyImpl.java @@ -69,6 +69,6 @@ public void addChunkArgument(ChunkArgument ignored) {} @Override public Class getReturnType() { - return Object.class; + throw new UnsupportedOperationException(); } } diff --git a/py/server/deephaven/column.py b/py/server/deephaven/column.py index 81fd5ffa202..635f60794f1 100644 --- a/py/server/deephaven/column.py +++ b/py/server/deephaven/column.py @@ -9,11 +9,13 @@ from typing import Sequence, Any import jpy +import numpy as np +import pandas as pd import deephaven.dtypes as dtypes -from deephaven import DHError +from deephaven import DHError, time from deephaven.dtypes import DType -from deephaven.dtypes import _instant_array +from deephaven.time import to_j_instant _JColumnHeader = jpy.get_type("io.deephaven.qst.column.header.ColumnHeader") _JColumn = jpy.get_type("io.deephaven.qst.column.Column") @@ -204,7 +206,27 @@ def datetime_col(name: str, data: Sequence) -> InputColumn: Returns: a new input column """ - data = _instant_array(data) + + # try to convert to numpy array of datetime64 if not already, so that we can call translateArrayLongToInstant on + # it to reduce the number of round trips to the JVM + if not isinstance(data, np.ndarray): + try: + data = np.array([pd.Timestamp(dt).to_numpy() for dt in data], dtype=np.datetime64) + except Exception as e: + ... + + if isinstance(data, np.ndarray) and data.dtype.kind in ('M', 'i', 'U'): + if data.dtype.kind == 'M': + longs = jpy.array('long', data.astype('datetime64[ns]').astype('int64')) + elif data.dtype.kind == 'i': + longs = jpy.array('long', data.astype('int64')) + else: # data.dtype.kind == 'U' + longs = jpy.array('long', [pd.Timestamp(str(dt)).to_numpy().astype('int64') for dt in data]) + data = _JPrimitiveArrayConversionUtility.translateArrayLongToInstant(longs) + + if not isinstance(data, dtypes.instant_array.j_type): + data = [to_j_instant(d) for d in data] + return InputColumn(name=name, data_type=dtypes.Instant, input_data=data) diff --git a/py/server/deephaven/dtypes.py b/py/server/deephaven/dtypes.py index cedf9c67138..a5f940f38eb 100644 --- a/py/server/deephaven/dtypes.py +++ b/py/server/deephaven/dtypes.py @@ -8,9 +8,7 @@ """ from __future__ import annotations -import datetime -import sys -from typing import Any, Sequence, Callable, Dict, Type, Union, _GenericAlias, Optional +from typing import Any, Sequence, Callable, Dict, Type, Union import jpy import numpy as np @@ -122,8 +120,6 @@ def __call__(self, *args, **kwargs): """Python object type""" JObject = DType(j_name="java.lang.Object") """Java Object type""" -bool_array = DType(j_name='[Z') -"""boolean array type""" byte_array = DType(j_name='[B') """Byte array type""" int8_array = byte_array @@ -132,8 +128,6 @@ def __call__(self, *args, **kwargs): """Short array type""" int16_array = short_array """Short array type""" -char_array = DType(j_name='[C') -"""char array type""" int32_array = DType(j_name='[I') """32bit integer array type""" long_array = DType(j_name='[J') @@ -142,7 +136,7 @@ def __call__(self, *args, **kwargs): """64bit integer array type""" int_array = long_array """64bit integer array type""" -single_array = DType(j_name='[F') +single_array = DType(j_name='[S') """Single-precision floating-point array type""" float32_array = single_array """Single-precision floating-point array type""" @@ -154,8 +148,6 @@ def __call__(self, *args, **kwargs): """Double-precision floating-point array type""" string_array = DType(j_name='[Ljava.lang.String;') """Java String array type""" -boolean_array = DType(j_name='[Ljava.lang.Boolean;') -"""Java Boolean array type""" instant_array = DType(j_name='[Ljava.time.Instant;') """Java Instant array type""" zdt_array = DType(j_name='[Ljava.time.ZonedDateTime;') @@ -172,19 +164,6 @@ def __call__(self, *args, **kwargs): float64: NULL_DOUBLE, } -_BUILDABLE_ARRAY_DTYPE_MAP = { - bool_: bool_array, - byte: int8_array, - char: char_array, - int16: int16_array, - int32: int32_array, - int64: int64_array, - float32: float32_array, - float64: float64_array, - string: string_array, - Instant: instant_array, -} - def null_remap(dtype: DType) -> Callable[[Any], Any]: """ Creates a null value remap function for the provided DType. @@ -205,33 +184,6 @@ def null_remap(dtype: DType) -> Callable[[Any], Any]: return lambda v: null_value if v is None else v -def _instant_array(data: Sequence) -> jpy.JType: - """Converts a sequence of either datetime64[ns], datetime.datetime, pandas.Timestamp, datetime strings, - or integers in nanoseconds, to a Java array of Instant values. """ - # try to convert to numpy array of datetime64 if not already, so that we can call translateArrayLongToInstant on - # it to reduce the number of round trips to the JVM - if not isinstance(data, np.ndarray): - try: - data = np.array([pd.Timestamp(dt).to_numpy() for dt in data], dtype=np.datetime64) - except Exception as e: - ... - - if isinstance(data, np.ndarray) and data.dtype.kind in ('M', 'i', 'U'): - if data.dtype.kind == 'M': - longs = jpy.array('long', data.astype('datetime64[ns]').astype('int64')) - elif data.dtype.kind == 'i': - longs = jpy.array('long', data.astype('int64')) - else: # data.dtype.kind == 'U' - longs = jpy.array('long', [pd.Timestamp(str(dt)).to_numpy().astype('int64') for dt in data]) - data = _JPrimitiveArrayConversionUtility.translateArrayLongToInstant(longs) - - if not isinstance(data, instant_array.j_type): - from deephaven.time import to_j_instant - data = [to_j_instant(d) for d in data] - - return data - - def array(dtype: DType, seq: Sequence, remap: Callable[[Any], Any] = None) -> jpy.JType: """ Creates a Java array of the specified data type populated with values from a sequence. @@ -263,14 +215,14 @@ def array(dtype: DType, seq: Sequence, remap: Callable[[Any], Any] = None) -> jp raise ValueError("Not a callable") seq = [remap(v) for v in seq] - if dtype == Instant: - return _instant_array(seq) - if isinstance(seq, np.ndarray): if dtype == bool_: bytes_ = seq.astype(dtype=np.int8) j_bytes = array(byte, bytes_) seq = _JPrimitiveArrayConversionUtility.translateArrayByteToBoolean(j_bytes) + elif dtype == Instant: + longs = jpy.array('long', seq.astype('datetime64[ns]').astype('int64')) + seq = _JPrimitiveArrayConversionUtility.translateArrayLongToInstant(longs) return jpy.array(dtype.j_type, seq) except Exception as e: @@ -314,66 +266,3 @@ def from_np_dtype(np_dtype: Union[np.dtype, pd.api.extensions.ExtensionDtype]) - return dtype return PyObject - - -_NUMPY_INT_TYPE_CODES = ["i", "l", "h", "b"] -_NUMPY_FLOATING_TYPE_CODES = ["f", "d"] - - -def _scalar(x): - """Converts a Python value to a Java scalar value. It converts the numpy primitive types, string to - their Python equivalents so that JPY can handle them. For datetime values, it converts them to Java Instant. - Otherwise, it returns the value as is.""" - if hasattr(x, "dtype"): - if x.dtype.char in _NUMPY_INT_TYPE_CODES: - return int(x) - elif x.dtype.char in _NUMPY_FLOATING_TYPE_CODES: - return float(x) - elif x.dtype.char == '?': - return bool(x) - elif x.dtype.char == 'U': - return str(x) - elif x.dtype.char == 'O': - return x - elif x.dtype.char == 'M': - from deephaven.time import to_j_instant - return to_j_instant(x) - else: - raise TypeError(f"Unsupported dtype: {x.dtype}") - else: - if isinstance(x, (datetime.datetime, pd.Timestamp)): - from deephaven.time import to_j_instant - return to_j_instant(x) - return x - - -def _np_dtype_char(t: Union[type, str]) -> str: - """Returns the numpy dtype character code for the given type.""" - try: - np_dtype = np.dtype(t if t else "object") - if np_dtype.kind == "O": - if t in (datetime.datetime, pd.Timestamp): - return "M" - except TypeError: - np_dtype = np.dtype("object") - - return np_dtype.char - - -def _component_np_dtype_char(t: type) -> Optional[str]: - """Returns the numpy dtype character code for the given type's component type if the type is a Sequence type or - numpy ndarray, otherwise return None. """ - component_type = None - if isinstance(t, _GenericAlias) and issubclass(t.__origin__, Sequence): - component_type = t.__args__[0] - - # np.ndarray as a generic alias is only supported in Python 3.9+ - if not component_type and sys.version_info.minor > 8: - import types - if isinstance(t, types.GenericAlias) and (issubclass(t.__origin__, Sequence) or t.__origin__ == np.ndarray): - component_type = t.__args__[0] - - if component_type: - return _np_dtype_char(component_type) - else: - return None diff --git a/py/server/deephaven/table.py b/py/server/deephaven/table.py index 78849e4cf46..15181e87691 100644 --- a/py/server/deephaven/table.py +++ b/py/server/deephaven/table.py @@ -11,12 +11,10 @@ import inspect from enum import Enum from enum import auto -from functools import wraps from typing import Any, Optional, Callable, Dict from typing import Sequence, List, Union, Protocol import jpy -import numba import numpy as np from deephaven import DHError @@ -31,8 +29,6 @@ from deephaven.jcompat import to_sequence, j_array_list from deephaven.update_graph import auto_locking_ctx, UpdateGraph from deephaven.updateby import UpdateByOperation -from deephaven.dtypes import _BUILDABLE_ARRAY_DTYPE_MAP, _scalar, _np_dtype_char, \ - _component_np_dtype_char # Table _J_Table = jpy.get_type("io.deephaven.engine.table.Table") @@ -363,93 +359,40 @@ def _j_py_script_session() -> _JPythonScriptSession: return None -_SUPPORTED_NP_TYPE_CODES = ["i", "l", "h", "f", "d", "b", "?", "U", "M", "O"] +_numpy_type_codes = ["i", "l", "h", "f", "d", "b", "?", "U", "O"] def _encode_signature(fn: Callable) -> str: """Encode the signature of a Python function by mapping the annotations of the parameter types and the return - type to numpy dtype chars (i,l,h,f,d,b,?,U,M,O), and pack them into a string with parameter type chars first, + type to numpy dtype chars (i,l,h,f,d,b,?,U,O), and pack them into a string with parameter type chars first, in their original order, followed by the delimiter string '->', then the return type_char. If a parameter or the return of the function is not annotated, the default 'O' - object type, will be used. """ sig = inspect.signature(fn) - np_type_codes = [] + parameter_types = [] for n, p in sig.parameters.items(): - np_type_codes.append(_np_dtype_char(p.annotation)) + try: + np_dtype = np.dtype(p.annotation if p.annotation else "object") + parameter_types.append(np_dtype) + except TypeError: + parameter_types.append(np.dtype("object")) + + try: + return_type = np.dtype(sig.return_annotation if sig.return_annotation else "object") + except TypeError: + return_type = np.dtype("object") - return_type_code = _np_dtype_char(sig.return_annotation) - np_type_codes = [c if c in _SUPPORTED_NP_TYPE_CODES else "O" for c in np_type_codes] - return_type_code = return_type_code if return_type_code in _SUPPORTED_NP_TYPE_CODES else "O" + np_type_codes = [np.dtype(p).char for p in parameter_types] + np_type_codes = [c if c in _numpy_type_codes else "O" for c in np_type_codes] + return_type_code = np.dtype(return_type).char + return_type_code = return_type_code if return_type_code in _numpy_type_codes else "O" np_type_codes.extend(["-", ">", return_type_code]) return "".join(np_type_codes) -def _py_udf(fn: Callable): - """A decorator that acts as a transparent translator for Python UDFs used in Deephaven query formulas between - Python and Java. This decorator is intended for use by the Deephaven query engine and should not be used by - users. - - For now, this decorator is only capable of converting Python function return values to Java values. It - does not yet convert Java values in arguments to usable Python object (e.g. numpy arrays) or properly translate - Deephaven primitive null values. - - For properly annotated functions, including numba vectorized and guvectorized ones, this decorator inspects the - signature of the function and determines its return type, including supported primitive types and arrays of - the supported primitive types. It then converts the return value of the function to the corresponding Java value - of the same type. For unsupported types, the decorator returns the original Python value which appears as - org.jpy.PyObject in Java. - """ - - if hasattr(fn, "return_type"): - return fn - - if isinstance(fn, (numba.np.ufunc.dufunc.DUFunc, numba.np.ufunc.gufunc.GUFunc)) and hasattr(fn, "types"): - dh_dtype = dtypes.from_np_dtype(np.dtype(fn.types[0][-1])) - else: - dh_dtype = dtypes.from_np_dtype(np.dtype(_encode_signature(fn)[-1])) - - return_array = False - - # If the function is a numba guvectorized function, examine the signature of the function to determine if it - # returns an array. - if isinstance(fn, numba.np.ufunc.gufunc.GUFunc): - sig = fn.signature - rtype = sig.split("->")[-1].strip("()") - if rtype: - return_array = True - else: - component_type = _component_np_dtype_char(inspect.signature(fn).return_annotation) - if component_type: - dh_dtype = dtypes.from_np_dtype(np.dtype(component_type)) - if dh_dtype in _BUILDABLE_ARRAY_DTYPE_MAP: - return_array = True - - @wraps(fn) - def wrapper(*args, **kwargs): - ret = fn(*args, **kwargs) - if return_array: - return dtypes.array(dh_dtype, ret) - elif dh_dtype == dtypes.PyObject: - return ret - else: - return _scalar(ret) - - wrapper.j_name = dh_dtype.j_name - ret_dtype = _BUILDABLE_ARRAY_DTYPE_MAP.get(dh_dtype) if return_array else dh_dtype - - if hasattr(dh_dtype.j_type, 'jclass'): - j_class = ret_dtype.j_type.jclass - else: - j_class = ret_dtype.qst_type.clazz() - - wrapper.return_type = j_class - - return wrapper - - def dh_vectorize(fn): """A decorator to vectorize a Python function used in Deephaven query formulas and invoked on a row basis. @@ -467,7 +410,6 @@ def dh_vectorize(fn): """ signature = _encode_signature(fn) - @wraps(fn) def wrapper(*args): if len(args) != len(signature) - len("->?") + 2: raise ValueError( @@ -481,7 +423,7 @@ def wrapper(*args): vectorized_args = zip(*args[2:]) for i in range(chunk_size): scalar_args = next(vectorized_args) - chunk_result[i] = _scalar(fn(*scalar_args)) + chunk_result[i] = fn(*scalar_args) else: for i in range(chunk_size): chunk_result[i] = fn() diff --git a/py/server/tests/test_numba_guvectorize.py b/py/server/tests/test_numba_guvectorize.py deleted file mode 100644 index c82b92296e3..00000000000 --- a/py/server/tests/test_numba_guvectorize.py +++ /dev/null @@ -1,94 +0,0 @@ -# -# Copyright (c) 2016-2023 Deephaven Data Labs and Patent Pending -# - -import unittest - -import numpy as np -from numba import guvectorize, int64 - -from deephaven import empty_table, dtypes -from tests.testbase import BaseTestCase - -a = np.arange(5, dtype=np.int64) - - -class NumbaGuvectorizeTestCase(BaseTestCase): - def test_scalar_return(self): - # vector input to scalar output function (m)->() - @guvectorize([(int64[:], int64[:])], "(m)->()", nopython=True) - def g(x, res): - res[0] = 0 - for xi in x: - res[0] += xi - - t = empty_table(10).update(["X=i%3", "Y=i"]).group_by("X").update("Z=g(Y)") - m = t.meta_table - self.assertEqual(t.columns[2].data_type, dtypes.int64) - - def test_vector_return(self): - # vector and scalar input to vector ouput function - @guvectorize([(int64[:], int64, int64[:])], "(m),()->(m)", nopython=True) - def g(x, y, res): - for i in range(len(x)): - res[i] = x[i] + y - - t = empty_table(10).update(["X=i%3", "Y=i"]).group_by("X").update("Z=g(Y,2)") - self.assertEqual(t.columns[2].data_type, dtypes.long_array) - - def test_fixed_length_vector_return(self): - # NOTE: the following does not work according to this thread from 7 years ago: - # https://numba-users.continuum.narkive.com/7OAX8Suv/numba-guvectorize-with-fixed-size-output-array - # but the latest numpy Generalized Universal Function API does seem to support frozen dimensions - # https://numpy.org/doc/stable/reference/c-api/generalized-ufuncs.html#generalized-universal-function-api - # There is an old numba ticket about it - # https://github.com/numba/numba/issues/1668 - # Possibly we could contribute a fix - - # fails with: bad token in signature "2" - - # #vector input to fixed-length vector ouput function - # @guvectorize([(int64[:],int64[:])],"(m)->(2)",nopython=True) - # def g3(x, res): - # res[0] = min(x) - # res[1] = max(x) - - # t3 = empty_table(10).update(["X=i%3", "Y=i"]).group_by("X").update("Z=g3(Y)") - # m3 = t3.meta_table - - # ** Workaround ** - - dummy = np.array([0, 0], dtype=np.int64) - - # vector input to fixed-length vector ouput function -- second arg is a dummy just to get a fixed size output - @guvectorize([(int64[:], int64[:], int64[:])], "(m),(n)->(n)", nopython=True) - def g(x, dummy, res): - res[0] = min(x) - res[1] = max(x) - - t = empty_table(10).update(["X=i%3", "Y=i"]).group_by("X").update("Z=g(Y,dummy)") - self.assertEqual(t.columns[2].data_type, dtypes.long_array) - - def test_np_on_java_array(self): - dummy = np.array([0, 0], dtype=np.int64) - - # vector input to fixed-length vector output function -- second arg is a dummy just to get a fixed size output - @guvectorize([(int64[:], int64[:], int64[:])], "(m),(n)->(n)", nopython=True) - def g(x, dummy, res): - res[0] = np.min(x) - res[1] = np.max(x) - - t = empty_table(10).update(["X=i%3", "Y=i"]).group_by("X").update("Z=g(Y,dummy)") - self.assertEqual(t.columns[2].data_type, dtypes.long_array) - - def test_np_on_java_array2(self): - @guvectorize([(int64[:], int64[:])], "(m)->(m)", nopython=True) - def g(x, res): - res[:] = x + 5 - - t = empty_table(10).update(["X=i%3", "Y=i"]).group_by("X").update("Z=g(Y)") - self.assertEqual(t.columns[2].data_type, dtypes.long_array) - - -if __name__ == '__main__': - unittest.main() diff --git a/py/server/tests/test_numba_vectorized_column.py b/py/server/tests/test_numba_vectorized_column.py index 6f14eb7b659..d47687c96d1 100644 --- a/py/server/tests/test_numba_vectorized_column.py +++ b/py/server/tests/test_numba_vectorized_column.py @@ -16,7 +16,7 @@ def vectorized_func(x, y): return x % 3 + y -class NumbaVectorizedColumnTestCase(BaseTestCase): +class TestNumbaVectorizedColumnClass(BaseTestCase): def test_part_of_expr(self): with self.assertRaises(Exception): diff --git a/py/server/tests/test_numba_vectorized_filter.py b/py/server/tests/test_numba_vectorized_filter.py index 468cc9e04c3..98155dadff4 100644 --- a/py/server/tests/test_numba_vectorized_filter.py +++ b/py/server/tests/test_numba_vectorized_filter.py @@ -22,7 +22,7 @@ def vectorized_func_wrong_return_type(x, y): return x % 2 > y % 5 -class NumbaVectorizedFilterTestCase(BaseTestCase): +class TestNumbaVectorizedFilterClass(BaseTestCase): def test_wrong_return_type(self): with self.assertRaises(Exception): diff --git a/py/server/tests/test_pyfunc_return_java_values.py b/py/server/tests/test_pyfunc_return_java_values.py deleted file mode 100644 index 34501242130..00000000000 --- a/py/server/tests/test_pyfunc_return_java_values.py +++ /dev/null @@ -1,174 +0,0 @@ -# -# Copyright (c) 2016-2023 Deephaven Data Labs and Patent Pending -# -import time -import unittest -from typing import List, Tuple, Sequence, Union - -import numpy as np -import pandas as pd -import datetime -from deephaven import empty_table, dtypes -from deephaven import time as dhtime -from tests.testbase import BaseTestCase - -_J_TYPE_NP_DTYPE_MAP = { - dtypes.double: "np.float64", - dtypes.float32: "np.float32", - dtypes.int32: "np.int32", - dtypes.long: "np.int64", - dtypes.short: "np.int16", - dtypes.byte: "np.int8", - dtypes.bool_: "np.bool_", - dtypes.string: "np.str_", - # dtypes.char: "np.uint16", -} - - -class PyFuncReturnJavaTestCase(BaseTestCase): - def test_scalar_return(self): - for dh_dtype, np_dtype in _J_TYPE_NP_DTYPE_MAP.items(): - with self.subTest(dh_dtype=dh_dtype, np_dtype=np_dtype): - func_str = f""" -def fn(col) -> {np_dtype}: - return {np_dtype}(col) -""" - exec(func_str, globals()) - - t = empty_table(10).update("X = i").update(f"Y= fn(X + 1)") - self.assertEqual(t.columns[1].data_type, dh_dtype) - - def test_array_return(self): - component_types = { - "int": dtypes.long_array, - "float": dtypes.double_array, - "np.int8": dtypes.byte_array, - "np.int16": dtypes.short_array, - "np.int32": dtypes.int32_array, - "np.int64": dtypes.long_array, - "np.float32": dtypes.float32_array, - "np.float64": dtypes.double_array, - "bool": dtypes.boolean_array, - "np.str_": dtypes.string_array, - # "np.uint16": dtypes.char_array, - } - container_types = ["List", "Tuple", "list", "tuple", "Sequence", "np.ndarray"] - for component_type, dh_dtype in component_types.items(): - for container_type in container_types: - with self.subTest(component_type=component_type, container_type=container_type): - func_decl_str = f"""def fn(col) -> {container_type}[{component_type}]:""" - if container_type == "np.ndarray": - func_body_str = f""" return np.array([{component_type}(c) for c in col])""" - else: - func_body_str = f""" return [{component_type}(c) for c in col]""" - exec("\n".join([func_decl_str, func_body_str]), globals()) - t = empty_table(10).update(["X = i % 3", "Y = i"]).group_by("X").update(f"Z= fn(Y + 1)") - self.assertEqual(t.columns[2].data_type, dh_dtype) - - def test_scalar_return_class_method_not_supported(self): - for dh_dtype, np_dtype in _J_TYPE_NP_DTYPE_MAP.items(): - with self.subTest(dh_dtype=dh_dtype, np_dtype=np_dtype): - func_str = f""" -class Foo: - def fn(self, col) -> {np_dtype}: - return {np_dtype}(col) -foo = Foo() -""" - exec(func_str, globals()) - - t = empty_table(10).update("X = i").update(f"Y= foo.fn(X + 1)") - self.assertNotEqual(t.columns[1].data_type, dh_dtype) - - def test_datetime_scalar_return(self): - dt_dtypes = [ - "np.dtype('datetime64[ns]')", - "np.dtype('datetime64[ms]')", - "datetime.datetime", - "pd.Timestamp" - ] - - for np_dtype in dt_dtypes: - with self.subTest(np_dtype=np_dtype): - func_decl_str = f"""def fn(col) -> {np_dtype}:""" - if np_dtype == "np.dtype('datetime64[ns]')": - func_body_str = f""" return pd.Timestamp(col).to_numpy()""" - elif np_dtype == "datetime.datetime": - func_body_str = f""" return pd.Timestamp(col).to_pydatetime()""" - elif np_dtype == "pd.Timestamp": - func_body_str = f""" return pd.Timestamp(col)""" - - exec("\n".join([func_decl_str, func_body_str]), globals()) - - t = empty_table(10).update("X = i").update(f"Y= fn(X + 1)") - self.assertEqual(t.columns[1].data_type, dtypes.Instant) - # vectorized - t = empty_table(10).update("X = i").update(f"Y= fn(X)") - self.assertEqual(t.columns[1].data_type, dtypes.Instant) - - def test_datetime_array_return(self): - - dt = datetime.datetime.now() - ts = pd.Timestamp(dt) - np_dt = np.datetime64(dt) - dt_list = [ts, np_dt, dt] - - # test if we can convert to numpy datetime64 array - np_array = np.array([pd.Timestamp(dt).to_numpy() for dt in dt_list], dtype=np.datetime64) - - dt_dtypes = [ - "np.ndarray[np.dtype('datetime64[ns]')]", - "List[datetime.datetime]", - "Tuple[pd.Timestamp]" - ] - - dt_data = [ - "dt_list", - "np_array", - ] - - # we are capable of mapping all datetime array (Sequence, np.ndarray) to instant array, so even if the actual - # return value of the function doesn't match its type hint (which will be caught by a static type checker. we - # still can convert it to instant array - for np_dtype in dt_dtypes: - for data in dt_data: - with self.subTest(np_dtype=np_dtype, data=data): - func_decl_str = f"""def fn(col) -> {np_dtype}:""" - func_body_str = f""" return {data}""" - exec("\n".join([func_decl_str, func_body_str]), globals().update( - {"dt_list": dt_list, "np_array": np_array})) - - t = empty_table(10).update("X = i").update(f"Y= fn(X + 1)") - self.assertEqual(t.columns[1].data_type, dtypes.instant_array) - - def test_return_value_errors(self): - def fn(col) -> List[object]: - return [col] - - def fn1(col) -> List: - return [col] - - def fn2(col): - return col - - def fn3(col) -> List[Union[datetime.datetime, int]]: - return [col] - - with self.subTest(fn): - t = empty_table(1).update("X = i").update(f"Y= fn(X + 1)") - self.assertEqual(t.columns[1].data_type, dtypes.JObject) - - with self.subTest(fn1): - t = empty_table(1).update("X = i").update(f"Y= fn1(X + 1)") - self.assertEqual(t.columns[1].data_type, dtypes.JObject) - - with self.subTest(fn2): - t = empty_table(1).update("X = i").update(f"Y= fn2(X + 1)") - self.assertEqual(t.columns[1].data_type, dtypes.JObject) - - with self.subTest(fn3): - t = empty_table(1).update("X = i").update(f"Y= fn3(X + 1)") - self.assertEqual(t.columns[1].data_type, dtypes.JObject) - - -if __name__ == '__main__': - unittest.main() diff --git a/py/server/tests/test_vectorization.py b/py/server/tests/test_vectorization.py index 4bb941788fd..b38004f8dcb 100644 --- a/py/server/tests/test_vectorization.py +++ b/py/server/tests/test_vectorization.py @@ -143,7 +143,7 @@ def pyfunc(p1, p2, p3) -> int: return p1 + p2 + p3 t = empty_table(1).update("X = i").update(["Y = pyfunc(X, i, 33)", "Z = pyfunc(X, ii, 66)"]) - self.assertEqual(deephaven.table._vectorized_count, 1) + self.assertEqual(deephaven.table._vectorized_count, 3) self.assertIn("33", t.to_string(cols=["Y"])) self.assertIn("66", t.to_string(cols=["Z"])) @@ -186,11 +186,11 @@ def pyfunc_bool(p1, p2, p3) -> bool: conditions = ["pyfunc_bool(I, 3, J)", "pyfunc_bool(i, 10, ii)"] filters = Filter.from_(conditions) t = empty_table(10).view(formulas=["I=ii", "J=(ii * 2)"]).where(filters) - self.assertEqual(1, deephaven.table._vectorized_count) + self.assertEqual(3, deephaven.table._vectorized_count) filter_and = and_(filters) t1 = empty_table(10).view(formulas=["I=ii", "J=(ii * 2)"]).where(filter_and) - self.assertEqual(1, deephaven.table._vectorized_count) + self.assertEqual(5, deephaven.table._vectorized_count) self.assertEqual(t1.size, t.size) self.assertEqual(9, t.size)