diff --git a/runtimepy/net/stream/json.py b/runtimepy/net/stream/json.py index 8ae7b6a4..4b06e68f 100644 --- a/runtimepy/net/stream/json.py +++ b/runtimepy/net/stream/json.py @@ -6,7 +6,7 @@ # built-in from copy import copy -from json import dumps, loads +from json import JSONDecodeError, dumps, loads from typing import Any, Awaitable, Callable, Dict, Tuple, Type, TypeVar, Union # third-party @@ -14,6 +14,7 @@ # internal from runtimepy.net.stream.string import StringMessageConnection +from runtimepy.net.udp import UdpConnection JsonMessage = Dict[str, Any] @@ -33,6 +34,7 @@ TypedHandler = Callable[[JsonMessage, T], Awaitable[None]] DEFAULT_LOOPBACK = {"a": 1, "b": 2, "c": 3} +DEFAULT_TIMEOUT = 3 async def loopback_handler(outbox: JsonMessage, inbox: JsonMessage) -> None: @@ -41,6 +43,19 @@ async def loopback_handler(outbox: JsonMessage, inbox: JsonMessage) -> None: outbox.update(inbox) +async def event_wait(event: asyncio.Event, timeout: float) -> bool: + """Wait for an event to be set within a timeout.""" + + result = True + + try: + await asyncio.wait_for(event.wait(), timeout) + except asyncio.TimeoutError: + result = False + + return result + + class JsonMessageConnection(StringMessageConnection): """A connection interface for JSON messaging.""" @@ -107,7 +122,10 @@ def send_json( self.send_message_str(dumps(data, separators=(",", ":")), addr=addr) async def wait_json( - self, data: Union[JsonMessage, JsonCodec], addr: Tuple[str, int] = None + self, + data: Union[JsonMessage, JsonCodec], + addr: Tuple[str, int] = None, + timeout: float = DEFAULT_TIMEOUT, ) -> JsonMessage: """Send a JSON message and wait for a response.""" @@ -128,21 +146,30 @@ async def wait_json( # Send message and await response. self.send_json(data, addr=addr) - await got_response.wait() + + assert await event_wait( + got_response, timeout + ), f"No response received in {timeout} seconds!" # Return the result. result = self.id_responses[ident] del self.id_responses[ident] + return result - async def loopback(self, data: JsonMessage = None) -> bool: + async def loopback( + self, + data: JsonMessage = None, + addr: Tuple[str, int] = None, + timeout: float = DEFAULT_TIMEOUT, + ) -> bool: """Perform a simple loopback test on this connection.""" if data is None: data = DEFAULT_LOOPBACK message = {"loopback": data} - response = await self.wait_json(message) + response = await self.wait_json(message, addr=addr, timeout=timeout) status = response == message self.logger.info( @@ -153,6 +180,22 @@ async def loopback(self, data: JsonMessage = None) -> bool: return status + async def async_init(self) -> bool: + """A runtime initialization routine (executes during 'process').""" + + # Check loopback if it makes sense to. + result = await super().async_init() + + # Only not-connected UDP connections can't do this. + if ( + result + and hasattr("self", "remote_address") + or not isinstance(self, UdpConnection) + ): + result = await self.loopback() + + return result + async def process_json( self, data: JsonMessage, addr: Tuple[str, int] = None ) -> bool: @@ -217,11 +260,15 @@ async def process_message( """Process a string message.""" result = True - decoded = loads(data) - if decoded and isinstance(decoded, dict): - result = await self.process_json(decoded, addr=addr) - else: - self.logger.error("Ignoring message '%s'.", data) + try: + decoded = loads(data) + + if decoded and isinstance(decoded, dict): + result = await self.process_json(decoded, addr=addr) + else: + self.logger.error("Ignoring message '%s'.", data) + except JSONDecodeError as exc: + self.logger.exception("Couldn't decode '%s': %s", data, exc) return result diff --git a/tests/net/stream/__init__.py b/tests/net/stream/__init__.py index af9657fd..26efb61b 100644 --- a/tests/net/stream/__init__.py +++ b/tests/net/stream/__init__.py @@ -5,10 +5,13 @@ # built-in import asyncio +# third-party +from vcorelib.dict.codec import BasicDictCodec + # module under test from runtimepy.net.arbiter.info import AppInfo from runtimepy.net.stream import StringMessageConnection -from runtimepy.net.stream.json import JsonMessageConnection +from runtimepy.net.stream.json import JsonMessage, JsonMessageConnection async def stream_test(app: AppInfo) -> int: @@ -32,13 +35,19 @@ async def json_client_test(client: JsonMessageConnection) -> int: """Test a single JSON client.""" client.send_json({}) - await client.wait_json({}) assert await client.wait_json({"unknown": 0, "command": 1}) == { "keys_ignored": ["command", "unknown"] } + codec = BasicDictCodec.create({"a": 1, "b": 2, "c": 3}) + client.send_json(codec) + assert await client.wait_json(codec) == {"keys_ignored": ["a", "b", "c"]} + + # Should trigger decode error. + client.send_message_str("{hello") + # Test loopback. assert await client.loopback() assert await client.loopback(data={"a": 1, "b": 2, "c": 3}) @@ -49,14 +58,40 @@ async def json_client_test(client: JsonMessageConnection) -> int: async def json_test(app: AppInfo) -> int: """Test JSON clients in parallel.""" + # Add typed handler for UDP server connection. + udp_server = app.single( + pattern="udp_json_server", kind=JsonMessageConnection + ) + + async def typed_handler( + response: JsonMessage, data: BasicDictCodec + ) -> None: + """An example handler.""" + + response["it_worked"] = True + response.update(data.asdict()) + + # Test handler. + udp_server.typed_handler("test", BasicDictCodec, typed_handler) + + udp_client = app.single( + pattern="udp_json_client", kind=JsonMessageConnection + ) + + result = await udp_client.wait_json({"test": {"a": 1, "b": 2, "c": 3}}) + result = result["test"] + assert "it_worked" in result + assert result["it_worked"] is True, result + assert result["a"] == 1, result + assert result["b"] == 2, result + assert result["c"] == 3, result + return sum( await asyncio.gather( *[ json_client_test(client) for client in [ - app.single( - pattern="udp_json_client", kind=JsonMessageConnection - ), + udp_client, app.single(pattern="tcp_json", kind=JsonMessageConnection), app.single( pattern="websocket_json", kind=JsonMessageConnection diff --git a/tests/net/stream/test_stream.py b/tests/net/stream/test_stream.py new file mode 100644 index 00000000..8667f098 --- /dev/null +++ b/tests/net/stream/test_stream.py @@ -0,0 +1,20 @@ +""" +Test the 'net.stream' module. +""" + +# built-in +import asyncio + +# third-party +from pytest import mark + +# module under test +from runtimepy.net.stream.json import event_wait + + +@mark.asyncio +async def test_event_wait_basic(): + """Test the event wait can time out.""" + + event = asyncio.Event() + assert not await event_wait(event, 0.0)