From 565beff31f636a49e55ceac2aecb0a7e52f59d4d Mon Sep 17 00:00:00 2001 From: Callum Forrester Date: Fri, 14 Jul 2023 10:16:06 +0100 Subject: [PATCH] Improve HTTP and ZeroMQ adapters (#137) Closes #111 and #15 Changes: * Add synchronised start/stop mechanism to HTTP adapter * Write suite of tests using this mechanism * Make HTTP adapter support interrupts (it did not appear to already) * Remove `include_json` parameter from HTTP endpoints as `aiohttp` can work that out for itself * Create new ZeroMQ adapter specifically for pushing. The previous one had issues and also implemented more functionality than was needed. Multiple ZMQ adapters for the different ZMQ socket modes (PUSH, PUBLISH, DEALER etc.) seems like a better way to go. * Synchronize ZeroMQ socket binding with a lock to avoid conflicts * Write suite of tests for ZeroMQ push adapter --- docs/user/reference/api.rst | 6 +- examples/configs/http-and-zeromq-devices.yaml | 7 + examples/configs/http-device.yaml | 11 +- examples/devices/http_device.py | 81 ++--- examples/devices/zeromq_push_device.py | 49 +++ pyproject.toml | 1 + src/tickit/adapters/httpadapter.py | 66 ++++- .../interpreters/endpoints/http_endpoint.py | 54 ++-- src/tickit/adapters/zeromq/__init__.py | 0 src/tickit/adapters/zeromq/push_adapter.py | 120 ++++++++ src/tickit/adapters/zmqadapter.py | 96 ------ .../endpoints/test_http_endpoint.py | 24 +- tests/adapters/test_httpadapter.py | 278 +++++++++++++++--- tests/adapters/test_zmqadapter.py | 145 --------- tests/adapters/zeromq/__init__.py | 0 tests/adapters/zeromq/test_push_adapter.py | 155 ++++++++++ 16 files changed, 697 insertions(+), 396 deletions(-) create mode 100644 examples/configs/http-and-zeromq-devices.yaml create mode 100644 examples/devices/zeromq_push_device.py create mode 100644 src/tickit/adapters/zeromq/__init__.py create mode 100644 src/tickit/adapters/zeromq/push_adapter.py delete mode 100644 src/tickit/adapters/zmqadapter.py delete mode 100644 tests/adapters/test_zmqadapter.py create mode 100644 tests/adapters/zeromq/__init__.py create mode 100644 tests/adapters/zeromq/test_push_adapter.py diff --git a/docs/user/reference/api.rst b/docs/user/reference/api.rst index 62f53d880..f2992e673 100644 --- a/docs/user/reference/api.rst +++ b/docs/user/reference/api.rst @@ -231,11 +231,11 @@ This is the internal API reference for tickit ------------------------------- - .. automodule:: tickit.adapters.zmqadapter + .. automodule:: tickit.adapters.zeromq.push_adapter :members: - ``tickit.adapters.zmqadapter`` - ------------------------------ + ``tickit.adapters.zeromq.push_adapter`` + --------------------------------------- .. automodule:: tickit.adapters.epicsadapter diff --git a/examples/configs/http-and-zeromq-devices.yaml b/examples/configs/http-and-zeromq-devices.yaml new file mode 100644 index 000000000..f54a83e3e --- /dev/null +++ b/examples/configs/http-and-zeromq-devices.yaml @@ -0,0 +1,7 @@ +- examples.devices.http_device.ExampleHttpDevice: + name: http-device + inputs: {} +- examples.devices.zeromq_push_device.ExampleZeroMqPusher: + name: zeromq-pusher + inputs: + updates: http-device:updates diff --git a/examples/configs/http-device.yaml b/examples/configs/http-device.yaml index 75c592a9f..eee6ba563 100644 --- a/examples/configs/http-device.yaml +++ b/examples/configs/http-device.yaml @@ -1,10 +1,3 @@ -- tickit.devices.source.Source: - name: source - inputs: {} - value: False -- examples.devices.http_device.ExampleHTTP: +- examples.devices.http_device.ExampleHttpDevice: name: http-device - inputs: - foo: source:value - foo: False - bar: 10 + inputs: {} diff --git a/examples/devices/http_device.py b/examples/devices/http_device.py index a10ee7d1f..e41548e0a 100644 --- a/examples/devices/http_device.py +++ b/examples/devices/http_device.py @@ -1,63 +1,21 @@ from dataclasses import dataclass -from typing import Optional from aiohttp import web -from tickit.adapters.httpadapter import HTTPAdapter -from tickit.adapters.interpreters.endpoints.http_endpoint import HTTPEndpoint +from tickit.adapters.httpadapter import HttpAdapter +from tickit.adapters.interpreters.endpoints.http_endpoint import HttpEndpoint from tickit.core.components.component import Component, ComponentConfig from tickit.core.components.device_simulation import DeviceSimulation -from tickit.core.device import Device, DeviceUpdate -from tickit.core.typedefs import SimTime -from tickit.utils.compat.typing_compat import TypedDict +from tickit.devices.iobox import IoBoxDevice -class ExampleHTTPDevice(Device): - """A device class for an example HTTP device. +class IoBoxHttpAdapter(HttpAdapter): + """An adapter for an IoBox that allows reads and writes via REST calls""" - ... - """ + device: IoBoxDevice - Inputs: TypedDict = TypedDict("Inputs", {"foo": bool}) - - Outputs: TypedDict = TypedDict("Outputs", {"bar": float}) - - def __init__( - self, - foo: bool = False, - bar: Optional[int] = 10, - ) -> None: - """An example HTTP device constructor which configures the ... . - - Args: - foo (bool): A flag to indicate something. Defauls to False. - bar (int, optional): A number to represent something. Defaults to 3600. - """ - self.foo = foo - self.bar = bar - - def update(self, time: SimTime, inputs: Inputs) -> DeviceUpdate[Outputs]: - """Generic update function to update the values of the ExampleHTTPDevice. - - Args: - time (SimTime): The simulation time in nanoseconds. - inputs (Inputs): A TypedDict of the inputs to the ExampleHTTPDevice. - - Returns: - DeviceUpdate[Outputs]: - The produced update event which contains the value of the device - variables. - """ - pass - - -class ExampleHTTPAdapter(HTTPAdapter): - """An Eiger adapter which parses the commands sent to the HTTP server.""" - - device: ExampleHTTPDevice - - @HTTPEndpoint.put("/command/foo/") - async def foo(self, request: web.Request) -> web.Response: + @HttpEndpoint.put("/memory/{address}", interrupt=True) + async def write_to_address(self, request: web.Request) -> web.Response: """A HTTP endpoint for sending a command to the example HTTP device. Args: @@ -66,10 +24,13 @@ async def foo(self, request: web.Request) -> web.Response: Returns: web.Response: [description] """ - return web.Response(text=str("put data")) + address = request.match_info["address"] + new_value = (await request.json())["value"] + self.device.write(address, new_value) + return web.json_response({address: new_value}) - @HTTPEndpoint.get("/info/bar/{data}") - async def bar(self, request: web.Request) -> web.Response: + @HttpEndpoint.get("/memory/{address}") + async def read_from_address(self, request: web.Request) -> web.Response: """A HTTP endpoint for requesting data from the example HTTP device. Args: @@ -78,19 +39,21 @@ async def bar(self, request: web.Request) -> web.Response: Returns: web.Response: [description] """ - return web.Response(text=f"Your data: {request.match_info['data']}") + address = request.match_info["address"] + value = self.device.read(address) + return web.json_response({address: value}) @dataclass -class ExampleHTTP(ComponentConfig): +class ExampleHttpDevice(ComponentConfig): """Example HTTP device.""" - foo: bool = False - bar: Optional[int] = 10 + host: str = "localhost" + port: int = 8080 def __call__(self) -> Component: # noqa: D102 return DeviceSimulation( name=self.name, - device=ExampleHTTPDevice(foo=self.foo, bar=self.bar), - adapters=[ExampleHTTPAdapter()], + device=IoBoxDevice(), + adapters=[IoBoxHttpAdapter(self.host, self.port)], ) diff --git a/examples/devices/zeromq_push_device.py b/examples/devices/zeromq_push_device.py new file mode 100644 index 000000000..7f97a9485 --- /dev/null +++ b/examples/devices/zeromq_push_device.py @@ -0,0 +1,49 @@ +from dataclasses import dataclass, field +from typing import Optional, Set + +from tickit.adapters.zeromq.push_adapter import ( + SocketFactory, + ZeroMqPushAdapter, + create_zmq_push_socket, +) +from tickit.core.components.component import Component, ComponentConfig +from tickit.core.components.device_simulation import DeviceSimulation +from tickit.devices.iobox import IoBoxDevice + + +class IoBoxZeroMqAdapter(ZeroMqPushAdapter): + """An Eiger adapter which parses the commands sent to the HTTP server.""" + + device: IoBoxDevice[str, int] + _addresses_to_publish: Set[str] + + def __init__( + self, + host: str = "127.0.0.1", + port: int = 5555, + socket_factory: Optional[SocketFactory] = create_zmq_push_socket, + addresses_to_publish: Optional[Set[str]] = None, + ) -> None: + super().__init__(host, port, socket_factory) + self._addresses_to_publish = addresses_to_publish or set() + + def after_update(self): + for address in self._addresses_to_publish: + value = self.device.read(address) + self.send_message([{address: value}]) + + +@dataclass +class ExampleZeroMqPusher(ComponentConfig): + """Device that can publish writes to its memory over a zeromq socket.""" + + host: str = "127.0.0.1" + port: int = 5555 + addresses_to_publish: Set[str] = field(default_factory=lambda: {"foo", "bar"}) + + def __call__(self) -> Component: # noqa: D102 + return DeviceSimulation( + name=self.name, + device=IoBoxDevice(), + adapters=[IoBoxZeroMqAdapter()], + ) diff --git a/pyproject.toml b/pyproject.toml index a91cd1323..ba653aca6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,6 +18,7 @@ dependencies = [ "aiozmq", "apischema==0.16.1", "immutables", + "pydantic", "pyyaml", "pyzmq", "softioc", diff --git a/src/tickit/adapters/httpadapter.py b/src/tickit/adapters/httpadapter.py index c23056f03..c10ba71e3 100644 --- a/src/tickit/adapters/httpadapter.py +++ b/src/tickit/adapters/httpadapter.py @@ -2,12 +2,12 @@ import logging from dataclasses import dataclass from inspect import getmembers -from typing import Iterable +from typing import Awaitable, Callable, Iterable, Optional from aiohttp import web from aiohttp.web_routedef import RouteDef -from tickit.adapters.interpreters.endpoints.http_endpoint import HTTPEndpoint +from tickit.adapters.interpreters.endpoints.http_endpoint import HttpEndpoint from tickit.core.adapter import Adapter, RaiseInterrupt from tickit.core.device import Device @@ -15,7 +15,7 @@ @dataclass -class HTTPAdapter(Adapter): +class HttpAdapter(Adapter): """An adapter implementation which delegates to a server and sets up endpoints. An adapter implementation which delegates the hosting of an http requests to a @@ -25,20 +25,47 @@ class HTTPAdapter(Adapter): host: str = "localhost" port: int = 8080 + _stopped: Optional[asyncio.Event] = None + _ready: Optional[asyncio.Event] = None + async def run_forever( self, device: Device, raise_interrupt: RaiseInterrupt ) -> None: """Runs the server continuously.""" await super().run_forever(device, raise_interrupt) + self._ensure_stopped_event().clear() await self._start_server() - + self._ensure_ready_event().set() try: - await asyncio.Event().wait() - finally: - # TODO: This doesn't work yet due to asyncio's own exception handler + await self._ensure_stopped_event().wait() + except asyncio.CancelledError: + await self.stop() + + async def wait_until_ready(self, timeout: float = 1.0) -> None: + while self._ready is None: + await asyncio.sleep(0.1) + await asyncio.wait_for(self._ready.wait(), timeout=timeout) + + async def stop(self) -> None: + stopped = self._ensure_stopped_event() + if not stopped.is_set(): + await self.site.stop() await self.app.shutdown() await self.app.cleanup() + self._ensure_stopped_event().set() + if self._ready is not None: + self._ready.clear() + + def _ensure_stopped_event(self) -> asyncio.Event: + if self._stopped is None: + self._stopped = asyncio.Event() + return self._stopped + + def _ensure_ready_event(self) -> asyncio.Event: + if self._ready is None: + self._ready = asyncio.Event() + return self._ready async def _start_server(self): LOGGER.debug(f"Starting HTTP server... {self}") @@ -46,8 +73,8 @@ async def _start_server(self): self.app.add_routes(list(self.endpoints())) runner = web.AppRunner(self.app) await runner.setup() - site = web.TCPSite(runner, host=self.host, port=self.port) - await site.start() + self.site = web.TCPSite(runner, host=self.host, port=self.port) + await self.site.start() def endpoints(self) -> Iterable[RouteDef]: """Returns list of endpoints. @@ -56,12 +83,27 @@ def endpoints(self) -> Iterable[RouteDef]: then yields them. Returns: - Iterable[HTTPEndpoint]: The list of defined endpoints + Iterable[HttpEndpoint]: The list of defined endpoints Yields: - Iterator[Iterable[HTTPEndpoint]]: The iterator of the defined endpoints + Iterator[Iterable[HttpEndpoint]]: The iterator of the defined endpoints """ for _, func in getmembers(self): endpoint = getattr(func, "__endpoint__", None) # type: ignore - if endpoint is not None and isinstance(endpoint, HTTPEndpoint): + if endpoint is not None and isinstance(endpoint, HttpEndpoint): + if endpoint.interrupt: + func = _with_posthoc_task(func, self.raise_interrupt) yield endpoint.define(func) + + +def _with_posthoc_task( + func: Callable[[web.Request], Awaitable[web.Response]], + afterwards: Callable[[], Awaitable[None]], +) -> Callable[[web.Request], Awaitable[web.Response]]: + # @functools.wraps + async def wrapped(request: web.Request) -> web.Response: + response = await func(request) + await afterwards() + return response + + return wrapped diff --git a/src/tickit/adapters/interpreters/endpoints/http_endpoint.py b/src/tickit/adapters/interpreters/endpoints/http_endpoint.py index 56577ee28..8a411abe8 100644 --- a/src/tickit/adapters/interpreters/endpoints/http_endpoint.py +++ b/src/tickit/adapters/interpreters/endpoints/http_endpoint.py @@ -7,34 +7,33 @@ @dataclass(frozen=True) -class HTTPEndpoint(Generic[AnyStr]): - """A decorator to register a device adapter method as a HTTP Endpoint. +class HttpEndpoint(Generic[AnyStr]): + """A decorator intended for use with HttpAdapter. + + Routes an HTTP endpoint to the decorated method. Args: - url (str): The URL that will point to a specific endpoint. - method (str): The method to use when using this endpoint. - name (str): The name of the route. - include_json (bool): A flag to indicate whether the route should include json. - interrupt (bool): A flag indicating whether calling of the method should - raise an adapter interrupt. Defaults to False. + path: The URL that will point to a specific endpoint. + method: The method to use when using this endpoint. + interrupt: If True, every time this endpoint is called the adapter's device + will be interrupted. Defaults to False. Returns: Callable: A decorator which registers the adapter method as an endpoint. """ - url: str + path: str method: str - include_json: bool = False interrupt: bool = False # Type signature can become more specific if support is dropped for # Python 3.7, see https://github.com/python/mypy/issues/708 def __call__(self, func: Callable) -> Callable: - """A decorator which registers the adapter method as an endpoint. + """Decorate a function for HTTP routing. Args: - func (Callable): The adapter method to be registered as an endpoint. + func: The adapter method to be registered as an endpoint. Returns: Callable: The registered adapter endpoint. @@ -57,26 +56,31 @@ def define( Returns: RouteDef: The route definition for the endpoint. """ - return RouteDef(self.method, self.url, func, {}) + return RouteDef(self.method, self.path, func, {}) + + @classmethod + def get(cls, url: str, interrupt: bool = False) -> "HttpEndpoint": + """Shortcut to making a GET endpoint. + + Returns: + cls: The class of HttpEndpoint with the "GET" request method. + """ + return cls(url, "GET", interrupt) @classmethod - def get( - cls, url: str, include_json: bool = False, interrupt: bool = False - ) -> "HTTPEndpoint": - """Method for the HTTPEndpoint that sets the request method to "GET". + def put(cls, url: str, interrupt: bool = False) -> "HttpEndpoint": + """Shortcut to making a PUT endpoint. Returns: - cls: The class of HTTPEndpoint with the "GET" request method. + cls: The class of HttpEndpoint with the "PUT" request method. """ - return cls(url, "GET", include_json, interrupt) + return cls(url, "PUT", interrupt) @classmethod - def put( - cls, url: str, include_json: bool = False, interrupt: bool = False - ) -> "HTTPEndpoint": - """Method for the HTTPEndpoint that sets the request method to "PUT". + def post(cls, url: str, interrupt: bool = False) -> "HttpEndpoint": + """Shortcut to making a POST endpoint. Returns: - cls: The class of HTTPEndpoint with the "PUT" request method. + cls: The class of HttpEndpoint with the "POST" request method. """ - return cls(url, "PUT", include_json, interrupt) + return cls(url, "POST", interrupt) diff --git a/src/tickit/adapters/zeromq/__init__.py b/src/tickit/adapters/zeromq/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/tickit/adapters/zeromq/push_adapter.py b/src/tickit/adapters/zeromq/push_adapter.py new file mode 100644 index 000000000..b1cb35469 --- /dev/null +++ b/src/tickit/adapters/zeromq/push_adapter.py @@ -0,0 +1,120 @@ +import asyncio +import json +import logging +from typing import ( + Any, + Iterable, + Mapping, + Optional, + Protocol, + Sequence, + Union, + runtime_checkable, +) + +import aiozmq +import zmq +from pydantic.v1 import BaseModel + +from tickit.core.adapter import Adapter, RaiseInterrupt +from tickit.core.device import Device + +LOGGER = logging.getLogger(__name__) + + +_MessagePart = Union[bytes, zmq.Frame, memoryview] +_SerializableMessagePart = Union[ + _MessagePart, + str, + Mapping[str, Any], + BaseModel, +] + +_ZeroMqInternalMessage = Sequence[_MessagePart] +ZeroMqMessage = Sequence[_SerializableMessagePart] +# SocketFactory = Callable[[], Awaitable[aiozmq.ZmqStream]] + + +@runtime_checkable +class SocketFactory(Protocol): + async def __call__(self, __host: str, __port: int) -> aiozmq.ZmqStream: + ... + + +async def create_zmq_push_socket(host: str, port: int) -> aiozmq.ZmqStream: + addr = f"tcp://{host}:{port}" + return await aiozmq.create_zmq_stream(zmq.PUSH, connect=addr, bind=addr) + + +class ZeroMqPushAdapter(Adapter): + """An adapter for a ZeroMQ data stream.""" + + _host: str + _port: int + _socket: Optional[aiozmq.ZmqStream] + _socket_factory: SocketFactory + _socket_lock: asyncio.Lock + + def __init__( + self, + host: str = "127.0.0.1", + port: int = 5555, + socket_factory: SocketFactory = create_zmq_push_socket, + ) -> None: + """Initialize with default values.""" + super().__init__() + self._host = host + self._port = port + self._socket = None + self._socket_factory = socket_factory + self._socket_lock = asyncio.Lock() + + async def run_forever( + self, + device: Device, + raise_interrupt: RaiseInterrupt, + ) -> None: + """Runs the ZeroMQ adapter continuously.""" + await super().run_forever(device, raise_interrupt) + + try: + await self._ensure_socket() + except asyncio.CancelledError: + if self._socket is not None: + self._socket.close() + await self._socket.drain() + + def send_message_sequence_soon( + self, + messages: Iterable[ZeroMqMessage], + ) -> None: + async def send_message_sequence() -> None: + for message in messages: + await self.send_message(message) + + asyncio.create_task(send_message_sequence()) + + async def send_message(self, message: ZeroMqMessage) -> None: + socket = await self._ensure_socket() + serialized = self._serialize(message) + socket.write(serialized) + await socket.drain() + + async def _ensure_socket(self) -> aiozmq.ZmqStream: + async with self._socket_lock: + if self._socket is None: + self._socket = await self._socket_factory(self._host, self._port) + return self._socket + + def _serialize(self, message: ZeroMqMessage) -> _ZeroMqInternalMessage: + return list(map(self._serialize_part, message)) + + def _serialize_part(self, part: _SerializableMessagePart) -> _MessagePart: + if isinstance(part, BaseModel): + return self._serialize_part(part.dict()) + elif isinstance(part, dict) or isinstance(part, str): + return self._serialize_part(json.dumps(part).encode("utf_8")) + elif isinstance(part, bytes): + return part + else: + raise TypeError(f"Message: {part} is not serializable") diff --git a/src/tickit/adapters/zmqadapter.py b/src/tickit/adapters/zmqadapter.py deleted file mode 100644 index 126c6de84..000000000 --- a/src/tickit/adapters/zmqadapter.py +++ /dev/null @@ -1,96 +0,0 @@ -import asyncio -import logging -from typing import Any - -import aiozmq -import zmq - -from tickit.core.adapter import Adapter, RaiseInterrupt -from tickit.core.device import Device - -LOGGER = logging.getLogger(__name__) - - -class ZeroMQAdapter(Adapter): - """An adapter for a ZeroMQ data stream.""" - - zmq_host: str - zmq_port: int - running: bool - - _router: aiozmq.ZmqStream - _dealer: aiozmq.ZmqStream - - def __init__( - self, zmq_host: str = "127.0.0.1", zmq_port: int = 5555, running: bool = False - ) -> None: - """Initialize with default values.""" - super().__init__() - self.zmq_host = zmq_host - self.zmq_port = zmq_port - self.running = running - - async def start_stream(self) -> None: - """Start the ZeroMQ stream.""" - LOGGER.debug("Starting stream...") - self._router = await aiozmq.create_zmq_stream( - zmq.ROUTER, bind=f"tcp://{self.zmq_host}:{self.zmq_port}" - ) - - addr = list(self._router.transport.bindings())[0] - self._dealer = await aiozmq.create_zmq_stream(zmq.DEALER, connect=addr) - - self._router.transport.setsockopt(zmq.LINGER, 0) - self._dealer.transport.setsockopt(zmq.LINGER, 0) - - async def close_stream(self) -> None: - """Close the ZeroMQ stream.""" - self._dealer.close() - self._router.close() - - self.running = False - - def send_message(self, message: Any) -> None: - """Send a message down the ZeroMQ stream. - - Sets up an asyncio task to put the message on the message queue, before - being processed. - - Args: - message (str): The message to send down the ZeroMQ stream. - """ - asyncio.create_task(self._message_queue.put(message)) - - async def run_forever( - self, device: Device, raise_interrupt: RaiseInterrupt - ) -> None: - """Runs the ZeroMQ adapter continuously.""" - await super().run_forever(device, raise_interrupt) - self._message_queue: asyncio.Queue = asyncio.Queue() - await self.start_stream() - self.running = True - await self._process_message_queue() - - def check_if_running(self): - """Returns the running state of the adapter.""" - return self.running - - async def _process_message_queue(self) -> None: - running = True - while running: - message = await self._message_queue.get() - await self._process_message(message) - running = self.check_if_running() - - async def _process_message(self, message: str) -> None: - if message is not None: - LOGGER.debug(f"Data from ZMQ stream: {message!r}") - - msg = (b"Data", str(message).encode("utf-8")) - self._dealer.write(msg) - data = await self._router.read() - self._router.write(data) - answer = await self._dealer.read() - LOGGER.debug(f"Received {answer!r}") - else: - LOGGER.debug("No message") diff --git a/tests/adapters/interpreters/endpoints/test_http_endpoint.py b/tests/adapters/interpreters/endpoints/test_http_endpoint.py index 7b83db073..07acbeff1 100644 --- a/tests/adapters/interpreters/endpoints/test_http_endpoint.py +++ b/tests/adapters/interpreters/endpoints/test_http_endpoint.py @@ -1,30 +1,30 @@ import pytest from aiohttp import web -from tickit.adapters.interpreters.endpoints.http_endpoint import HTTPEndpoint +from tickit.adapters.interpreters.endpoints.http_endpoint import HttpEndpoint @pytest.fixture -def http_endpoint(url: str, method: str, include_json: bool, interrupt: bool): - return HTTPEndpoint(url, method, include_json, interrupt) +def http_endpoint(url: str, method: str, interrupt: bool): + return HttpEndpoint(url, method, interrupt) def test_http_endpoint_registers_get_endpoint(): class TestAdapter: - @HTTPEndpoint.get("test", False, False) + @HttpEndpoint.get("test", False) def test_endpoint(): pass - assert isinstance(TestAdapter.test_endpoint.__endpoint__, HTTPEndpoint) + assert isinstance(TestAdapter.test_endpoint.__endpoint__, HttpEndpoint) def test_http_endpoint_registers_put_endpoint(): class TestAdapter: - @HTTPEndpoint.put("test", False, False) + @HttpEndpoint.put("test", False) def test_endpoint(): pass - assert isinstance(TestAdapter.test_endpoint.__endpoint__, HTTPEndpoint) + assert isinstance(TestAdapter.test_endpoint.__endpoint__, HttpEndpoint) def fake_endpoint(request: web.Request): @@ -32,36 +32,34 @@ def fake_endpoint(request: web.Request): @pytest.mark.parametrize( - ["url", "method", "include_json", "interrupt", "expected"], + ["url", "method", "interrupt", "expected"], [ ( r"TestUrl", r"GET", False, - False, web.RouteDef("GET", "TestUrl", fake_endpoint, {}), ) ], ) def test_http_get_endpoint_define_returns_get_routedef( - http_endpoint: HTTPEndpoint, expected: web.RouteDef + http_endpoint: HttpEndpoint, expected: web.RouteDef ): assert expected == http_endpoint.define(fake_endpoint) @pytest.mark.parametrize( - ["url", "method", "include_json", "interrupt", "expected"], + ["url", "method", "interrupt", "expected"], [ ( r"TestUrl", r"PUT", False, - False, web.RouteDef("PUT", "TestUrl", fake_endpoint, {}), ) ], ) def test_http_put_endpoint_define_returns_put_routedef( - http_endpoint: HTTPEndpoint, expected: web.RouteDef + http_endpoint: HttpEndpoint, expected: web.RouteDef ): assert expected == http_endpoint.define(fake_endpoint) diff --git a/tests/adapters/test_httpadapter.py b/tests/adapters/test_httpadapter.py index 82d9428b9..311f7a766 100644 --- a/tests/adapters/test_httpadapter.py +++ b/tests/adapters/test_httpadapter.py @@ -1,15 +1,19 @@ -from typing import Iterable +import asyncio +import aiohttp import pytest +import pytest_asyncio from aiohttp import web from mock import Mock -from mock.mock import create_autospec, patch +from mock.mock import create_autospec -from tickit.adapters.httpadapter import HTTPAdapter -from tickit.adapters.interpreters.endpoints.http_endpoint import HTTPEndpoint +from tickit.adapters.httpadapter import HttpAdapter +from tickit.adapters.interpreters.endpoints.http_endpoint import HttpEndpoint +from tickit.core.adapter import RaiseInterrupt from tickit.core.device import Device ISSUE_LINK = "https://github.com/dls-controls/tickit/issues/111" +REQUEST_TIMEOUT = 0.5 @pytest.fixture @@ -18,59 +22,265 @@ def mock_device() -> Device: @pytest.fixture -def mock_raise_interrupt(): +def mock_raise_interrupt() -> RaiseInterrupt: async def raise_interrupt(): return False return Mock(raise_interrupt) -class MockAdapter(HTTPAdapter): +class ExampleAdapter(HttpAdapter): device: Device - @HTTPEndpoint.get("/mock_endpoint") - async def mock_endpoint(self, request: web.Request) -> web.Response: - return web.Response(text="test") + @HttpEndpoint.get("/foo") + async def get_foo(self, request: web.Request) -> web.Response: + return web.json_response({"value": "foo"}) + + @HttpEndpoint.put("/foo") + async def put_foo(self, request: web.Request) -> web.Response: + value = (await request.json())["value"] + return web.json_response({"value": value}) + + @HttpEndpoint.post("/foo") + async def post_foo(self, request: web.Request) -> web.Response: + value = (await request.json())["value"] + return web.json_response({"value": value}) + + @HttpEndpoint.get("/bar/{name}") + async def get_bar(self, request: web.Request) -> web.Response: + name = request.match_info["name"] + return web.json_response({"entity": name, "value": "bar"}) + + @HttpEndpoint.put("/bar/{name}") + async def put_bar(self, request: web.Request) -> web.Response: + name = request.match_info["name"] + value = (await request.json())["value"] + return web.json_response({"entity": name, "value": value}) + + @HttpEndpoint.post("/bar/{name}") + async def post_bar(self, request: web.Request) -> web.Response: + name = request.match_info["name"] + value = (await request.json())["value"] + return web.json_response({"entity": name, "value": value}) + + @HttpEndpoint.get("/baz") + async def get_baz(self, request: web.Request) -> web.Response: + return web.Response(status=403) + + @HttpEndpoint.get("/error") + async def cause_error(self, request: web.Request) -> web.Response: + raise Exception("An error has occurred") + + @HttpEndpoint.put("/interrupt/{name}", interrupt=True) + async def put_interrupt(self, request: web.Request) -> web.Response: + name = request.match_info["name"] + value = (await request.json())["value"] + return web.json_response({"entity": name, "value": value}) @pytest.fixture -def http_adapter() -> HTTPAdapter: - http_adapter = HTTPAdapter() +def adapter() -> HttpAdapter: + http_adapter = ExampleAdapter() return http_adapter -@pytest.mark.skip(ISSUE_LINK) -def test_http_adapter_constructor(): - HTTPAdapter() +@pytest_asyncio.fixture +async def adapter_task( + adapter: HttpAdapter, + mock_raise_interrupt: RaiseInterrupt, + mock_device: Device, + event_loop: asyncio.BaseEventLoop, +): + task = event_loop.create_task( + adapter.run_forever(mock_device, mock_raise_interrupt) + ) + await adapter.wait_until_ready() + yield task + await adapter.stop() + await asyncio.wait_for(task, timeout=10.0) + assert task.done() -@pytest.fixture -def patch_asyncio_event_wait() -> Iterable[Mock]: - with patch( - "tickit.core.components.device_simulation.asyncio.Event.wait", autospec=True - ) as mock: - yield mock +@pytest_asyncio.fixture +async def adapter_url(adapter_task: asyncio.Task, adapter: HttpAdapter): + yield f"http://localhost:{adapter.port}" @pytest.mark.asyncio -@pytest.mark.skip(ISSUE_LINK) -async def test_http_adapter_run_forever_method( - http_adapter: HTTPAdapter, - mock_device: Device, - mock_raise_interrupt: Mock, - patch_asyncio_event_wait: Mock, +async def test_shuts_down_server_on_cancel( + adapter: HttpAdapter, + adapter_task: asyncio.Task, + adapter_url: str, ): - await http_adapter.run_forever(mock_device, mock_raise_interrupt) - await http_adapter.shutdown() + # Verify server is up + await assert_server_is_up(adapter_url) + + # Cancel task + adapter_task.cancel() + try: + await adapter_task + except asyncio.CancelledError: + pass - patch_asyncio_event_wait.assert_awaited_once() + # Verify server is now down + await assert_server_is_down(adapter_url) + + +@pytest.mark.asyncio +async def test_stop_is_idempotent( + adapter: HttpAdapter, + adapter_task: asyncio.Task, + adapter_url: str, + mock_raise_interrupt: RaiseInterrupt, + mock_device: Device, +) -> None: + # First ensure the server is working, then stop it and + # ensure it is no longer working + await assert_server_is_up(adapter_url) + await adapter.stop() + await adapter_task + assert adapter_task.done() + await assert_server_is_down(adapter_url) + + for i in range(2): + # Then start it again and check it is working + new_task = asyncio.create_task( + adapter.run_forever( + mock_device, + mock_raise_interrupt, + ) + ) + await adapter.wait_until_ready() + await assert_server_is_up(adapter_url) + + # Finally stop it one more time and check it is stopped + await adapter.stop() + await new_task + assert new_task.done() + await assert_server_is_down(adapter_url) + + +async def assert_server_is_up(adapter_url: str) -> None: + url = f"{adapter_url}/foo" + async with aiohttp.ClientSession() as session: + async with session.get(url, timeout=REQUEST_TIMEOUT) as response: + assert response.status == 200 + + +async def assert_server_is_down(adapter_url: str) -> None: + url = f"{adapter_url}/foo" + with pytest.raises(aiohttp.ClientConnectionError): + async with aiohttp.ClientSession() as session: + async with session.get(url, timeout=REQUEST_TIMEOUT) as response: + assert response.status == 200 + + +@pytest.mark.asyncio +async def test_get(adapter_url: str): + url = f"{adapter_url}/foo" + async with aiohttp.ClientSession() as session: + async with session.get(url, timeout=REQUEST_TIMEOUT) as response: + assert response.status == 200 + assert (await response.json()) == {"value": "foo"} @pytest.mark.asyncio -@pytest.mark.skip(ISSUE_LINK) -async def test_http_adapter_endpoints(): - adapter = MockAdapter() +@pytest.mark.parametrize("name", ["a", "b"]) +async def test_get_by_name(adapter_url: str, name: str): + url = f"{adapter_url}/bar/{name}" + async with aiohttp.ClientSession() as session: + async with session.get(url, timeout=REQUEST_TIMEOUT) as response: + assert response.status == 200 + assert (await response.json()) == {"entity": name, "value": "bar"} - resp = await list(adapter.endpoints())[0].handler(None) - assert resp.text == "test" +@pytest.mark.asyncio +async def test_error_code(adapter_url: str): + url = f"{adapter_url}/baz" + async with aiohttp.ClientSession() as session: + async with session.get(url, timeout=REQUEST_TIMEOUT) as response: + assert response.status == 403 + + +@pytest.mark.asyncio +async def test_internal_error(adapter_url: str): + url = f"{adapter_url}/error" + async with aiohttp.ClientSession() as session: + async with session.get(url, timeout=REQUEST_TIMEOUT) as response: + assert response.status == 500 + + +@pytest.mark.asyncio +async def test_put(adapter_url: str): + url = f"{adapter_url}/foo" + async with aiohttp.ClientSession() as session: + async with session.put( + url, json={"value": "bar"}, timeout=REQUEST_TIMEOUT + ) as response: + assert response.status == 200 + assert (await response.json()) == {"value": "bar"} + + +@pytest.mark.asyncio +@pytest.mark.parametrize("name", ["a", "b"]) +async def test_put_by_name(adapter_url: str, name: str): + url = f"{adapter_url}/bar/{name}" + async with aiohttp.ClientSession() as session: + async with session.put( + url, json={"value": "foo"}, timeout=REQUEST_TIMEOUT + ) as response: + assert response.status == 200 + assert (await response.json()) == {"entity": name, "value": "foo"} + + +@pytest.mark.asyncio +async def test_post(adapter_url: str): + url = f"{adapter_url}/foo" + async with aiohttp.ClientSession() as session: + async with session.post( + url, json={"value": "bar"}, timeout=REQUEST_TIMEOUT + ) as response: + assert response.status == 200 + assert (await response.json()) == {"value": "bar"} + + +@pytest.mark.asyncio +@pytest.mark.parametrize("name", ["a", "b"]) +async def test_post_by_name(adapter_url: str, name: str): + url = f"{adapter_url}/bar/{name}" + async with aiohttp.ClientSession() as session: + async with session.post( + url, json={"value": "foo"}, timeout=REQUEST_TIMEOUT + ) as response: + assert response.status == 200 + assert (await response.json()) == {"entity": name, "value": "foo"} + + +@pytest.mark.asyncio +async def test_put_to_non_interrupting_endpoint_does_not_interrupt( + mock_raise_interrupt: Mock, + adapter_url: str, +): + url = f"{adapter_url}/foo" + async with aiohttp.ClientSession() as session: + async with session.put( + url, json={"value": "bar"}, timeout=REQUEST_TIMEOUT + ) as response: + assert response.status == 200 + assert (await response.json()) == {"value": "bar"} + mock_raise_interrupt.assert_not_called() + + +@pytest.mark.asyncio +async def test_put_to_interrupt( + mock_raise_interrupt: Mock, + adapter_url: str, +): + url = f"{adapter_url}/interrupt/a" + async with aiohttp.ClientSession() as session: + async with session.put( + url, json={"value": "foo"}, timeout=REQUEST_TIMEOUT + ) as response: + assert response.status == 200 + assert (await response.json()) == {"entity": "a", "value": "foo"} + mock_raise_interrupt.assert_called_once() diff --git a/tests/adapters/test_zmqadapter.py b/tests/adapters/test_zmqadapter.py deleted file mode 100644 index ef801ef24..000000000 --- a/tests/adapters/test_zmqadapter.py +++ /dev/null @@ -1,145 +0,0 @@ -import asyncio -import logging - -import aiozmq -import pytest -from mock import Mock -from mock.mock import AsyncMock, create_autospec - -from tickit.adapters.zmqadapter import ZeroMQAdapter -from tickit.core.device import Device - - -@pytest.fixture -def mock_device() -> Device: - return create_autospec(Device) - - -@pytest.fixture -def mock_raise_interrupt(): - async def raise_interrupt(): - return False - - return Mock(raise_interrupt) - - -@pytest.fixture -def mock_process_message_queue() -> AsyncMock: - async def _process_message_queue(): - return True - - return AsyncMock(_process_message_queue) - - -@pytest.fixture -def zeromq_adapter() -> ZeroMQAdapter: - zmq_adapter = ZeroMQAdapter() - zmq_adapter._dealer = AsyncMock() - zmq_adapter._router = AsyncMock() - zmq_adapter._message_queue = Mock(asyncio.Queue) - return zmq_adapter - - -def test_zeromq_adapter_constructor(): - ZeroMQAdapter() - - -@pytest.mark.asyncio -async def test_zeromq_adapter_start_stream(zeromq_adapter: ZeroMQAdapter): - await zeromq_adapter.start_stream() - - assert isinstance(zeromq_adapter._router, aiozmq.stream.ZmqStream) - assert isinstance(zeromq_adapter._dealer, aiozmq.stream.ZmqStream) - - await zeromq_adapter.close_stream() - assert zeromq_adapter.running is False - - -@pytest.mark.asyncio -async def test_zeromq_adapter_close_stream(zeromq_adapter: ZeroMQAdapter): - await zeromq_adapter.start_stream() - - await zeromq_adapter.close_stream() - await asyncio.sleep(0.1) - - assert zeromq_adapter.running is False - assert None is zeromq_adapter._router._transport - assert None is zeromq_adapter._dealer._transport - - -@pytest.mark.asyncio -async def test_zeromq_adapter_after_update(zeromq_adapter): - zeromq_adapter.after_update() - - -@pytest.mark.asyncio -async def test_zeromq_adapter_send_message(zeromq_adapter): - mock_message = AsyncMock() - - zeromq_adapter.send_message(mock_message) - task = asyncio.current_task() - asyncio.gather(task) - zeromq_adapter._message_queue.put.assert_called_once() - - -@pytest.mark.asyncio -async def test_zeromq_adapter_run_forever_method( - zeromq_adapter, - mock_device: Device, - mock_process_message_queue: AsyncMock, - mock_raise_interrupt: Mock, -): - zeromq_adapter._process_message_queue = mock_process_message_queue - - await zeromq_adapter.run_forever(mock_device, mock_raise_interrupt) - - zeromq_adapter._process_message_queue.assert_called_once() - - await zeromq_adapter.close_stream() - assert zeromq_adapter.running is False - - -@pytest.mark.asyncio -async def test_zeromq_adapter_check_if_running(zeromq_adapter): - assert zeromq_adapter.check_if_running() is False - - -@pytest.mark.asyncio -async def test_zeromq_adapter_process_message_queue(zeromq_adapter): - zeromq_adapter._process_message = AsyncMock() - zeromq_adapter.check_if_running = Mock(return_value=False) - - await zeromq_adapter._process_message_queue() - - zeromq_adapter._process_message.assert_awaited_once() - - -@pytest.mark.asyncio -async def test_zeromq_adapter_process_message(zeromq_adapter): - mock_message = "test" - - zeromq_adapter._dealer.write = Mock() - zeromq_adapter._router.write = Mock() - zeromq_adapter._dealer.read.return_value = ("Data", "test") - zeromq_adapter._router.read.return_value = ("Data", "test") - - await zeromq_adapter._process_message(mock_message) - - zeromq_adapter._dealer.read.assert_awaited_once() - zeromq_adapter._router.read.assert_awaited_once() - - -@pytest.mark.asyncio -async def test_zeromq_adapter_process_message_no_message(zeromq_adapter, caplog): - mock_message = None - - zeromq_adapter._dealer.read.return_value = ("Data", None) - zeromq_adapter._router.read.return_value = ("Data", None) - - with caplog.at_level(logging.DEBUG): - await zeromq_adapter._process_message(mock_message) - - assert len(caplog.records) == 1 - - zeromq_adapter._dealer.read.assert_not_awaited() - zeromq_adapter._router.read.assert_not_awaited() diff --git a/tests/adapters/zeromq/__init__.py b/tests/adapters/zeromq/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/adapters/zeromq/test_push_adapter.py b/tests/adapters/zeromq/test_push_adapter.py new file mode 100644 index 000000000..0bfbb2019 --- /dev/null +++ b/tests/adapters/zeromq/test_push_adapter.py @@ -0,0 +1,155 @@ +import asyncio +from typing import Sequence + +import aiozmq +import pytest +import pytest_asyncio +from mock import MagicMock, Mock +from mock.mock import AsyncMock, create_autospec +from pydantic.v1 import BaseModel + +from tickit.adapters.zeromq.push_adapter import ( + SocketFactory, + ZeroMqMessage, + ZeroMqPushAdapter, +) +from tickit.core.adapter import RaiseInterrupt +from tickit.core.device import Device + +_HOST = "test.host" +_PORT = 12345 + + +@pytest.fixture +def mock_device() -> Device: + return create_autospec(Device) + + +@pytest.fixture +def mock_raise_interrupt() -> RaiseInterrupt: + async def raise_interrupt(): + return False + + return Mock(raise_interrupt) + + +@pytest.fixture +def mock_socket() -> aiozmq.ZmqStream: + return MagicMock(aiozmq.ZmqStream) + + +@pytest.fixture +def socket_created() -> asyncio.Event: + return asyncio.Event() + + +@pytest.fixture +def mock_socket_factory( + mock_socket: aiozmq.ZmqStream, socket_created: asyncio.Event +) -> SocketFactory: + def make_socket(host: str, port: int) -> aiozmq.ZmqStream: + socket_created.set() + return mock_socket + + factory = AsyncMock() + factory.side_effect = make_socket + return factory + + +@pytest.fixture +def zeromq_adapter(mock_socket_factory: aiozmq.ZmqStream) -> ZeroMqPushAdapter: + return ZeroMqPushAdapter( + host=_HOST, + port=_PORT, + socket_factory=mock_socket_factory, + ) + + +@pytest_asyncio.fixture +async def running_zeromq_adapter( + zeromq_adapter: ZeroMqPushAdapter, + socket_created: asyncio.Event, + mock_device: Device, + mock_raise_interrupt: RaiseInterrupt, +) -> ZeroMqPushAdapter: + asyncio.create_task(zeromq_adapter.run_forever(mock_device, mock_raise_interrupt)) + await asyncio.wait_for(socket_created.wait(), timeout=2.0) + return zeromq_adapter + + +@pytest.mark.asyncio +async def test_socket_not_created_until_run_forever( + zeromq_adapter: ZeroMqPushAdapter, + mock_socket_factory: AsyncMock, + socket_created: asyncio.Event, + mock_device: Device, + mock_raise_interrupt: RaiseInterrupt, +) -> None: + mock_socket_factory.assert_not_called() + asyncio.create_task(zeromq_adapter.run_forever(mock_device, mock_raise_interrupt)) + await asyncio.wait_for(socket_created.wait(), timeout=2.0) + mock_socket_factory.assert_called_once_with(_HOST, _PORT) + + +class SimpleMessage(BaseModel): + foo: int + bar: str + + +class SubMessage(BaseModel): + baz: bool + + +class NestedMessage(BaseModel): + foo: int + bar: SubMessage + + +MESSGAGES = [ + ([b"foo"], [b"foo"]), + (["foo"], [b'"foo"']), + ([b"foo", b"bar"], [b"foo", b"bar"]), + ([b"foo", "bar"], [b"foo", b'"bar"']), + ([{"foo": 1, "bar": "baz"}], [b'{"foo": 1, "bar": "baz"}']), + ([{"foo": 1, "bar": {"baz": False}}], [b'{"foo": 1, "bar": {"baz": false}}']), + ([SimpleMessage(foo=1, bar="baz")], [b'{"foo": 1, "bar": "baz"}']), + ( + [NestedMessage(foo=1, bar=SubMessage(baz=False))], + [b'{"foo": 1, "bar": {"baz": false}}'], + ), +] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("message,serialized_message", MESSGAGES) +async def test_serializes_and_sends_message( + running_zeromq_adapter: ZeroMqPushAdapter, + mock_socket: MagicMock, + message: ZeroMqMessage, + serialized_message: Sequence[bytes], +) -> None: + await running_zeromq_adapter.send_message(message) + mock_socket.write.assert_called_once_with(serialized_message) + + +@pytest.mark.asyncio +async def test_socket_cleaned_up_on_cancel( + mock_device: Device, + mock_raise_interrupt: RaiseInterrupt, +) -> None: + adapter_a = ZeroMqPushAdapter() + adapter_b = ZeroMqPushAdapter() + for adapter in (adapter_a, adapter_b): + task = asyncio.create_task( + adapter.run_forever( + mock_device, + mock_raise_interrupt, + ) + ) + await adapter.send_message([b"test"]) + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + assert task.done()