From 2ebce477af13559b6263df540a2304a0e85905ea Mon Sep 17 00:00:00 2001 From: jianfengmao Date: Wed, 4 Sep 2024 11:26:20 -0600 Subject: [PATCH 1/2] Add multi_join() in the Session class --- py/client/pydeephaven/_table_ops.py | 23 +++++ py/client/pydeephaven/_table_service.py | 4 +- py/client/pydeephaven/session.py | 66 +++++++++++++- py/client/tests/test_multijoin.py | 113 ++++++++++++++++++++++++ py/server/deephaven/table.py | 2 +- 5 files changed, 202 insertions(+), 6 deletions(-) create mode 100644 py/client/tests/test_multijoin.py diff --git a/py/client/pydeephaven/_table_ops.py b/py/client/pydeephaven/_table_ops.py index cbb7830c546..ee367e2bb89 100644 --- a/py/client/pydeephaven/_table_ops.py +++ b/py/client/pydeephaven/_table_ops.py @@ -687,3 +687,26 @@ def make_grpc_request(self, result_id, source_id) -> Any: def make_grpc_request_for_batch(self, result_id, source_id) -> Any: return table_pb2.BatchTableRequest.Operation( meta_table=self.make_grpc_request(result_id=result_id, source_id=source_id)) + + +class MultijoinTablesOp(TableOp): + def __init__(self, multi_join_inputs: List["MultiJoinInput"]): + self.multi_join_inputs = multi_join_inputs + + @classmethod + def get_stub_func(cls, table_service_stub: table_pb2_grpc.TableServiceStub) -> Any: + return table_service_stub.MultiJoinTables + + def make_grpc_request(self, result_id, source_id) -> Any: + pb_inputs = [] + for mji in self.multi_join_inputs: + source_id = table_pb2.TableReference(ticket=mji.table.ticket) + columns_to_match = mji.on + columns_to_add = mji.joins + pb_inputs.append(table_pb2.MultiJoinInput(source_id=source_id, columns_to_match=columns_to_match, + columns_to_add=columns_to_add)) + return table_pb2.MultiJoinTablesRequest(result_id=result_id, multi_join_inputs=pb_inputs) + + def make_grpc_request_for_batch(self, result_id, source_id) -> Any: + return table_pb2.BatchTableRequest.Operation( + multi_join=self.make_grpc_request(result_id=result_id, source_id=source_id)) \ No newline at end of file diff --git a/py/client/pydeephaven/_table_service.py b/py/client/pydeephaven/_table_service.py index 749f43f5a0f..9ceb9ed5b0a 100644 --- a/py/client/pydeephaven/_table_service.py +++ b/py/client/pydeephaven/_table_service.py @@ -1,7 +1,7 @@ # # Copyright (c) 2016-2024 Deephaven Data Labs and Patent Pending # -from typing import Union, List +from typing import Union, List, Optional from pydeephaven._batch_assembler import BatchOpAssembler from pydeephaven._table_ops import TableOp @@ -37,7 +37,7 @@ def batch(self, ops: List[TableOp]) -> Table: except Exception as e: raise DHError("failed to finish the table batch operation.") from e - def grpc_table_op(self, table: Table, op: TableOp, table_class: type = Table) -> Union[Table, InputTable]: + def grpc_table_op(self, table: Optional[Table], op: TableOp, table_class: type = Table) -> Union[Table, InputTable]: """Makes a single gRPC Table operation call and returns a new Table.""" try: result_id = self.session.make_ticket() diff --git a/py/client/pydeephaven/session.py b/py/client/pydeephaven/session.py index 6581306b167..29448a9998c 100644 --- a/py/client/pydeephaven/session.py +++ b/py/client/pydeephaven/session.py @@ -10,7 +10,7 @@ import os from random import random import threading -from typing import Any, Dict, Iterable, List, Union, Tuple, NewType +from typing import Any, Dict, Iterable, List, Union, Tuple, NewType, Sequence from uuid import uuid4 import grpc @@ -26,7 +26,8 @@ from pydeephaven._input_table_service import InputTableService from pydeephaven._plugin_obj_service import PluginObjService from pydeephaven._session_service import SessionService -from pydeephaven._table_ops import TimeTableOp, EmptyTableOp, MergeTablesOp, FetchTableOp, CreateInputTableOp +from pydeephaven._table_ops import TimeTableOp, EmptyTableOp, MergeTablesOp, FetchTableOp, CreateInputTableOp, \ + MultijoinTablesOp from pydeephaven._table_service import TableService from pydeephaven._utils import to_list from pydeephaven.dherror import DHError @@ -105,6 +106,7 @@ def random_ticket(cls) -> SharedTicket: bytes_ = uuid4().int.to_bytes(16, byteorder='little', signed=False) return cls(ticket_bytes=b'h' + bytes_) + _BidiRpc = NewType("_BidiRpc", grpc.StreamStreamMultiCallable) _NotBidiRpc = NewType( @@ -114,6 +116,7 @@ def random_ticket(cls) -> SharedTicket: grpc.UnaryStreamMultiCallable, grpc.StreamUnaryMultiCallable]) + class Session: """A Session object represents a connection to the Deephaven data server. It contains a number of convenience methods for asking the server to create tables, import Arrow data into tables, merge tables, run Python scripts, and @@ -426,7 +429,7 @@ def _connect(self): # started together don't align retries. skew = random() # Backoff schedule for retries after consecutive failures to refresh auth token - self._refresh_backoff = [ skew + 0.1, skew + 1, skew + 10 ] + self._refresh_backoff = [skew + 0.1, skew + 1, skew + 10] if self._refresh_backoff[0] > self._timeout_seconds: raise DHError(f'server configuration http.session.durationMs={session_duration} is too small.') @@ -734,3 +737,60 @@ def plugin_client(self, exportable_obj: ticket_pb2.TypedTicket) -> PluginClient: Part of the experimental plugin API.""" return PluginClient(self, exportable_obj) + + def multi_join(self, input: Union[Table, Sequence[Table], MultiJoinInput, Sequence[MultiJoinInput]], + on: Union[str, Sequence[str]] = None) -> Table: + """ The multi_join method creates a new table by performing a multi-table natural join on the input tables. The + result consists of the set of distinct keys from the input tables natural joined to each input table. Input + tables need not have a matching row for each key, but they may not have multiple matching rows for a given key. + + Args: + input (Union[Table, Sequence[Table], MultiJoinInput, Sequence[MultiJoinInput]]): the input objects specifying the + tables and columns to include in the join. + on (Union[str, Sequence[str]], optional): the column(s) to match, can be a common name or an equality expression + that matches every input table, i.e. "col_a = col_b" to rename output column names. Note: When + MultiJoinInput objects are supplied, this parameter must be omitted. + + Returns: + Table: the result of the multi-table natural join operation. + + Raises: + DHError + """ + if isinstance(input, Table) or (isinstance(input, Sequence) and all(isinstance(t, Table) for t in input)): + tables = to_list(input) + multi_join_inputs = [MultiJoinInput(table=t, on=on) for t in tables] + elif isinstance(input, MultiJoinInput) or ( + isinstance(input, Sequence) and all(isinstance(ji, MultiJoinInput) for ji in input)): + if on is not None: + raise DHError(message="on parameter is not permitted when MultiJoinInput objects are provided.") + multi_join_inputs = to_list(input) + else: + raise DHError( + message="input must be a Table, a sequence of Tables, a MultiJoinInput, or a sequence of MultiJoinInputs.") + + table_op = MultijoinTablesOp(multi_join_inputs=multi_join_inputs) + return self.table_service.grpc_table_op(None, table_op, table_class=Table) + + +class MultiJoinInput: + """A MultiJoinInput represents the input tables, key columns and additional columns to be used in the multi-table + natural join. + """ + table: Table + on: Union[str, Sequence[str]] + joins: Union[str, Sequence[str]] = None + + def __init__(self, table: Table, on: Union[str, Sequence[str]], joins: Union[str, Sequence[str]] = None): + """Initializes a MultiJoinInput object. + + Args: + table (Table): the right table to include in the join + on (Union[str, Sequence[str]]): the column(s) to match, can be a common name or an equality expression that + matches every input table, i.e. "col_a = col_b" to rename output column names. + joins (Union[str, Sequence[str]], optional): the column(s) to be added from the table to the result + table, can be renaming expressions, i.e. "new_col = col"; default is None + """ + self.table = table + self.on = to_list(on) + self.joins = to_list(joins) diff --git a/py/client/tests/test_multijoin.py b/py/client/tests/test_multijoin.py new file mode 100644 index 00000000000..8d9f3a09784 --- /dev/null +++ b/py/client/tests/test_multijoin.py @@ -0,0 +1,113 @@ +# +# Copyright (c) 2016-2024 Deephaven Data Labs and Patent Pending +# +import unittest + +from pyarrow import csv + +from pydeephaven import DHError +from pydeephaven.session import MultiJoinInput +from tests.testbase import BaseTestCase + + +class MultiJoinTestCase(BaseTestCase): + def setUp(self): + super().setUp() + pa_table = csv.read_csv(self.csv_file) + self.static_tableA = self.session.import_table(pa_table).select(["a", "b", "c1=c", "d1=d", "e1=e"]) + self.static_tableB = self.static_tableA.update(["c2=c1+1", "d2=d1+2", "e2=e1+3"]).drop_columns( + ["c1", "d1", "e1"]) + self.ticking_tableA = self.session.time_table("PT00:00:00.001").update( + ["a = i", "b = i*i % 13", "c1 = i * 13 % 23", "d1 = a + b", "e1 = a - b"]).drop_columns(["Timestamp"]) + self.ticking_tableB = self.ticking_tableA.update(["c2=c1+1", "d2=d1+2", "e2=e1+3"]).drop_columns( + ["c1", "d1", "e1"]) + + def tearDown(self) -> None: + self.static_tableA = None + self.static_tableB = None + self.ticking_tableA = None + self.ticking_tableB = None + super().tearDown() + + def test_static_simple(self): + # Test with multiple input tables + mj_table = self.session.multi_join(input=[self.static_tableA, self.static_tableB], on=["a", "b"]) + + # Output table is static + self.assertFalse(mj_table.is_refreshing) + # Output table has same # rows as sources + self.assertEqual(mj_table.size, self.static_tableA.size) + self.assertEqual(mj_table.size, self.static_tableB.size) + + # Test with a single input table + mj_table = self.session.multi_join(self.static_tableA, ["a", "b"]) + + # Output table is static + self.assertFalse(mj_table.is_refreshing) + # Output table has same # rows as sources + self.assertEqual(mj_table.size, self.static_tableA.size) + + def test_ticking_simple(self): + # Test with multiple input tables + mj_table = self.session.multi_join(input=[self.ticking_tableA, self.ticking_tableB], on=["a", "b"]) + + # Output table is refreshing + self.assertTrue(mj_table.is_refreshing) + + # Test with a single input table + mj_table = self.session.multi_join(input=self.ticking_tableA, on=["a", "b"]) + + # Output table is refreshing + self.assertTrue(mj_table.is_refreshing) + + def test_static(self): + # Test with multiple input + mj_input = [ + MultiJoinInput(table=self.static_tableA, on=["key1=a", "key2=b"], joins=["c1", "e1"]), + MultiJoinInput(table=self.static_tableB, on=["key1=a", "key2=b"], joins=["d2"]) + ] + mj_table = self.session.multi_join(mj_input) + + # Output table is static + self.assertFalse(mj_table.is_refreshing) + # Output table has same # rows as sources + self.assertEqual(mj_table.size, self.static_tableA.size) + self.assertEqual(mj_table.size, self.static_tableB.size) + + # Test with a single input + mj_table = self.session.multi_join(MultiJoinInput(table=self.static_tableA, on=["key1=a", "key2=b"], joins="c1")) + + # Output table is static + self.assertFalse(mj_table.is_refreshing) + # Output table has same # rows as sources + self.assertEqual(mj_table.size, self.static_tableA.size) + + def test_ticking(self): + # Test with multiple input + mj_input = [ + MultiJoinInput(table=self.ticking_tableA, on=["key1=a", "key2=b"], joins=["c1", "e1"]), + MultiJoinInput(table=self.ticking_tableB, on=["key1=a", "key2=b"], joins=["d2"]) + ] + mj_table = self.session.multi_join(mj_input) + + # Output table is refreshing + self.assertTrue(mj_table.is_refreshing) + + # Test with a single input + mj_table = self.session.multi_join(input=MultiJoinInput(table=self.ticking_tableA, on=["key1=a", "key2=b"], joins="c1")) + + # Output table is refreshing + self.assertTrue(mj_table.is_refreshing) + + def test_errors(self): + # Assert the exception is raised when providing MultiJoinInput and the on parameter is not None (omitted). + mj_input = [ + MultiJoinInput(table=self.ticking_tableA, on=["key1=a", "key2=b"], joins=["c1", "e1"]), + MultiJoinInput(table=self.ticking_tableB, on=["key1=a", "key2=b"], joins=["d2"]) + ] + with self.assertRaises(DHError): + mj_table = self.session.multi_join(mj_input, on=["key1=a", "key2=b"]) + + +if __name__ == '__main__': + unittest.main() diff --git a/py/server/deephaven/table.py b/py/server/deephaven/table.py index abb22e1031c..8a65cdada81 100644 --- a/py/server/deephaven/table.py +++ b/py/server/deephaven/table.py @@ -3782,7 +3782,7 @@ def __init__(self, table: Table, on: Union[str, Sequence[str]], joins: Union[str table (Table): the right table to include in the join on (Union[str, Sequence[str]]): the column(s) to match, can be a common name or an equal expression, i.e. "col_a = col_b" for different column names - joins (Union[str, Sequence[str]], optional): the column(s) to be added from the this table to the result + joins (Union[str, Sequence[str]], optional): the column(s) to be added from the table to the result table, can be renaming expressions, i.e. "new_col = col"; default is None Raises: From 90654c392696b916c2982af5c5b319d9a5f7cf94 Mon Sep 17 00:00:00 2001 From: jianfengmao Date: Wed, 18 Sep 2024 15:31:53 -0600 Subject: [PATCH 2/2] Move the new interface to table.py --- py/client/pydeephaven/session.py | 62 +--------------------------- py/client/pydeephaven/table.py | 68 ++++++++++++++++++++++++++++++- py/client/tests/test_multijoin.py | 35 +++++++++++----- 3 files changed, 92 insertions(+), 73 deletions(-) diff --git a/py/client/pydeephaven/session.py b/py/client/pydeephaven/session.py index 29448a9998c..1aa8300b8f1 100644 --- a/py/client/pydeephaven/session.py +++ b/py/client/pydeephaven/session.py @@ -10,7 +10,7 @@ import os from random import random import threading -from typing import Any, Dict, Iterable, List, Union, Tuple, NewType, Sequence +from typing import Any, Dict, Iterable, List, Union, Tuple, NewType from uuid import uuid4 import grpc @@ -26,8 +26,7 @@ from pydeephaven._input_table_service import InputTableService from pydeephaven._plugin_obj_service import PluginObjService from pydeephaven._session_service import SessionService -from pydeephaven._table_ops import TimeTableOp, EmptyTableOp, MergeTablesOp, FetchTableOp, CreateInputTableOp, \ - MultijoinTablesOp +from pydeephaven._table_ops import TimeTableOp, EmptyTableOp, MergeTablesOp, FetchTableOp, CreateInputTableOp from pydeephaven._table_service import TableService from pydeephaven._utils import to_list from pydeephaven.dherror import DHError @@ -737,60 +736,3 @@ def plugin_client(self, exportable_obj: ticket_pb2.TypedTicket) -> PluginClient: Part of the experimental plugin API.""" return PluginClient(self, exportable_obj) - - def multi_join(self, input: Union[Table, Sequence[Table], MultiJoinInput, Sequence[MultiJoinInput]], - on: Union[str, Sequence[str]] = None) -> Table: - """ The multi_join method creates a new table by performing a multi-table natural join on the input tables. The - result consists of the set of distinct keys from the input tables natural joined to each input table. Input - tables need not have a matching row for each key, but they may not have multiple matching rows for a given key. - - Args: - input (Union[Table, Sequence[Table], MultiJoinInput, Sequence[MultiJoinInput]]): the input objects specifying the - tables and columns to include in the join. - on (Union[str, Sequence[str]], optional): the column(s) to match, can be a common name or an equality expression - that matches every input table, i.e. "col_a = col_b" to rename output column names. Note: When - MultiJoinInput objects are supplied, this parameter must be omitted. - - Returns: - Table: the result of the multi-table natural join operation. - - Raises: - DHError - """ - if isinstance(input, Table) or (isinstance(input, Sequence) and all(isinstance(t, Table) for t in input)): - tables = to_list(input) - multi_join_inputs = [MultiJoinInput(table=t, on=on) for t in tables] - elif isinstance(input, MultiJoinInput) or ( - isinstance(input, Sequence) and all(isinstance(ji, MultiJoinInput) for ji in input)): - if on is not None: - raise DHError(message="on parameter is not permitted when MultiJoinInput objects are provided.") - multi_join_inputs = to_list(input) - else: - raise DHError( - message="input must be a Table, a sequence of Tables, a MultiJoinInput, or a sequence of MultiJoinInputs.") - - table_op = MultijoinTablesOp(multi_join_inputs=multi_join_inputs) - return self.table_service.grpc_table_op(None, table_op, table_class=Table) - - -class MultiJoinInput: - """A MultiJoinInput represents the input tables, key columns and additional columns to be used in the multi-table - natural join. - """ - table: Table - on: Union[str, Sequence[str]] - joins: Union[str, Sequence[str]] = None - - def __init__(self, table: Table, on: Union[str, Sequence[str]], joins: Union[str, Sequence[str]] = None): - """Initializes a MultiJoinInput object. - - Args: - table (Table): the right table to include in the join - on (Union[str, Sequence[str]]): the column(s) to match, can be a common name or an equality expression that - matches every input table, i.e. "col_a = col_b" to rename output column names. - joins (Union[str, Sequence[str]], optional): the column(s) to be added from the table to the result - table, can be renaming expressions, i.e. "new_col = col"; default is None - """ - self.table = table - self.on = to_list(on) - self.joins = to_list(joins) diff --git a/py/client/pydeephaven/table.py b/py/client/pydeephaven/table.py index 6339a7199c3..70df8c0406a 100644 --- a/py/client/pydeephaven/table.py +++ b/py/client/pydeephaven/table.py @@ -6,12 +6,12 @@ from __future__ import annotations -from typing import List, Union +from typing import List, Union, Sequence import pyarrow as pa from pydeephaven._utils import to_list -from pydeephaven._table_ops import MetaTableOp, SortDirection +from pydeephaven._table_ops import MetaTableOp, SortDirection, MultijoinTablesOp from pydeephaven.agg import Aggregation from pydeephaven.dherror import DHError from pydeephaven._table_interface import TableInterface @@ -805,3 +805,67 @@ def delete(self, table: Table) -> None: self.session.input_table_service.delete(self, table) except Exception as e: raise DHError("delete data in the InputTable failed.") from e + + +class MultiJoinInput: + """A MultiJoinInput represents the input tables, key columns and additional columns to be used in the multi-table + natural join. + """ + table: Table + on: Union[str, Sequence[str]] + joins: Union[str, Sequence[str]] = None + + def __init__(self, table: Table, on: Union[str, Sequence[str]], joins: Union[str, Sequence[str]] = None): + """Initializes a MultiJoinInput object. + + Args: + table (Table): the right table to include in the join + on (Union[str, Sequence[str]]): the column(s) to match, can be a common name or an equality expression that + matches every input table, i.e. "col_a = col_b" to rename output column names. + joins (Union[str, Sequence[str]], optional): the column(s) to be added from the table to the result + table, can be renaming expressions, i.e. "new_col = col"; default is None + """ + self.table = table + self.on = to_list(on) + self.joins = to_list(joins) + + +def multi_join(input: Union[Table, Sequence[Table], MultiJoinInput, Sequence[MultiJoinInput]], + on: Union[str, Sequence[str]] = None) -> Table: + """ The multi_join method creates a new table by performing a multi-table natural join on the input tables. The + result consists of the set of distinct keys from the input tables natural joined to each input table. Input + tables need not have a matching row for each key, but they may not have multiple matching rows for a given key. + + Args: + input (Union[Table, Sequence[Table], MultiJoinInput, Sequence[MultiJoinInput]]): the input objects specifying the + tables and columns to include in the join. + on (Union[str, Sequence[str]], optional): the column(s) to match, can be a common name or an equality expression + that matches every input table, i.e. "col_a = col_b" to rename output column names. Note: When + MultiJoinInput objects are supplied, this parameter must be omitted. + + Returns: + Table: the result of the multi-table natural join operation. + + Raises: + DHError + """ + if isinstance(input, Table) or (isinstance(input, Sequence) and all(isinstance(t, Table) for t in input)): + tables = to_list(input) + session = tables[0].session + if not all([t.session == session for t in tables]): + raise DHError(message="all tables must be from the same session.") + multi_join_inputs = [MultiJoinInput(table=t, on=on) for t in tables] + elif isinstance(input, MultiJoinInput) or ( + isinstance(input, Sequence) and all(isinstance(ji, MultiJoinInput) for ji in input)): + if on is not None: + raise DHError(message="on parameter is not permitted when MultiJoinInput objects are provided.") + multi_join_inputs = to_list(input) + session = multi_join_inputs[0].table.session + if not all([mji.table.session == session for mji in multi_join_inputs]): + raise DHError(message="all tables must be from the same session.") + else: + raise DHError( + message="input must be a Table, a sequence of Tables, a MultiJoinInput, or a sequence of MultiJoinInputs.") + + table_op = MultijoinTablesOp(multi_join_inputs=multi_join_inputs) + return session.table_service.grpc_table_op(None, table_op, table_class=Table) diff --git a/py/client/tests/test_multijoin.py b/py/client/tests/test_multijoin.py index 8d9f3a09784..ceece494540 100644 --- a/py/client/tests/test_multijoin.py +++ b/py/client/tests/test_multijoin.py @@ -5,8 +5,8 @@ from pyarrow import csv -from pydeephaven import DHError -from pydeephaven.session import MultiJoinInput +from pydeephaven import DHError, Session +from pydeephaven.table import MultiJoinInput, multi_join from tests.testbase import BaseTestCase @@ -31,7 +31,7 @@ def tearDown(self) -> None: def test_static_simple(self): # Test with multiple input tables - mj_table = self.session.multi_join(input=[self.static_tableA, self.static_tableB], on=["a", "b"]) + mj_table = multi_join(input=[self.static_tableA, self.static_tableB], on=["a", "b"]) # Output table is static self.assertFalse(mj_table.is_refreshing) @@ -40,7 +40,7 @@ def test_static_simple(self): self.assertEqual(mj_table.size, self.static_tableB.size) # Test with a single input table - mj_table = self.session.multi_join(self.static_tableA, ["a", "b"]) + mj_table = multi_join(self.static_tableA, ["a", "b"]) # Output table is static self.assertFalse(mj_table.is_refreshing) @@ -49,13 +49,13 @@ def test_static_simple(self): def test_ticking_simple(self): # Test with multiple input tables - mj_table = self.session.multi_join(input=[self.ticking_tableA, self.ticking_tableB], on=["a", "b"]) + mj_table = multi_join(input=[self.ticking_tableA, self.ticking_tableB], on=["a", "b"]) # Output table is refreshing self.assertTrue(mj_table.is_refreshing) # Test with a single input table - mj_table = self.session.multi_join(input=self.ticking_tableA, on=["a", "b"]) + mj_table = multi_join(input=self.ticking_tableA, on=["a", "b"]) # Output table is refreshing self.assertTrue(mj_table.is_refreshing) @@ -66,7 +66,7 @@ def test_static(self): MultiJoinInput(table=self.static_tableA, on=["key1=a", "key2=b"], joins=["c1", "e1"]), MultiJoinInput(table=self.static_tableB, on=["key1=a", "key2=b"], joins=["d2"]) ] - mj_table = self.session.multi_join(mj_input) + mj_table = multi_join(mj_input) # Output table is static self.assertFalse(mj_table.is_refreshing) @@ -75,7 +75,7 @@ def test_static(self): self.assertEqual(mj_table.size, self.static_tableB.size) # Test with a single input - mj_table = self.session.multi_join(MultiJoinInput(table=self.static_tableA, on=["key1=a", "key2=b"], joins="c1")) + mj_table = multi_join(MultiJoinInput(table=self.static_tableA, on=["key1=a", "key2=b"], joins="c1")) # Output table is static self.assertFalse(mj_table.is_refreshing) @@ -88,13 +88,13 @@ def test_ticking(self): MultiJoinInput(table=self.ticking_tableA, on=["key1=a", "key2=b"], joins=["c1", "e1"]), MultiJoinInput(table=self.ticking_tableB, on=["key1=a", "key2=b"], joins=["d2"]) ] - mj_table = self.session.multi_join(mj_input) + mj_table = multi_join(mj_input) # Output table is refreshing self.assertTrue(mj_table.is_refreshing) # Test with a single input - mj_table = self.session.multi_join(input=MultiJoinInput(table=self.ticking_tableA, on=["key1=a", "key2=b"], joins="c1")) + mj_table = multi_join(input=MultiJoinInput(table=self.ticking_tableA, on=["key1=a", "key2=b"], joins="c1")) # Output table is refreshing self.assertTrue(mj_table.is_refreshing) @@ -106,7 +106,20 @@ def test_errors(self): MultiJoinInput(table=self.ticking_tableB, on=["key1=a", "key2=b"], joins=["d2"]) ] with self.assertRaises(DHError): - mj_table = self.session.multi_join(mj_input, on=["key1=a", "key2=b"]) + mj_table = multi_join(mj_input, on=["key1=a", "key2=b"]) + + session = Session() + t = session.time_table("PT00:00:00.001").update( + ["a = i", "b = i*i % 13", "c1 = i * 13 % 23", "d1 = a + b", "e1 = a - b"]).drop_columns(["Timestamp"]) + + # Assert the exception is raised when to-be-joined tables are not from the same session. + mj_input = [ + MultiJoinInput(table=self.ticking_tableA, on=["key1=a", "key2=b"], joins=["c1", "e1"]), + MultiJoinInput(table=t, on=["key1=a", "key2=b"], joins=["d2"]) + ] + with self.assertRaises(DHError) as cm: + mj_table = multi_join(mj_input) + self.assertIn("all tables must be from the same session", str(cm.exception)) if __name__ == '__main__':