Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More FlightEndpoint attributes #2

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 36 additions & 3 deletions python/pyarrow/_flight.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ from libcpp cimport bool as c_bool
from pyarrow.lib cimport *
from pyarrow.lib import (ArrowCancelled, ArrowException, ArrowInvalid,
SignalStopHandler)
from pyarrow.lib import as_buffer, frombytes, tobytes
from pyarrow.lib import as_buffer, frombytes, timestamp, tobytes
from pyarrow.includes.chrono cimport duration_cast, microseconds
from pyarrow.includes.libarrow_flight cimport *
from pyarrow.ipc import _get_legacy_format_default, _ReadPandasMixin
import pyarrow.lib as lib
Expand Down Expand Up @@ -704,7 +705,7 @@ cdef class FlightEndpoint(_Weakrefable):
cdef:
CFlightEndpoint endpoint

def __init__(self, ticket, locations):
def __init__(self, ticket, locations, expiration_time=None, app_metadata=""):
"""Create a FlightEndpoint from a ticket and list of locations.

Parameters
Expand All @@ -713,6 +714,12 @@ cdef class FlightEndpoint(_Weakrefable):
the ticket needed to access this flight
locations : list of string URIs
locations where this flight is available
expiration_time : TimestampScalar optional, default None
Expiration time of this stream. If present, clients may assume
they can retry DoGet requests. Otherwise, clients should avoid
retrying DoGet requests.
app_metadata : bytes or str optional, default ""
Application-defined opaque metadata.

Raises
------
Expand All @@ -736,6 +743,12 @@ cdef class FlightEndpoint(_Weakrefable):
CLocation.Parse(tobytes(location)).Value(&c_location))
self.endpoint.locations.push_back(c_location)

if expiration_time is not None:
self.endpoint.expiration_time = time_point(
microseconds(expiration_time.cast(timestamp("us")).value))

self.endpoint.app_metadata = tobytes(app_metadata)

@property
def ticket(self):
"""Get the ticket in this endpoint."""
Expand All @@ -746,6 +759,24 @@ cdef class FlightEndpoint(_Weakrefable):
return [Location.wrap(location)
for location in self.endpoint.locations]

@property
def expiration_time(self):
cdef:
int64_t time_since_epoch
const char* UTC = "UTC"
shared_ptr[CTimestampType] time_type = make_shared[CTimestampType](TimeUnit.TimeUnit_MICRO, UTC)
shared_ptr[CTimestampScalar] shared
if self.endpoint.expiration_time.has_value():
time_since_epoch = duration_cast[microseconds](
self.endpoint.expiration_time.value().time_since_epoch()).count()
shared = make_shared[CTimestampScalar](time_since_epoch, time_type)
return Scalar.wrap(<shared_ptr[CScalar]> shared)
return None

@property
def app_metadata(self):
return self.endpoint.app_metadata

def serialize(self):
"""Get the wire-format representation of this type.

Expand All @@ -770,7 +801,9 @@ cdef class FlightEndpoint(_Weakrefable):

def __repr__(self):
return (f"<pyarrow.flight.FlightEndpoint ticket={self.ticket!r} "
f"locations={self.locations!r}>")
f"locations={self.locations!r} "
f"expiration_time={self.expiration_time} "
f"app_metadata='{self.app_metadata.hex()}'>")

def __eq__(self, FlightEndpoint other):
return self.endpoint == other.endpoint
Expand Down
37 changes: 37 additions & 0 deletions python/pyarrow/includes/chrono.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# distutils: language = c++

from libc.stdint cimport *


cdef extern from "<chrono>" namespace "std::chrono":
cdef cppclass duration:
Copy link

@adamreeve adamreeve Jul 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious why the template parameter isn't needed here, but I guess it's inferred from the constructor parameter type and/or count return type?

duration(int64_t count)
const int64_t count()

cdef cppclass microseconds(duration):
microseconds(int64_t count)

T duration_cast[T](duration d)


cdef extern from "<chrono>" namespace "std::chrono::system_clock":
cdef cppclass time_point:
time_point(const duration& d)
const duration time_since_epoch()
1 change: 1 addition & 0 deletions python/pyarrow/includes/libarrow.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@ cdef extern from "arrow/api.h" namespace "arrow" nogil:

cdef cppclass CTimestampType" arrow::TimestampType"(CFixedWidthType):
CTimestampType(TimeUnit unit)
CTimestampType(TimeUnit unit, const c_string& timezone)
TimeUnit unit()
const c_string& timezone()

Expand Down
3 changes: 3 additions & 0 deletions python/pyarrow/includes/libarrow_flight.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from pyarrow.includes.common cimport *
from pyarrow.includes.libarrow cimport *
from pyarrow.includes.chrono cimport time_point


cdef extern from "arrow/flight/api.h" namespace "arrow" nogil:
Expand Down Expand Up @@ -134,6 +135,8 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil:

CTicket ticket
vector[CLocation] locations
optional[time_point] expiration_time
c_string app_metadata

bint operator==(CFlightEndpoint)
CResult[c_string] SerializeToString()
Expand Down
35 changes: 30 additions & 5 deletions python/pyarrow/tests/test_flight.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,12 +251,14 @@ def get_flight_info(self, context, descriptor):
flight.FlightEndpoint(
b'',
[flight.Location.for_grpc_tcp('localhost', 5005)],
pa.scalar("2023-04-05T12:34:56.789").cast(pa.timestamp("ms")),
"endpoint app metadata"
),
],
1,
42,
True,
"test app metadata"
"info app metadata"
)

def get_schema(self, context, descriptor):
Expand Down Expand Up @@ -877,7 +879,9 @@ def test_repr():
descriptor_repr = "<pyarrow.flight.FlightDescriptor cmd=b'foo'>"
endpoint_repr = ("<pyarrow.flight.FlightEndpoint "
"ticket=<pyarrow.flight.Ticket ticket=b'foo'> "
"locations=[]>")
"locations=[] "
"expiration_time=2023-04-05 12:34:56+00:00 "
"app_metadata='656e64706f696e7420617070206d65746164617461'>")
info_repr = (
"<pyarrow.flight.FlightInfo "
"schema= "
Expand All @@ -896,7 +900,11 @@ def test_repr():
assert repr(flight.ActionType("foo", "bar")) == action_type_repr
assert repr(flight.BasicAuth("user", "pass")) == basic_auth_repr
assert repr(flight.FlightDescriptor.for_command("foo")) == descriptor_repr
assert repr(flight.FlightEndpoint(b"foo", [])) == endpoint_repr
endpoint = flight.FlightEndpoint(
b"foo", [], pa.scalar("2023-04-05T12:34:56").cast(pa.timestamp("s")),
b"endpoint app metadata"
)
assert repr(endpoint) == endpoint_repr
info = flight.FlightInfo(
pa.schema([]), flight.FlightDescriptor.for_path(), [],
1, 42, True, b"test app metadata"
Expand Down Expand Up @@ -933,6 +941,14 @@ def test_eq():
flight.FlightEndpoint(
b"foo", [flight.Location("grpc+tls://localhost:1234")])
),
lambda: (
flight.FlightEndpoint(
b"foo", [], pa.scalar("2023-04-05T12:34:56").cast(pa.timestamp("s"))),
flight.FlightEndpoint(
b"foo", [],
pa.scalar("2023-04-05T12:34:56.789").cast(pa.timestamp("ms")))),
lambda: (flight.FlightEndpoint(b"foo", [], app_metadata=b''),
flight.FlightEndpoint(b"foo", [], app_metadata=b'meta')),
lambda: (
flight.FlightInfo(
pa.schema([]),
Expand Down Expand Up @@ -1127,11 +1143,16 @@ def test_flight_get_info():
assert info.total_records == 1
assert info.total_bytes == 42
assert info.ordered
assert info.app_metadata == b"test app metadata"
assert info.app_metadata == b"info app metadata"
assert info.schema == pa.schema([('a', pa.int32())])
assert len(info.endpoints) == 2
assert len(info.endpoints[0].locations) == 1
assert info.endpoints[0].expiration_time is None
assert info.endpoints[0].app_metadata == b""
assert info.endpoints[0].locations[0] == flight.Location('grpc://test')
assert info.endpoints[1].expiration_time == \
pa.scalar("2023-04-05T12:34:56.789+00:00").cast(pa.timestamp("us", "UTC"))
assert info.endpoints[1].app_metadata == b"endpoint app metadata"
assert info.endpoints[1].locations[0] == \
flight.Location.for_grpc_tcp('localhost', 5005)

Expand Down Expand Up @@ -1742,6 +1763,8 @@ def test_roundtrip_types():
flight.FlightEndpoint(
b'',
[flight.Location.for_grpc_tcp('localhost', 5005)],
pa.scalar("2023-04-05T12:34:56").cast(pa.timestamp("ms")),
b'endpoint app metadata'
),
],
1,
Expand All @@ -1760,7 +1783,9 @@ def test_roundtrip_types():

endpoint = flight.FlightEndpoint(
ticket,
['grpc://test', flight.Location.for_grpc_tcp('localhost', 5005)]
['grpc://test', flight.Location.for_grpc_tcp('localhost', 5005)],
pa.scalar("2023-04-05T12:34:56").cast(pa.timestamp("s")),
b'endpoint app metadata'
)
assert endpoint == flight.FlightEndpoint.deserialize(endpoint.serialize())

Expand Down
Loading