diff --git a/src/ophyd_async/core/_detector.py b/src/ophyd_async/core/_detector.py index 45d48fa43f..374c38adc6 100644 --- a/src/ophyd_async/core/_detector.py +++ b/src/ophyd_async/core/_detector.py @@ -236,7 +236,7 @@ async def _check_config_sigs(self): @AsyncStatus.wrap async def unstage(self) -> None: # Stop data writing. - await self.writer.close() + await asyncio.gather(self.writer.close(), self.controller.disarm()) async def read_configuration(self) -> Dict[str, Reading]: return await merge_gathered_dicts(sig.read() for sig in self._config_sigs) diff --git a/src/ophyd_async/sim/demo/_pattern_detector/_pattern_detector_controller.py b/src/ophyd_async/sim/demo/_pattern_detector/_pattern_detector_controller.py index 039ddb066c..756b4177ed 100644 --- a/src/ophyd_async/sim/demo/_pattern_detector/_pattern_detector_controller.py +++ b/src/ophyd_async/sim/demo/_pattern_detector/_pattern_detector_controller.py @@ -1,8 +1,6 @@ import asyncio from typing import Optional -from pydantic import Field - from ophyd_async.core import DetectorControl, PathProvider from ophyd_async.core._detector import TriggerInfo @@ -14,7 +12,7 @@ def __init__( self, pattern_generator: PatternGenerator, path_provider: PathProvider, - exposure: float = Field(default=0.1), + exposure: Optional[float] = 0.1, ) -> None: self.pattern_generator: PatternGenerator = pattern_generator self.pattern_generator.set_exposure(exposure) @@ -46,13 +44,13 @@ async def wait_for_idle(self): await self.task async def disarm(self): - if self.task: + if self.task and not self.task.done(): self.task.cancel() try: await self.task except asyncio.CancelledError: pass - self.task = None + self.task = None def get_deadtime(self, exposure: float | None) -> float: return 0.001 diff --git a/src/ophyd_async/sim/demo/_pattern_detector/_pattern_generator.py b/src/ophyd_async/sim/demo/_pattern_detector/_pattern_generator.py index 7b9269f0bb..0031a931e4 100644 --- a/src/ophyd_async/sim/demo/_pattern_detector/_pattern_generator.py +++ b/src/ophyd_async/sim/demo/_pattern_detector/_pattern_generator.py @@ -199,7 +199,6 @@ async def collect_stream_docs( def close(self) -> None: if self._handle_for_h5_file: self._handle_for_h5_file.close() - print("file closed") self._handle_for_h5_file = None async def observe_indices_written( diff --git a/tests/conftest.py b/tests/conftest.py index d9a39888af..6f042ac19b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,6 @@ import asyncio import os +import pprint import subprocess import sys import time @@ -8,6 +9,7 @@ import pytest from bluesky.run_engine import RunEngine, TransitionError +from pytest import FixtureRequest from ophyd_async.core import ( DetectorTrigger, @@ -58,11 +60,70 @@ def configure_epics_environment(): os.environ["EPICS_PVA_AUTO_ADDR_LIST"] = "NO" +_ALLOWED_PYTEST_TASKS = {"async_finalizer", "async_setup", "async_teardown"} + + +def _error_and_kill_pending_tasks( + loop: asyncio.AbstractEventLoop, test_name: str, test_passed: bool +) -> set[asyncio.Task]: + """Cancels pending tasks in the event loop for a test. Raises an exception if + the test hasn't already. + + Args: + loop: The event loop to check for pending tasks. + test_name: The name of the test. + test_passed: Indicates whether the test passed. + + Returns: + set[asyncio.Task]: The set of unfinished tasks that were cancelled. + + Raises: + RuntimeError: If there are unfinished tasks and the test didn't fail. + """ + unfinished_tasks = { + task + for task in asyncio.all_tasks(loop) + if task.get_coro().__name__ not in _ALLOWED_PYTEST_TASKS and not task.done() + } + for task in unfinished_tasks: + task.cancel() + + # We only raise an exception here if the test didn't fail anyway. + # If it did then it makes sense that there's some tasks we need to cancel, + # but an exception will already have been raised. + if unfinished_tasks and test_passed: + raise RuntimeError( + f"Not all tasks closed during test {test_name}:\n" + f"{pprint.pformat(unfinished_tasks, width=88)}" + ) + + return unfinished_tasks + + +@pytest.fixture(autouse=True, scope="function") +def fail_test_on_unclosed_tasks(request: FixtureRequest): + """ + Used on every test to ensure failure if there are pending tasks + by the end of the test. + """ + + fail_count = request.session.testsfailed + loop = asyncio.get_event_loop() + loop.set_debug(True) + + request.addfinalizer( + lambda: _error_and_kill_pending_tasks( + loop, request.node.name, request.session.testsfailed == fail_count + ) + ) + + @pytest.fixture(scope="function") -def RE(request): +def RE(request: FixtureRequest): loop = asyncio.new_event_loop() loop.set_debug(True) RE = RunEngine({}, call_returns_result=True, loop=loop) + fail_count = request.session.testsfailed def clean_event_loop(): if RE.state not in ("idle", "panicked"): @@ -70,9 +131,16 @@ def clean_event_loop(): RE.halt() except TransitionError: pass + loop.call_soon_threadsafe(loop.stop) RE._th.join() - loop.close() + + try: + _error_and_kill_pending_tasks( + loop, request.node.name, request.session.testsfailed == fail_count + ) + finally: + loop.close() request.addfinalizer(clean_event_loop) return RE diff --git a/tests/core/test_status.py b/tests/core/test_status.py index f8af39bbc4..d2ec206bc8 100644 --- a/tests/core/test_status.py +++ b/tests/core/test_status.py @@ -147,9 +147,18 @@ async def test_status_propogates_traceback_under_RE(RE) -> None: async def test_async_status_exception_timeout(): - st = AsyncStatus(asyncio.sleep(0.1)) - with pytest.raises(Exception): - st.exception(timeout=1.0) + try: + st = AsyncStatus(asyncio.sleep(0.1)) + with pytest.raises( + ValueError, + match=( + "cannot honour any timeout other than 0 in an asynchronous function" + ), + ): + st.exception(timeout=1.0) + finally: + if not st.done: + st.task.cancel() @pytest.fixture diff --git a/tests/epics/adcore/test_single_trigger.py b/tests/epics/adcore/test_single_trigger.py index 93e42ed400..9bb5135cc0 100644 --- a/tests/epics/adcore/test_single_trigger.py +++ b/tests/epics/adcore/test_single_trigger.py @@ -1,42 +1,48 @@ +import bluesky.plan_stubs as bps import bluesky.plans as bp import pytest from bluesky import RunEngine -from ophyd_async.core import DeviceCollector, set_mock_value +import ophyd_async.plan_stubs as ops from ophyd_async.epics import adcore @pytest.fixture -async def single_trigger_det(): - async with DeviceCollector(mock=True): - stats = adcore.NDPluginStatsIO("PREFIX:STATS") - det = adcore.SingleTriggerDetector( - drv=adcore.ADBaseIO("PREFIX:DRV"), - stats=stats, - read_uncached=[stats.unique_id], - ) +async def single_trigger_det_with_stats(): + stats = adcore.NDPluginStatsIO("PREFIX:STATS", name="stats") + det = adcore.SingleTriggerDetector( + drv=adcore.ADBaseIO("PREFIX:DRV"), + stats=stats, + read_uncached=[stats.unique_id], + name="det", + ) - assert det.name == "det" - assert stats.name == "det-stats" # Set non-default values to check they are set back # These are using set_mock_value to simulate the backend IOC being setup # in a particular way, rather than values being set by the Ophyd signals - set_mock_value(det.drv.acquire_time, 0.5) - set_mock_value(det.drv.array_counter, 1) - set_mock_value(det.drv.image_mode, adcore.ImageMode.continuous) - set_mock_value(stats.unique_id, 3) - yield det + yield det, stats async def test_single_trigger_det( - single_trigger_det: adcore.SingleTriggerDetector, RE: RunEngine + single_trigger_det_with_stats: adcore.SingleTriggerDetector, RE: RunEngine ): + single_trigger_det, stats = single_trigger_det_with_stats names = [] docs = [] RE.subscribe(lambda name, _: names.append(name)) RE.subscribe(lambda _, doc: docs.append(doc)) - RE(bp.count([single_trigger_det])) + def plan(): + yield from ops.ensure_connected(single_trigger_det, mock=True) + yield from bps.abs_set(single_trigger_det.drv.acquire_time, 0.5) + yield from bps.abs_set(single_trigger_det.drv.array_counter, 1) + yield from bps.abs_set( + single_trigger_det.drv.image_mode, adcore.ImageMode.continuous + ) + # set_mock_value(stats.unique_id, 3) + yield from bp.count([single_trigger_det]) + + RE(plan()) drv = single_trigger_det.drv assert 1 == await drv.acquire.get_value() @@ -47,4 +53,4 @@ async def test_single_trigger_det( _, descriptor, event, _ = docs assert descriptor["configuration"]["det"]["data"]["det-drv-acquire_time"] == 0.5 assert event["data"]["det-drv-array_counter"] == 1 - assert event["data"]["det-stats-unique_id"] == 3 + assert event["data"]["det-stats-unique_id"] == 0 diff --git a/tests/epics/adsimdetector/test_sim.py b/tests/epics/adsimdetector/test_sim.py index 891d89c33c..9e69733680 100644 --- a/tests/epics/adsimdetector/test_sim.py +++ b/tests/epics/adsimdetector/test_sim.py @@ -11,6 +11,7 @@ from bluesky import RunEngine from bluesky.utils import new_uid +import ophyd_async.plan_stubs as ops from ophyd_async.core import ( AsyncStatus, DetectorTrigger, @@ -185,22 +186,37 @@ async def test_two_detectors_step( for det in two_detectors ] - RE(count_sim(two_detectors, times=1)) - controller_a = cast(adsimdetector.SimController, two_detectors[0].controller) writer_a = cast(adcore.ADHDFWriter, two_detectors[0].writer) writer_b = cast(adcore.ADHDFWriter, two_detectors[1].writer) + info_a = writer_a._path_provider(device_name=writer_a.hdf.name) + info_b = writer_b._path_provider(device_name=writer_b.hdf.name) + file_name_a = None + file_name_b = None + + def plan(): + nonlocal file_name_a, file_name_b + yield from count_sim(two_detectors, times=1) + + drv = controller_a.driver + assert False is (yield from bps.rd(drv.acquire)) + assert adcore.ImageMode.multiple == (yield from bps.rd(drv.image_mode)) + + hdfb = writer_b.hdf + assert True is (yield from bps.rd(hdfb.lazy_open)) + assert True is (yield from bps.rd(hdfb.swmr_mode)) + assert 0 == (yield from bps.rd(hdfb.num_capture)) + assert adcore.FileWriteMode.stream == (yield from bps.rd(hdfb.file_write_mode)) - drv = controller_a.driver - assert 1 == await drv.acquire.get_value() - assert adcore.ImageMode.multiple == await drv.image_mode.get_value() + assert (yield from bps.rd(writer_a.hdf.file_path)) == str(info_a.directory_path) + file_name_a = yield from bps.rd(writer_a.hdf.file_name) + assert file_name_a == info_a.filename - hdfb = writer_b.hdf - assert True is await hdfb.lazy_open.get_value() - assert True is await hdfb.swmr_mode.get_value() - assert 0 == await hdfb.num_capture.get_value() - assert adcore.FileWriteMode.stream == await hdfb.file_write_mode.get_value() + assert (yield from bps.rd(writer_b.hdf.file_path)) == str(info_b.directory_path) + file_name_b = yield from bps.rd(writer_b.hdf.file_name) + assert file_name_b == info_b.filename + RE(plan()) assert names == [ "start", "descriptor", @@ -211,16 +227,6 @@ async def test_two_detectors_step( "event", "stop", ] - info_a = writer_a._path_provider(device_name=writer_a.hdf.name) - info_b = writer_b._path_provider(device_name=writer_b.hdf.name) - - assert await writer_a.hdf.file_path.get_value() == str(info_a.directory_path) - file_name_a = await writer_a.hdf.file_name.get_value() - assert file_name_a == info_a.filename - - assert await writer_b.hdf.file_path.get_value() == str(info_b.directory_path) - file_name_b = await writer_b.hdf.file_name.get_value() - assert file_name_b == info_b.filename _, descriptor, sra, sda, srb, sdb, event, _ = docs assert descriptor["configuration"]["testa"]["data"]["testa-drv-acquire_time"] == 0.8 @@ -322,39 +328,52 @@ async def test_trigger_logic(): ... -async def test_detector_with_unnamed_or_disconnected_config_sigs( - RE, static_filename_provider: StaticFilenameProvider, tmp_path: Path +@pytest.mark.parametrize( + "driver_name, error_output", + [ + ("", "config signal must be named before it is passed to the detector"), + ( + "some-name", + ( + "config signal some-name-acquire_time must be connected " + "before it is passed to the detector" + ), + ), + ], +) +def test_detector_with_unnamed_or_disconnected_config_sigs( + RE, + static_filename_provider: StaticFilenameProvider, + tmp_path: Path, + driver_name, + error_output, ): dp = StaticPathProvider(static_filename_provider, tmp_path) - some_other_driver = adcore.ADBaseIO("TEST") + some_other_driver = adcore.ADBaseIO("TEST", name=driver_name) - async with DeviceCollector(mock=True): - det = adsimdetector.SimDetector( - "FOO:", - dp, - name="foo", - ) + det = adsimdetector.SimDetector( + "FOO:", + dp, + name="foo", + ) det._config_sigs = [some_other_driver.acquire_time, det.drv.acquire] - with pytest.raises(Exception) as exc: - RE(count_sim([det], times=1)) - - assert isinstance(exc.value.args[0], AsyncStatus) - assert ( - str(exc.value.args[0].exception()) - == "config signal must be named before it is passed to the detector" - ) + def my_plan(): + yield from ops.ensure_connected(det, mock=True) + assert det.drv.acquire.name == "foo-drv-acquire" + assert some_other_driver.acquire_time.name == ( + driver_name + "-acquire_time" if driver_name else "" + ) - some_other_driver.set_name("some-name") + yield from count_sim([det], times=1) with pytest.raises(Exception) as exc: - RE(count_sim([det], times=1)) + RE(my_plan()) assert isinstance(exc.value.args[0], AsyncStatus) - assert ( - str(exc.value.args[0].exception()) - == "config signal some-name-acquire_time must be connected before it is " - + "passed to the detector" - ) + assert str(exc.value.args[0].exception()) == error_output + + # Need to unstage to properly kill tasks + RE(bps.unstage(det, wait=True)) diff --git a/tests/sim/conftest.py b/tests/sim/conftest.py index fe0871f5e0..b02740c082 100644 --- a/tests/sim/conftest.py +++ b/tests/sim/conftest.py @@ -2,14 +2,10 @@ import pytest -from ophyd_async.core import DeviceCollector from ophyd_async.sim.demo import PatternDetector @pytest.fixture async def sim_pattern_detector(tmp_path_factory) -> PatternDetector: path: Path = tmp_path_factory.mktemp("tmp") - async with DeviceCollector(mock=True): - sim_pattern_detector = PatternDetector(name="PATTERN1", path=path) - - return sim_pattern_detector + return PatternDetector(name="PATTERN1", path=path) diff --git a/tests/sim/test_sim_detector.py b/tests/sim/test_sim_detector.py index 785afee3d9..7631a5b558 100644 --- a/tests/sim/test_sim_detector.py +++ b/tests/sim/test_sim_detector.py @@ -3,21 +3,13 @@ import bluesky.plans as bp import h5py import numpy as np -import pytest from bluesky import RunEngine -from ophyd_async.core import DeviceCollector, assert_emitted -from ophyd_async.epics import motor +from ophyd_async.core import assert_emitted +from ophyd_async.plan_stubs import ensure_connected from ophyd_async.sim.demo import PatternDetector -@pytest.fixture -async def sim_motor(): - async with DeviceCollector(mock=True): - sim_motor = motor.Motor("test") - return sim_motor - - async def test_sim_pattern_detector_initialization( sim_pattern_detector: PatternDetector, ): @@ -33,7 +25,7 @@ async def test_detector_creates_controller_and_writer( assert sim_pattern_detector.controller -async def test_writes_pattern_to_file( +def test_writes_pattern_to_file( sim_pattern_detector: PatternDetector, RE: RunEngine, ): @@ -43,7 +35,11 @@ async def test_writes_pattern_to_file( def capture_emitted(name, doc): docs[name].append(doc) - RE(bp.count([sim_pattern_detector]), capture_emitted) + def plan(): + yield from ensure_connected(sim_pattern_detector, mock=True) + yield from bp.count([sim_pattern_detector]) + + RE(plan(), capture_emitted) assert_emitted( docs, start=1, descriptor=1, stream_resource=2, stream_datum=2, event=1, stop=1 ) diff --git a/tests/sim/test_streaming_plan.py b/tests/sim/test_streaming_plan.py index 9a4b46a80e..222f92d685 100644 --- a/tests/sim/test_streaming_plan.py +++ b/tests/sim/test_streaming_plan.py @@ -4,6 +4,7 @@ from bluesky.run_engine import RunEngine from ophyd_async.core import assert_emitted +from ophyd_async.plan_stubs import ensure_connected from ophyd_async.sim.demo import PatternDetector @@ -20,7 +21,11 @@ def append_and_print(name, doc): RE.subscribe(append_and_print) - RE(bp.count([sim_pattern_detector], num=1)) + def plan(): + yield from ensure_connected(sim_pattern_detector, mock=True) + yield from bp.count([sim_pattern_detector], num=1) + + RE(plan()) # NOTE - double resource because double stream assert names == [ @@ -38,8 +43,13 @@ def append_and_print(name, doc): async def test_plan(RE: RunEngine, sim_pattern_detector: PatternDetector): docs = defaultdict(list) - RE(bp.count([sim_pattern_detector]), lambda name, doc: docs[name].append(doc)) + + def plan(): + yield from ensure_connected(sim_pattern_detector, mock=True) + yield from bp.count([sim_pattern_detector]) + + RE(plan(), lambda name, doc: docs[name].append(doc)) + assert_emitted( docs, start=1, descriptor=1, stream_resource=2, stream_datum=2, event=1, stop=1 ) - await sim_pattern_detector.writer.close()