From a365bc8f43abdb1d52887b2a33bd7a1c412bdaac Mon Sep 17 00:00:00 2001 From: jpinner-lyft Date: Thu, 5 Nov 2020 15:17:32 -0800 Subject: [PATCH] Fix item refresh when using model discriminators. Fixes #879 (#880) --- docs/release_notes.rst | 2 +- pynamodb/__init__.py | 2 +- pynamodb/attributes.py | 26 ++++++++++++++++---------- pynamodb/models.py | 9 ++++++++- 4 files changed, 26 insertions(+), 13 deletions(-) diff --git a/docs/release_notes.rst b/docs/release_notes.rst index 571101442..9b6f09608 100644 --- a/docs/release_notes.rst +++ b/docs/release_notes.rst @@ -1,7 +1,7 @@ Release Notes ============= -v5.0.0b3 +v5.0.0b4 ------------------- :date: 2020-xx-xx diff --git a/pynamodb/__init__.py b/pynamodb/__init__.py index f2750bbad..c84fb8f0b 100644 --- a/pynamodb/__init__.py +++ b/pynamodb/__init__.py @@ -7,4 +7,4 @@ """ __author__ = 'Jharrod LaFon' __license__ = 'MIT' -__version__ = '5.0.0b3' +__version__ = '5.0.0b4' diff --git a/pynamodb/attributes.py b/pynamodb/attributes.py index 504412969..2eacd9e5a 100644 --- a/pynamodb/attributes.py +++ b/pynamodb/attributes.py @@ -312,7 +312,7 @@ def _get_discriminator_attribute(cls) -> Optional['DiscriminatorAttribute']: def _set_discriminator(self) -> None: discriminator_attr = self._get_discriminator_attribute() if discriminator_attr and discriminator_attr.get_discriminator(self.__class__) is not None: - self.attribute_values[self._discriminator] = self.__class__ # type: ignore + setattr(self, self._discriminator, self.__class__) # type: ignore def _set_defaults(self, _user_instantiated: bool = True) -> None: """ @@ -371,18 +371,22 @@ def deserialize(self, attribute_values: Dict[str, Dict[str, Any]]) -> None: setattr(self, name, value) @classmethod - def _instantiate(cls: Type[_ACT], attribute_values: Dict[str, Dict[str, Any]]) -> _ACT: + def _get_discriminator_class(cls, attribute_values: Dict[str, Dict[str, Any]]) -> Optional[Type]: discriminator_attr = cls._get_discriminator_attribute() if discriminator_attr: - discriminator_attribute_value = attribute_values.pop(discriminator_attr.attr_name, None) + discriminator_attribute_value = attribute_values.get(discriminator_attr.attr_name, None) if discriminator_attribute_value: discriminator_value = discriminator_attr.get_value(discriminator_attribute_value) - stored_cls = discriminator_attr.deserialize(discriminator_value) - if not issubclass(stored_cls, cls): - raise ValueError("Cannot instantiate a {} from the returned class: {}".format( - cls.__name__, stored_cls.__name__)) - cls = stored_cls - instance = cls(_user_instantiated=False) + return discriminator_attr.deserialize(discriminator_value) + return None + + @classmethod + def _instantiate(cls: Type[_ACT], attribute_values: Dict[str, Dict[str, Any]]) -> _ACT: + stored_cls = cls._get_discriminator_class(attribute_values) + if stored_cls and not issubclass(stored_cls, cls): + raise ValueError("Cannot instantiate a {} from the returned class: {}".format( + cls.__name__, stored_cls.__name__)) + instance = (stored_cls or cls)(_user_instantiated=False) AttributeContainer.deserialize(instance, attribute_values) return instance @@ -422,7 +426,9 @@ def get_discriminator(self, cls: type) -> Optional[Any]: return self._class_map.get(cls) def __set__(self, instance: Any, value: Optional[type]) -> None: - raise TypeError("'{}' object does not support item assignment".format(self.__class__.__name__)) + if type(instance) != value: + raise ValueError("The discriminator attribute must be set to the instance type: {}".format(type(instance))) + super().__set__(instance, value) def serialize(self, value): """ diff --git a/pynamodb/models.py b/pynamodb/models.py index 75dc24c70..aee6bb8e8 100644 --- a/pynamodb/models.py +++ b/pynamodb/models.py @@ -424,7 +424,11 @@ def update(self, actions: Sequence[Action], condition: Optional[Condition] = Non kwargs.update(actions=actions) data = self._get_connection().update_item(*args, **kwargs) - self.deserialize(data[ATTRIBUTES]) + item_data = data[ATTRIBUTES] + stored_cls = self._get_discriminator_class(item_data) + if stored_cls and stored_cls != type(self): + raise ValueError("Cannot update this item from the returned class: {}".format(stored_cls.__name__)) + self.deserialize(item_data) return data def save(self, condition: Optional[Condition] = None) -> Dict[str, Any]: @@ -453,6 +457,9 @@ def refresh(self, consistent_read: bool = False) -> None: item_data = attrs.get(ITEM, None) if item_data is None: raise self.DoesNotExist("This item does not exist in the table.") + stored_cls = self._get_discriminator_class(item_data) + if stored_cls and stored_cls != type(self): + raise ValueError("Cannot refresh this item from the returned class: {}".format(stored_cls.__name__)) self.deserialize(item_data) def get_operation_kwargs_from_instance(