Skip to content

Commit

Permalink
Merge branch 'chunk-size-in-sres' of https://github.com/jwlodek/ophyd…
Browse files Browse the repository at this point in the history
…-async into chunk-size-in-sres
  • Loading branch information
jwlodek committed Sep 16, 2024
2 parents f105e34 + 4a76859 commit 395b58b
Show file tree
Hide file tree
Showing 46 changed files with 1,182 additions and 480 deletions.
8 changes: 7 additions & 1 deletion src/ophyd_async/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,14 @@
soft_signal_rw,
wait_for_value,
)
from ._signal_backend import RuntimeSubsetEnum, SignalBackend, SubsetEnum
from ._signal_backend import (
RuntimeSubsetEnum,
SignalBackend,
SubsetEnum,
)
from ._soft_signal_backend import SignalMetadata, SoftSignalBackend
from ._status import AsyncStatus, WatchableAsyncStatus, completed_status
from ._table import Table
from ._utils import (
DEFAULT_TIMEOUT,
CalculatableTimeout,
Expand Down Expand Up @@ -152,6 +157,7 @@
"CalculateTimeout",
"NotConnected",
"ReadingValueCallback",
"Table",
"T",
"WatcherUpdate",
"get_dtype",
Expand Down
92 changes: 53 additions & 39 deletions src/ophyd_async/core/_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,20 +51,24 @@ class TriggerInfo(BaseModel):
"""Minimal set of information required to setup triggering on a detector"""

#: Number of triggers that will be sent, 0 means infinite
number: int = Field(gt=0)
number: int = Field(ge=0)
#: Sort of triggers that will be sent
trigger: DetectorTrigger = Field()
trigger: DetectorTrigger = Field(default=DetectorTrigger.internal)
#: What is the minimum deadtime between triggers
deadtime: float | None = Field(ge=0)
deadtime: float | None = Field(default=None, ge=0)
#: What is the maximum high time of the triggers
livetime: float | None = Field(ge=0)
livetime: float | None = Field(default=None, ge=0)
#: What is the maximum timeout on waiting for a frame
frame_timeout: float | None = Field(None, gt=0)
frame_timeout: float | None = Field(default=None, gt=0)
#: How many triggers make up a single StreamDatum index, to allow multiple frames
#: from a faster detector to be zipped with a single frame from a slow detector
#: e.g. if num=10 and multiplier=5 then the detector will take 10 frames,
#: but publish 2 indices, and describe() will show a shape of (5, h, w)
multiplier: int = 1
#: The number of times the detector can go through a complete cycle of kickoff and
#: complete without needing to re-arm. This is important for detectors where the
#: process of arming is expensive in terms of time
iteration: int = 1


class DetectorControl(ABC):
Expand All @@ -78,27 +82,35 @@ def get_deadtime(self, exposure: float | None) -> float:
"""For a given exposure, how long should the time between exposures be"""

@abstractmethod
async def arm(
self,
num: int,
trigger: DetectorTrigger = DetectorTrigger.internal,
exposure: Optional[float] = None,
) -> AsyncStatus:
async def prepare(self, trigger_info: TriggerInfo):
"""
Arm detector, do all necessary steps to prepare detector for triggers.
Do all necessary steps to prepare the detector for triggers.
Args:
num: Expected number of frames
trigger: Type of trigger for which to prepare the detector. Defaults to
DetectorTrigger.internal.
exposure: Exposure time with which to set up the detector. Defaults to None
if not applicable or the detector is expected to use its previously-set
exposure time.
trigger_info: This is a Pydantic model which contains
number Expected number of frames.
trigger Type of trigger for which to prepare the detector. Defaults
to DetectorTrigger.internal.
livetime Livetime / Exposure time with which to set up the detector.
Defaults to None
if not applicable or the detector is expected to use its previously-set
exposure time.
deadtime Defaults to None. This is the minimum deadtime between
triggers.
multiplier The number of triggers grouped into a single StreamDatum
index.
"""

Returns:
AsyncStatus: Status representing the arm operation. This function returning
represents the start of the arm. The returned status completing means
the detector is now armed.
@abstractmethod
async def arm(self) -> None:
"""
Arm the detector
"""

@abstractmethod
async def wait_for_idle(self):
"""
This will wait on the internal _arm_status and wait for it to get disarmed/idle
"""

@abstractmethod
Expand Down Expand Up @@ -186,7 +198,7 @@ def __init__(
self._watchers: List[Callable] = []
self._fly_status: Optional[WatchableAsyncStatus] = None
self._fly_start: float

self._iterations_completed: int = 0
self._intial_frame: int
self._last_frame: int
super().__init__(name)
Expand Down Expand Up @@ -248,15 +260,15 @@ async def trigger(self) -> None:
trigger=DetectorTrigger.internal,
deadtime=None,
livetime=None,
frame_timeout=None,
)
)
assert self._trigger_info
assert self._trigger_info.trigger is DetectorTrigger.internal
# Arm the detector and wait for it to finish.
indices_written = await self.writer.get_indices_written()
written_status = await self.controller.arm(
num=self._trigger_info.number,
trigger=self._trigger_info.trigger,
)
await written_status
await self.controller.arm()
await self.controller.wait_for_idle()
end_observation = indices_written + 1

async for index in self.writer.observe_indices_written(
Expand All @@ -283,35 +295,35 @@ async def prepare(self, value: TriggerInfo) -> None:
Args:
value: TriggerInfo describing how to trigger the detector
"""
self._trigger_info = value
if value.trigger != DetectorTrigger.internal:
assert (
value.deadtime
), "Deadtime must be supplied when in externally triggered mode"
if value.deadtime:
required = self.controller.get_deadtime(self._trigger_info.livetime)
required = self.controller.get_deadtime(value.livetime)
assert required <= value.deadtime, (
f"Detector {self.controller} needs at least {required}s deadtime, "
f"but trigger logic provides only {value.deadtime}s"
)
self._trigger_info = value
self._initial_frame = await self.writer.get_indices_written()
self._last_frame = self._initial_frame + self._trigger_info.number
self._arm_status = await self.controller.arm(
num=self._trigger_info.number,
trigger=self._trigger_info.trigger,
exposure=self._trigger_info.livetime,
self._describe, _ = await asyncio.gather(
self.writer.open(value.multiplier), self.controller.prepare(value)
)
self._fly_start = time.monotonic()
self._describe = await self.writer.open(value.multiplier)
if value.trigger != DetectorTrigger.internal:
await self.controller.arm()
self._fly_start = time.monotonic()

@AsyncStatus.wrap
async def kickoff(self):
if not self._arm_status:
raise Exception("Detector not armed!")
assert self._trigger_info, "Prepare must be called before kickoff!"
if self._iterations_completed >= self._trigger_info.iteration:
raise Exception(f"Kickoff called more than {self._trigger_info.iteration}")
self._iterations_completed += 1

@WatchableAsyncStatus.wrap
async def complete(self):
assert self._arm_status, "Prepare not run"
assert self._trigger_info
async for index in self.writer.observe_indices_written(
self._trigger_info.frame_timeout
Expand All @@ -332,6 +344,8 @@ async def complete(self):
)
if index >= self._trigger_info.number:
break
if self._iterations_completed == self._trigger_info.iteration:
await self.controller.wait_for_idle()

async def describe_collect(self) -> Dict[str, DataKey]:
return self._describe
Expand Down
12 changes: 12 additions & 0 deletions src/ophyd_async/core/_device_save_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from bluesky.plan_stubs import abs_set, wait
from bluesky.protocols import Location
from bluesky.utils import Msg
from pydantic import BaseModel

from ._device import Device
from ._signal import SignalRW
Expand All @@ -18,6 +19,12 @@ def ndarray_representer(dumper: yaml.Dumper, array: npt.NDArray[Any]) -> yaml.No
)


def pydantic_model_abstraction_representer(
dumper: yaml.Dumper, model: BaseModel
) -> yaml.Node:
return dumper.represent_data(model.model_dump(mode="python"))


class OphydDumper(yaml.Dumper):
def represent_data(self, data: Any) -> Any:
if isinstance(data, Enum):
Expand Down Expand Up @@ -134,6 +141,11 @@ def save_to_yaml(phases: Sequence[Dict[str, Any]], save_path: str) -> None:
"""

yaml.add_representer(np.ndarray, ndarray_representer, Dumper=yaml.Dumper)
yaml.add_multi_representer(
BaseModel,
pydantic_model_abstraction_representer,
Dumper=yaml.Dumper,
)

with open(save_path, "w") as file:
yaml.dump(phases, file, Dumper=OphydDumper, default_flow_style=False)
Expand Down
2 changes: 1 addition & 1 deletion src/ophyd_async/core/_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ async def wait_for_value(self, signal: SignalR[T], timeout: Optional[float]):
try:
await asyncio.wait_for(self._wait_for_value(signal), timeout)
except asyncio.TimeoutError as e:
raise TimeoutError(
raise asyncio.TimeoutError(
f"{signal.name} didn't match {self._matcher_name} in {timeout}s, "
f"last value {self._last_value!r}"
) from e
Expand Down
15 changes: 14 additions & 1 deletion src/ophyd_async/core/_signal_backend.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
from abc import abstractmethod
from typing import TYPE_CHECKING, ClassVar, Generic, Literal, Optional, Tuple, Type
from typing import (
TYPE_CHECKING,
ClassVar,
Generic,
Literal,
Optional,
Tuple,
Type,
)

from ._protocol import DataKey, Reading
from ._utils import DEFAULT_TIMEOUT, ReadingValueCallback, T
Expand All @@ -11,6 +19,11 @@ class SignalBackend(Generic[T]):
#: Datatype of the signal value
datatype: Optional[Type[T]] = None

@classmethod
@abstractmethod
def datatype_allowed(cls, dtype: type):
"""Check if a given datatype is acceptable for this signal backend."""

#: Like ca://PV_PREFIX:SIGNAL
@abstractmethod
def source(self, name: str) -> str:
Expand Down
32 changes: 30 additions & 2 deletions src/ophyd_async/core/_soft_signal_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,20 @@

import inspect
import time
from abc import ABCMeta
from collections import abc
from enum import Enum
from typing import Dict, Generic, Optional, Tuple, Type, Union, cast, get_origin

import numpy as np
from bluesky.protocols import DataKey, Dtype, Reading
from pydantic import BaseModel
from typing_extensions import TypedDict

from ._signal_backend import RuntimeSubsetEnum, SignalBackend
from ._signal_backend import (
RuntimeSubsetEnum,
SignalBackend,
)
from ._utils import DEFAULT_TIMEOUT, ReadingValueCallback, T, get_dtype

primitive_dtypes: Dict[type, Dtype] = {
Expand Down Expand Up @@ -94,7 +99,7 @@ def make_initial_value(self, datatype: Optional[Type[T]]) -> T:
class SoftEnumConverter(SoftConverter):
choices: Tuple[str, ...]

def __init__(self, datatype: Union[RuntimeSubsetEnum, Enum]):
def __init__(self, datatype: Union[RuntimeSubsetEnum, Type[Enum]]):
if issubclass(datatype, Enum):
self.choices = tuple(v.value for v in datatype)
else:
Expand Down Expand Up @@ -122,17 +127,36 @@ def make_initial_value(self, datatype: Optional[Type[T]]) -> T:
return cast(T, self.choices[0])


class SoftPydanticModelConverter(SoftConverter):
def __init__(self, datatype: Type[BaseModel]):
self.datatype = datatype

def write_value(self, value):
if isinstance(value, dict):
return self.datatype(**value)
return value


def make_converter(datatype):
is_array = get_dtype(datatype) is not None
is_sequence = get_origin(datatype) == abc.Sequence
is_enum = inspect.isclass(datatype) and (
issubclass(datatype, Enum) or issubclass(datatype, RuntimeSubsetEnum)
)

is_pydantic_model = (
inspect.isclass(datatype)
# Necessary to avoid weirdness in ABCMeta.__subclasscheck__
and isinstance(datatype, ABCMeta)
and issubclass(datatype, BaseModel)
)

if is_array or is_sequence:
return SoftArrayConverter()
if is_enum:
return SoftEnumConverter(datatype)
if is_pydantic_model:
return SoftPydanticModelConverter(datatype)

return SoftConverter()

Expand All @@ -145,6 +169,10 @@ class SoftSignalBackend(SignalBackend[T]):
_timestamp: float
_severity: int

@classmethod
def datatype_allowed(cls, datatype: Type) -> bool:
return True # Any value allowed in a soft signal

def __init__(
self,
datatype: Optional[Type[T]],
Expand Down
58 changes: 58 additions & 0 deletions src/ophyd_async/core/_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import numpy as np
from pydantic import BaseModel, ConfigDict, model_validator


class Table(BaseModel):
"""An abstraction of a Table of str to numpy array."""

model_config = ConfigDict(validate_assignment=True, strict=False)

@classmethod
def row(cls, sub_cls, **kwargs) -> "Table":
arrayified_kwargs = {
field_name: np.concatenate(
(
(default_arr := field_value.default_factory()),
np.array([kwargs[field_name]], dtype=default_arr.dtype),
)
)
for field_name, field_value in sub_cls.model_fields.items()
}
return sub_cls(**arrayified_kwargs)

def __add__(self, right: "Table") -> "Table":
"""Concatenate the arrays in field values."""

assert isinstance(right, type(self)), (
f"{right} is not a `Table`, or is not the same "
f"type of `Table` as {self}."
)

return type(self)(
**{
field_name: np.concatenate(
(getattr(self, field_name), getattr(right, field_name))
)
for field_name in self.model_fields
}
)

@model_validator(mode="after")
def validate_arrays(self) -> "Table":
first_length = len(next(iter(self))[1])
assert all(
len(field_value) == first_length for _, field_value in self
), "Rows should all be of equal size."

if not all(
np.issubdtype(
self.model_fields[field_name].default_factory().dtype, field_value.dtype
)
for field_name, field_value in self
):
raise ValueError(
f"Cannot construct a `{type(self).__name__}`, "
"some rows have incorrect types."
)

return self
Loading

0 comments on commit 395b58b

Please sign in to comment.