From 102d7225e7505e0bf9a75d163e23bea96477311e Mon Sep 17 00:00:00 2001 From: jianfengmao Date: Fri, 15 Mar 2024 10:44:34 -0600 Subject: [PATCH 01/10] WIP check arg type against defs at parsing time --- .../table/impl/lang/QueryLanguageParser.java | 44 +++++-- .../engine/util/PyCallableWrapper.java | 44 ++++++- .../engine/util/PyCallableWrapperJpyImpl.java | 119 ++++++++++++++---- .../impl/lang/PyCallableWrapperDummyImpl.java | 15 +++ py/server/deephaven/_udf.py | 5 +- py/server/tests/test_numba_guvectorize.py | 13 +- py/server/tests/test_udf_numpy_args.py | 8 ++ 7 files changed, 205 insertions(+), 43 deletions(-) diff --git a/engine/table/src/main/java/io/deephaven/engine/table/impl/lang/QueryLanguageParser.java b/engine/table/src/main/java/io/deephaven/engine/table/impl/lang/QueryLanguageParser.java index 0f636d9a5b5..e6a081297b5 100644 --- a/engine/table/src/main/java/io/deephaven/engine/table/impl/lang/QueryLanguageParser.java +++ b/engine/table/src/main/java/io/deephaven/engine/table/impl/lang/QueryLanguageParser.java @@ -1939,7 +1939,7 @@ private static boolean isAssociativitySafeExpression(Expression expr) { * @return {@code true} if a conversion from {@code original} to {@code target} is a widening conversion; otherwise, * {@code false}. */ - static boolean isWideningPrimitiveConversion(Class original, Class target) { + public static boolean isWideningPrimitiveConversion(Class original, Class target) { if (original == null || !original.isPrimitive() || target == null || !target.isPrimitive() || original.equals(void.class) || target.equals(void.class)) { throw new IllegalArgumentException("Arguments must be a primitive type (excluding void)!"); @@ -2498,6 +2498,7 @@ public Class visit(MethodCallExpr n, VisitArgs printer) { // Attempt python function call vectorization. if (scopeType != null && PyCallableWrapper.class.isAssignableFrom(scopeType)) { + verifyPyCallableArguments(n, argTypes); tryVectorizePythonCallable(n, scopeType, convertedArgExpressions, argTypes); } @@ -2505,6 +2506,28 @@ public Class visit(MethodCallExpr n, VisitArgs printer) { typeArguments); } + private void verifyPyCallableArguments(@NotNull MethodCallExpr n, @NotNull Class[] argTypes) { + final String invokedMethodName = n.getNameAsString(); + + if (GET_ATTRIBUTE_METHOD_NAME.equals(invokedMethodName)) { + // Only PyCallableWrapper.getAttribute()/PyCallableWrapper.call() may be invoked from the query language. + // UDF type checks are not currently supported for getAttribute() calls. + return; + } + if (!n.containsData(QueryLanguageParserDataKeys.PY_CALLABLE_DETAILS)) { + return; + } + final PyCallableDetails pyCallableDetails = n.getData(QueryLanguageParserDataKeys.PY_CALLABLE_DETAILS); + final String pyMethodName = pyCallableDetails.pythonMethodName; + final Object paramValueRaw = queryScopeVariables.get(pyMethodName); + if (!(paramValueRaw instanceof PyCallableWrapper)) { + return; + } + final PyCallableWrapper pyCallableWrapper = (PyCallableWrapper) paramValueRaw; + pyCallableWrapper.parseSignature(); + pyCallableWrapper.verifyArguments(argTypes); + } + private Optional makeCastExpressionForPyCallable(Class retType, MethodCallExpr callMethodCall) { if (retType.isPrimitive()) { return Optional.of(new CastExpr( @@ -2726,11 +2749,10 @@ private void checkVectorizability(@NotNull final MethodCallExpr n, } } - List> paramTypes = pyCallableWrapper.getParamTypes(); - if (paramTypes.size() != expressions.length) { + if (pyCallableWrapper.getNumParameters() != expressions.length) { // note vectorization doesn't handle Python variadic arguments throw new PythonCallVectorizationFailure("Python function argument count mismatch: " + n + " " - + paramTypes.size() + " vs. " + expressions.length); + + pyCallableWrapper.getNumParameters() + " vs. " + expressions.length); } } @@ -2739,10 +2761,9 @@ private void prepareVectorizationArgs( Expression[] expressions, Class[] argTypes, PyCallableWrapper pyCallableWrapper) { - List> paramTypes = pyCallableWrapper.getParamTypes(); - if (paramTypes.size() != expressions.length) { + if (pyCallableWrapper.getNumParameters() != expressions.length) { throw new PythonCallVectorizationFailure("Python function argument count mismatch: " + n + " " - + paramTypes.size() + " vs. " + expressions.length); + + pyCallableWrapper.getNumParameters() + " vs. " + expressions.length); } pyCallableWrapper.initializeChunkArguments(); @@ -2764,10 +2785,11 @@ private void prepareVectorizationArgs( throw new IllegalStateException("Vectorizability check failed: " + n); } - if (!isSafelyCoerceable(argTypes[i], paramTypes.get(i))) { - throw new PythonCallVectorizationFailure("Python vectorized function argument type mismatch: " + n + " " - + argTypes[i].getSimpleName() + " -> " + paramTypes.get(i).getSimpleName()); - } + // TODO related to core#709, but should be covered by PyCallableWrapper.verifyArguments, needs to verify +// if (!isSafelyCoerceable(argTypes[i], paramTypes.get(i))) { +// throw new PythonCallVectorizationFailure("Python vectorized function argument type mismatch: " + n + " " +// + argTypes[i].getSimpleName() + " -> " + paramTypes.get(i).getSimpleName()); +// } } } diff --git a/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapper.java b/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapper.java index b9c24773981..d1baf80903c 100644 --- a/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapper.java +++ b/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapper.java @@ -6,6 +6,7 @@ import org.jpy.PyObject; import java.util.List; +import java.util.Set; /** * Created by rbasralian on 8/12/23 @@ -19,7 +20,9 @@ public interface PyCallableWrapper { Object call(Object... args); - List> getParamTypes(); + List getParameters(); + + int getNumParameters(); boolean isVectorized(); @@ -33,6 +36,8 @@ public interface PyCallableWrapper { Class getReturnType(); + void verifyArguments(Class[] argTypes); + abstract class ChunkArgument { private final Class type; @@ -88,4 +93,41 @@ public Object getValue() { } boolean isVectorizableReturnType(); + + class Signature { + private final List parameters; + private final Class returnType; + + public Signature(List parameters, Class returnType) { + this.parameters = parameters; + this.returnType = returnType; + } + + public List getParameters() { + return parameters; + } + + public Class getReturnType() { + return returnType; + } + } + class Parameter { + private final String name; + private final Set> possibleTypes; + + + public Parameter(String name, Set> possibleTypes) { + this.name = name; + this.possibleTypes = possibleTypes; + } + + public Set> getPossibleTypes() { + return possibleTypes; + } + + public String getName() { + return name; + } + } + } diff --git a/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapperJpyImpl.java b/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapperJpyImpl.java index 6ba399302cb..e7f7c977918 100644 --- a/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapperJpyImpl.java +++ b/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapperJpyImpl.java @@ -10,12 +10,9 @@ import org.jpy.PyObject; import java.time.Instant; -import java.util.ArrayList; -import java.util.Collection; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Set; +import java.util.*; + +import static io.deephaven.engine.table.impl.lang.QueryLanguageParser.isWideningPrimitiveConversion; /** * When given a pyObject that is a callable, we stick it inside the callable wrapper, which implements a call() varargs @@ -30,6 +27,7 @@ public class PyCallableWrapperJpyImpl implements PyCallableWrapper { private static final PyModule dh_udf_module = PyModule.importModule("deephaven._udf"); private static final Map> numpyType2JavaClass = new HashMap<>(); + private static final Map> numpyType2JavaArrayClass = new HashMap<>(); static { numpyType2JavaClass.put('b', byte.class); @@ -43,8 +41,22 @@ public class PyCallableWrapperJpyImpl implements PyCallableWrapper { numpyType2JavaClass.put('U', String.class); numpyType2JavaClass.put('M', Instant.class); numpyType2JavaClass.put('O', Object.class); + + numpyType2JavaArrayClass.put('b', byte[].class); + numpyType2JavaArrayClass.put('h', short[].class); + numpyType2JavaArrayClass.put('H', char[].class); + numpyType2JavaArrayClass.put('i', int[].class); + numpyType2JavaArrayClass.put('l', long[].class); + numpyType2JavaArrayClass.put('f', float[].class); + numpyType2JavaArrayClass.put('d', double[].class); + numpyType2JavaArrayClass.put('?', boolean[].class); + numpyType2JavaArrayClass.put('U', String[].class); + numpyType2JavaArrayClass.put('M', Instant[].class); + numpyType2JavaArrayClass.put('O', Object[].class); } + + /** * Ensure that the class initializer runs. */ @@ -75,8 +87,8 @@ public boolean isVectorizableReturnType() { private final PyObject pyCallable; - private String signature = null; - private List> paramTypes; + private String signatureString = null; + private List parameters = new ArrayList<>(); private Class returnType; private boolean vectorizable = false; private boolean vectorized = false; @@ -168,12 +180,13 @@ private void prepareSignature() { vectorized = false; } pyUdfDecoratedCallable = dh_udf_module.call("_py_udf", unwrapped); - signature = pyUdfDecoratedCallable.getAttribute("signature").toString(); + signatureString = pyUdfDecoratedCallable.getAttribute("signature").toString(); } + @Override public void parseSignature() { - if (signature != null) { + if (signatureString != null) { return; } @@ -181,27 +194,47 @@ public void parseSignature() { // the 'types' field of a vectorized function follows the pattern of '[ilhfdb?O]*->[ilhfdb?O]', // eg. [ll->d] defines two int64 (long) arguments and a double return type. - if (signature == null || signature.isEmpty()) { + if (signatureString == null || signatureString.isEmpty()) { throw new IllegalStateException("Signature should always be available."); } - List> paramTypes = new ArrayList<>(); - for (char numpyTypeChar : signature.toCharArray()) { - if (numpyTypeChar != '-') { - Class paramType = numpyType2JavaClass.get(numpyTypeChar); - if (paramType == null) { - throw new IllegalStateException( - "Parameters of vectorized functions should always be of integral, floating point, boolean, String, or Object type: " - + numpyTypeChar + " of " + signature); +// List> paramTypes = new ArrayList<>(); +// for (char numpyTypeChar : signatureString.toCharArray()) { +// if (numpyTypeChar != '-') { +// Class paramType = numpyType2JavaClass.get(numpyTypeChar); +// if (paramType == null) { +// throw new IllegalStateException( +// "Parameters of vectorized functions should always be of integral, floating point, boolean, String, or Object type: " +// + numpyTypeChar + " of " + signatureString); +// } +// paramTypes.add(paramType); +// } else { +// break; +// } +// } +// this.paramTypes = paramTypes; + String pyEncodedParamsStr = signatureString.split("->")[0]; + if (!pyEncodedParamsStr.isEmpty()){ + String[] pyEncodedParams = pyEncodedParamsStr.split(","); + for (int i = 0; i < pyEncodedParams.length; i++) { + String[] paramDetail = pyEncodedParams[i].split(":"); + String paramName = paramDetail[0]; + String paramTypeCodes = paramDetail[1]; + Set> possibleTypes = new HashSet<>(); + for (int ti = 0; ti < paramTypeCodes.length(); ti++) { + char typeCode = paramTypeCodes.charAt(ti); + if (typeCode == '[') { + // skip the array type code + ti++; + possibleTypes.add(numpyType2JavaArrayClass.get(paramTypeCodes.charAt(ti))); + } else { + possibleTypes.add(numpyType2JavaClass.get(typeCode)); + } } - paramTypes.add(paramType); - } else { - break; + parameters.add(new Parameter(paramName, possibleTypes)); } } - this.paramTypes = paramTypes; - returnType = pyUdfDecoratedCallable.getAttribute("return_type", null); if (returnType == null) { throw new IllegalStateException( @@ -213,6 +246,35 @@ public void parseSignature() { } } + private boolean isSafelyCastable(Set> types, Class type) { + for (Class t : types) { + if (t.isAssignableFrom(type)) { + return true; + } + if (t.isPrimitive() && type.isPrimitive() && isWideningPrimitiveConversion(type, t)) { + return true; + } + } + return false; + } + + + public void verifyArguments(Class[] argTypes) { + String callableName = pyCallable.getAttribute("__name__").toString(); + + if (argTypes.length != parameters.size()) { + throw new IllegalArgumentException(callableName + ": " + "Expected " + parameters.size() + " arguments, got " + argTypes.length); + } + for (int i = 0; i < argTypes.length; i++) { + Set> types = parameters.get(i).getPossibleTypes(); + if (!types.contains(argTypes[i]) && !types.contains(Object.class) && !isSafelyCastable(types, argTypes[i])) { + throw new IllegalArgumentException( + callableName + ": " + "Expected argument (" + parameters.get(i).getName() + ") to be one of " + parameters.get(i).getPossibleTypes() + ", got " + + argTypes[i]); + } + } + }; + // In vectorized mode, we want to call the vectorized function directly. public PyObject vectorizedCallable() { if (numbaVectorized || vectorized) { @@ -230,8 +292,13 @@ public Object call(Object... args) { } @Override - public List> getParamTypes() { - return paramTypes; + public List getParameters() { + return parameters; + } + + @Override + public int getNumParameters() { + return parameters.size(); } @Override diff --git a/engine/table/src/test/java/io/deephaven/engine/table/impl/lang/PyCallableWrapperDummyImpl.java b/engine/table/src/test/java/io/deephaven/engine/table/impl/lang/PyCallableWrapperDummyImpl.java index 76ef69a0687..a0fb1278949 100644 --- a/engine/table/src/test/java/io/deephaven/engine/table/impl/lang/PyCallableWrapperDummyImpl.java +++ b/engine/table/src/test/java/io/deephaven/engine/table/impl/lang/PyCallableWrapperDummyImpl.java @@ -6,6 +6,7 @@ import io.deephaven.engine.util.PyCallableWrapper; import org.jpy.PyObject; +import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -45,6 +46,15 @@ public Object call(Object... args) { } @Override + public List getParameters() { + return new ArrayList<>(); + } + + @Override + public int getNumParameters() { + return 0; + } + public List> getParamTypes() { return parameterTypes; } @@ -75,6 +85,11 @@ public Class getReturnType() { return Object.class; } + @Override + public void verifyArguments(Class[] argTypes) { + + } + @Override public boolean isVectorizableReturnType() { return false; diff --git a/py/server/deephaven/_udf.py b/py/server/deephaven/_udf.py index 2bfe947e3c5..71f684f1ef9 100644 --- a/py/server/deephaven/_udf.py +++ b/py/server/deephaven/_udf.py @@ -64,7 +64,7 @@ def encoded(self) -> str: then the return type char. If a parameter or the return of the function is not annotated, the default 'O' - object type, will be used. """ - param_str = ",".join(["".join(p.encoded_types) for p in self.params]) + param_str = ",".join([str(p.name) + ":" + "".join(p.encoded_types) for p in self.params]) # ret_annotation has only one parsed annotation, and it might be Optional which means it contains 'N' in the # encoded type. We need to remove it. return_type_code = re.sub(r"[N]", "", self.ret_annotation.encoded_type) @@ -414,7 +414,8 @@ def wrapper(*args, **kwargs): j_class = real_ret_dtype.qst_type.clazz() wrapper.return_type = j_class - wrapper.signature = sig_str_vectorization + # wrapper.signature = sig_str_vectorization + wrapper.signature = p_sig.encoded return wrapper diff --git a/py/server/tests/test_numba_guvectorize.py b/py/server/tests/test_numba_guvectorize.py index 096d88902f3..2fa6c42b0e9 100644 --- a/py/server/tests/test_numba_guvectorize.py +++ b/py/server/tests/test_numba_guvectorize.py @@ -66,7 +66,10 @@ def g(x, dummy, res): res[0] = min(x) res[1] = max(x) - t = empty_table(10).update(["X=i%3", "Y=i"]).group_by("X").update("Z=g(Y,dummy)") + # convert dummy to a Java array + # TODO this is a hack, we might want to add a helper function for QLP to call to get the type of a PyObject arg + j_array = dtypes.array(dtypes.int64, dummy) + t = empty_table(10).update(["X=i%3", "Y=i"]).group_by("X").update("Z=g(Y,j_array)") self.assertEqual(t.columns[2].data_type, dtypes.long_array) def test_np_on_java_array(self): @@ -78,7 +81,11 @@ def g(x, dummy, res): res[0] = np.min(x) res[1] = np.max(x) - t = empty_table(10).update(["X=i%3", "Y=ii"]).group_by("X").update("Z=g(Y,dummy)") + # convert dummy to a Java array + # TODO this is a hack, we might want to add a helper function for QLP to call to get the type of a PyObject arg + j_array = dtypes.array(dtypes.int64, dummy) + + t = empty_table(10).update(["X=i%3", "Y=ii"]).group_by("X").update("Z=g(Y,j_array)") self.assertEqual(t.columns[2].data_type, dtypes.long_array) def test_np_on_java_array2(self): @@ -99,7 +106,7 @@ def g(x, res): with self.assertRaises(DHError) as cm: t = empty_table(10).update(["X=i%3", "Y=(double)ii"]).group_by("X").update("Z=g(Y)") - self.assertIn("Argument 1", str(cm.exception)) + self.assertIn("g: Expected argument (1)", str(cm.exception)) if __name__ == '__main__': diff --git a/py/server/tests/test_udf_numpy_args.py b/py/server/tests/test_udf_numpy_args.py index f77bc615def..55c688b4a8f 100644 --- a/py/server/tests/test_udf_numpy_args.py +++ b/py/server/tests/test_udf_numpy_args.py @@ -428,5 +428,13 @@ def f(x: {p_type}) -> bool: # note typing t = empty_table(1).update(["X = i", f"Y = f((float)X)"]) self.assertEqual(1, t.to_string(cols="Y").count("true")) + def test_np_typehints_mismatch(self): + def f(x: float) -> bool: + return True + + with self.assertRaises(DHError) as cm: + t = empty_table(1).update(["X = i", "Y = f(ii)"]) + self.assertRegex(str(cm.exception), "f: Expect") + if __name__ == "__main__": unittest.main() From d54b5c69d695fd48a3f1ba7512381fd083c6d92d Mon Sep 17 00:00:00 2001 From: jianfengmao Date: Fri, 15 Mar 2024 10:51:58 -0600 Subject: [PATCH 02/10] Fix spotless check errors --- .../table/impl/lang/QueryLanguageParser.java | 8 ++-- .../engine/util/PyCallableWrapperJpyImpl.java | 42 ++++++++++--------- 2 files changed, 27 insertions(+), 23 deletions(-) diff --git a/engine/table/src/main/java/io/deephaven/engine/table/impl/lang/QueryLanguageParser.java b/engine/table/src/main/java/io/deephaven/engine/table/impl/lang/QueryLanguageParser.java index e6a081297b5..e3e5fc50851 100644 --- a/engine/table/src/main/java/io/deephaven/engine/table/impl/lang/QueryLanguageParser.java +++ b/engine/table/src/main/java/io/deephaven/engine/table/impl/lang/QueryLanguageParser.java @@ -2786,10 +2786,10 @@ private void prepareVectorizationArgs( } // TODO related to core#709, but should be covered by PyCallableWrapper.verifyArguments, needs to verify -// if (!isSafelyCoerceable(argTypes[i], paramTypes.get(i))) { -// throw new PythonCallVectorizationFailure("Python vectorized function argument type mismatch: " + n + " " -// + argTypes[i].getSimpleName() + " -> " + paramTypes.get(i).getSimpleName()); -// } + // if (!isSafelyCoerceable(argTypes[i], paramTypes.get(i))) { + // throw new PythonCallVectorizationFailure("Python vectorized function argument type mismatch: " + n + " " + // + argTypes[i].getSimpleName() + " -> " + paramTypes.get(i).getSimpleName()); + // } } } diff --git a/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapperJpyImpl.java b/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapperJpyImpl.java index e7f7c977918..c3171176cc7 100644 --- a/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapperJpyImpl.java +++ b/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapperJpyImpl.java @@ -198,23 +198,24 @@ public void parseSignature() { throw new IllegalStateException("Signature should always be available."); } -// List> paramTypes = new ArrayList<>(); -// for (char numpyTypeChar : signatureString.toCharArray()) { -// if (numpyTypeChar != '-') { -// Class paramType = numpyType2JavaClass.get(numpyTypeChar); -// if (paramType == null) { -// throw new IllegalStateException( -// "Parameters of vectorized functions should always be of integral, floating point, boolean, String, or Object type: " -// + numpyTypeChar + " of " + signatureString); -// } -// paramTypes.add(paramType); -// } else { -// break; -// } -// } -// this.paramTypes = paramTypes; + // List> paramTypes = new ArrayList<>(); + // for (char numpyTypeChar : signatureString.toCharArray()) { + // if (numpyTypeChar != '-') { + // Class paramType = numpyType2JavaClass.get(numpyTypeChar); + // if (paramType == null) { + // throw new IllegalStateException( + // "Parameters of vectorized functions should always be of integral, floating point, boolean, String, or Object + // type: " + // + numpyTypeChar + " of " + signatureString); + // } + // paramTypes.add(paramType); + // } else { + // break; + // } + // } + // this.paramTypes = paramTypes; String pyEncodedParamsStr = signatureString.split("->")[0]; - if (!pyEncodedParamsStr.isEmpty()){ + if (!pyEncodedParamsStr.isEmpty()) { String[] pyEncodedParams = pyEncodedParamsStr.split(","); for (int i = 0; i < pyEncodedParams.length; i++) { String[] paramDetail = pyEncodedParams[i].split(":"); @@ -263,13 +264,16 @@ public void verifyArguments(Class[] argTypes) { String callableName = pyCallable.getAttribute("__name__").toString(); if (argTypes.length != parameters.size()) { - throw new IllegalArgumentException(callableName + ": " + "Expected " + parameters.size() + " arguments, got " + argTypes.length); + throw new IllegalArgumentException( + callableName + ": " + "Expected " + parameters.size() + " arguments, got " + argTypes.length); } for (int i = 0; i < argTypes.length; i++) { Set> types = parameters.get(i).getPossibleTypes(); - if (!types.contains(argTypes[i]) && !types.contains(Object.class) && !isSafelyCastable(types, argTypes[i])) { + if (!types.contains(argTypes[i]) && !types.contains(Object.class) + && !isSafelyCastable(types, argTypes[i])) { throw new IllegalArgumentException( - callableName + ": " + "Expected argument (" + parameters.get(i).getName() + ") to be one of " + parameters.get(i).getPossibleTypes() + ", got " + callableName + ": " + "Expected argument (" + parameters.get(i).getName() + ") to be one of " + + parameters.get(i).getPossibleTypes() + ", got " + argTypes[i]); } } From 37deaded1e020caa1d9fa641a9e3032287c3b56e Mon Sep 17 00:00:00 2001 From: jianfengmao Date: Mon, 18 Mar 2024 13:39:18 -0600 Subject: [PATCH 03/10] Fix bugs and old test cases --- .../table/impl/lang/QueryLanguageParser.java | 35 ++++++++++++- .../engine/util/PyCallableWrapperJpyImpl.java | 35 +++++++++---- py/server/tests/test_udf_numpy_args.py | 50 ++++++++++++++----- .../tests/test_udf_return_java_values.py | 2 +- py/server/tests/test_vectorization.py | 2 +- 5 files changed, 98 insertions(+), 26 deletions(-) diff --git a/engine/table/src/main/java/io/deephaven/engine/table/impl/lang/QueryLanguageParser.java b/engine/table/src/main/java/io/deephaven/engine/table/impl/lang/QueryLanguageParser.java index e3e5fc50851..88faecd3d46 100644 --- a/engine/table/src/main/java/io/deephaven/engine/table/impl/lang/QueryLanguageParser.java +++ b/engine/table/src/main/java/io/deephaven/engine/table/impl/lang/QueryLanguageParser.java @@ -153,7 +153,7 @@ public final class QueryLanguageParser extends GenericVisitorAdapter, Q private final Map> staticImportLookupCache = new HashMap<>(); // We need some class to represent null. We know for certain that this one won't be used... - private static final Class NULL_CLASS = QueryLanguageParser.class; + public static final Class NULL_CLASS = QueryLanguageParser.class; /** * The result of the QueryLanguageParser for the expression passed given to the constructor. @@ -1968,6 +1968,39 @@ public static boolean isWideningPrimitiveConversion(Class original, Class return false; } + public static boolean isLosslessWideningPrimitiveConversion(Class original, Class target) { + if (original == null || !original.isPrimitive() || target == null || !target.isPrimitive() + || original.equals(void.class) || target.equals(void.class)) { + throw new IllegalArgumentException("Arguments must be a primitive type (excluding void)!"); + } + + if (original.equals(target)) { + return true; + } + + LanguageParserPrimitiveType originalEnum = LanguageParserPrimitiveType.getPrimitiveType(original); + + switch (originalEnum) { + case BytePrimitive: + if (target == short.class) + return true; + case ShortPrimitive: + case CharPrimitive: + if (target == int.class) + return true; + case IntPrimitive: + if (target == long.class) + return true; + break; + case FloatPrimitive: + if (target == double.class) + return true; + break; + } + + return false; + } + private enum LanguageParserPrimitiveType { // Including "Enum" (or really, any differentiating string) in these names is important. They're used // in a switch() statement, which apparently does not support qualified names. And we can't use diff --git a/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapperJpyImpl.java b/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapperJpyImpl.java index c3171176cc7..aadce5e6422 100644 --- a/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapperJpyImpl.java +++ b/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapperJpyImpl.java @@ -3,6 +3,7 @@ // package io.deephaven.engine.util; +import io.deephaven.engine.table.impl.lang.QueryLanguageParser; import io.deephaven.engine.table.impl.select.python.ArgumentsChunked; import io.deephaven.internal.log.LoggerFactory; import io.deephaven.io.logger.Logger; @@ -12,7 +13,9 @@ import java.time.Instant; import java.util.*; -import static io.deephaven.engine.table.impl.lang.QueryLanguageParser.isWideningPrimitiveConversion; +import static io.deephaven.engine.table.impl.lang.QueryLanguageParser.NULL_CLASS; +import static io.deephaven.engine.table.impl.lang.QueryLanguageParser.isLosslessWideningPrimitiveConversion; +import static io.deephaven.util.type.TypeUtils.getUnboxedType; /** * When given a pyObject that is a callable, we stick it inside the callable wrapper, which implements a call() varargs @@ -228,6 +231,11 @@ public void parseSignature() { // skip the array type code ti++; possibleTypes.add(numpyType2JavaArrayClass.get(paramTypeCodes.charAt(ti))); + if (paramTypeCodes.charAt(ti) == '?') { + possibleTypes.add(Boolean[].class); + } + } else if (typeCode == 'N') { + possibleTypes.add(NULL_CLASS); } else { possibleTypes.add(numpyType2JavaClass.get(typeCode)); } @@ -252,7 +260,7 @@ private boolean isSafelyCastable(Set> types, Class type) { if (t.isAssignableFrom(type)) { return true; } - if (t.isPrimitive() && type.isPrimitive() && isWideningPrimitiveConversion(type, t)) { + if (t.isPrimitive() && type.isPrimitive() && isLosslessWideningPrimitiveConversion(type, t)) { return true; } } @@ -263,18 +271,25 @@ private boolean isSafelyCastable(Set> types, Class type) { public void verifyArguments(Class[] argTypes) { String callableName = pyCallable.getAttribute("__name__").toString(); - if (argTypes.length != parameters.size()) { - throw new IllegalArgumentException( - callableName + ": " + "Expected " + parameters.size() + " arguments, got " + argTypes.length); - } + // if (argTypes.length > parameters.size()) { + // throw new IllegalArgumentException( + // callableName + ": " + "Expected " + parameters.size() + " or fewer arguments, got " + argTypes.length); + // } for (int i = 0; i < argTypes.length; i++) { - Set> types = parameters.get(i).getPossibleTypes(); - if (!types.contains(argTypes[i]) && !types.contains(Object.class) - && !isSafelyCastable(types, argTypes[i])) { + Set> types = + parameters.get(i > parameters.size() - 1 ? parameters.size() - 1 : i).getPossibleTypes(); + // Object is a catch-all type, so we don't need to check for it + if (argTypes[i] == Object.class) { + continue; + } + + Class t = getUnboxedType(argTypes[i]) == null ? argTypes[i] : getUnboxedType(argTypes[i]); + if (!types.contains(t) && !types.contains(Object.class) + && !isSafelyCastable(types, t)) { throw new IllegalArgumentException( callableName + ": " + "Expected argument (" + parameters.get(i).getName() + ") to be one of " + parameters.get(i).getPossibleTypes() + ", got " - + argTypes[i]); + + (argTypes[i].equals(NULL_CLASS) ? "null" : argTypes[i])); } } }; diff --git a/py/server/tests/test_udf_numpy_args.py b/py/server/tests/test_udf_numpy_args.py index 55c688b4a8f..fbb913520ec 100644 --- a/py/server/tests/test_udf_numpy_args.py +++ b/py/server/tests/test_udf_numpy_args.py @@ -218,7 +218,7 @@ def f11(p1: Union[float, np.float32]) -> bool: with self.assertRaises(DHError) as cm: t = empty_table(10).update(["X1 = f11(i)"]) - def f2(p1: Union[np.int16, np.float64]) -> Union[Optional[bool]]: + def f2(p1: Union[np.int32, np.float64]) -> Union[Optional[bool]]: return bool(p1) t = empty_table(10).update(["X1 = f2(i)"]) @@ -231,7 +231,7 @@ def f21(p1: Union[np.int16, np.float64]) -> Union[Optional[bool], int]: with self.assertRaises(DHError) as cm: t = empty_table(10).update(["X1 = f21(i)"]) - def f3(p1: Union[np.int16, np.float64], p2=None) -> bool: + def f3(p1: Union[np.int32, np.float64], p2=None) -> bool: return bool(p1) t = empty_table(10).update(["X1 = f3(i)"]) @@ -244,7 +244,7 @@ def f4(p1: Union[np.int16, np.float64], p2=None) -> bool: self.assertEqual(t.columns[0].data_type, dtypes.bool_) with self.assertRaises(DHError) as cm: t = empty_table(10).update(["X1 = f4(now())"]) - self.assertRegex(str(cm.exception), "Argument .* is not compatible with annotation*") + self.assertRegex(str(cm.exception), "f4: Expected .* got .*Instant") def f41(p1: Union[np.int16, np.float64, Union[Any]], p2=None) -> bool: return bool(p1) @@ -266,7 +266,7 @@ def f5(col1, col2: np.ndarray[np.int32]) -> bool: t = t.update(["X1 = f5(X, Y)"]) with self.assertRaises(DHError) as cm: t = t.update(["X1 = f5(X, null)"]) - self.assertRegex(str(cm.exception), "Argument .* is not compatible with annotation*") + self.assertRegex(str(cm.exception), "f5: Expected .* got null") def f51(col1, col2: Optional[np.ndarray[np.int32]]) -> bool: return np.nanmean(col2) == np.mean(col2) @@ -287,7 +287,7 @@ def f6(*args: np.int32, col2: np.ndarray[np.int32]) -> bool: with self.assertRaises(DHError) as cm: t1 = t.update(["X1 = f6(X, Y=null)"]) - self.assertIn("not compatible with annotation", str(cm.exception)) + self.assertIn("f6: Expected argument (col2) to be one of [class [I], got boolean", str(cm.exception)) def test_str_bool_datetime_array(self): with self.subTest("str"): @@ -299,7 +299,7 @@ def f1(p1: np.ndarray[str], p2=None) -> bool: self.assertEqual(t1.columns[2].data_type, dtypes.bool_) with self.assertRaises(DHError) as cm: t2 = t.update(["X1 = f1(null, Y )"]) - self.assertRegex(str(cm.exception), "Argument .* is not compatible with annotation*") + self.assertRegex(str(cm.exception), "f1: Expected .* got null") def f11(p1: Union[np.ndarray[str], None], p2=None) -> bool: return bool(len(p1)) if p1 is not None else False @@ -315,7 +315,7 @@ def f2(p1: np.ndarray[np.datetime64], p2=None) -> bool: self.assertEqual(t1.columns[2].data_type, dtypes.bool_) with self.assertRaises(DHError) as cm: t2 = t.update(["X1 = f2(null, Y )"]) - self.assertRegex(str(cm.exception), "Argument .* is not compatible with annotation*") + self.assertRegex(str(cm.exception), "f2: Expected .* got null") def f21(p1: Union[np.ndarray[np.datetime64], None], p2=None) -> bool: return bool(len(p1)) if p1 is not None else False @@ -337,7 +337,7 @@ def f3(p1: np.ndarray[np.bool_], p2=None) -> bool: self.assertEqual(t1.columns[2].data_type, dtypes.bool_) with self.assertRaises(DHError) as cm: t2 = t.update(["X1 = f3(null, Y )"]) - self.assertRegex(str(cm.exception), "Argument 'p1': None is not compatible with annotation") + self.assertRegex(str(cm.exception), "f3: Expected .* got null") def f31(p1: Optional[np.ndarray[bool]], p2=None) -> bool: return bool(len(p1)) if p1 is not None else False @@ -405,9 +405,9 @@ def f(x: {p_type}) -> bool: # note typing t = empty_table(1).update(["X = i", f"Y = f(({p_type})X)"]) self.assertEqual(1, t.to_string(cols="Y").count("true")) - - np_int_types = {"np.int8", "np.int16", "np.int32", "np.int64"} - for p_type in np_int_types: + def test_np_typehints(self): + widening_np_int_types = {"np.int32", "np.int64"} + for p_type in widening_np_int_types: with self.subTest(p_type): func_str = f""" def f(x: {p_type}) -> bool: # note typing @@ -417,8 +417,20 @@ def f(x: {p_type}) -> bool: # note typing t = empty_table(1).update(["X = i", f"Y = f(X)"]) self.assertEqual(1, t.to_string(cols="Y").count("true")) - np_floating_types = {"np.float32", "np.float64"} - for p_type in np_floating_types: + narrowing_np_int_types = {"np.int8", "np.int16"} + for p_type in narrowing_np_int_types: + with self.subTest(p_type): + func_str = f""" +def f(x: {p_type}) -> bool: # note typing + return type(x) == {p_type} +""" + exec(func_str, globals()) + with self.assertRaises(DHError) as cm: + t = empty_table(1).update(["X = i", f"Y = f(X)"]) + self.assertRegex(str(cm.exception), "f: Expect") + + widening_np_floating_types = {"np.float32", "np.float64"} + for p_type in widening_np_floating_types: with self.subTest(p_type): func_str = f""" def f(x: {p_type}) -> bool: # note typing @@ -428,6 +440,18 @@ def f(x: {p_type}) -> bool: # note typing t = empty_table(1).update(["X = i", f"Y = f((float)X)"]) self.assertEqual(1, t.to_string(cols="Y").count("true")) + int_to_floating_types = {"np.float32", "np.float64"} + for p_type in int_to_floating_types: + with self.subTest(p_type): + func_str = f""" +def f(x: {p_type}) -> bool: # note typing + return type(x) == {p_type} +""" + exec(func_str, globals()) + with self.assertRaises(DHError) as cm: + t = empty_table(1).update(["X = i", f"Y = f(X)"]) + self.assertRegex(str(cm.exception), "f: Expect") + def test_np_typehints_mismatch(self): def f(x: float) -> bool: return True diff --git a/py/server/tests/test_udf_return_java_values.py b/py/server/tests/test_udf_return_java_values.py index 7c1f55c5827..129105d5698 100644 --- a/py/server/tests/test_udf_return_java_values.py +++ b/py/server/tests/test_udf_return_java_values.py @@ -287,7 +287,7 @@ def test_np_ufunc(self): nbsin = numba.vectorize([numba.float64(numba.float64)])(np.sin) # this is the workaround that utilizes vectorization and type inference - @numba.vectorize([numba.float64(numba.float64)], nopython=True) + @numba.vectorize([numba.float64(numba.int64)], nopython=True) def nbsin(x): return np.sin(x) t3 = empty_table(10).update(["X3 = nbsin(i)"]) diff --git a/py/server/tests/test_vectorization.py b/py/server/tests/test_vectorization.py index d7532647640..2685fa843d5 100644 --- a/py/server/tests/test_vectorization.py +++ b/py/server/tests/test_vectorization.py @@ -269,7 +269,7 @@ def sinc2(x): self.assertEqual(t.columns[1].data_type, dtypes.PyObject) def test_optional_annotations(self): - def pyfunc(p1: np.int32, p2: np.int32, p3: Optional[np.int32]) -> Optional[int]: + def pyfunc(p1: np.int32, p2: np.int64, p3: Optional[np.int32]) -> Optional[int]: total = p1 + p2 + p3 return None if total % 3 == 0 else total From 5124581328f7d2ffc512f1c4027c092ae13ea837 Mon Sep 17 00:00:00 2001 From: jianfengmao Date: Tue, 19 Mar 2024 13:27:14 -0600 Subject: [PATCH 04/10] Refactor parsing code of UDF signatures --- .../table/impl/lang/QueryLanguageParser.java | 6 - .../engine/util/PyCallableWrapperJpyImpl.java | 25 +--- py/server/deephaven/_udf.py | 113 ++++++++++++++++-- py/server/deephaven/dtypes.py | 69 +---------- ...est_udf_numpy_args.py => test_udf_args.py} | 41 +++++-- .../tests/test_udf_return_java_values.py | 3 +- py/server/tests/test_vectorization.py | 2 +- 7 files changed, 145 insertions(+), 114 deletions(-) rename py/server/tests/{test_udf_numpy_args.py => test_udf_args.py} (92%) diff --git a/engine/table/src/main/java/io/deephaven/engine/table/impl/lang/QueryLanguageParser.java b/engine/table/src/main/java/io/deephaven/engine/table/impl/lang/QueryLanguageParser.java index 88faecd3d46..56b7d731d1d 100644 --- a/engine/table/src/main/java/io/deephaven/engine/table/impl/lang/QueryLanguageParser.java +++ b/engine/table/src/main/java/io/deephaven/engine/table/impl/lang/QueryLanguageParser.java @@ -2817,12 +2817,6 @@ private void prepareVectorizationArgs( } else { throw new IllegalStateException("Vectorizability check failed: " + n); } - - // TODO related to core#709, but should be covered by PyCallableWrapper.verifyArguments, needs to verify - // if (!isSafelyCoerceable(argTypes[i], paramTypes.get(i))) { - // throw new PythonCallVectorizationFailure("Python vectorized function argument type mismatch: " + n + " " - // + argTypes[i].getSimpleName() + " -> " + paramTypes.get(i).getSimpleName()); - // } } } diff --git a/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapperJpyImpl.java b/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapperJpyImpl.java index aadce5e6422..5cbede20417 100644 --- a/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapperJpyImpl.java +++ b/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapperJpyImpl.java @@ -201,22 +201,6 @@ public void parseSignature() { throw new IllegalStateException("Signature should always be available."); } - // List> paramTypes = new ArrayList<>(); - // for (char numpyTypeChar : signatureString.toCharArray()) { - // if (numpyTypeChar != '-') { - // Class paramType = numpyType2JavaClass.get(numpyTypeChar); - // if (paramType == null) { - // throw new IllegalStateException( - // "Parameters of vectorized functions should always be of integral, floating point, boolean, String, or Object - // type: " - // + numpyTypeChar + " of " + signatureString); - // } - // paramTypes.add(paramType); - // } else { - // break; - // } - // } - // this.paramTypes = paramTypes; String pyEncodedParamsStr = signatureString.split("->")[0]; if (!pyEncodedParamsStr.isEmpty()) { String[] pyEncodedParams = pyEncodedParamsStr.split(","); @@ -271,14 +255,13 @@ private boolean isSafelyCastable(Set> types, Class type) { public void verifyArguments(Class[] argTypes) { String callableName = pyCallable.getAttribute("__name__").toString(); - // if (argTypes.length > parameters.size()) { - // throw new IllegalArgumentException( - // callableName + ": " + "Expected " + parameters.size() + " or fewer arguments, got " + argTypes.length); - // } for (int i = 0; i < argTypes.length; i++) { Set> types = parameters.get(i > parameters.size() - 1 ? parameters.size() - 1 : i).getPossibleTypes(); - // Object is a catch-all type, so we don't need to check for it + + // to prevent the unpacking of an array column when calling a Python function, we prefix the column accessor + // with a cast to generic Object type, until we can find a way to convey that info, we'll just skip the + // check for Object type input if (argTypes[i] == Object.class) { continue; } diff --git a/py/server/deephaven/_udf.py b/py/server/deephaven/_udf.py index 71f684f1ef9..99708161a09 100644 --- a/py/server/deephaven/_udf.py +++ b/py/server/deephaven/_udf.py @@ -5,9 +5,13 @@ import inspect import re import sys +import typing from dataclasses import dataclass, field +from datetime import datetime from functools import wraps -from typing import Callable, List, Any, Union, Tuple, _GenericAlias, Set +from typing import Callable, List, Any, Union, Tuple, _GenericAlias, Set, Optional, Sequence + +import pandas as pd from deephaven._dep import soft_dependency @@ -17,9 +21,9 @@ import numpy as np from deephaven import DHError, dtypes -from deephaven.dtypes import _np_ndarray_component_type, _np_dtype_char, _NUMPY_INT_TYPE_CODES, \ - _NUMPY_FLOATING_TYPE_CODES, _component_np_dtype_char, _J_ARRAY_NP_TYPE_MAP, _PRIMITIVE_DTYPE_NULL_MAP, _scalar, \ - _BUILDABLE_ARRAY_DTYPE_MAP +from deephaven.dtypes import _NUMPY_INT_TYPE_CODES, _NUMPY_FLOATING_TYPE_CODES, _J_ARRAY_NP_TYPE_MAP, \ + _PRIMITIVE_DTYPE_NULL_MAP, _scalar, \ + _BUILDABLE_ARRAY_DTYPE_MAP, DType from deephaven.jcompat import _j_array_to_numpy_array from deephaven.time import to_np_datetime64 @@ -27,7 +31,6 @@ test_vectorization = False vectorized_count = 0 - _SUPPORTED_NP_TYPE_CODES = {"b", "h", "H", "i", "l", "f", "d", "?", "U", "M", "O"} @@ -80,7 +83,7 @@ def _encode_param_type(t: type) -> str: return "N" # find the component type if it is numpy ndarray - component_type = _np_ndarray_component_type(t) + component_type = _component_np_dtype_char(t) if component_type: t = component_type @@ -92,6 +95,92 @@ def _encode_param_type(t: type) -> str: return tc +def _np_dtype_char(t: Union[type, str]) -> str: + """Returns the numpy dtype character code for the given type.""" + try: + np_dtype = np.dtype(t if t else "object") + if np_dtype.kind == "O": + if t in (datetime, pd.Timestamp): + return "M" + except TypeError: + np_dtype = np.dtype("object") + + return np_dtype.char + + +def _component_np_dtype_char(t: type) -> Optional[str]: + """Returns the numpy dtype character code for the given type's component type if the type is a Sequence type or + numpy ndarray, otherwise return None. """ + component_type = None + + if not component_type and sys.version_info.major == 3 and sys.version_info.minor > 8: + import types + if isinstance(t, types.GenericAlias) and issubclass(t.__origin__, Sequence): + component_type = t.__args__[0] + + if not component_type: + if isinstance(t, _GenericAlias) and issubclass(t.__origin__, Sequence): + component_type = t.__args__[0] + # if the component type is a DType, get its numpy type + if isinstance(component_type, DType): + component_type = component_type.np_type + + if not component_type: + if t == bytes or t == bytearray: + return "b" + + if not component_type: + component_type = _np_ndarray_component_type(t) + + if component_type: + return _np_dtype_char(component_type) + else: + return None + + +def _np_ndarray_component_type(t: type) -> Optional[type]: + """Returns the numpy ndarray component type if the type is a numpy ndarray, otherwise return None.""" + + # Py3.8: npt.NDArray can be used in Py 3.8 as a generic alias, but a specific alias (e.g. npt.NDArray[np.int64]) + # is an instance of a private class of np, yet we don't have a choice but to use it. And when npt.NDArray is used, + # the 1st argument is typing.Any, the 2nd argument is another generic alias of which the 1st argument is the + # component type + component_type = None + if sys.version_info.major == 3 and sys.version_info.minor == 8: + if isinstance(t, np._typing._generic_alias._GenericAlias) and t.__origin__ == np.ndarray: + component_type = t.__args__[1].__args__[0] + # Py3.9+, np.ndarray as a generic alias is only supported in Python 3.9+, also npt.NDArray is still available but a + # specific alias (e.g. npt.NDArray[np.int64]) now is an instance of typing.GenericAlias. + # when npt.NDArray is used, the 1st argument is typing.Any, the 2nd argument is another generic alias of which + # the 1st argument is the component type + # when np.ndarray is used, the 1st argument is the component type + if not component_type and sys.version_info.major == 3 and sys.version_info.minor > 8: + import types + if isinstance(t, types.GenericAlias) and t.__origin__ == np.ndarray: + nargs = len(t.__args__) + if nargs == 1: + component_type = t.__args__[0] + elif nargs == 2: # for npt.NDArray[np.int64], etc. + a0 = t.__args__[0] + a1 = t.__args__[1] + if a0 == typing.Any and isinstance(a1, types.GenericAlias): + component_type = a1.__args__[0] + return component_type + + +def _is_union_type(t: type) -> bool: + """Return True if the type is a Union type""" + if sys.version_info.major == 3 and sys.version_info.minor >= 10: + import types + if isinstance(t, types.UnionType): + return True + + if isinstance(t, _GenericAlias) and t.__origin__ == Union: + return True + + return False + + def _parse_param(name: str, annotation: Any) -> _ParsedParam: """ Parse a parameter annotation in a function's signature """ p_param = _ParsedParam(name) @@ -99,7 +188,7 @@ def _parse_param(name: str, annotation: Any) -> _ParsedParam: if annotation is inspect._empty: p_param.encoded_types.add("O") p_param.none_allowed = True - elif isinstance(annotation, _GenericAlias) and annotation.__origin__ == Union: + elif _is_union_type(annotation): for t in annotation.__args__: _parse_type_no_nested(annotation, p_param, t) else: @@ -149,7 +238,7 @@ def _parse_return_annotation(annotation: Any) -> _ParsedReturnAnnotation: t = annotation pra.orig_type = t - if isinstance(annotation, _GenericAlias) and annotation.__origin__ == Union and len(annotation.__args__) == 2: + if _is_union_type(annotation) and len(annotation.__args__) == 2: # if the annotation is a Union of two types, we'll use the non-None type if annotation.__args__[1] == type(None): # noqa: E721 t = annotation.__args__[0] @@ -170,7 +259,8 @@ def _parse_return_annotation(annotation: Any) -> _ParsedReturnAnnotation: if numba: - def _parse_numba_signature(fn: Union[numba.np.ufunc.gufunc.GUFunc, numba.np.ufunc.dufunc.DUFunc]) -> _ParsedSignature: + def _parse_numba_signature( + fn: Union[numba.np.ufunc.gufunc.GUFunc, numba.np.ufunc.dufunc.DUFunc]) -> _ParsedSignature: """ Parse a numba function's signature""" sigs = fn.types # in the format of ll->l, ff->f,dd->d,OO->O, etc. if sigs: @@ -261,7 +351,8 @@ def _parse_signature(fn: Callable) -> _ParsedSignature: t = eval(p.annotation, fn.__globals__) if isinstance(p.annotation, str) else p.annotation p_sig.params.append(_parse_param(n, t)) - t = eval(sig.return_annotation, fn.__globals__) if isinstance(sig.return_annotation, str) else sig.return_annotation + t = eval(sig.return_annotation, fn.__globals__) if isinstance(sig.return_annotation, + str) else sig.return_annotation p_sig.ret_annotation = _parse_return_annotation(t) return p_sig @@ -476,4 +567,4 @@ def wrapper(*args): global vectorized_count vectorized_count += 1 - return wrapper \ No newline at end of file + return wrapper diff --git a/py/server/deephaven/dtypes.py b/py/server/deephaven/dtypes.py index 7da00a0fe37..5aa6c8acca8 100644 --- a/py/server/deephaven/dtypes.py +++ b/py/server/deephaven/dtypes.py @@ -9,13 +9,10 @@ from __future__ import annotations import datetime -import sys -import typing -from typing import Any, Sequence, Callable, Dict, Type, Union, _GenericAlias, Optional +from typing import Any, Sequence, Callable, Dict, Type, Union, Optional import jpy import numpy as np -import numpy._typing as npt import pandas as pd from deephaven import DHError @@ -304,7 +301,7 @@ def array(dtype: DType, seq: Optional[Sequence], remap: Callable[[Any], Any] = N raise DHError(e, f"failed to create a Java {dtype.j_name} array.") from e -def from_jtype(j_class: Any) -> DType: +def from_jtype(j_class: Any) -> Optional[DType]: """ looks up a DType that matches the java type, if not found, creates a DType for it. """ if not j_class: return None @@ -391,65 +388,3 @@ def _scalar(x: Any, dtype: DType) -> Any: return x except: return x - - -def _np_dtype_char(t: Union[type, str]) -> str: - """Returns the numpy dtype character code for the given type.""" - try: - np_dtype = np.dtype(t if t else "object") - if np_dtype.kind == "O": - if t in (datetime.datetime, pd.Timestamp): - return "M" - except TypeError: - np_dtype = np.dtype("object") - - return np_dtype.char - - -def _component_np_dtype_char(t: type) -> Optional[str]: - """Returns the numpy dtype character code for the given type's component type if the type is a Sequence type or - numpy ndarray, otherwise return None. """ - component_type = None - if isinstance(t, _GenericAlias) and issubclass(t.__origin__, Sequence): - component_type = t.__args__[0] - # if the component type is a DType, get its numpy type - if isinstance(component_type, DType): - component_type = component_type.np_type - - if not component_type: - component_type = _np_ndarray_component_type(t) - - if component_type: - return _np_dtype_char(component_type) - else: - return None - - -def _np_ndarray_component_type(t: type) -> Optional[type]: - """Returns the numpy ndarray component type if the type is a numpy ndarray, otherwise return None.""" - - # Py3.8: npt.NDArray can be used in Py 3.8 as a generic alias, but a specific alias (e.g. npt.NDArray[np.int64]) - # is an instance of a private class of np, yet we don't have a choice but to use it. And when npt.NDArray is used, - # the 1st argument is typing.Any, the 2nd argument is another generic alias of which the 1st argument is the - # component type - component_type = None - if sys.version_info.major == 3 and sys.version_info.minor == 8: - if isinstance(t, np._typing._generic_alias._GenericAlias) and t.__origin__ == np.ndarray: - component_type = t.__args__[1].__args__[0] - # Py3.9+, np.ndarray as a generic alias is only supported in Python 3.9+, also npt.NDArray is still available but a - # specific alias (e.g. npt.NDArray[np.int64]) now is an instance of typing.GenericAlias. - # when npt.NDArray is used, the 1st argument is typing.Any, the 2nd argument is another generic alias of which - # the 1st argument is the component type - # when np.ndarray is used, the 1st argument is the component type - if not component_type and sys.version_info.major == 3 and sys.version_info.minor > 8: - import types - if isinstance(t, types.GenericAlias) and (issubclass(t.__origin__, Sequence) or t.__origin__ == np.ndarray): # novermin - nargs = len(t.__args__) - if nargs == 1: - component_type = t.__args__[0] - elif nargs == 2: # for npt.NDArray[np.int64], etc. - a0 = t.__args__[0] - a1 = t.__args__[1] - if a0 == typing.Any and isinstance(a1, types.GenericAlias): - component_type = a1.__args__[0] - return component_type diff --git a/py/server/tests/test_udf_numpy_args.py b/py/server/tests/test_udf_args.py similarity index 92% rename from py/server/tests/test_udf_numpy_args.py rename to py/server/tests/test_udf_args.py index fbb913520ec..8e58050a9d6 100644 --- a/py/server/tests/test_udf_numpy_args.py +++ b/py/server/tests/test_udf_args.py @@ -2,7 +2,7 @@ # Copyright (c) 2016-2024 Deephaven Data Labs and Patent Pending # import typing -from typing import Optional, Union, Any +from typing import Optional, Union, Any, Sequence import unittest import numpy as np @@ -452,13 +452,40 @@ def f(x: {p_type}) -> bool: # note typing t = empty_table(1).update(["X = i", f"Y = f(X)"]) self.assertRegex(str(cm.exception), "f: Expect") - def test_np_typehints_mismatch(self): - def f(x: float) -> bool: - return True + def test_sequence_args(self): + with self.subTest("Sequence"): + def f(x: Sequence[int]) -> bool: + return True + + with self.assertRaises(DHError) as cm: + t = empty_table(1).update(["X = i", "Y = f(ii)"]) + self.assertRegex(str(cm.exception), "f: Expect") + + t = empty_table(1).update(["X = i", "Y = ii"]).group_by("X").update(["Z = f(Y.toArray())"]) + self.assertEqual(t.columns[2].data_type, dtypes.bool_) + + with self.subTest("bytes"): + def f(x: bytes) -> bool: + return True + + with self.assertRaises(DHError) as cm: + t = empty_table(1).update(["X = i", "Y = f(ii)"]) + self.assertRegex(str(cm.exception), "f: Expect") + + t = empty_table(1).update(["X = i", "Y = (byte)(ii % 128)"]).group_by("X").update(["Z = f(Y.toArray())"]) + self.assertEqual(t.columns[2].data_type, dtypes.bool_) + + with self.subTest("bytearray"): + def f(x: bytearray) -> bool: + return True + + with self.assertRaises(DHError) as cm: + t = empty_table(1).update(["X = i", "Y = f(ii)"]) + self.assertRegex(str(cm.exception), "f: Expect") + + t = empty_table(1).update(["X = i", "Y = (byte)(ii % 128)"]).group_by("X").update(["Z = f(Y.toArray())"]) + self.assertEqual(t.columns[2].data_type, dtypes.bool_) - with self.assertRaises(DHError) as cm: - t = empty_table(1).update(["X = i", "Y = f(ii)"]) - self.assertRegex(str(cm.exception), "f: Expect") if __name__ == "__main__": unittest.main() diff --git a/py/server/tests/test_udf_return_java_values.py b/py/server/tests/test_udf_return_java_values.py index 129105d5698..d42b6f8465c 100644 --- a/py/server/tests/test_udf_return_java_values.py +++ b/py/server/tests/test_udf_return_java_values.py @@ -54,7 +54,8 @@ def test_array_return(self): "np.str_": dtypes.string_array, "np.uint16": dtypes.char_array, } - container_types = ["List", "Tuple", "list", "tuple", "Sequence", "np.ndarray"] + # container_types = ["List", "Tuple", "list", "tuple", "Sequence", "np.ndarray"] + container_types = ["list"] for component_type, dh_dtype in component_types.items(): for container_type in container_types: with self.subTest(component_type=component_type, container_type=container_type): diff --git a/py/server/tests/test_vectorization.py b/py/server/tests/test_vectorization.py index 2685fa843d5..7632be171e7 100644 --- a/py/server/tests/test_vectorization.py +++ b/py/server/tests/test_vectorization.py @@ -15,7 +15,7 @@ from deephaven._udf import _dh_vectorize as dh_vectorize from tests.testbase import BaseTestCase -from tests.test_udf_numpy_args import _J_TYPE_NULL_MAP, _J_TYPE_NP_DTYPE_MAP, _J_TYPE_J_ARRAY_TYPE_MAP +from tests.test_udf_args import _J_TYPE_NULL_MAP, _J_TYPE_NP_DTYPE_MAP, _J_TYPE_J_ARRAY_TYPE_MAP class VectorizationTestCase(BaseTestCase): From 6de6edf97fe7c8d555d67b492e2b8c69132d50fa Mon Sep 17 00:00:00 2001 From: jianfengmao Date: Fri, 22 Mar 2024 08:19:41 -0600 Subject: [PATCH 05/10] Refactoring interface --- .../table/impl/lang/QueryLanguageParser.java | 23 +++--- .../impl/select/AbstractConditionFilter.java | 5 +- .../table/impl/select/DhFormulaColumn.java | 3 +- .../engine/util/PyCallableWrapper.java | 79 +++++++++---------- .../engine/util/PyCallableWrapperJpyImpl.java | 34 +++----- .../impl/lang/PyCallableWrapperDummyImpl.java | 14 +--- py/server/tests/test_vectorization.py | 1 + 7 files changed, 70 insertions(+), 89 deletions(-) diff --git a/engine/table/src/main/java/io/deephaven/engine/table/impl/lang/QueryLanguageParser.java b/engine/table/src/main/java/io/deephaven/engine/table/impl/lang/QueryLanguageParser.java index 56b7d731d1d..0329488adda 100644 --- a/engine/table/src/main/java/io/deephaven/engine/table/impl/lang/QueryLanguageParser.java +++ b/engine/table/src/main/java/io/deephaven/engine/table/impl/lang/QueryLanguageParser.java @@ -1985,8 +1985,8 @@ public static boolean isLosslessWideningPrimitiveConversion(Class original, C if (target == short.class) return true; case ShortPrimitive: - case CharPrimitive: - if (target == int.class) + case CharPrimitive: // char is unsigned, so it's a lossless conversion to int + if (target == int.class) // this covers all the smaller integer types return true; case IntPrimitive: if (target == long.class) @@ -2552,11 +2552,11 @@ private void verifyPyCallableArguments(@NotNull MethodCallExpr n, @NotNull Class } final PyCallableDetails pyCallableDetails = n.getData(QueryLanguageParserDataKeys.PY_CALLABLE_DETAILS); final String pyMethodName = pyCallableDetails.pythonMethodName; - final Object paramValueRaw = queryScopeVariables.get(pyMethodName); - if (!(paramValueRaw instanceof PyCallableWrapper)) { + final Object methodVar = queryScopeVariables.get(pyMethodName); + if (!(methodVar instanceof PyCallableWrapper)) { return; } - final PyCallableWrapper pyCallableWrapper = (PyCallableWrapper) paramValueRaw; + final PyCallableWrapper pyCallableWrapper = (PyCallableWrapper) methodVar; pyCallableWrapper.parseSignature(); pyCallableWrapper.verifyArguments(argTypes); } @@ -2608,7 +2608,7 @@ private Optional> pyCallableReturnType(@NotNull MethodCallExpr n) { } final PyCallableWrapper pyCallableWrapper = (PyCallableWrapper) paramValueRaw; pyCallableWrapper.parseSignature(); - return Optional.ofNullable(pyCallableWrapper.getReturnType()); + return Optional.ofNullable(pyCallableWrapper.getSignature().getReturnType()); } @NotNull @@ -2739,7 +2739,8 @@ private void checkVectorizability(@NotNull final MethodCallExpr n, pyCallableWrapper.parseSignature(); if (!pyCallableWrapper.isVectorizableReturnType()) { throw new PythonCallVectorizationFailure( - "Python function return type is not supported: " + pyCallableWrapper.getReturnType()); + "Python function return type is not supported: " + + pyCallableWrapper.getSignature().getReturnType()); } // Python vectorized functions(numba, DH) return arrays of primitive/Object types. This will break the generated @@ -2782,10 +2783,10 @@ private void checkVectorizability(@NotNull final MethodCallExpr n, } } - if (pyCallableWrapper.getNumParameters() != expressions.length) { + if (pyCallableWrapper.getSignature().getParameters().size() != expressions.length) { // note vectorization doesn't handle Python variadic arguments throw new PythonCallVectorizationFailure("Python function argument count mismatch: " + n + " " - + pyCallableWrapper.getNumParameters() + " vs. " + expressions.length); + + pyCallableWrapper.getSignature().getParameters().size() + " vs. " + expressions.length); } } @@ -2794,9 +2795,9 @@ private void prepareVectorizationArgs( Expression[] expressions, Class[] argTypes, PyCallableWrapper pyCallableWrapper) { - if (pyCallableWrapper.getNumParameters() != expressions.length) { + if (pyCallableWrapper.getSignature().getParameters().size() != expressions.length) { throw new PythonCallVectorizationFailure("Python function argument count mismatch: " + n + " " - + pyCallableWrapper.getNumParameters() + " vs. " + expressions.length); + + pyCallableWrapper.getSignature().getParameters().size() + " vs. " + expressions.length); } pyCallableWrapper.initializeChunkArguments(); diff --git a/engine/table/src/main/java/io/deephaven/engine/table/impl/select/AbstractConditionFilter.java b/engine/table/src/main/java/io/deephaven/engine/table/impl/select/AbstractConditionFilter.java index 95098cdb5b4..481d25065ee 100644 --- a/engine/table/src/main/java/io/deephaven/engine/table/impl/select/AbstractConditionFilter.java +++ b/engine/table/src/main/java/io/deephaven/engine/table/impl/select/AbstractConditionFilter.java @@ -267,7 +267,7 @@ private void checkAndInitializeVectorization(QueryLanguageParser.Result result, final PyCallableWrapperJpyImpl pyCallableWrapper = cws[0]; if (pyCallableWrapper.isVectorizable()) { - checkReturnType(result, pyCallableWrapper.getReturnType()); + checkReturnType(result, pyCallableWrapper.getSignature().getReturnType()); for (String variable : result.getVariablesUsed()) { if (variable.equals("i")) { @@ -284,7 +284,8 @@ private void checkAndInitializeVectorization(QueryLanguageParser.Result result, ArgumentsChunked argumentsChunked = pyCallableWrapper.buildArgumentsChunked(usedColumns); PyObject vectorized = pyCallableWrapper.vectorizedCallable(); DeephavenCompatibleFunction dcf = DeephavenCompatibleFunction.create(vectorized, - pyCallableWrapper.getReturnType(), usedColumns.toArray(new String[0]), argumentsChunked, true); + pyCallableWrapper.getSignature().getReturnType(), usedColumns.toArray(new String[0]), + argumentsChunked, true); setFilter(new ConditionFilter.ChunkFilter( dcf.toFilterKernel(), dcf.getColumnNames().toArray(new String[0]), diff --git a/engine/table/src/main/java/io/deephaven/engine/table/impl/select/DhFormulaColumn.java b/engine/table/src/main/java/io/deephaven/engine/table/impl/select/DhFormulaColumn.java index 6d788018888..e1babe4bec3 100644 --- a/engine/table/src/main/java/io/deephaven/engine/table/impl/select/DhFormulaColumn.java +++ b/engine/table/src/main/java/io/deephaven/engine/table/impl/select/DhFormulaColumn.java @@ -238,7 +238,8 @@ private void checkAndInitializeVectorization(Map> co PyObject vectorized = pyCallableWrapper.vectorizedCallable(); formulaColumnPython = FormulaColumnPython.create(this.columnName, DeephavenCompatibleFunction.create(vectorized, - pyCallableWrapper.getReturnType(), this.analyzedFormula.sourceDescriptor.sources, + pyCallableWrapper.getSignature().getReturnType(), + this.analyzedFormula.sourceDescriptor.sources, argumentsChunked, true)); formulaColumnPython.initDef(columnDefinitionMap); diff --git a/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapper.java b/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapper.java index d1baf80903c..bcbcfb5462a 100644 --- a/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapper.java +++ b/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapper.java @@ -20,10 +20,6 @@ public interface PyCallableWrapper { Object call(Object... args); - List getParameters(); - - int getNumParameters(); - boolean isVectorized(); boolean isVectorizable(); @@ -34,10 +30,47 @@ public interface PyCallableWrapper { void addChunkArgument(ChunkArgument chunkArgument); - Class getReturnType(); + Signature getSignature(); void verifyArguments(Class[] argTypes); + class Parameter { + private final String name; + private final Set> possibleTypes; + + + public Parameter(String name, Set> possibleTypes) { + this.name = name; + this.possibleTypes = possibleTypes; + } + + public Set> getPossibleTypes() { + return possibleTypes; + } + + public String getName() { + return name; + } + } + + class Signature { + private final List parameters; + private final Class returnType; + + public Signature(List parameters, Class returnType) { + this.parameters = parameters; + this.returnType = returnType; + } + + public List getParameters() { + return parameters; + } + + public Class getReturnType() { + return returnType; + } + } + abstract class ChunkArgument { private final Class type; @@ -94,40 +127,4 @@ public Object getValue() { boolean isVectorizableReturnType(); - class Signature { - private final List parameters; - private final Class returnType; - - public Signature(List parameters, Class returnType) { - this.parameters = parameters; - this.returnType = returnType; - } - - public List getParameters() { - return parameters; - } - - public Class getReturnType() { - return returnType; - } - } - class Parameter { - private final String name; - private final Set> possibleTypes; - - - public Parameter(String name, Set> possibleTypes) { - this.name = name; - this.possibleTypes = possibleTypes; - } - - public Set> getPossibleTypes() { - return possibleTypes; - } - - public String getName() { - return name; - } - } - } diff --git a/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapperJpyImpl.java b/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapperJpyImpl.java index 5cbede20417..a420c0eed27 100644 --- a/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapperJpyImpl.java +++ b/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapperJpyImpl.java @@ -3,7 +3,6 @@ // package io.deephaven.engine.util; -import io.deephaven.engine.table.impl.lang.QueryLanguageParser; import io.deephaven.engine.table.impl.select.python.ArgumentsChunked; import io.deephaven.internal.log.LoggerFactory; import io.deephaven.io.logger.Logger; @@ -58,8 +57,6 @@ public class PyCallableWrapperJpyImpl implements PyCallableWrapper { numpyType2JavaArrayClass.put('O', Object[].class); } - - /** * Ensure that the class initializer runs. */ @@ -85,14 +82,12 @@ public static void init() {} @Override public boolean isVectorizableReturnType() { parseSignature(); - return vectorizableReturnTypes.contains(returnType); + return vectorizableReturnTypes.contains(signature.getReturnType()); } private final PyObject pyCallable; - private String signatureString = null; - private List parameters = new ArrayList<>(); - private Class returnType; + private Signature signature; private boolean vectorizable = false; private boolean vectorized = false; private Collection chunkArguments; @@ -125,7 +120,7 @@ public ArgumentsChunked buildArgumentsChunked(List columnNames) { ((ColumnChunkArgument) arg).setSourceChunkIndex(chunkSourceIndex); } } - return new ArgumentsChunked(chunkArguments, returnType, numbaVectorized); + return new ArgumentsChunked(chunkArguments, signature.getReturnType(), numbaVectorized); } /** @@ -202,6 +197,7 @@ public void parseSignature() { } String pyEncodedParamsStr = signatureString.split("->")[0]; + List parameters = new ArrayList<>(); if (!pyEncodedParamsStr.isEmpty()) { String[] pyEncodedParams = pyEncodedParamsStr.split(","); for (int i = 0; i < pyEncodedParams.length; i++) { @@ -228,15 +224,18 @@ public void parseSignature() { } } - returnType = pyUdfDecoratedCallable.getAttribute("return_type", null); + Class returnType = pyUdfDecoratedCallable.getAttribute("return_type", null); if (returnType == null) { throw new IllegalStateException( "Python functions should always have an integral, floating point, boolean, String, arrays, or Object return type"); } if (returnType == boolean.class) { - this.returnType = Boolean.class; + returnType = Boolean.class; } + + signature = new Signature(parameters, returnType); + } private boolean isSafelyCastable(Set> types, Class type) { @@ -254,6 +253,7 @@ private boolean isSafelyCastable(Set> types, Class type) { public void verifyArguments(Class[] argTypes) { String callableName = pyCallable.getAttribute("__name__").toString(); + List parameters = signature.getParameters(); for (int i = 0; i < argTypes.length; i++) { Set> types = @@ -293,16 +293,6 @@ public Object call(Object... args) { return PythonScopeJpyImpl.convert(pyCallable.callMethod("__call__", args)); } - @Override - public List getParameters() { - return parameters; - } - - @Override - public int getNumParameters() { - return parameters.size(); - } - @Override public boolean isVectorized() { return vectorized; @@ -329,8 +319,8 @@ public void addChunkArgument(ChunkArgument chunkArgument) { } @Override - public Class getReturnType() { - return returnType; + public Signature getSignature() { + return signature; } } diff --git a/engine/table/src/test/java/io/deephaven/engine/table/impl/lang/PyCallableWrapperDummyImpl.java b/engine/table/src/test/java/io/deephaven/engine/table/impl/lang/PyCallableWrapperDummyImpl.java index a0fb1278949..70ea8a57f56 100644 --- a/engine/table/src/test/java/io/deephaven/engine/table/impl/lang/PyCallableWrapperDummyImpl.java +++ b/engine/table/src/test/java/io/deephaven/engine/table/impl/lang/PyCallableWrapperDummyImpl.java @@ -45,16 +45,6 @@ public Object call(Object... args) { throw new UnsupportedOperationException(); } - @Override - public List getParameters() { - return new ArrayList<>(); - } - - @Override - public int getNumParameters() { - return 0; - } - public List> getParamTypes() { return parameterTypes; } @@ -81,8 +71,8 @@ public void initializeChunkArguments() {} public void addChunkArgument(ChunkArgument ignored) {} @Override - public Class getReturnType() { - return Object.class; + public Signature getSignature() { + return null; } @Override diff --git a/py/server/tests/test_vectorization.py b/py/server/tests/test_vectorization.py index 7632be171e7..627ca44eb94 100644 --- a/py/server/tests/test_vectorization.py +++ b/py/server/tests/test_vectorization.py @@ -253,6 +253,7 @@ def my_sum(*args): source = new_table([int_col(c, [0, 1, 2, 3, 4, 5, 6]) for c in cols]) result = source.update(f"X = my_sum({','.join(cols)})") self.assertEqual(len(cols) + 1, len(result.columns)) + self.assertEqual(_udf.vectorized_count, 0) def test_enclosed_by_parentheses(self): def sinc(x) -> np.double: From b6d8cfca31d0900e4b3828357881e8047a6835ba Mon Sep 17 00:00:00 2001 From: jianfengmao Date: Fri, 22 Mar 2024 11:59:32 -0600 Subject: [PATCH 06/10] Code clean up/add more comments/test cases --- .../table/impl/lang/QueryLanguageParser.java | 32 -------------- .../engine/util/PyCallableWrapperJpyImpl.java | 43 ++++++++++++++----- py/server/deephaven/_udf.py | 9 +--- py/server/tests/test_udf_args.py | 9 ++++ .../tests/test_udf_return_java_values.py | 12 +++++- 5 files changed, 53 insertions(+), 52 deletions(-) diff --git a/engine/table/src/main/java/io/deephaven/engine/table/impl/lang/QueryLanguageParser.java b/engine/table/src/main/java/io/deephaven/engine/table/impl/lang/QueryLanguageParser.java index 0329488adda..83109869bad 100644 --- a/engine/table/src/main/java/io/deephaven/engine/table/impl/lang/QueryLanguageParser.java +++ b/engine/table/src/main/java/io/deephaven/engine/table/impl/lang/QueryLanguageParser.java @@ -1968,38 +1968,6 @@ public static boolean isWideningPrimitiveConversion(Class original, Class return false; } - public static boolean isLosslessWideningPrimitiveConversion(Class original, Class target) { - if (original == null || !original.isPrimitive() || target == null || !target.isPrimitive() - || original.equals(void.class) || target.equals(void.class)) { - throw new IllegalArgumentException("Arguments must be a primitive type (excluding void)!"); - } - - if (original.equals(target)) { - return true; - } - - LanguageParserPrimitiveType originalEnum = LanguageParserPrimitiveType.getPrimitiveType(original); - - switch (originalEnum) { - case BytePrimitive: - if (target == short.class) - return true; - case ShortPrimitive: - case CharPrimitive: // char is unsigned, so it's a lossless conversion to int - if (target == int.class) // this covers all the smaller integer types - return true; - case IntPrimitive: - if (target == long.class) - return true; - break; - case FloatPrimitive: - if (target == double.class) - return true; - break; - } - - return false; - } private enum LanguageParserPrimitiveType { // Including "Enum" (or really, any differentiating string) in these names is important. They're used diff --git a/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapperJpyImpl.java b/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapperJpyImpl.java index a420c0eed27..84c75748aed 100644 --- a/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapperJpyImpl.java +++ b/engine/table/src/main/java/io/deephaven/engine/util/PyCallableWrapperJpyImpl.java @@ -13,7 +13,6 @@ import java.util.*; import static io.deephaven.engine.table.impl.lang.QueryLanguageParser.NULL_CLASS; -import static io.deephaven.engine.table.impl.lang.QueryLanguageParser.isLosslessWideningPrimitiveConversion; import static io.deephaven.util.type.TypeUtils.getUnboxedType; /** @@ -51,7 +50,7 @@ public class PyCallableWrapperJpyImpl implements PyCallableWrapper { numpyType2JavaArrayClass.put('l', long[].class); numpyType2JavaArrayClass.put('f', float[].class); numpyType2JavaArrayClass.put('d', double[].class); - numpyType2JavaArrayClass.put('?', boolean[].class); + numpyType2JavaArrayClass.put('?', Boolean[].class); numpyType2JavaArrayClass.put('U', String[].class); numpyType2JavaArrayClass.put('M', Instant[].class); numpyType2JavaArrayClass.put('O', Object[].class); @@ -200,8 +199,8 @@ public void parseSignature() { List parameters = new ArrayList<>(); if (!pyEncodedParamsStr.isEmpty()) { String[] pyEncodedParams = pyEncodedParamsStr.split(","); - for (int i = 0; i < pyEncodedParams.length; i++) { - String[] paramDetail = pyEncodedParams[i].split(":"); + for (String pyEncodedParam : pyEncodedParams) { + String[] paramDetail = pyEncodedParam.split(":"); String paramName = paramDetail[0]; String paramTypeCodes = paramDetail[1]; Set> possibleTypes = new HashSet<>(); @@ -211,9 +210,6 @@ public void parseSignature() { // skip the array type code ti++; possibleTypes.add(numpyType2JavaArrayClass.get(paramTypeCodes.charAt(ti))); - if (paramTypeCodes.charAt(ti) == '?') { - possibleTypes.add(Boolean[].class); - } } else if (typeCode == 'N') { possibleTypes.add(NULL_CLASS); } else { @@ -250,14 +246,40 @@ private boolean isSafelyCastable(Set> types, Class type) { return false; } + public static boolean isLosslessWideningPrimitiveConversion(Class original, Class target) { + if (original == null || !original.isPrimitive() || target == null || !target.isPrimitive() + || original.equals(void.class) || target.equals(void.class)) { + throw new IllegalArgumentException("Arguments must be a primitive type (excluding void)!"); + } + + if (original.equals(target)) { + return true; + } + + if (original.equals(byte.class)) { + return target == short.class || target == int.class || target == long.class; + } else if (original.equals(short.class) || original.equals(char.class)) { // char is unsigned, so it's a + // lossless conversion to int + return target == int.class || target == long.class; + } else if (original.equals(int.class)) { + return target == long.class; + } else if (original.equals(float.class)) { + return target == double.class; + } + + return false; + } public void verifyArguments(Class[] argTypes) { String callableName = pyCallable.getAttribute("__name__").toString(); List parameters = signature.getParameters(); for (int i = 0; i < argTypes.length; i++) { + // if there are more arguments than parameters, we'll need to consider the last parameter as a varargs + // parameter. This is not ideal. We should consider a better way to handle this, i.e. a way to convey that + // the function is variadic. Set> types = - parameters.get(i > parameters.size() - 1 ? parameters.size() - 1 : i).getPossibleTypes(); + parameters.get(Math.min(i, parameters.size() - 1)).getPossibleTypes(); // to prevent the unpacking of an array column when calling a Python function, we prefix the column accessor // with a cast to generic Object type, until we can find a way to convey that info, we'll just skip the @@ -267,15 +289,14 @@ public void verifyArguments(Class[] argTypes) { } Class t = getUnboxedType(argTypes[i]) == null ? argTypes[i] : getUnboxedType(argTypes[i]); - if (!types.contains(t) && !types.contains(Object.class) - && !isSafelyCastable(types, t)) { + if (!types.contains(t) && !types.contains(Object.class) && !isSafelyCastable(types, t)) { throw new IllegalArgumentException( callableName + ": " + "Expected argument (" + parameters.get(i).getName() + ") to be one of " + parameters.get(i).getPossibleTypes() + ", got " + (argTypes[i].equals(NULL_CLASS) ? "null" : argTypes[i])); } } - }; + } // In vectorized mode, we want to call the vectorized function directly. public PyObject vectorizedCallable() { diff --git a/py/server/deephaven/_udf.py b/py/server/deephaven/_udf.py index 99708161a09..e1c5a5a3971 100644 --- a/py/server/deephaven/_udf.py +++ b/py/server/deephaven/_udf.py @@ -113,7 +113,7 @@ def _component_np_dtype_char(t: type) -> Optional[str]: numpy ndarray, otherwise return None. """ component_type = None - if not component_type and sys.version_info.major == 3 and sys.version_info.minor > 8: + if sys.version_info > (3, 8): import types if isinstance(t, types.GenericAlias) and issubclass(t.__origin__, Sequence): component_type = t.__args__[0] @@ -175,10 +175,7 @@ def _is_union_type(t: type) -> bool: if isinstance(t, types.UnionType): return True - if isinstance(t, _GenericAlias) and t.__origin__ == Union: - return True - - return False + return isinstance(t, _GenericAlias) and t.__origin__ == Union def _parse_param(name: str, annotation: Any) -> _ParsedParam: @@ -480,7 +477,6 @@ def _py_udf(fn: Callable): # build a signature string for vectorization by removing NoneType, array char '[', and comma from the encoded types # since vectorization only supports UDFs with a single signature and enforces an exact match, any non-compliant # signature (e.g. Union with more than 1 non-NoneType) will be rejected by the vectorizer. - sig_str_vectorization = re.sub(r"[\[N,]", "", p_sig.encoded) return_array = p_sig.ret_annotation.has_array ret_dtype = dtypes.from_np_dtype(np.dtype(p_sig.ret_annotation.encoded_type[-1])) @@ -505,7 +501,6 @@ def wrapper(*args, **kwargs): j_class = real_ret_dtype.qst_type.clazz() wrapper.return_type = j_class - # wrapper.signature = sig_str_vectorization wrapper.signature = p_sig.encoded return wrapper diff --git a/py/server/tests/test_udf_args.py b/py/server/tests/test_udf_args.py index 8e58050a9d6..9e1c4a1cfc5 100644 --- a/py/server/tests/test_udf_args.py +++ b/py/server/tests/test_udf_args.py @@ -486,6 +486,15 @@ def f(x: bytearray) -> bool: t = empty_table(1).update(["X = i", "Y = (byte)(ii % 128)"]).group_by("X").update(["Z = f(Y.toArray())"]) self.assertEqual(t.columns[2].data_type, dtypes.bool_) + def test_non_common_cases(self): + def f1(x: int) -> float: + ... + + def f2(x: float) -> int: + ... + + t = empty_table(1).update("X = f2(f1(ii))") + self.assertEqual(t.columns[0].data_type, dtypes.int_) if __name__ == "__main__": unittest.main() diff --git a/py/server/tests/test_udf_return_java_values.py b/py/server/tests/test_udf_return_java_values.py index d42b6f8465c..6d897195e21 100644 --- a/py/server/tests/test_udf_return_java_values.py +++ b/py/server/tests/test_udf_return_java_values.py @@ -54,8 +54,7 @@ def test_array_return(self): "np.str_": dtypes.string_array, "np.uint16": dtypes.char_array, } - # container_types = ["List", "Tuple", "list", "tuple", "Sequence", "np.ndarray"] - container_types = ["list"] + container_types = ["List", "Tuple", "list", "tuple", "Sequence", "np.ndarray"] for component_type, dh_dtype in component_types.items(): for container_type in container_types: with self.subTest(component_type=component_type, container_type=container_type): @@ -68,6 +67,15 @@ def test_array_return(self): t = empty_table(10).update(["X = i % 3", "Y = i"]).group_by("X").update(f"Z= fn(Y + 1)") self.assertEqual(t.columns[2].data_type, dh_dtype) + container_types = ["bytes", "bytearray"] + for container_type in container_types: + with self.subTest(container_type=container_type): + func_decl_str = f"""def fn(col) -> {container_type}:""" + func_body_str = f""" return {container_type}(col)""" + exec("\n".join([func_decl_str, func_body_str]), globals()) + t = empty_table(10).update(["X = i % 3", "Y = i"]).group_by("X").update(f"Z= fn(Y + 1)") + self.assertEqual(t.columns[2].data_type, dtypes.byte_array) + def test_scalar_return_class_method_not_supported(self): for dh_dtype, np_dtype in _J_TYPE_NP_DTYPE_MAP.items(): with self.subTest(dh_dtype=dh_dtype, np_dtype=np_dtype): From 41d12f55414aa859eb6ecaf3c59677ed33f5fd11 Mon Sep 17 00:00:00 2001 From: jianfengmao Date: Fri, 22 Mar 2024 13:03:25 -0600 Subject: [PATCH 07/10] Fix a test failure --- .../engine/table/impl/lang/PyCallableWrapperDummyImpl.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/engine/table/src/test/java/io/deephaven/engine/table/impl/lang/PyCallableWrapperDummyImpl.java b/engine/table/src/test/java/io/deephaven/engine/table/impl/lang/PyCallableWrapperDummyImpl.java index 70ea8a57f56..6943a880769 100644 --- a/engine/table/src/test/java/io/deephaven/engine/table/impl/lang/PyCallableWrapperDummyImpl.java +++ b/engine/table/src/test/java/io/deephaven/engine/table/impl/lang/PyCallableWrapperDummyImpl.java @@ -72,7 +72,7 @@ public void addChunkArgument(ChunkArgument ignored) {} @Override public Signature getSignature() { - return null; + return new Signature(new ArrayList<>(), Void.class); } @Override From f69f7c4c5c38c061072ab49c028d5294c666d12b Mon Sep 17 00:00:00 2001 From: jianfengmao Date: Fri, 22 Mar 2024 13:43:58 -0600 Subject: [PATCH 08/10] Skip vermin check for explictly ver checked code --- py/server/deephaven/_udf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/py/server/deephaven/_udf.py b/py/server/deephaven/_udf.py index e1c5a5a3971..c34e5d2614c 100644 --- a/py/server/deephaven/_udf.py +++ b/py/server/deephaven/_udf.py @@ -115,7 +115,7 @@ def _component_np_dtype_char(t: type) -> Optional[str]: if sys.version_info > (3, 8): import types - if isinstance(t, types.GenericAlias) and issubclass(t.__origin__, Sequence): + if isinstance(t, types.GenericAlias) and issubclass(t.__origin__, Sequence): # novermin component_type = t.__args__[0] if not component_type: @@ -172,7 +172,7 @@ def _is_union_type(t: type) -> bool: """Return True if the type is a Union type""" if sys.version_info.major == 3 and sys.version_info.minor >= 10: import types - if isinstance(t, types.UnionType): + if isinstance(t, types.UnionType): # novermin return True return isinstance(t, _GenericAlias) and t.__origin__ == Union From b276280baa8bc8029bae78c6716578342a10bf41 Mon Sep 17 00:00:00 2001 From: jianfengmao Date: Sun, 24 Mar 2024 19:19:12 -0600 Subject: [PATCH 09/10] More novermin check --- py/server/deephaven/_udf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/py/server/deephaven/_udf.py b/py/server/deephaven/_udf.py index c34e5d2614c..fb8b77704b7 100644 --- a/py/server/deephaven/_udf.py +++ b/py/server/deephaven/_udf.py @@ -156,14 +156,14 @@ def _np_ndarray_component_type(t: type) -> Optional[type]: # when np.ndarray is used, the 1st argument is the component type if not component_type and sys.version_info.major == 3 and sys.version_info.minor > 8: import types - if isinstance(t, types.GenericAlias) and t.__origin__ == np.ndarray: + if isinstance(t, types.GenericAlias) and t.__origin__ == np.ndarray: # novermin nargs = len(t.__args__) if nargs == 1: component_type = t.__args__[0] elif nargs == 2: # for npt.NDArray[np.int64], etc. a0 = t.__args__[0] a1 = t.__args__[1] - if a0 == typing.Any and isinstance(a1, types.GenericAlias): + if a0 == typing.Any and isinstance(a1, types.GenericAlias): # novermin component_type = a1.__args__[0] return component_type From 31a33a738306749e57edbb421399d3b76e9d4362 Mon Sep 17 00:00:00 2001 From: jianfengmao Date: Sun, 24 Mar 2024 20:42:46 -0600 Subject: [PATCH 10/10] Make comment more clear and new test case --- .../table/impl/lang/QueryLanguageParser.java | 5 +++-- py/server/tests/test_udf_args.py | 21 ++++++++++++++++++- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/engine/table/src/main/java/io/deephaven/engine/table/impl/lang/QueryLanguageParser.java b/engine/table/src/main/java/io/deephaven/engine/table/impl/lang/QueryLanguageParser.java index 83109869bad..ae037b986d6 100644 --- a/engine/table/src/main/java/io/deephaven/engine/table/impl/lang/QueryLanguageParser.java +++ b/engine/table/src/main/java/io/deephaven/engine/table/impl/lang/QueryLanguageParser.java @@ -2511,8 +2511,9 @@ private void verifyPyCallableArguments(@NotNull MethodCallExpr n, @NotNull Class final String invokedMethodName = n.getNameAsString(); if (GET_ATTRIBUTE_METHOD_NAME.equals(invokedMethodName)) { - // Only PyCallableWrapper.getAttribute()/PyCallableWrapper.call() may be invoked from the query language. - // UDF type checks are not currently supported for getAttribute() calls. + // Currently Python UDF handling is only supported for top module level function(callable) calls. + // The getAttribute() calls which is needed to support Python method calls, which is beyond the scope of + // current implementation. So we are skipping the argument verification for getAttribute() calls. return; } if (!n.containsData(QueryLanguageParserDataKeys.PY_CALLABLE_DETAILS)) { diff --git a/py/server/tests/test_udf_args.py b/py/server/tests/test_udf_args.py index 9e1c4a1cfc5..0f8f875ffa8 100644 --- a/py/server/tests/test_udf_args.py +++ b/py/server/tests/test_udf_args.py @@ -8,7 +8,8 @@ import numpy as np import numpy.typing as npt -from deephaven import empty_table, DHError, dtypes +from deephaven import empty_table, DHError, dtypes, new_table +from deephaven.column import int_col from deephaven.dtypes import double_array, int32_array, long_array, int16_array, char_array, int8_array, \ float32_array from tests.testbase import BaseTestCase @@ -496,5 +497,23 @@ def f2(x: float) -> int: t = empty_table(1).update("X = f2(f1(ii))") self.assertEqual(t.columns[0].data_type, dtypes.int_) + def test_varargs(self): + cols = ["A", "B", "C", "D"] + + def my_sum(p1: np.int32, *args: np.int64) -> int: + return sum(args) + + t = new_table([int_col(c, [0, 1, 2, 3, 4, 5, 6]) for c in cols]) + result = t.update(f"X = my_sum({','.join(cols)})") + self.assertEqual(result.columns[4].data_type, dtypes.int64) + + def my_sum_error(p1: np.int32, *args: np.int16) -> int: + return sum(args) + with self.assertRaises(DHError) as cm: + t.update(f"X = my_sum_error({','.join(cols)})") + self.assertRegex(str(cm.exception), "my_sum_error: Expected argument .* got int") + + + if __name__ == "__main__": unittest.main()