diff --git a/docs/release_notes.rst b/docs/release_notes.rst index 09ba3732e..314cf3506 100644 --- a/docs/release_notes.rst +++ b/docs/release_notes.rst @@ -1,6 +1,19 @@ Release Notes ============= +v5.3.0 +---------- +* No longer call ``DescribeTable`` API before first operation + + Before this change, we would call ``DescribeTable`` before the first operation + on a given table in order to discover its schema. This slowed down bootstrap + (particularly important for lambdas), complicated testing and could potentially + cause inconsistent behavior since queries were serialized using the table's + (key) schema but deserialized using the model's schema. + + With this change, both queries and models now use the model's schema. + + v5.2.3 ---------- * Update for botocore 1.28 private API change (#1087) which caused the following exception:: diff --git a/pynamodb/__init__.py b/pynamodb/__init__.py index 222fbd95d..43017d49b 100644 --- a/pynamodb/__init__.py +++ b/pynamodb/__init__.py @@ -7,4 +7,4 @@ """ __author__ = 'Jharrod LaFon' __license__ = 'MIT' -__version__ = '5.2.3' +__version__ = '5.3.0' diff --git a/pynamodb/connection/base.py b/pynamodb/connection/base.py index 706e0535d..61df60d74 100644 --- a/pynamodb/connection/base.py +++ b/pynamodb/connection/base.py @@ -79,6 +79,13 @@ def __repr__(self) -> str: return "MetaTable<{}>".format(self.data.get(TABLE_NAME)) return "" + @property + def table_name(self) -> str: + """ + Returns the table name + """ + return self.data[TABLE_NAME] + @property def range_keyname(self) -> Optional[str]: """ @@ -559,25 +566,22 @@ def client(self) -> BotocoreBaseClientPrivate: self._convert_to_request_dict__endpoint_url = 'endpoint_url' in inspect.signature(self._client._convert_to_request_dict).parameters return self._client - def get_meta_table(self, table_name: str, refresh: bool = False): + def add_meta_table(self, meta_table: MetaTable) -> None: """ - Returns a MetaTable + Adds information about the table's schema. """ - if table_name not in self._tables or refresh: - operation_kwargs = { - TABLE_NAME: table_name - } - try: - data = self.dispatch(DESCRIBE_TABLE, operation_kwargs) - self._tables[table_name] = MetaTable(data.get(TABLE_KEY)) - except BotoCoreError as e: - raise TableError("Unable to describe table: {}".format(e), e) - except ClientError as e: - if 'ResourceNotFound' in e.response['Error']['Code']: - raise TableDoesNotExist(e.response['Error']['Message']) - else: - raise - return self._tables[table_name] + if meta_table.table_name in self._tables: + raise ValueError(f"Meta-table for '{meta_table.table_name}' already added") + self._tables[meta_table.table_name] = meta_table + + def get_meta_table(self, table_name: str) -> MetaTable: + """ + Returns information about the table's schema. + """ + try: + return self._tables[table_name] + except KeyError: + raise TableError(f"Meta-table for '{table_name}' not initialized") from None def create_table( self, @@ -608,8 +612,8 @@ def create_table( raise ValueError("attribute_definitions argument is required") for attr in attribute_definitions: attrs_list.append({ - ATTR_NAME: attr.get('attribute_name'), - ATTR_TYPE: attr.get('attribute_type') + ATTR_NAME: attr.get(ATTR_NAME) or attr['attribute_name'], + ATTR_TYPE: attr.get(ATTR_TYPE) or attr['attribute_type'] }) operation_kwargs[ATTR_DEFINITIONS] = attrs_list @@ -639,8 +643,8 @@ def create_table( key_schema_list = [] for item in key_schema: key_schema_list.append({ - ATTR_NAME: item.get('attribute_name'), - KEY_TYPE: str(item.get('key_type')).upper() + ATTR_NAME: item.get(ATTR_NAME) or item['attribute_name'], + KEY_TYPE: str(item.get(KEY_TYPE) or item['key_type']).upper() }) operation_kwargs[KEY_SCHEMA] = sorted(key_schema_list, key=lambda x: x.get(KEY_TYPE)) @@ -767,13 +771,26 @@ def describe_table(self, table_name: str) -> Dict: """ Performs the DescribeTable operation """ + operation_kwargs = { + TABLE_NAME: table_name + } try: - tbl = self.get_meta_table(table_name, refresh=True) - if tbl: - return tbl.data - except ValueError: - pass - raise TableDoesNotExist(table_name) + data = self.dispatch(DESCRIBE_TABLE, operation_kwargs) + table_data = data.get(TABLE_KEY) + # For compatibility with existing code which uses Connection directly, + # we can let DescribeTable set the meta table. + if table_data: + meta_table = MetaTable(table_data) + if meta_table.table_name not in self._tables: + self.add_meta_table(meta_table) + return table_data + except BotoCoreError as e: + raise TableError("Unable to describe table: {}".format(e), e) + except ClientError as e: + if 'ResourceNotFound' in e.response['Error']['Code']: + raise TableDoesNotExist(e.response['Error']['Message']) + else: + raise def get_item_attribute_map( self, diff --git a/pynamodb/connection/table.py b/pynamodb/connection/table.py index 183467a9f..fb7720e00 100644 --- a/pynamodb/connection/table.py +++ b/pynamodb/connection/table.py @@ -30,6 +30,8 @@ def __init__( aws_access_key_id: Optional[str] = None, aws_secret_access_key: Optional[str] = None, aws_session_token: Optional[str] = None, + *, + meta_table: Optional[MetaTable] = None, ) -> None: self.table_name = table_name self.connection = Connection(region=region, @@ -40,17 +42,19 @@ def __init__( base_backoff_ms=base_backoff_ms, max_pool_connections=max_pool_connections, extra_headers=extra_headers) + if meta_table is not None: + self.connection.add_meta_table(meta_table) if aws_access_key_id and aws_secret_access_key: self.connection.session.set_credentials(aws_access_key_id, aws_secret_access_key, aws_session_token) - def get_meta_table(self, refresh: bool = False) -> MetaTable: + def get_meta_table(self) -> MetaTable: """ Returns a MetaTable """ - return self.connection.get_meta_table(self.table_name, refresh=refresh) + return self.connection.get_meta_table(self.table_name) def get_operation_kwargs( self, diff --git a/pynamodb/models.py b/pynamodb/models.py index b1a6aa4cc..08b90e814 100644 --- a/pynamodb/models.py +++ b/pynamodb/models.py @@ -24,6 +24,8 @@ from typing import Union from typing import cast +from pynamodb.connection.base import MetaTable + if sys.version_info >= (3, 8): from typing import Protocol else: @@ -38,9 +40,10 @@ from pynamodb.connection.table import TableConnection from pynamodb.expressions.condition import Condition from pynamodb.types import HASH, RANGE -from pynamodb.indexes import Index, GlobalSecondaryIndex +from pynamodb.indexes import Index, GlobalSecondaryIndex, LocalSecondaryIndex from pynamodb.pagination import ResultIterator from pynamodb.settings import get_settings_value, OperationSettings +from pynamodb import constants from pynamodb.constants import ( ATTR_DEFINITIONS, ATTR_NAME, ATTR_TYPE, KEY_SCHEMA, KEY_TYPE, ITEM, READ_CAPACITY_UNITS, WRITE_CAPACITY_UNITS, @@ -53,7 +56,7 @@ BATCH_WRITE_PAGE_LIMIT, META_CLASS_NAME, REGION, HOST, NULL, COUNT, ITEM_COUNT, KEY, UNPROCESSED_ITEMS, STREAM_VIEW_TYPE, - STREAM_SPECIFICATION, STREAM_ENABLED, BILLING_MODE, PAY_PER_REQUEST_BILLING_MODE, TAGS + STREAM_SPECIFICATION, STREAM_ENABLED, BILLING_MODE, PAY_PER_REQUEST_BILLING_MODE, TAGS, TABLE_NAME ) from pynamodb.util import attribute_value_to_json from pynamodb.util import json_to_attribute_value @@ -863,18 +866,18 @@ def _get_schema(cls) -> Dict[str, Any]: for attr_name, attr_cls in cls.get_attributes().items(): if attr_cls.is_hash_key or attr_cls.is_range_key: schema['attribute_definitions'].append({ - 'attribute_name': attr_cls.attr_name, - 'attribute_type': attr_cls.attr_type + ATTR_NAME: attr_cls.attr_name, + ATTR_TYPE: attr_cls.attr_type }) if attr_cls.is_hash_key: schema['key_schema'].append({ - 'key_type': HASH, - 'attribute_name': attr_cls.attr_name + KEY_TYPE: HASH, + ATTR_NAME: attr_cls.attr_name }) elif attr_cls.is_range_key: schema['key_schema'].append({ - 'key_type': RANGE, - 'attribute_name': attr_cls.attr_name + KEY_TYPE: RANGE, + ATTR_NAME: attr_cls.attr_name }) for index in cls._indexes.values(): index_schema = index._get_schema() @@ -887,13 +890,13 @@ def _get_schema(cls) -> Dict[str, Any]: attr_names = {key_schema[ATTR_NAME] for index_schema in (*schema['global_secondary_indexes'], *schema['local_secondary_indexes']) for key_schema in index_schema['key_schema']} - attr_keys = {attr.get('attribute_name') for attr in schema['attribute_definitions']} + attr_keys = {attr[ATTR_NAME] for attr in schema['attribute_definitions']} for attr_name in attr_names: if attr_name not in attr_keys: attr_cls = cls.get_attributes()[cls._dynamo_to_python_attr(attr_name)] schema['attribute_definitions'].append({ - 'attribute_name': attr_cls.attr_name, - 'attribute_type': attr_cls.attr_type + ATTR_NAME: attr_cls.attr_name, + ATTR_TYPE: attr_cls.attr_type }) return schema @@ -1057,7 +1060,28 @@ def _get_connection(cls) -> TableConnection: # For now we just check that the connection exists and (in the case of model inheritance) # points to the same table. In the future we should update the connection if any of the attributes differ. if cls._connection is None or cls._connection.table_name != cls.Meta.table_name: + schema = cls._get_schema() + meta_table = MetaTable({ + constants.TABLE_NAME: cls.Meta.table_name, + constants.KEY_SCHEMA: schema['key_schema'], + constants.ATTR_DEFINITIONS: schema['attribute_definitions'], + constants.GLOBAL_SECONDARY_INDEXES: [ + { + constants.INDEX_NAME: index_schema['index_name'], + constants.KEY_SCHEMA: index_schema['key_schema'], + } + for index_schema in schema['global_secondary_indexes'] + ], + constants.LOCAL_SECONDARY_INDEXES: [ + { + constants.INDEX_NAME: index_schema['index_name'], + constants.KEY_SCHEMA: index_schema['key_schema'], + } + for index_schema in schema['local_secondary_indexes'] + ], + }) cls._connection = TableConnection(cls.Meta.table_name, + meta_table=meta_table, region=cls.Meta.region, host=cls.Meta.host, connect_timeout_seconds=cls.Meta.connect_timeout_seconds, diff --git a/tests/data.py b/tests/data.py index 3b99a8aa0..5d7514a30 100644 --- a/tests/data.py +++ b/tests/data.py @@ -39,7 +39,6 @@ } } - MODEL_TABLE_DATA = { "Table": { "AttributeDefinitions": [ @@ -345,89 +344,6 @@ } } -DESCRIBE_TABLE_DATA_PAY_PER_REQUEST = { - "Table": { - "AttributeDefinitions": [ - { - "AttributeName": "ForumName", - "AttributeType": "S" - }, - { - "AttributeName": "LastPostDateTime", - "AttributeType": "S" - }, - { - "AttributeName": "Subject", - "AttributeType": "S" - } - ], - "CreationDateTime": 1.363729002358E9, - "ItemCount": 0, - "KeySchema": [ - { - "AttributeName": "ForumName", - "KeyType": "HASH" - }, - { - "AttributeName": "Subject", - "KeyType": "RANGE" - } - ], - "GlobalSecondaryIndexes": [ - { - "IndexName": "LastPostIndex", - "IndexSizeBytes": 0, - "ItemCount": 0, - "KeySchema": [ - { - "AttributeName": "ForumName", - "KeyType": "HASH" - }, - { - "AttributeName": "LastPostDateTime", - "KeyType": "RANGE" - } - ], - "Projection": { - "ProjectionType": "KEYS_ONLY" - } - } - ], - "LocalSecondaryIndexes": [ - { - "IndexName": "LastPostIndex", - "IndexSizeBytes": 0, - "ItemCount": 0, - "KeySchema": [ - { - "AttributeName": "ForumName", - "KeyType": "HASH" - }, - { - "AttributeName": "LastPostDateTime", - "KeyType": "RANGE" - } - ], - "Projection": { - "ProjectionType": "KEYS_ONLY" - } - } - ], - "ProvisionedThroughput": { - "NumberOfDecreasesToday": 0, - "ReadCapacityUnits": 0, - "WriteCapacityUnits": 0 - }, - "TableName": "Thread", - "TableSizeBytes": 0, - "TableStatus": "ACTIVE", - "BillingModeSummary": { - "BillingMode": "PAY_PER_REQUEST", - "LastUpdateToPayPerRequestDateTime": 1548353644.074 - } - } -} - GET_MODEL_ITEM_DATA = { 'Item': { 'user_name': { diff --git a/tests/test_base_connection.py b/tests/test_base_connection.py index 901182753..1ffbfcb05 100644 --- a/tests/test_base_connection.py +++ b/tests/test_base_connection.py @@ -64,7 +64,7 @@ class ConnectionTestCase(TestCase): """ def setUp(self): - self.test_table_name = 'ci-table' + self.test_table_name = 'Thread' self.region = DEFAULT_REGION def test_create_connection(self): @@ -141,7 +141,7 @@ def test_create_table(self): } ] params = { - 'TableName': 'ci-table', + 'TableName': 'Thread', 'ProvisionedThroughput': { 'WriteCapacityUnits': 1, 'ReadCapacityUnits': 1 @@ -283,7 +283,7 @@ def test_delete_table(self): """ Connection.delete_table """ - params = {'TableName': 'ci-table'} + params = {'TableName': 'Thread'} with patch(PATCH_METHOD) as req: req.return_value = None conn = Connection(self.region) @@ -308,7 +308,7 @@ def test_update_table(self): 'WriteCapacityUnits': 2, 'ReadCapacityUnits': 2 }, - 'TableName': 'ci-table' + 'TableName': 'Thread' } conn.update_table( self.test_table_name, @@ -341,7 +341,7 @@ def test_update_table(self): } ] params = { - 'TableName': 'ci-table', + 'TableName': 'Thread', 'ProvisionedThroughput': { 'ReadCapacityUnits': 2, 'WriteCapacityUnits': 2, @@ -375,7 +375,7 @@ def test_describe_table(self): req.return_value = DESCRIBE_TABLE_DATA conn = Connection(self.region) conn.describe_table(self.test_table_name) - self.assertEqual(req.call_args[0][1], {'TableName': 'ci-table'}) + self.assertEqual(req.call_args[0][1], {'TableName': 'Thread'}) with self.assertRaises(TableDoesNotExist): with patch(PATCH_METHOD) as req: @@ -383,12 +383,6 @@ def test_describe_table(self): conn = Connection(self.region) conn.describe_table(self.test_table_name) - with self.assertRaises(TableDoesNotExist): - with patch(PATCH_METHOD) as req: - req.side_effect = ValueError() - conn = Connection(self.region) - conn.describe_table(self.test_table_name) - def test_list_tables(self): """ Connection.list_tables @@ -422,9 +416,7 @@ def test_delete_item(self): Connection.delete_item """ conn = Connection(self.region) - with patch(PATCH_METHOD) as req: - req.return_value = DESCRIBE_TABLE_DATA - conn.describe_table(self.test_table_name) + conn.add_meta_table(MetaTable(DESCRIBE_TABLE_DATA[TABLE_KEY])) with patch(PATCH_METHOD) as req: req.side_effect = BotoCoreError @@ -575,9 +567,7 @@ def test_get_item(self): """ conn = Connection(self.region) table_name = 'Thread' - with patch(PATCH_METHOD) as req: - req.return_value = DESCRIBE_TABLE_DATA - conn.describe_table(table_name) + conn.add_meta_table(MetaTable(DESCRIBE_TABLE_DATA[TABLE_KEY])) with patch(PATCH_METHOD) as req: req.return_value = GET_ITEM_DATA @@ -627,21 +617,12 @@ def test_update_item(self): Connection.update_item """ conn = Connection() - with patch(PATCH_METHOD) as req: - req.return_value = DESCRIBE_TABLE_DATA - conn.describe_table(self.test_table_name) + conn.add_meta_table(MetaTable(DESCRIBE_TABLE_DATA[TABLE_KEY])) self.assertRaises(ValueError, conn.update_item, self.test_table_name, 'foo-key') self.assertRaises(ValueError, conn.update_item, self.test_table_name, 'foo', actions=[]) - attr_updates = { - 'Subject': { - 'Value': 'foo-subject', - 'Action': 'PUT' - }, - } - with patch(PATCH_METHOD) as req: req.return_value = {} conn.update_item( @@ -677,7 +658,7 @@ def test_update_item(self): 'S': 'foo-subject' } }, - 'TableName': 'ci-table' + 'TableName': 'Thread' } self.assertEqual(req.call_args[0][1], params) @@ -724,7 +705,7 @@ def test_update_item(self): } }, 'ReturnConsumedCapacity': 'TOTAL', - 'TableName': 'ci-table' + 'TableName': 'Thread' } self.assertEqual(req.call_args[0][1], params) @@ -744,9 +725,7 @@ def test_put_item(self): Connection.put_item """ conn = Connection(self.region) - with patch(PATCH_METHOD) as req: - req.return_value = DESCRIBE_TABLE_DATA - conn.describe_table(self.test_table_name) + conn.add_meta_table(MetaTable(DESCRIBE_TABLE_DATA[TABLE_KEY])) with patch(PATCH_METHOD) as req: req.side_effect = BotoCoreError @@ -943,9 +922,7 @@ def test_batch_write_item(self): conn.batch_write_item, table_name) - with patch(PATCH_METHOD) as req: - req.return_value = DESCRIBE_TABLE_DATA - conn.describe_table(table_name) + conn.add_meta_table(MetaTable(DESCRIBE_TABLE_DATA[TABLE_KEY])) with patch(PATCH_METHOD) as req: req.return_value = {} @@ -1072,9 +1049,7 @@ def test_batch_get_item(self): items.append( {"ForumName": "FooForum", "Subject": "thread-{}".format(i)} ) - with patch(PATCH_METHOD) as req: - req.return_value = DESCRIBE_TABLE_DATA - conn.describe_table(table_name) + conn.add_meta_table(MetaTable(DESCRIBE_TABLE_DATA[TABLE_KEY])) with patch(PATCH_METHOD) as req: req.side_effect = BotoCoreError @@ -1156,9 +1131,7 @@ def test_query(self): """ conn = Connection() table_name = 'Thread' - with patch(PATCH_METHOD) as req: - req.return_value = DESCRIBE_TABLE_DATA - conn.describe_table(table_name) + conn.add_meta_table(MetaTable(DESCRIBE_TABLE_DATA[TABLE_KEY])) with pytest.raises(ValueError, match="Table Thread has no index: NonExistentIndexName"): conn.query(table_name, "FooForum", limit=1, index_name='NonExistentIndexName') @@ -1291,9 +1264,7 @@ def test_scan(self): conn = Connection() table_name = 'Thread' - with patch(PATCH_METHOD) as req: - req.return_value = DESCRIBE_TABLE_DATA - conn.describe_table(table_name) + conn.add_meta_table(MetaTable(DESCRIBE_TABLE_DATA[TABLE_KEY])) with patch(PATCH_METHOD) as req: req.return_value = {} diff --git a/tests/test_model.py b/tests/test_model.py index 357fb157c..5d9598c22 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -30,7 +30,6 @@ BooleanAttribute, ListAttribute, TTLAttribute, VersionAttribute) from .data import ( MODEL_TABLE_DATA, GET_MODEL_ITEM_DATA, SIMPLE_MODEL_TABLE_DATA, - DESCRIBE_TABLE_DATA_PAY_PER_REQUEST, BATCH_GET_ITEMS, SIMPLE_BATCH_GET_ITEMS, COMPLEX_TABLE_DATA, COMPLEX_ITEM_DATA, INDEX_TABLE_DATA, LOCAL_INDEX_TABLE_DATA, DOG_TABLE_DATA, CUSTOM_ATTR_NAME_INDEX_TABLE_DATA, CUSTOM_ATTR_NAME_ITEM_DATA, @@ -556,24 +555,11 @@ def fake_dynamodb(*args): self.assertEqual(UserModel.Meta.read_capacity_units, 25) self.assertEqual(UserModel.Meta.write_capacity_units, 25) - # Test for wrong billing_mode - setattr(UserModel.Meta, 'billing_mode', 'WRONG') - with patch(PATCH_METHOD) as req: - req.return_value = MODEL_TABLE_DATA - self.assertRaises(ValueError) - delattr(UserModel.Meta, 'billing_mode') - # A table with billing_mode set as on_demand self.assertEqual(BillingModeOnDemandModel.Meta.billing_mode, 'PAY_PER_REQUEST') with patch(PATCH_METHOD) as req: - req.return_value = DESCRIBE_TABLE_DATA_PAY_PER_REQUEST + req.return_value = MODEL_TABLE_DATA BillingModeOnDemandModel.create_table(read_capacity_units=2, write_capacity_units=2) - self.assertEqual(BillingModeOnDemandModel._connection.get_meta_table().data - .get('BillingModeSummary', {}).get('BillingMode', None), 'PAY_PER_REQUEST') - self.assertEqual(BillingModeOnDemandModel._connection.get_meta_table().data - .get('ProvisionedThroughput', {}).get('ReadCapacityUnits', None), 0) - self.assertEqual(BillingModeOnDemandModel._connection.get_meta_table().data - .get('ProvisionedThroughput', {}).get('WriteCapacityUnits', None), 0) UserModel._connection = None @@ -683,12 +669,12 @@ def test_overidden_defaults(self): schema = CustomAttrNameModel._get_schema() correct_schema = { 'KeySchema': [ - {'key_type': 'HASH', 'attribute_name': 'user_name'}, - {'key_type': 'RANGE', 'attribute_name': 'user_id'} + {'KeyType': 'HASH', 'AttributeName': 'user_name'}, + {'KeyType': 'RANGE', 'AttributeName': 'user_id'} ], 'AttributeDefinitions': [ - {'attribute_type': 'S', 'attribute_name': 'user_name'}, - {'attribute_type': 'S', 'attribute_name': 'user_id'} + {'AttributeType': 'S', 'AttributeName': 'user_name'}, + {'AttributeType': 'S', 'AttributeName': 'user_id'} ] } self.assert_dict_lists_equal(correct_schema['KeySchema'], schema['key_schema']) @@ -2077,21 +2063,16 @@ def test_batch_write_with_unprocessed(self): { UNPROCESSED_ITEMS: { UserModel.Meta.table_name: unprocessed_items[:2], - } - }, - { - UNPROCESSED_ITEMS: { - UserModel.Meta.table_name: unprocessed_items[2:], - } + }, }, - {} + {}, ] with UserModel.batch_write() as batch: for item in items: batch.save(item) - self.assertEqual(len(req.mock_calls), 3) + self.assertEqual(len(req.mock_calls), 2) def test_batch_write_raises_put_error(self): items = [] @@ -2111,19 +2092,11 @@ def test_batch_write_raises_put_error(self): }) with patch(PATCH_METHOD) as req: - req.side_effect = [ - { - UNPROCESSED_ITEMS: { - BatchModel.Meta.table_name: unprocessed_items[:2], - } - }, - { - UNPROCESSED_ITEMS: { - BatchModel.Meta.table_name: unprocessed_items[2:], - } - }, - {} - ] + req.return_value = { + UNPROCESSED_ITEMS: { + BatchModel.Meta.table_name: unprocessed_items[2:], + } + } with self.assertRaises(PutError): with BatchModel.batch_write() as batch: for item in items: @@ -2428,9 +2401,9 @@ def test_local_index(self): } ], 'attribute_definitions': [ - {'attribute_type': 'S', 'attribute_name': 'user_name'}, - {'attribute_type': 'S', 'attribute_name': 'email'}, - {'attribute_type': 'NS', 'attribute_name': 'numbers'} + {'AttributeType': 'S', 'AttributeName': 'user_name'}, + {'AttributeType': 'S', 'AttributeName': 'email'}, + {'AttributeType': 'NS', 'AttributeName': 'numbers'} ] } self.assert_dict_lists_equal( diff --git a/tests/test_table_connection.py b/tests/test_table_connection.py index b5cb6694f..87e955fd6 100644 --- a/tests/test_table_connection.py +++ b/tests/test_table_connection.py @@ -5,6 +5,8 @@ from pynamodb.connection import TableConnection from pynamodb.constants import DEFAULT_REGION +from pynamodb.connection.base import MetaTable +from pynamodb.constants import TABLE_KEY from pynamodb.expressions.operand import Path from .data import DESCRIBE_TABLE_DATA, GET_ITEM_DATA from .response import HttpOK @@ -20,19 +22,20 @@ class ConnectionTestCase(TestCase): """ def setUp(self): - self.test_table_name = 'ci-table' + self.test_table_name = 'Thread' self.region = DEFAULT_REGION def test_create_connection(self): """ TableConnection() """ - conn = TableConnection(self.test_table_name) + conn = TableConnection(self.test_table_name, meta_table=MetaTable(DESCRIBE_TABLE_DATA[TABLE_KEY])) self.assertIsNotNone(conn) def test_connection_session_set_credentials(self): conn = TableConnection( self.test_table_name, + meta_table=MetaTable(DESCRIBE_TABLE_DATA[TABLE_KEY]), aws_access_key_id='access_key_id', aws_secret_access_key='secret_access_key') @@ -44,6 +47,7 @@ def test_connection_session_set_credentials(self): def test_connection_session_set_credentials_with_session_token(self): conn = TableConnection( self.test_table_name, + meta_table=MetaTable(DESCRIBE_TABLE_DATA[TABLE_KEY]), aws_access_key_id='access_key_id', aws_secret_access_key='secret_access_key', aws_session_token='session_token') @@ -58,7 +62,7 @@ def test_create_table(self): """ TableConnection.create_table """ - conn = TableConnection(self.test_table_name) + conn = TableConnection(self.test_table_name, meta_table=MetaTable(DESCRIBE_TABLE_DATA[TABLE_KEY])) kwargs = { 'read_capacity_units': 1, 'write_capacity_units': 1, @@ -86,7 +90,7 @@ def test_create_table(self): } ] params = { - 'TableName': 'ci-table', + 'TableName': 'Thread', 'ProvisionedThroughput': { 'WriteCapacityUnits': 1, 'ReadCapacityUnits': 1 @@ -121,7 +125,7 @@ def test_create_table(self): self.assertEqual(kwargs, params) def test_create_table_with_tags(self): - conn = TableConnection(self.test_table_name) + conn = TableConnection(self.test_table_name, meta_table=MetaTable(DESCRIBE_TABLE_DATA[TABLE_KEY])) kwargs = { 'read_capacity_units': 1, 'write_capacity_units': 1, @@ -151,7 +155,7 @@ def test_create_table_with_tags(self): } } params = { - 'TableName': 'ci-table', + 'TableName': 'Thread', 'ProvisionedThroughput': { 'WriteCapacityUnits': 1, 'ReadCapacityUnits': 1 @@ -200,7 +204,7 @@ def test_update_time_to_live(self): TableConnection.update_time_to_live """ params = { - 'TableName': 'ci-table', + 'TableName': 'Thread', 'TimeToLiveSpecification': { 'AttributeName': 'ttl_attr', 'Enabled': True, @@ -208,7 +212,7 @@ def test_update_time_to_live(self): } with patch(PATCH_METHOD) as req: req.return_value = HttpOK(), None - conn = TableConnection(self.test_table_name) + conn = TableConnection(self.test_table_name, meta_table=MetaTable(DESCRIBE_TABLE_DATA[TABLE_KEY])) conn.update_time_to_live('ttl_attr') kwargs = req.call_args[0][1] self.assertEqual(kwargs, params) @@ -217,10 +221,10 @@ def test_delete_table(self): """ TableConnection.delete_table """ - params = {'TableName': 'ci-table'} + params = {'TableName': 'Thread'} with patch(PATCH_METHOD) as req: req.return_value = HttpOK(), None - conn = TableConnection(self.test_table_name) + conn = TableConnection(self.test_table_name, meta_table=MetaTable(DESCRIBE_TABLE_DATA[TABLE_KEY])) conn.delete_table() kwargs = req.call_args[0][1] self.assertEqual(kwargs, params) @@ -231,7 +235,7 @@ def test_update_table(self): """ with patch(PATCH_METHOD) as req: req.return_value = HttpOK(), None - conn = TableConnection(self.test_table_name) + conn = TableConnection(self.test_table_name, meta_table=MetaTable(DESCRIBE_TABLE_DATA[TABLE_KEY])) params = { 'ProvisionedThroughput': { 'WriteCapacityUnits': 2, @@ -247,7 +251,7 @@ def test_update_table(self): with patch(PATCH_METHOD) as req: req.return_value = HttpOK(), None - conn = TableConnection(self.test_table_name) + conn = TableConnection(self.test_table_name, meta_table=MetaTable(DESCRIBE_TABLE_DATA[TABLE_KEY])) global_secondary_index_updates = [ { @@ -288,19 +292,16 @@ def test_describe_table(self): """ with patch(PATCH_METHOD) as req: req.return_value = DESCRIBE_TABLE_DATA - conn = TableConnection(self.test_table_name) - conn.describe_table() - self.assertEqual(conn.table_name, self.test_table_name) - self.assertEqual(req.call_args[0][1], {'TableName': 'ci-table'}) + conn = TableConnection(self.test_table_name, meta_table=MetaTable(DESCRIBE_TABLE_DATA[TABLE_KEY])) + data = conn.describe_table() + self.assertEqual(data, DESCRIBE_TABLE_DATA[TABLE_KEY]) + self.assertEqual(req.call_args[0][1], {'TableName': 'Thread'}) def test_delete_item(self): """ TableConnection.delete_item """ - conn = TableConnection(self.test_table_name) - with patch(PATCH_METHOD) as req: - req.return_value = DESCRIBE_TABLE_DATA - conn.describe_table() + conn = TableConnection(self.test_table_name, meta_table=MetaTable(DESCRIBE_TABLE_DATA[TABLE_KEY])) with patch(PATCH_METHOD) as req: req.return_value = {} @@ -325,17 +326,7 @@ def test_update_item(self): """ TableConnection.update_item """ - conn = TableConnection(self.test_table_name) - with patch(PATCH_METHOD) as req: - req.return_value = DESCRIBE_TABLE_DATA - conn.describe_table() - - attr_updates = { - 'Subject': { - 'Value': 'foo-subject', - 'Action': 'PUT' - }, - } + conn = TableConnection(self.test_table_name, meta_table=MetaTable(DESCRIBE_TABLE_DATA[TABLE_KEY])) with patch(PATCH_METHOD) as req: req.return_value = HttpOK(), {} @@ -363,7 +354,7 @@ def test_update_item(self): } }, 'ReturnConsumedCapacity': 'TOTAL', - 'TableName': 'ci-table' + 'TableName': 'Thread' } self.assertEqual(req.call_args[0][1], params) @@ -371,10 +362,7 @@ def test_get_item(self): """ TableConnection.get_item """ - conn = TableConnection(self.test_table_name) - with patch(PATCH_METHOD) as req: - req.return_value = DESCRIBE_TABLE_DATA - conn.describe_table() + conn = TableConnection(self.test_table_name, meta_table=MetaTable(DESCRIBE_TABLE_DATA[TABLE_KEY])) with patch(PATCH_METHOD) as req: req.return_value = GET_ITEM_DATA @@ -385,10 +373,7 @@ def test_put_item(self): """ TableConnection.put_item """ - conn = TableConnection(self.test_table_name) - with patch(PATCH_METHOD) as req: - req.return_value = DESCRIBE_TABLE_DATA - conn.describe_table() + conn = TableConnection(self.test_table_name, meta_table=MetaTable(DESCRIBE_TABLE_DATA[TABLE_KEY])) with patch(PATCH_METHOD) as req: req.return_value = {} @@ -456,14 +441,11 @@ def test_batch_write_item(self): TableConnection.batch_write_item """ items = [] - conn = TableConnection(self.test_table_name) + conn = TableConnection(self.test_table_name, meta_table=MetaTable(DESCRIBE_TABLE_DATA[TABLE_KEY])) for i in range(10): items.append( {"ForumName": "FooForum", "Subject": "thread-{}".format(i)} ) - with patch(PATCH_METHOD) as req: - req.return_value = DESCRIBE_TABLE_DATA - conn.describe_table() with patch(PATCH_METHOD) as req: req.return_value = {} conn.batch_write_item( @@ -493,14 +475,11 @@ def test_batch_get_item(self): TableConnection.batch_get_item """ items = [] - conn = TableConnection(self.test_table_name) + conn = TableConnection(self.test_table_name, meta_table=MetaTable(DESCRIBE_TABLE_DATA[TABLE_KEY])) for i in range(10): items.append( {"ForumName": "FooForum", "Subject": "thread-{}".format(i)} ) - with patch(PATCH_METHOD) as req: - req.return_value = DESCRIBE_TABLE_DATA - conn.describe_table() with patch(PATCH_METHOD) as req: req.return_value = {} @@ -532,10 +511,7 @@ def test_query(self): """ TableConnection.query """ - conn = TableConnection(self.test_table_name) - with patch(PATCH_METHOD) as req: - req.return_value = DESCRIBE_TABLE_DATA - conn.describe_table() + conn = TableConnection(self.test_table_name, meta_table=MetaTable(DESCRIBE_TABLE_DATA[TABLE_KEY])) with patch(PATCH_METHOD) as req: req.return_value = {} @@ -566,10 +542,7 @@ def test_scan(self): """ TableConnection.scan """ - conn = TableConnection(self.test_table_name) - with patch(PATCH_METHOD) as req: - req.return_value = DESCRIBE_TABLE_DATA - conn.describe_table() + conn = TableConnection(self.test_table_name, meta_table=MetaTable(DESCRIBE_TABLE_DATA[TABLE_KEY])) with patch(PATCH_METHOD) as req: req.return_value = HttpOK(), {} conn.scan() diff --git a/tests/test_transaction.py b/tests/test_transaction.py index 608bc90c8..71e6cd8f3 100644 --- a/tests/test_transaction.py +++ b/tests/test_transaction.py @@ -2,6 +2,8 @@ from pynamodb.attributes import NumberAttribute, UnicodeAttribute, VersionAttribute from pynamodb.connection import Connection +from pynamodb.connection.base import MetaTable +from pynamodb.constants import TABLE_KEY from pynamodb.transactions import Transaction, TransactGet, TransactWrite from pynamodb.models import Model from tests.test_base_connection import PATCH_METHOD @@ -21,24 +23,24 @@ class Meta: MOCK_TABLE_DESCRIPTOR = { "Table": { - "TableName": "Mock", + "TableName": "mock", "KeySchema": [ { - "AttributeName": "MockHash", + "AttributeName": "mock_hash", "KeyType": "HASH" }, { - "AttributeName": "MockRange", + "AttributeName": "mock_range", "KeyType": "RANGE" } ], "AttributeDefinitions": [ { - "AttributeName": "MockHash", + "AttributeName": "mock_hash", "AttributeType": "N" }, { - "AttributeName": "MockRange", + "AttributeName": "mock_range", "AttributeType": "N" } ] @@ -58,15 +60,15 @@ class TestTransactGet: def test_commit(self, mocker): connection = Connection() + connection.add_meta_table(MetaTable(MOCK_TABLE_DESCRIPTOR[TABLE_KEY])) + mock_connection_transact_get = mocker.patch.object(connection, 'transact_get_items') - with patch(PATCH_METHOD) as req: - req.return_value = MOCK_TABLE_DESCRIPTOR - with TransactGet(connection=connection) as t: - t.get(MockModel, 1, 2) + with TransactGet(connection=connection) as t: + t.get(MockModel, 1, 2) mock_connection_transact_get.assert_called_once_with( - get_items=[{'Key': {'MockHash': {'N': '1'}, 'MockRange': {'N': '2'}}, 'TableName': 'mock'}], + get_items=[{'Key': {'mock_hash': {'N': '1'}, 'mock_range': {'N': '2'}}, 'TableName': 'mock'}], return_consumed_capacity=None ) @@ -92,25 +94,25 @@ def test_commit(self, mocker): expected_condition_checks = [{ 'ConditionExpression': 'attribute_not_exists (#0)', 'ExpressionAttributeNames': {'#0': 'mock_hash'}, - 'Key': {'MockHash': {'N': '1'}, 'MockRange': {'N': '3'}}, + 'Key': {'mock_hash': {'N': '1'}, 'mock_range': {'N': '3'}}, 'TableName': 'mock'} ] expected_deletes = [{ 'ConditionExpression': 'attribute_not_exists (#0)', 'ExpressionAttributeNames': {'#0': 'mock_version'}, - 'Key': {'MockHash': {'N': '2'}, 'MockRange': {'N': '4'}}, + 'Key': {'mock_hash': {'N': '2'}, 'mock_range': {'N': '4'}}, 'TableName': 'mock' }] expected_puts = [{ 'ConditionExpression': 'attribute_not_exists (#0)', 'ExpressionAttributeNames': {'#0': 'mock_version'}, - 'Item': {'MockHash': {'N': '3'}, 'MockRange': {'N': '5'}, 'mock_version': {'N': '1'}}, + 'Item': {'mock_hash': {'N': '3'}, 'mock_range': {'N': '5'}, 'mock_version': {'N': '1'}}, 'TableName': 'mock' }] expected_updates = [{ 'ConditionExpression': 'attribute_not_exists (#0)', 'TableName': 'mock', - 'Key': {'MockHash': {'N': '4'}, 'MockRange': {'N': '6'}}, + 'Key': {'mock_hash': {'N': '4'}, 'mock_range': {'N': '6'}}, 'ReturnValuesOnConditionCheckFailure': 'ALL_OLD', 'UpdateExpression': 'SET #1 = :0, #0 = :1', 'ExpressionAttributeNames': {'#0': 'mock_version', '#1': 'mock_toot'},