Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: Support Multijoin in the Python client #6020

Merged
merged 6 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions py/client/pydeephaven/_table_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,3 +687,26 @@ def make_grpc_request(self, result_id, source_id) -> Any:
def make_grpc_request_for_batch(self, result_id, source_id) -> Any:
return table_pb2.BatchTableRequest.Operation(
meta_table=self.make_grpc_request(result_id=result_id, source_id=source_id))


class MultijoinTablesOp(TableOp):
def __init__(self, multi_join_inputs: List["MultiJoinInput"]):
self.multi_join_inputs = multi_join_inputs

@classmethod
def get_stub_func(cls, table_service_stub: table_pb2_grpc.TableServiceStub) -> Any:
return table_service_stub.MultiJoinTables

def make_grpc_request(self, result_id, source_id) -> Any:
pb_inputs = []
for mji in self.multi_join_inputs:
source_id = table_pb2.TableReference(ticket=mji.table.ticket.pb_ticket)
columns_to_match = mji.on
columns_to_add = mji.joins
pb_inputs.append(table_pb2.MultiJoinInput(source_id=source_id, columns_to_match=columns_to_match,
columns_to_add=columns_to_add))
return table_pb2.MultiJoinTablesRequest(result_id=result_id, multi_join_inputs=pb_inputs)

def make_grpc_request_for_batch(self, result_id, source_id) -> Any:
return table_pb2.BatchTableRequest.Operation(
multi_join=self.make_grpc_request(result_id=result_id, source_id=source_id))
4 changes: 2 additions & 2 deletions py/client/pydeephaven/_table_service.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#
# Copyright (c) 2016-2024 Deephaven Data Labs and Patent Pending
#
from typing import Union, List
from typing import Union, List, Optional

from pydeephaven._batch_assembler import BatchOpAssembler
from pydeephaven._table_ops import TableOp
Expand Down Expand Up @@ -38,7 +38,7 @@ def batch(self, ops: List[TableOp]) -> Table:
except Exception as e:
raise DHError("failed to finish the table batch operation.") from e

def grpc_table_op(self, table: Table, op: TableOp, table_class: type = Table) -> Union[Table, InputTable]:
def grpc_table_op(self, table: Optional[Table], op: TableOp, table_class: type = Table) -> Union[Table, InputTable]:
"""Makes a single gRPC Table operation call and returns a new Table."""
try:
export_ticket = self.session.make_export_ticket()
Expand Down
2 changes: 1 addition & 1 deletion py/client/pydeephaven/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ def _connect(self):
# started together don't align retries.
skew = random()
# Backoff schedule for retries after consecutive failures to refresh auth token
self._refresh_backoff = [ skew + 0.1, skew + 1, skew + 10 ]
self._refresh_backoff = [skew + 0.1, skew + 1, skew + 10]

if self._refresh_backoff[0] > self._timeout_seconds:
raise DHError(f'server configuration http.session.durationMs={session_duration} is too small.')
Expand Down
82 changes: 80 additions & 2 deletions py/client/pydeephaven/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@

from __future__ import annotations

from typing import List, Union
from typing import List, Union, Sequence

import pyarrow as pa

from pydeephaven._utils import to_list

from pydeephaven._table_ops import MetaTableOp, SortDirection
from pydeephaven._table_ops import MetaTableOp, SortDirection, MultijoinTablesOp
from pydeephaven.agg import Aggregation
from pydeephaven.dherror import DHError
from pydeephaven._table_interface import TableInterface
Expand Down Expand Up @@ -804,3 +804,81 @@ def delete(self, table: Table) -> None:
self.session.input_table_service.delete(self, table)
except Exception as e:
raise DHError("delete data in the InputTable failed.") from e


class MultiJoinTable:
"""A MultiJoinTable is an object that contains the result of a multi-table natural join. To retrieve the underlying
result Table, use the :attr:`.table` property. """

def __init__(self, table: Table):
self._table = table

@property
def table(self) -> Table:
jmao-denver marked this conversation as resolved.
Show resolved Hide resolved
"""Returns the Table containing the multi-table natural join output. """
return self._table


class MultiJoinInput:
"""A MultiJoinInput represents the input tables, key columns and additional columns to be used in the multi-table
natural join.
"""
table: Table
on: Union[str, Sequence[str]]
joins: Union[str, Sequence[str]] = None

def __init__(self, table: Table, on: Union[str, Sequence[str]], joins: Union[str, Sequence[str]] = None):
"""Initializes a MultiJoinInput object.

Args:
table (Table): the right table to include in the join
on (Union[str, Sequence[str]]): the column(s) to match, can be a common name or an equality expression that
matches every input table, i.e. "col_a = col_b" to rename output column names.
joins (Union[str, Sequence[str]], optional): the column(s) to be added from the table to the result
table, can be renaming expressions, i.e. "new_col = col"; default is None
"""
self.table = table
self.on = to_list(on)
self.joins = to_list(joins)


def multi_join(input: Union[Table, Sequence[Table], MultiJoinInput, Sequence[MultiJoinInput]],
on: Union[str, Sequence[str]] = None) -> MultiJoinTable:
""" The multi_join method creates a new table by performing a multi-table natural join on the input tables. The
result consists of the set of distinct keys from the input tables natural joined to each input table. Input
tables need not have a matching row for each key, but they may not have multiple matching rows for a given key.

Args:
input (Union[Table, Sequence[Table], MultiJoinInput, Sequence[MultiJoinInput]]): the input objects specifying the
tables and columns to include in the join.
on (Union[str, Sequence[str]], optional): the column(s) to match, can be a common name or an equality expression
that matches every input table, i.e. "col_a = col_b" to rename output column names. Note: When
MultiJoinInput objects are supplied, this parameter must be omitted.

Returns:
MultiJoinTable: the result of the multi-table natural join operation. To access the underlying Table, use the
:attr:`~MultiJoinTable.table` property.

Raises:
DHError
"""
if isinstance(input, Table) or (isinstance(input, Sequence) and all(isinstance(t, Table) for t in input)):
tables = to_list(input)
session = tables[0].session
if not all([t.session == session for t in tables]):
raise DHError(message="all tables must be from the same session.")
multi_join_inputs = [MultiJoinInput(table=t, on=on) for t in tables]
elif isinstance(input, MultiJoinInput) or (
isinstance(input, Sequence) and all(isinstance(ji, MultiJoinInput) for ji in input)):
if on is not None:
raise DHError(message="on parameter is not permitted when MultiJoinInput objects are provided.")
multi_join_inputs = to_list(input)
session = multi_join_inputs[0].table.session
if not all([mji.table.session == session for mji in multi_join_inputs]):
raise DHError(message="all tables must be from the same session.")
else:
raise DHError(
message="input must be a Table, a sequence of Tables, a MultiJoinInput, or a sequence of MultiJoinInputs.")

table_op = MultijoinTablesOp(multi_join_inputs=multi_join_inputs)
return MultiJoinTable(table=session.table_service.grpc_table_op(None, table_op, table_class=Table))
127 changes: 127 additions & 0 deletions py/client/tests/test_multijoin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
#
# Copyright (c) 2016-2024 Deephaven Data Labs and Patent Pending
#
import unittest

from pyarrow import csv

from pydeephaven import DHError, Session
from pydeephaven.table import MultiJoinInput, multi_join
from tests.testbase import BaseTestCase


class MultiJoinTestCase(BaseTestCase):
def setUp(self):
super().setUp()
pa_table = csv.read_csv(self.csv_file)
self.static_tableA = self.session.import_table(pa_table).select(["a", "b", "c1=c", "d1=d", "e1=e"])
self.static_tableB = self.static_tableA.update(["c2=c1+1", "d2=d1+2", "e2=e1+3"]).drop_columns(
["c1", "d1", "e1"])
self.ticking_tableA = self.session.time_table("PT00:00:00.001").update(
["a = i", "b = i*i % 13", "c1 = i * 13 % 23", "d1 = a + b", "e1 = a - b"]).drop_columns(["Timestamp"])
self.ticking_tableB = self.ticking_tableA.update(["c2=c1+1", "d2=d1+2", "e2=e1+3"]).drop_columns(
["c1", "d1", "e1"])

def tearDown(self) -> None:
self.static_tableA = None
self.static_tableB = None
self.ticking_tableA = None
self.ticking_tableB = None
super().tearDown()

def test_static_simple(self):
# Test with multiple input tables
mj_table = multi_join(input=[self.static_tableA, self.static_tableB], on=["a", "b"])

# Output table is static
self.assertFalse(mj_table.table.is_refreshing)
# Output table has same # rows as sources
self.assertEqual(mj_table.table.size, self.static_tableA.size)
self.assertEqual(mj_table.table.size, self.static_tableB.size)

# Test with a single input table
mj_table = multi_join(self.static_tableA, ["a", "b"])

# Output table is static
self.assertFalse(mj_table.table.is_refreshing)
# Output table has same # rows as sources
self.assertEqual(mj_table.table.size, self.static_tableA.size)

def test_ticking_simple(self):
# Test with multiple input tables
mj_table = multi_join(input=[self.ticking_tableA, self.ticking_tableB], on=["a", "b"])

# Output table is refreshing
self.assertTrue(mj_table.table.is_refreshing)

# Test with a single input table
mj_table = multi_join(input=self.ticking_tableA, on=["a", "b"])

# Output table is refreshing
self.assertTrue(mj_table.table.is_refreshing)

def test_static(self):
# Test with multiple input
mj_input = [
MultiJoinInput(table=self.static_tableA, on=["key1=a", "key2=b"], joins=["c1", "e1"]),
MultiJoinInput(table=self.static_tableB, on=["key1=a", "key2=b"], joins=["d2"])
]
mj_table = multi_join(mj_input)

# Output table is static
self.assertFalse(mj_table.table.is_refreshing)
# Output table has same # rows as sources
self.assertEqual(mj_table.table.size, self.static_tableA.size)
self.assertEqual(mj_table.table.size, self.static_tableB.size)

# Test with a single input
mj_table = multi_join(MultiJoinInput(table=self.static_tableA, on=["key1=a", "key2=b"], joins="c1"))

# Output table is static
self.assertFalse(mj_table.table.is_refreshing)
# Output table has same # rows as sources
self.assertEqual(mj_table.table.size, self.static_tableA.size)

def test_ticking(self):
# Test with multiple input
mj_input = [
MultiJoinInput(table=self.ticking_tableA, on=["key1=a", "key2=b"], joins=["c1", "e1"]),
MultiJoinInput(table=self.ticking_tableB, on=["key1=a", "key2=b"], joins=["d2"])
]
mj_table = multi_join(mj_input)

# Output table is refreshing
self.assertTrue(mj_table.table.is_refreshing)

# Test with a single input
mj_table = multi_join(input=MultiJoinInput(table=self.ticking_tableA, on=["key1=a", "key2=b"], joins="c1"))

# Output table is refreshing
self.assertTrue(mj_table.table.is_refreshing)

def test_errors(self):
# Assert the exception is raised when providing MultiJoinInput and the on parameter is not None (omitted).
mj_input = [
MultiJoinInput(table=self.ticking_tableA, on=["key1=a", "key2=b"], joins=["c1", "e1"]),
MultiJoinInput(table=self.ticking_tableB, on=["key1=a", "key2=b"], joins=["d2"])
]
with self.assertRaises(DHError) as cm:
mj_table = multi_join(mj_input, on=["key1=a", "key2=b"])
self.assertIn("on parameter is not permitted", str(cm.exception))

session = Session()
t = session.time_table("PT00:00:00.001").update(
["a = i", "b = i*i % 13", "c1 = i * 13 % 23", "d1 = a + b", "e1 = a - b"]).drop_columns(["Timestamp"])

# Assert the exception is raised when to-be-joined tables are not from the same session.
mj_input = [
MultiJoinInput(table=self.ticking_tableA, on=["key1=a", "key2=b"], joins=["c1", "e1"]),
MultiJoinInput(table=t, on=["key1=a", "key2=b"], joins=["d2"])
]
with self.assertRaises(DHError) as cm:
mj_table = multi_join(mj_input)
self.assertIn("all tables must be from the same session", str(cm.exception))


if __name__ == '__main__':
unittest.main()
7 changes: 4 additions & 3 deletions py/server/deephaven/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -3786,7 +3786,7 @@ def __init__(self, table: Table, on: Union[str, Sequence[str]], joins: Union[str
table (Table): the right table to include in the join
on (Union[str, Sequence[str]]): the column(s) to match, can be a common name or an equal expression,
i.e. "col_a = col_b" for different column names
joins (Union[str, Sequence[str]], optional): the column(s) to be added from the this table to the result
joins (Union[str, Sequence[str]], optional): the column(s) to be added from the table to the result
table, can be renaming expressions, i.e. "new_col = col"; default is None

Raises:
Expand All @@ -3803,13 +3803,14 @@ def __init__(self, table: Table, on: Union[str, Sequence[str]], joins: Union[str

class MultiJoinTable(JObjectWrapper):
"""A MultiJoinTable is an object that contains the result of a multi-table natural join. To retrieve the underlying
result Table, use the table() method. """
result Table, use the :attr:`.table` property. """
j_object_type = _JMultiJoinTable

@property
def j_object(self) -> jpy.JType:
return self.j_multijointable

@property
def table(self) -> Table:
"""Returns the Table containing the multi-table natural join output. """
return Table(j_table=self.j_multijointable.table())
Expand Down Expand Up @@ -3866,7 +3867,7 @@ def multi_join(input: Union[Table, Sequence[Table], MultiJoinInput, Sequence[Mul

Returns:
MultiJoinTable: the result of the multi-table natural join operation. To access the underlying Table, use the
table() method.
:attr:`~MultiJoinTable.table` property.
"""
return MultiJoinTable(input, on)

Expand Down
Loading