diff --git a/dev/test_pubsub_reshard.py b/dev/test_pubsub_reshard.py index 43c7c02..6b11e8a 100644 --- a/dev/test_pubsub_reshard.py +++ b/dev/test_pubsub_reshard.py @@ -1,13 +1,16 @@ import argparse import asyncio import logging +import random import signal from collections import deque -from typing import Counter, Deque, Dict, Mapping, Optional +from itertools import cycle +from typing import Counter, Deque, Dict, Mapping, Optional, Sequence -from aioredis_cluster import RedisCluster, create_redis_cluster +from aioredis_cluster import Cluster, ClusterNode, RedisCluster, create_redis_cluster from aioredis_cluster.aioredis import Channel from aioredis_cluster.compat.asyncio import timeout +from aioredis_cluster.crc import key_slot logger = logging.getLogger(__name__) @@ -18,8 +21,13 @@ async def tick_log( global_counters: Counter[str], ) -> None: count = 0 + last = False while True: - await asyncio.sleep(tick) + try: + await asyncio.sleep(tick) + except asyncio.CancelledError: + last = True + count += 1 logger.info("tick %d", count) logger.info("tick %d: %r", count, global_counters) @@ -31,6 +39,17 @@ async def tick_log( for routine_id, counters in routines: logger.info("tick %d: %s: %r", count, routine_id, counters) + if last: + break + + +def get_channel_name(routine_id: int) -> str: + return f"ch:{routine_id}:{{shard}}" + + +class ChannelNotClosedError(Exception): + pass + async def subscribe_routine( *, @@ -40,7 +59,7 @@ async def subscribe_routine( global_counters: Counter[str], ): await asyncio.sleep(0.5) - ch_name = f"ch:{routine_id}:{{shard}}" + ch_name = get_channel_name(routine_id) while True: counters["routine:subscribes"] += 1 prev_ch: Optional[Channel] = None @@ -48,7 +67,7 @@ async def subscribe_routine( pool = await redis.keys_master(ch_name) global_counters["subscribe_in_fly"] += 1 try: - ch: Channel = (await pool.subscribe(ch_name))[0] + ch: Channel = (await pool.ssubscribe(ch_name))[0] if prev_ch is not None and ch is prev_ch: logger.error("%s: Previous Channel is current: %r", routine_id, ch) prev_ch = ch @@ -67,10 +86,13 @@ async def subscribe_routine( else: counters["msg:received"] += 1 global_counters["msg:received"] += 1 - await pool.unsubscribe(ch_name) + await pool.sunsubscribe(ch_name) finally: global_counters["subscribe_in_fly"] -= 1 - assert ch.is_active is False + + if not ch._queue.closed: + raise ChannelNotClosedError() + except asyncio.CancelledError: break except Exception as e: @@ -81,6 +103,108 @@ async def subscribe_routine( logger.error("%s: Channel exception: %r", routine_id, e) +async def publish_routine( + *, + redis: RedisCluster, + routines: Sequence[int], + global_counters: Counter[str], +): + routines_cycle = cycle(routines) + for routine_id in routines_cycle: + await asyncio.sleep(random.uniform(0.001, 0.1)) + channel_name = get_channel_name(routine_id) + await redis.spublish(channel_name, f"msg:{routine_id}") + global_counters["published_msgs"] += 1 + + +async def cluster_move_slot( + *, + slot: int, + node_src: RedisCluster, + node_src_info: ClusterNode, + node_dest: RedisCluster, + node_dest_info: ClusterNode, +) -> None: + await node_dest.cluster_setslot(slot, "IMPORTING", node_src_info.node_id) + await node_src.cluster_setslot(slot, "MIGRATING", node_dest_info.node_id) + while True: + keys = await node_src.cluster_get_keys_in_slots(slot, 100) + if not keys: + break + await node_src.migrate_keys( + node_dest_info.addr.host, + node_dest_info.addr.port, + keys, + 0, + 5000, + ) + await node_dest.cluster_setslot(slot, "NODE", node_dest_info.node_id) + await node_src.cluster_setslot(slot, "NODE", node_dest_info.node_id) + + +async def reshard_routine( + *, + redis: RedisCluster, + global_counters: Counter[str], +): + cluster: Cluster = redis.connection + cluster_state = await cluster.get_cluster_state() + + moving_slot = key_slot(b"shard") + master1 = await cluster.keys_master(b"shard") + + # search another master + count = 0 + while True: + count += 1 + key = f"key:{count}" + master2 = await cluster.keys_master(key) + if master2.address != master1.address: + break + + master1_info = cluster_state.addr_node(master1.address) + master2_info = cluster_state.addr_node(master2.address) + logger.info("Master1 info - %r", master1_info) + logger.info("Master2 info - %r", master2_info) + + node_src = master1 + node_src_info = master1_info + node_dest = master2 + node_dest_info = master2_info + + while True: + await asyncio.sleep(random.uniform(2.0, 5.0)) + global_counters["reshards"] += 1 + logger.info( + "Start moving slot %s from %s to %s", moving_slot, node_src.address, node_dest.address + ) + move_slot_task = asyncio.create_task( + cluster_move_slot( + slot=moving_slot, + node_src=node_src, + node_src_info=node_src_info, + node_dest=node_dest, + node_dest_info=node_dest_info, + ) + ) + try: + await asyncio.shield(move_slot_task) + except asyncio.CancelledError: + await move_slot_task + raise + except Exception as e: + logger.exception("Unexpected error on reshard: %r", e) + break + + # swap nodes + node_src, node_src_info, node_dest, node_dest_info = ( + node_dest, + node_dest_info, + node_src, + node_src_info, + ) + + async def async_main(args: argparse.Namespace) -> None: loop = asyncio.get_event_loop() node_addr: str = args.node @@ -98,6 +222,15 @@ async def async_main(args: argparse.Namespace) -> None: follow_cluster=True, ) routine_tasks: Deque[asyncio.Task] = deque() + + reshard_routine_task = asyncio.create_task( + reshard_routine( + redis=redis, + global_counters=global_counters, + ) + ) + routine_tasks.append(reshard_routine_task) + try: for routine_id in range(1, args.routines + 1): counters = routines_counters[routine_id] = Counter() @@ -111,6 +244,15 @@ async def async_main(args: argparse.Namespace) -> None: ) routine_tasks.append(routine_task) + publish_routine_task = asyncio.create_task( + publish_routine( + redis=redis, + routines=list(range(1, args.routines + 1)), + global_counters=global_counters, + ) + ) + routine_tasks.append(publish_routine_task) + logger.info("Routines %d", len(routine_tasks)) await asyncio.sleep(args.wait) @@ -147,7 +289,7 @@ def main() -> None: else: uvloop.install() - logging.basicConfig(level=logging.INFO) + logging.basicConfig(level=logging.DEBUG) loop = asyncio.get_event_loop() main_task = loop.create_task(async_main(args)) @@ -161,7 +303,7 @@ def main() -> None: finally: if not main_task.done() and not main_task.cancelled(): main_task.cancel() - loop.run_until_complete(asyncio.wait([main_task])) + loop.run_until_complete(main_task) loop.close() diff --git a/src/aioredis_cluster/cluster_state.py b/src/aioredis_cluster/cluster_state.py index 4c34964..ab20eb6 100644 --- a/src/aioredis_cluster/cluster_state.py +++ b/src/aioredis_cluster/cluster_state.py @@ -212,6 +212,9 @@ def random_node(self) -> ClusterNode: def has_addr(self, addr: Address) -> bool: return addr in self._data.nodes + def addr_node(self, addr: Address) -> ClusterNode: + return self._data.nodes[addr] + def master_replicas(self, addr: Address) -> List[ClusterNode]: try: return list(self._data.replicas[addr]) diff --git a/src/aioredis_cluster/commands/cluster.py b/src/aioredis_cluster/commands/cluster.py index 7674888..8dcb656 100644 --- a/src/aioredis_cluster/commands/cluster.py +++ b/src/aioredis_cluster/commands/cluster.py @@ -105,10 +105,25 @@ def cluster_set_config_epoch(self, config_epoch: int): fut = self.execute(b"CLUSTER", b"SET-CONFIG-EPOCH", config_epoch) return wait_ok(fut) - def cluster_setslot(self, slot: int, command, node_id: str = None): + def cluster_setslot(self, slot: int, subcommand: str, node_id: str = None): """Bind a hash slot to specified node.""" - raise NotImplementedError() + subcommand = subcommand.upper() + if subcommand in {"NODE", "IMPORTING", "MIGRATING"}: + if not node_id: + raise ValueError(f"For subcommand {subcommand} node_id must be provided") + elif subcommand in {"STABLE"}: + if node_id: + raise ValueError("For subcommand STABLE node_id is not required") + else: + raise ValueError(f"Unknown subcommand {subcommand}") + + extra = [] + if node_id: + extra.append(node_id) + + fut = self.execute(b"CLUSTER", b"SETSLOT", slot, subcommand, *extra) + return fut def cluster_slaves(self, node_id: str): """List slave nodes of the specified master node.""" diff --git a/src/aioredis_cluster/connection.py b/src/aioredis_cluster/connection.py index ab2917d..3c79a96 100644 --- a/src/aioredis_cluster/connection.py +++ b/src/aioredis_cluster/connection.py @@ -528,11 +528,18 @@ async def _read_data(self) -> None: if self._server_in_pubsub: if isinstance(obj, MovedError): - logger.warning( - "Received MOVED in PubSub mode. Unsubscribe all channels from %d slot", - obj.info.slot_id, - ) - self._pubsub_store.slot_channels_unsubscribe(obj.info.slot_id) + if self._pubsub_store.have_slot_channels(obj.info.slot_id): + logger.warning( + ( + "Received MOVED in PubSub mode from %s to %s:%s. " + "Unsubscribe all channels from %d slot", + ), + self.address, + obj.info.host, + obj.info.port, + obj.info.slot_id, + ) + self._pubsub_store.slot_channels_unsubscribe(obj.info.slot_id) elif isinstance(obj, RedisError): raise obj else: @@ -542,7 +549,13 @@ async def _read_data(self) -> None: if isinstance(obj, MovedError): if self._pubsub_store.have_slot_channels(obj.info.slot_id): logger.warning( - "Received MOVED. Unsubscribe all channels from %d slot", + ( + "Received MOVED from %s to %s:%s. " + "Unsubscribe all channels from %d slot", + ), + self.address, + obj.info.host, + obj.info.port, obj.info.slot_id, ) self._pubsub_store.slot_channels_unsubscribe(obj.info.slot_id)