Skip to content

Commit

Permalink
Fix item refresh when using model discriminators. Fixes #879 (#880)
Browse files Browse the repository at this point in the history
  • Loading branch information
jpinner-lyft authored Nov 5, 2020
1 parent f680242 commit a365bc8
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 13 deletions.
2 changes: 1 addition & 1 deletion docs/release_notes.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Release Notes
=============

v5.0.0b3
v5.0.0b4
-------------------

:date: 2020-xx-xx
Expand Down
2 changes: 1 addition & 1 deletion pynamodb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@
"""
__author__ = 'Jharrod LaFon'
__license__ = 'MIT'
__version__ = '5.0.0b3'
__version__ = '5.0.0b4'
26 changes: 16 additions & 10 deletions pynamodb/attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def _get_discriminator_attribute(cls) -> Optional['DiscriminatorAttribute']:
def _set_discriminator(self) -> None:
discriminator_attr = self._get_discriminator_attribute()
if discriminator_attr and discriminator_attr.get_discriminator(self.__class__) is not None:
self.attribute_values[self._discriminator] = self.__class__ # type: ignore
setattr(self, self._discriminator, self.__class__) # type: ignore

def _set_defaults(self, _user_instantiated: bool = True) -> None:
"""
Expand Down Expand Up @@ -371,18 +371,22 @@ def deserialize(self, attribute_values: Dict[str, Dict[str, Any]]) -> None:
setattr(self, name, value)

@classmethod
def _instantiate(cls: Type[_ACT], attribute_values: Dict[str, Dict[str, Any]]) -> _ACT:
def _get_discriminator_class(cls, attribute_values: Dict[str, Dict[str, Any]]) -> Optional[Type]:
discriminator_attr = cls._get_discriminator_attribute()
if discriminator_attr:
discriminator_attribute_value = attribute_values.pop(discriminator_attr.attr_name, None)
discriminator_attribute_value = attribute_values.get(discriminator_attr.attr_name, None)
if discriminator_attribute_value:
discriminator_value = discriminator_attr.get_value(discriminator_attribute_value)
stored_cls = discriminator_attr.deserialize(discriminator_value)
if not issubclass(stored_cls, cls):
raise ValueError("Cannot instantiate a {} from the returned class: {}".format(
cls.__name__, stored_cls.__name__))
cls = stored_cls
instance = cls(_user_instantiated=False)
return discriminator_attr.deserialize(discriminator_value)
return None

@classmethod
def _instantiate(cls: Type[_ACT], attribute_values: Dict[str, Dict[str, Any]]) -> _ACT:
stored_cls = cls._get_discriminator_class(attribute_values)
if stored_cls and not issubclass(stored_cls, cls):
raise ValueError("Cannot instantiate a {} from the returned class: {}".format(
cls.__name__, stored_cls.__name__))
instance = (stored_cls or cls)(_user_instantiated=False)
AttributeContainer.deserialize(instance, attribute_values)
return instance

Expand Down Expand Up @@ -422,7 +426,9 @@ def get_discriminator(self, cls: type) -> Optional[Any]:
return self._class_map.get(cls)

def __set__(self, instance: Any, value: Optional[type]) -> None:
raise TypeError("'{}' object does not support item assignment".format(self.__class__.__name__))
if type(instance) != value:
raise ValueError("The discriminator attribute must be set to the instance type: {}".format(type(instance)))
super().__set__(instance, value)

def serialize(self, value):
"""
Expand Down
9 changes: 8 additions & 1 deletion pynamodb/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,7 +424,11 @@ def update(self, actions: Sequence[Action], condition: Optional[Condition] = Non
kwargs.update(actions=actions)

data = self._get_connection().update_item(*args, **kwargs)
self.deserialize(data[ATTRIBUTES])
item_data = data[ATTRIBUTES]
stored_cls = self._get_discriminator_class(item_data)
if stored_cls and stored_cls != type(self):
raise ValueError("Cannot update this item from the returned class: {}".format(stored_cls.__name__))
self.deserialize(item_data)
return data

def save(self, condition: Optional[Condition] = None) -> Dict[str, Any]:
Expand Down Expand Up @@ -453,6 +457,9 @@ def refresh(self, consistent_read: bool = False) -> None:
item_data = attrs.get(ITEM, None)
if item_data is None:
raise self.DoesNotExist("This item does not exist in the table.")
stored_cls = self._get_discriminator_class(item_data)
if stored_cls and stored_cls != type(self):
raise ValueError("Cannot refresh this item from the returned class: {}".format(stored_cls.__name__))
self.deserialize(item_data)

def get_operation_kwargs_from_instance(
Expand Down

0 comments on commit a365bc8

Please sign in to comment.