Skip to content

Commit

Permalink
Support model polymorphism using discriminators. Fixes #247, #328 (#864)
Browse files Browse the repository at this point in the history
  • Loading branch information
jpinner-lyft authored Oct 9, 2020
1 parent 5c8862e commit 2e2a37c
Show file tree
Hide file tree
Showing 8 changed files with 198 additions and 39 deletions.
30 changes: 29 additions & 1 deletion docs/polymorphism.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ Polymorphism
PynamoDB supports polymorphism through the use of discriminators.

A discriminator is a value that is written to DynamoDB that identifies the python class being stored.
(Note: currently discriminators are only supported on MapAttribute subclasses; support for model subclasses coming soon.)

Discriminator Attributes
^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down Expand Up @@ -49,3 +48,32 @@ A class can also be registered manually:
Discriminator values are written to DynamoDB.
Changing the value after items have been saved to the database can result in deserialization failures.
In order to read items with an old discriminator value, the old value must be manually registered.


Model Discriminators
^^^^^^^^^^^^^^^^^^^^

Model classes also support polymorphism through the use of discriminators.
(Note: currently discriminator attributes cannot be used as the hash or range key of a table.)

.. code-block:: python
class ParentModel(Model):
class Meta:
table_name = 'polymorphic_table'
id = UnicodeAttribute(hash_key=True)
cls = DiscriminatorAttribute()
class FooModel(ParentModel, discriminator='Foo'):
foo = UnicodeAttribute()
class BarModel(ParentModel, discriminator='Bar'):
bar = UnicodeAttribute()
BarModel(id='Hello', bar='World!').serialize()
# {'id': {'S': 'Hello'}, 'cls': {'S': 'Bar'}, 'bar': {'S': 'World!'}}
.. note::

Read operations that are performed on a class that has a discriminator value are slightly modified to ensure that only instances of the class are returned.
Query and scan operations transparently add a filter condition to ensure that only items with a matching discriminator value are returned.
Get and batch get operations will raise a ``ValueError`` if the returned item(s) are not a subclass of the model being read.
2 changes: 1 addition & 1 deletion docs/release_notes.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Release Notes
=============

v5.0.0b1
v5.0.0b2
-------------------

:date: 2020-xx-xx
Expand Down
2 changes: 1 addition & 1 deletion pynamodb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@
"""
__author__ = 'Jharrod LaFon'
__license__ = 'MIT'
__version__ = '5.0.0b1'
__version__ = '5.0.0b2'
31 changes: 18 additions & 13 deletions pynamodb/attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
_KT = TypeVar('_KT', bound=str)
_VT = TypeVar('_VT')
_MT = TypeVar('_MT', bound='MapAttribute')
_ACT = TypeVar('_ACT', bound = 'AttributeContainer')

_A = TypeVar('_A', bound='Attribute')

Expand Down Expand Up @@ -259,9 +260,6 @@ def _initialize_attributes(cls, discriminator_value):
raise ValueError("{} has more than one discriminator attribute: {}".format(
cls.__name__, ", ".join(discriminators)))
cls._discriminator = discriminators[0] if discriminators else None
# TODO(jpinner) add support for model polymorphism
if cls._discriminator and not issubclass(cls, MapAttribute):
raise NotImplementedError("Discriminators are not yet supported in model classes.")
if discriminator_value is not None:
if not cls._discriminator:
raise ValueError("{} does not have a discriminator attribute".format(cls.__name__))
Expand Down Expand Up @@ -372,6 +370,22 @@ def deserialize(self, attribute_values: Dict[str, Dict[str, Any]]) -> None:
value = attr.deserialize(attr.get_value(attribute_value))
setattr(self, name, value)

@classmethod
def _instantiate(cls: Type[_ACT], attribute_values: Dict[str, Dict[str, Any]]) -> _ACT:
discriminator_attr = cls._get_discriminator_attribute()
if discriminator_attr:
discriminator_attribute_value = attribute_values.pop(discriminator_attr.attr_name, None)
if discriminator_attribute_value:
discriminator_value = discriminator_attr.get_value(discriminator_attribute_value)
stored_cls = discriminator_attr.deserialize(discriminator_value)
if not issubclass(stored_cls, cls):
raise ValueError("Cannot instantiate a {} from the returned class: {}".format(
cls.__name__, stored_cls.__name__))
cls = stored_cls
instance = cls(_user_instantiated=False)
AttributeContainer.deserialize(instance, attribute_values)
return instance

def __eq__(self, other: Any) -> bool:
# This is required so that MapAttribute can call this method.
return self is other
Expand Down Expand Up @@ -940,16 +954,7 @@ def deserialize(self, values):
"""
if not self.is_raw():
# If this is a subclass of a MapAttribute (i.e typed), instantiate an instance
cls = type(self)
discriminator_attr = cls._get_discriminator_attribute()
if discriminator_attr:
discriminator_attribute_value = values.pop(discriminator_attr.attr_name, None)
if discriminator_attribute_value:
discriminator_value = discriminator_attr.get_value(discriminator_attribute_value)
cls = discriminator_attr.deserialize(discriminator_value)
instance = cls()
AttributeContainer.deserialize(instance, values)
return instance
return self._instantiate(values)

return {
k: DESERIALIZE_CLASS_MAP[attr_type].deserialize(attr_value)
Expand Down
67 changes: 51 additions & 16 deletions pynamodb/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,28 @@
import time
import logging
import warnings
import sys
from inspect import getmembers
from typing import Any, Dict, Generic, Iterable, Iterator, List, Optional, Sequence, Mapping, Type, TypeVar, Text, \
Tuple, Union, cast
from typing import Any
from typing import Dict
from typing import Generic
from typing import Iterable
from typing import Iterator
from typing import List
from typing import Mapping
from typing import Optional
from typing import Sequence
from typing import Text
from typing import Tuple
from typing import Type
from typing import TypeVar
from typing import Union
from typing import cast

if sys.version_info >= (3, 8):
from typing import Protocol
else:
from typing_extensions import Protocol

from pynamodb.expressions.update import Action
from pynamodb.exceptions import DoesNotExist, TableDoesNotExist, TableError, InvalidStateError, PutError
Expand Down Expand Up @@ -151,7 +170,7 @@ def commit(self) -> None:
unprocessed_items = data.get(UNPROCESSED_ITEMS, {}).get(self.model.Meta.table_name)


class MetaModel(AttributeContainerMeta):
class MetaProtocol(Protocol):
table_name: str
read_capacity_units: Optional[int]
write_capacity_units: Optional[int]
Expand All @@ -169,14 +188,17 @@ class MetaModel(AttributeContainerMeta):
billing_mode: Optional[str]
stream_view_type: Optional[str]


class MetaModel(AttributeContainerMeta):
"""
Model meta class
This class is just here so that index queries have nice syntax.
Model.index.query()
"""
def __init__(self, name: str, bases: Any, attrs: Dict[str, Any]) -> None:
super().__init__(name, bases, attrs)
def __new__(cls, name, bases, namespace, discriminator=None):
# Defined so that the discriminator can be set in the class definition.
return super().__new__(cls, name, bases, namespace)

def __init__(self, name, bases, namespace, discriminator=None) -> None:
super().__init__(name, bases, namespace, discriminator)
cls = cast(Type['Model'], self)
for attr_name, attribute in cls.get_attributes().items():
if attribute.is_hash_key:
Expand All @@ -200,8 +222,8 @@ def __init__(self, name: str, bases: Any, attrs: Dict[str, Any]) -> None:
raise ValueError("{} has more than one TTL attribute: {}".format(
cls.__name__, ", ".join(ttl_attr_names)))

if isinstance(attrs, dict):
for attr_name, attr_obj in attrs.items():
if isinstance(namespace, dict):
for attr_name, attr_obj in namespace.items():
if attr_name == META_CLASS_NAME:
if not hasattr(attr_obj, REGION):
setattr(attr_obj, REGION, get_settings_value('region'))
Expand Down Expand Up @@ -234,9 +256,9 @@ def __init__(self, name: str, bases: Any, attrs: Dict[str, Any]) -> None:

# create a custom Model.DoesNotExist derived from pynamodb.exceptions.DoesNotExist,
# so that "except Model.DoesNotExist:" would not catch other models' exceptions
if 'DoesNotExist' not in attrs:
if 'DoesNotExist' not in namespace:
exception_attrs = {
'__module__': attrs.get('__module__'),
'__module__': namespace.get('__module__'),
'__qualname__': f'{cls.__qualname__}.{"DoesNotExist"}',
}
cls.DoesNotExist = type('DoesNotExist', (DoesNotExist, ), exception_attrs)
Expand All @@ -260,7 +282,7 @@ class Model(AttributeContainer, metaclass=MetaModel):
DoesNotExist: Type[DoesNotExist] = DoesNotExist
_version_attribute_name: Optional[str] = None

Meta: MetaModel
Meta: MetaProtocol

def __init__(
self,
Expand Down Expand Up @@ -520,9 +542,7 @@ def from_raw_data(cls: Type[_T], data: Dict[str, Any]) -> _T:
if data is None:
raise ValueError("Received no data to construct object")

model = cls(_user_instantiated=False)
model.deserialize(data)
return model
return cls._instantiate(data)

@classmethod
def count(
Expand Down Expand Up @@ -556,6 +576,11 @@ def count(
else:
hash_key = cls._serialize_keys(hash_key)[0]

# If this class has a discriminator value, filter the query to only return instances of this class.
discriminator_attr = cls._get_discriminator_attribute()
if discriminator_attr and discriminator_attr.get_discriminator(cls):
filter_condition &= discriminator_attr == cls

query_args = (hash_key,)
query_kwargs = dict(
range_key_condition=range_key_condition,
Expand Down Expand Up @@ -616,6 +641,11 @@ def query(
else:
hash_key = cls._serialize_keys(hash_key)[0]

# If this class has a discriminator value, filter the query to only return instances of this class.
discriminator_attr = cls._get_discriminator_attribute()
if discriminator_attr and discriminator_attr.get_discriminator(cls):
filter_condition &= discriminator_attr == cls

if page_size is None:
page_size = limit

Expand Down Expand Up @@ -668,6 +698,11 @@ def scan(
:param rate_limit: If set then consumed capacity will be limited to this amount per second
:param attributes_to_get: If set, specifies the properties to include in the projection expression
"""
# If this class has a discriminator value, filter the scan to only return instances of this class.
discriminator_attr = cls._get_discriminator_attribute()
if discriminator_attr and discriminator_attr.get_discriminator(cls):
filter_condition &= discriminator_attr == cls

if page_size is None:
page_size = limit

Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

install_requires = [
'botocore>=1.12.54',
'typing-extensions>=3.7; python_version<"3.8"'
]

setup(
Expand Down
39 changes: 33 additions & 6 deletions tests/test_discriminator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,18 @@ class RenamedValue(TypedValue, discriminator='custom_name'):
value = UnicodeAttribute()


class DiscriminatorTestModel(Model):
class DiscriminatorTestModel(Model, discriminator='Parent'):
class Meta:
host = 'http://localhost:8000'
table_name = 'test'
hash_key = UnicodeAttribute(hash_key=True)
value = TypedValue()
values = ListAttribute(of=TypedValue)
type = DiscriminatorAttribute()


class ChildModel(DiscriminatorTestModel, discriminator='Child'):
value = UnicodeAttribute()


class TestDiscriminatorAttribute:
Expand All @@ -46,6 +51,7 @@ def test_serialize(self):
dtm.values = [NumberValue(name='bar', value=5), RenamedValue(name='baz', value='World')]
assert dtm.serialize() == {
'hash_key': {'S': 'foo'},
'type': {'S': 'Parent'},
'value': {'M': {'cls': {'S': 'StringValue'}, 'name': {'S': 'foo'}, 'value': {'S': 'Hello'}}},
'values': {'L': [
{'M': {'cls': {'S': 'NumberValue'}, 'name': {'S': 'bar'}, 'value': {'N': '5'}}},
Expand All @@ -56,6 +62,7 @@ def test_serialize(self):
def test_deserialize(self):
item = {
'hash_key': {'S': 'foo'},
'type': {'S': 'Parent'},
'value': {'M': {'cls': {'S': 'StringValue'}, 'name': {'S': 'foo'}, 'value': {'S': 'Hello'}}},
'values': {'L': [
{'M': {'cls': {'S': 'NumberValue'}, 'name': {'S': 'bar'}, 'value': {'N': '5'}}},
Expand Down Expand Up @@ -96,8 +103,28 @@ def test_multiple_discriminator_classes(self):
class RenamedValue2(TypedValue, discriminator='custom_name'):
pass

def test_model(self):
with pytest.raises(NotImplementedError):
class DiscriminatedModel(Model):
hash_key = UnicodeAttribute(hash_key=True)
_cls = DiscriminatorAttribute()
class TestDiscriminatorModel:

def test_serialize(self):
cm = ChildModel()
cm.hash_key = 'foo'
cm.value = 'bar'
cm.values = []
assert cm.serialize() == {
'hash_key': {'S': 'foo'},
'type': {'S': 'Child'},
'value': {'S': 'bar'},
'values': {'L': []}
}

def test_deserialize(self):
item = {
'hash_key': {'S': 'foo'},
'type': {'S': 'Child'},
'value': {'S': 'bar'},
'values': {'L': []}
}
cm = DiscriminatorTestModel.from_raw_data(item)
assert isinstance(cm, ChildModel)
assert cm.hash_key == 'foo'
assert cm.value == 'bar'
Loading

0 comments on commit 2e2a37c

Please sign in to comment.