Skip to content

Commit

Permalink
Make comment more clear and new test case
Browse files Browse the repository at this point in the history
  • Loading branch information
jmao-denver committed Mar 25, 2024
1 parent b276280 commit 31a33a7
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2511,8 +2511,9 @@ private void verifyPyCallableArguments(@NotNull MethodCallExpr n, @NotNull Class
final String invokedMethodName = n.getNameAsString();

if (GET_ATTRIBUTE_METHOD_NAME.equals(invokedMethodName)) {
// Only PyCallableWrapper.getAttribute()/PyCallableWrapper.call() may be invoked from the query language.
// UDF type checks are not currently supported for getAttribute() calls.
// Currently Python UDF handling is only supported for top module level function(callable) calls.
// The getAttribute() calls which is needed to support Python method calls, which is beyond the scope of
// current implementation. So we are skipping the argument verification for getAttribute() calls.
return;
}
if (!n.containsData(QueryLanguageParserDataKeys.PY_CALLABLE_DETAILS)) {
Expand Down
21 changes: 20 additions & 1 deletion py/server/tests/test_udf_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import numpy as np
import numpy.typing as npt

from deephaven import empty_table, DHError, dtypes
from deephaven import empty_table, DHError, dtypes, new_table
from deephaven.column import int_col
from deephaven.dtypes import double_array, int32_array, long_array, int16_array, char_array, int8_array, \
float32_array
from tests.testbase import BaseTestCase
Expand Down Expand Up @@ -496,5 +497,23 @@ def f2(x: float) -> int:
t = empty_table(1).update("X = f2(f1(ii))")
self.assertEqual(t.columns[0].data_type, dtypes.int_)

def test_varargs(self):
cols = ["A", "B", "C", "D"]

def my_sum(p1: np.int32, *args: np.int64) -> int:
return sum(args)

t = new_table([int_col(c, [0, 1, 2, 3, 4, 5, 6]) for c in cols])
result = t.update(f"X = my_sum({','.join(cols)})")
self.assertEqual(result.columns[4].data_type, dtypes.int64)

def my_sum_error(p1: np.int32, *args: np.int16) -> int:
return sum(args)
with self.assertRaises(DHError) as cm:
t.update(f"X = my_sum_error({','.join(cols)})")
self.assertRegex(str(cm.exception), "my_sum_error: Expected argument .* got int")



if __name__ == "__main__":
unittest.main()

0 comments on commit 31a33a7

Please sign in to comment.