diff --git a/pynamodb/connection/base.py b/pynamodb/connection/base.py index 86d39edc9..cae5512df 100644 --- a/pynamodb/connection/base.py +++ b/pynamodb/connection/base.py @@ -24,7 +24,7 @@ RETURN_CONSUMED_CAPACITY_VALUES, RETURN_ITEM_COLL_METRICS_VALUES, COMPARISON_OPERATOR_VALUES, RETURN_ITEM_COLL_METRICS, RETURN_CONSUMED_CAPACITY, RETURN_VALUES_VALUES, ATTR_UPDATE_ACTIONS, COMPARISON_OPERATOR, EXCLUSIVE_START_KEY, SCAN_INDEX_FORWARD, SCAN_FILTER_VALUES, ATTR_DEFINITIONS, - BATCH_WRITE_ITEM, CONSISTENT_READ, ATTR_VALUE_LIST, DESCRIBE_TABLE, DEFAULT_REGION, KEY_CONDITIONS, + BATCH_WRITE_ITEM, CONSISTENT_READ, ATTR_VALUE_LIST, DESCRIBE_TABLE, KEY_CONDITIONS, BATCH_GET_ITEM, DELETE_REQUEST, SELECT_VALUES, RETURN_VALUES, REQUEST_ITEMS, ATTR_UPDATES, ATTRS_TO_GET, SERVICE_NAME, DELETE_ITEM, PUT_REQUEST, UPDATE_ITEM, SCAN_FILTER, TABLE_NAME, INDEX_NAME, KEY_SCHEMA, ATTR_NAME, ATTR_TYPE, TABLE_KEY, EXPECTED, KEY_TYPE, GET_ITEM, UPDATE, @@ -36,14 +36,10 @@ CONDITIONAL_OPERATORS, NULL, NOT_NULL, SHORT_ATTR_TYPES, DELETE, ITEMS, DEFAULT_ENCODING, BINARY_SHORT, BINARY_SET_SHORT, LAST_EVALUATED_KEY, RESPONSES, UNPROCESSED_KEYS, UNPROCESSED_ITEMS, STREAM_SPECIFICATION, STREAM_VIEW_TYPE, STREAM_ENABLED) +from pynamodb.settings import get_settings_value BOTOCORE_EXCEPTIONS = (BotoCoreError, ClientError) -# retry parameters -DEFAULT_TIMEOUT = 60 # matches legacy retry timeout from botocore -DEFAULT_MAX_RETRY_ATTEMPTS_EXCEPTION = 3 -DEFAULT_BASE_BACKOFF_MS = 25 - log = logging.getLogger(__name__) log.addHandler(NullHandler()) @@ -178,7 +174,8 @@ class Connection(object): A higher level abstraction over botocore """ - def __init__(self, region=None, host=None, session_cls=None): + def __init__(self, region=None, host=None, session_cls=None, + request_timeout_seconds=None, max_retry_attempts=None, base_backoff_ms=None): self._tables = {} self.host = host self._session = None @@ -187,17 +184,27 @@ def __init__(self, region=None, host=None, session_cls=None): if region: self.region = region else: - self.region = DEFAULT_REGION - - # TODO: provide configurability of retry parameters via arguments - self._request_timeout_seconds = DEFAULT_TIMEOUT - self._max_retry_attempts_exception = DEFAULT_MAX_RETRY_ATTEMPTS_EXCEPTION - self._base_backoff_ms = DEFAULT_BASE_BACKOFF_MS + self.region = get_settings_value('region') if session_cls: self.session_cls = session_cls else: - self.session_cls = requests.Session + self.session_cls = get_settings_value('session_cls') + + if request_timeout_seconds is not None: + self._request_timeout_seconds = request_timeout_seconds + else: + self._request_timeout_seconds = get_settings_value('request_timeout_seconds') + + if max_retry_attempts is not None: + self._max_retry_attempts_exception = max_retry_attempts + else: + self._max_retry_attempts_exception = get_settings_value('max_retry_attempts') + + if base_backoff_ms is not None: + self._base_backoff_ms = base_backoff_ms + else: + self._base_backoff_ms = get_settings_value('base_backoff_ms') def __repr__(self): return six.u("Connection<{0}>".format(self.client.meta.endpoint_url)) @@ -254,8 +261,9 @@ def _make_api_call(self, operation_name, operation_kwargs): ) prepared_request = self.client._endpoint.create_request(request_dict, operation_model) - for attempt_number in range(1, self._max_retry_attempts_exception + 1): - is_last_attempt_for_exceptions = attempt_number == self._max_retry_attempts_exception + for i in range(0, self._max_retry_attempts_exception + 1): + attempt_number = i + 1 + is_last_attempt_for_exceptions = i == self._max_retry_attempts_exception try: response = self.requests_session.send( @@ -309,7 +317,7 @@ def _make_api_call(self, operation_name, operation_kwargs): else: # We use fully-jittered exponentially-backed-off retries: # https://www.awsarchitectureblog.com/2015/03/backoff.html - sleep_time_ms = random.randint(0, self._base_backoff_ms * (2 ** attempt_number)) + sleep_time_ms = random.randint(0, self._base_backoff_ms * (2 ** i)) log.debug( 'Retry with backoff needed for (%s) after attempt %s,' 'sleeping for %s milliseconds, retryable %s caught: %s', diff --git a/pynamodb/connection/table.py b/pynamodb/connection/table.py index 69162b11f..d1d7e35e6 100644 --- a/pynamodb/connection/table.py +++ b/pynamodb/connection/table.py @@ -10,11 +10,23 @@ class TableConnection(object): A higher level abstraction over botocore """ - def __init__(self, table_name, region=None, host=None, session_cls=None,): + def __init__(self, + table_name, + region=None, + host=None, + session_cls=None, + request_timeout_seconds=None, + max_retry_attempts=None, + base_backoff_ms=None): self._hash_keyname = None self._range_keyname = None self.table_name = table_name - self.connection = Connection(region=region, host=host, session_cls=session_cls,) + self.connection = Connection(region=region, + host=host, + session_cls=session_cls, + request_timeout_seconds=request_timeout_seconds, + max_retry_attempts=max_retry_attempts, + base_backoff_ms=base_backoff_ms) def delete_item(self, hash_key, range_key=None, diff --git a/pynamodb/constants.py b/pynamodb/constants.py index 4bf540709..68bcae3db 100644 --- a/pynamodb/constants.py +++ b/pynamodb/constants.py @@ -239,3 +239,4 @@ AND = 'AND' OR = 'OR' CONDITIONAL_OPERATORS = [AND, OR] + diff --git a/pynamodb/models.py b/pynamodb/models.py index ca05cdcbb..9bd175ed9 100644 --- a/pynamodb/models.py +++ b/pynamodb/models.py @@ -17,6 +17,7 @@ from pynamodb.types import HASH, RANGE from pynamodb.compat import NullHandler from pynamodb.indexes import Index, GlobalSecondaryIndex +from pynamodb.settings import get_settings_value from pynamodb.constants import ( ATTR_TYPE_MAP, ATTR_DEFINITIONS, ATTR_NAME, ATTR_TYPE, KEY_SCHEMA, KEY_TYPE, ITEM, ITEMS, READ_CAPACITY_UNITS, WRITE_CAPACITY_UNITS, CAMEL_COUNT, @@ -27,7 +28,7 @@ TABLE_STATUS, ACTIVE, RETURN_VALUES, BATCH_GET_PAGE_LIMIT, UNPROCESSED_KEYS, PUT_REQUEST, DELETE_REQUEST, LAST_EVALUATED_KEY, QUERY_OPERATOR_MAP, NOT_NULL, SCAN_OPERATOR_MAP, CONSUMED_CAPACITY, BATCH_WRITE_PAGE_LIMIT, TABLE_NAME, - CAPACITY_UNITS, DEFAULT_REGION, META_CLASS_NAME, REGION, HOST, EXISTS, NULL, + CAPACITY_UNITS, META_CLASS_NAME, REGION, HOST, EXISTS, NULL, DELETE_FILTER_OPERATOR_MAP, UPDATE_FILTER_OPERATOR_MAP, PUT_FILTER_OPERATOR_MAP, COUNT, ITEM_COUNT, KEY, UNPROCESSED_ITEMS, STREAM_VIEW_TYPE, STREAM_SPECIFICATION, STREAM_ENABLED, EQ, NE) @@ -137,9 +138,7 @@ def commit(self): class DefaultMeta(object): - table_name = None - region = DEFAULT_REGION - host = None + pass class ResultSet(object): @@ -165,11 +164,17 @@ def __init__(cls, name, bases, attrs): for attr_name, attr_obj in attrs.items(): if attr_name == META_CLASS_NAME: if not hasattr(attr_obj, REGION): - setattr(attr_obj, REGION, DEFAULT_REGION) + setattr(attr_obj, REGION, get_settings_value('region')) if not hasattr(attr_obj, HOST): - setattr(attr_obj, HOST, None) + setattr(attr_obj, HOST, get_settings_value('host')) if not hasattr(attr_obj, 'session_cls'): - setattr(attr_obj, 'session_cls', None) + setattr(attr_obj, 'session_cls', get_settings_value('session_cls')) + if not hasattr(attr_obj, 'request_timeout_seconds'): + setattr(attr_obj, 'request_timeout_seconds', get_settings_value('request_timeout_seconds')) + if not hasattr(attr_obj, 'base_backoff_ms'): + setattr(attr_obj, 'base_backoff_ms', get_settings_value('base_backoff_ms')) + if not hasattr(attr_obj, 'max_retry_attempts'): + setattr(attr_obj, 'max_retry_attempts', get_settings_value('max_retry_attempts')) elif issubclass(attr_obj.__class__, (Index, )): attr_obj.Meta.model = cls if not hasattr(attr_obj.Meta, "index_name"): @@ -1167,8 +1172,13 @@ def _get_connection(cls): See https://pynamodb.readthedocs.io/en/latest/release_notes.html""" ) if cls._connection is None: - cls._connection = TableConnection(cls.Meta.table_name, region=cls.Meta.region, host=cls.Meta.host, - session_cls=cls.Meta.session_cls) + cls._connection = TableConnection(cls.Meta.table_name, + region=cls.Meta.region, + host=cls.Meta.host, + session_cls=cls.Meta.session_cls, + request_timeout_seconds=cls.Meta.request_timeout_seconds, + max_retry_attempts=cls.Meta.max_retry_attempts, + base_backoff_ms=cls.Meta.base_backoff_ms) return cls._connection def _deserialize(self, attrs): diff --git a/pynamodb/settings.py b/pynamodb/settings.py new file mode 100644 index 000000000..f2b62608c --- /dev/null +++ b/pynamodb/settings.py @@ -0,0 +1,40 @@ +import imp +import logging +import os +from os import getenv + +from botocore.vendored import requests + +log = logging.getLogger(__name__) + +default_settings_dict = { + 'request_timeout_seconds': 60, + 'max_retry_attempts': 3, + 'base_backoff_ms': 25, + 'region': 'us-east-1', + 'session_cls': requests.Session +} + +OVERRIDE_SETTINGS_PATH = getenv('PYNAMODB_CONFIG', '/etc/pynamodb/global_default_settings.py') + +override_settings = {} +if os.path.isfile(OVERRIDE_SETTINGS_PATH): + override_settings = imp.load_source(OVERRIDE_SETTINGS_PATH, OVERRIDE_SETTINGS_PATH) + log.info('Override settings for pynamo available {0}'.format(OVERRIDE_SETTINGS_PATH)) +else: + log.info('Override settings for pynamo not available {0}'.format(OVERRIDE_SETTINGS_PATH)) + log.info('Using Default settings value') + + +def get_settings_value(key): + """ + Fetches the value from the override file. + If the value is not present, then tries to fetch the values from constants.py + """ + if hasattr(override_settings, key): + return getattr(override_settings, key) + + if key in default_settings_dict: + return default_settings_dict[key] + + return None diff --git a/pynamodb/tests/test_base_connection.py b/pynamodb/tests/test_base_connection.py index 610339c76..2a026b2e4 100644 --- a/pynamodb/tests/test_base_connection.py +++ b/pynamodb/tests/test_base_connection.py @@ -14,7 +14,6 @@ from pynamodb.tests.deep_eq import deep_eq from botocore.exceptions import BotoCoreError from botocore.client import ClientError - if six.PY3: from unittest.mock import patch from unittest import mock @@ -1664,9 +1663,10 @@ def test_make_api_call_throws_verbose_error_after_backoff(self, requests_session ) raise + @mock.patch('random.randint') @mock.patch('pynamodb.connection.Connection.session') @mock.patch('pynamodb.connection.Connection.requests_session') - def test_make_api_call_throws_verbose_error_after_backoff_later_succeeds(self, requests_session_mock, session_mock): + def test_make_api_call_throws_verbose_error_after_backoff_later_succeeds(self, requests_session_mock, session_mock, rand_int_mock): # mock response bad_response = requests.Response() @@ -1681,14 +1681,19 @@ def test_make_api_call_throws_verbose_error_after_backoff_later_succeeds(self, r good_response._content = json.dumps(good_response_content).encode('utf-8') requests_session_mock.send.side_effect = [ + bad_response, bad_response, good_response, ] + rand_int_mock.return_value = 1 + c = Connection() self.assertEqual(good_response_content, c._make_api_call('CreateTable', {'TableName': 'MyTable'})) - self.assertEqual(len(requests_session_mock.send.mock_calls), 2) + self.assertEqual(len(requests_session_mock.send.mock_calls), 3) + + assert rand_int_mock.call_args_list == [mock.call(0, 25), mock.call(0, 50)] @mock.patch('pynamodb.connection.Connection.session') @mock.patch('pynamodb.connection.Connection.requests_session') @@ -1705,13 +1710,13 @@ def test_make_api_call_retries_properly(self, requests_session_mock, session_moc session_mock.create_client.return_value._endpoint.create_request.return_value = prepared_request requests_session_mock.send.side_effect = [ - requests.ConnectionError('problems!'), + bad_response, requests.Timeout('problems!'), bad_response, deserializable_response ] c = Connection() - c._max_retry_attempts_exception = 4 + c._max_retry_attempts_exception = 3 c._make_api_call('DescribeTable', {'TableName': 'MyTable'}) self.assertEqual(len(requests_session_mock.mock_calls), 4) @@ -1732,15 +1737,40 @@ def test_make_api_call_throws_when_retries_exhausted(self, requests_session_mock requests.Timeout('problems!'), ] c = Connection() - c._max_retry_attempts_exception = 4 + c._max_retry_attempts_exception = 3 with self.assertRaises(requests.Timeout): c._make_api_call('DescribeTable', {'TableName': 'MyTable'}) self.assertEqual(len(requests_session_mock.mock_calls), 4) + assert requests_session_mock.send.call_args[1]['timeout'] == 60 for call in requests_session_mock.mock_calls: self.assertEqual(call[:2], ('send', (prepared_request,))) + + @mock.patch('random.randint') + @mock.patch('pynamodb.connection.Connection.session') + @mock.patch('pynamodb.connection.Connection.requests_session') + def test_make_api_call_throws_retry_disabled(self, requests_session_mock, session_mock, rand_int_mock): + prepared_request = requests.Request('GET', 'http://lyft.com').prepare() + session_mock.create_client.return_value._endpoint.create_request.return_value = prepared_request + + requests_session_mock.send.side_effect = [ + requests.Timeout('problems!'), + ] + c = Connection(request_timeout_seconds=11, base_backoff_ms=3, max_retry_attempts=0) + assert c._base_backoff_ms == 3 + with self.assertRaises(requests.Timeout): + c._make_api_call('DescribeTable', {'TableName': 'MyTable'}) + + self.assertEqual(len(requests_session_mock.mock_calls), 1) + rand_int_mock.assert_not_called() + + assert requests_session_mock.send.call_args[1]['timeout'] == 11 + for call in requests_session_mock.mock_calls: + self.assertEqual(call[:2], ('send', (prepared_request,))) + + def test_handle_binary_attributes_for_unprocessed_items(self): binary_blob = six.b('\x00\xFF\x00\xFF') diff --git a/pynamodb/tests/test_model.py b/pynamodb/tests/test_model.py index aa6692aef..c7dfee493 100644 --- a/pynamodb/tests/test_model.py +++ b/pynamodb/tests/test_model.py @@ -9,6 +9,7 @@ import six from botocore.client import ClientError +from botocore.vendored import requests from pynamodb.compat import CompatTestCase as TestCase from pynamodb.tests.deep_eq import deep_eq @@ -282,6 +283,29 @@ class Meta: is_human = BooleanAttribute() +class OverriddenSession(requests.Session): + """ + A overridden session for test + """ + def __init__(self): + super(OverriddenSession, self).__init__() + + +class OverriddenSessionModel(Model): + """ + A testing model + """ + class Meta: + table_name = 'OverriddenSessionModel' + request_timeout_seconds = 9999 + max_retry_attempts = 200 + base_backoff_ms = 4120 + session_cls = OverriddenSession + + random_user_name = UnicodeAttribute(hash_key=True, attr_name='random_name_1') + random_attr = UnicodeAttribute(attr_name='random_attr_1', null=True) + + class ModelTestCase(TestCase): """ Tests for the models API @@ -334,6 +358,17 @@ def fake_dynamodb(*args): # Test for default region self.assertEqual(UserModel.Meta.region, 'us-east-1') + self.assertEqual(UserModel.Meta.request_timeout_seconds, 60) + self.assertEqual(UserModel.Meta.max_retry_attempts, 3) + self.assertEqual(UserModel.Meta.base_backoff_ms, 25) + self.assertTrue(UserModel.Meta.session_cls is requests.Session) + + self.assertEqual(UserModel._connection.connection._request_timeout_seconds, 60) + self.assertEqual(UserModel._connection.connection._max_retry_attempts_exception, 3) + self.assertEqual(UserModel._connection.connection._base_backoff_ms, 25) + + self.assertTrue(type(UserModel._connection.connection.requests_session) is requests.Session) + with patch(PATCH_METHOD) as req: req.return_value = MODEL_TABLE_DATA UserModel.create_table(read_capacity_units=2, write_capacity_units=2) @@ -481,6 +516,28 @@ def test_overidden_defaults(self): self.assert_dict_lists_equal(correct_schema['KeySchema'], schema['key_schema']) self.assert_dict_lists_equal(correct_schema['AttributeDefinitions'], schema['attribute_definitions']) + def test_overidden_session(self): + """ + Custom session + """ + fake_db = MagicMock() + + with patch(PATCH_METHOD, new=fake_db): + with patch("pynamodb.connection.TableConnection.describe_table") as req: + req.return_value = None + with self.assertRaises(TableError): + OverriddenSessionModel.create_table(read_capacity_units=2, write_capacity_units=2, wait=True) + + self.assertEqual(OverriddenSessionModel.Meta.request_timeout_seconds, 9999) + self.assertEqual(OverriddenSessionModel.Meta.max_retry_attempts, 200) + self.assertEqual(OverriddenSessionModel.Meta.base_backoff_ms, 4120) + self.assertTrue(OverriddenSessionModel.Meta.session_cls is OverriddenSession) + + self.assertEqual(OverriddenSessionModel._connection.connection._request_timeout_seconds, 9999) + self.assertEqual(OverriddenSessionModel._connection.connection._max_retry_attempts_exception, 200) + self.assertEqual(OverriddenSessionModel._connection.connection._base_backoff_ms, 4120) + self.assertTrue(type(OverriddenSessionModel._connection.connection.requests_session) is OverriddenSession) + def test_refresh(self): """ Model.refresh