Skip to content

Commit

Permalink
Code clean up/add more comments/test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
jmao-denver committed Mar 22, 2024
1 parent 835a4a1 commit 1f9167e
Show file tree
Hide file tree
Showing 5 changed files with 53 additions and 52 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1968,38 +1968,6 @@ public static boolean isWideningPrimitiveConversion(Class<?> original, Class<?>
return false;
}

public static boolean isLosslessWideningPrimitiveConversion(Class<?> original, Class<?> target) {
if (original == null || !original.isPrimitive() || target == null || !target.isPrimitive()
|| original.equals(void.class) || target.equals(void.class)) {
throw new IllegalArgumentException("Arguments must be a primitive type (excluding void)!");
}

if (original.equals(target)) {
return true;
}

LanguageParserPrimitiveType originalEnum = LanguageParserPrimitiveType.getPrimitiveType(original);

switch (originalEnum) {
case BytePrimitive:
if (target == short.class)
return true;
case ShortPrimitive:
case CharPrimitive: // char is unsigned, so it's a lossless conversion to int
if (target == int.class) // this covers all the smaller integer types
return true;
case IntPrimitive:
if (target == long.class)
return true;
break;
case FloatPrimitive:
if (target == double.class)
return true;
break;
}

return false;
}

private enum LanguageParserPrimitiveType {
// Including "Enum" (or really, any differentiating string) in these names is important. They're used
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import java.util.*;

import static io.deephaven.engine.table.impl.lang.QueryLanguageParser.NULL_CLASS;
import static io.deephaven.engine.table.impl.lang.QueryLanguageParser.isLosslessWideningPrimitiveConversion;
import static io.deephaven.util.type.TypeUtils.getUnboxedType;

/**
Expand Down Expand Up @@ -51,7 +50,7 @@ public class PyCallableWrapperJpyImpl implements PyCallableWrapper {
numpyType2JavaArrayClass.put('l', long[].class);
numpyType2JavaArrayClass.put('f', float[].class);
numpyType2JavaArrayClass.put('d', double[].class);
numpyType2JavaArrayClass.put('?', boolean[].class);
numpyType2JavaArrayClass.put('?', Boolean[].class);
numpyType2JavaArrayClass.put('U', String[].class);
numpyType2JavaArrayClass.put('M', Instant[].class);
numpyType2JavaArrayClass.put('O', Object[].class);
Expand Down Expand Up @@ -200,8 +199,8 @@ public void parseSignature() {
List<Parameter> parameters = new ArrayList<>();
if (!pyEncodedParamsStr.isEmpty()) {
String[] pyEncodedParams = pyEncodedParamsStr.split(",");
for (int i = 0; i < pyEncodedParams.length; i++) {
String[] paramDetail = pyEncodedParams[i].split(":");
for (String pyEncodedParam : pyEncodedParams) {
String[] paramDetail = pyEncodedParam.split(":");
String paramName = paramDetail[0];
String paramTypeCodes = paramDetail[1];
Set<Class<?>> possibleTypes = new HashSet<>();
Expand All @@ -211,9 +210,6 @@ public void parseSignature() {
// skip the array type code
ti++;
possibleTypes.add(numpyType2JavaArrayClass.get(paramTypeCodes.charAt(ti)));
if (paramTypeCodes.charAt(ti) == '?') {
possibleTypes.add(Boolean[].class);
}
} else if (typeCode == 'N') {
possibleTypes.add(NULL_CLASS);
} else {
Expand Down Expand Up @@ -250,14 +246,40 @@ private boolean isSafelyCastable(Set<Class<?>> types, Class<?> type) {
return false;
}

public static boolean isLosslessWideningPrimitiveConversion(Class<?> original, Class<?> target) {
if (original == null || !original.isPrimitive() || target == null || !target.isPrimitive()
|| original.equals(void.class) || target.equals(void.class)) {
throw new IllegalArgumentException("Arguments must be a primitive type (excluding void)!");
}

if (original.equals(target)) {
return true;
}

if (original.equals(byte.class)) {
return target == short.class || target == int.class || target == long.class;
} else if (original.equals(short.class) || original.equals(char.class)) { // char is unsigned, so it's a
// lossless conversion to int
return target == int.class || target == long.class;
} else if (original.equals(int.class)) {
return target == long.class;
} else if (original.equals(float.class)) {
return target == double.class;
}

return false;
}

public void verifyArguments(Class<?>[] argTypes) {
String callableName = pyCallable.getAttribute("__name__").toString();
List<Parameter> parameters = signature.getParameters();

for (int i = 0; i < argTypes.length; i++) {
// if there are more arguments than parameters, we'll need to consider the last parameter as a varargs
// parameter. This is not ideal. We should consider a better way to handle this, i.e. a way to convey that
// the function is variadic.
Set<Class<?>> types =
parameters.get(i > parameters.size() - 1 ? parameters.size() - 1 : i).getPossibleTypes();
parameters.get(Math.min(i, parameters.size() - 1)).getPossibleTypes();

// to prevent the unpacking of an array column when calling a Python function, we prefix the column accessor
// with a cast to generic Object type, until we can find a way to convey that info, we'll just skip the
Expand All @@ -267,15 +289,14 @@ public void verifyArguments(Class<?>[] argTypes) {
}

Class<?> t = getUnboxedType(argTypes[i]) == null ? argTypes[i] : getUnboxedType(argTypes[i]);
if (!types.contains(t) && !types.contains(Object.class)
&& !isSafelyCastable(types, t)) {
if (!types.contains(t) && !types.contains(Object.class) && !isSafelyCastable(types, t)) {
throw new IllegalArgumentException(
callableName + ": " + "Expected argument (" + parameters.get(i).getName() + ") to be one of "
+ parameters.get(i).getPossibleTypes() + ", got "
+ (argTypes[i].equals(NULL_CLASS) ? "null" : argTypes[i]));
}
}
};
}

// In vectorized mode, we want to call the vectorized function directly.
public PyObject vectorizedCallable() {
Expand Down
9 changes: 2 additions & 7 deletions py/server/deephaven/_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def _component_np_dtype_char(t: type) -> Optional[str]:
numpy ndarray, otherwise return None. """
component_type = None

if not component_type and sys.version_info.major == 3 and sys.version_info.minor > 8:
if sys.version_info > (3, 8):
import types
if isinstance(t, types.GenericAlias) and issubclass(t.__origin__, Sequence):
component_type = t.__args__[0]
Expand Down Expand Up @@ -175,10 +175,7 @@ def _is_union_type(t: type) -> bool:
if isinstance(t, types.UnionType):
return True

if isinstance(t, _GenericAlias) and t.__origin__ == Union:
return True

return False
return isinstance(t, _GenericAlias) and t.__origin__ == Union


def _parse_param(name: str, annotation: Any) -> _ParsedParam:
Expand Down Expand Up @@ -480,7 +477,6 @@ def _py_udf(fn: Callable):
# build a signature string for vectorization by removing NoneType, array char '[', and comma from the encoded types
# since vectorization only supports UDFs with a single signature and enforces an exact match, any non-compliant
# signature (e.g. Union with more than 1 non-NoneType) will be rejected by the vectorizer.
sig_str_vectorization = re.sub(r"[\[N,]", "", p_sig.encoded)
return_array = p_sig.ret_annotation.has_array
ret_dtype = dtypes.from_np_dtype(np.dtype(p_sig.ret_annotation.encoded_type[-1]))

Expand All @@ -505,7 +501,6 @@ def wrapper(*args, **kwargs):
j_class = real_ret_dtype.qst_type.clazz()

wrapper.return_type = j_class
# wrapper.signature = sig_str_vectorization
wrapper.signature = p_sig.encoded

return wrapper
Expand Down
9 changes: 9 additions & 0 deletions py/server/tests/test_udf_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,15 @@ def f(x: bytearray) -> bool:
t = empty_table(1).update(["X = i", "Y = (byte)(ii % 128)"]).group_by("X").update(["Z = f(Y.toArray())"])
self.assertEqual(t.columns[2].data_type, dtypes.bool_)

def test_non_common_cases(self):
def f1(x: int) -> float:
...

def f2(x: float) -> int:
...

t = empty_table(1).update("X = f2(f1(ii))")
self.assertEqual(t.columns[0].data_type, dtypes.int_)

if __name__ == "__main__":
unittest.main()
12 changes: 10 additions & 2 deletions py/server/tests/test_udf_return_java_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,7 @@ def test_array_return(self):
"np.str_": dtypes.string_array,
"np.uint16": dtypes.char_array,
}
# container_types = ["List", "Tuple", "list", "tuple", "Sequence", "np.ndarray"]
container_types = ["list"]
container_types = ["List", "Tuple", "list", "tuple", "Sequence", "np.ndarray"]
for component_type, dh_dtype in component_types.items():
for container_type in container_types:
with self.subTest(component_type=component_type, container_type=container_type):
Expand All @@ -68,6 +67,15 @@ def test_array_return(self):
t = empty_table(10).update(["X = i % 3", "Y = i"]).group_by("X").update(f"Z= fn(Y + 1)")
self.assertEqual(t.columns[2].data_type, dh_dtype)

container_types = ["bytes", "bytearray"]
for container_type in container_types:
with self.subTest(container_type=container_type):
func_decl_str = f"""def fn(col) -> {container_type}:"""
func_body_str = f""" return {container_type}(col)"""
exec("\n".join([func_decl_str, func_body_str]), globals())
t = empty_table(10).update(["X = i % 3", "Y = i"]).group_by("X").update(f"Z= fn(Y + 1)")
self.assertEqual(t.columns[2].data_type, dtypes.byte_array)

def test_scalar_return_class_method_not_supported(self):
for dh_dtype, np_dtype in _J_TYPE_NP_DTYPE_MAP.items():
with self.subTest(dh_dtype=dh_dtype, np_dtype=np_dtype):
Expand Down

0 comments on commit 1f9167e

Please sign in to comment.