Skip to content

Commit

Permalink
Support private properties, fix bug with p/c serialisables
Browse files Browse the repository at this point in the history
  • Loading branch information
dkraczkowski committed Nov 8, 2023
1 parent 20ea4d3 commit b75276e
Show file tree
Hide file tree
Showing 10 changed files with 353 additions and 113 deletions.
106 changes: 61 additions & 45 deletions chili/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,23 @@ def decode(self, value):


@final
class ProxyDecoder(Generic[T]):
def __init__(self, func: Callable[[Any], T]):
class SimpleDecoder(Generic[T]):
def __init__(self, func: Callable[[Any], T]) -> None:
self._decoder = func

def decode(self, value: Any) -> T:
return self._decoder(value)


@final
class ProxyDecoder(Generic[T]):
def __init__(self, type_annotation: Any) -> None:
self._decoder = build_type_decoder(type_annotation)

def decode(self, value: Any) -> T:
return self._decoder.decode(value)


_REGEX_FLAGS = {
"i": re.I,
"m": re.M,
Expand Down Expand Up @@ -158,44 +167,44 @@ def ordered_dict(value: List[List[Any]]) -> collections.OrderedDict:

_builtin_type_decoders = TypeDecoders(
{
bool: ProxyDecoder[bool](bool),
int: ProxyDecoder[int](int),
float: ProxyDecoder[float](float),
str: ProxyDecoder[str](str),
bytes: ProxyDecoder[bytes](lambda value: b64decode(value.encode("utf8"))),
bytearray: ProxyDecoder[bytearray](lambda value: bytearray(b64decode(value.encode("utf8")))),
list: ProxyDecoder[list](list),
set: ProxyDecoder[set](set),
frozenset: ProxyDecoder[frozenset](frozenset),
tuple: ProxyDecoder[tuple](tuple),
dict: ProxyDecoder[dict](dict),
collections.OrderedDict: ProxyDecoder[collections.OrderedDict](ordered_dict),
collections.deque: ProxyDecoder[collections.deque](collections.deque),
typing.TypedDict: ProxyDecoder[typing.TypedDict](typing.TypedDict), # type: ignore
typing.Dict: ProxyDecoder[dict](dict),
typing.List: ProxyDecoder[list](list),
typing.Sequence: ProxyDecoder[list](list),
typing.Tuple: ProxyDecoder[tuple](tuple), # type: ignore
typing.Set: ProxyDecoder[set](set),
typing.FrozenSet: ProxyDecoder[frozenset](frozenset),
typing.Deque: ProxyDecoder[typing.Deque](typing.Deque),
typing.AnyStr: ProxyDecoder[str](str), # type: ignore
decimal.Decimal: ProxyDecoder[decimal.Decimal](decimal.Decimal),
datetime.time: ProxyDecoder[datetime.time](parse_iso_time),
datetime.date: ProxyDecoder[datetime.date](parse_iso_date),
datetime.datetime: ProxyDecoder[datetime.datetime](parse_iso_datetime),
datetime.timedelta: ProxyDecoder[datetime.timedelta](parse_iso_duration),
PurePath: ProxyDecoder[PurePath](PurePath),
PureWindowsPath: ProxyDecoder[PureWindowsPath](PureWindowsPath),
PurePosixPath: ProxyDecoder[PurePosixPath](PurePosixPath),
Path: ProxyDecoder[Path](Path),
PosixPath: ProxyDecoder[PosixPath](PosixPath),
WindowsPath: ProxyDecoder[WindowsPath](WindowsPath),
Pattern: ProxyDecoder[Pattern](decode_regex_from_string),
re.Pattern: ProxyDecoder[re.Pattern](decode_regex_from_string),
IPv4Address: ProxyDecoder[IPv4Address](IPv4Address),
IPv6Address: ProxyDecoder[IPv6Address](IPv6Address),
UUID: ProxyDecoder[UUID](UUID),
bool: SimpleDecoder[bool](bool),
int: SimpleDecoder[int](int),
float: SimpleDecoder[float](float),
str: SimpleDecoder[str](str),
bytes: SimpleDecoder[bytes](lambda value: b64decode(value.encode("utf8"))),
bytearray: SimpleDecoder[bytearray](lambda value: bytearray(b64decode(value.encode("utf8")))),
list: SimpleDecoder[list](list),
set: SimpleDecoder[set](set),
frozenset: SimpleDecoder[frozenset](frozenset),
tuple: SimpleDecoder[tuple](tuple),
dict: SimpleDecoder[dict](dict),
collections.OrderedDict: SimpleDecoder[collections.OrderedDict](ordered_dict),
collections.deque: SimpleDecoder[collections.deque](collections.deque),
typing.TypedDict: SimpleDecoder[typing.TypedDict](typing.TypedDict), # type: ignore
typing.Dict: SimpleDecoder[dict](dict),
typing.List: SimpleDecoder[list](list),
typing.Sequence: SimpleDecoder[list](list),
typing.Tuple: SimpleDecoder[tuple](tuple), # type: ignore
typing.Set: SimpleDecoder[set](set),
typing.FrozenSet: SimpleDecoder[frozenset](frozenset),
typing.Deque: SimpleDecoder[typing.Deque](typing.Deque),
typing.AnyStr: SimpleDecoder[str](str), # type: ignore
decimal.Decimal: SimpleDecoder[decimal.Decimal](decimal.Decimal),
datetime.time: SimpleDecoder[datetime.time](parse_iso_time),
datetime.date: SimpleDecoder[datetime.date](parse_iso_date),
datetime.datetime: SimpleDecoder[datetime.datetime](parse_iso_datetime),
datetime.timedelta: SimpleDecoder[datetime.timedelta](parse_iso_duration),
PurePath: SimpleDecoder[PurePath](PurePath),
PureWindowsPath: SimpleDecoder[PureWindowsPath](PureWindowsPath),
PurePosixPath: SimpleDecoder[PurePosixPath](PurePosixPath),
Path: SimpleDecoder[Path](Path),
PosixPath: SimpleDecoder[PosixPath](PosixPath),
WindowsPath: SimpleDecoder[WindowsPath](WindowsPath),
Pattern: SimpleDecoder[Pattern](decode_regex_from_string),
re.Pattern: SimpleDecoder[re.Pattern](decode_regex_from_string),
IPv4Address: SimpleDecoder[IPv4Address](IPv4Address),
IPv6Address: SimpleDecoder[IPv6Address](IPv6Address),
UUID: SimpleDecoder[UUID](UUID),
}
)

Expand Down Expand Up @@ -321,7 +330,7 @@ class ClassDecoder(TypeDecoder):

def __init__(self, class_name: Type, extra_decoders: TypeDecoders = None):
self.class_name = class_name
self._schema = create_schema(class_name)
self._schema = create_schema(class_name) # type: ignore
self._extra_decoders = extra_decoders

def decode(self, value: StateObject) -> Any:
Expand Down Expand Up @@ -440,7 +449,9 @@ def decode(self, value: Any) -> Any:


@lru_cache(maxsize=None)
def build_type_decoder(a_type: Type, extra_decoders: TypeDecoders = None, module: Any = None) -> TypeDecoder:
def build_type_decoder(
a_type: Type, extra_decoders: TypeDecoders = None, module: Any = None, force: bool = False
) -> TypeDecoder:
if extra_decoders and a_type in extra_decoders:
return extra_decoders[a_type]

Expand Down Expand Up @@ -472,7 +483,7 @@ def build_type_decoder(a_type: Type, extra_decoders: TypeDecoders = None, module
return TypedDictDecoder(origin_type, extra_decoders)

if is_class(origin_type) and is_user_string(origin_type):
return ProxyDecoder[origin_type](origin_type) # type: ignore
return SimpleDecoder[origin_type](origin_type) # type: ignore

if origin_type is Union:
type_args = get_type_args(a_type)
Expand Down Expand Up @@ -505,6 +516,8 @@ def build_type_decoder(a_type: Type, extra_decoders: TypeDecoders = None, module
return OptionalTypeDecoder(build_type_decoder(unpack_optional(a_type))) # type: ignore

if origin_type not in _supported_generics:
if force and is_class(origin_type):
return Decoder[origin_type](extra_decoders) # type: ignore
raise DecoderError.invalid_type(a_type)

type_attributes: List[Union[TypeDecoder, Any]] = [
Expand Down Expand Up @@ -551,15 +564,18 @@ def decode(self, obj: Dict[str, StateObject]) -> T:
else:
value = self._decoders[prop.name].decode(obj[key])

setattr(instance, prop.name, value)
try:
setattr(instance, prop.name, value)
except AttributeError:
setattr(instance, f"_{prop.name}", value)

return instance

def _build_decoders(self) -> Dict[str, TypeDecoder]:
schema: TypeSchema = getattr(self.__generic__, _PROPERTIES)

return {
prop.name: build_type_decoder(prop.type, extra_decoders=self.type_decoders) # type: ignore
prop.name: build_type_decoder(prop.type, extra_decoders=self.type_decoders, force=True) # type: ignore
for prop in schema.values()
}

Expand Down
104 changes: 57 additions & 47 deletions chili/encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,23 @@ def encode(self, value):


@final
class ProxyEncoder(TypeEncoder, Generic[T]):
class SimpleEncoder(TypeEncoder, Generic[T]):
def __init__(self, func: Callable[[Any], T]):
self._encoder = func

def encode(self, value: Any) -> T:
return self._encoder(value)


@final
class ProxyEncoder(Generic[T]):
def __init__(self, type_annotation: Any) -> None:
self._encoder = build_type_encoder(type_annotation)

def decode(self, value: Any) -> T:
return self._encoder.encode(value)


def encode_regex_to_string(value: Pattern) -> str:
"""
Encodes regex into string and preserves flags if they are set. Then regex is normally wrapped between slashes.
Expand Down Expand Up @@ -126,44 +135,44 @@ def ordered_dict(value: collections.OrderedDict) -> List[List[Any]]:

_builtin_type_encoders = TypeEncoders(
{
bool: ProxyEncoder[bool](bool),
int: ProxyEncoder[int](int),
float: ProxyEncoder[float](float),
str: ProxyEncoder[str](str),
bytes: ProxyEncoder[str](lambda value: b64encode(value).decode("utf8")),
bytearray: ProxyEncoder[str](lambda value: b64encode(value).decode("utf8")),
list: ProxyEncoder[list](list),
set: ProxyEncoder[list](list),
frozenset: ProxyEncoder[list](list),
tuple: ProxyEncoder[list](list),
dict: ProxyEncoder[dict](dict),
collections.OrderedDict: ProxyEncoder[list](ordered_dict),
collections.deque: ProxyEncoder[list](list),
typing.TypedDict: ProxyEncoder[dict](dict), # type: ignore
typing.Dict: ProxyEncoder[dict](dict),
typing.List: ProxyEncoder[list](list),
typing.Sequence: ProxyEncoder[list](list),
typing.Tuple: ProxyEncoder[list](list), # type: ignore
typing.Set: ProxyEncoder[list](list),
typing.FrozenSet: ProxyEncoder[list](list),
typing.Deque: ProxyEncoder[list](list),
typing.AnyStr: ProxyEncoder[str](str), # type: ignore
decimal.Decimal: ProxyEncoder[str](str),
datetime.time: ProxyEncoder[str](lambda value: value.isoformat()),
datetime.date: ProxyEncoder[str](lambda value: value.isoformat()),
datetime.datetime: ProxyEncoder[str](lambda value: value.isoformat()),
datetime.timedelta: ProxyEncoder[str](timedelta_to_iso_duration),
PurePath: ProxyEncoder[str](str),
PureWindowsPath: ProxyEncoder[str](str),
PurePosixPath: ProxyEncoder[str](str),
Path: ProxyEncoder[str](str),
PosixPath: ProxyEncoder[str](str),
WindowsPath: ProxyEncoder[str](str),
Pattern: ProxyEncoder[str](encode_regex_to_string),
re.Pattern: ProxyEncoder[str](encode_regex_to_string),
IPv6Address: ProxyEncoder[str](str),
IPv4Address: ProxyEncoder[str](str),
UUID: ProxyEncoder[str](str),
bool: SimpleEncoder[bool](bool),
int: SimpleEncoder[int](int),
float: SimpleEncoder[float](float),
str: SimpleEncoder[str](str),
bytes: SimpleEncoder[str](lambda value: b64encode(value).decode("utf8")),
bytearray: SimpleEncoder[str](lambda value: b64encode(value).decode("utf8")),
list: SimpleEncoder[list](list),
set: SimpleEncoder[list](list),
frozenset: SimpleEncoder[list](list),
tuple: SimpleEncoder[list](list),
dict: SimpleEncoder[dict](dict),
collections.OrderedDict: SimpleEncoder[list](ordered_dict),
collections.deque: SimpleEncoder[list](list),
typing.TypedDict: SimpleEncoder[dict](dict), # type: ignore
typing.Dict: SimpleEncoder[dict](dict),
typing.List: SimpleEncoder[list](list),
typing.Sequence: SimpleEncoder[list](list),
typing.Tuple: SimpleEncoder[list](list), # type: ignore
typing.Set: SimpleEncoder[list](list),
typing.FrozenSet: SimpleEncoder[list](list),
typing.Deque: SimpleEncoder[list](list),
typing.AnyStr: SimpleEncoder[str](str), # type: ignore
decimal.Decimal: SimpleEncoder[str](str),
datetime.time: SimpleEncoder[str](lambda value: value.isoformat()),
datetime.date: SimpleEncoder[str](lambda value: value.isoformat()),
datetime.datetime: SimpleEncoder[str](lambda value: value.isoformat()),
datetime.timedelta: SimpleEncoder[str](timedelta_to_iso_duration),
PurePath: SimpleEncoder[str](str),
PureWindowsPath: SimpleEncoder[str](str),
PurePosixPath: SimpleEncoder[str](str),
Path: SimpleEncoder[str](str),
PosixPath: SimpleEncoder[str](str),
WindowsPath: SimpleEncoder[str](str),
Pattern: SimpleEncoder[str](encode_regex_to_string),
re.Pattern: SimpleEncoder[str](encode_regex_to_string),
IPv6Address: SimpleEncoder[str](str),
IPv4Address: SimpleEncoder[str](str),
UUID: SimpleEncoder[str](str),
}
)

Expand Down Expand Up @@ -227,7 +236,7 @@ class ClassEncoder(TypeEncoder):
def __init__(self, class_name: Type, extra_encoders: TypeEncoders = None):
self.class_name = class_name
self._extra_encoders = extra_encoders
self._schema = create_schema(class_name)
self._schema = create_schema(class_name) # type: ignore

def encode(self, value: Any) -> StateObject:
if not isinstance(value, self.class_name):
Expand Down Expand Up @@ -349,7 +358,9 @@ def encode(self, value: Any) -> Any:


@lru_cache(maxsize=None)
def build_type_encoder(a_type: Type, extra_encoders: TypeEncoders = None, module: Any = None) -> TypeEncoder:
def build_type_encoder(
a_type: Type, extra_encoders: TypeEncoders = None, module: Any = None, force: bool = False
) -> TypeEncoder:
if extra_encoders and a_type in extra_encoders:
return extra_encoders[a_type]

Expand Down Expand Up @@ -381,7 +392,7 @@ def build_type_encoder(a_type: Type, extra_encoders: TypeEncoders = None, module
return TypedDictEncoder(origin_type, extra_encoders)

if is_class(origin_type) and is_user_string(origin_type):
return ProxyEncoder[str](str)
return SimpleEncoder[str](str)

if origin_type is Union:
type_args = get_type_args(a_type)
Expand All @@ -402,9 +413,6 @@ def build_type_encoder(a_type: Type, extra_encoders: TypeEncoders = None, module
return GenericClassEncoder(a_type)
return Encoder[origin_type](encoders=extra_encoders) # type: ignore[valid-type]

if is_optional(a_type):
return OptionalTypeEncoder(build_type_encoder(unpack_optional(a_type), extra_encoders)) # type: ignore

if isinstance(a_type, TypeVar):
if a_type.__bound__ is None:
raise EncoderError.invalid_type(a_type)
Expand All @@ -414,7 +422,9 @@ def build_type_encoder(a_type: Type, extra_encoders: TypeEncoders = None, module
return build_type_encoder(a_type.__supertype__, extra_encoders, module)

if origin_type not in _supported_generics:
raise EncoderError.invalid_type(a_type)
if is_class(origin_type) and force:
return Encoder[origin_type](encoders=extra_encoders) # type: ignore[valid-type]
raise EncoderError.invalid_type(type=a_type)

type_attributes: List[TypeEncoder] = [
build_type_encoder(subtype, extra_encoders=extra_encoders, module=module) # type: ignore
Expand Down Expand Up @@ -467,7 +477,7 @@ def _build_encoders(self) -> Dict[str, TypeEncoder]:
schema: TypeSchema = self.schema

return {
prop.name: build_type_encoder(prop.type, extra_encoders=self.type_encoders) # type: ignore
prop.name: build_type_encoder(prop.type, extra_encoders=self.type_encoders, force=True) # type: ignore
for prop in schema.values()
}

Expand Down
12 changes: 6 additions & 6 deletions chili/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,12 @@ def __class_getitem__(cls, item: Type[T]) -> Type[Serializer]: # noqa: E501

def serializable(_cls=None, in_mapper: Optional[Mapper] = None, out_mapper: Optional[Mapper] = None) -> Any:
def _decorate(cls) -> Type[C]:
if not hasattr(cls, _PROPERTIES):
setattr(cls, _PROPERTIES, create_schema(cls))
if in_mapper is not None:
setattr(cls, _DECODE_MAPPER, in_mapper)
if out_mapper is not None:
setattr(cls, _ENCODE_MAPPER, out_mapper)

setattr(cls, _PROPERTIES, create_schema(cls))
if in_mapper is not None:
setattr(cls, _DECODE_MAPPER, in_mapper)
if out_mapper is not None:
setattr(cls, _ENCODE_MAPPER, out_mapper)

setattr(cls, _DECODABLE, True)
setattr(cls, _ENCODABLE, True)
Expand Down
2 changes: 2 additions & 0 deletions chili/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from collections import UserString
from dataclasses import MISSING, Field, InitVar, is_dataclass
from enum import Enum
from functools import lru_cache
from inspect import isclass as is_class
from typing import Any, Callable, ClassVar, Dict, List, NewType, Optional, Type, Union

Expand Down Expand Up @@ -184,6 +185,7 @@ def __eq__(self, other: Property) -> bool: # type: ignore
_default_factories = (list, dict, tuple, set, bytes, bytearray, frozenset)


@lru_cache
def create_schema(cls: Type) -> TypeSchema:
try:
properties = typing.get_type_hints(cls, localns=cls.__dict__) # type: ignore
Expand Down
Loading

0 comments on commit b75276e

Please sign in to comment.