From 037238ac22a065e0a07045d0d63b6bf4ee4da574 Mon Sep 17 00:00:00 2001 From: CircuitSacul Date: Thu, 8 Sep 2022 06:30:21 +0000 Subject: [PATCH] be more lenient with the destination for events/commands (#50) --- hikari_clusters/brain.py | 2 +- hikari_clusters/info_classes.py | 3 +++ hikari_clusters/ipc_client.py | 45 +++++++++++++++++++-------------- 3 files changed, 30 insertions(+), 20 deletions(-) diff --git a/hikari_clusters/brain.py b/hikari_clusters/brain.py index 84a4d13..a8a3f83 100644 --- a/hikari_clusters/brain.py +++ b/hikari_clusters/brain.py @@ -221,7 +221,7 @@ async def _main_loop(self) -> None: server_uid, shards = to_launch await self.ipc.send_command( - [server_uid], + server_uid, "launch_cluster", {"shard_ids": shards, "shard_count": self.total_shards}, ) diff --git a/hikari_clusters/info_classes.py b/hikari_clusters/info_classes.py index 29b3673..038f91f 100644 --- a/hikari_clusters/info_classes.py +++ b/hikari_clusters/info_classes.py @@ -49,6 +49,9 @@ def fromdict(data: dict[str, Any]) -> BaseInfo: cls = BaseInfo._info_classes[data.pop("_info_class_id")] return cls(**data) + def __int__(self) -> int: + return self.uid + @dataclass class ServerInfo(BaseInfo): diff --git a/hikari_clusters/ipc_client.py b/hikari_clusters/ipc_client.py index 3fb9ed2..baf391b 100644 --- a/hikari_clusters/ipc_client.py +++ b/hikari_clusters/ipc_client.py @@ -27,7 +27,7 @@ import logging import pathlib import ssl -from typing import Any, Iterable, TypeVar, cast +from typing import Any, Iterable, TypeVar, Union, cast from websockets.exceptions import ConnectionClosed, ConnectionClosedOK from websockets.legacy import client @@ -46,6 +46,12 @@ __all__ = ("IpcClient",) +_TO = Union[Iterable[Union[BaseInfo, int]], BaseInfo, int] + + +def _parse_to(to: _TO) -> Iterable[int]: + return map(int, to) if isinstance(to, Iterable) else [int(to)] + class IpcClient(IpcBase): """A connection to a :class:`~ipc_server.IpcServer`. @@ -186,23 +192,21 @@ def _stop(*args: Any, **kwargs: Any) -> None: self.tasks.create_task(self._start()).add_done_callback(_stop) - async def send_not_found_response( - self, to: Iterable[int], callback: int - ) -> None: + async def send_not_found_response(self, to: _TO, callback: int) -> None: """Respond to a command saying that the command was not found. Parameters ---------- - to : Iterable[int] + to : Iterable[int | BaseInfo] The clients to send the response to. callback : int The command callback (:attr:`~payload.Command.callback`) """ - await self._send(to, payload.ResponseNotFound(callback)) + await self._send(_parse_to(to), payload.ResponseNotFound(callback)) async def send_ok_response( - self, to: Iterable[int], callback: int, data: payload.DATA = None + self, to: _TO, callback: int, data: payload.DATA = None ) -> None: """Respond that the command *function* finished without any problems. @@ -210,7 +214,7 @@ async def send_ok_response( Parameters ---------- - to : Iterable[int] + to : Iterable[int | BaseInfo] The clients to send the response to. callback : int The command callback (:attr:`~payload.Command.callback`) @@ -218,16 +222,14 @@ async def send_ok_response( The data to send with the response, by default None """ - await self._send(to, payload.ResponseOk(callback, data)) + await self._send(_parse_to(to), payload.ResponseOk(callback, data)) - async def send_tb_response( - self, to: Iterable[int], callback: int, tb: str - ) -> None: + async def send_tb_response(self, to: _TO, callback: int, tb: str) -> None: """Respond that the command function raised an exception. Parameters ---------- - to : Iterable[int] + to : Iterable[int | BaseInfo] The clients to send the response to. callback : int The command callback (:attr:`~payload.Command.callback`) @@ -235,10 +237,12 @@ async def send_tb_response( The exception traceback. """ - await self._send(to, payload.ResponseTraceback(callback, tb)) + await self._send( + _parse_to(to), payload.ResponseTraceback(callback, tb) + ) async def send_event( - self, to: Iterable[int], name: str, data: payload.DATA = None + self, to: _TO, name: str, data: payload.DATA = None ) -> None: """Dispatch an event. @@ -254,11 +258,11 @@ async def send_event( The data to send with the event, by default None """ - await self._send(to, payload.Event(name, data)) + await self._send(_parse_to(to), payload.Event(name, data)) async def send_command( self, - to: Iterable[int], + to: _TO, name: str, data: payload.DATA = None, timeout: float = 3.0, @@ -267,7 +271,7 @@ async def send_command( Parameters ---------- - to : Iterable[int] + to : Iterable[int | BaseInfo] The clients to send the command to. name : str The name of the command. @@ -283,6 +287,7 @@ async def send_command( if any. """ + to = _parse_to(to) with self.callbacks.callback(to) as cb: await self._send(to, payload.Command(name, cb.key, data)) await cb.wait(timeout) @@ -363,7 +368,9 @@ async def _send( self, to: Iterable[int], pl_data: payload.PAYLOAD_DATA ) -> None: assert self.uid is not None - pl = payload.Payload(pl_data.opcode, self.uid, list(to), pl_data) + pl = payload.Payload( + pl_data.opcode, self.uid, list(map(int, to)), pl_data + ) await self._raw_send(json.dumps(pl.serialize())) async def _raw_send(self, msg: str) -> None: