From 4e97cb2fdda844282333c48919ea6cf656ac1f7e Mon Sep 17 00:00:00 2001 From: Anton Ilyushenkov Date: Wed, 29 Nov 2023 18:50:07 +0300 Subject: [PATCH 01/17] Rework and fix PubSub race conditions - add aioredis_cluster.aioredis.stream module - rework PubSub command execution flow for prevent race conditions on spontaneously server channels unsubscribe push - make fully dedicated RedisConnection implementation for cluster - add key slot handling for sharded PubSub channels - fix and improve legacy aioredis tests - more tests --- pyproject.toml | 2 +- src/aioredis_cluster/_aioredis/parser.py | 7 +- src/aioredis_cluster/_aioredis/util.py | 14 +- src/aioredis_cluster/aioredis/__init__.py | 80 +- .../aioredis/commands/__init__.py | 78 +- src/aioredis_cluster/aioredis/connection.py | 5 +- src/aioredis_cluster/aioredis/parser.py | 9 + src/aioredis_cluster/aioredis/stream.py | 10 + src/aioredis_cluster/cluster.py | 12 +- src/aioredis_cluster/command_info/commands.py | 51 +- src/aioredis_cluster/connection.py | 814 +++++++++++++++--- src/aioredis_cluster/crc.py | 15 + src/aioredis_cluster/pool.py | 18 +- tests/aioredis_tests/conftest.py | 32 +- tests/aioredis_tests/pubsub_commands_test.py | 27 + tests/unit_tests/aioredis_cluster/conftest.py | 33 +- .../aioredis_cluster/test_connection.py | 0 .../test_connection_pubsub.py | 395 ++++++++- 18 files changed, 1373 insertions(+), 229 deletions(-) create mode 100644 src/aioredis_cluster/aioredis/parser.py create mode 100644 src/aioredis_cluster/aioredis/stream.py create mode 100644 tests/unit_tests/aioredis_cluster/test_connection.py diff --git a/pyproject.toml b/pyproject.toml index e6eefcf..26c239c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,7 +88,7 @@ target-version = ["py38"] exclude = '.pyi$' [tool.pytest.ini_options] -addopts = "--cov-report=term --cov-report=html -v" +addopts = "-v --cov --cov-report=term --cov-report=html" asyncio_mode = "auto" [tool.coverage.run] diff --git a/src/aioredis_cluster/_aioredis/parser.py b/src/aioredis_cluster/_aioredis/parser.py index 0d443ea..5fe075f 100644 --- a/src/aioredis_cluster/_aioredis/parser.py +++ b/src/aioredis_cluster/_aioredis/parser.py @@ -53,7 +53,12 @@ def getmaxbuf(self) -> int: class Parser: - def __init__(self, protocolError: Callable, replyError: Callable, encoding: Optional[str]): + def __init__( + self, + protocolError: Callable, + replyError: Callable, + encoding: Optional[str], + ): self.buf: bytearray = bytearray() self.pos: int = 0 self.protocolError: Callable = protocolError diff --git a/src/aioredis_cluster/_aioredis/util.py b/src/aioredis_cluster/_aioredis/util.py index 386d84b..d60f017 100644 --- a/src/aioredis_cluster/_aioredis/util.py +++ b/src/aioredis_cluster/_aioredis/util.py @@ -1,5 +1,6 @@ import asyncio import sys +from typing import Any, Dict, Generic, TypeVar from urllib.parse import parse_qsl, urlparse from .log import logger @@ -10,6 +11,9 @@ _NOTSET = object() +T = TypeVar("T") + + IS_PY38 = sys.version_info >= (3, 8) # NOTE: never put here anything else; @@ -75,13 +79,13 @@ async def wait_make_dict(fut): return dict(zip(it, it)) -class coerced_keys_dict(dict): - def __getitem__(self, other): +class coerced_keys_dict(Generic[T], Dict[Any, T], dict): + def __getitem__(self, other) -> T: if not isinstance(other, bytes): other = _converters[type(other)](other) return dict.__getitem__(self, other) - def __contains__(self, other): + def __contains__(self, other) -> bool: if not isinstance(other, bytes): other = _converters[type(other)](other) return dict.__contains__(self, other) @@ -108,7 +112,7 @@ async def __anext__(self): return ret -def _set_result(fut, result, *info): +def _set_result(fut: asyncio.Future, result: Any, *info) -> None: if fut.done(): logger.debug("Waiter future is already done %r %r", fut, info) assert fut.cancelled(), ("waiting future is in wrong state", fut, result, info) @@ -116,7 +120,7 @@ def _set_result(fut, result, *info): fut.set_result(result) -def _set_exception(fut, exception): +def _set_exception(fut: asyncio.Future, exception: BaseException) -> None: if fut.done(): logger.debug("Waiter future is already done %r", fut) assert fut.cancelled(), ("waiting future is in wrong state", fut, exception) diff --git a/src/aioredis_cluster/aioredis/__init__.py b/src/aioredis_cluster/aioredis/__init__.py index 8f74d12..392e326 100644 --- a/src/aioredis_cluster/aioredis/__init__.py +++ b/src/aioredis_cluster/aioredis/__init__.py @@ -1,67 +1,31 @@ -try: - from aioredis.commands import ( - GeoMember, - GeoPoint, - Redis, - create_redis, - create_redis_pool, - ) -except ImportError: - from .._aioredis.commands import ( - GeoMember, - GeoPoint, - Redis, - create_redis, - create_redis_pool, - ) +from .commands import GeoMember, GeoPoint, Redis, create_redis, create_redis_pool try: from aioredis.connection import RedisConnection except ImportError: from .._aioredis.connection import RedisConnection -try: - from aioredis.errors import ( - AuthError, - BusyGroupError, - ChannelClosedError, - ConnectionClosedError, - ConnectionForcedCloseError, - MasterNotFoundError, - MasterReplyError, - MaxClientsError, - MultiExecError, - PipelineError, - PoolClosedError, - ProtocolError, - ReadOnlyError, - RedisError, - ReplyError, - SlaveNotFoundError, - SlaveReplyError, - WatchVariableError, - ) -except ImportError: - from .._aioredis.errors import ( - AuthError, - BusyGroupError, - ChannelClosedError, - ConnectionClosedError, - ConnectionForcedCloseError, - MasterNotFoundError, - MasterReplyError, - MaxClientsError, - MultiExecError, - PipelineError, - PoolClosedError, - ProtocolError, - ReadOnlyError, - RedisError, - ReplyError, - SlaveNotFoundError, - SlaveReplyError, - WatchVariableError, - ) +from .errors import ( + AuthError, + BusyGroupError, + ChannelClosedError, + ConnectionClosedError, + ConnectionForcedCloseError, + MasterNotFoundError, + MasterReplyError, + MaxClientsError, + MultiExecError, + PipelineError, + PoolClosedError, + ProtocolError, + ReadOnlyError, + RedisError, + ReplyError, + SlaveNotFoundError, + SlaveReplyError, + WatchVariableError, +) + try: from aioredis.pool import ConnectionsPool except ImportError: diff --git a/src/aioredis_cluster/aioredis/commands/__init__.py b/src/aioredis_cluster/aioredis/commands/__init__.py index 5fc0342..b4eee07 100644 --- a/src/aioredis_cluster/aioredis/commands/__init__.py +++ b/src/aioredis_cluster/aioredis/commands/__init__.py @@ -1,3 +1,9 @@ +import asyncio +from typing import Any, Optional, Tuple, Union + +from aioredis_cluster.aioredis.connection import create_connection +from aioredis_cluster.aioredis.pool import create_pool + try: from aioredis.commands import ( ContextRedis, @@ -5,8 +11,7 @@ GeoPoint, MultiExec, Pipeline, - create_redis, - create_redis_pool, + Redis, ) except ImportError: from aioredis_cluster._aioredis.commands import ( @@ -15,8 +20,7 @@ GeoPoint, MultiExec, Pipeline, - create_redis, - create_redis_pool, + Redis, ) @@ -32,3 +36,69 @@ "GeoPoint", "GeoMember", ) + + +async def create_redis( + address: Union[Tuple[str, int], str], + *, + db: Optional[int] = None, + password: Optional[str] = None, + ssl: Optional[Any] = None, + encoding: Optional[str] = None, + commands_factory: Redis = Redis, + parser=None, + timeout: Optional[float] = None, + connection_cls=None, + loop: Optional[asyncio.AbstractEventLoop] = None, +) -> Redis: + """Creates high-level Redis interface. + + This function is a coroutine. + """ + conn = await create_connection( + address, + db=db, + password=password, + ssl=ssl, + encoding=encoding, + parser=parser, + timeout=timeout, + connection_cls=connection_cls, + ) + return commands_factory(conn) + + +async def create_redis_pool( + address: Union[Tuple[str, int], str], + *, + db: Optional[int] = None, + password: Optional[str] = None, + ssl: Optional[Any] = None, + encoding: Optional[str] = None, + commands_factory: Redis = Redis, + minsize: int = 1, + maxsize: int = 10, + parser=None, + timeout: Optional[float] = None, + pool_cls=None, + connection_cls=None, + loop: Optional[asyncio.AbstractEventLoop] = None, +) -> Redis: + """Creates high-level Redis interface. + + This function is a coroutine. + """ + pool = await create_pool( + address, + db=db, + password=password, + ssl=ssl, + encoding=encoding, + minsize=minsize, + maxsize=maxsize, + parser=parser, + create_connection_timeout=timeout, + pool_cls=pool_cls, + connection_cls=connection_cls, + ) + return commands_factory(pool) diff --git a/src/aioredis_cluster/aioredis/connection.py b/src/aioredis_cluster/aioredis/connection.py index 4cb9ed1..46bbb98 100644 --- a/src/aioredis_cluster/aioredis/connection.py +++ b/src/aioredis_cluster/aioredis/connection.py @@ -7,12 +7,9 @@ from aioredis_cluster.compat.asyncio import timeout as atimeout from .abc import AbcConnection +from .stream import open_connection, open_unix_connection from .util import parse_url -try: - from aioredis.stream import open_connection, open_unix_connection -except ImportError: - from .._aioredis.stream import open_connection, open_unix_connection try: from aioredis.connection import MAX_CHUNK_SIZE, RedisConnection except ImportError: diff --git a/src/aioredis_cluster/aioredis/parser.py b/src/aioredis_cluster/aioredis/parser.py new file mode 100644 index 0000000..155c338 --- /dev/null +++ b/src/aioredis_cluster/aioredis/parser.py @@ -0,0 +1,9 @@ +try: + from aioredis.parser import PyReader, Reader +except ImportError: + from .._aioredis.parser import PyReader, Reader + +__all__ = ( + "Reader", + "PyReader", +) diff --git a/src/aioredis_cluster/aioredis/stream.py b/src/aioredis_cluster/aioredis/stream.py new file mode 100644 index 0000000..608ad92 --- /dev/null +++ b/src/aioredis_cluster/aioredis/stream.py @@ -0,0 +1,10 @@ +try: + from aioredis.stream import StreamReader, open_connection, open_unix_connection +except ImportError: + from .._aioredis.stream import StreamReader, open_connection, open_unix_connection + +__all__ = ( + "open_connection", + "open_unix_connection", + "StreamReader", +) diff --git a/src/aioredis_cluster/cluster.py b/src/aioredis_cluster/cluster.py index 03e1e67..429d39f 100644 --- a/src/aioredis_cluster/cluster.py +++ b/src/aioredis_cluster/cluster.py @@ -29,7 +29,7 @@ from aioredis_cluster.command_info import CommandInfo, extract_keys from aioredis_cluster.commands import RedisCluster from aioredis_cluster.compat.asyncio import timeout as atimeout -from aioredis_cluster.crc import key_slot +from aioredis_cluster.crc import CrossSlotKeysError, determine_slot from aioredis_cluster.errors import ( AskError, ClusterClosedError, @@ -418,12 +418,10 @@ async def authorize(pool) -> None: await self._pooler.batch_op(authorize) def determine_slot(self, first_key: bytes, *keys: bytes) -> int: - slot: int = key_slot(first_key) - for k in keys: - if slot != key_slot(k): - raise RedisClusterError("all keys must map to the same key slot") - - return slot + try: + return determine_slot(first_key, *keys) + except CrossSlotKeysError: + raise RedisClusterError(str(CrossSlotKeysError)) from None async def all_masters(self) -> List[Redis]: ctx = self._make_exec_context((b"PING",), {}) diff --git a/src/aioredis_cluster/command_info/commands.py b/src/aioredis_cluster/command_info/commands.py index 87f61fb..491d241 100644 --- a/src/aioredis_cluster/command_info/commands.py +++ b/src/aioredis_cluster/command_info/commands.py @@ -1,5 +1,6 @@ # redis command output (v5.0.8) -from typing import FrozenSet, Iterable, MutableSet, Union +import enum +from typing import Dict, FrozenSet, Iterable, Mapping, MutableSet, Union, cast __all__ = ( "COMMANDS", @@ -10,9 +11,14 @@ "XREAD_COMMAND", "XREADGROUP_COMMAND", "PUBSUB_COMMANDS", + "PATTERN_PUBSUB_COMMANDS", "SHARDED_PUBSUB_COMMANDS", "PUBSUB_FAMILY_COMMANDS", "PING_COMMANDS", + "PUBSUB_SUBSCRIBE_COMMANDS", + "PUBSUB_COMMAND_TO_TYPE", + "PUBSUB_RESP_KIND_TO_TYPE", + "PubSubType", ) @@ -270,11 +276,17 @@ def _gen_commands_set(commands: Iterable[str]) -> FrozenSet[Union[bytes, str]]: XREADGROUP_COMMAND = "XREADGROUP" -PUBSUB_COMMANDS = _gen_commands_set( +class PubSubType(enum.Enum): + CHANNEL = enum.auto() + PATTERN = enum.auto() + SHARDED = enum.auto() + + +PUBSUB_COMMANDS = _gen_commands_set({"SUBSCRIBE", "UNSUBSCRIBE"}) + +PATTERN_PUBSUB_COMMANDS = _gen_commands_set( { - "SUBSCRIBE", "PSUBSCRIBE", - "UNSUBSCRIBE", "PUNSUBSCRIBE", } ) @@ -286,6 +298,35 @@ def _gen_commands_set(commands: Iterable[str]) -> FrozenSet[Union[bytes, str]]: } ) -PUBSUB_FAMILY_COMMANDS = PUBSUB_COMMANDS | SHARDED_PUBSUB_COMMANDS +PUBSUB_SUBSCRIBE_COMMANDS = _gen_commands_set( + { + "SUBSCRIBE", + "PSUBSCRIBE", + "SSUBSCRIBE", + } +) + +PUBSUB_FAMILY_COMMANDS = PUBSUB_COMMANDS | PATTERN_PUBSUB_COMMANDS | SHARDED_PUBSUB_COMMANDS PING_COMMANDS = _gen_commands_set({"PING"}) + +PUBSUB_RESP_KIND_TO_TYPE: Mapping[bytes, PubSubType] = { + b"message": PubSubType.CHANNEL, + b"subscribe": PubSubType.CHANNEL, + b"unsubscribe": PubSubType.CHANNEL, + b"pmessage": PubSubType.PATTERN, + b"psubscribe": PubSubType.PATTERN, + b"punsubscribe": PubSubType.PATTERN, + b"smessage": PubSubType.SHARDED, + b"ssubscribe": PubSubType.SHARDED, + b"sunsubscribe": PubSubType.SHARDED, +} + +_pubsub_command_to_type: Dict[Union[str, bytes], PubSubType] = {} +for cmd in PUBSUB_COMMANDS: + _pubsub_command_to_type[cmd] = PubSubType.CHANNEL +for cmd in PATTERN_PUBSUB_COMMANDS: + _pubsub_command_to_type[cmd] = PubSubType.PATTERN +for cmd in SHARDED_PUBSUB_COMMANDS: + _pubsub_command_to_type[cmd] = PubSubType.SHARDED +PUBSUB_COMMAND_TO_TYPE = cast(Mapping[Union[str, bytes], PubSubType], _pubsub_command_to_type) diff --git a/src/aioredis_cluster/connection.py b/src/aioredis_cluster/connection.py index 342039e..c66cd8f 100644 --- a/src/aioredis_cluster/connection.py +++ b/src/aioredis_cluster/connection.py @@ -1,26 +1,82 @@ import asyncio import logging +import warnings +from collections import deque +from contextlib import contextmanager from functools import partial from types import MappingProxyType -from typing import Iterable, List, Mapping +from typing import ( + Any, + Callable, + Deque, + Dict, + Iterable, + List, + Mapping, + NamedTuple, + Optional, + Protocol, + Set, + Tuple, + Union, +) -from aioredis_cluster._aioredis.util import coerced_keys_dict, wait_ok +from aioredis_cluster._aioredis.util import ( + _set_exception, + _set_result, + coerced_keys_dict, + decode, + wait_ok, +) from aioredis_cluster.abc import AbcChannel, AbcConnection -from aioredis_cluster.aioredis import Channel, ConnectionClosedError -from aioredis_cluster.aioredis import RedisConnection as BaseConnection +from aioredis_cluster.aioredis import ( + Channel, + ConnectionClosedError, + ConnectionForcedCloseError, + MaxClientsError, + ProtocolError, + ReadOnlyError, + ReplyError, + WatchVariableError, +) +from aioredis_cluster.aioredis.parser import Reader +from aioredis_cluster.aioredis.stream import StreamReader from aioredis_cluster.aioredis.util import _NOTSET from aioredis_cluster.command_info.commands import ( PING_COMMANDS, + PUBSUB_COMMAND_TO_TYPE, PUBSUB_FAMILY_COMMANDS, - SHARDED_PUBSUB_COMMANDS, + PUBSUB_RESP_KIND_TO_TYPE, + PUBSUB_SUBSCRIBE_COMMANDS, + PubSubType, ) -from aioredis_cluster.errors import RedisError +from aioredis_cluster.crc import CrossSlotKeysError, determine_slot +from aioredis_cluster.errors import MovedError, RedisError from aioredis_cluster.typedef import PClosableConnection -from aioredis_cluster.util import encode_command +from aioredis_cluster.util import encode_command, ensure_bytes logger = logging.getLogger(__name__) +TExecuteCallback = Callable[[Any], Any] + + +class ExecuteWaiter(NamedTuple): + fut: asyncio.Future + enc: Optional[str] + cb: Optional[TExecuteCallback] + + +class PParserFactory(Protocol): + def __call__( + self, + protocolError: Callable = ProtocolError, + replyError: Callable = ReplyError, + encoding: Optional[str] = None, + ) -> Reader: + ... + + async def close_connections(conns: Iterable[PClosableConnection]) -> None: close_waiters = set() for conn in conns: @@ -30,17 +86,173 @@ async def close_connections(conns: Iterable[PClosableConnection]) -> None: await asyncio.wait(close_waiters) -class RedisConnection(BaseConnection, AbcConnection): - _in_pubsub: int +class PubSub: + def __init__(self) -> None: + self._channels: coerced_keys_dict[AbcChannel] = coerced_keys_dict() + self._patterns: coerced_keys_dict[AbcChannel] = coerced_keys_dict() + self._sharded: coerced_keys_dict[AbcChannel] = coerced_keys_dict() + self._sharded_to_slot: Dict[bytes, int] = {} + self._slot_to_sharded: Dict[int, Set[bytes]] = {} + + @property + def channels(self) -> Mapping[str, AbcChannel]: + """Returns read-only channels dict.""" + return MappingProxyType(self._channels) - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) + @property + def patterns(self) -> Mapping[str, AbcChannel]: + """Returns read-only patterns dict.""" + return MappingProxyType(self._patterns) + @property + def sharded(self) -> Mapping[str, AbcChannel]: + """Returns read-only sharded channels dict.""" + return MappingProxyType(self._sharded) + + def channel_subscribe( + self, + *, + channel_type: PubSubType, + channel_name: bytes, + channel: AbcChannel, + key_slot: int, + ) -> None: + if channel_type is PubSubType.CHANNEL: + if channel_name not in self._channels: + self._channels[channel_name] = channel + elif channel_type is PubSubType.PATTERN: + if channel_name not in self._patterns: + self._patterns[channel_name] = channel + elif channel_type is PubSubType.SHARDED: + if channel_name not in self._sharded: + self._sharded[channel_name] = channel + self._sharded_to_slot[channel_name] = key_slot + if key_slot not in self._slot_to_sharded: + self._slot_to_sharded[key_slot] = set((channel_name,)) + else: + self._slot_to_sharded[key_slot].add(channel_name) + + def channel_unsubscribe( + self, + *, + channel_type: PubSubType, + channel_name: bytes, + ) -> None: + channel: Optional[AbcChannel] = None + if channel_type is PubSubType.CHANNEL: + channel = self._channels.pop(channel_name, None) + elif channel_type is PubSubType.PATTERN: + channel = self._patterns.pop(channel_name, None) + elif channel_type is PubSubType.SHARDED: + channel = self._sharded.pop(channel_name, None) + key_slot = self._sharded_to_slot.pop(channel_name, None) + if key_slot is not None: + key_slot_channels = self._slot_to_sharded[key_slot] + key_slot_channels.discard(channel_name) + if len(key_slot_channels) == 0: + del self._slot_to_sharded[key_slot] + + if channel is not None: + channel.close() + + def slot_channels_unsubscribe(self, key_slot: int) -> None: + channel_names = self._slot_to_sharded.pop(key_slot, None) + if channel_names is None: + return + + while channel_names: + channel_name = channel_names.pop() + del self._sharded_to_slot[channel_name] + channel = self._sharded.pop(channel_name) + channel.close() + + @property + def channels_total(self) -> int: + return len(self._channels) + len(self._patterns) + len(self._sharded) + + @property + def channels_num(self) -> int: + return len(self._channels) + len(self._patterns) + + @property + def sharded_channels_num(self) -> int: + return len(self._sharded) + + def get_channel(self, channel_type: PubSubType, channel_name: bytes) -> AbcChannel: + if channel_type is PubSubType.CHANNEL: + channel = self._channels[channel_name] + elif channel_type is PubSubType.PATTERN: + channel = self._patterns[channel_name] + elif channel_type is PubSubType.SHARDED: + channel = self._sharded[channel_name] + else: # pragma: no cover + raise RuntimeError(f"Unexpected channel type {channel_type!r}") + return channel + + def close(self, exc: Optional[BaseException]) -> None: + while self._channels: + _, ch = self._channels.popitem() + logger.debug("Closing pubsub channel %r", ch) + ch.close(exc) + while self._patterns: + _, ch = self._patterns.popitem() + logger.debug("Closing pubsub pattern %r", ch) + ch.close(exc) + while self._sharded: + _, ch = self._sharded.popitem() + logger.debug("Closing sharded pubsub channel %r", ch) + ch.close(exc) + + self._slot_to_sharded.clear() + self._sharded_to_slot.clear() + + +class RedisConnection(AbcConnection): + def __init__( + self, + reader: StreamReader, + writer: asyncio.StreamWriter, + *, + address: Union[Tuple[str, int], str], + encoding: Optional[str] = None, + parser: Optional[PParserFactory] = None, + loop: Optional[asyncio.AbstractEventLoop] = None, + ): + if loop is not None: + warnings.warn("The loop argument is deprecated", DeprecationWarning) + if parser is None: + parser = Reader + assert callable(parser), ("Parser argument is not callable", parser) + self._reader = reader + self._writer = writer + self._address = address + self._waiters: Deque[ExecuteWaiter] = deque() + self._reader.set_parser(parser(protocolError=ProtocolError, replyError=ReplyError)) + self._close_msg = "" + self._db = 0 + self._closing = False + self._closed = False + self._close_state = asyncio.Event() + self._in_transaction: Optional[Deque[Tuple[Optional[str], Optional[Callable]]]] = None + self._transaction_error: Optional[Exception] = None # XXX: never used? + self._pubsub_channels_store = PubSub() + # client side PubSub mode flag + self._client_in_pubsub = False + # confirmed PubSub from Redis server via first subscribe reply + self._server_in_pubsub = False + + self._encoding = encoding + self._pipeline_buffer: Optional[bytearray] = None self._readonly = False - self._sharded_pubsub_channels = coerced_keys_dict() self._loop = asyncio.get_running_loop() self._last_use_generation = 0 + self._reader_task: Optional[asyncio.Task] = self._loop.create_task(self._read_data()) + self._reader_task.add_done_callback(self._on_reader_task_done) + + def __repr__(self): + return f"<{type(self).__name__} address:{self.address} db:{self.db}>" + @property def readonly(self) -> bool: return self._readonly @@ -57,14 +269,30 @@ async def set_readonly(self, value: bool) -> None: async def auth_with_username(self, username: str, password: str) -> bool: """Authenticate to server with username and password.""" - fut = self.execute("AUTH", username, password) + fut = self.execute(b"AUTH", username, password) return await wait_ok(fut) + @property + def pubsub_channels(self) -> Mapping[str, AbcChannel]: + """Returns read-only channels dict.""" + return self._pubsub_channels_store.channels + + @property + def pubsub_patterns(self) -> Mapping[str, AbcChannel]: + """Returns read-only patterns dict.""" + return self._pubsub_channels_store.patterns + @property def sharded_pubsub_channels(self) -> Mapping[str, AbcChannel]: - return MappingProxyType(self._sharded_pubsub_channels) + """Returns read-only sharded channels dict.""" + return self._pubsub_channels_store.sharded + + async def auth(self, password: str) -> bool: + """Authenticate to server.""" + fut = self.execute(b"AUTH", password) + return await wait_ok(fut) - def execute(self, command, *args, encoding=_NOTSET): + async def execute(self, command, *args, encoding=_NOTSET) -> Any: """Executes redis command and returns Future waiting for the answer. Raises: @@ -87,10 +315,14 @@ def execute(self, command, *args, encoding=_NOTSET): if command in PUBSUB_FAMILY_COMMANDS: raise ValueError(f"PUB/SUB command {command!r} is prohibited for use with .execute()") + if encoding is _NOTSET: + encoding = self._encoding + is_ping = command in PING_COMMANDS - if self._in_pubsub and not is_ping: - raise RedisError("Connection in SUBSCRIBE mode") + if not is_ping and (self._client_in_pubsub or self._server_in_pubsub): + raise RedisError("Connection in PubSub mode") + cb: Optional[TExecuteCallback] = None if command in ("SELECT", b"SELECT"): cb = partial(self._set_db, args=args) elif command in ("MULTI", b"MULTI"): @@ -100,58 +332,137 @@ def execute(self, command, *args, encoding=_NOTSET): encoding = None elif command in ("DISCARD", b"DISCARD"): cb = partial(self._end_transaction, discard=True) - else: - cb = None - if encoding is _NOTSET: - encoding = self._encoding - fut = self._loop.create_future() if self._pipeline_buffer is None: self._writer.write(encode_command(command, *args)) else: encode_command(command, *args, buf=self._pipeline_buffer) - self._waiters.append((fut, encoding, cb)) - return fut - def execute_pubsub(self, command, *channels): - """Executes redis (p)subscribe/(p)unsubscribe commands. + fut = self._loop.create_future() + self._waiters.append( + ExecuteWaiter( + fut=fut, + enc=encoding, + cb=cb, + ) + ) + + return await fut + + async def execute_pubsub(self, command, *channels: Union[bytes, str, AbcChannel]): + """Executes redis (p|s)subscribe/(p|s)unsubscribe commands. Returns asyncio.gather coroutine waiting for all channels/patterns to receive answers. """ command = command.upper().strip() - assert command in PUBSUB_FAMILY_COMMANDS, ("Pub/Sub command expected", command) + if command not in PUBSUB_FAMILY_COMMANDS: + raise ValueError(f"Pub/Sub command expected, not {command!r}") if self._reader is None or self._reader.at_eof(): raise ConnectionClosedError("Connection closed or corrupted") if None in set(channels): raise TypeError("args must not contain None") - if not len(channels): - raise TypeError("No channels/patterns supplied") - if command in SHARDED_PUBSUB_COMMANDS: - is_pattern = False + channel_type = PUBSUB_COMMAND_TO_TYPE[command] + is_subscribe_command = command in PUBSUB_SUBSCRIBE_COMMANDS + is_pattern = channel_type is PubSubType.PATTERN + key_slot = -1 + reply_kind = ensure_bytes(command.lower()) + + channels_obj: Dict[str, AbcChannel] + if len(channels) == 0: + if is_subscribe_command: + raise ValueError("No channels to (un)subscribe") + elif channel_type is PubSubType.PATTERN: + channels_obj = dict(self._pubsub_channels_store.patterns) + elif channel_type is PubSubType.SHARDED: + channels_obj = dict(self._pubsub_channels_store.sharded) + else: + channels_obj = dict(self._pubsub_channels_store.channels) else: - is_pattern = len(command) in (10, 12) - - mkchannel = partial(Channel, is_pattern=is_pattern) - channels_obj: List[AbcChannel] = [ - ch if isinstance(ch, AbcChannel) else mkchannel(ch) for ch in channels - ] - if not all(ch.is_pattern == is_pattern for ch in channels_obj): - raise ValueError("Not all channels {} match command {}".format(channels, command)) + mkchannel = partial(Channel, is_pattern=is_pattern) + channels_obj = {} + for channel_name_or_obj in channels: + if not isinstance(channel_name_or_obj, AbcChannel): + ch = mkchannel(channel_name_or_obj) + if ch.name in channels_obj: + raise ValueError(f"Found channel duplicates in {channels!r}") + if ch.is_pattern != is_pattern: + raise ValueError(f"Not all channels {channels!r} match command {command!r}") + channels_obj[ch.name] = ch + + if channel_type is PubSubType.SHARDED: + try: + key_slot = determine_slot(*(ensure_bytes(name) for name in channels_obj.keys())) + except CrossSlotKeysError: + raise ValueError( + f"Not all channels shared one key slot in cluster {channels!r}" + ) from None + + cmd = encode_command(command, *(name for name in channels_obj.keys())) + res: List[Any] = [] + + if is_subscribe_command: + for ch in channels_obj.values(): + channel_name = ensure_bytes(ch.name) + self._pubsub_channels_store.channel_subscribe( + channel_type=channel_type, + channel_name=channel_name, + channel=ch, + key_slot=key_slot, + ) + if channel_type is PubSubType.SHARDED: + channels_num = self._pubsub_channels_store.sharded_channels_num + else: + channels_num = self._pubsub_channels_store.channels_num + res.append([reply_kind, channel_name, channels_num]) + + # otherwise unsubscribe command + else: + for ch in channels_obj.values(): + channel_name = ensure_bytes(ch.name) + self._pubsub_channels_store.channel_unsubscribe( + channel_type=channel_type, + channel_name=channel_name, + ) + if channel_type is PubSubType.SHARDED: + channels_num = self._pubsub_channels_store.sharded_channels_num + else: + channels_num = self._pubsub_channels_store.channels_num + res.append([reply_kind, channel_name, channels_num]) - cmd = encode_command(command, *(ch.name for ch in channels_obj)) - res = [] - for ch in channels_obj: - fut = self._loop.create_future() - res.append(fut) - cb = partial(self._update_pubsub, ch=ch) - self._waiters.append((fut, None, cb)) if self._pipeline_buffer is None: self._writer.write(cmd) else: self._pipeline_buffer.extend(cmd) - return asyncio.gather(*res) + + if not self._client_in_pubsub and not self._server_in_pubsub: + if is_subscribe_command: + # entering to PubSub mode on client side + self._client_in_pubsub = True + + fut = self._loop.create_future() + self._waiters.append( + ExecuteWaiter( + fut=fut, + enc=None, + cb=self._process_pubsub, + ) + ) + server_reply = list(await fut) + if server_reply != res[0]: + if logger.isEnabledFor(logging.DEBUG): + logger.error( + "Unexpected server reply on PubSub on %r: %r, expected %r", + command, + server_reply, + res[0], + ) + exc = RedisError(f"Unexpected server reply on PubSub {command!r}") + self._do_close(exc) + raise exc + + return res def get_last_use_generation(self) -> int: return self._last_use_generation @@ -159,77 +470,352 @@ def get_last_use_generation(self) -> int: def set_last_use_generation(self, gen: int): self._last_use_generation = gen - def _process_pubsub(self, obj, *, process_waiters: bool = True): - """Processes pubsub messages.""" - - if isinstance(obj, RedisError): - # case for new pubsub command for example: - # new ssubscribe to old node of slot - return self._process_data(obj) - - kind, *args, data = obj - if kind in (b"subscribe", b"unsubscribe"): - (chan,) = args - if process_waiters and self._in_pubsub and self._waiters: - self._process_data(obj) - if kind == b"unsubscribe": - ch = self._pubsub_channels.pop(chan, None) - if ch: - ch.close() - self._in_pubsub = data - elif kind in (b"ssubscribe", b"sunsubscribe"): - (chan,) = args - if process_waiters and self._in_pubsub and self._waiters: - self._process_data(obj) - if kind == b"sunsubscribe": - ch = self._sharded_pubsub_channels.pop(chan, None) - if ch: - ch.close() - self._in_pubsub = data - elif kind in (b"psubscribe", b"punsubscribe"): - (chan,) = args - if process_waiters and self._in_pubsub and self._waiters: - self._process_data(obj) - if kind == b"punsubscribe": - ch = self._pubsub_patterns.pop(chan, None) - if ch: - ch.close() - self._in_pubsub = data - elif kind == b"message": - (chan,) = args - self._pubsub_channels[chan].put_nowait(data) - elif kind == b"smessage": - (chan,) = args - self._sharded_pubsub_channels[chan].put_nowait(data) + def close(self) -> None: + """Close connection.""" + self._do_close(ConnectionForcedCloseError()) + + @property + def closed(self) -> bool: + """True if connection is closed.""" + closed = self._closing or self._closed + if not closed and self._reader and self._reader.at_eof(): + self._closing = closed = True + self._loop.call_soon(self._do_close, None) + return closed + + async def wait_closed(self) -> None: + """Coroutine waiting until connection is closed.""" + await self._close_state.wait() + + @property + def db(self) -> int: + """Currently selected db index.""" + return self._db + + @property + def encoding(self) -> Optional[str]: + """Current set codec or None.""" + return self._encoding + + @property + def address(self) -> Union[Tuple[str, int], str]: + """Redis server address, either host-port tuple or str.""" + return self._address + + @property + def in_transaction(self) -> bool: + """Set to True when MULTI command was issued.""" + return self._in_transaction is not None + + @property + def in_pubsub(self) -> int: + """Indicates that connection is in PUB/SUB mode. + + Provides the number of subscribed channels. + """ + return self._pubsub_channels_store.channels_total + + async def select(self, db: int) -> bool: + """Change the selected database for the current connection.""" + if not isinstance(db, int): + raise TypeError("DB must be of int type, not {!r}".format(db)) + if db < 0: + raise ValueError("DB must be greater or equal 0, got {!r}".format(db)) + fut = self.execute(b"SELECT", db) + return await wait_ok(fut) + + def _set_db(self, ok, args): + assert ok in {b"OK", "OK"}, ("Unexpected result of SELECT", ok) + self._db = args[0] + return ok + + def _start_transaction(self, ok): + if self._in_transaction is not None: + raise RuntimeError("Connection is already in transaction") + self._in_transaction = deque() + self._transaction_error = None + return ok + + def _end_transaction(self, obj: Any, discard: bool) -> Any: + if self._in_transaction is None: + raise RuntimeError("Connection is not in transaction") + self._transaction_error = None + recall, self._in_transaction = self._in_transaction, None + recall.popleft() # ignore first (its _start_transaction) + if discard: + return obj + + if not (isinstance(obj, list) or (obj is None and not discard)): + raise RuntimeError(f"Unexpected MULTI/EXEC result: {obj!r}, {recall!r}") + + # TODO: need to be able to re-try transaction + if obj is None: + err = WatchVariableError("WATCH variable has changed") + obj = [err] * len(recall) + + if len(obj) != len(recall): + raise RuntimeError(f"Wrong number of result items in mutli-exec: {obj!r}, {recall!r}") + + res = [] + for o, (encoding, cb) in zip(obj, recall): + if not isinstance(o, RedisError): + try: + if encoding: + o = decode(o, encoding) + if cb: + o = cb(o) + except Exception as err: + res.append(err) + continue + res.append(o) + return res + + def _do_close(self, exc: Optional[BaseException]) -> None: + if self._closed: + return + self._closed = True + self._closing = False + self._writer.transport.close() + if self._reader_task is not None: + self._reader_task.cancel() + self._reader_task = None + del self._writer + del self._reader + self._pipeline_buffer = None + + if exc is not None: + self._close_msg = str(exc) + + while self._waiters: + waiter = self._waiters.popleft() + logger.debug("Cancelling waiter %r", waiter) + if exc is None: + _set_exception(waiter.fut, ConnectionForcedCloseError()) + else: + _set_exception(waiter.fut, exc) + + self._pubsub_channels_store.close(exc) + + def _on_reader_task_done(self, task: asyncio.Task) -> None: + if not task.cancelled() and task.exception(): + logger.error( + "Reader task unexpectedly done with expection: %r", + task.exception(), + exc_info=task.exception(), + ) + # prevent RedisConnection stuck in half-closed state + self._reader_task = None + self._do_close(ConnectionForcedCloseError()) + self._close_state.set() + + def _is_pubsub_resp(self, obj: Any) -> bool: + if not isinstance(obj, (tuple, list)): + return False + if len(obj) == 0: + return False + return obj[0] in PUBSUB_RESP_KIND_TO_TYPE + + async def _read_data(self) -> None: + """Response reader task.""" + last_error = ConnectionClosedError("Connection has been closed by server") + while not self._reader.at_eof(): + try: + obj = await self._reader.readobj() + except asyncio.CancelledError: + # NOTE: reader can get cancelled from `close()` method only. + last_error = RuntimeError("this is unexpected") + break + except ProtocolError as exc: + # ProtocolError is fatal + # so connection must be closed + if self._in_transaction is not None: + self._transaction_error = exc + last_error = exc + break + except Exception as exc: + # NOTE: for QUIT command connection error can be received + # before response + last_error = exc + break + else: + if (obj == b"" or obj is None) and self._reader.at_eof(): + logger.debug("Connection has been closed by server, response: %r", obj) + last_error = ConnectionClosedError("Reader at end of file") + break + + if isinstance(obj, MaxClientsError): + last_error = obj + break + + if self._loop.get_debug(): + logger.debug( + "Received reply (client_in_pubsub:%s, server_in_pubsub:%s): %r", + self._client_in_pubsub, + self._server_in_pubsub, + obj, + ) + + if self._server_in_pubsub: + if isinstance(obj, MovedError): + logger.warning( + "Received MOVED in PubSub mode. Unsubscribe all channels from %d slot", + obj.info.slot_id, + ) + self._pubsub_channels_store.slot_channels_unsubscribe(obj.info.slot_id) + elif isinstance(obj, RedisError): + raise obj + else: + self._process_pubsub(obj) + else: + if isinstance(obj, RedisError): + if isinstance(obj, ReplyError): + if obj.args[0].startswith("READONLY"): + obj = ReadOnlyError(obj.args[0]) + self._wakeup_waiter_with_exc(obj) + else: + self._wakeup_waiter_with_result(obj) + + self._closing = True + self._loop.call_soon(self._do_close, last_error) + + def _wakeup_waiter_with_exc(self, exc: Exception) -> None: + """Processes command errors.""" + + if not self._waiters: + logger.error("No waiter for process error: %r", exc) + return + + waiter = self._waiters.popleft() + _set_exception(waiter.fut, exc) + if self._in_transaction is not None: + self._transaction_error = exc + + def _wakeup_waiter_with_result(self, result: Any) -> None: + """Processes command results.""" + + if self._loop.get_debug(): + logger.debug("Wakeup first waiter for reply: %r", result) + + if not self._waiters: + logger.error("No waiter for received reply: %r, %r", type(result), result) + return + + waiter = self._waiters.popleft() + self._resolve_waiter_with_result(waiter, result) + + def _resolve_waiter_with_result(self, waiter: ExecuteWaiter, result: Any) -> None: + if waiter.enc is not None: + try: + decoded_result = decode(result, waiter.enc) + except Exception as exc: + _set_exception(waiter.fut, exc) + return + else: + decoded_result = result + + del result + + if waiter.cb is not None: + try: + converted_result = waiter.cb(decoded_result) + except Exception as exc: + _set_exception(waiter.fut, exc) + return + else: + converted_result = decoded_result + + del decoded_result + + _set_result(waiter.fut, converted_result) + if self._in_transaction is not None: + self._in_transaction.append((waiter.enc, waiter.cb)) + + def _process_pubsub(self, obj: Any) -> Any: + """Processes pubsub messages. + + This method calls directly on `_read_data` routine + and used as callback in `execute_pubsub` for first PubSub mode initial reply + """ + + if self._loop.get_debug(): + logger.debug( + "Process PubSub reply (client_in_pubsub:%s, server_in_pubsub:%s): %r", + self._client_in_pubsub, + self._server_in_pubsub, + obj, + ) + + if isinstance(obj, bytes): + # process simple bytes as PING reply + kind = b"pong" + data = obj + else: + kind, *args, data = obj + + channel_name: bytes + + if kind in {b"subscribe", b"psubscribe", b"ssubscribe"}: + logger.debug("PubSub subscribe confirmation received: %r", obj) + # confirm PubSub mode in client side based on server reply and reset pending flag + if self._client_in_pubsub and not self._server_in_pubsub: + self._server_in_pubsub = True + elif kind in {b"unsubscribe", b"punsubscribe", b"sunsubscribe"}: + (channel_name,) = args + channel_type = PUBSUB_RESP_KIND_TO_TYPE[kind] + self._pubsub_channels_store.channel_unsubscribe( + channel_type=channel_type, + channel_name=channel_name, + ) + if self._pubsub_channels_store.channels_total == 0: + self._server_in_pubsub = False + self._client_in_pubsub = False + elif kind in {b"message", b"smessage"}: + (channel_name,) = args + channel_type = PUBSUB_RESP_KIND_TO_TYPE[kind] + channel = self._pubsub_channels_store.get_channel(channel_type, channel_name) + channel.put_nowait(data) elif kind == b"pmessage": - pattern, chan = args - self._pubsub_patterns[pattern].put_nowait((chan, data)) + (pattern, channel_name) = args + channel_type = PUBSUB_RESP_KIND_TO_TYPE[kind] + channel = self._pubsub_channels_store.get_channel(channel_type, pattern) + channel.put_nowait((channel_name, data)) elif kind == b"pong": - if process_waiters and self._in_pubsub and self._waiters: - self._process_data(data or b"PONG") + if not self._waiters: + logger.error("No PubSub PONG waiters for received data %r", data) + else: + # in PubSub mode only PING waiters in this deque + # see in execute() method `is_ping` condition + waiter = self._waiters.popleft() + self._resolve_waiter_with_result(waiter, data or b"PONG") else: logger.warning("Unknown pubsub message received %r", obj) - def _update_pubsub(self, obj, *, ch: AbcChannel): - kind, *pattern, channel, subscriptions = obj - self._in_pubsub, was_in_pubsub = subscriptions, self._in_pubsub - if kind == b"subscribe" and channel not in self._pubsub_channels: - self._pubsub_channels[channel] = ch - elif kind == b"psubscribe" and channel not in self._pubsub_patterns: - self._pubsub_patterns[channel] = ch - elif kind == b"ssubscribe" and channel not in self._sharded_pubsub_channels: - self._sharded_pubsub_channels[channel] = ch - if not was_in_pubsub: - self._process_pubsub(obj, process_waiters=False) return obj - def _do_close(self, exc): - super()._do_close(exc) - - if self._closed: - return - - while self._sharded_pubsub_channels: - _, ch = self._sharded_pubsub_channels.popitem() - logger.debug("Closing sharded pubsub channel %r", ch) - ch.close(exc) + @contextmanager + def _buffered(self): + # XXX: we must ensure that no await happens + # as long as we buffer commands. + # Probably we can set some error-raising callback on enter + # and remove it on exit + # if some await happens in between -> throw an error. + # This is creepy solution, 'cause some one might want to await + # on some other source except redis. + # So we must only raise error we someone tries to await + # pending aioredis future + # One of solutions is to return coroutine instead of a future + # in `execute` method. + # In a coroutine we can check if buffering is enabled and raise error. + + # TODO: describe in docs difference in pipeline mode for + # conn.execute vs pipeline.execute() + if self._pipeline_buffer is None: + self._pipeline_buffer = bytearray() + try: + yield self + buf = self._pipeline_buffer + self._writer.write(buf) + finally: + self._pipeline_buffer = None + else: + yield self diff --git a/src/aioredis_cluster/crc.py b/src/aioredis_cluster/crc.py index 7123576..6c06f75 100644 --- a/src/aioredis_cluster/crc.py +++ b/src/aioredis_cluster/crc.py @@ -10,6 +10,8 @@ __all__ = ( "crc16", "key_slot", + "determine_slot", + "CrossSlotKeysError", ) REDIS_CLUSTER_HASH_SLOTS = 16384 @@ -43,3 +45,16 @@ def py_key_slot(k: bytes, bucket: int = REDIS_CLUSTER_HASH_SLOTS) -> int: key_slot = cy_key_slot else: key_slot = py_key_slot + + +class CrossSlotKeysError(Exception): + pass + + +def determine_slot(first_key: bytes, *keys: bytes) -> int: + slot: int = key_slot(first_key) + for k in keys: + if slot != key_slot(k): + raise CrossSlotKeysError("all keys must map to the same key slot") + + return slot diff --git a/src/aioredis_cluster/pool.py b/src/aioredis_cluster/pool.py index 2d962ed..09c8de5 100644 --- a/src/aioredis_cluster/pool.py +++ b/src/aioredis_cluster/pool.py @@ -14,6 +14,7 @@ from aioredis_cluster.aioredis import Channel, PoolClosedError, create_connection from aioredis_cluster.command_info.commands import ( BLOCKING_COMMANDS, + PATTERN_PUBSUB_COMMANDS, PUBSUB_COMMANDS, PUBSUB_FAMILY_COMMANDS, SHARDED_PUBSUB_COMMANDS, @@ -226,7 +227,7 @@ async def execute_pubsub(self, command: TBytesOrStr, *channels): self._check_closed() - if command in PUBSUB_COMMANDS: + if command in PUBSUB_COMMANDS or command in PATTERN_PUBSUB_COMMANDS: conn = await self._get_pubsub_connection() elif command in SHARDED_PUBSUB_COMMANDS: conn = await self._get_sharded_pubsub_connection() @@ -246,7 +247,7 @@ def get_connection(self, command: TBytesOrStr, args=()): command = command.upper().strip() ret_conn: Optional[AbcConnection] = None - if command in PUBSUB_COMMANDS: + if command in PUBSUB_COMMANDS or command in PATTERN_PUBSUB_COMMANDS: if self._pubsub_conn and not self._pubsub_conn.closed: ret_conn = self._pubsub_conn elif command in SHARDED_PUBSUB_COMMANDS: @@ -267,18 +268,23 @@ def get_connection(self, command: TBytesOrStr, args=()): return ret_conn, self._address - async def select(self, db): - """For cluster implementation this method is unavailable""" + async def select(self, db: int) -> bool: + """Changes db index for all free connections. + + All previously acquired connections will be closed when released. + + For cluster implementation this method is unavailable + """ raise NotImplementedError("Feature is blocked in cluster mode") - async def auth(self, password) -> None: + async def auth(self, password: str) -> None: self._password = password async with self._cond: for conn in tuple(self._pool): conn.set_last_use_generation(self._idle_connections_collect_gen) await conn.auth(password) - async def auth_with_username(self, username, password) -> None: + async def auth_with_username(self, username: str, password: str) -> None: self._username = username self._password = password async with self._cond: diff --git a/tests/aioredis_tests/conftest.py b/tests/aioredis_tests/conftest.py index b4a750e..2daef05 100644 --- a/tests/aioredis_tests/conftest.py +++ b/tests/aioredis_tests/conftest.py @@ -9,6 +9,8 @@ import tempfile import time from collections import namedtuple +from functools import partial +from typing import List from urllib.parse import urlencode, urlunparse import pytest @@ -16,7 +18,18 @@ from aioredis_cluster import _aioredis as aioredis from aioredis_cluster._aioredis import sentinel as aioredis_sentinel +from aioredis_cluster.aioredis import create_redis as custom_create_redis from aioredis_cluster.compat.asyncio import timeout as atimeout +from aioredis_cluster.connection import RedisConnection + +try: + import aioredis as _origin_aioredis + + (_origin_aioredis,) + aioredis_installed = True +except ImportError: + aioredis_installed = False + TCPAddress = namedtuple("TCPAddress", "host port") @@ -59,8 +72,25 @@ async def f(*args, **kw): return f +create_redis_fixture_params: List = [ + aioredis.create_redis, + aioredis.create_redis_pool, +] +create_redis_fixture_ids: List = [ + "single", + "pool", +] + +if aioredis_installed is False: + create_redis_fixture_params.append(partial(custom_create_redis, connection_cls=RedisConnection)) + create_redis_fixture_ids.append( + "single_from_cluster", + ) + + @pytest_asyncio.fixture( - params=[aioredis.create_redis, aioredis.create_redis_pool], ids=["single", "pool"] + params=create_redis_fixture_params, + ids=create_redis_fixture_ids, ) def create_redis(_closable, request): """Wrapper around aioredis.create_redis.""" diff --git a/tests/aioredis_tests/pubsub_commands_test.py b/tests/aioredis_tests/pubsub_commands_test.py index 4d96f05..1426db5 100644 --- a/tests/aioredis_tests/pubsub_commands_test.py +++ b/tests/aioredis_tests/pubsub_commands_test.py @@ -62,6 +62,33 @@ async def test_subscribe(redis): assert res == [[b"unsubscribe", b"chan:1", 1], [b"unsubscribe", b"chan:2", 0]] +async def test_subscribe__multiple_times(redis): + res1 = await redis.subscribe("chan:1") + assert redis.in_pubsub == 1 + res2 = await redis.subscribe("chan:1") + assert redis.in_pubsub == 1 + res3 = await redis.psubscribe("chan:1") + assert redis.in_pubsub == 2 + # res4 = await redis.connection.execute_pubsub("SSUBSCRIBE", "chan:1") + # assert redis.in_pubsub == 3 + + ch1 = redis.channels["chan:1"] + ch3 = redis.patterns["chan:1"] + + assert res1 == [ch1] + assert res2 == [ch1] + assert res3 == [ch3] + + res = await redis.unsubscribe("chan:1") + assert res == [[b"unsubscribe", b"chan:1", 1]] + + res = await redis.punsubscribe("chan:1") + assert res == [[b"punsubscribe", b"chan:1", 0]] + + res = await redis.unsubscribe("chan:1") + assert res == [[b"unsubscribe", b"chan:1", 0]] + + @pytest.mark.parametrize( "create_redis", [ diff --git a/tests/unit_tests/aioredis_cluster/conftest.py b/tests/unit_tests/aioredis_cluster/conftest.py index 2a4b639..b2c4887 100644 --- a/tests/unit_tests/aioredis_cluster/conftest.py +++ b/tests/unit_tests/aioredis_cluster/conftest.py @@ -1,9 +1,30 @@ -# import asyncio +import pytest +import pytest_asyncio -# import pytest +@pytest.fixture +def add_finalizer(): + finalizers = [] -# def pytest_collection_modifyitems(items): -# for item in items: -# if not item.get_closest_marker("asyncio") and asyncio.iscoroutinefunction(item.function): -# item.add_marker(pytest.mark.asyncio) + def adder(finalizer): + finalizers.append(finalizer) + + try: + yield adder + finally: + for finalizer in finalizers: + finalizer() + + +@pytest_asyncio.fixture +async def add_async_finalizer(): + finalizers = [] + + def adder(finalizer): + finalizers.append(finalizer) + + try: + yield adder + finally: + for finalizer in finalizers: + await finalizer() diff --git a/tests/unit_tests/aioredis_cluster/test_connection.py b/tests/unit_tests/aioredis_cluster/test_connection.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit_tests/aioredis_cluster/test_connection_pubsub.py b/tests/unit_tests/aioredis_cluster/test_connection_pubsub.py index f94f695..7c44e72 100644 --- a/tests/unit_tests/aioredis_cluster/test_connection_pubsub.py +++ b/tests/unit_tests/aioredis_cluster/test_connection_pubsub.py @@ -4,18 +4,28 @@ import pytest +from aioredis_cluster.aioredis.stream import StreamReader +from aioredis_cluster.command_info.commands import PUBSUB_SUBSCRIBE_COMMANDS from aioredis_cluster.connection import RedisConnection -from aioredis_cluster.errors import MovedError +from aioredis_cluster.errors import MovedError, RedisError -class Reader: +async def moment(times: int = 1) -> None: + for _ in range(times): + await asyncio.sleep(0) + + +class MockedReader(StreamReader): def __init__(self) -> None: - self.queue = Queue() + self.queue: Queue = Queue() self.eof = False def set_parser(self, *args): pass + def feed_data(self, data): + pass + async def readobj(self): result = await self.queue.get() self.queue.task_done() @@ -25,24 +35,375 @@ def at_eof(self) -> bool: return self.eof and self.queue.empty() -async def test_moved_with_pubsub(): - reader = Reader() +def get_mocked_reader(): + return MockedReader() + + +def get_mocked_writer(): writer = mock.AsyncMock() writer.write = mock.Mock() writer.transport = mock.NonCallableMock() + return writer + + +async def close_connection(conn: RedisConnection) -> None: + conn.close() + await conn.wait_closed() + + +async def test_execute__simple_subscribe(add_async_finalizer): + reader = get_mocked_reader() + writer = get_mocked_writer() + redis = RedisConnection(reader=reader, writer=writer, address="localhost:6379") + add_async_finalizer(lambda: close_connection(redis)) + + reader.queue.put_nowait([b"subscribe", b"chan", 1]) + reader.queue.put_nowait([b"ssubscribe", b"chan", 1]) + reader.queue.put_nowait([b"psubscribe", b"chan", 2]) + + result_channel = await redis.execute_pubsub("SUBSCRIBE", "chan") + result_sharded = await redis.execute_pubsub("SSUBSCRIBE", "chan") + result_pattern = await redis.execute_pubsub("PSUBSCRIBE", "chan") + + assert result_channel == [[b"subscribe", b"chan", 1]] + assert result_pattern == [[b"psubscribe", b"chan", 2]] + assert result_sharded == [[b"ssubscribe", b"chan", 1]] + assert redis.in_pubsub == 3 + assert redis._client_in_pubsub is True + assert redis._server_in_pubsub is True + assert len(redis._waiters) == 0 + + assert "chan" in redis.pubsub_channels + assert "chan" in redis.pubsub_patterns + assert "chan" in redis.sharded_pubsub_channels + + assert redis.pubsub_channels["chan"] is not redis.pubsub_patterns["chan"] + assert redis.pubsub_channels["chan"] is not redis.sharded_pubsub_channels["chan"] + assert redis.pubsub_patterns["chan"] is not redis.sharded_pubsub_channels["chan"] + + +async def test_execute__simple_unsubscribe(add_async_finalizer): + reader = get_mocked_reader() + writer = get_mocked_writer() + redis = RedisConnection(reader=reader, writer=writer, address="localhost:6379") + add_async_finalizer(lambda: close_connection(redis)) + + reader.queue.put_nowait([b"ssubscribe", b"chan", 1]) + reader.queue.put_nowait([b"psubscribe", b"chan", 1]) + reader.queue.put_nowait([b"subscribe", b"chan", 2]) + + await redis.execute_pubsub("SSUBSCRIBE", "chan") + await redis.execute_pubsub("PSUBSCRIBE", "chan") + await redis.execute_pubsub("SUBSCRIBE", "chan") + + assert redis.in_pubsub == 3 + + reader.queue.put_nowait([b"unsubscribe", b"chan", 1]) + reader.queue.put_nowait([b"punsubscribe", b"chan", 0]) + reader.queue.put_nowait([b"sunsubscribe", b"chan", 0]) + result_channel = await redis.execute_pubsub("UNSUBSCRIBE", "chan") + result_pattern = await redis.execute_pubsub("PUNSUBSCRIBE", "chan") + result_sharded = await redis.execute_pubsub("SUNSUBSCRIBE", "chan") + + await moment() + + assert redis.in_pubsub == 0 + assert result_channel == [[b"unsubscribe", b"chan", 1]] + assert result_pattern == [[b"punsubscribe", b"chan", 0]] + assert result_sharded == [[b"sunsubscribe", b"chan", 0]] + assert redis._client_in_pubsub is False + assert redis._server_in_pubsub is False + + +@pytest.mark.parametrize( + "command", + [ + "SUBSCRIBE", + "PSUBSCRIBE", + "SSUBSCRIBE", + "UNSUBSCRIBE", + "PUNSUBSCRIBE", + "SUNSUBSCRIBE", + ], +) +async def test_execute__first_command(add_async_finalizer, command: str): + reader = get_mocked_reader() + writer = get_mocked_writer() redis = RedisConnection(reader=reader, writer=writer, address="localhost:6379") + add_async_finalizer(lambda: close_connection(redis)) + + is_subscribe_command = command in PUBSUB_SUBSCRIBE_COMMANDS + kind = command.encode().lower() + + subs_num = 1 if is_subscribe_command else 0 + reader.queue.put_nowait((kind, b"chan", subs_num)) + + await redis.execute_pubsub(command, "chan") + + if is_subscribe_command: + assert redis.in_pubsub == 1 + assert redis._client_in_pubsub is True + assert redis._server_in_pubsub is True + else: + assert redis.in_pubsub == 0 + assert redis._client_in_pubsub is False + assert redis._server_in_pubsub is False + + +async def test_execute__half_open_pubsub_mode(add_async_finalizer): + reader = get_mocked_reader() + writer = get_mocked_writer() + redis = RedisConnection(reader=reader, writer=writer, address="localhost:6379") + add_async_finalizer(lambda: close_connection(redis)) + + get_task = asyncio.ensure_future(redis.execute("GET", "foo", encoding="utf-8")) + ping1_task = asyncio.ensure_future(redis.execute("PING", "ping_reply1")) + subs_task = asyncio.ensure_future(redis.execute_pubsub("SSUBSCRIBE", "chan")) + # SET not send and execute() must raise RedisError exception + set_task = asyncio.ensure_future(redis.execute("SET", "foo", "val2")) + ping2_task = asyncio.ensure_future(redis.execute("PING", "ping_reply2", encoding="utf-8")) + + # need extra loop for asyncio.ensure_future starts a tasks + await moment() + + assert redis._client_in_pubsub is True + assert redis._server_in_pubsub is False + + reader.queue.put_nowait(b"val1") + reader.queue.put_nowait(b"ping_reply1") + reader.queue.put_nowait([b"ssubscribe", b"chan", 1]) + + # This is incorrect. Redis must return error with restrict this command in PubSub mode + # and client must prevent send SET command in half-open PubSub mode + # reader.queue.put_nowait(b"OK") + + reader.queue.put_nowait([b"pong", b"ping_reply2"]) + + # make 2 extra loops + await moment(2) + + assert get_task.done() is True + assert ping1_task.done() is True + assert subs_task.done() is True + assert set_task.done() is True + assert ping2_task.done() is True + + assert get_task.result() == "val1" + assert ping1_task.result() == b"ping_reply1" + assert subs_task.result() == [[b"ssubscribe", b"chan", 1]] + with pytest.raises(RedisError, match="Connection in PubSub mode"): + assert set_task.result() + assert ping2_task.result() == "ping_reply2" + + assert redis.in_pubsub == 1 + + +async def test_execute__unsubscribe_without_subscribe(add_async_finalizer): + reader = get_mocked_reader() + writer = get_mocked_writer() + redis = RedisConnection(reader=reader, writer=writer, address="localhost:6379") + add_async_finalizer(lambda: close_connection(redis)) + + reader.queue.put_nowait([b"sunsubscribe", b"chan", 0]) + await redis.execute_pubsub("SUNSUBSCRIBE", "chan") + reader.queue.put_nowait((b"punsubscribe", b"chan", 0)) + await redis.execute_pubsub("PUNSUBSCRIBE", "chan") + reader.queue.put_nowait((b"unsubscribe", b"chan", 0)) + await redis.execute_pubsub("UNSUBSCRIBE", "chan") + + assert redis.in_pubsub == 0 + assert redis._client_in_pubsub is False + assert redis._server_in_pubsub is False + + +async def test__redis_push_unsubscribe(add_async_finalizer): + reader = get_mocked_reader() + writer = get_mocked_writer() + redis = RedisConnection(reader=reader, writer=writer, address="localhost:6379") + add_async_finalizer(lambda: close_connection(redis)) + + reader.queue.put_nowait([b"subscribe", b"chan:1", 1]) + reader.queue.put_nowait([b"subscribe", b"chan:2", 2]) + reader.queue.put_nowait([b"psubscribe", b"chan:3", 3]) + reader.queue.put_nowait([b"psubscribe", b"chan:4", 4]) + reader.queue.put_nowait([b"ssubscribe", b"chan:5:{shard}", 1]) + reader.queue.put_nowait([b"ssubscribe", b"chan:6:{shard}", 2]) + await redis.execute_pubsub("SUBSCRIBE", "chan:1", "chan:2") + await redis.execute_pubsub("PSUBSCRIBE", "chan:3", "chan:4") + await redis.execute_pubsub("SSUBSCRIBE", "chan:5:{shard}", "chan:6:{shard}") + + assert redis.in_pubsub == 6 + + reader.queue.put_nowait([b"unsubscribe", b"chan:1", 3]) + reader.queue.put_nowait([b"unsubscribe", b"chan:2", 2]) + reader.queue.put_nowait([b"punsubscribe", b"chan:3", 1]) + reader.queue.put_nowait([b"punsubscribe", b"chan:4", 0]) + reader.queue.put_nowait([b"sunsubscribe", b"chan:5:{shard}", 1]) + reader.queue.put_nowait([b"sunsubscribe", b"chan:6:{shard}", 0]) + # some extra channel + reader.queue.put_nowait([b"unsubscribe", b"chan:7", 3]) + + await moment() + + assert redis.in_pubsub == 0 + + assert redis._reader_task.done() is False + + +async def test_moved_with_pubsub(add_async_finalizer): + reader = get_mocked_reader() + writer = get_mocked_writer() + redis = RedisConnection(reader=reader, writer=writer, address="localhost:6379") + add_async_finalizer(lambda: close_connection(redis)) + + reader.queue.put_nowait((b"ssubscribe", b"chan1", 1)) + + # key slot for chan1 - 2323 + await redis.execute_pubsub("SSUBSCRIBE", "chan1") + + assert len(redis.sharded_pubsub_channels) == 1 + assert "chan1" in redis.sharded_pubsub_channels + + # key slot chan2:{shard1} - 10271 + reader.queue.put_nowait([b"ssubscribe", b"chan2:{shard1}", 11]) + reader.queue.put_nowait([b"ssubscribe", b"chan3:{shard1}", 11]) + await redis.execute_pubsub("SSUBSCRIBE", "chan2:{shard1}") + await redis.execute_pubsub("SSUBSCRIBE", "chan3:{shard1}") + + assert len(redis.sharded_pubsub_channels) == 3 + assert "chan2:{shard1}" in redis.sharded_pubsub_channels + assert "chan3:{shard1}" in redis.sharded_pubsub_channels + + reader.queue.put_nowait(MovedError("MOVED 2323 127.0.0.1:6379")) + await moment() + + assert len(redis.sharded_pubsub_channels) == 2 + assert "chan1" not in redis.sharded_pubsub_channels + + reader.queue.put_nowait(MovedError("MOVED 10271 127.0.0.1:6379")) + await moment() + + assert len(redis.sharded_pubsub_channels) == 0 + + assert redis._reader_task.done() is False, redis._reader_task.exception() + + +async def test_execute__unexpectable_unsubscribe_and_moved(add_async_finalizer): + reader = get_mocked_reader() + writer = get_mocked_writer() + redis = RedisConnection(reader=reader, writer=writer, address="localhost:6379") + add_async_finalizer(lambda: close_connection(redis)) + + reader.queue.put_nowait([b"ssubscribe", b"chan1:{shard1}", 1]) + reader.queue.put_nowait([b"ssubscribe", b"chan2:{shard1}", 2]) + + await redis.execute_pubsub("SSUBSCRIBE", "chan1:{shard1}") + await redis.execute_pubsub("SSUBSCRIBE", "chan2:{shard1}") + + reader.queue.put_nowait([b"sunsubscribe", b"chan2:{shard1}", 1]) + reader.queue.put_nowait(MovedError("MOVED 10271 127.0.0.1:6379")) + await moment() + + assert redis.in_pubsub == 0 + assert redis._reader_task.done() is False + + +async def test_execute__client_unsubscribe_with_server_unsubscribe(add_async_finalizer): + reader = get_mocked_reader() + writer = get_mocked_writer() + redis = RedisConnection(reader=reader, writer=writer, address="localhost:6379") + add_async_finalizer(lambda: close_connection(redis)) + + reader.queue.put_nowait([b"ssubscribe", b"chan:1", 1]) + sub_result1 = await redis.execute_pubsub("SSUBSCRIBE", "chan:1") + + reader.queue.put_nowait([b"ssubscribe", b"chan:2", 2]) + sub_result2 = await redis.execute_pubsub("SSUBSCRIBE", "chan:2") + + reader.queue.put_nowait([b"ssubscribe", b"chan:3", 3]) + sub_result3 = await redis.execute_pubsub("SSUBSCRIBE", "chan:3") + + assert sub_result1 == [[b"ssubscribe", b"chan:1", 1]] + assert sub_result2 == [[b"ssubscribe", b"chan:2", 2]] + assert sub_result3 == [[b"ssubscribe", b"chan:3", 3]] + assert redis.in_pubsub == 3 + + reader.queue.put_nowait([b"sunsubscribe", b"chan:1", 2]) + reader.queue.put_nowait([b"sunsubscribe", b"chan:3", 1]) + reader.queue.put_nowait([b"sunsubscribe", b"chan:2", 0]) + reader.queue.put_nowait(MovedError("MOVED 1 127.0.0.1:6379")) + await moment() + + reader.queue.put_nowait([b"sunsubscribe", b"chan:3", 0]) + unsub_result3 = await redis.execute_pubsub("SUNSUBSCRIBE", "chan:3") + + assert unsub_result3 == [[b"sunsubscribe", b"chan:3", 0]] + + await moment() + + assert redis.in_pubsub == 0 + + assert redis._reader_task is not None + assert redis._reader_task.done() is False + + +async def test_execute__ping(add_async_finalizer): + reader = get_mocked_reader() + writer = get_mocked_writer() + redis = RedisConnection(reader=reader, writer=writer, address="localhost:6379") + redis._in_pubsub = 1 + add_async_finalizer(lambda: close_connection(redis)) + + ping1_task = asyncio.ensure_future(redis.execute("PING")) + subs_task = asyncio.ensure_future(redis.execute_pubsub("SUBSCRIBE", "chan")) + await moment() + ping2_task = asyncio.ensure_future(redis.execute("PING")) + ping3_task = asyncio.ensure_future(redis.execute("PING", "my_message")) + reader.queue.put_nowait(b"PONG") + reader.queue.put_nowait((b"subscribe", b"chan", 1)) + reader.queue.put_nowait(b"PONG") + reader.queue.put_nowait((b"pong", "my_message")) + await moment(2) + + assert redis.in_pubsub == 1 + + assert ping1_task.done() is True + assert subs_task.done() is True + assert ping2_task.done() is True + assert ping3_task.done() is True + + assert redis._reader_task is not None + assert redis._reader_task.done() is False + + +async def test_subscribe_and_receive_messages(add_async_finalizer): + reader = get_mocked_reader() + writer = get_mocked_writer() + redis = RedisConnection(reader=reader, writer=writer, address="localhost:6379") + add_async_finalizer(lambda: close_connection(redis)) + + reader.queue.put_nowait([b"subscribe", b"chan", 1]) + reader.queue.put_nowait([b"ssubscribe", b"chan:{shard}", 1]) + reader.queue.put_nowait([b"psubscribe", b"chan:*", 2]) + + await redis.execute_pubsub("SUBSCRIBE", "chan") + await redis.execute_pubsub("SSUBSCRIBE", "chan:{shard}") + await redis.execute_pubsub("PSUBSCRIBE", "chan:*") + + channel = redis.pubsub_channels["chan"] + pattern = redis.pubsub_patterns["chan:*"] + sharded = redis.sharded_pubsub_channels["chan:{shard}"] + + reader.queue.put_nowait([b"smessage", b"chan:{shard}", b"sharded_msg"]) + reader.queue.put_nowait([b"pmessage", b"chan:*", b"chan:foo", b"pattern_msg"]) + reader.queue.put_nowait([b"message", b"chan", b"channel_msg"]) - s = redis.execute_pubsub("SSUBSCRIBE", "a") - reader.queue.put_nowait((b"ssubscribe", b"a", 10)) - await s + await moment() - s = redis.execute_pubsub("SSUBSCRIBE", "b") - await reader.queue.put(MovedError("1 1 127.0.0.1:6379")) - with pytest.raises(MovedError): - await asyncio.wait_for(s, timeout=1) - assert not redis._reader_task.done(), redis._reader_task.exception() + channel_msg = await channel.get() + pattern_msg = await pattern.get() + sharded_msg = await sharded.get() - reader.queue.put_nowait((b"smessage", b"a", b"123")) - assert not redis._reader_task.done() - redis.close() - await redis.wait_closed() + assert channel_msg == b"channel_msg" + assert pattern_msg == (b"chan:foo", b"pattern_msg") + assert sharded_msg == b"sharded_msg" From 8eb8010d24be2523fae555a5b79d68dd01a6baad Mon Sep 17 00:00:00 2001 From: Anton Ilyushenkov Date: Wed, 6 Dec 2023 12:19:51 +0300 Subject: [PATCH 02/17] fix tests --- src/aioredis_cluster/cluster.py | 6 +++--- src/aioredis_cluster/connection.py | 4 ++-- src/aioredis_cluster/crc.py | 6 +++--- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/aioredis_cluster/cluster.py b/src/aioredis_cluster/cluster.py index 429d39f..967d7dc 100644 --- a/src/aioredis_cluster/cluster.py +++ b/src/aioredis_cluster/cluster.py @@ -29,7 +29,7 @@ from aioredis_cluster.command_info import CommandInfo, extract_keys from aioredis_cluster.commands import RedisCluster from aioredis_cluster.compat.asyncio import timeout as atimeout -from aioredis_cluster.crc import CrossSlotKeysError, determine_slot +from aioredis_cluster.crc import CrossSlotError, determine_slot from aioredis_cluster.errors import ( AskError, ClusterClosedError, @@ -420,8 +420,8 @@ async def authorize(pool) -> None: def determine_slot(self, first_key: bytes, *keys: bytes) -> int: try: return determine_slot(first_key, *keys) - except CrossSlotKeysError: - raise RedisClusterError(str(CrossSlotKeysError)) from None + except CrossSlotError as e: + raise RedisClusterError(str(e)) from None async def all_masters(self) -> List[Redis]: ctx = self._make_exec_context((b"PING",), {}) diff --git a/src/aioredis_cluster/connection.py b/src/aioredis_cluster/connection.py index c66cd8f..fb1b5a1 100644 --- a/src/aioredis_cluster/connection.py +++ b/src/aioredis_cluster/connection.py @@ -50,7 +50,7 @@ PUBSUB_SUBSCRIBE_COMMANDS, PubSubType, ) -from aioredis_cluster.crc import CrossSlotKeysError, determine_slot +from aioredis_cluster.crc import CrossSlotError, determine_slot from aioredis_cluster.errors import MovedError, RedisError from aioredis_cluster.typedef import PClosableConnection from aioredis_cluster.util import encode_command, ensure_bytes @@ -394,7 +394,7 @@ async def execute_pubsub(self, command, *channels: Union[bytes, str, AbcChannel] if channel_type is PubSubType.SHARDED: try: key_slot = determine_slot(*(ensure_bytes(name) for name in channels_obj.keys())) - except CrossSlotKeysError: + except CrossSlotError: raise ValueError( f"Not all channels shared one key slot in cluster {channels!r}" ) from None diff --git a/src/aioredis_cluster/crc.py b/src/aioredis_cluster/crc.py index 6c06f75..4bc6e77 100644 --- a/src/aioredis_cluster/crc.py +++ b/src/aioredis_cluster/crc.py @@ -11,7 +11,7 @@ "crc16", "key_slot", "determine_slot", - "CrossSlotKeysError", + "CrossSlotError", ) REDIS_CLUSTER_HASH_SLOTS = 16384 @@ -47,7 +47,7 @@ def py_key_slot(k: bytes, bucket: int = REDIS_CLUSTER_HASH_SLOTS) -> int: key_slot = py_key_slot -class CrossSlotKeysError(Exception): +class CrossSlotError(Exception): pass @@ -55,6 +55,6 @@ def determine_slot(first_key: bytes, *keys: bytes) -> int: slot: int = key_slot(first_key) for k in keys: if slot != key_slot(k): - raise CrossSlotKeysError("all keys must map to the same key slot") + raise CrossSlotError("all keys must map to the same key slot") return slot From e180356d9c75c7d3b2e8ff1837c470e6aeac6aff Mon Sep 17 00:00:00 2001 From: Anton Ilyushenkov Date: Wed, 6 Dec 2023 14:18:57 +0300 Subject: [PATCH 03/17] fix tests --- src/aioredis_cluster/_aioredis/pubsub.py | 17 ++++----- .../commands/sharded_pubsub.py | 6 ++-- src/aioredis_cluster/connection.py | 25 +++++++++++-- src/aioredis_cluster/pool.py | 22 ++++++------ tests/system_tests/test_connection_pubsub.py | 35 +++++++++++++++---- .../test_connection_pubsub.py | 19 ++++++++++ 6 files changed, 91 insertions(+), 33 deletions(-) diff --git a/src/aioredis_cluster/_aioredis/pubsub.py b/src/aioredis_cluster/_aioredis/pubsub.py index 9da6413..47e3fdb 100644 --- a/src/aioredis_cluster/_aioredis/pubsub.py +++ b/src/aioredis_cluster/_aioredis/pubsub.py @@ -4,6 +4,7 @@ import sys import types import warnings +from typing import Optional from .abc import AbcChannel from .errors import ChannelClosedError @@ -37,17 +38,17 @@ def __repr__(self): ) @property - def name(self): + def name(self) -> bytes: """Encoded channel name/pattern.""" return self._name @property - def is_pattern(self): + def is_pattern(self) -> bool: """Set to True if channel is subscribed to pattern.""" return self._is_pattern @property - def is_active(self): + def is_active(self) -> bool: """Returns True until there are messages in channel or connection is subscribed to it. @@ -87,11 +88,11 @@ async def get(self, *, encoding=None, decoder=None): return dest_channel, msg return msg - async def get_json(self, encoding="utf-8"): + async def get_json(self, encoding: str = "utf-8"): """Shortcut to get JSON messages.""" return await self.get(encoding=encoding, decoder=json.loads) - def iter(self, *, encoding=None, decoder=None): + def iter(self, *, encoding: Optional[str] = None, decoder=None): """Same as get method but its native coroutine. Usage example: @@ -103,7 +104,7 @@ def iter(self, *, encoding=None, decoder=None): self, is_active=lambda ch: ch.is_active, encoding=encoding, decoder=decoder ) - async def wait_message(self): + async def wait_message(self) -> bool: """Waits for message to become available in channel or channel is closed (unsubscribed). @@ -121,10 +122,10 @@ async def wait_message(self): # internal methods - def put_nowait(self, data): + def put_nowait(self, data) -> None: self._queue.put(data) - def close(self, exc=None): + def close(self, exc: Optional[BaseException] = None) -> None: """Marks channel as inactive. Internal method, will be called from connection diff --git a/src/aioredis_cluster/commands/sharded_pubsub.py b/src/aioredis_cluster/commands/sharded_pubsub.py index eb8ccaf..3df0b84 100644 --- a/src/aioredis_cluster/commands/sharded_pubsub.py +++ b/src/aioredis_cluster/commands/sharded_pubsub.py @@ -1,8 +1,8 @@ -from typing import Awaitable, Callable +from typing import Awaitable, Callable, Mapping from aioredis_cluster._aioredis.commands.pubsub import wait_return_channels from aioredis_cluster._aioredis.util import wait_make_dict -from aioredis_cluster.abc import AbcConnection +from aioredis_cluster.abc import AbcChannel, AbcConnection class ShardedPubSubCommandsMixin: @@ -54,7 +54,7 @@ def pubsub_shardnumsub(self, *channels): return wait_make_dict(self.execute(b"PUBSUB", b"SHARDNUMSUB", *channels)) @property - def sharded_pubsub_channels(self): + def sharded_pubsub_channels(self) -> Mapping[str, AbcChannel]: """Returns read-only channels dict. See :attr:`~aioredis.RedisConnection.pubsub_channels` diff --git a/src/aioredis_cluster/connection.py b/src/aioredis_cluster/connection.py index fb1b5a1..e7b387a 100644 --- a/src/aioredis_cluster/connection.py +++ b/src/aioredis_cluster/connection.py @@ -166,6 +166,9 @@ def slot_channels_unsubscribe(self, key_slot: int) -> None: channel = self._sharded.pop(channel_name) channel.close() + def have_slot_channels(self, key_slot: int) -> bool: + return key_slot in self._slot_to_sharded + @property def channels_total(self) -> int: return len(self._channels) + len(self._patterns) + len(self._sharded) @@ -449,8 +452,15 @@ async def execute_pubsub(self, command, *channels: Union[bytes, str, AbcChannel] cb=self._process_pubsub, ) ) - server_reply = list(await fut) - if server_reply != res[0]: + + try: + server_reply = await fut + except ReplyError: + # return PubSub mode to closed state if any reply error received + self._client_in_pubsub = False + raise + + if list(server_reply) != res[0]: if logger.isEnabledFor(logging.DEBUG): logger.error( "Unexpected server reply on PubSub on %r: %r, expected %r", @@ -668,7 +678,16 @@ async def _read_data(self) -> None: self._process_pubsub(obj) else: if isinstance(obj, RedisError): - if isinstance(obj, ReplyError): + if isinstance(obj, MovedError): + if self._pubsub_channels_store.have_slot_channels(obj.info.slot_id): + logger.warning( + "Received MOVED. Unsubscribe all channels from %d slot", + obj.info.slot_id, + ) + self._pubsub_channels_store.slot_channels_unsubscribe( + obj.info.slot_id + ) + elif isinstance(obj, ReplyError): if obj.args[0].startswith("READONLY"): obj = ReadOnlyError(obj.args[0]) self._wakeup_waiter_with_exc(obj) diff --git a/src/aioredis_cluster/pool.py b/src/aioredis_cluster/pool.py index 09c8de5..cdf33e8 100644 --- a/src/aioredis_cluster/pool.py +++ b/src/aioredis_cluster/pool.py @@ -3,15 +3,15 @@ import logging import random import types -from typing import Deque, Dict, List, Mapping, Optional, Set, Tuple, Type, Union +from typing import Deque, List, Mapping, Optional, Set, Tuple, Type, Union from aioredis_cluster._aioredis.pool import ( _AsyncConnectionContextManager, _ConnectionContextManager, ) from aioredis_cluster._aioredis.util import CloseEvent -from aioredis_cluster.abc import AbcConnection, AbcPool -from aioredis_cluster.aioredis import Channel, PoolClosedError, create_connection +from aioredis_cluster.abc import AbcChannel, AbcConnection, AbcPool +from aioredis_cluster.aioredis import PoolClosedError, create_connection from aioredis_cluster.command_info.commands import ( BLOCKING_COMMANDS, PATTERN_PUBSUB_COMMANDS, @@ -313,21 +313,19 @@ def in_pubsub(self) -> int: return in_pubsub @property - def pubsub_channels(self) -> Mapping[str, Channel]: - channels: Dict[str, Channel] = {} + def pubsub_channels(self) -> Mapping[str, AbcChannel]: if self._pubsub_conn and not self._pubsub_conn.closed: - channels.update(self._pubsub_conn.pubsub_channels) - return types.MappingProxyType(channels) + return self._pubsub_conn.pubsub_channels + return types.MappingProxyType({}) @property - def sharded_pubsub_channels(self) -> Mapping[str, Channel]: - channels: Dict[str, Channel] = {} + def sharded_pubsub_channels(self) -> Mapping[str, AbcChannel]: if self._sharded_pubsub_conn and not self._sharded_pubsub_conn.closed: - channels.update(self._sharded_pubsub_conn.sharded_pubsub_channels) - return types.MappingProxyType(channels) + return self._sharded_pubsub_conn.sharded_pubsub_channels + return types.MappingProxyType({}) @property - def pubsub_patterns(self): + def pubsub_patterns(self) -> Mapping[str, AbcChannel]: if self._pubsub_conn and not self._pubsub_conn.closed: return self._pubsub_conn.pubsub_patterns return types.MappingProxyType({}) diff --git a/tests/system_tests/test_connection_pubsub.py b/tests/system_tests/test_connection_pubsub.py index 30e6e12..737b7ee 100644 --- a/tests/system_tests/test_connection_pubsub.py +++ b/tests/system_tests/test_connection_pubsub.py @@ -1,20 +1,41 @@ +import asyncio from string import ascii_letters +from typing import Awaitable, Callable import pytest -from aioredis_cluster.errors import MovedError +from aioredis_cluster import Cluster @pytest.mark.redis_version(gte="7.0.0") @pytest.mark.timeout(2) -async def test_moved_with_pubsub(cluster): - c = await cluster() - redis = await c.keys_master("a") +async def test_moved_with_pubsub(cluster: Callable[[], Awaitable[Cluster]]): + client = await cluster() + redis = await client.keys_master("a") await redis.ssubscribe("a") - with pytest.raises(MovedError): - for b in ascii_letters: - await redis.ssubscribe(b) + assert "a" in redis.sharded_pubsub_channels + assert b"a" in redis.sharded_pubsub_channels + + channels_dump = {} + for char in ascii_letters: + await redis.ssubscribe(char) + # Channel objects creates immediately and close and removed after received MOVED reply + channels_dump[char] = redis.sharded_pubsub_channels[char] + + await asyncio.sleep(0.01) + + # check number of created Channel objects + assert len(channels_dump) == len(ascii_letters) + + # check of closed Channels after MovedError reply received + num_of_closed = 0 + for channel in channels_dump.values(): + if not channel.is_active: + num_of_closed += 1 + + assert num_of_closed > 0 + assert len(redis.sharded_pubsub_channels) < len(channels_dump) redis.close() await redis.wait_closed() diff --git a/tests/unit_tests/aioredis_cluster/test_connection_pubsub.py b/tests/unit_tests/aioredis_cluster/test_connection_pubsub.py index 7c44e72..365efcb 100644 --- a/tests/unit_tests/aioredis_cluster/test_connection_pubsub.py +++ b/tests/unit_tests/aioredis_cluster/test_connection_pubsub.py @@ -9,6 +9,8 @@ from aioredis_cluster.connection import RedisConnection from aioredis_cluster.errors import MovedError, RedisError +pytestmark = [pytest.mark.timeout(1)] + async def moment(times: int = 1) -> None: for _ in range(times): @@ -308,6 +310,23 @@ async def test_execute__unexpectable_unsubscribe_and_moved(add_async_finalizer): assert redis._reader_task.done() is False +async def test_execute__ssubscribe_with_first_moved(add_async_finalizer): + reader = get_mocked_reader() + writer = get_mocked_writer() + redis = RedisConnection(reader=reader, writer=writer, address="localhost:6379") + add_async_finalizer(lambda: close_connection(redis)) + + reader.queue.put_nowait(MovedError("MOVED 10271 127.0.0.1:6379")) + + with pytest.raises(MovedError, match="MOVED 10271 127.0.0.1:6379"): + await redis.execute_pubsub("SSUBSCRIBE", "chan1:{shard1}") + + assert redis.in_pubsub == 0 + assert redis._client_in_pubsub is False + assert redis._server_in_pubsub is False + assert redis._reader_task.done() is False + + async def test_execute__client_unsubscribe_with_server_unsubscribe(add_async_finalizer): reader = get_mocked_reader() writer = get_mocked_writer() From 5f21d478c6741d7e275e4d3e0681ed692d740bfd Mon Sep 17 00:00:00 2001 From: Anton Ilyushenkov Date: Wed, 6 Dec 2023 15:04:09 +0300 Subject: [PATCH 04/17] fix tests --- src/aioredis_cluster/connection.py | 107 ++++++++++++------ .../test_connection_pubsub.py | 18 ++- 2 files changed, 83 insertions(+), 42 deletions(-) diff --git a/src/aioredis_cluster/connection.py b/src/aioredis_cluster/connection.py index e7b387a..159eafb 100644 --- a/src/aioredis_cluster/connection.py +++ b/src/aioredis_cluster/connection.py @@ -59,12 +59,14 @@ TExecuteCallback = Callable[[Any], Any] +TExecuteErrCallback = Callable[[Exception], Exception] class ExecuteWaiter(NamedTuple): fut: asyncio.Future enc: Optional[str] cb: Optional[TExecuteCallback] + err_cb: Optional[TExecuteErrCallback] class PParserFactory(Protocol): @@ -295,7 +297,7 @@ async def auth(self, password: str) -> bool: fut = self.execute(b"AUTH", password) return await wait_ok(fut) - async def execute(self, command, *args, encoding=_NOTSET) -> Any: + def execute(self, command, *args, encoding=_NOTSET) -> Any: """Executes redis command and returns Future waiting for the answer. Raises: @@ -347,12 +349,13 @@ async def execute(self, command, *args, encoding=_NOTSET) -> Any: fut=fut, enc=encoding, cb=cb, + err_cb=None, ) ) - return await fut + return fut - async def execute_pubsub(self, command, *channels: Union[bytes, str, AbcChannel]): + def execute_pubsub(self, command, *channels: Union[bytes, str, AbcChannel]): """Executes redis (p|s)subscribe/(p|s)unsubscribe commands. Returns asyncio.gather coroutine waiting for all channels/patterns @@ -372,41 +375,44 @@ async def execute_pubsub(self, command, *channels: Union[bytes, str, AbcChannel] key_slot = -1 reply_kind = ensure_bytes(command.lower()) - channels_obj: Dict[str, AbcChannel] + channels_obj: List[AbcChannel] if len(channels) == 0: if is_subscribe_command: raise ValueError("No channels to (un)subscribe") elif channel_type is PubSubType.PATTERN: - channels_obj = dict(self._pubsub_channels_store.patterns) + channels_obj = list(self._pubsub_channels_store.patterns.values()) elif channel_type is PubSubType.SHARDED: - channels_obj = dict(self._pubsub_channels_store.sharded) + channels_obj = list(self._pubsub_channels_store.sharded.values()) else: - channels_obj = dict(self._pubsub_channels_store.channels) + channels_obj = list(self._pubsub_channels_store.channels.values()) else: mkchannel = partial(Channel, is_pattern=is_pattern) - channels_obj = {} + channels_obj = [] for channel_name_or_obj in channels: - if not isinstance(channel_name_or_obj, AbcChannel): + if isinstance(channel_name_or_obj, AbcChannel): + ch = channel_name_or_obj + else: ch = mkchannel(channel_name_or_obj) - if ch.name in channels_obj: - raise ValueError(f"Found channel duplicates in {channels!r}") + # FIXME: processing duplicate channels totally broken in aioredis + # if ch.name in channels_obj: + # raise ValueError(f"Found channel duplicates in {channels!r}") if ch.is_pattern != is_pattern: raise ValueError(f"Not all channels {channels!r} match command {command!r}") - channels_obj[ch.name] = ch + channels_obj.append(ch) if channel_type is PubSubType.SHARDED: try: - key_slot = determine_slot(*(ensure_bytes(name) for name in channels_obj.keys())) + key_slot = determine_slot(*(ensure_bytes(ch.name) for ch in channels_obj)) except CrossSlotError: raise ValueError( f"Not all channels shared one key slot in cluster {channels!r}" ) from None - cmd = encode_command(command, *(name for name in channels_obj.keys())) + cmd = encode_command(command, *(ch.name for ch in channels_obj)) res: List[Any] = [] if is_subscribe_command: - for ch in channels_obj.values(): + for ch in channels_obj: channel_name = ensure_bytes(ch.name) self._pubsub_channels_store.channel_subscribe( channel_type=channel_type, @@ -422,7 +428,7 @@ async def execute_pubsub(self, command, *channels: Union[bytes, str, AbcChannel] # otherwise unsubscribe command else: - for ch in channels_obj.values(): + for ch in channels_obj: channel_name = ensure_bytes(ch.name) self._pubsub_channels_store.channel_unsubscribe( channel_type=channel_type, @@ -449,30 +455,15 @@ async def execute_pubsub(self, command, *channels: Union[bytes, str, AbcChannel] ExecuteWaiter( fut=fut, enc=None, - cb=self._process_pubsub, + cb=self._get_execute_pubsub_callback(command, res), + err_cb=self._get_execute_pubsub_err_callback(), ) ) + else: + fut = self._loop.create_future() + fut.set_result(res) - try: - server_reply = await fut - except ReplyError: - # return PubSub mode to closed state if any reply error received - self._client_in_pubsub = False - raise - - if list(server_reply) != res[0]: - if logger.isEnabledFor(logging.DEBUG): - logger.error( - "Unexpected server reply on PubSub on %r: %r, expected %r", - command, - server_reply, - res[0], - ) - exc = RedisError(f"Unexpected server reply on PubSub {command!r}") - self._do_close(exc) - raise exc - - return res + return fut def get_last_use_generation(self) -> int: return self._last_use_generation @@ -705,6 +696,14 @@ def _wakeup_waiter_with_exc(self, exc: Exception) -> None: return waiter = self._waiters.popleft() + + if waiter.err_cb is not None: + try: + exc = waiter.err_cb(exc) + except Exception as cb_exc: + logger.exception("Waiter error callback failed with exception: %r", cb_exc) + exc = cb_exc + _set_exception(waiter.fut, exc) if self._in_transaction is not None: self._transaction_error = exc @@ -838,3 +837,37 @@ def _buffered(self): self._pipeline_buffer = None else: yield self + + def _get_execute_pubsub_callback( + self, command: Union[str, bytes], expect_replies: List[Any] + ) -> TExecuteCallback: + def callback(server_reply: Any) -> Any: + # this callback processing only first reply on (p|s)(un)subscribe commands + + server_reply = self._process_pubsub(server_reply) + + if list(server_reply) != expect_replies[0]: + if logger.isEnabledFor(logging.DEBUG): + logger.error( + "Unexpected server reply on PubSub on %r: %r, expected %r", + command, + server_reply, + expect_replies[0], + ) + + exc = RedisError(f"Unexpected server reply on PubSub {command!r}") + self._loop.call_soon(self._do_close, exc) + raise exc + + return expect_replies + + return callback + + def _get_execute_pubsub_err_callback(self) -> TExecuteErrCallback: + def callback(exc: Exception) -> Exception: + if isinstance(exc, ReplyError): + # return PubSub mode to closed state if any reply error received + self._client_in_pubsub = False + return exc + + return callback diff --git a/tests/unit_tests/aioredis_cluster/test_connection_pubsub.py b/tests/unit_tests/aioredis_cluster/test_connection_pubsub.py index 365efcb..5f51a44 100644 --- a/tests/unit_tests/aioredis_cluster/test_connection_pubsub.py +++ b/tests/unit_tests/aioredis_cluster/test_connection_pubsub.py @@ -53,6 +53,14 @@ async def close_connection(conn: RedisConnection) -> None: await conn.wait_closed() +async def execute(redis: RedisConnection, *args, **kwargs): + return await redis.execute(*args, **kwargs) + + +async def execute_pubsub(redis: RedisConnection, *args, **kwargs): + return await redis.execute_pubsub(*args, **kwargs) + + async def test_execute__simple_subscribe(add_async_finalizer): reader = get_mocked_reader() writer = get_mocked_writer() @@ -158,12 +166,12 @@ async def test_execute__half_open_pubsub_mode(add_async_finalizer): redis = RedisConnection(reader=reader, writer=writer, address="localhost:6379") add_async_finalizer(lambda: close_connection(redis)) - get_task = asyncio.ensure_future(redis.execute("GET", "foo", encoding="utf-8")) - ping1_task = asyncio.ensure_future(redis.execute("PING", "ping_reply1")) - subs_task = asyncio.ensure_future(redis.execute_pubsub("SSUBSCRIBE", "chan")) + get_task = asyncio.ensure_future(execute(redis, "GET", "foo", encoding="utf-8")) + ping1_task = asyncio.ensure_future(execute(redis, "PING", "ping_reply1")) + subs_task = asyncio.ensure_future(execute_pubsub(redis, "SSUBSCRIBE", "chan")) # SET not send and execute() must raise RedisError exception - set_task = asyncio.ensure_future(redis.execute("SET", "foo", "val2")) - ping2_task = asyncio.ensure_future(redis.execute("PING", "ping_reply2", encoding="utf-8")) + set_task = asyncio.ensure_future(execute(redis, "SET", "foo", "val2")) + ping2_task = asyncio.ensure_future(execute(redis, "PING", "ping_reply2", encoding="utf-8")) # need extra loop for asyncio.ensure_future starts a tasks await moment() From a5cdd7b7b2176dcca623fd68d1d5f1306be8e2f0 Mon Sep 17 00:00:00 2001 From: Anton Ilyushenkov Date: Wed, 6 Dec 2023 16:30:26 +0300 Subject: [PATCH 05/17] add logging --- src/aioredis_cluster/connection.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/aioredis_cluster/connection.py b/src/aioredis_cluster/connection.py index 159eafb..111a946 100644 --- a/src/aioredis_cluster/connection.py +++ b/src/aioredis_cluster/connection.py @@ -773,11 +773,12 @@ def _process_pubsub(self, obj: Any) -> Any: channel_name: bytes if kind in {b"subscribe", b"psubscribe", b"ssubscribe"}: - logger.debug("PubSub subscribe confirmation received: %r", obj) + logger.debug("PubSub %s event received: %r", kind, obj) # confirm PubSub mode in client side based on server reply and reset pending flag if self._client_in_pubsub and not self._server_in_pubsub: self._server_in_pubsub = True elif kind in {b"unsubscribe", b"punsubscribe", b"sunsubscribe"}: + logger.debug("PubSub %s event received: %r", kind, obj) (channel_name,) = args channel_type = PUBSUB_RESP_KIND_TO_TYPE[kind] self._pubsub_channels_store.channel_unsubscribe( From 5a3e4670ec05946428743b006e0367ac4ed2a1af Mon Sep 17 00:00:00 2001 From: Anton Ilyushenkov Date: Wed, 6 Dec 2023 17:48:06 +0300 Subject: [PATCH 06/17] restrict exit in PubSub mode for RedisConnection --- src/aioredis_cluster/connection.py | 48 +++++++++---- .../test_connection_pubsub.py | 71 ++++++++++++++++++- 2 files changed, 105 insertions(+), 14 deletions(-) diff --git a/src/aioredis_cluster/connection.py b/src/aioredis_cluster/connection.py index 111a946..01d278e 100644 --- a/src/aioredis_cluster/connection.py +++ b/src/aioredis_cluster/connection.py @@ -95,6 +95,7 @@ def __init__(self) -> None: self._sharded: coerced_keys_dict[AbcChannel] = coerced_keys_dict() self._sharded_to_slot: Dict[bytes, int] = {} self._slot_to_sharded: Dict[int, Set[bytes]] = {} + self._pending_unsubscribe: Set[Tuple[PubSubType, bytes]] = set() @property def channels(self) -> Mapping[str, AbcChannel]: @@ -111,6 +112,14 @@ def sharded(self) -> Mapping[str, AbcChannel]: """Returns read-only sharded channels dict.""" return MappingProxyType(self._sharded) + def channel_pending_unsubscribe( + self, + *, + channel_type: PubSubType, + channel_name: bytes, + ) -> None: + self._pending_unsubscribe.add((channel_type, channel_name)) + def channel_subscribe( self, *, @@ -183,6 +192,16 @@ def channels_num(self) -> int: def sharded_channels_num(self) -> int: return len(self._sharded) + def has_channel(self, channel_type: PubSubType, channel_name: bytes) -> bool: + ret = False + if channel_type is PubSubType.CHANNEL: + ret = channel_name in self._channels + elif channel_type is PubSubType.PATTERN: + ret = channel_name in self._patterns + elif channel_type is PubSubType.SHARDED: + ret = channel_name in self._sharded + return ret + def get_channel(self, channel_type: PubSubType, channel_name: bytes) -> AbcChannel: if channel_type is PubSubType.CHANNEL: channel = self._channels[channel_name] @@ -785,19 +804,24 @@ def _process_pubsub(self, obj: Any) -> Any: channel_type=channel_type, channel_name=channel_name, ) - if self._pubsub_channels_store.channels_total == 0: - self._server_in_pubsub = False - self._client_in_pubsub = False - elif kind in {b"message", b"smessage"}: - (channel_name,) = args - channel_type = PUBSUB_RESP_KIND_TO_TYPE[kind] - channel = self._pubsub_channels_store.get_channel(channel_type, channel_name) - channel.put_nowait(data) - elif kind == b"pmessage": - (pattern, channel_name) = args + elif kind in {b"message", b"smessage", b"pmessage"}: + if kind == b"pmessage": + (pattern, channel_name) = args + else: + (channel_name,) = args + pattern = channel_name + channel_type = PUBSUB_RESP_KIND_TO_TYPE[kind] - channel = self._pubsub_channels_store.get_channel(channel_type, pattern) - channel.put_nowait((channel_name, data)) + if self._pubsub_channels_store.has_channel(channel_type, pattern): + channel = self._pubsub_channels_store.get_channel(channel_type, pattern) + if channel_type is PubSubType.PATTERN: + channel.put_nowait((channel_name, data)) + else: + channel.put_nowait(data) + else: + logger.warning( + "No channel %r with type %s for received message", pattern, channel_type + ) elif kind == b"pong": if not self._waiters: logger.error("No PubSub PONG waiters for received data %r", data) diff --git a/tests/unit_tests/aioredis_cluster/test_connection_pubsub.py b/tests/unit_tests/aioredis_cluster/test_connection_pubsub.py index 5f51a44..1340237 100644 --- a/tests/unit_tests/aioredis_cluster/test_connection_pubsub.py +++ b/tests/unit_tests/aioredis_cluster/test_connection_pubsub.py @@ -1,9 +1,11 @@ import asyncio +import logging from asyncio.queues import Queue from unittest import mock import pytest +from aioredis_cluster.aioredis import ChannelClosedError from aioredis_cluster.aioredis.stream import StreamReader from aioredis_cluster.command_info.commands import PUBSUB_SUBSCRIBE_COMMANDS from aioredis_cluster.connection import RedisConnection @@ -121,8 +123,8 @@ async def test_execute__simple_unsubscribe(add_async_finalizer): assert result_channel == [[b"unsubscribe", b"chan", 1]] assert result_pattern == [[b"punsubscribe", b"chan", 0]] assert result_sharded == [[b"sunsubscribe", b"chan", 0]] - assert redis._client_in_pubsub is False - assert redis._server_in_pubsub is False + assert redis._client_in_pubsub is True + assert redis._server_in_pubsub is True @pytest.mark.parametrize( @@ -434,3 +436,68 @@ async def test_subscribe_and_receive_messages(add_async_finalizer): assert channel_msg == b"channel_msg" assert pattern_msg == (b"chan:foo", b"pattern_msg") assert sharded_msg == b"sharded_msg" + + assert redis._reader_task is not None + assert redis._reader_task.done() is False + + +async def test_receive_message_after_unsubscribe(caplog, add_async_finalizer): + reader = get_mocked_reader() + writer = get_mocked_writer() + redis = RedisConnection(reader=reader, writer=writer, address="localhost:6379") + add_async_finalizer(lambda: close_connection(redis)) + + with caplog.at_level(logging.WARNING): + reader.queue.put_nowait([b"ssubscribe", b"chan:{shard}", 1]) + await redis.execute_pubsub("SSUBSCRIBE", "chan:{shard}") + sharded = redis.sharded_pubsub_channels["chan:{shard}"] + await redis.execute_pubsub("SUNSUBSCRIBE", "chan:{shard}") + + reader.queue.put_nowait([b"smessage", b"chan:{shard}", b"sharded_msg"]) + + await moment() + + assert sharded.is_active is False + assert sharded._queue.qsize() == 0 + with pytest.raises(ChannelClosedError): + await sharded.get() + + no_channel_record = "" + for record in caplog.records: + assert "No waiter for received reply" not in record.message, record.message + if "No channel" in record.message and "for received message" in record.message: + no_channel_record = record.message + + assert no_channel_record != "" + + assert redis._reader_task is not None + assert redis._reader_task.done() is False + + +async def test_subscribe_and_immediately_unsubscribe(caplog, add_async_finalizer): + reader = get_mocked_reader() + writer = get_mocked_writer() + redis = RedisConnection(reader=reader, writer=writer, address="localhost:6379") + add_async_finalizer(lambda: close_connection(redis)) + + reader.queue.put_nowait([b"ssubscribe", b"chan1:{shard}", 1]) + reader.queue.put_nowait([b"ssubscribe", b"chan2:{shard}", 2]) + + await redis.execute_pubsub("SSUBSCRIBE", "chan1:{shard}") + await redis.execute_pubsub("SSUBSCRIBE", "chan2:{shard}") + + with caplog.at_level(logging.ERROR): + await redis.execute_pubsub("SUNSUBSCRIBE", "chan2:{shard}") + await redis.execute_pubsub("SUNSUBSCRIBE", "chan1:{shard}") + reader.queue.put_nowait([b"sunsubscribe", b"chan2:{shard}", 1]) + reader.queue.put_nowait([b"sunsubscribe", b"chan1:{shard}", 0]) + + await moment(2) + + for record in caplog.records: + assert "No waiter for received reply" not in record.message, record.message + + assert redis.in_pubsub == 0 + + assert redis._reader_task is not None + assert redis._reader_task.done() is False From 0db8ebf7e65fce1ccacb31a80cbf9c625635e12c Mon Sep 17 00:00:00 2001 From: Anton Ilyushenkov Date: Wed, 6 Dec 2023 20:20:49 +0300 Subject: [PATCH 07/17] in_pubsub property now indicates boolead flag instead numbers of created channels for Cluster, Pool, RedisConnections --- src/aioredis_cluster/cluster.py | 5 ++- src/aioredis_cluster/connection.py | 14 ++------ src/aioredis_cluster/pool.py | 17 ++++++---- tests/system_tests/test_redis_cluster.py | 6 ++-- .../test_connection_pubsub.py | 33 ++++++++++++++----- 5 files changed, 44 insertions(+), 31 deletions(-) diff --git a/src/aioredis_cluster/cluster.py b/src/aioredis_cluster/cluster.py index 967d7dc..93cb73e 100644 --- a/src/aioredis_cluster/cluster.py +++ b/src/aioredis_cluster/cluster.py @@ -362,7 +362,10 @@ def in_pubsub(self) -> int: Can be tested as bool indicating Pub/Sub mode state. """ - return sum(p.in_pubsub for p in self._pooler.pools()) + for pool in self._pooler.pools(): + if pool.in_pubsub: + return 1 + return 0 @property def pubsub_channels(self) -> Mapping[str, AbcChannel]: diff --git a/src/aioredis_cluster/connection.py b/src/aioredis_cluster/connection.py index 01d278e..bfd33f9 100644 --- a/src/aioredis_cluster/connection.py +++ b/src/aioredis_cluster/connection.py @@ -95,7 +95,6 @@ def __init__(self) -> None: self._sharded: coerced_keys_dict[AbcChannel] = coerced_keys_dict() self._sharded_to_slot: Dict[bytes, int] = {} self._slot_to_sharded: Dict[int, Set[bytes]] = {} - self._pending_unsubscribe: Set[Tuple[PubSubType, bytes]] = set() @property def channels(self) -> Mapping[str, AbcChannel]: @@ -112,14 +111,6 @@ def sharded(self) -> Mapping[str, AbcChannel]: """Returns read-only sharded channels dict.""" return MappingProxyType(self._sharded) - def channel_pending_unsubscribe( - self, - *, - channel_type: PubSubType, - channel_name: bytes, - ) -> None: - self._pending_unsubscribe.add((channel_type, channel_name)) - def channel_subscribe( self, *, @@ -531,9 +522,10 @@ def in_transaction(self) -> bool: def in_pubsub(self) -> int: """Indicates that connection is in PUB/SUB mode. - Provides the number of subscribed channels. + This implementation NOT provides the number of subscribed channels + and provides only boolean flag """ - return self._pubsub_channels_store.channels_total + return int(self._client_in_pubsub) async def select(self, db: int) -> bool: """Change the selected database for the current connection.""" diff --git a/src/aioredis_cluster/pool.py b/src/aioredis_cluster/pool.py index cdf33e8..6bde084 100644 --- a/src/aioredis_cluster/pool.py +++ b/src/aioredis_cluster/pool.py @@ -305,12 +305,15 @@ async def set_readonly(self, value: bool) -> None: @property def in_pubsub(self) -> int: - in_pubsub = 0 - if self._pubsub_conn and not self._pubsub_conn.closed: - in_pubsub += self._pubsub_conn.in_pubsub - if self._sharded_pubsub_conn and not self._sharded_pubsub_conn.closed: - in_pubsub += self._sharded_pubsub_conn.in_pubsub - return in_pubsub + if self._pubsub_conn and not self._pubsub_conn.closed and self._pubsub_conn.in_pubsub: + return 1 + if ( + self._sharded_pubsub_conn + and not self._sharded_pubsub_conn.closed + and self._sharded_pubsub_conn.in_pubsub + ): + return 1 + return 0 @property def pubsub_channels(self) -> Mapping[str, AbcChannel]: @@ -379,7 +382,7 @@ def release(self, conn: AbcConnection) -> None: logger.warning("Connection %r is in transaction, closing it.", conn) conn.close() elif conn.in_pubsub: - logger.warning("Connection %r is in subscribe mode, closing it.", conn) + logger.warning("Connection %r is in PubSub mode, closing it.", conn) conn.close() elif conn._waiters: logger.warning("Connection %r has pending commands, closing it.", conn) diff --git a/tests/system_tests/test_redis_cluster.py b/tests/system_tests/test_redis_cluster.py index 3a97fd4..1942b8c 100644 --- a/tests/system_tests/test_redis_cluster.py +++ b/tests/system_tests/test_redis_cluster.py @@ -156,7 +156,7 @@ async def test_sharded_pubsub(redis_cluster): ch2: Channel = channels[0] assert len(cl1.sharded_pubsub_channels) == 2 - assert cl1.in_pubsub == 2 + assert cl1.in_pubsub == 1 assert len(cl1.channels) == 0 assert len(cl1.patterns) == 0 @@ -173,7 +173,7 @@ async def test_sharded_pubsub(redis_cluster): await cl1.sunsubscribe("channel2") assert len(cl1.sharded_pubsub_channels) == 0 - assert cl1.in_pubsub == 0 + assert cl1.in_pubsub == 1 @pytest.mark.redis_version(gte="7.0.0") @@ -186,7 +186,7 @@ async def test_sharded_pubsub__multiple_subscribe(redis_cluster): ch3: Channel = (await cl1.ssubscribe("channel:{shard_key}:3"))[0] assert len(cl1.sharded_pubsub_channels) == 3 - assert cl1.in_pubsub == 3 + assert cl1.in_pubsub == 1 shard_pool = await cl1.keys_master("{shard_key}") assert len(shard_pool.sharded_pubsub_channels) == 3 diff --git a/tests/unit_tests/aioredis_cluster/test_connection_pubsub.py b/tests/unit_tests/aioredis_cluster/test_connection_pubsub.py index 1340237..ebdc32b 100644 --- a/tests/unit_tests/aioredis_cluster/test_connection_pubsub.py +++ b/tests/unit_tests/aioredis_cluster/test_connection_pubsub.py @@ -80,7 +80,7 @@ async def test_execute__simple_subscribe(add_async_finalizer): assert result_channel == [[b"subscribe", b"chan", 1]] assert result_pattern == [[b"psubscribe", b"chan", 2]] assert result_sharded == [[b"ssubscribe", b"chan", 1]] - assert redis.in_pubsub == 3 + assert redis.in_pubsub == 1 assert redis._client_in_pubsub is True assert redis._server_in_pubsub is True assert len(redis._waiters) == 0 @@ -108,7 +108,10 @@ async def test_execute__simple_unsubscribe(add_async_finalizer): await redis.execute_pubsub("PSUBSCRIBE", "chan") await redis.execute_pubsub("SUBSCRIBE", "chan") - assert redis.in_pubsub == 3 + assert redis.in_pubsub == 1 + assert len(redis.pubsub_channels) == 1 + assert len(redis.pubsub_patterns) == 1 + assert len(redis.sharded_pubsub_channels) == 1 reader.queue.put_nowait([b"unsubscribe", b"chan", 1]) reader.queue.put_nowait([b"punsubscribe", b"chan", 0]) @@ -119,7 +122,10 @@ async def test_execute__simple_unsubscribe(add_async_finalizer): await moment() - assert redis.in_pubsub == 0 + assert redis.in_pubsub == 1 + assert len(redis.pubsub_channels) == 0 + assert len(redis.pubsub_patterns) == 0 + assert len(redis.sharded_pubsub_channels) == 0 assert result_channel == [[b"unsubscribe", b"chan", 1]] assert result_pattern == [[b"punsubscribe", b"chan", 0]] assert result_sharded == [[b"sunsubscribe", b"chan", 0]] @@ -244,7 +250,10 @@ async def test__redis_push_unsubscribe(add_async_finalizer): await redis.execute_pubsub("PSUBSCRIBE", "chan:3", "chan:4") await redis.execute_pubsub("SSUBSCRIBE", "chan:5:{shard}", "chan:6:{shard}") - assert redis.in_pubsub == 6 + assert redis.in_pubsub == 1 + assert len(redis.pubsub_channels) == 2 + assert len(redis.pubsub_patterns) == 2 + assert len(redis.sharded_pubsub_channels) == 2 reader.queue.put_nowait([b"unsubscribe", b"chan:1", 3]) reader.queue.put_nowait([b"unsubscribe", b"chan:2", 2]) @@ -257,7 +266,10 @@ async def test__redis_push_unsubscribe(add_async_finalizer): await moment() - assert redis.in_pubsub == 0 + assert redis.in_pubsub == 1 + assert len(redis.pubsub_channels) == 0 + assert len(redis.pubsub_patterns) == 0 + assert len(redis.sharded_pubsub_channels) == 0 assert redis._reader_task.done() is False @@ -316,7 +328,7 @@ async def test_execute__unexpectable_unsubscribe_and_moved(add_async_finalizer): reader.queue.put_nowait(MovedError("MOVED 10271 127.0.0.1:6379")) await moment() - assert redis.in_pubsub == 0 + assert redis.in_pubsub == 1 assert redis._reader_task.done() is False @@ -355,7 +367,8 @@ async def test_execute__client_unsubscribe_with_server_unsubscribe(add_async_fin assert sub_result1 == [[b"ssubscribe", b"chan:1", 1]] assert sub_result2 == [[b"ssubscribe", b"chan:2", 2]] assert sub_result3 == [[b"ssubscribe", b"chan:3", 3]] - assert redis.in_pubsub == 3 + assert redis.in_pubsub == 1 + assert len(redis.sharded_pubsub_channels) == 3 reader.queue.put_nowait([b"sunsubscribe", b"chan:1", 2]) reader.queue.put_nowait([b"sunsubscribe", b"chan:3", 1]) @@ -370,7 +383,8 @@ async def test_execute__client_unsubscribe_with_server_unsubscribe(add_async_fin await moment() - assert redis.in_pubsub == 0 + assert redis.in_pubsub == 1 + assert len(redis.sharded_pubsub_channels) == 0 assert redis._reader_task is not None assert redis._reader_task.done() is False @@ -497,7 +511,8 @@ async def test_subscribe_and_immediately_unsubscribe(caplog, add_async_finalizer for record in caplog.records: assert "No waiter for received reply" not in record.message, record.message - assert redis.in_pubsub == 0 + assert redis.in_pubsub == 1 + assert len(redis.sharded_pubsub_channels) == 0 assert redis._reader_task is not None assert redis._reader_task.done() is False From 103b9ecdbcbacfef0955fa971094a3bcabc33ee5 Mon Sep 17 00:00:00 2001 From: Anton Ilyushenkov Date: Wed, 6 Dec 2023 20:33:46 +0300 Subject: [PATCH 08/17] in_pubsub property now indicates boolead flag instead numbers of created channels for Cluster, Pool, RedisConnections --- tests/aioredis_tests/pubsub_commands_test.py | 28 +++++++++++++++++--- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/tests/aioredis_tests/pubsub_commands_test.py b/tests/aioredis_tests/pubsub_commands_test.py index 1426db5..fd54989 100644 --- a/tests/aioredis_tests/pubsub_commands_test.py +++ b/tests/aioredis_tests/pubsub_commands_test.py @@ -4,6 +4,7 @@ from _testutils import redis_version from aioredis_cluster import _aioredis as aioredis +from aioredis_cluster.connection import RedisConnection as ClusterConnection async def _reader(channel, output, waiter, conn): @@ -49,7 +50,12 @@ async def test_publish_json(create_connection, redis, server): async def test_subscribe(redis): res = await redis.subscribe("chan:1", "chan:2") - assert redis.in_pubsub == 2 + + assert len(redis.connection.pubsub_channels) == 2 + if isinstance(redis.connection, ClusterConnection): + assert redis.in_pubsub == 1 + else: + assert redis.in_pubsub == 2 ch1 = redis.channels["chan:1"] ch2 = redis.channels["chan:2"] @@ -68,7 +74,13 @@ async def test_subscribe__multiple_times(redis): res2 = await redis.subscribe("chan:1") assert redis.in_pubsub == 1 res3 = await redis.psubscribe("chan:1") - assert redis.in_pubsub == 2 + + assert len(redis.connection.pubsub_channels) == 1 + assert len(redis.connection.pubsub_patterns) == 1 + if isinstance(redis.connection, ClusterConnection): + assert redis.in_pubsub == 1 + else: + assert redis.in_pubsub == 2 # res4 = await redis.connection.execute_pubsub("SSUBSCRIBE", "chan:1") # assert redis.in_pubsub == 3 @@ -117,7 +129,10 @@ async def test_subscribe_empty_pool(create_redis, server, _closable): async def test_psubscribe(redis, create_redis, server): sub = redis res = await sub.psubscribe("patt:*", "chan:*") - assert sub.in_pubsub == 2 + if isinstance(redis.connection, ClusterConnection): + assert sub.in_pubsub == 1 + else: + assert sub.in_pubsub == 2 pat1 = sub.patterns["patt:*"] pat2 = sub.patterns["chan:*"] @@ -149,7 +164,12 @@ async def test_psubscribe_empty_pool(create_redis, server, _closable): _closable(pub) await sub.connection.clear() res = await sub.psubscribe("patt:*", "chan:*") - assert sub.in_pubsub == 2 + + assert len(sub.connection.pubsub_patterns) == 2 + if isinstance(sub.connection, ClusterConnection): + assert sub.in_pubsub == 1 + else: + assert sub.in_pubsub == 2 pat1 = sub.patterns["patt:*"] pat2 = sub.patterns["chan:*"] From 4ee1abf272ae9f0b6674018d63a15d22edb8f7f5 Mon Sep 17 00:00:00 2001 From: Anton Ilyushenkov Date: Mon, 11 Dec 2023 21:31:25 +0300 Subject: [PATCH 09/17] add test script for pubsub reshard --- dev/test_pubsub_reshard.py | 174 +++++++++++++++++++++++++++++++++++++ 1 file changed, 174 insertions(+) create mode 100644 dev/test_pubsub_reshard.py diff --git a/dev/test_pubsub_reshard.py b/dev/test_pubsub_reshard.py new file mode 100644 index 0000000..a561e53 --- /dev/null +++ b/dev/test_pubsub_reshard.py @@ -0,0 +1,174 @@ +import argparse +import asyncio +import logging +import random +import signal +from collections import deque +from typing import Counter, Deque, Dict, Mapping, Optional, Set + +try: + from aioredis import Channel, Redis +except ImportError: + from aioredis_cluster.aioredis import Redis, Channel + +from aioredis_cluster import RedisCluster, create_redis_cluster +import async_timeout + +logger = logging.getLogger(__name__) + + +async def tick_log( + tick: float, + routines_counters: Mapping[int, Counter[str]], + global_counters: Counter[str], +) -> None: + count = 0 + while True: + await asyncio.sleep(tick) + count += 1 + logger.info("tick %d", count) + logger.info("tick %d: %r", count, global_counters) + routines = sorted( + routines_counters.items(), + key=lambda item: item[0], + reverse=False, + ) + for routine_id, counters in routines: + logger.info("tick %d: %s: %r", count, routine_id, counters) + + +async def subscribe_routine( + *, + redis: RedisCluster, + routine_id: int, + counters: Counter[str], + global_counters: Counter[str], +): + await asyncio.sleep(0.5) + ch_name = f"ch:{routine_id}:{{shard}}" + while True: + counters["routine:subscribes"] += 1 + prev_ch: Optional[Channel] = None + try: + pool = await redis.keys_master(ch_name) + global_counters["subscribe_in_fly"] += 1 + try: + ch: Channel = (await pool.subscribe(ch_name))[0] + if prev_ch is not None and ch is prev_ch: + logger.error("%s: Previous Channel is current: %r", routine_id, ch) + prev_ch = ch + # logger.info('Wait channel %s', ch_name) + try: + async with async_timeout.timeout(1.0): + res = await ch.get() + except asyncio.TimeoutError: + counters["timeouts"] += 1 + global_counters["timeouts"] += 1 + # logger.warning("%s: ch.get() is timed out", routine_id) + else: + if res is None: + counters["msg:received:None"] += 1 + global_counters["msg:received:None"] += 1 + else: + counters["msg:received"] += 1 + global_counters["msg:received"] += 1 + await pool.unsubscribe(ch_name) + finally: + global_counters["subscribe_in_fly"] -= 1 + assert ch.is_active is False + except asyncio.CancelledError: + break + except Exception as e: + counters["errors"] += 1 + counters[f"errors:{type(e).__name__}"] += 1 + global_counters["errors"] += 1 + global_counters[f"errors:{type(e).__name__}"] += 1 + logger.error("%s: Channel exception: %r", routine_id, e) + + +async def async_main(args: argparse.Namespace) -> None: + loop = asyncio.get_event_loop() + node_addr: str = args.node + + global_counters: Counter[str] = Counter() + routines_counters: Dict[int, Counter[str]] = {} + + tick_task = loop.create_task(tick_log(5.0, routines_counters, global_counters)) + + redis = await create_redis_cluster( + [node_addr], + pool_minsize=1, + pool_maxsize=1, + connect_timeout=1.0, + follow_cluster=True, + ) + routine_tasks: Deque[asyncio.Task] = deque() + try: + for routine_id in range(1, args.routines + 1): + counters = routines_counters[routine_id] = Counter() + routine_task = asyncio.create_task( + subscribe_routine( + redis=redis, + routine_id=routine_id, + counters=counters, + global_counters=global_counters, + ) + ) + routine_tasks.append(routine_task) + + logger.info("Routines %d", len(routine_tasks)) + + await asyncio.sleep(args.wait) + finally: + logger.info("Cancel %d routines", len(routine_tasks)) + for rt in routine_tasks: + if not rt.done(): + rt.cancel() + await asyncio.wait(routine_tasks) + + logger.info("Close redis client") + redis.close() + await redis.wait_closed() + + logger.info("Cancel tick task") + tick_task.cancel() + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("node") + parser.add_argument("--wait", type=float, default=30) + parser.add_argument("--routines", type=int, default=10) + args = parser.parse_args() + if args.routines < 1: + parser.error("--routines must be positive int") + if args.wait < 0: + parser.error("--wait must be positive float or int") + + try: + import uvloop + except ImportError: + pass + else: + uvloop.install() + + logging.basicConfig(level=logging.INFO) + + loop = asyncio.get_event_loop() + main_task = loop.create_task(async_main(args)) + + main_task.add_done_callback(lambda f: loop.stop()) + loop.add_signal_handler(signal.SIGINT, lambda: loop.stop()) + loop.add_signal_handler(signal.SIGTERM, lambda: loop.stop()) + + try: + loop.run_forever() + finally: + if not main_task.done() and not main_task.cancelled(): + main_task.cancel() + loop.run_until_complete(asyncio.wait([main_task])) + loop.close() + + +if __name__ == "__main__": + main() From 6dabd0cd65039ea6e186205c915507ebb341209d Mon Sep 17 00:00:00 2001 From: Anton Ilyushenkov Date: Mon, 11 Dec 2023 21:31:37 +0300 Subject: [PATCH 10/17] add test script for pubsub reshard --- dev/test_pubsub_reshard.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dev/test_pubsub_reshard.py b/dev/test_pubsub_reshard.py index a561e53..65499d1 100644 --- a/dev/test_pubsub_reshard.py +++ b/dev/test_pubsub_reshard.py @@ -11,9 +11,10 @@ except ImportError: from aioredis_cluster.aioredis import Redis, Channel -from aioredis_cluster import RedisCluster, create_redis_cluster import async_timeout +from aioredis_cluster import RedisCluster, create_redis_cluster + logger = logging.getLogger(__name__) From 15ccd58b358a2261328936f33d5a49a4c46db935 Mon Sep 17 00:00:00 2001 From: Anton Ilyushenkov Date: Mon, 11 Dec 2023 21:32:13 +0300 Subject: [PATCH 11/17] add test script for pubsub reshard --- dev/test_pubsub_reshard.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/dev/test_pubsub_reshard.py b/dev/test_pubsub_reshard.py index 65499d1..416eb85 100644 --- a/dev/test_pubsub_reshard.py +++ b/dev/test_pubsub_reshard.py @@ -1,19 +1,14 @@ import argparse import asyncio import logging -import random import signal from collections import deque -from typing import Counter, Deque, Dict, Mapping, Optional, Set - -try: - from aioredis import Channel, Redis -except ImportError: - from aioredis_cluster.aioredis import Redis, Channel +from typing import Counter, Deque, Dict, Mapping, Optional import async_timeout from aioredis_cluster import RedisCluster, create_redis_cluster +from aioredis_cluster.aioredis import Channel logger = logging.getLogger(__name__) From 9e5fea43c234f01845f6582bac0554cdce652710 Mon Sep 17 00:00:00 2001 From: Anton Ilyushenkov Date: Tue, 12 Dec 2023 18:08:04 +0300 Subject: [PATCH 12/17] improve and fix resubscribe channels --- dev/test_pubsub_reshard.py | 5 +- src/aioredis_cluster/connection.py | 199 ++------- src/aioredis_cluster/pubsub.py | 191 ++++++++ .../test_connection_pubsub.py | 56 ++- .../aioredis_cluster/test_pubsub.py | 407 ++++++++++++++++++ 5 files changed, 683 insertions(+), 175 deletions(-) create mode 100644 src/aioredis_cluster/pubsub.py create mode 100644 tests/unit_tests/aioredis_cluster/test_pubsub.py diff --git a/dev/test_pubsub_reshard.py b/dev/test_pubsub_reshard.py index 416eb85..43c7c02 100644 --- a/dev/test_pubsub_reshard.py +++ b/dev/test_pubsub_reshard.py @@ -5,10 +5,9 @@ from collections import deque from typing import Counter, Deque, Dict, Mapping, Optional -import async_timeout - from aioredis_cluster import RedisCluster, create_redis_cluster from aioredis_cluster.aioredis import Channel +from aioredis_cluster.compat.asyncio import timeout logger = logging.getLogger(__name__) @@ -55,7 +54,7 @@ async def subscribe_routine( prev_ch = ch # logger.info('Wait channel %s', ch_name) try: - async with async_timeout.timeout(1.0): + async with timeout(1.0): res = await ch.get() except asyncio.TimeoutError: counters["timeouts"] += 1 diff --git a/src/aioredis_cluster/connection.py b/src/aioredis_cluster/connection.py index bfd33f9..ab2917d 100644 --- a/src/aioredis_cluster/connection.py +++ b/src/aioredis_cluster/connection.py @@ -4,30 +4,21 @@ from collections import deque from contextlib import contextmanager from functools import partial -from types import MappingProxyType from typing import ( Any, Callable, Deque, - Dict, Iterable, List, Mapping, NamedTuple, Optional, Protocol, - Set, Tuple, Union, ) -from aioredis_cluster._aioredis.util import ( - _set_exception, - _set_result, - coerced_keys_dict, - decode, - wait_ok, -) +from aioredis_cluster._aioredis.util import _set_exception, _set_result, decode, wait_ok from aioredis_cluster.abc import AbcChannel, AbcConnection from aioredis_cluster.aioredis import ( Channel, @@ -52,6 +43,7 @@ ) from aioredis_cluster.crc import CrossSlotError, determine_slot from aioredis_cluster.errors import MovedError, RedisError +from aioredis_cluster.pubsub import PubSubStore from aioredis_cluster.typedef import PClosableConnection from aioredis_cluster.util import encode_command, ensure_bytes @@ -88,140 +80,6 @@ async def close_connections(conns: Iterable[PClosableConnection]) -> None: await asyncio.wait(close_waiters) -class PubSub: - def __init__(self) -> None: - self._channels: coerced_keys_dict[AbcChannel] = coerced_keys_dict() - self._patterns: coerced_keys_dict[AbcChannel] = coerced_keys_dict() - self._sharded: coerced_keys_dict[AbcChannel] = coerced_keys_dict() - self._sharded_to_slot: Dict[bytes, int] = {} - self._slot_to_sharded: Dict[int, Set[bytes]] = {} - - @property - def channels(self) -> Mapping[str, AbcChannel]: - """Returns read-only channels dict.""" - return MappingProxyType(self._channels) - - @property - def patterns(self) -> Mapping[str, AbcChannel]: - """Returns read-only patterns dict.""" - return MappingProxyType(self._patterns) - - @property - def sharded(self) -> Mapping[str, AbcChannel]: - """Returns read-only sharded channels dict.""" - return MappingProxyType(self._sharded) - - def channel_subscribe( - self, - *, - channel_type: PubSubType, - channel_name: bytes, - channel: AbcChannel, - key_slot: int, - ) -> None: - if channel_type is PubSubType.CHANNEL: - if channel_name not in self._channels: - self._channels[channel_name] = channel - elif channel_type is PubSubType.PATTERN: - if channel_name not in self._patterns: - self._patterns[channel_name] = channel - elif channel_type is PubSubType.SHARDED: - if channel_name not in self._sharded: - self._sharded[channel_name] = channel - self._sharded_to_slot[channel_name] = key_slot - if key_slot not in self._slot_to_sharded: - self._slot_to_sharded[key_slot] = set((channel_name,)) - else: - self._slot_to_sharded[key_slot].add(channel_name) - - def channel_unsubscribe( - self, - *, - channel_type: PubSubType, - channel_name: bytes, - ) -> None: - channel: Optional[AbcChannel] = None - if channel_type is PubSubType.CHANNEL: - channel = self._channels.pop(channel_name, None) - elif channel_type is PubSubType.PATTERN: - channel = self._patterns.pop(channel_name, None) - elif channel_type is PubSubType.SHARDED: - channel = self._sharded.pop(channel_name, None) - key_slot = self._sharded_to_slot.pop(channel_name, None) - if key_slot is not None: - key_slot_channels = self._slot_to_sharded[key_slot] - key_slot_channels.discard(channel_name) - if len(key_slot_channels) == 0: - del self._slot_to_sharded[key_slot] - - if channel is not None: - channel.close() - - def slot_channels_unsubscribe(self, key_slot: int) -> None: - channel_names = self._slot_to_sharded.pop(key_slot, None) - if channel_names is None: - return - - while channel_names: - channel_name = channel_names.pop() - del self._sharded_to_slot[channel_name] - channel = self._sharded.pop(channel_name) - channel.close() - - def have_slot_channels(self, key_slot: int) -> bool: - return key_slot in self._slot_to_sharded - - @property - def channels_total(self) -> int: - return len(self._channels) + len(self._patterns) + len(self._sharded) - - @property - def channels_num(self) -> int: - return len(self._channels) + len(self._patterns) - - @property - def sharded_channels_num(self) -> int: - return len(self._sharded) - - def has_channel(self, channel_type: PubSubType, channel_name: bytes) -> bool: - ret = False - if channel_type is PubSubType.CHANNEL: - ret = channel_name in self._channels - elif channel_type is PubSubType.PATTERN: - ret = channel_name in self._patterns - elif channel_type is PubSubType.SHARDED: - ret = channel_name in self._sharded - return ret - - def get_channel(self, channel_type: PubSubType, channel_name: bytes) -> AbcChannel: - if channel_type is PubSubType.CHANNEL: - channel = self._channels[channel_name] - elif channel_type is PubSubType.PATTERN: - channel = self._patterns[channel_name] - elif channel_type is PubSubType.SHARDED: - channel = self._sharded[channel_name] - else: # pragma: no cover - raise RuntimeError(f"Unexpected channel type {channel_type!r}") - return channel - - def close(self, exc: Optional[BaseException]) -> None: - while self._channels: - _, ch = self._channels.popitem() - logger.debug("Closing pubsub channel %r", ch) - ch.close(exc) - while self._patterns: - _, ch = self._patterns.popitem() - logger.debug("Closing pubsub pattern %r", ch) - ch.close(exc) - while self._sharded: - _, ch = self._sharded.popitem() - logger.debug("Closing sharded pubsub channel %r", ch) - ch.close(exc) - - self._slot_to_sharded.clear() - self._sharded_to_slot.clear() - - class RedisConnection(AbcConnection): def __init__( self, @@ -250,7 +108,7 @@ def __init__( self._close_state = asyncio.Event() self._in_transaction: Optional[Deque[Tuple[Optional[str], Optional[Callable]]]] = None self._transaction_error: Optional[Exception] = None # XXX: never used? - self._pubsub_channels_store = PubSub() + self._pubsub_store = PubSubStore() # client side PubSub mode flag self._client_in_pubsub = False # confirmed PubSub from Redis server via first subscribe reply @@ -290,17 +148,17 @@ async def auth_with_username(self, username: str, password: str) -> bool: @property def pubsub_channels(self) -> Mapping[str, AbcChannel]: """Returns read-only channels dict.""" - return self._pubsub_channels_store.channels + return self._pubsub_store.channels @property def pubsub_patterns(self) -> Mapping[str, AbcChannel]: """Returns read-only patterns dict.""" - return self._pubsub_channels_store.patterns + return self._pubsub_store.patterns @property def sharded_pubsub_channels(self) -> Mapping[str, AbcChannel]: """Returns read-only sharded channels dict.""" - return self._pubsub_channels_store.sharded + return self._pubsub_store.sharded async def auth(self, password: str) -> bool: """Authenticate to server.""" @@ -390,11 +248,11 @@ def execute_pubsub(self, command, *channels: Union[bytes, str, AbcChannel]): if is_subscribe_command: raise ValueError("No channels to (un)subscribe") elif channel_type is PubSubType.PATTERN: - channels_obj = list(self._pubsub_channels_store.patterns.values()) + channels_obj = list(self._pubsub_store.patterns.values()) elif channel_type is PubSubType.SHARDED: - channels_obj = list(self._pubsub_channels_store.sharded.values()) + channels_obj = list(self._pubsub_store.sharded.values()) else: - channels_obj = list(self._pubsub_channels_store.channels.values()) + channels_obj = list(self._pubsub_store.channels.values()) else: mkchannel = partial(Channel, is_pattern=is_pattern) channels_obj = [] @@ -424,30 +282,31 @@ def execute_pubsub(self, command, *channels: Union[bytes, str, AbcChannel]): if is_subscribe_command: for ch in channels_obj: channel_name = ensure_bytes(ch.name) - self._pubsub_channels_store.channel_subscribe( + self._pubsub_store.channel_subscribe( channel_type=channel_type, channel_name=channel_name, channel=ch, key_slot=key_slot, ) if channel_type is PubSubType.SHARDED: - channels_num = self._pubsub_channels_store.sharded_channels_num + channels_num = self._pubsub_store.sharded_channels_num else: - channels_num = self._pubsub_channels_store.channels_num + channels_num = self._pubsub_store.channels_num res.append([reply_kind, channel_name, channels_num]) # otherwise unsubscribe command else: for ch in channels_obj: channel_name = ensure_bytes(ch.name) - self._pubsub_channels_store.channel_unsubscribe( + self._pubsub_store.channel_unsubscribe( channel_type=channel_type, channel_name=channel_name, + by_reply=False, ) if channel_type is PubSubType.SHARDED: - channels_num = self._pubsub_channels_store.sharded_channels_num + channels_num = self._pubsub_store.sharded_channels_num else: - channels_num = self._pubsub_channels_store.channels_num + channels_num = self._pubsub_store.channels_num res.append([reply_kind, channel_name, channels_num]) if self._pipeline_buffer is None: @@ -606,7 +465,7 @@ def _do_close(self, exc: Optional[BaseException]) -> None: else: _set_exception(waiter.fut, exc) - self._pubsub_channels_store.close(exc) + self._pubsub_store.close(exc) def _on_reader_task_done(self, task: asyncio.Task) -> None: if not task.cancelled() and task.exception(): @@ -673,7 +532,7 @@ async def _read_data(self) -> None: "Received MOVED in PubSub mode. Unsubscribe all channels from %d slot", obj.info.slot_id, ) - self._pubsub_channels_store.slot_channels_unsubscribe(obj.info.slot_id) + self._pubsub_store.slot_channels_unsubscribe(obj.info.slot_id) elif isinstance(obj, RedisError): raise obj else: @@ -681,14 +540,12 @@ async def _read_data(self) -> None: else: if isinstance(obj, RedisError): if isinstance(obj, MovedError): - if self._pubsub_channels_store.have_slot_channels(obj.info.slot_id): + if self._pubsub_store.have_slot_channels(obj.info.slot_id): logger.warning( "Received MOVED. Unsubscribe all channels from %d slot", obj.info.slot_id, ) - self._pubsub_channels_store.slot_channels_unsubscribe( - obj.info.slot_id - ) + self._pubsub_store.slot_channels_unsubscribe(obj.info.slot_id) elif isinstance(obj, ReplyError): if obj.args[0].startswith("READONLY"): obj = ReadOnlyError(obj.args[0]) @@ -722,7 +579,7 @@ def _wakeup_waiter_with_exc(self, exc: Exception) -> None: def _wakeup_waiter_with_result(self, result: Any) -> None: """Processes command results.""" - if self._loop.get_debug(): + if self._loop.get_debug(): # pragma: no cover logger.debug("Wakeup first waiter for reply: %r", result) if not self._waiters: @@ -766,7 +623,7 @@ def _process_pubsub(self, obj: Any) -> Any: and used as callback in `execute_pubsub` for first PubSub mode initial reply """ - if self._loop.get_debug(): + if self._loop.get_debug(): # pragma: no cover logger.debug( "Process PubSub reply (client_in_pubsub:%s, server_in_pubsub:%s): %r", self._client_in_pubsub, @@ -785,16 +642,20 @@ def _process_pubsub(self, obj: Any) -> Any: if kind in {b"subscribe", b"psubscribe", b"ssubscribe"}: logger.debug("PubSub %s event received: %r", kind, obj) - # confirm PubSub mode in client side based on server reply and reset pending flag + (channel_name,) = args + channel_type = PUBSUB_RESP_KIND_TO_TYPE[kind] if self._client_in_pubsub and not self._server_in_pubsub: self._server_in_pubsub = True + # confirm PubSub mode in client side based on server reply and reset pending flag + self._pubsub_store.confirm_subscribe(channel_type, channel_name) elif kind in {b"unsubscribe", b"punsubscribe", b"sunsubscribe"}: logger.debug("PubSub %s event received: %r", kind, obj) (channel_name,) = args channel_type = PUBSUB_RESP_KIND_TO_TYPE[kind] - self._pubsub_channels_store.channel_unsubscribe( + self._pubsub_store.channel_unsubscribe( channel_type=channel_type, channel_name=channel_name, + by_reply=True, ) elif kind in {b"message", b"smessage", b"pmessage"}: if kind == b"pmessage": @@ -804,8 +665,8 @@ def _process_pubsub(self, obj: Any) -> Any: pattern = channel_name channel_type = PUBSUB_RESP_KIND_TO_TYPE[kind] - if self._pubsub_channels_store.has_channel(channel_type, pattern): - channel = self._pubsub_channels_store.get_channel(channel_type, pattern) + if self._pubsub_store.has_channel(channel_type, pattern): + channel = self._pubsub_store.get_channel(channel_type, pattern) if channel_type is PubSubType.PATTERN: channel.put_nowait((channel_name, data)) else: diff --git a/src/aioredis_cluster/pubsub.py b/src/aioredis_cluster/pubsub.py new file mode 100644 index 0000000..fe69cfa --- /dev/null +++ b/src/aioredis_cluster/pubsub.py @@ -0,0 +1,191 @@ +import logging +from types import MappingProxyType +from typing import Dict, Mapping, Optional, Set, Tuple + +from aioredis_cluster._aioredis.util import coerced_keys_dict +from aioredis_cluster.abc import AbcChannel +from aioredis_cluster.command_info.commands import PubSubType +from aioredis_cluster.crc import key_slot as calc_key_slot + +logger = logging.getLogger(__name__) + + +class PubSubStore: + def __init__(self) -> None: + self._channels: coerced_keys_dict[AbcChannel] = coerced_keys_dict() + self._patterns: coerced_keys_dict[AbcChannel] = coerced_keys_dict() + self._sharded: coerced_keys_dict[AbcChannel] = coerced_keys_dict() + self._slot_to_sharded: Dict[int, Set[bytes]] = {} + self._unconfirmed_subscribes: Dict[Tuple[PubSubType, bytes], int] = {} + + @property + def channels(self) -> Mapping[str, AbcChannel]: + """Returns read-only channels dict.""" + return MappingProxyType(self._channels) + + @property + def patterns(self) -> Mapping[str, AbcChannel]: + """Returns read-only patterns dict.""" + return MappingProxyType(self._patterns) + + @property + def sharded(self) -> Mapping[str, AbcChannel]: + """Returns read-only sharded channels dict.""" + return MappingProxyType(self._sharded) + + def channel_subscribe( + self, + *, + channel_type: PubSubType, + channel_name: bytes, + channel: AbcChannel, + key_slot: int, + ) -> None: + if channel_type is PubSubType.CHANNEL: + if channel_name not in self._channels: + self._channels[channel_name] = channel + elif channel_type is PubSubType.PATTERN: + if channel_name not in self._patterns: + self._patterns[channel_name] = channel + elif channel_type is PubSubType.SHARDED: + if key_slot < 0: + raise ValueError("key_slot cannot be negative for sharded channel") + + if channel_name not in self._sharded: + self._sharded[channel_name] = channel + if key_slot not in self._slot_to_sharded: + self._slot_to_sharded[key_slot] = {channel_name} + else: + self._slot_to_sharded[key_slot].add(channel_name) + else: # pragma: no cover + raise RuntimeError(f"Unexpected channel_type {channel_type}") + + unconfirmed_key = (channel_type, channel_name) + if unconfirmed_key not in self._unconfirmed_subscribes: + self._unconfirmed_subscribes[unconfirmed_key] = 1 + else: + self._unconfirmed_subscribes[unconfirmed_key] += 1 + + def confirm_subscribe(self, channel_type: PubSubType, channel_name: bytes) -> None: + unconfirmed_key = (channel_type, channel_name) + val = self._unconfirmed_subscribes.get(unconfirmed_key) + if val is not None: + val -= 1 + if val < 0: + logger.error("Too much PubSub subscribe confirms for %r", unconfirmed_key) + val = 0 + + if val == 0: + del self._unconfirmed_subscribes[unconfirmed_key] + if channel_type is PubSubType.SHARDED: + # this is counterpart of channel_unsubscribe() + # we must remove key_slot -> channel_name entry + # if not channels objects exists + if channel_name not in self._sharded: + self._remove_channel_from_slot_map(channel_name) + else: + self._unconfirmed_subscribes[unconfirmed_key] = val + else: + logger.error("Unexpected PubSub subscribe confirm for %r", unconfirmed_key) + + def channel_unsubscribe( + self, + *, + channel_type: PubSubType, + channel_name: bytes, + by_reply: bool, + ) -> None: + have_unconfirmed_subs = (channel_type, channel_name) in self._unconfirmed_subscribes + # if receive (p|s)unsubscribe reply from Redis + # and make several sequently subscribe->unsubscribe->subscribe commands + # - we must ignore all server unsubscribe replies until all subscribes is confirmed + if by_reply and have_unconfirmed_subs: + return + + channel: Optional[AbcChannel] = None + if channel_type is PubSubType.CHANNEL: + channel = self._channels.pop(channel_name, None) + elif channel_type is PubSubType.PATTERN: + channel = self._patterns.pop(channel_name, None) + elif channel_type is PubSubType.SHARDED: + channel = self._sharded.pop(channel_name, None) + # we must remove key_slot -> channel_name entry + # only if all subscription is confirmed + if not have_unconfirmed_subs: + self._remove_channel_from_slot_map(channel_name) + + if channel is not None: + channel.close() + + def slot_channels_unsubscribe(self, key_slot: int) -> None: + channel_names = self._slot_to_sharded.pop(key_slot, None) + if channel_names is None: + return + + while channel_names: + channel_name = channel_names.pop() + self._unconfirmed_subscribes.pop((PubSubType.SHARDED, channel_name), None) + channel = self._sharded.pop(channel_name, None) + if channel is not None: + channel.close() + + def have_slot_channels(self, key_slot: int) -> bool: + return key_slot in self._slot_to_sharded + + @property + def channels_total(self) -> int: + return len(self._channels) + len(self._patterns) + len(self._sharded) + + @property + def channels_num(self) -> int: + return len(self._channels) + len(self._patterns) + + @property + def sharded_channels_num(self) -> int: + return len(self._sharded) + + def has_channel(self, channel_type: PubSubType, channel_name: bytes) -> bool: + ret = False + if channel_type is PubSubType.CHANNEL: + ret = channel_name in self._channels + elif channel_type is PubSubType.PATTERN: + ret = channel_name in self._patterns + elif channel_type is PubSubType.SHARDED: + ret = channel_name in self._sharded + return ret + + def get_channel(self, channel_type: PubSubType, channel_name: bytes) -> AbcChannel: + if channel_type is PubSubType.CHANNEL: + channel = self._channels[channel_name] + elif channel_type is PubSubType.PATTERN: + channel = self._patterns[channel_name] + elif channel_type is PubSubType.SHARDED: + channel = self._sharded[channel_name] + else: # pragma: no cover + raise RuntimeError(f"Unexpected channel type {channel_type!r}") + return channel + + def close(self, exc: Optional[BaseException]) -> None: + while self._channels: + _, ch = self._channels.popitem() + logger.debug("Closing pubsub channel %r", ch) + ch.close(exc) + while self._patterns: + _, ch = self._patterns.popitem() + logger.debug("Closing pubsub pattern %r", ch) + ch.close(exc) + while self._sharded: + _, ch = self._sharded.popitem() + logger.debug("Closing sharded pubsub channel %r", ch) + ch.close(exc) + + self._slot_to_sharded.clear() + self._unconfirmed_subscribes.clear() + + def _remove_channel_from_slot_map(self, channel_name: bytes) -> None: + key_slot = calc_key_slot(channel_name) + if key_slot in self._slot_to_sharded: + channels_set = self._slot_to_sharded[key_slot] + channels_set.discard(channel_name) + if len(channels_set) == 0: + del self._slot_to_sharded[key_slot] diff --git a/tests/unit_tests/aioredis_cluster/test_connection_pubsub.py b/tests/unit_tests/aioredis_cluster/test_connection_pubsub.py index ebdc32b..0477c61 100644 --- a/tests/unit_tests/aioredis_cluster/test_connection_pubsub.py +++ b/tests/unit_tests/aioredis_cluster/test_connection_pubsub.py @@ -240,15 +240,25 @@ async def test__redis_push_unsubscribe(add_async_finalizer): redis = RedisConnection(reader=reader, writer=writer, address="localhost:6379") add_async_finalizer(lambda: close_connection(redis)) + sub_task = asyncio.ensure_future(redis.execute_pubsub("SUBSCRIBE", "chan:1", "chan:2")) + psub_task = asyncio.ensure_future(redis.execute_pubsub("PSUBSCRIBE", "chan:3", "chan:4")) + ssub_task = asyncio.ensure_future( + redis.execute_pubsub("SSUBSCRIBE", "chan:5:{shard}", "chan:6:{shard}") + ) + await moment() + + # push replies reader.queue.put_nowait([b"subscribe", b"chan:1", 1]) reader.queue.put_nowait([b"subscribe", b"chan:2", 2]) reader.queue.put_nowait([b"psubscribe", b"chan:3", 3]) reader.queue.put_nowait([b"psubscribe", b"chan:4", 4]) reader.queue.put_nowait([b"ssubscribe", b"chan:5:{shard}", 1]) reader.queue.put_nowait([b"ssubscribe", b"chan:6:{shard}", 2]) - await redis.execute_pubsub("SUBSCRIBE", "chan:1", "chan:2") - await redis.execute_pubsub("PSUBSCRIBE", "chan:3", "chan:4") - await redis.execute_pubsub("SSUBSCRIBE", "chan:5:{shard}", "chan:6:{shard}") + await moment() + + assert sub_task.result() + assert psub_task.result() + assert ssub_task.result() assert redis.in_pubsub == 1 assert len(redis.pubsub_channels) == 2 @@ -273,6 +283,8 @@ async def test__redis_push_unsubscribe(add_async_finalizer): assert redis._reader_task.done() is False + assert len(redis._pubsub_store._unconfirmed_subscribes) == 0 + async def test_moved_with_pubsub(add_async_finalizer): reader = get_mocked_reader() @@ -516,3 +528,41 @@ async def test_subscribe_and_immediately_unsubscribe(caplog, add_async_finalizer assert redis._reader_task is not None assert redis._reader_task.done() is False + + +async def test_immediately_resubscribe(caplog, add_async_finalizer): + reader = get_mocked_reader() + writer = get_mocked_writer() + redis = RedisConnection(reader=reader, writer=writer, address="localhost:6379") + add_async_finalizer(lambda: close_connection(redis)) + + # switch connection to pubsub mode + reader.queue.put_nowait([b"ssubscribe", b"chan1:{shard}", 1]) + await redis.execute_pubsub("SSUBSCRIBE", "chan1:{shard}") + + reader.queue.put_nowait([b"ssubscribe", b"chan2:{shard}", 2]) + await redis.execute_pubsub("SSUBSCRIBE", "chan2:{shard}") + + unsub_task = asyncio.ensure_future(redis.execute_pubsub("SUNSUBSCRIBE", "chan2:{shard}")) + sub_task = asyncio.ensure_future(redis.execute_pubsub("SSUBSCRIBE", "chan2:{shard}")) + await moment() + + assert unsub_task.done() is True + assert sub_task.done() is True + + ch = redis.sharded_pubsub_channels["chan2:{shard}"] + + ch_get_task = asyncio.ensure_future(ch.get()) + # start task + await moment() + + # redis send sequence of replies + reader.queue.put_nowait([b"sunsubscribe", b"chan2:{shard}", 1]) + reader.queue.put_nowait([b"ssubscribe", b"chan2:{shard}", 2]) + + # consume replies + await moment() + # wait done callback for ch.get() + await moment() + + assert ch_get_task.done() is False diff --git a/tests/unit_tests/aioredis_cluster/test_pubsub.py b/tests/unit_tests/aioredis_cluster/test_pubsub.py new file mode 100644 index 0000000..3a7ee0a --- /dev/null +++ b/tests/unit_tests/aioredis_cluster/test_pubsub.py @@ -0,0 +1,407 @@ +from unittest import mock + +import pytest + +from aioredis_cluster.abc import AbcChannel +from aioredis_cluster.command_info.commands import PubSubType +from aioredis_cluster.crc import key_slot +from aioredis_cluster.pubsub import PubSubStore + + +def make_channel_mock(): + return mock.NonCallableMock(AbcChannel) + + +def test_channel_subscribe__one_channel(): + store = PubSubStore() + chan = make_channel_mock() + pchan = make_channel_mock() + schan = make_channel_mock() + + store.channel_subscribe( + channel_type=PubSubType.CHANNEL, + channel_name=b"chan", + channel=chan, + key_slot=-1, + ) + + assert store.channels_num == 1 + assert store.channels_total == 1 + assert store.sharded_channels_num == 0 + assert len(store.channels) == 1 + assert len(store.patterns) == 0 + assert len(store.sharded) == 0 + assert store.channels["chan"] is chan + + store.channel_subscribe( + channel_type=PubSubType.PATTERN, + channel_name=b"pchan", + channel=pchan, + key_slot=-1, + ) + + assert store.channels_num == 2 + assert store.channels_total == 2 + assert store.sharded_channels_num == 0 + assert len(store.channels) == 1 + assert len(store.patterns) == 1 + assert len(store.sharded) == 0 + assert store.patterns["pchan"] is pchan + + schan_key_slot = key_slot(b"schan") + store.channel_subscribe( + channel_type=PubSubType.SHARDED, + channel_name=b"schan", + channel=schan, + key_slot=schan_key_slot, + ) + + assert store.channels_num == 2 + assert store.channels_total == 3 + assert store.sharded_channels_num == 1 + assert len(store.channels) == 1 + assert len(store.patterns) == 1 + assert len(store.sharded) == 1 + assert store.sharded["schan"] is schan + + assert schan_key_slot in store._slot_to_sharded + assert store._slot_to_sharded[schan_key_slot] == {b"schan"} + + +def test_close__empty(): + store = PubSubStore() + store.close(None) + + +def test_close__with_channels(): + store = PubSubStore() + + chan = make_channel_mock() + pchan = make_channel_mock() + schan = make_channel_mock() + + store.channel_subscribe( + channel_type=PubSubType.CHANNEL, + channel_name=b"chan", + channel=chan, + key_slot=-1, + ) + store.channel_subscribe( + channel_type=PubSubType.PATTERN, + channel_name=b"pchan", + channel=pchan, + key_slot=-1, + ) + store.channel_subscribe( + channel_type=PubSubType.SHARDED, + channel_name=b"schan", + channel=schan, + key_slot=key_slot(b"schan"), + ) + + close_exc = Exception() + store.close(close_exc) + + assert store.channels_total == 0 + assert len(store._slot_to_sharded) == 0 + assert len(store._unconfirmed_subscribes) == 0 + + chan.close.assert_called_once_with(close_exc) + pchan.close.assert_called_once_with(close_exc) + schan.close.assert_called_once_with(close_exc) + + +def test_confirm_subscribe__no_confirms(caplog): + store = PubSubStore() + + with caplog.at_level("ERROR", "aioredis_cluster.pubsub"): + store.confirm_subscribe(PubSubType.CHANNEL, b"chan") + + assert len(caplog.records) == 1 + assert "Unexpected PubSub subscribe confirm for" in caplog.records[0].message + + +def test_confirm_subscribe__simple_confirm(caplog): + store = PubSubStore() + + chan = make_channel_mock() + pchan = make_channel_mock() + schan = make_channel_mock() + + store.channel_subscribe( + channel_type=PubSubType.CHANNEL, + channel_name=b"chan", + channel=chan, + key_slot=-1, + ) + store.channel_subscribe( + channel_type=PubSubType.PATTERN, + channel_name=b"pchan", + channel=pchan, + key_slot=-1, + ) + schan_key_slot = key_slot(b"schan") + store.channel_subscribe( + channel_type=PubSubType.SHARDED, + channel_name=b"schan", + channel=schan, + key_slot=schan_key_slot, + ) + + assert len(store._unconfirmed_subscribes) == 3 + assert schan_key_slot in store._slot_to_sharded + + with caplog.at_level("ERROR", "aioredis_cluster.pubsub"): + store.confirm_subscribe(PubSubType.SHARDED, b"schan") + + assert len(store._unconfirmed_subscribes) == 2 + assert schan_key_slot in store._slot_to_sharded + assert len(caplog.records) == 0 + + with caplog.at_level("ERROR", "aioredis_cluster.pubsub"): + store.confirm_subscribe(PubSubType.PATTERN, b"pchan") + store.confirm_subscribe(PubSubType.CHANNEL, b"chan") + + assert len(store._unconfirmed_subscribes) == 0 + assert len(caplog.records) == 0 + + +@pytest.mark.parametrize("channel_type", list(PubSubType)) +def test_confirm_subscribe__resub_confirms(channel_type): + store = PubSubStore() + chan = make_channel_mock() + channel_name = b"chan" + channel_key_slot = key_slot(channel_name) + + # calls in execute_pubsub() for (P|S)SUBSCRIBE commands + store.channel_subscribe( + channel_type=channel_type, + channel_name=channel_name, + channel=chan, + # this is key slot for b"chan" + key_slot=channel_key_slot, + ) + + # calls in execute_pubsub() for (P|S)UNSUBSCRIBE commands + store.channel_unsubscribe( + channel_type=channel_type, + channel_name=channel_name, + by_reply=False, + ) + + store.channel_subscribe( + channel_type=channel_type, + channel_name=channel_name, + channel=chan, + # this is key slot for b"chan" + key_slot=channel_key_slot, + ) + + assert len(store._unconfirmed_subscribes) == 1 + assert store._unconfirmed_subscribes[(channel_type, channel_name)] == 2 + assert store.channels_total == 1 + + # call process_pubsub on receive (p|s)subscribe kind events + store.confirm_subscribe(channel_type, channel_name) + + assert len(store._unconfirmed_subscribes) == 1 + assert store._unconfirmed_subscribes[(channel_type, channel_name)] == 1 + if channel_type is PubSubType.SHARDED: + assert channel_key_slot in store._slot_to_sharded + + # call process_pubsub on receive (p|s)unsubscribe kind events + # this call do nothing because have unconfirmed subscription calls + store.channel_unsubscribe( + channel_type=channel_type, + channel_name=channel_name, + by_reply=True, + ) + + # second server reply for subscribe command + store.confirm_subscribe(channel_type, channel_name) + assert len(store._unconfirmed_subscribes) == 0 + + +@pytest.mark.parametrize("channel_type", list(PubSubType)) +def test_confirm_subscribe__resub_and_unsub(channel_type): + store = PubSubStore() + chan = make_channel_mock() + channel_name = b"chan" + channel_key_slot = key_slot(channel_name) + + for i in range(2): + store.channel_subscribe( + channel_type=channel_type, + channel_name=channel_name, + channel=chan, + # this is key slot for b"chan" + key_slot=channel_key_slot, + ) + store.channel_unsubscribe( + channel_type=channel_type, + channel_name=channel_name, + by_reply=False, + ) + + assert len(store._unconfirmed_subscribes) == 1 + assert store._unconfirmed_subscribes[(channel_type, channel_name)] == 2 + assert store.channels_total == 0 + + # first subscribe reply + store.confirm_subscribe(channel_type, channel_name) + + assert len(store._unconfirmed_subscribes) == 1 + assert store._unconfirmed_subscribes[(channel_type, channel_name)] == 1 + if channel_type is PubSubType.SHARDED: + assert channel_key_slot in store._slot_to_sharded + + # first unsubscribe reply + store.channel_unsubscribe( + channel_type=channel_type, + channel_name=channel_name, + by_reply=True, + ) + + # second subscribe reply + store.confirm_subscribe(channel_type, channel_name) + assert len(store._unconfirmed_subscribes) == 0 + if channel_type is PubSubType.SHARDED: + assert len(store._slot_to_sharded) == 0 + + +@pytest.mark.parametrize("channel_type", list(PubSubType)) +def test_confirm_subscribe__unsub_server_push(channel_type): + store = PubSubStore() + chan = make_channel_mock() + channel_name = b"chan" + channel_key_slot = key_slot(channel_name) + + store.channel_subscribe( + channel_type=channel_type, + channel_name=channel_name, + channel=chan, + # this is key slot for b"chan" + key_slot=channel_key_slot, + ) + + # server push (p|s)unsubscribe kind event + # probably previous sub->unsub attempts + # we do nothing + store.channel_unsubscribe( + channel_type=channel_type, + channel_name=channel_name, + by_reply=True, + ) + + assert store.channels_total == 1 + chan.close.assert_not_called() + + # subscribe reply + store.confirm_subscribe(channel_type, channel_name) + + # server push (p|s)unsubscribe kind event + # maybe is cluster reshard or node is shutting down + # we must close all confirmed channels + store.channel_unsubscribe( + channel_type=channel_type, + channel_name=channel_name, + by_reply=True, + ) + + assert store.channels_total == 0 + chan.close.assert_called_once_with() + if channel_type is PubSubType.SHARDED: + assert len(store._slot_to_sharded) == 0 + + +@pytest.mark.parametrize("channel_type", list(PubSubType)) +def test_channel_unsubscribe__subscribe_confirm_and_unsubscibe(caplog, channel_type): + store = PubSubStore() + chan = make_channel_mock() + channel_name = b"chan" + channel_key_slot = key_slot(channel_name) + + store.channel_subscribe( + channel_type=channel_type, + channel_name=channel_name, + channel=chan, + # this is key slot for b"chan" + key_slot=channel_key_slot, + ) + # subscribe reply + store.confirm_subscribe(channel_type, channel_name) + + store.channel_unsubscribe( + channel_type=channel_type, + channel_name=channel_name, + by_reply=False, + ) + + assert store.channels_total == 0 + assert store.have_slot_channels(channel_key_slot) is False + assert len(store._slot_to_sharded) == 0 + chan.close.assert_called_once_with() + + with caplog.at_level("WARNING", "aioredis_cluster.pubsub"): + # server reply for unsubscribe + store.channel_unsubscribe( + channel_type=channel_type, + channel_name=channel_name, + by_reply=True, + ) + + assert len(caplog.records) == 0 + + +def test_slot_channels_unsubscribe__empty(): + store = PubSubStore() + store.slot_channels_unsubscribe(1234) + + +def test_slot_channels_unsubscribe__with_unconfirmed_subscribes(): + # this is unrealistic case but we need check this + store = PubSubStore() + chan1 = make_channel_mock() + channel1_name = b"chan1" + channel1_key_slot = key_slot(channel1_name) + chan2 = make_channel_mock() + channel2_name = b"chan2" + channel2_key_slot = key_slot(channel2_name) + chan3 = make_channel_mock() + channel3_name = b"chan3:{chan1}" + channel3_key_slot = key_slot(channel3_name) + + assert channel1_key_slot == channel3_key_slot + + store.channel_subscribe( + channel_type=PubSubType.SHARDED, + channel_name=channel1_name, + channel=chan1, + key_slot=channel1_key_slot, + ) + store.channel_subscribe( + channel_type=PubSubType.SHARDED, + channel_name=channel2_name, + channel=chan2, + key_slot=channel2_key_slot, + ) + store.channel_subscribe( + channel_type=PubSubType.SHARDED, + channel_name=channel3_name, + channel=chan3, + key_slot=channel3_key_slot, + ) + + assert store.channels_total == 3 + assert store.have_slot_channels(channel1_key_slot) is True + assert store.have_slot_channels(channel2_key_slot) is True + assert store.have_slot_channels(0) is False + + store.slot_channels_unsubscribe(channel3_key_slot) + + assert store.channels_total == 1 + assert len(store._unconfirmed_subscribes) == 1 + assert (PubSubType.SHARDED, channel2_name) in store._unconfirmed_subscribes + chan1.close.assert_called_once_with() + chan2.close.assert_not_called() + chan3.close.assert_called_once_with() From 1183172caf0b3d6a0469d711795c6c0b3e7be005 Mon Sep 17 00:00:00 2001 From: Anton Ilyushenkov Date: Tue, 12 Dec 2023 18:14:48 +0300 Subject: [PATCH 13/17] improve and fix resubscribe channels --- .../test_connection_pubsub.py | 34 +++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/tests/unit_tests/aioredis_cluster/test_connection_pubsub.py b/tests/unit_tests/aioredis_cluster/test_connection_pubsub.py index 0477c61..5bf38fe 100644 --- a/tests/unit_tests/aioredis_cluster/test_connection_pubsub.py +++ b/tests/unit_tests/aioredis_cluster/test_connection_pubsub.py @@ -566,3 +566,37 @@ async def test_immediately_resubscribe(caplog, add_async_finalizer): await moment() assert ch_get_task.done() is False + + +async def test_resubscribe_with_message_received(add_async_finalizer): + reader = get_mocked_reader() + writer = get_mocked_writer() + redis = RedisConnection(reader=reader, writer=writer, address="localhost:6379") + add_async_finalizer(lambda: close_connection(redis)) + + # switch connection to pubsub mode + asyncio.ensure_future(redis.execute_pubsub("SSUBSCRIBE", "chan1:{shard}")) + reader.queue.put_nowait([b"ssubscribe", b"chan1:{shard}", 1]) + + await moment() + + asyncio.ensure_future(redis.execute_pubsub("SUNSUBSCRIBE", "chan2:{shard}")) + + await moment() + + resub_task = asyncio.ensure_future(redis.execute_pubsub("SSUBSCRIBE", "chan1:{shard}")) + + reader.queue.put_nowait([b"smessage", b"chan2:{shard}", b"msg1"]) + reader.queue.put_nowait([b"sunsubscribe", b"chan2:{shard}", 0]) + reader.queue.put_nowait([b"smessage", b"chan2:{shard}", b"msg2"]) + + await moment() + + channel_name = resub_task.result()[0] + ch = redis.sharded_pubsub_channels[channel_name] + + ch_get_task = asyncio.ensure_future(ch.get()) + # start task + await moment() + + assert ch_get_task.done() is False From d6f87fa91f8e549cc29a40162fed9b9bf216017b Mon Sep 17 00:00:00 2001 From: Anton Ilyushenkov Date: Tue, 12 Dec 2023 18:39:18 +0300 Subject: [PATCH 14/17] add more tests --- tests/system_tests/test_connection_pubsub.py | 50 ++++++++++++++++++- .../test_connection_pubsub.py | 29 +++++++---- 2 files changed, 68 insertions(+), 11 deletions(-) diff --git a/tests/system_tests/test_connection_pubsub.py b/tests/system_tests/test_connection_pubsub.py index 737b7ee..3ad5a5a 100644 --- a/tests/system_tests/test_connection_pubsub.py +++ b/tests/system_tests/test_connection_pubsub.py @@ -5,6 +5,7 @@ import pytest from aioredis_cluster import Cluster +from aioredis_cluster.compat.asyncio import timeout @pytest.mark.redis_version(gte="7.0.0") @@ -37,5 +38,50 @@ async def test_moved_with_pubsub(cluster: Callable[[], Awaitable[Cluster]]): assert num_of_closed > 0 assert len(redis.sharded_pubsub_channels) < len(channels_dump) - redis.close() - await redis.wait_closed() + client.close() + await client.wait_closed() + + +@pytest.mark.redis_version(gte="7.0.0") +@pytest.mark.timeout(2) +async def test_immediately_resubscribe(cluster: Callable[[], Awaitable[Cluster]]): + client = await cluster() + redis = await client.keys_master("chan") + for i in range(10): + await redis.ssubscribe("chan") + await redis.sunsubscribe("chan") + chan = (await redis.ssubscribe("chan"))[0] + + chan_get_task = asyncio.ensure_future(chan.get()) + + await client.execute("spublish", "chan", "msg1") + await client.execute("spublish", "chan", "msg2") + await client.execute("spublish", "chan", "msg3") + # wait 50ms until Redis message delivery + await asyncio.sleep(0.05) + + assert chan_get_task.done() is True + assert chan_get_task.result() == b"msg1" + + # message must be already exists in internal queue + async with timeout(0): + msg2 = await chan.get() + assert msg2 == b"msg2" + + async with timeout(0): + msg3 = await chan.get() + assert msg3 == b"msg3" + + # no more messages + with pytest.raises(asyncio.TimeoutError): + async with timeout(0): + await chan.get() + + assert chan.is_active is True + + await redis.sunsubscribe("chan") + + assert chan.is_active is False + + client.close() + await client.wait_closed() diff --git a/tests/unit_tests/aioredis_cluster/test_connection_pubsub.py b/tests/unit_tests/aioredis_cluster/test_connection_pubsub.py index 5bf38fe..14e016e 100644 --- a/tests/unit_tests/aioredis_cluster/test_connection_pubsub.py +++ b/tests/unit_tests/aioredis_cluster/test_connection_pubsub.py @@ -8,6 +8,7 @@ from aioredis_cluster.aioredis import ChannelClosedError from aioredis_cluster.aioredis.stream import StreamReader from aioredis_cluster.command_info.commands import PUBSUB_SUBSCRIBE_COMMANDS +from aioredis_cluster.compat.asyncio import timeout from aioredis_cluster.connection import RedisConnection from aioredis_cluster.errors import MovedError, RedisError @@ -580,23 +581,33 @@ async def test_resubscribe_with_message_received(add_async_finalizer): await moment() - asyncio.ensure_future(redis.execute_pubsub("SUNSUBSCRIBE", "chan2:{shard}")) + asyncio.ensure_future(redis.execute_pubsub("SUNSUBSCRIBE", "chan1:{shard}")) await moment() resub_task = asyncio.ensure_future(redis.execute_pubsub("SSUBSCRIBE", "chan1:{shard}")) + await moment() - reader.queue.put_nowait([b"smessage", b"chan2:{shard}", b"msg1"]) - reader.queue.put_nowait([b"sunsubscribe", b"chan2:{shard}", 0]) - reader.queue.put_nowait([b"smessage", b"chan2:{shard}", b"msg2"]) + reader.queue.put_nowait([b"smessage", b"chan1:{shard}", b"msg1"]) + reader.queue.put_nowait([b"sunsubscribe", b"chan1:{shard}", 0]) + reader.queue.put_nowait([b"ssubscribe", b"chan1:{shard}", 1]) + reader.queue.put_nowait([b"smessage", b"chan1:{shard}", b"msg2"]) + # push loop cycle for received events await moment() - channel_name = resub_task.result()[0] + channel_name = resub_task.result()[0][1] ch = redis.sharded_pubsub_channels[channel_name] - ch_get_task = asyncio.ensure_future(ch.get()) - # start task - await moment() + async with timeout(0): + msg1 = await ch.get() + assert msg1 == b"msg1" - assert ch_get_task.done() is False + async with timeout(0): + msg2 = await ch.get() + assert msg2 == b"msg2" + + # no more messages + with pytest.raises(asyncio.TimeoutError): + async with timeout(0.001): + await ch.get() From 2a8164c4fb6672ea18d6a9218c4bc6b9c58ad311 Mon Sep 17 00:00:00 2001 From: Anton Ilyushenkov Date: Tue, 12 Dec 2023 18:48:33 +0300 Subject: [PATCH 15/17] fix caplog tests --- .../aioredis_cluster/test_pooler.py | 46 +++++++++++++------ 1 file changed, 31 insertions(+), 15 deletions(-) diff --git a/tests/unit_tests/aioredis_cluster/test_pooler.py b/tests/unit_tests/aioredis_cluster/test_pooler.py index 2ed672c..8d3c2b5 100644 --- a/tests/unit_tests/aioredis_cluster/test_pooler.py +++ b/tests/unit_tests/aioredis_cluster/test_pooler.py @@ -15,11 +15,12 @@ def create_pool_mock(): return mocked -async def test_ensure_pool__identical_address(): +async def test_ensure_pool__identical_address(add_async_finalizer): mocked_create_pool = mock.AsyncMock( return_value=create_pool_mock(), ) pooler = Pooler(mocked_create_pool) + add_async_finalizer(lambda: pooler.close()) result = await pooler.ensure_pool(Address("localhost", 1234)) @@ -32,11 +33,12 @@ async def test_ensure_pool__identical_address(): assert mocked_create_pool.call_count == 1 -async def test_ensure_pool__multiple(): +async def test_ensure_pool__multiple(add_async_finalizer): pools = [object(), object(), object()] mocked_create_pool = mock.AsyncMock(side_effect=pools) pooler = Pooler(mocked_create_pool) + add_async_finalizer(lambda: pooler.close()) result1 = await pooler.ensure_pool(Address("localhost", 1234)) result2 = await pooler.ensure_pool(Address("localhost", 4321)) @@ -55,7 +57,7 @@ async def test_ensure_pool__multiple(): ) -async def test_ensure_pool__only_one(): +async def test_ensure_pool__only_one(add_async_finalizer): event_loop = asyncio.get_running_loop() pools = { ("h1", 1): create_pool_mock(), @@ -71,6 +73,7 @@ async def create_pool_se(addr): mocked_create_pool = mock.AsyncMock(side_effect=create_pool_se) pooler = Pooler(mocked_create_pool) + add_async_finalizer(lambda: pooler.close()) tasks = [] for i in range(10): @@ -88,11 +91,12 @@ async def create_pool_se(addr): assert mocked_create_pool.call_count == 2 -async def test_ensure_pool__error(): +async def test_ensure_pool__error(add_async_finalizer): pools = [RuntimeError(), object()] mocked_create_pool = mock.AsyncMock(side_effect=pools) pooler = Pooler(mocked_create_pool) + add_async_finalizer(lambda: pooler.close()) addr = Address("localhost", 1234) with pytest.raises(RuntimeError): @@ -117,7 +121,7 @@ async def test_close__empty_pooler(): assert pooler.closed is True -async def test_close__with_pools(mocker): +async def test_close__with_pools(mocker, add_async_finalizer): addrs_pools = [ (Address("h1", 1), create_pool_mock()), (Address("h2", 2), create_pool_mock()), @@ -127,6 +131,7 @@ async def test_close__with_pools(mocker): mocked_create_pool = mock.AsyncMock(side_effect=pools) pooler = Pooler(mocked_create_pool) + add_async_finalizer(lambda: pooler.close()) result1 = await pooler.ensure_pool(addrs[0]) result2 = await pooler.ensure_pool(addrs[1]) @@ -143,7 +148,7 @@ async def test_close__with_pools(mocker): result2.wait_closed.assert_called_once() -async def test_reap_pools(mocker): +async def test_reap_pools(mocker, add_async_finalizer): addrs_pools = [ (Address("h1", 1), create_pool_mock()), (Address("h2", 2), create_pool_mock()), @@ -153,6 +158,7 @@ async def test_reap_pools(mocker): mocked_create_pool = mock.AsyncMock(side_effect=pools) pooler = Pooler(mocked_create_pool, reap_frequency=-1) + add_async_finalizer(lambda: pooler.close()) # create pools await pooler.ensure_pool(addrs[0]) @@ -176,8 +182,9 @@ async def test_reap_pools(mocker): assert len(pooler._nodes) == 0 -async def test_reaper(mocker): +async def test_reaper(mocker, add_async_finalizer): pooler = Pooler(mock.AsyncMock(), reap_frequency=0) + add_async_finalizer(lambda: pooler.close()) assert pooler._reap_calls == 0 @@ -194,8 +201,9 @@ async def test_reaper(mocker): assert pooler._reaper_task.cancelled() is True -async def test_add_pubsub_channel__no_addr(): +async def test_add_pubsub_channel__no_addr(add_async_finalizer): pooler = Pooler(mock.AsyncMock(), reap_frequency=-1) + add_async_finalizer(lambda: pooler.close()) addr = Address("h1", 1234) result = pooler.add_pubsub_channel(addr, b"channel", is_pattern=False) @@ -203,8 +211,9 @@ async def test_add_pubsub_channel__no_addr(): assert result is False -async def test_add_pubsub_channel(): +async def test_add_pubsub_channel(add_async_finalizer): pooler = Pooler(mock.AsyncMock(return_value=create_pool_mock()), reap_frequency=-1) + add_async_finalizer(lambda: pooler.close()) addr1 = Address("h1", 1234) addr2 = Address("h2", 1234) @@ -237,15 +246,17 @@ async def test_add_pubsub_channel(): assert (b"ch3", False) in collected_channels -async def test_remove_pubsub_channel__no_addr(): +async def test_remove_pubsub_channel__no_addr(add_async_finalizer): pooler = Pooler(mock.AsyncMock(), reap_frequency=-1) + add_async_finalizer(lambda: pooler.close()) result = pooler.remove_pubsub_channel(b"channel", is_pattern=False) assert result is False -async def test_remove_pubsub_channel(): +async def test_remove_pubsub_channel(add_async_finalizer): pooler = Pooler(mock.AsyncMock(), reap_frequency=-1) + add_async_finalizer(lambda: pooler.close()) addr1 = Address("h1", 1234) addr2 = Address("h2", 1234) @@ -275,8 +286,9 @@ async def test_remove_pubsub_channel(): assert len(pooler._pubsub_channels) == 0 -async def test_get_pubsub_addr(): +async def test_get_pubsub_addr(add_async_finalizer): pooler = Pooler(mock.AsyncMock(), reap_frequency=-1) + add_async_finalizer(lambda: pooler.close()) addr1 = Address("h1", 1234) addr2 = Address("h2", 1234) @@ -297,11 +309,12 @@ async def test_get_pubsub_addr(): assert result4 == addr2 -async def test_ensure_pool__create_pubsub_addr_set(): +async def test_ensure_pool__create_pubsub_addr_set(add_async_finalizer): addr1 = Address("h1", 1234) addr2 = Address("h2", 1234) pooler = Pooler(mock.AsyncMock(return_value=create_pool_mock())) + add_async_finalizer(lambda: pooler.close()) assert len(pooler._pubsub_addrs) == 0 @@ -320,9 +333,10 @@ async def test_ensure_pool__create_pubsub_addr_set(): assert len(pooler._pubsub_addrs[addr1]) == 1 -async def test_reap_pools__cleanup_channels(): +async def test_reap_pools__cleanup_channels(add_async_finalizer): pool_factory = mock.AsyncMock(return_value=mock.Mock(AbcPool)) pooler = Pooler(pool_factory, reap_frequency=-1) + add_async_finalizer(lambda: pooler.close()) addr1 = Address("h1", 1) addr2 = Address("h2", 2) @@ -346,12 +360,14 @@ async def test_reap_pools__cleanup_channels(): assert len(pooler._pubsub_channels) == 0 -async def test_close_only(): +async def test_close_only(add_async_finalizer): pool1 = create_pool_mock() pool2 = create_pool_mock() pool3 = create_pool_mock() mocked_create_pool = mock.AsyncMock(side_effect=[pool1, pool2, pool3]) pooler = Pooler(mocked_create_pool) + add_async_finalizer(lambda: pooler.close()) + addr1 = Address("h1", 1) addr2 = Address("h2", 2) From 29e16ac853313c5aab62f40d341583eb5bd9ab84 Mon Sep 17 00:00:00 2001 From: Anton Ilyushenkov Date: Tue, 12 Dec 2023 18:51:06 +0300 Subject: [PATCH 16/17] fix caplog tests --- tests/unit_tests/aioredis_cluster/test_pooler.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/unit_tests/aioredis_cluster/test_pooler.py b/tests/unit_tests/aioredis_cluster/test_pooler.py index 8d3c2b5..0a6000a 100644 --- a/tests/unit_tests/aioredis_cluster/test_pooler.py +++ b/tests/unit_tests/aioredis_cluster/test_pooler.py @@ -3,7 +3,7 @@ import pytest -from aioredis_cluster.abc import AbcPool +from aioredis_cluster.abc import AbcConnection, AbcPool from aioredis_cluster.pooler import Pooler from aioredis_cluster.structs import Address @@ -34,7 +34,11 @@ async def test_ensure_pool__identical_address(add_async_finalizer): async def test_ensure_pool__multiple(add_async_finalizer): - pools = [object(), object(), object()] + pools = [ + mock.AsyncMock(AbcConnection), + mock.AsyncMock(AbcConnection), + mock.AsyncMock(AbcConnection), + ] mocked_create_pool = mock.AsyncMock(side_effect=pools) pooler = Pooler(mocked_create_pool) @@ -92,7 +96,7 @@ async def create_pool_se(addr): async def test_ensure_pool__error(add_async_finalizer): - pools = [RuntimeError(), object()] + pools = [RuntimeError(), mock.AsyncMock(AbcConnection)] mocked_create_pool = mock.AsyncMock(side_effect=pools) pooler = Pooler(mocked_create_pool) From 5187d543ccbaabe28164448cc18c726cd446dc4b Mon Sep 17 00:00:00 2001 From: Anton Ilyushenkov Date: Tue, 12 Dec 2023 20:45:05 +0300 Subject: [PATCH 17/17] fixes and improves --- dev/test_pubsub_reshard.py | 160 +++++++++++++++++++++-- src/aioredis_cluster/cluster_state.py | 3 + src/aioredis_cluster/commands/cluster.py | 19 ++- src/aioredis_cluster/connection.py | 25 +++- 4 files changed, 190 insertions(+), 17 deletions(-) diff --git a/dev/test_pubsub_reshard.py b/dev/test_pubsub_reshard.py index 43c7c02..6b11e8a 100644 --- a/dev/test_pubsub_reshard.py +++ b/dev/test_pubsub_reshard.py @@ -1,13 +1,16 @@ import argparse import asyncio import logging +import random import signal from collections import deque -from typing import Counter, Deque, Dict, Mapping, Optional +from itertools import cycle +from typing import Counter, Deque, Dict, Mapping, Optional, Sequence -from aioredis_cluster import RedisCluster, create_redis_cluster +from aioredis_cluster import Cluster, ClusterNode, RedisCluster, create_redis_cluster from aioredis_cluster.aioredis import Channel from aioredis_cluster.compat.asyncio import timeout +from aioredis_cluster.crc import key_slot logger = logging.getLogger(__name__) @@ -18,8 +21,13 @@ async def tick_log( global_counters: Counter[str], ) -> None: count = 0 + last = False while True: - await asyncio.sleep(tick) + try: + await asyncio.sleep(tick) + except asyncio.CancelledError: + last = True + count += 1 logger.info("tick %d", count) logger.info("tick %d: %r", count, global_counters) @@ -31,6 +39,17 @@ async def tick_log( for routine_id, counters in routines: logger.info("tick %d: %s: %r", count, routine_id, counters) + if last: + break + + +def get_channel_name(routine_id: int) -> str: + return f"ch:{routine_id}:{{shard}}" + + +class ChannelNotClosedError(Exception): + pass + async def subscribe_routine( *, @@ -40,7 +59,7 @@ async def subscribe_routine( global_counters: Counter[str], ): await asyncio.sleep(0.5) - ch_name = f"ch:{routine_id}:{{shard}}" + ch_name = get_channel_name(routine_id) while True: counters["routine:subscribes"] += 1 prev_ch: Optional[Channel] = None @@ -48,7 +67,7 @@ async def subscribe_routine( pool = await redis.keys_master(ch_name) global_counters["subscribe_in_fly"] += 1 try: - ch: Channel = (await pool.subscribe(ch_name))[0] + ch: Channel = (await pool.ssubscribe(ch_name))[0] if prev_ch is not None and ch is prev_ch: logger.error("%s: Previous Channel is current: %r", routine_id, ch) prev_ch = ch @@ -67,10 +86,13 @@ async def subscribe_routine( else: counters["msg:received"] += 1 global_counters["msg:received"] += 1 - await pool.unsubscribe(ch_name) + await pool.sunsubscribe(ch_name) finally: global_counters["subscribe_in_fly"] -= 1 - assert ch.is_active is False + + if not ch._queue.closed: + raise ChannelNotClosedError() + except asyncio.CancelledError: break except Exception as e: @@ -81,6 +103,108 @@ async def subscribe_routine( logger.error("%s: Channel exception: %r", routine_id, e) +async def publish_routine( + *, + redis: RedisCluster, + routines: Sequence[int], + global_counters: Counter[str], +): + routines_cycle = cycle(routines) + for routine_id in routines_cycle: + await asyncio.sleep(random.uniform(0.001, 0.1)) + channel_name = get_channel_name(routine_id) + await redis.spublish(channel_name, f"msg:{routine_id}") + global_counters["published_msgs"] += 1 + + +async def cluster_move_slot( + *, + slot: int, + node_src: RedisCluster, + node_src_info: ClusterNode, + node_dest: RedisCluster, + node_dest_info: ClusterNode, +) -> None: + await node_dest.cluster_setslot(slot, "IMPORTING", node_src_info.node_id) + await node_src.cluster_setslot(slot, "MIGRATING", node_dest_info.node_id) + while True: + keys = await node_src.cluster_get_keys_in_slots(slot, 100) + if not keys: + break + await node_src.migrate_keys( + node_dest_info.addr.host, + node_dest_info.addr.port, + keys, + 0, + 5000, + ) + await node_dest.cluster_setslot(slot, "NODE", node_dest_info.node_id) + await node_src.cluster_setslot(slot, "NODE", node_dest_info.node_id) + + +async def reshard_routine( + *, + redis: RedisCluster, + global_counters: Counter[str], +): + cluster: Cluster = redis.connection + cluster_state = await cluster.get_cluster_state() + + moving_slot = key_slot(b"shard") + master1 = await cluster.keys_master(b"shard") + + # search another master + count = 0 + while True: + count += 1 + key = f"key:{count}" + master2 = await cluster.keys_master(key) + if master2.address != master1.address: + break + + master1_info = cluster_state.addr_node(master1.address) + master2_info = cluster_state.addr_node(master2.address) + logger.info("Master1 info - %r", master1_info) + logger.info("Master2 info - %r", master2_info) + + node_src = master1 + node_src_info = master1_info + node_dest = master2 + node_dest_info = master2_info + + while True: + await asyncio.sleep(random.uniform(2.0, 5.0)) + global_counters["reshards"] += 1 + logger.info( + "Start moving slot %s from %s to %s", moving_slot, node_src.address, node_dest.address + ) + move_slot_task = asyncio.create_task( + cluster_move_slot( + slot=moving_slot, + node_src=node_src, + node_src_info=node_src_info, + node_dest=node_dest, + node_dest_info=node_dest_info, + ) + ) + try: + await asyncio.shield(move_slot_task) + except asyncio.CancelledError: + await move_slot_task + raise + except Exception as e: + logger.exception("Unexpected error on reshard: %r", e) + break + + # swap nodes + node_src, node_src_info, node_dest, node_dest_info = ( + node_dest, + node_dest_info, + node_src, + node_src_info, + ) + + async def async_main(args: argparse.Namespace) -> None: loop = asyncio.get_event_loop() node_addr: str = args.node @@ -98,6 +222,15 @@ async def async_main(args: argparse.Namespace) -> None: follow_cluster=True, ) routine_tasks: Deque[asyncio.Task] = deque() + + reshard_routine_task = asyncio.create_task( + reshard_routine( + redis=redis, + global_counters=global_counters, + ) + ) + routine_tasks.append(reshard_routine_task) + try: for routine_id in range(1, args.routines + 1): counters = routines_counters[routine_id] = Counter() @@ -111,6 +244,15 @@ async def async_main(args: argparse.Namespace) -> None: ) routine_tasks.append(routine_task) + publish_routine_task = asyncio.create_task( + publish_routine( + redis=redis, + routines=list(range(1, args.routines + 1)), + global_counters=global_counters, + ) + ) + routine_tasks.append(publish_routine_task) + logger.info("Routines %d", len(routine_tasks)) await asyncio.sleep(args.wait) @@ -147,7 +289,7 @@ def main() -> None: else: uvloop.install() - logging.basicConfig(level=logging.INFO) + logging.basicConfig(level=logging.DEBUG) loop = asyncio.get_event_loop() main_task = loop.create_task(async_main(args)) @@ -161,7 +303,7 @@ def main() -> None: finally: if not main_task.done() and not main_task.cancelled(): main_task.cancel() - loop.run_until_complete(asyncio.wait([main_task])) + loop.run_until_complete(main_task) loop.close() diff --git a/src/aioredis_cluster/cluster_state.py b/src/aioredis_cluster/cluster_state.py index 4c34964..ab20eb6 100644 --- a/src/aioredis_cluster/cluster_state.py +++ b/src/aioredis_cluster/cluster_state.py @@ -212,6 +212,9 @@ def random_node(self) -> ClusterNode: def has_addr(self, addr: Address) -> bool: return addr in self._data.nodes + def addr_node(self, addr: Address) -> ClusterNode: + return self._data.nodes[addr] + def master_replicas(self, addr: Address) -> List[ClusterNode]: try: return list(self._data.replicas[addr]) diff --git a/src/aioredis_cluster/commands/cluster.py b/src/aioredis_cluster/commands/cluster.py index 7674888..8dcb656 100644 --- a/src/aioredis_cluster/commands/cluster.py +++ b/src/aioredis_cluster/commands/cluster.py @@ -105,10 +105,25 @@ def cluster_set_config_epoch(self, config_epoch: int): fut = self.execute(b"CLUSTER", b"SET-CONFIG-EPOCH", config_epoch) return wait_ok(fut) - def cluster_setslot(self, slot: int, command, node_id: str = None): + def cluster_setslot(self, slot: int, subcommand: str, node_id: str = None): """Bind a hash slot to specified node.""" - raise NotImplementedError() + subcommand = subcommand.upper() + if subcommand in {"NODE", "IMPORTING", "MIGRATING"}: + if not node_id: + raise ValueError(f"For subcommand {subcommand} node_id must be provided") + elif subcommand in {"STABLE"}: + if node_id: + raise ValueError("For subcommand STABLE node_id is not required") + else: + raise ValueError(f"Unknown subcommand {subcommand}") + + extra = [] + if node_id: + extra.append(node_id) + + fut = self.execute(b"CLUSTER", b"SETSLOT", slot, subcommand, *extra) + return fut def cluster_slaves(self, node_id: str): """List slave nodes of the specified master node.""" diff --git a/src/aioredis_cluster/connection.py b/src/aioredis_cluster/connection.py index ab2917d..3c79a96 100644 --- a/src/aioredis_cluster/connection.py +++ b/src/aioredis_cluster/connection.py @@ -528,11 +528,18 @@ async def _read_data(self) -> None: if self._server_in_pubsub: if isinstance(obj, MovedError): - logger.warning( - "Received MOVED in PubSub mode. Unsubscribe all channels from %d slot", - obj.info.slot_id, - ) - self._pubsub_store.slot_channels_unsubscribe(obj.info.slot_id) + if self._pubsub_store.have_slot_channels(obj.info.slot_id): + logger.warning( + ( + "Received MOVED in PubSub mode from %s to %s:%s. " + "Unsubscribe all channels from %d slot", + ), + self.address, + obj.info.host, + obj.info.port, + obj.info.slot_id, + ) + self._pubsub_store.slot_channels_unsubscribe(obj.info.slot_id) elif isinstance(obj, RedisError): raise obj else: @@ -542,7 +549,13 @@ async def _read_data(self) -> None: if isinstance(obj, MovedError): if self._pubsub_store.have_slot_channels(obj.info.slot_id): logger.warning( - "Received MOVED. Unsubscribe all channels from %d slot", + ( + "Received MOVED from %s to %s:%s. " + "Unsubscribe all channels from %d slot", + ), + self.address, + obj.info.host, + obj.info.port, obj.info.slot_id, ) self._pubsub_store.slot_channels_unsubscribe(obj.info.slot_id)