From 5f00b6b163c63279931b98c93e3c6be95f2f3d6b Mon Sep 17 00:00:00 2001 From: Quentin Dawans Date: Wed, 11 Dec 2024 16:26:42 +0100 Subject: [PATCH] mypy: fix part of quixstreams.dataframe --- pyproject.toml | 3 +- quixstreams/context.py | 9 +- quixstreams/core/stream/functions/apply.py | 4 +- quixstreams/core/stream/stream.py | 8 +- quixstreams/dataframe/base.py | 9 +- quixstreams/dataframe/dataframe.py | 194 ++++++++++++------ quixstreams/dataframe/registry.py | 2 +- .../test_dataframe/test_dataframe.py | 8 +- 8 files changed, 166 insertions(+), 71 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7a2c32e4b..6d170dfc1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -123,7 +123,8 @@ ignore_errors = true [[tool.mypy.overrides]] module = [ - "quixstreams.dataframe.*", + "quixstreams.dataframe.series.*", + "quixstreams.dataframe.windows.*", "quixstreams.rowproducer.*" ] ignore_errors = true diff --git a/quixstreams/context.py b/quixstreams/context.py index 2de9d3be9..18286794f 100644 --- a/quixstreams/context.py +++ b/quixstreams/context.py @@ -50,7 +50,7 @@ def alter_context(value): _current_message_context.set(context) -def message_context() -> Optional[MessageContext]: +def message_context() -> MessageContext: """ Get a MessageContext for the current message, which houses most of the message metadata, like: @@ -75,6 +75,11 @@ def message_context() -> Optional[MessageContext]: :return: instance of `MessageContext` """ try: - return _current_message_context.get() + ctx = _current_message_context.get() except LookupError: raise MessageContextNotSetError("Message context is not set") + + if ctx is None: + raise MessageContextNotSetError("Message context is not set") + + return ctx diff --git a/quixstreams/core/stream/functions/apply.py b/quixstreams/core/stream/functions/apply.py index 9609d6cdd..bdf493953 100644 --- a/quixstreams/core/stream/functions/apply.py +++ b/quixstreams/core/stream/functions/apply.py @@ -80,7 +80,7 @@ class ApplyWithMetadataFunction(StreamFunction): @overload def __init__( - self, func: ApplyWithMetadataCallback, expand: Literal[False] + self, func: ApplyWithMetadataCallback, expand: Literal[False] = False ) -> None: ... @overload @@ -90,7 +90,7 @@ def __init__( def __init__( self, - func: ApplyWithMetadataCallback, + func: Union[ApplyWithMetadataCallback, ApplyWithMetadataExpandedCallback], expand: bool = False, ): super().__init__(func) diff --git a/quixstreams/core/stream/stream.py b/quixstreams/core/stream/stream.py index 8e43ad1e5..72eaa65ef 100644 --- a/quixstreams/core/stream/stream.py +++ b/quixstreams/core/stream/stream.py @@ -90,7 +90,7 @@ def __repr__(self) -> str: ) return f"<{self.__class__.__name__} [{len(tree_funcs)}]: {funcs_repr}>" - def diff(self, other: Self) -> Optional[Self]: + def diff(self, other: Self) -> Self: """ Takes the difference between Streams `self` and `other` based on their last common parent, and returns a new, independent `Stream` that includes only @@ -102,7 +102,7 @@ def diff(self, other: Self) -> Optional[Self]: the `other` Stream, and the resulting diff is empty. :param other: a `Stream` to take a diff from. - :raises ValueError: if Streams don't have a common parent, + :raises InvalidOperation: if Streams don't have a common parent, if the diff is empty, or pruning failed. :return: a new independent `Stream` instance whose root begins at the diff """ @@ -149,6 +149,10 @@ def diff(self, other: Self) -> Optional[Self]: node.parent = parent parent = node self._prune(diff[0]) + + if parent is None: + raise InvalidOperation("No common parent found") + return parent def root_path(self, allow_splits=True) -> List[Self]: diff --git a/quixstreams/dataframe/base.py b/quixstreams/dataframe/base.py index 4fcca213c..ede71e663 100644 --- a/quixstreams/dataframe/base.py +++ b/quixstreams/dataframe/base.py @@ -11,9 +11,14 @@ class BaseStreaming: def stream(self) -> Stream: ... @abc.abstractmethod - def compose(self, *args, **kwargs) -> VoidExecutor: ... + def compose(self, *args, **kwargs) -> dict[str, VoidExecutor]: ... @abc.abstractmethod def test( - self, value: Any, key: Any, timestamp: int, ctx: Optional[MessageContext] = None + self, + value: Any, + key: Any, + timestamp: int, + headers: Optional[Any] = None, + ctx: Optional[MessageContext] = None, ) -> Any: ... diff --git a/quixstreams/dataframe/dataframe.py b/quixstreams/dataframe/dataframe.py index 01978896e..bfc5e4d6a 100644 --- a/quixstreams/dataframe/dataframe.py +++ b/quixstreams/dataframe/dataframe.py @@ -26,8 +26,10 @@ ) from quixstreams.core.stream import ( ApplyCallback, + ApplyExpandedCallback, ApplyFunction, ApplyWithMetadataCallback, + ApplyWithMetadataExpandedCallback, ApplyWithMetadataFunction, FilterCallback, FilterFunction, @@ -147,17 +149,15 @@ def stream(self) -> Stream: def topic(self) -> Topic: return self._topic - @overload - def apply(self, func: ApplyCallback, *, expand: bool = ...) -> Self: ... - @overload def apply( self, - func: ApplyWithMetadataCallback, + func: Union[ApplyCallback, ApplyExpandedCallback], *, - metadata: Literal[True], - expand: bool = ..., - ) -> Self: ... + stateful: Literal[False] = False, + expand: Union[Literal[False], Literal[True]] = False, + metadata: Literal[False] = False, + ): ... @overload def apply( @@ -165,8 +165,19 @@ def apply( func: ApplyCallbackStateful, *, stateful: Literal[True], - expand: bool = ..., - ) -> Self: ... + expand: Union[Literal[False], Literal[True]] = False, + metadata: Literal[False] = False, + ): ... + + @overload + def apply( + self, + func: Union[ApplyWithMetadataCallback, ApplyWithMetadataExpandedCallback], + *, + stateful: Literal[False] = False, + expand: Union[Literal[False], Literal[True]] = False, + metadata: Literal[True], + ): ... @overload def apply( @@ -174,21 +185,23 @@ def apply( func: ApplyWithMetadataCallbackStateful, *, stateful: Literal[True], + expand: Union[Literal[False], Literal[True]] = False, metadata: Literal[True], - expand: bool = ..., - ) -> Self: ... + ): ... def apply( self, func: Union[ ApplyCallback, + ApplyExpandedCallback, ApplyCallbackStateful, ApplyWithMetadataCallback, + ApplyWithMetadataExpandedCallback, ApplyWithMetadataCallbackStateful, ], *, stateful: bool = False, - expand: bool = False, + expand: Union[Literal[True], Literal[False]] = False, metadata: bool = False, ) -> Self: """ @@ -226,12 +239,15 @@ def func(d: dict, state: State): """ if stateful: self._register_store() + # Force the callback to accept metadata - with_metadata_func = ( - cast(ApplyWithMetadataCallbackStateful, func) - if metadata - else _as_metadata_func(cast(ApplyCallbackStateful, func)) - ) + if metadata: + with_metadata_func = cast(ApplyWithMetadataCallbackStateful, func) + else: + with_metadata_func = _as_metadata_func( + cast(ApplyCallbackStateful, func) + ) + stateful_func = _as_stateful( func=with_metadata_func, processing_context=self._processing_context, @@ -241,23 +257,42 @@ def func(d: dict, state: State): ApplyWithMetadataFunction(stateful_func, expand=expand) ) elif metadata: + func = cast( + Union[ApplyWithMetadataCallback, ApplyWithMetadataExpandedCallback], + func, + ) stream = self.stream.add(ApplyWithMetadataFunction(func, expand=expand)) else: + func = cast(Union[ApplyCallback, ApplyExpandedCallback], func) stream = self.stream.add(ApplyFunction(func, expand=expand)) return self.__dataframe_clone__(stream=stream) @overload - def update(self, func: UpdateCallback) -> Self: ... + def update( + self, + func: UpdateCallback, + *, + stateful: Literal[False] = False, + metadata: Literal[False] = False, + ): ... @overload def update( - self, func: UpdateWithMetadataCallback, *, metadata: Literal[True] - ) -> Self: ... + self, + func: UpdateCallbackStateful, + *, + stateful: Literal[True], + metadata: Literal[False] = False, + ): ... @overload def update( - self, func: UpdateCallbackStateful, *, stateful: Literal[True] - ) -> Self: ... + self, + func: UpdateWithMetadataCallback, + *, + stateful: Literal[False] = False, + metadata: Literal[True], + ): ... @overload def update( @@ -266,7 +301,7 @@ def update( *, stateful: Literal[True], metadata: Literal[True], - ) -> Self: ... + ): ... def update( self, @@ -278,7 +313,7 @@ def update( ], *, stateful: bool = False, - metadata: bool = False, + metadata: Union[Literal[True], Literal[False]] = False, ) -> Self: """ Apply a function to mutate value in-place or to perform a side effect @@ -319,37 +354,53 @@ def func(values: list, state: State): """ if stateful: self._register_store() + # Force the callback to accept metadata - with_metadata_func = ( - func - if metadata - else _as_metadata_func(cast(UpdateCallbackStateful, func)) - ) + if metadata: + with_metadata_func = cast(UpdateWithMetadataCallbackStateful, func) + else: + with_metadata_func = _as_metadata_func( + cast(UpdateCallbackStateful, func) + ) + stateful_func = _as_stateful( - func=cast(UpdateWithMetadataCallbackStateful, with_metadata_func), + func=with_metadata_func, processing_context=self._processing_context, ) - return self._add_update( - cast(UpdateWithMetadataCallback, stateful_func), metadata=True - ) + return self._add_update(stateful_func, metadata=True) + elif metadata: + func = cast(UpdateWithMetadataCallback, func) + return self._add_update(func, metadata=metadata) else: - return self._add_update( - cast(Union[UpdateCallback, UpdateWithMetadataCallback], func), - metadata=metadata, - ) + func = cast(UpdateCallback, func) + return self._add_update(func, metadata=metadata) @overload - def filter(self, func: FilterCallback) -> Self: ... + def filter( + self, + func: FilterCallback, + *, + stateful: Literal[False] = False, + metadata: Literal[False] = False, + ): ... @overload def filter( - self, func: FilterWithMetadataCallback, *, metadata: Literal[True] - ) -> Self: ... + self, + func: FilterCallbackStateful, + *, + stateful: Literal[True], + metadata: Literal[False] = False, + ): ... @overload def filter( - self, func: FilterCallbackStateful, *, stateful: Literal[True] - ) -> Self: ... + self, + func: FilterWithMetadataCallback, + *, + stateful: Literal[False] = False, + metadata: Literal[True], + ): ... @overload def filter( @@ -358,7 +409,7 @@ def filter( *, stateful: Literal[True], metadata: Literal[True], - ) -> Self: ... + ): ... def filter( self, @@ -406,17 +457,26 @@ def func(d: dict, state: State): if stateful: self._register_store() + # Force the callback to accept metadata - with_metadata_func = func if metadata else _as_metadata_func(func) + if metadata: + with_metadata_func = cast(FilterWithMetadataCallbackStateful, func) + else: + with_metadata_func = _as_metadata_func( + cast(FilterCallbackStateful, func) + ) + stateful_func = _as_stateful( - func=cast(FilterWithMetadataCallbackStateful, with_metadata_func), + func=with_metadata_func, processing_context=self._processing_context, ) stream = self.stream.add(FilterWithMetadataFunction(stateful_func)) elif metadata: + func = cast(FilterWithMetadataCallback, func) stream = self.stream.add(FilterWithMetadataFunction(func)) else: + func = cast(FilterCallback, func) stream = self.stream.add(FilterFunction(func)) return self.__dataframe_clone__(stream=stream) @@ -495,15 +555,21 @@ def func(d: dict, state: State): :return: a clone with this operation added (assign to keep its effect). """ + if not key: raise ValueError('Parameter "key" cannot be empty') - if callable(key) and not name: + + operation = name + if not operation and isinstance(key, str): + operation = key + + if not operation: raise ValueError( 'group_by requires "name" parameter when "key" is a function' ) groupby_topic = self._topic_manager.repartition_topic( - operation=name or key, + operation=operation, topic_name=self._topic.name, key_serializer=key_serializer, value_serializer=value_serializer, @@ -678,7 +744,11 @@ def _set_headers_callback( stream = self.stream.add(TransformFunction(_set_headers_callback)) return self.__dataframe_clone__(stream=stream) - def print(self, pretty: bool = True, metadata: bool = False) -> Self: + def print( + self, + pretty: bool = True, + metadata: Union[Literal[True], Literal[False]] = False, + ) -> Self: """ Print out the current message value (and optionally, the message metadata) to stdout (console) (like the built-in `print` function). @@ -712,9 +782,12 @@ def print(self, pretty: bool = True, metadata: bool = False) -> Self: """ print_args = ["value", "key", "timestamp", "headers"] if pretty: - printer = functools.partial(pprint.pprint, indent=2, sort_dicts=False) + printer: Callable[[dict[str, str]], None] = functools.partial( + pprint.pprint, indent=2, sort_dicts=False + ) else: printer = print + return self._add_update( lambda *args: printer({print_args[i]: args[i] for i in range(len(args))}), metadata=metadata, @@ -1136,9 +1209,17 @@ def _produce( ) self._producer.produce_row(row=row, topic=topic, key=key, timestamp=timestamp) + @overload + def _add_update(self, func: UpdateCallback, metadata: Literal[False] = False): ... + + @overload + def _add_update( + self, func: UpdateWithMetadataCallback, metadata: Literal[True] + ): ... + def _add_update( self, - func: Union[UpdateCallback, UpdateWithMetadataCallback], + func, metadata: bool = False, ): if metadata: @@ -1194,6 +1275,7 @@ def __setitem__(self, item_key: Any, item: Union[Self, object]): # Update an item key with a result of another sdf.apply() diff = self.stream.diff(item.stream) other_sdf_composed = diff.compose_returning() + self._add_update( lambda value, key, timestamp, headers: operator.setitem( value, @@ -1315,17 +1397,9 @@ def wrapper( def _as_stateful( - func: Union[ - ApplyWithMetadataCallbackStateful, - FilterWithMetadataCallbackStateful, - UpdateWithMetadataCallbackStateful, - ], + func: ApplyWithMetadataCallbackStateful, processing_context: ProcessingContext, -) -> Union[ - ApplyWithMetadataCallback, - FilterWithMetadataCallback, - UpdateWithMetadataCallback, -]: +) -> ApplyWithMetadataCallback: @functools.wraps(func) def wrapper(value: Any, key: Any, timestamp: int, headers: Any) -> Any: ctx = message_context() diff --git a/quixstreams/dataframe/registry.py b/quixstreams/dataframe/registry.py index e6bbe0207..421c1f53e 100644 --- a/quixstreams/dataframe/registry.py +++ b/quixstreams/dataframe/registry.py @@ -21,7 +21,7 @@ class DataframeRegistry: `SDF`s are registered by storing their topic and current Stream. """ - def __init__(self): + def __init__(self) -> None: self._registry: Dict[str, Stream] = {} self._topics: List[Topic] = [] # {repartition_topic_name: source_topic_name} diff --git a/tests/test_quixstreams/test_dataframe/test_dataframe.py b/tests/test_quixstreams/test_dataframe/test_dataframe.py index 95a740732..73c0e58aa 100644 --- a/tests/test_quixstreams/test_dataframe/test_dataframe.py +++ b/tests/test_quixstreams/test_dataframe/test_dataframe.py @@ -1451,9 +1451,15 @@ def test_group_by_column_with_name( sdf = dataframe_factory( topic, topic_manager=topic_manager, producer=producer, registry=sdf_registry ) + + print(sdf.topic.name) + sdf = sdf.group_by(col, name=op_name) sdf[col] = col_update + print(sdf.topic.name) + print(topic_manager.repartition_topic(op_name, topic.name).name) + groupby_topic = sdf.topic assert sdf_registry.consumer_topics == [topic, sdf.topic] assert ( @@ -1610,7 +1616,7 @@ def test_group_by_invalid_key_func(self, dataframe_factory, topic_manager_factor topic = topic_manager.topic(str(uuid.uuid4())) sdf = dataframe_factory(topic, topic_manager=topic_manager) - with pytest.raises(TypeError): + with pytest.raises(ValueError): sdf.group_by({"um": "what is this"}) def test_group_by_limit_exceeded(self, dataframe_factory, topic_manager_factory):