Skip to content

Commit

Permalink
Merge pull request #56 from PandABlocks/asyncio-flushing
Browse files Browse the repository at this point in the history
Fix bug flushing Data frames from asyncio client
  • Loading branch information
coretl authored Sep 4, 2023
2 parents f99c713 + 4ebc66a commit 3e086c9
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 27 deletions.
56 changes: 32 additions & 24 deletions pandablocks/asyncio.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import asyncio
import logging
from asyncio.streams import StreamReader, StreamWriter
from collections import deque
from typing import AsyncGenerator, Deque, Dict, Optional
from contextlib import suppress
from typing import AsyncGenerator, Dict, Iterable, Optional

from .commands import Command, T
from .connections import ControlConnection, DataConnection
Expand Down Expand Up @@ -143,35 +143,43 @@ async def data(
`asyncio.TimeoutError`
"""

data_stream = _StreamHelper()
await data_stream.connect(self._host, 8889),

stream = _StreamHelper()
connection = DataConnection()
data: Deque[Data] = deque()
reader = data_stream.reader
# Should we flush every FrameData?
flush_every_frame = flush_period is None
queue: asyncio.Queue[Iterable[Data]] = asyncio.Queue()

async def queue_flushed_data():
data.extend(connection.flush())
def raise_timeouterror():
raise asyncio.TimeoutError(f"No data received for {frame_timeout}s")
yield

async def periodic_flush():
if not flush_every_frame:
if flush_period is not None:
while True:
# Every flush_period seconds flush and queue data
await asyncio.gather(
asyncio.sleep(flush_period), queue_flushed_data()
)
await asyncio.sleep(flush_period)
queue.put_nowait(connection.flush())

flush_task = asyncio.create_task(periodic_flush())
async def read_from_stream():
reader = stream.reader
# Should we flush every FrameData?
flush_every_frame = flush_period is None
while True:
try:
recv = await asyncio.wait_for(reader.read(4096), frame_timeout)
except asyncio.TimeoutError:
queue.put_nowait(raise_timeouterror())
break
else:
queue.put_nowait(connection.receive_bytes(recv, flush_every_frame))

await stream.connect(self._host, 8889)
await stream.write_and_drain(connection.connect(scaled))
fut = asyncio.gather(periodic_flush(), read_from_stream())
try:
await data_stream.write_and_drain(connection.connect(scaled))
while True:
received = await asyncio.wait_for(reader.read(4096), frame_timeout)
for d in connection.receive_bytes(received, flush_every_frame):
data.append(d)
while data:
yield data.popleft()
for data in await queue.get():
yield data
finally:
flush_task.cancel()
await data_stream.close()
fut.cancel()
await stream.close()
with suppress(asyncio.CancelledError):
await fut
22 changes: 19 additions & 3 deletions tests/test_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,33 @@ async def test_asyncio_bad_put_raises(dummy_server_async):


@pytest.mark.asyncio
async def test_asyncio_data(dummy_server_async, fast_dump, fast_dump_expected):
@pytest.mark.parametrize("disarmed", [True, False])
@pytest.mark.parametrize("flush_period", [0.1, None])
async def test_asyncio_data(
dummy_server_async, fast_dump, fast_dump_expected, disarmed, flush_period
):
if not disarmed:
# simulate getting the data without the END marker as if arm was not pressed
fast_dump = map(lambda x: x.split(b"END")[0], fast_dump)
fast_dump_expected = list(fast_dump_expected)[:-1]
dummy_server_async.data = fast_dump
events = []
async with AsyncioClient("localhost") as client:
async for data in client.data(frame_timeout=1):
async for data in client.data(frame_timeout=1, flush_period=flush_period):
events.append(data)
if len(events) == 9:
if len(events) == len(fast_dump_expected):
break
assert fast_dump_expected == events


async def test_asyncio_data_timeout(dummy_server_async, fast_dump):
dummy_server_async.data = fast_dump
async with AsyncioClient("localhost") as client:
with pytest.raises(asyncio.TimeoutError, match="No data received for 0.1s"):
async for data in client.data(frame_timeout=0.1):
"This goes forever, when it runs out of data we will get our timeout"


@pytest.mark.asyncio
async def test_asyncio_connects(dummy_server_async: DummyServer):
async with AsyncioClient("localhost") as client:
Expand Down

0 comments on commit 3e086c9

Please sign in to comment.