From 2a69b60350772952492d19c6ff20d7fd7704dd3e Mon Sep 17 00:00:00 2001 From: beagold <86345081+beagold@users.noreply.github.com> Date: Sun, 28 Apr 2024 10:24:44 +0200 Subject: [PATCH] Optimize gateway transport (#1898) --- changes/1898.feature.md | 3 + hikari/errors.py | 2 +- hikari/impl/shard.py | 47 ++++++++------- tests/hikari/impl/test_shard.py | 103 +++++++++++--------------------- 4 files changed, 64 insertions(+), 91 deletions(-) create mode 100644 changes/1898.feature.md diff --git a/changes/1898.feature.md b/changes/1898.feature.md new file mode 100644 index 0000000000..d38c1e2aac --- /dev/null +++ b/changes/1898.feature.md @@ -0,0 +1,3 @@ +Optimize gateway transport +- Merge cold path for zlib compression into main path to avoid additional call +- Handle data in `bytes`, rather than in `str` to make good use of speedups (similar to `RESTClient`) diff --git a/hikari/errors.py b/hikari/errors.py index 085f5a296f..0b5a949441 100644 --- a/hikari/errors.py +++ b/hikari/errors.py @@ -182,7 +182,7 @@ class GatewayTransportError(GatewayError): """An exception thrown if an issue occurs at the transport layer.""" def __str__(self) -> str: - return f"Gateway transport error: {self.reason!r}" + return f"Gateway transport error: {self.reason}" @attrs.define(auto_exc=True, repr=False, slots=False) diff --git a/hikari/impl/shard.py b/hikari/impl/shard.py index df2694873e..3218ceedfc 100644 --- a/hikari/impl/shard.py +++ b/hikari/impl/shard.py @@ -117,9 +117,9 @@ _CUSTOM_STATUS_NAME = "Custom Status" -def _log_filterer(token: str) -> typing.Callable[[str], str]: - def filterer(entry: str) -> str: - return entry.replace(token, "**REDACTED TOKEN**") +def _log_filterer(token: bytes) -> typing.Callable[[bytes], bytes]: + def filterer(entry: bytes) -> bytes: + return entry.replace(token, b"**REDACTED TOKEN**") return filterer @@ -153,7 +153,7 @@ def __init__( transport_compression: bool, exit_stack: contextlib.AsyncExitStack, logger: logging.Logger, - log_filterer: typing.Callable[[str], str], + log_filterer: typing.Callable[[bytes], bytes], dumps: data_binding.JSONEncoder, loads: data_binding.JSONDecoder, ) -> None: @@ -203,7 +203,7 @@ async def receive_json(self) -> typing.Any: async def send_json(self, data: data_binding.JSONObject) -> None: pl = self._dumps(data) if self._logger.isEnabledFor(ux.TRACE): - filtered = self._log_filterer(pl.decode("utf-8")) + filtered = self._log_filterer(pl) self._logger.log(ux.TRACE, "sending payload with size %s\n %s", len(pl), filtered) await self._ws.send_bytes(pl) @@ -232,39 +232,40 @@ def _handle_other_message(self, message: aiohttp.WSMessage, /) -> typing.NoRetur reason = f"{message.data!r} [extra={message.extra!r}, type={message.type}]" raise errors.GatewayTransportError(reason) from self._ws.exception() - async def _receive_and_check_text(self) -> str: + async def _receive_and_check_text(self) -> bytes: message = await self._ws.receive() if message.type == aiohttp.WSMsgType.TEXT: assert isinstance(message.data, str) - return message.data + return message.data.encode() self._handle_other_message(message) - async def _receive_and_check_zlib(self) -> str: + async def _receive_and_check_zlib(self) -> bytes: message = await self._ws.receive() if message.type == aiohttp.WSMsgType.BINARY: if message.data.endswith(_ZLIB_SUFFIX): - return self._zlib.decompress(message.data).decode("utf-8") - - return await self._receive_and_check_complete_zlib_package(message.data) + # Hot and fast path: we already have the full message + # in a single frame + return self._zlib.decompress(message.data) - self._handle_other_message(message) + # Cold and slow path: we need to keep receiving frames to complete + # the whole message. Only then do we create a buffer + buff = bytearray(message.data) - async def _receive_and_check_complete_zlib_package(self, initial_data: bytes, /) -> str: - buff = bytearray(initial_data) + while not buff.endswith(_ZLIB_SUFFIX): + message = await self._ws.receive() - while not buff.endswith(_ZLIB_SUFFIX): - message = await self._ws.receive() + if message.type == aiohttp.WSMsgType.BINARY: + buff.extend(message.data) + continue - if message.type == aiohttp.WSMsgType.BINARY: - buff.extend(message.data) - continue + self._handle_other_message(message) - self._handle_other_message(message) + return self._zlib.decompress(buff) - return self._zlib.decompress(buff).decode("utf-8") + self._handle_other_message(message) @classmethod async def connect( @@ -273,7 +274,7 @@ async def connect( http_settings: config.HTTPSettings, logger: logging.Logger, proxy_settings: config.ProxySettings, - log_filterer: typing.Callable[[str], str], + log_filterer: typing.Callable[[bytes], bytes], dumps: data_binding.JSONEncoder, loads: data_binding.JSONDecoder, transport_compression: bool, @@ -810,7 +811,7 @@ async def _connect(self) -> typing.Tuple[asyncio.Task[None], ...]: self._ws = await _GatewayTransport.connect( http_settings=self._http_settings, - log_filterer=_log_filterer(self._token), + log_filterer=_log_filterer(self._token.encode()), logger=self._logger, proxy_settings=self._proxy_settings, transport_compression=self._transport_compression, diff --git a/tests/hikari/impl/test_shard.py b/tests/hikari/impl/test_shard.py index ee3bd0eda4..b8420f91d9 100644 --- a/tests/hikari/impl/test_shard.py +++ b/tests/hikari/impl/test_shard.py @@ -44,11 +44,11 @@ def test_log_filterer(): - filterer = shard._log_filterer("TOKEN") + filterer = shard._log_filterer(b"TOKEN") - returned = filterer("this log contains the TOKEN and it should get removed and the TOKEN here too") + returned = filterer(b"this log contains the TOKEN and it should get removed and the TOKEN here too") assert returned == ( - "this log contains the **REDACTED TOKEN** and it should get removed and the **REDACTED TOKEN** here too" + b"this log contains the **REDACTED TOKEN** and it should get removed and the **REDACTED TOKEN** here too" ) @@ -275,100 +275,69 @@ def test__handle_other_message_when_message_type_is_unknown(self, transport_impl assert exc_info.value.__cause__ is exception @pytest.mark.asyncio - async def test__receive_and_check_text_when_message_type_is_TEXT(self, transport_impl): + async def test__receive_and_check_text(self, transport_impl): transport_impl._ws.receive = mock.AsyncMock( return_value=StubResponse(type=aiohttp.WSMsgType.TEXT, data="some text") ) - assert await transport_impl._receive_and_check_text() == "some text" + assert await transport_impl._receive_and_check_text() == b"some text" transport_impl._ws.receive.assert_awaited_once_with() @pytest.mark.asyncio async def test__receive_and_check_text_when_message_type_is_unknown(self, transport_impl): - mock_exception = errors.GatewayError("aye") transport_impl._ws.receive = mock.AsyncMock(return_value=StubResponse(type=aiohttp.WSMsgType.BINARY)) - with mock.patch.object( - shard._GatewayTransport, "_handle_other_message", side_effect=mock_exception - ) as handle_other_message: - with pytest.raises(errors.GatewayError) as exc_info: - await transport_impl._receive_and_check_text() + with pytest.raises( + errors.GatewayTransportError, + match="Gateway transport error: Unexpected message type received BINARY, expected TEXT", + ): + await transport_impl._receive_and_check_text() - assert exc_info.value is mock_exception transport_impl._ws.receive.assert_awaited_once_with() - handle_other_message.assert_called_once_with(transport_impl._ws.receive.return_value) @pytest.mark.asyncio - async def test__receive_and_check_zlib_when_message_type_is_BINARY(self, transport_impl): - response = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"some initial data") - transport_impl._ws.receive = mock.AsyncMock(return_value=response) + async def test__receive_and_check_zlib_when_payload_split_across_frames(self, transport_impl): + response1 = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"x\xda\xf2H\xcd\xc9") + response2 = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"\xc9W(\xcf/\xcaIQ\x04\x00\x00") + response3 = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"\x00\xff\xff") + transport_impl._ws.receive = mock.AsyncMock(side_effect=[response1, response2, response3]) - with mock.patch.object( - shard._GatewayTransport, "_receive_and_check_complete_zlib_package" - ) as receive_and_check_complete_zlib_package: - assert ( - await transport_impl._receive_and_check_zlib() is receive_and_check_complete_zlib_package.return_value - ) + assert await transport_impl._receive_and_check_zlib() == b"Hello world!" - transport_impl._ws.receive.assert_awaited_once_with() - receive_and_check_complete_zlib_package.assert_awaited_once_with(b"some initial data") + assert transport_impl._ws.receive.call_count == 3 @pytest.mark.asyncio - async def test__receive_and_check_zlib_when_message_type_is_BINARY_and_the_full_payload(self, transport_impl): - response = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"some initial data\x00\x00\xff\xff") + async def test__receive_and_check_zlib_when_full_payload_in_one_frame(self, transport_impl): + response = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"x\xdaJLD\x07\x00\x00\x00\x00\xff\xff") transport_impl._ws.receive = mock.AsyncMock(return_value=response) - transport_impl._zlib = mock.Mock(decompress=mock.Mock(return_value=b"aaaaaaaaaaaaaaaaaa")) - assert await transport_impl._receive_and_check_zlib() == "aaaaaaaaaaaaaaaaaa" + assert await transport_impl._receive_and_check_zlib() == b"aaaaaaaaaaaaaaaaaa" transport_impl._ws.receive.assert_awaited_once_with() - transport_impl._zlib.decompress.assert_called_once_with(response.data) @pytest.mark.asyncio async def test__receive_and_check_zlib_when_message_type_is_unknown(self, transport_impl): - mock_exception = errors.GatewayError("aye") transport_impl._ws.receive = mock.AsyncMock(return_value=StubResponse(type=aiohttp.WSMsgType.TEXT)) - with mock.patch.object( - shard._GatewayTransport, "_handle_other_message", side_effect=mock_exception - ) as handle_other_message: - with pytest.raises(errors.GatewayError) as exc_info: - await transport_impl._receive_and_check_zlib() - - assert exc_info.value is mock_exception - transport_impl._ws.receive.assert_awaited_once_with() - handle_other_message.assert_called_once_with(transport_impl._ws.receive.return_value) - - @pytest.mark.asyncio - async def test__receive_and_check_complete_zlib_package_for_unexpected_message_type(self, transport_impl): - mock_exception = errors.GatewayError("aye") - response = StubResponse(type=aiohttp.WSMsgType.TEXT) - transport_impl._ws.receive = mock.AsyncMock(return_value=response) - - with mock.patch.object( - shard._GatewayTransport, "_handle_other_message", side_effect=mock_exception - ) as handle_other_message: - with pytest.raises(errors.GatewayError) as exc_info: - await transport_impl._receive_and_check_complete_zlib_package(b"some") - - assert exc_info.value is mock_exception - transport_impl._ws.receive.assert_awaited_with() - handle_other_message.assert_called_once_with(response) + with pytest.raises( + errors.GatewayTransportError, + match="Gateway transport error: Unexpected message type received TEXT, expected BINARY", + ): + await transport_impl._receive_and_check_zlib() @pytest.mark.asyncio - async def test__receive_and_check_complete_zlib_package(self, transport_impl): - response1 = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"more") - response2 = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"data") - response3 = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"\x00\x00\xff\xff") + async def test__receive_and_check_zlib_when_issue_during_reception_of_multiple_frames(self, transport_impl): + response1 = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"x\xda\xf2H\xcd\xc9") + response2 = StubResponse(type=aiohttp.WSMsgType.ERROR, data="Something broke!") + response3 = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"\x00\xff\xff") transport_impl._ws.receive = mock.AsyncMock(side_effect=[response1, response2, response3]) - transport_impl._zlib = mock.Mock(decompress=mock.Mock(return_value=b"decoded utf-8 encoded bytes")) - - assert await transport_impl._receive_and_check_complete_zlib_package(b"some") == "decoded utf-8 encoded bytes" + transport_impl._ws.exception = mock.Mock(return_value=None) - assert transport_impl._ws.receive.call_count == 3 - transport_impl._ws.receive.assert_has_awaits([mock.call(), mock.call(), mock.call()]) - transport_impl._zlib.decompress.assert_called_once_with(bytearray(b"somemoredata\x00\x00\xff\xff")) + with pytest.raises( + errors.GatewayTransportError, match=r"Gateway transport error: 'Something broke!' \[extra=None, type=258\]" + ): + await transport_impl._receive_and_check_zlib() @pytest.mark.parametrize("transport_compression", [True, False]) @pytest.mark.asyncio @@ -1002,7 +971,7 @@ async def test__connect_when_not_reconnecting(self, client, http_settings, proxy with stack: assert await client._connect() == (heartbeat_task, poll_events_task) - log_filterer.assert_called_once_with("sometoken") + log_filterer.assert_called_once_with(b"sometoken") gateway_transport_connect.assert_called_once_with( http_settings=http_settings, log_filterer=log_filterer.return_value, @@ -1087,7 +1056,7 @@ async def test__connect_when_reconnecting(self, client, http_settings, proxy_set with stack: assert await client._connect() == (heartbeat_task, poll_events_task) - log_filterer.assert_called_once_with("sometoken") + log_filterer.assert_called_once_with(b"sometoken") gateway_transport_connect.assert_called_once_with( http_settings=http_settings, log_filterer=log_filterer.return_value,