From 12d8fd23025c4b6a1f4ec2a98b14d2f8277e1782 Mon Sep 17 00:00:00 2001 From: Eric Muller Date: Mon, 5 Aug 2019 16:46:14 -0500 Subject: [PATCH] Support for versioned optimistic locking a la DynamoDBMapper (#664) --- AUTHORS.rst | 3 +- docs/optimistic_locking.rst | 172 +++++++++++++++++ examples/optimistic_locking.py | 140 ++++++++++++++ pynamodb/attributes.py | 32 ++++ pynamodb/attributes.pyi | 5 + pynamodb/connection/base.py | 2 +- pynamodb/models.py | 74 +++++++- pynamodb/transactions.py | 8 +- tests/data.py | 35 ++++ tests/integration/model_integration_test.py | 6 +- .../test_transaction_integration.py | 93 +++++++++- tests/test_attributes.py | 17 +- tests/test_model.py | 174 +++++++++++++++++- tests/test_transaction.py | 17 +- 14 files changed, 759 insertions(+), 19 deletions(-) create mode 100644 docs/optimistic_locking.rst create mode 100644 examples/optimistic_locking.py diff --git a/AUTHORS.rst b/AUTHORS.rst index f914fc0a5..dc8ede754 100644 --- a/AUTHORS.rst +++ b/AUTHORS.rst @@ -3,4 +3,5 @@ PynamoDB is written and maintained by Jharrod LaFon and numerous contributors: * Craig Bruce * Adam Chainz * Andy Wolfe -* Pior Bastida \ No newline at end of file +* Pior Bastida +* Eric Muller \ No newline at end of file diff --git a/docs/optimistic_locking.rst b/docs/optimistic_locking.rst new file mode 100644 index 000000000..93b66f63d --- /dev/null +++ b/docs/optimistic_locking.rst @@ -0,0 +1,172 @@ +================== +Optimistic Locking +================== + +Optimistic Locking is a strategy for ensuring that your database writes are not overwritten by the writes of others. +With optimistic locking, each item has an attribute that acts as a version number. If you retrieve an item from a +table, the application records the version number of that item. You can update the item, but only if the version number +on the server side has not changed. If there is a version mismatch, it means that someone else has modified the item +before you did. The update attempt fails, because you have a stale version of the item. If this happens, you simply +try again by retrieving the item and then trying to update it. Optimistic locking prevents you from accidentally +overwriting changes that were made by others. It also prevents others from accidentally overwriting your changes. + +.. warning:: - Optimistic locking will not work properly if you use DynamoDB global tables as they use last-write-wins for concurrent updates. + +See also: +`DynamoDBMapper Documentation on Optimistic Locking `_. + +Version Attribute +----------------- + +To enable optimistic locking for a table simply add a ``VersionAttribute`` to your model definition. + +.. code-block:: python + + class OfficeEmployeeMap(MapAttribute): + office_employee_id = UnicodeAttribute() + person = UnicodeAttribute() + + def __eq__(self, other): + return isinstance(other, OfficeEmployeeMap) and self.person == other.person + + def __repr__(self): + return str(vars(self)) + + + class Office(Model): + class Meta: + read_capacity_units = 1 + write_capacity_units = 1 + table_name = 'Office' + host = "http://localhost:8000" + office_id = UnicodeAttribute(hash_key=True) + employees = ListAttribute(of=OfficeEmployeeMap) + name = UnicodeAttribute() + version = VersionAttribute() + +The attribute is underpinned by an integer which is initialized with 1 when an item is saved for the first time +and is incremented by 1 with each subsequent write operation. + +.. code-block:: python + + justin = OfficeEmployeeMap(office_employee_id=str(uuid4()), person='justin') + garrett = OfficeEmployeeMap(office_employee_id=str(uuid4()), person='garrett') + office = Office(office_id=str(uuid4()), name="office", employees=[justin, garrett]) + office.save() + assert office.version == 1 + + # Get a second local copy of Office + office_out_of_date = Office.get(office.office_id) + + # Add another employee and persist the change. + office.employees.append(OfficeEmployeeMap(office_employee_id=str(uuid4()), person='lita')) + office.save() + # On subsequent save or update operations the version is also incremented locally to match the persisted value so + # there's no need to refresh between operations when reusing the local copy. + assert office.version == 2 + assert office_out_of_date.version == 1 + +The version checking is implemented using DynamoDB conditional write constraints, asserting that no value exists +for the version attribute on the initial save and that the persisted value matches the local value on subsequent writes. + + +Model.{update, save, delete} +---------------------------- +These operations will fail if the local object is out-of-date. + +.. code-block:: python + + @contextmanager + def assert_condition_check_fails(): + try: + yield + except (PutError, UpdateError, DeleteError) as e: + assert isinstance(e.cause, ClientError) + assert e.cause_response_code == "ConditionalCheckFailedException" + except TransactWriteError as e: + assert isinstance(e.cause, ClientError) + assert e.cause_response_code == "TransactionCanceledException" + assert "ConditionalCheckFailed" in e.cause_response_message + else: + raise AssertionError("The version attribute conditional check should have failed.") + + + with assert_condition_check_fails(): + office_out_of_date.update(actions=[Office.name.set('new office name')]) + + office_out_of_date.employees.remove(garrett) + with assert_condition_check_fails(): + office_out_of_date.save() + + # After refreshing the local copy our write operations succeed. + office_out_of_date.refresh() + office_out_of_date.employees.remove(garrett) + office_out_of_date.save() + assert office_out_of_date.version == 3 + + with assert_condition_check_fails(): + office.delete() + +Transactions +------------ + +Transactions are supported. + +Successful +__________ + +.. code-block:: python + + connection = Connection(host='http://localhost:8000') + + office2 = Office(office_id=str(uuid4()), name="second office", employees=[justin]) + office2.save() + assert office2.version == 1 + office3 = Office(office_id=str(uuid4()), name="third office", employees=[garrett]) + office3.save() + assert office3.version == 1 + + with TransactWrite(connection=connection) as transaction: + transaction.condition_check(Office, office.office_id, condition=(Office.name.exists())) + transaction.delete(office2) + transaction.save(Office(office_id=str(uuid4()), name="new office", employees=[justin, garrett])) + transaction.update( + office3, + actions=[ + Office.name.set('birdistheword'), + ] + ) + + try: + office2.refresh() + except DoesNotExist: + pass + else: + raise AssertionError( + 'Office with office_id="{}" should have been deleted in the transaction.' + .format(office2.office_id) + ) + + assert office.version == 2 + assert office3.version == 2 + +Failed +______ + +.. code-block:: python + + with assert_condition_check_fails(), TransactWrite(connection=connection) as transaction: + transaction.save(Office(office.office_id, name='newer name', employees=[])) + + with assert_condition_check_fails(), TransactWrite(connection=connection) as transaction: + transaction.update( + Office(office.office_id, name='newer name', employees=[]), + actions=[Office.name.set('Newer Office Name')] + ) + + with assert_condition_check_fails(), TransactWrite(connection=connection) as transaction: + transaction.delete(Office(office.office_id, name='newer name', employees=[])) + +Batch Operations +---------------- +*Unsupported* as they do not support conditional writes. diff --git a/examples/optimistic_locking.py b/examples/optimistic_locking.py new file mode 100644 index 000000000..d80e9833a --- /dev/null +++ b/examples/optimistic_locking.py @@ -0,0 +1,140 @@ +from contextlib import contextmanager +from uuid import uuid4 +from botocore.client import ClientError + +from pynamodb.connection import Connection +from pynamodb.attributes import ListAttribute, MapAttribute, UnicodeAttribute, VersionAttribute +from pynamodb.exceptions import PutError, UpdateError, TransactWriteError, DeleteError, DoesNotExist +from pynamodb.models import Model +from pynamodb.transactions import TransactWrite + + +class OfficeEmployeeMap(MapAttribute): + office_employee_id = UnicodeAttribute() + person = UnicodeAttribute() + + def __eq__(self, other): + return isinstance(other, OfficeEmployeeMap) and self.person == other.person + + def __repr__(self): + return str(vars(self)) + + +class Office(Model): + class Meta: + read_capacity_units = 1 + write_capacity_units = 1 + table_name = 'Office' + host = "http://localhost:8000" + office_id = UnicodeAttribute(hash_key=True) + employees = ListAttribute(of=OfficeEmployeeMap) + name = UnicodeAttribute() + version = VersionAttribute() + + +if not Office.exists(): + Office.create_table(wait=True) + + +@contextmanager +def assert_condition_check_fails(): + try: + yield + except (PutError, UpdateError, DeleteError) as e: + assert isinstance(e.cause, ClientError) + assert e.cause_response_code == "ConditionalCheckFailedException" + except TransactWriteError as e: + assert isinstance(e.cause, ClientError) + assert e.cause_response_code == "TransactionCanceledException" + assert "ConditionalCheckFailed" in e.cause_response_message + else: + raise AssertionError("The version attribute conditional check should have failed.") + + +justin = OfficeEmployeeMap(office_employee_id=str(uuid4()), person='justin') +garrett = OfficeEmployeeMap(office_employee_id=str(uuid4()), person='garrett') +office = Office(office_id=str(uuid4()), name="office 3", employees=[justin, garrett]) +office.save() +assert office.version == 1 + +# Get a second local copy of Office +office_out_of_date = Office.get(office.office_id) +# Add another employee and save the changes. +office.employees.append(OfficeEmployeeMap(office_employee_id=str(uuid4()), person='lita')) +office.save() +# After a successful save or update operation the version is set or incremented locally so there's no need to refresh +# between operations using the same local copy. +assert office.version == 2 +assert office_out_of_date.version == 1 + +# Condition check fails for update. +with assert_condition_check_fails(): + office_out_of_date.update(actions=[Office.name.set('new office name')]) + +# Condition check fails for save. +office_out_of_date.employees.remove(garrett) +with assert_condition_check_fails(): + office_out_of_date.save() + +# After refreshing the local copy the operation will succeed. +office_out_of_date.refresh() +office_out_of_date.employees.remove(garrett) +office_out_of_date.save() +assert office_out_of_date.version == 3 + +# Condition check fails for delete. +with assert_condition_check_fails(): + office.delete() + +# Example failed transactions. +connection = Connection(host='http://localhost:8000') + +with assert_condition_check_fails(), TransactWrite(connection=connection) as transaction: + transaction.save(Office(office.office_id, name='newer name', employees=[])) + +with assert_condition_check_fails(), TransactWrite(connection=connection) as transaction: + transaction.update( + Office(office.office_id, name='newer name', employees=[]), + actions=[ + Office.name.set('Newer Office Name'), + ] + ) + +with assert_condition_check_fails(), TransactWrite(connection=connection) as transaction: + transaction.delete(Office(office.office_id, name='newer name', employees=[])) + +# Example successful transaction. +office2 = Office(office_id=str(uuid4()), name="second office", employees=[justin]) +office2.save() +assert office2.version == 1 +office3 = Office(office_id=str(uuid4()), name="third office", employees=[garrett]) +office3.save() +assert office3.version == 1 + +with TransactWrite(connection=connection) as transaction: + transaction.condition_check(Office, office.office_id, condition=(Office.name.exists())) + transaction.delete(office2) + transaction.save(Office(office_id=str(uuid4()), name="new office", employees=[justin, garrett])) + transaction.update( + office3, + actions=[ + Office.name.set('birdistheword'), + ] + ) + +try: + office2.refresh() +except DoesNotExist: + pass +else: + raise AssertionError( + "This item should have been deleted, but no DoesNotExist " + "exception was raised when attempting to refresh a local copy." + ) + +assert office.version == 2 +# The version attribute of items which are saved or updated in a transaction are updated automatically to match the +# persisted value. +assert office3.version == 2 +office.refresh() +assert office.version == 3 diff --git a/pynamodb/attributes.py b/pynamodb/attributes.py index 482628160..4d23557da 100644 --- a/pynamodb/attributes.py +++ b/pynamodb/attributes.py @@ -494,6 +494,38 @@ def deserialize(self, value): return json.loads(value) +class VersionAttribute(NumberAttribute): + """ + A version attribute + """ + null = True + + def __set__(self, instance, value): + """ + Cast assigned value to int. + """ + super(VersionAttribute, self).__set__(instance, int(value)) + + def __get__(self, instance, owner): + """ + Cast retrieved value to int. + """ + val = super(VersionAttribute, self).__get__(instance, owner) + return int(val) if isinstance(val, float) else val + + def serialize(self, value): + """ + Cast value to int then encode as JSON + """ + return super(VersionAttribute, self).serialize(int(value)) + + def deserialize(self, value): + """ + Decode numbers from JSON and cast to int. + """ + return int(super(VersionAttribute, self).deserialize(value)) + + class TTLAttribute(Attribute): """ A time-to-live attribute that signifies when the item expires and can be automatically deleted. diff --git a/pynamodb/attributes.pyi b/pynamodb/attributes.pyi index e168ee97f..9937d3ec6 100644 --- a/pynamodb/attributes.pyi +++ b/pynamodb/attributes.pyi @@ -119,6 +119,11 @@ class NumberAttribute(Attribute[float]): @overload def __get__(self, instance: Any, owner: Any) -> float: ... +class VersionAttribute(NumberAttribute): + @overload + def __get__(self: _A, instance: None, owner: Any) -> _A: ... + @overload + def __get__(self, instance: Any, owner: Any) -> int: ... class TTLAttribute(Attribute[datetime]): @overload diff --git a/pynamodb/connection/base.py b/pynamodb/connection/base.py index 58a1704ef..1123edbba 100644 --- a/pynamodb/connection/base.py +++ b/pynamodb/connection/base.py @@ -849,7 +849,7 @@ def get_operation_kwargs(self, operation_kwargs[TABLE_NAME] = table_name operation_kwargs.update(self.get_identifier_map(table_name, hash_key, range_key, key=key)) - if attributes: + if attributes and operation_kwargs.get(ITEM) is not None: attrs = self.get_item_attribute_map(table_name, attributes) operation_kwargs[ITEM].update(attrs[ITEM]) if attributes_to_get is not None: diff --git a/pynamodb/models.py b/pynamodb/models.py index 18329fa64..6dd7fe764 100644 --- a/pynamodb/models.py +++ b/pynamodb/models.py @@ -9,8 +9,12 @@ from inspect import getmembers from six import add_metaclass + +from pynamodb.expressions.condition import NotExists, Comparison from pynamodb.exceptions import DoesNotExist, TableDoesNotExist, TableError, InvalidStateError -from pynamodb.attributes import Attribute, AttributeContainer, AttributeContainerMeta, MapAttribute, TTLAttribute +from pynamodb.attributes import ( + Attribute, AttributeContainer, AttributeContainerMeta, MapAttribute, TTLAttribute, VersionAttribute +) from pynamodb.connection.table import TableConnection from pynamodb.connection.util import pythonic from pynamodb.types import HASH, RANGE @@ -172,6 +176,13 @@ def __init__(cls, name, bases, attrs): cls._hash_keyname = attr_name if attribute.is_range_key: cls._range_keyname = attr_name + if isinstance(attribute, VersionAttribute): + if cls._version_attribute_name: + raise ValueError( + "The model has more than one Version attribute: {}, {}" + .format(cls._version_attribute_name, attr_name) + ) + cls._version_attribute_name = attr_name if isinstance(attrs, dict): for attr_name, attr_obj in attrs.items(): if attr_name == META_CLASS_NAME: @@ -238,6 +249,7 @@ class Model(AttributeContainer): _connection = None _index_classes = None DoesNotExist = DoesNotExist + _version_attribute_name = None def __init__(self, hash_key=None, range_key=None, _user_instantiated=True, **attributes): """ @@ -334,6 +346,10 @@ def delete(self, condition=None): Deletes this object from dynamodb """ args, kwargs = self._get_save_args(attributes=False, null_check=False) + version_condition = self._handle_version_attribute(kwargs) + if version_condition is not None: + condition &= version_condition + kwargs.update(condition=condition) return self._get_connection().delete_item(*args, **kwargs) @@ -348,6 +364,9 @@ def update(self, actions, condition=None): raise TypeError("the value of `actions` is expected to be a non-empty list") args, save_kwargs = self._get_save_args(null_check=False) + version_condition = self._handle_version_attribute(save_kwargs, actions=actions) + if version_condition is not None: + condition &= version_condition kwargs = { pythonic(RETURN_VALUES): ALL_NEW, } @@ -371,8 +390,13 @@ def save(self, condition=None): Save this object to dynamodb """ args, kwargs = self._get_save_args() + version_condition = self._handle_version_attribute(serialized_attributes=kwargs) + if version_condition is not None: + condition &= version_condition kwargs.update(condition=condition) - return self._get_connection().put_item(*args, **kwargs) + data = self._get_connection().put_item(*args, **kwargs) + self.update_local_version_attribute() + return data def refresh(self, consistent_read=False): """ @@ -394,7 +418,16 @@ def get_operation_kwargs_from_instance(self, condition=None, return_values_on_condition_failure=None): is_update = actions is not None + is_delete = actions is None and key is KEY args, save_kwargs = self._get_save_args(null_check=not is_update) + + version_condition = self._handle_version_attribute( + serialized_attributes={} if is_delete else save_kwargs, + actions=actions + ) + if version_condition is not None: + condition &= version_condition + kwargs = dict( key=key, actions=actions, @@ -881,6 +914,43 @@ def _get_save_args(self, attributes=True, null_check=True): kwargs[pythonic(ATTRIBUTES)] = serialized[pythonic(ATTRIBUTES)] return args, kwargs + def _handle_version_attribute(self, serialized_attributes, actions=None): + """ + Handles modifying the request to set or increment the version attribute. + + :param serialized_attributes: A dictionary mapping attribute names to serialized values. + :param actions: A non-empty list when performing an update, otherwise None. + """ + if self._version_attribute_name is None: + return + + version_attribute = self.get_attributes()[self._version_attribute_name] + version_attribute_value = getattr(self, self._version_attribute_name) + + if version_attribute_value: + version_condition = version_attribute == version_attribute_value + if actions: + actions.append(version_attribute.add(1)) + elif pythonic(ATTRIBUTES) in serialized_attributes: + serialized_attributes[pythonic(ATTRIBUTES)][version_attribute.attr_name] = self._serialize_value( + version_attribute, version_attribute_value + 1, null_check=True + ) + else: + version_condition = version_attribute.does_not_exist() + if actions: + actions.append(version_attribute.set(1)) + elif pythonic(ATTRIBUTES) in serialized_attributes: + serialized_attributes[pythonic(ATTRIBUTES)][version_attribute.attr_name] = self._serialize_value( + version_attribute, 1, null_check=True + ) + + return version_condition + + def update_local_version_attribute(self): + if self._version_attribute_name: + value = getattr(self, self._version_attribute_name, None) or 0 + setattr(self, self._version_attribute_name, value + 1) + @classmethod def _hash_key_attribute(cls): """ diff --git a/pynamodb/transactions.py b/pynamodb/transactions.py index 331a526d8..9d19e72ee 100644 --- a/pynamodb/transactions.py +++ b/pynamodb/transactions.py @@ -72,6 +72,7 @@ def __init__(self, client_request_token=None, return_item_collection_metrics=Non self._delete_items = [] self._put_items = [] self._update_items = [] + self._models_for_version_attribute_update = [] def condition_check(self, model_cls, hash_key, range_key=None, condition=None): if condition is None: @@ -94,6 +95,7 @@ def save(self, model, condition=None, return_values=None): return_values_on_condition_failure=return_values ) self._put_items.append(operation_kwargs) + self._models_for_version_attribute_update.append(model) def update(self, model, actions, condition=None, return_values=None): operation_kwargs = model.get_operation_kwargs_from_instance( @@ -102,9 +104,10 @@ def update(self, model, actions, condition=None, return_values=None): return_values_on_condition_failure=return_values ) self._update_items.append(operation_kwargs) + self._models_for_version_attribute_update.append(model) def _commit(self): - return self._connection.transact_write_items( + response = self._connection.transact_write_items( condition_check_items=self._condition_check_items, delete_items=self._delete_items, put_items=self._put_items, @@ -113,3 +116,6 @@ def _commit(self): return_consumed_capacity=self._return_consumed_capacity, return_item_collection_metrics=self._return_item_collection_metrics, ) + for model in self._models_for_version_attribute_update: + model.update_local_version_attribute() + return response diff --git a/tests/data.py b/tests/data.py index 7538a1006..b81287a09 100644 --- a/tests/data.py +++ b/tests/data.py @@ -1412,3 +1412,38 @@ "TableStatus": "ACTIVE" } } + +VERSIONED_TABLE_DATA = { + "Table": { + "AttributeDefinitions": [ + { + "AttributeName": "name", + "AttributeType": "S" + }, + { + "AttributeName": "email", + "AttributeType": "S" + }, + { + "AttributeName": "version", + "AttributeType": "N" + } + ], + "CreationDateTime": 1.363729002358E9, + "ItemCount": 42, + "KeySchema": [ + { + "AttributeName": "name", + "KeyType": "HASH" + }, + ], + "ProvisionedThroughput": { + "NumberOfDecreasesToday": 0, + "ReadCapacityUnits": 5, + "WriteCapacityUnits": 5 + }, + "TableName": "VersionedModel", + "TableSizeBytes": 0, + "TableStatus": "ACTIVE" + } +} diff --git a/tests/integration/model_integration_test.py b/tests/integration/model_integration_test.py index 222750e0f..92649e295 100644 --- a/tests/integration/model_integration_test.py +++ b/tests/integration/model_integration_test.py @@ -5,8 +5,8 @@ from pynamodb.models import Model from pynamodb.indexes import GlobalSecondaryIndex, AllProjection, LocalSecondaryIndex from pynamodb.attributes import ( - UnicodeAttribute, BinaryAttribute, UTCDateTimeAttribute, NumberSetAttribute, NumberAttribute -) + UnicodeAttribute, BinaryAttribute, UTCDateTimeAttribute, NumberSetAttribute, NumberAttribute, + VersionAttribute) import pytest @@ -51,6 +51,7 @@ class Meta: epoch = UTCDateTimeAttribute(default=datetime.now) content = BinaryAttribute(null=True) scores = NumberSetAttribute() + version = VersionAttribute() if not TestModel.exists(): print("Creating table") @@ -100,3 +101,4 @@ class Meta: print("Item queried from index: {}".format(item.view)) print(query_obj.update([TestModel.view.add(1)], condition=TestModel.forum.exists())) + TestModel.delete_table() diff --git a/tests/integration/test_transaction_integration.py b/tests/integration/test_transaction_integration.py index 518527eee..ace251fec 100644 --- a/tests/integration/test_transaction_integration.py +++ b/tests/integration/test_transaction_integration.py @@ -6,7 +6,10 @@ from pynamodb.connection import Connection from pynamodb.exceptions import DoesNotExist, TransactWriteError, TransactGetError, InvalidStateError -from pynamodb.attributes import NumberAttribute, UnicodeAttribute, UTCDateTimeAttribute, BooleanAttribute + +from pynamodb.attributes import ( + NumberAttribute, UnicodeAttribute, UTCDateTimeAttribute, BooleanAttribute, VersionAttribute +) from pynamodb.transactions import TransactGet, TransactWrite from pynamodb.models import Model @@ -59,11 +62,22 @@ class Meta: entry_index = NumberAttribute(hash_key=True) +class Foo(Model): + class Meta: + region = 'us-east-1' + table_name = 'foo' + + bar = NumberAttribute(hash_key=True) + star = UnicodeAttribute(null=True) + version = VersionAttribute() + + TEST_MODELS = [ BankStatement, DifferentRegion, LineItem, User, + Foo ] @@ -271,3 +285,80 @@ def test_transact_write__one_of_each(connection): statement.refresh() assert not statement.active assert statement.balance == 0 + + +@pytest.mark.ddblocal +def test_transaction_write_with_version_attribute(connection): + foo1 = Foo(1) + foo1.save() + foo2 = Foo(2, star='bar') + foo2.save() + foo3 = Foo(3) + foo3.save() + + with TransactWrite(connection=connection) as transaction: + transaction.condition_check(Foo, 1, condition=(Foo.bar.exists())) + transaction.delete(foo2) + transaction.save(Foo(4)) + transaction.update( + foo3, + actions=[ + Foo.star.set('birdistheword'), + ] + ) + + assert Foo.get(1).version == 1 + with pytest.raises(DoesNotExist): + Foo.get(2) + # Local object's version attribute is updated automatically. + assert foo3.version == 2 + assert Foo.get(4).version == 1 + + +@pytest.mark.ddblocal +def test_transaction_get_with_version_attribute(connection): + Foo(11).save() + Foo(12, star='bar').save() + + with TransactGet(connection=connection) as transaction: + foo1_future = transaction.get(Foo, 11) + foo2_future = transaction.get(Foo, 12) + + foo1 = foo1_future.get() + assert foo1.version == 1 + foo2 = foo2_future.get() + assert foo2.version == 1 + assert foo2.star == 'bar' + + +@pytest.mark.ddblocal +def test_transaction_write_with_version_attribute_condition_failure(connection): + foo = Foo(21) + foo.save() + + foo2 = Foo(21) + + with pytest.raises(TransactWriteError) as exc_info: + with TransactWrite(connection=connection) as transaction: + transaction.save(Foo(21)) + assert get_error_code(exc_info.value) == TRANSACTION_CANCELLED + assert 'ConditionalCheckFailed' in get_error_message(exc_info.value) + + with pytest.raises(TransactWriteError) as exc_info: + with TransactWrite(connection=connection) as transaction: + transaction.update( + foo2, + actions=[ + Foo.star.set('birdistheword'), + ] + ) + assert get_error_code(exc_info.value) == TRANSACTION_CANCELLED + assert 'ConditionalCheckFailed' in get_error_message(exc_info.value) + # Version attribute is not updated on failure. + assert foo2.version is None + + with pytest.raises(TransactWriteError) as exc_info: + with TransactWrite(connection=connection) as transaction: + transaction.delete(foo2) + assert get_error_code(exc_info.value) == TRANSACTION_CANCELLED + assert 'ConditionalCheckFailed' in get_error_message(exc_info.value) diff --git a/tests/test_attributes.py b/tests/test_attributes.py index 4fc69ae2d..035df2723 100644 --- a/tests/test_attributes.py +++ b/tests/test_attributes.py @@ -18,7 +18,7 @@ BinarySetAttribute, BinaryAttribute, NumberSetAttribute, NumberAttribute, UnicodeAttribute, UnicodeSetAttribute, UTCDateTimeAttribute, BooleanAttribute, MapAttribute, ListAttribute, JSONAttribute, TTLAttribute, _get_value_for_deserialize, _fast_parse_utc_datestring, -) + VersionAttribute) from pynamodb.constants import ( DATETIME_FORMAT, DEFAULT_ENCODING, NUMBER, STRING, STRING_SET, NUMBER_SET, BINARY_SET, BINARY, BOOLEAN, @@ -1004,3 +1004,18 @@ def __eq__(self, other): assert deserialized == inp assert serialize_mock.call_args_list == [call(1), call(2)] assert deserialize_mock.call_args_list == [call('1'), call('2')] + + +class TestVersionAttribute: + def test_serialize(self): + attr = VersionAttribute() + assert attr.attr_type == NUMBER + assert attr.serialize(3.141) == '3' + assert attr.serialize(1) == '1' + assert attr.serialize(12345678909876543211234234324234) == '12345678909876543211234234324234' + + def test_deserialize(self): + attr = VersionAttribute() + assert attr.deserialize('1') == 1 + assert attr.deserialize('3.141') == 3 + assert attr.deserialize('12345678909876543211234234324234') == 12345678909876543211234234324234 diff --git a/tests/test_model.py b/tests/test_model.py index d6d94e55b..ea6b092f6 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -30,7 +30,7 @@ from pynamodb.attributes import ( UnicodeAttribute, NumberAttribute, BinaryAttribute, UTCDateTimeAttribute, UnicodeSetAttribute, NumberSetAttribute, BinarySetAttribute, MapAttribute, - BooleanAttribute, ListAttribute, TTLAttribute) + BooleanAttribute, ListAttribute, TTLAttribute, VersionAttribute) from .data import ( MODEL_TABLE_DATA, GET_MODEL_ITEM_DATA, SIMPLE_MODEL_TABLE_DATA, DESCRIBE_TABLE_DATA_PAY_PER_REQUEST, @@ -41,12 +41,13 @@ GET_OFFICE_EMPLOYEE_ITEM_DATA, GET_OFFICE_EMPLOYEE_ITEM_DATA_WITH_NULL, GROCERY_LIST_MODEL_TABLE_DATA, GET_GROCERY_LIST_ITEM_DATA, GET_OFFICE_ITEM_DATA, OFFICE_MODEL_TABLE_DATA, COMPLEX_MODEL_TABLE_DATA, COMPLEX_MODEL_ITEM_DATA, - CAR_MODEL_TABLE_DATA, FULL_CAR_MODEL_ITEM_DATA, CAR_MODEL_WITH_NULL_ITEM_DATA, INVALID_CAR_MODEL_WITH_NULL_ITEM_DATA, + CAR_MODEL_TABLE_DATA, FULL_CAR_MODEL_ITEM_DATA, CAR_MODEL_WITH_NULL_ITEM_DATA, + INVALID_CAR_MODEL_WITH_NULL_ITEM_DATA, BOOLEAN_MODEL_TABLE_DATA, BOOLEAN_MODEL_FALSE_ITEM_DATA, BOOLEAN_MODEL_TRUE_ITEM_DATA, TREE_MODEL_TABLE_DATA, TREE_MODEL_ITEM_DATA, EXPLICIT_RAW_MAP_MODEL_TABLE_DATA, EXPLICIT_RAW_MAP_MODEL_ITEM_DATA, - EXPLICIT_RAW_MAP_MODEL_AS_SUB_MAP_IN_TYPED_MAP_ITEM_DATA, EXPLICIT_RAW_MAP_MODEL_AS_SUB_MAP_IN_TYPED_MAP_TABLE_DATA -) + EXPLICIT_RAW_MAP_MODEL_AS_SUB_MAP_IN_TYPED_MAP_ITEM_DATA, EXPLICIT_RAW_MAP_MODEL_AS_SUB_MAP_IN_TYPED_MAP_TABLE_DATA, + VERSIONED_TABLE_DATA) if six.PY3: from unittest.mock import patch, MagicMock @@ -440,6 +441,15 @@ class Meta: my_ttl = TTLAttribute(default_for_new=timedelta(minutes=1)) +class VersionedModel(Model): + class Meta: + table_name = 'VersionedModel' + + name = UnicodeAttribute(hash_key=True) + email = UnicodeAttribute() + version = VersionAttribute() + + class ModelTestCase(TestCase): """ Tests for the models API @@ -2964,6 +2974,153 @@ def fake_dynamodb(*args, **kwargs): self.assert_dict_lists_equal(actual['AttributeDefinitions'], DOG_TABLE_DATA['Table']['AttributeDefinitions']) + def test_model_version_attribute_save(self): + self.init_table_meta(VersionedModel, VERSIONED_TABLE_DATA) + item = VersionedModel('test_user_name', email='test_user@email.com') + + with patch(PATCH_METHOD) as req: + req.return_value = {} + item.save() + args = req.call_args[0][1] + params = { + 'Item': { + 'name': { + 'S': 'test_user_name' + }, + 'email': { + 'S': 'test_user@email.com' + }, + 'version': { + 'N': '1' + }, + }, + 'ReturnConsumedCapacity': 'TOTAL', + 'TableName': 'VersionedModel', + 'ConditionExpression': 'attribute_not_exists (#0)', + 'ExpressionAttributeNames': {'#0': 'version'}, + } + + deep_eq(args, params, _assert=True) + item.version = 1 + item.name = "test_new_username" + item.save() + args = req.call_args[0][1] + + params = { + 'Item': { + 'name': { + 'S': 'test_new_username' + }, + 'email': { + 'S': 'test_user@email.com' + }, + 'version': { + 'N': '2' + }, + }, + 'ReturnConsumedCapacity': 'TOTAL', + 'TableName': 'VersionedModel', + 'ConditionExpression': '#0 = :0', + 'ExpressionAttributeNames': {'#0': 'version'}, + 'ExpressionAttributeValues': {':0': {'N': '1'}} + } + + deep_eq(args, params, _assert=True) + + def test_version_attribute_increments_on_update(self): + self.init_table_meta(VersionedModel, VERSIONED_TABLE_DATA) + item = VersionedModel('test_user_name', email='test_user@email.com') + + with patch(PATCH_METHOD) as req: + req.return_value = { + ATTRIBUTES: { + 'name': { + 'S': 'test_user_name' + }, + 'email': { + 'S': 'new@email.com' + }, + 'version': { + 'N': '1' + }, + } + } + item.update(actions=[VersionedModel.email.set('new@email.com')]) + args = req.call_args[0][1] + params = { + 'ConditionExpression': 'attribute_not_exists (#0)', + 'ExpressionAttributeNames': { + '#0': 'version', + '#1': 'email' + }, + 'ExpressionAttributeValues': { + ':0': { + 'S': 'new@email.com' + }, + ':1': { + 'N': '1' + } + }, + 'Key': { + 'name': { + 'S': 'test_user_name' + } + }, + 'ReturnConsumedCapacity': 'TOTAL', + 'ReturnValues': 'ALL_NEW', + 'TableName': 'VersionedModel', + 'UpdateExpression': 'SET #1 = :0, #0 = :1' + } + + deep_eq(args, params, _assert=True) + assert item.version == 1 + + req.return_value = { + ATTRIBUTES: { + 'name': { + 'S': 'test_user_name' + }, + 'email': { + 'S': 'newer@email.com' + }, + 'version': { + 'N': '2' + }, + } + } + + item.update(actions=[VersionedModel.email.set('newer@email.com')]) + args = req.call_args[0][1] + params = { + 'ConditionExpression': '#0 = :0', + 'ExpressionAttributeNames': { + '#0': 'version', + '#1': 'email' + }, + 'ExpressionAttributeValues': { + ':0': { + 'N': '1' + }, + ':1': { + 'S': 'newer@email.com' + }, + ':2': { + 'N': '1' + } + }, + 'Key': { + 'name': { + 'S': 'test_user_name' + } + }, + 'ReturnConsumedCapacity': 'TOTAL', + 'ReturnValues': 'ALL_NEW', + 'TableName': 'VersionedModel', + 'UpdateExpression': 'SET #1 = :1 ADD #0 :2' + } + + deep_eq(args, params, _assert=True) + class ModelInitTestCase(TestCase): @@ -3075,3 +3232,12 @@ def test_deserialized_with_ttl(self): req.return_value = SIMPLE_MODEL_TABLE_DATA m = TTLModel.from_raw_data({'user_name': {'S': 'mock'}, 'my_ttl': {'N': '1546300800'}}) assert m.my_ttl == datetime(2019, 1, 1, tzinfo=tzutc()) + + def test_multiple_version_attributes(self): + with self.assertRaises(ValueError): + class BadVersionedModel(Model): + class Meta: + table_name = 'BadVersionedModel' + + version = VersionAttribute() + another_version = VersionAttribute() diff --git a/tests/test_transaction.py b/tests/test_transaction.py index 7b87b3a6b..33fead69d 100644 --- a/tests/test_transaction.py +++ b/tests/test_transaction.py @@ -1,6 +1,6 @@ import pytest import six -from pynamodb.attributes import NumberAttribute, UnicodeAttribute +from pynamodb.attributes import NumberAttribute, UnicodeAttribute, VersionAttribute from pynamodb.connection import Connection from pynamodb.transactions import Transaction, TransactGet, TransactWrite @@ -20,6 +20,7 @@ class Meta: mock_hash = NumberAttribute(hash_key=True) mock_range = NumberAttribute(range_key=True) mock_toot = UnicodeAttribute(null=True) + mock_version = VersionAttribute() MOCK_TABLE_DESCRIPTOR = { @@ -99,22 +100,26 @@ def test_commit(self, mocker): 'TableName': 'mock'} ] expected_deletes = [{ + 'ConditionExpression': 'attribute_not_exists (#0)', + 'ExpressionAttributeNames': {'#0': 'mock_version'}, 'Key': {'MockHash': {'N': '2'}, 'MockRange': {'N': '4'}}, 'TableName': 'mock' }] expected_puts = [{ - 'Item': {'MockHash': {'N': '3'}, 'MockRange': {'N': '5'}}, + 'ConditionExpression': 'attribute_not_exists (#0)', + 'ExpressionAttributeNames': {'#0': 'mock_version'}, + 'Item': {'MockHash': {'N': '3'}, 'MockRange': {'N': '5'}, 'mock_version': {'N': '1'}}, 'TableName': 'mock' }] expected_updates = [{ + 'ConditionExpression': 'attribute_not_exists (#0)', 'TableName': 'mock', 'Key': {'MockHash': {'N': '4'}, 'MockRange': {'N': '6'}}, 'ReturnValuesOnConditionCheckFailure': 'ALL_OLD', - 'UpdateExpression': 'SET #0 = :0', - 'ExpressionAttributeNames': {'#0': 'mock_toot'}, - 'ExpressionAttributeValues': {':0': {'S': 'hello'}} + 'UpdateExpression': 'SET #1 = :0, #0 = :1', + 'ExpressionAttributeNames': {'#0': 'mock_version', '#1': 'mock_toot'}, + 'ExpressionAttributeValues': {':0': {'S': 'hello'}, ':1': {'N': '1'}} }] - mock_connection_transact_write.assert_called_once_with( condition_check_items=expected_condition_checks, delete_items=expected_deletes,