-
Notifications
You must be signed in to change notification settings - Fork 80
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: Support Multijoin in the Python client #6020
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -10,7 +10,7 @@ | |
import os | ||
from random import random | ||
import threading | ||
from typing import Any, Dict, Iterable, List, Union, Tuple, NewType | ||
from typing import Any, Dict, Iterable, List, Union, Tuple, NewType, Sequence | ||
from uuid import uuid4 | ||
|
||
import grpc | ||
|
@@ -26,7 +26,8 @@ | |
from pydeephaven._input_table_service import InputTableService | ||
from pydeephaven._plugin_obj_service import PluginObjService | ||
from pydeephaven._session_service import SessionService | ||
from pydeephaven._table_ops import TimeTableOp, EmptyTableOp, MergeTablesOp, FetchTableOp, CreateInputTableOp | ||
from pydeephaven._table_ops import TimeTableOp, EmptyTableOp, MergeTablesOp, FetchTableOp, CreateInputTableOp, \ | ||
MultijoinTablesOp | ||
from pydeephaven._table_service import TableService | ||
from pydeephaven._utils import to_list | ||
from pydeephaven.dherror import DHError | ||
|
@@ -105,6 +106,7 @@ def random_ticket(cls) -> SharedTicket: | |
bytes_ = uuid4().int.to_bytes(16, byteorder='little', signed=False) | ||
return cls(ticket_bytes=b'h' + bytes_) | ||
|
||
|
||
_BidiRpc = NewType("_BidiRpc", grpc.StreamStreamMultiCallable) | ||
|
||
_NotBidiRpc = NewType( | ||
|
@@ -114,6 +116,7 @@ def random_ticket(cls) -> SharedTicket: | |
grpc.UnaryStreamMultiCallable, | ||
grpc.StreamUnaryMultiCallable]) | ||
|
||
|
||
class Session: | ||
"""A Session object represents a connection to the Deephaven data server. It contains a number of convenience | ||
methods for asking the server to create tables, import Arrow data into tables, merge tables, run Python scripts, and | ||
|
@@ -426,7 +429,7 @@ def _connect(self): | |
# started together don't align retries. | ||
skew = random() | ||
# Backoff schedule for retries after consecutive failures to refresh auth token | ||
self._refresh_backoff = [ skew + 0.1, skew + 1, skew + 10 ] | ||
self._refresh_backoff = [skew + 0.1, skew + 1, skew + 10] | ||
|
||
if self._refresh_backoff[0] > self._timeout_seconds: | ||
raise DHError(f'server configuration http.session.durationMs={session_duration} is too small.') | ||
|
@@ -734,3 +737,60 @@ def plugin_client(self, exportable_obj: ticket_pb2.TypedTicket) -> PluginClient: | |
|
||
Part of the experimental plugin API.""" | ||
return PluginClient(self, exportable_obj) | ||
|
||
def multi_join(self, input: Union[Table, Sequence[Table], MultiJoinInput, Sequence[MultiJoinInput]], | ||
on: Union[str, Sequence[str]] = None) -> Table: | ||
""" The multi_join method creates a new table by performing a multi-table natural join on the input tables. The | ||
result consists of the set of distinct keys from the input tables natural joined to each input table. Input | ||
tables need not have a matching row for each key, but they may not have multiple matching rows for a given key. | ||
|
||
Args: | ||
input (Union[Table, Sequence[Table], MultiJoinInput, Sequence[MultiJoinInput]]): the input objects specifying the | ||
tables and columns to include in the join. | ||
on (Union[str, Sequence[str]], optional): the column(s) to match, can be a common name or an equality expression | ||
that matches every input table, i.e. "col_a = col_b" to rename output column names. Note: When | ||
MultiJoinInput objects are supplied, this parameter must be omitted. | ||
|
||
Returns: | ||
Table: the result of the multi-table natural join operation. | ||
|
||
Raises: | ||
DHError | ||
""" | ||
if isinstance(input, Table) or (isinstance(input, Sequence) and all(isinstance(t, Table) for t in input)): | ||
tables = to_list(input) | ||
multi_join_inputs = [MultiJoinInput(table=t, on=on) for t in tables] | ||
elif isinstance(input, MultiJoinInput) or ( | ||
isinstance(input, Sequence) and all(isinstance(ji, MultiJoinInput) for ji in input)): | ||
if on is not None: | ||
raise DHError(message="on parameter is not permitted when MultiJoinInput objects are provided.") | ||
multi_join_inputs = to_list(input) | ||
else: | ||
raise DHError( | ||
message="input must be a Table, a sequence of Tables, a MultiJoinInput, or a sequence of MultiJoinInputs.") | ||
|
||
table_op = MultijoinTablesOp(multi_join_inputs=multi_join_inputs) | ||
return self.table_service.grpc_table_op(None, table_op, table_class=Table) | ||
|
||
|
||
class MultiJoinInput: | ||
chipkent marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""A MultiJoinInput represents the input tables, key columns and additional columns to be used in the multi-table | ||
natural join. | ||
""" | ||
table: Table | ||
on: Union[str, Sequence[str]] | ||
joins: Union[str, Sequence[str]] = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not present on the server There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Server side is a wrapper around a Java MJI. I'm ambivalent about whether these are user-facing, but they should not be mutable. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is still an incompatibility in the API. |
||
|
||
def __init__(self, table: Table, on: Union[str, Sequence[str]], joins: Union[str, Sequence[str]] = None): | ||
"""Initializes a MultiJoinInput object. | ||
|
||
Args: | ||
table (Table): the right table to include in the join | ||
on (Union[str, Sequence[str]]): the column(s) to match, can be a common name or an equality expression that | ||
matches every input table, i.e. "col_a = col_b" to rename output column names. | ||
joins (Union[str, Sequence[str]], optional): the column(s) to be added from the table to the result | ||
table, can be renaming expressions, i.e. "new_col = col"; default is None | ||
""" | ||
self.table = table | ||
self.on = to_list(on) | ||
self.joins = to_list(joins) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
# | ||
# Copyright (c) 2016-2024 Deephaven Data Labs and Patent Pending | ||
# | ||
import unittest | ||
|
||
from pyarrow import csv | ||
|
||
from pydeephaven import DHError | ||
from pydeephaven.session import MultiJoinInput | ||
from tests.testbase import BaseTestCase | ||
|
||
|
||
class MultiJoinTestCase(BaseTestCase): | ||
def setUp(self): | ||
super().setUp() | ||
pa_table = csv.read_csv(self.csv_file) | ||
self.static_tableA = self.session.import_table(pa_table).select(["a", "b", "c1=c", "d1=d", "e1=e"]) | ||
self.static_tableB = self.static_tableA.update(["c2=c1+1", "d2=d1+2", "e2=e1+3"]).drop_columns( | ||
["c1", "d1", "e1"]) | ||
self.ticking_tableA = self.session.time_table("PT00:00:00.001").update( | ||
["a = i", "b = i*i % 13", "c1 = i * 13 % 23", "d1 = a + b", "e1 = a - b"]).drop_columns(["Timestamp"]) | ||
self.ticking_tableB = self.ticking_tableA.update(["c2=c1+1", "d2=d1+2", "e2=e1+3"]).drop_columns( | ||
["c1", "d1", "e1"]) | ||
|
||
def tearDown(self) -> None: | ||
self.static_tableA = None | ||
self.static_tableB = None | ||
self.ticking_tableA = None | ||
self.ticking_tableB = None | ||
super().tearDown() | ||
|
||
def test_static_simple(self): | ||
# Test with multiple input tables | ||
mj_table = self.session.multi_join(input=[self.static_tableA, self.static_tableB], on=["a", "b"]) | ||
|
||
# Output table is static | ||
self.assertFalse(mj_table.is_refreshing) | ||
# Output table has same # rows as sources | ||
self.assertEqual(mj_table.size, self.static_tableA.size) | ||
self.assertEqual(mj_table.size, self.static_tableB.size) | ||
|
||
# Test with a single input table | ||
mj_table = self.session.multi_join(self.static_tableA, ["a", "b"]) | ||
|
||
# Output table is static | ||
self.assertFalse(mj_table.is_refreshing) | ||
# Output table has same # rows as sources | ||
self.assertEqual(mj_table.size, self.static_tableA.size) | ||
|
||
def test_ticking_simple(self): | ||
# Test with multiple input tables | ||
mj_table = self.session.multi_join(input=[self.ticking_tableA, self.ticking_tableB], on=["a", "b"]) | ||
|
||
# Output table is refreshing | ||
self.assertTrue(mj_table.is_refreshing) | ||
|
||
# Test with a single input table | ||
mj_table = self.session.multi_join(input=self.ticking_tableA, on=["a", "b"]) | ||
|
||
# Output table is refreshing | ||
self.assertTrue(mj_table.is_refreshing) | ||
|
||
def test_static(self): | ||
# Test with multiple input | ||
mj_input = [ | ||
MultiJoinInput(table=self.static_tableA, on=["key1=a", "key2=b"], joins=["c1", "e1"]), | ||
MultiJoinInput(table=self.static_tableB, on=["key1=a", "key2=b"], joins=["d2"]) | ||
] | ||
mj_table = self.session.multi_join(mj_input) | ||
|
||
# Output table is static | ||
self.assertFalse(mj_table.is_refreshing) | ||
# Output table has same # rows as sources | ||
self.assertEqual(mj_table.size, self.static_tableA.size) | ||
self.assertEqual(mj_table.size, self.static_tableB.size) | ||
|
||
# Test with a single input | ||
mj_table = self.session.multi_join(MultiJoinInput(table=self.static_tableA, on=["key1=a", "key2=b"], joins="c1")) | ||
|
||
# Output table is static | ||
self.assertFalse(mj_table.is_refreshing) | ||
# Output table has same # rows as sources | ||
self.assertEqual(mj_table.size, self.static_tableA.size) | ||
|
||
def test_ticking(self): | ||
# Test with multiple input | ||
mj_input = [ | ||
MultiJoinInput(table=self.ticking_tableA, on=["key1=a", "key2=b"], joins=["c1", "e1"]), | ||
MultiJoinInput(table=self.ticking_tableB, on=["key1=a", "key2=b"], joins=["d2"]) | ||
] | ||
mj_table = self.session.multi_join(mj_input) | ||
|
||
# Output table is refreshing | ||
self.assertTrue(mj_table.is_refreshing) | ||
|
||
# Test with a single input | ||
mj_table = self.session.multi_join(input=MultiJoinInput(table=self.ticking_tableA, on=["key1=a", "key2=b"], joins="c1")) | ||
|
||
# Output table is refreshing | ||
self.assertTrue(mj_table.is_refreshing) | ||
|
||
def test_errors(self): | ||
# Assert the exception is raised when providing MultiJoinInput and the on parameter is not None (omitted). | ||
mj_input = [ | ||
MultiJoinInput(table=self.ticking_tableA, on=["key1=a", "key2=b"], joins=["c1", "e1"]), | ||
MultiJoinInput(table=self.ticking_tableB, on=["key1=a", "key2=b"], joins=["d2"]) | ||
] | ||
with self.assertRaises(DHError): | ||
mj_table = self.session.multi_join(mj_input, on=["key1=a", "key2=b"]) | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Users have very very strongly wanted the server and client APIs to be the same. This has a number of deviations compared to the server that I can see at a glance.
multi_join
is a method on session instead of just being a function.multi_join
is in a different submodule (session
instead oftable
).Table
instead ofMultiJoinTable
.How similar can the APIs be?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Other clients have taken the approach of returning Table for now, instead of MultiJoinTable. MJT is more of a future-looking API, and given that its potential (adding more joins to an existing MJ and thus getting out a new Table) hasn't been realized yet there's no real benefit to building an object type plugin for MJT.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have moved the new method to table.py as a function and have also relocated the
MultiJoinInput
class there as well.@chipkent I totally understand the desire to have the interface be the same for the two APIs, but the client isn't limited to working with just one server and making a 'session' level function a method on the Session object has the benefit of clear isolation/readability etc. So this change has a trade-off which I am quite ambivalent about.