Skip to content

Commit

Permalink
Merge pull request #336 from DiamondLightSource/hyperion_1125_pin_det…
Browse files Browse the repository at this point in the history
…ect_hotfixes

Make the ophyd pin tip detection triggerable like the AD plugin
  • Loading branch information
olliesilvester authored Feb 21, 2024
2 parents d6c3872 + 73ab698 commit b938678
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 88 deletions.
160 changes: 80 additions & 80 deletions src/dodal/devices/oav/pin_image_recognition/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import asyncio
import time
from collections import OrderedDict
from typing import Optional, Tuple
from typing import Optional

import numpy as np
from bluesky.protocols import Descriptor, Reading
from numpy.typing import NDArray
from ophyd_async.core import SignalR, SignalRW, StandardReadable
from ophyd_async.core import AsyncStatus, StandardReadable, observe_value, set_sim_value
from ophyd_async.epics.signal import epics_signal_r

from dodal.devices.oav.pin_image_recognition.utils import (
Expand All @@ -15,20 +13,36 @@
ScanDirections,
identity,
)
from dodal.devices.ophyd_async_utils import create_soft_signal_rw
from dodal.devices.ophyd_async_utils import create_soft_signal_r, create_soft_signal_rw
from dodal.log import LOGGER

Tip = tuple[Optional[int], Optional[int]]


class InvalidPinException(Exception):
pass


def set_pin_value(*args, **kwargs):
if args[1] == PinTipDetection.INVALID_POSITION:
raise InvalidPinException
else:
return set_sim_value(*args, **kwargs)


class PinTipDetection(StandardReadable):
"""
A device which will read a single frame from an on-axis view and use that frame
to calculate the pin-tip offset (in pixels) of that frame.
A device which will read from an on-axis view and calculate the location of the
pin-tip (in pixels) of that frame.
Used for pin tip centring workflow.
Note that if the tip of the sample is off-screen, this class will return the centre as the "edge"
of the image. If the entire sample if off-screen (i.e. no suitable edges were detected at all)
then it will return (None, None).
Note that if the tip of the sample is off-screen, this class will return the tip as
the "edge" of the image.
If no tip is found it will return {INVALID_POSITION}. However, it will also
occassionally give incorrect data. Therefore, it is recommended that you trigger
this device, which will attempt to find a pin within {validity_timeout} seconds.
"""

INVALID_POSITION = (None, None)
Expand All @@ -37,55 +51,45 @@ def __init__(self, prefix: str, name: str = ""):
self._prefix: str = prefix
self._name = name

self.array_data: SignalR[NDArray[np.uint8]] = epics_signal_r(
NDArray[np.uint8], f"pva://{prefix}PVA:ARRAY"
)
self.triggered_tip = create_soft_signal_r(Tip, "triggered_tip", self.name)
self.array_data = epics_signal_r(NDArray[np.uint8], f"pva://{prefix}PVA:ARRAY")

# Soft parameters for pin-tip detection.
self.timeout: SignalRW[float] = create_soft_signal_rw(
float, "timeout", self.name
)
self.preprocess_operation: SignalRW[int] = create_soft_signal_rw(
int, "preprocess", self.name
)
self.preprocess_ksize: SignalRW[int] = create_soft_signal_rw(
self.preprocess_operation = create_soft_signal_rw(int, "preprocess", self.name)
self.preprocess_ksize = create_soft_signal_rw(
int, "preprocess_ksize", self.name
)
self.preprocess_iterations: SignalRW[int] = create_soft_signal_rw(
self.preprocess_iterations = create_soft_signal_rw(
int, "preprocess_iterations", self.name
)
self.canny_upper_threshold: SignalRW[int] = create_soft_signal_rw(
self.canny_upper_threshold = create_soft_signal_rw(
int, "canny_upper", self.name
)
self.canny_lower_threshold: SignalRW[int] = create_soft_signal_rw(
self.canny_lower_threshold = create_soft_signal_rw(
int, "canny_lower", self.name
)
self.close_ksize: SignalRW[int] = create_soft_signal_rw(
int, "close_ksize", self.name
)
self.close_iterations: SignalRW[int] = create_soft_signal_rw(
self.close_ksize = create_soft_signal_rw(int, "close_ksize", self.name)
self.close_iterations = create_soft_signal_rw(
int, "close_iterations", self.name
)
self.scan_direction: SignalRW[int] = create_soft_signal_rw(
int, "scan_direction", self.name
)
self.min_tip_height: SignalRW[int] = create_soft_signal_rw(
int, "min_tip_height", self.name
)
self.validity_timeout: SignalR[float] = create_soft_signal_rw(
self.scan_direction = create_soft_signal_rw(int, "scan_direction", self.name)
self.min_tip_height = create_soft_signal_rw(int, "min_tip_height", self.name)
self.validity_timeout = create_soft_signal_rw(
float, "validity_timeout", self.name
)

self.set_readable_signals(
read=[self.triggered_tip],
)

super().__init__(name=name)

async def _get_tip_position(
self,
) -> Tuple[Tuple[Optional[int], Optional[int]], float]:
async def _get_tip_position(self, array_data: NDArray[np.uint8]) -> Tip:
"""
Gets the location of the pin tip.
Returns tuple of:
((tip_x, tip_y), timestamp)
(tip_x, tip_y)
"""
preprocess_key = await self.preprocess_operation.get_value()
preprocess_iter = await self.preprocess_iterations.get_value()
Expand Down Expand Up @@ -115,32 +119,22 @@ async def _get_tip_position(
min_tip_height=await self.min_tip_height.get_value(),
)

array_reading: dict[str, Reading] = await self.array_data.read()
array_data: NDArray[np.uint8] = array_reading[self.array_data.name]["value"]
timestamp: float = array_reading[self.array_data.name]["timestamp"]

try:
start_time = time.time()
location = sample_detection.processArray(array_data)
end_time = time.time()
LOGGER.debug(
"Sample location detection took {}ms".format(
(end_time - start_time) * 1000.0
)
start_time = time.time()
location = sample_detection.processArray(array_data)
end_time = time.time()
LOGGER.debug(
"Sample location detection took {}ms".format(
(end_time - start_time) * 1000.0
)
tip_x = location.tip_x
tip_y = location.tip_y
except Exception as e:
LOGGER.error(f"Failed to detect pin-tip location due to exception: {e}")
tip_x, tip_y = self.INVALID_POSITION
)

return (tip_x, tip_y), timestamp
return (location.tip_x, location.tip_y)

async def connect(self, sim: bool = False):
await super().connect(sim)

# Set defaults for soft parameters
await self.timeout.set(10.0)
await self.validity_timeout.set(5.0)
await self.canny_upper_threshold.set(100)
await self.canny_lower_threshold.set(50)
await self.close_iterations.set(5)
Expand All @@ -151,27 +145,33 @@ async def connect(self, sim: bool = False):
await self.preprocess_iterations.set(5)
await self.preprocess_ksize.set(5)

async def read(self) -> dict[str, Reading]:
tip_pos, timestamp = await asyncio.wait_for(
self._get_tip_position(), timeout=await self.timeout.get_value()
)

return OrderedDict(
[
(self._name, {"value": tip_pos, "timestamp": timestamp}),
]
)
@AsyncStatus.wrap
async def trigger(self):
async def _set_triggered_tip():
"""Monitors the camera data and updates the triggered_tip signal.
If a tip is found it will update the signal and stop monitoring
If no tip is found it will retry with the next monitored value
This loop will serve as a good example of using 'observe_value' in the ophyd_async documentation
"""
async for value in observe_value(self.array_data):
try:
set_pin_value(
self.triggered_tip, await self._get_tip_position(value)
)
except Exception as e:
LOGGER.warn(
f"Failed to detect pin-tip location, will retry with next image: {e}"
)
else:
return

async def describe(self) -> dict[str, Descriptor]:
return OrderedDict(
[
(
self._name,
{
"source": f"pva://{self._prefix}PVA:ARRAY",
"dtype": "number",
"shape": [2], # Tuple of (x, y) tip position
},
)
],
)
try:
await asyncio.wait_for(
_set_triggered_tip(), timeout=await self.validity_timeout.get_value()
)
except asyncio.exceptions.TimeoutError:
LOGGER.error(
f"No tip found in {await self.validity_timeout.get_value()} seconds."
)
set_sim_value(self.triggered_tip, self.INVALID_POSITION)
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import asyncio
from unittest.mock import patch
from unittest.mock import MagicMock, patch

import numpy as np
import pytest
Expand All @@ -13,6 +13,7 @@

pytest_plugins = ("pytest_asyncio",)
DEVICE_NAME = "pin_tip_detection"
TRIGGERED_TIP_READING = DEVICE_NAME + "-triggered_tip"


async def _get_pin_tip_detection_device() -> PinTipDetection:
Expand All @@ -31,7 +32,7 @@ async def test_pin_tip_detect_can_be_connected_in_sim_mode():
async def test_soft_parameter_defaults_are_correct():
device = await _get_pin_tip_detection_device()

assert await device.timeout.get_value() == 10.0
assert await device.validity_timeout.get_value() == 5.0
assert await device.canny_lower_threshold.get_value() == 50
assert await device.canny_upper_threshold.get_value() == 100
assert await device.close_ksize.get_value() == 5
Expand All @@ -47,7 +48,7 @@ async def test_soft_parameter_defaults_are_correct():
async def test_numeric_soft_parameters_can_be_changed():
device = await _get_pin_tip_detection_device()

await device.timeout.set(100.0)
await device.validity_timeout.set(100.0)
await device.canny_lower_threshold.set(5)
await device.canny_upper_threshold.set(10)
await device.close_ksize.set(15)
Expand All @@ -58,7 +59,7 @@ async def test_numeric_soft_parameters_can_be_changed():
await device.preprocess_ksize.set(3)
await device.preprocess_iterations.set(4)

assert await device.timeout.get_value() == 100.0
assert await device.validity_timeout.get_value() == 100.0
assert await device.canny_lower_threshold.get_value() == 5
assert await device.canny_upper_threshold.get_value() == 10
assert await device.close_ksize.get_value() == 15
Expand All @@ -73,14 +74,15 @@ async def test_numeric_soft_parameters_can_be_changed():
@pytest.mark.asyncio
async def test_invalid_processing_func_uses_identity_function():
device = await _get_pin_tip_detection_device()
test_sample_location = SampleLocation(100, 200, np.array([]), np.array([]))

set_sim_value(device.preprocess_operation, 50) # Invalid index

with (
patch.object(MxSampleDetect, "__init__", return_value=None) as mock_init,
patch.object(MxSampleDetect, "processArray", return_value=((None, None), None)),
patch.object(MxSampleDetect, "processArray", return_value=test_sample_location),
):
await device.read()
await device._get_tip_position(np.array([]))

mock_init.assert_called_once()

Expand All @@ -104,9 +106,80 @@ async def test_given_valid_data_reading_then_used_to_find_location():
MxSampleDetect, "processArray", return_value=test_sample_location
) as mock_process_array,
):
await device.trigger()
location = await device.read()

process_call = mock_process_array.call_args[0][0]
assert np.array_equal(process_call, image_array)
assert location[DEVICE_NAME]["value"] == (200, 100)
assert location[DEVICE_NAME]["timestamp"] > 0
assert location[TRIGGERED_TIP_READING]["value"] == (200, 100)
assert location[TRIGGERED_TIP_READING]["timestamp"] > 0


@pytest.mark.asyncio
async def test_given_find_tip_fails_when_triggered_then_tip_invalid():
device = await _get_pin_tip_detection_device()
await device.validity_timeout.set(0.1)
set_sim_value(device.array_data, np.array([1, 2, 3]))

with (
patch.object(MxSampleDetect, "__init__", return_value=None),
patch.object(MxSampleDetect, "processArray", side_effect=Exception()),
):
await device.trigger()
reading = await device.read()
assert reading[TRIGGERED_TIP_READING]["value"] == device.INVALID_POSITION


@pytest.mark.asyncio
@patch("dodal.devices.oav.pin_image_recognition.observe_value")
async def test_given_find_tip_fails_twice_when_triggered_then_tip_invalid_and_tried_twice(
mock_image_read,
):
async def get_array_data(_):
yield np.array([1, 2, 3])
yield np.array([1, 2])
await asyncio.sleep(100)

mock_image_read.side_effect = get_array_data
device = await _get_pin_tip_detection_device()
await device.validity_timeout.set(0.1)

with (
patch.object(MxSampleDetect, "__init__", return_value=None),
patch.object(
MxSampleDetect, "processArray", side_effect=Exception()
) as mock_process_array,
):
await device.trigger()
reading = await device.read()
assert reading[TRIGGERED_TIP_READING]["value"] == device.INVALID_POSITION
assert mock_process_array.call_count > 1


@pytest.mark.asyncio
@patch("dodal.devices.oav.pin_image_recognition.LOGGER.warn")
@patch("dodal.devices.oav.pin_image_recognition.observe_value")
async def test_given_tip_invalid_then_loop_keeps_retrying_until_valid(
mock_image_read: MagicMock,
mock_logger: MagicMock,
):
async def get_array_data(_):
yield np.array([1, 2, 3])
yield np.array([1, 2])
await asyncio.sleep(100)

mock_image_read.side_effect = get_array_data
device = await _get_pin_tip_detection_device()

class FakeLocation:
def __init__(self, tip_x, tip_y):
self.tip_x = tip_x
self.tip_y = tip_y

with patch.object(MxSampleDetect, "__init__", return_value=None), patch.object(
MxSampleDetect,
"processArray",
side_effect=[FakeLocation(None, None), FakeLocation(1, 1)],
):
await device.trigger()
mock_logger.assert_called_once()

0 comments on commit b938678

Please sign in to comment.