Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support Multijoin in the Python client #6020

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions py/client/pydeephaven/_table_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,3 +687,26 @@ def make_grpc_request(self, result_id, source_id) -> Any:
def make_grpc_request_for_batch(self, result_id, source_id) -> Any:
return table_pb2.BatchTableRequest.Operation(
meta_table=self.make_grpc_request(result_id=result_id, source_id=source_id))


class MultijoinTablesOp(TableOp):
def __init__(self, multi_join_inputs: List["MultiJoinInput"]):
self.multi_join_inputs = multi_join_inputs

@classmethod
def get_stub_func(cls, table_service_stub: table_pb2_grpc.TableServiceStub) -> Any:
return table_service_stub.MultiJoinTables

def make_grpc_request(self, result_id, source_id) -> Any:
pb_inputs = []
for mji in self.multi_join_inputs:
source_id = table_pb2.TableReference(ticket=mji.table.ticket)
columns_to_match = mji.on
columns_to_add = mji.joins
pb_inputs.append(table_pb2.MultiJoinInput(source_id=source_id, columns_to_match=columns_to_match,
columns_to_add=columns_to_add))
return table_pb2.MultiJoinTablesRequest(result_id=result_id, multi_join_inputs=pb_inputs)

def make_grpc_request_for_batch(self, result_id, source_id) -> Any:
return table_pb2.BatchTableRequest.Operation(
multi_join=self.make_grpc_request(result_id=result_id, source_id=source_id))
4 changes: 2 additions & 2 deletions py/client/pydeephaven/_table_service.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#
# Copyright (c) 2016-2024 Deephaven Data Labs and Patent Pending
#
from typing import Union, List
from typing import Union, List, Optional

from pydeephaven._batch_assembler import BatchOpAssembler
from pydeephaven._table_ops import TableOp
Expand Down Expand Up @@ -37,7 +37,7 @@ def batch(self, ops: List[TableOp]) -> Table:
except Exception as e:
raise DHError("failed to finish the table batch operation.") from e

def grpc_table_op(self, table: Table, op: TableOp, table_class: type = Table) -> Union[Table, InputTable]:
def grpc_table_op(self, table: Optional[Table], op: TableOp, table_class: type = Table) -> Union[Table, InputTable]:
"""Makes a single gRPC Table operation call and returns a new Table."""
try:
result_id = self.session.make_ticket()
Expand Down
4 changes: 3 additions & 1 deletion py/client/pydeephaven/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ def random_ticket(cls) -> SharedTicket:
bytes_ = uuid4().int.to_bytes(16, byteorder='little', signed=False)
return cls(ticket_bytes=b'h' + bytes_)


_BidiRpc = NewType("_BidiRpc", grpc.StreamStreamMultiCallable)

_NotBidiRpc = NewType(
Expand All @@ -114,6 +115,7 @@ def random_ticket(cls) -> SharedTicket:
grpc.UnaryStreamMultiCallable,
grpc.StreamUnaryMultiCallable])


class Session:
"""A Session object represents a connection to the Deephaven data server. It contains a number of convenience
methods for asking the server to create tables, import Arrow data into tables, merge tables, run Python scripts, and
Expand Down Expand Up @@ -426,7 +428,7 @@ def _connect(self):
# started together don't align retries.
skew = random()
# Backoff schedule for retries after consecutive failures to refresh auth token
self._refresh_backoff = [ skew + 0.1, skew + 1, skew + 10 ]
self._refresh_backoff = [skew + 0.1, skew + 1, skew + 10]

if self._refresh_backoff[0] > self._timeout_seconds:
raise DHError(f'server configuration http.session.durationMs={session_duration} is too small.')
Expand Down
68 changes: 66 additions & 2 deletions py/client/pydeephaven/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@

from __future__ import annotations

from typing import List, Union
from typing import List, Union, Sequence

import pyarrow as pa
from pydeephaven._utils import to_list

from pydeephaven._table_ops import MetaTableOp, SortDirection
from pydeephaven._table_ops import MetaTableOp, SortDirection, MultijoinTablesOp
from pydeephaven.agg import Aggregation
from pydeephaven.dherror import DHError
from pydeephaven._table_interface import TableInterface
Expand Down Expand Up @@ -805,3 +805,67 @@ def delete(self, table: Table) -> None:
self.session.input_table_service.delete(self, table)
except Exception as e:
raise DHError("delete data in the InputTable failed.") from e


class MultiJoinInput:
"""A MultiJoinInput represents the input tables, key columns and additional columns to be used in the multi-table
natural join.
"""
table: Table
on: Union[str, Sequence[str]]
joins: Union[str, Sequence[str]] = None

def __init__(self, table: Table, on: Union[str, Sequence[str]], joins: Union[str, Sequence[str]] = None):
"""Initializes a MultiJoinInput object.

Args:
table (Table): the right table to include in the join
on (Union[str, Sequence[str]]): the column(s) to match, can be a common name or an equality expression that
matches every input table, i.e. "col_a = col_b" to rename output column names.
joins (Union[str, Sequence[str]], optional): the column(s) to be added from the table to the result
table, can be renaming expressions, i.e. "new_col = col"; default is None
"""
self.table = table
self.on = to_list(on)
self.joins = to_list(joins)


def multi_join(input: Union[Table, Sequence[Table], MultiJoinInput, Sequence[MultiJoinInput]],
on: Union[str, Sequence[str]] = None) -> Table:
Comment on lines +833 to +834
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The return type here is Table where the deephaven method returns MultiJoinTable. MultiJoinTable does not inherit from Table and has a different API. Users of MultiJoinTable must call table() to get a table to use. I think I complained about this during that review. Without adding a MultiJoinTable on the client side, the APIs are incompatible.
image

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

""" The multi_join method creates a new table by performing a multi-table natural join on the input tables. The
result consists of the set of distinct keys from the input tables natural joined to each input table. Input
tables need not have a matching row for each key, but they may not have multiple matching rows for a given key.

Args:
input (Union[Table, Sequence[Table], MultiJoinInput, Sequence[MultiJoinInput]]): the input objects specifying the
tables and columns to include in the join.
on (Union[str, Sequence[str]], optional): the column(s) to match, can be a common name or an equality expression
that matches every input table, i.e. "col_a = col_b" to rename output column names. Note: When
MultiJoinInput objects are supplied, this parameter must be omitted.

Returns:
Table: the result of the multi-table natural join operation.

Raises:
DHError
"""
if isinstance(input, Table) or (isinstance(input, Sequence) and all(isinstance(t, Table) for t in input)):
tables = to_list(input)
session = tables[0].session
if not all([t.session == session for t in tables]):
raise DHError(message="all tables must be from the same session.")
multi_join_inputs = [MultiJoinInput(table=t, on=on) for t in tables]
elif isinstance(input, MultiJoinInput) or (
isinstance(input, Sequence) and all(isinstance(ji, MultiJoinInput) for ji in input)):
if on is not None:
raise DHError(message="on parameter is not permitted when MultiJoinInput objects are provided.")
multi_join_inputs = to_list(input)
session = multi_join_inputs[0].table.session
if not all([mji.table.session == session for mji in multi_join_inputs]):
raise DHError(message="all tables must be from the same session.")
else:
raise DHError(
message="input must be a Table, a sequence of Tables, a MultiJoinInput, or a sequence of MultiJoinInputs.")

table_op = MultijoinTablesOp(multi_join_inputs=multi_join_inputs)
return session.table_service.grpc_table_op(None, table_op, table_class=Table)
126 changes: 126 additions & 0 deletions py/client/tests/test_multijoin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
#
# Copyright (c) 2016-2024 Deephaven Data Labs and Patent Pending
#
import unittest

from pyarrow import csv

from pydeephaven import DHError, Session
from pydeephaven.table import MultiJoinInput, multi_join
from tests.testbase import BaseTestCase


class MultiJoinTestCase(BaseTestCase):
def setUp(self):
super().setUp()
pa_table = csv.read_csv(self.csv_file)
self.static_tableA = self.session.import_table(pa_table).select(["a", "b", "c1=c", "d1=d", "e1=e"])
self.static_tableB = self.static_tableA.update(["c2=c1+1", "d2=d1+2", "e2=e1+3"]).drop_columns(
["c1", "d1", "e1"])
self.ticking_tableA = self.session.time_table("PT00:00:00.001").update(
["a = i", "b = i*i % 13", "c1 = i * 13 % 23", "d1 = a + b", "e1 = a - b"]).drop_columns(["Timestamp"])
self.ticking_tableB = self.ticking_tableA.update(["c2=c1+1", "d2=d1+2", "e2=e1+3"]).drop_columns(
["c1", "d1", "e1"])

def tearDown(self) -> None:
self.static_tableA = None
self.static_tableB = None
self.ticking_tableA = None
self.ticking_tableB = None
super().tearDown()

def test_static_simple(self):
# Test with multiple input tables
mj_table = multi_join(input=[self.static_tableA, self.static_tableB], on=["a", "b"])

# Output table is static
self.assertFalse(mj_table.is_refreshing)
# Output table has same # rows as sources
self.assertEqual(mj_table.size, self.static_tableA.size)
self.assertEqual(mj_table.size, self.static_tableB.size)

# Test with a single input table
mj_table = multi_join(self.static_tableA, ["a", "b"])

# Output table is static
self.assertFalse(mj_table.is_refreshing)
# Output table has same # rows as sources
self.assertEqual(mj_table.size, self.static_tableA.size)

def test_ticking_simple(self):
# Test with multiple input tables
mj_table = multi_join(input=[self.ticking_tableA, self.ticking_tableB], on=["a", "b"])

# Output table is refreshing
self.assertTrue(mj_table.is_refreshing)

# Test with a single input table
mj_table = multi_join(input=self.ticking_tableA, on=["a", "b"])

# Output table is refreshing
self.assertTrue(mj_table.is_refreshing)

def test_static(self):
# Test with multiple input
mj_input = [
MultiJoinInput(table=self.static_tableA, on=["key1=a", "key2=b"], joins=["c1", "e1"]),
MultiJoinInput(table=self.static_tableB, on=["key1=a", "key2=b"], joins=["d2"])
]
mj_table = multi_join(mj_input)

# Output table is static
self.assertFalse(mj_table.is_refreshing)
# Output table has same # rows as sources
self.assertEqual(mj_table.size, self.static_tableA.size)
self.assertEqual(mj_table.size, self.static_tableB.size)

# Test with a single input
mj_table = multi_join(MultiJoinInput(table=self.static_tableA, on=["key1=a", "key2=b"], joins="c1"))

# Output table is static
self.assertFalse(mj_table.is_refreshing)
# Output table has same # rows as sources
self.assertEqual(mj_table.size, self.static_tableA.size)

def test_ticking(self):
# Test with multiple input
mj_input = [
MultiJoinInput(table=self.ticking_tableA, on=["key1=a", "key2=b"], joins=["c1", "e1"]),
MultiJoinInput(table=self.ticking_tableB, on=["key1=a", "key2=b"], joins=["d2"])
]
mj_table = multi_join(mj_input)

# Output table is refreshing
self.assertTrue(mj_table.is_refreshing)

# Test with a single input
mj_table = multi_join(input=MultiJoinInput(table=self.ticking_tableA, on=["key1=a", "key2=b"], joins="c1"))

# Output table is refreshing
self.assertTrue(mj_table.is_refreshing)

def test_errors(self):
# Assert the exception is raised when providing MultiJoinInput and the on parameter is not None (omitted).
mj_input = [
MultiJoinInput(table=self.ticking_tableA, on=["key1=a", "key2=b"], joins=["c1", "e1"]),
MultiJoinInput(table=self.ticking_tableB, on=["key1=a", "key2=b"], joins=["d2"])
]
with self.assertRaises(DHError):
mj_table = multi_join(mj_input, on=["key1=a", "key2=b"])

session = Session()
t = session.time_table("PT00:00:00.001").update(
["a = i", "b = i*i % 13", "c1 = i * 13 % 23", "d1 = a + b", "e1 = a - b"]).drop_columns(["Timestamp"])

# Assert the exception is raised when to-be-joined tables are not from the same session.
mj_input = [
MultiJoinInput(table=self.ticking_tableA, on=["key1=a", "key2=b"], joins=["c1", "e1"]),
MultiJoinInput(table=t, on=["key1=a", "key2=b"], joins=["d2"])
]
with self.assertRaises(DHError) as cm:
mj_table = multi_join(mj_input)
self.assertIn("all tables must be from the same session", str(cm.exception))


if __name__ == '__main__':
unittest.main()
2 changes: 1 addition & 1 deletion py/server/deephaven/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -3782,7 +3782,7 @@ def __init__(self, table: Table, on: Union[str, Sequence[str]], joins: Union[str
table (Table): the right table to include in the join
on (Union[str, Sequence[str]]): the column(s) to match, can be a common name or an equal expression,
i.e. "col_a = col_b" for different column names
joins (Union[str, Sequence[str]], optional): the column(s) to be added from the this table to the result
joins (Union[str, Sequence[str]], optional): the column(s) to be added from the table to the result
table, can be renaming expressions, i.e. "new_col = col"; default is None

Raises:
Expand Down
Loading