Skip to content

Commit

Permalink
Fix pyright issues
Browse files Browse the repository at this point in the history
Fixes #346
  • Loading branch information
coretl committed Sep 17, 2024
1 parent 300833e commit 9539dd8
Show file tree
Hide file tree
Showing 45 changed files with 229 additions and 196 deletions.
4 changes: 2 additions & 2 deletions src/ophyd_async/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@
from ._status import AsyncStatus, WatchableAsyncStatus, completed_status
from ._table import Table
from ._utils import (
CALCULATE_TIMEOUT,
DEFAULT_TIMEOUT,
CalculatableTimeout,
CalculateTimeout,
NotConnected,
ReadingValueCallback,
T,
Expand Down Expand Up @@ -154,7 +154,7 @@
"WatchableAsyncStatus",
"DEFAULT_TIMEOUT",
"CalculatableTimeout",
"CalculateTimeout",
"CALCULATE_TIMEOUT",
"NotConnected",
"ReadingValueCallback",
"Table",
Expand Down
5 changes: 3 additions & 2 deletions src/ophyd_async/core/_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

from bluesky.protocols import (
Collectable,
DataKey,
Flyable,
Preparable,
Reading,
Expand All @@ -20,10 +19,12 @@
Triggerable,
WritesStreamAssets,
)
from event_model import DataKey
from pydantic import BaseModel, Field

from ._device import Device
from ._protocol import AsyncConfigurable, AsyncReadable
from ._signal import SignalR
from ._status import AsyncStatus, WatchableAsyncStatus
from ._utils import DEFAULT_TIMEOUT, T, WatcherUpdate, merge_gathered_dicts

Expand Down Expand Up @@ -168,7 +169,7 @@ def __init__(
self,
controller: DetectorControl,
writer: DetectorWriter,
config_sigs: Sequence[AsyncReadable] = (),
config_sigs: Sequence[SignalR] = (),
name: str = "",
) -> None:
"""
Expand Down
2 changes: 1 addition & 1 deletion src/ophyd_async/core/_flyer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Generic

from bluesky.protocols import Flyable, Preparable, Reading, Stageable
from event_model.documents.event_descriptor import DataKey
from event_model import DataKey

from ._device import Device
from ._signal import SignalR
Expand Down
8 changes: 2 additions & 6 deletions src/ophyd_async/core/_hdf_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
ComposeStreamResource,
ComposeStreamResourceBundle,
StreamDatum,
StreamRange,
StreamResource,
)

Expand Down Expand Up @@ -79,15 +80,10 @@ def stream_resources(self) -> Iterator[StreamResource]:
def stream_data(self, indices_written: int) -> Iterator[StreamDatum]:
# Indices are relative to resource
if indices_written > self._last_emitted:
indices = {
indices: StreamRange = {
"start": self._last_emitted,
"stop": indices_written,
}
self._last_emitted = indices_written
for bundle in self._bundles:
yield bundle.compose_stream_datum(indices)
return None

def close(self) -> None:
for bundle in self._bundles:
bundle.close()
4 changes: 3 additions & 1 deletion src/ophyd_async/core/_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class ColoredFormatterWithDeviceName(colorlog.ColoredFormatter):
def format(self, record):
message = super().format(record)
if hasattr(record, "ophyd_async_device_name"):
message = f"[{record.ophyd_async_device_name}]{message}"
message = f"[{record.ophyd_async_device_name}]{message}" # type: ignore
return message


Expand All @@ -39,6 +39,8 @@ def _validate_level(level) -> int:
levelno = level
elif isinstance(level, str):
levelno = logging.getLevelName(level)
else:
raise TypeError(f"Level {level!r} is not an int or str")

if isinstance(levelno, int):
return levelno
Expand Down
4 changes: 2 additions & 2 deletions src/ophyd_async/core/_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
)

from bluesky.protocols import HasName, Reading
from event_model.documents.event_descriptor import DataKey
from event_model import DataKey

if TYPE_CHECKING:
from ._status import AsyncStatus
Expand Down Expand Up @@ -57,7 +57,7 @@ async def describe(self) -> dict[str, DataKey]:


@runtime_checkable
class AsyncConfigurable(Protocol):
class AsyncConfigurable(HasName, Protocol):
@abstractmethod
async def read_configuration(self) -> dict[str, Reading]:
"""Same API as ``read`` but for slow-changing fields related to configuration.
Expand Down
2 changes: 1 addition & 1 deletion src/ophyd_async/core/_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def __init__(
num_calls_per_inc: int = 1,
increment: int = 1,
inc_delimeter: str = "_",
base_name: str = None,
base_name: str | None = None,
) -> None:
self._filename_provider = filename_provider
self._base_directory_path = base_directory_path
Expand Down
8 changes: 6 additions & 2 deletions src/ophyd_async/core/_readable.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from contextlib import contextmanager

from bluesky.protocols import HasHints, Hints, Reading
from event_model.documents.event_descriptor import DataKey
from event_model import DataKey

from ._device import Device, DeviceVector
from ._protocol import AsyncConfigurable, AsyncReadable, AsyncStageable
Expand Down Expand Up @@ -171,7 +171,7 @@ def add_children_as_readables(

def add_readables(
self,
devices: Sequence[Device],
devices: Sequence[ReadableChild],
wrapper: ReadableChildWrapper | None = None,
) -> None:
"""Add the given devices to the lists of known Devices
Expand Down Expand Up @@ -226,6 +226,10 @@ async def read_configuration(self) -> dict[str, Reading]:
async def describe_configuration(self) -> dict[str, DataKey]:
return await self.signal.describe()

@property
def name(self) -> str:
return self.signal.name


class HintedSignal(HasHints, AsyncReadable):
def __init__(self, signal: ReadableChild, allow_cache: bool = True) -> None:
Expand Down
46 changes: 28 additions & 18 deletions src/ophyd_async/core/_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,29 +3,25 @@
import asyncio
import functools
from collections.abc import AsyncGenerator, Callable, Mapping
from typing import (
Any,
Generic,
TypeVar,
)
from typing import Any, Generic, TypeVar, cast

from bluesky.protocols import (
DataKey,
Locatable,
Location,
Movable,
Reading,
Status,
Subscribable,
)
from event_model import DataKey

from ._device import Device
from ._mock_signal_backend import MockSignalBackend
from ._protocol import AsyncConfigurable, AsyncReadable, AsyncStageable
from ._signal_backend import SignalBackend
from ._soft_signal_backend import SignalMetadata, SoftSignalBackend
from ._status import AsyncStatus
from ._utils import DEFAULT_TIMEOUT, CalculatableTimeout, CalculateTimeout, Callback, T
from ._utils import CALCULATE_TIMEOUT, DEFAULT_TIMEOUT, CalculatableTimeout, Callback, T

S = TypeVar("S")

Expand All @@ -38,12 +34,25 @@ async def wrapper(self: Signal, *args, **kwargs):
return wrapper


def _fail(*args, **kwargs):
raise RuntimeError("Signal has not been supplied a backend yet")


class DisconnectedBackend(SignalBackend):
source = connect = put = get_datakey = get_reading = get_value = get_setpoint = (
set_callback
) = _fail


DISCONNECTED_BACKEND = DisconnectedBackend()


class Signal(Device, Generic[T]):
"""A Device with the concept of a value, with R, RW, W and X flavours"""

def __init__(
self,
backend: SignalBackend[T] | None = None,
backend: SignalBackend[T] = DISCONNECTED_BACKEND,
timeout: float | None = DEFAULT_TIMEOUT,
name: str = "",
) -> None:
Expand All @@ -59,7 +68,10 @@ async def connect(
backend: SignalBackend[T] | None = None,
):
if backend:
if self._backend and backend is not self._backend:
if (
self._backend is not DISCONNECTED_BACKEND
and backend is not self._backend
):
raise ValueError("Backend at connection different from previous one.")

self._backend = backend
Expand Down Expand Up @@ -230,10 +242,10 @@ class SignalW(Signal[T], Movable):
"""Signal that can be set"""

def set(
self, value: T, wait=True, timeout: CalculatableTimeout = CalculateTimeout
self, value: T, wait=True, timeout: CalculatableTimeout = CALCULATE_TIMEOUT
) -> AsyncStatus:
"""Set the value and return a status saying when it's done"""
if timeout is CalculateTimeout:
if timeout is CALCULATE_TIMEOUT:
timeout = self._timeout

async def do_set():
Expand Down Expand Up @@ -261,10 +273,10 @@ class SignalX(Signal):
"""Signal that puts the default value"""

def trigger(
self, wait=True, timeout: CalculatableTimeout = CalculateTimeout
self, wait=True, timeout: CalculatableTimeout = CALCULATE_TIMEOUT
) -> AsyncStatus:
"""Trigger the action and return a status saying when it's done"""
if timeout is CalculateTimeout:
if timeout is CALCULATE_TIMEOUT:
timeout = self._timeout
coro = self._backend.put(None, wait=wait, timeout=timeout)
return AsyncStatus(coro)
Expand Down Expand Up @@ -307,9 +319,7 @@ def soft_signal_r_and_setter(
return (signal, backend.set_value)


def _generate_assert_error_msg(
name: str, expected_result: str, actual_result: str
) -> str:
def _generate_assert_error_msg(name: str, expected_result, actual_result) -> str:
WARNING = "\033[93m"
FAIL = "\033[91m"
ENDC = "\033[0m"
Expand Down Expand Up @@ -475,7 +485,7 @@ async def get_value():
else:
break
else:
yield item
yield cast(T, item)
finally:
signal.clear_sub(q.put_nowait)

Expand Down Expand Up @@ -531,7 +541,7 @@ async def wait_for_value(
wait_for_value(device.num_captured, lambda v: v > 45, timeout=1)
"""
if callable(match):
checker = _ValueChecker(match, match.__name__)
checker = _ValueChecker(match, match.__name__) # type: ignore
else:
checker = _ValueChecker(lambda v: v == match, repr(match))
await checker.wait_for_value(signal, timeout)
Expand Down
7 changes: 4 additions & 3 deletions src/ophyd_async/core/_signal_backend.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from abc import abstractmethod
from typing import (
TYPE_CHECKING,
Any,
ClassVar,
Generic,
Literal,
)

from bluesky.protocols import Reading
from event_model.documents.event_descriptor import DataKey
from event_model import DataKey

from ._utils import DEFAULT_TIMEOUT, ReadingValueCallback, T

Expand All @@ -20,7 +21,7 @@ class SignalBackend(Generic[T]):

@classmethod
@abstractmethod
def datatype_allowed(cls, dtype: type):
def datatype_allowed(cls, dtype: Any) -> bool:
"""Check if a given datatype is acceptable for this signal backend."""

#: Like ca://PV_PREFIX:SIGNAL
Expand Down Expand Up @@ -61,7 +62,7 @@ def set_callback(self, callback: ReadingValueCallback[T] | None) -> None:
class _RuntimeSubsetEnumMeta(type):
def __str__(cls):
if hasattr(cls, "choices"):
return f"SubsetEnum{list(cls.choices)}"
return f"SubsetEnum{list(cls.choices)}" # type: ignore
return "SubsetEnum"

def __getitem__(cls, _choices):
Expand Down
Loading

0 comments on commit 9539dd8

Please sign in to comment.