Skip to content

Commit

Permalink
Model.batch_get: add guard-rails (#1184)
Browse files Browse the repository at this point in the history
For models with a range key, fail if:
- item is a `str` ("accidental" iterable)
- item is an iterable with != 2 items
  • Loading branch information
ikonst authored May 26, 2023
1 parent 0cf2e94 commit 12e127f
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 6 deletions.
18 changes: 13 additions & 5 deletions pynamodb/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,15 +357,23 @@ def batch_get(
keys_to_get = []
item = items.pop()
if range_key_attribute:
hash_key, range_key = cls._serialize_keys(item[0], item[1]) # type: ignore
if isinstance(item, str):
raise ValueError(f'Invalid key value {item!r}: '
'expected non-str iterable with exactly 2 elements (hash key, range key)')
try:
hash_key, range_key = item
except (TypeError, ValueError):
raise ValueError(f'Invalid key value {item!r}: '
'expected iterable with exactly 2 elements (hash key, range key)')
hash_key_ser, range_key_ser = cls._serialize_keys(hash_key, range_key)
keys_to_get.append({
hash_key_attribute.attr_name: hash_key,
range_key_attribute.attr_name: range_key
hash_key_attribute.attr_name: hash_key_ser,
range_key_attribute.attr_name: range_key_ser,
})
else:
hash_key = cls._serialize_keys(item)[0]
hash_key_ser, _ = cls._serialize_keys(item)
keys_to_get.append({
hash_key_attribute.attr_name: hash_key
hash_key_attribute.attr_name: hash_key_ser
})

while keys_to_get:
Expand Down
52 changes: 51 additions & 1 deletion tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import base64
import json
import copy
import re
from datetime import datetime
from datetime import timedelta
from datetime import timezone
Expand Down Expand Up @@ -1845,7 +1846,6 @@ def test_batch_get(self):
}
self.assertEqual(params, req.call_args[0][1])


with patch(PATCH_METHOD) as req:
item_keys = [('hash-{}'.format(x), '{}'.format(x)) for x in range(10)]
item_keys_copy = list(item_keys)
Expand Down Expand Up @@ -1906,6 +1906,56 @@ def fake_batch_get(*batch_args):
for item in UserModel.batch_get(item_keys):
self.assertIsNotNone(item)

def test_batch_get__range_key(self):
with patch(PATCH_METHOD) as req:
req.return_value = {
'UnprocessedKeys': {},
'Responses': {
'UserModel': [],
}
}
items = [(f'hash-{x}', f'range-{x}') for x in range(10)]
_ = list(UserModel.batch_get(items))

actual_keys = req.call_args[0][1]['RequestItems']['UserModel']['Keys']
actual_keys.sort(key=json.dumps)
assert actual_keys == [
{'user_name': {'S': f'hash-{x}'}, 'user_id': {'S': f'range-{x}'}}
for x in range(10)
]

def test_batch_get__range_key__invalid__string(self):
with patch(PATCH_METHOD) as req:
req.return_value = {
'UnprocessedKeys': {},
'Responses': {
'UserModel': [],
}
}
with pytest.raises(
ValueError,
match=re.escape(
"Invalid key value 'ab': expected non-str iterable with exactly 2 elements (hash key, range key)"
)
):
_ = list(UserModel.batch_get(['ab']))

def test_batch_get__range_key__invalid__3_elements(self):
with patch(PATCH_METHOD) as req:
req.return_value = {
'UnprocessedKeys': {},
'Responses': {
'UserModel': [],
}
}
with pytest.raises(
ValueError,
match=re.escape(
"Invalid key value ('a', 'b', 'c'): expected iterable with exactly 2 elements (hash key, range key)"
)
):
_ = list(UserModel.batch_get([('a', 'b', 'c')]))

def test_batch_write(self):
"""
Model.batch_write
Expand Down

0 comments on commit 12e127f

Please sign in to comment.