Skip to content

Commit

Permalink
auto convert Python return val to Java val (#4579)
Browse files Browse the repository at this point in the history
* Auto conv Py UDF returns and numba guvectorized

* Add numba guvectorize tests

* Harden the code for bad type annotations

* More meaningful names

* Clean up test code

* Apply suggestions from code review

Co-authored-by: Chip Kent <5250374+chipkent@users.noreply.github.com>

* Add support for datetime data

* Update py/server/deephaven/dtypes.py

Co-authored-by: Chip Kent <5250374+chipkent@users.noreply.github.com>

* Respond to review comments and array testcases

* Respond to reivew comments

* Respond to new review comments

* new test case for complex(unsupported) typehint

---------

Co-authored-by: Chip Kent <5250374+chipkent@users.noreply.github.com>
  • Loading branch information
jmao-denver and chipkent committed Oct 6, 2023
1 parent 16cb2f4 commit 77587f4
Show file tree
Hide file tree
Showing 11 changed files with 623 additions and 82 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import com.github.javaparser.ast.comments.BlockComment;
import com.github.javaparser.ast.comments.JavadocComment;
import com.github.javaparser.ast.comments.LineComment;
import com.github.javaparser.ast.comments.Comment;
import com.github.javaparser.ast.expr.ArrayAccessExpr;
import com.github.javaparser.ast.expr.ArrayCreationExpr;
import com.github.javaparser.ast.expr.ArrayInitializerExpr;
Expand Down Expand Up @@ -123,6 +122,7 @@
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.lang.reflect.TypeVariable;
import java.time.Instant;
import java.util.*;
import java.util.function.Function;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -1214,7 +1214,7 @@ private Expression[] convertParameters(final Executable executable,
// as a single argument
if (ObjectVector.class.isAssignableFrom(expressionTypes[ei])) {
expressions[ei] = new CastExpr(
new ClassOrInterfaceType("java.lang.Object"),
StaticJavaParser.parseClassOrInterfaceType("java.lang.Object"),
expressions[ei]);
expressionTypes[ei] = Object.class;
} else {
Expand Down Expand Up @@ -2023,7 +2023,7 @@ public Class<?> visit(ConditionalExpr n, VisitArgs printer) {
if (classA == boolean.class && classB == Boolean.class) {
// a little hacky, but this handles the null case where it unboxes. very weird stuff
final Expression uncastExpr = n.getThenExpr();
final CastExpr castExpr = new CastExpr(new ClassOrInterfaceType(null, "Boolean"), uncastExpr);
final CastExpr castExpr = new CastExpr(StaticJavaParser.parseClassOrInterfaceType("Boolean"), uncastExpr);
n.setThenExpr(castExpr);
// fix parent in uncastExpr (it is cleared when it is replaced with the CastExpr)
uncastExpr.setParentNode(castExpr);
Expand All @@ -2032,7 +2032,7 @@ public Class<?> visit(ConditionalExpr n, VisitArgs printer) {
if (classA == Boolean.class && classB == boolean.class) {
// a little hacky, but this handles the null case where it unboxes. very weird stuff
final Expression uncastExpr = n.getElseExpr();
final CastExpr castExpr = new CastExpr(new ClassOrInterfaceType(null, "Boolean"), uncastExpr);
final CastExpr castExpr = new CastExpr(StaticJavaParser.parseClassOrInterfaceType("Boolean"), uncastExpr);
n.setElseExpr(castExpr);
// fix parent in uncastExpr (it is cleared when it is replaced with the CastExpr)
uncastExpr.setParentNode(castExpr);
Expand Down Expand Up @@ -2159,7 +2159,8 @@ public Class<?> visit(FieldAccessExpr n, VisitArgs printer) {
printer.append(", " + clsName + ".class");

final ClassExpr targetType =
new ClassExpr(new ClassOrInterfaceType(null, printer.pythonCastContext.getSimpleName()));
new ClassExpr(
StaticJavaParser.parseClassOrInterfaceType(printer.pythonCastContext.getSimpleName()));
getAttributeArgs.add(targetType);

// Let's advertise to the caller the cast context type
Expand Down Expand Up @@ -2337,6 +2338,35 @@ public Class<?> visit(MethodCallExpr n, VisitArgs printer) {
callMethodCall.setData(QueryLanguageParserDataKeys.PY_CALLABLE_DETAILS,
new PyCallableDetails(null, methodName));

if (PyCallableWrapper.class.isAssignableFrom(method.getDeclaringClass())) {
final Optional<Class<?>> optionalRetType = pyCallableReturnType(callMethodCall);
if (optionalRetType.isPresent()) {
Class<?> retType = optionalRetType.get();
final Optional<CastExpr> optionalCastExpr =
makeCastExpressionForPyCallable(retType, callMethodCall);
if (optionalCastExpr.isPresent()) {
final CastExpr castExpr = optionalCastExpr.get();
replaceChildExpression(
n.getParentNode().orElseThrow(),
n,
castExpr);

callMethodCall.getData(QueryLanguageParserDataKeys.PY_CALLABLE_DETAILS).setCasted(true);
try {
return castExpr.accept(this, printer);
} catch (Exception e) {
// exceptions could be thrown by {@link #tryVectorizePythonCallable}
replaceChildExpression(
castExpr.getParentNode().orElseThrow(),
castExpr,
callMethodCall);
callMethodCall.getData(QueryLanguageParserDataKeys.PY_CALLABLE_DETAILS)
.setCasted(false);
return callMethodCall.accept(this, printer);
}
}
}
}
replaceChildExpression(
n.getParentNode().orElseThrow(),
n,
Expand Down Expand Up @@ -2376,7 +2406,7 @@ public Class<?> visit(MethodCallExpr n, VisitArgs printer) {

final ObjectCreationExpr newPyCallableExpr = new ObjectCreationExpr(
null,
new ClassOrInterfaceType(null, pyCallableWrapperImplName),
StaticJavaParser.parseClassOrInterfaceType(pyCallableWrapperImplName),
NodeList.nodeList(getAttributeCall));

final MethodCallExpr callMethodCall = new MethodCallExpr(
Expand Down Expand Up @@ -2430,6 +2460,60 @@ public Class<?> visit(MethodCallExpr n, VisitArgs printer) {
typeArguments);
}

private Optional<CastExpr> makeCastExpressionForPyCallable(Class<?> retType, MethodCallExpr callMethodCall) {
if (retType.isPrimitive()) {
return Optional.of(new CastExpr(
new PrimitiveType(PrimitiveType.Primitive
.valueOf(retType.getSimpleName().toUpperCase())),
callMethodCall));
} else if (retType.getComponentType() != null) {
final Class<?> componentType = retType.getComponentType();
if (componentType.isPrimitive()) {
ArrayType arrayType;
if (componentType == boolean.class) {
arrayType = new ArrayType(StaticJavaParser.parseClassOrInterfaceType("java.lang.Boolean"));
} else {
arrayType = new ArrayType(new PrimitiveType(PrimitiveType.Primitive
.valueOf(retType.getComponentType().getSimpleName().toUpperCase())));
}
return Optional.of(new CastExpr(arrayType, callMethodCall));
} else if (retType.getComponentType() == String.class || retType.getComponentType() == Boolean.class
|| retType.getComponentType() == Instant.class) {
ArrayType arrayType =
new ArrayType(
StaticJavaParser.parseClassOrInterfaceType(retType.getComponentType().getSimpleName()));
return Optional.of(new CastExpr(arrayType, callMethodCall));
}
} else if (retType == Boolean.class) {
return Optional
.of(new CastExpr(StaticJavaParser.parseClassOrInterfaceType("java.lang.Boolean"), callMethodCall));
} else if (retType == String.class) {
return Optional
.of(new CastExpr(StaticJavaParser.parseClassOrInterfaceType("java.lang.String"), callMethodCall));
} else if (retType == Instant.class) {
return Optional
.of(new CastExpr(StaticJavaParser.parseClassOrInterfaceType("java.time.Instant"), callMethodCall));
}

return Optional.empty();
}

private Optional<Class<?>> pyCallableReturnType(@NotNull MethodCallExpr n) {
final PyCallableDetails pyCallableDetails = n.getData(QueryLanguageParserDataKeys.PY_CALLABLE_DETAILS);
final String pyMethodName = pyCallableDetails.pythonMethodName;
final QueryScope queryScope = ExecutionContext.getContext().getQueryScope();
final Object paramValueRaw = queryScope.readParamValue(pyMethodName, null);
if (paramValueRaw == null) {
return Optional.empty();
}
if (!(paramValueRaw instanceof PyCallableWrapper)) {
return Optional.empty();
}
final PyCallableWrapper pyCallableWrapper = (PyCallableWrapper) paramValueRaw;
pyCallableWrapper.parseSignature();
return Optional.ofNullable(pyCallableWrapper.getReturnType());
}

@NotNull
private static Expression[] getExpressionsArray(final NodeList<Expression> exprNodeList) {
return exprNodeList == null ? new Expression[0]
Expand Down Expand Up @@ -2563,13 +2647,27 @@ private void checkVectorizability(@NotNull final MethodCallExpr n,
// expression evaluation code that expects singular values. This check makes sure that numba/dh vectorized
// functions must be used alone as the entire expression after removing the enclosing parentheses.
Node n1 = n;
boolean autoCastChecked = false;
while (n1.hasParentNode()) {
n1 = n1.getParentNode().orElseThrow();
Class<?> cls = n1.getClass();

if (cls == CastExpr.class) {
throw new PythonCallVectorizationFailure(
"The return values of Python vectorized function can't be cast: " + n1);
if (!autoCastChecked && n.getData(QueryLanguageParserDataKeys.PY_CALLABLE_DETAILS).isCasted()) {
autoCastChecked = true;
} else {
throw new PythonCallVectorizationFailure(
"The return values of Python vectorized functions can't be cast: " + n1);
}
} else if (cls == MethodCallExpr.class) {
String methodName = ((MethodCallExpr) n1).getNameAsString();
if (!autoCastChecked && n.getData(QueryLanguageParserDataKeys.PY_CALLABLE_DETAILS).isCasted()
&& methodName.endsWith("Cast")) {
autoCastChecked = true;
} else {
throw new PythonCallVectorizationFailure(
"Python vectorized function can't be used in another expression: " + n1);
}
} else if (cls != EnclosedExpr.class && cls != WrapperNode.class) {
throw new PythonCallVectorizationFailure(
"Python vectorized function can't be used in another expression: " + n1);
Expand Down Expand Up @@ -3233,6 +3331,17 @@ private static class PyCallableDetails {
@NotNull
private final String pythonMethodName;

@NotNull
private boolean isCasted = false;

public boolean isCasted() {
return isCasted;
}

public void setCasted(boolean casted) {
isCasted = casted;
}

private PyCallableDetails(@Nullable String pythonScopeExpr, @NotNull String pythonMethodName) {
this.pythonScopeExpr = pythonScopeExpr;
this.pythonMethodName = pythonMethodName;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import org.jpy.PyModule;
import org.jpy.PyObject;

import java.time.Instant;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
Expand All @@ -20,6 +21,7 @@ public class PyCallableWrapperJpyImpl implements PyCallableWrapper {
private static final Logger log = LoggerFactory.getLogger(PyCallableWrapperJpyImpl.class);

private static final PyObject NUMBA_VECTORIZED_FUNC_TYPE = getNumbaVectorizedFuncType();
private static final PyObject NUMBA_GUVECTORIZED_FUNC_TYPE = getNumbaGUVectorizedFuncType();

private static final PyModule dh_table_module = PyModule.importModule("deephaven.table");

Expand All @@ -34,6 +36,7 @@ public class PyCallableWrapperJpyImpl implements PyCallableWrapper {
numpyType2JavaClass.put('b', byte.class);
numpyType2JavaClass.put('?', boolean.class);
numpyType2JavaClass.put('U', String.class);
numpyType2JavaClass.put('M', Instant.class);
numpyType2JavaClass.put('O', Object.class);
}

Expand All @@ -47,6 +50,7 @@ public class PyCallableWrapperJpyImpl implements PyCallableWrapper {
private Collection<ChunkArgument> chunkArguments;
private boolean numbaVectorized;
private PyObject unwrapped;
private PyObject pyUdfDecoratedCallable;

public PyCallableWrapperJpyImpl(PyObject pyCallable) {
this.pyCallable = pyCallable;
Expand Down Expand Up @@ -91,23 +95,37 @@ private static PyObject getNumbaVectorizedFuncType() {
}
}

private static PyObject getNumbaGUVectorizedFuncType() {
try {
return PyModule.importModule("numba.np.ufunc.gufunc").getAttribute("GUFunc");
} catch (Exception e) {
if (log.isDebugEnabled()) {
log.debug("Numba isn't installed in the Python environment.");
}
return null;
}
}

private void prepareSignature() {
if (pyCallable.getType().equals(NUMBA_VECTORIZED_FUNC_TYPE)) {
boolean isNumbaVectorized = pyCallable.getType().equals(NUMBA_VECTORIZED_FUNC_TYPE);
boolean isNumbaGUVectorized = pyCallable.equals(NUMBA_GUVECTORIZED_FUNC_TYPE);
if (isNumbaGUVectorized || isNumbaVectorized) {
List<PyObject> params = pyCallable.getAttribute("types").asList();
if (params.isEmpty()) {
throw new IllegalArgumentException(
"numba vectorized function must have an explicit signature: " + pyCallable);
"numba vectorized/guvectorized function must have an explicit signature: " + pyCallable);
}
// numba allows a vectorized function to have multiple signatures
if (params.size() > 1) {
throw new UnsupportedOperationException(
pyCallable
+ " has multiple signatures; this is not currently supported for numba vectorized functions");
+ " has multiple signatures; this is not currently supported for numba vectorized/guvectorized functions");
}
signature = params.get(0).getStringValue();
unwrapped = null;
numbaVectorized = true;
vectorized = true;
unwrapped = pyCallable;
// since vectorization doesn't support array type parameters, don't flag numba guvectorized as vectorized
numbaVectorized = isNumbaVectorized;
vectorized = isNumbaVectorized;
} else if (pyCallable.hasAttribute("dh_vectorized")) {
signature = pyCallable.getAttribute("signature").toString();
unwrapped = pyCallable.getAttribute("callable");
Expand All @@ -119,6 +137,7 @@ private void prepareSignature() {
numbaVectorized = false;
vectorized = false;
}
pyUdfDecoratedCallable = dh_table_module.call("_py_udf", unwrapped);
}

@Override
Expand All @@ -135,14 +154,6 @@ public void parseSignature() {
throw new IllegalStateException("Signature should always be available.");
}

char numpyTypeCode = signature.charAt(signature.length() - 1);
Class<?> returnType = numpyType2JavaClass.get(numpyTypeCode);
if (returnType == null) {
throw new IllegalStateException(
"Vectorized functions should always have an integral, floating point, boolean, String, or Object return type: "
+ numpyTypeCode);
}

List<Class<?>> paramTypes = new ArrayList<>();
for (char numpyTypeChar : signature.toCharArray()) {
if (numpyTypeChar != '-') {
Expand All @@ -159,25 +170,31 @@ public void parseSignature() {
}

this.paramTypes = paramTypes;
if (returnType == Object.class) {
this.returnType = PyObject.class;
} else if (returnType == boolean.class) {

returnType = pyUdfDecoratedCallable.getAttribute("return_type", null);
if (returnType == null) {
throw new IllegalStateException(
"Python functions should always have an integral, floating point, boolean, String, arrays, or Object return type");
}

if (returnType == boolean.class) {
this.returnType = Boolean.class;
} else {
this.returnType = returnType;
}
}

// In vectorized mode, we want to call the vectorized function directly.
public PyObject vectorizedCallable() {
if (numbaVectorized) {
if (numbaVectorized || vectorized) {
return pyCallable;
} else {
return dh_table_module.call("dh_vectorize", unwrapped);
}
}

// In non-vectorized mode, we want to call the udf decorated function or the original function.
@Override
public Object call(Object... args) {
PyObject pyCallable = this.pyUdfDecoratedCallable != null ? this.pyUdfDecoratedCallable : this.pyCallable;
return PythonScopeJpyImpl.convert(pyCallable.callMethod("__call__", args));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,6 @@ public void addChunkArgument(ChunkArgument ignored) {}

@Override
public Class<?> getReturnType() {
throw new UnsupportedOperationException();
return Object.class;
}
}
28 changes: 3 additions & 25 deletions py/server/deephaven/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,11 @@
from typing import Sequence, Any

import jpy
import numpy as np
import pandas as pd

import deephaven.dtypes as dtypes
from deephaven import DHError, time
from deephaven import DHError
from deephaven.dtypes import DType
from deephaven.time import to_j_instant
from deephaven.dtypes import _instant_array

_JColumnHeader = jpy.get_type("io.deephaven.qst.column.header.ColumnHeader")
_JColumn = jpy.get_type("io.deephaven.qst.column.Column")
Expand Down Expand Up @@ -206,27 +204,7 @@ def datetime_col(name: str, data: Sequence) -> InputColumn:
Returns:
a new input column
"""

# try to convert to numpy array of datetime64 if not already, so that we can call translateArrayLongToInstant on
# it to reduce the number of round trips to the JVM
if not isinstance(data, np.ndarray):
try:
data = np.array([pd.Timestamp(dt).to_numpy() for dt in data], dtype=np.datetime64)
except Exception as e:
...

if isinstance(data, np.ndarray) and data.dtype.kind in ('M', 'i', 'U'):
if data.dtype.kind == 'M':
longs = jpy.array('long', data.astype('datetime64[ns]').astype('int64'))
elif data.dtype.kind == 'i':
longs = jpy.array('long', data.astype('int64'))
else: # data.dtype.kind == 'U'
longs = jpy.array('long', [pd.Timestamp(str(dt)).to_numpy().astype('int64') for dt in data])
data = _JPrimitiveArrayConversionUtility.translateArrayLongToInstant(longs)

if not isinstance(data, dtypes.instant_array.j_type):
data = [to_j_instant(d) for d in data]

data = _instant_array(data)
return InputColumn(name=name, data_type=dtypes.Instant, input_data=data)


Expand Down
Loading

0 comments on commit 77587f4

Please sign in to comment.