From 78db9e2b6fd8c56c33c48f912c03a7ab564a2ba8 Mon Sep 17 00:00:00 2001 From: Hallie Lomax Date: Mon, 8 Jul 2019 12:18:06 -0700 Subject: [PATCH] Add support for transactions (#618) --- .gitignore | 1 + pynamodb/connection/base.py | 255 +++++++++++----- pynamodb/connection/base.pyi | 24 +- pynamodb/connection/table.py | 35 ++- pynamodb/connection/table.pyi | 16 + pynamodb/connection/transactions.py | 115 ++++++++ pynamodb/connection/transactions.pyi | 41 +++ pynamodb/constants.py | 16 + pynamodb/constants.pyi | 14 +- pynamodb/exceptions.py | 21 ++ pynamodb/exceptions.pyi | 4 + pynamodb/models.py | 65 ++++- pynamodb/models.pyi | 11 +- tests/integration/__init__.py | 0 tests/integration/conftest.py | 2 +- .../test_transaction_integration.py | 273 ++++++++++++++++++ tests/test_base_connection.py | 24 ++ tests/test_transaction.py | 126 ++++++++ 18 files changed, 955 insertions(+), 88 deletions(-) create mode 100644 pynamodb/connection/transactions.py create mode 100644 pynamodb/connection/transactions.pyi create mode 100644 tests/integration/__init__.py create mode 100644 tests/integration/test_transaction_integration.py create mode 100644 tests/test_transaction.py diff --git a/.gitignore b/.gitignore index 040eb66d5..63d070304 100644 --- a/.gitignore +++ b/.gitignore @@ -43,6 +43,7 @@ venv pip-log.txt # Unit test / coverage reports +build/ .coverage cover/ .tox diff --git a/pynamodb/connection/base.py b/pynamodb/connection/base.py index 63a776660..58a1704ef 100644 --- a/pynamodb/connection/base.py +++ b/pynamodb/connection/base.py @@ -5,7 +5,6 @@ import json import logging -import math import random import sys import time @@ -39,15 +38,19 @@ CONSUMED_CAPACITY, CAPACITY_UNITS, SHORT_ATTR_TYPES, ITEMS, DEFAULT_ENCODING, BINARY_SHORT, BINARY_SET_SHORT, LAST_EVALUATED_KEY, RESPONSES, UNPROCESSED_KEYS, - UNPROCESSED_ITEMS, STREAM_SPECIFICATION, STREAM_VIEW_TYPE, STREAM_ENABLED, UPDATE_EXPRESSION, + UNPROCESSED_ITEMS, STREAM_SPECIFICATION, STREAM_VIEW_TYPE, STREAM_ENABLED, EXPRESSION_ATTRIBUTE_NAMES, EXPRESSION_ATTRIBUTE_VALUES, CONDITION_EXPRESSION, FILTER_EXPRESSION, + TRANSACT_WRITE_ITEMS, TRANSACT_GET_ITEMS, CLIENT_REQUEST_TOKEN, TRANSACT_ITEMS, TRANSACT_CONDITION_CHECK, + TRANSACT_GET, TRANSACT_PUT, TRANSACT_DELETE, TRANSACT_UPDATE, UPDATE_EXPRESSION, + RETURN_VALUES_ON_CONDITION_FAILURE_VALUES, RETURN_VALUES_ON_CONDITION_FAILURE, AVAILABLE_BILLING_MODES, DEFAULT_BILLING_MODE, BILLING_MODE, PAY_PER_REQUEST_BILLING_MODE, - TIME_TO_LIVE_SPECIFICATION, ENABLED, UPDATE_TIME_TO_LIVE) + TIME_TO_LIVE_SPECIFICATION, ENABLED, UPDATE_TIME_TO_LIVE +) from pynamodb.exceptions import ( TableError, QueryError, PutError, DeleteError, UpdateError, GetError, ScanError, TableDoesNotExist, - VerboseClientError -) + VerboseClientError, + TransactGetError, TransactWriteError) from pynamodb.expressions.condition import Condition from pynamodb.expressions.operand import Path from pynamodb.expressions.projection import create_projection_expression @@ -473,10 +476,15 @@ def _handle_binary_attributes(data): for attr in six.itervalues(item): _convert_binary(attr) if RESPONSES in data: - for item_list in six.itervalues(data[RESPONSES]): - for item in item_list: + if isinstance(data[RESPONSES], list): + for item in data[RESPONSES]: for attr in six.itervalues(item): _convert_binary(attr) + else: + for item_list in six.itervalues(data[RESPONSES]): + for item in item_list: + for attr in six.itervalues(item): + _convert_binary(attr) if LAST_EVALUATED_KEY in data: for attr in six.itervalues(data[LAST_EVALUATED_KEY]): _convert_binary(attr) @@ -787,6 +795,19 @@ def get_return_values_map(self, return_values): RETURN_VALUES: str(return_values).upper() } + def get_return_values_on_condition_failure_map(self, return_values_on_condition_failure): + """ + Builds the return values map that is common to several operations + """ + if return_values_on_condition_failure.upper() not in RETURN_VALUES_VALUES: + raise ValueError("{} must be one of {}".format( + RETURN_VALUES_ON_CONDITION_FAILURE, + RETURN_VALUES_ON_CONDITION_FAILURE_VALUES + )) + return { + RETURN_VALUES_ON_CONDITION_FAILURE: str(return_values_on_condition_failure).upper() + } + def get_item_collection_map(self, return_item_collection_metrics): """ Builds the item collection map @@ -806,38 +827,79 @@ def get_exclusive_start_key_map(self, table_name, exclusive_start_key): raise TableError("No such table {}".format(table_name)) return tbl.get_exclusive_start_key_map(exclusive_start_key) - def delete_item(self, - table_name, - hash_key, - range_key=None, - condition=None, - return_values=None, - return_consumed_capacity=None, - return_item_collection_metrics=None): - """ - Performs the DeleteItem operation and returns the result - """ + def get_operation_kwargs(self, + table_name, + hash_key, + range_key=None, + key=KEY, + attributes=None, + attributes_to_get=None, + actions=None, + condition=None, + consistent_read=None, + return_values=None, + return_consumed_capacity=None, + return_item_collection_metrics=None, + return_values_on_condition_failure=None): self._check_condition('condition', condition) - operation_kwargs = {TABLE_NAME: table_name} - operation_kwargs.update(self.get_identifier_map(table_name, hash_key, range_key)) + operation_kwargs = {} name_placeholders = {} expression_attribute_values = {} + operation_kwargs[TABLE_NAME] = table_name + operation_kwargs.update(self.get_identifier_map(table_name, hash_key, range_key, key=key)) + if attributes: + attrs = self.get_item_attribute_map(table_name, attributes) + operation_kwargs[ITEM].update(attrs[ITEM]) + if attributes_to_get is not None: + projection_expression = create_projection_expression(attributes_to_get, name_placeholders) + operation_kwargs[PROJECTION_EXPRESSION] = projection_expression if condition is not None: condition_expression = condition.serialize(name_placeholders, expression_attribute_values) operation_kwargs[CONDITION_EXPRESSION] = condition_expression - if return_values: + if consistent_read is not None: + operation_kwargs[CONSISTENT_READ] = consistent_read + if return_values is not None: operation_kwargs.update(self.get_return_values_map(return_values)) - if return_consumed_capacity: + if return_values_on_condition_failure is not None: + operation_kwargs.update(self.get_return_values_on_condition_failure_map(return_values_on_condition_failure)) + if return_consumed_capacity is not None: operation_kwargs.update(self.get_consumed_capacity_map(return_consumed_capacity)) - if return_item_collection_metrics: + if return_item_collection_metrics is not None: operation_kwargs.update(self.get_item_collection_map(return_item_collection_metrics)) + if actions is not None: + update_expression = Update(*actions) + operation_kwargs[UPDATE_EXPRESSION] = update_expression.serialize( + name_placeholders, + expression_attribute_values + ) if name_placeholders: operation_kwargs[EXPRESSION_ATTRIBUTE_NAMES] = self._reverse_dict(name_placeholders) if expression_attribute_values: operation_kwargs[EXPRESSION_ATTRIBUTE_VALUES] = expression_attribute_values + return operation_kwargs + def delete_item(self, + table_name, + hash_key, + range_key=None, + condition=None, + return_values=None, + return_consumed_capacity=None, + return_item_collection_metrics=None): + """ + Performs the DeleteItem operation and returns the result + """ + operation_kwargs = self.get_operation_kwargs( + table_name, + hash_key, + range_key=range_key, + condition=condition, + return_values=return_values, + return_consumed_capacity=return_consumed_capacity, + return_item_collection_metrics=return_item_collection_metrics + ) try: return self.dispatch(DELETE_ITEM, operation_kwargs) except BOTOCORE_EXCEPTIONS as e: @@ -855,33 +917,19 @@ def update_item(self, """ Performs the UpdateItem operation """ - self._check_condition('condition', condition) - - operation_kwargs = {TABLE_NAME: table_name} - operation_kwargs.update(self.get_identifier_map(table_name, hash_key, range_key)) - name_placeholders = {} - expression_attribute_values = {} - - if condition is not None: - condition_expression = condition.serialize(name_placeholders, expression_attribute_values) - operation_kwargs[CONDITION_EXPRESSION] = condition_expression - if return_consumed_capacity: - operation_kwargs.update(self.get_consumed_capacity_map(return_consumed_capacity)) - if return_item_collection_metrics: - operation_kwargs.update(self.get_item_collection_map(return_item_collection_metrics)) - if return_values: - operation_kwargs.update(self.get_return_values_map(return_values)) if not actions: raise ValueError("'actions' cannot be empty") - update_expression = Update(*actions) - operation_kwargs[UPDATE_EXPRESSION] = update_expression.serialize(name_placeholders, expression_attribute_values) - - if name_placeholders: - operation_kwargs[EXPRESSION_ATTRIBUTE_NAMES] = self._reverse_dict(name_placeholders) - if expression_attribute_values: - operation_kwargs[EXPRESSION_ATTRIBUTE_VALUES] = expression_attribute_values - + operation_kwargs = self.get_operation_kwargs( + table_name=table_name, + hash_key=hash_key, + range_key=range_key, + actions=actions, + condition=condition, + return_values=return_values, + return_consumed_capacity=return_consumed_capacity, + return_item_collection_metrics=return_item_collection_metrics + ) try: return self.dispatch(UPDATE_ITEM, operation_kwargs) except BOTOCORE_EXCEPTIONS as e: @@ -899,34 +947,86 @@ def put_item(self, """ Performs the PutItem operation and returns the result """ - self._check_condition('condition', condition) - - operation_kwargs = {TABLE_NAME: table_name} - operation_kwargs.update(self.get_identifier_map(table_name, hash_key, range_key, key=ITEM)) - name_placeholders = {} - expression_attribute_values = {} + operation_kwargs = self.get_operation_kwargs( + table_name=table_name, + hash_key=hash_key, + range_key=range_key, + key=ITEM, + attributes=attributes, + condition=condition, + return_values=return_values, + return_consumed_capacity=return_consumed_capacity, + return_item_collection_metrics=return_item_collection_metrics + ) + try: + return self.dispatch(PUT_ITEM, operation_kwargs) + except BOTOCORE_EXCEPTIONS as e: + raise PutError("Failed to put item: {}".format(e), e) - if attributes: - attrs = self.get_item_attribute_map(table_name, attributes) - operation_kwargs[ITEM].update(attrs[ITEM]) - if condition is not None: - condition_expression = condition.serialize(name_placeholders, expression_attribute_values) - operation_kwargs[CONDITION_EXPRESSION] = condition_expression - if return_consumed_capacity: + def _get_transact_operation_kwargs(self, + client_request_token=None, + return_consumed_capacity=None, + return_item_collection_metrics=None): + operation_kwargs = {} + if client_request_token is not None: + operation_kwargs[CLIENT_REQUEST_TOKEN] = client_request_token + if return_consumed_capacity is not None: operation_kwargs.update(self.get_consumed_capacity_map(return_consumed_capacity)) - if return_item_collection_metrics: + if return_item_collection_metrics is not None: operation_kwargs.update(self.get_item_collection_map(return_item_collection_metrics)) - if return_values: - operation_kwargs.update(self.get_return_values_map(return_values)) - if name_placeholders: - operation_kwargs[EXPRESSION_ATTRIBUTE_NAMES] = self._reverse_dict(name_placeholders) - if expression_attribute_values: - operation_kwargs[EXPRESSION_ATTRIBUTE_VALUES] = expression_attribute_values + + return operation_kwargs + + def transact_write_items(self, + condition_check_items, + delete_items, + put_items, + update_items, + client_request_token=None, + return_consumed_capacity=None, + return_item_collection_metrics=None): + """ + Performs the TransactWrite operation and returns the result + """ + transact_items = [] + transact_items.extend( + {TRANSACT_CONDITION_CHECK: item} for item in condition_check_items + ) + transact_items.extend( + {TRANSACT_DELETE: item} for item in delete_items + ) + transact_items.extend( + {TRANSACT_PUT: item} for item in put_items + ) + transact_items.extend( + {TRANSACT_UPDATE: item} for item in update_items + ) + + operation_kwargs = self._get_transact_operation_kwargs( + client_request_token=client_request_token, + return_consumed_capacity=return_consumed_capacity, + return_item_collection_metrics=return_item_collection_metrics + ) + operation_kwargs[TRANSACT_ITEMS] = transact_items try: - return self.dispatch(PUT_ITEM, operation_kwargs) + return self.dispatch(TRANSACT_WRITE_ITEMS, operation_kwargs) except BOTOCORE_EXCEPTIONS as e: - raise PutError("Failed to put item: {}".format(e), e) + raise TransactWriteError("Failed to write transaction items", e) + + def transact_get_items(self, get_items, return_consumed_capacity=None): + """ + Performs the TransactGet operation and returns the result + """ + operation_kwargs = self._get_transact_operation_kwargs(return_consumed_capacity=return_consumed_capacity) + operation_kwargs[TRANSACT_ITEMS] = [ + {TRANSACT_GET: item} for item in get_items + ] + + try: + return self.dispatch(TRANSACT_GET_ITEMS, operation_kwargs) + except BOTOCORE_EXCEPTIONS as e: + raise TransactGetError("Failed to get transaction items", e) def batch_write_item(self, table_name, @@ -1014,16 +1114,13 @@ def get_item(self, """ Performs the GetItem operation and returns the result """ - operation_kwargs = {} - name_placeholders = {} - if attributes_to_get is not None: - projection_expression = create_projection_expression(attributes_to_get, name_placeholders) - operation_kwargs[PROJECTION_EXPRESSION] = projection_expression - if name_placeholders: - operation_kwargs[EXPRESSION_ATTRIBUTE_NAMES] = self._reverse_dict(name_placeholders) - operation_kwargs[CONSISTENT_READ] = consistent_read - operation_kwargs[TABLE_NAME] = table_name - operation_kwargs.update(self.get_identifier_map(table_name, hash_key, range_key)) + operation_kwargs = self.get_operation_kwargs( + table_name=table_name, + hash_key=hash_key, + range_key=range_key, + consistent_read=consistent_read, + attributes_to_get=attributes_to_get + ) try: return self.dispatch(GET_ITEM, operation_kwargs) except BOTOCORE_EXCEPTIONS as e: diff --git a/pynamodb/connection/base.pyi b/pynamodb/connection/base.pyi index 2cbe0befb..8abaa452b 100644 --- a/pynamodb/connection/base.pyi +++ b/pynamodb/connection/base.pyi @@ -1,4 +1,4 @@ -from typing import Any, Dict, Iterator, MutableMapping, Optional, Sequence, Text +from typing import Any, Dict, Iterator, MutableMapping, Optional, Sequence, Text, List import botocore.session from botocore.awsrequest import AWSPreparedRequest @@ -56,9 +56,27 @@ class Connection: def get_identifier_map(self, table_name: Text, hash_key, range_key: Optional[Any] = ..., key: Any = ...): ... def get_consumed_capacity_map(self, return_consumed_capacity): ... def get_return_values_map(self, return_values): ... + def get_return_values_on_condition_failure_map(self, return_values_on_condition_failure: str) -> Dict[str, str]: ... def get_item_collection_map(self, return_item_collection_metrics): ... def get_exclusive_start_key_map(self, table_name: Text, exclusive_start_key): ... + def get_operation_kwargs( + self, + table_name: Text, + hash_key: Any, + range_key: Optional[Any] = ..., + key: Text = ..., + attributes: Optional[Any] = ..., + attributes_to_get: Optional[Any] = ..., + actions: Optional[Sequence[Action]] = ..., + client_request_token: Optional[Text] = ..., + condition: Optional[Condition] = ..., + consistent_read: Optional[bool] = ..., + return_values: Optional[Any] = ..., + return_consumed_capacity: Optional[Any] = ..., + return_item_collection_metrics: Optional[Any] = ... + ) -> Dict: ... + def delete_item( self, table_name: Text, @@ -94,6 +112,10 @@ class Connection: return_item_collection_metrics: Optional[Any] = ... ) -> Dict: ... + def _get_transact_operation_kwargs(self, client_request_token: Optional[str] = ..., return_consumed_capacity: Optional[Any] = ..., return_item_collection_metrics: Optional[Any] = ...) -> Dict: ... + def transact_get_items(self, get_items: List[Dict], return_consumed_capacity: Optional[Any] = ...) -> Dict: ... + def transact_write_items(self, condition_check_items: List[Dict], delete_items: List[Dict], put_items: List[Dict], update_items: List[Dict], client_request_token: Optional[Text] = ..., return_consumed_capacity: Optional[Any] = ..., return_item_collection_metrics: Optional[Any] = ...) -> Dict: ... + def batch_write_item(self, table_name: Text, put_items: Optional[Any] = ..., delete_items: Optional[Any] = ..., return_consumed_capacity: Optional[Any] = ..., return_item_collection_metrics: Optional[Any] = ...): ... def batch_get_item(self, table_name: Text, keys, consistent_read: Optional[Any] = ..., return_consumed_capacity: Optional[Any] = ..., attributes_to_get: Optional[Any] = ...): ... def get_item(self, table_name: Text, hash_key, range_key: Optional[Any] = ..., consistent_read: bool = ..., attributes_to_get: Optional[Any] = ...): ... diff --git a/pynamodb/connection/table.py b/pynamodb/connection/table.py index 2eec6e3fb..9b6d5c49b 100644 --- a/pynamodb/connection/table.py +++ b/pynamodb/connection/table.py @@ -3,7 +3,8 @@ ~~~~~~~~~~~~~~~~~~~~~~~~~~~ """ from pynamodb.connection.base import Connection -from pynamodb.constants import DEFAULT_BILLING_MODE +from pynamodb.constants import DEFAULT_BILLING_MODE, KEY + class TableConnection(object): """ @@ -44,6 +45,35 @@ def get_meta_table(self, refresh=False): """ return self.connection.get_meta_table(self.table_name, refresh=refresh) + def get_operation_kwargs(self, + hash_key, + range_key=None, + key=KEY, + attributes=None, + attributes_to_get=None, + actions=None, + condition=None, + consistent_read=None, + return_values=None, + return_consumed_capacity=None, + return_item_collection_metrics=None, + return_values_on_condition_failure=None): + return self.connection.get_operation_kwargs( + self.table_name, + hash_key, + range_key=range_key, + key=key, + attributes=attributes, + attributes_to_get=attributes_to_get, + actions=actions, + condition=condition, + consistent_read=consistent_read, + return_values=return_values, + return_consumed_capacity=return_consumed_capacity, + return_item_collection_metrics=return_item_collection_metrics, + return_values_on_condition_failure=return_values_on_condition_failure + ) + def delete_item(self, hash_key, range_key=None, @@ -85,7 +115,8 @@ def update_item(self, return_item_collection_metrics=return_item_collection_metrics, return_values=return_values) - def put_item(self, hash_key, + def put_item(self, + hash_key, range_key=None, attributes=None, condition=None, diff --git a/pynamodb/connection/table.pyi b/pynamodb/connection/table.pyi index e86f063ec..2d1c7c121 100644 --- a/pynamodb/connection/table.pyi +++ b/pynamodb/connection/table.pyi @@ -22,6 +22,22 @@ class TableConnection: aws_secret_access_key: Optional[str] = ..., ) -> None: ... + def get_operation_kwargs( + self, + hash_key, + range_key: Optional[Any] = ..., + key: Text = ..., + attributes: Optional[Any] = ..., + attributes_to_get: Optional[Any] = ..., + actions: Optional[Sequence[Action]] = ..., + condition: Optional[Condition] = ..., + consistent_read: bool = ..., + return_values: Optional[Any] = ..., + return_consumed_capacity: Optional[Any] = ..., + return_item_collection_metrics: Optional[Any] = ..., + return_values_on_condition_failure: Optional[Any] = ... + ) -> Dict: ... + def delete_item( self, hash_key, diff --git a/pynamodb/connection/transactions.py b/pynamodb/connection/transactions.py new file mode 100644 index 000000000..331a526d8 --- /dev/null +++ b/pynamodb/connection/transactions.py @@ -0,0 +1,115 @@ +from pynamodb.constants import ITEM, RESPONSES +from pynamodb.models import _ModelFuture + + +class Transaction(object): + + """ + Base class for a type of transaction operation + """ + + def __init__(self, connection, return_consumed_capacity=None): + self._connection = connection + self._return_consumed_capacity = return_consumed_capacity + + def _commit(self): + raise NotImplementedError() + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if exc_type is None and exc_val is None and exc_tb is None: + self._commit() + + +class TransactGet(Transaction): + + _results = None + + def __init__(self, *args, **kwargs): + self._get_items = [] + self._futures = [] + super(TransactGet, self).__init__(*args, **kwargs) + + def get(self, model_cls, hash_key, range_key=None): + """ + Adds the operation arguments for an item to list of models to get + returns a _ModelFuture object as a placeholder + + :param model_cls: + :param hash_key: + :param range_key: + :return: + """ + operation_kwargs = model_cls.get_operation_kwargs_from_class(hash_key, range_key=range_key) + model_future = _ModelFuture(model_cls) + self._futures.append(model_future) + self._get_items.append(operation_kwargs) + return model_future + + def _update_futures(self): + for model, data in zip(self._futures, self._results): + model.update_with_raw_data(data.get(ITEM)) + + def _commit(self): + response = self._connection.transact_get_items( + get_items=self._get_items, + return_consumed_capacity=self._return_consumed_capacity + ) + self._results = response[RESPONSES] + self._update_futures() + return response + + +class TransactWrite(Transaction): + + def __init__(self, client_request_token=None, return_item_collection_metrics=None, **kwargs): + super(TransactWrite, self).__init__(**kwargs) + self._client_request_token = client_request_token + self._return_item_collection_metrics = return_item_collection_metrics + self._condition_check_items = [] + self._delete_items = [] + self._put_items = [] + self._update_items = [] + + def condition_check(self, model_cls, hash_key, range_key=None, condition=None): + if condition is None: + raise TypeError('`condition` cannot be None') + operation_kwargs = model_cls.get_operation_kwargs_from_class( + hash_key, + range_key=range_key, + condition=condition + ) + self._condition_check_items.append(operation_kwargs) + + def delete(self, model, condition=None): + operation_kwargs = model.get_operation_kwargs_from_instance(condition=condition) + self._delete_items.append(operation_kwargs) + + def save(self, model, condition=None, return_values=None): + operation_kwargs = model.get_operation_kwargs_from_instance( + key=ITEM, + condition=condition, + return_values_on_condition_failure=return_values + ) + self._put_items.append(operation_kwargs) + + def update(self, model, actions, condition=None, return_values=None): + operation_kwargs = model.get_operation_kwargs_from_instance( + actions=actions, + condition=condition, + return_values_on_condition_failure=return_values + ) + self._update_items.append(operation_kwargs) + + def _commit(self): + return self._connection.transact_write_items( + condition_check_items=self._condition_check_items, + delete_items=self._delete_items, + put_items=self._put_items, + update_items=self._update_items, + client_request_token=self._client_request_token, + return_consumed_capacity=self._return_consumed_capacity, + return_item_collection_metrics=self._return_item_collection_metrics, + ) diff --git a/pynamodb/connection/transactions.pyi b/pynamodb/connection/transactions.pyi new file mode 100644 index 000000000..0d55b6a22 --- /dev/null +++ b/pynamodb/connection/transactions.pyi @@ -0,0 +1,41 @@ +from typing import Set, Tuple, TypeVar, Type, Any, List, Optional, Dict, Union, Text + +from pynamodb.expressions.condition import Condition +from pynamodb.models import Model, _ModelFuture + +from pynamodb.connection import Connection + + +KeyType = Union[Text, bytes, float, int, Tuple] +ModelType = TypeVar('ModelType', bound=Model) + +class Transaction: + _connection: Connection + _return_consumed_capacity: Optional[Any] + + def __enter__(self) -> Transaction: ... + def __exit__(self, exc_type, exc_val, exc_tb) -> Any: ... + def __init__(self, connection: Connection, return_consumed_capacity: Optional[Any] = ...) -> None: ... + def _commit(self) -> Any: ... + +class TransactGet(Transaction): + _get_items: List[Dict] + _futures = List[_ModelFuture] + _results: List[Dict] + + def _update_futures(self) -> None: ... + def get(self, model_cls: Type[ModelType], hash_key: KeyType, range_key: Optional[KeyType] = ...) -> ModelType: ... + +class TransactWrite(Transaction): + _condition_check_items: List[Dict] + _delete_items: List[Dict] + _put_items: List[Dict] + _update_items: List[Dict] + _client_request_token: Optional[str] + _return_item_collection_metrics: Optional[Any] + + def __int__(self, connection: Connection, return_consumed_capacity: Optional[Any] = ..., client_request_token: Optional[Text] = ..., return_item_collection_metrics: Optional[Text] = ...) -> None: ... + def condition_check(self, model_cls: Type[ModelType], hash_key: KeyType, range_key: Optional[KeyType] = ..., condition: Condition = ...) -> None: ... + def delete(self, model: ModelType, condition: Optional[Condition] = ...) -> None: ... + def save(self, model: ModelType, condition: Optional[Condition] = ..., return_values: Optional[Any] = ...) -> None: ... + def update(self, model: ModelType, actions: List[Any], condition: Optional[Condition] = ..., return_values: Optional[Any] = ...) -> None: ... diff --git a/pynamodb/constants.py b/pynamodb/constants.py index 26ed3ba66..1d1a11aa5 100644 --- a/pynamodb/constants.py +++ b/pynamodb/constants.py @@ -3,6 +3,8 @@ """ # Operations +TRANSACT_WRITE_ITEMS = 'TransactWriteItems' +TRANSACT_GET_ITEMS = 'TransactGetItems' BATCH_WRITE_ITEM = 'BatchWriteItem' DESCRIBE_TABLE = 'DescribeTable' BATCH_GET_ITEM = 'BatchGetItem' @@ -18,10 +20,13 @@ SCAN = 'Scan' # Request Parameters +RETURN_VALUES_ON_CONDITION_FAILURE = 'ReturnValuesOnConditionCheckFailure' GLOBAL_SECONDARY_INDEX_UPDATES = 'GlobalSecondaryIndexUpdates' RETURN_ITEM_COLL_METRICS = 'ReturnItemCollectionMetrics' EXCLUSIVE_START_TABLE_NAME = 'ExclusiveStartTableName' RETURN_CONSUMED_CAPACITY = 'ReturnConsumedCapacity' +CLIENT_REQUEST_TOKEN = 'ClientRequestToken' +COMPARISON_OPERATOR = 'ComparisonOperator' SCAN_INDEX_FORWARD = 'ScanIndexForward' ATTR_DEFINITIONS = 'AttributeDefinitions' TABLE_DESCRIPTION = 'TableDescription' @@ -29,6 +34,7 @@ UNPROCESSED_ITEMS = 'UnprocessedItems' CONSISTENT_READ = 'ConsistentRead' DELETE_REQUEST = 'DeleteRequest' +TRANSACT_ITEMS = 'TransactItems' RETURN_VALUES = 'ReturnValues' REQUEST_ITEMS = 'RequestItems' ATTRS_TO_GET = 'AttributesToGet' @@ -55,6 +61,15 @@ KEYS = 'Keys' UTC = 'UTC' KEY = 'Key' +GET = 'Get' + +# transaction operators +TRANSACT_CONDITION_CHECK = 'ConditionCheck' +TRANSACT_DELETE = 'Delete' +TRANSACT_GET = 'Get' +TRANSACT_PUT = 'Put' +TRANSACT_UPDATE = 'Update' + ACTION = 'Action' # Response Parameters @@ -200,6 +215,7 @@ ALL_NEW = 'ALL_NEW' UPDATED_NEW = 'UPDATED_NEW' RETURN_VALUES_VALUES = [NONE, ALL_OLD, UPDATED_OLD, ALL_NEW, UPDATED_NEW] +RETURN_VALUES_ON_CONDITION_FAILURE_VALUES = [NONE, ALL_OLD] # These are constants used in the AttributeUpdates parameter for UpdateItem # See: http://docs.aws.amazon.com/amazondynamodb/latest/APIReference/API_UpdateItem.html#DDB-UpdateItem-request-AttributeUpdates diff --git a/pynamodb/constants.pyi b/pynamodb/constants.pyi index 4a3975f74..cc83485c8 100644 --- a/pynamodb/constants.pyi +++ b/pynamodb/constants.pyi @@ -1,4 +1,4 @@ -from typing import Any, List +from typing import Any, List, Set BATCH_WRITE_ITEM: str DESCRIBE_TABLE: str @@ -13,6 +13,7 @@ GET_ITEM: str PUT_ITEM: str QUERY: str SCAN: str +RETURN_VALUES_ON_CONDITION_FAILURE: str GLOBAL_SECONDARY_INDEX_UPDATES: str RETURN_ITEM_COLL_METRICS: str EXCLUSIVE_START_TABLE_NAME: str @@ -124,6 +125,7 @@ UPDATED_OLD: str ALL_NEW: str UPDATED_NEW: str RETURN_VALUES_VALUES: Any +RETURN_VALUES_ON_CONDITION_FAILURE_VALUES: List[Any] PUT: str DELETE: str ADD: str @@ -148,6 +150,16 @@ SELECT_VALUES: List[str] DEFAULT_BILLING_MODE: str AND: str OR: str +GET: str +TRANSACT_GET_ITEMS: str +TRANSACT_WRITE_ITEMS: str +TRANSACT_ITEMS: str +TRANSACT_CONDITION_CHECK: str +TRANSACT_DELETE: str +TRANSACT_GET: str +TRANSACT_PUT: str +TRANSACT_UPDATE: str +CLIENT_REQUEST_TOKEN: str BETWEEN: str IN: str NULL: str diff --git a/pynamodb/exceptions.py b/pynamodb/exceptions.py index aa0a5ed7e..bc6a072ef 100644 --- a/pynamodb/exceptions.py +++ b/pynamodb/exceptions.py @@ -95,6 +95,27 @@ def __init__(self, table_name): super(TableDoesNotExist, self).__init__(msg) +class TransactWriteError(PynamoDBException): + """ + Raised when a TransactWrite operation fails + """ + pass + + +class TransactGetError(PynamoDBException): + """ + Raised when a TransactGet operation fails + """ + pass + + +class InvalidStateError(PynamoDBException): + """ + Raises when the internal state of an operation context is invalid + """ + msg = "Operation in invalid state" + + class VerboseClientError(botocore.exceptions.ClientError): def __init__(self, error_response, operation_name, verbose_properties=None): """ Modify the message template to include the desired verbose properties """ diff --git a/pynamodb/exceptions.pyi b/pynamodb/exceptions.pyi index 3b1df2e6e..5c9b83638 100644 --- a/pynamodb/exceptions.pyi +++ b/pynamodb/exceptions.pyi @@ -18,6 +18,10 @@ class UpdateError(PynamoDBConnectionError): ... class GetError(PynamoDBConnectionError): ... class TableError(PynamoDBConnectionError): ... class DoesNotExist(PynamoDBException): ... +class TransactWriteError(PynamoDBException): ... +class TransactGetError(PynamoDBException): ... +class InvalidStateError(PynamoDBException): ... + class TableDoesNotExist(PynamoDBException): def __init__(self, table_name) -> None: ... diff --git a/pynamodb/models.py b/pynamodb/models.py index b21b61fbe..18329fa64 100644 --- a/pynamodb/models.py +++ b/pynamodb/models.py @@ -9,7 +9,7 @@ from inspect import getmembers from six import add_metaclass -from pynamodb.exceptions import DoesNotExist, TableDoesNotExist, TableError +from pynamodb.exceptions import DoesNotExist, TableDoesNotExist, TableError, InvalidStateError from pynamodb.attributes import Attribute, AttributeContainer, AttributeContainerMeta, MapAttribute, TTLAttribute from pynamodb.connection.table import TableConnection from pynamodb.connection.util import pythonic @@ -29,7 +29,8 @@ BATCH_WRITE_PAGE_LIMIT, META_CLASS_NAME, REGION, HOST, NULL, COUNT, ITEM_COUNT, KEY, UNPROCESSED_ITEMS, STREAM_VIEW_TYPE, - STREAM_SPECIFICATION, STREAM_ENABLED, BILLING_MODE) + STREAM_SPECIFICATION, STREAM_ENABLED, BILLING_MODE +) log = logging.getLogger(__name__) @@ -356,6 +357,7 @@ def update(self, actions, condition=None): kwargs.update(condition=condition) kwargs.update(actions=actions) + data = self._get_connection().update_item(*args, **kwargs) for name, value in data[ATTRIBUTES].items(): attr_name = self._dynamo_to_python_attr(name) @@ -386,6 +388,37 @@ def refresh(self, consistent_read=False): raise self.DoesNotExist("This item does not exist in the table.") self._deserialize(item_data) + def get_operation_kwargs_from_instance(self, + key=KEY, + actions=None, + condition=None, + return_values_on_condition_failure=None): + is_update = actions is not None + args, save_kwargs = self._get_save_args(null_check=not is_update) + kwargs = dict( + key=key, + actions=actions, + condition=condition, + return_values_on_condition_failure=return_values_on_condition_failure + ) + if not is_update: + kwargs.update(save_kwargs) + elif pythonic(RANGE_KEY) in save_kwargs: + kwargs[pythonic(RANGE_KEY)] = save_kwargs[pythonic(RANGE_KEY)] + return self._get_connection().get_operation_kwargs(*args, **kwargs) + + @classmethod + def get_operation_kwargs_from_class(cls, + hash_key, + range_key=None, + condition=None): + hash_key, range_key = cls._serialize_keys(hash_key, range_key) + return cls._get_connection().get_operation_kwargs( + hash_key=hash_key, + range_key=range_key, + condition=condition + ) + @classmethod def get(cls, hash_key, @@ -397,8 +430,11 @@ def get(cls, :param hash_key: The hash key of the desired item :param range_key: The range key of the desired item, only used when appropriate. + :param consistent_read + :param attributes_to_get """ hash_key, range_key = cls._serialize_keys(hash_key, range_key) + data = cls._get_connection().get_item( hash_key, range_key=range_key, @@ -1016,3 +1052,28 @@ def _serialize_keys(cls, hash_key, range_key=None): if range_key is not None: range_key = cls._range_key_attribute().serialize(range_key) return hash_key, range_key + + +class _ModelFuture: + """ + A placeholder object for a model that does not exist yet + + For example: when performing a TransactGet request, this is a stand-in for a model that will be returned + when the operation is complete + """ + def __init__(self, model_cls): + self._model_cls = model_cls + self._model = None + self._resolved = False + + def update_with_raw_data(self, data): + if data is not None and data != {}: + self._model = self._model_cls.from_raw_data(data=data) + self._resolved = True + + def get(self): + if not self._resolved: + raise InvalidStateError() + if self._model: + return self._model + raise self._model_cls.DoesNotExist() diff --git a/pynamodb/models.pyi b/pynamodb/models.pyi index 6d03eece1..2bb871da9 100644 --- a/pynamodb/models.pyi +++ b/pynamodb/models.pyi @@ -1,3 +1,4 @@ + from .attributes import Attribute from .exceptions import DoesNotExist as DoesNotExist from typing import Any, Dict, Generic, Iterable, Iterator, List, Optional, Sequence, Tuple, Type, TypeVar, Text, Union @@ -40,10 +41,8 @@ class Model(metaclass=MetaModel): def update(self, actions: List[Any], condition: Optional[Condition] = ...) -> Any: ... def save(self, condition: Optional[Condition] = ...) -> Dict[str, Any]: ... def refresh(self, consistent_read: bool = ...): ... - @classmethod def get(cls: Type[_T], hash_key: KeyType, range_key: Optional[KeyType] = ..., consistent_read: bool = ...) -> _T: ... - @classmethod def from_raw_data(cls: Type[_T], data) -> _T: ... @@ -140,3 +139,11 @@ class BatchWrite(Generic[_T], ModelContextManager[_T]): def __exit__(self, exc_type, exc_val, exc_tb) -> None: ... pending_operations: Any def commit(self) -> None: ... + +class _ModelFuture(Generic[_T]): + _model_cls: Type[_T] + _model: Optional[_T] + _resolved: bool + def __init__(self, model_cls: Type[_T]) -> None: ... + def update_with_raw_data(self, data: Dict) -> None: ... + def get(self) -> _T: ... diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 77d191258..6a9a93e0b 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -3,7 +3,7 @@ import pytest -@pytest.fixture +@pytest.fixture(scope='module') def ddb_url(): """Obtain the URL of a local DynamoDB instance. diff --git a/tests/integration/test_transaction_integration.py b/tests/integration/test_transaction_integration.py new file mode 100644 index 000000000..9e057d972 --- /dev/null +++ b/tests/integration/test_transaction_integration.py @@ -0,0 +1,273 @@ +import uuid +from datetime import datetime + +import pytest + +from pynamodb.connection import Connection +from pynamodb.exceptions import DoesNotExist, TransactWriteError, TransactGetError, InvalidStateError + +from pynamodb.attributes import NumberAttribute, UnicodeAttribute, UTCDateTimeAttribute, BooleanAttribute +from pynamodb.connection.transactions import TransactGet, TransactWrite + +from pynamodb.models import Model + +IDEMPOTENT_PARAMETER_MISMATCH = 'IdempotentParameterMismatchException' +PROVISIONED_THROUGHPUT_EXCEEDED = 'ProvisionedThroughputExceededException' +RESOURCE_NOT_FOUND = 'ResourceNotFoundException' +TRANSACTION_CANCELLED = 'TransactionCanceledException' +TRANSACTION_IN_PROGRESS = 'TransactionInProgressException' +VALIDATION_EXCEPTION = 'ValidationException' + + +class User(Model): + class Meta: + region = 'us-east-1' + table_name = 'user' + + user_id = NumberAttribute(hash_key=True) + + +class BankStatement(Model): + + class Meta: + region = 'us-east-1' + table_name = 'statement' + + user_id = NumberAttribute(hash_key=True) + balance = NumberAttribute(default=0) + active = BooleanAttribute(default=True) + + +class LineItem(Model): + + class Meta: + region = 'us-east-1' + table_name = 'line-item' + + user_id = NumberAttribute(hash_key=True) + created_at = UTCDateTimeAttribute(range_key=True, default=datetime.now()) + amount = NumberAttribute() + currency = UnicodeAttribute() + + +class DifferentRegion(Model): + + class Meta: + region = 'us-east-2' + table_name = 'different-region' + + entry_index = NumberAttribute(hash_key=True) + + +TEST_MODELS = [ + BankStatement, + DifferentRegion, + LineItem, + User, +] + + +@pytest.fixture(scope='module') +def connection(ddb_url): + yield Connection(host=ddb_url) + + +@pytest.fixture(scope='module', autouse=True) +def create_tables(ddb_url): + for m in TEST_MODELS: + m.Meta.host = ddb_url + m.create_table( + read_capacity_units=10, + write_capacity_units=10, + wait=True + ) + + yield + + for m in TEST_MODELS: + if m.exists(): + m.delete_table() + + +def get_error_code(error): + return error.cause.response['Error'].get('Code') + + +def get_error_message(error): + return error.cause.response['Error'].get('Message') + + +@pytest.mark.ddblocal +def test_transact_write__error__idempotent_parameter_mismatch(connection): + client_token = str(uuid.uuid4()) + + with TransactWrite(connection=connection, client_request_token=client_token) as transaction: + transaction.save(User(1)) + transaction.save(User(2)) + + with pytest.raises(TransactWriteError) as exc_info: + # committing the first time, then adding more info and committing again + with TransactWrite(connection=connection, client_request_token=client_token) as transaction: + transaction.save(User(3)) + assert get_error_code(exc_info.value) == IDEMPOTENT_PARAMETER_MISMATCH + + # ensure that the first request succeeded in creating new users + assert User.get(1) + assert User.get(2) + + with pytest.raises(DoesNotExist): + # ensure it did not create the user from second request + User.get(3) + + +@pytest.mark.ddblocal +def test_transact_write__error__different_regions(connection): + with pytest.raises(TransactWriteError) as exc_info: + with TransactWrite(connection=connection) as transact_write: + # creating a model in a table outside the region everyone else operates in + transact_write.save(DifferentRegion(entry_index=0)) + transact_write.save(BankStatement(1)) + transact_write.save(User(1)) + assert get_error_code(exc_info.value) == RESOURCE_NOT_FOUND + + +@pytest.mark.ddblocal +def test_transact_write__error__transaction_cancelled__condition_check_failure(connection): + # create a users and a bank statements for them + User(1).save() + BankStatement(1).save() + + # attempt to do this as a transaction with the condition that they don't already exist + with pytest.raises(TransactWriteError) as exc_info: + with TransactWrite(connection=connection) as transaction: + transaction.save(User(1), condition=(User.user_id.does_not_exist())) + transaction.save(BankStatement(1), condition=(BankStatement.user_id.does_not_exist())) + assert get_error_code(exc_info.value) == TRANSACTION_CANCELLED + assert 'ConditionalCheckFailed' in get_error_message(exc_info.value) + + +@pytest.mark.ddblocal +def test_transact_write__error__multiple_operations_on_same_record(connection): + BankStatement(1).save() + + # attempt to do a transaction with multiple operations on the same record + with pytest.raises(TransactWriteError) as exc_info: + with TransactWrite(connection=connection) as transaction: + transaction.condition_check(BankStatement, 1, condition=(BankStatement.user_id.exists())) + transaction.update(BankStatement(1), actions=[(BankStatement.balance.add(10))]) + assert get_error_code(exc_info.value) == VALIDATION_EXCEPTION + + +@pytest.mark.ddblocal +def test_transact_get(connection): + # making sure these entries exist, and with the expected info + User(1).save() + BankStatement(1).save() + User(2).save() + BankStatement(2, balance=100).save() + + # get users and statements we just created and assign them to variables + with TransactGet(connection=connection) as transaction: + _user1_future = transaction.get(User, 1) + _statement1_future = transaction.get(BankStatement, 1) + _user2_future = transaction.get(User, 2) + _statement2_future = transaction.get(BankStatement, 2) + + user1 = _user1_future.get() + statement1 = _statement1_future.get() + user2 = _user2_future.get() + statement2 = _statement2_future.get() + + assert user1.user_id == statement1.user_id == 1 + assert statement1.balance == 0 + assert user2.user_id == statement2.user_id == 2 + assert statement2.balance == 100 + + +@pytest.mark.ddblocal +def test_transact_get__does_not_exist(connection): + with TransactGet(connection=connection) as transaction: + _user_future = transaction.get(User, 100) + with pytest.raises(User.DoesNotExist): + _user_future.get() + + +@pytest.mark.ddblocal +def test_transact_get__invalid_state(connection): + with TransactGet(connection=connection) as transaction: + _user_future = transaction.get(User, 100) + with pytest.raises(InvalidStateError): + _user_future.get() + + +@pytest.mark.ddblocal +def test_transact_write(connection): + # making sure these entries exist, and with the expected info + BankStatement(1, balance=0).save() + BankStatement(2, balance=100).save() + + # assert values are what we think they should be + statement1 = BankStatement.get(1) + statement2 = BankStatement.get(2) + assert statement1.balance == 0 + assert statement2.balance == 100 + + with TransactWrite(connection=connection) as transaction: + # let the users send money to one another + # create a credit line item to user 1's account + transaction.save( + LineItem(user_id=1, amount=50, currency='USD'), + condition=(LineItem.user_id.does_not_exist()), + ) + # create a debit to user 2's account + transaction.save( + LineItem(user_id=2, amount=-50, currency='USD'), + condition=(LineItem.user_id.does_not_exist()), + ) + + # add credit to user 1's account + transaction.update(statement1, actions=[BankStatement.balance.add(50)]) + # debit from user 2's account if they have enough in the bank + transaction.update( + statement2, + actions=[BankStatement.balance.add(-50)], + condition=(BankStatement.balance >= 50) + ) + + statement1.refresh() + statement2.refresh() + assert statement1.balance == statement2.balance == 50 + + +@pytest.mark.ddblocal +def test_transact_write__one_of_each(connection): + User(1).save() + User(2).save() + statement = BankStatement(1, balance=100, active=True) + statement.save() + + with TransactWrite(connection=connection) as transaction: + transaction.condition_check(User, 1, condition=(User.user_id.exists())) + transaction.delete(User(2)) + transaction.save(LineItem(4, amount=100, currency='USD'), condition=(LineItem.user_id.does_not_exist())) + transaction.update( + statement, + actions=[ + BankStatement.active.set(False), + BankStatement.balance.set(0), + ] + ) + + # confirming transaction correct and successful + assert User.get(1) + with pytest.raises(DoesNotExist): + User.get(2) + + new_line_item = next(LineItem.query(4, scan_index_forward=False, limit=1), None) + assert new_line_item + assert new_line_item.amount == 100 + assert new_line_item.currency == 'USD' + + statement.refresh() + assert not statement.active + assert statement.balance == 0 diff --git a/tests/test_base_connection.py b/tests/test_base_connection.py index c652bb678..7bb5a2f88 100644 --- a/tests/test_base_connection.py +++ b/tests/test_base_connection.py @@ -907,6 +907,30 @@ def test_put_item(self): } self.assertEqual(req.call_args[0][1], params) + def test_transact_write_items(self): + conn = Connection() + with patch(PATCH_METHOD) as req: + conn.transact_write_items([], [], [], []) + self.assertEqual(req.call_args[0][0], 'TransactWriteItems') + self.assertDictEqual( + req.call_args[0][1], { + 'TransactItems': [], + 'ReturnConsumedCapacity': 'TOTAL' + } + ) + + def test_transact_get_items(self): + conn = Connection() + with patch(PATCH_METHOD) as req: + conn.transact_get_items([]) + self.assertEqual(req.call_args[0][0], 'TransactGetItems') + self.assertDictEqual( + req.call_args[0][1], { + 'TransactItems': [], + 'ReturnConsumedCapacity': 'TOTAL' + } + ) + def test_batch_write_item(self): """ Connection.batch_write_item diff --git a/tests/test_transaction.py b/tests/test_transaction.py new file mode 100644 index 000000000..36beb0524 --- /dev/null +++ b/tests/test_transaction.py @@ -0,0 +1,126 @@ +import pytest +import six +from pynamodb.attributes import NumberAttribute, UnicodeAttribute + +from pynamodb.connection import Connection +from pynamodb.connection.transactions import Transaction, TransactGet, TransactWrite +from pynamodb.models import Model +from tests.test_base_connection import PATCH_METHOD + +if six.PY3: + from unittest.mock import patch +else: + from mock import patch + + +class MockModel(Model): + class Meta: + table_name = 'mock' + + mock_hash = NumberAttribute(hash_key=True) + mock_range = NumberAttribute(range_key=True) + mock_toot = UnicodeAttribute(null=True) + + +MOCK_TABLE_DESCRIPTOR = { + "Table": { + "TableName": "Mock", + "KeySchema": [ + { + "AttributeName": "MockHash", + "KeyType": "HASH" + }, + { + "AttributeName": "MockRange", + "KeyType": "RANGE" + } + ], + "AttributeDefinitions": [ + { + "AttributeName": "MockHash", + "AttributeType": "N" + }, + { + "AttributeName": "MockRange", + "AttributeType": "N" + } + ] + } +} + + +class TestTransaction: + + def test_commit__not_implemented(self): + t = Transaction(connection=Connection()) + with pytest.raises(NotImplementedError): + t._commit() + + +class TestTransactGet: + + def test_commit(self, mocker): + connection = Connection() + mock_connection_transact_get = mocker.patch.object(connection, 'transact_get_items') + + with patch(PATCH_METHOD) as req: + req.return_value = MOCK_TABLE_DESCRIPTOR + with TransactGet(connection=connection) as t: + t.get(MockModel, 1, 2) + + mock_connection_transact_get.assert_called_once_with( + get_items=[{'Key': {'MockHash': {'N': '1'}, 'MockRange': {'N': '2'}}, 'TableName': 'mock'}], + return_consumed_capacity=None + ) + + +class TestTransactWrite: + + def test_condition_check__no_condition(self): + with pytest.raises(TypeError): + with TransactWrite(connection=Connection()) as transaction: + transaction.condition_check(MockModel, hash_key=1, condition=None) + + def test_commit(self, mocker): + connection = Connection() + mock_connection_transact_write = mocker.patch.object(connection, 'transact_write_items') + with patch(PATCH_METHOD) as req: + req.return_value = MOCK_TABLE_DESCRIPTOR + with TransactWrite(connection=connection) as t: + t.condition_check(MockModel, 1, 3, condition=(MockModel.mock_hash.does_not_exist())) + t.delete(MockModel(2, 4)) + t.save(MockModel(3, 5)) + t.update(MockModel(4, 6), actions=[MockModel.mock_toot.set('hello')], return_values='ALL_OLD') + + expected_condition_checks = [{ + 'ConditionExpression': 'attribute_not_exists (#0)', + 'ExpressionAttributeNames': {'#0': 'mock_hash'}, + 'Key': {'MockHash': {'N': '1'}, 'MockRange': {'N': '3'}}, + 'TableName': 'mock'} + ] + expected_deletes = [{ + 'Key': {'MockHash': {'N': '2'}, 'MockRange': {'N': '4'}}, + 'TableName': 'mock' + }] + expected_puts = [{ + 'Item': {'MockHash': {'N': '3'}, 'MockRange': {'N': '5'}}, + 'TableName': 'mock' + }] + expected_updates = [{ + 'TableName': 'mock', + 'Key': {'MockHash': {'N': '4'}, 'MockRange': {'N': '6'}}, + 'ReturnValuesOnConditionCheckFailure': 'ALL_OLD', + 'UpdateExpression': 'SET #0 = :0', + 'ExpressionAttributeNames': {'#0': 'mock_toot'}, + 'ExpressionAttributeValues': {':0': {'S': 'hello'}} + }] + + mock_connection_transact_write.assert_called_once_with( + condition_check_items=expected_condition_checks, + delete_items=expected_deletes, + put_items=expected_puts, + update_items=expected_updates, + client_request_token=None, + return_consumed_capacity=None, + return_item_collection_metrics=None + )