diff --git a/pyproject.toml b/pyproject.toml index a84e0b62b..be104f7f6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -136,15 +136,19 @@ commands = src = ["src", "tests"] line-length = 88 lint.select = [ - "B", # flake8-bugbear - https://docs.astral.sh/ruff/rules/#flake8-bugbear-b - "C4", # flake8-comprehensions - https://docs.astral.sh/ruff/rules/#flake8-comprehensions-c4 - "E", # pycodestyle errors - https://docs.astral.sh/ruff/rules/#error-e - "F", # pyflakes rules - https://docs.astral.sh/ruff/rules/#pyflakes-f - "W", # pycodestyle warnings - https://docs.astral.sh/ruff/rules/#warning-w - "I", # isort - https://docs.astral.sh/ruff/rules/#isort-i - "UP", # pyupgrade - https://docs.astral.sh/ruff/rules/#pyupgrade-up + "B", # flake8-bugbear - https://docs.astral.sh/ruff/rules/#flake8-bugbear-b + "C4", # flake8-comprehensions - https://docs.astral.sh/ruff/rules/#flake8-comprehensions-c4 + "TID252", # flake8-tidy-imports - https://docs.astral.sh/ruff/rules/relative-imports/ + "E", # pycodestyle errors - https://docs.astral.sh/ruff/rules/#error-e + "F", # pyflakes rules - https://docs.astral.sh/ruff/rules/#pyflakes-f + "W", # pycodestyle warnings - https://docs.astral.sh/ruff/rules/#warning-w + "I", # isort - https://docs.astral.sh/ruff/rules/#isort-i + "UP", # pyupgrade - https://docs.astral.sh/ruff/rules/#pyupgrade-up ] +[tool.ruff.lint.flake8-tidy-imports] +ban-relative-imports = "all" + [tool.ruff.lint.flake8-bugbear] extend-immutable-calls = [ "fastapi.Depends", diff --git a/src/blueapi/__init__.py b/src/blueapi/__init__.py index 26d23badb..38d66be25 100644 --- a/src/blueapi/__init__.py +++ b/src/blueapi/__init__.py @@ -1,3 +1,3 @@ -from ._version import __version__ +from blueapi._version import __version__ __all__ = ["__version__"] diff --git a/src/blueapi/__main__.py b/src/blueapi/__main__.py index ac1539475..92fb6c4fc 100644 --- a/src/blueapi/__main__.py +++ b/src/blueapi/__main__.py @@ -1,4 +1,4 @@ -from .cli.cli import main +from blueapi.cli.cli import main # test with: python -m blueapi if __name__ == "__main__": diff --git a/src/blueapi/cli/__init__.py b/src/blueapi/cli/__init__.py index ed32c05eb..5e0c19132 100644 --- a/src/blueapi/cli/__init__.py +++ b/src/blueapi/cli/__init__.py @@ -1,3 +1,3 @@ -from .cli import main +from blueapi.cli.cli import main __all__ = ["main"] diff --git a/src/blueapi/cli/cli.py b/src/blueapi/cli/cli.py index c98800c20..a8cbe2433 100644 --- a/src/blueapi/cli/cli.py +++ b/src/blueapi/cli/cli.py @@ -15,6 +15,8 @@ from blueapi import __version__ from blueapi.cli.format import OutputFormat +from blueapi.cli.scratch import setup_scratch +from blueapi.cli.updates import CliEventRenderer from blueapi.client.client import BlueapiClient from blueapi.client.event_bus import AnyEvent, BlueskyStreamingError, EventBusClient from blueapi.client.rest import BlueskyRemoteControlError @@ -22,9 +24,6 @@ from blueapi.core import OTLP_EXPORT_ENABLED, DataEvent from blueapi.worker import ProgressEvent, Task, WorkerEvent -from .scratch import setup_scratch -from .updates import CliEventRenderer - @click.group(invoke_without_command=True) @click.version_option(version=__version__, prog_name="blueapi") diff --git a/src/blueapi/client/client.py b/src/blueapi/client/client.py index 2e805ea74..72f7c6eee 100644 --- a/src/blueapi/client/client.py +++ b/src/blueapi/client/client.py @@ -8,6 +8,13 @@ start_as_current_span, ) +from blueapi.client.event_bus import ( + AnyEvent, + BlueskyStreamingError, + EventBusClient, + OnAnyEvent, +) +from blueapi.client.rest import BlueapiRestClient, BlueskyRemoteControlError from blueapi.config import ApplicationConfig from blueapi.core.bluesky_types import DataEvent from blueapi.service.model import ( @@ -23,9 +30,6 @@ from blueapi.worker import Task, TrackableTask, WorkerEvent, WorkerState from blueapi.worker.event import ProgressEvent, TaskStatus -from .event_bus import AnyEvent, BlueskyStreamingError, EventBusClient, OnAnyEvent -from .rest import BlueapiRestClient, BlueskyRemoteControlError - TRACER = get_tracer("client") diff --git a/src/blueapi/core/__init__.py b/src/blueapi/core/__init__.py index 15e3b2602..4f1892641 100644 --- a/src/blueapi/core/__init__.py +++ b/src/blueapi/core/__init__.py @@ -1,7 +1,7 @@ from os import environ -from .bluesky_event_loop import configure_bluesky_event_loop -from .bluesky_types import ( +from blueapi.core.bluesky_event_loop import configure_bluesky_event_loop +from blueapi.core.bluesky_types import ( BLUESKY_PROTOCOLS, DataEvent, Device, @@ -13,8 +13,8 @@ is_bluesky_compatible_device_type, is_bluesky_plan_generator, ) -from .context import BlueskyContext -from .event import EventPublisher, EventStream +from blueapi.core.context import BlueskyContext +from blueapi.core.event import EventPublisher, EventStream OTLP_EXPORT_ENABLED = environ.get("OTLP_EXPORT_ENABLED") == "true" diff --git a/src/blueapi/core/bluesky_types.py b/src/blueapi/core/bluesky_types.py index 1a2978213..5b59ae016 100644 --- a/src/blueapi/core/bluesky_types.py +++ b/src/blueapi/core/bluesky_types.py @@ -53,7 +53,7 @@ ) #: Protocols defining interface to hardware -BLUESKY_PROTOCOLS = list(Device.__args__) # type: ignore +BLUESKY_PROTOCOLS = list(Device.__args__) def is_bluesky_compatible_device(obj: Any) -> bool: diff --git a/src/blueapi/core/context.py b/src/blueapi/core/context.py index c92c38113..3a13703ba 100644 --- a/src/blueapi/core/context.py +++ b/src/blueapi/core/context.py @@ -23,9 +23,7 @@ from pydantic_core import CoreSchema, core_schema from blueapi.config import EnvironmentConfig, SourceKind -from blueapi.utils import BlueapiPlanModelConfig, load_module_all - -from .bluesky_types import ( +from blueapi.core.bluesky_types import ( BLUESKY_PROTOCOLS, Device, HasName, @@ -34,7 +32,8 @@ is_bluesky_compatible_device, is_bluesky_plan_generator, ) -from .device_lookup import find_component +from blueapi.core.device_lookup import find_component +from blueapi.utils import BlueapiPlanModelConfig, load_module_all LOGGER = logging.getLogger(__name__) @@ -61,7 +60,7 @@ def find_device(self, addr: str | list[str]) -> Device | None: Find a device in this context, allows for recursive search. Args: - addr (Union[str, List[str]]): Address of the device, examples: + addr (str | list[str]): Address of the device, examples: "motors", "motors.x" Returns: diff --git a/src/blueapi/core/device_lookup.py b/src/blueapi/core/device_lookup.py index 1bace1676..6ee7c6c98 100644 --- a/src/blueapi/core/device_lookup.py +++ b/src/blueapi/core/device_lookup.py @@ -1,6 +1,6 @@ from typing import Any, TypeVar -from .bluesky_types import Device, is_bluesky_compatible_device +from blueapi.core.bluesky_types import Device, is_bluesky_compatible_device #: Device obeying Bluesky protocols D = TypeVar("D", bound=Device) diff --git a/src/blueapi/service/__init__.py b/src/blueapi/service/__init__.py index 7c2fa404c..17b04729a 100644 --- a/src/blueapi/service/__init__.py +++ b/src/blueapi/service/__init__.py @@ -1,3 +1,3 @@ -from .model import DeviceModel, PlanModel +from blueapi.service.model import DeviceModel, PlanModel __all__ = ["PlanModel", "DeviceModel"] diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index 11ad7271f..add1df9bb 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -27,10 +27,7 @@ from blueapi.config import ApplicationConfig from blueapi.service import interface -from blueapi.worker import Task, TrackableTask, WorkerState -from blueapi.worker.event import TaskStatusEnum - -from .model import ( +from blueapi.service.model import ( DeviceModel, DeviceResponse, EnvironmentResponse, @@ -41,7 +38,9 @@ TasksListResponse, WorkerTask, ) -from .runner import WorkerDispatcher +from blueapi.service.runner import WorkerDispatcher +from blueapi.worker import Task, TrackableTask, WorkerState +from blueapi.worker.event import TaskStatusEnum REST_API_VERSION = "0.0.5" diff --git a/src/blueapi/startup/example_devices.py b/src/blueapi/startup/example_devices.py index a472137dd..3397a34ab 100644 --- a/src/blueapi/startup/example_devices.py +++ b/src/blueapi/startup/example_devices.py @@ -1,6 +1,6 @@ from ophyd.sim import Syn2DGauss, SynGauss, SynSignal -from .simmotor import BrokenSynAxis, SynAxisWithMotionEvents +from blueapi.startup.simmotor import BrokenSynAxis, SynAxisWithMotionEvents def x(name="x") -> SynAxisWithMotionEvents: diff --git a/src/blueapi/utils/__init__.py b/src/blueapi/utils/__init__.py index b871f842a..dd2f7c307 100644 --- a/src/blueapi/utils/__init__.py +++ b/src/blueapi/utils/__init__.py @@ -1,13 +1,16 @@ -from .base_model import BlueapiBaseModel, BlueapiModelConfig, BlueapiPlanModelConfig -from .invalid_config_error import InvalidConfigError -from .modules import load_module_all -from .serialization import serialize -from .thread_exception import handle_all_exceptions +from blueapi.utils.base_model import ( + BlueapiBaseModel, + BlueapiModelConfig, + BlueapiPlanModelConfig, +) +from blueapi.utils.invalid_config_error import InvalidConfigError +from blueapi.utils.modules import load_module_all +from blueapi.utils.serialization import serialize +from blueapi.utils.thread_exception import handle_all_exceptions __all__ = [ "handle_all_exceptions", "load_module_all", - "ConfigLoader", "serialize", "BlueapiBaseModel", "BlueapiModelConfig", diff --git a/src/blueapi/worker/__init__.py b/src/blueapi/worker/__init__.py index 7862912cc..3acabda09 100644 --- a/src/blueapi/worker/__init__.py +++ b/src/blueapi/worker/__init__.py @@ -1,12 +1,17 @@ -from .event import ProgressEvent, StatusView, TaskStatus, WorkerEvent, WorkerState -from .task import Task -from .task_worker import TaskWorker, TrackableTask -from .worker_errors import WorkerAlreadyStartedError, WorkerBusyError +from blueapi.worker.event import ( + ProgressEvent, + StatusView, + TaskStatus, + WorkerEvent, + WorkerState, +) +from blueapi.worker.task import Task +from blueapi.worker.task_worker import TaskWorker, TrackableTask +from blueapi.worker.worker_errors import WorkerAlreadyStartedError, WorkerBusyError __all__ = [ "TaskWorker", "Task", - "Worker", "WorkerEvent", "WorkerState", "StatusView", diff --git a/src/blueapi/worker/task_worker.py b/src/blueapi/worker/task_worker.py index d9aa38e86..0d5ef7b1a 100644 --- a/src/blueapi/worker/task_worker.py +++ b/src/blueapi/worker/task_worker.py @@ -32,8 +32,7 @@ from blueapi.core.bluesky_event_loop import configure_bluesky_event_loop from blueapi.utils.base_model import BlueapiBaseModel from blueapi.utils.thread_exception import handle_all_exceptions - -from .event import ( +from blueapi.worker.event import ( ProgressEvent, RawRunEngineState, StatusView, @@ -42,8 +41,8 @@ WorkerEvent, WorkerState, ) -from .task import Task -from .worker_errors import WorkerAlreadyStartedError, WorkerBusyError +from blueapi.worker.task import Task +from blueapi.worker.worker_errors import WorkerAlreadyStartedError, WorkerBusyError LOGGER = logging.getLogger(__name__) TRACER = get_tracer("task_worker") @@ -86,7 +85,7 @@ class TaskWorker: _state: WorkerState _errors: list[str] _warnings: list[str] - _task_channel: Queue # type: ignore + _task_channel: Queue _current: TrackableTask | None _status_lock: RLock _status_snapshot: dict[str, StatusView] @@ -474,7 +473,7 @@ def on_complete(status: Status) -> None: del self._status_snapshot[status_uuid] self._completed_statuses.add(status_uuid) - status.add_callback(on_complete) # type: ignore + status.add_callback(on_complete) def _on_status_event( self, diff --git a/tests/conftest.py b/tests/conftest.py index 8f311754b..9bc753b51 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,6 @@ import asyncio from typing import cast -# Based on https://docs.pytest.org/en/latest/example/simple.html#control-skipping-of-tests-according-to-command-line-option # noqa: E501 import pytest from bluesky import RunEngine from bluesky.run_engine import TransitionError @@ -12,7 +11,7 @@ @pytest.fixture(scope="function") -def RE(request): +def RE(request: pytest.FixtureRequest) -> RunEngine: loop = asyncio.new_event_loop() loop.set_debug(True) RE = RunEngine({}, call_returns_result=True, loop=loop) diff --git a/tests/system_tests/test_blueapi_system.py b/tests/system_tests/test_blueapi_system.py index 4297598bc..82fb99ce8 100644 --- a/tests/system_tests/test_blueapi_system.py +++ b/tests/system_tests/test_blueapi_system.py @@ -66,7 +66,7 @@ def test_get_plans_by_name(client: BlueapiClient, expected_plans: PlanResponse): def test_get_non_existent_plan(client: BlueapiClient): with pytest.raises(KeyError) as exception: client.get_plan("Not exists") - assert str(exception) == ("{'detail': 'Item not found'}") + assert str(exception) == ("{'detail': 'Item not found'}") def test_get_devices(client: BlueapiClient, expected_devices: DeviceResponse): @@ -80,8 +80,8 @@ def test_get_device_by_name(client: BlueapiClient, expected_devices: DeviceRespo def test_get_non_existent_device(client: BlueapiClient): with pytest.raises(KeyError) as exception: - assert client.get_device("Not exists") - assert str(exception) == ("{'detail': 'Item not found'}") + client.get_device("Not exists") + assert str(exception) == ("{'detail': 'Item not found'}") def test_create_task_and_delete_task_by_id(client: BlueapiClient): @@ -92,7 +92,7 @@ def test_create_task_and_delete_task_by_id(client: BlueapiClient): def test_create_task_validation_error(client: BlueapiClient): with pytest.raises(KeyError) as exception: client.create_task(Task(name="Not-exists", params={"Not-exists": 0.0})) - assert str(exception) == ("{'detail': 'Item not found'}") + assert str(exception) == ("{'detail': 'Item not found'}") def test_get_all_tasks(client: BlueapiClient): @@ -128,13 +128,13 @@ def test_get_task_by_id(client: BlueapiClient): def test_get_non_existent_task(client: BlueapiClient): with pytest.raises(KeyError) as exception: client.get_task("Not-exists") - assert str(exception) == "{'detail': 'Item not found'}" + assert str(exception) == "{'detail': 'Item not found'}" def test_delete_non_existent_task(client: BlueapiClient): with pytest.raises(KeyError) as exception: client.clear_task("Not-exists") - assert str(exception) == "{'detail': 'Item not found'}" + assert str(exception) == "{'detail': 'Item not found'}" def test_put_worker_task(client: BlueapiClient): @@ -155,7 +155,7 @@ def test_put_worker_task_fails_if_not_idle(client: BlueapiClient): with pytest.raises(BlueskyRemoteControlError) as exception: client.start_task(WorkerTask(task_id=small_task.task_id)) - assert str(exception) == "" + assert str(exception) == "" client.abort() client.clear_task(small_task.task_id) client.clear_task(long_task.task_id) @@ -168,11 +168,11 @@ def test_get_worker_state(client: BlueapiClient): def test_set_state_transition_error(client: BlueapiClient): with pytest.raises(BlueskyRemoteControlError) as exception: client.resume() - assert str(exception) == "" + assert str(exception) == "" with pytest.raises(BlueskyRemoteControlError) as exception: client.pause() - assert str(exception) == "" + assert str(exception) == "" def test_get_task_by_status(client: BlueapiClient): diff --git a/tests/unit_tests/cli/test_scratch.py b/tests/unit_tests/cli/test_scratch.py index 7e4ca3b9e..bba43d75b 100644 --- a/tests/unit_tests/cli/test_scratch.py +++ b/tests/unit_tests/cli/test_scratch.py @@ -1,8 +1,8 @@ import os import stat import uuid +from collections.abc import Generator from pathlib import Path -from tempfile import TemporaryDirectory from unittest.mock import Mock, call, patch import pytest @@ -12,15 +12,8 @@ @pytest.fixture -def directory_path() -> Path: # type: ignore - temporary_directory = TemporaryDirectory() - yield Path(temporary_directory.name) - temporary_directory.cleanup() - - -@pytest.fixture -def file_path(directory_path: Path) -> Path: # type: ignore - file_path = directory_path / str(uuid.uuid4()) +def file_path(tmp_path: Path) -> Generator[Path]: + file_path = tmp_path / str(uuid.uuid4()) with file_path.open("w") as stream: stream.write("foo") yield file_path @@ -28,8 +21,8 @@ def file_path(directory_path: Path) -> Path: # type: ignore @pytest.fixture -def nonexistant_path(directory_path: Path) -> Path: - file_path = directory_path / str(uuid.uuid4()) +def nonexistant_path(tmp_path: Path) -> Path: + file_path = tmp_path / str(uuid.uuid4()) assert not file_path.exists() return file_path @@ -37,13 +30,13 @@ def nonexistant_path(directory_path: Path) -> Path: @patch("blueapi.cli.scratch.Popen") def test_scratch_install_installs_path( mock_popen: Mock, - directory_path: Path, + tmp_path: Path, ): mock_process = Mock() mock_process.returncode = 0 mock_popen.return_value = mock_process - scratch_install(directory_path, timeout=1.0) + scratch_install(tmp_path, timeout=1.0) mock_popen.assert_called_once_with( [ @@ -53,7 +46,7 @@ def test_scratch_install_installs_path( "install", "--no-deps", "-e", - str(directory_path), + str(tmp_path), ] ) @@ -72,7 +65,7 @@ def test_scratch_install_fails_on_nonexistant_path(nonexistant_path: Path): @pytest.mark.parametrize("code", [1, 2, 65536]) def test_scratch_install_fails_on_non_zero_exit_code( mock_popen: Mock, - directory_path: Path, + tmp_path: Path, code: int, ): mock_process = Mock() @@ -80,16 +73,16 @@ def test_scratch_install_fails_on_non_zero_exit_code( mock_popen.return_value = mock_process with pytest.raises(RuntimeError): - scratch_install(directory_path, timeout=1.0) + scratch_install(tmp_path, timeout=1.0) @patch("blueapi.cli.scratch.Repo") def test_repo_not_cloned_and_validated_if_found_locally( mock_repo: Mock, - directory_path: Path, + tmp_path: Path, ): - ensure_repo("http://example.com/foo.git", directory_path) - mock_repo.assert_called_once_with(directory_path) + ensure_repo("http://example.com/foo.git", tmp_path) + mock_repo.assert_called_once_with(tmp_path) mock_repo.clone_from.assert_not_called() @@ -108,9 +101,9 @@ def test_repo_cloned_if_not_found_locally( @patch("blueapi.cli.scratch.Repo") def test_repo_cloned_with_correct_umask( mock_repo: Mock, - directory_path: Path, + tmp_path: Path, ): - repo_root = directory_path / "foo" + repo_root = tmp_path / "foo" file_path = repo_root / "a" def write_repo_files(): @@ -153,10 +146,10 @@ def test_setup_scratch_fails_on_non_directory_root( def test_setup_scratch_iterates_repos( mock_scratch_install: Mock, mock_ensure_repo: Mock, - directory_path: Path, + tmp_path: Path, ): config = ScratchConfig( - root=directory_path, + root=tmp_path, repositories=[ ScratchRepository( name="foo", @@ -172,15 +165,15 @@ def test_setup_scratch_iterates_repos( mock_ensure_repo.assert_has_calls( [ - call("http://example.com/foo.git", directory_path / "foo"), - call("http://example.com/bar.git", directory_path / "bar"), + call("http://example.com/foo.git", tmp_path / "foo"), + call("http://example.com/bar.git", tmp_path / "bar"), ] ) mock_scratch_install.assert_has_calls( [ - call(directory_path / "foo", timeout=120.0), - call(directory_path / "bar", timeout=120.0), + call(tmp_path / "foo", timeout=120.0), + call(tmp_path / "bar", timeout=120.0), ] ) @@ -190,10 +183,10 @@ def test_setup_scratch_iterates_repos( def test_setup_scratch_continues_after_failure( mock_scratch_install: Mock, mock_ensure_repo: Mock, - directory_path: Path, + tmp_path: Path, ): config = ScratchConfig( - root=directory_path, + root=tmp_path, repositories=[ ScratchRepository( name="foo", diff --git a/tests/unit_tests/client/test_client.py b/tests/unit_tests/client/test_client.py index bebda69b0..7e9acc2d6 100644 --- a/tests/unit_tests/client/test_client.py +++ b/tests/unit_tests/client/test_client.py @@ -1,5 +1,6 @@ from collections.abc import Callable -from unittest.mock import MagicMock, Mock, call +from typing import cast +from unittest.mock import MagicMock, call import pytest from bluesky_stomp.messaging import MessageContext @@ -63,7 +64,7 @@ @pytest.fixture def mock_rest() -> BlueapiRestClient: - mock = Mock(spec=BlueapiRestClient) + mock = MagicMock(spec=BlueapiRestClient) mock.get_plans.return_value = PLANS mock.get_plan.return_value = PLAN @@ -81,20 +82,22 @@ def mock_rest() -> BlueapiRestClient: @pytest.fixture def mock_events() -> EventBusClient: - mock_events = MagicMock(spec=EventBusClient) - ctx = Mock() + mock_events: EventBusClient = MagicMock(spec=EventBusClient) + ctx = MagicMock() ctx.correlation_id = "foo" - mock_events.subscribe_to_all_events = lambda on_event: on_event(ctx, COMPLETE_EVENT) + cast(MagicMock, mock_events).subscribe_to_all_events = lambda on_event: on_event( + ctx, COMPLETE_EVENT + ) return mock_events @pytest.fixture -def client(mock_rest: Mock) -> BlueapiClient: +def client(mock_rest: BlueapiRestClient) -> BlueapiClient: return BlueapiClient(rest=mock_rest) @pytest.fixture -def client_with_events(mock_rest: Mock, mock_events: MagicMock): +def client_with_events(mock_rest: BlueapiRestClient, mock_events: EventBusClient): return BlueapiClient(rest=mock_rest, events=mock_events) @@ -108,9 +111,9 @@ def test_get_plan(client: BlueapiClient): def test_get_nonexistant_plan( client: BlueapiClient, - mock_rest: Mock, + mock_rest: BlueapiRestClient, ): - mock_rest.get_plan.side_effect = KeyError("Not found") + cast(MagicMock, mock_rest.get_plan).side_effect = KeyError("Not found") with pytest.raises(KeyError): client.get_plan("baz") @@ -125,9 +128,9 @@ def test_get_device(client: BlueapiClient): def test_get_nonexistant_device( client: BlueapiClient, - mock_rest: Mock, + mock_rest: BlueapiRestClient, ): - mock_rest.get_device.side_effect = KeyError("Not found") + cast(MagicMock, mock_rest.get_device).side_effect = KeyError("Not found") with pytest.raises(KeyError): client.get_device("baz") @@ -142,9 +145,9 @@ def test_get_task(client: BlueapiClient): def test_get_nonexistent_task( client: BlueapiClient, - mock_rest: Mock, + mock_rest: BlueapiRestClient, ): - mock_rest.get_task.side_effect = KeyError("Not found") + cast(MagicMock, mock_rest.get_task).side_effect = KeyError("Not found") with pytest.raises(KeyError): client.get_task("baz") @@ -163,26 +166,26 @@ def test_get_all_tasks( def test_create_task( client: BlueapiClient, - mock_rest: Mock, + mock_rest: BlueapiRestClient, ): client.create_task(task=Task(name="foo")) - mock_rest.create_task.assert_called_once_with(Task(name="foo")) + cast(MagicMock, mock_rest.create_task).assert_called_once_with(Task(name="foo")) def test_create_task_does_not_start_task( client: BlueapiClient, - mock_rest: Mock, + mock_rest: BlueapiRestClient, ): client.create_task(task=Task(name="foo")) - mock_rest.update_worker_task.assert_not_called() + cast(MagicMock, mock_rest.update_worker_task).assert_not_called() def test_clear_task( client: BlueapiClient, - mock_rest: Mock, + mock_rest: BlueapiRestClient, ): client.clear_task(task_id="foo") - mock_rest.clear_task.assert_called_once_with("foo") + cast(MagicMock, mock_rest.clear_task).assert_called_once_with("foo") def test_get_active_task(client: BlueapiClient): @@ -191,57 +194,69 @@ def test_get_active_task(client: BlueapiClient): def test_start_task( client: BlueapiClient, - mock_rest: Mock, + mock_rest: BlueapiRestClient, ): client.start_task(task=WorkerTask(task_id="bar")) - mock_rest.update_worker_task.assert_called_once_with(WorkerTask(task_id="bar")) + cast(MagicMock, mock_rest.update_worker_task).assert_called_once_with( + WorkerTask(task_id="bar") + ) def test_start_nonexistant_task( client: BlueapiClient, - mock_rest: Mock, + mock_rest: BlueapiRestClient, ): - mock_rest.update_worker_task.side_effect = KeyError("Not found") + cast(MagicMock, mock_rest.update_worker_task).side_effect = KeyError("Not found") with pytest.raises(KeyError): client.start_task(task=WorkerTask(task_id="bar")) def test_create_and_start_task_calls_both_creating_and_starting_endpoints( client: BlueapiClient, - mock_rest: Mock, + mock_rest: BlueapiRestClient, ): - mock_rest.create_task.return_value = TaskResponse(task_id="baz") - mock_rest.update_worker_task.return_value = TaskResponse(task_id="baz") + cast(MagicMock, mock_rest.create_task).return_value = TaskResponse(task_id="baz") + cast(MagicMock, mock_rest.update_worker_task).return_value = TaskResponse( + task_id="baz" + ) client.create_and_start_task(Task(name="baz")) - mock_rest.create_task.assert_called_once_with(Task(name="baz")) - mock_rest.update_worker_task.assert_called_once_with(WorkerTask(task_id="baz")) + cast(MagicMock, mock_rest.create_task).assert_called_once_with(Task(name="baz")) + cast(MagicMock, mock_rest.update_worker_task).assert_called_once_with( + WorkerTask(task_id="baz") + ) def test_create_and_start_task_fails_if_task_creation_fails( client: BlueapiClient, - mock_rest: Mock, + mock_rest: BlueapiRestClient, ): - mock_rest.create_task.side_effect = BlueskyRemoteControlError("No can do") + cast(MagicMock, mock_rest.create_task).side_effect = BlueskyRemoteControlError( + "No can do" + ) with pytest.raises(BlueskyRemoteControlError): client.create_and_start_task(Task(name="baz")) def test_create_and_start_task_fails_if_task_id_is_wrong( client: BlueapiClient, - mock_rest: Mock, + mock_rest: BlueapiRestClient, ): - mock_rest.create_task.return_value = TaskResponse(task_id="baz") - mock_rest.update_worker_task.return_value = TaskResponse(task_id="bar") + cast(MagicMock, mock_rest.create_task).return_value = TaskResponse(task_id="baz") + cast(MagicMock, mock_rest.update_worker_task).return_value = TaskResponse( + task_id="bar" + ) with pytest.raises(BlueskyRemoteControlError): client.create_and_start_task(Task(name="baz")) def test_create_and_start_task_fails_if_task_start_fails( client: BlueapiClient, - mock_rest: Mock, + mock_rest: BlueapiRestClient, ): - mock_rest.create_task.return_value = TaskResponse(task_id="baz") - mock_rest.update_worker_task.side_effect = BlueskyRemoteControlError("No can do") + cast(MagicMock, mock_rest.create_task).return_value = TaskResponse(task_id="baz") + cast( + MagicMock, mock_rest.update_worker_task + ).side_effect = BlueskyRemoteControlError("No can do") with pytest.raises(BlueskyRemoteControlError): client.create_and_start_task(Task(name="baz")) @@ -252,18 +267,18 @@ def test_get_environment(client: BlueapiClient): def test_reload_environment( client: BlueapiClient, - mock_rest: Mock, + mock_rest: BlueapiRestClient, ): client.reload_environment() - mock_rest.get_environment.assert_called_once() - mock_rest.delete_environment.assert_called_once() + cast(MagicMock, mock_rest.get_environment).assert_called_once() + cast(MagicMock, mock_rest.delete_environment).assert_called_once() def test_reload_environment_failure( client: BlueapiClient, - mock_rest: Mock, + mock_rest: BlueapiRestClient, ): - mock_rest.get_environment.return_value = EnvironmentResponse( + cast(MagicMock, mock_rest.get_environment).return_value = EnvironmentResponse( initialized=False, error_message="foo" ) with pytest.raises(BlueskyRemoteControlError, match="foo"): @@ -272,10 +287,10 @@ def test_reload_environment_failure( def test_abort( client: BlueapiClient, - mock_rest: Mock, + mock_rest: BlueapiRestClient, ): client.abort(reason="foo") - mock_rest.cancel_current_task.assert_called_once_with( + cast(MagicMock, mock_rest.cancel_current_task).assert_called_once_with( WorkerState.ABORTING, reason="foo", ) @@ -283,18 +298,20 @@ def test_abort( def test_stop( client: BlueapiClient, - mock_rest: Mock, + mock_rest: BlueapiRestClient, ): client.stop() - mock_rest.cancel_current_task.assert_called_once_with(WorkerState.STOPPING) + cast(MagicMock, mock_rest.cancel_current_task).assert_called_once_with( + WorkerState.STOPPING + ) def test_pause( client: BlueapiClient, - mock_rest: Mock, + mock_rest: BlueapiRestClient, ): client.pause(defer=True) - mock_rest.set_state.assert_called_once_with( + cast(MagicMock, mock_rest.set_state).assert_called_once_with( WorkerState.PAUSED, defer=True, ) @@ -302,10 +319,10 @@ def test_pause( def test_resume( client: BlueapiClient, - mock_rest: Mock, + mock_rest: BlueapiRestClient, ): client.resume() - mock_rest.set_state.assert_called_once_with( + cast(MagicMock, mock_rest.set_state).assert_called_once_with( WorkerState.RUNNING, defer=False, ) @@ -321,33 +338,43 @@ def test_cannot_run_task_without_message_bus(client: BlueapiClient): def test_run_task_sets_up_control( client_with_events: BlueapiClient, - mock_rest: Mock, - mock_events: MagicMock, + mock_rest: BlueapiRestClient, + mock_events: EventBusClient, ): - mock_rest.create_task.return_value = TaskResponse(task_id="foo") - mock_rest.update_worker_task.return_value = TaskResponse(task_id="foo") - ctx = Mock() + cast(MagicMock, mock_rest.create_task).return_value = TaskResponse(task_id="foo") + cast(MagicMock, mock_rest.update_worker_task).return_value = TaskResponse( + task_id="foo" + ) + ctx = MagicMock() ctx.correlation_id = "foo" - mock_events.subscribe_to_all_events = lambda on_event: on_event(COMPLETE_EVENT, ctx) + cast(MagicMock, mock_events).subscribe_to_all_events = lambda on_event: on_event( + COMPLETE_EVENT, ctx + ) client_with_events.run_task(Task(name="foo")) - mock_rest.create_task.assert_called_once_with(Task(name="foo")) - mock_rest.update_worker_task.assert_called_once_with(WorkerTask(task_id="foo")) + cast(MagicMock, mock_rest.create_task).assert_called_once_with(Task(name="foo")) + cast(MagicMock, mock_rest.update_worker_task).assert_called_once_with( + WorkerTask(task_id="foo") + ) def test_run_task_fails_on_failing_event( client_with_events: BlueapiClient, - mock_rest: Mock, - mock_events: MagicMock, + mock_rest: BlueapiRestClient, + mock_events: EventBusClient, ): - mock_rest.create_task.return_value = TaskResponse(task_id="foo") - mock_rest.update_worker_task.return_value = TaskResponse(task_id="foo") + cast(MagicMock, mock_rest.create_task).return_value = TaskResponse(task_id="foo") + cast(MagicMock, mock_rest.update_worker_task).return_value = TaskResponse( + task_id="foo" + ) - ctx = Mock() + ctx = MagicMock() ctx.correlation_id = "foo" - mock_events.subscribe_to_all_events = lambda on_event: on_event(FAILED_EVENT, ctx) + cast(MagicMock, mock_events).subscribe_to_all_events = lambda on_event: on_event( + FAILED_EVENT, ctx + ) - on_event = Mock() + on_event = MagicMock() with pytest.raises(BlueskyStreamingError): client_with_events.run_task(Task(name="foo"), on_event=on_event) @@ -371,23 +398,25 @@ def test_run_task_fails_on_failing_event( ) def test_run_task_calls_event_callback( client_with_events: BlueapiClient, - mock_rest: Mock, - mock_events: MagicMock, + mock_rest: BlueapiRestClient, + mock_events: EventBusClient, test_event: AnyEvent, ): - mock_rest.create_task.return_value = TaskResponse(task_id="foo") - mock_rest.update_worker_task.return_value = TaskResponse(task_id="foo") + cast(MagicMock, mock_rest.create_task).return_value = TaskResponse(task_id="foo") + cast(MagicMock, mock_rest.update_worker_task).return_value = TaskResponse( + task_id="foo" + ) - ctx = Mock() + ctx = MagicMock() ctx.correlation_id = "foo" def callback(on_event: Callable[[AnyEvent, MessageContext], None]): on_event(test_event, ctx) on_event(COMPLETE_EVENT, ctx) - mock_events.subscribe_to_all_events = callback # type: ignore + cast(MagicMock, mock_events).subscribe_to_all_events = callback - mock_on_event = Mock() + mock_on_event = MagicMock() client_with_events.run_task(Task(name="foo"), on_event=mock_on_event) assert mock_on_event.mock_calls == [call(test_event), call(COMPLETE_EVENT)] @@ -410,23 +439,25 @@ def callback(on_event: Callable[[AnyEvent, MessageContext], None]): ) def test_run_task_ignores_non_matching_events( client_with_events: BlueapiClient, - mock_rest: Mock, - mock_events: MagicMock, + mock_rest: BlueapiRestClient, + mock_events: EventBusClient, test_event: AnyEvent, ): - mock_rest.create_task.return_value = TaskResponse(task_id="foo") # type: ignore - mock_rest.update_worker_task.return_value = TaskResponse(task_id="foo") # type: ignore + cast(MagicMock, mock_rest.create_task).return_value = TaskResponse(task_id="foo") + cast(MagicMock, mock_rest.update_worker_task).return_value = TaskResponse( + task_id="foo" + ) - ctx = Mock() + ctx = MagicMock() ctx.correlation_id = "foo" def callback(on_event: Callable[[AnyEvent, MessageContext], None]): - on_event(test_event, ctx) # type: ignore + on_event(test_event, ctx) on_event(COMPLETE_EVENT, ctx) - mock_events.subscribe_to_all_events = callback + cast(MagicMock, mock_events).subscribe_to_all_events = callback - mock_on_event = Mock() + mock_on_event = MagicMock() client_with_events.run_task(Task(name="foo"), on_event=mock_on_event) mock_on_event.assert_called_once_with(COMPLETE_EVENT) @@ -473,7 +504,6 @@ def test_get_all_tasks_span_ok( def test_create_task_span_ok( exporter: JsonObjectSpanExporter, client: BlueapiClient, - mock_rest: Mock, ): with asserting_span_exporter(exporter, "create_task", "task"): client.create_task(task=Task(name="foo")) @@ -482,7 +512,6 @@ def test_create_task_span_ok( def test_clear_task_span_ok( exporter: JsonObjectSpanExporter, client: BlueapiClient, - mock_rest: Mock, ): with asserting_span_exporter(exporter, "clear_task"): client.clear_task(task_id="foo") @@ -498,7 +527,6 @@ def test_get_active_task_span_ok( def test_start_task_span_ok( exporter: JsonObjectSpanExporter, client: BlueapiClient, - mock_rest: Mock, ): with asserting_span_exporter(exporter, "start_task", "task"): client.start_task(task=WorkerTask(task_id="bar")) @@ -507,10 +535,12 @@ def test_start_task_span_ok( def test_create_and_start_task_span_ok( exporter: JsonObjectSpanExporter, client: BlueapiClient, - mock_rest: Mock, + mock_rest: BlueapiRestClient, ): - mock_rest.create_task.return_value = TaskResponse(task_id="baz") - mock_rest.update_worker_task.return_value = TaskResponse(task_id="baz") + cast(MagicMock, mock_rest.create_task).return_value = TaskResponse(task_id="baz") + cast(MagicMock, mock_rest.update_worker_task).return_value = TaskResponse( + task_id="baz" + ) with asserting_span_exporter(exporter, "create_and_start_task", "task"): client.create_and_start_task(Task(name="baz")) @@ -525,7 +555,6 @@ def test_get_environment_span_ok( def test_reload_environment_span_ok( exporter: JsonObjectSpanExporter, client: BlueapiClient, - mock_rest: Mock, ): with asserting_span_exporter(exporter, "reload_environment"): client.reload_environment() @@ -534,7 +563,6 @@ def test_reload_environment_span_ok( def test_abort_span_ok( exporter: JsonObjectSpanExporter, client: BlueapiClient, - mock_rest: Mock, ): with asserting_span_exporter(exporter, "abort", "reason"): client.abort(reason="foo") @@ -543,7 +571,6 @@ def test_abort_span_ok( def test_stop_span_ok( exporter: JsonObjectSpanExporter, client: BlueapiClient, - mock_rest: Mock, ): with asserting_span_exporter(exporter, "stop"): client.stop() @@ -552,7 +579,6 @@ def test_stop_span_ok( def test_pause_span_ok( exporter: JsonObjectSpanExporter, client: BlueapiClient, - mock_rest: Mock, ): with asserting_span_exporter(exporter, "pause"): client.pause(defer=True) @@ -561,7 +587,6 @@ def test_pause_span_ok( def test_resume_span_ok( exporter: JsonObjectSpanExporter, client: BlueapiClient, - mock_rest: Mock, ): with asserting_span_exporter(exporter, "resume"): client.resume() diff --git a/tests/unit_tests/client/test_event_bus.py b/tests/unit_tests/client/test_event_bus.py index 89f5b5ef2..3ba8b9f5f 100644 --- a/tests/unit_tests/client/test_event_bus.py +++ b/tests/unit_tests/client/test_event_bus.py @@ -1,9 +1,13 @@ +from collections.abc import Callable +from typing import Any from unittest.mock import ANY, Mock import pytest from bluesky_stomp.messaging import StompClient from blueapi.client.event_bus import BlueskyStreamingError, EventBusClient +from blueapi.core.bluesky_types import DataEvent +from blueapi.worker.event import ProgressEvent, WorkerEvent @pytest.fixture @@ -18,7 +22,7 @@ def events(mock_stomp_client: StompClient) -> EventBusClient: def test_context_manager_connects_and_disconnects( events: EventBusClient, - mock_stomp_client: Mock, + mock_stomp_client: StompClient, ): mock_stomp_client.connect.assert_not_called() mock_stomp_client.disconnect.assert_not_called() @@ -32,23 +36,23 @@ def test_context_manager_connects_and_disconnects( def test_client_subscribes_to_all_events( events: EventBusClient, - mock_stomp_client: Mock, + mock_stomp_client: StompClient, ): - on_event = Mock + on_event = Mock(spec=Callable[[WorkerEvent | ProgressEvent | DataEvent, Any], None]) with events: - events.subscribe_to_all_events(on_event=on_event) # type: ignore + events.subscribe_to_all_events(on_event=on_event) mock_stomp_client.subscribe.assert_called_once_with(ANY, on_event) def test_client_raises_streaming_error_on_subscribe_failure( events: EventBusClient, - mock_stomp_client: Mock, + mock_stomp_client: StompClient, ): mock_stomp_client.subscribe.side_effect = RuntimeError("Foo") - on_event = Mock + on_event = Mock(spec=Callable[[WorkerEvent | ProgressEvent | DataEvent, Any], None]) with events: with pytest.raises( BlueskyStreamingError, match="Unable to subscribe to messages from blueapi", ): - events.subscribe_to_all_events(on_event=on_event) # type: ignore + events.subscribe_to_all_events(on_event=on_event) diff --git a/tests/unit_tests/core/test_context.py b/tests/unit_tests/core/test_context.py index 0540401f2..16f0f5c0b 100644 --- a/tests/unit_tests/core/test_context.py +++ b/tests/unit_tests/core/test_context.py @@ -24,32 +24,32 @@ # -def has_no_params() -> MsgGenerator: # type: ignore - ... +def has_no_params() -> MsgGenerator: + yield from {} -def has_one_param(foo: int) -> MsgGenerator: # type: ignore - ... +def has_one_param(foo: int) -> MsgGenerator: + yield from {} -def has_some_params(foo: int = 42, bar: str = "bar") -> MsgGenerator: # type: ignore - ... +def has_some_params(foo: int = 42, bar: str = "bar") -> MsgGenerator: + yield from {} -def has_typeless_param(foo) -> MsgGenerator: # type: ignore - ... +def has_typeless_param(foo) -> MsgGenerator: + yield from {} -def has_typed_and_typeless_params(foo: int, bar) -> MsgGenerator: # type: ignore - ... +def has_typed_and_typeless_params(foo: int, bar) -> MsgGenerator: + yield from {} -def has_typeless_params(foo, bar) -> MsgGenerator: # type: ignore - ... +def has_typeless_params(foo, bar) -> MsgGenerator: + yield from {} def has_default_reference(m: Movable = inject(SIM_MOTOR_NAME)) -> MsgGenerator: - yield from [] + yield from {} MOVABLE_DEFAULT = [inject(SIM_MOTOR_NAME)] @@ -58,7 +58,7 @@ def has_default_reference(m: Movable = inject(SIM_MOTOR_NAME)) -> MsgGenerator: def has_default_nested_reference( m: list[Movable] = MOVABLE_DEFAULT, ) -> MsgGenerator: - yield from [] + yield from {} # @@ -102,12 +102,11 @@ def devicey_context(sim_motor: SynAxis, sim_detector: SynGauss) -> BlueskyContex class SomeConfigurable: - def read_configuration(self) -> SyncOrAsync[dict[str, Reading]]: # type: ignore - ... + def read_configuration(self) -> SyncOrAsync[dict[str, Reading]]: + return {} - def describe_configuration( # type: ignore - self, - ) -> SyncOrAsync[dict[str, Descriptor]]: ... + def describe_configuration(self) -> SyncOrAsync[dict[str, Descriptor]]: + return {} @pytest.fixture @@ -124,11 +123,11 @@ def test_add_plan(empty_context: BlueskyContext, plan: PlanGenerator) -> None: def test_generated_schema( empty_context: BlueskyContext, ): - def demo_plan(foo: int, mov: Movable) -> MsgGenerator: # type: ignore - ... + def demo_plan(foo: int, mov: Movable) -> MsgGenerator: + yield from {} empty_context.register_plan(demo_plan) - schema = empty_context.plans["demo_plan"].model.schema() + schema = empty_context.plans["demo_plan"].model.model_json_schema() assert schema["properties"] == { "foo": {"title": "Foo", "type": "integer"}, "mov": {"title": "Mov", "type": "bluesky.protocols.Movable"}, @@ -246,12 +245,12 @@ def test_lookup_non_device(devicey_context: BlueskyContext) -> None: def test_add_non_plan(empty_context: BlueskyContext) -> None: with pytest.raises(TypeError): - empty_context.register_plan("not a plan") # type: ignore + empty_context.register_plan("not a plan") def test_add_non_device(empty_context: BlueskyContext) -> None: with pytest.raises(TypeError): - empty_context.register_device("not a device") # type: ignore + empty_context.register_device("not a device") def test_add_devices_and_plans_from_modules_with_config( @@ -307,12 +306,12 @@ def test_reference_type_conversion(empty_context: BlueskyContext) -> None: ) -def test_reference_type_conversion_union(empty_context: BlueskyContext) -> None: +def test_reference_type_conversion_explicit_union( + empty_context: BlueskyContext, +) -> None: movable_ref: type = empty_context._reference(Movable) assert empty_context._convert_type(Movable) == movable_ref - assert ( - empty_context._convert_type(Union[Movable, int]) == Union[movable_ref, int] # noqa # type: ignore - ) + assert empty_context._convert_type(Union[Movable, int]) == Union[movable_ref, int] # noqa # type: ignore def test_reference_type_conversion_new_style_union( @@ -320,14 +319,12 @@ def test_reference_type_conversion_new_style_union( ) -> None: movable_ref: type = empty_context._reference(Movable) assert empty_context._convert_type(Movable) == movable_ref - assert ( - empty_context._convert_type(Movable | int) == movable_ref | int # type: ignore - ) + assert empty_context._convert_type(Movable | int) == movable_ref | int def test_default_device_reference(empty_context: BlueskyContext) -> None: - def default_movable(mov: Movable = "demo") -> MsgGenerator: # type: ignore - ... + def default_movable(mov: Movable = inject("demo")) -> MsgGenerator: + yield from {} spec = empty_context._type_spec_for_function(default_movable) movable_ref = empty_context._reference(Movable) diff --git a/tests/unit_tests/service/test_interface.py b/tests/unit_tests/service/test_interface.py index 8faf91332..4858ac03c 100644 --- a/tests/unit_tests/service/test_interface.py +++ b/tests/unit_tests/service/test_interface.py @@ -18,12 +18,12 @@ @pytest.fixture -def mock_connection() -> Mock: +def mock_connection() -> Connection: return Mock(spec=Connection) @pytest.fixture -def mock_stomp_client(mock_connection: Mock) -> StompClient: +def mock_stomp_client(mock_connection: Connection) -> StompClient: stomp_client = StompClient(conn=mock_connection) stomp_client.disconnect = MagicMock() return stomp_client diff --git a/tests/unit_tests/service/test_rest_api.py b/tests/unit_tests/service/test_rest_api.py index 8a1fe5219..869c9b759 100644 --- a/tests/unit_tests/service/test_rest_api.py +++ b/tests/unit_tests/service/test_rest_api.py @@ -1,6 +1,7 @@ import uuid from collections.abc import Iterator from dataclasses import dataclass +from typing import Any from unittest.mock import MagicMock, patch import pytest @@ -146,10 +147,7 @@ def test_get_non_existent_device_by_name( @patch("blueapi.service.interface.submit_task") -@patch("blueapi.service.interface.get_plan") -def test_create_task( - get_plan_mock: MagicMock, submit_task_mock: MagicMock, client: TestClient -) -> None: +def test_create_task(submit_task_mock: MagicMock, client: TestClient) -> None: task = Task(name="count", params={"detectors": ["x"]}) task_id = str(uuid.uuid4()) @@ -171,14 +169,20 @@ class MyModel(BaseModel): plan = Plan(name="my-plan", model=MyModel) get_plan_mock.return_value = PlanModel.from_plan(plan) - submit_task_mock.side_effect = ValidationError.from_exception_data( - title="ValueError", - line_errors=[ - InitErrorDetails( - type="missing", loc=("id",), msg="value is required for Identifier" - ) # type: ignore - ], - ) + + def raise_validation_error(bar: Any): + raise ValidationError.from_exception_data( + title="ValueError", + line_errors=[ + InitErrorDetails( + input=bar, + type="missing", + loc=("id",), + ) + ], + ) + + submit_task_mock.side_effect = raise_validation_error response = client.post("/tasks", json={"name": "my-plan"}) assert response.status_code == 422 assert response.json() == { diff --git a/tests/unit_tests/service/test_runner.py b/tests/unit_tests/service/test_runner.py index 1162d8108..c349e1eb3 100644 --- a/tests/unit_tests/service/test_runner.py +++ b/tests/unit_tests/service/test_runner.py @@ -1,3 +1,4 @@ +from collections.abc import Generator from typing import Any, Generic, TypeVar from unittest import mock from unittest.mock import MagicMock, patch @@ -20,17 +21,17 @@ @pytest.fixture -def local_runner(): +def local_runner() -> WorkerDispatcher: return WorkerDispatcher(use_subprocess=False) @pytest.fixture -def runner(): +def runner() -> WorkerDispatcher: return WorkerDispatcher() @pytest.fixture -def started_runner(runner: WorkerDispatcher): +def started_runner(runner: WorkerDispatcher) -> Generator[WorkerDispatcher]: runner.start() yield runner runner.stop() diff --git a/tests/unit_tests/test_cli.py b/tests/unit_tests/test_cli.py index b6c75eea9..76b1ab6ee 100644 --- a/tests/unit_tests/test_cli.py +++ b/tests/unit_tests/test_cli.py @@ -33,17 +33,17 @@ @pytest.fixture -def mock_connection() -> Mock: +def mock_connection() -> Connection: return Mock(spec=Connection) @pytest.fixture -def mock_stomp_client(mock_connection: Mock) -> StompClient: +def mock_stomp_client(mock_connection: Connection) -> StompClient: return StompClient(conn=mock_connection) @pytest.fixture -def runner(): +def runner() -> CliRunner: return CliRunner() @@ -154,7 +154,7 @@ def test_cannot_run_plans_without_stomp_config(runner: CliRunner): def test_valid_stomp_config_for_listener( mock_stomp_client: StompClient, runner: CliRunner, - mock_connection: Mock, + mock_connection: Connection, ): mock_connection.is_connected.return_value = True result = runner.invoke( diff --git a/tests/unit_tests/test_config.py b/tests/unit_tests/test_config.py index 5e2ec84f9..eff522446 100644 --- a/tests/unit_tests/test_config.py +++ b/tests/unit_tests/test_config.py @@ -322,11 +322,11 @@ def test_config_yaml_parsed_complete(temp_yaml_config_file: dict): assert loaded_config.stomp.auth is not None assert ( loaded_config.stomp.auth.password.get_secret_value() - == config_data["stomp"]["auth"]["password"] # noqa: E501 + == config_data["stomp"]["auth"]["password"] ) # Remove the password field to not compare it again in the full dict comparison del target_dict_json["stomp"]["auth"]["password"] - del config_data["stomp"]["auth"]["password"] # noqa: E501 + del config_data["stomp"]["auth"]["password"] # Assert that the remaining config data is identical assert ( target_dict_json == config_data diff --git a/tests/unit_tests/worker/test_task_worker.py b/tests/unit_tests/worker/test_task_worker.py index 1b42af09e..1bf125158 100644 --- a/tests/unit_tests/worker/test_task_worker.py +++ b/tests/unit_tests/worker/test_task_worker.py @@ -453,12 +453,10 @@ def take_events_from_streams( The type for streams will be any combination of event streams each of a given event type, where the event type is generic: - List[ - Union[ - EventStream[WorkerEvent, int], - EventStream[DataEvent, int], - EventStream[ProgressEvent, int] - ] + list[ + EventStream[WorkerEvent, int] | + EventStream[DataEvent, int] | + EventStream[ProgressEvent, int] ] """