Skip to content

Commit

Permalink
Fix typing
Browse files Browse the repository at this point in the history
  • Loading branch information
dargueta committed Sep 8, 2024
1 parent 5ddc367 commit f530a73
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 17 deletions.
43 changes: 29 additions & 14 deletions binobj/fields/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Union as _Union

from typing_extensions import override
from typing_extensions import TypeAlias

from binobj import errors
from binobj.fields.base import Field
Expand All @@ -33,18 +34,29 @@
TStruct = TypeVar("TStruct", bound=Struct)


HaltCheckFn = Callable[
HaltCheckFn: TypeAlias = Callable[
["Array[T]", BinaryIO, MutableSequence[Optional[T]], Any, StrDict], bool
]
"""A function used to detect the end of an array when deserializing.
FieldOrTStruct = _Union[Field[Any], type[Struct]]
FieldLoadDecider = Callable[[BinaryIO, Sequence[Field[Any]], Any, StrDict], Field[Any]]
FieldDumpDecider = Callable[[Any, Sequence[Field[Any]], Any, StrDict], Field[Any]]
See :meth:`Array.should_halt` for a full description of arguments.
"""

StructLoadDecider = Callable[

FieldLoadDecider: TypeAlias = Callable[
[BinaryIO, Sequence[Field[Any]], Any, StrDict], Field[Any]
]

FieldDumpDecider: TypeAlias = Callable[
[Any, Sequence[Field[Any]], Any, StrDict], Field[Any]
]

StructLoadDecider: TypeAlias = Callable[
[BinaryIO, Sequence[type[Struct]], Any, StrDict], type[Struct]
]
StructDumpDecider = Callable[[Any, Sequence[type[Struct]], Any, StrDict], type[Struct]]
StructDumpDecider: TypeAlias = Callable[
[Any, Sequence[type[Struct]], Any, StrDict], type[Struct]
]


class Array(Field[list[Optional[T]]]):
Expand All @@ -65,7 +77,8 @@ class Array(Field[list[Optional[T]]]):
:param callable halt_check:
A function taking five arguments. See :meth:`should_halt` for the default
implementation. Subclasses can override this function if desired to avoid having
to pass in a custom function every time.
to pass in a custom function every time. For further details, see
:data:`HaltCheckFn`.
.. versionchanged:: 0.3.0
``count`` can now be a :class:`~.fields.base.Field` or string.
Expand Down Expand Up @@ -358,7 +371,7 @@ def _do_load(
return self.struct_class.from_stream(stream, context)


class Union(Field[T]):
class Union(Field[Any]):
"""A field that can be one of several different types of structs or fields.
:param choices:
Expand Down Expand Up @@ -430,7 +443,7 @@ def __init__(
@overload
def __init__(
self,
*choices: TStruct,
*choices: type[Struct],
load_decider: StructLoadDecider,
dump_decider: StructDumpDecider,
**kwargs: Any,
Expand All @@ -439,24 +452,24 @@ def __init__(

def __init__(
self,
*choices: T,
*choices: Any,
load_decider: Any,
dump_decider: Any,
**kwargs: Any,
):
super().__init__(**kwargs)
if any(isinstance(c, type) and issubclass(c, Field) for c in choices):
raise errors.ConfigurationError(
"You must pass an instance of a Field, not a class.", field=self
"A `Union` must be passed Field instances, not classes.", field=self
)

super().__init__(**kwargs)
self.choices = choices
self.load_decider = load_decider
self.dump_decider = dump_decider

@override
def _do_dump(
self, stream: BinaryIO, data: T, context: object, all_fields: StrDict
self, stream: BinaryIO, data: Any, context: object, all_fields: StrDict
) -> None:
dumper = self.dump_decider(data, self.choices, context, all_fields)
if isinstance(dumper, Field):
Expand All @@ -474,7 +487,9 @@ def _do_dump(
)

@override
def _do_load(self, stream: BinaryIO, context: object, loaded_fields: StrDict) -> T:
def _do_load(
self, stream: BinaryIO, context: object, loaded_fields: StrDict
) -> Any:
loader = self.load_decider(stream, self.choices, context, loaded_fields)
if isinstance(loader, Field):
return loader._do_load(stream, context, loaded_fields) # noqa: SLF001
Expand Down
7 changes: 4 additions & 3 deletions tests/fields/containers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,11 +401,12 @@ def test_union__fields__load_basic():

def test_union__field_class_crashes():
"""Passing a Field class to a Union should crash."""
with pytest.raises(errors.ConfigurationError) as errinfo:
with pytest.raises(
errors.ConfigurationError,
match=r"^A `Union` must be passed Field instances, not classes\.$",
):
fields.Union(fields.StringZ, load_decider=None, dump_decider=None)

assert str(errinfo.value) == "You must pass an instance of a Field, not a class."


def test_union__dump_non_mapping_for_struct():
"""If the dump decider returns a Struct as the serializer,"""
Expand Down

0 comments on commit f530a73

Please sign in to comment.