Skip to content

Commit

Permalink
restrict exit in PubSub mode for RedisConnection
Browse files Browse the repository at this point in the history
  • Loading branch information
Anton Ilyushenkov committed Dec 6, 2023
1 parent a5cdd7b commit 5a3e467
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 14 deletions.
48 changes: 36 additions & 12 deletions src/aioredis_cluster/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def __init__(self) -> None:
self._sharded: coerced_keys_dict[AbcChannel] = coerced_keys_dict()
self._sharded_to_slot: Dict[bytes, int] = {}
self._slot_to_sharded: Dict[int, Set[bytes]] = {}
self._pending_unsubscribe: Set[Tuple[PubSubType, bytes]] = set()

@property
def channels(self) -> Mapping[str, AbcChannel]:
Expand All @@ -111,6 +112,14 @@ def sharded(self) -> Mapping[str, AbcChannel]:
"""Returns read-only sharded channels dict."""
return MappingProxyType(self._sharded)

def channel_pending_unsubscribe(
self,
*,
channel_type: PubSubType,
channel_name: bytes,
) -> None:
self._pending_unsubscribe.add((channel_type, channel_name))

def channel_subscribe(
self,
*,
Expand Down Expand Up @@ -183,6 +192,16 @@ def channels_num(self) -> int:
def sharded_channels_num(self) -> int:
return len(self._sharded)

def has_channel(self, channel_type: PubSubType, channel_name: bytes) -> bool:
ret = False
if channel_type is PubSubType.CHANNEL:
ret = channel_name in self._channels
elif channel_type is PubSubType.PATTERN:
ret = channel_name in self._patterns
elif channel_type is PubSubType.SHARDED:
ret = channel_name in self._sharded
return ret

def get_channel(self, channel_type: PubSubType, channel_name: bytes) -> AbcChannel:
if channel_type is PubSubType.CHANNEL:
channel = self._channels[channel_name]
Expand Down Expand Up @@ -785,19 +804,24 @@ def _process_pubsub(self, obj: Any) -> Any:
channel_type=channel_type,
channel_name=channel_name,
)
if self._pubsub_channels_store.channels_total == 0:
self._server_in_pubsub = False
self._client_in_pubsub = False
elif kind in {b"message", b"smessage"}:
(channel_name,) = args
channel_type = PUBSUB_RESP_KIND_TO_TYPE[kind]
channel = self._pubsub_channels_store.get_channel(channel_type, channel_name)
channel.put_nowait(data)
elif kind == b"pmessage":
(pattern, channel_name) = args
elif kind in {b"message", b"smessage", b"pmessage"}:
if kind == b"pmessage":
(pattern, channel_name) = args
else:
(channel_name,) = args
pattern = channel_name

channel_type = PUBSUB_RESP_KIND_TO_TYPE[kind]
channel = self._pubsub_channels_store.get_channel(channel_type, pattern)
channel.put_nowait((channel_name, data))
if self._pubsub_channels_store.has_channel(channel_type, pattern):
channel = self._pubsub_channels_store.get_channel(channel_type, pattern)
if channel_type is PubSubType.PATTERN:
channel.put_nowait((channel_name, data))
else:
channel.put_nowait(data)
else:
logger.warning(
"No channel %r with type %s for received message", pattern, channel_type
)
elif kind == b"pong":
if not self._waiters:
logger.error("No PubSub PONG waiters for received data %r", data)
Expand Down
71 changes: 69 additions & 2 deletions tests/unit_tests/aioredis_cluster/test_connection_pubsub.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import asyncio
import logging
from asyncio.queues import Queue
from unittest import mock

import pytest

from aioredis_cluster.aioredis import ChannelClosedError
from aioredis_cluster.aioredis.stream import StreamReader
from aioredis_cluster.command_info.commands import PUBSUB_SUBSCRIBE_COMMANDS
from aioredis_cluster.connection import RedisConnection
Expand Down Expand Up @@ -121,8 +123,8 @@ async def test_execute__simple_unsubscribe(add_async_finalizer):
assert result_channel == [[b"unsubscribe", b"chan", 1]]
assert result_pattern == [[b"punsubscribe", b"chan", 0]]
assert result_sharded == [[b"sunsubscribe", b"chan", 0]]
assert redis._client_in_pubsub is False
assert redis._server_in_pubsub is False
assert redis._client_in_pubsub is True
assert redis._server_in_pubsub is True


@pytest.mark.parametrize(
Expand Down Expand Up @@ -434,3 +436,68 @@ async def test_subscribe_and_receive_messages(add_async_finalizer):
assert channel_msg == b"channel_msg"
assert pattern_msg == (b"chan:foo", b"pattern_msg")
assert sharded_msg == b"sharded_msg"

assert redis._reader_task is not None
assert redis._reader_task.done() is False


async def test_receive_message_after_unsubscribe(caplog, add_async_finalizer):
reader = get_mocked_reader()
writer = get_mocked_writer()
redis = RedisConnection(reader=reader, writer=writer, address="localhost:6379")
add_async_finalizer(lambda: close_connection(redis))

with caplog.at_level(logging.WARNING):
reader.queue.put_nowait([b"ssubscribe", b"chan:{shard}", 1])
await redis.execute_pubsub("SSUBSCRIBE", "chan:{shard}")
sharded = redis.sharded_pubsub_channels["chan:{shard}"]
await redis.execute_pubsub("SUNSUBSCRIBE", "chan:{shard}")

reader.queue.put_nowait([b"smessage", b"chan:{shard}", b"sharded_msg"])

await moment()

assert sharded.is_active is False
assert sharded._queue.qsize() == 0
with pytest.raises(ChannelClosedError):
await sharded.get()

no_channel_record = ""
for record in caplog.records:
assert "No waiter for received reply" not in record.message, record.message
if "No channel" in record.message and "for received message" in record.message:
no_channel_record = record.message

assert no_channel_record != ""

assert redis._reader_task is not None
assert redis._reader_task.done() is False


async def test_subscribe_and_immediately_unsubscribe(caplog, add_async_finalizer):
reader = get_mocked_reader()
writer = get_mocked_writer()
redis = RedisConnection(reader=reader, writer=writer, address="localhost:6379")
add_async_finalizer(lambda: close_connection(redis))

reader.queue.put_nowait([b"ssubscribe", b"chan1:{shard}", 1])
reader.queue.put_nowait([b"ssubscribe", b"chan2:{shard}", 2])

await redis.execute_pubsub("SSUBSCRIBE", "chan1:{shard}")
await redis.execute_pubsub("SSUBSCRIBE", "chan2:{shard}")

with caplog.at_level(logging.ERROR):
await redis.execute_pubsub("SUNSUBSCRIBE", "chan2:{shard}")
await redis.execute_pubsub("SUNSUBSCRIBE", "chan1:{shard}")
reader.queue.put_nowait([b"sunsubscribe", b"chan2:{shard}", 1])
reader.queue.put_nowait([b"sunsubscribe", b"chan1:{shard}", 0])

await moment(2)

for record in caplog.records:
assert "No waiter for received reply" not in record.message, record.message

assert redis.in_pubsub == 0

assert redis._reader_task is not None
assert redis._reader_task.done() is False

0 comments on commit 5a3e467

Please sign in to comment.