diff --git a/pynamodb/connection/base.py b/pynamodb/connection/base.py index d0181a81..0b70cc1f 100644 --- a/pynamodb/connection/base.py +++ b/pynamodb/connection/base.py @@ -331,10 +331,9 @@ def send_pre_boto_callback(self, operation_name, req_uuid, table_name): except Exception: log.exception("pre_boto callback threw an exception.") - def _before_sign(self, request, **_) -> None: + def _before_send(self, request, **_) -> None: if self._extra_headers is not None: - for k, v in self._extra_headers.items(): - request.headers.add_header(k, v) + request.headers.update(self._extra_headers) def _make_api_call(self, operation_name: str, operation_kwargs: Dict) -> Dict: try: @@ -412,7 +411,7 @@ def client(self) -> BotocoreBaseClientPrivate: ) self._client = cast(BotocoreBaseClientPrivate, self.session.create_client(SERVICE_NAME, self.region, endpoint_url=self.host, config=config)) - self._client.meta.events.register_first('before-sign.*.*', self._before_sign) + self._client.meta.events.register_first('before-send.*.*', self._before_send) return self._client def add_meta_table(self, meta_table: MetaTable) -> None: diff --git a/tests/test_base_connection.py b/tests/test_base_connection.py index 7e5f5223..0510374a 100644 --- a/tests/test_base_connection.py +++ b/tests/test_base_connection.py @@ -3,6 +3,8 @@ """ import base64 import json +from uuid import UUID + import urllib3 from unittest import mock from unittest.mock import patch @@ -1516,11 +1518,14 @@ def test_connection__botocore_config(): @mock.patch('botocore.httpsession.URLLib3Session.send') -def test_connection_make_api_call___extra_headers(send_mock): +def test_connection_make_api_call___extra_headers(send_mock, mocker): good_response = mock.Mock(spec=AWSResponse, status_code=200, headers={}, text='{}', content=b'{}') send_mock.return_value = good_response + # return constant UUID + mocker.patch('uuid.uuid4', return_value=UUID('01FC4BDB-B223-4B86-88F4-DEE79B77F275')) + c = Connection(extra_headers={'foo': 'bar'}, max_retry_attempts=0) c._make_api_call( 'DescribeTable', @@ -1529,7 +1534,18 @@ def test_connection_make_api_call___extra_headers(send_mock): assert send_mock.call_count == 1 request = send_mock.call_args[0][0] - assert request.headers.get('foo').decode() == 'bar' + assert request.headers['foo'] == 'bar' + + c = Connection(extra_headers={'foo': 'baz'}, max_retry_attempts=0) + c._make_api_call( + 'DescribeTable', + {'TableName': 'MyTable'}, + ) + + assert send_mock.call_count == 2 + request2 = send_mock.call_args[0][0] + # all headers, including signatures, and except 'foo', should match + assert {**request.headers, 'foo': ''} == {**request2.headers, 'foo': ''} @mock.patch('botocore.httpsession.URLLib3Session.send')