Skip to content

Commit

Permalink
Merge pull request #358 from DiamondLightSource/hyperion_1068_1213_op…
Browse files Browse the repository at this point in the history
…hyd_oav_for_grid_detect

Make the PinTipDetection device also store top and bottom edges
  • Loading branch information
DominicOram authored Mar 6, 2024
2 parents 350455f + e5e629e commit 56ef04b
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 26 deletions.
9 changes: 6 additions & 3 deletions src/dodal/devices/motors.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import dataclass
from typing import List, Tuple, Union

import numpy as np
from ophyd import Component, Device, EpicsMotor
Expand All @@ -24,8 +25,8 @@ def is_within(self, position: float) -> bool:
:param position: The position to check
:return: True if position is within the limits
"""
low = self.motor.low_limit_travel.get()
high = self.motor.high_limit_travel.get()
low = float(self.motor.low_limit_travel.get())
high = float(self.motor.high_limit_travel.get())
return low <= position <= high


Expand All @@ -39,7 +40,9 @@ class XYZLimitBundle:
y: MotorLimitHelper
z: MotorLimitHelper

def position_valid(self, position: np.ndarray):
def position_valid(
self, position: Union[np.ndarray, List[float], Tuple[float, float, float]]
):
if len(position) != 3:
raise ValueError(
f"Position valid expects a 3-vector, got {position} instead"
Expand Down
39 changes: 27 additions & 12 deletions src/dodal/devices/oav/pin_image_recognition/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from dodal.devices.oav.pin_image_recognition.utils import (
ARRAY_PROCESSING_FUNCTIONS_MAP,
MxSampleDetect,
SampleLocation,
ScanDirections,
identity,
)
Expand Down Expand Up @@ -46,6 +47,12 @@ def __init__(self, prefix: str, name: str = ""):
self._name = name

self.triggered_tip = create_soft_signal_r(Tip, "triggered_tip", self.name)
self.triggered_top_edge = create_soft_signal_r(
NDArray[np.uint32], "triggered_top_edge", self.name
)
self.triggered_bottom_edge = create_soft_signal_r(
NDArray[np.uint32], "triggered_bottom_edge", self.name
)
self.array_data = epics_signal_r(NDArray[np.uint8], f"pva://{prefix}PVA:ARRAY")

# Soft parameters for pin-tip detection.
Expand Down Expand Up @@ -73,23 +80,29 @@ def __init__(self, prefix: str, name: str = ""):
)

self.set_readable_signals(
read=[self.triggered_tip],
read=[
self.triggered_tip,
self.triggered_top_edge,
self.triggered_bottom_edge,
],
)

super().__init__(name=name)

async def _set_triggered_tip(self, value):
if value == self.INVALID_POSITION:
async def _set_triggered_values(self, results: SampleLocation):
tip = (results.tip_x, results.tip_y)
if tip == self.INVALID_POSITION:
raise InvalidPinException
else:
await self.triggered_tip._backend.put(value)
await self.triggered_tip._backend.put(tip)
await self.triggered_top_edge._backend.put(results.edge_top)
await self.triggered_bottom_edge._backend.put(results.edge_bottom)

async def _get_tip_position(self, array_data: NDArray[np.uint8]) -> Tip:
async def _get_tip_and_edge_data(
self, array_data: NDArray[np.uint8]
) -> SampleLocation:
"""
Gets the location of the pin tip.
Returns tuple of:
(tip_x, tip_y)
Gets the location of the pin tip and the top and bottom edges.
"""
preprocess_key = await self.preprocess_operation.get_value()
preprocess_iter = await self.preprocess_iterations.get_value()
Expand Down Expand Up @@ -127,8 +140,7 @@ async def _get_tip_position(self, array_data: NDArray[np.uint8]) -> Tip:
(end_time - start_time) * 1000.0
)
)

return (location.tip_x, location.tip_y)
return location

async def connect(self, sim: bool = False):
await super().connect(sim)
Expand Down Expand Up @@ -156,7 +168,8 @@ async def _set_triggered_tip():
"""
async for value in observe_value(self.array_data):
try:
await self._set_triggered_tip(await self._get_tip_position(value))
location = await self._get_tip_and_edge_data(value)
await self._set_triggered_values(location)
except Exception as e:
LOGGER.warn(
f"Failed to detect pin-tip location, will retry with next image: {e}"
Expand All @@ -173,3 +186,5 @@ async def _set_triggered_tip():
f"No tip found in {await self.validity_timeout.get_value()} seconds."
)
await self.triggered_tip._backend.put(self.INVALID_POSITION)
await self.triggered_bottom_edge._backend.put(np.array([]))
await self.triggered_top_edge._backend.put(np.array([]))
6 changes: 3 additions & 3 deletions src/dodal/devices/oav/pin_image_recognition/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ class SampleLocation:
Holder type for results from sample detection.
"""

tip_y: Optional[int]
tip_x: Optional[int]
tip_y: Optional[int]
edge_top: np.ndarray
edge_bottom: np.ndarray

Expand Down Expand Up @@ -209,7 +209,7 @@ def _locate_sample(self, edge_arr: np.ndarray) -> SampleLocation:
"pin-tip detection: No non-narrow edges found - cannot locate pin tip"
)
return SampleLocation(
tip_y=None, tip_x=None, edge_bottom=bottom, edge_top=top
tip_x=None, tip_y=None, edge_bottom=bottom, edge_top=top
)

# Choose our starting point - i.e. first column with non-narrow width for positive scan, last one for negative scan.
Expand Down Expand Up @@ -248,5 +248,5 @@ def _locate_sample(self, edge_arr: np.ndarray) -> SampleLocation:
)
)
return SampleLocation(
tip_y=tip_y, tip_x=tip_x, edge_bottom=bottom, edge_top=top
tip_x=tip_x, tip_y=tip_y, edge_bottom=bottom, edge_top=top
)
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
pytest_plugins = ("pytest_asyncio",)
DEVICE_NAME = "pin_tip_detection"
TRIGGERED_TIP_READING = DEVICE_NAME + "-triggered_tip"
TRIGGERED_TOP_EDGE_READING = DEVICE_NAME + "-triggered_top_edge"
TRIGGERED_BOTTOM_EDGE_READING = DEVICE_NAME + "-triggered_bottom_edge"


async def _get_pin_tip_detection_device() -> PinTipDetection:
Expand Down Expand Up @@ -82,7 +84,7 @@ async def test_invalid_processing_func_uses_identity_function():
patch.object(MxSampleDetect, "__init__", return_value=None) as mock_init,
patch.object(MxSampleDetect, "processArray", return_value=test_sample_location),
):
await device._get_tip_position(np.array([]))
await device._get_tip_and_edge_data(np.array([]))

mock_init.assert_called_once()

Expand All @@ -97,7 +99,9 @@ async def test_invalid_processing_func_uses_identity_function():
async def test_given_valid_data_reading_then_used_to_find_location():
device = await _get_pin_tip_detection_device()
image_array = np.array([1, 2, 3])
test_sample_location = SampleLocation(100, 200, np.array([]), np.array([]))
test_sample_location = SampleLocation(
100, 200, np.array([1, 2, 3]), np.array([4, 5, 6])
)
set_sim_value(device.array_data, image_array)

with (
Expand All @@ -111,7 +115,13 @@ async def test_given_valid_data_reading_then_used_to_find_location():

process_call = mock_process_array.call_args[0][0]
assert np.array_equal(process_call, image_array)
assert location[TRIGGERED_TIP_READING]["value"] == (200, 100)
assert location[TRIGGERED_TIP_READING]["value"] == (100, 200)
assert np.all(
location[TRIGGERED_TOP_EDGE_READING]["value"] == np.array([1, 2, 3])
)
assert np.all(
location[TRIGGERED_BOTTOM_EDGE_READING]["value"] == np.array([4, 5, 6])
)
assert location[TRIGGERED_TIP_READING]["timestamp"] > 0


Expand All @@ -128,6 +138,8 @@ async def test_given_find_tip_fails_when_triggered_then_tip_invalid():
await device.trigger()
reading = await device.read()
assert reading[TRIGGERED_TIP_READING]["value"] == device.INVALID_POSITION
assert len(reading[TRIGGERED_TOP_EDGE_READING]["value"]) == 0
assert len(reading[TRIGGERED_BOTTOM_EDGE_READING]["value"]) == 0


@pytest.mark.asyncio
Expand Down Expand Up @@ -172,14 +184,25 @@ async def get_array_data(_):
device = await _get_pin_tip_detection_device()

class FakeLocation:
def __init__(self, tip_x, tip_y):
def __init__(self, tip_x, tip_y, edge_top, edge_bottom):
self.tip_x = tip_x
self.tip_y = tip_y
self.edge_top = edge_top
self.edge_bottom = edge_bottom

fake_top_edge = np.array([1, 2, 3])
fake_bottom_edge = np.array([4, 5, 6])

with patch.object(MxSampleDetect, "__init__", return_value=None), patch.object(
MxSampleDetect,
"processArray",
side_effect=[FakeLocation(None, None), FakeLocation(1, 1)],
with (
patch.object(MxSampleDetect, "__init__", return_value=None),
patch.object(
MxSampleDetect,
"processArray",
side_effect=[
FakeLocation(None, None, fake_top_edge, fake_bottom_edge),
FakeLocation(1, 1, fake_top_edge, fake_bottom_edge),
],
),
):
await device.trigger()
mock_logger.assert_called_once()

0 comments on commit 56ef04b

Please sign in to comment.