diff --git a/conda-recipe/meta.yaml b/conda-recipe/meta.yaml index 15228dd..6408d19 100644 --- a/conda-recipe/meta.yaml +++ b/conda-recipe/meta.yaml @@ -21,6 +21,7 @@ requirements: - setuptools_scm run: - python >=3.9 + - apischema - pcdsutils - pyqt - qtpy diff --git a/docs/source/upcoming_release_notes/3-enh_data_model.rst b/docs/source/upcoming_release_notes/3-enh_data_model.rst new file mode 100644 index 0000000..ba697da --- /dev/null +++ b/docs/source/upcoming_release_notes/3-enh_data_model.rst @@ -0,0 +1,22 @@ +3 enh_data_model +################ + +API Breaks +---------- +- N/A + +Features +-------- +- Adds backend model dataclasses and basic validation methods + +Bugfixes +-------- +- N/A + +Maintenance +----------- +- N/A + +Contributors +------------ +- tangkong diff --git a/requirements.txt b/requirements.txt index 7c41b91..8754683 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ # List requirements here. +apischema pcdsutils PyQt5 qtpy diff --git a/superscore/model.py b/superscore/model.py new file mode 100644 index 0000000..3eb9ecf --- /dev/null +++ b/superscore/model.py @@ -0,0 +1,251 @@ +""" +Dataclasses structures for data model. + +All data objects inherit from Entry, which specifies common metadata +""" +from __future__ import annotations + +import logging +from dataclasses import dataclass, field +from datetime import datetime +from typing import (Any, List, Optional, Sequence, Union, get_args, get_origin, + get_type_hints) +from uuid import UUID, uuid4 + +from apischema.validation import ValidationError, validate, validator + +from superscore.serialization import as_tagged_union +from superscore.type_hints import AnyEpicsType + +logger = logging.getLogger(__name__) +_default_uuid = 'd3c8b6b8-7d4d-47aa-bb55-ba9c1b99bd9e' + + +@dataclass +@as_tagged_union +class Entry: + """ + Base class for items in the data model. + Holds common metadata and validation methods + """ + meta_id: UUID = field(default_factory=uuid4) + name: str = '' + description: str = '' + creation: datetime = field(default_factory=datetime.utcnow, compare=False) + + def validate(self, recursive: bool = True) -> None: + """ + Validate current values conform to type hints. + Throws ValidationError on failure + + Parameters + ---------- + recursive : bool, optional + whether or not to validate, by default True + """ + # apischema validates on deserialization, but we want to validate at runtime + # Will gather validator decorated methods + validate(self) + + @validator + def validate_types(self): + """validate any types inheriting from Entry are valid""" + # This probably could just use typeguard, but let's see if I can do it + hint_dict = get_type_hints(type(self)) + + for field_name, hint in hint_dict.items(): + self.validate_field(field_name, hint) + + def validate_field(self, field_name: str, hint: Any) -> None: + """ + Validate `self.{field_name}` matches type hint `hint`. + Speciallized to the type hints present in this module, + additions may require modification. + + Parameters + ---------- + field_name : str + the name of the field on self + hint : Any + the type hint we expect `self.{field_name}` to match + + Raises + ------ + ValidationError + if a type mismatch is found + """ + field_value = getattr(self, field_name) + origin = get_origin(hint) + is_list = False + + while origin: # Drill down through nested types, only Lists currently + if origin is Union: + break + elif origin in (list, Sequence): + hint = get_args(hint)[0] # Lists only have one type + origin = get_origin(hint) + is_list = True + # Mark list and check each entry in list + else: + origin = get_origin(hint) + hint = get_args(hint) + + # end condition + if origin is None: + break + + if Any in get_args(hint): + return + elif (origin is None): + if not isinstance(field_value, hint): + raise ValidationError( + f'improper type ({type(field_value)}) found in field ' + f'(expecting {hint})' + ) + elif (origin is Union): + if is_list: + list_comp = (isinstance(it, get_args(hint)) for it in field_value) + if not all(list_comp): + raise ValidationError( + f'improper type ({type(field_value)}) found in field ' + f'(expecting List[{get_args(hint)}])' + ) + elif not isinstance(field_value, get_args(hint)): + raise ValidationError( + f'improper type ({type(field_value)}) found in field ' + f'(expecting {get_args(hint)})' + ) + + +@dataclass +class Parameter(Entry): + """An Entry that stores a PV name""" + pv_name: str = '' + read_only: bool = False + + +@dataclass +class Value(Entry): + """ + An Entry that attaches a piece of data to a Parameter. + Can be thought of a PV - data pair + """ + data: AnyEpicsType = '' + origin: Union[UUID, Parameter] = _default_uuid + + def __post_init__(self): + if self.origin is _default_uuid: + raise TypeError("__init__ missing required argument: 'origin'") + + @classmethod + def from_origin(cls, origin: Parameter, data: Optional[AnyEpicsType] = None) -> Value: + """ + Create a Value from its originating Parameter and corresponding `data` + Note that the returned Value may not be valid. + + Parameters + ---------- + origin : Parameter + the parameter used to + data : Optional[AnyEpicsType] + The data read from the Parameter - `origin` + + Returns + ------- + Value + A filled and valid Value object + """ + new_value = cls( + name=origin.name + '_value', + description=f'Value generated from {origin.name}', + origin=origin + ) + + if data is not None: + new_value.data = data + + return new_value + + +@dataclass +class Collection(Entry): + """An Entry composed of Parameters and Collections.""" + parameters: List[Union[UUID, Parameter]] = field(default_factory=list) + collections: List[Union[UUID, Collection]] = field(default_factory=list) + + +@dataclass +class Snapshot(Entry): + """An Entry that attaches data to each sub-Entry of a Collection.""" + origin: Union[UUID, Collection] = _default_uuid + + values: List[Union[UUID, Value]] = field(default_factory=list) + snapshots: List[Union[UUID, Snapshot]] = field(default_factory=list) + + def __post_init__(self): + if self.origin is _default_uuid: + raise TypeError("__init__ missing required argument: 'origin'") + + @classmethod + def from_origin( + cls, + origin: Collection, + values: Optional[List[Union[UUID, Value]]] = None, + snapshots: Optional[List[Union[UUID, Snapshot]]] = None + ) -> Snapshot: + """ + Create a Snapshot from its originating Collection. + Note that the returned Snapshot may not be valid. + + Parameters + ---------- + origin : Collection + the Collection used to define this Snapshot + values : Optional[List[UUID | Value] + a list of Values to attach to this Snapshot + snapshots : Optional[List[UUID | Snapshot]] + a list of Snapshots to attach to this Snapshot + + Returns + ------- + Snapshot + A filled and valid Snapshot object + """ + new_snap = cls( + name=origin.name + '_snapshot', + description=f'Snapshot generated from {origin.name}', + origin=origin + ) + + if values is not None: + new_snap.values = values + + if snapshots is not None: + new_snap.snapshots = snapshots + + return new_snap + + @validator + def validate_tree(self) -> None: + """ + Validate the values and snapshots match those specified in origin. + Structure should be identical, and Values/Snapshots should reference the + Parameters/Collections in self.origin + """ + # TODO: complete this method + return + + @validator + def validate_loop(self) -> None: + """Check that the Snapshot is not self-referential (does not loop)""" + return + + +@dataclass +class Root: + """Base level structure holding Entry objects""" + entries: List[Entry] = field(default_factory=list) + + def validate(self): + for entry in self.entries: + entry.validate() diff --git a/superscore/serialization.py b/superscore/serialization.py new file mode 100644 index 0000000..69c96fd --- /dev/null +++ b/superscore/serialization.py @@ -0,0 +1,138 @@ +""" +Serialization helpers for apischema. +""" +# Largely based on issue discussions regarding tagged unions. + +from collections import defaultdict +from collections.abc import Callable, Iterator +from types import new_class +from typing import (Any, Dict, Generic, List, Tuple, TypeVar, get_origin, + get_type_hints) + +from apischema import deserializer, serializer, type_name +from apischema.conversions import Conversion +from apischema.metadata import conversion +from apischema.objects import object_deserialization +from apischema.tagged_unions import Tagged, TaggedUnion, get_tagged +from apischema.utils import to_pascal_case + +_alternative_constructors: Dict[type, List[Callable]] = defaultdict(list) +Func = TypeVar("Func", bound=Callable) + + +def alternative_constructor(func: Func) -> Func: + """Alternative constructor for a given type.""" + return_type = get_type_hints(func)["return"] + _alternative_constructors[get_origin(return_type) or return_type].append(func) + return func + + +def get_all_subclasses(cls: type) -> Iterator[type]: + """Recursive implementation of type.__subclasses__""" + for sub_cls in cls.__subclasses__(): + yield sub_cls + yield from get_all_subclasses(sub_cls) + + +Cls = TypeVar("Cls", bound=type) + + +def _get_generic_name_factory(cls: type, *args: type): + def _capitalized(name: str) -> str: + return name[0].upper() + name[1:] + + return "".join((cls.__name__, *(_capitalized(arg.__name__) for arg in args))) + + +generic_name = type_name(_get_generic_name_factory) + + +def as_tagged_union(cls: Cls) -> Cls: + """ + Tagged union decorator, to be used on base class. + + Supports generics as well, with names generated by way of + `_get_generic_name_factory`. + """ + params = tuple(getattr(cls, "__parameters__", ())) + tagged_union_bases: Tuple[type, ...] = (TaggedUnion,) + + # Generic handling is here: + if params: + tagged_union_bases = (TaggedUnion, Generic[params]) + generic_name(cls) + prev_init_subclass = getattr(cls, "__init_subclass__", None) + + def __init_subclass__(cls, **kwargs): + if prev_init_subclass is not None: + prev_init_subclass(**kwargs) + generic_name(cls) + + cls.__init_subclass__ = classmethod(__init_subclass__) + + def with_params(cls: type) -> Any: + """Specify type of Generic if set.""" + return cls[params] if params else cls + + def serialization() -> Conversion: + """ + Define the serializer Conversion for the tagged union. + + source is the base ``cls`` (or ``cls[T]``). + target is the new tagged union class ``TaggedUnion`` which gets the + dictionary {cls.__name__: obj} as its arguments. + """ + annotations = { + # Assume that subclasses have same generic parameters than cls + sub.__name__: Tagged[with_params(sub)] + for sub in get_all_subclasses(cls) + } + namespace = {"__annotations__": annotations} + tagged_union = new_class( + cls.__name__, tagged_union_bases, exec_body=lambda ns: ns.update(namespace) + ) + return Conversion( + lambda obj: tagged_union(**{obj.__class__.__name__: obj}), + source=with_params(cls), + target=with_params(tagged_union), + # Conversion must not be inherited because it would lead to + # infinite recursion otherwise + inherited=False, + ) + + def deserialization() -> Conversion: + """ + Define the deserializer Conversion for the tagged union. + + Allows for alternative standalone constructors as per the apischema + example. + """ + annotations: dict[str, Any] = {} + namespace: dict[str, Any] = {"__annotations__": annotations} + for sub in get_all_subclasses(cls): + annotations[sub.__name__] = Tagged[with_params(sub)] + for constructor in _alternative_constructors.get(sub, ()): + # Build the alias of the field + alias = to_pascal_case(constructor.__name__) + # object_deserialization uses get_type_hints, but the constructor + # return type is stringified and the class not defined yet, + # so it must be assigned manually + constructor.__annotations__["return"] = with_params(sub) + # Use object_deserialization to wrap constructor as deserializer + deserialization = object_deserialization(constructor, generic_name) + # Add constructor tagged field with its conversion + annotations[alias] = Tagged[with_params(sub)] + namespace[alias] = Tagged(conversion(deserialization=deserialization)) + # Create the deserialization tagged union class + tagged_union = new_class( + cls.__name__, tagged_union_bases, exec_body=lambda ns: ns.update(namespace) + ) + return Conversion( + lambda obj: get_tagged(obj)[1], + source=with_params(tagged_union), + target=with_params(cls), + ) + + deserializer(lazy=deserialization, target=cls) + serializer(lazy=serialization, source=cls) + return cls diff --git a/superscore/tests/conftest.py b/superscore/tests/conftest.py index e69de29..59470ae 100644 --- a/superscore/tests/conftest.py +++ b/superscore/tests/conftest.py @@ -0,0 +1,58 @@ +import pytest + +from superscore.model import Collection, Parameter, Root, Snapshot, Value + + +@pytest.fixture(scope='function') +def sample_database(): + """ + A sample superscore database, including all the Entry types. + Corresponds to a caproto.ioc_examples.fake_motor_record, which mimics an IMS + motor record + """ + root = Root() + + param_1 = Parameter( + name='parameter 1', + description='parameter in root', + pv_name='MY:MOTOR:mtr1.ACCL' + ) + root.entries.append(param_1) + + value_1 = Value( + name='value 1', + description='Value created from parameter 1', + data=2, + origin=param_1 + ) + root.entries.append(value_1) + + coll_1 = Collection( + name='collection 1', + description='collection defining some motor fields', + ) + snap_1 = Snapshot( + name='snapshot 1', + description='Snapshot created from collection 1', + origin=coll_1 + ) + for fld, data in zip(['ACCL', 'VELO', 'PREC'], [2, 2, 6]): # Defaults[1, 1, 3] + sub_param = Parameter( + name=f'coll_1_field_{fld}', + description=f'motor field {fld}', + pv_name=f'MY:PREFIX:mtr1.{fld}' + ) + sub_value = Value( + name=f'coll_1_fld_{fld} value', + description=f'motor_field_{fld} value', + data=data, + origin=sub_param + ) + + coll_1.parameters.append(sub_param) + snap_1.values.append(sub_value) + + root.entries.append(coll_1) + root.entries.append(snap_1) + + return root diff --git a/superscore/tests/test_model.py b/superscore/tests/test_model.py new file mode 100644 index 0000000..261b814 --- /dev/null +++ b/superscore/tests/test_model.py @@ -0,0 +1,75 @@ +from typing import Any +from uuid import uuid4 + +import pytest +from apischema import ValidationError + +from superscore.model import Root, Value + + +def test_validate_fields(sample_database: Root): + sample_database.validate() + + +# TODO: be comprehensive with these validation failures +@pytest.mark.parametrize( + 'entry_index,replace_fld,replace_obj', + [ + [0, 'pv_name', 1], + [2, 'parameters', [1, 2, 3, 4]] + ] +) +def test_validate_failure( + sample_database: Root, + entry_index: int, + replace_fld: str, + replace_obj: Any +): + """Passes if improper types fail to validate""" + entry = sample_database.entries[entry_index] + sample_database.validate() + setattr(entry, replace_fld, replace_obj) + + with pytest.raises(ValidationError): + sample_database.validate() + with pytest.raises(ValidationError): + entry.validate() + + +@pytest.mark.parametrize( + 'data,valid', + [ + [1, True], ['one', True], [True, True], [1.1, True], + [object(), False], [Root(), False] + ] +) +def test_epics_type_validate(data: Any, valid: bool): + """Passes if EPICS types are validated correctly""" + value = Value( + name='My Value', + description='description value', + data=data, + origin=uuid4() + ) + if not valid: + with pytest.raises(ValidationError): + value.validate() + else: + value.validate() + + +def test_uuid_validate(sample_database: Root): + """ + Passes if uuids can be validated. + Note that this does not check if said uuids reference valid objects + """ + def replace_origin_with_uuid(entry): + """Recursively replace origin with a random uuid""" + if hasattr(entry, 'origin'): + replace_origin_with_uuid(getattr(entry, 'origin')) + setattr(entry, 'origin', uuid4()) + + for entry in sample_database.entries: + replace_origin_with_uuid(entry) + + sample_database.validate() diff --git a/superscore/type_hints.py b/superscore/type_hints.py new file mode 100644 index 0000000..94e6261 --- /dev/null +++ b/superscore/type_hints.py @@ -0,0 +1,3 @@ +from typing import Union + +AnyEpicsType = Union[int, str, float, bool]