Skip to content

Commit

Permalink
feat: adds Web3SubscriptionsManager.get_subscription_data_nowait() (#95)
Browse files Browse the repository at this point in the history
* feat: adds Web3SubscriptionsManager.pop_subscription_data()

* feat: adds get_subscription_data_nowait() to Web3SubscriptionsManager

* fix: awaiting sync func

* refactor: leverage asyncio.timeout when calling websocket client.recv

* fix: one day I'll learn to lint before committing.

* fix: Python 3.10 does not support asyncio.timeout

* fix: get asyncio lock when calling _receive()

* chore: cleanup, remove pop_subscription_data()

* chore: logging and return from generator on end

* chore: cleanup debug statements
  • Loading branch information
mikeshultz authored Jul 17, 2024
1 parent fc636ce commit 34a12e0
Showing 1 changed file with 35 additions and 3 deletions.
38 changes: 35 additions & 3 deletions silverback/subscriptions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
import json
from enum import Enum
from typing import AsyncGenerator
from typing import AsyncGenerator, Optional

from ape.logging import logger
from websockets import ConnectionClosedError
Expand Down Expand Up @@ -46,14 +46,23 @@ async def __anext__(self) -> str:
if not self.connection:
raise StopAsyncIteration

message = await self.connection.recv()
return await self._receive()

async def _receive(self, timeout: Optional[int] = None) -> str:
"""Receive (and wait if no timeout) for the next message from the
socket.
"""
if not self.connection:
raise ConnectionError("Connection not opened")

message = await asyncio.wait_for(self.connection.recv(), timeout)
# TODO: Handle retries when connection breaks

response = json.loads(message)
if response.get("method") == "eth_subscription":
sub_params: dict = response.get("params", {})
if not (sub_id := sub_params.get("subscription")) or not isinstance(sub_id, str):
logger.debug(f"Corrupted subscription data: {response}")
logger.warning(f"Corrupted subscription data: {response}")
return response

if sub_id not in self._subscriptions:
Expand Down Expand Up @@ -115,6 +124,9 @@ async def subscribe(self, type: SubscriptionType, **filter_params) -> str:
return sub_id

async def get_subscription_data(self, sub_id: str) -> AsyncGenerator[dict, None]:
"""Iterate items from the subscription queue. If nothing is in the
queue, await.
"""
while True:
if not (queue := self._subscriptions.get(sub_id)) or queue.empty():
async with self._ws_lock:
Expand All @@ -124,6 +136,26 @@ async def get_subscription_data(self, sub_id: str) -> AsyncGenerator[dict, None]
else:
yield await queue.get()

async def get_subscription_data_nowait(
self, sub_id: str, timeout: Optional[int] = 15
) -> AsyncGenerator[dict, None]:
"""Iterate items from the subscription queue. If nothing is in the
queue, return.
"""
while True:
if not (queue := self._subscriptions.get(sub_id)) or queue.empty():
async with self._ws_lock:
try:
await self._receive(timeout=timeout)
except TimeoutError:
logger.debug("Receive call timed out.")
return
else:
try:
yield queue.get_nowait()
except asyncio.QueueEmpty:
return

async def unsubscribe(self, sub_id: str) -> bool:
if sub_id not in self._subscriptions:
raise ValueError(f"Unknown sub_id '{sub_id}'")
Expand Down

0 comments on commit 34a12e0

Please sign in to comment.