Skip to content

Commit

Permalink
fix: dependency injection, typing annotations, prefetch_count (#10)
Browse files Browse the repository at this point in the history
  • Loading branch information
RaRhAeu authored Aug 22, 2024
1 parent 336c21f commit fb77450
Show file tree
Hide file tree
Showing 7 changed files with 41 additions and 42 deletions.
2 changes: 1 addition & 1 deletion eventiq/__about__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.1.1"
__version__ = "1.1.2"
29 changes: 18 additions & 11 deletions eventiq/backends/nats.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import asyncio
from abc import ABC
from datetime import timedelta, timezone
from typing import TYPE_CHECKING, Annotated, Any, Callable

import anyio
Expand All @@ -16,7 +16,7 @@
from eventiq.exceptions import BrokerError
from eventiq.results import ResultBackend
from eventiq.settings import UrlBrokerSettings
from eventiq.utils import to_float, utc_now
from eventiq.utils import to_float

if TYPE_CHECKING:
from collections.abc import Awaitable
Expand Down Expand Up @@ -44,6 +44,7 @@ class JetStreamSettings(NatsSettings):

class AbstractNatsBroker(UrlBroker[NatsMsg, R], ABC):
""":param auto_flush: auto flush messages on publish
:param auto_flush: auto flush on publish
:param kwargs: options for base class
"""

Expand Down Expand Up @@ -99,6 +100,7 @@ async def connect(self) -> None:

async def disconnect(self) -> None:
if self.client.is_connected:
await self.client.drain()
await self.client.close()

async def flush(self) -> None:
Expand Down Expand Up @@ -144,11 +146,12 @@ async def publish(
body: bytes,
*,
headers: dict[str, str],
reply: str = "",
flush: bool = False,
**kwargs: Any,
) -> None:
reply = kwargs.get("reply", "")
await self.client.publish(topic, body, headers=headers, reply=reply)
if self._auto_flush or kwargs.get("flush"):
if self._auto_flush or flush:
await self.flush()


Expand All @@ -157,9 +160,8 @@ class JetStreamBroker(
ResultBackend[NatsMsg, api.PubAck],
):
"""NatsBroker with JetStream enabled
:param prefetch_count: default number of messages to prefetch
:param fetch_timeout: timeout for subscription pull
:param jetstream_options: additional options passed to nc.jetstream(...)
:param kv_options: options for nats KV initialization.
:param kwargs: all other options for base classes NatsBroker, Broker.
"""

Expand Down Expand Up @@ -233,7 +235,6 @@ async def sender(
if key in consumer.options:
config_kwargs[key] = consumer.options[key]
config = ConsumerConfig(**config_kwargs)
batch = consumer.options.get("batch", consumer.concurrency * 2)
fetch_timeout = consumer.options.get("fetch_timeout", 10)
heartbeat = consumer.options.get("heartbeat", 0.1)
subscription = await self.js.pull_subscribe(
Expand All @@ -245,23 +246,29 @@ async def sender(
async with send_stream:
while True:
try:
batch = consumer.concurrency - len(
send_stream._state.buffer # noqa: SLF001
)
if batch == 0:
await asyncio.sleep(0.1)
continue
self.logger.debug("Fetching %d messages", batch)
messages = await subscription.fetch(
batch=batch,
timeout=fetch_timeout,
heartbeat=heartbeat,
)
for message in messages:
await send_stream.send(message)
except FetchTimeoutError: # noqa: PERF203
except FetchTimeoutError:
pass
finally:
if consumer.dynamic:
await subscription.unsubscribe()
self.logger.info("Sender finished for %s", consumer.name)
self.logger.info("Stopped sender for consumer: %s", consumer.name)

def should_nack(self, raw_message: NatsMsg) -> bool:
date = raw_message.metadata.timestamp.replace(tzinfo=timezone.utc)
return date < (utc_now() - timedelta(seconds=self.validate_error_delay))
return raw_message.metadata.num_delivered <= 3

def get_num_delivered(self, raw_message: NatsMsg) -> int | None:
return raw_message.metadata.num_delivered
2 changes: 1 addition & 1 deletion eventiq/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ async def wrapped(
annotation, default = v
if annotation in state:
kwargs[k] = state[annotation]
elif default is Parameter.empty:
elif k not in kwargs and default is Parameter.empty:
err = f"Missing dependency {k}: {annotation}"
raise DependencyError(err)

Expand Down
33 changes: 13 additions & 20 deletions eventiq/middlewares/retries.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@
from datetime import timedelta
from typing import TYPE_CHECKING, Any, Callable, Generic, NamedTuple

from typing_extensions import ParamSpec

from eventiq.exceptions import Fail, Retry, Skip
from eventiq.logging import LoggerMixin
from eventiq.middleware import CloudEventType, Middleware
Expand All @@ -15,8 +13,6 @@
from eventiq.types import RetryStrategy


P = ParamSpec("P")

DelayGenerator = Callable[[CloudEventType, Exception], int]


Expand All @@ -27,7 +23,7 @@ class MessageStatus(NamedTuple):

def expo(factor: int = 1) -> DelayGenerator:
def _expo(message: CloudEvent, _: Exception) -> int:
return factor * message.age.seconds
return factor * int(message.age.total_seconds())

return _expo

Expand All @@ -39,24 +35,19 @@ def _constant(*_: Any) -> int:
return _constant


class BaseRetryStrategy(Generic[P, CloudEventType], LoggerMixin):
class BaseRetryStrategy(Generic[CloudEventType], LoggerMixin):
def __init__(
self,
*,
throws: tuple[type[Exception], ...] = (),
delay_generator: Callable[P, DelayGenerator] | None = None,
min_delay: int = 2,
delay_generator: DelayGenerator = expo(),
min_delay: int = 1,
log_exceptions: bool = True,
*args: P.args,
**kwargs: P.kwargs,
) -> None:
if Fail not in throws:
throws = (*throws, Fail)
self.throws = throws
self.min_delay = min_delay
self.log_exceptions = log_exceptions
self.delay_generator = (
delay_generator(*args, **kwargs) if delay_generator else expo()
)
self.delay_generator = delay_generator

def retry(self, message: CloudEventType, exc: Exception) -> None:
delay = getattr(exc, "delay", None)
Expand Down Expand Up @@ -92,16 +83,17 @@ def maybe_retry(
self.fail(message, exc)


class MaxAge(BaseRetryStrategy[P, CloudEventType]):
class MaxAge(BaseRetryStrategy[CloudEventType]):
def __init__(
self,
*,
max_age: timedelta | dict[str, Any] = timedelta(hours=6),
**extra: Any,
) -> None:
super().__init__(**extra)
if isinstance(max_age, Mapping):
max_age = timedelta(**max_age)
self.max_age: timedelta = max_age
self.max_age = max_age

def maybe_retry(
self,
Expand All @@ -115,8 +107,8 @@ def maybe_retry(
self.fail(message, exc)


class MaxRetries(BaseRetryStrategy[P, CloudEventType]):
def __init__(self, max_retries: int = 3, **extra: Any) -> None:
class MaxRetries(BaseRetryStrategy[CloudEventType]):
def __init__(self, *, max_retries: int = 3, **extra: Any) -> None:
super().__init__(**extra)
self.max_retries = max_retries

Expand All @@ -138,9 +130,10 @@ def maybe_retry(
self.fail(message, exc)


class RetryWhen(BaseRetryStrategy[P, CloudEventType]):
class RetryWhen(BaseRetryStrategy[CloudEventType]):
def __init__(
self,
*,
retry_when: Callable[[CloudEventType, Exception], bool],
**extra: Any,
) -> None:
Expand Down
12 changes: 6 additions & 6 deletions eventiq/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,18 +234,18 @@ async def start_consumers(self, tg: TaskGroup) -> None:
consumer.maybe_set_publisher(self.publish)
await self.dispatch_before("consumer_start", consumer=consumer)
send_stream, receive_stream = create_memory_object_stream[Any](
consumer.concurrency * 2,
consumer.concurrency,
)

tg.start_soon(self.broker.sender, self.name, consumer, send_stream)

for i in range(1, consumer.concurrency + 1):
for i in range(consumer.concurrency):
self.logger.info("Starting consumer %s task %s", consumer.name, i)
tg.start_soon(
self.receiver,
consumer,
receive_stream.clone(),
name=f"{consumer.name}:{i}",
name=f"{consumer.name}:{i+1}",
)
await self.dispatch_after("consumer_start", consumer=consumer)

Expand Down Expand Up @@ -465,8 +465,8 @@ async def _handle_message_finalization(
@asynccontextmanager
async def subscription(
self,
event_type: type[CloudEvent],
auto_ack: bool = False,
event_type: type[CloudEvent] = CloudEvent,
topic: str | None = None,
**options: Any,
) -> AsyncIterator[
MemoryObjectReceiveStream[tuple[CloudEvent, Callable[[], None]]]
Expand All @@ -483,8 +483,8 @@ async def subscription(
options["dynamic"] = True
consumer = ChannelConsumer(
channel=consumer_send,
auto_ack=auto_ack,
event_type=event_type,
topic=topic,
**options,
)

Expand Down
3 changes: 1 addition & 2 deletions eventiq/types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

from collections.abc import Awaitable
from contextlib import AbstractAsyncContextManager
from datetime import timedelta
from typing import (
Expand Down Expand Up @@ -51,7 +50,7 @@
P = ParamSpec("P")

MessageHandler = Union[
type["GenericConsumer"], Callable[Concatenate[CloudEventType, P], Awaitable[Any]]
type["GenericConsumer"], Callable[Concatenate[CloudEventType, P], Any]
]


Expand Down
2 changes: 1 addition & 1 deletion examples/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@ async def after_broker_connect(self):
@service.subscribe(topic="test.topic", concurrency=2)
async def example_run(message: CloudEvent):
print("Received Message", message.id, "with data:", message.data)
await asyncio.sleep(10)
await asyncio.sleep(5)

0 comments on commit fb77450

Please sign in to comment.