From ec13eac69479d35ae9a2fd5afce2933d0679edd8 Mon Sep 17 00:00:00 2001 From: DiamondJoseph <53935796+DiamondJoseph@users.noreply.github.com> Date: Fri, 6 Sep 2024 14:26:30 +0100 Subject: [PATCH] Update to Dodal with Pydantic2 support, remove deprecated Pydantic usage (#625) --- dev-requirements.txt | 65 +++++++++++++++++++++--------- docs/how-to/write-plans.md | 2 +- pyproject.toml | 8 ++-- src/blueapi/cli/cli.py | 4 +- src/blueapi/cli/format.py | 6 +-- src/blueapi/client/rest.py | 8 ++-- src/blueapi/config.py | 14 +++---- src/blueapi/core/context.py | 14 +++++-- src/blueapi/utils/serialization.py | 2 +- tests/core/test_context.py | 17 ++++---- tests/service/test_rest_api.py | 20 ++++----- tests/test_cli.py | 14 +++---- 12 files changed, 104 insertions(+), 70 deletions(-) diff --git a/dev-requirements.txt b/dev-requirements.txt index 317ce9015..dba27878b 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -10,6 +10,7 @@ anyio==4.4.0 appdirs==1.4.4 asciitree==0.3.3 asttokens==2.4.1 +async-timeout==4.0.3 attrs==24.2.0 babel==2.16.0 beautifulsoup4==4.12.3 @@ -19,9 +20,10 @@ bluesky-kafka==0.10.0 bluesky-live==0.0.8 bluesky-stomp==0.1.0 boltons==24.0.0 +bump-pydantic==0.8.0 cachetools==5.5.0 caproto==1.1.1 -certifi==2024.7.4 +certifi==2024.8.30 cfgv==3.4.0 charset-normalizer==3.3.2 click==8.1.7 @@ -29,36 +31,40 @@ cloudpickle==3.0.0 colorama==0.4.6 colorlog==6.8.2 comm==0.2.2 -confluent-kafka==2.5.0 +compress-pickle==2.1.0 +confluent-kafka==2.5.3 contourpy==1.3.0 copier==9.3.1 coverage==7.6.1 cycler==0.12.1 -dask==2024.8.1 +dask==2024.8.2 databroker==1.2.5 dataclasses-json==0.6.7 decorator==5.1.1 -deepmerge==1.1.1 +deepmerge==2.0 distlib==0.3.8 dls-bluesky-core==0.0.4 -dls-dodal==1.29.4 +dls-dodal==1.31.0 dnspython==2.6.1 docopt==0.6.2 doct==1.1.0 docutils==0.21.2 dunamai==1.22.0 +email_validator==2.2.0 entrypoints==0.4 epicscorelibs==7.0.7.99.0.2 -event-model==1.20.0 -executing==2.0.1 -fastapi==0.112.2 +event-model==1.21.0 +exceptiongroup==1.2.2 +executing==2.1.0 +fastapi==0.113.0 +fastapi-cli==0.0.5 fasteners==0.19 filelock==3.15.4 flexcache==0.3 flexparser==0.3.1 fonttools==4.53.1 frozenlist==1.4.1 -fsspec==2024.6.1 +fsspec==2024.9.0 funcy==2.0 gitdb==4.0.11 GitPython==3.1.43 @@ -68,6 +74,7 @@ h5py==3.11.0 HeapDict==1.0.1 historydict==1.2.6 httpcore==1.0.5 +httptools==0.6.1 httpx==0.27.2 humanize==4.10.0 identify==2.6.0 @@ -80,15 +87,19 @@ iniconfig==2.0.0 intake==0.6.4 ipython==8.18.0 ipywidgets==8.1.5 +itsdangerous==2.2.0 jedi==0.19.1 Jinja2==3.1.4 jinja2-ansible-filters==1.3.2 jsonschema==4.23.0 jsonschema-specifications==2023.12.1 jupyterlab_widgets==3.0.13 -kiwisolver==1.4.5 +kiwisolver==1.4.7 ldap3==2.9.1 +libcst==1.4.0 +livereload==2.7.0 locket==1.0.0 +lz4==4.3.3 markdown-it-py==3.0.0 MarkupSafe==2.1.5 marshmallow==3.22.0 @@ -113,7 +124,7 @@ numcodecs==0.13.0 numpy==1.26.4 opencv-python-headless==4.10.0.84 ophyd==1.9.0 -ophyd-async==0.3.4 +ophyd-async==0.5.2 orjson==3.10.7 p4p==4.1.12 packaging==24.1 @@ -143,9 +154,11 @@ pvxslibs==1.3.1 py==1.11.0 pyasn1==0.6.0 pycryptodome==3.20.0 -pydantic==2.8.2 +pydantic==2.9.0 +pydantic-extra-types==2.9.0 pydantic-settings==2.4.0 -pydantic_core==2.20.1 +pydantic_core==2.23.2 +pydantic_numpy==5.0.2 pydata-sphinx-theme==0.15.4 pyepics==3.5.7 Pygments==2.18.0 @@ -157,18 +170,25 @@ pytest-asyncio==0.24.0 pytest-cov==5.0.0 python-dateutil==2.9.0.post0 python-dotenv==1.0.1 +python-multipart==0.0.9 pytz==2024.1 PyYAML==6.0.2 +pyyaml-include==2.1 questionary==2.0.1 redis==5.0.8 redis-json-dict==0.2.0 referencing==0.35.1 requests==2.32.3 responses==0.25.3 +rich==13.7.1 rpds-py==0.20.0 -ruff==0.6.2 +ruamel.yaml==0.18.6 +ruamel.yaml.clib==0.2.8 +ruff==0.6.4 scanspec==0.7.2 +semver==3.0.2 setuptools-dso==2.11 +shellingham==1.5.4 six==1.16.0 slicerator==1.1.0 smmap==5.0.1 @@ -176,7 +196,7 @@ sniffio==1.3.1 snowballstemmer==2.2.0 soupsieve==2.6 Sphinx==8.0.2 -sphinx-autobuild==2024.4.16 +sphinx-autobuild==2024.9.3 sphinx-click==6.0.0 sphinx-copybutton==0.5.2 sphinx_design==0.6.1 @@ -190,28 +210,33 @@ sphinxcontrib-openapi==0.8.4 sphinxcontrib-qthelp==2.0.0 sphinxcontrib-serializinghtml==2.0.0 stack-data==0.6.3 -starlette==0.38.2 +starlette==0.38.4 stomp-py==8.1.2 suitcase-mongo==0.6.0 suitcase-msgpack==0.3.0 suitcase-utils==0.5.4 super-state-machine==2.0.2 -tifffile==2024.8.28 +tifffile==2024.8.30 +tomli==2.0.1 toolz==0.12.1 +tornado==6.4.1 tox==3.28.0 tox-direct==0.4 tqdm==4.66.5 traitlets==5.14.3 +typer==0.12.4 types-mock==5.1.0.20240425 types-PyYAML==6.0.12.20240808 -types-requests==2.32.0.20240712 +types-requests==2.32.0.20240905 types-urllib3==1.26.25.14 typing-inspect==0.9.0 typing_extensions==4.12.2 tzdata==2024.1 tzlocal==5.2 +ujson==5.10.0 urllib3==2.2.2 uvicorn==0.30.6 +uvloop==0.19.0 virtualenv==20.26.3 watchfiles==0.24.0 wcwidth==0.2.13 @@ -220,8 +245,8 @@ websockets==13.0.1 widgetsnbextension==4.0.13 workflows==2.27 xarray==2024.7.0 -yarl==1.9.4 -zarr==2.18.2 +yarl==1.9.11 +zarr==2.18.3 zict==2.2.0 zipp==3.20.1 zocalo==1.1.0 diff --git a/docs/how-to/write-plans.md b/docs/how-to/write-plans.md index d71798672..eb6d6f503 100644 --- a/docs/how-to/write-plans.md +++ b/docs/how-to/write-plans.md @@ -36,7 +36,7 @@ For example, if a plan is written to drive a specific implementation of Movable, When added to the blueapi context, `PlanGenerator`\ s are formalised into their schema- `a Pydantic BaseModel `__ with the expected argument types and their defaults. -Therefore, `PlanGenerator`\ s must only take as arguments `those types which are valid Pydantic fields `__ or Device types which implement `BLUESKY_PROTOCOLS` defined in dodal, which are fetched from the context at runtime. +Therefore, `PlanGenerator`\ s must only take as arguments `those types which are valid Pydantic fields `__ or Device types which implement `BLUESKY_PROTOCOLS` defined in dodal, which are fetched from the context at runtime. Allowed argument types for Pydantic BaseModels include the primitives, types that extend `BaseModel` and `dict`\ s, `list`\ s and other `sequence`\ s of supported types. Blueapi will deserialise these types from JSON, so `dict`\ s should use `str` keys. diff --git a/pyproject.toml b/pyproject.toml index 545ce0ae4..717c16b61 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,11 +27,11 @@ dependencies = [ "fastapi>=0.112.0", "uvicorn", "requests", - "dls-bluesky-core", #requires ophyd-async - "dls-dodal>=1.24.0", - "super-state-machine", # See GH issue 553 + "dls-bluesky-core", #requires ophyd-async + "dls-dodal>=1.31.0", + "super-state-machine", # See GH issue 553 "GitPython", - "bluesky-stomp>=0.1.0" + "bluesky-stomp>=0.1.0", ] dynamic = ["version"] license.file = "LICENSE" diff --git a/src/blueapi/cli/cli.py b/src/blueapi/cli/cli.py index 3802e4d13..6455a9814 100644 --- a/src/blueapi/cli/cli.py +++ b/src/blueapi/cli/cli.py @@ -162,7 +162,7 @@ def on_event( event: WorkerEvent | ProgressEvent | DataEvent, context: MessageContext, ) -> None: - converted = json.dumps(event.dict(), indent=2) + converted = json.dumps(event.model_dump(), indent=2) print(converted) print( @@ -218,7 +218,7 @@ def on_event(event: AnyEvent) -> None: pprint("task could not run") return - pprint(resp.dict()) + pprint(resp.model_dump()) if resp.task_status is not None and not resp.task_status.task_failed: print("Plan Succeeded") diff --git a/src/blueapi/cli/format.py b/src/blueapi/cli/format.py index f546f0c17..3baaafa99 100644 --- a/src/blueapi/cli/format.py +++ b/src/blueapi/cli/format.py @@ -57,11 +57,11 @@ def display_json(obj: Any, stream: Stream): print = partial(builtins.print, file=stream) match obj: case PlanResponse(plans=plans): - print(json.dumps([p.dict() for p in plans], indent=2)) + print(json.dumps([p.model_dump() for p in plans], indent=2)) case DeviceResponse(devices=devices): - print(json.dumps([d.dict() for d in devices], indent=2)) + print(json.dumps([d.model_dump() for d in devices], indent=2)) case BaseModel(): - print(json.dumps(obj.dict(), indent=2)) + print(json.dumps(obj.model_dump(), indent=2)) case _: print(json.dumps(obj)) diff --git a/src/blueapi/client/rest.py b/src/blueapi/client/rest.py index 81ece17d6..572d75974 100644 --- a/src/blueapi/client/rest.py +++ b/src/blueapi/client/rest.py @@ -2,7 +2,7 @@ from typing import Any, Literal, TypeVar import requests -from pydantic import parse_obj_as +from pydantic import TypeAdapter from blueapi.config import RestConfig from blueapi.service.model import ( @@ -78,7 +78,7 @@ def create_task(self, task: Task) -> TaskResponse: "/tasks", TaskResponse, method="POST", - data=task.dict(), + data=task.model_dump(), ) def clear_task(self, task_id: str) -> TaskResponse: @@ -91,7 +91,7 @@ def update_worker_task(self, task: WorkerTask) -> WorkerTask: "/worker/task", WorkerTask, method="PUT", - data=task.dict(), + data=task.model_dump(), ) def cancel_current_task( @@ -130,7 +130,7 @@ def _request_and_deserialize( exception = get_exception(response) if exception is not None: raise exception - deserialized = parse_obj_as(target_type, response.json()) + deserialized = TypeAdapter(target_type).validate_python(response.json()) return deserialized def _url(self, suffix: str) -> str: diff --git a/src/blueapi/config.py b/src/blueapi/config.py index 60f929761..98d0ef1e9 100644 --- a/src/blueapi/config.py +++ b/src/blueapi/config.py @@ -5,7 +5,7 @@ from typing import Any, Generic, Literal, TypeVar import yaml -from pydantic import BaseModel, Field, ValidationError, parse_obj_as, validator +from pydantic import BaseModel, Field, TypeAdapter, ValidationError, field_validator from blueapi.utils import BlueapiBaseModel, InvalidConfigError @@ -34,7 +34,8 @@ class BasicAuthentication(BaseModel): username: str = "guest" passcode: str = "guest" - @validator("username", "passcode") + @field_validator("username", "passcode") + @classmethod def get_from_env(cls, v: str): if v.startswith("${") and v.endswith("}"): return os.environ[v.removeprefix("${").removesuffix("}").upper()] @@ -129,12 +130,9 @@ class ConfigLoader(Generic[C]): of default values, dictionaries, YAML/JSON files etc. """ - _schema: type[C] - _values: dict[str, Any] - def __init__(self, schema: type[C]) -> None: - self._schema = schema - self._values = {} + self._adapter = TypeAdapter(schema) + self._values: dict[str, Any] = {} def use_values(self, values: Mapping[str, Any]) -> None: """ @@ -184,7 +182,7 @@ def load(self) -> C: """ try: - return parse_obj_as(self._schema, self._values) + return self._adapter.validate_python(self._values) except ValidationError as exc: raise InvalidConfigError( "Something is wrong with the configuration file: \n" diff --git a/src/blueapi/core/context.py b/src/blueapi/core/context.py index 53e6e0d48..429ef4d77 100644 --- a/src/blueapi/core/context.py +++ b/src/blueapi/core/context.py @@ -4,7 +4,15 @@ from importlib import import_module from inspect import Parameter, signature from types import ModuleType, UnionType -from typing import Any, Generic, TypeVar, Union, get_args, get_origin, get_type_hints +from typing import ( + Any, + Generic, + TypeVar, + Union, + get_args, + get_origin, + get_type_hints, +) from bluesky.run_engine import RunEngine from dodal.utils import make_all_devices @@ -217,7 +225,7 @@ def __get_pydantic_json_schema__( def _type_spec_for_function( self, func: Callable[..., Any] - ) -> dict[str, tuple[type, Any]]: + ) -> dict[str, tuple[type, FieldInfo]]: """ Parse a function signature and build map of field types and default values that can be used to deserialise arguments from external sources. @@ -234,7 +242,7 @@ def _type_spec_for_function( """ args = signature(func).parameters types = get_type_hints(func) - new_args = {} + new_args: dict[str, tuple[type, FieldInfo]] = {} for name, para in args.items(): arg_type = types.get(name, Parameter.empty) if arg_type is Parameter.empty: diff --git a/src/blueapi/utils/serialization.py b/src/blueapi/utils/serialization.py index 5298b407e..0be58c815 100644 --- a/src/blueapi/utils/serialization.py +++ b/src/blueapi/utils/serialization.py @@ -19,7 +19,7 @@ def serialize(obj: Any) -> Any: if isinstance(obj, BaseModel): # Serialize by alias so that our camelCase models leave the service # with camelCase field names - return obj.dict(by_alias=True) + return obj.model_dump(by_alias=True) elif hasattr(obj, "__pydantic_model__"): return serialize(obj.__pydantic_model__) else: diff --git a/tests/core/test_context.py b/tests/core/test_context.py index 2557415bf..d43e9c974 100644 --- a/tests/core/test_context.py +++ b/tests/core/test_context.py @@ -7,7 +7,7 @@ from bluesky.protocols import Descriptor, Movable, Readable, Reading, SyncOrAsync from dls_bluesky_core.core import MsgGenerator, PlanGenerator, inject from ophyd.sim import SynAxis, SynGauss -from pydantic import ValidationError, parse_obj_as +from pydantic import TypeAdapter, ValidationError from pytest import LogCaptureFixture from blueapi.config import EnvironmentConfig, Source, SourceKind @@ -366,13 +366,14 @@ def test_str_default( spec = empty_context._type_spec_for_function(has_default_reference) assert spec["m"][0] is movable_ref - assert spec["m"][1].default_factory() == SIM_MOTOR_NAME + assert (df := spec["m"][1].default_factory) and df() == SIM_MOTOR_NAME assert has_default_reference.__name__ in empty_context.plans model = empty_context.plans[has_default_reference.__name__].model - assert parse_obj_as(model, {}).m is sim_motor # type: ignore + adapter = TypeAdapter(model) + assert adapter.validate_python({}).m is sim_motor # type: ignore empty_context.device(alt_motor) - assert parse_obj_as(model, {"m": ALT_MOTOR_NAME}).m is alt_motor # type: ignore + assert adapter.validate_python({"m": ALT_MOTOR_NAME}).m is alt_motor # type: ignore def test_nested_str_default( @@ -384,13 +385,15 @@ def test_nested_str_default( spec = empty_context._type_spec_for_function(has_default_nested_reference) assert spec["m"][0] == list[movable_ref] # type: ignore - assert spec["m"][1].default_factory() == [SIM_MOTOR_NAME] + assert (df := spec["m"][1].default_factory) and df() == [SIM_MOTOR_NAME] assert has_default_nested_reference.__name__ in empty_context.plans model = empty_context.plans[has_default_nested_reference.__name__].model - assert parse_obj_as(model, {}).m == [sim_motor] # type: ignore + adapter = TypeAdapter(model) + + assert adapter.validate_python({}).m == [sim_motor] # type: ignore empty_context.device(alt_motor) - assert parse_obj_as(model, {"m": [ALT_MOTOR_NAME]}).m == [alt_motor] # type: ignore + assert adapter.validate_python({"m": [ALT_MOTOR_NAME]}).m == [alt_motor] # type: ignore def test_plan_models_not_auto_camelcased(empty_context: BlueskyContext) -> None: diff --git a/tests/service/test_rest_api.py b/tests/service/test_rest_api.py index 86b22b9c7..3342006d2 100644 --- a/tests/service/test_rest_api.py +++ b/tests/service/test_rest_api.py @@ -152,7 +152,7 @@ def test_create_task( submit_task_mock.return_value = task_id - response = client.post("/tasks", json=task.dict()) + response = client.post("/tasks", json=task.model_dump()) submit_task_mock.assert_called_once_with(task) assert response.json() == {"task_id": task_id} @@ -311,7 +311,7 @@ def test_set_active_task( task_id = str(uuid.uuid4()) task = WorkerTask(task_id=task_id) - response = client.put("/worker/task", json=task.dict()) + response = client.put("/worker/task", json=task.model_dump()) assert response.status_code == status.HTTP_200_OK assert response.json() == {"task_id": f"{task_id}"} @@ -332,7 +332,7 @@ def test_set_active_task_active_task_complete( is_pending=False, ) - response = client.put("/worker/task", json=task.dict()) + response = client.put("/worker/task", json=task.model_dump()) assert response.status_code == status.HTTP_200_OK assert response.json() == {"task_id": f"{task_id}"} @@ -353,7 +353,7 @@ def test_set_active_task_worker_already_running( is_pending=False, ) - response = client.put("/worker/task", json=task.dict()) + response = client.put("/worker/task", json=task.model_dump()) assert response.status_code == status.HTTP_409_CONFLICT assert response.json() == {"detail": "Worker already active"} @@ -430,7 +430,7 @@ def test_set_state_running_to_paused( get_worker_state_mock.side_effect = [current_state, final_state] response = client.put( - "/worker/state", json=StateChangeRequest(new_state=final_state).dict() + "/worker/state", json=StateChangeRequest(new_state=final_state).model_dump() ) pause_worker_mock.assert_called_once_with(False) @@ -448,7 +448,7 @@ def test_set_state_paused_to_running( get_worker_state_mock.side_effect = [current_state, final_state] response = client.put( - "/worker/state", json=StateChangeRequest(new_state=final_state).dict() + "/worker/state", json=StateChangeRequest(new_state=final_state).model_dump() ) resume_worker_mock.assert_called_once() @@ -468,7 +468,7 @@ def test_set_state_running_to_aborting( get_worker_state_mock.side_effect = [current_state, final_state] response = client.put( - "/worker/state", json=StateChangeRequest(new_state=final_state).dict() + "/worker/state", json=StateChangeRequest(new_state=final_state).model_dump() ) cancel_active_task_mock.assert_called_once_with(True, None) @@ -490,7 +490,7 @@ def test_set_state_running_to_stopping_including_reason( response = client.put( "/worker/state", - json=StateChangeRequest(new_state=final_state, reason=reason).dict(), + json=StateChangeRequest(new_state=final_state, reason=reason).model_dump(), ) cancel_active_task_mock.assert_called_once_with(False, reason) @@ -514,7 +514,7 @@ def test_set_state_transition_error( response = client.put( "/worker/state", - json=StateChangeRequest(new_state=final_state).dict(), + json=StateChangeRequest(new_state=final_state).model_dump(), ) assert response.status_code == status.HTTP_400_BAD_REQUEST @@ -533,7 +533,7 @@ def test_set_state_invalid_transition( response = client.put( "/worker/state", - json=StateChangeRequest(new_state=requested_state).dict(), + json=StateChangeRequest(new_state=requested_state).model_dump(), ) assert response.status_code == status.HTTP_400_BAD_REQUEST diff --git a/tests/test_cli.py b/tests/test_cli.py index 8c75cced7..0c6724ed8 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -84,7 +84,7 @@ def test_get_plans(runner: CliRunner): response = responses.add( responses.GET, "http://localhost:8000/plans", - json=PlanResponse(plans=[PlanModel.from_plan(plan)]).dict(), + json=PlanResponse(plans=[PlanModel.from_plan(plan)]).model_dump(), status=200, ) @@ -100,7 +100,7 @@ def test_get_devices(runner: CliRunner): response = responses.add( responses.GET, "http://localhost:8000/devices", - json=DeviceResponse(devices=[DeviceModel.from_device(device)]).dict(), + json=DeviceResponse(devices=[DeviceModel.from_device(device)]).model_dump(), status=200, ) @@ -179,7 +179,7 @@ def test_get_env( responses.add( responses.GET, "http://localhost:8000/environment", - json=EnvironmentResponse(initialized=True).dict(), + json=EnvironmentResponse(initialized=True).model_dump(), status=200, ) @@ -196,7 +196,7 @@ def test_reset_env_client_behavior( responses.add( responses.DELETE, "http://localhost:8000/environment", - json=EnvironmentResponse(initialized=False).dict(), + json=EnvironmentResponse(initialized=False).model_dump(), status=200, ) @@ -206,7 +206,7 @@ def test_reset_env_client_behavior( responses.add( responses.GET, "http://localhost:8000/environment", - json=EnvironmentResponse(initialized=state).dict(), + json=EnvironmentResponse(initialized=state).model_dump(), status=200, ) @@ -241,13 +241,13 @@ def test_env_timeout(mock_sleep: Mock, runner: CliRunner): responses.DELETE, "http://localhost:8000/environment", status=200, - json=EnvironmentResponse(initialized=False).dict(), + json=EnvironmentResponse(initialized=False).model_dump(), ) # Add responses for each polling attempt, all indicating not initialized responses.add( responses.GET, "http://localhost:8000/environment", - json=EnvironmentResponse(initialized=False).dict(), + json=EnvironmentResponse(initialized=False).model_dump(), status=200, )