Skip to content

Commit

Permalink
Create very simple RPC so the subprocess loads functions (#584)
Browse files Browse the repository at this point in the history
Fixes #590

---------

Co-authored-by: Ware, Joseph (DLSLtd,RAL,LSCI) <joseph.ware@diamond.ac.uk>
Co-authored-by: DiamondJoseph <53935796+DiamondJoseph@users.noreply.github.com>
  • Loading branch information
3 people authored Aug 14, 2024
1 parent 5351741 commit fd41e57
Show file tree
Hide file tree
Showing 5 changed files with 182 additions and 42 deletions.
24 changes: 11 additions & 13 deletions src/blueapi/service/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def get_plans(runner: WorkerDispatcher = Depends(_runner)):
)
def get_plan_by_name(name: str, runner: WorkerDispatcher = Depends(_runner)):
"""Retrieve information about a plan by its (unique) name."""
return runner.run(interface.get_plan, [name])
return runner.run(interface.get_plan, name)


@app.get("/devices", response_model=DeviceResponse)
Expand All @@ -132,7 +132,7 @@ def get_devices(runner: WorkerDispatcher = Depends(_runner)):
)
def get_device_by_name(name: str, runner: WorkerDispatcher = Depends(_runner)):
"""Retrieve information about a devices by its (unique) name."""
return runner.run(interface.get_device, [name])
return runner.run(interface.get_device, name)


example_task = Task(name="count", params={"detectors": ["x"]})
Expand All @@ -151,8 +151,8 @@ def submit_task(
):
"""Submit a task to the worker."""
try:
plan_model = runner.run(interface.get_plan, [task.name])
task_id: str = runner.run(interface.submit_task, [task])
plan_model = runner.run(interface.get_plan, task.name)
task_id: str = runner.run(interface.submit_task, task)
response.headers["Location"] = f"{request.url}/{task_id}"
return TaskResponse(task_id=task_id)
except ValidationError as e:
Expand All @@ -176,7 +176,7 @@ def delete_submitted_task(
task_id: str,
runner: WorkerDispatcher = Depends(_runner),
) -> TaskResponse:
return TaskResponse(task_id=runner.run(interface.clear_task, [task_id]))
return TaskResponse(task_id=runner.run(interface.clear_task, task_id))


def validate_task_status(v: str) -> TaskStatusEnum:
Expand Down Expand Up @@ -205,7 +205,7 @@ def get_tasks(
detail="Invalid status query parameter",
) from e

tasks = runner.run(interface.get_tasks_by_status, [desired_status])
tasks = runner.run(interface.get_tasks_by_status, desired_status)
else:
tasks = runner.run(interface.get_tasks)
return TasksListResponse(tasks=tasks)
Expand All @@ -227,7 +227,7 @@ def set_active_task(
raise HTTPException(
status_code=status.HTTP_409_CONFLICT, detail="Worker already active"
)
runner.run(interface.begin_task, [task])
runner.run(interface.begin_task, task)
return task


Expand All @@ -240,7 +240,7 @@ def get_task(
runner: WorkerDispatcher = Depends(_runner),
) -> TrackableTask:
"""Retrieve a task"""
task = runner.run(interface.get_task_by_id, [task_id])
task = runner.run(interface.get_task_by_id, task_id)
if task is None:
raise KeyError
return task
Expand Down Expand Up @@ -313,17 +313,15 @@ def set_state(
and new_state in _ALLOWED_TRANSITIONS[current_state]
):
if new_state == WorkerState.PAUSED:
runner.run(interface.pause_worker, [state_change_request.defer])
runner.run(interface.pause_worker, state_change_request.defer)
elif new_state == WorkerState.RUNNING:
runner.run(interface.resume_worker)
elif new_state in {WorkerState.ABORTING, WorkerState.STOPPING}:
try:
runner.run(
interface.cancel_active_task,
[
state_change_request.new_state is WorkerState.ABORTING,
state_change_request.reason,
],
state_change_request.new_state is WorkerState.ABORTING,
state_change_request.reason,
)
except TransitionError:
response.status_code = status.HTTP_400_BAD_REQUEST
Expand Down
82 changes: 67 additions & 15 deletions src/blueapi/service/runner.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
import inspect
import logging
import signal
from collections.abc import Callable, Iterable
from collections.abc import Callable
from importlib import import_module
from multiprocessing import Pool, set_start_method
from multiprocessing.pool import Pool as PoolClass
from typing import Any
from typing import Any, ParamSpec, TypeVar

from blueapi.config import ApplicationConfig
from blueapi.service.interface import (
setup,
teardown,
)
from blueapi.service.interface import setup, teardown
from blueapi.service.model import EnvironmentResponse

# The default multiprocessing start method is fork
set_start_method("spawn", force=True)

LOGGER = logging.getLogger(__name__)

P = ParamSpec("P")
T = TypeVar("T")


def _init_worker():
# Replace sigint to allow subprocess to be terminated
Expand Down Expand Up @@ -56,7 +58,7 @@ def start(self):
try:
if self._use_subprocess:
self._subprocess = Pool(initializer=_init_worker, processes=1)
self.run(setup, [self._config])
self.run(setup, self._config)
self._state = EnvironmentResponse(initialized=True)
except Exception as e:
self._state = EnvironmentResponse(
Expand All @@ -82,21 +84,39 @@ def stop(self):
)
LOGGER.exception(e)

def run(self, function: Callable, arguments: Iterable | None = None) -> Any:
arguments = arguments or []
def run(self, function: Callable[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
if self._use_subprocess:
return self._run_in_subprocess(function, arguments)
return self._run_in_subprocess(function, *args, **kwargs)
else:
return function(*arguments)
return function(*args, **kwargs)

def _run_in_subprocess(
self,
function: Callable,
arguments: Iterable,
) -> Any:
function: Callable[P, T],
*args: P.args,
**kwargs: P.kwargs,
) -> T:
if self._subprocess is None:
raise InvalidRunnerStateError("Subprocess runner has not been started")
return self._subprocess.apply(function, arguments)
if not (hasattr(function, "__name__") and hasattr(function, "__module__")):
raise RpcError(f"{function} is anonymous, cannot be run in subprocess")
if not callable(function):
raise RpcError(f"{function} is not Callable, cannot be run in subprocess")
try:
return_type = inspect.signature(function).return_annotation
except TypeError:
return_type = None

return self._subprocess.apply(
_rpc,
(
function.__module__,
function.__name__,
return_type,
*args,
),
kwargs,
)

@property
def state(self) -> EnvironmentResponse:
Expand All @@ -106,3 +126,35 @@ def state(self) -> EnvironmentResponse:
class InvalidRunnerStateError(Exception):
def __init__(self, message):
super().__init__(message)


class RpcError(Exception): ...


def _rpc(
module_name: str,
function_name: str,
expected_type: type[T] | None,
*args: Any,
**kwargs: Any,
) -> T:
mod = import_module(module_name)
func: Callable[P, T] = _validate_function(
mod.__dict__.get(function_name, None), function_name
)
value = func(*args, **kwargs)
if expected_type is None or isinstance(value, expected_type):
return value
else:
raise TypeError(
f"{function_name} returned value of type {type(value)}"
+ f" which is incompatible with expected {expected_type}"
)


def _validate_function(func: Any, function_name: str) -> Callable:
if func is None:
raise RpcError(f"{function_name}: No such function in subprocess API")
elif not callable(func):
raise RpcError(f"{function_name}: Object in subprocess is not a function")
return func
16 changes: 15 additions & 1 deletion tests/core/fake_device_module.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from unittest.mock import MagicMock
from unittest.mock import MagicMock, NonCallableMock

from ophyd import EpicsMotor

Expand Down Expand Up @@ -26,6 +26,20 @@ def fake_motor_bundle_a(


def _mock_with_name(name: str) -> MagicMock:
# mock.name must return str, cannot MagicMock(name=name)
mock = MagicMock()
mock.name = name
return mock


def wrong_return_type() -> int:
return "0" # type: ignore


fetchable_non_callable = NonCallableMock()
fetchable_callable = MagicMock(return_value="string")

fetchable_non_callable.__name__ = "fetchable_non_callable"
fetchable_non_callable.__module__ = fake_motor_bundle_a.__module__
fetchable_callable.__name__ = "fetchable_callable"
fetchable_callable.__module__ = fake_motor_bundle_a.__module__
2 changes: 1 addition & 1 deletion tests/messaging/test_stomptemplate.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def server(ctx: MessageContext, message: message_type) -> None: # type: ignore
reply = template.send_and_receive(test_queue, message, message_type).result(
timeout=_TIMEOUT
)
if type(message) == np.ndarray:
if type(message) is np.ndarray:
message = message.tolist()
assert reply == message

Expand Down
100 changes: 88 additions & 12 deletions tests/service/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,31 @@

from blueapi.service import interface
from blueapi.service.model import EnvironmentResponse
from blueapi.service.runner import InvalidRunnerStateError, WorkerDispatcher
from blueapi.service.runner import (
InvalidRunnerStateError,
RpcError,
WorkerDispatcher,
)


def test_initialize():
runner = WorkerDispatcher()
@pytest.fixture
def local_runner():
return WorkerDispatcher(use_subprocess=False)


@pytest.fixture
def runner():
return WorkerDispatcher()


@pytest.fixture
def started_runner(runner: WorkerDispatcher):
runner.start()
yield runner
runner.stop()


def test_initialize(runner: WorkerDispatcher):
assert not runner.state.initialized
runner.start()
assert runner.state.initialized
Expand All @@ -19,23 +39,20 @@ def test_initialize():
assert not runner.state.initialized


def test_reload():
runner = WorkerDispatcher()
def test_reload(runner: WorkerDispatcher):
runner.start()
assert runner.state.initialized
runner.reload()
assert runner.state.initialized
runner.stop()


def test_raises_if_used_before_started():
runner = WorkerDispatcher()
def test_raises_if_used_before_started(runner: WorkerDispatcher):
with pytest.raises(InvalidRunnerStateError):
assert runner.run(interface.get_plans) is None
runner.run(interface.get_plans)


def test_error_on_runner_setup():
runner = WorkerDispatcher(use_subprocess=False)
def test_error_on_runner_setup(local_runner: WorkerDispatcher):
expected_state = EnvironmentResponse(
initialized=False,
error_message="Intentional start_worker exception",
Expand All @@ -48,8 +65,8 @@ def test_error_on_runner_setup():
# Calling reload here instead of start also indirectly
# tests that stop() doesn't raise if there is no error message
# and the runner is not yet initialised
runner.reload()
state = runner.state
local_runner.reload()
state = local_runner.state
assert state == expected_state


Expand Down Expand Up @@ -85,3 +102,62 @@ def test_can_reload_after_an_error(pool_mock: MagicMock):
runner.reload()

assert runner.state == EnvironmentResponse(initialized=True, error_message=None)


def test_function_not_findable_on_subprocess(started_runner: WorkerDispatcher):
from tests.core.fake_device_module import fake_motor_y

# Valid target on main but not sub process
# Change in this process not reflected in subprocess
fake_motor_y.__name__ = "not_exported"

with pytest.raises(
RpcError, match="not_exported: No such function in subprocess API"
):
started_runner.run(fake_motor_y)


def test_non_callable_excepts_in_main_process(started_runner: WorkerDispatcher):
# Not a valid target on main or sub process
from tests.core.fake_device_module import fetchable_non_callable

with pytest.raises(
RpcError,
match="<NonCallableMock id='[0-9]+'> is not Callable, "
+ "cannot be run in subprocess",
):
started_runner.run(fetchable_non_callable)


def test_non_callable_excepts_in_sub_process(started_runner: WorkerDispatcher):
# Valid target on main but finds non-callable in sub process
from tests.core.fake_device_module import fetchable_callable, fetchable_non_callable

fetchable_callable.__name__ = fetchable_non_callable.__name__

with pytest.raises(
RpcError,
match="fetchable_non_callable: Object in subprocess is not a function",
):
started_runner.run(fetchable_callable)


def test_clear_message_for_anonymous_function(started_runner: WorkerDispatcher):
non_fetchable_callable = MagicMock()

with pytest.raises(
RpcError,
match="<MagicMock id='[0-9]+'> is anonymous, cannot be run in subprocess",
):
started_runner.run(non_fetchable_callable)


def test_clear_message_for_wrong_return(started_runner: WorkerDispatcher):
from tests.core.fake_device_module import wrong_return_type

with pytest.raises(
TypeError,
match="wrong_return_type returned value of type <class 'str'>"
+ " which is incompatible with expected <class 'int'>",
):
started_runner.run(wrong_return_type)

0 comments on commit fd41e57

Please sign in to comment.