From 5457b54a68f6408983d6f9580fa1982022cfe849 Mon Sep 17 00:00:00 2001 From: Quentin Dawans Date: Tue, 17 Dec 2024 15:44:45 +0100 Subject: [PATCH 1/2] mypy: make quixstreams.core.* pass type checks --- pyproject.toml | 4 +- quixstreams/context.py | 9 +- quixstreams/core/stream/functions/apply.py | 61 +++++-- quixstreams/core/stream/functions/base.py | 3 + quixstreams/core/stream/functions/filter.py | 22 ++- .../core/stream/functions/transform.py | 20 ++- quixstreams/core/stream/functions/update.py | 10 +- quixstreams/core/stream/stream.py | 108 +++++++++-- quixstreams/dataframe/base.py | 9 +- quixstreams/dataframe/dataframe.py | 168 ++++++++++++------ quixstreams/dataframe/registry.py | 2 +- .../test_dataframe/test_dataframe.py | 2 +- 12 files changed, 322 insertions(+), 96 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d4047f41d..6d170dfc1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -123,8 +123,8 @@ ignore_errors = true [[tool.mypy.overrides]] module = [ - "quixstreams.core.*", - "quixstreams.dataframe.*", + "quixstreams.dataframe.series.*", + "quixstreams.dataframe.windows.*", "quixstreams.rowproducer.*" ] ignore_errors = true diff --git a/quixstreams/context.py b/quixstreams/context.py index 2de9d3be9..18286794f 100644 --- a/quixstreams/context.py +++ b/quixstreams/context.py @@ -50,7 +50,7 @@ def alter_context(value): _current_message_context.set(context) -def message_context() -> Optional[MessageContext]: +def message_context() -> MessageContext: """ Get a MessageContext for the current message, which houses most of the message metadata, like: @@ -75,6 +75,11 @@ def message_context() -> Optional[MessageContext]: :return: instance of `MessageContext` """ try: - return _current_message_context.get() + ctx = _current_message_context.get() except LookupError: raise MessageContextNotSetError("Message context is not set") + + if ctx is None: + raise MessageContextNotSetError("Message context is not set") + + return ctx diff --git a/quixstreams/core/stream/functions/apply.py b/quixstreams/core/stream/functions/apply.py index 254caca2b..bdf493953 100644 --- a/quixstreams/core/stream/functions/apply.py +++ b/quixstreams/core/stream/functions/apply.py @@ -1,7 +1,13 @@ -from typing import Any +from typing import Any, Literal, Union, overload from .base import StreamFunction -from .types import ApplyCallback, ApplyWithMetadataCallback, VoidExecutor +from .types import ( + ApplyCallback, + ApplyExpandedCallback, + ApplyWithMetadataCallback, + ApplyWithMetadataExpandedCallback, + VoidExecutor, +) __all__ = ("ApplyFunction", "ApplyWithMetadataFunction") @@ -14,22 +20,34 @@ class ApplyFunction(StreamFunction): and its result will always be passed downstream. """ + @overload + def __init__(self, func: ApplyCallback, expand: Literal[False] = False) -> None: ... + + @overload + def __init__(self, func: ApplyExpandedCallback, expand: Literal[True]) -> None: ... + def __init__( self, - func: ApplyCallback, + func: Union[ApplyCallback, ApplyExpandedCallback], expand: bool = False, ): super().__init__(func) + + self.func: Union[ApplyCallback, ApplyExpandedCallback] self.expand = expand def get_executor(self, *child_executors: VoidExecutor) -> VoidExecutor: child_executor = self._resolve_branching(*child_executors) + func = self.func if self.expand: def wrapper( - value: Any, key: Any, timestamp: int, headers: Any, func=self.func - ): + value: Any, + key: Any, + timestamp: int, + headers: Any, + ) -> None: # Execute a function on a single value and wrap results into a list # to expand them downstream result = func(value) @@ -39,8 +57,11 @@ def wrapper( else: def wrapper( - value: Any, key: Any, timestamp: int, headers: Any, func=self.func - ): + value: Any, + key: Any, + timestamp: int, + headers: Any, + ) -> None: # Execute a function on a single value and return its result result = func(value) child_executor(result, key, timestamp, headers) @@ -57,20 +78,37 @@ class ApplyWithMetadataFunction(StreamFunction): and its result will always be passed downstream. """ + @overload + def __init__( + self, func: ApplyWithMetadataCallback, expand: Literal[False] = False + ) -> None: ... + + @overload + def __init__( + self, func: ApplyWithMetadataExpandedCallback, expand: Literal[True] + ) -> None: ... + def __init__( self, - func: ApplyWithMetadataCallback, + func: Union[ApplyWithMetadataCallback, ApplyWithMetadataExpandedCallback], expand: bool = False, ): super().__init__(func) + + self.func: Union[ApplyWithMetadataCallback, ApplyWithMetadataExpandedCallback] self.expand = expand def get_executor(self, *child_executors: VoidExecutor) -> VoidExecutor: child_executor = self._resolve_branching(*child_executors) + func = self.func + if self.expand: def wrapper( - value: Any, key: Any, timestamp: int, headers: Any, func=self.func + value: Any, + key: Any, + timestamp: int, + headers: Any, ): # Execute a function on a single value and wrap results into a list # to expand them downstream @@ -81,7 +119,10 @@ def wrapper( else: def wrapper( - value: Any, key: Any, timestamp: int, headers: Any, func=self.func + value: Any, + key: Any, + timestamp: int, + headers: Any, ): # Execute a function on a single value and return its result result = func(value, key, timestamp, headers) diff --git a/quixstreams/core/stream/functions/base.py b/quixstreams/core/stream/functions/base.py index 032055207..24438ef13 100644 --- a/quixstreams/core/stream/functions/base.py +++ b/quixstreams/core/stream/functions/base.py @@ -38,6 +38,9 @@ def _resolve_branching(self, *child_executors: VoidExecutor) -> VoidExecutor: If there's only one executor - copying is not neccessary, and the executor is returned as is. """ + if not child_executors: + raise TypeError("At least one executor is required") + if len(child_executors) > 1: def wrapper( diff --git a/quixstreams/core/stream/functions/filter.py b/quixstreams/core/stream/functions/filter.py index bd7a96a3a..e291880c7 100644 --- a/quixstreams/core/stream/functions/filter.py +++ b/quixstreams/core/stream/functions/filter.py @@ -17,11 +17,18 @@ class FilterFunction(StreamFunction): def __init__(self, func: FilterCallback): super().__init__(func) + self.func: FilterCallback def get_executor(self, *child_executors: VoidExecutor) -> VoidExecutor: child_executor = self._resolve_branching(*child_executors) - - def wrapper(value: Any, key: Any, timestamp: int, headers: Any, func=self.func): + func = self.func + + def wrapper( + value: Any, + key: Any, + timestamp: int, + headers: Any, + ): # Filter a single value if func(value): child_executor(value, key, timestamp, headers) @@ -42,11 +49,18 @@ class FilterWithMetadataFunction(StreamFunction): def __init__(self, func: FilterWithMetadataCallback): super().__init__(func) + self.func: FilterWithMetadataCallback def get_executor(self, *child_executors: VoidExecutor) -> VoidExecutor: child_executor = self._resolve_branching(*child_executors) - - def wrapper(value: Any, key: Any, timestamp: int, headers: Any, func=self.func): + func = self.func + + def wrapper( + value: Any, + key: Any, + timestamp: int, + headers: Any, + ): # Filter a single value if func(value, key, timestamp, headers): child_executor(value, key, timestamp, headers) diff --git a/quixstreams/core/stream/functions/transform.py b/quixstreams/core/stream/functions/transform.py index dae7872dc..219662b6b 100644 --- a/quixstreams/core/stream/functions/transform.py +++ b/quixstreams/core/stream/functions/transform.py @@ -1,4 +1,4 @@ -from typing import Any, Union +from typing import Any, Literal, Union, cast, overload from .base import StreamFunction from .types import TransformCallback, TransformExpandedCallback, VoidExecutor @@ -21,38 +21,50 @@ class TransformFunction(StreamFunction): The result of the callback will always be passed downstream. """ + @overload + def __init__( + self, func: TransformCallback, expand: Literal[False] = False + ) -> None: ... + + @overload + def __init__( + self, func: TransformExpandedCallback, expand: Literal[True] + ) -> None: ... + def __init__( self, func: Union[TransformCallback, TransformExpandedCallback], expand: bool = False, ): super().__init__(func) + + self.func: Union[TransformCallback, TransformExpandedCallback] self.expand = expand def get_executor(self, *child_executors: VoidExecutor) -> VoidExecutor: child_executor = self._resolve_branching(*child_executors) if self.expand: + expanded_func = cast(TransformExpandedCallback, self.func) def wrapper( value: Any, key: Any, timestamp: int, headers: Any, - func: TransformExpandedCallback = self.func, ): - result = func(value, key, timestamp, headers) + result = expanded_func(value, key, timestamp, headers) for new_value, new_key, new_timestamp, new_headers in result: child_executor(new_value, new_key, new_timestamp, new_headers) else: + func = cast(TransformCallback, self.func) def wrapper( value: Any, key: Any, timestamp: int, headers: Any, - func: TransformCallback = self.func, ): # Execute a function on a single value and return its result new_value, new_key, new_timestamp, new_headers = func( diff --git a/quixstreams/core/stream/functions/update.py b/quixstreams/core/stream/functions/update.py index 373e634e6..b2d9a19bc 100644 --- a/quixstreams/core/stream/functions/update.py +++ b/quixstreams/core/stream/functions/update.py @@ -20,10 +20,13 @@ class UpdateFunction(StreamFunction): def __init__(self, func: UpdateCallback): super().__init__(func) + self.func: UpdateCallback + def get_executor(self, *child_executors: VoidExecutor) -> VoidExecutor: child_executor = self._resolve_branching(*child_executors) + func = self.func - def wrapper(value: Any, key: Any, timestamp: int, headers: Any, func=self.func): + def wrapper(value: Any, key: Any, timestamp: int, headers: Any): # Update a single value and forward it func(value) child_executor(value, key, timestamp, headers) @@ -45,10 +48,13 @@ class UpdateWithMetadataFunction(StreamFunction): def __init__(self, func: UpdateWithMetadataCallback): super().__init__(func) + self.func: UpdateWithMetadataCallback + def get_executor(self, *child_executors: VoidExecutor) -> VoidExecutor: child_executor = self._resolve_branching(*child_executors) + func = self.func - def wrapper(value: Any, key: Any, timestamp: int, headers: Any, func=self.func): + def wrapper(value: Any, key: Any, timestamp: int, headers: Any): # Update a single value and forward it func(value, key, timestamp, headers) child_executor(value, key, timestamp, headers) diff --git a/quixstreams/core/stream/stream.py b/quixstreams/core/stream/stream.py index e7708fd36..947fac6ad 100644 --- a/quixstreams/core/stream/stream.py +++ b/quixstreams/core/stream/stream.py @@ -3,7 +3,18 @@ import functools import itertools from time import monotonic_ns -from typing import Any, Callable, List, Optional, Union +from typing import ( + Any, + Callable, + Deque, + List, + Literal, + Optional, + Set, + Union, + cast, + overload, +) from typing_extensions import Self @@ -84,7 +95,7 @@ def __init__( self.func = func if func is not None else ApplyFunction(lambda value: value) self.parent = parent - self.children = set() + self.children: Set[Self] = set() self.generated = monotonic_ns() self.pruned = False @@ -101,6 +112,14 @@ def __repr__(self) -> str: ) return f"<{self.__class__.__name__} [{len(tree_funcs)}]: {funcs_repr}>" + @overload + def add_filter(self, func: FilterCallback, *, metadata: Literal[False]): + pass + + @overload + def add_filter(self, func: FilterWithMetadataCallback, *, metadata: Literal[True]): + pass + def add_filter( self, func: Union[FilterCallback, FilterWithMetadataCallback], @@ -121,11 +140,49 @@ def add_filter( :return: a new `Stream` derived from the current one """ if metadata: - filter_func = FilterWithMetadataFunction(func) + filter_func: StreamFunction = FilterWithMetadataFunction( + cast(FilterWithMetadataCallback, func) + ) else: - filter_func = FilterFunction(func) + filter_func = FilterFunction(cast(FilterCallback, func)) return self._add(filter_func) + @overload + def add_apply( + self, func: ApplyCallback, *, expand: Literal[False], metadata: Literal[False] + ): + pass + + @overload + def add_apply( + self, + func: ApplyExpandedCallback, + *, + expand: Literal[True], + metadata: Literal[False], + ): + pass + + @overload + def add_apply( + self, + func: ApplyWithMetadataCallback, + *, + expand: Literal[False], + metadata: Literal[True], + ): + pass + + @overload + def add_apply( + self, + func: ApplyWithMetadataExpandedCallback, + *, + expand: Literal[True], + metadata: Literal[True], + ): + pass + def add_apply( self, func: Union[ @@ -154,11 +211,19 @@ def add_apply( :return: a new `Stream` derived from the current one """ if metadata: - apply_func = ApplyWithMetadataFunction(func, expand=expand) + apply_func: StreamFunction = ApplyWithMetadataFunction(func, expand=expand) # type: ignore[call-overload] else: - apply_func = ApplyFunction(func, expand=expand) + apply_func = ApplyFunction(func, expand=expand) # type: ignore[call-overload] return self._add(apply_func) + @overload + def add_update(self, func: UpdateCallback, *, metadata: Literal[False]): + pass + + @overload + def add_update(self, func: UpdateWithMetadataCallback, *, metadata: Literal[True]): + pass + def add_update( self, func: Union[UpdateCallback, UpdateWithMetadataCallback], @@ -178,11 +243,21 @@ def add_update( :return: a new Stream derived from the current one """ if metadata: - update_func = UpdateWithMetadataFunction(func) + update_func: StreamFunction = UpdateWithMetadataFunction( + cast(UpdateWithMetadataCallback, func) + ) else: - update_func = UpdateFunction(func) + update_func = UpdateFunction(cast(UpdateCallback, func)) return self._add(update_func) + @overload + def add_transform(self, func: TransformCallback, *, expand: Literal[False]): + pass + + @overload + def add_transform(self, func: TransformExpandedCallback, *, expand: Literal[True]): + pass + def add_transform( self, func: Union[TransformCallback, TransformExpandedCallback], @@ -205,10 +280,9 @@ def add_transform( Default - `False`. :return: a new Stream derived from the current one """ + return self._add(TransformFunction(func, expand=expand)) # type: ignore[call-overload] - return self._add(TransformFunction(func, expand=expand)) - - def diff(self, other: "Stream") -> Self: + def diff(self, other: Self) -> Self: """ Takes the difference between Streams `self` and `other` based on their last common parent, and returns a new, independent `Stream` that includes only @@ -267,6 +341,10 @@ def diff(self, other: "Stream") -> Self: node.parent = parent parent = node self._prune(diff[0]) + + if parent is None: + raise InvalidOperation("No common parent found") + return parent def root_path(self, allow_splits=True) -> List[Self]: @@ -369,7 +447,7 @@ def compose_returning(self) -> ReturningExecutor: # after executing the Stream. # The composed stream must have only the "apply" functions, # which always return a single. - buffer = collections.deque(maxlen=1) + buffer: Deque[tuple[Any, Any, int, Any]] = collections.deque(maxlen=1) composed = self.compose( allow_filters=False, allow_expands=False, @@ -394,12 +472,12 @@ def wrapper(value: Any, key: Any, timestamp: int, headers: Any) -> Any: def _compose( self, tree: List[Self], - composed: List[Callable[[Any, Any, int, Any], None]], + composed: Union[VoidExecutor, List[VoidExecutor]], allow_filters: bool, allow_updates: bool, allow_expands: bool, allow_transforms: bool, - ) -> VoidExecutor: + ) -> Union[VoidExecutor, List[VoidExecutor]]: functions = [node.func for node in tree] # Iterate over a reversed list of functions @@ -461,7 +539,7 @@ def _prune(self, other: Self): if other.pruned: raise ValueError("Node has already been pruned") other.pruned = True - node = self + node: Optional[Self] = self while node: if other in node.children: node.children.remove(other) diff --git a/quixstreams/dataframe/base.py b/quixstreams/dataframe/base.py index 4fcca213c..ede71e663 100644 --- a/quixstreams/dataframe/base.py +++ b/quixstreams/dataframe/base.py @@ -11,9 +11,14 @@ class BaseStreaming: def stream(self) -> Stream: ... @abc.abstractmethod - def compose(self, *args, **kwargs) -> VoidExecutor: ... + def compose(self, *args, **kwargs) -> dict[str, VoidExecutor]: ... @abc.abstractmethod def test( - self, value: Any, key: Any, timestamp: int, ctx: Optional[MessageContext] = None + self, + value: Any, + key: Any, + timestamp: int, + headers: Optional[Any] = None, + ctx: Optional[MessageContext] = None, ) -> Any: ... diff --git a/quixstreams/dataframe/dataframe.py b/quixstreams/dataframe/dataframe.py index cf4fad9b3..115570c07 100644 --- a/quixstreams/dataframe/dataframe.py +++ b/quixstreams/dataframe/dataframe.py @@ -26,7 +26,9 @@ ) from quixstreams.core.stream import ( ApplyCallback, + ApplyExpandedCallback, ApplyWithMetadataCallback, + ApplyWithMetadataExpandedCallback, FilterCallback, FilterWithMetadataCallback, Stream, @@ -140,16 +142,14 @@ def stream(self) -> Stream: def topic(self) -> Topic: return self._topic - @overload - def apply(self, func: ApplyCallback, *, expand: bool = ...) -> Self: ... - @overload def apply( self, - func: ApplyWithMetadataCallback, + func: Union[ApplyCallback, ApplyExpandedCallback], *, - metadata: Literal[True], - expand: bool = ..., + stateful: Literal[False] = False, + expand: Union[Literal[False], Literal[True]] = False, + metadata: Literal[False] = False, ) -> Self: ... @overload @@ -158,7 +158,18 @@ def apply( func: ApplyCallbackStateful, *, stateful: Literal[True], - expand: bool = ..., + expand: Union[Literal[False], Literal[True]] = False, + metadata: Literal[False] = False, + ) -> Self: ... + + @overload + def apply( + self, + func: Union[ApplyWithMetadataCallback, ApplyWithMetadataExpandedCallback], + *, + stateful: Literal[False] = False, + expand: Union[Literal[False], Literal[True]] = False, + metadata: Literal[True], ) -> Self: ... @overload @@ -167,16 +178,18 @@ def apply( func: ApplyWithMetadataCallbackStateful, *, stateful: Literal[True], + expand: Union[Literal[False], Literal[True]] = False, metadata: Literal[True], - expand: bool = ..., ) -> Self: ... def apply( self, func: Union[ ApplyCallback, + ApplyExpandedCallback, ApplyCallbackStateful, ApplyWithMetadataCallback, + ApplyWithMetadataExpandedCallback, ApplyWithMetadataCallbackStateful, ], *, @@ -220,35 +233,51 @@ def func(d: dict, state: State): if stateful: self._register_store() # Force the callback to accept metadata - with_metadata_func = ( - cast(ApplyWithMetadataCallbackStateful, func) - if metadata - else _as_metadata_func(cast(ApplyCallbackStateful, func)) - ) + if metadata: + with_metadata_func = cast(ApplyWithMetadataCallbackStateful, func) + else: + with_metadata_func = _as_metadata_func( + cast(ApplyCallbackStateful, func) + ) + stateful_func = _as_stateful( func=with_metadata_func, processing_context=self._processing_context, ) - stream = self.stream.add_apply(stateful_func, expand=expand, metadata=True) + stream = self.stream.add_apply(stateful_func, expand=expand, metadata=True) # type: ignore[call-overload] else: stream = self.stream.add_apply( cast(Union[ApplyCallback, ApplyWithMetadataCallback], func), expand=expand, metadata=metadata, - ) + ) # type: ignore[call-overload] return self.__dataframe_clone__(stream=stream) @overload - def update(self, func: UpdateCallback) -> Self: ... + def update( + self, + func: UpdateCallback, + *, + stateful: Literal[False] = False, + metadata: Literal[False] = False, + ) -> Self: ... @overload def update( - self, func: UpdateWithMetadataCallback, *, metadata: Literal[True] + self, + func: UpdateCallbackStateful, + *, + stateful: Literal[True], + metadata: Literal[False] = False, ) -> Self: ... @overload def update( - self, func: UpdateCallbackStateful, *, stateful: Literal[True] + self, + func: UpdateWithMetadataCallback, + *, + stateful: Literal[False] = False, + metadata: Literal[True], ) -> Self: ... @overload @@ -312,18 +341,18 @@ def func(values: list, state: State): if stateful: self._register_store() # Force the callback to accept metadata - with_metadata_func = ( - func - if metadata - else _as_metadata_func(cast(UpdateCallbackStateful, func)) - ) + if metadata: + with_metadata_func = cast(UpdateWithMetadataCallbackStateful, func) + else: + with_metadata_func = _as_metadata_func( + cast(UpdateCallbackStateful, func) + ) + stateful_func = _as_stateful( - func=cast(UpdateWithMetadataCallbackStateful, with_metadata_func), + func=with_metadata_func, processing_context=self._processing_context, ) - return self._add_update( - cast(UpdateWithMetadataCallback, stateful_func), metadata=True - ) + return self._add_update(stateful_func, metadata=True) else: return self._add_update( cast(Union[UpdateCallback, UpdateWithMetadataCallback], func), @@ -331,16 +360,30 @@ def func(values: list, state: State): ) @overload - def filter(self, func: FilterCallback) -> Self: ... + def filter( + self, + func: FilterCallback, + *, + stateful: Literal[False] = False, + metadata: Literal[False] = False, + ) -> Self: ... @overload def filter( - self, func: FilterWithMetadataCallback, *, metadata: Literal[True] + self, + func: FilterCallbackStateful, + *, + stateful: Literal[True], + metadata: Literal[False] = False, ) -> Self: ... @overload def filter( - self, func: FilterCallbackStateful, *, stateful: Literal[True] + self, + func: FilterWithMetadataCallback, + *, + stateful: Literal[False] = False, + metadata: Literal[True], ) -> Self: ... @overload @@ -399,18 +442,20 @@ def func(d: dict, state: State): if stateful: self._register_store() # Force the callback to accept metadata - with_metadata_func = ( - func - if metadata - else _as_metadata_func(cast(FilterCallbackStateful, func)) - ) + if metadata: + with_metadata_func = cast(FilterWithMetadataCallbackStateful, func) + else: + with_metadata_func = _as_metadata_func( + cast(FilterCallbackStateful, func) + ) + stateful_func = _as_stateful( - func=cast(FilterWithMetadataCallbackStateful, with_metadata_func), + func=with_metadata_func, processing_context=self._processing_context, ) stream = self.stream.add_filter(stateful_func, metadata=True) else: - stream = self.stream.add_filter( + stream = self.stream.add_filter( # type: ignore[call-overload] cast(Union[FilterCallback, FilterWithMetadataCallback], func), metadata=metadata, ) @@ -492,13 +537,18 @@ def func(d: dict, state: State): """ if not key: raise ValueError('Parameter "key" cannot be empty') - if callable(key) and not name: + + operation = name + if not operation and isinstance(key, str): + operation = key + + if not operation: raise ValueError( 'group_by requires "name" parameter when "key" is a function' ) groupby_topic = self._topic_manager.repartition_topic( - operation=name or key, + operation=operation, topic_name=self._topic.name, key_serializer=key_serializer, value_serializer=value_serializer, @@ -619,7 +669,7 @@ def _set_timestamp_callback( new_timestamp = func(value, key, timestamp, headers) return value, key, new_timestamp, headers - stream = self.stream.add_transform(func=_set_timestamp_callback) + stream = self.stream.add_transform(_set_timestamp_callback, expand=False) return self.__dataframe_clone__(stream=stream) def set_headers( @@ -670,7 +720,7 @@ def _set_headers_callback( new_headers = func(value, key, timestamp, headers) return value, key, timestamp, new_headers - stream = self.stream.add_transform(func=_set_headers_callback) + stream = self.stream.add_transform(func=_set_headers_callback, expand=False) return self.__dataframe_clone__(stream=stream) def print(self, pretty: bool = True, metadata: bool = False) -> Self: @@ -707,7 +757,9 @@ def print(self, pretty: bool = True, metadata: bool = False) -> Self: """ print_args = ["value", "key", "timestamp", "headers"] if pretty: - printer = functools.partial(pprint.pprint, indent=2, sort_dicts=False) + printer: Callable[[Any], None] = functools.partial( + pprint.pprint, indent=2, sort_dicts=False + ) else: printer = print return self._add_update( @@ -1136,7 +1188,7 @@ def _add_update( func: Union[UpdateCallback, UpdateWithMetadataCallback], metadata: bool = False, ): - self._stream = self._stream.add_update(func, metadata=metadata) + self._stream = self._stream.add_update(func, metadata=metadata) # type: ignore[call-overload] return self def _register_store(self): @@ -1272,6 +1324,24 @@ def _drop(value: Dict, columns: List[str], ignore_missing: bool = False): raise +@overload +def _as_metadata_func( + func: ApplyCallbackStateful, +) -> ApplyWithMetadataCallbackStateful: ... + + +@overload +def _as_metadata_func( + func: FilterCallbackStateful, +) -> FilterWithMetadataCallbackStateful: ... + + +@overload +def _as_metadata_func( + func: UpdateCallbackStateful, +) -> UpdateWithMetadataCallbackStateful: ... + + def _as_metadata_func( func: Union[ApplyCallbackStateful, FilterCallbackStateful, UpdateCallbackStateful], ) -> Union[ @@ -1289,17 +1359,9 @@ def wrapper( def _as_stateful( - func: Union[ - ApplyWithMetadataCallbackStateful, - FilterWithMetadataCallbackStateful, - UpdateWithMetadataCallbackStateful, - ], + func: ApplyWithMetadataCallbackStateful, processing_context: ProcessingContext, -) -> Union[ - ApplyWithMetadataCallback, - FilterWithMetadataCallback, - UpdateWithMetadataCallback, -]: +) -> ApplyWithMetadataCallback: @functools.wraps(func) def wrapper(value: Any, key: Any, timestamp: int, headers: Any) -> Any: ctx = message_context() diff --git a/quixstreams/dataframe/registry.py b/quixstreams/dataframe/registry.py index e6bbe0207..421c1f53e 100644 --- a/quixstreams/dataframe/registry.py +++ b/quixstreams/dataframe/registry.py @@ -21,7 +21,7 @@ class DataframeRegistry: `SDF`s are registered by storing their topic and current Stream. """ - def __init__(self): + def __init__(self) -> None: self._registry: Dict[str, Stream] = {} self._topics: List[Topic] = [] # {repartition_topic_name: source_topic_name} diff --git a/tests/test_quixstreams/test_dataframe/test_dataframe.py b/tests/test_quixstreams/test_dataframe/test_dataframe.py index 95a740732..9cc885959 100644 --- a/tests/test_quixstreams/test_dataframe/test_dataframe.py +++ b/tests/test_quixstreams/test_dataframe/test_dataframe.py @@ -1610,7 +1610,7 @@ def test_group_by_invalid_key_func(self, dataframe_factory, topic_manager_factor topic = topic_manager.topic(str(uuid.uuid4())) sdf = dataframe_factory(topic, topic_manager=topic_manager) - with pytest.raises(TypeError): + with pytest.raises(ValueError): sdf.group_by({"um": "what is this"}) def test_group_by_limit_exceeded(self, dataframe_factory, topic_manager_factory): From 965939d6bdde63d5bdb06727814a0a48b3cfe7e6 Mon Sep 17 00:00:00 2001 From: Quentin Dawans Date: Thu, 19 Dec 2024 11:18:43 +0100 Subject: [PATCH 2/2] cleanup callable type hint with typevar --- quixstreams/dataframe/dataframe.py | 30 ++++++------------------------ 1 file changed, 6 insertions(+), 24 deletions(-) diff --git a/quixstreams/dataframe/dataframe.py b/quixstreams/dataframe/dataframe.py index 115570c07..65e7f9b32 100644 --- a/quixstreams/dataframe/dataframe.py +++ b/quixstreams/dataframe/dataframe.py @@ -13,6 +13,7 @@ Literal, Optional, Tuple, + TypeVar, Union, cast, overload, @@ -1324,31 +1325,12 @@ def _drop(value: Dict, columns: List[str], ignore_missing: bool = False): raise -@overload -def _as_metadata_func( - func: ApplyCallbackStateful, -) -> ApplyWithMetadataCallbackStateful: ... - - -@overload -def _as_metadata_func( - func: FilterCallbackStateful, -) -> FilterWithMetadataCallbackStateful: ... - - -@overload -def _as_metadata_func( - func: UpdateCallbackStateful, -) -> UpdateWithMetadataCallbackStateful: ... +T = TypeVar("T") def _as_metadata_func( - func: Union[ApplyCallbackStateful, FilterCallbackStateful, UpdateCallbackStateful], -) -> Union[ - ApplyWithMetadataCallbackStateful, - FilterWithMetadataCallbackStateful, - UpdateWithMetadataCallbackStateful, -]: + func: Callable[[Any, State], T], +) -> Callable[[Any, Any, int, Any, State], T]: @functools.wraps(func) def wrapper( value: Any, _key: Any, _timestamp: int, _headers: Any, state: State @@ -1359,9 +1341,9 @@ def wrapper( def _as_stateful( - func: ApplyWithMetadataCallbackStateful, + func: Callable[[Any, Any, int, Any, State], T], processing_context: ProcessingContext, -) -> ApplyWithMetadataCallback: +) -> Callable[[Any, Any, int, Any], T]: @functools.wraps(func) def wrapper(value: Any, key: Any, timestamp: int, headers: Any) -> Any: ctx = message_context()