Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support Multijoin in the Python client #6020

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions py/client/pydeephaven/_table_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -687,3 +687,26 @@ def make_grpc_request(self, result_id, source_id) -> Any:
def make_grpc_request_for_batch(self, result_id, source_id) -> Any:
return table_pb2.BatchTableRequest.Operation(
meta_table=self.make_grpc_request(result_id=result_id, source_id=source_id))


class MultijoinTablesOp(TableOp):
def __init__(self, multi_join_inputs: List["MultiJoinInput"]):
self.multi_join_inputs = multi_join_inputs

@classmethod
def get_stub_func(cls, table_service_stub: table_pb2_grpc.TableServiceStub) -> Any:
return table_service_stub.MultiJoinTables

def make_grpc_request(self, result_id, source_id) -> Any:
pb_inputs = []
for mji in self.multi_join_inputs:
source_id = table_pb2.TableReference(ticket=mji.table.ticket)
columns_to_match = mji.on
columns_to_add = mji.joins
pb_inputs.append(table_pb2.MultiJoinInput(source_id=source_id, columns_to_match=columns_to_match,
columns_to_add=columns_to_add))
return table_pb2.MultiJoinTablesRequest(result_id=result_id, multi_join_inputs=pb_inputs)

def make_grpc_request_for_batch(self, result_id, source_id) -> Any:
return table_pb2.BatchTableRequest.Operation(
multi_join=self.make_grpc_request(result_id=result_id, source_id=source_id))
4 changes: 2 additions & 2 deletions py/client/pydeephaven/_table_service.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#
# Copyright (c) 2016-2024 Deephaven Data Labs and Patent Pending
#
from typing import Union, List
from typing import Union, List, Optional

from pydeephaven._batch_assembler import BatchOpAssembler
from pydeephaven._table_ops import TableOp
Expand Down Expand Up @@ -37,7 +37,7 @@ def batch(self, ops: List[TableOp]) -> Table:
except Exception as e:
raise DHError("failed to finish the table batch operation.") from e

def grpc_table_op(self, table: Table, op: TableOp, table_class: type = Table) -> Union[Table, InputTable]:
def grpc_table_op(self, table: Optional[Table], op: TableOp, table_class: type = Table) -> Union[Table, InputTable]:
"""Makes a single gRPC Table operation call and returns a new Table."""
try:
result_id = self.session.make_ticket()
Expand Down
66 changes: 63 additions & 3 deletions py/client/pydeephaven/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import os
from random import random
import threading
from typing import Any, Dict, Iterable, List, Union, Tuple, NewType
from typing import Any, Dict, Iterable, List, Union, Tuple, NewType, Sequence
from uuid import uuid4

import grpc
Expand All @@ -26,7 +26,8 @@
from pydeephaven._input_table_service import InputTableService
from pydeephaven._plugin_obj_service import PluginObjService
from pydeephaven._session_service import SessionService
from pydeephaven._table_ops import TimeTableOp, EmptyTableOp, MergeTablesOp, FetchTableOp, CreateInputTableOp
from pydeephaven._table_ops import TimeTableOp, EmptyTableOp, MergeTablesOp, FetchTableOp, CreateInputTableOp, \
MultijoinTablesOp
from pydeephaven._table_service import TableService
from pydeephaven._utils import to_list
from pydeephaven.dherror import DHError
Expand Down Expand Up @@ -105,6 +106,7 @@ def random_ticket(cls) -> SharedTicket:
bytes_ = uuid4().int.to_bytes(16, byteorder='little', signed=False)
return cls(ticket_bytes=b'h' + bytes_)


_BidiRpc = NewType("_BidiRpc", grpc.StreamStreamMultiCallable)

_NotBidiRpc = NewType(
Expand All @@ -114,6 +116,7 @@ def random_ticket(cls) -> SharedTicket:
grpc.UnaryStreamMultiCallable,
grpc.StreamUnaryMultiCallable])


class Session:
"""A Session object represents a connection to the Deephaven data server. It contains a number of convenience
methods for asking the server to create tables, import Arrow data into tables, merge tables, run Python scripts, and
Expand Down Expand Up @@ -426,7 +429,7 @@ def _connect(self):
# started together don't align retries.
skew = random()
# Backoff schedule for retries after consecutive failures to refresh auth token
self._refresh_backoff = [ skew + 0.1, skew + 1, skew + 10 ]
self._refresh_backoff = [skew + 0.1, skew + 1, skew + 10]

if self._refresh_backoff[0] > self._timeout_seconds:
raise DHError(f'server configuration http.session.durationMs={session_duration} is too small.')
Expand Down Expand Up @@ -734,3 +737,60 @@ def plugin_client(self, exportable_obj: ticket_pb2.TypedTicket) -> PluginClient:

Part of the experimental plugin API."""
return PluginClient(self, exportable_obj)

def multi_join(self, input: Union[Table, Sequence[Table], MultiJoinInput, Sequence[MultiJoinInput]],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Users have very very strongly wanted the server and client APIs to be the same. This has a number of deviations compared to the server that I can see at a glance.

  1. multi_join is a method on session instead of just being a function.
  2. multi_join is in a different submodule (session instead of table).
  3. The return type is Table instead of MultiJoinTable.

How similar can the APIs be?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Other clients have taken the approach of returning Table for now, instead of MultiJoinTable. MJT is more of a future-looking API, and given that its potential (adding more joins to an existing MJ and thus getting out a new Table) hasn't been realized yet there's no real benefit to building an object type plugin for MJT.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have moved the new method to table.py as a function and have also relocated the MultiJoinInput class there as well.

@chipkent I totally understand the desire to have the interface be the same for the two APIs, but the client isn't limited to working with just one server and making a 'session' level function a method on the Session object has the benefit of clear isolation/readability etc. So this change has a trade-off which I am quite ambivalent about.

on: Union[str, Sequence[str]] = None) -> Table:
""" The multi_join method creates a new table by performing a multi-table natural join on the input tables. The
result consists of the set of distinct keys from the input tables natural joined to each input table. Input
tables need not have a matching row for each key, but they may not have multiple matching rows for a given key.

Args:
input (Union[Table, Sequence[Table], MultiJoinInput, Sequence[MultiJoinInput]]): the input objects specifying the
tables and columns to include in the join.
on (Union[str, Sequence[str]], optional): the column(s) to match, can be a common name or an equality expression
that matches every input table, i.e. "col_a = col_b" to rename output column names. Note: When
MultiJoinInput objects are supplied, this parameter must be omitted.

Returns:
Table: the result of the multi-table natural join operation.

Raises:
DHError
"""
if isinstance(input, Table) or (isinstance(input, Sequence) and all(isinstance(t, Table) for t in input)):
tables = to_list(input)
multi_join_inputs = [MultiJoinInput(table=t, on=on) for t in tables]
elif isinstance(input, MultiJoinInput) or (
isinstance(input, Sequence) and all(isinstance(ji, MultiJoinInput) for ji in input)):
if on is not None:
raise DHError(message="on parameter is not permitted when MultiJoinInput objects are provided.")
multi_join_inputs = to_list(input)
else:
raise DHError(
message="input must be a Table, a sequence of Tables, a MultiJoinInput, or a sequence of MultiJoinInputs.")

table_op = MultijoinTablesOp(multi_join_inputs=multi_join_inputs)
return self.table_service.grpc_table_op(None, table_op, table_class=Table)


class MultiJoinInput:
chipkent marked this conversation as resolved.
Show resolved Hide resolved
"""A MultiJoinInput represents the input tables, key columns and additional columns to be used in the multi-table
natural join.
"""
table: Table
on: Union[str, Sequence[str]]
joins: Union[str, Sequence[str]] = None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not present on the server MultiJoinInput. Should they be made private? Should they be added to the server side?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Server side is a wrapper around a Java MJI. I'm ambivalent about whether these are user-facing, but they should not be mutable.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is still an incompatibility in the API.


def __init__(self, table: Table, on: Union[str, Sequence[str]], joins: Union[str, Sequence[str]] = None):
"""Initializes a MultiJoinInput object.

Args:
table (Table): the right table to include in the join
on (Union[str, Sequence[str]]): the column(s) to match, can be a common name or an equality expression that
matches every input table, i.e. "col_a = col_b" to rename output column names.
joins (Union[str, Sequence[str]], optional): the column(s) to be added from the table to the result
table, can be renaming expressions, i.e. "new_col = col"; default is None
"""
self.table = table
self.on = to_list(on)
self.joins = to_list(joins)
113 changes: 113 additions & 0 deletions py/client/tests/test_multijoin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
#
# Copyright (c) 2016-2024 Deephaven Data Labs and Patent Pending
#
import unittest

from pyarrow import csv

from pydeephaven import DHError
from pydeephaven.session import MultiJoinInput
from tests.testbase import BaseTestCase


class MultiJoinTestCase(BaseTestCase):
def setUp(self):
super().setUp()
pa_table = csv.read_csv(self.csv_file)
self.static_tableA = self.session.import_table(pa_table).select(["a", "b", "c1=c", "d1=d", "e1=e"])
self.static_tableB = self.static_tableA.update(["c2=c1+1", "d2=d1+2", "e2=e1+3"]).drop_columns(
["c1", "d1", "e1"])
self.ticking_tableA = self.session.time_table("PT00:00:00.001").update(
["a = i", "b = i*i % 13", "c1 = i * 13 % 23", "d1 = a + b", "e1 = a - b"]).drop_columns(["Timestamp"])
self.ticking_tableB = self.ticking_tableA.update(["c2=c1+1", "d2=d1+2", "e2=e1+3"]).drop_columns(
["c1", "d1", "e1"])

def tearDown(self) -> None:
self.static_tableA = None
self.static_tableB = None
self.ticking_tableA = None
self.ticking_tableB = None
super().tearDown()

def test_static_simple(self):
# Test with multiple input tables
mj_table = self.session.multi_join(input=[self.static_tableA, self.static_tableB], on=["a", "b"])

# Output table is static
self.assertFalse(mj_table.is_refreshing)
# Output table has same # rows as sources
self.assertEqual(mj_table.size, self.static_tableA.size)
self.assertEqual(mj_table.size, self.static_tableB.size)

# Test with a single input table
mj_table = self.session.multi_join(self.static_tableA, ["a", "b"])

# Output table is static
self.assertFalse(mj_table.is_refreshing)
# Output table has same # rows as sources
self.assertEqual(mj_table.size, self.static_tableA.size)

def test_ticking_simple(self):
# Test with multiple input tables
mj_table = self.session.multi_join(input=[self.ticking_tableA, self.ticking_tableB], on=["a", "b"])

# Output table is refreshing
self.assertTrue(mj_table.is_refreshing)

# Test with a single input table
mj_table = self.session.multi_join(input=self.ticking_tableA, on=["a", "b"])

# Output table is refreshing
self.assertTrue(mj_table.is_refreshing)

def test_static(self):
# Test with multiple input
mj_input = [
MultiJoinInput(table=self.static_tableA, on=["key1=a", "key2=b"], joins=["c1", "e1"]),
MultiJoinInput(table=self.static_tableB, on=["key1=a", "key2=b"], joins=["d2"])
]
mj_table = self.session.multi_join(mj_input)

# Output table is static
self.assertFalse(mj_table.is_refreshing)
# Output table has same # rows as sources
self.assertEqual(mj_table.size, self.static_tableA.size)
self.assertEqual(mj_table.size, self.static_tableB.size)

# Test with a single input
mj_table = self.session.multi_join(MultiJoinInput(table=self.static_tableA, on=["key1=a", "key2=b"], joins="c1"))

# Output table is static
self.assertFalse(mj_table.is_refreshing)
# Output table has same # rows as sources
self.assertEqual(mj_table.size, self.static_tableA.size)

def test_ticking(self):
# Test with multiple input
mj_input = [
MultiJoinInput(table=self.ticking_tableA, on=["key1=a", "key2=b"], joins=["c1", "e1"]),
MultiJoinInput(table=self.ticking_tableB, on=["key1=a", "key2=b"], joins=["d2"])
]
mj_table = self.session.multi_join(mj_input)

# Output table is refreshing
self.assertTrue(mj_table.is_refreshing)

# Test with a single input
mj_table = self.session.multi_join(input=MultiJoinInput(table=self.ticking_tableA, on=["key1=a", "key2=b"], joins="c1"))

# Output table is refreshing
self.assertTrue(mj_table.is_refreshing)

def test_errors(self):
# Assert the exception is raised when providing MultiJoinInput and the on parameter is not None (omitted).
mj_input = [
MultiJoinInput(table=self.ticking_tableA, on=["key1=a", "key2=b"], joins=["c1", "e1"]),
MultiJoinInput(table=self.ticking_tableB, on=["key1=a", "key2=b"], joins=["d2"])
]
with self.assertRaises(DHError):
mj_table = self.session.multi_join(mj_input, on=["key1=a", "key2=b"])


if __name__ == '__main__':
unittest.main()
2 changes: 1 addition & 1 deletion py/server/deephaven/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -3782,7 +3782,7 @@ def __init__(self, table: Table, on: Union[str, Sequence[str]], joins: Union[str
table (Table): the right table to include in the join
on (Union[str, Sequence[str]]): the column(s) to match, can be a common name or an equal expression,
i.e. "col_a = col_b" for different column names
joins (Union[str, Sequence[str]], optional): the column(s) to be added from the this table to the result
joins (Union[str, Sequence[str]], optional): the column(s) to be added from the table to the result
table, can be renaming expressions, i.e. "new_col = col"; default is None

Raises:
Expand Down