Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add force mode to unions and hotfix union encoding #29

Merged
merged 1 commit into from
Nov 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ pylint:
poetry run pylint chili

mypy:
poetry run mypy --install-types --non-interactive .
poetry run mypy --install-types --non-interactive chili

bandit:
poetry run bandit -r . -x ./tests,./test,./.venv
Expand Down
25 changes: 16 additions & 9 deletions chili/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
UNDEFINED,
TypeSchema,
create_schema,
get_non_optional_fields,
get_origin_type,
get_parameters_map,
get_type_args,
Expand Down Expand Up @@ -286,15 +287,19 @@ class UnionDecoder(TypeDecoder):
}
_CASTABLES_TYPES = {decimal.Decimal}

def __init__(self, valid_types: List[Type]):
def __init__(self, valid_types: List[Type], extra_decoders: TypeDecoders = None, force: bool = False):
self.valid_types = valid_types
self._type_decoders = {}

for a_type in valid_types:
if a_type in self._PRIMITIVE_TYPES:
self._type_decoders[a_type] = a_type
continue
self._type_decoders[a_type] = build_type_decoder(a_type) # type: ignore
self._type_decoders[a_type] = build_type_decoder(
a_type, extra_decoders=extra_decoders, force=force # type: ignore
)

self.force = force

def decode(self, value: Any) -> Any:
passed_type = type(value)
Expand All @@ -317,13 +322,15 @@ def decode(self, value: Any) -> Any:
continue

if passed_type is dict:
value_keys = value.keys()
for decodable, decoder in self._type_decoders.items():
provided_fields = set(value.keys())
for class_name, decoder in self._type_decoders.items():
try:
if not is_decodable(decodable) and value_keys == get_type_hints(decodable).keys():
return decoder.decode(value)
if is_decodable(decodable) and value_keys == getattr(decodable, _PROPERTIES, {}).keys():
return decoder.decode(value)
if is_decodable(class_name) or is_dataclass(class_name) or self.force:
expected_fields = set(get_non_optional_fields(class_name))
if provided_fields.issubset(expected_fields):
return decoder.decode(value)
continue
continue
except Exception:
continue

Expand Down Expand Up @@ -497,7 +504,7 @@ def build_type_decoder(
return OptionalTypeDecoder(
build_type_decoder(a_type=type_args[0], extra_decoders=extra_decoders) # type: ignore
)
return UnionDecoder(type_args)
return UnionDecoder(type_args, extra_decoders=extra_decoders, force=force)

if isinstance(a_type, typing.ForwardRef) and module is not None:
resolved_reference = resolve_forward_reference(module, a_type)
Expand Down
8 changes: 4 additions & 4 deletions chili/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
unpack_optional,
)


if sys.version_info >= (3, 10):
from types import UnionType
else:
Expand Down Expand Up @@ -341,16 +340,17 @@ def encode(self, value: Any) -> Any:


class UnionEncoder(TypeEncoder):
def __init__(self, supported_types: List[Type], extra_encoders: TypeEncoders = None):
def __init__(self, supported_types: List[Type], extra_encoders: TypeEncoders = None, force: bool = False):
self.supported_types = supported_types
self._extra_encoders = extra_encoders
self.force = force

def encode(self, value: Any) -> Any:
value_type = type(value)
if value_type not in self.supported_types:
raise EncoderError.invalid_input

return build_type_encoder(value_type, self._extra_encoders).encode(value) # type: ignore
return build_type_encoder(value_type, self._extra_encoders, force=self.force).encode(value) # type: ignore


_supported_generics = {
Expand Down Expand Up @@ -405,7 +405,7 @@ def build_type_encoder(
type_args = get_type_args(a_type)
if len(type_args) == 2 and type_args[-1] is type(None):
return OptionalTypeEncoder(build_type_encoder(type_args[0], extra_encoders)) # type: ignore
return UnionEncoder(type_args, extra_encoders)
return UnionEncoder(type_args, extra_encoders, force=force)

if isinstance(a_type, typing.ForwardRef) and module is not None:
resolved_reference = resolve_forward_reference(module, a_type)
Expand Down
12 changes: 11 additions & 1 deletion chili/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from enum import Enum
from functools import lru_cache
from inspect import isclass as is_class
from typing import Any, Callable, ClassVar, Dict, List, NewType, Optional, Type, Union
from typing import Any, Callable, ClassVar, Dict, List, NewType, Optional, Type, Union, get_type_hints

from chili.error import SerialisationError

Expand All @@ -27,6 +27,7 @@
"get_origin_type",
"get_parameters_map",
"get_type_args",
"get_type_hints",
"get_type_parameters",
"is_class",
"is_dataclass",
Expand All @@ -43,6 +44,15 @@
]


def get_non_optional_fields(type_name: Type) -> List[str]:
if is_encodable(type_name):
schema = getattr(type_name, _ENCODABLE)
else:
schema = create_schema(type_name) # type: ignore

return [field.name for field in schema.values() if not is_optional(field.type)]


def get_origin_type(type_name: Type) -> Optional[Type]:
return getattr(type_name, "__origin__", None)

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ license = "MIT"
name = "chili"
readme = "README.md"
repository = "https://github.com/kodemore/chili"
version = "2.8.0"
version = "2.8.1"

[tool.poetry.dependencies]
gaffe = ">=0.3.0"
Expand Down
34 changes: 34 additions & 0 deletions tests/test_typing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from dataclasses import dataclass
from typing import Optional

from chili.typing import get_non_optional_fields


def test_get_non_optional_fields_from_data_class() -> None:
# given
@dataclass
class Person:
name: str
age: int
street_name: Optional[str]
street_number: Optional[int]

# when
fields = get_non_optional_fields(Person)

# then
assert fields == ["name", "age"]


def test_get_non_optional_fields_from_class() -> None:
class Person:
name: str
age: int
street_name: Optional[str]
street_number: Optional[int]

# when
fields = get_non_optional_fields(Person)

# then
assert fields == ["name", "age"]
3 changes: 2 additions & 1 deletion tests/usecases/newtype_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import sys
from typing import NewType

from chili import Decoder, Encoder, decodable, encodable
import sys


def test_can_encode_newtype_type() -> None:
Expand Down
47 changes: 46 additions & 1 deletion tests/usecases/union_usecase.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import List, Union
from typing import List, Optional, Union

from chili import decode, encode

Expand Down Expand Up @@ -88,3 +88,48 @@ class EmailAddress:

# then
assert result == {"address": "simple@email.com", "label": ""}


def test_can_decode_nested_union() -> None:
# given
@dataclass
class HomeAddress:
home_street: str
number: Optional[int] = 0

@dataclass
class OfficeAddress:
office_street: str
number: Optional[int] = 0

@dataclass
class Address:
street: str
number: Optional[int] = 0

@dataclass
class Person:
name: str
address: HomeAddress | OfficeAddress | Address

data_office = {
"name": "Bob",
"address": {"office_street": "street"},
}

data_home = {
"name": "Bob",
"address": {"home_street": "home"},
}

# when
decoded = decode(data_office, Person)

# then
assert isinstance(decoded.address, OfficeAddress)

# when
decoded = decode(data_home, Person)

# then
assert isinstance(decoded.address, HomeAddress)
Loading