diff --git a/py/client/pydeephaven/_table_ops.py b/py/client/pydeephaven/_table_ops.py index cbb7830c546..ee367e2bb89 100644 --- a/py/client/pydeephaven/_table_ops.py +++ b/py/client/pydeephaven/_table_ops.py @@ -687,3 +687,26 @@ def make_grpc_request(self, result_id, source_id) -> Any: def make_grpc_request_for_batch(self, result_id, source_id) -> Any: return table_pb2.BatchTableRequest.Operation( meta_table=self.make_grpc_request(result_id=result_id, source_id=source_id)) + + +class MultijoinTablesOp(TableOp): + def __init__(self, multi_join_inputs: List["MultiJoinInput"]): + self.multi_join_inputs = multi_join_inputs + + @classmethod + def get_stub_func(cls, table_service_stub: table_pb2_grpc.TableServiceStub) -> Any: + return table_service_stub.MultiJoinTables + + def make_grpc_request(self, result_id, source_id) -> Any: + pb_inputs = [] + for mji in self.multi_join_inputs: + source_id = table_pb2.TableReference(ticket=mji.table.ticket) + columns_to_match = mji.on + columns_to_add = mji.joins + pb_inputs.append(table_pb2.MultiJoinInput(source_id=source_id, columns_to_match=columns_to_match, + columns_to_add=columns_to_add)) + return table_pb2.MultiJoinTablesRequest(result_id=result_id, multi_join_inputs=pb_inputs) + + def make_grpc_request_for_batch(self, result_id, source_id) -> Any: + return table_pb2.BatchTableRequest.Operation( + multi_join=self.make_grpc_request(result_id=result_id, source_id=source_id)) \ No newline at end of file diff --git a/py/client/pydeephaven/_table_service.py b/py/client/pydeephaven/_table_service.py index 749f43f5a0f..9ceb9ed5b0a 100644 --- a/py/client/pydeephaven/_table_service.py +++ b/py/client/pydeephaven/_table_service.py @@ -1,7 +1,7 @@ # # Copyright (c) 2016-2024 Deephaven Data Labs and Patent Pending # -from typing import Union, List +from typing import Union, List, Optional from pydeephaven._batch_assembler import BatchOpAssembler from pydeephaven._table_ops import TableOp @@ -37,7 +37,7 @@ def batch(self, ops: List[TableOp]) -> Table: except Exception as e: raise DHError("failed to finish the table batch operation.") from e - def grpc_table_op(self, table: Table, op: TableOp, table_class: type = Table) -> Union[Table, InputTable]: + def grpc_table_op(self, table: Optional[Table], op: TableOp, table_class: type = Table) -> Union[Table, InputTable]: """Makes a single gRPC Table operation call and returns a new Table.""" try: result_id = self.session.make_ticket() diff --git a/py/client/pydeephaven/session.py b/py/client/pydeephaven/session.py index 6581306b167..1aa8300b8f1 100644 --- a/py/client/pydeephaven/session.py +++ b/py/client/pydeephaven/session.py @@ -105,6 +105,7 @@ def random_ticket(cls) -> SharedTicket: bytes_ = uuid4().int.to_bytes(16, byteorder='little', signed=False) return cls(ticket_bytes=b'h' + bytes_) + _BidiRpc = NewType("_BidiRpc", grpc.StreamStreamMultiCallable) _NotBidiRpc = NewType( @@ -114,6 +115,7 @@ def random_ticket(cls) -> SharedTicket: grpc.UnaryStreamMultiCallable, grpc.StreamUnaryMultiCallable]) + class Session: """A Session object represents a connection to the Deephaven data server. It contains a number of convenience methods for asking the server to create tables, import Arrow data into tables, merge tables, run Python scripts, and @@ -426,7 +428,7 @@ def _connect(self): # started together don't align retries. skew = random() # Backoff schedule for retries after consecutive failures to refresh auth token - self._refresh_backoff = [ skew + 0.1, skew + 1, skew + 10 ] + self._refresh_backoff = [skew + 0.1, skew + 1, skew + 10] if self._refresh_backoff[0] > self._timeout_seconds: raise DHError(f'server configuration http.session.durationMs={session_duration} is too small.') diff --git a/py/client/pydeephaven/table.py b/py/client/pydeephaven/table.py index 6339a7199c3..70df8c0406a 100644 --- a/py/client/pydeephaven/table.py +++ b/py/client/pydeephaven/table.py @@ -6,12 +6,12 @@ from __future__ import annotations -from typing import List, Union +from typing import List, Union, Sequence import pyarrow as pa from pydeephaven._utils import to_list -from pydeephaven._table_ops import MetaTableOp, SortDirection +from pydeephaven._table_ops import MetaTableOp, SortDirection, MultijoinTablesOp from pydeephaven.agg import Aggregation from pydeephaven.dherror import DHError from pydeephaven._table_interface import TableInterface @@ -805,3 +805,67 @@ def delete(self, table: Table) -> None: self.session.input_table_service.delete(self, table) except Exception as e: raise DHError("delete data in the InputTable failed.") from e + + +class MultiJoinInput: + """A MultiJoinInput represents the input tables, key columns and additional columns to be used in the multi-table + natural join. + """ + table: Table + on: Union[str, Sequence[str]] + joins: Union[str, Sequence[str]] = None + + def __init__(self, table: Table, on: Union[str, Sequence[str]], joins: Union[str, Sequence[str]] = None): + """Initializes a MultiJoinInput object. + + Args: + table (Table): the right table to include in the join + on (Union[str, Sequence[str]]): the column(s) to match, can be a common name or an equality expression that + matches every input table, i.e. "col_a = col_b" to rename output column names. + joins (Union[str, Sequence[str]], optional): the column(s) to be added from the table to the result + table, can be renaming expressions, i.e. "new_col = col"; default is None + """ + self.table = table + self.on = to_list(on) + self.joins = to_list(joins) + + +def multi_join(input: Union[Table, Sequence[Table], MultiJoinInput, Sequence[MultiJoinInput]], + on: Union[str, Sequence[str]] = None) -> Table: + """ The multi_join method creates a new table by performing a multi-table natural join on the input tables. The + result consists of the set of distinct keys from the input tables natural joined to each input table. Input + tables need not have a matching row for each key, but they may not have multiple matching rows for a given key. + + Args: + input (Union[Table, Sequence[Table], MultiJoinInput, Sequence[MultiJoinInput]]): the input objects specifying the + tables and columns to include in the join. + on (Union[str, Sequence[str]], optional): the column(s) to match, can be a common name or an equality expression + that matches every input table, i.e. "col_a = col_b" to rename output column names. Note: When + MultiJoinInput objects are supplied, this parameter must be omitted. + + Returns: + Table: the result of the multi-table natural join operation. + + Raises: + DHError + """ + if isinstance(input, Table) or (isinstance(input, Sequence) and all(isinstance(t, Table) for t in input)): + tables = to_list(input) + session = tables[0].session + if not all([t.session == session for t in tables]): + raise DHError(message="all tables must be from the same session.") + multi_join_inputs = [MultiJoinInput(table=t, on=on) for t in tables] + elif isinstance(input, MultiJoinInput) or ( + isinstance(input, Sequence) and all(isinstance(ji, MultiJoinInput) for ji in input)): + if on is not None: + raise DHError(message="on parameter is not permitted when MultiJoinInput objects are provided.") + multi_join_inputs = to_list(input) + session = multi_join_inputs[0].table.session + if not all([mji.table.session == session for mji in multi_join_inputs]): + raise DHError(message="all tables must be from the same session.") + else: + raise DHError( + message="input must be a Table, a sequence of Tables, a MultiJoinInput, or a sequence of MultiJoinInputs.") + + table_op = MultijoinTablesOp(multi_join_inputs=multi_join_inputs) + return session.table_service.grpc_table_op(None, table_op, table_class=Table) diff --git a/py/client/tests/test_multijoin.py b/py/client/tests/test_multijoin.py new file mode 100644 index 00000000000..ceece494540 --- /dev/null +++ b/py/client/tests/test_multijoin.py @@ -0,0 +1,126 @@ +# +# Copyright (c) 2016-2024 Deephaven Data Labs and Patent Pending +# +import unittest + +from pyarrow import csv + +from pydeephaven import DHError, Session +from pydeephaven.table import MultiJoinInput, multi_join +from tests.testbase import BaseTestCase + + +class MultiJoinTestCase(BaseTestCase): + def setUp(self): + super().setUp() + pa_table = csv.read_csv(self.csv_file) + self.static_tableA = self.session.import_table(pa_table).select(["a", "b", "c1=c", "d1=d", "e1=e"]) + self.static_tableB = self.static_tableA.update(["c2=c1+1", "d2=d1+2", "e2=e1+3"]).drop_columns( + ["c1", "d1", "e1"]) + self.ticking_tableA = self.session.time_table("PT00:00:00.001").update( + ["a = i", "b = i*i % 13", "c1 = i * 13 % 23", "d1 = a + b", "e1 = a - b"]).drop_columns(["Timestamp"]) + self.ticking_tableB = self.ticking_tableA.update(["c2=c1+1", "d2=d1+2", "e2=e1+3"]).drop_columns( + ["c1", "d1", "e1"]) + + def tearDown(self) -> None: + self.static_tableA = None + self.static_tableB = None + self.ticking_tableA = None + self.ticking_tableB = None + super().tearDown() + + def test_static_simple(self): + # Test with multiple input tables + mj_table = multi_join(input=[self.static_tableA, self.static_tableB], on=["a", "b"]) + + # Output table is static + self.assertFalse(mj_table.is_refreshing) + # Output table has same # rows as sources + self.assertEqual(mj_table.size, self.static_tableA.size) + self.assertEqual(mj_table.size, self.static_tableB.size) + + # Test with a single input table + mj_table = multi_join(self.static_tableA, ["a", "b"]) + + # Output table is static + self.assertFalse(mj_table.is_refreshing) + # Output table has same # rows as sources + self.assertEqual(mj_table.size, self.static_tableA.size) + + def test_ticking_simple(self): + # Test with multiple input tables + mj_table = multi_join(input=[self.ticking_tableA, self.ticking_tableB], on=["a", "b"]) + + # Output table is refreshing + self.assertTrue(mj_table.is_refreshing) + + # Test with a single input table + mj_table = multi_join(input=self.ticking_tableA, on=["a", "b"]) + + # Output table is refreshing + self.assertTrue(mj_table.is_refreshing) + + def test_static(self): + # Test with multiple input + mj_input = [ + MultiJoinInput(table=self.static_tableA, on=["key1=a", "key2=b"], joins=["c1", "e1"]), + MultiJoinInput(table=self.static_tableB, on=["key1=a", "key2=b"], joins=["d2"]) + ] + mj_table = multi_join(mj_input) + + # Output table is static + self.assertFalse(mj_table.is_refreshing) + # Output table has same # rows as sources + self.assertEqual(mj_table.size, self.static_tableA.size) + self.assertEqual(mj_table.size, self.static_tableB.size) + + # Test with a single input + mj_table = multi_join(MultiJoinInput(table=self.static_tableA, on=["key1=a", "key2=b"], joins="c1")) + + # Output table is static + self.assertFalse(mj_table.is_refreshing) + # Output table has same # rows as sources + self.assertEqual(mj_table.size, self.static_tableA.size) + + def test_ticking(self): + # Test with multiple input + mj_input = [ + MultiJoinInput(table=self.ticking_tableA, on=["key1=a", "key2=b"], joins=["c1", "e1"]), + MultiJoinInput(table=self.ticking_tableB, on=["key1=a", "key2=b"], joins=["d2"]) + ] + mj_table = multi_join(mj_input) + + # Output table is refreshing + self.assertTrue(mj_table.is_refreshing) + + # Test with a single input + mj_table = multi_join(input=MultiJoinInput(table=self.ticking_tableA, on=["key1=a", "key2=b"], joins="c1")) + + # Output table is refreshing + self.assertTrue(mj_table.is_refreshing) + + def test_errors(self): + # Assert the exception is raised when providing MultiJoinInput and the on parameter is not None (omitted). + mj_input = [ + MultiJoinInput(table=self.ticking_tableA, on=["key1=a", "key2=b"], joins=["c1", "e1"]), + MultiJoinInput(table=self.ticking_tableB, on=["key1=a", "key2=b"], joins=["d2"]) + ] + with self.assertRaises(DHError): + mj_table = multi_join(mj_input, on=["key1=a", "key2=b"]) + + session = Session() + t = session.time_table("PT00:00:00.001").update( + ["a = i", "b = i*i % 13", "c1 = i * 13 % 23", "d1 = a + b", "e1 = a - b"]).drop_columns(["Timestamp"]) + + # Assert the exception is raised when to-be-joined tables are not from the same session. + mj_input = [ + MultiJoinInput(table=self.ticking_tableA, on=["key1=a", "key2=b"], joins=["c1", "e1"]), + MultiJoinInput(table=t, on=["key1=a", "key2=b"], joins=["d2"]) + ] + with self.assertRaises(DHError) as cm: + mj_table = multi_join(mj_input) + self.assertIn("all tables must be from the same session", str(cm.exception)) + + +if __name__ == '__main__': + unittest.main() diff --git a/py/server/deephaven/table.py b/py/server/deephaven/table.py index abb22e1031c..8a65cdada81 100644 --- a/py/server/deephaven/table.py +++ b/py/server/deephaven/table.py @@ -3782,7 +3782,7 @@ def __init__(self, table: Table, on: Union[str, Sequence[str]], joins: Union[str table (Table): the right table to include in the join on (Union[str, Sequence[str]]): the column(s) to match, can be a common name or an equal expression, i.e. "col_a = col_b" for different column names - joins (Union[str, Sequence[str]], optional): the column(s) to be added from the this table to the result + joins (Union[str, Sequence[str]], optional): the column(s) to be added from the table to the result table, can be renaming expressions, i.e. "new_col = col"; default is None Raises: