Skip to content

Commit

Permalink
Handle parameterised generics (#598)
Browse files Browse the repository at this point in the history
Closes #597

---------

Co-authored-by: Callum Forrester <callum.forrester@diamond.ac.uk>
  • Loading branch information
DiamondJoseph and callumforrester committed Aug 29, 2024
1 parent 1d80837 commit ad3f686
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 9 deletions.
13 changes: 8 additions & 5 deletions src/blueapi/service/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from multiprocessing.pool import Pool as PoolClass
from typing import Any, ParamSpec, TypeVar

from pydantic import TypeAdapter

from blueapi.config import ApplicationConfig
from blueapi.service.interface import setup, teardown
from blueapi.service.model import EnvironmentResponse
Expand Down Expand Up @@ -143,13 +145,14 @@ def _rpc(
mod.__dict__.get(function_name, None), function_name
)
value = func(*args, **kwargs)
if expected_type is None or isinstance(value, expected_type):
return _valid_return(value, expected_type)


def _valid_return(value: Any, expected_type: type[T] | None = None) -> T:
if expected_type is None:
return value
else:
raise TypeError(
f"{function_name} returned value of type {type(value)}"
+ f" which is incompatible with expected {expected_type}"
)
return TypeAdapter(expected_type).validate_python(value)


def _validate_function(func: Any, function_name: str) -> Callable:
Expand Down
2 changes: 1 addition & 1 deletion tests/core/fake_device_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def _mock_with_name(name: str) -> MagicMock:


def wrong_return_type() -> int:
return "0" # type: ignore
return "str" # type: ignore


fetchable_non_callable = NonCallableMock()
Expand Down
86 changes: 83 additions & 3 deletions tests/service/test_runner.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from typing import Any, Generic, TypeVar
from unittest import mock
from unittest.mock import MagicMock, patch

import pytest
from ophyd import Callable
from pydantic import BaseModel, ValidationError

from blueapi.service import interface
from blueapi.service.model import EnvironmentResponse
Expand Down Expand Up @@ -156,8 +159,85 @@ def test_clear_message_for_wrong_return(started_runner: WorkerDispatcher):
from tests.core.fake_device_module import wrong_return_type

with pytest.raises(
TypeError,
match="wrong_return_type returned value of type <class 'str'>"
+ " which is incompatible with expected <class 'int'>",
ValidationError,
match="1 validation error for int",
):
started_runner.run(wrong_return_type)


T = TypeVar("T")


class SimpleModel(BaseModel):
a: int
b: str


class NestedModel(BaseModel):
nested: SimpleModel
c: bool


class GenericModel(BaseModel, Generic[T]):
a: T
b: str


def return_int() -> int:
return 1


def return_str() -> str:
return "hello"


def return_list() -> list[int]:
return [1, 2, 3]


def return_dict() -> dict[str, int]:
return {
"test": 1,
"other_test": 2,
}


def return_simple_model() -> SimpleModel:
return SimpleModel(a=1, b="hi")


def return_nested_model() -> NestedModel:
return NestedModel(nested=return_simple_model(), c=False)


def return_unbound_generic_model() -> GenericModel:
return GenericModel(a="foo", b="bar")


def return_bound_generic_model() -> GenericModel[int]:
return GenericModel(a=1, b="hi")


def return_explicitly_bound_generic_model() -> GenericModel[int]:
return GenericModel[int](a=1, b="hi")


@pytest.mark.parametrize(
"rpc_function",
[
return_int,
return_str,
return_list,
return_dict,
return_simple_model,
return_nested_model,
return_unbound_generic_model,
# https://github.com/pydantic/pydantic/issues/6870 return_bound_generic_model,
return_explicitly_bound_generic_model,
],
)
def test_accepts_return_type(
started_runner: WorkerDispatcher,
rpc_function: Callable[[], Any],
):
started_runner.run(rpc_function)

0 comments on commit ad3f686

Please sign in to comment.