diff --git a/dev/test_pubsub_reshard.py b/dev/test_pubsub_reshard.py new file mode 100644 index 0000000..6b11e8a --- /dev/null +++ b/dev/test_pubsub_reshard.py @@ -0,0 +1,311 @@ +import argparse +import asyncio +import logging +import random +import signal +from collections import deque +from itertools import cycle +from typing import Counter, Deque, Dict, Mapping, Optional, Sequence + +from aioredis_cluster import Cluster, ClusterNode, RedisCluster, create_redis_cluster +from aioredis_cluster.aioredis import Channel +from aioredis_cluster.compat.asyncio import timeout +from aioredis_cluster.crc import key_slot + +logger = logging.getLogger(__name__) + + +async def tick_log( + tick: float, + routines_counters: Mapping[int, Counter[str]], + global_counters: Counter[str], +) -> None: + count = 0 + last = False + while True: + try: + await asyncio.sleep(tick) + except asyncio.CancelledError: + last = True + + count += 1 + logger.info("tick %d", count) + logger.info("tick %d: %r", count, global_counters) + routines = sorted( + routines_counters.items(), + key=lambda item: item[0], + reverse=False, + ) + for routine_id, counters in routines: + logger.info("tick %d: %s: %r", count, routine_id, counters) + + if last: + break + + +def get_channel_name(routine_id: int) -> str: + return f"ch:{routine_id}:{{shard}}" + + +class ChannelNotClosedError(Exception): + pass + + +async def subscribe_routine( + *, + redis: RedisCluster, + routine_id: int, + counters: Counter[str], + global_counters: Counter[str], +): + await asyncio.sleep(0.5) + ch_name = get_channel_name(routine_id) + while True: + counters["routine:subscribes"] += 1 + prev_ch: Optional[Channel] = None + try: + pool = await redis.keys_master(ch_name) + global_counters["subscribe_in_fly"] += 1 + try: + ch: Channel = (await pool.ssubscribe(ch_name))[0] + if prev_ch is not None and ch is prev_ch: + logger.error("%s: Previous Channel is current: %r", routine_id, ch) + prev_ch = ch + # logger.info('Wait channel %s', ch_name) + try: + async with timeout(1.0): + res = await ch.get() + except asyncio.TimeoutError: + counters["timeouts"] += 1 + global_counters["timeouts"] += 1 + # logger.warning("%s: ch.get() is timed out", routine_id) + else: + if res is None: + counters["msg:received:None"] += 1 + global_counters["msg:received:None"] += 1 + else: + counters["msg:received"] += 1 + global_counters["msg:received"] += 1 + await pool.sunsubscribe(ch_name) + finally: + global_counters["subscribe_in_fly"] -= 1 + + if not ch._queue.closed: + raise ChannelNotClosedError() + + except asyncio.CancelledError: + break + except Exception as e: + counters["errors"] += 1 + counters[f"errors:{type(e).__name__}"] += 1 + global_counters["errors"] += 1 + global_counters[f"errors:{type(e).__name__}"] += 1 + logger.error("%s: Channel exception: %r", routine_id, e) + + +async def publish_routine( + *, + redis: RedisCluster, + routines: Sequence[int], + global_counters: Counter[str], +): + routines_cycle = cycle(routines) + for routine_id in routines_cycle: + await asyncio.sleep(random.uniform(0.001, 0.1)) + channel_name = get_channel_name(routine_id) + await redis.spublish(channel_name, f"msg:{routine_id}") + global_counters["published_msgs"] += 1 + + +async def cluster_move_slot( + *, + slot: int, + node_src: RedisCluster, + node_src_info: ClusterNode, + node_dest: RedisCluster, + node_dest_info: ClusterNode, +) -> None: + await node_dest.cluster_setslot(slot, "IMPORTING", node_src_info.node_id) + await node_src.cluster_setslot(slot, "MIGRATING", node_dest_info.node_id) + while True: + keys = await node_src.cluster_get_keys_in_slots(slot, 100) + if not keys: + break + await node_src.migrate_keys( + node_dest_info.addr.host, + node_dest_info.addr.port, + keys, + 0, + 5000, + ) + await node_dest.cluster_setslot(slot, "NODE", node_dest_info.node_id) + await node_src.cluster_setslot(slot, "NODE", node_dest_info.node_id) + + +async def reshard_routine( + *, + redis: RedisCluster, + global_counters: Counter[str], +): + cluster: Cluster = redis.connection + cluster_state = await cluster.get_cluster_state() + + moving_slot = key_slot(b"shard") + master1 = await cluster.keys_master(b"shard") + + # search another master + count = 0 + while True: + count += 1 + key = f"key:{count}" + master2 = await cluster.keys_master(key) + if master2.address != master1.address: + break + + master1_info = cluster_state.addr_node(master1.address) + master2_info = cluster_state.addr_node(master2.address) + logger.info("Master1 info - %r", master1_info) + logger.info("Master2 info - %r", master2_info) + + node_src = master1 + node_src_info = master1_info + node_dest = master2 + node_dest_info = master2_info + + while True: + await asyncio.sleep(random.uniform(2.0, 5.0)) + global_counters["reshards"] += 1 + logger.info( + "Start moving slot %s from %s to %s", moving_slot, node_src.address, node_dest.address + ) + move_slot_task = asyncio.create_task( + cluster_move_slot( + slot=moving_slot, + node_src=node_src, + node_src_info=node_src_info, + node_dest=node_dest, + node_dest_info=node_dest_info, + ) + ) + try: + await asyncio.shield(move_slot_task) + except asyncio.CancelledError: + await move_slot_task + raise + except Exception as e: + logger.exception("Unexpected error on reshard: %r", e) + break + + # swap nodes + node_src, node_src_info, node_dest, node_dest_info = ( + node_dest, + node_dest_info, + node_src, + node_src_info, + ) + + +async def async_main(args: argparse.Namespace) -> None: + loop = asyncio.get_event_loop() + node_addr: str = args.node + + global_counters: Counter[str] = Counter() + routines_counters: Dict[int, Counter[str]] = {} + + tick_task = loop.create_task(tick_log(5.0, routines_counters, global_counters)) + + redis = await create_redis_cluster( + [node_addr], + pool_minsize=1, + pool_maxsize=1, + connect_timeout=1.0, + follow_cluster=True, + ) + routine_tasks: Deque[asyncio.Task] = deque() + + reshard_routine_task = asyncio.create_task( + reshard_routine( + redis=redis, + global_counters=global_counters, + ) + ) + routine_tasks.append(reshard_routine_task) + + try: + for routine_id in range(1, args.routines + 1): + counters = routines_counters[routine_id] = Counter() + routine_task = asyncio.create_task( + subscribe_routine( + redis=redis, + routine_id=routine_id, + counters=counters, + global_counters=global_counters, + ) + ) + routine_tasks.append(routine_task) + + publish_routine_task = asyncio.create_task( + publish_routine( + redis=redis, + routines=list(range(1, args.routines + 1)), + global_counters=global_counters, + ) + ) + routine_tasks.append(publish_routine_task) + + logger.info("Routines %d", len(routine_tasks)) + + await asyncio.sleep(args.wait) + finally: + logger.info("Cancel %d routines", len(routine_tasks)) + for rt in routine_tasks: + if not rt.done(): + rt.cancel() + await asyncio.wait(routine_tasks) + + logger.info("Close redis client") + redis.close() + await redis.wait_closed() + + logger.info("Cancel tick task") + tick_task.cancel() + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("node") + parser.add_argument("--wait", type=float, default=30) + parser.add_argument("--routines", type=int, default=10) + args = parser.parse_args() + if args.routines < 1: + parser.error("--routines must be positive int") + if args.wait < 0: + parser.error("--wait must be positive float or int") + + try: + import uvloop + except ImportError: + pass + else: + uvloop.install() + + logging.basicConfig(level=logging.DEBUG) + + loop = asyncio.get_event_loop() + main_task = loop.create_task(async_main(args)) + + main_task.add_done_callback(lambda f: loop.stop()) + loop.add_signal_handler(signal.SIGINT, lambda: loop.stop()) + loop.add_signal_handler(signal.SIGTERM, lambda: loop.stop()) + + try: + loop.run_forever() + finally: + if not main_task.done() and not main_task.cancelled(): + main_task.cancel() + loop.run_until_complete(main_task) + loop.close() + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index e6eefcf..26c239c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,7 +88,7 @@ target-version = ["py38"] exclude = '.pyi$' [tool.pytest.ini_options] -addopts = "--cov-report=term --cov-report=html -v" +addopts = "-v --cov --cov-report=term --cov-report=html" asyncio_mode = "auto" [tool.coverage.run] diff --git a/src/aioredis_cluster/_aioredis/parser.py b/src/aioredis_cluster/_aioredis/parser.py index 0d443ea..5fe075f 100644 --- a/src/aioredis_cluster/_aioredis/parser.py +++ b/src/aioredis_cluster/_aioredis/parser.py @@ -53,7 +53,12 @@ def getmaxbuf(self) -> int: class Parser: - def __init__(self, protocolError: Callable, replyError: Callable, encoding: Optional[str]): + def __init__( + self, + protocolError: Callable, + replyError: Callable, + encoding: Optional[str], + ): self.buf: bytearray = bytearray() self.pos: int = 0 self.protocolError: Callable = protocolError diff --git a/src/aioredis_cluster/_aioredis/pubsub.py b/src/aioredis_cluster/_aioredis/pubsub.py index 9da6413..47e3fdb 100644 --- a/src/aioredis_cluster/_aioredis/pubsub.py +++ b/src/aioredis_cluster/_aioredis/pubsub.py @@ -4,6 +4,7 @@ import sys import types import warnings +from typing import Optional from .abc import AbcChannel from .errors import ChannelClosedError @@ -37,17 +38,17 @@ def __repr__(self): ) @property - def name(self): + def name(self) -> bytes: """Encoded channel name/pattern.""" return self._name @property - def is_pattern(self): + def is_pattern(self) -> bool: """Set to True if channel is subscribed to pattern.""" return self._is_pattern @property - def is_active(self): + def is_active(self) -> bool: """Returns True until there are messages in channel or connection is subscribed to it. @@ -87,11 +88,11 @@ async def get(self, *, encoding=None, decoder=None): return dest_channel, msg return msg - async def get_json(self, encoding="utf-8"): + async def get_json(self, encoding: str = "utf-8"): """Shortcut to get JSON messages.""" return await self.get(encoding=encoding, decoder=json.loads) - def iter(self, *, encoding=None, decoder=None): + def iter(self, *, encoding: Optional[str] = None, decoder=None): """Same as get method but its native coroutine. Usage example: @@ -103,7 +104,7 @@ def iter(self, *, encoding=None, decoder=None): self, is_active=lambda ch: ch.is_active, encoding=encoding, decoder=decoder ) - async def wait_message(self): + async def wait_message(self) -> bool: """Waits for message to become available in channel or channel is closed (unsubscribed). @@ -121,10 +122,10 @@ async def wait_message(self): # internal methods - def put_nowait(self, data): + def put_nowait(self, data) -> None: self._queue.put(data) - def close(self, exc=None): + def close(self, exc: Optional[BaseException] = None) -> None: """Marks channel as inactive. Internal method, will be called from connection diff --git a/src/aioredis_cluster/_aioredis/util.py b/src/aioredis_cluster/_aioredis/util.py index 386d84b..d60f017 100644 --- a/src/aioredis_cluster/_aioredis/util.py +++ b/src/aioredis_cluster/_aioredis/util.py @@ -1,5 +1,6 @@ import asyncio import sys +from typing import Any, Dict, Generic, TypeVar from urllib.parse import parse_qsl, urlparse from .log import logger @@ -10,6 +11,9 @@ _NOTSET = object() +T = TypeVar("T") + + IS_PY38 = sys.version_info >= (3, 8) # NOTE: never put here anything else; @@ -75,13 +79,13 @@ async def wait_make_dict(fut): return dict(zip(it, it)) -class coerced_keys_dict(dict): - def __getitem__(self, other): +class coerced_keys_dict(Generic[T], Dict[Any, T], dict): + def __getitem__(self, other) -> T: if not isinstance(other, bytes): other = _converters[type(other)](other) return dict.__getitem__(self, other) - def __contains__(self, other): + def __contains__(self, other) -> bool: if not isinstance(other, bytes): other = _converters[type(other)](other) return dict.__contains__(self, other) @@ -108,7 +112,7 @@ async def __anext__(self): return ret -def _set_result(fut, result, *info): +def _set_result(fut: asyncio.Future, result: Any, *info) -> None: if fut.done(): logger.debug("Waiter future is already done %r %r", fut, info) assert fut.cancelled(), ("waiting future is in wrong state", fut, result, info) @@ -116,7 +120,7 @@ def _set_result(fut, result, *info): fut.set_result(result) -def _set_exception(fut, exception): +def _set_exception(fut: asyncio.Future, exception: BaseException) -> None: if fut.done(): logger.debug("Waiter future is already done %r", fut) assert fut.cancelled(), ("waiting future is in wrong state", fut, exception) diff --git a/src/aioredis_cluster/aioredis/__init__.py b/src/aioredis_cluster/aioredis/__init__.py index 8f74d12..392e326 100644 --- a/src/aioredis_cluster/aioredis/__init__.py +++ b/src/aioredis_cluster/aioredis/__init__.py @@ -1,67 +1,31 @@ -try: - from aioredis.commands import ( - GeoMember, - GeoPoint, - Redis, - create_redis, - create_redis_pool, - ) -except ImportError: - from .._aioredis.commands import ( - GeoMember, - GeoPoint, - Redis, - create_redis, - create_redis_pool, - ) +from .commands import GeoMember, GeoPoint, Redis, create_redis, create_redis_pool try: from aioredis.connection import RedisConnection except ImportError: from .._aioredis.connection import RedisConnection -try: - from aioredis.errors import ( - AuthError, - BusyGroupError, - ChannelClosedError, - ConnectionClosedError, - ConnectionForcedCloseError, - MasterNotFoundError, - MasterReplyError, - MaxClientsError, - MultiExecError, - PipelineError, - PoolClosedError, - ProtocolError, - ReadOnlyError, - RedisError, - ReplyError, - SlaveNotFoundError, - SlaveReplyError, - WatchVariableError, - ) -except ImportError: - from .._aioredis.errors import ( - AuthError, - BusyGroupError, - ChannelClosedError, - ConnectionClosedError, - ConnectionForcedCloseError, - MasterNotFoundError, - MasterReplyError, - MaxClientsError, - MultiExecError, - PipelineError, - PoolClosedError, - ProtocolError, - ReadOnlyError, - RedisError, - ReplyError, - SlaveNotFoundError, - SlaveReplyError, - WatchVariableError, - ) +from .errors import ( + AuthError, + BusyGroupError, + ChannelClosedError, + ConnectionClosedError, + ConnectionForcedCloseError, + MasterNotFoundError, + MasterReplyError, + MaxClientsError, + MultiExecError, + PipelineError, + PoolClosedError, + ProtocolError, + ReadOnlyError, + RedisError, + ReplyError, + SlaveNotFoundError, + SlaveReplyError, + WatchVariableError, +) + try: from aioredis.pool import ConnectionsPool except ImportError: diff --git a/src/aioredis_cluster/aioredis/commands/__init__.py b/src/aioredis_cluster/aioredis/commands/__init__.py index 5fc0342..b4eee07 100644 --- a/src/aioredis_cluster/aioredis/commands/__init__.py +++ b/src/aioredis_cluster/aioredis/commands/__init__.py @@ -1,3 +1,9 @@ +import asyncio +from typing import Any, Optional, Tuple, Union + +from aioredis_cluster.aioredis.connection import create_connection +from aioredis_cluster.aioredis.pool import create_pool + try: from aioredis.commands import ( ContextRedis, @@ -5,8 +11,7 @@ GeoPoint, MultiExec, Pipeline, - create_redis, - create_redis_pool, + Redis, ) except ImportError: from aioredis_cluster._aioredis.commands import ( @@ -15,8 +20,7 @@ GeoPoint, MultiExec, Pipeline, - create_redis, - create_redis_pool, + Redis, ) @@ -32,3 +36,69 @@ "GeoPoint", "GeoMember", ) + + +async def create_redis( + address: Union[Tuple[str, int], str], + *, + db: Optional[int] = None, + password: Optional[str] = None, + ssl: Optional[Any] = None, + encoding: Optional[str] = None, + commands_factory: Redis = Redis, + parser=None, + timeout: Optional[float] = None, + connection_cls=None, + loop: Optional[asyncio.AbstractEventLoop] = None, +) -> Redis: + """Creates high-level Redis interface. + + This function is a coroutine. + """ + conn = await create_connection( + address, + db=db, + password=password, + ssl=ssl, + encoding=encoding, + parser=parser, + timeout=timeout, + connection_cls=connection_cls, + ) + return commands_factory(conn) + + +async def create_redis_pool( + address: Union[Tuple[str, int], str], + *, + db: Optional[int] = None, + password: Optional[str] = None, + ssl: Optional[Any] = None, + encoding: Optional[str] = None, + commands_factory: Redis = Redis, + minsize: int = 1, + maxsize: int = 10, + parser=None, + timeout: Optional[float] = None, + pool_cls=None, + connection_cls=None, + loop: Optional[asyncio.AbstractEventLoop] = None, +) -> Redis: + """Creates high-level Redis interface. + + This function is a coroutine. + """ + pool = await create_pool( + address, + db=db, + password=password, + ssl=ssl, + encoding=encoding, + minsize=minsize, + maxsize=maxsize, + parser=parser, + create_connection_timeout=timeout, + pool_cls=pool_cls, + connection_cls=connection_cls, + ) + return commands_factory(pool) diff --git a/src/aioredis_cluster/aioredis/connection.py b/src/aioredis_cluster/aioredis/connection.py index 4cb9ed1..46bbb98 100644 --- a/src/aioredis_cluster/aioredis/connection.py +++ b/src/aioredis_cluster/aioredis/connection.py @@ -7,12 +7,9 @@ from aioredis_cluster.compat.asyncio import timeout as atimeout from .abc import AbcConnection +from .stream import open_connection, open_unix_connection from .util import parse_url -try: - from aioredis.stream import open_connection, open_unix_connection -except ImportError: - from .._aioredis.stream import open_connection, open_unix_connection try: from aioredis.connection import MAX_CHUNK_SIZE, RedisConnection except ImportError: diff --git a/src/aioredis_cluster/aioredis/parser.py b/src/aioredis_cluster/aioredis/parser.py new file mode 100644 index 0000000..155c338 --- /dev/null +++ b/src/aioredis_cluster/aioredis/parser.py @@ -0,0 +1,9 @@ +try: + from aioredis.parser import PyReader, Reader +except ImportError: + from .._aioredis.parser import PyReader, Reader + +__all__ = ( + "Reader", + "PyReader", +) diff --git a/src/aioredis_cluster/aioredis/stream.py b/src/aioredis_cluster/aioredis/stream.py new file mode 100644 index 0000000..608ad92 --- /dev/null +++ b/src/aioredis_cluster/aioredis/stream.py @@ -0,0 +1,10 @@ +try: + from aioredis.stream import StreamReader, open_connection, open_unix_connection +except ImportError: + from .._aioredis.stream import StreamReader, open_connection, open_unix_connection + +__all__ = ( + "open_connection", + "open_unix_connection", + "StreamReader", +) diff --git a/src/aioredis_cluster/cluster.py b/src/aioredis_cluster/cluster.py index 03e1e67..93cb73e 100644 --- a/src/aioredis_cluster/cluster.py +++ b/src/aioredis_cluster/cluster.py @@ -29,7 +29,7 @@ from aioredis_cluster.command_info import CommandInfo, extract_keys from aioredis_cluster.commands import RedisCluster from aioredis_cluster.compat.asyncio import timeout as atimeout -from aioredis_cluster.crc import key_slot +from aioredis_cluster.crc import CrossSlotError, determine_slot from aioredis_cluster.errors import ( AskError, ClusterClosedError, @@ -362,7 +362,10 @@ def in_pubsub(self) -> int: Can be tested as bool indicating Pub/Sub mode state. """ - return sum(p.in_pubsub for p in self._pooler.pools()) + for pool in self._pooler.pools(): + if pool.in_pubsub: + return 1 + return 0 @property def pubsub_channels(self) -> Mapping[str, AbcChannel]: @@ -418,12 +421,10 @@ async def authorize(pool) -> None: await self._pooler.batch_op(authorize) def determine_slot(self, first_key: bytes, *keys: bytes) -> int: - slot: int = key_slot(first_key) - for k in keys: - if slot != key_slot(k): - raise RedisClusterError("all keys must map to the same key slot") - - return slot + try: + return determine_slot(first_key, *keys) + except CrossSlotError as e: + raise RedisClusterError(str(e)) from None async def all_masters(self) -> List[Redis]: ctx = self._make_exec_context((b"PING",), {}) diff --git a/src/aioredis_cluster/cluster_state.py b/src/aioredis_cluster/cluster_state.py index 4c34964..ab20eb6 100644 --- a/src/aioredis_cluster/cluster_state.py +++ b/src/aioredis_cluster/cluster_state.py @@ -212,6 +212,9 @@ def random_node(self) -> ClusterNode: def has_addr(self, addr: Address) -> bool: return addr in self._data.nodes + def addr_node(self, addr: Address) -> ClusterNode: + return self._data.nodes[addr] + def master_replicas(self, addr: Address) -> List[ClusterNode]: try: return list(self._data.replicas[addr]) diff --git a/src/aioredis_cluster/command_info/commands.py b/src/aioredis_cluster/command_info/commands.py index 87f61fb..491d241 100644 --- a/src/aioredis_cluster/command_info/commands.py +++ b/src/aioredis_cluster/command_info/commands.py @@ -1,5 +1,6 @@ # redis command output (v5.0.8) -from typing import FrozenSet, Iterable, MutableSet, Union +import enum +from typing import Dict, FrozenSet, Iterable, Mapping, MutableSet, Union, cast __all__ = ( "COMMANDS", @@ -10,9 +11,14 @@ "XREAD_COMMAND", "XREADGROUP_COMMAND", "PUBSUB_COMMANDS", + "PATTERN_PUBSUB_COMMANDS", "SHARDED_PUBSUB_COMMANDS", "PUBSUB_FAMILY_COMMANDS", "PING_COMMANDS", + "PUBSUB_SUBSCRIBE_COMMANDS", + "PUBSUB_COMMAND_TO_TYPE", + "PUBSUB_RESP_KIND_TO_TYPE", + "PubSubType", ) @@ -270,11 +276,17 @@ def _gen_commands_set(commands: Iterable[str]) -> FrozenSet[Union[bytes, str]]: XREADGROUP_COMMAND = "XREADGROUP" -PUBSUB_COMMANDS = _gen_commands_set( +class PubSubType(enum.Enum): + CHANNEL = enum.auto() + PATTERN = enum.auto() + SHARDED = enum.auto() + + +PUBSUB_COMMANDS = _gen_commands_set({"SUBSCRIBE", "UNSUBSCRIBE"}) + +PATTERN_PUBSUB_COMMANDS = _gen_commands_set( { - "SUBSCRIBE", "PSUBSCRIBE", - "UNSUBSCRIBE", "PUNSUBSCRIBE", } ) @@ -286,6 +298,35 @@ def _gen_commands_set(commands: Iterable[str]) -> FrozenSet[Union[bytes, str]]: } ) -PUBSUB_FAMILY_COMMANDS = PUBSUB_COMMANDS | SHARDED_PUBSUB_COMMANDS +PUBSUB_SUBSCRIBE_COMMANDS = _gen_commands_set( + { + "SUBSCRIBE", + "PSUBSCRIBE", + "SSUBSCRIBE", + } +) + +PUBSUB_FAMILY_COMMANDS = PUBSUB_COMMANDS | PATTERN_PUBSUB_COMMANDS | SHARDED_PUBSUB_COMMANDS PING_COMMANDS = _gen_commands_set({"PING"}) + +PUBSUB_RESP_KIND_TO_TYPE: Mapping[bytes, PubSubType] = { + b"message": PubSubType.CHANNEL, + b"subscribe": PubSubType.CHANNEL, + b"unsubscribe": PubSubType.CHANNEL, + b"pmessage": PubSubType.PATTERN, + b"psubscribe": PubSubType.PATTERN, + b"punsubscribe": PubSubType.PATTERN, + b"smessage": PubSubType.SHARDED, + b"ssubscribe": PubSubType.SHARDED, + b"sunsubscribe": PubSubType.SHARDED, +} + +_pubsub_command_to_type: Dict[Union[str, bytes], PubSubType] = {} +for cmd in PUBSUB_COMMANDS: + _pubsub_command_to_type[cmd] = PubSubType.CHANNEL +for cmd in PATTERN_PUBSUB_COMMANDS: + _pubsub_command_to_type[cmd] = PubSubType.PATTERN +for cmd in SHARDED_PUBSUB_COMMANDS: + _pubsub_command_to_type[cmd] = PubSubType.SHARDED +PUBSUB_COMMAND_TO_TYPE = cast(Mapping[Union[str, bytes], PubSubType], _pubsub_command_to_type) diff --git a/src/aioredis_cluster/commands/cluster.py b/src/aioredis_cluster/commands/cluster.py index 7674888..8dcb656 100644 --- a/src/aioredis_cluster/commands/cluster.py +++ b/src/aioredis_cluster/commands/cluster.py @@ -105,10 +105,25 @@ def cluster_set_config_epoch(self, config_epoch: int): fut = self.execute(b"CLUSTER", b"SET-CONFIG-EPOCH", config_epoch) return wait_ok(fut) - def cluster_setslot(self, slot: int, command, node_id: str = None): + def cluster_setslot(self, slot: int, subcommand: str, node_id: str = None): """Bind a hash slot to specified node.""" - raise NotImplementedError() + subcommand = subcommand.upper() + if subcommand in {"NODE", "IMPORTING", "MIGRATING"}: + if not node_id: + raise ValueError(f"For subcommand {subcommand} node_id must be provided") + elif subcommand in {"STABLE"}: + if node_id: + raise ValueError("For subcommand STABLE node_id is not required") + else: + raise ValueError(f"Unknown subcommand {subcommand}") + + extra = [] + if node_id: + extra.append(node_id) + + fut = self.execute(b"CLUSTER", b"SETSLOT", slot, subcommand, *extra) + return fut def cluster_slaves(self, node_id: str): """List slave nodes of the specified master node.""" diff --git a/src/aioredis_cluster/commands/sharded_pubsub.py b/src/aioredis_cluster/commands/sharded_pubsub.py index eb8ccaf..3df0b84 100644 --- a/src/aioredis_cluster/commands/sharded_pubsub.py +++ b/src/aioredis_cluster/commands/sharded_pubsub.py @@ -1,8 +1,8 @@ -from typing import Awaitable, Callable +from typing import Awaitable, Callable, Mapping from aioredis_cluster._aioredis.commands.pubsub import wait_return_channels from aioredis_cluster._aioredis.util import wait_make_dict -from aioredis_cluster.abc import AbcConnection +from aioredis_cluster.abc import AbcChannel, AbcConnection class ShardedPubSubCommandsMixin: @@ -54,7 +54,7 @@ def pubsub_shardnumsub(self, *channels): return wait_make_dict(self.execute(b"PUBSUB", b"SHARDNUMSUB", *channels)) @property - def sharded_pubsub_channels(self): + def sharded_pubsub_channels(self) -> Mapping[str, AbcChannel]: """Returns read-only channels dict. See :attr:`~aioredis.RedisConnection.pubsub_channels` diff --git a/src/aioredis_cluster/connection.py b/src/aioredis_cluster/connection.py index 342039e..3c79a96 100644 --- a/src/aioredis_cluster/connection.py +++ b/src/aioredis_cluster/connection.py @@ -1,26 +1,76 @@ import asyncio import logging +import warnings +from collections import deque +from contextlib import contextmanager from functools import partial -from types import MappingProxyType -from typing import Iterable, List, Mapping +from typing import ( + Any, + Callable, + Deque, + Iterable, + List, + Mapping, + NamedTuple, + Optional, + Protocol, + Tuple, + Union, +) -from aioredis_cluster._aioredis.util import coerced_keys_dict, wait_ok +from aioredis_cluster._aioredis.util import _set_exception, _set_result, decode, wait_ok from aioredis_cluster.abc import AbcChannel, AbcConnection -from aioredis_cluster.aioredis import Channel, ConnectionClosedError -from aioredis_cluster.aioredis import RedisConnection as BaseConnection +from aioredis_cluster.aioredis import ( + Channel, + ConnectionClosedError, + ConnectionForcedCloseError, + MaxClientsError, + ProtocolError, + ReadOnlyError, + ReplyError, + WatchVariableError, +) +from aioredis_cluster.aioredis.parser import Reader +from aioredis_cluster.aioredis.stream import StreamReader from aioredis_cluster.aioredis.util import _NOTSET from aioredis_cluster.command_info.commands import ( PING_COMMANDS, + PUBSUB_COMMAND_TO_TYPE, PUBSUB_FAMILY_COMMANDS, - SHARDED_PUBSUB_COMMANDS, + PUBSUB_RESP_KIND_TO_TYPE, + PUBSUB_SUBSCRIBE_COMMANDS, + PubSubType, ) -from aioredis_cluster.errors import RedisError +from aioredis_cluster.crc import CrossSlotError, determine_slot +from aioredis_cluster.errors import MovedError, RedisError +from aioredis_cluster.pubsub import PubSubStore from aioredis_cluster.typedef import PClosableConnection -from aioredis_cluster.util import encode_command +from aioredis_cluster.util import encode_command, ensure_bytes logger = logging.getLogger(__name__) +TExecuteCallback = Callable[[Any], Any] +TExecuteErrCallback = Callable[[Exception], Exception] + + +class ExecuteWaiter(NamedTuple): + fut: asyncio.Future + enc: Optional[str] + cb: Optional[TExecuteCallback] + err_cb: Optional[TExecuteErrCallback] + + +class PParserFactory(Protocol): + def __call__( + self, + protocolError: Callable = ProtocolError, + replyError: Callable = ReplyError, + encoding: Optional[str] = None, + ) -> Reader: + ... + + async def close_connections(conns: Iterable[PClosableConnection]) -> None: close_waiters = set() for conn in conns: @@ -30,17 +80,52 @@ async def close_connections(conns: Iterable[PClosableConnection]) -> None: await asyncio.wait(close_waiters) -class RedisConnection(BaseConnection, AbcConnection): - _in_pubsub: int - - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - +class RedisConnection(AbcConnection): + def __init__( + self, + reader: StreamReader, + writer: asyncio.StreamWriter, + *, + address: Union[Tuple[str, int], str], + encoding: Optional[str] = None, + parser: Optional[PParserFactory] = None, + loop: Optional[asyncio.AbstractEventLoop] = None, + ): + if loop is not None: + warnings.warn("The loop argument is deprecated", DeprecationWarning) + if parser is None: + parser = Reader + assert callable(parser), ("Parser argument is not callable", parser) + self._reader = reader + self._writer = writer + self._address = address + self._waiters: Deque[ExecuteWaiter] = deque() + self._reader.set_parser(parser(protocolError=ProtocolError, replyError=ReplyError)) + self._close_msg = "" + self._db = 0 + self._closing = False + self._closed = False + self._close_state = asyncio.Event() + self._in_transaction: Optional[Deque[Tuple[Optional[str], Optional[Callable]]]] = None + self._transaction_error: Optional[Exception] = None # XXX: never used? + self._pubsub_store = PubSubStore() + # client side PubSub mode flag + self._client_in_pubsub = False + # confirmed PubSub from Redis server via first subscribe reply + self._server_in_pubsub = False + + self._encoding = encoding + self._pipeline_buffer: Optional[bytearray] = None self._readonly = False - self._sharded_pubsub_channels = coerced_keys_dict() self._loop = asyncio.get_running_loop() self._last_use_generation = 0 + self._reader_task: Optional[asyncio.Task] = self._loop.create_task(self._read_data()) + self._reader_task.add_done_callback(self._on_reader_task_done) + + def __repr__(self): + return f"<{type(self).__name__} address:{self.address} db:{self.db}>" + @property def readonly(self) -> bool: return self._readonly @@ -57,14 +142,30 @@ async def set_readonly(self, value: bool) -> None: async def auth_with_username(self, username: str, password: str) -> bool: """Authenticate to server with username and password.""" - fut = self.execute("AUTH", username, password) + fut = self.execute(b"AUTH", username, password) return await wait_ok(fut) + @property + def pubsub_channels(self) -> Mapping[str, AbcChannel]: + """Returns read-only channels dict.""" + return self._pubsub_store.channels + + @property + def pubsub_patterns(self) -> Mapping[str, AbcChannel]: + """Returns read-only patterns dict.""" + return self._pubsub_store.patterns + @property def sharded_pubsub_channels(self) -> Mapping[str, AbcChannel]: - return MappingProxyType(self._sharded_pubsub_channels) + """Returns read-only sharded channels dict.""" + return self._pubsub_store.sharded + + async def auth(self, password: str) -> bool: + """Authenticate to server.""" + fut = self.execute(b"AUTH", password) + return await wait_ok(fut) - def execute(self, command, *args, encoding=_NOTSET): + def execute(self, command, *args, encoding=_NOTSET) -> Any: """Executes redis command and returns Future waiting for the answer. Raises: @@ -87,10 +188,14 @@ def execute(self, command, *args, encoding=_NOTSET): if command in PUBSUB_FAMILY_COMMANDS: raise ValueError(f"PUB/SUB command {command!r} is prohibited for use with .execute()") + if encoding is _NOTSET: + encoding = self._encoding + is_ping = command in PING_COMMANDS - if self._in_pubsub and not is_ping: - raise RedisError("Connection in SUBSCRIBE mode") + if not is_ping and (self._client_in_pubsub or self._server_in_pubsub): + raise RedisError("Connection in PubSub mode") + cb: Optional[TExecuteCallback] = None if command in ("SELECT", b"SELECT"): cb = partial(self._set_db, args=args) elif command in ("MULTI", b"MULTI"): @@ -100,58 +205,134 @@ def execute(self, command, *args, encoding=_NOTSET): encoding = None elif command in ("DISCARD", b"DISCARD"): cb = partial(self._end_transaction, discard=True) - else: - cb = None - if encoding is _NOTSET: - encoding = self._encoding - fut = self._loop.create_future() if self._pipeline_buffer is None: self._writer.write(encode_command(command, *args)) else: encode_command(command, *args, buf=self._pipeline_buffer) - self._waiters.append((fut, encoding, cb)) + + fut = self._loop.create_future() + self._waiters.append( + ExecuteWaiter( + fut=fut, + enc=encoding, + cb=cb, + err_cb=None, + ) + ) + return fut - def execute_pubsub(self, command, *channels): - """Executes redis (p)subscribe/(p)unsubscribe commands. + def execute_pubsub(self, command, *channels: Union[bytes, str, AbcChannel]): + """Executes redis (p|s)subscribe/(p|s)unsubscribe commands. Returns asyncio.gather coroutine waiting for all channels/patterns to receive answers. """ command = command.upper().strip() - assert command in PUBSUB_FAMILY_COMMANDS, ("Pub/Sub command expected", command) + if command not in PUBSUB_FAMILY_COMMANDS: + raise ValueError(f"Pub/Sub command expected, not {command!r}") if self._reader is None or self._reader.at_eof(): raise ConnectionClosedError("Connection closed or corrupted") if None in set(channels): raise TypeError("args must not contain None") - if not len(channels): - raise TypeError("No channels/patterns supplied") - if command in SHARDED_PUBSUB_COMMANDS: - is_pattern = False + channel_type = PUBSUB_COMMAND_TO_TYPE[command] + is_subscribe_command = command in PUBSUB_SUBSCRIBE_COMMANDS + is_pattern = channel_type is PubSubType.PATTERN + key_slot = -1 + reply_kind = ensure_bytes(command.lower()) + + channels_obj: List[AbcChannel] + if len(channels) == 0: + if is_subscribe_command: + raise ValueError("No channels to (un)subscribe") + elif channel_type is PubSubType.PATTERN: + channels_obj = list(self._pubsub_store.patterns.values()) + elif channel_type is PubSubType.SHARDED: + channels_obj = list(self._pubsub_store.sharded.values()) + else: + channels_obj = list(self._pubsub_store.channels.values()) else: - is_pattern = len(command) in (10, 12) - - mkchannel = partial(Channel, is_pattern=is_pattern) - channels_obj: List[AbcChannel] = [ - ch if isinstance(ch, AbcChannel) else mkchannel(ch) for ch in channels - ] - if not all(ch.is_pattern == is_pattern for ch in channels_obj): - raise ValueError("Not all channels {} match command {}".format(channels, command)) + mkchannel = partial(Channel, is_pattern=is_pattern) + channels_obj = [] + for channel_name_or_obj in channels: + if isinstance(channel_name_or_obj, AbcChannel): + ch = channel_name_or_obj + else: + ch = mkchannel(channel_name_or_obj) + # FIXME: processing duplicate channels totally broken in aioredis + # if ch.name in channels_obj: + # raise ValueError(f"Found channel duplicates in {channels!r}") + if ch.is_pattern != is_pattern: + raise ValueError(f"Not all channels {channels!r} match command {command!r}") + channels_obj.append(ch) + + if channel_type is PubSubType.SHARDED: + try: + key_slot = determine_slot(*(ensure_bytes(ch.name) for ch in channels_obj)) + except CrossSlotError: + raise ValueError( + f"Not all channels shared one key slot in cluster {channels!r}" + ) from None cmd = encode_command(command, *(ch.name for ch in channels_obj)) - res = [] - for ch in channels_obj: - fut = self._loop.create_future() - res.append(fut) - cb = partial(self._update_pubsub, ch=ch) - self._waiters.append((fut, None, cb)) + res: List[Any] = [] + + if is_subscribe_command: + for ch in channels_obj: + channel_name = ensure_bytes(ch.name) + self._pubsub_store.channel_subscribe( + channel_type=channel_type, + channel_name=channel_name, + channel=ch, + key_slot=key_slot, + ) + if channel_type is PubSubType.SHARDED: + channels_num = self._pubsub_store.sharded_channels_num + else: + channels_num = self._pubsub_store.channels_num + res.append([reply_kind, channel_name, channels_num]) + + # otherwise unsubscribe command + else: + for ch in channels_obj: + channel_name = ensure_bytes(ch.name) + self._pubsub_store.channel_unsubscribe( + channel_type=channel_type, + channel_name=channel_name, + by_reply=False, + ) + if channel_type is PubSubType.SHARDED: + channels_num = self._pubsub_store.sharded_channels_num + else: + channels_num = self._pubsub_store.channels_num + res.append([reply_kind, channel_name, channels_num]) + if self._pipeline_buffer is None: self._writer.write(cmd) else: self._pipeline_buffer.extend(cmd) - return asyncio.gather(*res) + + if not self._client_in_pubsub and not self._server_in_pubsub: + if is_subscribe_command: + # entering to PubSub mode on client side + self._client_in_pubsub = True + + fut = self._loop.create_future() + self._waiters.append( + ExecuteWaiter( + fut=fut, + enc=None, + cb=self._get_execute_pubsub_callback(command, res), + err_cb=self._get_execute_pubsub_err_callback(), + ) + ) + else: + fut = self._loop.create_future() + fut.set_result(res) + + return fut def get_last_use_generation(self) -> int: return self._last_use_generation @@ -159,77 +340,425 @@ def get_last_use_generation(self) -> int: def set_last_use_generation(self, gen: int): self._last_use_generation = gen - def _process_pubsub(self, obj, *, process_waiters: bool = True): - """Processes pubsub messages.""" - - if isinstance(obj, RedisError): - # case for new pubsub command for example: - # new ssubscribe to old node of slot - return self._process_data(obj) - - kind, *args, data = obj - if kind in (b"subscribe", b"unsubscribe"): - (chan,) = args - if process_waiters and self._in_pubsub and self._waiters: - self._process_data(obj) - if kind == b"unsubscribe": - ch = self._pubsub_channels.pop(chan, None) - if ch: - ch.close() - self._in_pubsub = data - elif kind in (b"ssubscribe", b"sunsubscribe"): - (chan,) = args - if process_waiters and self._in_pubsub and self._waiters: - self._process_data(obj) - if kind == b"sunsubscribe": - ch = self._sharded_pubsub_channels.pop(chan, None) - if ch: - ch.close() - self._in_pubsub = data - elif kind in (b"psubscribe", b"punsubscribe"): - (chan,) = args - if process_waiters and self._in_pubsub and self._waiters: - self._process_data(obj) - if kind == b"punsubscribe": - ch = self._pubsub_patterns.pop(chan, None) - if ch: - ch.close() - self._in_pubsub = data - elif kind == b"message": - (chan,) = args - self._pubsub_channels[chan].put_nowait(data) - elif kind == b"smessage": - (chan,) = args - self._sharded_pubsub_channels[chan].put_nowait(data) - elif kind == b"pmessage": - pattern, chan = args - self._pubsub_patterns[pattern].put_nowait((chan, data)) + def close(self) -> None: + """Close connection.""" + self._do_close(ConnectionForcedCloseError()) + + @property + def closed(self) -> bool: + """True if connection is closed.""" + closed = self._closing or self._closed + if not closed and self._reader and self._reader.at_eof(): + self._closing = closed = True + self._loop.call_soon(self._do_close, None) + return closed + + async def wait_closed(self) -> None: + """Coroutine waiting until connection is closed.""" + await self._close_state.wait() + + @property + def db(self) -> int: + """Currently selected db index.""" + return self._db + + @property + def encoding(self) -> Optional[str]: + """Current set codec or None.""" + return self._encoding + + @property + def address(self) -> Union[Tuple[str, int], str]: + """Redis server address, either host-port tuple or str.""" + return self._address + + @property + def in_transaction(self) -> bool: + """Set to True when MULTI command was issued.""" + return self._in_transaction is not None + + @property + def in_pubsub(self) -> int: + """Indicates that connection is in PUB/SUB mode. + + This implementation NOT provides the number of subscribed channels + and provides only boolean flag + """ + return int(self._client_in_pubsub) + + async def select(self, db: int) -> bool: + """Change the selected database for the current connection.""" + if not isinstance(db, int): + raise TypeError("DB must be of int type, not {!r}".format(db)) + if db < 0: + raise ValueError("DB must be greater or equal 0, got {!r}".format(db)) + fut = self.execute(b"SELECT", db) + return await wait_ok(fut) + + def _set_db(self, ok, args): + assert ok in {b"OK", "OK"}, ("Unexpected result of SELECT", ok) + self._db = args[0] + return ok + + def _start_transaction(self, ok): + if self._in_transaction is not None: + raise RuntimeError("Connection is already in transaction") + self._in_transaction = deque() + self._transaction_error = None + return ok + + def _end_transaction(self, obj: Any, discard: bool) -> Any: + if self._in_transaction is None: + raise RuntimeError("Connection is not in transaction") + self._transaction_error = None + recall, self._in_transaction = self._in_transaction, None + recall.popleft() # ignore first (its _start_transaction) + if discard: + return obj + + if not (isinstance(obj, list) or (obj is None and not discard)): + raise RuntimeError(f"Unexpected MULTI/EXEC result: {obj!r}, {recall!r}") + + # TODO: need to be able to re-try transaction + if obj is None: + err = WatchVariableError("WATCH variable has changed") + obj = [err] * len(recall) + + if len(obj) != len(recall): + raise RuntimeError(f"Wrong number of result items in mutli-exec: {obj!r}, {recall!r}") + + res = [] + for o, (encoding, cb) in zip(obj, recall): + if not isinstance(o, RedisError): + try: + if encoding: + o = decode(o, encoding) + if cb: + o = cb(o) + except Exception as err: + res.append(err) + continue + res.append(o) + return res + + def _do_close(self, exc: Optional[BaseException]) -> None: + if self._closed: + return + self._closed = True + self._closing = False + self._writer.transport.close() + if self._reader_task is not None: + self._reader_task.cancel() + self._reader_task = None + del self._writer + del self._reader + self._pipeline_buffer = None + + if exc is not None: + self._close_msg = str(exc) + + while self._waiters: + waiter = self._waiters.popleft() + logger.debug("Cancelling waiter %r", waiter) + if exc is None: + _set_exception(waiter.fut, ConnectionForcedCloseError()) + else: + _set_exception(waiter.fut, exc) + + self._pubsub_store.close(exc) + + def _on_reader_task_done(self, task: asyncio.Task) -> None: + if not task.cancelled() and task.exception(): + logger.error( + "Reader task unexpectedly done with expection: %r", + task.exception(), + exc_info=task.exception(), + ) + # prevent RedisConnection stuck in half-closed state + self._reader_task = None + self._do_close(ConnectionForcedCloseError()) + self._close_state.set() + + def _is_pubsub_resp(self, obj: Any) -> bool: + if not isinstance(obj, (tuple, list)): + return False + if len(obj) == 0: + return False + return obj[0] in PUBSUB_RESP_KIND_TO_TYPE + + async def _read_data(self) -> None: + """Response reader task.""" + last_error = ConnectionClosedError("Connection has been closed by server") + while not self._reader.at_eof(): + try: + obj = await self._reader.readobj() + except asyncio.CancelledError: + # NOTE: reader can get cancelled from `close()` method only. + last_error = RuntimeError("this is unexpected") + break + except ProtocolError as exc: + # ProtocolError is fatal + # so connection must be closed + if self._in_transaction is not None: + self._transaction_error = exc + last_error = exc + break + except Exception as exc: + # NOTE: for QUIT command connection error can be received + # before response + last_error = exc + break + else: + if (obj == b"" or obj is None) and self._reader.at_eof(): + logger.debug("Connection has been closed by server, response: %r", obj) + last_error = ConnectionClosedError("Reader at end of file") + break + + if isinstance(obj, MaxClientsError): + last_error = obj + break + + if self._loop.get_debug(): + logger.debug( + "Received reply (client_in_pubsub:%s, server_in_pubsub:%s): %r", + self._client_in_pubsub, + self._server_in_pubsub, + obj, + ) + + if self._server_in_pubsub: + if isinstance(obj, MovedError): + if self._pubsub_store.have_slot_channels(obj.info.slot_id): + logger.warning( + ( + "Received MOVED in PubSub mode from %s to %s:%s. " + "Unsubscribe all channels from %d slot", + ), + self.address, + obj.info.host, + obj.info.port, + obj.info.slot_id, + ) + self._pubsub_store.slot_channels_unsubscribe(obj.info.slot_id) + elif isinstance(obj, RedisError): + raise obj + else: + self._process_pubsub(obj) + else: + if isinstance(obj, RedisError): + if isinstance(obj, MovedError): + if self._pubsub_store.have_slot_channels(obj.info.slot_id): + logger.warning( + ( + "Received MOVED from %s to %s:%s. " + "Unsubscribe all channels from %d slot", + ), + self.address, + obj.info.host, + obj.info.port, + obj.info.slot_id, + ) + self._pubsub_store.slot_channels_unsubscribe(obj.info.slot_id) + elif isinstance(obj, ReplyError): + if obj.args[0].startswith("READONLY"): + obj = ReadOnlyError(obj.args[0]) + self._wakeup_waiter_with_exc(obj) + else: + self._wakeup_waiter_with_result(obj) + + self._closing = True + self._loop.call_soon(self._do_close, last_error) + + def _wakeup_waiter_with_exc(self, exc: Exception) -> None: + """Processes command errors.""" + + if not self._waiters: + logger.error("No waiter for process error: %r", exc) + return + + waiter = self._waiters.popleft() + + if waiter.err_cb is not None: + try: + exc = waiter.err_cb(exc) + except Exception as cb_exc: + logger.exception("Waiter error callback failed with exception: %r", cb_exc) + exc = cb_exc + + _set_exception(waiter.fut, exc) + if self._in_transaction is not None: + self._transaction_error = exc + + def _wakeup_waiter_with_result(self, result: Any) -> None: + """Processes command results.""" + + if self._loop.get_debug(): # pragma: no cover + logger.debug("Wakeup first waiter for reply: %r", result) + + if not self._waiters: + logger.error("No waiter for received reply: %r, %r", type(result), result) + return + + waiter = self._waiters.popleft() + self._resolve_waiter_with_result(waiter, result) + + def _resolve_waiter_with_result(self, waiter: ExecuteWaiter, result: Any) -> None: + if waiter.enc is not None: + try: + decoded_result = decode(result, waiter.enc) + except Exception as exc: + _set_exception(waiter.fut, exc) + return + else: + decoded_result = result + + del result + + if waiter.cb is not None: + try: + converted_result = waiter.cb(decoded_result) + except Exception as exc: + _set_exception(waiter.fut, exc) + return + else: + converted_result = decoded_result + + del decoded_result + + _set_result(waiter.fut, converted_result) + if self._in_transaction is not None: + self._in_transaction.append((waiter.enc, waiter.cb)) + + def _process_pubsub(self, obj: Any) -> Any: + """Processes pubsub messages. + + This method calls directly on `_read_data` routine + and used as callback in `execute_pubsub` for first PubSub mode initial reply + """ + + if self._loop.get_debug(): # pragma: no cover + logger.debug( + "Process PubSub reply (client_in_pubsub:%s, server_in_pubsub:%s): %r", + self._client_in_pubsub, + self._server_in_pubsub, + obj, + ) + + if isinstance(obj, bytes): + # process simple bytes as PING reply + kind = b"pong" + data = obj + else: + kind, *args, data = obj + + channel_name: bytes + + if kind in {b"subscribe", b"psubscribe", b"ssubscribe"}: + logger.debug("PubSub %s event received: %r", kind, obj) + (channel_name,) = args + channel_type = PUBSUB_RESP_KIND_TO_TYPE[kind] + if self._client_in_pubsub and not self._server_in_pubsub: + self._server_in_pubsub = True + # confirm PubSub mode in client side based on server reply and reset pending flag + self._pubsub_store.confirm_subscribe(channel_type, channel_name) + elif kind in {b"unsubscribe", b"punsubscribe", b"sunsubscribe"}: + logger.debug("PubSub %s event received: %r", kind, obj) + (channel_name,) = args + channel_type = PUBSUB_RESP_KIND_TO_TYPE[kind] + self._pubsub_store.channel_unsubscribe( + channel_type=channel_type, + channel_name=channel_name, + by_reply=True, + ) + elif kind in {b"message", b"smessage", b"pmessage"}: + if kind == b"pmessage": + (pattern, channel_name) = args + else: + (channel_name,) = args + pattern = channel_name + + channel_type = PUBSUB_RESP_KIND_TO_TYPE[kind] + if self._pubsub_store.has_channel(channel_type, pattern): + channel = self._pubsub_store.get_channel(channel_type, pattern) + if channel_type is PubSubType.PATTERN: + channel.put_nowait((channel_name, data)) + else: + channel.put_nowait(data) + else: + logger.warning( + "No channel %r with type %s for received message", pattern, channel_type + ) elif kind == b"pong": - if process_waiters and self._in_pubsub and self._waiters: - self._process_data(data or b"PONG") + if not self._waiters: + logger.error("No PubSub PONG waiters for received data %r", data) + else: + # in PubSub mode only PING waiters in this deque + # see in execute() method `is_ping` condition + waiter = self._waiters.popleft() + self._resolve_waiter_with_result(waiter, data or b"PONG") else: logger.warning("Unknown pubsub message received %r", obj) - def _update_pubsub(self, obj, *, ch: AbcChannel): - kind, *pattern, channel, subscriptions = obj - self._in_pubsub, was_in_pubsub = subscriptions, self._in_pubsub - if kind == b"subscribe" and channel not in self._pubsub_channels: - self._pubsub_channels[channel] = ch - elif kind == b"psubscribe" and channel not in self._pubsub_patterns: - self._pubsub_patterns[channel] = ch - elif kind == b"ssubscribe" and channel not in self._sharded_pubsub_channels: - self._sharded_pubsub_channels[channel] = ch - if not was_in_pubsub: - self._process_pubsub(obj, process_waiters=False) return obj - def _do_close(self, exc): - super()._do_close(exc) + @contextmanager + def _buffered(self): + # XXX: we must ensure that no await happens + # as long as we buffer commands. + # Probably we can set some error-raising callback on enter + # and remove it on exit + # if some await happens in between -> throw an error. + # This is creepy solution, 'cause some one might want to await + # on some other source except redis. + # So we must only raise error we someone tries to await + # pending aioredis future + # One of solutions is to return coroutine instead of a future + # in `execute` method. + # In a coroutine we can check if buffering is enabled and raise error. + + # TODO: describe in docs difference in pipeline mode for + # conn.execute vs pipeline.execute() + if self._pipeline_buffer is None: + self._pipeline_buffer = bytearray() + try: + yield self + buf = self._pipeline_buffer + self._writer.write(buf) + finally: + self._pipeline_buffer = None + else: + yield self + + def _get_execute_pubsub_callback( + self, command: Union[str, bytes], expect_replies: List[Any] + ) -> TExecuteCallback: + def callback(server_reply: Any) -> Any: + # this callback processing only first reply on (p|s)(un)subscribe commands - if self._closed: - return + server_reply = self._process_pubsub(server_reply) + + if list(server_reply) != expect_replies[0]: + if logger.isEnabledFor(logging.DEBUG): + logger.error( + "Unexpected server reply on PubSub on %r: %r, expected %r", + command, + server_reply, + expect_replies[0], + ) + + exc = RedisError(f"Unexpected server reply on PubSub {command!r}") + self._loop.call_soon(self._do_close, exc) + raise exc + + return expect_replies + + return callback + + def _get_execute_pubsub_err_callback(self) -> TExecuteErrCallback: + def callback(exc: Exception) -> Exception: + if isinstance(exc, ReplyError): + # return PubSub mode to closed state if any reply error received + self._client_in_pubsub = False + return exc - while self._sharded_pubsub_channels: - _, ch = self._sharded_pubsub_channels.popitem() - logger.debug("Closing sharded pubsub channel %r", ch) - ch.close(exc) + return callback diff --git a/src/aioredis_cluster/crc.py b/src/aioredis_cluster/crc.py index 7123576..4bc6e77 100644 --- a/src/aioredis_cluster/crc.py +++ b/src/aioredis_cluster/crc.py @@ -10,6 +10,8 @@ __all__ = ( "crc16", "key_slot", + "determine_slot", + "CrossSlotError", ) REDIS_CLUSTER_HASH_SLOTS = 16384 @@ -43,3 +45,16 @@ def py_key_slot(k: bytes, bucket: int = REDIS_CLUSTER_HASH_SLOTS) -> int: key_slot = cy_key_slot else: key_slot = py_key_slot + + +class CrossSlotError(Exception): + pass + + +def determine_slot(first_key: bytes, *keys: bytes) -> int: + slot: int = key_slot(first_key) + for k in keys: + if slot != key_slot(k): + raise CrossSlotError("all keys must map to the same key slot") + + return slot diff --git a/src/aioredis_cluster/pool.py b/src/aioredis_cluster/pool.py index 2d962ed..6bde084 100644 --- a/src/aioredis_cluster/pool.py +++ b/src/aioredis_cluster/pool.py @@ -3,17 +3,18 @@ import logging import random import types -from typing import Deque, Dict, List, Mapping, Optional, Set, Tuple, Type, Union +from typing import Deque, List, Mapping, Optional, Set, Tuple, Type, Union from aioredis_cluster._aioredis.pool import ( _AsyncConnectionContextManager, _ConnectionContextManager, ) from aioredis_cluster._aioredis.util import CloseEvent -from aioredis_cluster.abc import AbcConnection, AbcPool -from aioredis_cluster.aioredis import Channel, PoolClosedError, create_connection +from aioredis_cluster.abc import AbcChannel, AbcConnection, AbcPool +from aioredis_cluster.aioredis import PoolClosedError, create_connection from aioredis_cluster.command_info.commands import ( BLOCKING_COMMANDS, + PATTERN_PUBSUB_COMMANDS, PUBSUB_COMMANDS, PUBSUB_FAMILY_COMMANDS, SHARDED_PUBSUB_COMMANDS, @@ -226,7 +227,7 @@ async def execute_pubsub(self, command: TBytesOrStr, *channels): self._check_closed() - if command in PUBSUB_COMMANDS: + if command in PUBSUB_COMMANDS or command in PATTERN_PUBSUB_COMMANDS: conn = await self._get_pubsub_connection() elif command in SHARDED_PUBSUB_COMMANDS: conn = await self._get_sharded_pubsub_connection() @@ -246,7 +247,7 @@ def get_connection(self, command: TBytesOrStr, args=()): command = command.upper().strip() ret_conn: Optional[AbcConnection] = None - if command in PUBSUB_COMMANDS: + if command in PUBSUB_COMMANDS or command in PATTERN_PUBSUB_COMMANDS: if self._pubsub_conn and not self._pubsub_conn.closed: ret_conn = self._pubsub_conn elif command in SHARDED_PUBSUB_COMMANDS: @@ -267,18 +268,23 @@ def get_connection(self, command: TBytesOrStr, args=()): return ret_conn, self._address - async def select(self, db): - """For cluster implementation this method is unavailable""" + async def select(self, db: int) -> bool: + """Changes db index for all free connections. + + All previously acquired connections will be closed when released. + + For cluster implementation this method is unavailable + """ raise NotImplementedError("Feature is blocked in cluster mode") - async def auth(self, password) -> None: + async def auth(self, password: str) -> None: self._password = password async with self._cond: for conn in tuple(self._pool): conn.set_last_use_generation(self._idle_connections_collect_gen) await conn.auth(password) - async def auth_with_username(self, username, password) -> None: + async def auth_with_username(self, username: str, password: str) -> None: self._username = username self._password = password async with self._cond: @@ -299,29 +305,30 @@ async def set_readonly(self, value: bool) -> None: @property def in_pubsub(self) -> int: - in_pubsub = 0 - if self._pubsub_conn and not self._pubsub_conn.closed: - in_pubsub += self._pubsub_conn.in_pubsub - if self._sharded_pubsub_conn and not self._sharded_pubsub_conn.closed: - in_pubsub += self._sharded_pubsub_conn.in_pubsub - return in_pubsub + if self._pubsub_conn and not self._pubsub_conn.closed and self._pubsub_conn.in_pubsub: + return 1 + if ( + self._sharded_pubsub_conn + and not self._sharded_pubsub_conn.closed + and self._sharded_pubsub_conn.in_pubsub + ): + return 1 + return 0 @property - def pubsub_channels(self) -> Mapping[str, Channel]: - channels: Dict[str, Channel] = {} + def pubsub_channels(self) -> Mapping[str, AbcChannel]: if self._pubsub_conn and not self._pubsub_conn.closed: - channels.update(self._pubsub_conn.pubsub_channels) - return types.MappingProxyType(channels) + return self._pubsub_conn.pubsub_channels + return types.MappingProxyType({}) @property - def sharded_pubsub_channels(self) -> Mapping[str, Channel]: - channels: Dict[str, Channel] = {} + def sharded_pubsub_channels(self) -> Mapping[str, AbcChannel]: if self._sharded_pubsub_conn and not self._sharded_pubsub_conn.closed: - channels.update(self._sharded_pubsub_conn.sharded_pubsub_channels) - return types.MappingProxyType(channels) + return self._sharded_pubsub_conn.sharded_pubsub_channels + return types.MappingProxyType({}) @property - def pubsub_patterns(self): + def pubsub_patterns(self) -> Mapping[str, AbcChannel]: if self._pubsub_conn and not self._pubsub_conn.closed: return self._pubsub_conn.pubsub_patterns return types.MappingProxyType({}) @@ -375,7 +382,7 @@ def release(self, conn: AbcConnection) -> None: logger.warning("Connection %r is in transaction, closing it.", conn) conn.close() elif conn.in_pubsub: - logger.warning("Connection %r is in subscribe mode, closing it.", conn) + logger.warning("Connection %r is in PubSub mode, closing it.", conn) conn.close() elif conn._waiters: logger.warning("Connection %r has pending commands, closing it.", conn) diff --git a/src/aioredis_cluster/pubsub.py b/src/aioredis_cluster/pubsub.py new file mode 100644 index 0000000..fe69cfa --- /dev/null +++ b/src/aioredis_cluster/pubsub.py @@ -0,0 +1,191 @@ +import logging +from types import MappingProxyType +from typing import Dict, Mapping, Optional, Set, Tuple + +from aioredis_cluster._aioredis.util import coerced_keys_dict +from aioredis_cluster.abc import AbcChannel +from aioredis_cluster.command_info.commands import PubSubType +from aioredis_cluster.crc import key_slot as calc_key_slot + +logger = logging.getLogger(__name__) + + +class PubSubStore: + def __init__(self) -> None: + self._channels: coerced_keys_dict[AbcChannel] = coerced_keys_dict() + self._patterns: coerced_keys_dict[AbcChannel] = coerced_keys_dict() + self._sharded: coerced_keys_dict[AbcChannel] = coerced_keys_dict() + self._slot_to_sharded: Dict[int, Set[bytes]] = {} + self._unconfirmed_subscribes: Dict[Tuple[PubSubType, bytes], int] = {} + + @property + def channels(self) -> Mapping[str, AbcChannel]: + """Returns read-only channels dict.""" + return MappingProxyType(self._channels) + + @property + def patterns(self) -> Mapping[str, AbcChannel]: + """Returns read-only patterns dict.""" + return MappingProxyType(self._patterns) + + @property + def sharded(self) -> Mapping[str, AbcChannel]: + """Returns read-only sharded channels dict.""" + return MappingProxyType(self._sharded) + + def channel_subscribe( + self, + *, + channel_type: PubSubType, + channel_name: bytes, + channel: AbcChannel, + key_slot: int, + ) -> None: + if channel_type is PubSubType.CHANNEL: + if channel_name not in self._channels: + self._channels[channel_name] = channel + elif channel_type is PubSubType.PATTERN: + if channel_name not in self._patterns: + self._patterns[channel_name] = channel + elif channel_type is PubSubType.SHARDED: + if key_slot < 0: + raise ValueError("key_slot cannot be negative for sharded channel") + + if channel_name not in self._sharded: + self._sharded[channel_name] = channel + if key_slot not in self._slot_to_sharded: + self._slot_to_sharded[key_slot] = {channel_name} + else: + self._slot_to_sharded[key_slot].add(channel_name) + else: # pragma: no cover + raise RuntimeError(f"Unexpected channel_type {channel_type}") + + unconfirmed_key = (channel_type, channel_name) + if unconfirmed_key not in self._unconfirmed_subscribes: + self._unconfirmed_subscribes[unconfirmed_key] = 1 + else: + self._unconfirmed_subscribes[unconfirmed_key] += 1 + + def confirm_subscribe(self, channel_type: PubSubType, channel_name: bytes) -> None: + unconfirmed_key = (channel_type, channel_name) + val = self._unconfirmed_subscribes.get(unconfirmed_key) + if val is not None: + val -= 1 + if val < 0: + logger.error("Too much PubSub subscribe confirms for %r", unconfirmed_key) + val = 0 + + if val == 0: + del self._unconfirmed_subscribes[unconfirmed_key] + if channel_type is PubSubType.SHARDED: + # this is counterpart of channel_unsubscribe() + # we must remove key_slot -> channel_name entry + # if not channels objects exists + if channel_name not in self._sharded: + self._remove_channel_from_slot_map(channel_name) + else: + self._unconfirmed_subscribes[unconfirmed_key] = val + else: + logger.error("Unexpected PubSub subscribe confirm for %r", unconfirmed_key) + + def channel_unsubscribe( + self, + *, + channel_type: PubSubType, + channel_name: bytes, + by_reply: bool, + ) -> None: + have_unconfirmed_subs = (channel_type, channel_name) in self._unconfirmed_subscribes + # if receive (p|s)unsubscribe reply from Redis + # and make several sequently subscribe->unsubscribe->subscribe commands + # - we must ignore all server unsubscribe replies until all subscribes is confirmed + if by_reply and have_unconfirmed_subs: + return + + channel: Optional[AbcChannel] = None + if channel_type is PubSubType.CHANNEL: + channel = self._channels.pop(channel_name, None) + elif channel_type is PubSubType.PATTERN: + channel = self._patterns.pop(channel_name, None) + elif channel_type is PubSubType.SHARDED: + channel = self._sharded.pop(channel_name, None) + # we must remove key_slot -> channel_name entry + # only if all subscription is confirmed + if not have_unconfirmed_subs: + self._remove_channel_from_slot_map(channel_name) + + if channel is not None: + channel.close() + + def slot_channels_unsubscribe(self, key_slot: int) -> None: + channel_names = self._slot_to_sharded.pop(key_slot, None) + if channel_names is None: + return + + while channel_names: + channel_name = channel_names.pop() + self._unconfirmed_subscribes.pop((PubSubType.SHARDED, channel_name), None) + channel = self._sharded.pop(channel_name, None) + if channel is not None: + channel.close() + + def have_slot_channels(self, key_slot: int) -> bool: + return key_slot in self._slot_to_sharded + + @property + def channels_total(self) -> int: + return len(self._channels) + len(self._patterns) + len(self._sharded) + + @property + def channels_num(self) -> int: + return len(self._channels) + len(self._patterns) + + @property + def sharded_channels_num(self) -> int: + return len(self._sharded) + + def has_channel(self, channel_type: PubSubType, channel_name: bytes) -> bool: + ret = False + if channel_type is PubSubType.CHANNEL: + ret = channel_name in self._channels + elif channel_type is PubSubType.PATTERN: + ret = channel_name in self._patterns + elif channel_type is PubSubType.SHARDED: + ret = channel_name in self._sharded + return ret + + def get_channel(self, channel_type: PubSubType, channel_name: bytes) -> AbcChannel: + if channel_type is PubSubType.CHANNEL: + channel = self._channels[channel_name] + elif channel_type is PubSubType.PATTERN: + channel = self._patterns[channel_name] + elif channel_type is PubSubType.SHARDED: + channel = self._sharded[channel_name] + else: # pragma: no cover + raise RuntimeError(f"Unexpected channel type {channel_type!r}") + return channel + + def close(self, exc: Optional[BaseException]) -> None: + while self._channels: + _, ch = self._channels.popitem() + logger.debug("Closing pubsub channel %r", ch) + ch.close(exc) + while self._patterns: + _, ch = self._patterns.popitem() + logger.debug("Closing pubsub pattern %r", ch) + ch.close(exc) + while self._sharded: + _, ch = self._sharded.popitem() + logger.debug("Closing sharded pubsub channel %r", ch) + ch.close(exc) + + self._slot_to_sharded.clear() + self._unconfirmed_subscribes.clear() + + def _remove_channel_from_slot_map(self, channel_name: bytes) -> None: + key_slot = calc_key_slot(channel_name) + if key_slot in self._slot_to_sharded: + channels_set = self._slot_to_sharded[key_slot] + channels_set.discard(channel_name) + if len(channels_set) == 0: + del self._slot_to_sharded[key_slot] diff --git a/tests/aioredis_tests/conftest.py b/tests/aioredis_tests/conftest.py index b4a750e..2daef05 100644 --- a/tests/aioredis_tests/conftest.py +++ b/tests/aioredis_tests/conftest.py @@ -9,6 +9,8 @@ import tempfile import time from collections import namedtuple +from functools import partial +from typing import List from urllib.parse import urlencode, urlunparse import pytest @@ -16,7 +18,18 @@ from aioredis_cluster import _aioredis as aioredis from aioredis_cluster._aioredis import sentinel as aioredis_sentinel +from aioredis_cluster.aioredis import create_redis as custom_create_redis from aioredis_cluster.compat.asyncio import timeout as atimeout +from aioredis_cluster.connection import RedisConnection + +try: + import aioredis as _origin_aioredis + + (_origin_aioredis,) + aioredis_installed = True +except ImportError: + aioredis_installed = False + TCPAddress = namedtuple("TCPAddress", "host port") @@ -59,8 +72,25 @@ async def f(*args, **kw): return f +create_redis_fixture_params: List = [ + aioredis.create_redis, + aioredis.create_redis_pool, +] +create_redis_fixture_ids: List = [ + "single", + "pool", +] + +if aioredis_installed is False: + create_redis_fixture_params.append(partial(custom_create_redis, connection_cls=RedisConnection)) + create_redis_fixture_ids.append( + "single_from_cluster", + ) + + @pytest_asyncio.fixture( - params=[aioredis.create_redis, aioredis.create_redis_pool], ids=["single", "pool"] + params=create_redis_fixture_params, + ids=create_redis_fixture_ids, ) def create_redis(_closable, request): """Wrapper around aioredis.create_redis.""" diff --git a/tests/aioredis_tests/pubsub_commands_test.py b/tests/aioredis_tests/pubsub_commands_test.py index 4d96f05..fd54989 100644 --- a/tests/aioredis_tests/pubsub_commands_test.py +++ b/tests/aioredis_tests/pubsub_commands_test.py @@ -4,6 +4,7 @@ from _testutils import redis_version from aioredis_cluster import _aioredis as aioredis +from aioredis_cluster.connection import RedisConnection as ClusterConnection async def _reader(channel, output, waiter, conn): @@ -49,7 +50,12 @@ async def test_publish_json(create_connection, redis, server): async def test_subscribe(redis): res = await redis.subscribe("chan:1", "chan:2") - assert redis.in_pubsub == 2 + + assert len(redis.connection.pubsub_channels) == 2 + if isinstance(redis.connection, ClusterConnection): + assert redis.in_pubsub == 1 + else: + assert redis.in_pubsub == 2 ch1 = redis.channels["chan:1"] ch2 = redis.channels["chan:2"] @@ -62,6 +68,39 @@ async def test_subscribe(redis): assert res == [[b"unsubscribe", b"chan:1", 1], [b"unsubscribe", b"chan:2", 0]] +async def test_subscribe__multiple_times(redis): + res1 = await redis.subscribe("chan:1") + assert redis.in_pubsub == 1 + res2 = await redis.subscribe("chan:1") + assert redis.in_pubsub == 1 + res3 = await redis.psubscribe("chan:1") + + assert len(redis.connection.pubsub_channels) == 1 + assert len(redis.connection.pubsub_patterns) == 1 + if isinstance(redis.connection, ClusterConnection): + assert redis.in_pubsub == 1 + else: + assert redis.in_pubsub == 2 + # res4 = await redis.connection.execute_pubsub("SSUBSCRIBE", "chan:1") + # assert redis.in_pubsub == 3 + + ch1 = redis.channels["chan:1"] + ch3 = redis.patterns["chan:1"] + + assert res1 == [ch1] + assert res2 == [ch1] + assert res3 == [ch3] + + res = await redis.unsubscribe("chan:1") + assert res == [[b"unsubscribe", b"chan:1", 1]] + + res = await redis.punsubscribe("chan:1") + assert res == [[b"punsubscribe", b"chan:1", 0]] + + res = await redis.unsubscribe("chan:1") + assert res == [[b"unsubscribe", b"chan:1", 0]] + + @pytest.mark.parametrize( "create_redis", [ @@ -90,7 +129,10 @@ async def test_subscribe_empty_pool(create_redis, server, _closable): async def test_psubscribe(redis, create_redis, server): sub = redis res = await sub.psubscribe("patt:*", "chan:*") - assert sub.in_pubsub == 2 + if isinstance(redis.connection, ClusterConnection): + assert sub.in_pubsub == 1 + else: + assert sub.in_pubsub == 2 pat1 = sub.patterns["patt:*"] pat2 = sub.patterns["chan:*"] @@ -122,7 +164,12 @@ async def test_psubscribe_empty_pool(create_redis, server, _closable): _closable(pub) await sub.connection.clear() res = await sub.psubscribe("patt:*", "chan:*") - assert sub.in_pubsub == 2 + + assert len(sub.connection.pubsub_patterns) == 2 + if isinstance(sub.connection, ClusterConnection): + assert sub.in_pubsub == 1 + else: + assert sub.in_pubsub == 2 pat1 = sub.patterns["patt:*"] pat2 = sub.patterns["chan:*"] diff --git a/tests/system_tests/test_connection_pubsub.py b/tests/system_tests/test_connection_pubsub.py index 30e6e12..3ad5a5a 100644 --- a/tests/system_tests/test_connection_pubsub.py +++ b/tests/system_tests/test_connection_pubsub.py @@ -1,20 +1,87 @@ +import asyncio from string import ascii_letters +from typing import Awaitable, Callable import pytest -from aioredis_cluster.errors import MovedError +from aioredis_cluster import Cluster +from aioredis_cluster.compat.asyncio import timeout @pytest.mark.redis_version(gte="7.0.0") @pytest.mark.timeout(2) -async def test_moved_with_pubsub(cluster): - c = await cluster() - redis = await c.keys_master("a") +async def test_moved_with_pubsub(cluster: Callable[[], Awaitable[Cluster]]): + client = await cluster() + redis = await client.keys_master("a") await redis.ssubscribe("a") - with pytest.raises(MovedError): - for b in ascii_letters: - await redis.ssubscribe(b) + assert "a" in redis.sharded_pubsub_channels + assert b"a" in redis.sharded_pubsub_channels - redis.close() - await redis.wait_closed() + channels_dump = {} + for char in ascii_letters: + await redis.ssubscribe(char) + # Channel objects creates immediately and close and removed after received MOVED reply + channels_dump[char] = redis.sharded_pubsub_channels[char] + + await asyncio.sleep(0.01) + + # check number of created Channel objects + assert len(channels_dump) == len(ascii_letters) + + # check of closed Channels after MovedError reply received + num_of_closed = 0 + for channel in channels_dump.values(): + if not channel.is_active: + num_of_closed += 1 + + assert num_of_closed > 0 + assert len(redis.sharded_pubsub_channels) < len(channels_dump) + + client.close() + await client.wait_closed() + + +@pytest.mark.redis_version(gte="7.0.0") +@pytest.mark.timeout(2) +async def test_immediately_resubscribe(cluster: Callable[[], Awaitable[Cluster]]): + client = await cluster() + redis = await client.keys_master("chan") + for i in range(10): + await redis.ssubscribe("chan") + await redis.sunsubscribe("chan") + chan = (await redis.ssubscribe("chan"))[0] + + chan_get_task = asyncio.ensure_future(chan.get()) + + await client.execute("spublish", "chan", "msg1") + await client.execute("spublish", "chan", "msg2") + await client.execute("spublish", "chan", "msg3") + # wait 50ms until Redis message delivery + await asyncio.sleep(0.05) + + assert chan_get_task.done() is True + assert chan_get_task.result() == b"msg1" + + # message must be already exists in internal queue + async with timeout(0): + msg2 = await chan.get() + assert msg2 == b"msg2" + + async with timeout(0): + msg3 = await chan.get() + assert msg3 == b"msg3" + + # no more messages + with pytest.raises(asyncio.TimeoutError): + async with timeout(0): + await chan.get() + + assert chan.is_active is True + + await redis.sunsubscribe("chan") + + assert chan.is_active is False + + client.close() + await client.wait_closed() diff --git a/tests/system_tests/test_redis_cluster.py b/tests/system_tests/test_redis_cluster.py index 3a97fd4..1942b8c 100644 --- a/tests/system_tests/test_redis_cluster.py +++ b/tests/system_tests/test_redis_cluster.py @@ -156,7 +156,7 @@ async def test_sharded_pubsub(redis_cluster): ch2: Channel = channels[0] assert len(cl1.sharded_pubsub_channels) == 2 - assert cl1.in_pubsub == 2 + assert cl1.in_pubsub == 1 assert len(cl1.channels) == 0 assert len(cl1.patterns) == 0 @@ -173,7 +173,7 @@ async def test_sharded_pubsub(redis_cluster): await cl1.sunsubscribe("channel2") assert len(cl1.sharded_pubsub_channels) == 0 - assert cl1.in_pubsub == 0 + assert cl1.in_pubsub == 1 @pytest.mark.redis_version(gte="7.0.0") @@ -186,7 +186,7 @@ async def test_sharded_pubsub__multiple_subscribe(redis_cluster): ch3: Channel = (await cl1.ssubscribe("channel:{shard_key}:3"))[0] assert len(cl1.sharded_pubsub_channels) == 3 - assert cl1.in_pubsub == 3 + assert cl1.in_pubsub == 1 shard_pool = await cl1.keys_master("{shard_key}") assert len(shard_pool.sharded_pubsub_channels) == 3 diff --git a/tests/unit_tests/aioredis_cluster/conftest.py b/tests/unit_tests/aioredis_cluster/conftest.py index 2a4b639..b2c4887 100644 --- a/tests/unit_tests/aioredis_cluster/conftest.py +++ b/tests/unit_tests/aioredis_cluster/conftest.py @@ -1,9 +1,30 @@ -# import asyncio +import pytest +import pytest_asyncio -# import pytest +@pytest.fixture +def add_finalizer(): + finalizers = [] -# def pytest_collection_modifyitems(items): -# for item in items: -# if not item.get_closest_marker("asyncio") and asyncio.iscoroutinefunction(item.function): -# item.add_marker(pytest.mark.asyncio) + def adder(finalizer): + finalizers.append(finalizer) + + try: + yield adder + finally: + for finalizer in finalizers: + finalizer() + + +@pytest_asyncio.fixture +async def add_async_finalizer(): + finalizers = [] + + def adder(finalizer): + finalizers.append(finalizer) + + try: + yield adder + finally: + for finalizer in finalizers: + await finalizer() diff --git a/tests/unit_tests/aioredis_cluster/test_connection.py b/tests/unit_tests/aioredis_cluster/test_connection.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit_tests/aioredis_cluster/test_connection_pubsub.py b/tests/unit_tests/aioredis_cluster/test_connection_pubsub.py index f94f695..14e016e 100644 --- a/tests/unit_tests/aioredis_cluster/test_connection_pubsub.py +++ b/tests/unit_tests/aioredis_cluster/test_connection_pubsub.py @@ -1,21 +1,36 @@ import asyncio +import logging from asyncio.queues import Queue from unittest import mock import pytest +from aioredis_cluster.aioredis import ChannelClosedError +from aioredis_cluster.aioredis.stream import StreamReader +from aioredis_cluster.command_info.commands import PUBSUB_SUBSCRIBE_COMMANDS +from aioredis_cluster.compat.asyncio import timeout from aioredis_cluster.connection import RedisConnection -from aioredis_cluster.errors import MovedError +from aioredis_cluster.errors import MovedError, RedisError +pytestmark = [pytest.mark.timeout(1)] -class Reader: + +async def moment(times: int = 1) -> None: + for _ in range(times): + await asyncio.sleep(0) + + +class MockedReader(StreamReader): def __init__(self) -> None: - self.queue = Queue() + self.queue: Queue = Queue() self.eof = False def set_parser(self, *args): pass + def feed_data(self, data): + pass + async def readobj(self): result = await self.queue.get() self.queue.task_done() @@ -25,24 +40,574 @@ def at_eof(self) -> bool: return self.eof and self.queue.empty() -async def test_moved_with_pubsub(): - reader = Reader() +def get_mocked_reader(): + return MockedReader() + + +def get_mocked_writer(): writer = mock.AsyncMock() writer.write = mock.Mock() writer.transport = mock.NonCallableMock() + return writer + + +async def close_connection(conn: RedisConnection) -> None: + conn.close() + await conn.wait_closed() + + +async def execute(redis: RedisConnection, *args, **kwargs): + return await redis.execute(*args, **kwargs) + + +async def execute_pubsub(redis: RedisConnection, *args, **kwargs): + return await redis.execute_pubsub(*args, **kwargs) + + +async def test_execute__simple_subscribe(add_async_finalizer): + reader = get_mocked_reader() + writer = get_mocked_writer() + redis = RedisConnection(reader=reader, writer=writer, address="localhost:6379") + add_async_finalizer(lambda: close_connection(redis)) + + reader.queue.put_nowait([b"subscribe", b"chan", 1]) + reader.queue.put_nowait([b"ssubscribe", b"chan", 1]) + reader.queue.put_nowait([b"psubscribe", b"chan", 2]) + + result_channel = await redis.execute_pubsub("SUBSCRIBE", "chan") + result_sharded = await redis.execute_pubsub("SSUBSCRIBE", "chan") + result_pattern = await redis.execute_pubsub("PSUBSCRIBE", "chan") + + assert result_channel == [[b"subscribe", b"chan", 1]] + assert result_pattern == [[b"psubscribe", b"chan", 2]] + assert result_sharded == [[b"ssubscribe", b"chan", 1]] + assert redis.in_pubsub == 1 + assert redis._client_in_pubsub is True + assert redis._server_in_pubsub is True + assert len(redis._waiters) == 0 + + assert "chan" in redis.pubsub_channels + assert "chan" in redis.pubsub_patterns + assert "chan" in redis.sharded_pubsub_channels + + assert redis.pubsub_channels["chan"] is not redis.pubsub_patterns["chan"] + assert redis.pubsub_channels["chan"] is not redis.sharded_pubsub_channels["chan"] + assert redis.pubsub_patterns["chan"] is not redis.sharded_pubsub_channels["chan"] + + +async def test_execute__simple_unsubscribe(add_async_finalizer): + reader = get_mocked_reader() + writer = get_mocked_writer() + redis = RedisConnection(reader=reader, writer=writer, address="localhost:6379") + add_async_finalizer(lambda: close_connection(redis)) + + reader.queue.put_nowait([b"ssubscribe", b"chan", 1]) + reader.queue.put_nowait([b"psubscribe", b"chan", 1]) + reader.queue.put_nowait([b"subscribe", b"chan", 2]) + + await redis.execute_pubsub("SSUBSCRIBE", "chan") + await redis.execute_pubsub("PSUBSCRIBE", "chan") + await redis.execute_pubsub("SUBSCRIBE", "chan") + + assert redis.in_pubsub == 1 + assert len(redis.pubsub_channels) == 1 + assert len(redis.pubsub_patterns) == 1 + assert len(redis.sharded_pubsub_channels) == 1 + + reader.queue.put_nowait([b"unsubscribe", b"chan", 1]) + reader.queue.put_nowait([b"punsubscribe", b"chan", 0]) + reader.queue.put_nowait([b"sunsubscribe", b"chan", 0]) + result_channel = await redis.execute_pubsub("UNSUBSCRIBE", "chan") + result_pattern = await redis.execute_pubsub("PUNSUBSCRIBE", "chan") + result_sharded = await redis.execute_pubsub("SUNSUBSCRIBE", "chan") + + await moment() + + assert redis.in_pubsub == 1 + assert len(redis.pubsub_channels) == 0 + assert len(redis.pubsub_patterns) == 0 + assert len(redis.sharded_pubsub_channels) == 0 + assert result_channel == [[b"unsubscribe", b"chan", 1]] + assert result_pattern == [[b"punsubscribe", b"chan", 0]] + assert result_sharded == [[b"sunsubscribe", b"chan", 0]] + assert redis._client_in_pubsub is True + assert redis._server_in_pubsub is True + + +@pytest.mark.parametrize( + "command", + [ + "SUBSCRIBE", + "PSUBSCRIBE", + "SSUBSCRIBE", + "UNSUBSCRIBE", + "PUNSUBSCRIBE", + "SUNSUBSCRIBE", + ], +) +async def test_execute__first_command(add_async_finalizer, command: str): + reader = get_mocked_reader() + writer = get_mocked_writer() + redis = RedisConnection(reader=reader, writer=writer, address="localhost:6379") + add_async_finalizer(lambda: close_connection(redis)) + + is_subscribe_command = command in PUBSUB_SUBSCRIBE_COMMANDS + kind = command.encode().lower() + + subs_num = 1 if is_subscribe_command else 0 + reader.queue.put_nowait((kind, b"chan", subs_num)) + + await redis.execute_pubsub(command, "chan") + + if is_subscribe_command: + assert redis.in_pubsub == 1 + assert redis._client_in_pubsub is True + assert redis._server_in_pubsub is True + else: + assert redis.in_pubsub == 0 + assert redis._client_in_pubsub is False + assert redis._server_in_pubsub is False + + +async def test_execute__half_open_pubsub_mode(add_async_finalizer): + reader = get_mocked_reader() + writer = get_mocked_writer() + redis = RedisConnection(reader=reader, writer=writer, address="localhost:6379") + add_async_finalizer(lambda: close_connection(redis)) + + get_task = asyncio.ensure_future(execute(redis, "GET", "foo", encoding="utf-8")) + ping1_task = asyncio.ensure_future(execute(redis, "PING", "ping_reply1")) + subs_task = asyncio.ensure_future(execute_pubsub(redis, "SSUBSCRIBE", "chan")) + # SET not send and execute() must raise RedisError exception + set_task = asyncio.ensure_future(execute(redis, "SET", "foo", "val2")) + ping2_task = asyncio.ensure_future(execute(redis, "PING", "ping_reply2", encoding="utf-8")) + + # need extra loop for asyncio.ensure_future starts a tasks + await moment() + + assert redis._client_in_pubsub is True + assert redis._server_in_pubsub is False + + reader.queue.put_nowait(b"val1") + reader.queue.put_nowait(b"ping_reply1") + reader.queue.put_nowait([b"ssubscribe", b"chan", 1]) + + # This is incorrect. Redis must return error with restrict this command in PubSub mode + # and client must prevent send SET command in half-open PubSub mode + # reader.queue.put_nowait(b"OK") + + reader.queue.put_nowait([b"pong", b"ping_reply2"]) + + # make 2 extra loops + await moment(2) + + assert get_task.done() is True + assert ping1_task.done() is True + assert subs_task.done() is True + assert set_task.done() is True + assert ping2_task.done() is True + + assert get_task.result() == "val1" + assert ping1_task.result() == b"ping_reply1" + assert subs_task.result() == [[b"ssubscribe", b"chan", 1]] + with pytest.raises(RedisError, match="Connection in PubSub mode"): + assert set_task.result() + assert ping2_task.result() == "ping_reply2" + + assert redis.in_pubsub == 1 + + +async def test_execute__unsubscribe_without_subscribe(add_async_finalizer): + reader = get_mocked_reader() + writer = get_mocked_writer() + redis = RedisConnection(reader=reader, writer=writer, address="localhost:6379") + add_async_finalizer(lambda: close_connection(redis)) + + reader.queue.put_nowait([b"sunsubscribe", b"chan", 0]) + await redis.execute_pubsub("SUNSUBSCRIBE", "chan") + reader.queue.put_nowait((b"punsubscribe", b"chan", 0)) + await redis.execute_pubsub("PUNSUBSCRIBE", "chan") + reader.queue.put_nowait((b"unsubscribe", b"chan", 0)) + await redis.execute_pubsub("UNSUBSCRIBE", "chan") + + assert redis.in_pubsub == 0 + assert redis._client_in_pubsub is False + assert redis._server_in_pubsub is False + + +async def test__redis_push_unsubscribe(add_async_finalizer): + reader = get_mocked_reader() + writer = get_mocked_writer() + redis = RedisConnection(reader=reader, writer=writer, address="localhost:6379") + add_async_finalizer(lambda: close_connection(redis)) + + sub_task = asyncio.ensure_future(redis.execute_pubsub("SUBSCRIBE", "chan:1", "chan:2")) + psub_task = asyncio.ensure_future(redis.execute_pubsub("PSUBSCRIBE", "chan:3", "chan:4")) + ssub_task = asyncio.ensure_future( + redis.execute_pubsub("SSUBSCRIBE", "chan:5:{shard}", "chan:6:{shard}") + ) + await moment() + + # push replies + reader.queue.put_nowait([b"subscribe", b"chan:1", 1]) + reader.queue.put_nowait([b"subscribe", b"chan:2", 2]) + reader.queue.put_nowait([b"psubscribe", b"chan:3", 3]) + reader.queue.put_nowait([b"psubscribe", b"chan:4", 4]) + reader.queue.put_nowait([b"ssubscribe", b"chan:5:{shard}", 1]) + reader.queue.put_nowait([b"ssubscribe", b"chan:6:{shard}", 2]) + await moment() + + assert sub_task.result() + assert psub_task.result() + assert ssub_task.result() + + assert redis.in_pubsub == 1 + assert len(redis.pubsub_channels) == 2 + assert len(redis.pubsub_patterns) == 2 + assert len(redis.sharded_pubsub_channels) == 2 + + reader.queue.put_nowait([b"unsubscribe", b"chan:1", 3]) + reader.queue.put_nowait([b"unsubscribe", b"chan:2", 2]) + reader.queue.put_nowait([b"punsubscribe", b"chan:3", 1]) + reader.queue.put_nowait([b"punsubscribe", b"chan:4", 0]) + reader.queue.put_nowait([b"sunsubscribe", b"chan:5:{shard}", 1]) + reader.queue.put_nowait([b"sunsubscribe", b"chan:6:{shard}", 0]) + # some extra channel + reader.queue.put_nowait([b"unsubscribe", b"chan:7", 3]) + + await moment() + + assert redis.in_pubsub == 1 + assert len(redis.pubsub_channels) == 0 + assert len(redis.pubsub_patterns) == 0 + assert len(redis.sharded_pubsub_channels) == 0 + + assert redis._reader_task.done() is False + + assert len(redis._pubsub_store._unconfirmed_subscribes) == 0 + + +async def test_moved_with_pubsub(add_async_finalizer): + reader = get_mocked_reader() + writer = get_mocked_writer() + redis = RedisConnection(reader=reader, writer=writer, address="localhost:6379") + add_async_finalizer(lambda: close_connection(redis)) + + reader.queue.put_nowait((b"ssubscribe", b"chan1", 1)) + + # key slot for chan1 - 2323 + await redis.execute_pubsub("SSUBSCRIBE", "chan1") + + assert len(redis.sharded_pubsub_channels) == 1 + assert "chan1" in redis.sharded_pubsub_channels + + # key slot chan2:{shard1} - 10271 + reader.queue.put_nowait([b"ssubscribe", b"chan2:{shard1}", 11]) + reader.queue.put_nowait([b"ssubscribe", b"chan3:{shard1}", 11]) + await redis.execute_pubsub("SSUBSCRIBE", "chan2:{shard1}") + await redis.execute_pubsub("SSUBSCRIBE", "chan3:{shard1}") + + assert len(redis.sharded_pubsub_channels) == 3 + assert "chan2:{shard1}" in redis.sharded_pubsub_channels + assert "chan3:{shard1}" in redis.sharded_pubsub_channels + + reader.queue.put_nowait(MovedError("MOVED 2323 127.0.0.1:6379")) + await moment() + + assert len(redis.sharded_pubsub_channels) == 2 + assert "chan1" not in redis.sharded_pubsub_channels + + reader.queue.put_nowait(MovedError("MOVED 10271 127.0.0.1:6379")) + await moment() + + assert len(redis.sharded_pubsub_channels) == 0 + + assert redis._reader_task.done() is False, redis._reader_task.exception() + + +async def test_execute__unexpectable_unsubscribe_and_moved(add_async_finalizer): + reader = get_mocked_reader() + writer = get_mocked_writer() + redis = RedisConnection(reader=reader, writer=writer, address="localhost:6379") + add_async_finalizer(lambda: close_connection(redis)) + + reader.queue.put_nowait([b"ssubscribe", b"chan1:{shard1}", 1]) + reader.queue.put_nowait([b"ssubscribe", b"chan2:{shard1}", 2]) + + await redis.execute_pubsub("SSUBSCRIBE", "chan1:{shard1}") + await redis.execute_pubsub("SSUBSCRIBE", "chan2:{shard1}") + + reader.queue.put_nowait([b"sunsubscribe", b"chan2:{shard1}", 1]) + reader.queue.put_nowait(MovedError("MOVED 10271 127.0.0.1:6379")) + await moment() + + assert redis.in_pubsub == 1 + assert redis._reader_task.done() is False + + +async def test_execute__ssubscribe_with_first_moved(add_async_finalizer): + reader = get_mocked_reader() + writer = get_mocked_writer() + redis = RedisConnection(reader=reader, writer=writer, address="localhost:6379") + add_async_finalizer(lambda: close_connection(redis)) + + reader.queue.put_nowait(MovedError("MOVED 10271 127.0.0.1:6379")) + + with pytest.raises(MovedError, match="MOVED 10271 127.0.0.1:6379"): + await redis.execute_pubsub("SSUBSCRIBE", "chan1:{shard1}") + + assert redis.in_pubsub == 0 + assert redis._client_in_pubsub is False + assert redis._server_in_pubsub is False + assert redis._reader_task.done() is False + + +async def test_execute__client_unsubscribe_with_server_unsubscribe(add_async_finalizer): + reader = get_mocked_reader() + writer = get_mocked_writer() + redis = RedisConnection(reader=reader, writer=writer, address="localhost:6379") + add_async_finalizer(lambda: close_connection(redis)) + + reader.queue.put_nowait([b"ssubscribe", b"chan:1", 1]) + sub_result1 = await redis.execute_pubsub("SSUBSCRIBE", "chan:1") + + reader.queue.put_nowait([b"ssubscribe", b"chan:2", 2]) + sub_result2 = await redis.execute_pubsub("SSUBSCRIBE", "chan:2") + + reader.queue.put_nowait([b"ssubscribe", b"chan:3", 3]) + sub_result3 = await redis.execute_pubsub("SSUBSCRIBE", "chan:3") + + assert sub_result1 == [[b"ssubscribe", b"chan:1", 1]] + assert sub_result2 == [[b"ssubscribe", b"chan:2", 2]] + assert sub_result3 == [[b"ssubscribe", b"chan:3", 3]] + assert redis.in_pubsub == 1 + assert len(redis.sharded_pubsub_channels) == 3 + + reader.queue.put_nowait([b"sunsubscribe", b"chan:1", 2]) + reader.queue.put_nowait([b"sunsubscribe", b"chan:3", 1]) + reader.queue.put_nowait([b"sunsubscribe", b"chan:2", 0]) + reader.queue.put_nowait(MovedError("MOVED 1 127.0.0.1:6379")) + await moment() + + reader.queue.put_nowait([b"sunsubscribe", b"chan:3", 0]) + unsub_result3 = await redis.execute_pubsub("SUNSUBSCRIBE", "chan:3") + + assert unsub_result3 == [[b"sunsubscribe", b"chan:3", 0]] + + await moment() + + assert redis.in_pubsub == 1 + assert len(redis.sharded_pubsub_channels) == 0 + + assert redis._reader_task is not None + assert redis._reader_task.done() is False + + +async def test_execute__ping(add_async_finalizer): + reader = get_mocked_reader() + writer = get_mocked_writer() + redis = RedisConnection(reader=reader, writer=writer, address="localhost:6379") + redis._in_pubsub = 1 + add_async_finalizer(lambda: close_connection(redis)) + + ping1_task = asyncio.ensure_future(redis.execute("PING")) + subs_task = asyncio.ensure_future(redis.execute_pubsub("SUBSCRIBE", "chan")) + await moment() + ping2_task = asyncio.ensure_future(redis.execute("PING")) + ping3_task = asyncio.ensure_future(redis.execute("PING", "my_message")) + reader.queue.put_nowait(b"PONG") + reader.queue.put_nowait((b"subscribe", b"chan", 1)) + reader.queue.put_nowait(b"PONG") + reader.queue.put_nowait((b"pong", "my_message")) + await moment(2) + + assert redis.in_pubsub == 1 + + assert ping1_task.done() is True + assert subs_task.done() is True + assert ping2_task.done() is True + assert ping3_task.done() is True + + assert redis._reader_task is not None + assert redis._reader_task.done() is False + + +async def test_subscribe_and_receive_messages(add_async_finalizer): + reader = get_mocked_reader() + writer = get_mocked_writer() redis = RedisConnection(reader=reader, writer=writer, address="localhost:6379") + add_async_finalizer(lambda: close_connection(redis)) + + reader.queue.put_nowait([b"subscribe", b"chan", 1]) + reader.queue.put_nowait([b"ssubscribe", b"chan:{shard}", 1]) + reader.queue.put_nowait([b"psubscribe", b"chan:*", 2]) + + await redis.execute_pubsub("SUBSCRIBE", "chan") + await redis.execute_pubsub("SSUBSCRIBE", "chan:{shard}") + await redis.execute_pubsub("PSUBSCRIBE", "chan:*") + + channel = redis.pubsub_channels["chan"] + pattern = redis.pubsub_patterns["chan:*"] + sharded = redis.sharded_pubsub_channels["chan:{shard}"] + + reader.queue.put_nowait([b"smessage", b"chan:{shard}", b"sharded_msg"]) + reader.queue.put_nowait([b"pmessage", b"chan:*", b"chan:foo", b"pattern_msg"]) + reader.queue.put_nowait([b"message", b"chan", b"channel_msg"]) + + await moment() + + channel_msg = await channel.get() + pattern_msg = await pattern.get() + sharded_msg = await sharded.get() + + assert channel_msg == b"channel_msg" + assert pattern_msg == (b"chan:foo", b"pattern_msg") + assert sharded_msg == b"sharded_msg" + + assert redis._reader_task is not None + assert redis._reader_task.done() is False + + +async def test_receive_message_after_unsubscribe(caplog, add_async_finalizer): + reader = get_mocked_reader() + writer = get_mocked_writer() + redis = RedisConnection(reader=reader, writer=writer, address="localhost:6379") + add_async_finalizer(lambda: close_connection(redis)) + + with caplog.at_level(logging.WARNING): + reader.queue.put_nowait([b"ssubscribe", b"chan:{shard}", 1]) + await redis.execute_pubsub("SSUBSCRIBE", "chan:{shard}") + sharded = redis.sharded_pubsub_channels["chan:{shard}"] + await redis.execute_pubsub("SUNSUBSCRIBE", "chan:{shard}") + + reader.queue.put_nowait([b"smessage", b"chan:{shard}", b"sharded_msg"]) + + await moment() + + assert sharded.is_active is False + assert sharded._queue.qsize() == 0 + with pytest.raises(ChannelClosedError): + await sharded.get() + + no_channel_record = "" + for record in caplog.records: + assert "No waiter for received reply" not in record.message, record.message + if "No channel" in record.message and "for received message" in record.message: + no_channel_record = record.message + + assert no_channel_record != "" + + assert redis._reader_task is not None + assert redis._reader_task.done() is False + + +async def test_subscribe_and_immediately_unsubscribe(caplog, add_async_finalizer): + reader = get_mocked_reader() + writer = get_mocked_writer() + redis = RedisConnection(reader=reader, writer=writer, address="localhost:6379") + add_async_finalizer(lambda: close_connection(redis)) + + reader.queue.put_nowait([b"ssubscribe", b"chan1:{shard}", 1]) + reader.queue.put_nowait([b"ssubscribe", b"chan2:{shard}", 2]) + + await redis.execute_pubsub("SSUBSCRIBE", "chan1:{shard}") + await redis.execute_pubsub("SSUBSCRIBE", "chan2:{shard}") + + with caplog.at_level(logging.ERROR): + await redis.execute_pubsub("SUNSUBSCRIBE", "chan2:{shard}") + await redis.execute_pubsub("SUNSUBSCRIBE", "chan1:{shard}") + reader.queue.put_nowait([b"sunsubscribe", b"chan2:{shard}", 1]) + reader.queue.put_nowait([b"sunsubscribe", b"chan1:{shard}", 0]) + + await moment(2) + + for record in caplog.records: + assert "No waiter for received reply" not in record.message, record.message + + assert redis.in_pubsub == 1 + assert len(redis.sharded_pubsub_channels) == 0 + + assert redis._reader_task is not None + assert redis._reader_task.done() is False + + +async def test_immediately_resubscribe(caplog, add_async_finalizer): + reader = get_mocked_reader() + writer = get_mocked_writer() + redis = RedisConnection(reader=reader, writer=writer, address="localhost:6379") + add_async_finalizer(lambda: close_connection(redis)) + + # switch connection to pubsub mode + reader.queue.put_nowait([b"ssubscribe", b"chan1:{shard}", 1]) + await redis.execute_pubsub("SSUBSCRIBE", "chan1:{shard}") + + reader.queue.put_nowait([b"ssubscribe", b"chan2:{shard}", 2]) + await redis.execute_pubsub("SSUBSCRIBE", "chan2:{shard}") + + unsub_task = asyncio.ensure_future(redis.execute_pubsub("SUNSUBSCRIBE", "chan2:{shard}")) + sub_task = asyncio.ensure_future(redis.execute_pubsub("SSUBSCRIBE", "chan2:{shard}")) + await moment() + + assert unsub_task.done() is True + assert sub_task.done() is True + + ch = redis.sharded_pubsub_channels["chan2:{shard}"] + + ch_get_task = asyncio.ensure_future(ch.get()) + # start task + await moment() + + # redis send sequence of replies + reader.queue.put_nowait([b"sunsubscribe", b"chan2:{shard}", 1]) + reader.queue.put_nowait([b"ssubscribe", b"chan2:{shard}", 2]) + + # consume replies + await moment() + # wait done callback for ch.get() + await moment() + + assert ch_get_task.done() is False + + +async def test_resubscribe_with_message_received(add_async_finalizer): + reader = get_mocked_reader() + writer = get_mocked_writer() + redis = RedisConnection(reader=reader, writer=writer, address="localhost:6379") + add_async_finalizer(lambda: close_connection(redis)) + + # switch connection to pubsub mode + asyncio.ensure_future(redis.execute_pubsub("SSUBSCRIBE", "chan1:{shard}")) + reader.queue.put_nowait([b"ssubscribe", b"chan1:{shard}", 1]) + + await moment() + + asyncio.ensure_future(redis.execute_pubsub("SUNSUBSCRIBE", "chan1:{shard}")) + + await moment() + + resub_task = asyncio.ensure_future(redis.execute_pubsub("SSUBSCRIBE", "chan1:{shard}")) + await moment() + + reader.queue.put_nowait([b"smessage", b"chan1:{shard}", b"msg1"]) + reader.queue.put_nowait([b"sunsubscribe", b"chan1:{shard}", 0]) + reader.queue.put_nowait([b"ssubscribe", b"chan1:{shard}", 1]) + reader.queue.put_nowait([b"smessage", b"chan1:{shard}", b"msg2"]) + + # push loop cycle for received events + await moment() + + channel_name = resub_task.result()[0][1] + ch = redis.sharded_pubsub_channels[channel_name] - s = redis.execute_pubsub("SSUBSCRIBE", "a") - reader.queue.put_nowait((b"ssubscribe", b"a", 10)) - await s + async with timeout(0): + msg1 = await ch.get() + assert msg1 == b"msg1" - s = redis.execute_pubsub("SSUBSCRIBE", "b") - await reader.queue.put(MovedError("1 1 127.0.0.1:6379")) - with pytest.raises(MovedError): - await asyncio.wait_for(s, timeout=1) - assert not redis._reader_task.done(), redis._reader_task.exception() + async with timeout(0): + msg2 = await ch.get() + assert msg2 == b"msg2" - reader.queue.put_nowait((b"smessage", b"a", b"123")) - assert not redis._reader_task.done() - redis.close() - await redis.wait_closed() + # no more messages + with pytest.raises(asyncio.TimeoutError): + async with timeout(0.001): + await ch.get() diff --git a/tests/unit_tests/aioredis_cluster/test_pooler.py b/tests/unit_tests/aioredis_cluster/test_pooler.py index 2ed672c..0a6000a 100644 --- a/tests/unit_tests/aioredis_cluster/test_pooler.py +++ b/tests/unit_tests/aioredis_cluster/test_pooler.py @@ -3,7 +3,7 @@ import pytest -from aioredis_cluster.abc import AbcPool +from aioredis_cluster.abc import AbcConnection, AbcPool from aioredis_cluster.pooler import Pooler from aioredis_cluster.structs import Address @@ -15,11 +15,12 @@ def create_pool_mock(): return mocked -async def test_ensure_pool__identical_address(): +async def test_ensure_pool__identical_address(add_async_finalizer): mocked_create_pool = mock.AsyncMock( return_value=create_pool_mock(), ) pooler = Pooler(mocked_create_pool) + add_async_finalizer(lambda: pooler.close()) result = await pooler.ensure_pool(Address("localhost", 1234)) @@ -32,11 +33,16 @@ async def test_ensure_pool__identical_address(): assert mocked_create_pool.call_count == 1 -async def test_ensure_pool__multiple(): - pools = [object(), object(), object()] +async def test_ensure_pool__multiple(add_async_finalizer): + pools = [ + mock.AsyncMock(AbcConnection), + mock.AsyncMock(AbcConnection), + mock.AsyncMock(AbcConnection), + ] mocked_create_pool = mock.AsyncMock(side_effect=pools) pooler = Pooler(mocked_create_pool) + add_async_finalizer(lambda: pooler.close()) result1 = await pooler.ensure_pool(Address("localhost", 1234)) result2 = await pooler.ensure_pool(Address("localhost", 4321)) @@ -55,7 +61,7 @@ async def test_ensure_pool__multiple(): ) -async def test_ensure_pool__only_one(): +async def test_ensure_pool__only_one(add_async_finalizer): event_loop = asyncio.get_running_loop() pools = { ("h1", 1): create_pool_mock(), @@ -71,6 +77,7 @@ async def create_pool_se(addr): mocked_create_pool = mock.AsyncMock(side_effect=create_pool_se) pooler = Pooler(mocked_create_pool) + add_async_finalizer(lambda: pooler.close()) tasks = [] for i in range(10): @@ -88,11 +95,12 @@ async def create_pool_se(addr): assert mocked_create_pool.call_count == 2 -async def test_ensure_pool__error(): - pools = [RuntimeError(), object()] +async def test_ensure_pool__error(add_async_finalizer): + pools = [RuntimeError(), mock.AsyncMock(AbcConnection)] mocked_create_pool = mock.AsyncMock(side_effect=pools) pooler = Pooler(mocked_create_pool) + add_async_finalizer(lambda: pooler.close()) addr = Address("localhost", 1234) with pytest.raises(RuntimeError): @@ -117,7 +125,7 @@ async def test_close__empty_pooler(): assert pooler.closed is True -async def test_close__with_pools(mocker): +async def test_close__with_pools(mocker, add_async_finalizer): addrs_pools = [ (Address("h1", 1), create_pool_mock()), (Address("h2", 2), create_pool_mock()), @@ -127,6 +135,7 @@ async def test_close__with_pools(mocker): mocked_create_pool = mock.AsyncMock(side_effect=pools) pooler = Pooler(mocked_create_pool) + add_async_finalizer(lambda: pooler.close()) result1 = await pooler.ensure_pool(addrs[0]) result2 = await pooler.ensure_pool(addrs[1]) @@ -143,7 +152,7 @@ async def test_close__with_pools(mocker): result2.wait_closed.assert_called_once() -async def test_reap_pools(mocker): +async def test_reap_pools(mocker, add_async_finalizer): addrs_pools = [ (Address("h1", 1), create_pool_mock()), (Address("h2", 2), create_pool_mock()), @@ -153,6 +162,7 @@ async def test_reap_pools(mocker): mocked_create_pool = mock.AsyncMock(side_effect=pools) pooler = Pooler(mocked_create_pool, reap_frequency=-1) + add_async_finalizer(lambda: pooler.close()) # create pools await pooler.ensure_pool(addrs[0]) @@ -176,8 +186,9 @@ async def test_reap_pools(mocker): assert len(pooler._nodes) == 0 -async def test_reaper(mocker): +async def test_reaper(mocker, add_async_finalizer): pooler = Pooler(mock.AsyncMock(), reap_frequency=0) + add_async_finalizer(lambda: pooler.close()) assert pooler._reap_calls == 0 @@ -194,8 +205,9 @@ async def test_reaper(mocker): assert pooler._reaper_task.cancelled() is True -async def test_add_pubsub_channel__no_addr(): +async def test_add_pubsub_channel__no_addr(add_async_finalizer): pooler = Pooler(mock.AsyncMock(), reap_frequency=-1) + add_async_finalizer(lambda: pooler.close()) addr = Address("h1", 1234) result = pooler.add_pubsub_channel(addr, b"channel", is_pattern=False) @@ -203,8 +215,9 @@ async def test_add_pubsub_channel__no_addr(): assert result is False -async def test_add_pubsub_channel(): +async def test_add_pubsub_channel(add_async_finalizer): pooler = Pooler(mock.AsyncMock(return_value=create_pool_mock()), reap_frequency=-1) + add_async_finalizer(lambda: pooler.close()) addr1 = Address("h1", 1234) addr2 = Address("h2", 1234) @@ -237,15 +250,17 @@ async def test_add_pubsub_channel(): assert (b"ch3", False) in collected_channels -async def test_remove_pubsub_channel__no_addr(): +async def test_remove_pubsub_channel__no_addr(add_async_finalizer): pooler = Pooler(mock.AsyncMock(), reap_frequency=-1) + add_async_finalizer(lambda: pooler.close()) result = pooler.remove_pubsub_channel(b"channel", is_pattern=False) assert result is False -async def test_remove_pubsub_channel(): +async def test_remove_pubsub_channel(add_async_finalizer): pooler = Pooler(mock.AsyncMock(), reap_frequency=-1) + add_async_finalizer(lambda: pooler.close()) addr1 = Address("h1", 1234) addr2 = Address("h2", 1234) @@ -275,8 +290,9 @@ async def test_remove_pubsub_channel(): assert len(pooler._pubsub_channels) == 0 -async def test_get_pubsub_addr(): +async def test_get_pubsub_addr(add_async_finalizer): pooler = Pooler(mock.AsyncMock(), reap_frequency=-1) + add_async_finalizer(lambda: pooler.close()) addr1 = Address("h1", 1234) addr2 = Address("h2", 1234) @@ -297,11 +313,12 @@ async def test_get_pubsub_addr(): assert result4 == addr2 -async def test_ensure_pool__create_pubsub_addr_set(): +async def test_ensure_pool__create_pubsub_addr_set(add_async_finalizer): addr1 = Address("h1", 1234) addr2 = Address("h2", 1234) pooler = Pooler(mock.AsyncMock(return_value=create_pool_mock())) + add_async_finalizer(lambda: pooler.close()) assert len(pooler._pubsub_addrs) == 0 @@ -320,9 +337,10 @@ async def test_ensure_pool__create_pubsub_addr_set(): assert len(pooler._pubsub_addrs[addr1]) == 1 -async def test_reap_pools__cleanup_channels(): +async def test_reap_pools__cleanup_channels(add_async_finalizer): pool_factory = mock.AsyncMock(return_value=mock.Mock(AbcPool)) pooler = Pooler(pool_factory, reap_frequency=-1) + add_async_finalizer(lambda: pooler.close()) addr1 = Address("h1", 1) addr2 = Address("h2", 2) @@ -346,12 +364,14 @@ async def test_reap_pools__cleanup_channels(): assert len(pooler._pubsub_channels) == 0 -async def test_close_only(): +async def test_close_only(add_async_finalizer): pool1 = create_pool_mock() pool2 = create_pool_mock() pool3 = create_pool_mock() mocked_create_pool = mock.AsyncMock(side_effect=[pool1, pool2, pool3]) pooler = Pooler(mocked_create_pool) + add_async_finalizer(lambda: pooler.close()) + addr1 = Address("h1", 1) addr2 = Address("h2", 2) diff --git a/tests/unit_tests/aioredis_cluster/test_pubsub.py b/tests/unit_tests/aioredis_cluster/test_pubsub.py new file mode 100644 index 0000000..3a7ee0a --- /dev/null +++ b/tests/unit_tests/aioredis_cluster/test_pubsub.py @@ -0,0 +1,407 @@ +from unittest import mock + +import pytest + +from aioredis_cluster.abc import AbcChannel +from aioredis_cluster.command_info.commands import PubSubType +from aioredis_cluster.crc import key_slot +from aioredis_cluster.pubsub import PubSubStore + + +def make_channel_mock(): + return mock.NonCallableMock(AbcChannel) + + +def test_channel_subscribe__one_channel(): + store = PubSubStore() + chan = make_channel_mock() + pchan = make_channel_mock() + schan = make_channel_mock() + + store.channel_subscribe( + channel_type=PubSubType.CHANNEL, + channel_name=b"chan", + channel=chan, + key_slot=-1, + ) + + assert store.channels_num == 1 + assert store.channels_total == 1 + assert store.sharded_channels_num == 0 + assert len(store.channels) == 1 + assert len(store.patterns) == 0 + assert len(store.sharded) == 0 + assert store.channels["chan"] is chan + + store.channel_subscribe( + channel_type=PubSubType.PATTERN, + channel_name=b"pchan", + channel=pchan, + key_slot=-1, + ) + + assert store.channels_num == 2 + assert store.channels_total == 2 + assert store.sharded_channels_num == 0 + assert len(store.channels) == 1 + assert len(store.patterns) == 1 + assert len(store.sharded) == 0 + assert store.patterns["pchan"] is pchan + + schan_key_slot = key_slot(b"schan") + store.channel_subscribe( + channel_type=PubSubType.SHARDED, + channel_name=b"schan", + channel=schan, + key_slot=schan_key_slot, + ) + + assert store.channels_num == 2 + assert store.channels_total == 3 + assert store.sharded_channels_num == 1 + assert len(store.channels) == 1 + assert len(store.patterns) == 1 + assert len(store.sharded) == 1 + assert store.sharded["schan"] is schan + + assert schan_key_slot in store._slot_to_sharded + assert store._slot_to_sharded[schan_key_slot] == {b"schan"} + + +def test_close__empty(): + store = PubSubStore() + store.close(None) + + +def test_close__with_channels(): + store = PubSubStore() + + chan = make_channel_mock() + pchan = make_channel_mock() + schan = make_channel_mock() + + store.channel_subscribe( + channel_type=PubSubType.CHANNEL, + channel_name=b"chan", + channel=chan, + key_slot=-1, + ) + store.channel_subscribe( + channel_type=PubSubType.PATTERN, + channel_name=b"pchan", + channel=pchan, + key_slot=-1, + ) + store.channel_subscribe( + channel_type=PubSubType.SHARDED, + channel_name=b"schan", + channel=schan, + key_slot=key_slot(b"schan"), + ) + + close_exc = Exception() + store.close(close_exc) + + assert store.channels_total == 0 + assert len(store._slot_to_sharded) == 0 + assert len(store._unconfirmed_subscribes) == 0 + + chan.close.assert_called_once_with(close_exc) + pchan.close.assert_called_once_with(close_exc) + schan.close.assert_called_once_with(close_exc) + + +def test_confirm_subscribe__no_confirms(caplog): + store = PubSubStore() + + with caplog.at_level("ERROR", "aioredis_cluster.pubsub"): + store.confirm_subscribe(PubSubType.CHANNEL, b"chan") + + assert len(caplog.records) == 1 + assert "Unexpected PubSub subscribe confirm for" in caplog.records[0].message + + +def test_confirm_subscribe__simple_confirm(caplog): + store = PubSubStore() + + chan = make_channel_mock() + pchan = make_channel_mock() + schan = make_channel_mock() + + store.channel_subscribe( + channel_type=PubSubType.CHANNEL, + channel_name=b"chan", + channel=chan, + key_slot=-1, + ) + store.channel_subscribe( + channel_type=PubSubType.PATTERN, + channel_name=b"pchan", + channel=pchan, + key_slot=-1, + ) + schan_key_slot = key_slot(b"schan") + store.channel_subscribe( + channel_type=PubSubType.SHARDED, + channel_name=b"schan", + channel=schan, + key_slot=schan_key_slot, + ) + + assert len(store._unconfirmed_subscribes) == 3 + assert schan_key_slot in store._slot_to_sharded + + with caplog.at_level("ERROR", "aioredis_cluster.pubsub"): + store.confirm_subscribe(PubSubType.SHARDED, b"schan") + + assert len(store._unconfirmed_subscribes) == 2 + assert schan_key_slot in store._slot_to_sharded + assert len(caplog.records) == 0 + + with caplog.at_level("ERROR", "aioredis_cluster.pubsub"): + store.confirm_subscribe(PubSubType.PATTERN, b"pchan") + store.confirm_subscribe(PubSubType.CHANNEL, b"chan") + + assert len(store._unconfirmed_subscribes) == 0 + assert len(caplog.records) == 0 + + +@pytest.mark.parametrize("channel_type", list(PubSubType)) +def test_confirm_subscribe__resub_confirms(channel_type): + store = PubSubStore() + chan = make_channel_mock() + channel_name = b"chan" + channel_key_slot = key_slot(channel_name) + + # calls in execute_pubsub() for (P|S)SUBSCRIBE commands + store.channel_subscribe( + channel_type=channel_type, + channel_name=channel_name, + channel=chan, + # this is key slot for b"chan" + key_slot=channel_key_slot, + ) + + # calls in execute_pubsub() for (P|S)UNSUBSCRIBE commands + store.channel_unsubscribe( + channel_type=channel_type, + channel_name=channel_name, + by_reply=False, + ) + + store.channel_subscribe( + channel_type=channel_type, + channel_name=channel_name, + channel=chan, + # this is key slot for b"chan" + key_slot=channel_key_slot, + ) + + assert len(store._unconfirmed_subscribes) == 1 + assert store._unconfirmed_subscribes[(channel_type, channel_name)] == 2 + assert store.channels_total == 1 + + # call process_pubsub on receive (p|s)subscribe kind events + store.confirm_subscribe(channel_type, channel_name) + + assert len(store._unconfirmed_subscribes) == 1 + assert store._unconfirmed_subscribes[(channel_type, channel_name)] == 1 + if channel_type is PubSubType.SHARDED: + assert channel_key_slot in store._slot_to_sharded + + # call process_pubsub on receive (p|s)unsubscribe kind events + # this call do nothing because have unconfirmed subscription calls + store.channel_unsubscribe( + channel_type=channel_type, + channel_name=channel_name, + by_reply=True, + ) + + # second server reply for subscribe command + store.confirm_subscribe(channel_type, channel_name) + assert len(store._unconfirmed_subscribes) == 0 + + +@pytest.mark.parametrize("channel_type", list(PubSubType)) +def test_confirm_subscribe__resub_and_unsub(channel_type): + store = PubSubStore() + chan = make_channel_mock() + channel_name = b"chan" + channel_key_slot = key_slot(channel_name) + + for i in range(2): + store.channel_subscribe( + channel_type=channel_type, + channel_name=channel_name, + channel=chan, + # this is key slot for b"chan" + key_slot=channel_key_slot, + ) + store.channel_unsubscribe( + channel_type=channel_type, + channel_name=channel_name, + by_reply=False, + ) + + assert len(store._unconfirmed_subscribes) == 1 + assert store._unconfirmed_subscribes[(channel_type, channel_name)] == 2 + assert store.channels_total == 0 + + # first subscribe reply + store.confirm_subscribe(channel_type, channel_name) + + assert len(store._unconfirmed_subscribes) == 1 + assert store._unconfirmed_subscribes[(channel_type, channel_name)] == 1 + if channel_type is PubSubType.SHARDED: + assert channel_key_slot in store._slot_to_sharded + + # first unsubscribe reply + store.channel_unsubscribe( + channel_type=channel_type, + channel_name=channel_name, + by_reply=True, + ) + + # second subscribe reply + store.confirm_subscribe(channel_type, channel_name) + assert len(store._unconfirmed_subscribes) == 0 + if channel_type is PubSubType.SHARDED: + assert len(store._slot_to_sharded) == 0 + + +@pytest.mark.parametrize("channel_type", list(PubSubType)) +def test_confirm_subscribe__unsub_server_push(channel_type): + store = PubSubStore() + chan = make_channel_mock() + channel_name = b"chan" + channel_key_slot = key_slot(channel_name) + + store.channel_subscribe( + channel_type=channel_type, + channel_name=channel_name, + channel=chan, + # this is key slot for b"chan" + key_slot=channel_key_slot, + ) + + # server push (p|s)unsubscribe kind event + # probably previous sub->unsub attempts + # we do nothing + store.channel_unsubscribe( + channel_type=channel_type, + channel_name=channel_name, + by_reply=True, + ) + + assert store.channels_total == 1 + chan.close.assert_not_called() + + # subscribe reply + store.confirm_subscribe(channel_type, channel_name) + + # server push (p|s)unsubscribe kind event + # maybe is cluster reshard or node is shutting down + # we must close all confirmed channels + store.channel_unsubscribe( + channel_type=channel_type, + channel_name=channel_name, + by_reply=True, + ) + + assert store.channels_total == 0 + chan.close.assert_called_once_with() + if channel_type is PubSubType.SHARDED: + assert len(store._slot_to_sharded) == 0 + + +@pytest.mark.parametrize("channel_type", list(PubSubType)) +def test_channel_unsubscribe__subscribe_confirm_and_unsubscibe(caplog, channel_type): + store = PubSubStore() + chan = make_channel_mock() + channel_name = b"chan" + channel_key_slot = key_slot(channel_name) + + store.channel_subscribe( + channel_type=channel_type, + channel_name=channel_name, + channel=chan, + # this is key slot for b"chan" + key_slot=channel_key_slot, + ) + # subscribe reply + store.confirm_subscribe(channel_type, channel_name) + + store.channel_unsubscribe( + channel_type=channel_type, + channel_name=channel_name, + by_reply=False, + ) + + assert store.channels_total == 0 + assert store.have_slot_channels(channel_key_slot) is False + assert len(store._slot_to_sharded) == 0 + chan.close.assert_called_once_with() + + with caplog.at_level("WARNING", "aioredis_cluster.pubsub"): + # server reply for unsubscribe + store.channel_unsubscribe( + channel_type=channel_type, + channel_name=channel_name, + by_reply=True, + ) + + assert len(caplog.records) == 0 + + +def test_slot_channels_unsubscribe__empty(): + store = PubSubStore() + store.slot_channels_unsubscribe(1234) + + +def test_slot_channels_unsubscribe__with_unconfirmed_subscribes(): + # this is unrealistic case but we need check this + store = PubSubStore() + chan1 = make_channel_mock() + channel1_name = b"chan1" + channel1_key_slot = key_slot(channel1_name) + chan2 = make_channel_mock() + channel2_name = b"chan2" + channel2_key_slot = key_slot(channel2_name) + chan3 = make_channel_mock() + channel3_name = b"chan3:{chan1}" + channel3_key_slot = key_slot(channel3_name) + + assert channel1_key_slot == channel3_key_slot + + store.channel_subscribe( + channel_type=PubSubType.SHARDED, + channel_name=channel1_name, + channel=chan1, + key_slot=channel1_key_slot, + ) + store.channel_subscribe( + channel_type=PubSubType.SHARDED, + channel_name=channel2_name, + channel=chan2, + key_slot=channel2_key_slot, + ) + store.channel_subscribe( + channel_type=PubSubType.SHARDED, + channel_name=channel3_name, + channel=chan3, + key_slot=channel3_key_slot, + ) + + assert store.channels_total == 3 + assert store.have_slot_channels(channel1_key_slot) is True + assert store.have_slot_channels(channel2_key_slot) is True + assert store.have_slot_channels(0) is False + + store.slot_channels_unsubscribe(channel3_key_slot) + + assert store.channels_total == 1 + assert len(store._unconfirmed_subscribes) == 1 + assert (PubSubType.SHARDED, channel2_name) in store._unconfirmed_subscribes + chan1.close.assert_called_once_with() + chan2.close.assert_not_called() + chan3.close.assert_called_once_with()