Skip to content

Commit

Permalink
fixes and improves
Browse files Browse the repository at this point in the history
  • Loading branch information
Anton Ilyushenkov committed Dec 12, 2023
1 parent 29e16ac commit 5187d54
Show file tree
Hide file tree
Showing 4 changed files with 190 additions and 17 deletions.
160 changes: 151 additions & 9 deletions dev/test_pubsub_reshard.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
import argparse
import asyncio
import logging
import random
import signal
from collections import deque
from typing import Counter, Deque, Dict, Mapping, Optional
from itertools import cycle
from typing import Counter, Deque, Dict, Mapping, Optional, Sequence

from aioredis_cluster import RedisCluster, create_redis_cluster
from aioredis_cluster import Cluster, ClusterNode, RedisCluster, create_redis_cluster
from aioredis_cluster.aioredis import Channel
from aioredis_cluster.compat.asyncio import timeout
from aioredis_cluster.crc import key_slot

logger = logging.getLogger(__name__)

Expand All @@ -18,8 +21,13 @@ async def tick_log(
global_counters: Counter[str],
) -> None:
count = 0
last = False
while True:
await asyncio.sleep(tick)
try:
await asyncio.sleep(tick)
except asyncio.CancelledError:
last = True

count += 1
logger.info("tick %d", count)
logger.info("tick %d: %r", count, global_counters)
Expand All @@ -31,6 +39,17 @@ async def tick_log(
for routine_id, counters in routines:
logger.info("tick %d: %s: %r", count, routine_id, counters)

if last:
break


def get_channel_name(routine_id: int) -> str:
return f"ch:{routine_id}:{{shard}}"


class ChannelNotClosedError(Exception):
pass


async def subscribe_routine(
*,
Expand All @@ -40,15 +59,15 @@ async def subscribe_routine(
global_counters: Counter[str],
):
await asyncio.sleep(0.5)
ch_name = f"ch:{routine_id}:{{shard}}"
ch_name = get_channel_name(routine_id)
while True:
counters["routine:subscribes"] += 1
prev_ch: Optional[Channel] = None
try:
pool = await redis.keys_master(ch_name)
global_counters["subscribe_in_fly"] += 1
try:
ch: Channel = (await pool.subscribe(ch_name))[0]
ch: Channel = (await pool.ssubscribe(ch_name))[0]
if prev_ch is not None and ch is prev_ch:
logger.error("%s: Previous Channel is current: %r", routine_id, ch)
prev_ch = ch
Expand All @@ -67,10 +86,13 @@ async def subscribe_routine(
else:
counters["msg:received"] += 1
global_counters["msg:received"] += 1
await pool.unsubscribe(ch_name)
await pool.sunsubscribe(ch_name)
finally:
global_counters["subscribe_in_fly"] -= 1
assert ch.is_active is False

if not ch._queue.closed:
raise ChannelNotClosedError()

except asyncio.CancelledError:
break
except Exception as e:
Expand All @@ -81,6 +103,108 @@ async def subscribe_routine(
logger.error("%s: Channel exception: %r", routine_id, e)


async def publish_routine(
*,
redis: RedisCluster,
routines: Sequence[int],
global_counters: Counter[str],
):
routines_cycle = cycle(routines)
for routine_id in routines_cycle:
await asyncio.sleep(random.uniform(0.001, 0.1))
channel_name = get_channel_name(routine_id)
await redis.spublish(channel_name, f"msg:{routine_id}")
global_counters["published_msgs"] += 1


async def cluster_move_slot(
*,
slot: int,
node_src: RedisCluster,
node_src_info: ClusterNode,
node_dest: RedisCluster,
node_dest_info: ClusterNode,
) -> None:
await node_dest.cluster_setslot(slot, "IMPORTING", node_src_info.node_id)
await node_src.cluster_setslot(slot, "MIGRATING", node_dest_info.node_id)
while True:
keys = await node_src.cluster_get_keys_in_slots(slot, 100)
if not keys:
break
await node_src.migrate_keys(
node_dest_info.addr.host,
node_dest_info.addr.port,
keys,
0,
5000,
)
await node_dest.cluster_setslot(slot, "NODE", node_dest_info.node_id)
await node_src.cluster_setslot(slot, "NODE", node_dest_info.node_id)


async def reshard_routine(
*,
redis: RedisCluster,
global_counters: Counter[str],
):
cluster: Cluster = redis.connection
cluster_state = await cluster.get_cluster_state()

moving_slot = key_slot(b"shard")
master1 = await cluster.keys_master(b"shard")

# search another master
count = 0
while True:
count += 1
key = f"key:{count}"
master2 = await cluster.keys_master(key)
if master2.address != master1.address:
break

master1_info = cluster_state.addr_node(master1.address)
master2_info = cluster_state.addr_node(master2.address)
logger.info("Master1 info - %r", master1_info)
logger.info("Master2 info - %r", master2_info)

node_src = master1
node_src_info = master1_info
node_dest = master2
node_dest_info = master2_info

while True:
await asyncio.sleep(random.uniform(2.0, 5.0))
global_counters["reshards"] += 1
logger.info(
"Start moving slot %s from %s to %s", moving_slot, node_src.address, node_dest.address
)
move_slot_task = asyncio.create_task(
cluster_move_slot(
slot=moving_slot,
node_src=node_src,
node_src_info=node_src_info,
node_dest=node_dest,
node_dest_info=node_dest_info,
)
)
try:
await asyncio.shield(move_slot_task)
except asyncio.CancelledError:
await move_slot_task
raise
except Exception as e:
logger.exception("Unexpected error on reshard: %r", e)
break

# swap nodes
node_src, node_src_info, node_dest, node_dest_info = (
node_dest,
node_dest_info,
node_src,
node_src_info,
)


async def async_main(args: argparse.Namespace) -> None:
loop = asyncio.get_event_loop()
node_addr: str = args.node
Expand All @@ -98,6 +222,15 @@ async def async_main(args: argparse.Namespace) -> None:
follow_cluster=True,
)
routine_tasks: Deque[asyncio.Task] = deque()

reshard_routine_task = asyncio.create_task(
reshard_routine(
redis=redis,
global_counters=global_counters,
)
)
routine_tasks.append(reshard_routine_task)

try:
for routine_id in range(1, args.routines + 1):
counters = routines_counters[routine_id] = Counter()
Expand All @@ -111,6 +244,15 @@ async def async_main(args: argparse.Namespace) -> None:
)
routine_tasks.append(routine_task)

publish_routine_task = asyncio.create_task(
publish_routine(
redis=redis,
routines=list(range(1, args.routines + 1)),
global_counters=global_counters,
)
)
routine_tasks.append(publish_routine_task)

logger.info("Routines %d", len(routine_tasks))

await asyncio.sleep(args.wait)
Expand Down Expand Up @@ -147,7 +289,7 @@ def main() -> None:
else:
uvloop.install()

logging.basicConfig(level=logging.INFO)
logging.basicConfig(level=logging.DEBUG)

loop = asyncio.get_event_loop()
main_task = loop.create_task(async_main(args))
Expand All @@ -161,7 +303,7 @@ def main() -> None:
finally:
if not main_task.done() and not main_task.cancelled():
main_task.cancel()
loop.run_until_complete(asyncio.wait([main_task]))
loop.run_until_complete(main_task)
loop.close()


Expand Down
3 changes: 3 additions & 0 deletions src/aioredis_cluster/cluster_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,9 @@ def random_node(self) -> ClusterNode:
def has_addr(self, addr: Address) -> bool:
return addr in self._data.nodes

def addr_node(self, addr: Address) -> ClusterNode:
return self._data.nodes[addr]

def master_replicas(self, addr: Address) -> List[ClusterNode]:
try:
return list(self._data.replicas[addr])
Expand Down
19 changes: 17 additions & 2 deletions src/aioredis_cluster/commands/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,25 @@ def cluster_set_config_epoch(self, config_epoch: int):
fut = self.execute(b"CLUSTER", b"SET-CONFIG-EPOCH", config_epoch)
return wait_ok(fut)

def cluster_setslot(self, slot: int, command, node_id: str = None):
def cluster_setslot(self, slot: int, subcommand: str, node_id: str = None):
"""Bind a hash slot to specified node."""

raise NotImplementedError()
subcommand = subcommand.upper()
if subcommand in {"NODE", "IMPORTING", "MIGRATING"}:
if not node_id:
raise ValueError(f"For subcommand {subcommand} node_id must be provided")
elif subcommand in {"STABLE"}:
if node_id:
raise ValueError("For subcommand STABLE node_id is not required")
else:
raise ValueError(f"Unknown subcommand {subcommand}")

extra = []
if node_id:
extra.append(node_id)

fut = self.execute(b"CLUSTER", b"SETSLOT", slot, subcommand, *extra)
return fut

def cluster_slaves(self, node_id: str):
"""List slave nodes of the specified master node."""
Expand Down
25 changes: 19 additions & 6 deletions src/aioredis_cluster/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,11 +528,18 @@ async def _read_data(self) -> None:

if self._server_in_pubsub:
if isinstance(obj, MovedError):
logger.warning(
"Received MOVED in PubSub mode. Unsubscribe all channels from %d slot",
obj.info.slot_id,
)
self._pubsub_store.slot_channels_unsubscribe(obj.info.slot_id)
if self._pubsub_store.have_slot_channels(obj.info.slot_id):
logger.warning(
(
"Received MOVED in PubSub mode from %s to %s:%s. "
"Unsubscribe all channels from %d slot",
),
self.address,
obj.info.host,
obj.info.port,
obj.info.slot_id,
)
self._pubsub_store.slot_channels_unsubscribe(obj.info.slot_id)
elif isinstance(obj, RedisError):
raise obj
else:
Expand All @@ -542,7 +549,13 @@ async def _read_data(self) -> None:
if isinstance(obj, MovedError):
if self._pubsub_store.have_slot_channels(obj.info.slot_id):
logger.warning(
"Received MOVED. Unsubscribe all channels from %d slot",
(
"Received MOVED from %s to %s:%s. "
"Unsubscribe all channels from %d slot",
),
self.address,
obj.info.host,
obj.info.port,
obj.info.slot_id,
)
self._pubsub_store.slot_channels_unsubscribe(obj.info.slot_id)
Expand Down

0 comments on commit 5187d54

Please sign in to comment.