From c301ea680df70bdbfdaed9732fd90a2dc8899694 Mon Sep 17 00:00:00 2001 From: arman-ddl <122062464+arman-ddl@users.noreply.github.com> Date: Thu, 21 Mar 2024 09:16:15 -0700 Subject: [PATCH] Fix input table jpy restriction (#5260) * Fix input table jpy restriction * Remove unnecessary jpy type * Wrap java object with the most specific wrapper * Fix docstring format issue * Apply suggestions from code review Co-authored-by: Chip Kent <5250374+chipkent@users.noreply.github.com> * Respond to review comments * Tiny coding formatting issue * Update py/server/deephaven/_wrapper.py Co-authored-by: Chip Kent <5250374+chipkent@users.noreply.github.com> * Add one more test case * Empty Commit to force rerun of CI * Respond to latest review comments --------- Co-authored-by: jianfengmao Co-authored-by: Jianfeng Mao <4297243+jmao-denver@users.noreply.github.com> Co-authored-by: Chip Kent <5250374+chipkent@users.noreply.github.com> --- py/server/deephaven/_wrapper.py | 51 +++++++++++++++++++-------- py/server/deephaven/table_factory.py | 5 +-- py/server/tests/test_table_factory.py | 32 +++++++++++++++-- 3 files changed, 70 insertions(+), 18 deletions(-) diff --git a/py/server/deephaven/_wrapper.py b/py/server/deephaven/_wrapper.py index cc11e7e835f..59df56be97c 100644 --- a/py/server/deephaven/_wrapper.py +++ b/py/server/deephaven/_wrapper.py @@ -13,7 +13,7 @@ import pkgutil import sys from abc import ABC, abstractmethod -from typing import Set, Union, Optional, Any +from typing import Set, Union, Optional, Any, List import jpy @@ -102,8 +102,8 @@ def _is_direct_initialisable(cls) -> bool: return False -def _lookup_wrapped_class(j_obj: jpy.JType) -> Optional[type]: - """ Returns the wrapper class for the specified Java object. """ +def _lookup_wrapped_class(j_obj: jpy.JType) -> List[JObjectWrapper]: + """ Returns the wrapper classes for the specified Java object. """ # load every module in the deephaven package so that all the wrapper classes are loaded and available to wrap # the Java objects returned by calling resolve() global _has_all_wrappers_imported @@ -111,12 +111,7 @@ def _lookup_wrapped_class(j_obj: jpy.JType) -> Optional[type]: _recursive_import(__package__.partition(".")[0]) _has_all_wrappers_imported = True - for wc in _di_wrapper_classes: - j_clz = wc.j_object_type - if j_clz.jclass.isInstance(j_obj): - return wc - - return None + return [wc for wc in _di_wrapper_classes if wc.j_object_type.jclass.isInstance(j_obj)] def javaify(obj: Any) -> Optional[jpy.JType]: @@ -162,15 +157,43 @@ def pythonify(j_obj: Any) -> Optional[Any]: return wrap_j_object(j_obj) -def wrap_j_object(j_obj: jpy.JType) -> Union[JObjectWrapper, jpy.JType]: - """ Wraps the specified Java object as an instance of a custom wrapper class if one is available, otherwise returns - the raw Java object. """ +def _wrap_with_subclass(j_obj: jpy.JType, cls: type) -> Optional[JObjectWrapper]: + """ Returns a wrapper instance for the specified Java object by trying the entire subclasses' hierarchy. The + function employs a Depth First Search strategy to try the most specific subclass first. If no matching wrapper class is found, + returns None. + + The premises for this function are as follows: + - The subclasses all share the same class attribute `j_object_type` (guaranteed by subclassing JObjectWrapper) + - The subclasses are all direct initialisable (guaranteed by subclassing JObjectWrapper) + - The subclasses are all distinct from each other and check for their uniqueness in the initializer (__init__), e.g. + InputTable checks for the presence of the INPUT_TABLE_ATTRIBUTE attribute on the Java object. + """ + for subclass in cls.__subclasses__(): + try: + if (wrapper := _wrap_with_subclass(j_obj, subclass)) is not None: + return wrapper + return subclass(j_obj) + except: + continue + return None + + +def wrap_j_object(j_obj: jpy.JType) -> Optional[Union[JObjectWrapper, jpy.JType]]: + """ Wraps the specified Java object as an instance of the most specific custom wrapper class if one is available, + otherwise returns the raw Java object. """ if j_obj is None: return None - wc = _lookup_wrapped_class(j_obj) + wcs = _lookup_wrapped_class(j_obj) + for wc in wcs: + try: + if (wrapper:= _wrap_with_subclass(j_obj, wc)) is not None: + return wrapper + return wc(j_obj) + except: + continue - return wc(j_obj) if wc else j_obj + return j_obj def unwrap(obj: Any) -> Union[jpy.JType, Any]: diff --git a/py/server/deephaven/table_factory.py b/py/server/deephaven/table_factory.py index c9689b84c77..7efff4b534c 100644 --- a/py/server/deephaven/table_factory.py +++ b/py/server/deephaven/table_factory.py @@ -23,13 +23,13 @@ _JTableFactory = jpy.get_type("io.deephaven.engine.table.TableFactory") _JTableTools = jpy.get_type("io.deephaven.engine.util.TableTools") _JDynamicTableWriter = jpy.get_type("io.deephaven.engine.table.impl.util.DynamicTableWriter") -_JBaseArrayBackedInputTable = jpy.get_type("io.deephaven.engine.table.impl.util.BaseArrayBackedInputTable") _JAppendOnlyArrayBackedInputTable = jpy.get_type( "io.deephaven.engine.table.impl.util.AppendOnlyArrayBackedInputTable") _JKeyedArrayBackedInputTable = jpy.get_type("io.deephaven.engine.table.impl.util.KeyedArrayBackedInputTable") _JTableDefinition = jpy.get_type("io.deephaven.engine.table.TableDefinition") _JTable = jpy.get_type("io.deephaven.engine.table.Table") _J_INPUT_TABLE_ATTRIBUTE = _JTable.INPUT_TABLE_ATTRIBUTE +_J_InputTableUpdater = jpy.get_type("io.deephaven.engine.util.input.InputTableUpdater") _JRingTableTools = jpy.get_type("io.deephaven.engine.table.impl.sources.ring.RingTableTools") _JSupplier = jpy.get_type('java.util.function.Supplier') _JFunctionGeneratedTableFactory = jpy.get_type("io.deephaven.engine.table.impl.util.FunctionGeneratedTableFactory") @@ -235,13 +235,14 @@ class InputTable(Table): Users should always create InputTables through factory methods rather than directly from the constructor. """ - j_object_type = _JBaseArrayBackedInputTable def __init__(self, j_table: jpy.JType): super().__init__(j_table) self.j_input_table = self.j_table.getAttribute(_J_INPUT_TABLE_ATTRIBUTE) if not self.j_input_table: raise DHError("the provided table input is not suitable for input tables.") + if not _J_InputTableUpdater.jclass.isInstance(self.j_input_table): + raise DHError("the provided table's InputTable attribute type is not of InputTableUpdater type.") def add(self, table: Table) -> None: """Synchronously writes rows from the provided table to this input table. If this is a keyed input table, added rows with keys diff --git a/py/server/tests/test_table_factory.py b/py/server/tests/test_table_factory.py index 0e4aa080ad1..3b1cfe55062 100644 --- a/py/server/tests/test_table_factory.py +++ b/py/server/tests/test_table_factory.py @@ -9,11 +9,11 @@ import numpy as np from deephaven import DHError, read_csv, time_table, empty_table, merge, merge_sorted, dtypes, new_table, \ - input_table, time + input_table, time, _wrapper from deephaven.column import byte_col, char_col, short_col, bool_col, int_col, long_col, float_col, double_col, \ string_col, datetime_col, pyobj_col, jobj_col from deephaven.constants import NULL_DOUBLE, NULL_FLOAT, NULL_LONG, NULL_INT, NULL_SHORT, NULL_BYTE -from deephaven.table_factory import DynamicTableWriter, ring_table +from deephaven.table_factory import DynamicTableWriter, InputTable, ring_table from tests.testbase import BaseTestCase from deephaven.table import Table from deephaven.stream import blink_to_append_only, stream_to_append_only @@ -21,6 +21,7 @@ JArrayList = jpy.get_type("java.util.ArrayList") _JBlinkTableTools = jpy.get_type("io.deephaven.engine.table.impl.BlinkTableTools") _JDateTimeUtils = jpy.get_type("io.deephaven.time.DateTimeUtils") +_JTable = jpy.get_type("io.deephaven.engine.table.Table") @dataclass @@ -372,6 +373,16 @@ def test_input_table(self): keyed_input_table.delete(t.select(["String", "Double"])) self.assertEqual(keyed_input_table.size, 0) + with self.subTest("custom input table creation"): + place_holder_input_table = empty_table(1).update_view(["Key=`A`", "Value=10"]).with_attributes({_JTable.INPUT_TABLE_ATTRIBUTE: "Placeholder IT"}).j_table + + with self.assertRaises(DHError) as cm: + InputTable(place_holder_input_table) + self.assertIn("not of InputTableUpdater type", str(cm.exception)) + + self.assertTrue(isinstance(_wrapper.wrap_j_object(place_holder_input_table), Table)) + + def test_ring_table(self): cols = [ bool_col(name="Boolean", data=[True, False]), @@ -450,5 +461,22 @@ def test_input_table_empty_data(self): it.delete(t) self.assertEqual(it.size, 0) + def test_j_input_wrapping(self): + cols = [ + bool_col(name="Boolean", data=[True, False]), + string_col(name="String", data=["foo", "bar"]), + ] + t = new_table(cols=cols) + col_defs = {c.name: c.data_type for c in t.columns} + append_only_input_table = input_table(col_defs=col_defs) + + it = _wrapper.wrap_j_object(append_only_input_table.j_table) + self.assertTrue(isinstance(it, InputTable)) + + t = _wrapper.wrap_j_object(t.j_object) + self.assertFalse(isinstance(t, InputTable)) + self.assertTrue(isinstance(t, Table)) + + if __name__ == '__main__': unittest.main()