diff --git a/engine/table/src/main/java/io/deephaven/engine/table/impl/lang/QueryLanguageParser.java b/engine/table/src/main/java/io/deephaven/engine/table/impl/lang/QueryLanguageParser.java index 83109869bad..ae037b986d6 100644 --- a/engine/table/src/main/java/io/deephaven/engine/table/impl/lang/QueryLanguageParser.java +++ b/engine/table/src/main/java/io/deephaven/engine/table/impl/lang/QueryLanguageParser.java @@ -2511,8 +2511,9 @@ private void verifyPyCallableArguments(@NotNull MethodCallExpr n, @NotNull Class final String invokedMethodName = n.getNameAsString(); if (GET_ATTRIBUTE_METHOD_NAME.equals(invokedMethodName)) { - // Only PyCallableWrapper.getAttribute()/PyCallableWrapper.call() may be invoked from the query language. - // UDF type checks are not currently supported for getAttribute() calls. + // Currently Python UDF handling is only supported for top module level function(callable) calls. + // The getAttribute() calls which is needed to support Python method calls, which is beyond the scope of + // current implementation. So we are skipping the argument verification for getAttribute() calls. return; } if (!n.containsData(QueryLanguageParserDataKeys.PY_CALLABLE_DETAILS)) { diff --git a/py/server/tests/test_udf_args.py b/py/server/tests/test_udf_args.py index 9e1c4a1cfc5..0f8f875ffa8 100644 --- a/py/server/tests/test_udf_args.py +++ b/py/server/tests/test_udf_args.py @@ -8,7 +8,8 @@ import numpy as np import numpy.typing as npt -from deephaven import empty_table, DHError, dtypes +from deephaven import empty_table, DHError, dtypes, new_table +from deephaven.column import int_col from deephaven.dtypes import double_array, int32_array, long_array, int16_array, char_array, int8_array, \ float32_array from tests.testbase import BaseTestCase @@ -496,5 +497,23 @@ def f2(x: float) -> int: t = empty_table(1).update("X = f2(f1(ii))") self.assertEqual(t.columns[0].data_type, dtypes.int_) + def test_varargs(self): + cols = ["A", "B", "C", "D"] + + def my_sum(p1: np.int32, *args: np.int64) -> int: + return sum(args) + + t = new_table([int_col(c, [0, 1, 2, 3, 4, 5, 6]) for c in cols]) + result = t.update(f"X = my_sum({','.join(cols)})") + self.assertEqual(result.columns[4].data_type, dtypes.int64) + + def my_sum_error(p1: np.int32, *args: np.int16) -> int: + return sum(args) + with self.assertRaises(DHError) as cm: + t.update(f"X = my_sum_error({','.join(cols)})") + self.assertRegex(str(cm.exception), "my_sum_error: Expected argument .* got int") + + + if __name__ == "__main__": unittest.main()