diff --git a/docs/user/reference/openapi.yaml b/docs/user/reference/openapi.yaml index b918defb5..fa51219c6 100644 --- a/docs/user/reference/openapi.yaml +++ b/docs/user/reference/openapi.yaml @@ -33,6 +33,18 @@ components: - devices title: DeviceResponse type: object + EnvironmentResponse: + additionalProperties: false + description: State of internal environment. + properties: + initialized: + description: blueapi context initialized + title: Initialized + type: boolean + required: + - initialized + title: EnvironmentResponse + type: object HTTPValidationError: properties: detail: @@ -197,7 +209,7 @@ components: type: object info: title: BlueAPI Control - version: 0.0.4 + version: 0.0.5 openapi: 3.0.2 paths: /devices: @@ -237,6 +249,26 @@ paths: $ref: '#/components/schemas/HTTPValidationError' description: Validation Error summary: Get Device By Name + /environment: + delete: + operationId: delete_environment_environment_delete + responses: + '200': + content: + application/json: + schema: {} + description: Successful Response + summary: Delete Environment + get: + operationId: get_environment_environment_get + responses: + '200': + content: + application/json: + schema: + $ref: '#/components/schemas/EnvironmentResponse' + description: Successful Response + summary: Get Environment /plans: get: description: Retrieve information about all available plans. diff --git a/src/blueapi/cli/rest.py b/src/blueapi/cli/rest.py index 48dcc1e7a..7c4eeef81 100644 --- a/src/blueapi/cli/rest.py +++ b/src/blueapi/cli/rest.py @@ -72,7 +72,7 @@ def create_task(self, task: RunPlan) -> TaskResponse: data=task.dict(), ) - def delete_task(self, task_id: str) -> TaskResponse: + def clear_pending_task(self, task_id: str) -> TaskResponse: return self._request_and_deserialize( f"/tasks/{task_id}", TaskResponse, method="DELETE" ) diff --git a/src/blueapi/service/handler.py b/src/blueapi/service/handler.py index d7534f0e3..bab1eba34 100644 --- a/src/blueapi/service/handler.py +++ b/src/blueapi/service/handler.py @@ -28,6 +28,7 @@ class Handler(BlueskyHandler): _worker: Worker _config: ApplicationConfig _messaging_template: MessagingTemplate + _initialized: bool = False def __init__( self, @@ -64,6 +65,7 @@ def start(self) -> None: ) self._messaging_template.connect() + self._initialized = True def _publish_event_streams( self, streams_to_destinations: Mapping[EventStream, str] @@ -79,6 +81,7 @@ def _publish_event_stream(self, stream: EventStream, destination: str) -> None: ) def stop(self) -> None: + self._initialized = False self._worker.stop() if self._messaging_template.is_connected(): self._messaging_template.disconnect() @@ -134,6 +137,10 @@ def pending_tasks(self) -> List[TrackableTask]: def get_pending_task(self, task_id: str) -> Optional[TrackableTask]: return self._worker.get_pending_task(task_id) + @property + def initialized(self) -> bool: + return self._initialized + HANDLER: Optional[Handler] = None @@ -145,7 +152,6 @@ def setup_handler( provider = None plan_wrappers = [] - if config: visit_service_client: VisitServiceClientBase if config.env.data_writing.visit_service_url is not None: diff --git a/src/blueapi/service/handler_base.py b/src/blueapi/service/handler_base.py index faecda6e6..8ca15f1a8 100644 --- a/src/blueapi/service/handler_base.py +++ b/src/blueapi/service/handler_base.py @@ -83,3 +83,21 @@ def pending_tasks(self) -> List[TrackableTask]: def get_pending_task(self, task_id: str) -> Optional[TrackableTask]: """Returns a task matching the task ID supplied, if the worker knows of it""" + + @abstractmethod + def start(self): + """Start the handler""" + + @abstractmethod + def stop(self): + """Stop the handler""" + + @property + @abstractmethod + def initialized(self) -> bool: + """Handler initialization state""" + + +class HandlerNotStartedError(Exception): + def __init__(self, message): + super().__init__(message) diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index 7440c124b..d02dfc2f1 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -1,7 +1,16 @@ from contextlib import asynccontextmanager -from typing import Dict, Set - -from fastapi import Body, Depends, FastAPI, HTTPException, Request, Response, status +from typing import Dict, Optional, Set + +from fastapi import ( + BackgroundTasks, + Body, + Depends, + FastAPI, + HTTPException, + Request, + Response, + status, +) from pydantic import ValidationError from starlette.responses import JSONResponse from super_state_machine.errors import TransitionError @@ -9,19 +18,44 @@ from blueapi.config import ApplicationConfig from blueapi.worker import RunPlan, TrackableTask, WorkerState -from .handler import get_handler, setup_handler, teardown_handler from .handler_base import BlueskyHandler from .model import ( DeviceModel, DeviceResponse, + EnvironmentResponse, PlanModel, PlanResponse, StateChangeRequest, TaskResponse, WorkerTask, ) +from .subprocess_handler import SubprocessHandler + +REST_API_VERSION = "0.0.5" + +HANDLER: Optional[BlueskyHandler] = None + + +def get_handler() -> BlueskyHandler: + if HANDLER is None: + raise ValueError() + return HANDLER + + +def setup_handler(config: Optional[ApplicationConfig] = None): + global HANDLER + handler = SubprocessHandler(config) + handler.start() + + HANDLER = handler -REST_API_VERSION = "0.0.4" + +def teardown_handler(): + global HANDLER + if HANDLER is None: + return + HANDLER.stop() + HANDLER = None @asynccontextmanager @@ -49,6 +83,26 @@ async def on_key_error_404(_: Request, __: KeyError): ) +@app.get("/environment", response_model=EnvironmentResponse) +def get_environment( + handler: BlueskyHandler = Depends(get_handler), +) -> EnvironmentResponse: + return EnvironmentResponse(initialized=handler.initialized) + + +@app.delete("/environment") +async def delete_environment( + background_tasks: BackgroundTasks, + handler: BlueskyHandler = Depends(get_handler), +): + def restart_handler(handler: BlueskyHandler): + handler.stop() + handler.start() + + if handler.initialized: + background_tasks.add_task(restart_handler, handler) + + @app.get("/plans", response_model=PlanResponse) def get_plans(handler: BlueskyHandler = Depends(get_handler)): """Retrieve information about all available plans.""" diff --git a/src/blueapi/service/model.py b/src/blueapi/service/model.py index aa34ee98f..5626498d0 100644 --- a/src/blueapi/service/model.py +++ b/src/blueapi/service/model.py @@ -128,3 +128,11 @@ class StateChangeRequest(BlueapiBaseModel): description="The reason for the current run to be aborted", default=None, ) + + +class EnvironmentResponse(BlueapiBaseModel): + """ + State of internal environment. + """ + + initialized: bool = Field(description="blueapi context initialized") diff --git a/src/blueapi/service/subprocess_handler.py b/src/blueapi/service/subprocess_handler.py new file mode 100644 index 000000000..151d7cd9a --- /dev/null +++ b/src/blueapi/service/subprocess_handler.py @@ -0,0 +1,175 @@ +import logging +import signal +from multiprocessing import Pool, set_start_method +from multiprocessing.pool import Pool as PoolClass +from typing import Callable, Iterable, List, Optional + +from blueapi.config import ApplicationConfig +from blueapi.service.handler import get_handler, setup_handler, teardown_handler +from blueapi.service.handler_base import BlueskyHandler, HandlerNotStartedError +from blueapi.service.model import DeviceModel, PlanModel, WorkerTask +from blueapi.worker.event import WorkerState +from blueapi.worker.task import RunPlan +from blueapi.worker.worker import TrackableTask + +set_start_method("spawn", force=True) +LOGGER = logging.getLogger(__name__) + + +def _init_worker(): + # Replace sigint to allow subprocess to be terminated + signal.signal(signal.SIGINT, signal.SIG_IGN) + + +class SubprocessHandler(BlueskyHandler): + _config: ApplicationConfig + _subprocess: Optional[PoolClass] + _initialized: bool = False + + def __init__( + self, + config: Optional[ApplicationConfig] = None, + ) -> None: + self._config = config or ApplicationConfig() + self._subprocess = None + + def start(self): + if self._subprocess is None: + self._subprocess = Pool(initializer=_init_worker, processes=1) + self._subprocess.apply( + logging.basicConfig, kwds={"level": self._config.logging.level} + ) + self._subprocess.apply(setup_handler, [self._config]) + self._initialized = True + + def stop(self): + if self._subprocess is not None: + self._initialized = False + self._subprocess.apply(teardown_handler) + self._subprocess.close() + self._subprocess.join() + self._subprocess = None + + def reload_context(self): + self.stop() + self.start() + LOGGER.info("Context reloaded") + + def _run_in_subprocess( + self, function: Callable, arguments: Optional[Iterable] = None + ): + if arguments is None: + arguments = [] + if self._subprocess is None: + raise HandlerNotStartedError("Subprocess handler has not been started") + return self._subprocess.apply(function, arguments) + + @property + def plans(self) -> List[PlanModel]: + return self._run_in_subprocess(plans) + + def get_plan(self, name: str) -> PlanModel: + return self._run_in_subprocess(get_plan, [name]) + + @property + def devices(self) -> List[DeviceModel]: + return self._run_in_subprocess(devices) + + def get_device(self, name: str) -> DeviceModel: + return self._run_in_subprocess(get_device, [name]) + + def submit_task(self, task: RunPlan) -> str: + return self._run_in_subprocess(submit_task, [task]) + + def clear_pending_task(self, task_id: str) -> str: + return self._run_in_subprocess(clear_pending_task, [task_id]) + + def begin_task(self, task: WorkerTask) -> WorkerTask: + return self._run_in_subprocess(begin_task, [task]) + + @property + def active_task(self) -> Optional[TrackableTask]: + return self._run_in_subprocess(active_task) + + @property + def state(self) -> WorkerState: + return self._run_in_subprocess(state) + + def pause_worker(self, defer: Optional[bool]) -> None: + return self._run_in_subprocess(pause_worker, [defer]) + + def resume_worker(self) -> None: + return self._run_in_subprocess(resume_worker) + + def cancel_active_task(self, failure: bool, reason: Optional[str]) -> None: + return self._run_in_subprocess(cancel_active_task, [failure, reason]) + + @property + def pending_tasks(self) -> List[TrackableTask]: + return self._run_in_subprocess(pending_tasks) + + def get_pending_task(self, task_id: str) -> Optional[TrackableTask]: + return self._run_in_subprocess(get_pending_task, [task_id]) + + @property + def initialized(self) -> bool: + return self._initialized + + +# Free functions (passed to subprocess) for each of the methods required by Handler + + +def plans() -> List[PlanModel]: + return get_handler().plans + + +def get_plan(name: str): + return get_handler().get_plan(name) + + +def devices() -> List[DeviceModel]: + return get_handler().devices + + +def get_device(name: str) -> DeviceModel: + return get_handler().get_device(name) + + +def submit_task(task: RunPlan) -> str: + return get_handler().submit_task(task) + + +def clear_pending_task(task_id: str) -> str: + return get_handler().clear_pending_task(task_id) + + +def begin_task(task: WorkerTask) -> WorkerTask: + return get_handler().begin_task(task) + + +def active_task() -> Optional[TrackableTask]: + return get_handler().active_task + + +def state() -> WorkerState: + return get_handler().state + + +def pause_worker(defer: Optional[bool]) -> None: + return get_handler().pause_worker(defer) + + +def resume_worker() -> None: + return get_handler().resume_worker() + + +def cancel_active_task(failure: bool, reason: Optional[str]) -> None: + return get_handler().cancel_active_task(failure, reason) + + +def pending_tasks() -> List[TrackableTask]: + return get_handler().pending_tasks + + +def get_pending_task(task_id: str) -> Optional[TrackableTask]: + return get_handler().get_pending_task(task_id) diff --git a/tests/conftest.py b/tests/conftest.py index 362a91d01..c4dec97fa 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -9,8 +9,8 @@ from bluesky.run_engine import RunEngineStateMachine, TransitionError from fastapi.testclient import TestClient -from blueapi.service.handler import Handler, get_handler -from blueapi.service.main import app +from blueapi.service.handler import Handler +from blueapi.service.main import app, get_handler from src.blueapi.core import BlueskyContext diff --git a/tests/service/test_rest_api.py b/tests/service/test_rest_api.py index 84ce72a47..c72103302 100644 --- a/tests/service/test_rest_api.py +++ b/tests/service/test_rest_api.py @@ -11,6 +11,7 @@ from blueapi.core.bluesky_types import Plan from blueapi.service.handler import Handler +from blueapi.service.main import get_handler, setup_handler, teardown_handler from blueapi.worker.task import RunPlan from src.blueapi.worker import WorkerState @@ -485,3 +486,24 @@ def test_current_complete_returns_400( "/worker/state", json={"new_state": WorkerState.ABORTING.name, "reason": "foo"} ) assert response.status_code is status.HTTP_400_BAD_REQUEST + + +def test_get_environment(handler: Handler, client: TestClient) -> None: + assert client.get("/environment").json() == {"initialized": False} + + +def test_delete_environment(handler: Handler, client: TestClient) -> None: + handler._initialized = True + assert client.delete("/environment").status_code is status.HTTP_200_OK + + +def test_teardown_handler(): + setup_handler() + assert get_handler() is not None + teardown_handler() + with pytest.raises(ValueError): + get_handler() + + +def test_teardown_handler_does_not_raise(): + assert teardown_handler() is None diff --git a/tests/service/test_subprocess_handler.py b/tests/service/test_subprocess_handler.py new file mode 100644 index 000000000..15fdcdab4 --- /dev/null +++ b/tests/service/test_subprocess_handler.py @@ -0,0 +1,169 @@ +from typing import List, Optional + +import pytest +from mock import MagicMock, patch + +from blueapi.service.handler_base import BlueskyHandler, HandlerNotStartedError +from blueapi.service.model import DeviceModel, PlanModel, WorkerTask +from blueapi.service.subprocess_handler import SubprocessHandler +from blueapi.worker.event import WorkerState +from blueapi.worker.task import RunPlan +from blueapi.worker.worker import TrackableTask + + +@pytest.fixture(scope="module") +def sp_handler(): + sp_handler = SubprocessHandler() + sp_handler.start() + yield sp_handler + sp_handler.stop() + + +def test_initialize(): + sp_handler = SubprocessHandler() + assert not sp_handler.initialized + sp_handler.start() + assert sp_handler.initialized + # Run a single call to the handler for coverage of dispatch to subprocess + assert sp_handler.pending_tasks == [] + sp_handler.stop() + assert not sp_handler.initialized + + +def test_reload(): + sp_handler = SubprocessHandler() + sp_handler.start() + assert sp_handler.initialized + sp_handler.reload_context() + assert sp_handler.initialized + sp_handler.stop() + + +def test_raises_if_not_started(): + sp_handler = SubprocessHandler() + with pytest.raises(HandlerNotStartedError): + sp_handler.state + + +class DummyHandler(BlueskyHandler): + @property + def plans(self) -> List[PlanModel]: + return [PlanModel(name="plan1"), PlanModel(name="plan2")] + + def get_plan(self, name: str) -> PlanModel: + return PlanModel(name="plan1") + + @property + def devices(self) -> List[DeviceModel]: + return [ + DeviceModel(name="device1", protocols=[]), + DeviceModel(name="device2", protocols=[]), + ] + + def get_device(self, name: str) -> DeviceModel: + return DeviceModel(name="device1", protocols=[]) + + def submit_task(self, task: RunPlan) -> str: + return "0" + + def clear_pending_task(self, task_id: str) -> str: + return "1" + + def begin_task(self, task: WorkerTask) -> WorkerTask: + return WorkerTask(task_id=task.task_id) + + @property + def active_task(self) -> Optional[TrackableTask]: + return None + + @property + def state(self) -> WorkerState: + return WorkerState.IDLE + + def pause_worker(self, defer: Optional[bool]) -> None: ... + + def resume_worker(self) -> None: ... + + def cancel_active_task(self, failure: bool, reason: Optional[str]) -> None: ... + + @property + def pending_tasks(self) -> List[TrackableTask]: + return [ + TrackableTask( + task_id="abc", task=RunPlan(name="sleep", params={"time": 0.0}) + ) + ] + + def get_pending_task(self, task_id: str) -> Optional[TrackableTask]: + return None + + def start(self): ... + + def stop(self): ... + + # Initialized is a special case as it is not delegated + # Tested by test_initialize + @property + def initialized(self) -> bool: + raise Exception("Not implemented") + + +@patch("blueapi.service.subprocess_handler.get_handler") +def test_method_routing(get_handler_mock: MagicMock): + + # Mock get_handler to prevent using a real internal handler + dummy_handler = DummyHandler() + get_handler_mock.return_value = dummy_handler + + # For above to work, prevent use of subprocess + def run_in_same_process(func, args=None): + if args is None: + args = [] + return func(*args) + + sp_handler = SubprocessHandler() + sp_handler._run_in_subprocess = MagicMock( # type: ignore + side_effect=run_in_same_process + ) + + # Verify each method is routed correctly + + assert sp_handler.plans == dummy_handler.plans + + assert sp_handler.get_plan("name") == dummy_handler.get_plan("name") + + assert sp_handler.devices == dummy_handler.devices + + assert sp_handler.get_device("name") == dummy_handler.get_device("name") + + assert sp_handler.submit_task( + RunPlan(name="sleep", params={"time": 0.0}) + ) == dummy_handler.submit_task(RunPlan(name="sleep", params={"time": 0.0})) + + assert sp_handler.clear_pending_task("task_id") == dummy_handler.clear_pending_task( + "task_id" + ) + + assert sp_handler.begin_task(WorkerTask(task_id="foo")) == dummy_handler.begin_task( + WorkerTask(task_id="foo") + ) + + assert sp_handler.active_task == dummy_handler.active_task + + assert sp_handler.state == dummy_handler.state + + sp_handler.pause_worker(True) + + sp_handler.resume_worker() + + sp_handler.cancel_active_task(True, "reason") + + assert sp_handler.pending_tasks == dummy_handler.pending_tasks + + assert sp_handler.get_pending_task("task_id") == dummy_handler.get_pending_task( + "task_id" + ) + + assert sp_handler.start() == dummy_handler.start() + + assert sp_handler.stop() == dummy_handler.stop()