From 92e6e3f3cac806e450297fd95ba7696b667b95db Mon Sep 17 00:00:00 2001 From: Ivan Kang Date: Wed, 12 Jun 2019 21:40:26 -0700 Subject: [PATCH] Add TTLAttribute and default_as_new (#633) This is a backport to 3.x. --- pynamodb/attributes.py | 65 ++++++++++++++++++++++--- pynamodb/attributes.pyi | 10 +++- pynamodb/compat.py | 4 ++ pynamodb/connection/base.py | 21 +++++++- pynamodb/connection/table.py | 6 +++ pynamodb/constants.py | 5 ++ pynamodb/models.py | 55 ++++++++++++++++++--- pynamodb/tests/test_attributes.py | 53 +++++++++++++++++++- pynamodb/tests/test_base_connection.py | 6 +++ pynamodb/tests/test_model.py | 42 +++++++++++++++- pynamodb/tests/test_table_connection.py | 18 +++++++ 11 files changed, 266 insertions(+), 19 deletions(-) diff --git a/pynamodb/attributes.py b/pynamodb/attributes.py index d3c3ea53d..054359fb4 100644 --- a/pynamodb/attributes.py +++ b/pynamodb/attributes.py @@ -1,12 +1,14 @@ """ PynamoDB attributes """ +import calendar import six from six import add_metaclass import json +import time from base64 import b64encode, b64decode from copy import deepcopy -from datetime import datetime +from datetime import datetime, timedelta import warnings from dateutil.parser import parse from dateutil.tz import tzutc @@ -15,7 +17,7 @@ STRING, STRING_SHORT, NUMBER, BINARY, UTC, DATETIME_FORMAT, BINARY_SET, STRING_SET, NUMBER_SET, MAP, MAP_SHORT, LIST, LIST_SHORT, DEFAULT_ENCODING, BOOLEAN, ATTR_TYPE_MAP, NUMBER_SHORT, NULL, SHORT_ATTR_TYPES ) -from pynamodb.compat import getmembers_issubclass +from pynamodb.compat import getmembers_issubclass, timedelta_total_seconds from pynamodb.expressions.operand import Path import collections @@ -32,9 +34,15 @@ def __init__(self, range_key=False, null=None, default=None, + default_for_new=None, attr_name=None ): + if default and default_for_new: + raise ValueError("An attribute cannot have both default and default_for_new parameters") self.default = default + # This default is only set for new objects (ie: it's not set for re-saved objects) + self.default_for_new = default_for_new + if null is not None: self.null = null self.is_hash_key = hash_key @@ -212,14 +220,14 @@ def _initialize_attributes(cls): @add_metaclass(AttributeContainerMeta) class AttributeContainer(object): - def __init__(self, **attributes): + def __init__(self, _user_instantiated=True, **attributes): # The `attribute_values` dictionary is used by the Attribute data descriptors in cls._attributes # to store the values that are bound to this instance. Attributes store values in the dictionary # using the `python_attr_name` as the dictionary key. "Raw" (i.e. non-subclassed) MapAttribute # instances do not have any Attributes defined and instead use this dictionary to store their # collection of name-value pairs. self.attribute_values = {} - self._set_defaults() + self._set_defaults(_user_instantiated=_user_instantiated) self._set_attributes(**attributes) @classmethod @@ -250,12 +258,15 @@ def _dynamo_to_python_attr(cls, dynamo_key): """ return cls._dynamo_to_python_attrs.get(dynamo_key, dynamo_key) - def _set_defaults(self): + def _set_defaults(self, _user_instantiated=True): """ Sets and fields that provide a default value """ for name, attr in self.get_attributes().items(): - default = attr.default + if _user_instantiated and attr.default_for_new is not None: + default = attr.default_for_new + else: + default = attr.default if callable(default): value = default() else: @@ -525,6 +536,48 @@ def deserialize(self, value): return json.loads(value) +class TTLAttribute(Attribute): + """ + A time-to-live attribute that signifies when the item expires and can be automatically deleted. + It can be assigned with a timezone-aware datetime value (for absolute expiry time) + or a timedelta value (for expiry relative to the current time), + but always reads as a UTC datetime value. + """ + attr_type = NUMBER + + def __set__(self, instance, value): + """ + Converts assigned values to a UTC datetime + """ + if isinstance(value, timedelta): + value = int(time.time() + timedelta_total_seconds(value)) + elif isinstance(value, datetime): + if value.tzinfo is None: + raise ValueError("datetime must be timezone-aware") + value = calendar.timegm(value.utctimetuple()) + elif value is not None: + raise ValueError("TTLAttribute value must be a timedelta or datetime") + attr_name = instance._dynamo_to_python_attrs.get(self.attr_name, self.attr_name) + if value is not None: + value = datetime.utcfromtimestamp(value).replace(tzinfo=tzutc()) + instance.attribute_values[attr_name] = value + + def serialize(self, value): + """ + Serializes a datetime as a timestamp (Unix time). + """ + if value is None: + return None + return json.dumps(calendar.timegm(value.utctimetuple())) + + def deserialize(self, value): + """ + Deserializes a timestamp (Unix time) as a UTC datetime. + """ + timestamp = json.loads(value) + return datetime.utcfromtimestamp(timestamp).replace(tzinfo=tzutc()) + + class UTCDateTimeAttribute(Attribute): """ An attribute for storing a UTC Datetime diff --git a/pynamodb/attributes.pyi b/pynamodb/attributes.pyi index e3a1be06e..7702e4b73 100644 --- a/pynamodb/attributes.pyi +++ b/pynamodb/attributes.pyi @@ -24,9 +24,10 @@ class Attribute(Generic[_T]): attr_type: Text null: bool default: Any + default_for_new: Any is_hash_key: bool is_range_key: bool - def __init__(self, hash_key: bool = ..., range_key: bool = ..., null: Optional[bool] = ..., default: Optional[Union[_T, Callable[..., _T]]] = ..., attr_name: Optional[Text] = ...) -> None: ... + def __init__(self, hash_key: bool = ..., range_key: bool = ..., null: Optional[bool] = ..., default: Optional[Union[_T, Callable[..., _T]]] = ..., default_for_new: Optional[Union[Any, Callable[..., _T]]] = ..., attr_name: Optional[Text] = ...) -> None: ... def __set__(self, instance: Any, value: Optional[_T]) -> None: ... def serialize(self, value: Any) -> Any: ... def deserialize(self, value: Any) -> Any: ... @@ -118,6 +119,13 @@ class NumberAttribute(Attribute[float]): def __get__(self, instance: Any, owner: Any) -> float: ... +class TTLAttribute(Attribute[datetime]): + @overload + def __get__(self: _A, instance: None, owner: Any) -> _A: ... + @overload + def __get__(self, instance: Any, owner: Any) -> datetime: ... + + class UTCDateTimeAttribute(Attribute[datetime]): @overload def __get__(self: _A, instance: None, owner: Any) -> _A: ... diff --git a/pynamodb/compat.py b/pynamodb/compat.py index 51633999d..e8e5f7bac 100644 --- a/pynamodb/compat.py +++ b/pynamodb/compat.py @@ -77,3 +77,7 @@ def getmembers_issubclass(object, classinfo): results.append((key, value)) results.sort() return results + + +def timedelta_total_seconds(td): # compensate for Python 2.6 not having total_seconds + return (td.microseconds + (td.seconds + td.days * 24 * 3600) * 10 ** 6) / 10 ** 6 diff --git a/pynamodb/connection/base.py b/pynamodb/connection/base.py index f835293b0..a6cae9aa1 100644 --- a/pynamodb/connection/base.py +++ b/pynamodb/connection/base.py @@ -39,7 +39,8 @@ ITEMS, DEFAULT_ENCODING, BINARY_SHORT, BINARY_SET_SHORT, LAST_EVALUATED_KEY, RESPONSES, UNPROCESSED_KEYS, UNPROCESSED_ITEMS, STREAM_SPECIFICATION, STREAM_VIEW_TYPE, STREAM_ENABLED, UPDATE_EXPRESSION, EXPRESSION_ATTRIBUTE_NAMES, EXPRESSION_ATTRIBUTE_VALUES, KEY_CONDITION_OPERATOR_MAP, - CONDITION_EXPRESSION, FILTER_EXPRESSION, FILTER_EXPRESSION_OPERATOR_MAP, NOT_CONTAINS, AND) + CONDITION_EXPRESSION, FILTER_EXPRESSION, FILTER_EXPRESSION_OPERATOR_MAP, NOT_CONTAINS, AND, + TIME_TO_LIVE_SPECIFICATION, ENABLED, UPDATE_TIME_TO_LIVE) from pynamodb.exceptions import ( TableError, QueryError, PutError, DeleteError, UpdateError, GetError, ScanError, TableDoesNotExist, VerboseClientError @@ -301,7 +302,7 @@ def dispatch(self, operation_name, operation_kwargs): Raises TableDoesNotExist if the specified table does not exist """ - if operation_name not in [DESCRIBE_TABLE, LIST_TABLES, UPDATE_TABLE, DELETE_TABLE, CREATE_TABLE]: + if operation_name not in [DESCRIBE_TABLE, LIST_TABLES, UPDATE_TABLE, UPDATE_TIME_TO_LIVE, DELETE_TABLE, CREATE_TABLE]: if RETURN_CONSUMED_CAPACITY not in operation_kwargs: operation_kwargs.update(self.get_consumed_capacity_map(TOTAL)) self._log_debug(operation_name, operation_kwargs) @@ -584,6 +585,22 @@ def create_table(self, raise TableError("Failed to create table: {0}".format(e), e) return data + def update_time_to_live(self, table_name, ttl_attribute_name): + """ + Performs the UpdateTimeToLive operation + """ + operation_kwargs = { + TABLE_NAME: table_name, + TIME_TO_LIVE_SPECIFICATION: { + ATTR_NAME: ttl_attribute_name, + ENABLED: True, + } + } + try: + return self.dispatch(UPDATE_TIME_TO_LIVE, operation_kwargs) + except BOTOCORE_EXCEPTIONS as e: + raise TableError("Failed to update TTL on table: {0}".format(e), e) + def delete_table(self, table_name): """ Performs the DeleteTable operation diff --git a/pynamodb/connection/table.py b/pynamodb/connection/table.py index 86e826e4a..aca8cb6de 100644 --- a/pynamodb/connection/table.py +++ b/pynamodb/connection/table.py @@ -268,6 +268,12 @@ def delete_table(self): """ return self.connection.delete_table(self.table_name) + def update_time_to_live(self, ttl_attr_name): + """ + Performs the UpdateTimeToLive operation and returns the result + """ + return self.connection.update_time_to_live(self.table_name, ttl_attr_name) + def update_table(self, read_capacity_units=None, write_capacity_units=None, diff --git a/pynamodb/constants.py b/pynamodb/constants.py index ce1f2cd07..994d75e7b 100644 --- a/pynamodb/constants.py +++ b/pynamodb/constants.py @@ -149,6 +149,11 @@ STREAM_NEW_AND_OLD_IMAGE = 'NEW_AND_OLD_IMAGES' STREAM_KEYS_ONLY = 'KEYS_ONLY' +# Constants for updating a table's TTL +UPDATE_TIME_TO_LIVE = 'UpdateTimeToLive' +TIME_TO_LIVE_SPECIFICATION = 'TimeToLiveSpecification' +ENABLED = 'Enabled' + # These are constants used in the KeyConditionExpression parameter # http://docs.aws.amazon.com/amazondynamodb/latest/APIReference/API_Query.html#DDB-Query-request-KeyConditionExpression EXCLUSIVE_START_KEY = 'ExclusiveStartKey' diff --git a/pynamodb/models.py b/pynamodb/models.py index caf738865..c065fc0f3 100644 --- a/pynamodb/models.py +++ b/pynamodb/models.py @@ -10,7 +10,8 @@ from six import add_metaclass from pynamodb.exceptions import DoesNotExist, TableDoesNotExist, TableError -from pynamodb.attributes import Attribute, AttributeContainer, AttributeContainerMeta, MapAttribute, ListAttribute +from pynamodb.attributes import ( + Attribute, AttributeContainer, AttributeContainerMeta, MapAttribute, ListAttribute, TTLAttribute) from pynamodb.connection.base import MetaTable from pynamodb.connection.table import TableConnection from pynamodb.connection.util import pythonic @@ -196,6 +197,10 @@ def __init__(cls, name, bases, attrs): if attr_obj.attr_name is None: attr_obj.attr_name = attr_name + ttl_attr_names = [name for name, attr_obj in attrs.items() if isinstance(attr_obj, TTLAttribute)] + if len(ttl_attr_names) > 1: + raise ValueError("The model has more than one TTL attribute: {}".format(", ".join(ttl_attr_names))) + if META_CLASS_NAME not in attrs: setattr(cls, META_CLASS_NAME, DefaultMeta) @@ -225,7 +230,7 @@ class Model(AttributeContainer): _index_classes = None DoesNotExist = DoesNotExist - def __init__(self, hash_key=None, range_key=None, **attributes): + def __init__(self, hash_key=None, range_key=None, _user_instantiated=True, **attributes): """ :param hash_key: Required. The hash key for this object. :param range_key: Only required if the table has a range key attribute. @@ -240,7 +245,7 @@ def __init__(self, hash_key=None, range_key=None, **attributes): "This table has no range key, but a range key value was provided: {0}".format(range_key) ) attributes[self._dynamo_to_python_attr(range_keyname)] = range_key - super(Model, self).__init__(**attributes) + super(Model, self).__init__(_user_instantiated=_user_instantiated, **attributes) @classmethod def has_map_or_list_attributes(cls): @@ -526,7 +531,7 @@ def from_raw_data(cls, data): attr = cls.get_attributes().get(attr_name, None) if attr: kwargs[attr_name] = attr.deserialize(attr.get_value(value)) - return cls(*args, **kwargs) + return cls(*args, _user_instantiated=False, **kwargs) @classmethod def count(cls, @@ -843,7 +848,12 @@ def describe_table(cls): return cls._get_connection().describe_table() @classmethod - def create_table(cls, wait=False, read_capacity_units=None, write_capacity_units=None): + def create_table( + cls, + wait=False, + read_capacity_units=None, + write_capacity_units=None, + ignore_update_ttl_errors=False): """ Create the table for this model @@ -885,12 +895,32 @@ def create_table(cls, wait=False, read_capacity_units=None, write_capacity_units if status: data = status.get(TABLE_STATUS) if data == ACTIVE: - return + break else: time.sleep(2) else: raise TableError("No TableStatus returned for table") + cls.update_ttl(ignore_update_ttl_errors) + + @classmethod + def update_ttl(cls, ignore_update_ttl_errors): + """ + Attempt to update the TTL on the table. + Certain implementations (eg: dynalite) do not support updating TTLs and will fail. + """ + ttl_attribute = cls._ttl_attribute() + if ttl_attribute: + # Some dynamoDB implementations (eg: dynalite) do not support updating TTLs so + # this will fail. It's fine for this to fail in those cases. + try: + cls._get_connection().update_time_to_live(ttl_attribute.attr_name) + except Exception: + if ignore_update_ttl_errors: + log.info("Unable to update the TTL for {}".format(cls.Meta.table_name)) + else: + raise + @classmethod def dumps(cls): """ @@ -939,7 +969,7 @@ def _from_data(cls, data): attributes[range_keyname] = { range_keytype: range_key } - item = cls() + item = cls(_user_instantiated=False) item._deserialize(attributes) return item @@ -1220,6 +1250,17 @@ def _hash_key_attribute(cls): hash_keyname = cls._get_meta_data().hash_keyname return attributes[cls._dynamo_to_python_attr(hash_keyname)] + @classmethod + def _ttl_attribute(cls): + """ + Returns the ttl attribute for this table + """ + attributes = cls.get_attributes() + for attr_obj in attributes.values(): + if isinstance(attr_obj, TTLAttribute): + return attr_obj + return None + def _get_keys(self): """ Returns the proper arguments for deleting diff --git a/pynamodb/tests/test_attributes.py b/pynamodb/tests/test_attributes.py index 5c318f767..66bb20b6f 100644 --- a/pynamodb/tests/test_attributes.py +++ b/pynamodb/tests/test_attributes.py @@ -3,10 +3,12 @@ """ import json import six +import time from base64 import b64encode from datetime import datetime +from datetime import timedelta from dateutil.parser import parse from dateutil.tz import tzutc @@ -16,7 +18,7 @@ from pynamodb.attributes import ( BinarySetAttribute, BinaryAttribute, NumberSetAttribute, NumberAttribute, UnicodeAttribute, UnicodeSetAttribute, UTCDateTimeAttribute, BooleanAttribute, LegacyBooleanAttribute, - MapAttribute, MapAttributeMeta, ListAttribute, JSONAttribute, _get_value_for_deserialize, + MapAttribute, MapAttributeMeta, ListAttribute, JSONAttribute, TTLAttribute, _get_value_for_deserialize, ) from pynamodb.constants import ( DATETIME_FORMAT, DEFAULT_ENCODING, NUMBER, STRING, STRING_SET, NUMBER_SET, BINARY_SET, @@ -44,6 +46,7 @@ class Meta: bool_attr = BooleanAttribute() json_attr = JSONAttribute() map_attr = MapAttribute() + ttl_attr = TTLAttribute() class CustomAttrMap(MapAttribute): @@ -494,6 +497,54 @@ def test_boolean_deserialize(self): assert attr.deserialize(False) is False +class TestTTLAttribute: + """ + Test TTLAttribute. + """ + def test_default_and_default_for_new(self): + with pytest.raises(ValueError, match='An attribute cannot have both default and default_for_new parameters'): + TTLAttribute(default=timedelta(seconds=1), default_for_new=timedelta(seconds=2)) + + @patch('time.time') + def test_timedelta_ttl(self, mock_time): + mock_time.side_effect = [1559692800] # 2019-06-05 00:00:00 UTC + model = AttributeTestModel() + model.ttl_attr = timedelta(seconds=60) + assert model.ttl_attr == datetime(2019, 6, 5, 0, 1, tzinfo=UTC) + + def test_datetime_naive_ttl(self): + model = AttributeTestModel() + with pytest.raises(ValueError, match='timezone-aware'): + model.ttl_attr = datetime(2019, 6, 5, 0, 1) + assert model.ttl_attr is None + + def test_datetime_with_tz_ttl(self): + model = AttributeTestModel() + model.ttl_attr = datetime(2019, 6, 5, 0, 1, tzinfo=UTC) + assert model.ttl_attr == datetime(2019, 6, 5, 0, 1, tzinfo=UTC) + + def test_ttl_attribute_wrong_type(self): + with pytest.raises(ValueError, match='TTLAttribute value must be a timedelta or datetime'): + model = AttributeTestModel() + model.ttl_attr = 'wrong type' + + def test_serialize_none(self): + model = AttributeTestModel() + model.ttl_attr = None + assert model.ttl_attr == None + assert TTLAttribute().serialize(model.ttl_attr) == None + + @patch('time.time') + def test_serialize_deserialize(self, mock_time): + mock_time.side_effect = [1559692800, 1559692800] # 2019-06-05 00:00:00 UTC + model = AttributeTestModel() + model.ttl_attr = timedelta(minutes=1) + assert model.ttl_attr == datetime(2019, 6, 5, 0, 1, tzinfo=UTC) + s = TTLAttribute().serialize(model.ttl_attr) + assert s == '1559692860' + assert TTLAttribute().deserialize(s) == datetime(2019, 6, 5, 0, 1, 0, tzinfo=UTC) + + class TestJSONAttribute: """ Tests json attributes diff --git a/pynamodb/tests/test_base_connection.py b/pynamodb/tests/test_base_connection.py index b9b33fa0e..afc9b368f 100644 --- a/pynamodb/tests/test_base_connection.py +++ b/pynamodb/tests/test_base_connection.py @@ -2673,6 +2673,12 @@ def test_get_expected_map(self): {'Expected': {'ForumName': {'ComparisonOperator': 'EQ', 'AttributeValueList': [{'S': 'foo'}]}}} ) + def test_update_time_to_live_fail(self): + conn = Connection(self.region) + with patch(PATCH_METHOD) as req: + req.side_effect = BotoCoreError + self.assertRaises(TableError, conn.update_time_to_live, 'test table', 'my_ttl') + def test_get_query_filter_map(self): conn = Connection(self.region) with patch(PATCH_METHOD) as req: diff --git a/pynamodb/tests/test_model.py b/pynamodb/tests/test_model.py index 9d6775e13..69fcc4aec 100644 --- a/pynamodb/tests/test_model.py +++ b/pynamodb/tests/test_model.py @@ -5,12 +5,13 @@ import random import json import copy -from datetime import datetime +from datetime import datetime, timedelta import six from botocore.client import ClientError from botocore.vendored import requests import pytest +from dateutil.tz import tzutc from pynamodb.compat import CompatTestCase as TestCase from pynamodb.tests.deep_eq import deep_eq @@ -30,7 +31,7 @@ from pynamodb.attributes import ( UnicodeAttribute, NumberAttribute, BinaryAttribute, UTCDateTimeAttribute, UnicodeSetAttribute, NumberSetAttribute, BinarySetAttribute, MapAttribute, - BooleanAttribute, ListAttribute) + BooleanAttribute, ListAttribute, TTLAttribute) from pynamodb.tests.data import ( MODEL_TABLE_DATA, GET_MODEL_ITEM_DATA, SIMPLE_MODEL_TABLE_DATA, BATCH_GET_ITEMS, SIMPLE_BATCH_GET_ITEMS, COMPLEX_TABLE_DATA, @@ -232,6 +233,7 @@ class Meta: zip_code = NumberAttribute(null=True) email = UnicodeAttribute(default='needs_email') callable_field = NumberAttribute(default=lambda: 42) + ttl = TTLAttribute(null=True) class HostSpecificModel(Model): @@ -444,6 +446,13 @@ class Meta: breed = UnicodeAttribute() +class TTLModel(Model): + class Meta: + table_name = 'TTLModel' + user_name = UnicodeAttribute(hash_key=True) + my_ttl = TTLAttribute(default_for_new=timedelta(minutes=1)) + + class ModelTestCase(TestCase): """ Tests for the models API @@ -4467,3 +4476,32 @@ def test_subclassed_map_attribute_with_map_attribute_member_with_initialized_ins self.assertEquals(actual.left.left.value, left_instance.left.value) self.assertEquals(actual.right.right.left.value, right_instance.right.left.value) self.assertEquals(actual.right.right.value, right_instance.right.value) + + def test_bad_ttl_model(self): + with self.assertRaises(ValueError): + class BadTTLModel(Model): + class Meta: + table_name = 'BadTTLModel' + ttl = TTLAttribute(default_for_new=timedelta(minutes=1)) + another_ttl = TTLAttribute() + BadTTLModel() + + def test_get_ttl_attribute_fails(self): + with patch(PATCH_METHOD) as req: + req.side_effect = Exception + self.assertRaises(Exception, TTLModel.update_ttl, False) + + def test_get_ttl_attribute(self): + assert TTLModel._ttl_attribute().attr_name == "my_ttl" + + def test_deserialized(self): + with patch(PATCH_METHOD) as req: + req.return_value = SIMPLE_MODEL_TABLE_DATA + m = TTLModel.from_raw_data({'user_name': {'S': 'mock'}}) + assert m.my_ttl is None + + def test_deserialized_with_ttl(self): + with patch(PATCH_METHOD) as req: + req.return_value = SIMPLE_MODEL_TABLE_DATA + m = TTLModel.from_raw_data({'user_name': {'S': 'mock'}, 'my_ttl': {'N': '1546300800'}}) + assert m.my_ttl == datetime(2019, 1, 1, tzinfo=tzutc()) diff --git a/pynamodb/tests/test_table_connection.py b/pynamodb/tests/test_table_connection.py index 81ebce3dd..8360ab81d 100644 --- a/pynamodb/tests/test_table_connection.py +++ b/pynamodb/tests/test_table_connection.py @@ -110,6 +110,24 @@ def test_create_table(self): kwargs = req.call_args[0][1] self.assertEqual(kwargs, params) + def test_update_time_to_live(self): + """ + TableConnection.update_time_to_live + """ + params = { + 'TableName': 'ci-table', + 'TimeToLiveSpecification': { + 'AttributeName': 'ttl_attr', + 'Enabled': True, + } + } + with patch(PATCH_METHOD) as req: + req.return_value = HttpOK(), None + conn = TableConnection(self.test_table_name) + conn.update_time_to_live('ttl_attr') + kwargs = req.call_args[0][1] + self.assertEqual(kwargs, params) + def test_delete_table(self): """ TableConnection.delete_table