diff --git a/src/aioredis_cluster/connection.py b/src/aioredis_cluster/connection.py index 111a946..01d278e 100644 --- a/src/aioredis_cluster/connection.py +++ b/src/aioredis_cluster/connection.py @@ -95,6 +95,7 @@ def __init__(self) -> None: self._sharded: coerced_keys_dict[AbcChannel] = coerced_keys_dict() self._sharded_to_slot: Dict[bytes, int] = {} self._slot_to_sharded: Dict[int, Set[bytes]] = {} + self._pending_unsubscribe: Set[Tuple[PubSubType, bytes]] = set() @property def channels(self) -> Mapping[str, AbcChannel]: @@ -111,6 +112,14 @@ def sharded(self) -> Mapping[str, AbcChannel]: """Returns read-only sharded channels dict.""" return MappingProxyType(self._sharded) + def channel_pending_unsubscribe( + self, + *, + channel_type: PubSubType, + channel_name: bytes, + ) -> None: + self._pending_unsubscribe.add((channel_type, channel_name)) + def channel_subscribe( self, *, @@ -183,6 +192,16 @@ def channels_num(self) -> int: def sharded_channels_num(self) -> int: return len(self._sharded) + def has_channel(self, channel_type: PubSubType, channel_name: bytes) -> bool: + ret = False + if channel_type is PubSubType.CHANNEL: + ret = channel_name in self._channels + elif channel_type is PubSubType.PATTERN: + ret = channel_name in self._patterns + elif channel_type is PubSubType.SHARDED: + ret = channel_name in self._sharded + return ret + def get_channel(self, channel_type: PubSubType, channel_name: bytes) -> AbcChannel: if channel_type is PubSubType.CHANNEL: channel = self._channels[channel_name] @@ -785,19 +804,24 @@ def _process_pubsub(self, obj: Any) -> Any: channel_type=channel_type, channel_name=channel_name, ) - if self._pubsub_channels_store.channels_total == 0: - self._server_in_pubsub = False - self._client_in_pubsub = False - elif kind in {b"message", b"smessage"}: - (channel_name,) = args - channel_type = PUBSUB_RESP_KIND_TO_TYPE[kind] - channel = self._pubsub_channels_store.get_channel(channel_type, channel_name) - channel.put_nowait(data) - elif kind == b"pmessage": - (pattern, channel_name) = args + elif kind in {b"message", b"smessage", b"pmessage"}: + if kind == b"pmessage": + (pattern, channel_name) = args + else: + (channel_name,) = args + pattern = channel_name + channel_type = PUBSUB_RESP_KIND_TO_TYPE[kind] - channel = self._pubsub_channels_store.get_channel(channel_type, pattern) - channel.put_nowait((channel_name, data)) + if self._pubsub_channels_store.has_channel(channel_type, pattern): + channel = self._pubsub_channels_store.get_channel(channel_type, pattern) + if channel_type is PubSubType.PATTERN: + channel.put_nowait((channel_name, data)) + else: + channel.put_nowait(data) + else: + logger.warning( + "No channel %r with type %s for received message", pattern, channel_type + ) elif kind == b"pong": if not self._waiters: logger.error("No PubSub PONG waiters for received data %r", data) diff --git a/tests/unit_tests/aioredis_cluster/test_connection_pubsub.py b/tests/unit_tests/aioredis_cluster/test_connection_pubsub.py index 5f51a44..1340237 100644 --- a/tests/unit_tests/aioredis_cluster/test_connection_pubsub.py +++ b/tests/unit_tests/aioredis_cluster/test_connection_pubsub.py @@ -1,9 +1,11 @@ import asyncio +import logging from asyncio.queues import Queue from unittest import mock import pytest +from aioredis_cluster.aioredis import ChannelClosedError from aioredis_cluster.aioredis.stream import StreamReader from aioredis_cluster.command_info.commands import PUBSUB_SUBSCRIBE_COMMANDS from aioredis_cluster.connection import RedisConnection @@ -121,8 +123,8 @@ async def test_execute__simple_unsubscribe(add_async_finalizer): assert result_channel == [[b"unsubscribe", b"chan", 1]] assert result_pattern == [[b"punsubscribe", b"chan", 0]] assert result_sharded == [[b"sunsubscribe", b"chan", 0]] - assert redis._client_in_pubsub is False - assert redis._server_in_pubsub is False + assert redis._client_in_pubsub is True + assert redis._server_in_pubsub is True @pytest.mark.parametrize( @@ -434,3 +436,68 @@ async def test_subscribe_and_receive_messages(add_async_finalizer): assert channel_msg == b"channel_msg" assert pattern_msg == (b"chan:foo", b"pattern_msg") assert sharded_msg == b"sharded_msg" + + assert redis._reader_task is not None + assert redis._reader_task.done() is False + + +async def test_receive_message_after_unsubscribe(caplog, add_async_finalizer): + reader = get_mocked_reader() + writer = get_mocked_writer() + redis = RedisConnection(reader=reader, writer=writer, address="localhost:6379") + add_async_finalizer(lambda: close_connection(redis)) + + with caplog.at_level(logging.WARNING): + reader.queue.put_nowait([b"ssubscribe", b"chan:{shard}", 1]) + await redis.execute_pubsub("SSUBSCRIBE", "chan:{shard}") + sharded = redis.sharded_pubsub_channels["chan:{shard}"] + await redis.execute_pubsub("SUNSUBSCRIBE", "chan:{shard}") + + reader.queue.put_nowait([b"smessage", b"chan:{shard}", b"sharded_msg"]) + + await moment() + + assert sharded.is_active is False + assert sharded._queue.qsize() == 0 + with pytest.raises(ChannelClosedError): + await sharded.get() + + no_channel_record = "" + for record in caplog.records: + assert "No waiter for received reply" not in record.message, record.message + if "No channel" in record.message and "for received message" in record.message: + no_channel_record = record.message + + assert no_channel_record != "" + + assert redis._reader_task is not None + assert redis._reader_task.done() is False + + +async def test_subscribe_and_immediately_unsubscribe(caplog, add_async_finalizer): + reader = get_mocked_reader() + writer = get_mocked_writer() + redis = RedisConnection(reader=reader, writer=writer, address="localhost:6379") + add_async_finalizer(lambda: close_connection(redis)) + + reader.queue.put_nowait([b"ssubscribe", b"chan1:{shard}", 1]) + reader.queue.put_nowait([b"ssubscribe", b"chan2:{shard}", 2]) + + await redis.execute_pubsub("SSUBSCRIBE", "chan1:{shard}") + await redis.execute_pubsub("SSUBSCRIBE", "chan2:{shard}") + + with caplog.at_level(logging.ERROR): + await redis.execute_pubsub("SUNSUBSCRIBE", "chan2:{shard}") + await redis.execute_pubsub("SUNSUBSCRIBE", "chan1:{shard}") + reader.queue.put_nowait([b"sunsubscribe", b"chan2:{shard}", 1]) + reader.queue.put_nowait([b"sunsubscribe", b"chan1:{shard}", 0]) + + await moment(2) + + for record in caplog.records: + assert "No waiter for received reply" not in record.message, record.message + + assert redis.in_pubsub == 0 + + assert redis._reader_task is not None + assert redis._reader_task.done() is False