Skip to content

Commit

Permalink
Optimize gateway transport (#1898)
Browse files Browse the repository at this point in the history
  • Loading branch information
beagold authored Apr 28, 2024
1 parent 1f50bb1 commit 2a69b60
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 91 deletions.
3 changes: 3 additions & 0 deletions changes/1898.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
Optimize gateway transport
- Merge cold path for zlib compression into main path to avoid additional call
- Handle data in `bytes`, rather than in `str` to make good use of speedups (similar to `RESTClient`)
2 changes: 1 addition & 1 deletion hikari/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ class GatewayTransportError(GatewayError):
"""An exception thrown if an issue occurs at the transport layer."""

def __str__(self) -> str:
return f"Gateway transport error: {self.reason!r}"
return f"Gateway transport error: {self.reason}"


@attrs.define(auto_exc=True, repr=False, slots=False)
Expand Down
47 changes: 24 additions & 23 deletions hikari/impl/shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,9 @@
_CUSTOM_STATUS_NAME = "Custom Status"


def _log_filterer(token: str) -> typing.Callable[[str], str]:
def filterer(entry: str) -> str:
return entry.replace(token, "**REDACTED TOKEN**")
def _log_filterer(token: bytes) -> typing.Callable[[bytes], bytes]:
def filterer(entry: bytes) -> bytes:
return entry.replace(token, b"**REDACTED TOKEN**")

return filterer

Expand Down Expand Up @@ -153,7 +153,7 @@ def __init__(
transport_compression: bool,
exit_stack: contextlib.AsyncExitStack,
logger: logging.Logger,
log_filterer: typing.Callable[[str], str],
log_filterer: typing.Callable[[bytes], bytes],
dumps: data_binding.JSONEncoder,
loads: data_binding.JSONDecoder,
) -> None:
Expand Down Expand Up @@ -203,7 +203,7 @@ async def receive_json(self) -> typing.Any:
async def send_json(self, data: data_binding.JSONObject) -> None:
pl = self._dumps(data)
if self._logger.isEnabledFor(ux.TRACE):
filtered = self._log_filterer(pl.decode("utf-8"))
filtered = self._log_filterer(pl)
self._logger.log(ux.TRACE, "sending payload with size %s\n %s", len(pl), filtered)

await self._ws.send_bytes(pl)
Expand Down Expand Up @@ -232,39 +232,40 @@ def _handle_other_message(self, message: aiohttp.WSMessage, /) -> typing.NoRetur
reason = f"{message.data!r} [extra={message.extra!r}, type={message.type}]"
raise errors.GatewayTransportError(reason) from self._ws.exception()

async def _receive_and_check_text(self) -> str:
async def _receive_and_check_text(self) -> bytes:
message = await self._ws.receive()

if message.type == aiohttp.WSMsgType.TEXT:
assert isinstance(message.data, str)
return message.data
return message.data.encode()

self._handle_other_message(message)

async def _receive_and_check_zlib(self) -> str:
async def _receive_and_check_zlib(self) -> bytes:
message = await self._ws.receive()

if message.type == aiohttp.WSMsgType.BINARY:
if message.data.endswith(_ZLIB_SUFFIX):
return self._zlib.decompress(message.data).decode("utf-8")

return await self._receive_and_check_complete_zlib_package(message.data)
# Hot and fast path: we already have the full message
# in a single frame
return self._zlib.decompress(message.data)

self._handle_other_message(message)
# Cold and slow path: we need to keep receiving frames to complete
# the whole message. Only then do we create a buffer
buff = bytearray(message.data)

async def _receive_and_check_complete_zlib_package(self, initial_data: bytes, /) -> str:
buff = bytearray(initial_data)
while not buff.endswith(_ZLIB_SUFFIX):
message = await self._ws.receive()

while not buff.endswith(_ZLIB_SUFFIX):
message = await self._ws.receive()
if message.type == aiohttp.WSMsgType.BINARY:
buff.extend(message.data)
continue

if message.type == aiohttp.WSMsgType.BINARY:
buff.extend(message.data)
continue
self._handle_other_message(message)

self._handle_other_message(message)
return self._zlib.decompress(buff)

return self._zlib.decompress(buff).decode("utf-8")
self._handle_other_message(message)

@classmethod
async def connect(
Expand All @@ -273,7 +274,7 @@ async def connect(
http_settings: config.HTTPSettings,
logger: logging.Logger,
proxy_settings: config.ProxySettings,
log_filterer: typing.Callable[[str], str],
log_filterer: typing.Callable[[bytes], bytes],
dumps: data_binding.JSONEncoder,
loads: data_binding.JSONDecoder,
transport_compression: bool,
Expand Down Expand Up @@ -810,7 +811,7 @@ async def _connect(self) -> typing.Tuple[asyncio.Task[None], ...]:

self._ws = await _GatewayTransport.connect(
http_settings=self._http_settings,
log_filterer=_log_filterer(self._token),
log_filterer=_log_filterer(self._token.encode()),
logger=self._logger,
proxy_settings=self._proxy_settings,
transport_compression=self._transport_compression,
Expand Down
103 changes: 36 additions & 67 deletions tests/hikari/impl/test_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,11 @@


def test_log_filterer():
filterer = shard._log_filterer("TOKEN")
filterer = shard._log_filterer(b"TOKEN")

returned = filterer("this log contains the TOKEN and it should get removed and the TOKEN here too")
returned = filterer(b"this log contains the TOKEN and it should get removed and the TOKEN here too")
assert returned == (
"this log contains the **REDACTED TOKEN** and it should get removed and the **REDACTED TOKEN** here too"
b"this log contains the **REDACTED TOKEN** and it should get removed and the **REDACTED TOKEN** here too"
)


Expand Down Expand Up @@ -275,100 +275,69 @@ def test__handle_other_message_when_message_type_is_unknown(self, transport_impl
assert exc_info.value.__cause__ is exception

@pytest.mark.asyncio
async def test__receive_and_check_text_when_message_type_is_TEXT(self, transport_impl):
async def test__receive_and_check_text(self, transport_impl):
transport_impl._ws.receive = mock.AsyncMock(
return_value=StubResponse(type=aiohttp.WSMsgType.TEXT, data="some text")
)

assert await transport_impl._receive_and_check_text() == "some text"
assert await transport_impl._receive_and_check_text() == b"some text"

transport_impl._ws.receive.assert_awaited_once_with()

@pytest.mark.asyncio
async def test__receive_and_check_text_when_message_type_is_unknown(self, transport_impl):
mock_exception = errors.GatewayError("aye")
transport_impl._ws.receive = mock.AsyncMock(return_value=StubResponse(type=aiohttp.WSMsgType.BINARY))

with mock.patch.object(
shard._GatewayTransport, "_handle_other_message", side_effect=mock_exception
) as handle_other_message:
with pytest.raises(errors.GatewayError) as exc_info:
await transport_impl._receive_and_check_text()
with pytest.raises(
errors.GatewayTransportError,
match="Gateway transport error: Unexpected message type received BINARY, expected TEXT",
):
await transport_impl._receive_and_check_text()

assert exc_info.value is mock_exception
transport_impl._ws.receive.assert_awaited_once_with()
handle_other_message.assert_called_once_with(transport_impl._ws.receive.return_value)

@pytest.mark.asyncio
async def test__receive_and_check_zlib_when_message_type_is_BINARY(self, transport_impl):
response = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"some initial data")
transport_impl._ws.receive = mock.AsyncMock(return_value=response)
async def test__receive_and_check_zlib_when_payload_split_across_frames(self, transport_impl):
response1 = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"x\xda\xf2H\xcd\xc9")
response2 = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"\xc9W(\xcf/\xcaIQ\x04\x00\x00")
response3 = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"\x00\xff\xff")
transport_impl._ws.receive = mock.AsyncMock(side_effect=[response1, response2, response3])

with mock.patch.object(
shard._GatewayTransport, "_receive_and_check_complete_zlib_package"
) as receive_and_check_complete_zlib_package:
assert (
await transport_impl._receive_and_check_zlib() is receive_and_check_complete_zlib_package.return_value
)
assert await transport_impl._receive_and_check_zlib() == b"Hello world!"

transport_impl._ws.receive.assert_awaited_once_with()
receive_and_check_complete_zlib_package.assert_awaited_once_with(b"some initial data")
assert transport_impl._ws.receive.call_count == 3

@pytest.mark.asyncio
async def test__receive_and_check_zlib_when_message_type_is_BINARY_and_the_full_payload(self, transport_impl):
response = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"some initial data\x00\x00\xff\xff")
async def test__receive_and_check_zlib_when_full_payload_in_one_frame(self, transport_impl):
response = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"x\xdaJLD\x07\x00\x00\x00\x00\xff\xff")
transport_impl._ws.receive = mock.AsyncMock(return_value=response)
transport_impl._zlib = mock.Mock(decompress=mock.Mock(return_value=b"aaaaaaaaaaaaaaaaaa"))

assert await transport_impl._receive_and_check_zlib() == "aaaaaaaaaaaaaaaaaa"
assert await transport_impl._receive_and_check_zlib() == b"aaaaaaaaaaaaaaaaaa"

transport_impl._ws.receive.assert_awaited_once_with()
transport_impl._zlib.decompress.assert_called_once_with(response.data)

@pytest.mark.asyncio
async def test__receive_and_check_zlib_when_message_type_is_unknown(self, transport_impl):
mock_exception = errors.GatewayError("aye")
transport_impl._ws.receive = mock.AsyncMock(return_value=StubResponse(type=aiohttp.WSMsgType.TEXT))

with mock.patch.object(
shard._GatewayTransport, "_handle_other_message", side_effect=mock_exception
) as handle_other_message:
with pytest.raises(errors.GatewayError) as exc_info:
await transport_impl._receive_and_check_zlib()

assert exc_info.value is mock_exception
transport_impl._ws.receive.assert_awaited_once_with()
handle_other_message.assert_called_once_with(transport_impl._ws.receive.return_value)

@pytest.mark.asyncio
async def test__receive_and_check_complete_zlib_package_for_unexpected_message_type(self, transport_impl):
mock_exception = errors.GatewayError("aye")
response = StubResponse(type=aiohttp.WSMsgType.TEXT)
transport_impl._ws.receive = mock.AsyncMock(return_value=response)

with mock.patch.object(
shard._GatewayTransport, "_handle_other_message", side_effect=mock_exception
) as handle_other_message:
with pytest.raises(errors.GatewayError) as exc_info:
await transport_impl._receive_and_check_complete_zlib_package(b"some")

assert exc_info.value is mock_exception
transport_impl._ws.receive.assert_awaited_with()
handle_other_message.assert_called_once_with(response)
with pytest.raises(
errors.GatewayTransportError,
match="Gateway transport error: Unexpected message type received TEXT, expected BINARY",
):
await transport_impl._receive_and_check_zlib()

@pytest.mark.asyncio
async def test__receive_and_check_complete_zlib_package(self, transport_impl):
response1 = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"more")
response2 = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"data")
response3 = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"\x00\x00\xff\xff")
async def test__receive_and_check_zlib_when_issue_during_reception_of_multiple_frames(self, transport_impl):
response1 = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"x\xda\xf2H\xcd\xc9")
response2 = StubResponse(type=aiohttp.WSMsgType.ERROR, data="Something broke!")
response3 = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"\x00\xff\xff")
transport_impl._ws.receive = mock.AsyncMock(side_effect=[response1, response2, response3])
transport_impl._zlib = mock.Mock(decompress=mock.Mock(return_value=b"decoded utf-8 encoded bytes"))

assert await transport_impl._receive_and_check_complete_zlib_package(b"some") == "decoded utf-8 encoded bytes"
transport_impl._ws.exception = mock.Mock(return_value=None)

assert transport_impl._ws.receive.call_count == 3
transport_impl._ws.receive.assert_has_awaits([mock.call(), mock.call(), mock.call()])
transport_impl._zlib.decompress.assert_called_once_with(bytearray(b"somemoredata\x00\x00\xff\xff"))
with pytest.raises(
errors.GatewayTransportError, match=r"Gateway transport error: 'Something broke!' \[extra=None, type=258\]"
):
await transport_impl._receive_and_check_zlib()

@pytest.mark.parametrize("transport_compression", [True, False])
@pytest.mark.asyncio
Expand Down Expand Up @@ -1002,7 +971,7 @@ async def test__connect_when_not_reconnecting(self, client, http_settings, proxy
with stack:
assert await client._connect() == (heartbeat_task, poll_events_task)

log_filterer.assert_called_once_with("sometoken")
log_filterer.assert_called_once_with(b"sometoken")
gateway_transport_connect.assert_called_once_with(
http_settings=http_settings,
log_filterer=log_filterer.return_value,
Expand Down Expand Up @@ -1087,7 +1056,7 @@ async def test__connect_when_reconnecting(self, client, http_settings, proxy_set
with stack:
assert await client._connect() == (heartbeat_task, poll_events_task)

log_filterer.assert_called_once_with("sometoken")
log_filterer.assert_called_once_with(b"sometoken")
gateway_transport_connect.assert_called_once_with(
http_settings=http_settings,
log_filterer=log_filterer.return_value,
Expand Down

0 comments on commit 2a69b60

Please sign in to comment.