From 3a0c475b7916ef71410069f3264875a334994967 Mon Sep 17 00:00:00 2001 From: henribru <6639509+henribru@users.noreply.github.com> Date: Sat, 11 Nov 2023 15:36:04 +0100 Subject: [PATCH] Improve message base class (#2) --- proto-stubs/message.pyi | 58 +++++++++++++++++++++++++---------------- pyproject.toml | 2 +- 2 files changed, 37 insertions(+), 23 deletions(-) diff --git a/proto-stubs/message.pyi b/proto-stubs/message.pyi index 8cbd6f7..8864655 100644 --- a/proto-stubs/message.pyi +++ b/proto-stubs/message.pyi @@ -1,49 +1,63 @@ -from typing import Any, List, Type +from collections.abc import Mapping +from typing import Any, TypeVar, overload from google.protobuf import descriptor_pb2, message from proto.fields import Field from proto.marshal import Marshal +_M = TypeVar("_M") + class MessageMeta(type): def __new__(mcls, name, bases, attrs): ... @classmethod def __prepare__(mcls, name, bases, **kwargs): ... @property def meta(cls): ... - def pb(cls, obj: Any | None = ..., *, coerce: bool = ...): ... - def wrap(cls, pb): ... - def serialize(cls, instance) -> bytes: ... - def deserialize(cls, payload: bytes) -> Message: ... + @overload + def pb( + cls: type[_M], obj: None = ..., *, coerce: bool = ... + ) -> type[message.Message]: ... + @overload + def pb(cls: type[_M], obj: _M, *, coerce: bool = ...) -> message.Message: ... + def wrap(cls: type[_M], pb: message.Message) -> _M: ... + def serialize(cls: type[_M], instance: _M | Mapping | message.Message) -> bytes: ... + def deserialize(cls: type[_M], payload: bytes) -> _M: ... def to_json( - cls, - instance, + cls: type[_M], + instance: _M, *, use_integers_for_enums: bool = ..., including_default_value_fields: bool = ..., preserving_proto_field_name: bool = ... ) -> str: ... - def from_json(cls, payload, *, ignore_unknown_fields: bool = ...) -> Message: ... + def from_json( + cls: type[_M], payload: str, *, ignore_unknown_fields: bool = ... + ) -> _M: ... def to_dict( - cls, - instance, + cls: type[_M], + instance: _M, *, use_integers_for_enums: bool = ..., preserving_proto_field_name: bool = ... - ) -> Message: ... - def copy_from(cls, instance, other) -> None: ... + ) -> dict[str, Any]: ... + def copy_from( + cls: type[_M], instance: _M | Mapping | message.Message, other + ) -> None: ... class Message(metaclass=MessageMeta): def __init__( - self, mapping: Any | None = ..., *, ignore_unknown_fields: bool = ..., **kwargs + self: _M, + mapping: _M | Mapping | message.Message | None = ..., + *, + ignore_unknown_fields: bool = ..., + **kwargs ) -> None: ... - def __bool__(self): ... - def __contains__(self, key): ... - def __delattr__(self, key) -> None: ... - def __eq__(self, other): ... - def __getattr__(self, key): ... - def __ne__(self, other): ... - def __setattr__(self, key, value): ... + def __bool__(self) -> bool: ... + def __contains__(self, key: str) -> bool: ... + def __delattr__(self, key: str) -> None: ... + def __eq__(self, other: Any) -> bool: ... + def __ne__(self, other: Any) -> bool: ... class _MessageInfo: package: Any @@ -54,11 +68,11 @@ class _MessageInfo: marshal: Any def __init__( self, - fields: List[Field], + fields: list[Field], package: str, full_name: str, marshal: Marshal, options: descriptor_pb2.MessageOptions, ) -> None: ... @property - def pb(self) -> Type[message.Message]: ... + def pb(self) -> type[message.Message]: ... diff --git a/pyproject.toml b/pyproject.toml index e3bf6b1..8059d2c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,7 @@ proto-plus = ">=1.18.0" types-protobuf = ">=3.17.4" [tool.poetry.group.dev.dependencies] -mypy = "^0.990" +mypy = {version = "^1.7.0", python = "^3.8"} black = "^22.10.0" isort = "^5.10.1"