Skip to content

Commit

Permalink
Fix bugs and old test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
jmao-denver committed Mar 18, 2024
1 parent f4419e6 commit 122e000
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ public final class QueryLanguageParser extends GenericVisitorAdapter<Class<?>, Q
private final Map<String, Class<?>> staticImportLookupCache = new HashMap<>();

// We need some class to represent null. We know for certain that this one won't be used...
private static final Class<?> NULL_CLASS = QueryLanguageParser.class;
public static final Class<?> NULL_CLASS = QueryLanguageParser.class;

/**
* The result of the QueryLanguageParser for the expression passed given to the constructor.
Expand Down Expand Up @@ -1968,6 +1968,39 @@ public static boolean isWideningPrimitiveConversion(Class<?> original, Class<?>
return false;
}

public static boolean isLosslessWideningPrimitiveConversion(Class<?> original, Class<?> target) {
if (original == null || !original.isPrimitive() || target == null || !target.isPrimitive()
|| original.equals(void.class) || target.equals(void.class)) {
throw new IllegalArgumentException("Arguments must be a primitive type (excluding void)!");
}

if (original.equals(target)) {
return true;
}

LanguageParserPrimitiveType originalEnum = LanguageParserPrimitiveType.getPrimitiveType(original);

switch (originalEnum) {
case BytePrimitive:
if (target == short.class)
return true;
case ShortPrimitive:
case CharPrimitive:
if (target == int.class)
return true;
case IntPrimitive:
if (target == long.class)
return true;
break;
case FloatPrimitive:
if (target == double.class)
return true;
break;
}

return false;
}

private enum LanguageParserPrimitiveType {
// Including "Enum" (or really, any differentiating string) in these names is important. They're used
// in a switch() statement, which apparently does not support qualified names. And we can't use
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
//
package io.deephaven.engine.util;

import io.deephaven.engine.table.impl.lang.QueryLanguageParser;
import io.deephaven.engine.table.impl.select.python.ArgumentsChunked;
import io.deephaven.internal.log.LoggerFactory;
import io.deephaven.io.logger.Logger;
Expand All @@ -12,7 +13,9 @@
import java.time.Instant;
import java.util.*;

import static io.deephaven.engine.table.impl.lang.QueryLanguageParser.isWideningPrimitiveConversion;
import static io.deephaven.engine.table.impl.lang.QueryLanguageParser.NULL_CLASS;
import static io.deephaven.engine.table.impl.lang.QueryLanguageParser.isLosslessWideningPrimitiveConversion;
import static io.deephaven.util.type.TypeUtils.getUnboxedType;

/**
* When given a pyObject that is a callable, we stick it inside the callable wrapper, which implements a call() varargs
Expand Down Expand Up @@ -228,6 +231,11 @@ public void parseSignature() {
// skip the array type code
ti++;
possibleTypes.add(numpyType2JavaArrayClass.get(paramTypeCodes.charAt(ti)));
if (paramTypeCodes.charAt(ti) == '?') {
possibleTypes.add(Boolean[].class);
}
} else if (typeCode == 'N') {
possibleTypes.add(NULL_CLASS);
} else {
possibleTypes.add(numpyType2JavaClass.get(typeCode));
}
Expand All @@ -252,7 +260,7 @@ private boolean isSafelyCastable(Set<Class<?>> types, Class<?> type) {
if (t.isAssignableFrom(type)) {
return true;
}
if (t.isPrimitive() && type.isPrimitive() && isWideningPrimitiveConversion(type, t)) {
if (t.isPrimitive() && type.isPrimitive() && isLosslessWideningPrimitiveConversion(type, t)) {
return true;
}
}
Expand All @@ -263,18 +271,25 @@ private boolean isSafelyCastable(Set<Class<?>> types, Class<?> type) {
public void verifyArguments(Class<?>[] argTypes) {
String callableName = pyCallable.getAttribute("__name__").toString();

if (argTypes.length != parameters.size()) {
throw new IllegalArgumentException(
callableName + ": " + "Expected " + parameters.size() + " arguments, got " + argTypes.length);
}
// if (argTypes.length > parameters.size()) {
// throw new IllegalArgumentException(
// callableName + ": " + "Expected " + parameters.size() + " or fewer arguments, got " + argTypes.length);
// }
for (int i = 0; i < argTypes.length; i++) {
Set<Class<?>> types = parameters.get(i).getPossibleTypes();
if (!types.contains(argTypes[i]) && !types.contains(Object.class)
&& !isSafelyCastable(types, argTypes[i])) {
Set<Class<?>> types =
parameters.get(i > parameters.size() - 1 ? parameters.size() - 1 : i).getPossibleTypes();
// Object is a catch-all type, so we don't need to check for it
if (argTypes[i] == Object.class) {
continue;
}

Class<?> t = getUnboxedType(argTypes[i]) == null ? argTypes[i] : getUnboxedType(argTypes[i]);
if (!types.contains(t) && !types.contains(Object.class)
&& !isSafelyCastable(types, t)) {
throw new IllegalArgumentException(
callableName + ": " + "Expected argument (" + parameters.get(i).getName() + ") to be one of "
+ parameters.get(i).getPossibleTypes() + ", got "
+ argTypes[i]);
+ (argTypes[i].equals(NULL_CLASS) ? "null" : argTypes[i]));
}
}
};
Expand Down
50 changes: 37 additions & 13 deletions py/server/tests/test_udf_numpy_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ def f11(p1: Union[float, np.float32]) -> bool:
with self.assertRaises(DHError) as cm:
t = empty_table(10).update(["X1 = f11(i)"])

def f2(p1: Union[np.int16, np.float64]) -> Union[Optional[bool]]:
def f2(p1: Union[np.int32, np.float64]) -> Union[Optional[bool]]:
return bool(p1)

t = empty_table(10).update(["X1 = f2(i)"])
Expand All @@ -231,7 +231,7 @@ def f21(p1: Union[np.int16, np.float64]) -> Union[Optional[bool], int]:
with self.assertRaises(DHError) as cm:
t = empty_table(10).update(["X1 = f21(i)"])

def f3(p1: Union[np.int16, np.float64], p2=None) -> bool:
def f3(p1: Union[np.int32, np.float64], p2=None) -> bool:
return bool(p1)

t = empty_table(10).update(["X1 = f3(i)"])
Expand All @@ -244,7 +244,7 @@ def f4(p1: Union[np.int16, np.float64], p2=None) -> bool:
self.assertEqual(t.columns[0].data_type, dtypes.bool_)
with self.assertRaises(DHError) as cm:
t = empty_table(10).update(["X1 = f4(now())"])
self.assertRegex(str(cm.exception), "Argument .* is not compatible with annotation*")
self.assertRegex(str(cm.exception), "f4: Expected .* got .*Instant")

def f41(p1: Union[np.int16, np.float64, Union[Any]], p2=None) -> bool:
return bool(p1)
Expand All @@ -266,7 +266,7 @@ def f5(col1, col2: np.ndarray[np.int32]) -> bool:
t = t.update(["X1 = f5(X, Y)"])
with self.assertRaises(DHError) as cm:
t = t.update(["X1 = f5(X, null)"])
self.assertRegex(str(cm.exception), "Argument .* is not compatible with annotation*")
self.assertRegex(str(cm.exception), "f5: Expected .* got null")

def f51(col1, col2: Optional[np.ndarray[np.int32]]) -> bool:
return np.nanmean(col2) == np.mean(col2)
Expand All @@ -287,7 +287,7 @@ def f6(*args: np.int32, col2: np.ndarray[np.int32]) -> bool:

with self.assertRaises(DHError) as cm:
t1 = t.update(["X1 = f6(X, Y=null)"])
self.assertIn("not compatible with annotation", str(cm.exception))
self.assertIn("f6: Expected argument (col2) to be one of [class [I], got boolean", str(cm.exception))

def test_str_bool_datetime_array(self):
with self.subTest("str"):
Expand All @@ -299,7 +299,7 @@ def f1(p1: np.ndarray[str], p2=None) -> bool:
self.assertEqual(t1.columns[2].data_type, dtypes.bool_)
with self.assertRaises(DHError) as cm:
t2 = t.update(["X1 = f1(null, Y )"])
self.assertRegex(str(cm.exception), "Argument .* is not compatible with annotation*")
self.assertRegex(str(cm.exception), "f1: Expected .* got null")

def f11(p1: Union[np.ndarray[str], None], p2=None) -> bool:
return bool(len(p1)) if p1 is not None else False
Expand All @@ -315,7 +315,7 @@ def f2(p1: np.ndarray[np.datetime64], p2=None) -> bool:
self.assertEqual(t1.columns[2].data_type, dtypes.bool_)
with self.assertRaises(DHError) as cm:
t2 = t.update(["X1 = f2(null, Y )"])
self.assertRegex(str(cm.exception), "Argument .* is not compatible with annotation*")
self.assertRegex(str(cm.exception), "f2: Expected .* got null")

def f21(p1: Union[np.ndarray[np.datetime64], None], p2=None) -> bool:
return bool(len(p1)) if p1 is not None else False
Expand All @@ -337,7 +337,7 @@ def f3(p1: np.ndarray[np.bool_], p2=None) -> bool:
self.assertEqual(t1.columns[2].data_type, dtypes.bool_)
with self.assertRaises(DHError) as cm:
t2 = t.update(["X1 = f3(null, Y )"])
self.assertRegex(str(cm.exception), "Argument 'p1': None is not compatible with annotation")
self.assertRegex(str(cm.exception), "f3: Expected .* got null")

def f31(p1: Optional[np.ndarray[bool]], p2=None) -> bool:
return bool(len(p1)) if p1 is not None else False
Expand Down Expand Up @@ -405,9 +405,9 @@ def f(x: {p_type}) -> bool: # note typing
t = empty_table(1).update(["X = i", f"Y = f(({p_type})X)"])
self.assertEqual(1, t.to_string(cols="Y").count("true"))


np_int_types = {"np.int8", "np.int16", "np.int32", "np.int64"}
for p_type in np_int_types:
def test_np_typehints(self):
widening_np_int_types = {"np.int32", "np.int64"}
for p_type in widening_np_int_types:
with self.subTest(p_type):
func_str = f"""
def f(x: {p_type}) -> bool: # note typing
Expand All @@ -417,8 +417,20 @@ def f(x: {p_type}) -> bool: # note typing
t = empty_table(1).update(["X = i", f"Y = f(X)"])
self.assertEqual(1, t.to_string(cols="Y").count("true"))

np_floating_types = {"np.float32", "np.float64"}
for p_type in np_floating_types:
narrowing_np_int_types = {"np.int8", "np.int16"}
for p_type in narrowing_np_int_types:
with self.subTest(p_type):
func_str = f"""
def f(x: {p_type}) -> bool: # note typing
return type(x) == {p_type}
"""
exec(func_str, globals())
with self.assertRaises(DHError) as cm:
t = empty_table(1).update(["X = i", f"Y = f(X)"])
self.assertRegex(str(cm.exception), "f: Expect")

widening_np_floating_types = {"np.float32", "np.float64"}
for p_type in widening_np_floating_types:
with self.subTest(p_type):
func_str = f"""
def f(x: {p_type}) -> bool: # note typing
Expand All @@ -428,6 +440,18 @@ def f(x: {p_type}) -> bool: # note typing
t = empty_table(1).update(["X = i", f"Y = f((float)X)"])
self.assertEqual(1, t.to_string(cols="Y").count("true"))

int_to_floating_types = {"np.float32", "np.float64"}
for p_type in int_to_floating_types:
with self.subTest(p_type):
func_str = f"""
def f(x: {p_type}) -> bool: # note typing
return type(x) == {p_type}
"""
exec(func_str, globals())
with self.assertRaises(DHError) as cm:
t = empty_table(1).update(["X = i", f"Y = f(X)"])
self.assertRegex(str(cm.exception), "f: Expect")

def test_np_typehints_mismatch(self):
def f(x: float) -> bool:
return True
Expand Down
2 changes: 1 addition & 1 deletion py/server/tests/test_udf_return_java_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def test_np_ufunc(self):
nbsin = numba.vectorize([numba.float64(numba.float64)])(np.sin)

# this is the workaround that utilizes vectorization and type inference
@numba.vectorize([numba.float64(numba.float64)], nopython=True)
@numba.vectorize([numba.float64(numba.int64)], nopython=True)
def nbsin(x):
return np.sin(x)
t3 = empty_table(10).update(["X3 = nbsin(i)"])
Expand Down
2 changes: 1 addition & 1 deletion py/server/tests/test_vectorization.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def sinc2(x):
self.assertEqual(t.columns[1].data_type, dtypes.PyObject)

def test_optional_annotations(self):
def pyfunc(p1: np.int32, p2: np.int32, p3: Optional[np.int32]) -> Optional[int]:
def pyfunc(p1: np.int32, p2: np.int64, p3: Optional[np.int32]) -> Optional[int]:
total = p1 + p2 + p3
return None if total % 3 == 0 else total

Expand Down

0 comments on commit 122e000

Please sign in to comment.