Skip to content

Commit

Permalink
Ensure queries and scans return model subclasses when using discrimin…
Browse files Browse the repository at this point in the history
…ators. (#873)
  • Loading branch information
jpinner-lyft authored Oct 21, 2020
1 parent ce43a2f commit ec56f95
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 12 deletions.
2 changes: 1 addition & 1 deletion docs/release_notes.rst
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Release Notes
=============

v5.0.0b2
v5.0.0b3
-------------------

:date: 2020-xx-xx
Expand Down
2 changes: 1 addition & 1 deletion pynamodb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@
"""
__author__ = 'Jharrod LaFon'
__license__ = 'MIT'
__version__ = '5.0.0b2'
__version__ = '5.0.0b3'
3 changes: 3 additions & 0 deletions pynamodb/attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,9 @@ def register_class(self, cls: type, discriminator: Any):

self._discriminator_map[discriminator] = cls

def get_registered_subclasses(self, cls: type) -> List[type]:
return [k for k in self._class_map.keys() if issubclass(k, cls)]

def get_discriminator(self, cls: type) -> Optional[Any]:
return self._class_map.get(cls)

Expand Down
18 changes: 9 additions & 9 deletions pynamodb/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,10 +575,10 @@ def count(
else:
hash_key = cls._serialize_keys(hash_key)[0]

# If this class has a discriminator value, filter the query to only return instances of this class.
# If this class has a discriminator attribute, filter the query to only return instances of this class.
discriminator_attr = cls._get_discriminator_attribute()
if discriminator_attr and discriminator_attr.get_discriminator(cls):
filter_condition &= discriminator_attr == cls
if discriminator_attr:
filter_condition &= discriminator_attr.is_in(*discriminator_attr.get_registered_subclasses(cls))

query_args = (hash_key,)
query_kwargs = dict(
Expand Down Expand Up @@ -640,10 +640,10 @@ def query(
else:
hash_key = cls._serialize_keys(hash_key)[0]

# If this class has a discriminator value, filter the query to only return instances of this class.
# If this class has a discriminator attribute, filter the query to only return instances of this class.
discriminator_attr = cls._get_discriminator_attribute()
if discriminator_attr and discriminator_attr.get_discriminator(cls):
filter_condition &= discriminator_attr == cls
if discriminator_attr:
filter_condition &= discriminator_attr.is_in(*discriminator_attr.get_registered_subclasses(cls))

if page_size is None:
page_size = limit
Expand Down Expand Up @@ -697,10 +697,10 @@ def scan(
:param rate_limit: If set then consumed capacity will be limited to this amount per second
:param attributes_to_get: If set, specifies the properties to include in the projection expression
"""
# If this class has a discriminator value, filter the scan to only return instances of this class.
# If this class has a discriminator attribute, filter the scan to only return instances of this class.
discriminator_attr = cls._get_discriminator_attribute()
if discriminator_attr and discriminator_attr.get_discriminator(cls):
filter_condition &= discriminator_attr == cls
if discriminator_attr:
filter_condition &= discriminator_attr.is_in(*discriminator_attr.get_registered_subclasses(cls))

if page_size is None:
page_size = limit
Expand Down
9 changes: 8 additions & 1 deletion tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1553,6 +1553,10 @@ class Meta:
class ChildModel(ParentModel, discriminator='Child'):
foo = UnicodeAttribute()

# register a model that subclasses Child to ensure queries return model subclasses
class GrandchildModel(ChildModel, discriminator='Grandchild'):
bar = UnicodeAttribute()

with patch(PATCH_METHOD) as req:
req.return_value = {
"Table": {
Expand Down Expand Up @@ -1588,7 +1592,7 @@ class ChildModel(ParentModel, discriminator='Child'):
pass
params = {
'KeyConditionExpression': '#0 = :0',
'FilterExpression': '#1 = :1',
'FilterExpression': '#1 IN (:1, :2)',
'ExpressionAttributeNames': {
'#0': 'id',
'#1': 'cls'
Expand All @@ -1599,6 +1603,9 @@ class ChildModel(ParentModel, discriminator='Child'):
},
':1': {
'S': u'Child'
},
':2': {
'S': u'Grandchild'
}
},
'ReturnConsumedCapacity': 'TOTAL',
Expand Down

0 comments on commit ec56f95

Please sign in to comment.