From fb77450b7de1cab9a246031f35e45f27a14da168 Mon Sep 17 00:00:00 2001 From: RaRhAeu <37556570+RaRhAeu@users.noreply.github.com> Date: Thu, 22 Aug 2024 10:01:14 +0200 Subject: [PATCH] fix: dependency injection, typing annotations, prefetch_count (#10) --- eventiq/__about__.py | 2 +- eventiq/backends/nats.py | 29 ++++++++++++++++++----------- eventiq/dependencies.py | 2 +- eventiq/middlewares/retries.py | 33 +++++++++++++-------------------- eventiq/service.py | 12 ++++++------ eventiq/types.py | 3 +-- examples/base.py | 2 +- 7 files changed, 41 insertions(+), 42 deletions(-) diff --git a/eventiq/__about__.py b/eventiq/__about__.py index a82b376..72f26f5 100644 --- a/eventiq/__about__.py +++ b/eventiq/__about__.py @@ -1 +1 @@ -__version__ = "1.1.1" +__version__ = "1.1.2" diff --git a/eventiq/backends/nats.py b/eventiq/backends/nats.py index cd619fa..7f611f5 100644 --- a/eventiq/backends/nats.py +++ b/eventiq/backends/nats.py @@ -1,7 +1,7 @@ from __future__ import annotations +import asyncio from abc import ABC -from datetime import timedelta, timezone from typing import TYPE_CHECKING, Annotated, Any, Callable import anyio @@ -16,7 +16,7 @@ from eventiq.exceptions import BrokerError from eventiq.results import ResultBackend from eventiq.settings import UrlBrokerSettings -from eventiq.utils import to_float, utc_now +from eventiq.utils import to_float if TYPE_CHECKING: from collections.abc import Awaitable @@ -44,6 +44,7 @@ class JetStreamSettings(NatsSettings): class AbstractNatsBroker(UrlBroker[NatsMsg, R], ABC): """:param auto_flush: auto flush messages on publish + :param auto_flush: auto flush on publish :param kwargs: options for base class """ @@ -99,6 +100,7 @@ async def connect(self) -> None: async def disconnect(self) -> None: if self.client.is_connected: + await self.client.drain() await self.client.close() async def flush(self) -> None: @@ -144,11 +146,12 @@ async def publish( body: bytes, *, headers: dict[str, str], + reply: str = "", + flush: bool = False, **kwargs: Any, ) -> None: - reply = kwargs.get("reply", "") await self.client.publish(topic, body, headers=headers, reply=reply) - if self._auto_flush or kwargs.get("flush"): + if self._auto_flush or flush: await self.flush() @@ -157,9 +160,8 @@ class JetStreamBroker( ResultBackend[NatsMsg, api.PubAck], ): """NatsBroker with JetStream enabled - :param prefetch_count: default number of messages to prefetch - :param fetch_timeout: timeout for subscription pull :param jetstream_options: additional options passed to nc.jetstream(...) + :param kv_options: options for nats KV initialization. :param kwargs: all other options for base classes NatsBroker, Broker. """ @@ -233,7 +235,6 @@ async def sender( if key in consumer.options: config_kwargs[key] = consumer.options[key] config = ConsumerConfig(**config_kwargs) - batch = consumer.options.get("batch", consumer.concurrency * 2) fetch_timeout = consumer.options.get("fetch_timeout", 10) heartbeat = consumer.options.get("heartbeat", 0.1) subscription = await self.js.pull_subscribe( @@ -245,6 +246,13 @@ async def sender( async with send_stream: while True: try: + batch = consumer.concurrency - len( + send_stream._state.buffer # noqa: SLF001 + ) + if batch == 0: + await asyncio.sleep(0.1) + continue + self.logger.debug("Fetching %d messages", batch) messages = await subscription.fetch( batch=batch, timeout=fetch_timeout, @@ -252,16 +260,15 @@ async def sender( ) for message in messages: await send_stream.send(message) - except FetchTimeoutError: # noqa: PERF203 + except FetchTimeoutError: pass finally: if consumer.dynamic: await subscription.unsubscribe() - self.logger.info("Sender finished for %s", consumer.name) + self.logger.info("Stopped sender for consumer: %s", consumer.name) def should_nack(self, raw_message: NatsMsg) -> bool: - date = raw_message.metadata.timestamp.replace(tzinfo=timezone.utc) - return date < (utc_now() - timedelta(seconds=self.validate_error_delay)) + return raw_message.metadata.num_delivered <= 3 def get_num_delivered(self, raw_message: NatsMsg) -> int | None: return raw_message.metadata.num_delivered diff --git a/eventiq/dependencies.py b/eventiq/dependencies.py index 4d922aa..e01381a 100644 --- a/eventiq/dependencies.py +++ b/eventiq/dependencies.py @@ -45,7 +45,7 @@ async def wrapped( annotation, default = v if annotation in state: kwargs[k] = state[annotation] - elif default is Parameter.empty: + elif k not in kwargs and default is Parameter.empty: err = f"Missing dependency {k}: {annotation}" raise DependencyError(err) diff --git a/eventiq/middlewares/retries.py b/eventiq/middlewares/retries.py index 729af82..6f1585d 100644 --- a/eventiq/middlewares/retries.py +++ b/eventiq/middlewares/retries.py @@ -4,8 +4,6 @@ from datetime import timedelta from typing import TYPE_CHECKING, Any, Callable, Generic, NamedTuple -from typing_extensions import ParamSpec - from eventiq.exceptions import Fail, Retry, Skip from eventiq.logging import LoggerMixin from eventiq.middleware import CloudEventType, Middleware @@ -15,8 +13,6 @@ from eventiq.types import RetryStrategy -P = ParamSpec("P") - DelayGenerator = Callable[[CloudEventType, Exception], int] @@ -27,7 +23,7 @@ class MessageStatus(NamedTuple): def expo(factor: int = 1) -> DelayGenerator: def _expo(message: CloudEvent, _: Exception) -> int: - return factor * message.age.seconds + return factor * int(message.age.total_seconds()) return _expo @@ -39,24 +35,19 @@ def _constant(*_: Any) -> int: return _constant -class BaseRetryStrategy(Generic[P, CloudEventType], LoggerMixin): +class BaseRetryStrategy(Generic[CloudEventType], LoggerMixin): def __init__( self, + *, throws: tuple[type[Exception], ...] = (), - delay_generator: Callable[P, DelayGenerator] | None = None, - min_delay: int = 2, + delay_generator: DelayGenerator = expo(), + min_delay: int = 1, log_exceptions: bool = True, - *args: P.args, - **kwargs: P.kwargs, ) -> None: - if Fail not in throws: - throws = (*throws, Fail) self.throws = throws self.min_delay = min_delay self.log_exceptions = log_exceptions - self.delay_generator = ( - delay_generator(*args, **kwargs) if delay_generator else expo() - ) + self.delay_generator = delay_generator def retry(self, message: CloudEventType, exc: Exception) -> None: delay = getattr(exc, "delay", None) @@ -92,16 +83,17 @@ def maybe_retry( self.fail(message, exc) -class MaxAge(BaseRetryStrategy[P, CloudEventType]): +class MaxAge(BaseRetryStrategy[CloudEventType]): def __init__( self, + *, max_age: timedelta | dict[str, Any] = timedelta(hours=6), **extra: Any, ) -> None: super().__init__(**extra) if isinstance(max_age, Mapping): max_age = timedelta(**max_age) - self.max_age: timedelta = max_age + self.max_age = max_age def maybe_retry( self, @@ -115,8 +107,8 @@ def maybe_retry( self.fail(message, exc) -class MaxRetries(BaseRetryStrategy[P, CloudEventType]): - def __init__(self, max_retries: int = 3, **extra: Any) -> None: +class MaxRetries(BaseRetryStrategy[CloudEventType]): + def __init__(self, *, max_retries: int = 3, **extra: Any) -> None: super().__init__(**extra) self.max_retries = max_retries @@ -138,9 +130,10 @@ def maybe_retry( self.fail(message, exc) -class RetryWhen(BaseRetryStrategy[P, CloudEventType]): +class RetryWhen(BaseRetryStrategy[CloudEventType]): def __init__( self, + *, retry_when: Callable[[CloudEventType, Exception], bool], **extra: Any, ) -> None: diff --git a/eventiq/service.py b/eventiq/service.py index cac18eb..f7c2751 100644 --- a/eventiq/service.py +++ b/eventiq/service.py @@ -234,18 +234,18 @@ async def start_consumers(self, tg: TaskGroup) -> None: consumer.maybe_set_publisher(self.publish) await self.dispatch_before("consumer_start", consumer=consumer) send_stream, receive_stream = create_memory_object_stream[Any]( - consumer.concurrency * 2, + consumer.concurrency, ) tg.start_soon(self.broker.sender, self.name, consumer, send_stream) - for i in range(1, consumer.concurrency + 1): + for i in range(consumer.concurrency): self.logger.info("Starting consumer %s task %s", consumer.name, i) tg.start_soon( self.receiver, consumer, receive_stream.clone(), - name=f"{consumer.name}:{i}", + name=f"{consumer.name}:{i+1}", ) await self.dispatch_after("consumer_start", consumer=consumer) @@ -465,8 +465,8 @@ async def _handle_message_finalization( @asynccontextmanager async def subscription( self, - event_type: type[CloudEvent], - auto_ack: bool = False, + event_type: type[CloudEvent] = CloudEvent, + topic: str | None = None, **options: Any, ) -> AsyncIterator[ MemoryObjectReceiveStream[tuple[CloudEvent, Callable[[], None]]] @@ -483,8 +483,8 @@ async def subscription( options["dynamic"] = True consumer = ChannelConsumer( channel=consumer_send, - auto_ack=auto_ack, event_type=event_type, + topic=topic, **options, ) diff --git a/eventiq/types.py b/eventiq/types.py index 0ef07eb..1688a70 100644 --- a/eventiq/types.py +++ b/eventiq/types.py @@ -1,6 +1,5 @@ from __future__ import annotations -from collections.abc import Awaitable from contextlib import AbstractAsyncContextManager from datetime import timedelta from typing import ( @@ -51,7 +50,7 @@ P = ParamSpec("P") MessageHandler = Union[ - type["GenericConsumer"], Callable[Concatenate[CloudEventType, P], Awaitable[Any]] + type["GenericConsumer"], Callable[Concatenate[CloudEventType, P], Any] ] diff --git a/examples/base.py b/examples/base.py index 610c7aa..e2be08c 100644 --- a/examples/base.py +++ b/examples/base.py @@ -25,4 +25,4 @@ async def after_broker_connect(self): @service.subscribe(topic="test.topic", concurrency=2) async def example_run(message: CloudEvent): print("Received Message", message.id, "with data:", message.data) - await asyncio.sleep(10) + await asyncio.sleep(5)