diff --git a/pandablocks/asyncio.py b/pandablocks/asyncio.py index bc1f63fa2..bb0db8e1f 100644 --- a/pandablocks/asyncio.py +++ b/pandablocks/asyncio.py @@ -1,8 +1,8 @@ import asyncio import logging from asyncio.streams import StreamReader, StreamWriter -from collections import deque -from typing import AsyncGenerator, Deque, Dict, Optional +from contextlib import suppress +from typing import AsyncGenerator, Dict, Iterable, Optional from .commands import Command, T from .connections import ControlConnection, DataConnection @@ -143,35 +143,43 @@ async def data( `asyncio.TimeoutError` """ - data_stream = _StreamHelper() - await data_stream.connect(self._host, 8889), - + stream = _StreamHelper() connection = DataConnection() - data: Deque[Data] = deque() - reader = data_stream.reader - # Should we flush every FrameData? - flush_every_frame = flush_period is None + queue: asyncio.Queue[Iterable[Data]] = asyncio.Queue() - async def queue_flushed_data(): - data.extend(connection.flush()) + def raise_timeouterror(): + raise asyncio.TimeoutError(f"No data received for {frame_timeout}s") + yield async def periodic_flush(): - if not flush_every_frame: + if flush_period is not None: while True: # Every flush_period seconds flush and queue data - await asyncio.gather( - asyncio.sleep(flush_period), queue_flushed_data() - ) + await asyncio.sleep(flush_period) + queue.put_nowait(connection.flush()) - flush_task = asyncio.create_task(periodic_flush()) + async def read_from_stream(): + reader = stream.reader + # Should we flush every FrameData? + flush_every_frame = flush_period is None + while True: + try: + recv = await asyncio.wait_for(reader.read(4096), frame_timeout) + except asyncio.TimeoutError: + queue.put_nowait(raise_timeouterror()) + break + else: + queue.put_nowait(connection.receive_bytes(recv, flush_every_frame)) + + await stream.connect(self._host, 8889) + await stream.write_and_drain(connection.connect(scaled)) + fut = asyncio.gather(periodic_flush(), read_from_stream()) try: - await data_stream.write_and_drain(connection.connect(scaled)) while True: - received = await asyncio.wait_for(reader.read(4096), frame_timeout) - for d in connection.receive_bytes(received, flush_every_frame): - data.append(d) - while data: - yield data.popleft() + for data in await queue.get(): + yield data finally: - flush_task.cancel() - await data_stream.close() + fut.cancel() + await stream.close() + with suppress(asyncio.CancelledError): + await fut diff --git a/tests/test_asyncio.py b/tests/test_asyncio.py index ff52f006c..1d49d3685 100644 --- a/tests/test_asyncio.py +++ b/tests/test_asyncio.py @@ -31,17 +31,33 @@ async def test_asyncio_bad_put_raises(dummy_server_async): @pytest.mark.asyncio -async def test_asyncio_data(dummy_server_async, fast_dump, fast_dump_expected): +@pytest.mark.parametrize("disarmed", [True, False]) +@pytest.mark.parametrize("flush_period", [0.1, None]) +async def test_asyncio_data( + dummy_server_async, fast_dump, fast_dump_expected, disarmed, flush_period +): + if not disarmed: + # simulate getting the data without the END marker as if arm was not pressed + fast_dump = map(lambda x: x.split(b"END")[0], fast_dump) + fast_dump_expected = list(fast_dump_expected)[:-1] dummy_server_async.data = fast_dump events = [] async with AsyncioClient("localhost") as client: - async for data in client.data(frame_timeout=1): + async for data in client.data(frame_timeout=1, flush_period=flush_period): events.append(data) - if len(events) == 9: + if len(events) == len(fast_dump_expected): break assert fast_dump_expected == events +async def test_asyncio_data_timeout(dummy_server_async, fast_dump): + dummy_server_async.data = fast_dump + async with AsyncioClient("localhost") as client: + with pytest.raises(asyncio.TimeoutError, match="No data received for 0.1s"): + async for data in client.data(frame_timeout=0.1): + "This goes forever, when it runs out of data we will get our timeout" + + @pytest.mark.asyncio async def test_asyncio_connects(dummy_server_async: DummyServer): async with AsyncioClient("localhost") as client: