From c01ac48ab679a2da44644ddfe591835fa8b3240f Mon Sep 17 00:00:00 2001 From: tangkong Date: Mon, 26 Feb 2024 14:26:51 -0800 Subject: [PATCH 1/9] ENH: add initial model dataclasses, validation methods --- superscore/model.py | 223 +++++++++++++++++++++++++++++++++++++++ superscore/type_hints.py | 3 + 2 files changed, 226 insertions(+) create mode 100644 superscore/model.py create mode 100644 superscore/type_hints.py diff --git a/superscore/model.py b/superscore/model.py new file mode 100644 index 0000000..e03f9d1 --- /dev/null +++ b/superscore/model.py @@ -0,0 +1,223 @@ +""" +Dataclasses structures for data model. + +All data objects inherit from Entry, which specifies common metadata +""" +from __future__ import annotations + +import logging +from dataclasses import dataclass, field +from datetime import datetime +from typing import (Any, List, Optional, Sequence, Union, get_args, get_origin, + get_type_hints) +from uuid import UUID, uuid4 + +from apischema.validation import validate, validator + +from superscore.type_hints import AnyEpicsType + +logger = logging.getLogger(__name__) +_default_uuid = 'd3c8b6b8-7d4d-47aa-bb55-ba9c1b99bd9e' + + +@dataclass +class Entry: + """ + Base class for items in the data model. + Holds common metadata and validation methods + """ + meta_id: UUID = field(default_factory=uuid4) + name: str = '' + description: str = '' + creation: datetime = field(default_factory=datetime.utcnow, compare=False) + + def validate(self, recursive: bool = True) -> None: + """ + Validate current values conform to type hints. Throws ValidationError on failure + + Parameters + ---------- + recursive : bool, optional + whether or not to validate , by default True + """ + # apischema validates on deserialization, but we want to validate at runtime + # Will gather validator decorated methods + validate(self) + + @validator + def validate_types(self): + """validate any types inheriting from Entry are valid""" + # This probably could just use typeguard, but let's see if I can do it + hint_dict = get_type_hints(type(self)) + + for field_name, hint in hint_dict.items(): + self.validate_field(field_name, hint) + + def validate_field(self, field_name: str, hint: Any) -> None: + """ + Validate `self.{field_name}` matches type hint `hint` + + Parameters + ---------- + field_name : str + the name of the field on self + hint : Any + the type hint we expect `self.{field_name}` to match + + Raises + ------ + TypeError + if a type mismatch is found + """ + field_value = getattr(self, field_name) + origin = get_origin(hint) + is_list = False + + while origin: # Drill down through nested types, only Lists currently + if origin is Union: + break + elif origin in (list, Sequence): + hint = get_args(hint)[0] # Lists only have one type + origin = get_origin(hint) + is_list = True + # Mark list and check each entry in list + else: + origin = get_origin(hint) + hint = get_args(hint) + + # end condition + if origin is None: + break + + if Any in get_args(hint): + return + elif (origin is None): + if not isinstance(field_value, hint): + raise TypeError('improper type found in field') + elif (origin is Union) and (UUID in get_args(hint)): + # Case of interest. A hint of Union[UUID, SomeType] + if is_list: + list_comp = (isinstance(it, get_args(hint)) for it in field_value) + if not all(list_comp): + raise TypeError('improper type in list-field') + elif not isinstance(field_value, get_args(hint)): + raise TypeError('improper type found in field') + + +@dataclass +class Parameter(Entry): + """An Entry that stores a PV name""" + pv_name: str = '' + read_only: bool = False + + +@dataclass +class Value(Entry): + """ + An Entry that attaches a piece of data to a Parameter. + Can be thought of a PV - data pair + """ + data: AnyEpicsType = '' + origin: Union[UUID, Parameter] = _default_uuid + + def __post_init__(self): + if self.origin is _default_uuid: + raise TypeError("__init__ missing required argument: 'origin'") + + @classmethod + def from_origin(cls, origin: Parameter, data: Optional[AnyEpicsType] = None) -> Value: + """ + Create a Value from its originating Parameter and corresponding `data` + Note that the returned Value may not be valid. + + Parameters + ---------- + origin : Parameter + the parameter used to + data : Optional[AnyEpicsType] + The data read from the Parameter - `origin` + + Returns + ------- + Value + A filled and valid Value object + """ + new_value = cls( + name=origin.name + '_value', + description=f'Value generated from {origin.name}', + origin=origin + ) + + if data is not None: + new_value.data = data + + return new_value + + +@dataclass +class Collection(Entry): + parameters: List[Union[UUID, Parameter]] = field(default_factory=list) + collections: List[Union[UUID, Collection]] = field(default_factory=list) + + +@dataclass +class Snapshot(Entry): + origin: Union[UUID, Collection] = _default_uuid + + values: List[Union[UUID, Value]] = field(default_factory=list) + snapshots: List[Union[UUID, Snapshot]] = field(default_factory=list) + + def __post_init__(self): + if self.origin is _default_uuid: + raise TypeError("__init__ missing required argument: 'origin'") + + @classmethod + def from_origin( + cls, + origin: Collection, + values: Optional[List[Union[UUID, Value]]] = None, + snapshots: Optional[List[Union[UUID, Snapshot]]] = None + ) -> Snapshot: + """ + Create a Snapshot from its originating Collection. + Note that the returned Snapshot may not be valid. + + Parameters + ---------- + origin : Collection + the Collection used to define this Snapshot + + Returns + ------- + Snapshot + A filled and valid Snapshot object + """ + new_snap = cls( + name=origin.name + '_snapshot', + description=f'Snapshot generated from {origin.name}', + origin=origin + ) + + if values is not None: + new_snap.values = values + + if snapshots is not None: + new_snap.snapshots = snapshots + + return new_snap + + @validator + def validate_tree(self) -> None: + """Validate the values and snapshots match those specified in origin""" + # TODO: complete this method + return + + +@dataclass +class Root: + """Base level structure holding Entry objects""" + entries: List[Entry] = field(default_factory=list) + + def validate(self): + for entry in self.entries: + entry.validate() diff --git a/superscore/type_hints.py b/superscore/type_hints.py new file mode 100644 index 0000000..94e6261 --- /dev/null +++ b/superscore/type_hints.py @@ -0,0 +1,3 @@ +from typing import Union + +AnyEpicsType = Union[int, str, float, bool] From 4323c132536074a8a6f24dbdc0b79a445e9f90f7 Mon Sep 17 00:00:00 2001 From: tangkong Date: Mon, 26 Feb 2024 15:32:10 -0800 Subject: [PATCH 2/9] TST/WIP: basic validation tests, simple database fixture --- superscore/tests/conftest.py | 58 ++++++++++++++++++++++++++++++++++ superscore/tests/test_model.py | 37 ++++++++++++++++++++++ 2 files changed, 95 insertions(+) create mode 100644 superscore/tests/test_model.py diff --git a/superscore/tests/conftest.py b/superscore/tests/conftest.py index e69de29..59470ae 100644 --- a/superscore/tests/conftest.py +++ b/superscore/tests/conftest.py @@ -0,0 +1,58 @@ +import pytest + +from superscore.model import Collection, Parameter, Root, Snapshot, Value + + +@pytest.fixture(scope='function') +def sample_database(): + """ + A sample superscore database, including all the Entry types. + Corresponds to a caproto.ioc_examples.fake_motor_record, which mimics an IMS + motor record + """ + root = Root() + + param_1 = Parameter( + name='parameter 1', + description='parameter in root', + pv_name='MY:MOTOR:mtr1.ACCL' + ) + root.entries.append(param_1) + + value_1 = Value( + name='value 1', + description='Value created from parameter 1', + data=2, + origin=param_1 + ) + root.entries.append(value_1) + + coll_1 = Collection( + name='collection 1', + description='collection defining some motor fields', + ) + snap_1 = Snapshot( + name='snapshot 1', + description='Snapshot created from collection 1', + origin=coll_1 + ) + for fld, data in zip(['ACCL', 'VELO', 'PREC'], [2, 2, 6]): # Defaults[1, 1, 3] + sub_param = Parameter( + name=f'coll_1_field_{fld}', + description=f'motor field {fld}', + pv_name=f'MY:PREFIX:mtr1.{fld}' + ) + sub_value = Value( + name=f'coll_1_fld_{fld} value', + description=f'motor_field_{fld} value', + data=data, + origin=sub_param + ) + + coll_1.parameters.append(sub_param) + snap_1.values.append(sub_value) + + root.entries.append(coll_1) + root.entries.append(snap_1) + + return root diff --git a/superscore/tests/test_model.py b/superscore/tests/test_model.py new file mode 100644 index 0000000..027e1f2 --- /dev/null +++ b/superscore/tests/test_model.py @@ -0,0 +1,37 @@ +from typing import Any + +import pytest + +from superscore.model import Root + + +def test_validate_fields(sample_database: Root): + sample_database.validate() + + +# TODO: be comprehensive with these validation failures +@pytest.mark.parametrize( + 'entry_index,replace_fld,replace_obj', + [ + [0, 'pv_name', 1], + [2, 'parameters', [1, 2, 3, 4]] + ] +) +def test_validate_failure( + sample_database: Root, + entry_index: int, + replace_fld: str, + replace_obj: Any +): + entry = sample_database.entries[entry_index] + sample_database.validate() + setattr(entry, replace_fld, replace_obj) + + with pytest.raises(TypeError): + sample_database.validate() + with pytest.raises(TypeError): + entry.validate() + + +def test_backend_load(): + assert True From 7b9a861db18c1aa1dbebcfc514ffa719eec97eb0 Mon Sep 17 00:00:00 2001 From: tangkong Date: Mon, 26 Feb 2024 16:09:01 -0800 Subject: [PATCH 3/9] MNT: validators can only throw ValidationError --- superscore/model.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/superscore/model.py b/superscore/model.py index e03f9d1..eeae4c4 100644 --- a/superscore/model.py +++ b/superscore/model.py @@ -12,7 +12,7 @@ get_type_hints) from uuid import UUID, uuid4 -from apischema.validation import validate, validator +from apischema.validation import ValidationError, validate, validator from superscore.type_hints import AnyEpicsType @@ -66,7 +66,7 @@ def validate_field(self, field_name: str, hint: Any) -> None: Raises ------ - TypeError + ValidationError if a type mismatch is found """ field_value = getattr(self, field_name) @@ -93,15 +93,15 @@ def validate_field(self, field_name: str, hint: Any) -> None: return elif (origin is None): if not isinstance(field_value, hint): - raise TypeError('improper type found in field') + raise ValidationError('improper type found in field') elif (origin is Union) and (UUID in get_args(hint)): # Case of interest. A hint of Union[UUID, SomeType] if is_list: list_comp = (isinstance(it, get_args(hint)) for it in field_value) if not all(list_comp): - raise TypeError('improper type in list-field') + raise ValidationError('improper type in list-field') elif not isinstance(field_value, get_args(hint)): - raise TypeError('improper type found in field') + raise ValidationError('improper type found in field') @dataclass From da8eca21245ebe98e7ce28995266370ec8322cfd Mon Sep 17 00:00:00 2001 From: tangkong Date: Mon, 26 Feb 2024 16:14:47 -0800 Subject: [PATCH 4/9] BLD: add apischema --- conda-recipe/meta.yaml | 1 + requirements.txt | 1 + 2 files changed, 2 insertions(+) diff --git a/conda-recipe/meta.yaml b/conda-recipe/meta.yaml index 15228dd..6408d19 100644 --- a/conda-recipe/meta.yaml +++ b/conda-recipe/meta.yaml @@ -21,6 +21,7 @@ requirements: - setuptools_scm run: - python >=3.9 + - apischema - pcdsutils - pyqt - qtpy diff --git a/requirements.txt b/requirements.txt index 7c41b91..8754683 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ # List requirements here. +apischema pcdsutils PyQt5 qtpy From 0e8985bdf0927a503fd836d1193634e4fd3aa75a Mon Sep 17 00:00:00 2001 From: tangkong Date: Mon, 26 Feb 2024 16:34:22 -0800 Subject: [PATCH 5/9] TST: catch ValidationError instead of TypeError --- superscore/tests/test_model.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/superscore/tests/test_model.py b/superscore/tests/test_model.py index 027e1f2..3c9002a 100644 --- a/superscore/tests/test_model.py +++ b/superscore/tests/test_model.py @@ -1,6 +1,7 @@ from typing import Any import pytest +from apischema import ValidationError from superscore.model import Root @@ -27,9 +28,9 @@ def test_validate_failure( sample_database.validate() setattr(entry, replace_fld, replace_obj) - with pytest.raises(TypeError): + with pytest.raises(ValidationError): sample_database.validate() - with pytest.raises(TypeError): + with pytest.raises(ValidationError): entry.validate() From 24bcbb25022da7d2f9f4bce766e8b70cd5f56ae1 Mon Sep 17 00:00:00 2001 From: tangkong Date: Tue, 27 Feb 2024 12:14:49 -0800 Subject: [PATCH 6/9] MNT: fix type verification for AnyEpicsType --- superscore/model.py | 22 +++++++++++++----- superscore/tests/test_model.py | 41 +++++++++++++++++++++++++++++++--- 2 files changed, 54 insertions(+), 9 deletions(-) diff --git a/superscore/model.py b/superscore/model.py index eeae4c4..5799d13 100644 --- a/superscore/model.py +++ b/superscore/model.py @@ -55,7 +55,9 @@ def validate_types(self): def validate_field(self, field_name: str, hint: Any) -> None: """ - Validate `self.{field_name}` matches type hint `hint` + Validate `self.{field_name}` matches type hint `hint`. + Speciallized to the type hints present in this module, + additions may require modification. Parameters ---------- @@ -93,15 +95,23 @@ def validate_field(self, field_name: str, hint: Any) -> None: return elif (origin is None): if not isinstance(field_value, hint): - raise ValidationError('improper type found in field') - elif (origin is Union) and (UUID in get_args(hint)): - # Case of interest. A hint of Union[UUID, SomeType] + raise ValidationError( + f'improper type ({type(field_value)}) found in field ' + f'(expecting {hint})' + ) + elif (origin is Union): if is_list: list_comp = (isinstance(it, get_args(hint)) for it in field_value) if not all(list_comp): - raise ValidationError('improper type in list-field') + raise ValidationError( + f'improper type ({type(field_value)}) found in field ' + f'(expecting List[{get_args(hint)}])' + ) elif not isinstance(field_value, get_args(hint)): - raise ValidationError('improper type found in field') + raise ValidationError( + f'improper type ({type(field_value)}) found in field ' + f'(expecting {get_args(hint)})' + ) @dataclass diff --git a/superscore/tests/test_model.py b/superscore/tests/test_model.py index 3c9002a..6ac4ecd 100644 --- a/superscore/tests/test_model.py +++ b/superscore/tests/test_model.py @@ -1,9 +1,10 @@ from typing import Any +from uuid import uuid4 import pytest from apischema import ValidationError -from superscore.model import Root +from superscore.model import Root, Value def test_validate_fields(sample_database: Root): @@ -34,5 +35,39 @@ def test_validate_failure( entry.validate() -def test_backend_load(): - assert True +@pytest.mark.parametrize( + 'data,valid', + [ + [1, True], ['one', True], [True, True], [1.1, True], + [object(), False] + ] +) +def test_epics_type_validate(data: Any, valid: bool): + value = Value( + name='My Value', + description='description value', + data=data, + origin=uuid4() + ) + if not valid: + with pytest.raises(ValidationError): + value.validate() + else: + value.validate() + + +def test_uuid_validate(sample_database: Root): + """ + Passes if uuids can be validated. + Note that this does not check if said uuids reference valid objects + """ + def replace_origin_with_uuid(entry): + """Recursively replace origin with a random uuid""" + if hasattr(entry, 'origin'): + replace_origin_with_uuid(getattr(entry, 'origin')) + setattr(entry, 'origin', uuid4()) + + for entry in sample_database.entries: + replace_origin_with_uuid(entry) + + sample_database.validate() From 5faa1b8f642b214bde304afea0438ac88d6029ed Mon Sep 17 00:00:00 2001 From: tangkong Date: Tue, 27 Feb 2024 12:30:57 -0800 Subject: [PATCH 7/9] DOC: touch up docstrings, leave some notes --- superscore/model.py | 22 +++++++++++++++++++--- superscore/tests/test_model.py | 4 +++- 2 files changed, 22 insertions(+), 4 deletions(-) diff --git a/superscore/model.py b/superscore/model.py index 5799d13..6262176 100644 --- a/superscore/model.py +++ b/superscore/model.py @@ -33,12 +33,13 @@ class Entry: def validate(self, recursive: bool = True) -> None: """ - Validate current values conform to type hints. Throws ValidationError on failure + Validate current values conform to type hints. + Throws ValidationError on failure Parameters ---------- recursive : bool, optional - whether or not to validate , by default True + whether or not to validate, by default True """ # apischema validates on deserialization, but we want to validate at runtime # Will gather validator decorated methods @@ -166,12 +167,14 @@ def from_origin(cls, origin: Parameter, data: Optional[AnyEpicsType] = None) -> @dataclass class Collection(Entry): + """An Entry composed of Parameters and Collections.""" parameters: List[Union[UUID, Parameter]] = field(default_factory=list) collections: List[Union[UUID, Collection]] = field(default_factory=list) @dataclass class Snapshot(Entry): + """An Entry that attaches data to each sub-Entry of a Collection.""" origin: Union[UUID, Collection] = _default_uuid values: List[Union[UUID, Value]] = field(default_factory=list) @@ -196,6 +199,10 @@ def from_origin( ---------- origin : Collection the Collection used to define this Snapshot + values : Optional[List[UUID | Value] + a list of Values to attach to this Snapshot + snapshots : Optional[List[UUID | Snapshot]] + a list of Snapshots to attach to this Snapshot Returns ------- @@ -218,10 +225,19 @@ def from_origin( @validator def validate_tree(self) -> None: - """Validate the values and snapshots match those specified in origin""" + """ + Validate the values and snapshots match those specified in origin. + Structure should be identical, and Values/Snapshots should reference the + Parameters/Collections in self.origin + """ # TODO: complete this method return + @validator + def validate_loop(self) -> None: + """Check that the Snapshot is not self-referential (does not loop)""" + return + @dataclass class Root: diff --git a/superscore/tests/test_model.py b/superscore/tests/test_model.py index 6ac4ecd..261b814 100644 --- a/superscore/tests/test_model.py +++ b/superscore/tests/test_model.py @@ -25,6 +25,7 @@ def test_validate_failure( replace_fld: str, replace_obj: Any ): + """Passes if improper types fail to validate""" entry = sample_database.entries[entry_index] sample_database.validate() setattr(entry, replace_fld, replace_obj) @@ -39,10 +40,11 @@ def test_validate_failure( 'data,valid', [ [1, True], ['one', True], [True, True], [1.1, True], - [object(), False] + [object(), False], [Root(), False] ] ) def test_epics_type_validate(data: Any, valid: bool): + """Passes if EPICS types are validated correctly""" value = Value( name='My Value', description='description value', From 197115cf370f7d2b3c4bc4268269f56afc3bfd56 Mon Sep 17 00:00:00 2001 From: tangkong Date: Tue, 27 Feb 2024 12:33:05 -0800 Subject: [PATCH 8/9] DOC: add pre-release notes --- .../3-enh_data_model.rst | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 docs/source/upcoming_release_notes/3-enh_data_model.rst diff --git a/docs/source/upcoming_release_notes/3-enh_data_model.rst b/docs/source/upcoming_release_notes/3-enh_data_model.rst new file mode 100644 index 0000000..ba697da --- /dev/null +++ b/docs/source/upcoming_release_notes/3-enh_data_model.rst @@ -0,0 +1,22 @@ +3 enh_data_model +################ + +API Breaks +---------- +- N/A + +Features +-------- +- Adds backend model dataclasses and basic validation methods + +Bugfixes +-------- +- N/A + +Maintenance +----------- +- N/A + +Contributors +------------ +- tangkong From 8ebdb6c4fd6966e82b8873c639bb0ffe05c30693 Mon Sep 17 00:00:00 2001 From: tangkong Date: Tue, 27 Feb 2024 15:29:30 -0800 Subject: [PATCH 9/9] MNT: add as_tagged_union decorator as in atef --- superscore/model.py | 2 + superscore/serialization.py | 138 ++++++++++++++++++++++++++++++++++++ 2 files changed, 140 insertions(+) create mode 100644 superscore/serialization.py diff --git a/superscore/model.py b/superscore/model.py index 6262176..3eb9ecf 100644 --- a/superscore/model.py +++ b/superscore/model.py @@ -14,6 +14,7 @@ from apischema.validation import ValidationError, validate, validator +from superscore.serialization import as_tagged_union from superscore.type_hints import AnyEpicsType logger = logging.getLogger(__name__) @@ -21,6 +22,7 @@ @dataclass +@as_tagged_union class Entry: """ Base class for items in the data model. diff --git a/superscore/serialization.py b/superscore/serialization.py new file mode 100644 index 0000000..69c96fd --- /dev/null +++ b/superscore/serialization.py @@ -0,0 +1,138 @@ +""" +Serialization helpers for apischema. +""" +# Largely based on issue discussions regarding tagged unions. + +from collections import defaultdict +from collections.abc import Callable, Iterator +from types import new_class +from typing import (Any, Dict, Generic, List, Tuple, TypeVar, get_origin, + get_type_hints) + +from apischema import deserializer, serializer, type_name +from apischema.conversions import Conversion +from apischema.metadata import conversion +from apischema.objects import object_deserialization +from apischema.tagged_unions import Tagged, TaggedUnion, get_tagged +from apischema.utils import to_pascal_case + +_alternative_constructors: Dict[type, List[Callable]] = defaultdict(list) +Func = TypeVar("Func", bound=Callable) + + +def alternative_constructor(func: Func) -> Func: + """Alternative constructor for a given type.""" + return_type = get_type_hints(func)["return"] + _alternative_constructors[get_origin(return_type) or return_type].append(func) + return func + + +def get_all_subclasses(cls: type) -> Iterator[type]: + """Recursive implementation of type.__subclasses__""" + for sub_cls in cls.__subclasses__(): + yield sub_cls + yield from get_all_subclasses(sub_cls) + + +Cls = TypeVar("Cls", bound=type) + + +def _get_generic_name_factory(cls: type, *args: type): + def _capitalized(name: str) -> str: + return name[0].upper() + name[1:] + + return "".join((cls.__name__, *(_capitalized(arg.__name__) for arg in args))) + + +generic_name = type_name(_get_generic_name_factory) + + +def as_tagged_union(cls: Cls) -> Cls: + """ + Tagged union decorator, to be used on base class. + + Supports generics as well, with names generated by way of + `_get_generic_name_factory`. + """ + params = tuple(getattr(cls, "__parameters__", ())) + tagged_union_bases: Tuple[type, ...] = (TaggedUnion,) + + # Generic handling is here: + if params: + tagged_union_bases = (TaggedUnion, Generic[params]) + generic_name(cls) + prev_init_subclass = getattr(cls, "__init_subclass__", None) + + def __init_subclass__(cls, **kwargs): + if prev_init_subclass is not None: + prev_init_subclass(**kwargs) + generic_name(cls) + + cls.__init_subclass__ = classmethod(__init_subclass__) + + def with_params(cls: type) -> Any: + """Specify type of Generic if set.""" + return cls[params] if params else cls + + def serialization() -> Conversion: + """ + Define the serializer Conversion for the tagged union. + + source is the base ``cls`` (or ``cls[T]``). + target is the new tagged union class ``TaggedUnion`` which gets the + dictionary {cls.__name__: obj} as its arguments. + """ + annotations = { + # Assume that subclasses have same generic parameters than cls + sub.__name__: Tagged[with_params(sub)] + for sub in get_all_subclasses(cls) + } + namespace = {"__annotations__": annotations} + tagged_union = new_class( + cls.__name__, tagged_union_bases, exec_body=lambda ns: ns.update(namespace) + ) + return Conversion( + lambda obj: tagged_union(**{obj.__class__.__name__: obj}), + source=with_params(cls), + target=with_params(tagged_union), + # Conversion must not be inherited because it would lead to + # infinite recursion otherwise + inherited=False, + ) + + def deserialization() -> Conversion: + """ + Define the deserializer Conversion for the tagged union. + + Allows for alternative standalone constructors as per the apischema + example. + """ + annotations: dict[str, Any] = {} + namespace: dict[str, Any] = {"__annotations__": annotations} + for sub in get_all_subclasses(cls): + annotations[sub.__name__] = Tagged[with_params(sub)] + for constructor in _alternative_constructors.get(sub, ()): + # Build the alias of the field + alias = to_pascal_case(constructor.__name__) + # object_deserialization uses get_type_hints, but the constructor + # return type is stringified and the class not defined yet, + # so it must be assigned manually + constructor.__annotations__["return"] = with_params(sub) + # Use object_deserialization to wrap constructor as deserializer + deserialization = object_deserialization(constructor, generic_name) + # Add constructor tagged field with its conversion + annotations[alias] = Tagged[with_params(sub)] + namespace[alias] = Tagged(conversion(deserialization=deserialization)) + # Create the deserialization tagged union class + tagged_union = new_class( + cls.__name__, tagged_union_bases, exec_body=lambda ns: ns.update(namespace) + ) + return Conversion( + lambda obj: get_tagged(obj)[1], + source=with_params(tagged_union), + target=with_params(cls), + ) + + deserializer(lazy=deserialization, target=cls) + serializer(lazy=serialization, source=cls) + return cls