From d7d8ce80bfaa235ec6caa7e77d4b9b98d6c8d246 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Thu, 5 Sep 2024 11:20:03 -0400 Subject: [PATCH 01/29] PYTHON-3193 - Add ResourceWarning for unclosed MongoClients in __del__ --- pymongo/asynchronous/mongo_client.py | 19 +++++++++++++++++++ pymongo/synchronous/mongo_client.py | 19 +++++++++++++++++++ pyproject.toml | 3 +++ 3 files changed, 41 insertions(+) diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index 2af773c440..187f934729 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -34,6 +34,7 @@ import contextlib import os +import warnings import weakref from collections import defaultdict from typing import ( @@ -864,6 +865,7 @@ def __init__( ) self._opened = False + self._closed = False self._init_background() if _IS_SYNC and connect: @@ -1173,6 +1175,22 @@ def __getitem__(self, name: str) -> database.AsyncDatabase[_DocumentType]: """ return database.AsyncDatabase(self, name) + def __del__(self) -> None: + """Check that this AsyncMongoClient has been closed and issue a warning if not.""" + # TODO: Remove in https://jira.mongodb.org/browse/PYTHON-4731 + try: + if not self._closed: + warnings.warn( + f"Unclosed {self}", + ResourceWarning, + stacklevel=2, + source=self, + ) + if _IS_SYNC and self._opened: + self.close() + except AttributeError: + pass + def _close_cursor_soon( self, cursor_id: int, @@ -1540,6 +1558,7 @@ async def close(self) -> None: if self._encrypter: # TODO: PYTHON-1921 Encrypted MongoClients cannot be re-opened. await self._encrypter.close() + self._closed = True if not _IS_SYNC: # Add support for contextlib.aclosing. diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 6c5f68b7eb..0e0a64b238 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -34,6 +34,7 @@ import contextlib import os +import warnings import weakref from collections import defaultdict from typing import ( @@ -863,6 +864,7 @@ def __init__( ) self._opened = False + self._closed = False self._init_background() if _IS_SYNC and connect: @@ -1172,6 +1174,22 @@ def __getitem__(self, name: str) -> database.Database[_DocumentType]: """ return database.Database(self, name) + def __del__(self) -> None: + """Check that this MongoClient has been closed and issue a warning if not.""" + # TODO: Remove in https://jira.mongodb.org/browse/PYTHON-4731 + try: + if not self._closed: + warnings.warn( + f"Unclosed {self}", + ResourceWarning, + stacklevel=2, + source=self, + ) + if _IS_SYNC and self._opened: + self.close() + except AttributeError: + pass + def _close_cursor_soon( self, cursor_id: int, @@ -1535,6 +1553,7 @@ def close(self) -> None: if self._encrypter: # TODO: PYTHON-1921 Encrypted MongoClients cannot be re-opened. self._encrypter.close() + self._closed = True if not _IS_SYNC: # Add support for contextlib.closing. diff --git a/pyproject.toml b/pyproject.toml index 8452bfe956..9147c623f1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -94,6 +94,9 @@ filterwarnings = [ "module:please use dns.resolver.Resolver.resolve:DeprecationWarning", # https://github.com/dateutil/dateutil/issues/1314 "module:datetime.datetime.utc:DeprecationWarning:dateutil", + # TODO: Remove both of these in https://jira.mongodb.org/browse/PYTHON-4731 + "ignore:Unclosed AsyncMongoClient*", + "ignore:Unclosed MongoClient*", ] markers = [ "auth_aws: tests that rely on pymongo-auth-aws", From f5fbeef16316c5d06321194df87e00d38779ea9b Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Thu, 5 Sep 2024 11:34:23 -0400 Subject: [PATCH 02/29] Fix typecheck --- pymongo/asynchronous/mongo_client.py | 2 +- pymongo/synchronous/mongo_client.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index 187f934729..a542650dc4 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -1187,7 +1187,7 @@ def __del__(self) -> None: source=self, ) if _IS_SYNC and self._opened: - self.close() + self.close() # type: ignore[unused-coroutine] except AttributeError: pass diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 0e0a64b238..3af7d48c55 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -1186,7 +1186,7 @@ def __del__(self) -> None: source=self, ) if _IS_SYNC and self._opened: - self.close() + self.close() # type: ignore[unused-coroutine] except AttributeError: pass From 2bac24c66c9ddf7bbebe9a9fbf19a47b7b918454 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Thu, 5 Sep 2024 14:07:32 -0400 Subject: [PATCH 03/29] Address review --- pymongo/asynchronous/mongo_client.py | 14 +++++++++++--- pymongo/synchronous/mongo_client.py | 14 +++++++++++--- 2 files changed, 22 insertions(+), 6 deletions(-) diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index a542650dc4..3574097516 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -1180,14 +1180,22 @@ def __del__(self) -> None: # TODO: Remove in https://jira.mongodb.org/browse/PYTHON-4731 try: if not self._closed: + if _IS_SYNC: + msg = ( + f"Unclosed {type(self)}. " + f"Call {type(self)}.close() to safely shut down your client and free up resources." + ) + else: + msg = ( + f"Unclosed {type(self)}. " + f"Call await {type(self)}.close() to safely shut down your client and free up resources." + ) warnings.warn( - f"Unclosed {self}", + msg, ResourceWarning, stacklevel=2, source=self, ) - if _IS_SYNC and self._opened: - self.close() # type: ignore[unused-coroutine] except AttributeError: pass diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 3af7d48c55..59bb94e11a 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -1179,14 +1179,22 @@ def __del__(self) -> None: # TODO: Remove in https://jira.mongodb.org/browse/PYTHON-4731 try: if not self._closed: + if _IS_SYNC: + msg = ( + f"Unclosed {type(self)}. " + f"Call {type(self)}.close() to safely shut down your client and free up resources." + ) + else: + msg = ( + f"Unclosed {type(self)}. " + f"Call {type(self)}.close() to safely shut down your client and free up resources." + ) warnings.warn( - f"Unclosed {self}", + msg, ResourceWarning, stacklevel=2, source=self, ) - if _IS_SYNC and self._opened: - self.close() # type: ignore[unused-coroutine] except AttributeError: pass From c9f99a07f0c1d001431de3fd4714edd3c7d5b7e3 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Fri, 6 Sep 2024 10:11:57 -0400 Subject: [PATCH 04/29] Add traceback --- pymongo/asynchronous/mongo_client.py | 16 ++++------------ pymongo/synchronous/mongo_client.py | 16 ++++------------ 2 files changed, 8 insertions(+), 24 deletions(-) diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index 3574097516..bbfb39ebb0 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -1177,21 +1177,13 @@ def __getitem__(self, name: str) -> database.AsyncDatabase[_DocumentType]: def __del__(self) -> None: """Check that this AsyncMongoClient has been closed and issue a warning if not.""" - # TODO: Remove in https://jira.mongodb.org/browse/PYTHON-4731 try: if not self._closed: - if _IS_SYNC: - msg = ( - f"Unclosed {type(self)}. " - f"Call {type(self)}.close() to safely shut down your client and free up resources." - ) - else: - msg = ( - f"Unclosed {type(self)}. " - f"Call await {type(self)}.close() to safely shut down your client and free up resources." - ) warnings.warn( - msg, + ( + f"Unclosed {type(self).__name__} opened at:\n{self._topology_settings._stack}" + f"Call {type(self).__name__}.close() to safely shut down your client and free up resources." + ), ResourceWarning, stacklevel=2, source=self, diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 59bb94e11a..1da818fb8c 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -1176,21 +1176,13 @@ def __getitem__(self, name: str) -> database.Database[_DocumentType]: def __del__(self) -> None: """Check that this MongoClient has been closed and issue a warning if not.""" - # TODO: Remove in https://jira.mongodb.org/browse/PYTHON-4731 try: if not self._closed: - if _IS_SYNC: - msg = ( - f"Unclosed {type(self)}. " - f"Call {type(self)}.close() to safely shut down your client and free up resources." - ) - else: - msg = ( - f"Unclosed {type(self)}. " - f"Call {type(self)}.close() to safely shut down your client and free up resources." - ) warnings.warn( - msg, + ( + f"Unclosed {type(self).__name__} opened at:\n{self._topology_settings._stack}" + f"Call {type(self).__name__}.close() to safely shut down your client and free up resources." + ), ResourceWarning, stacklevel=2, source=self, From 2fcf40b834a995fc7a0159de2f64d99c9058dd70 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Fri, 6 Sep 2024 12:04:46 -0400 Subject: [PATCH 05/29] WIP - done with test_client --- pyproject.toml | 3 - test/asynchronous/test_client.py | 573 ++++++++++++++++++------------- test/test_client.py | 516 ++++++++++++++++------------ 3 files changed, 628 insertions(+), 464 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 9147c623f1..8452bfe956 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -94,9 +94,6 @@ filterwarnings = [ "module:please use dns.resolver.Resolver.resolve:DeprecationWarning", # https://github.com/dateutil/dateutil/issues/1314 "module:datetime.datetime.utc:DeprecationWarning:dateutil", - # TODO: Remove both of these in https://jira.mongodb.org/browse/PYTHON-4731 - "ignore:Unclosed AsyncMongoClient*", - "ignore:Unclosed MongoClient*", ] markers = [ "auth_aws: tests that rely on pymongo-auth-aws", diff --git a/test/asynchronous/test_client.py b/test/asynchronous/test_client.py index d4f09cde33..9dbf9e3cb7 100644 --- a/test/asynchronous/test_client.py +++ b/test/asynchronous/test_client.py @@ -36,6 +36,7 @@ from unittest.mock import patch import pytest +import pytest_asyncio from pymongo.operations import _Op @@ -143,8 +144,8 @@ async def _tearDown_class(cls): def inject_fixtures(self, caplog): self._caplog = caplog - def test_keyword_arg_defaults(self): - client = AsyncMongoClient( + async def test_keyword_arg_defaults(self): + async with AsyncMongoClient( socketTimeoutMS=None, connectTimeoutMS=20000, waitQueueTimeoutMS=None, @@ -156,34 +157,36 @@ def test_keyword_arg_defaults(self): tlsCAFile=None, connect=False, serverSelectionTimeoutMS=12000, - ) - - options = client.options - pool_opts = options.pool_options - self.assertEqual(None, pool_opts.socket_timeout) - # socket.Socket.settimeout takes a float in seconds - self.assertEqual(20.0, pool_opts.connect_timeout) - self.assertEqual(None, pool_opts.wait_queue_timeout) - self.assertEqual(None, pool_opts._ssl_context) - self.assertEqual(None, options.replica_set_name) - self.assertEqual(ReadPreference.PRIMARY, client.read_preference) - self.assertAlmostEqual(12, client.options.server_selection_timeout) - - def test_connect_timeout(self): + ) as client: + options = client.options + pool_opts = options.pool_options + self.assertEqual(None, pool_opts.socket_timeout) + # socket.Socket.settimeout takes a float in seconds + self.assertEqual(20.0, pool_opts.connect_timeout) + self.assertEqual(None, pool_opts.wait_queue_timeout) + self.assertEqual(None, pool_opts._ssl_context) + self.assertEqual(None, options.replica_set_name) + self.assertEqual(ReadPreference.PRIMARY, client.read_preference) + self.assertAlmostEqual(12, client.options.server_selection_timeout) + + async def test_connect_timeout(self): client = AsyncMongoClient(connect=False, connectTimeoutMS=None, socketTimeoutMS=None) pool_opts = client.options.pool_options self.assertEqual(None, pool_opts.socket_timeout) self.assertEqual(None, pool_opts.connect_timeout) + await client.close() client = AsyncMongoClient(connect=False, connectTimeoutMS=0, socketTimeoutMS=0) pool_opts = client.options.pool_options self.assertEqual(None, pool_opts.socket_timeout) self.assertEqual(None, pool_opts.connect_timeout) + await client.close() client = AsyncMongoClient( "mongodb://localhost/?connectTimeoutMS=0&socketTimeoutMS=0", connect=False ) pool_opts = client.options.pool_options self.assertEqual(None, pool_opts.socket_timeout) self.assertEqual(None, pool_opts.connect_timeout) + await client.close() def test_types(self): self.assertRaises(TypeError, AsyncMongoClient, 1) @@ -194,8 +197,9 @@ def test_types(self): self.assertRaises(ConfigurationError, AsyncMongoClient, []) - def test_max_pool_size_zero(self): - AsyncMongoClient(maxPoolSize=0) + async def test_max_pool_size_zero(self): + async with AsyncMongoClient(maxPoolSize=0): + pass def test_uri_detection(self): self.assertRaises(ConfigurationError, AsyncMongoClient, "/foo") @@ -260,36 +264,38 @@ def test_iteration(self): self.assertNotIsInstance(client, Iterable) async def test_get_default_database(self): - c = await async_rs_or_single_client( + async with await async_rs_or_single_client( "mongodb://%s:%d/foo" % (await async_client_context.host, await async_client_context.port), connect=False, - ) - self.assertEqual(AsyncDatabase(c, "foo"), c.get_default_database()) - # Test that default doesn't override the URI value. - self.assertEqual(AsyncDatabase(c, "foo"), c.get_default_database("bar")) - - codec_options = CodecOptions(tz_aware=True) - write_concern = WriteConcern(w=2, j=True) - db = c.get_default_database(None, codec_options, ReadPreference.SECONDARY, write_concern) - self.assertEqual("foo", db.name) - self.assertEqual(codec_options, db.codec_options) - self.assertEqual(ReadPreference.SECONDARY, db.read_preference) - self.assertEqual(write_concern, db.write_concern) + ) as c: + self.assertEqual(AsyncDatabase(c, "foo"), c.get_default_database()) + # Test that default doesn't override the URI value. + self.assertEqual(AsyncDatabase(c, "foo"), c.get_default_database("bar")) + + codec_options = CodecOptions(tz_aware=True) + write_concern = WriteConcern(w=2, j=True) + db = c.get_default_database( + None, codec_options, ReadPreference.SECONDARY, write_concern + ) + self.assertEqual("foo", db.name) + self.assertEqual(codec_options, db.codec_options) + self.assertEqual(ReadPreference.SECONDARY, db.read_preference) + self.assertEqual(write_concern, db.write_concern) - c = await async_rs_or_single_client( + async with await async_rs_or_single_client( "mongodb://%s:%d/" % (await async_client_context.host, await async_client_context.port), connect=False, - ) - self.assertEqual(AsyncDatabase(c, "foo"), c.get_default_database("foo")) + ) as c: + self.assertEqual(AsyncDatabase(c, "foo"), c.get_default_database("foo")) async def test_get_default_database_error(self): # URI with no database. - c = await async_rs_or_single_client( + async with await async_rs_or_single_client( "mongodb://%s:%d/" % (await async_client_context.host, await async_client_context.port), connect=False, - ) - self.assertRaises(ConfigurationError, c.get_default_database) + ) as c: + self.assertRaises(ConfigurationError, c.get_default_database) async def test_get_default_database_with_authsource(self): # Ensure we distinguish database name from authSource. @@ -297,16 +303,16 @@ async def test_get_default_database_with_authsource(self): await async_client_context.host, await async_client_context.port, ) - c = await async_rs_or_single_client(uri, connect=False) - self.assertEqual(AsyncDatabase(c, "foo"), c.get_default_database()) + async with await async_rs_or_single_client(uri, connect=False) as c: + self.assertEqual(AsyncDatabase(c, "foo"), c.get_default_database()) async def test_get_database_default(self): - c = await async_rs_or_single_client( + async with await async_rs_or_single_client( "mongodb://%s:%d/foo" % (await async_client_context.host, await async_client_context.port), connect=False, - ) - self.assertEqual(AsyncDatabase(c, "foo"), c.get_database()) + ) as c: + self.assertEqual(AsyncDatabase(c, "foo"), c.get_database()) async def test_get_database_default_error(self): # URI with no database. @@ -315,6 +321,7 @@ async def test_get_database_default_error(self): connect=False, ) self.assertRaises(ConfigurationError, c.get_database) + await c.close() async def test_get_database_default_with_authsource(self): # Ensure we distinguish database name from authSource. @@ -324,6 +331,7 @@ async def test_get_database_default_with_authsource(self): ) c = await async_rs_or_single_client(uri, connect=False) self.assertEqual(AsyncDatabase(c, "foo"), c.get_database()) + await c.close() def test_primary_read_pref_with_tags(self): # No tags allowed with "primary". @@ -334,75 +342,84 @@ def test_primary_read_pref_with_tags(self): AsyncMongoClient("mongodb://host/?readpreference=primary&readpreferencetags=dc:east") async def test_read_preference(self): - c = await async_rs_or_single_client( + async with await async_rs_or_single_client( "mongodb://host", connect=False, readpreference=ReadPreference.NEAREST.mongos_mode - ) - self.assertEqual(c.read_preference, ReadPreference.NEAREST) + ) as c: + self.assertEqual(c.read_preference, ReadPreference.NEAREST) - def test_metadata(self): + async def test_metadata(self): metadata = copy.deepcopy(_METADATA) metadata["driver"]["name"] = "PyMongo|async" metadata["application"] = {"name": "foobar"} - client = AsyncMongoClient("mongodb://foo:27017/?appname=foobar&connect=false") - options = client.options - self.assertEqual(options.pool_options.metadata, metadata) - client = AsyncMongoClient("foo", 27017, appname="foobar", connect=False) - options = client.options - self.assertEqual(options.pool_options.metadata, metadata) + async with AsyncMongoClient("mongodb://foo:27017/?appname=foobar&connect=false") as client: + options = client.options + self.assertEqual(options.pool_options.metadata, metadata) + async with AsyncMongoClient("foo", 27017, appname="foobar", connect=False) as client: + options = client.options + self.assertEqual(options.pool_options.metadata, metadata) # No error - AsyncMongoClient(appname="x" * 128) - self.assertRaises(ValueError, AsyncMongoClient, appname="x" * 129) + async with AsyncMongoClient(appname="x" * 128): + pass + with self.assertRaises(ValueError): + async with AsyncMongoClient(appname="x" * 129): + pass # Bad "driver" options. self.assertRaises(TypeError, DriverInfo, "Foo", 1, "a") self.assertRaises(TypeError, DriverInfo, version="1", platform="a") self.assertRaises(TypeError, DriverInfo) - self.assertRaises(TypeError, AsyncMongoClient, driver=1) - self.assertRaises(TypeError, AsyncMongoClient, driver="abc") - self.assertRaises(TypeError, AsyncMongoClient, driver=("Foo", "1", "a")) + with self.assertRaises(TypeError): + async with AsyncMongoClient(driver=1): + pass + with self.assertRaises(TypeError): + async with AsyncMongoClient(driver="abc"): + pass + with self.assertRaises(TypeError): + async with AsyncMongoClient(driver=("Foo", "1", "a")): + pass # Test appending to driver info. metadata["driver"]["name"] = "PyMongo|async|FooDriver" metadata["driver"]["version"] = "{}|1.2.3".format(_METADATA["driver"]["version"]) - client = AsyncMongoClient( + async with AsyncMongoClient( "foo", 27017, appname="foobar", driver=DriverInfo("FooDriver", "1.2.3", None), connect=False, - ) - options = client.options - self.assertEqual(options.pool_options.metadata, metadata) + ) as client: + options = client.options + self.assertEqual(options.pool_options.metadata, metadata) metadata["platform"] = "{}|FooPlatform".format(_METADATA["platform"]) - client = AsyncMongoClient( + async with AsyncMongoClient( "foo", 27017, appname="foobar", driver=DriverInfo("FooDriver", "1.2.3", "FooPlatform"), connect=False, - ) - options = client.options - self.assertEqual(options.pool_options.metadata, metadata) + ) as client: + options = client.options + self.assertEqual(options.pool_options.metadata, metadata) # Test truncating driver info metadata. - client = AsyncMongoClient( + async with AsyncMongoClient( driver=DriverInfo(name="s" * _MAX_METADATA_SIZE), connect=False, - ) - options = client.options - self.assertLessEqual( - len(bson.encode(options.pool_options.metadata)), - _MAX_METADATA_SIZE, - ) - client = AsyncMongoClient( + ) as client: + options = client.options + self.assertLessEqual( + len(bson.encode(options.pool_options.metadata)), + _MAX_METADATA_SIZE, + ) + async with AsyncMongoClient( driver=DriverInfo(name="s" * _MAX_METADATA_SIZE, version="s" * _MAX_METADATA_SIZE), connect=False, - ) - options = client.options - self.assertLessEqual( - len(bson.encode(options.pool_options.metadata)), - _MAX_METADATA_SIZE, - ) + ) as client: + options = client.options + self.assertLessEqual( + len(bson.encode(options.pool_options.metadata)), + _MAX_METADATA_SIZE, + ) @mock.patch.dict("os.environ", {ENV_VAR_K8S: "1"}) - def test_container_metadata(self): + async def test_container_metadata(self): metadata = copy.deepcopy(_METADATA) metadata["driver"]["name"] = "PyMongo|async" metadata["env"] = {} @@ -410,8 +427,9 @@ def test_container_metadata(self): client = AsyncMongoClient("mongodb://foo:27017/?appname=foobar&connect=false") options = client.options self.assertEqual(options.pool_options.metadata["env"], metadata["env"]) + await client.close() - def test_kwargs_codec_options(self): + async def test_kwargs_codec_options(self): class MyFloatType: def __init__(self, x): self.__x = x @@ -433,7 +451,7 @@ def transform_python(self, value): uuid_representation_label = "javaLegacy" unicode_decode_error_handler = "ignore" tzinfo = utc - c = AsyncMongoClient( + async with AsyncMongoClient( document_class=document_class, type_registry=type_registry, tz_aware=tz_aware, @@ -441,16 +459,18 @@ def transform_python(self, value): unicode_decode_error_handler=unicode_decode_error_handler, tzinfo=tzinfo, connect=False, - ) - - self.assertEqual(c.codec_options.document_class, document_class) - self.assertEqual(c.codec_options.type_registry, type_registry) - self.assertEqual(c.codec_options.tz_aware, tz_aware) - self.assertEqual( - c.codec_options.uuid_representation, _UUID_REPRESENTATIONS[uuid_representation_label] - ) - self.assertEqual(c.codec_options.unicode_decode_error_handler, unicode_decode_error_handler) - self.assertEqual(c.codec_options.tzinfo, tzinfo) + ) as c: + self.assertEqual(c.codec_options.document_class, document_class) + self.assertEqual(c.codec_options.type_registry, type_registry) + self.assertEqual(c.codec_options.tz_aware, tz_aware) + self.assertEqual( + c.codec_options.uuid_representation, + _UUID_REPRESENTATIONS[uuid_representation_label], + ) + self.assertEqual( + c.codec_options.unicode_decode_error_handler, unicode_decode_error_handler + ) + self.assertEqual(c.codec_options.tzinfo, tzinfo) async def test_uri_codec_options(self): # Ensure codec options are passed in correctly @@ -469,39 +489,40 @@ async def test_uri_codec_options(self): datetime_conversion, ) ) - c = AsyncMongoClient(uri, connect=False) - - self.assertEqual(c.codec_options.tz_aware, True) - self.assertEqual( - c.codec_options.uuid_representation, _UUID_REPRESENTATIONS[uuid_representation_label] - ) - self.assertEqual(c.codec_options.unicode_decode_error_handler, unicode_decode_error_handler) - self.assertEqual( - c.codec_options.datetime_conversion, DatetimeConversion[datetime_conversion] - ) + async with AsyncMongoClient(uri, connect=False) as c: + self.assertEqual(c.codec_options.tz_aware, True) + self.assertEqual( + c.codec_options.uuid_representation, + _UUID_REPRESENTATIONS[uuid_representation_label], + ) + self.assertEqual( + c.codec_options.unicode_decode_error_handler, unicode_decode_error_handler + ) + self.assertEqual( + c.codec_options.datetime_conversion, DatetimeConversion[datetime_conversion] + ) # Change the passed datetime_conversion to a number and re-assert. uri = uri.replace(datetime_conversion, f"{int(DatetimeConversion[datetime_conversion])}") - c = AsyncMongoClient(uri, connect=False) - - self.assertEqual( - c.codec_options.datetime_conversion, DatetimeConversion[datetime_conversion] - ) + async with AsyncMongoClient(uri, connect=False) as c: + self.assertEqual( + c.codec_options.datetime_conversion, DatetimeConversion[datetime_conversion] + ) - def test_uri_option_precedence(self): + async def test_uri_option_precedence(self): # Ensure kwarg options override connection string options. uri = "mongodb://localhost/?ssl=true&replicaSet=name&readPreference=primary" - c = AsyncMongoClient( + async with AsyncMongoClient( uri, ssl=False, replicaSet="newname", readPreference="secondaryPreferred" - ) - clopts = c.options - opts = clopts._options + ) as c: + clopts = c.options + opts = clopts._options - self.assertEqual(opts["tls"], False) - self.assertEqual(clopts.replica_set_name, "newname") - self.assertEqual(clopts.read_preference, ReadPreference.SECONDARY_PREFERRED) + self.assertEqual(opts["tls"], False) + self.assertEqual(clopts.replica_set_name, "newname") + self.assertEqual(clopts.read_preference, ReadPreference.SECONDARY_PREFERRED) - def test_connection_timeout_ms_propagates_to_DNS_resolver(self): + async def test_connection_timeout_ms_propagates_to_DNS_resolver(self): # Patch the resolver. from pymongo.srv_resolver import _resolve @@ -520,68 +541,74 @@ def reset_resolver(): uri_with_timeout = base_uri + "/?connectTimeoutMS=6000" expected_uri_value = 6.0 - def test_scenario(args, kwargs, expected_value): + async def test_scenario(args, kwargs, expected_value): patched_resolver.reset() - AsyncMongoClient(*args, **kwargs) - for _, kw in patched_resolver.call_list(): - self.assertAlmostEqual(kw["lifetime"], expected_value) + async with AsyncMongoClient(*args, **kwargs): + for _, kw in patched_resolver.call_list(): + self.assertAlmostEqual(kw["lifetime"], expected_value) # No timeout specified. - test_scenario((base_uri,), {}, CONNECT_TIMEOUT) + await test_scenario((base_uri,), {}, CONNECT_TIMEOUT) # Timeout only specified in connection string. - test_scenario((uri_with_timeout,), {}, expected_uri_value) + await test_scenario((uri_with_timeout,), {}, expected_uri_value) # Timeout only specified in keyword arguments. kwarg = {"connectTimeoutMS": connectTimeoutMS} - test_scenario((base_uri,), kwarg, expected_kw_value) + await test_scenario((base_uri,), kwarg, expected_kw_value) # Timeout specified in both kwargs and connection string. - test_scenario((uri_with_timeout,), kwarg, expected_kw_value) + await test_scenario((uri_with_timeout,), kwarg, expected_kw_value) - def test_uri_security_options(self): + async def test_uri_security_options(self): # Ensure that we don't silently override security-related options. with self.assertRaises(InvalidURI): - AsyncMongoClient("mongodb://localhost/?ssl=true", tls=False, connect=False) + async with AsyncMongoClient("mongodb://localhost/?ssl=true", tls=False, connect=False): + pass # Matching SSL and TLS options should not cause errors. - c = AsyncMongoClient("mongodb://localhost/?ssl=false", tls=False, connect=False) - self.assertEqual(c.options._options["tls"], False) + async with AsyncMongoClient( + "mongodb://localhost/?ssl=false", tls=False, connect=False + ) as c: + self.assertEqual(c.options._options["tls"], False) # Conflicting tlsInsecure options should raise an error. with self.assertRaises(InvalidURI): - AsyncMongoClient( + async with AsyncMongoClient( "mongodb://localhost/?tlsInsecure=true", connect=False, tlsAllowInvalidHostnames=True, - ) + ): + pass # Conflicting legacy tlsInsecure options should also raise an error. with self.assertRaises(InvalidURI): - AsyncMongoClient( + async with AsyncMongoClient( "mongodb://localhost/?tlsInsecure=true", connect=False, tlsAllowInvalidCertificates=False, - ) + ): + pass # Conflicting kwargs should raise InvalidURI with self.assertRaises(InvalidURI): - AsyncMongoClient(ssl=True, tls=False) - - def test_event_listeners(self): - c = AsyncMongoClient(event_listeners=[], connect=False) - self.assertEqual(c.options.event_listeners, []) - listeners = [ - event_loggers.CommandLogger(), - event_loggers.HeartbeatLogger(), - event_loggers.ServerLogger(), - event_loggers.TopologyLogger(), - event_loggers.ConnectionPoolLogger(), - ] - c = AsyncMongoClient(event_listeners=listeners, connect=False) - self.assertEqual(c.options.event_listeners, listeners) + async with AsyncMongoClient(ssl=True, tls=False): + pass - def test_client_options(self): + async def test_event_listeners(self): + async with AsyncMongoClient(event_listeners=[], connect=False) as c: + self.assertEqual(c.options.event_listeners, []) + listeners = [ + event_loggers.CommandLogger(), + event_loggers.HeartbeatLogger(), + event_loggers.ServerLogger(), + event_loggers.TopologyLogger(), + event_loggers.ConnectionPoolLogger(), + ] + async with AsyncMongoClient(event_listeners=listeners, connect=False) as c: + self.assertEqual(c.options.event_listeners, listeners) + + async def test_client_options(self): c = AsyncMongoClient(connect=False) self.assertIsInstance(c.options, ClientOptions) self.assertIsInstance(c.options.pool_options, PoolOptions) @@ -589,6 +616,7 @@ def test_client_options(self): self.assertEqual(c.options.pool_options.max_idle_time_seconds, None) self.assertIsInstance(c.options.retry_writes, bool) self.assertIsInstance(c.options.retry_reads, bool) + await c.close() def test_validate_suggestion(self): """Validate kwargs in constructor.""" @@ -599,7 +627,7 @@ def test_validate_suggestion(self): AsyncMongoClient(**{typo: "standard"}) # type: ignore[arg-type] @patch("pymongo.srv_resolver._SrvResolver.get_hosts") - def test_detected_environment_logging(self, mock_get_hosts): + async def test_detected_environment_logging(self, mock_get_hosts): normal_hosts = [ "normal.host.com", "host.cosmos.azure.com", @@ -612,16 +640,19 @@ def test_detected_environment_logging(self, mock_get_hosts): ) with self.assertLogs("pymongo", level="INFO") as cm: for host in normal_hosts: - AsyncMongoClient(host) + async with AsyncMongoClient(host): + pass for host in srv_hosts: mock_get_hosts.return_value = [(host, 1)] - AsyncMongoClient(host) - AsyncMongoClient(multi_host) + async with AsyncMongoClient(host): + pass + async with AsyncMongoClient(multi_host): + pass logs = [record.message for record in cm.records if record.name == "pymongo.client"] self.assertEqual(len(logs), 7) @patch("pymongo.srv_resolver._SrvResolver.get_hosts") - def test_detected_environment_warning(self, mock_get_hosts): + async def test_detected_environment_warning(self, mock_get_hosts): with self._caplog.at_level(logging.WARN): normal_hosts = [ "host.cosmos.azure.com", @@ -634,13 +665,16 @@ def test_detected_environment_warning(self, mock_get_hosts): ) for host in normal_hosts: with self.assertWarns(UserWarning): - AsyncMongoClient(host) + async with AsyncMongoClient(host): + pass for host in srv_hosts: mock_get_hosts.return_value = [(host, 1)] with self.assertWarns(UserWarning): - AsyncMongoClient(host) + async with AsyncMongoClient(host): + pass with self.assertWarns(UserWarning): - AsyncMongoClient(multi_host) + async with AsyncMongoClient(multi_host): + pass class TestClient(AsyncIntegrationTest): @@ -724,6 +758,7 @@ async def test_max_idle_time_reaper_removes_stale(self): async def test_min_pool_size(self): with client_knobs(kill_cursor_frequency=0.1): client = await async_rs_or_single_client() + self.addAsyncCleanup(client.close) server = await (await client._get_topology()).select_server( readable_server_selector, _Op.TEST ) @@ -731,6 +766,7 @@ async def test_min_pool_size(self): # Assert that pool started up at minPoolSize client = await async_rs_or_single_client(minPoolSize=10) + self.addAsyncCleanup(client.close) server = await (await client._get_topology()).select_server( readable_server_selector, _Op.TEST ) @@ -752,6 +788,7 @@ async def test_max_idle_time_checkout(self): # Use high frequency to test _get_socket_no_auth. with client_knobs(kill_cursor_frequency=99999999): client = await async_rs_or_single_client(maxIdleTimeMS=500) + self.addAsyncCleanup(client.close) server = await (await client._get_topology()).select_server( readable_server_selector, _Op.TEST ) @@ -768,6 +805,7 @@ async def test_max_idle_time_checkout(self): # Test that connections are reused if maxIdleTimeMS is not set. client = await async_rs_or_single_client() + self.addAsyncCleanup(client.close) server = await (await client._get_topology()).select_server( readable_server_selector, _Op.TEST ) @@ -793,36 +831,45 @@ async def test_constants(self): AsyncMongoClient.HOST = "somedomainthatdoesntexist.org" AsyncMongoClient.PORT = 123456789 with self.assertRaises(AutoReconnect): - await connected(AsyncMongoClient(serverSelectionTimeoutMS=10, **kwargs)) + async with AsyncMongoClient(serverSelectionTimeoutMS=10, **kwargs) as c: + await connected(c) - # Override the defaults. No error. - await connected(AsyncMongoClient(host, port, **kwargs)) + async with AsyncMongoClient(host, port, **kwargs) as c: + # Override the defaults. No error. + await connected(c) # Set good defaults. AsyncMongoClient.HOST = host AsyncMongoClient.PORT = port # No error. - await connected(AsyncMongoClient(**kwargs)) + async with AsyncMongoClient(**kwargs) as c: + await connected(c) async def test_init_disconnected(self): host, port = await async_client_context.host, await async_client_context.port c = await async_rs_or_single_client(connect=False) + self.addAsyncCleanup(c.close) # is_primary causes client to block until connected self.assertIsInstance(await c.is_primary, bool) c = await async_rs_or_single_client(connect=False) + self.addAsyncCleanup(c.close) self.assertIsInstance(await c.is_mongos, bool) c = await async_rs_or_single_client(connect=False) + self.addAsyncCleanup(c.close) self.assertIsInstance(c.options.pool_options.max_pool_size, int) self.assertIsInstance(c.nodes, frozenset) c = await async_rs_or_single_client(connect=False) + self.addAsyncCleanup(c.close) self.assertEqual(c.codec_options, CodecOptions()) c = await async_rs_or_single_client(connect=False) + self.addAsyncCleanup(c.close) self.assertFalse(await c.primary) self.assertFalse(await c.secondaries) c = await async_rs_or_single_client(connect=False) + self.addAsyncCleanup(c.close) self.assertIsInstance(c.topology_description, TopologyDescription) self.assertEqual(c.topology_description, c._topology._description) self.assertIsNone(await c.address) # PYTHON-2981 @@ -835,36 +882,43 @@ async def test_init_disconnected(self): bad_host = "somedomainthatdoesntexist.org" c = AsyncMongoClient(bad_host, port, connectTimeoutMS=1, serverSelectionTimeoutMS=10) + self.addAsyncCleanup(c.close) with self.assertRaises(ConnectionFailure): await c.pymongo_test.test.find_one() async def test_init_disconnected_with_auth(self): uri = "mongodb://user:pass@somedomainthatdoesntexist" c = AsyncMongoClient(uri, connectTimeoutMS=1, serverSelectionTimeoutMS=10) + self.addAsyncCleanup(c.close) with self.assertRaises(ConnectionFailure): await c.pymongo_test.test.find_one() async def test_equality(self): seed = "{}:{}".format(*list(self.client._topology_settings.seeds)[0]) - c = await async_rs_or_single_client(seed, connect=False) - self.addAsyncCleanup(c.close) - self.assertEqual(async_client_context.client, c) - # Explicitly test inequality - self.assertFalse(async_client_context.client != c) + async with await async_rs_or_single_client(seed, connect=False) as c: + self.assertEqual(async_client_context.client, c) + # Explicitly test inequality + self.assertFalse(async_client_context.client != c) + + async with await async_rs_or_single_client("invalid.com", connect=False) as c: + self.assertNotEqual(async_client_context.client, c) + self.assertTrue(async_client_context.client != c) + + c1 = AsyncMongoClient("a", connect=False) + c2 = AsyncMongoClient("b", connect=False) + self.addAsyncCleanup(c1.close) + self.addAsyncCleanup(c2.close) - c = await async_rs_or_single_client("invalid.com", connect=False) - self.addAsyncCleanup(c.close) - self.assertNotEqual(async_client_context.client, c) - self.assertTrue(async_client_context.client != c) # Seeds differ: - self.assertNotEqual( - AsyncMongoClient("a", connect=False), AsyncMongoClient("b", connect=False) - ) + self.assertNotEqual(c1, c2) + + c1 = AsyncMongoClient(["a", "b", "c"], connect=False) + c2 = AsyncMongoClient(["c", "a", "b"], connect=False) + self.addAsyncCleanup(c1.close) + self.addAsyncCleanup(c2.close) + # Same seeds but out of order still compares equal: - self.assertEqual( - AsyncMongoClient(["a", "b", "c"], connect=False), - AsyncMongoClient(["c", "a", "b"], connect=False), - ) + self.assertEqual(c1, c2) async def test_hashable(self): seed = "{}:{}".format(*list(self.client._topology_settings.seeds)[0]) @@ -886,7 +940,7 @@ async def test_host_w_port(self): ) ) - def test_repr(self): + async def test_repr(self): # Used to test 'eval' below. import bson @@ -896,6 +950,7 @@ def test_repr(self): connect=False, document_class=SON, ) + self.addAsyncCleanup(client.close) the_repr = repr(client) self.assertIn("AsyncMongoClient(host=", the_repr) @@ -905,7 +960,8 @@ def test_repr(self): self.assertIn("w=1", the_repr) self.assertIn("wtimeoutms=100", the_repr) - self.assertEqual(eval(the_repr), client) + async with eval(the_repr) as client_two: + self.assertEqual(client_two, client) client = AsyncMongoClient( "localhost:27017,localhost:27018", @@ -916,6 +972,7 @@ def test_repr(self): wtimeoutms=100, connect=False, ) + self.addAsyncCleanup(client.close) the_repr = repr(client) self.assertIn("AsyncMongoClient(host=", the_repr) self.assertIn("document_class=dict, tz_aware=False, connect=False, ", the_repr) @@ -925,7 +982,8 @@ def test_repr(self): self.assertIn("w=1", the_repr) self.assertIn("wtimeoutms=100", the_repr) - self.assertEqual(eval(the_repr), client) + async with eval(the_repr) as client_two: + self.assertEqual(client_two, client) def test_getters(self): wait_until(lambda: async_client_context.nodes == self.client.nodes, "find all nodes") @@ -1032,6 +1090,7 @@ async def test_close_kills_cursors(self): # The killCursors task should not need to re-open the topology. await test_client._process_periodic_tasks() self.assertTrue(test_client._topology._opened) + await test_client.close() async def test_close_stops_kill_cursors_thread(self): client = await async_rs_client() @@ -1192,11 +1251,10 @@ async def test_unix_socket(self): # Confirm it fails with a missing socket. with self.assertRaises(ConnectionFailure): - await connected( - AsyncMongoClient( - "mongodb://%2Ftmp%2Fnon-existent.sock", serverSelectionTimeoutMS=100 - ), - ) + async with AsyncMongoClient( + "mongodb://%2Ftmp%2Fnon-existent.sock", serverSelectionTimeoutMS=100 + ) as c: + await connected(c) async def test_document_class(self): c = self.client @@ -1221,6 +1279,7 @@ async def test_timeouts(self): maxIdleTimeMS=10500, serverSelectionTimeoutMS=10500, ) + self.addAsyncCleanup(client.close) self.assertEqual(10.5, (await async_get_pool(client)).opts.connect_timeout) self.assertEqual(10.5, (await async_get_pool(client)).opts.socket_timeout) self.assertEqual(10.5, (await async_get_pool(client)).opts.max_idle_time_seconds) @@ -1229,22 +1288,28 @@ async def test_timeouts(self): async def test_socket_timeout_ms_validation(self): c = await async_rs_or_single_client(socketTimeoutMS=10 * 1000) + self.addAsyncCleanup(c.close) self.assertEqual(10, (await async_get_pool(c)).opts.socket_timeout) c = await connected(await async_rs_or_single_client(socketTimeoutMS=None)) + self.addAsyncCleanup(c.close) self.assertEqual(None, (await async_get_pool(c)).opts.socket_timeout) c = await connected(await async_rs_or_single_client(socketTimeoutMS=0)) + self.addAsyncCleanup(c.close) self.assertEqual(None, (await async_get_pool(c)).opts.socket_timeout) with self.assertRaises(ValueError): - await async_rs_or_single_client(socketTimeoutMS=-1) + async with await async_rs_or_single_client(socketTimeoutMS=-1): + pass with self.assertRaises(ValueError): - await async_rs_or_single_client(socketTimeoutMS=1e10) + async with await async_rs_or_single_client(socketTimeoutMS=1e10): + pass with self.assertRaises(ValueError): - await async_rs_or_single_client(socketTimeoutMS="foo") + async with await async_rs_or_single_client(socketTimeoutMS="foo"): + pass async def test_socket_timeout(self): no_timeout = self.client @@ -1266,11 +1331,13 @@ async def get_x(db): with self.assertRaises(NetworkTimeout): await get_x(timeout.pymongo_test) - def test_server_selection_timeout(self): + async def test_server_selection_timeout(self): client = AsyncMongoClient(serverSelectionTimeoutMS=100, connect=False) + self.addAsyncCleanup(client.close) self.assertAlmostEqual(0.1, client.options.server_selection_timeout) client = AsyncMongoClient(serverSelectionTimeoutMS=0, connect=False) + self.addAsyncCleanup(client.close) self.assertAlmostEqual(0, client.options.server_selection_timeout) @@ -1285,20 +1352,25 @@ def test_server_selection_timeout(self): client = AsyncMongoClient( "mongodb://localhost/?serverSelectionTimeoutMS=100", connect=False ) + self.addAsyncCleanup(client.close) self.assertAlmostEqual(0.1, client.options.server_selection_timeout) client = AsyncMongoClient("mongodb://localhost/?serverSelectionTimeoutMS=0", connect=False) + self.addAsyncCleanup(client.close) self.assertAlmostEqual(0, client.options.server_selection_timeout) # Test invalid timeout in URI ignored and set to default. client = AsyncMongoClient("mongodb://localhost/?serverSelectionTimeoutMS=-1", connect=False) + self.addAsyncCleanup(client.close) self.assertAlmostEqual(30, client.options.server_selection_timeout) client = AsyncMongoClient("mongodb://localhost/?serverSelectionTimeoutMS=", connect=False) + self.addAsyncCleanup(client.close) self.assertAlmostEqual(30, client.options.server_selection_timeout) async def test_waitQueueTimeoutMS(self): client = await async_rs_or_single_client(waitQueueTimeoutMS=2000) + self.addAsyncCleanup(client.close) self.assertEqual((await async_get_pool(client)).opts.wait_queue_timeout, 2) async def test_socketKeepAlive(self): @@ -1545,10 +1617,11 @@ async def test_auth_network_error(self): @async_client_context.require_no_replica_set async def test_connect_to_standalone_using_replica_set_name(self): - client = await async_single_client(replicaSet="anything", serverSelectionTimeoutMS=100) - - with self.assertRaises(AutoReconnect): - await client.test.test.find_one() + async with await async_single_client( + replicaSet="anything", serverSelectionTimeoutMS=100 + ) as client: + with self.assertRaises(AutoReconnect): + await client.test.test.find_one() @async_client_context.require_replica_set async def test_stale_getmore(self): @@ -1630,84 +1703,84 @@ def compression_settings(client): return pool_options._compression_settings uri = "mongodb://localhost:27017/?compressors=zlib" - client = AsyncMongoClient(uri, connect=False) - opts = compression_settings(client) - self.assertEqual(opts.compressors, ["zlib"]) + async with AsyncMongoClient(uri, connect=False) as client: + opts = compression_settings(client) + self.assertEqual(opts.compressors, ["zlib"]) uri = "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=4" - client = AsyncMongoClient(uri, connect=False) - opts = compression_settings(client) - self.assertEqual(opts.compressors, ["zlib"]) - self.assertEqual(opts.zlib_compression_level, 4) + async with AsyncMongoClient(uri, connect=False) as client: + opts = compression_settings(client) + self.assertEqual(opts.compressors, ["zlib"]) + self.assertEqual(opts.zlib_compression_level, 4) uri = "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=-1" - client = AsyncMongoClient(uri, connect=False) - opts = compression_settings(client) - self.assertEqual(opts.compressors, ["zlib"]) - self.assertEqual(opts.zlib_compression_level, -1) + async with AsyncMongoClient(uri, connect=False) as client: + opts = compression_settings(client) + self.assertEqual(opts.compressors, ["zlib"]) + self.assertEqual(opts.zlib_compression_level, -1) uri = "mongodb://localhost:27017" - client = AsyncMongoClient(uri, connect=False) - opts = compression_settings(client) - self.assertEqual(opts.compressors, []) - self.assertEqual(opts.zlib_compression_level, -1) + async with AsyncMongoClient(uri, connect=False) as client: + opts = compression_settings(client) + self.assertEqual(opts.compressors, []) + self.assertEqual(opts.zlib_compression_level, -1) uri = "mongodb://localhost:27017/?compressors=foobar" - client = AsyncMongoClient(uri, connect=False) - opts = compression_settings(client) - self.assertEqual(opts.compressors, []) - self.assertEqual(opts.zlib_compression_level, -1) + async with AsyncMongoClient(uri, connect=False) as client: + opts = compression_settings(client) + self.assertEqual(opts.compressors, []) + self.assertEqual(opts.zlib_compression_level, -1) uri = "mongodb://localhost:27017/?compressors=foobar,zlib" - client = AsyncMongoClient(uri, connect=False) - opts = compression_settings(client) - self.assertEqual(opts.compressors, ["zlib"]) - self.assertEqual(opts.zlib_compression_level, -1) + async with AsyncMongoClient(uri, connect=False) as client: + opts = compression_settings(client) + self.assertEqual(opts.compressors, ["zlib"]) + self.assertEqual(opts.zlib_compression_level, -1) # According to the connection string spec, unsupported values # just raise a warning and are ignored. uri = "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=10" - client = AsyncMongoClient(uri, connect=False) - opts = compression_settings(client) - self.assertEqual(opts.compressors, ["zlib"]) - self.assertEqual(opts.zlib_compression_level, -1) + async with AsyncMongoClient(uri, connect=False) as client: + opts = compression_settings(client) + self.assertEqual(opts.compressors, ["zlib"]) + self.assertEqual(opts.zlib_compression_level, -1) uri = "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=-2" - client = AsyncMongoClient(uri, connect=False) - opts = compression_settings(client) - self.assertEqual(opts.compressors, ["zlib"]) - self.assertEqual(opts.zlib_compression_level, -1) + async with AsyncMongoClient(uri, connect=False) as client: + opts = compression_settings(client) + self.assertEqual(opts.compressors, ["zlib"]) + self.assertEqual(opts.zlib_compression_level, -1) if not _have_snappy(): uri = "mongodb://localhost:27017/?compressors=snappy" - client = AsyncMongoClient(uri, connect=False) - opts = compression_settings(client) - self.assertEqual(opts.compressors, []) + async with AsyncMongoClient(uri, connect=False) as client: + opts = compression_settings(client) + self.assertEqual(opts.compressors, []) else: uri = "mongodb://localhost:27017/?compressors=snappy" - client = AsyncMongoClient(uri, connect=False) - opts = compression_settings(client) - self.assertEqual(opts.compressors, ["snappy"]) + async with AsyncMongoClient(uri, connect=False) as client: + opts = compression_settings(client) + self.assertEqual(opts.compressors, ["snappy"]) uri = "mongodb://localhost:27017/?compressors=snappy,zlib" - client = AsyncMongoClient(uri, connect=False) - opts = compression_settings(client) - self.assertEqual(opts.compressors, ["snappy", "zlib"]) + async with AsyncMongoClient(uri, connect=False) as client: + opts = compression_settings(client) + self.assertEqual(opts.compressors, ["snappy", "zlib"]) if not _have_zstd(): uri = "mongodb://localhost:27017/?compressors=zstd" - client = AsyncMongoClient(uri, connect=False) - opts = compression_settings(client) - self.assertEqual(opts.compressors, []) + async with AsyncMongoClient(uri, connect=False) as client: + opts = compression_settings(client) + self.assertEqual(opts.compressors, []) else: uri = "mongodb://localhost:27017/?compressors=zstd" - client = AsyncMongoClient(uri, connect=False) - opts = compression_settings(client) - self.assertEqual(opts.compressors, ["zstd"]) + async with AsyncMongoClient(uri, connect=False) as client: + opts = compression_settings(client) + self.assertEqual(opts.compressors, ["zstd"]) uri = "mongodb://localhost:27017/?compressors=zstd,zlib" - client = AsyncMongoClient(uri, connect=False) - opts = compression_settings(client) - self.assertEqual(opts.compressors, ["zstd", "zlib"]) + async with AsyncMongoClient(uri, connect=False) as client: + opts = compression_settings(client) + self.assertEqual(opts.compressors, ["zstd", "zlib"]) options = async_client_context.default_client_options if "compressors" in options and "zlib" in options["compressors"]: for level in range(-1, 10): - client = await async_single_client(zlibcompressionlevel=level) - # No error - await client.pymongo_test.test.find_one() + async with await async_single_client(zlibcompressionlevel=level) as client: + # No error + await client.pymongo_test.test.find_one() async def test_reset_during_update_pool(self): client = await async_rs_or_single_client(minPoolSize=10) @@ -1873,12 +1946,13 @@ async def test_process_periodic_tasks(self): with self.assertRaises(InvalidOperation): await coll.insert_many([{} for _ in range(5)]) - def test_service_name_from_kwargs(self): + async def test_service_name_from_kwargs(self): client = AsyncMongoClient( "mongodb+srv://user:password@test22.test.build.10gen.cc", srvServiceName="customname", connect=False, ) + self.addAsyncCleanup(client.close) self.assertEqual(client._topology_settings.srv_service_name, "customname") client = AsyncMongoClient( "mongodb+srv://user:password@test22.test.build.10gen.cc" @@ -1886,21 +1960,26 @@ def test_service_name_from_kwargs(self): srvServiceName="customname", connect=False, ) + self.addAsyncCleanup(client.close) self.assertEqual(client._topology_settings.srv_service_name, "customname") client = AsyncMongoClient( "mongodb+srv://user:password@test22.test.build.10gen.cc/?srvServiceName=customname", connect=False, ) + self.addAsyncCleanup(client.close) self.assertEqual(client._topology_settings.srv_service_name, "customname") - def test_srv_max_hosts_kwarg(self): + async def test_srv_max_hosts_kwarg(self): client = AsyncMongoClient("mongodb+srv://test1.test.build.10gen.cc/") + self.addAsyncCleanup(client.close) self.assertGreater(len(client.topology_description.server_descriptions()), 1) client = AsyncMongoClient("mongodb+srv://test1.test.build.10gen.cc/", srvmaxhosts=1) + self.addAsyncCleanup(client.close) self.assertEqual(len(client.topology_description.server_descriptions()), 1) client = AsyncMongoClient( "mongodb+srv://test1.test.build.10gen.cc/?srvMaxHosts=1", srvmaxhosts=2 ) + self.addAsyncCleanup(client.close) self.assertEqual(len(client.topology_description.server_descriptions()), 2) @unittest.skipIf( @@ -2046,6 +2125,7 @@ async def test_exhaust_query_server_error(self): # When doing an exhaust query, the socket stays checked out on success # but must be checked in on error to avoid semaphore leaks. client = await connected(await async_rs_or_single_client(maxPoolSize=1)) + self.addAsyncCleanup(client.close) collection = client.pymongo_test.test pool = await async_get_pool(client) @@ -2069,6 +2149,7 @@ async def test_exhaust_getmore_server_error(self): # When doing a getmore on an exhaust cursor, the socket stays checked # out on success but it's checked in on error to avoid semaphore leaks. client = await async_rs_or_single_client(maxPoolSize=1) + self.addAsyncCleanup(client.close) collection = client.pymongo_test.test await collection.drop() @@ -2108,6 +2189,7 @@ async def test_exhaust_query_network_error(self): # When doing an exhaust query, the socket stays checked out on success # but must be checked in on error to avoid semaphore leaks. client = await connected(await async_rs_or_single_client(maxPoolSize=1, retryReads=False)) + self.addAsyncCleanup(client.close) collection = client.pymongo_test.test pool = await async_get_pool(client) pool._check_interval_seconds = None # Never check. @@ -2129,6 +2211,7 @@ async def test_exhaust_getmore_network_error(self): # When doing a getmore on an exhaust cursor, the socket stays checked # out on success but it's checked in on error to avoid semaphore leaks. client = await async_rs_or_single_client(maxPoolSize=1) + self.addAsyncCleanup(client.close) collection = client.pymongo_test.test await collection.drop() await collection.insert_many([{} for _ in range(200)]) # More than one batch. diff --git a/test/test_client.py b/test/test_client.py index 22e94dcddb..fcfd8926f2 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -142,7 +142,7 @@ def inject_fixtures(self, caplog): self._caplog = caplog def test_keyword_arg_defaults(self): - client = MongoClient( + with MongoClient( socketTimeoutMS=None, connectTimeoutMS=20000, waitQueueTimeoutMS=None, @@ -154,34 +154,36 @@ def test_keyword_arg_defaults(self): tlsCAFile=None, connect=False, serverSelectionTimeoutMS=12000, - ) - - options = client.options - pool_opts = options.pool_options - self.assertEqual(None, pool_opts.socket_timeout) - # socket.Socket.settimeout takes a float in seconds - self.assertEqual(20.0, pool_opts.connect_timeout) - self.assertEqual(None, pool_opts.wait_queue_timeout) - self.assertEqual(None, pool_opts._ssl_context) - self.assertEqual(None, options.replica_set_name) - self.assertEqual(ReadPreference.PRIMARY, client.read_preference) - self.assertAlmostEqual(12, client.options.server_selection_timeout) + ) as client: + options = client.options + pool_opts = options.pool_options + self.assertEqual(None, pool_opts.socket_timeout) + # socket.Socket.settimeout takes a float in seconds + self.assertEqual(20.0, pool_opts.connect_timeout) + self.assertEqual(None, pool_opts.wait_queue_timeout) + self.assertEqual(None, pool_opts._ssl_context) + self.assertEqual(None, options.replica_set_name) + self.assertEqual(ReadPreference.PRIMARY, client.read_preference) + self.assertAlmostEqual(12, client.options.server_selection_timeout) def test_connect_timeout(self): client = MongoClient(connect=False, connectTimeoutMS=None, socketTimeoutMS=None) pool_opts = client.options.pool_options self.assertEqual(None, pool_opts.socket_timeout) self.assertEqual(None, pool_opts.connect_timeout) + client.close() client = MongoClient(connect=False, connectTimeoutMS=0, socketTimeoutMS=0) pool_opts = client.options.pool_options self.assertEqual(None, pool_opts.socket_timeout) self.assertEqual(None, pool_opts.connect_timeout) + client.close() client = MongoClient( "mongodb://localhost/?connectTimeoutMS=0&socketTimeoutMS=0", connect=False ) pool_opts = client.options.pool_options self.assertEqual(None, pool_opts.socket_timeout) self.assertEqual(None, pool_opts.connect_timeout) + client.close() def test_types(self): self.assertRaises(TypeError, MongoClient, 1) @@ -193,7 +195,8 @@ def test_types(self): self.assertRaises(ConfigurationError, MongoClient, []) def test_max_pool_size_zero(self): - MongoClient(maxPoolSize=0) + with MongoClient(maxPoolSize=0): + pass def test_uri_detection(self): self.assertRaises(ConfigurationError, MongoClient, "/foo") @@ -258,35 +261,37 @@ def test_iteration(self): self.assertNotIsInstance(client, Iterable) def test_get_default_database(self): - c = rs_or_single_client( + with rs_or_single_client( "mongodb://%s:%d/foo" % (client_context.host, client_context.port), connect=False, - ) - self.assertEqual(Database(c, "foo"), c.get_default_database()) - # Test that default doesn't override the URI value. - self.assertEqual(Database(c, "foo"), c.get_default_database("bar")) - - codec_options = CodecOptions(tz_aware=True) - write_concern = WriteConcern(w=2, j=True) - db = c.get_default_database(None, codec_options, ReadPreference.SECONDARY, write_concern) - self.assertEqual("foo", db.name) - self.assertEqual(codec_options, db.codec_options) - self.assertEqual(ReadPreference.SECONDARY, db.read_preference) - self.assertEqual(write_concern, db.write_concern) + ) as c: + self.assertEqual(Database(c, "foo"), c.get_default_database()) + # Test that default doesn't override the URI value. + self.assertEqual(Database(c, "foo"), c.get_default_database("bar")) + + codec_options = CodecOptions(tz_aware=True) + write_concern = WriteConcern(w=2, j=True) + db = c.get_default_database( + None, codec_options, ReadPreference.SECONDARY, write_concern + ) + self.assertEqual("foo", db.name) + self.assertEqual(codec_options, db.codec_options) + self.assertEqual(ReadPreference.SECONDARY, db.read_preference) + self.assertEqual(write_concern, db.write_concern) - c = rs_or_single_client( + with rs_or_single_client( "mongodb://%s:%d/" % (client_context.host, client_context.port), connect=False, - ) - self.assertEqual(Database(c, "foo"), c.get_default_database("foo")) + ) as c: + self.assertEqual(Database(c, "foo"), c.get_default_database("foo")) def test_get_default_database_error(self): # URI with no database. - c = rs_or_single_client( + with rs_or_single_client( "mongodb://%s:%d/" % (client_context.host, client_context.port), connect=False, - ) - self.assertRaises(ConfigurationError, c.get_default_database) + ) as c: + self.assertRaises(ConfigurationError, c.get_default_database) def test_get_default_database_with_authsource(self): # Ensure we distinguish database name from authSource. @@ -294,15 +299,15 @@ def test_get_default_database_with_authsource(self): client_context.host, client_context.port, ) - c = rs_or_single_client(uri, connect=False) - self.assertEqual(Database(c, "foo"), c.get_default_database()) + with rs_or_single_client(uri, connect=False) as c: + self.assertEqual(Database(c, "foo"), c.get_default_database()) def test_get_database_default(self): - c = rs_or_single_client( + with rs_or_single_client( "mongodb://%s:%d/foo" % (client_context.host, client_context.port), connect=False, - ) - self.assertEqual(Database(c, "foo"), c.get_database()) + ) as c: + self.assertEqual(Database(c, "foo"), c.get_database()) def test_get_database_default_error(self): # URI with no database. @@ -311,6 +316,7 @@ def test_get_database_default_error(self): connect=False, ) self.assertRaises(ConfigurationError, c.get_database) + c.close() def test_get_database_default_with_authsource(self): # Ensure we distinguish database name from authSource. @@ -320,6 +326,7 @@ def test_get_database_default_with_authsource(self): ) c = rs_or_single_client(uri, connect=False) self.assertEqual(Database(c, "foo"), c.get_database()) + c.close() def test_primary_read_pref_with_tags(self): # No tags allowed with "primary". @@ -330,72 +337,81 @@ def test_primary_read_pref_with_tags(self): MongoClient("mongodb://host/?readpreference=primary&readpreferencetags=dc:east") def test_read_preference(self): - c = rs_or_single_client( + with rs_or_single_client( "mongodb://host", connect=False, readpreference=ReadPreference.NEAREST.mongos_mode - ) - self.assertEqual(c.read_preference, ReadPreference.NEAREST) + ) as c: + self.assertEqual(c.read_preference, ReadPreference.NEAREST) def test_metadata(self): metadata = copy.deepcopy(_METADATA) metadata["driver"]["name"] = "PyMongo" metadata["application"] = {"name": "foobar"} - client = MongoClient("mongodb://foo:27017/?appname=foobar&connect=false") - options = client.options - self.assertEqual(options.pool_options.metadata, metadata) - client = MongoClient("foo", 27017, appname="foobar", connect=False) - options = client.options - self.assertEqual(options.pool_options.metadata, metadata) + with MongoClient("mongodb://foo:27017/?appname=foobar&connect=false") as client: + options = client.options + self.assertEqual(options.pool_options.metadata, metadata) + with MongoClient("foo", 27017, appname="foobar", connect=False) as client: + options = client.options + self.assertEqual(options.pool_options.metadata, metadata) # No error - MongoClient(appname="x" * 128) - self.assertRaises(ValueError, MongoClient, appname="x" * 129) + with MongoClient(appname="x" * 128): + pass + with self.assertRaises(ValueError): + with MongoClient(appname="x" * 129): + pass # Bad "driver" options. self.assertRaises(TypeError, DriverInfo, "Foo", 1, "a") self.assertRaises(TypeError, DriverInfo, version="1", platform="a") self.assertRaises(TypeError, DriverInfo) - self.assertRaises(TypeError, MongoClient, driver=1) - self.assertRaises(TypeError, MongoClient, driver="abc") - self.assertRaises(TypeError, MongoClient, driver=("Foo", "1", "a")) + with self.assertRaises(TypeError): + with MongoClient(driver=1): + pass + with self.assertRaises(TypeError): + with MongoClient(driver="abc"): + pass + with self.assertRaises(TypeError): + with MongoClient(driver=("Foo", "1", "a")): + pass # Test appending to driver info. metadata["driver"]["name"] = "PyMongo|FooDriver" metadata["driver"]["version"] = "{}|1.2.3".format(_METADATA["driver"]["version"]) - client = MongoClient( + with MongoClient( "foo", 27017, appname="foobar", driver=DriverInfo("FooDriver", "1.2.3", None), connect=False, - ) - options = client.options - self.assertEqual(options.pool_options.metadata, metadata) + ) as client: + options = client.options + self.assertEqual(options.pool_options.metadata, metadata) metadata["platform"] = "{}|FooPlatform".format(_METADATA["platform"]) - client = MongoClient( + with MongoClient( "foo", 27017, appname="foobar", driver=DriverInfo("FooDriver", "1.2.3", "FooPlatform"), connect=False, - ) - options = client.options - self.assertEqual(options.pool_options.metadata, metadata) + ) as client: + options = client.options + self.assertEqual(options.pool_options.metadata, metadata) # Test truncating driver info metadata. - client = MongoClient( + with MongoClient( driver=DriverInfo(name="s" * _MAX_METADATA_SIZE), connect=False, - ) - options = client.options - self.assertLessEqual( - len(bson.encode(options.pool_options.metadata)), - _MAX_METADATA_SIZE, - ) - client = MongoClient( + ) as client: + options = client.options + self.assertLessEqual( + len(bson.encode(options.pool_options.metadata)), + _MAX_METADATA_SIZE, + ) + with MongoClient( driver=DriverInfo(name="s" * _MAX_METADATA_SIZE, version="s" * _MAX_METADATA_SIZE), connect=False, - ) - options = client.options - self.assertLessEqual( - len(bson.encode(options.pool_options.metadata)), - _MAX_METADATA_SIZE, - ) + ) as client: + options = client.options + self.assertLessEqual( + len(bson.encode(options.pool_options.metadata)), + _MAX_METADATA_SIZE, + ) @mock.patch.dict("os.environ", {ENV_VAR_K8S: "1"}) def test_container_metadata(self): @@ -406,6 +422,7 @@ def test_container_metadata(self): client = MongoClient("mongodb://foo:27017/?appname=foobar&connect=false") options = client.options self.assertEqual(options.pool_options.metadata["env"], metadata["env"]) + client.close() def test_kwargs_codec_options(self): class MyFloatType: @@ -429,7 +446,7 @@ def transform_python(self, value): uuid_representation_label = "javaLegacy" unicode_decode_error_handler = "ignore" tzinfo = utc - c = MongoClient( + with MongoClient( document_class=document_class, type_registry=type_registry, tz_aware=tz_aware, @@ -437,16 +454,18 @@ def transform_python(self, value): unicode_decode_error_handler=unicode_decode_error_handler, tzinfo=tzinfo, connect=False, - ) - - self.assertEqual(c.codec_options.document_class, document_class) - self.assertEqual(c.codec_options.type_registry, type_registry) - self.assertEqual(c.codec_options.tz_aware, tz_aware) - self.assertEqual( - c.codec_options.uuid_representation, _UUID_REPRESENTATIONS[uuid_representation_label] - ) - self.assertEqual(c.codec_options.unicode_decode_error_handler, unicode_decode_error_handler) - self.assertEqual(c.codec_options.tzinfo, tzinfo) + ) as c: + self.assertEqual(c.codec_options.document_class, document_class) + self.assertEqual(c.codec_options.type_registry, type_registry) + self.assertEqual(c.codec_options.tz_aware, tz_aware) + self.assertEqual( + c.codec_options.uuid_representation, + _UUID_REPRESENTATIONS[uuid_representation_label], + ) + self.assertEqual( + c.codec_options.unicode_decode_error_handler, unicode_decode_error_handler + ) + self.assertEqual(c.codec_options.tzinfo, tzinfo) def test_uri_codec_options(self): # Ensure codec options are passed in correctly @@ -465,35 +484,38 @@ def test_uri_codec_options(self): datetime_conversion, ) ) - c = MongoClient(uri, connect=False) - - self.assertEqual(c.codec_options.tz_aware, True) - self.assertEqual( - c.codec_options.uuid_representation, _UUID_REPRESENTATIONS[uuid_representation_label] - ) - self.assertEqual(c.codec_options.unicode_decode_error_handler, unicode_decode_error_handler) - self.assertEqual( - c.codec_options.datetime_conversion, DatetimeConversion[datetime_conversion] - ) + with MongoClient(uri, connect=False) as c: + self.assertEqual(c.codec_options.tz_aware, True) + self.assertEqual( + c.codec_options.uuid_representation, + _UUID_REPRESENTATIONS[uuid_representation_label], + ) + self.assertEqual( + c.codec_options.unicode_decode_error_handler, unicode_decode_error_handler + ) + self.assertEqual( + c.codec_options.datetime_conversion, DatetimeConversion[datetime_conversion] + ) # Change the passed datetime_conversion to a number and re-assert. uri = uri.replace(datetime_conversion, f"{int(DatetimeConversion[datetime_conversion])}") - c = MongoClient(uri, connect=False) - - self.assertEqual( - c.codec_options.datetime_conversion, DatetimeConversion[datetime_conversion] - ) + with MongoClient(uri, connect=False) as c: + self.assertEqual( + c.codec_options.datetime_conversion, DatetimeConversion[datetime_conversion] + ) def test_uri_option_precedence(self): # Ensure kwarg options override connection string options. uri = "mongodb://localhost/?ssl=true&replicaSet=name&readPreference=primary" - c = MongoClient(uri, ssl=False, replicaSet="newname", readPreference="secondaryPreferred") - clopts = c.options - opts = clopts._options + with MongoClient( + uri, ssl=False, replicaSet="newname", readPreference="secondaryPreferred" + ) as c: + clopts = c.options + opts = clopts._options - self.assertEqual(opts["tls"], False) - self.assertEqual(clopts.replica_set_name, "newname") - self.assertEqual(clopts.read_preference, ReadPreference.SECONDARY_PREFERRED) + self.assertEqual(opts["tls"], False) + self.assertEqual(clopts.replica_set_name, "newname") + self.assertEqual(clopts.read_preference, ReadPreference.SECONDARY_PREFERRED) def test_connection_timeout_ms_propagates_to_DNS_resolver(self): # Patch the resolver. @@ -516,9 +538,9 @@ def reset_resolver(): def test_scenario(args, kwargs, expected_value): patched_resolver.reset() - MongoClient(*args, **kwargs) - for _, kw in patched_resolver.call_list(): - self.assertAlmostEqual(kw["lifetime"], expected_value) + with MongoClient(*args, **kwargs): + for _, kw in patched_resolver.call_list(): + self.assertAlmostEqual(kw["lifetime"], expected_value) # No timeout specified. test_scenario((base_uri,), {}, CONNECT_TIMEOUT) @@ -536,44 +558,48 @@ def test_scenario(args, kwargs, expected_value): def test_uri_security_options(self): # Ensure that we don't silently override security-related options. with self.assertRaises(InvalidURI): - MongoClient("mongodb://localhost/?ssl=true", tls=False, connect=False) + with MongoClient("mongodb://localhost/?ssl=true", tls=False, connect=False): + pass # Matching SSL and TLS options should not cause errors. - c = MongoClient("mongodb://localhost/?ssl=false", tls=False, connect=False) - self.assertEqual(c.options._options["tls"], False) + with MongoClient("mongodb://localhost/?ssl=false", tls=False, connect=False) as c: + self.assertEqual(c.options._options["tls"], False) # Conflicting tlsInsecure options should raise an error. with self.assertRaises(InvalidURI): - MongoClient( + with MongoClient( "mongodb://localhost/?tlsInsecure=true", connect=False, tlsAllowInvalidHostnames=True, - ) + ): + pass # Conflicting legacy tlsInsecure options should also raise an error. with self.assertRaises(InvalidURI): - MongoClient( + with MongoClient( "mongodb://localhost/?tlsInsecure=true", connect=False, tlsAllowInvalidCertificates=False, - ) + ): + pass # Conflicting kwargs should raise InvalidURI with self.assertRaises(InvalidURI): - MongoClient(ssl=True, tls=False) + with MongoClient(ssl=True, tls=False): + pass def test_event_listeners(self): - c = MongoClient(event_listeners=[], connect=False) - self.assertEqual(c.options.event_listeners, []) - listeners = [ - event_loggers.CommandLogger(), - event_loggers.HeartbeatLogger(), - event_loggers.ServerLogger(), - event_loggers.TopologyLogger(), - event_loggers.ConnectionPoolLogger(), - ] - c = MongoClient(event_listeners=listeners, connect=False) - self.assertEqual(c.options.event_listeners, listeners) + with MongoClient(event_listeners=[], connect=False) as c: + self.assertEqual(c.options.event_listeners, []) + listeners = [ + event_loggers.CommandLogger(), + event_loggers.HeartbeatLogger(), + event_loggers.ServerLogger(), + event_loggers.TopologyLogger(), + event_loggers.ConnectionPoolLogger(), + ] + with MongoClient(event_listeners=listeners, connect=False) as c: + self.assertEqual(c.options.event_listeners, listeners) def test_client_options(self): c = MongoClient(connect=False) @@ -583,6 +609,7 @@ def test_client_options(self): self.assertEqual(c.options.pool_options.max_idle_time_seconds, None) self.assertIsInstance(c.options.retry_writes, bool) self.assertIsInstance(c.options.retry_reads, bool) + c.close() def test_validate_suggestion(self): """Validate kwargs in constructor.""" @@ -606,11 +633,14 @@ def test_detected_environment_logging(self, mock_get_hosts): ) with self.assertLogs("pymongo", level="INFO") as cm: for host in normal_hosts: - MongoClient(host) + with MongoClient(host): + pass for host in srv_hosts: mock_get_hosts.return_value = [(host, 1)] - MongoClient(host) - MongoClient(multi_host) + with MongoClient(host): + pass + with MongoClient(multi_host): + pass logs = [record.message for record in cm.records if record.name == "pymongo.client"] self.assertEqual(len(logs), 7) @@ -628,13 +658,16 @@ def test_detected_environment_warning(self, mock_get_hosts): ) for host in normal_hosts: with self.assertWarns(UserWarning): - MongoClient(host) + with MongoClient(host): + pass for host in srv_hosts: mock_get_hosts.return_value = [(host, 1)] with self.assertWarns(UserWarning): - MongoClient(host) + with MongoClient(host): + pass with self.assertWarns(UserWarning): - MongoClient(multi_host) + with MongoClient(multi_host): + pass class TestClient(IntegrationTest): @@ -708,11 +741,13 @@ def test_max_idle_time_reaper_removes_stale(self): def test_min_pool_size(self): with client_knobs(kill_cursor_frequency=0.1): client = rs_or_single_client() + self.addCleanup(client.close) server = (client._get_topology()).select_server(readable_server_selector, _Op.TEST) self.assertEqual(0, len(server._pool.conns)) # Assert that pool started up at minPoolSize client = rs_or_single_client(minPoolSize=10) + self.addCleanup(client.close) server = (client._get_topology()).select_server(readable_server_selector, _Op.TEST) wait_until( lambda: len(server._pool.conns) == 10, @@ -732,6 +767,7 @@ def test_max_idle_time_checkout(self): # Use high frequency to test _get_socket_no_auth. with client_knobs(kill_cursor_frequency=99999999): client = rs_or_single_client(maxIdleTimeMS=500) + self.addCleanup(client.close) server = (client._get_topology()).select_server(readable_server_selector, _Op.TEST) with server._pool.checkout() as conn: pass @@ -746,6 +782,7 @@ def test_max_idle_time_checkout(self): # Test that connections are reused if maxIdleTimeMS is not set. client = rs_or_single_client() + self.addCleanup(client.close) server = (client._get_topology()).select_server(readable_server_selector, _Op.TEST) with server._pool.checkout() as conn: pass @@ -769,36 +806,45 @@ def test_constants(self): MongoClient.HOST = "somedomainthatdoesntexist.org" MongoClient.PORT = 123456789 with self.assertRaises(AutoReconnect): - connected(MongoClient(serverSelectionTimeoutMS=10, **kwargs)) + with MongoClient(serverSelectionTimeoutMS=10, **kwargs) as c: + connected(c) - # Override the defaults. No error. - connected(MongoClient(host, port, **kwargs)) + with MongoClient(host, port, **kwargs) as c: + # Override the defaults. No error. + connected(c) # Set good defaults. MongoClient.HOST = host MongoClient.PORT = port # No error. - connected(MongoClient(**kwargs)) + with MongoClient(**kwargs) as c: + connected(c) def test_init_disconnected(self): host, port = client_context.host, client_context.port c = rs_or_single_client(connect=False) + self.addCleanup(c.close) # is_primary causes client to block until connected self.assertIsInstance(c.is_primary, bool) c = rs_or_single_client(connect=False) + self.addCleanup(c.close) self.assertIsInstance(c.is_mongos, bool) c = rs_or_single_client(connect=False) + self.addCleanup(c.close) self.assertIsInstance(c.options.pool_options.max_pool_size, int) self.assertIsInstance(c.nodes, frozenset) c = rs_or_single_client(connect=False) + self.addCleanup(c.close) self.assertEqual(c.codec_options, CodecOptions()) c = rs_or_single_client(connect=False) + self.addCleanup(c.close) self.assertFalse(c.primary) self.assertFalse(c.secondaries) c = rs_or_single_client(connect=False) + self.addCleanup(c.close) self.assertIsInstance(c.topology_description, TopologyDescription) self.assertEqual(c.topology_description, c._topology._description) self.assertIsNone(c.address) # PYTHON-2981 @@ -811,34 +857,43 @@ def test_init_disconnected(self): bad_host = "somedomainthatdoesntexist.org" c = MongoClient(bad_host, port, connectTimeoutMS=1, serverSelectionTimeoutMS=10) + self.addCleanup(c.close) with self.assertRaises(ConnectionFailure): c.pymongo_test.test.find_one() def test_init_disconnected_with_auth(self): uri = "mongodb://user:pass@somedomainthatdoesntexist" c = MongoClient(uri, connectTimeoutMS=1, serverSelectionTimeoutMS=10) + self.addCleanup(c.close) with self.assertRaises(ConnectionFailure): c.pymongo_test.test.find_one() def test_equality(self): seed = "{}:{}".format(*list(self.client._topology_settings.seeds)[0]) - c = rs_or_single_client(seed, connect=False) - self.addCleanup(c.close) - self.assertEqual(client_context.client, c) - # Explicitly test inequality - self.assertFalse(client_context.client != c) + with rs_or_single_client(seed, connect=False) as c: + self.assertEqual(client_context.client, c) + # Explicitly test inequality + self.assertFalse(client_context.client != c) + + with rs_or_single_client("invalid.com", connect=False) as c: + self.assertNotEqual(client_context.client, c) + self.assertTrue(client_context.client != c) + + c1 = MongoClient("a", connect=False) + c2 = MongoClient("b", connect=False) + self.addCleanup(c1.close) + self.addCleanup(c2.close) - c = rs_or_single_client("invalid.com", connect=False) - self.addCleanup(c.close) - self.assertNotEqual(client_context.client, c) - self.assertTrue(client_context.client != c) # Seeds differ: - self.assertNotEqual(MongoClient("a", connect=False), MongoClient("b", connect=False)) + self.assertNotEqual(c1, c2) + + c1 = MongoClient(["a", "b", "c"], connect=False) + c2 = MongoClient(["c", "a", "b"], connect=False) + self.addCleanup(c1.close) + self.addCleanup(c2.close) + # Same seeds but out of order still compares equal: - self.assertEqual( - MongoClient(["a", "b", "c"], connect=False), - MongoClient(["c", "a", "b"], connect=False), - ) + self.assertEqual(c1, c2) def test_hashable(self): seed = "{}:{}".format(*list(self.client._topology_settings.seeds)[0]) @@ -870,6 +925,7 @@ def test_repr(self): connect=False, document_class=SON, ) + self.addCleanup(client.close) the_repr = repr(client) self.assertIn("MongoClient(host=", the_repr) @@ -879,7 +935,8 @@ def test_repr(self): self.assertIn("w=1", the_repr) self.assertIn("wtimeoutms=100", the_repr) - self.assertEqual(eval(the_repr), client) + with eval(the_repr) as client_two: + self.assertEqual(client_two, client) client = MongoClient( "localhost:27017,localhost:27018", @@ -890,6 +947,7 @@ def test_repr(self): wtimeoutms=100, connect=False, ) + self.addCleanup(client.close) the_repr = repr(client) self.assertIn("MongoClient(host=", the_repr) self.assertIn("document_class=dict, tz_aware=False, connect=False, ", the_repr) @@ -899,7 +957,8 @@ def test_repr(self): self.assertIn("w=1", the_repr) self.assertIn("wtimeoutms=100", the_repr) - self.assertEqual(eval(the_repr), client) + with eval(the_repr) as client_two: + self.assertEqual(client_two, client) def test_getters(self): wait_until(lambda: client_context.nodes == self.client.nodes, "find all nodes") @@ -1006,6 +1065,7 @@ def test_close_kills_cursors(self): # The killCursors task should not need to re-open the topology. test_client._process_periodic_tasks() self.assertTrue(test_client._topology._opened) + test_client.close() def test_close_stops_kill_cursors_thread(self): client = rs_client() @@ -1156,9 +1216,10 @@ def test_unix_socket(self): # Confirm it fails with a missing socket. with self.assertRaises(ConnectionFailure): - connected( - MongoClient("mongodb://%2Ftmp%2Fnon-existent.sock", serverSelectionTimeoutMS=100), - ) + with MongoClient( + "mongodb://%2Ftmp%2Fnon-existent.sock", serverSelectionTimeoutMS=100 + ) as c: + connected(c) def test_document_class(self): c = self.client @@ -1183,6 +1244,7 @@ def test_timeouts(self): maxIdleTimeMS=10500, serverSelectionTimeoutMS=10500, ) + self.addCleanup(client.close) self.assertEqual(10.5, (get_pool(client)).opts.connect_timeout) self.assertEqual(10.5, (get_pool(client)).opts.socket_timeout) self.assertEqual(10.5, (get_pool(client)).opts.max_idle_time_seconds) @@ -1191,22 +1253,28 @@ def test_timeouts(self): def test_socket_timeout_ms_validation(self): c = rs_or_single_client(socketTimeoutMS=10 * 1000) + self.addCleanup(c.close) self.assertEqual(10, (get_pool(c)).opts.socket_timeout) c = connected(rs_or_single_client(socketTimeoutMS=None)) + self.addCleanup(c.close) self.assertEqual(None, (get_pool(c)).opts.socket_timeout) c = connected(rs_or_single_client(socketTimeoutMS=0)) + self.addCleanup(c.close) self.assertEqual(None, (get_pool(c)).opts.socket_timeout) with self.assertRaises(ValueError): - rs_or_single_client(socketTimeoutMS=-1) + with rs_or_single_client(socketTimeoutMS=-1): + pass with self.assertRaises(ValueError): - rs_or_single_client(socketTimeoutMS=1e10) + with rs_or_single_client(socketTimeoutMS=1e10): + pass with self.assertRaises(ValueError): - rs_or_single_client(socketTimeoutMS="foo") + with rs_or_single_client(socketTimeoutMS="foo"): + pass def test_socket_timeout(self): no_timeout = self.client @@ -1230,9 +1298,11 @@ def get_x(db): def test_server_selection_timeout(self): client = MongoClient(serverSelectionTimeoutMS=100, connect=False) + self.addCleanup(client.close) self.assertAlmostEqual(0.1, client.options.server_selection_timeout) client = MongoClient(serverSelectionTimeoutMS=0, connect=False) + self.addCleanup(client.close) self.assertAlmostEqual(0, client.options.server_selection_timeout) @@ -1243,20 +1313,25 @@ def test_server_selection_timeout(self): ) client = MongoClient("mongodb://localhost/?serverSelectionTimeoutMS=100", connect=False) + self.addCleanup(client.close) self.assertAlmostEqual(0.1, client.options.server_selection_timeout) client = MongoClient("mongodb://localhost/?serverSelectionTimeoutMS=0", connect=False) + self.addCleanup(client.close) self.assertAlmostEqual(0, client.options.server_selection_timeout) # Test invalid timeout in URI ignored and set to default. client = MongoClient("mongodb://localhost/?serverSelectionTimeoutMS=-1", connect=False) + self.addCleanup(client.close) self.assertAlmostEqual(30, client.options.server_selection_timeout) client = MongoClient("mongodb://localhost/?serverSelectionTimeoutMS=", connect=False) + self.addCleanup(client.close) self.assertAlmostEqual(30, client.options.server_selection_timeout) def test_waitQueueTimeoutMS(self): client = rs_or_single_client(waitQueueTimeoutMS=2000) + self.addCleanup(client.close) self.assertEqual((get_pool(client)).opts.wait_queue_timeout, 2) def test_socketKeepAlive(self): @@ -1501,10 +1576,9 @@ def test_auth_network_error(self): @client_context.require_no_replica_set def test_connect_to_standalone_using_replica_set_name(self): - client = single_client(replicaSet="anything", serverSelectionTimeoutMS=100) - - with self.assertRaises(AutoReconnect): - client.test.test.find_one() + with single_client(replicaSet="anything", serverSelectionTimeoutMS=100) as client: + with self.assertRaises(AutoReconnect): + client.test.test.find_one() @client_context.require_replica_set def test_stale_getmore(self): @@ -1586,84 +1660,84 @@ def compression_settings(client): return pool_options._compression_settings uri = "mongodb://localhost:27017/?compressors=zlib" - client = MongoClient(uri, connect=False) - opts = compression_settings(client) - self.assertEqual(opts.compressors, ["zlib"]) + with MongoClient(uri, connect=False) as client: + opts = compression_settings(client) + self.assertEqual(opts.compressors, ["zlib"]) uri = "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=4" - client = MongoClient(uri, connect=False) - opts = compression_settings(client) - self.assertEqual(opts.compressors, ["zlib"]) - self.assertEqual(opts.zlib_compression_level, 4) + with MongoClient(uri, connect=False) as client: + opts = compression_settings(client) + self.assertEqual(opts.compressors, ["zlib"]) + self.assertEqual(opts.zlib_compression_level, 4) uri = "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=-1" - client = MongoClient(uri, connect=False) - opts = compression_settings(client) - self.assertEqual(opts.compressors, ["zlib"]) - self.assertEqual(opts.zlib_compression_level, -1) + with MongoClient(uri, connect=False) as client: + opts = compression_settings(client) + self.assertEqual(opts.compressors, ["zlib"]) + self.assertEqual(opts.zlib_compression_level, -1) uri = "mongodb://localhost:27017" - client = MongoClient(uri, connect=False) - opts = compression_settings(client) - self.assertEqual(opts.compressors, []) - self.assertEqual(opts.zlib_compression_level, -1) + with MongoClient(uri, connect=False) as client: + opts = compression_settings(client) + self.assertEqual(opts.compressors, []) + self.assertEqual(opts.zlib_compression_level, -1) uri = "mongodb://localhost:27017/?compressors=foobar" - client = MongoClient(uri, connect=False) - opts = compression_settings(client) - self.assertEqual(opts.compressors, []) - self.assertEqual(opts.zlib_compression_level, -1) + with MongoClient(uri, connect=False) as client: + opts = compression_settings(client) + self.assertEqual(opts.compressors, []) + self.assertEqual(opts.zlib_compression_level, -1) uri = "mongodb://localhost:27017/?compressors=foobar,zlib" - client = MongoClient(uri, connect=False) - opts = compression_settings(client) - self.assertEqual(opts.compressors, ["zlib"]) - self.assertEqual(opts.zlib_compression_level, -1) + with MongoClient(uri, connect=False) as client: + opts = compression_settings(client) + self.assertEqual(opts.compressors, ["zlib"]) + self.assertEqual(opts.zlib_compression_level, -1) # According to the connection string spec, unsupported values # just raise a warning and are ignored. uri = "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=10" - client = MongoClient(uri, connect=False) - opts = compression_settings(client) - self.assertEqual(opts.compressors, ["zlib"]) - self.assertEqual(opts.zlib_compression_level, -1) + with MongoClient(uri, connect=False) as client: + opts = compression_settings(client) + self.assertEqual(opts.compressors, ["zlib"]) + self.assertEqual(opts.zlib_compression_level, -1) uri = "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=-2" - client = MongoClient(uri, connect=False) - opts = compression_settings(client) - self.assertEqual(opts.compressors, ["zlib"]) - self.assertEqual(opts.zlib_compression_level, -1) + with MongoClient(uri, connect=False) as client: + opts = compression_settings(client) + self.assertEqual(opts.compressors, ["zlib"]) + self.assertEqual(opts.zlib_compression_level, -1) if not _have_snappy(): uri = "mongodb://localhost:27017/?compressors=snappy" - client = MongoClient(uri, connect=False) - opts = compression_settings(client) - self.assertEqual(opts.compressors, []) + with MongoClient(uri, connect=False) as client: + opts = compression_settings(client) + self.assertEqual(opts.compressors, []) else: uri = "mongodb://localhost:27017/?compressors=snappy" - client = MongoClient(uri, connect=False) - opts = compression_settings(client) - self.assertEqual(opts.compressors, ["snappy"]) + with MongoClient(uri, connect=False) as client: + opts = compression_settings(client) + self.assertEqual(opts.compressors, ["snappy"]) uri = "mongodb://localhost:27017/?compressors=snappy,zlib" - client = MongoClient(uri, connect=False) - opts = compression_settings(client) - self.assertEqual(opts.compressors, ["snappy", "zlib"]) + with MongoClient(uri, connect=False) as client: + opts = compression_settings(client) + self.assertEqual(opts.compressors, ["snappy", "zlib"]) if not _have_zstd(): uri = "mongodb://localhost:27017/?compressors=zstd" - client = MongoClient(uri, connect=False) - opts = compression_settings(client) - self.assertEqual(opts.compressors, []) + with MongoClient(uri, connect=False) as client: + opts = compression_settings(client) + self.assertEqual(opts.compressors, []) else: uri = "mongodb://localhost:27017/?compressors=zstd" - client = MongoClient(uri, connect=False) - opts = compression_settings(client) - self.assertEqual(opts.compressors, ["zstd"]) + with MongoClient(uri, connect=False) as client: + opts = compression_settings(client) + self.assertEqual(opts.compressors, ["zstd"]) uri = "mongodb://localhost:27017/?compressors=zstd,zlib" - client = MongoClient(uri, connect=False) - opts = compression_settings(client) - self.assertEqual(opts.compressors, ["zstd", "zlib"]) + with MongoClient(uri, connect=False) as client: + opts = compression_settings(client) + self.assertEqual(opts.compressors, ["zstd", "zlib"]) options = client_context.default_client_options if "compressors" in options and "zlib" in options["compressors"]: for level in range(-1, 10): - client = single_client(zlibcompressionlevel=level) - # No error - client.pymongo_test.test.find_one() + with single_client(zlibcompressionlevel=level) as client: + # No error + client.pymongo_test.test.find_one() def test_reset_during_update_pool(self): client = rs_or_single_client(minPoolSize=10) @@ -1835,6 +1909,7 @@ def test_service_name_from_kwargs(self): srvServiceName="customname", connect=False, ) + self.addCleanup(client.close) self.assertEqual(client._topology_settings.srv_service_name, "customname") client = MongoClient( "mongodb+srv://user:password@test22.test.build.10gen.cc" @@ -1842,21 +1917,26 @@ def test_service_name_from_kwargs(self): srvServiceName="customname", connect=False, ) + self.addCleanup(client.close) self.assertEqual(client._topology_settings.srv_service_name, "customname") client = MongoClient( "mongodb+srv://user:password@test22.test.build.10gen.cc/?srvServiceName=customname", connect=False, ) + self.addCleanup(client.close) self.assertEqual(client._topology_settings.srv_service_name, "customname") def test_srv_max_hosts_kwarg(self): client = MongoClient("mongodb+srv://test1.test.build.10gen.cc/") + self.addCleanup(client.close) self.assertGreater(len(client.topology_description.server_descriptions()), 1) client = MongoClient("mongodb+srv://test1.test.build.10gen.cc/", srvmaxhosts=1) + self.addCleanup(client.close) self.assertEqual(len(client.topology_description.server_descriptions()), 1) client = MongoClient( "mongodb+srv://test1.test.build.10gen.cc/?srvMaxHosts=1", srvmaxhosts=2 ) + self.addCleanup(client.close) self.assertEqual(len(client.topology_description.server_descriptions()), 2) @unittest.skipIf( @@ -2002,6 +2082,7 @@ def test_exhaust_query_server_error(self): # When doing an exhaust query, the socket stays checked out on success # but must be checked in on error to avoid semaphore leaks. client = connected(rs_or_single_client(maxPoolSize=1)) + self.addCleanup(client.close) collection = client.pymongo_test.test pool = get_pool(client) @@ -2025,6 +2106,7 @@ def test_exhaust_getmore_server_error(self): # When doing a getmore on an exhaust cursor, the socket stays checked # out on success but it's checked in on error to avoid semaphore leaks. client = rs_or_single_client(maxPoolSize=1) + self.addCleanup(client.close) collection = client.pymongo_test.test collection.drop() @@ -2064,6 +2146,7 @@ def test_exhaust_query_network_error(self): # When doing an exhaust query, the socket stays checked out on success # but must be checked in on error to avoid semaphore leaks. client = connected(rs_or_single_client(maxPoolSize=1, retryReads=False)) + self.addCleanup(client.close) collection = client.pymongo_test.test pool = get_pool(client) pool._check_interval_seconds = None # Never check. @@ -2085,6 +2168,7 @@ def test_exhaust_getmore_network_error(self): # When doing a getmore on an exhaust cursor, the socket stays checked # out on success but it's checked in on error to avoid semaphore leaks. client = rs_or_single_client(maxPoolSize=1) + self.addCleanup(client.close) collection = client.pymongo_test.test collection.drop() collection.insert_many([{} for _ in range(200)]) # More than one batch. From d45a8c7d26e6864c73e7a2231cdd921a041455df Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Wed, 11 Sep 2024 16:59:22 -0400 Subject: [PATCH 06/29] All async tests converted --- pymongo/asynchronous/mongo_client.py | 1 + pymongo/synchronous/mongo_client.py | 1 + test/__init__.py | 177 +++++++++++++++++- test/asynchronous/__init__.py | 197 +++++++++++++++++++- test/asynchronous/test_bulk.py | 11 +- test/asynchronous/test_client.py | 181 +++++++++--------- test/asynchronous/test_client_bulk_write.py | 33 ++-- test/asynchronous/test_collection.py | 20 +- test/asynchronous/test_cursor.py | 29 ++- test/asynchronous/test_database.py | 7 +- test/asynchronous/test_encryption.py | 66 +++---- test/asynchronous/test_grid_file.py | 7 +- test/asynchronous/test_logger.py | 3 +- test/asynchronous/test_session.py | 18 +- test/asynchronous/test_transactions.py | 30 ++- test/asynchronous/utils_spec_runner.py | 5 +- test/test_bulk.py | 11 +- test/test_client.py | 176 ++++++++--------- test/test_client_bulk_write.py | 33 ++-- test/test_collection.py | 20 +- test/test_cursor.py | 29 ++- test/test_custom_types.py | 3 +- test/test_database.py | 7 +- test/test_encryption.py | 54 +++--- test/test_grid_file.py | 7 +- test/test_logger.py | 3 +- test/test_session.py | 15 +- test/test_transactions.py | 26 +-- test/unified_format.py | 8 +- test/utils.py | 159 ---------------- test/utils_spec_runner.py | 5 +- 31 files changed, 755 insertions(+), 587 deletions(-) diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index bbfb39ebb0..3cb01462f7 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -865,6 +865,7 @@ def __init__( ) self._opened = False + self._has_resources = False self._closed = False self._init_background() diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 1da818fb8c..086fe28e97 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -864,6 +864,7 @@ def __init__( ) self._opened = False + self._has_resources = False self._closed = False self._init_background() diff --git a/test/__init__.py b/test/__init__.py index 41af81f979..5dc2b132fd 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -16,8 +16,6 @@ from __future__ import annotations import asyncio -import base64 -import contextlib import gc import multiprocessing import os @@ -27,7 +25,6 @@ import sys import threading import time -import traceback import unittest import warnings from asyncio import iscoroutinefunction @@ -54,6 +51,8 @@ sanitize_reply, ) +from pymongo.uri_parser import parse_uri + try: import ipaddress @@ -80,6 +79,12 @@ _IS_SYNC = True +def _connection_string(h): + if h.startswith(("mongodb://", "mongodb+srv://")): + return h + return f"mongodb://{h!s}" + + class ClientContext: client: MongoClient @@ -257,6 +262,8 @@ def _init_client(self): self.replica_set_name = str(hello["setName"]) self.is_rs = True if self.auth_enabled: + if self.client: + self.client.close() # It doesn't matter which member we use as the seed here. self.client = pymongo.MongoClient( host, @@ -267,6 +274,8 @@ def _init_client(self): **self.default_client_options, ) else: + if self.client: + self.client.close() self.client = pymongo.MongoClient( host, port, replicaSet=self.replica_set_name, **self.default_client_options ) @@ -318,6 +327,7 @@ def _init_client(self): hello = mongos_client.admin.command(HelloCompat.LEGACY_CMD) if hello.get("msg") == "isdbgrid": self.mongoses.append(next_address) + mongos_client.close() def init(self): with self.conn_lock: @@ -537,12 +547,6 @@ def require_auth(self, func): lambda: self.auth_enabled, "Authentication is not enabled on the server", func=func ) - def require_no_fips(self, func): - """Run a test only if the host does not have FIPS enabled.""" - return self._require( - lambda: not self.fips_enabled, "Test cannot run on a FIPS-enabled host", func=func - ) - def require_no_auth(self, func): """Run a test only if the server is running without auth enabled.""" return self._require( @@ -930,6 +934,161 @@ def _target() -> None: self.fail(f"child timed out after {timeout}s (see traceback in logs): deadlock?") self.assertEqual(proc.exitcode, 0) + @classmethod + def _unmanaged_async_mongo_client( + cls, host, port, authenticate=True, directConnection=None, **kwargs + ): + """Create a new client over SSL/TLS if necessary.""" + host = host or client_context.host + port = port or client_context.port + client_options: dict = client_context.default_client_options.copy() + if client_context.replica_set_name and not directConnection: + client_options["replicaSet"] = client_context.replica_set_name + if directConnection is not None: + client_options["directConnection"] = directConnection + client_options.update(kwargs) + + uri = _connection_string(host) + auth_mech = kwargs.get("authMechanism", "") + if client_context.auth_enabled and authenticate and auth_mech != "MONGODB-OIDC": + # Only add the default username or password if one is not provided. + res = parse_uri(uri) + if ( + not res["username"] + and not res["password"] + and "username" not in client_options + and "password" not in client_options + ): + client_options["username"] = db_user + client_options["password"] = db_pwd + client = MongoClient(uri, port, **client_options) + if client._options.connect: + client._connect() + return client + + def _async_mongo_client(self, host, port, authenticate=True, directConnection=None, **kwargs): + """Create a new client over SSL/TLS if necessary.""" + host = host or client_context.host + port = port or client_context.port + client_options: dict = client_context.default_client_options.copy() + if client_context.replica_set_name and not directConnection: + client_options["replicaSet"] = client_context.replica_set_name + if directConnection is not None: + client_options["directConnection"] = directConnection + client_options.update(kwargs) + + uri = _connection_string(host) + auth_mech = kwargs.get("authMechanism", "") + if client_context.auth_enabled and authenticate and auth_mech != "MONGODB-OIDC": + # Only add the default username or password if one is not provided. + res = parse_uri(uri) + if ( + not res["username"] + and not res["password"] + and "username" not in client_options + and "password" not in client_options + ): + client_options["username"] = db_user + client_options["password"] = db_pwd + client = MongoClient(uri, port, **client_options) + if client._options.connect: + client._connect() + self.addCleanup(client.close) + return client + + @classmethod + def unmanaged_single_client_noauth( + cls, h: Any = None, p: Any = None, **kwargs: Any + ) -> MongoClient[dict]: + """Make a direct connection. Don't authenticate.""" + return PyMongoTestCase._unmanaged_async_mongo_client( + h, p, authenticate=False, directConnection=True, **kwargs + ) + + @classmethod + def unmanaged_single_client( + cls, h: Any = None, p: Any = None, **kwargs: Any + ) -> MongoClient[dict]: + """Make a direct connection. Don't authenticate.""" + return PyMongoTestCase._unmanaged_async_mongo_client(h, p, directConnection=True, **kwargs) + + @classmethod + def unmanaged_rs_client(cls, h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]: + """Connect to the replica set and authenticate if necessary.""" + return PyMongoTestCase._unmanaged_async_mongo_client(h, p, **kwargs) + + @classmethod + def unmanaged_rs_client_noauth( + cls, h: Any = None, p: Any = None, **kwargs: Any + ) -> MongoClient[dict]: + """Make a direct connection. Don't authenticate.""" + return PyMongoTestCase._unmanaged_async_mongo_client(h, p, authenticate=False, **kwargs) + + @classmethod + def unmanaged_rs_or_single_client_noauth( + cls, h: Any = None, p: Any = None, **kwargs: Any + ) -> MongoClient[dict]: + """Make a direct connection. Don't authenticate.""" + return PyMongoTestCase._unmanaged_async_mongo_client(h, p, authenticate=False, **kwargs) + + @classmethod + def unmanaged_rs_or_single_client( + cls, h: Any = None, p: Any = None, **kwargs: Any + ) -> MongoClient[dict]: + """Make a direct connection. Don't authenticate.""" + return PyMongoTestCase._unmanaged_async_mongo_client(h, p, **kwargs) + + def single_client_noauth( + self, h: Any = None, p: Any = None, **kwargs: Any + ) -> MongoClient[dict]: + """Make a direct connection. Don't authenticate.""" + return self._async_mongo_client(h, p, authenticate=False, directConnection=True, **kwargs) + + def single_client(self, h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]: + """Make a direct connection, and authenticate if necessary.""" + return self._async_mongo_client(h, p, directConnection=True, **kwargs) + + def rs_client_noauth(self, h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]: + """Connect to the replica set. Don't authenticate.""" + return self._async_mongo_client(h, p, authenticate=False, **kwargs) + + def rs_client(self, h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]: + """Connect to the replica set and authenticate if necessary.""" + return self._async_mongo_client(h, p, **kwargs) + + def rs_or_single_client_noauth( + self, h: Any = None, p: Any = None, **kwargs: Any + ) -> MongoClient[dict]: + """Connect to the replica set if there is one, otherwise the standalone. + + Like rs_or_single_client, but does not authenticate. + """ + return self._async_mongo_client(h, p, authenticate=False, **kwargs) + + def rs_or_single_client(self, h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[Any]: + """Connect to the replica set if there is one, otherwise the standalone. + + Authenticates if necessary. + """ + return self._async_mongo_client(h, p, **kwargs) + + def simple_client(self, **kwargs: Any) -> MongoClient: + client = MongoClient(**kwargs) + self.addCleanup(client.close) + return client + + def disable_replication(self, client): + """Disable replication on all secondaries.""" + for h, p in client.secondaries: + secondary = self.single_client(h, p) + secondary.admin.command("configureFailPoint", "stopReplProducer", mode="alwaysOn") + + def enable_replication(self, client): + """Enable replication on all secondaries.""" + for h, p in client.secondaries: + secondary = self.single_client(h, p) + secondary.admin.command("configureFailPoint", "stopReplProducer", mode="off") + class UnitTest(PyMongoTestCase): """Async base class for TestCases that don't require a connection to MongoDB.""" diff --git a/test/asynchronous/__init__.py b/test/asynchronous/__init__.py index d1af89c184..3da10c5ec6 100644 --- a/test/asynchronous/__init__.py +++ b/test/asynchronous/__init__.py @@ -16,8 +16,6 @@ from __future__ import annotations import asyncio -import base64 -import contextlib import gc import multiprocessing import os @@ -27,7 +25,6 @@ import sys import threading import time -import traceback import unittest import warnings from asyncio import iscoroutinefunction @@ -54,6 +51,8 @@ sanitize_reply, ) +from pymongo.uri_parser import parse_uri + try: import ipaddress @@ -80,6 +79,12 @@ _IS_SYNC = False +def _connection_string(h): + if h.startswith(("mongodb://", "mongodb+srv://")): + return h + return f"mongodb://{h!s}" + + class AsyncClientContext: client: AsyncMongoClient @@ -257,6 +262,8 @@ async def _init_client(self): self.replica_set_name = str(hello["setName"]) self.is_rs = True if self.auth_enabled: + if self.client: + await self.client.close() # It doesn't matter which member we use as the seed here. self.client = pymongo.AsyncMongoClient( host, @@ -267,6 +274,8 @@ async def _init_client(self): **self.default_client_options, ) else: + if self.client: + await self.client.close() self.client = pymongo.AsyncMongoClient( host, port, replicaSet=self.replica_set_name, **self.default_client_options ) @@ -320,6 +329,7 @@ async def _init_client(self): hello = await mongos_client.admin.command(HelloCompat.LEGACY_CMD) if hello.get("msg") == "isdbgrid": self.mongoses.append(next_address) + await mongos_client.close() async def init(self): with self.conn_lock: @@ -539,12 +549,6 @@ def require_auth(self, func): lambda: self.auth_enabled, "Authentication is not enabled on the server", func=func ) - def require_no_fips(self, func): - """Run a test only if the host does not have FIPS enabled.""" - return self._require( - lambda: not self.fips_enabled, "Test cannot run on a FIPS-enabled host", func=func - ) - def require_no_auth(self, func): """Run a test only if the server is running without auth enabled.""" return self._require( @@ -932,6 +936,181 @@ def _target() -> None: self.fail(f"child timed out after {timeout}s (see traceback in logs): deadlock?") self.assertEqual(proc.exitcode, 0) + @classmethod + async def _unmanaged_async_mongo_client( + cls, host, port, authenticate=True, directConnection=None, **kwargs + ): + """Create a new client over SSL/TLS if necessary.""" + host = host or await async_client_context.host + port = port or await async_client_context.port + client_options: dict = async_client_context.default_client_options.copy() + if async_client_context.replica_set_name and not directConnection: + client_options["replicaSet"] = async_client_context.replica_set_name + if directConnection is not None: + client_options["directConnection"] = directConnection + client_options.update(kwargs) + + uri = _connection_string(host) + auth_mech = kwargs.get("authMechanism", "") + if async_client_context.auth_enabled and authenticate and auth_mech != "MONGODB-OIDC": + # Only add the default username or password if one is not provided. + res = parse_uri(uri) + if ( + not res["username"] + and not res["password"] + and "username" not in client_options + and "password" not in client_options + ): + client_options["username"] = db_user + client_options["password"] = db_pwd + client = AsyncMongoClient(uri, port, **client_options) + if client._options.connect: + await client.aconnect() + return client + + async def _async_mongo_client( + self, host, port, authenticate=True, directConnection=None, **kwargs + ): + """Create a new client over SSL/TLS if necessary.""" + host = host or await async_client_context.host + port = port or await async_client_context.port + client_options: dict = async_client_context.default_client_options.copy() + if async_client_context.replica_set_name and not directConnection: + client_options["replicaSet"] = async_client_context.replica_set_name + if directConnection is not None: + client_options["directConnection"] = directConnection + client_options.update(kwargs) + + uri = _connection_string(host) + auth_mech = kwargs.get("authMechanism", "") + if async_client_context.auth_enabled and authenticate and auth_mech != "MONGODB-OIDC": + # Only add the default username or password if one is not provided. + res = parse_uri(uri) + if ( + not res["username"] + and not res["password"] + and "username" not in client_options + and "password" not in client_options + ): + client_options["username"] = db_user + client_options["password"] = db_pwd + client = AsyncMongoClient(uri, port, **client_options) + if client._options.connect: + await client.aconnect() + self.addAsyncCleanup(client.close) + return client + + @classmethod + async def unmanaged_async_single_client_noauth( + cls, h: Any = None, p: Any = None, **kwargs: Any + ) -> AsyncMongoClient[dict]: + """Make a direct connection. Don't authenticate.""" + return await AsyncPyMongoTestCase._unmanaged_async_mongo_client( + h, p, authenticate=False, directConnection=True, **kwargs + ) + + @classmethod + async def unmanaged_async_single_client( + cls, h: Any = None, p: Any = None, **kwargs: Any + ) -> AsyncMongoClient[dict]: + """Make a direct connection. Don't authenticate.""" + return await AsyncPyMongoTestCase._unmanaged_async_mongo_client( + h, p, directConnection=True, **kwargs + ) + + @classmethod + async def unmanaged_async_rs_client( + cls, h: Any = None, p: Any = None, **kwargs: Any + ) -> AsyncMongoClient[dict]: + """Connect to the replica set and authenticate if necessary.""" + return await AsyncPyMongoTestCase._unmanaged_async_mongo_client(h, p, **kwargs) + + @classmethod + async def unmanaged_async_rs_client_noauth( + cls, h: Any = None, p: Any = None, **kwargs: Any + ) -> AsyncMongoClient[dict]: + """Make a direct connection. Don't authenticate.""" + return await AsyncPyMongoTestCase._unmanaged_async_mongo_client( + h, p, authenticate=False, **kwargs + ) + + @classmethod + async def unmanaged_async_rs_or_single_client_noauth( + cls, h: Any = None, p: Any = None, **kwargs: Any + ) -> AsyncMongoClient[dict]: + """Make a direct connection. Don't authenticate.""" + return await AsyncPyMongoTestCase._unmanaged_async_mongo_client( + h, p, authenticate=False, **kwargs + ) + + @classmethod + async def unmanaged_async_rs_or_single_client( + cls, h: Any = None, p: Any = None, **kwargs: Any + ) -> AsyncMongoClient[dict]: + """Make a direct connection. Don't authenticate.""" + return await AsyncPyMongoTestCase._unmanaged_async_mongo_client(h, p, **kwargs) + + async def async_single_client_noauth( + self, h: Any = None, p: Any = None, **kwargs: Any + ) -> AsyncMongoClient[dict]: + """Make a direct connection. Don't authenticate.""" + return await self._async_mongo_client( + h, p, authenticate=False, directConnection=True, **kwargs + ) + + async def async_single_client( + self, h: Any = None, p: Any = None, **kwargs: Any + ) -> AsyncMongoClient[dict]: + """Make a direct connection, and authenticate if necessary.""" + return await self._async_mongo_client(h, p, directConnection=True, **kwargs) + + async def async_rs_client_noauth( + self, h: Any = None, p: Any = None, **kwargs: Any + ) -> AsyncMongoClient[dict]: + """Connect to the replica set. Don't authenticate.""" + return await self._async_mongo_client(h, p, authenticate=False, **kwargs) + + async def async_rs_client( + self, h: Any = None, p: Any = None, **kwargs: Any + ) -> AsyncMongoClient[dict]: + """Connect to the replica set and authenticate if necessary.""" + return await self._async_mongo_client(h, p, **kwargs) + + async def async_rs_or_single_client_noauth( + self, h: Any = None, p: Any = None, **kwargs: Any + ) -> AsyncMongoClient[dict]: + """Connect to the replica set if there is one, otherwise the standalone. + + Like rs_or_single_client, but does not authenticate. + """ + return await self._async_mongo_client(h, p, authenticate=False, **kwargs) + + async def async_rs_or_single_client( + self, h: Any = None, p: Any = None, **kwargs: Any + ) -> AsyncMongoClient[Any]: + """Connect to the replica set if there is one, otherwise the standalone. + + Authenticates if necessary. + """ + return await self._async_mongo_client(h, p, **kwargs) + + def simple_client(self, **kwargs: Any) -> AsyncMongoClient: + client = AsyncMongoClient(**kwargs) + self.addAsyncCleanup(client.close) + return client + + async def disable_replication(self, client): + """Disable replication on all secondaries.""" + for h, p in client.secondaries: + secondary = await self.async_single_client(h, p) + secondary.admin.command("configureFailPoint", "stopReplProducer", mode="alwaysOn") + + async def enable_replication(self, client): + """Enable replication on all secondaries.""" + for h, p in client.secondaries: + secondary = await self.async_single_client(h, p) + secondary.admin.command("configureFailPoint", "stopReplProducer", mode="off") + class AsyncUnitTest(AsyncPyMongoTestCase): """Async base class for TestCases that don't require a connection to MongoDB.""" diff --git a/test/asynchronous/test_bulk.py b/test/asynchronous/test_bulk.py index 24111ad7c0..a90c237890 100644 --- a/test/asynchronous/test_bulk.py +++ b/test/asynchronous/test_bulk.py @@ -25,9 +25,7 @@ from test.asynchronous import AsyncIntegrationTest, async_client_context, remove_all_users, unittest from test.utils import ( - async_rs_or_single_client_noauth, async_wait_until, - single_client, ) from bson.binary import Binary, UuidRepresentation @@ -38,7 +36,6 @@ from pymongo.errors import ( BulkWriteError, ConfigurationError, - InvalidOperation, OperationFailure, ) from pymongo.operations import * @@ -915,7 +912,7 @@ class AsyncTestBulkAuthorization(AsyncBulkAuthorizationTestBase): async def test_readonly(self): # We test that an authorization failure aborts the batch and is raised # as OperationFailure. - cli = await async_rs_or_single_client_noauth( + cli = await self.async_rs_or_single_client_noauth( username="readonly", password="pw", authSource="pymongo_test" ) coll = cli.pymongo_test.test @@ -926,7 +923,7 @@ async def test_readonly(self): async def test_no_remove(self): # We test that an authorization failure aborts the batch and is raised # as OperationFailure. - cli = await async_rs_or_single_client_noauth( + cli = await self.async_rs_or_single_client_noauth( username="noremove", password="pw", authSource="pymongo_test" ) coll = cli.pymongo_test.test @@ -954,7 +951,9 @@ async def _setup_class(cls): if cls.w is not None and cls.w > 1: for member in (await async_client_context.hello)["hosts"]: if member != (await async_client_context.hello)["primary"]: - cls.secondary = single_client(*partition_node(member)) + cls.secondary = await AsyncTestBulkWriteConcern.unmanaged_async_single_client( + *partition_node(member) + ) break @classmethod diff --git a/test/asynchronous/test_client.py b/test/asynchronous/test_client.py index 9dbf9e3cb7..f3a354fee4 100644 --- a/test/asynchronous/test_client.py +++ b/test/asynchronous/test_client.py @@ -62,10 +62,6 @@ CMAPListener, FunctionCallRecorder, async_get_pool, - async_rs_client, - async_rs_or_single_client, - async_rs_or_single_client_noauth, - async_single_client, async_wait_until, asyncAssertRaisesExactly, delay, @@ -73,7 +69,6 @@ is_greenthread_patched, lazy_client_trial, one, - rs_or_single_client, wait_until, ) @@ -134,7 +129,9 @@ class AsyncClientUnitTest(AsyncUnitTest): @classmethod async def _setup_class(cls): - cls.client = await async_rs_or_single_client(connect=False, serverSelectionTimeoutMS=100) + cls.client = await AsyncClientUnitTest.unmanaged_async_rs_or_single_client( + connect=False, serverSelectionTimeoutMS=100 + ) @classmethod async def _tearDown_class(cls): @@ -264,7 +261,7 @@ def test_iteration(self): self.assertNotIsInstance(client, Iterable) async def test_get_default_database(self): - async with await async_rs_or_single_client( + async with await self.async_rs_or_single_client( "mongodb://%s:%d/foo" % (await async_client_context.host, await async_client_context.port), connect=False, @@ -283,7 +280,7 @@ async def test_get_default_database(self): self.assertEqual(ReadPreference.SECONDARY, db.read_preference) self.assertEqual(write_concern, db.write_concern) - async with await async_rs_or_single_client( + async with await self.async_rs_or_single_client( "mongodb://%s:%d/" % (await async_client_context.host, await async_client_context.port), connect=False, ) as c: @@ -291,7 +288,7 @@ async def test_get_default_database(self): async def test_get_default_database_error(self): # URI with no database. - async with await async_rs_or_single_client( + async with await self.async_rs_or_single_client( "mongodb://%s:%d/" % (await async_client_context.host, await async_client_context.port), connect=False, ) as c: @@ -303,11 +300,11 @@ async def test_get_default_database_with_authsource(self): await async_client_context.host, await async_client_context.port, ) - async with await async_rs_or_single_client(uri, connect=False) as c: + async with await self.async_rs_or_single_client(uri, connect=False) as c: self.assertEqual(AsyncDatabase(c, "foo"), c.get_default_database()) async def test_get_database_default(self): - async with await async_rs_or_single_client( + async with await self.async_rs_or_single_client( "mongodb://%s:%d/foo" % (await async_client_context.host, await async_client_context.port), connect=False, @@ -316,7 +313,7 @@ async def test_get_database_default(self): async def test_get_database_default_error(self): # URI with no database. - c = await async_rs_or_single_client( + c = await self.async_rs_or_single_client( "mongodb://%s:%d/" % (await async_client_context.host, await async_client_context.port), connect=False, ) @@ -329,7 +326,7 @@ async def test_get_database_default_with_authsource(self): await async_client_context.host, await async_client_context.port, ) - c = await async_rs_or_single_client(uri, connect=False) + c = await self.async_rs_or_single_client(uri, connect=False) self.assertEqual(AsyncDatabase(c, "foo"), c.get_database()) await c.close() @@ -342,7 +339,7 @@ def test_primary_read_pref_with_tags(self): AsyncMongoClient("mongodb://host/?readpreference=primary&readpreferencetags=dc:east") async def test_read_preference(self): - async with await async_rs_or_single_client( + async with await self.async_rs_or_single_client( "mongodb://host", connect=False, readpreference=ReadPreference.NEAREST.mongos_mode ) as c: self.assertEqual(c.read_preference, ReadPreference.NEAREST) @@ -691,7 +688,7 @@ def test_multiple_uris(self): async def test_max_idle_time_reaper_default(self): with client_knobs(kill_cursor_frequency=0.1): # Assert reaper doesn't remove connections when maxIdleTimeMS not set - client = await async_rs_or_single_client() + client = await self.async_rs_or_single_client() server = await (await client._get_topology()).select_server( readable_server_selector, _Op.TEST ) @@ -704,7 +701,7 @@ async def test_max_idle_time_reaper_default(self): async def test_max_idle_time_reaper_removes_stale_minPoolSize(self): with client_knobs(kill_cursor_frequency=0.1): # Assert reaper removes idle socket and replaces it with a new one - client = await async_rs_or_single_client(maxIdleTimeMS=500, minPoolSize=1) + client = await self.async_rs_or_single_client(maxIdleTimeMS=500, minPoolSize=1) server = await (await client._get_topology()).select_server( readable_server_selector, _Op.TEST ) @@ -720,7 +717,7 @@ async def test_max_idle_time_reaper_removes_stale_minPoolSize(self): async def test_max_idle_time_reaper_does_not_exceed_maxPoolSize(self): with client_knobs(kill_cursor_frequency=0.1): # Assert reaper respects maxPoolSize when adding new connections. - client = await async_rs_or_single_client( + client = await self.async_rs_or_single_client( maxIdleTimeMS=500, minPoolSize=1, maxPoolSize=1 ) server = await (await client._get_topology()).select_server( @@ -738,7 +735,7 @@ async def test_max_idle_time_reaper_does_not_exceed_maxPoolSize(self): async def test_max_idle_time_reaper_removes_stale(self): with client_knobs(kill_cursor_frequency=0.1): # Assert reaper has removed idle socket and NOT replaced it - client = await async_rs_or_single_client(maxIdleTimeMS=500) + client = await self.async_rs_or_single_client(maxIdleTimeMS=500) server = await (await client._get_topology()).select_server( readable_server_selector, _Op.TEST ) @@ -757,7 +754,7 @@ async def test_max_idle_time_reaper_removes_stale(self): async def test_min_pool_size(self): with client_knobs(kill_cursor_frequency=0.1): - client = await async_rs_or_single_client() + client = await self.async_rs_or_single_client() self.addAsyncCleanup(client.close) server = await (await client._get_topology()).select_server( readable_server_selector, _Op.TEST @@ -765,7 +762,7 @@ async def test_min_pool_size(self): self.assertEqual(0, len(server._pool.conns)) # Assert that pool started up at minPoolSize - client = await async_rs_or_single_client(minPoolSize=10) + client = await self.async_rs_or_single_client(minPoolSize=10) self.addAsyncCleanup(client.close) server = await (await client._get_topology()).select_server( readable_server_selector, _Op.TEST @@ -787,7 +784,7 @@ async def test_min_pool_size(self): async def test_max_idle_time_checkout(self): # Use high frequency to test _get_socket_no_auth. with client_knobs(kill_cursor_frequency=99999999): - client = await async_rs_or_single_client(maxIdleTimeMS=500) + client = await self.async_rs_or_single_client(maxIdleTimeMS=500) self.addAsyncCleanup(client.close) server = await (await client._get_topology()).select_server( readable_server_selector, _Op.TEST @@ -804,7 +801,7 @@ async def test_max_idle_time_checkout(self): self.assertTrue(new_con in server._pool.conns) # Test that connections are reused if maxIdleTimeMS is not set. - client = await async_rs_or_single_client() + client = await self.async_rs_or_single_client() self.addAsyncCleanup(client.close) server = await (await client._get_topology()).select_server( readable_server_selector, _Op.TEST @@ -848,27 +845,27 @@ async def test_constants(self): async def test_init_disconnected(self): host, port = await async_client_context.host, await async_client_context.port - c = await async_rs_or_single_client(connect=False) + c = await self.async_rs_or_single_client(connect=False) self.addAsyncCleanup(c.close) # is_primary causes client to block until connected self.assertIsInstance(await c.is_primary, bool) - c = await async_rs_or_single_client(connect=False) + c = await self.async_rs_or_single_client(connect=False) self.addAsyncCleanup(c.close) self.assertIsInstance(await c.is_mongos, bool) - c = await async_rs_or_single_client(connect=False) + c = await self.async_rs_or_single_client(connect=False) self.addAsyncCleanup(c.close) self.assertIsInstance(c.options.pool_options.max_pool_size, int) self.assertIsInstance(c.nodes, frozenset) - c = await async_rs_or_single_client(connect=False) + c = await self.async_rs_or_single_client(connect=False) self.addAsyncCleanup(c.close) self.assertEqual(c.codec_options, CodecOptions()) - c = await async_rs_or_single_client(connect=False) + c = await self.async_rs_or_single_client(connect=False) self.addAsyncCleanup(c.close) self.assertFalse(await c.primary) self.assertFalse(await c.secondaries) - c = await async_rs_or_single_client(connect=False) + c = await self.async_rs_or_single_client(connect=False) self.addAsyncCleanup(c.close) self.assertIsInstance(c.topology_description, TopologyDescription) self.assertEqual(c.topology_description, c._topology._description) @@ -895,12 +892,12 @@ async def test_init_disconnected_with_auth(self): async def test_equality(self): seed = "{}:{}".format(*list(self.client._topology_settings.seeds)[0]) - async with await async_rs_or_single_client(seed, connect=False) as c: + async with await self.async_rs_or_single_client(seed, connect=False) as c: self.assertEqual(async_client_context.client, c) # Explicitly test inequality self.assertFalse(async_client_context.client != c) - async with await async_rs_or_single_client("invalid.com", connect=False) as c: + async with await self.async_rs_or_single_client("invalid.com", connect=False) as c: self.assertNotEqual(async_client_context.client, c) self.assertTrue(async_client_context.client != c) @@ -922,10 +919,10 @@ async def test_equality(self): async def test_hashable(self): seed = "{}:{}".format(*list(self.client._topology_settings.seeds)[0]) - c = await async_rs_or_single_client(seed, connect=False) + c = await self.async_rs_or_single_client(seed, connect=False) self.addAsyncCleanup(c.close) self.assertIn(c, {async_client_context.client}) - c = await async_rs_or_single_client("invalid.com", connect=False) + c = await self.async_rs_or_single_client("invalid.com", connect=False) self.addAsyncCleanup(c.close) self.assertNotIn(c, {async_client_context.client}) @@ -999,7 +996,7 @@ async def test_list_databases(self): for helper_doc, cmd_doc in zip(helper_docs, cmd_docs): self.assertIs(type(helper_doc), dict) self.assertEqual(helper_doc.keys(), cmd_doc.keys()) - client = await async_rs_or_single_client(document_class=SON) + client = await self.async_rs_or_single_client(document_class=SON) self.addAsyncCleanup(client.close) async for doc in await client.list_databases(): self.assertIs(type(doc), dict) @@ -1039,7 +1036,7 @@ async def test_drop_database(self): await self.client.drop_database("pymongo_test") if async_client_context.is_rs: - wc_client = await async_rs_or_single_client(w=len(async_client_context.nodes) + 1) + wc_client = await self.async_rs_or_single_client(w=len(async_client_context.nodes) + 1) with self.assertRaises(WriteConcernError): await wc_client.drop_database("pymongo_test2") @@ -1049,7 +1046,7 @@ async def test_drop_database(self): self.assertNotIn("pymongo_test2", dbs) async def test_close(self): - test_client = await async_rs_or_single_client() + test_client = await self.async_rs_or_single_client() coll = test_client.pymongo_test.bar await test_client.close() with self.assertRaises(InvalidOperation): @@ -1059,7 +1056,7 @@ async def test_close_kills_cursors(self): if sys.platform.startswith("java"): # We can't figure out how to make this test reliable with Jython. raise SkipTest("Can't test with Jython") - test_client = await async_rs_or_single_client() + test_client = await self.async_rs_or_single_client() # Kill any cursors possibly queued up by previous tests. gc.collect() await test_client._process_periodic_tasks() @@ -1086,14 +1083,14 @@ async def test_close_kills_cursors(self): self.assertTrue(test_client._topology._opened) await test_client.close() self.assertFalse(test_client._topology._opened) - test_client = await async_rs_or_single_client() + test_client = await self.async_rs_or_single_client() # The killCursors task should not need to re-open the topology. await test_client._process_periodic_tasks() self.assertTrue(test_client._topology._opened) await test_client.close() async def test_close_stops_kill_cursors_thread(self): - client = await async_rs_client() + client = await self.async_rs_client() await client.test.test.find_one() self.assertFalse(client._kill_cursors_executor._stopped) @@ -1109,7 +1106,7 @@ async def test_close_stops_kill_cursors_thread(self): async def test_uri_connect_option(self): # Ensure that topology is not opened if connect=False. - client = await async_rs_client(connect=False) + client = await self.async_rs_client(connect=False) self.assertFalse(client._topology._opened) # Ensure kill cursors thread has not been started. @@ -1126,14 +1123,14 @@ async def test_uri_connect_option(self): await client.close() async def test_close_does_not_open_servers(self): - client = await async_rs_client(connect=False) + client = await self.async_rs_client(connect=False) topology = client._topology self.assertEqual(topology._servers, {}) await client.close() self.assertEqual(topology._servers, {}) async def test_close_closes_sockets(self): - client = await async_rs_client() + client = await self.async_rs_client() self.addAsyncCleanup(client.close) await client.test.test.find_one() topology = client._topology @@ -1163,35 +1160,35 @@ async def test_auth_from_uri(self): with self.assertRaises(OperationFailure): await connected( - await async_rs_or_single_client_noauth("mongodb://a:b@%s:%d" % (host, port)) + await self.async_rs_or_single_client_noauth("mongodb://a:b@%s:%d" % (host, port)) ) # No error. await connected( - await async_rs_or_single_client_noauth("mongodb://admin:pass@%s:%d" % (host, port)) + await self.async_rs_or_single_client_noauth("mongodb://admin:pass@%s:%d" % (host, port)) ) # Wrong database. uri = "mongodb://admin:pass@%s:%d/pymongo_test" % (host, port) with self.assertRaises(OperationFailure): - await connected(await async_rs_or_single_client_noauth(uri)) + await connected(await self.async_rs_or_single_client_noauth(uri)) # No error. await connected( - await async_rs_or_single_client_noauth( + await self.async_rs_or_single_client_noauth( "mongodb://user:pass@%s:%d/pymongo_test" % (host, port) ) ) # Auth with lazy connection. await ( - await async_rs_or_single_client_noauth( + await self.async_rs_or_single_client_noauth( "mongodb://user:pass@%s:%d/pymongo_test" % (host, port), connect=False ) ).pymongo_test.test.find_one() # Wrong password. - bad_client = await async_rs_or_single_client_noauth( + bad_client = await self.async_rs_or_single_client_noauth( "mongodb://user:wrong@%s:%d/pymongo_test" % (host, port), connect=False ) @@ -1203,7 +1200,7 @@ async def test_username_and_password(self): await async_client_context.create_user("admin", "ad min", "pa/ss") self.addAsyncCleanup(async_client_context.drop_user, "admin", "ad min") - c = await async_rs_or_single_client_noauth(username="ad min", password="pa/ss") + c = await self.async_rs_or_single_client_noauth(username="ad min", password="pa/ss") # Username and password aren't in strings that will likely be logged. self.assertNotIn("ad min", repr(c)) @@ -1216,14 +1213,14 @@ async def test_username_and_password(self): with self.assertRaises(OperationFailure): await ( - await async_rs_or_single_client_noauth(username="ad min", password="foo") + await self.async_rs_or_single_client_noauth(username="ad min", password="foo") ).server_info() @async_client_context.require_auth @async_client_context.require_no_fips async def test_lazy_auth_raises_operation_failure(self): host = await async_client_context.host - lazy_client = await async_rs_or_single_client_noauth( + lazy_client = await self.async_rs_or_single_client_noauth( f"mongodb://user:wrong@{host}/pymongo_test", connect=False ) @@ -1241,7 +1238,7 @@ async def test_unix_socket(self): uri = "mongodb://%s" % encoded_socket # Confirm we can do operations via the socket. - client = await async_rs_or_single_client(uri) + client = await self.async_rs_or_single_client(uri) self.addAsyncCleanup(client.close) await client.pymongo_test.test.insert_one({"dummy": "object"}) dbs = await client.list_database_names() @@ -1265,7 +1262,7 @@ async def test_document_class(self): self.assertTrue(isinstance(await db.test.find_one(), dict)) self.assertFalse(isinstance(await db.test.find_one(), SON)) - c = await async_rs_or_single_client(document_class=SON) + c = await self.async_rs_or_single_client(document_class=SON) self.addAsyncCleanup(c.close) db = c.pymongo_test @@ -1273,7 +1270,7 @@ async def test_document_class(self): self.assertTrue(isinstance(await db.test.find_one(), SON)) async def test_timeouts(self): - client = await async_rs_or_single_client( + client = await self.async_rs_or_single_client( connectTimeoutMS=10500, socketTimeoutMS=10500, maxIdleTimeMS=10500, @@ -1287,34 +1284,34 @@ async def test_timeouts(self): self.assertEqual(10.5, client.options.server_selection_timeout) async def test_socket_timeout_ms_validation(self): - c = await async_rs_or_single_client(socketTimeoutMS=10 * 1000) + c = await self.async_rs_or_single_client(socketTimeoutMS=10 * 1000) self.addAsyncCleanup(c.close) self.assertEqual(10, (await async_get_pool(c)).opts.socket_timeout) - c = await connected(await async_rs_or_single_client(socketTimeoutMS=None)) + c = await connected(await self.async_rs_or_single_client(socketTimeoutMS=None)) self.addAsyncCleanup(c.close) self.assertEqual(None, (await async_get_pool(c)).opts.socket_timeout) - c = await connected(await async_rs_or_single_client(socketTimeoutMS=0)) + c = await connected(await self.async_rs_or_single_client(socketTimeoutMS=0)) self.addAsyncCleanup(c.close) self.assertEqual(None, (await async_get_pool(c)).opts.socket_timeout) with self.assertRaises(ValueError): - async with await async_rs_or_single_client(socketTimeoutMS=-1): + async with await self.async_rs_or_single_client(socketTimeoutMS=-1): pass with self.assertRaises(ValueError): - async with await async_rs_or_single_client(socketTimeoutMS=1e10): + async with await self.async_rs_or_single_client(socketTimeoutMS=1e10): pass with self.assertRaises(ValueError): - async with await async_rs_or_single_client(socketTimeoutMS="foo"): + async with await self.async_rs_or_single_client(socketTimeoutMS="foo"): pass async def test_socket_timeout(self): no_timeout = self.client timeout_sec = 1 - timeout = await async_rs_or_single_client(socketTimeoutMS=1000 * timeout_sec) + timeout = await self.async_rs_or_single_client(socketTimeoutMS=1000 * timeout_sec) self.addAsyncCleanup(timeout.close) await no_timeout.pymongo_test.drop_collection("test") @@ -1369,7 +1366,7 @@ async def test_server_selection_timeout(self): self.assertAlmostEqual(30, client.options.server_selection_timeout) async def test_waitQueueTimeoutMS(self): - client = await async_rs_or_single_client(waitQueueTimeoutMS=2000) + client = await self.async_rs_or_single_client(waitQueueTimeoutMS=2000) self.addAsyncCleanup(client.close) self.assertEqual((await async_get_pool(client)).opts.wait_queue_timeout, 2) @@ -1383,7 +1380,7 @@ async def test_socketKeepAlive(self): async def test_tz_aware(self): self.assertRaises(ValueError, AsyncMongoClient, tz_aware="foo") - aware = await async_rs_or_single_client(tz_aware=True) + aware = await self.async_rs_or_single_client(tz_aware=True) self.addAsyncCleanup(aware.close) naive = self.client await aware.pymongo_test.drop_collection("test") @@ -1413,7 +1410,7 @@ async def test_ipv6(self): if async_client_context.is_rs: uri += "/?replicaSet=" + (async_client_context.replica_set_name or "") - client = await async_rs_or_single_client_noauth(uri) + client = await self.async_rs_or_single_client_noauth(uri) self.addAsyncCleanup(client.close) await client.pymongo_test.test.insert_one({"dummy": "object"}) await client.pymongo_test_bernie.test.insert_one({"dummy": "object"}) @@ -1423,7 +1420,7 @@ async def test_ipv6(self): self.assertTrue("pymongo_test_bernie" in dbs) async def test_contextlib(self): - client = await async_rs_or_single_client() + client = await self.async_rs_or_single_client() await client.pymongo_test.drop_collection("test") await client.pymongo_test.test.insert_one({"foo": "bar"}) @@ -1437,7 +1434,7 @@ async def test_contextlib(self): self.assertEqual("bar", (await client.pymongo_test.test.find_one())["foo"]) with self.assertRaises(InvalidOperation): await client.pymongo_test.test.find_one() - client = await async_rs_or_single_client() + client = await self.async_rs_or_single_client() async with client as client: self.assertEqual("bar", (await client.pymongo_test.test.find_one())["foo"]) with self.assertRaises(InvalidOperation): @@ -1515,7 +1512,7 @@ async def test_operation_failure(self): # response to getLastError. PYTHON-395. We need a new client here # to avoid race conditions caused by replica set failover or idle # socket reaping. - client = await async_single_client() + client = await self.async_single_client() self.addAsyncCleanup(client.close) await client.pymongo_test.test.find_one() pool = await async_get_pool(client) @@ -1540,7 +1537,7 @@ async def test_lazy_connect_w0(self): await async_client_context.client.drop_database("test_lazy_connect_w0") self.addAsyncCleanup(async_client_context.client.drop_database, "test_lazy_connect_w0") - client = await async_rs_or_single_client(connect=False, w=0) + client = await self.async_rs_or_single_client(connect=False, w=0) self.addAsyncCleanup(client.close) await client.test_lazy_connect_w0.test.insert_one({}) @@ -1549,7 +1546,7 @@ async def predicate(): await async_wait_until(predicate, "find one document") - client = await async_rs_or_single_client(connect=False, w=0) + client = await self.async_rs_or_single_client(connect=False, w=0) self.addAsyncCleanup(client.close) await client.test_lazy_connect_w0.test.update_one({}, {"$set": {"x": 1}}) @@ -1558,7 +1555,7 @@ async def predicate(): await async_wait_until(predicate, "update one document") - client = await async_rs_or_single_client(connect=False, w=0) + client = await self.async_rs_or_single_client(connect=False, w=0) self.addAsyncCleanup(client.close) await client.test_lazy_connect_w0.test.delete_one({}) @@ -1571,7 +1568,7 @@ async def predicate(): async def test_exhaust_network_error(self): # When doing an exhaust query, the socket stays checked out on success # but must be checked in on error to avoid semaphore leaks. - client = await async_rs_or_single_client(maxPoolSize=1, retryReads=False) + client = await self.async_rs_or_single_client(maxPoolSize=1, retryReads=False) self.addAsyncCleanup(client.close) collection = client.pymongo_test.test pool = await async_get_pool(client) @@ -1599,7 +1596,9 @@ async def test_auth_network_error(self): # Get a client with one socket so we detect if it's leaked. c = await connected( - await async_rs_or_single_client(maxPoolSize=1, waitQueueTimeoutMS=1, retryReads=False) + await self.async_rs_or_single_client( + maxPoolSize=1, waitQueueTimeoutMS=1, retryReads=False + ) ) # Cause a network error on the actual socket. @@ -1617,7 +1616,7 @@ async def test_auth_network_error(self): @async_client_context.require_no_replica_set async def test_connect_to_standalone_using_replica_set_name(self): - async with await async_single_client( + async with await self.async_single_client( replicaSet="anything", serverSelectionTimeoutMS=100 ) as client: with self.assertRaises(AutoReconnect): @@ -1629,7 +1628,7 @@ async def test_stale_getmore(self): # the topology before the getMore message is sent. Test that # AsyncMongoClient._run_operation_with_response handles the error. with self.assertRaises(AutoReconnect): - client = await async_rs_client(connect=False, serverSelectionTimeoutMS=100) + client = await self.async_rs_client(connect=False, serverSelectionTimeoutMS=100) await client._run_operation( operation=message._GetMore( "pymongo_test", @@ -1677,7 +1676,7 @@ def init(self, *args): await async_client_context.host, await async_client_context.port, ) - client = await async_single_client(uri, event_listeners=[listener]) + client = await self.async_single_client(uri, event_listeners=[listener]) wait_until( lambda: len(listener.results) >= 2, "record two ServerHeartbeatStartedEvents" ) @@ -1778,12 +1777,12 @@ def compression_settings(client): options = async_client_context.default_client_options if "compressors" in options and "zlib" in options["compressors"]: for level in range(-1, 10): - async with await async_single_client(zlibcompressionlevel=level) as client: + async with await self.async_single_client(zlibcompressionlevel=level) as client: # No error await client.pymongo_test.test.find_one() async def test_reset_during_update_pool(self): - client = await async_rs_or_single_client(minPoolSize=10) + client = await self.async_rs_or_single_client(minPoolSize=10) self.addAsyncCleanup(client.close) await client.admin.command("ping") pool = await async_get_pool(client) @@ -1830,7 +1829,7 @@ def run(self): async def test_background_connections_do_not_hold_locks(self): min_pool_size = 10 - client = await async_rs_or_single_client( + client = await self.async_rs_or_single_client( serverSelectionTimeoutMS=3000, minPoolSize=min_pool_size, connect=False ) self.addAsyncCleanup(client.close) @@ -1864,14 +1863,14 @@ def stall_connect(*args, **kwargs): @async_client_context.require_replica_set async def test_direct_connection(self): # direct_connection=True should result in Single topology. - client = await async_rs_or_single_client(directConnection=True) + client = await self.async_rs_or_single_client(directConnection=True) await client.admin.command("ping") self.assertEqual(len(client.nodes), 1) self.assertEqual(client._topology_settings.get_topology_type(), TOPOLOGY_TYPE.Single) await client.close() # direct_connection=False should result in RS topology. - client = await async_rs_or_single_client(directConnection=False) + client = await self.async_rs_or_single_client(directConnection=False) await client.admin.command("ping") self.assertGreaterEqual(len(client.nodes), 1) self.assertIn( @@ -1915,7 +1914,7 @@ def server_description_count(): @async_client_context.require_failCommand_fail_point async def test_network_error_message(self): - client = await async_single_client(retryReads=False) + client = await self.async_single_client(retryReads=False) self.addAsyncCleanup(client.close) await client.admin.command("ping") # connect async with self.fail_point( @@ -1928,7 +1927,7 @@ async def test_network_error_message(self): @unittest.skipIf("PyPy" in sys.version, "PYTHON-2938 could fail on PyPy") async def test_process_periodic_tasks(self): - client = await async_rs_or_single_client() + client = await self.async_rs_or_single_client() coll = client.db.collection await coll.insert_many([{} for _ in range(5)]) cursor = coll.find(batch_size=2) @@ -2025,7 +2024,9 @@ async def _test_handshake(self, env_vars, expected_env): if "AWS_REGION" not in env_vars: os.environ["AWS_REGION"] = "" - async with await async_rs_or_single_client(serverSelectionTimeoutMS=10000) as client: + async with await self.async_rs_or_single_client( + serverSelectionTimeoutMS=10000 + ) as client: await client.admin.command("ping") options = client.options self.assertEqual(options.pool_options.metadata, metadata) @@ -2124,7 +2125,7 @@ def setUp(self): async def test_exhaust_query_server_error(self): # When doing an exhaust query, the socket stays checked out on success # but must be checked in on error to avoid semaphore leaks. - client = await connected(await async_rs_or_single_client(maxPoolSize=1)) + client = await connected(await self.async_rs_or_single_client(maxPoolSize=1)) self.addAsyncCleanup(client.close) collection = client.pymongo_test.test @@ -2148,7 +2149,7 @@ async def test_exhaust_query_server_error(self): async def test_exhaust_getmore_server_error(self): # When doing a getmore on an exhaust cursor, the socket stays checked # out on success but it's checked in on error to avoid semaphore leaks. - client = await async_rs_or_single_client(maxPoolSize=1) + client = await self.async_rs_or_single_client(maxPoolSize=1) self.addAsyncCleanup(client.close) collection = client.pymongo_test.test await collection.drop() @@ -2188,7 +2189,9 @@ async def receive_message(request_id): async def test_exhaust_query_network_error(self): # When doing an exhaust query, the socket stays checked out on success # but must be checked in on error to avoid semaphore leaks. - client = await connected(await async_rs_or_single_client(maxPoolSize=1, retryReads=False)) + client = await connected( + await self.async_rs_or_single_client(maxPoolSize=1, retryReads=False) + ) self.addAsyncCleanup(client.close) collection = client.pymongo_test.test pool = await async_get_pool(client) @@ -2210,7 +2213,7 @@ async def test_exhaust_query_network_error(self): async def test_exhaust_getmore_network_error(self): # When doing a getmore on an exhaust cursor, the socket stays checked # out on success but it's checked in on error to avoid semaphore leaks. - client = await async_rs_or_single_client(maxPoolSize=1) + client = await self.async_rs_or_single_client(maxPoolSize=1) self.addAsyncCleanup(client.close) collection = client.pymongo_test.test await collection.drop() @@ -2260,7 +2263,7 @@ def test_gevent_timeout(self): raise SkipTest("Must be running monkey patched by gevent") from gevent import Timeout, spawn - client = rs_or_single_client(maxPoolSize=1) + client = self.async_rs_or_single_client(maxPoolSize=1) coll = client.pymongo_test.test coll.insert_one({}) @@ -2292,7 +2295,7 @@ def test_gevent_timeout_when_creating_connection(self): raise SkipTest("Must be running monkey patched by gevent") from gevent import Timeout, spawn - client = rs_or_single_client() + client = self.async_rs_or_single_client() self.addCleanup(client.close) coll = client.pymongo_test.test pool = async_get_pool(client) @@ -2329,7 +2332,7 @@ class TestClientLazyConnect(AsyncIntegrationTest): """Test concurrent operations on a lazily-connecting MongoClient.""" def _get_client(self): - return rs_or_single_client(connect=False) + return self.async_rs_or_single_client(connect=False) @async_client_context.require_sync def test_insert_one(self): diff --git a/test/asynchronous/test_client_bulk_write.py b/test/asynchronous/test_client_bulk_write.py index c35e823d03..02d49c207a 100644 --- a/test/asynchronous/test_client_bulk_write.py +++ b/test/asynchronous/test_client_bulk_write.py @@ -27,7 +27,6 @@ ) from test.utils import ( OvertCommandListener, - async_rs_or_single_client, ) from unittest.mock import patch @@ -36,10 +35,8 @@ from pymongo.errors import ( ClientBulkWriteException, DocumentTooLarge, - InvalidOperation, NetworkTimeout, ) -from pymongo.monitoring import * from pymongo.operations import * from pymongo.write_concern import WriteConcern @@ -97,7 +94,7 @@ async def asyncSetUp(self): @async_client_context.require_no_serverless async def test_batch_splits_if_num_operations_too_large(self): listener = OvertCommandListener() - client = await async_rs_or_single_client(event_listeners=[listener]) + client = await self.async_rs_or_single_client(event_listeners=[listener]) self.addAsyncCleanup(client.close) models = [] @@ -123,7 +120,7 @@ async def test_batch_splits_if_num_operations_too_large(self): @async_client_context.require_no_serverless async def test_batch_splits_if_ops_payload_too_large(self): listener = OvertCommandListener() - client = await async_rs_or_single_client(event_listeners=[listener]) + client = await self.async_rs_or_single_client(event_listeners=[listener]) self.addAsyncCleanup(client.close) models = [] @@ -157,7 +154,7 @@ async def test_batch_splits_if_ops_payload_too_large(self): @async_client_context.require_failCommand_fail_point async def test_collects_write_concern_errors_across_batches(self): listener = OvertCommandListener() - client = await async_rs_or_single_client( + client = await self.async_rs_or_single_client( event_listeners=[listener], retryWrites=False, ) @@ -200,7 +197,7 @@ async def test_collects_write_concern_errors_across_batches(self): @async_client_context.require_no_serverless async def test_collects_write_errors_across_batches_unordered(self): listener = OvertCommandListener() - client = await async_rs_or_single_client(event_listeners=[listener]) + client = await self.async_rs_or_single_client(event_listeners=[listener]) self.addAsyncCleanup(client.close) collection = client.db["coll"] @@ -231,7 +228,7 @@ async def test_collects_write_errors_across_batches_unordered(self): @async_client_context.require_no_serverless async def test_collects_write_errors_across_batches_ordered(self): listener = OvertCommandListener() - client = await async_rs_or_single_client(event_listeners=[listener]) + client = await self.async_rs_or_single_client(event_listeners=[listener]) self.addAsyncCleanup(client.close) collection = client.db["coll"] @@ -262,7 +259,7 @@ async def test_collects_write_errors_across_batches_ordered(self): @async_client_context.require_no_serverless async def test_handles_cursor_requiring_getMore(self): listener = OvertCommandListener() - client = await async_rs_or_single_client(event_listeners=[listener]) + client = await self.async_rs_or_single_client(event_listeners=[listener]) self.addAsyncCleanup(client.close) collection = client.db["coll"] @@ -304,7 +301,7 @@ async def test_handles_cursor_requiring_getMore(self): @async_client_context.require_no_standalone async def test_handles_cursor_requiring_getMore_within_transaction(self): listener = OvertCommandListener() - client = await async_rs_or_single_client(event_listeners=[listener]) + client = await self.async_rs_or_single_client(event_listeners=[listener]) self.addAsyncCleanup(client.close) collection = client.db["coll"] @@ -348,7 +345,7 @@ async def test_handles_cursor_requiring_getMore_within_transaction(self): @async_client_context.require_failCommand_fail_point async def test_handles_getMore_error(self): listener = OvertCommandListener() - client = await async_rs_or_single_client(event_listeners=[listener]) + client = await self.async_rs_or_single_client(event_listeners=[listener]) self.addAsyncCleanup(client.close) collection = client.db["coll"] @@ -403,7 +400,7 @@ async def test_handles_getMore_error(self): @async_client_context.require_no_serverless async def test_returns_error_if_unacknowledged_too_large_insert(self): listener = OvertCommandListener() - client = await async_rs_or_single_client(event_listeners=[listener]) + client = await self.async_rs_or_single_client(event_listeners=[listener]) self.addAsyncCleanup(client.close) b_repeated = "b" * self.max_bson_object_size @@ -460,7 +457,7 @@ async def _setup_namespace_test_models(self): @async_client_context.require_no_serverless async def test_no_batch_splits_if_new_namespace_is_not_too_large(self): listener = OvertCommandListener() - client = await async_rs_or_single_client(event_listeners=[listener]) + client = await self.async_rs_or_single_client(event_listeners=[listener]) self.addAsyncCleanup(client.close) num_models, models = await self._setup_namespace_test_models() @@ -492,7 +489,7 @@ async def test_no_batch_splits_if_new_namespace_is_not_too_large(self): @async_client_context.require_no_serverless async def test_batch_splits_if_new_namespace_is_too_large(self): listener = OvertCommandListener() - client = await async_rs_or_single_client(event_listeners=[listener]) + client = await self.async_rs_or_single_client(event_listeners=[listener]) self.addAsyncCleanup(client.close) num_models, models = await self._setup_namespace_test_models() @@ -530,7 +527,7 @@ async def test_batch_splits_if_new_namespace_is_too_large(self): @async_client_context.require_version_min(8, 0, 0, -24) @async_client_context.require_no_serverless async def test_returns_error_if_no_writes_can_be_added_to_ops(self): - client = await async_rs_or_single_client() + client = await self.async_rs_or_single_client() self.addAsyncCleanup(client.close) # Document too large. @@ -554,7 +551,7 @@ async def test_returns_error_if_auto_encryption_configured(self): key_vault_namespace="db.coll", kms_providers={"aws": {"accessKeyId": "foo", "secretAccessKey": "bar"}}, ) - client = await async_rs_or_single_client(auto_encryption_opts=opts) + client = await self.async_rs_or_single_client(auto_encryption_opts=opts) self.addAsyncCleanup(client.close) models = [InsertOne(namespace="db.coll", document={"a": "b"})] @@ -580,7 +577,7 @@ async def asyncSetUp(self): async def test_timeout_in_multi_batch_bulk_write(self): _OVERHEAD = 500 - internal_client = await async_rs_or_single_client(timeoutMS=None) + internal_client = await self.async_rs_or_single_client(timeoutMS=None) self.addAsyncCleanup(internal_client.close) collection = internal_client.db["coll"] @@ -605,7 +602,7 @@ async def test_timeout_in_multi_batch_bulk_write(self): ) listener = OvertCommandListener() - client = await async_rs_or_single_client( + client = await self.async_rs_or_single_client( event_listeners=[listener], readConcernLevel="majority", readPreference="primary", diff --git a/test/asynchronous/test_collection.py b/test/asynchronous/test_collection.py index 10d64a525c..74a4a5151d 100644 --- a/test/asynchronous/test_collection.py +++ b/test/asynchronous/test_collection.py @@ -30,6 +30,7 @@ from test import unittest from test.asynchronous import ( # TODO: fix sync imports in PYTHON-4528 AsyncIntegrationTest, + AsyncUnitTest, async_client_context, ) from test.utils import ( @@ -37,8 +38,6 @@ EventListener, async_get_pool, async_is_mongos, - async_rs_or_single_client, - async_single_client, async_wait_until, wait_until, ) @@ -82,14 +81,20 @@ _IS_SYNC = False -class TestCollectionNoConnect(unittest.TestCase): +class TestCollectionNoConnect(AsyncUnitTest): """Test Collection features on a client that does not connect.""" db: AsyncDatabase + client: AsyncMongoClient @classmethod - def setUpClass(cls): - cls.db = AsyncMongoClient(connect=False).pymongo_test + async def _setup_class(cls): + cls.client = AsyncMongoClient(connect=False) + cls.db = cls.client.pymongo_test + + @classmethod + async def _tearDown_class(cls): + await cls.client.close() def test_collection(self): self.assertRaises(TypeError, AsyncCollection, self.db, 5) @@ -1819,8 +1824,7 @@ async def test_exhaust(self): # Insert enough documents to require more than one batch await self.db.test.insert_many([{"i": i} for i in range(150)]) - client = await async_rs_or_single_client(maxPoolSize=1) - self.addAsyncCleanup(client.close) + client = await self.async_rs_or_single_client(maxPoolSize=1) pool = await async_get_pool(client) # Make sure the socket is returned after exhaustion. @@ -2100,7 +2104,7 @@ async def test_find_one_and(self): async def test_find_one_and_write_concern(self): listener = EventListener() - db = (await async_single_client(event_listeners=[listener]))[self.db.name] + db = (await self.async_single_client(event_listeners=[listener]))[self.db.name] # non-default WriteConcern. c_w0 = db.get_collection("test", write_concern=WriteConcern(w=0)) # default WriteConcern. diff --git a/test/asynchronous/test_cursor.py b/test/asynchronous/test_cursor.py index 6967205fe3..8431f4369b 100644 --- a/test/asynchronous/test_cursor.py +++ b/test/asynchronous/test_cursor.py @@ -34,7 +34,6 @@ AllowListEventListener, EventListener, OvertCommandListener, - async_rs_or_single_client, ignore_deprecations, wait_until, ) @@ -232,7 +231,7 @@ async def test_max_await_time_ms(self): self.assertEqual(90, cursor._max_await_time_ms) listener = AllowListEventListener("find", "getMore") - coll = (await async_rs_or_single_client(event_listeners=[listener]))[ + coll = (await self.async_rs_or_single_client(event_listeners=[listener]))[ self.db.name ].pymongo_test @@ -353,7 +352,7 @@ async def test_explain(self): async def test_explain_with_read_concern(self): # Do not add readConcern level to explain. listener = AllowListEventListener("explain") - client = await async_rs_or_single_client(event_listeners=[listener]) + client = await self.async_rs_or_single_client(event_listeners=[listener]) self.addAsyncCleanup(client.close) coll = client.pymongo_test.test.with_options(read_concern=ReadConcern(level="local")) self.assertTrue(await coll.find().explain()) @@ -1261,7 +1260,7 @@ async def test_close_kills_cursor_synchronously(self): await self.client._process_periodic_tasks() listener = AllowListEventListener("killCursors") - client = await async_rs_or_single_client(event_listeners=[listener]) + client = await self.async_rs_or_single_client(event_listeners=[listener]) self.addAsyncCleanup(client.close) coll = client[self.db.name].test_close_kills_cursors @@ -1300,7 +1299,7 @@ def assertCursorKilled(): @async_client_context.require_failCommand_appName async def test_timeout_kills_cursor_asynchronously(self): listener = AllowListEventListener("killCursors") - client = await async_rs_or_single_client(event_listeners=[listener]) + client = await self.async_rs_or_single_client(event_listeners=[listener]) self.addAsyncCleanup(client.close) coll = client[self.db.name].test_timeout_kills_cursor @@ -1358,7 +1357,7 @@ def test_delete_not_initialized(self): async def test_getMore_does_not_send_readPreference(self): listener = AllowListEventListener("find", "getMore") - client = await async_rs_or_single_client(event_listeners=[listener]) + client = await self.async_rs_or_single_client(event_listeners=[listener]) self.addAsyncCleanup(client.close) # We never send primary read preference so override the default. coll = client[self.db.name].get_collection( @@ -1463,7 +1462,7 @@ async def test_find_raw_transaction(self): await c.insert_many(docs) listener = OvertCommandListener() - client = await async_rs_or_single_client(event_listeners=[listener]) + client = await self.async_rs_or_single_client(event_listeners=[listener]) async with client.start_session() as session: async with await session.start_transaction(): batches = await ( @@ -1493,7 +1492,7 @@ async def test_find_raw_retryable_reads(self): await c.insert_many(docs) listener = OvertCommandListener() - client = await async_rs_or_single_client(event_listeners=[listener], retryReads=True) + client = await self.async_rs_or_single_client(event_listeners=[listener], retryReads=True) async with self.fail_point( {"mode": {"times": 1}, "data": {"failCommands": ["find"], "closeConnection": True}} ): @@ -1514,7 +1513,7 @@ async def test_find_raw_snapshot_reads(self): await c.insert_many(docs) listener = OvertCommandListener() - client = await async_rs_or_single_client(event_listeners=[listener], retryReads=True) + client = await self.async_rs_or_single_client(event_listeners=[listener], retryReads=True) db = client[self.db.name] async with client.start_session(snapshot=True) as session: await db.test.distinct("x", {}, session=session) @@ -1577,7 +1576,7 @@ async def test_read_concern(self): async def test_monitoring(self): listener = EventListener() - client = await async_rs_or_single_client(event_listeners=[listener]) + client = await self.async_rs_or_single_client(event_listeners=[listener]) c = client.pymongo_test.test await c.drop() await c.insert_many([{"_id": i} for i in range(10)]) @@ -1643,7 +1642,7 @@ async def test_aggregate_raw_transaction(self): await c.insert_many(docs) listener = OvertCommandListener() - client = await async_rs_or_single_client(event_listeners=[listener]) + client = await self.async_rs_or_single_client(event_listeners=[listener]) async with client.start_session() as session: async with await session.start_transaction(): batches = await ( @@ -1674,7 +1673,7 @@ async def test_aggregate_raw_retryable_reads(self): await c.insert_many(docs) listener = OvertCommandListener() - client = await async_rs_or_single_client(event_listeners=[listener], retryReads=True) + client = await self.async_rs_or_single_client(event_listeners=[listener], retryReads=True) async with self.fail_point( {"mode": {"times": 1}, "data": {"failCommands": ["aggregate"], "closeConnection": True}} ): @@ -1698,7 +1697,7 @@ async def test_aggregate_raw_snapshot_reads(self): await c.insert_many(docs) listener = OvertCommandListener() - client = await async_rs_or_single_client(event_listeners=[listener], retryReads=True) + client = await self.async_rs_or_single_client(event_listeners=[listener], retryReads=True) db = client[self.db.name] async with client.start_session(snapshot=True) as session: await db.test.distinct("x", {}, session=session) @@ -1744,7 +1743,7 @@ async def test_collation(self): async def test_monitoring(self): listener = EventListener() - client = await async_rs_or_single_client(event_listeners=[listener]) + client = await self.async_rs_or_single_client(event_listeners=[listener]) c = client.pymongo_test.test await c.drop() await c.insert_many([{"_id": i} for i in range(10)]) @@ -1788,7 +1787,7 @@ async def test_monitoring(self): @async_client_context.require_no_mongos async def test_exhaust_cursor_db_set(self): listener = OvertCommandListener() - client = await async_rs_or_single_client(event_listeners=[listener]) + client = await self.async_rs_or_single_client(event_listeners=[listener]) self.addAsyncCleanup(client.close) c = client.pymongo_test.test await c.delete_many({}) diff --git a/test/asynchronous/test_database.py b/test/asynchronous/test_database.py index 8f6886a2a7..2e1f8e0450 100644 --- a/test/asynchronous/test_database.py +++ b/test/asynchronous/test_database.py @@ -29,7 +29,6 @@ from test.utils import ( IMPOSSIBLE_WRITE_CONCERN, OvertCommandListener, - async_rs_or_single_client, async_wait_until, ) @@ -208,7 +207,7 @@ async def test_list_collection_names(self): async def test_list_collection_names_filter(self): listener = OvertCommandListener() - client = await async_rs_or_single_client(event_listeners=[listener]) + client = await self.async_rs_or_single_client(event_listeners=[listener]) db = client[self.db.name] await db.capped.drop() await db.create_collection("capped", capped=True, size=4096) @@ -235,7 +234,7 @@ async def test_list_collection_names_filter(self): async def test_check_exists(self): listener = OvertCommandListener() - client = await async_rs_or_single_client(event_listeners=[listener]) + client = await self.async_rs_or_single_client(event_listeners=[listener]) self.addAsyncCleanup(client.close) db = client[self.db.name] await db.drop_collection("unique") @@ -326,7 +325,7 @@ async def test_list_collections(self): await self.client.drop_database("pymongo_test") async def test_list_collection_names_single_socket(self): - client = await async_rs_or_single_client(maxPoolSize=1) + client = await self.async_rs_or_single_client(maxPoolSize=1) await client.drop_database("test_collection_names_single_socket") db = client.test_collection_names_single_socket for i in range(200): diff --git a/test/asynchronous/test_encryption.py b/test/asynchronous/test_encryption.py index eb431e1d50..dea3571aad 100644 --- a/test/asynchronous/test_encryption.py +++ b/test/asynchronous/test_encryption.py @@ -43,6 +43,8 @@ from test import ( unittest, ) +from test.asynchronous.test_bulk import AsyncBulkTestBase +from test.asynchronous.utils_spec_runner import AsyncSpecRunner from test.helpers import ( AWS_CREDS, AZURE_CREDS, @@ -52,19 +54,16 @@ KMIP_CREDS, LOCAL_MASTER_KEY, ) -from test.test_bulk import BulkTestBase from test.unified_format import generate_test_classes from test.utils import ( AllowListEventListener, OvertCommandListener, SpecTestCreator, TopologyEventListener, - async_rs_or_single_client, async_wait_until, camel_to_snake_args, is_greenthread_patched, ) -from test.utils_spec_runner import SpecRunner from bson import DatetimeMS, Decimal128, encode, json_util from bson.binary import UUID_SUBTYPE, Binary, UuidRepresentation @@ -260,7 +259,7 @@ def bson_data(*paths): class TestClientSimple(AsyncEncryptionIntegrationTest): async def _test_auto_encrypt(self, opts): - client = await async_rs_or_single_client(auto_encryption_opts=opts) + client = await self.async_rs_or_single_client(auto_encryption_opts=opts) self.addAsyncCleanup(client.aclose) # Create the encrypted field's data key. @@ -342,7 +341,7 @@ async def test_auto_encrypt_local_schema_map(self): async def test_use_after_close(self): opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys") - client = await async_rs_or_single_client(auto_encryption_opts=opts) + client = await self.async_rs_or_single_client(auto_encryption_opts=opts) self.addAsyncCleanup(client.aclose) await client.admin.command("ping") @@ -360,7 +359,7 @@ async def test_use_after_close(self): ) async def test_fork(self): opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys") - client = await async_rs_or_single_client(auto_encryption_opts=opts) + client = await self.async_rs_or_single_client(auto_encryption_opts=opts) self.addAsyncCleanup(client.aclose) async def target(): @@ -372,10 +371,10 @@ async def target(): await target() -class TestEncryptedBulkWrite(BulkTestBase, AsyncEncryptionIntegrationTest): +class TestEncryptedBulkWrite(AsyncBulkTestBase, AsyncEncryptionIntegrationTest): async def test_upsert_uuid_standard_encrypt(self): opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys") - client = await async_rs_or_single_client(auto_encryption_opts=opts) + client = await self.async_rs_or_single_client(auto_encryption_opts=opts) self.addAsyncCleanup(client.aclose) options = CodecOptions(uuid_representation=UuidRepresentation.STANDARD) @@ -416,7 +415,7 @@ async def _setup_class(cls): @async_client_context.require_version_max(4, 0, 99) async def test_raise_max_wire_version_error(self): opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys") - client = await async_rs_or_single_client(auto_encryption_opts=opts) + client = await self.async_rs_or_single_client(auto_encryption_opts=opts) self.addAsyncCleanup(client.aclose) msg = "Auto-encryption requires a minimum MongoDB version of 4.2" with self.assertRaisesRegex(ConfigurationError, msg): @@ -430,7 +429,7 @@ async def test_raise_max_wire_version_error(self): async def test_raise_unsupported_error(self): opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys") - client = await async_rs_or_single_client(auto_encryption_opts=opts) + client = await self.async_rs_or_single_client(auto_encryption_opts=opts) self.addAsyncCleanup(client.aclose) msg = "find_raw_batches does not support auto encryption" with self.assertRaisesRegex(InvalidOperation, msg): @@ -613,7 +612,7 @@ async def test_with_statement(self): if _IS_SYNC: # TODO: Add asynchronous SpecRunner (https://jira.mongodb.org/browse/PYTHON-4700) - class TestSpec(SpecRunner): + class TestSpec(AsyncSpecRunner): @classmethod @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") def setUpClass(cls): @@ -811,7 +810,9 @@ class TestDataKeyDoubleEncryption(AsyncEncryptionIntegrationTest): async def _setup_class(cls): await super()._setup_class() cls.listener = OvertCommandListener() - cls.client = await async_rs_or_single_client(event_listeners=[cls.listener]) + cls.client = await AsyncPyMongoTestCase.unmanaged_async_rs_or_single_client( + event_listeners=[cls.listener] + ) await cls.client.db.coll.drop() cls.vault = await create_key_vault(cls.client.keyvault.datakeys) @@ -833,7 +834,7 @@ async def _setup_class(cls): opts = AutoEncryptionOpts( cls.KMS_PROVIDERS, "keyvault.datakeys", schema_map=schemas, kms_tls_options=KMS_TLS_OPTS ) - cls.client_encrypted = await async_rs_or_single_client( + cls.client_encrypted = await AsyncPyMongoTestCase.unmanaged_async_rs_or_single_client( auto_encryption_opts=opts, uuidRepresentation="standard" ) cls.client_encryption = AsyncClientEncryption( @@ -923,7 +924,7 @@ async def _test_external_key_vault(self, with_external_key_vault): # Configure the encrypted field via the local schema_map option. schemas = {"db.coll": json_data("external", "external-schema.json")} if with_external_key_vault: - key_vault_client = await async_rs_or_single_client( + key_vault_client = await self.async_rs_or_single_client( username="fake-user", password="fake-pwd" ) self.addAsyncCleanup(key_vault_client.close) @@ -936,7 +937,7 @@ async def _test_external_key_vault(self, with_external_key_vault): key_vault_client=key_vault_client, ) - client_encrypted = await async_rs_or_single_client( + client_encrypted = await self.async_rs_or_single_client( auto_encryption_opts=opts, uuidRepresentation="standard" ) self.addAsyncCleanup(client_encrypted.close) @@ -990,7 +991,7 @@ async def test_views_are_prohibited(self): self.addAsyncCleanup(self.client.db.view.drop) opts = AutoEncryptionOpts(self.kms_providers(), "keyvault.datakeys") - client_encrypted = await async_rs_or_single_client( + client_encrypted = await self.async_rs_or_single_client( auto_encryption_opts=opts, uuidRepresentation="standard" ) self.addAsyncCleanup(client_encrypted.aclose) @@ -1050,7 +1051,7 @@ async def _test_corpus(self, opts): ) self.addAsyncCleanup(vault.drop) - client_encrypted = await async_rs_or_single_client(auto_encryption_opts=opts) + client_encrypted = await self.async_rs_or_single_client(auto_encryption_opts=opts) self.addAsyncCleanup(client_encrypted.close) client_encryption = AsyncClientEncryption( @@ -1203,7 +1204,7 @@ async def _setup_class(cls): opts = AutoEncryptionOpts({"local": {"key": LOCAL_MASTER_KEY}}, "keyvault.datakeys") cls.listener = OvertCommandListener() - cls.client_encrypted = await async_rs_or_single_client( + cls.client_encrypted = await AsyncPyMongoTestCase.unmanaged_async_rs_or_single_client( auto_encryption_opts=opts, event_listeners=[cls.listener] ) cls.coll_encrypted = cls.client_encrypted.db.coll @@ -1525,7 +1526,7 @@ async def _test_automatic(self, expectation_extjson, payload): ) insert_listener = AllowListEventListener("insert") - client = await async_rs_or_single_client( + client = await self.async_rs_or_single_client( auto_encryption_opts=encryption_opts, event_listeners=[insert_listener] ) self.addAsyncCleanup(client.aclose) @@ -1604,13 +1605,13 @@ async def test_automatic(self): # https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.rst#deadlock-tests class TestDeadlockProse(AsyncEncryptionIntegrationTest): async def asyncSetUp(self): - self.client_test = await async_rs_or_single_client( + self.client_test = await self.async_rs_or_single_client( maxPoolSize=1, readConcernLevel="majority", w="majority", uuidRepresentation="standard" ) self.addAsyncCleanup(self.client_test.aclose) self.client_keyvault_listener = OvertCommandListener() - self.client_keyvault = await async_rs_or_single_client( + self.client_keyvault = await self.async_rs_or_single_client( maxPoolSize=1, readConcernLevel="majority", w="majority", @@ -1645,7 +1646,7 @@ async def asyncSetUp(self): self.optargs = ({"local": {"key": LOCAL_MASTER_KEY}}, "keyvault.datakeys") async def _run_test(self, max_pool_size, auto_encryption_opts): - client_encrypted = await async_rs_or_single_client( + client_encrypted = await self.async_rs_or_single_client( readConcernLevel="majority", w="majority", maxPoolSize=max_pool_size, @@ -1855,7 +1856,7 @@ async def asyncSetUp(self): key_vault_namespace="keyvault.datakeys", kms_providers=kms_providers_map ) self.listener = AllowListEventListener("aggregate") - self.encrypted_client = await async_rs_or_single_client( + self.encrypted_client = await self.async_rs_or_single_client( auto_encryption_opts=opts, retryReads=False, event_listeners=[self.listener] ) self.addAsyncCleanup(self.encrypted_client.close) @@ -1935,7 +1936,7 @@ def reset_timeout(): "--port=27027", ], ) - client_encrypted = await async_rs_or_single_client(auto_encryption_opts=opts) + client_encrypted = await self.async_rs_or_single_client(auto_encryption_opts=opts) self.addAsyncCleanup(client_encrypted.close) with self.assertRaisesRegex(EncryptionError, "Timeout"): await client_encrypted.db.coll.insert_one({"encrypted": "test"}) @@ -1950,13 +1951,14 @@ async def test_bypassAutoEncryption(self): "--port=27027", ], ) - client_encrypted = await async_rs_or_single_client(auto_encryption_opts=opts) - self.addAsyncCleanup(client_encrypted.aclose) + client_encrypted = await self.async_rs_or_single_client(auto_encryption_opts=opts) + self.addAsyncCleanup(client_encrypted.close) await client_encrypted.db.coll.insert_one({"unencrypted": "test"}) # Validate that mongocryptd was not spawned: mongocryptd_client = AsyncMongoClient( "mongodb://localhost:27027/?serverSelectionTimeoutMS=500" ) + self.addAsyncCleanup(mongocryptd_client.close) with self.assertRaises(ServerSelectionTimeoutError): await mongocryptd_client.admin.command("ping") @@ -1978,7 +1980,7 @@ async def test_via_loading_shared_library(self): ], crypt_shared_lib_required=True, ) - client_encrypted = await async_rs_or_single_client(auto_encryption_opts=opts) + client_encrypted = await self.async_rs_or_single_client(auto_encryption_opts=opts) self.addAsyncCleanup(client_encrypted.aclose) await client_encrypted.db.coll.drop() await client_encrypted.db.coll.insert_one({"encrypted": "test"}) @@ -2020,7 +2022,7 @@ def listener(): mongocryptd_uri="mongodb://localhost:47021", crypt_shared_lib_required=False, ) - client_encrypted = await async_rs_or_single_client(auto_encryption_opts=opts) + client_encrypted = await self.async_rs_or_single_client(auto_encryption_opts=opts) self.addAsyncCleanup(client_encrypted.aclose) await client_encrypted.db.coll.drop() await client_encrypted.db.coll.insert_one({"encrypted": "test"}) @@ -2336,7 +2338,7 @@ async def asyncSetUp(self): key_vault.full_name, bypass_query_analysis=True, ) - self.encrypted_client = await async_rs_or_single_client(auto_encryption_opts=opts) + self.encrypted_client = await self.async_rs_or_single_client(auto_encryption_opts=opts) self.addAsyncCleanup(self.encrypted_client.aclose) async def test_01_insert_encrypted_indexed_and_find(self): @@ -2484,7 +2486,7 @@ async def run_test(self, src_provider, dst_provider): ) # Step 5. Create a ``AsyncClientEncryption`` object named ``client_encryption2`` - client2 = await async_rs_or_single_client() + client2 = await self.async_rs_or_single_client() self.addAsyncCleanup(client2.aclose) client_encryption2 = AsyncClientEncryption( key_vault_client=client2, @@ -2559,7 +2561,7 @@ async def test_queryable_encryption(self): # AsyncMongoClient to use in testing that handles auth/tls/etc, # and cleanup. async def AsyncMongoClient(**kwargs): - c = await async_rs_or_single_client(**kwargs) + c = await self.async_rs_or_single_client(**kwargs) self.addAsyncCleanup(c.aclose) return c @@ -2661,7 +2663,7 @@ async def asyncSetUp(self): key_vault.full_name, bypass_query_analysis=True, ) - self.encrypted_client = await async_rs_or_single_client(auto_encryption_opts=opts) + self.encrypted_client = await self.async_rs_or_single_client(auto_encryption_opts=opts) self.db = self.encrypted_client.db self.addAsyncCleanup(self.encrypted_client.aclose) diff --git a/test/asynchronous/test_grid_file.py b/test/asynchronous/test_grid_file.py index 7071fc76f4..8b5c4344d1 100644 --- a/test/asynchronous/test_grid_file.py +++ b/test/asynchronous/test_grid_file.py @@ -28,7 +28,7 @@ sys.path[0:0] = [""] from test import IntegrationTest, qcheck, unittest -from test.utils import EventListener, async_rs_or_single_client, rs_or_single_client +from test.utils import EventListener from bson.objectid import ObjectId from gridfs import GridFS @@ -791,6 +791,7 @@ async def test_grid_out_lazy_connect(self): async def test_grid_in_lazy_connect(self): client = AsyncMongoClient("badhost", connect=False, serverSelectionTimeoutMS=10) + self.addAsyncCleanup(client.close) fs = client.db.fs infile = AsyncGridIn(fs, file_id=-1, chunk_size=1) with self.assertRaises(ServerSelectionTimeoutError): @@ -801,7 +802,7 @@ async def test_grid_in_lazy_connect(self): async def test_unacknowledged(self): # w=0 is prohibited. with self.assertRaises(ConfigurationError): - AsyncGridIn((await async_rs_or_single_client(w=0)).pymongo_test.fs) + AsyncGridIn((await self.async_rs_or_single_client(w=0)).pymongo_test.fs) async def test_survive_cursor_not_found(self): # By default the find command returns 101 documents in the first batch. @@ -809,7 +810,7 @@ async def test_survive_cursor_not_found(self): chunk_size = 1024 data = b"d" * (102 * chunk_size) listener = EventListener() - client = await async_rs_or_single_client(event_listeners=[listener]) + client = await self.async_rs_or_single_client(event_listeners=[listener]) db = client.pymongo_test async with AsyncGridIn(db.fs, chunk_size=chunk_size) as infile: await infile.write(data) diff --git a/test/asynchronous/test_logger.py b/test/asynchronous/test_logger.py index 7a58846515..856087ec7c 100644 --- a/test/asynchronous/test_logger.py +++ b/test/asynchronous/test_logger.py @@ -16,7 +16,6 @@ import os from test import unittest from test.asynchronous import AsyncIntegrationTest -from test.utils import async_single_client from unittest.mock import patch from bson import json_util @@ -86,7 +85,7 @@ async def test_truncation_multi_byte_codepoints(self): self.assertEqual(last_3_bytes, str_to_repeat) async def test_logging_without_listeners(self): - c = await async_single_client() + c = await self.async_single_client() self.assertEqual(len(c._event_listeners.event_listeners()), 0) with self.assertLogs("pymongo.connection", level="DEBUG") as cm: await c.db.test.insert_one({"x": "1"}) diff --git a/test/asynchronous/test_session.py b/test/asynchronous/test_session.py index 1e1f5659ba..073864ef45 100644 --- a/test/asynchronous/test_session.py +++ b/test/asynchronous/test_session.py @@ -36,9 +36,7 @@ from test.utils import ( EventListener, ExceptionCatchingThread, - async_rs_or_single_client, async_wait_until, - rs_or_single_client, wait_until, ) @@ -90,7 +88,7 @@ async def _setup_class(cls): await super()._setup_class() # Create a second client so we can make sure clients cannot share # sessions. - cls.client2 = await async_rs_or_single_client() + cls.client2 = await AsyncPyMongoTestCase.unmanaged_async_rs_or_single_client() # Redact no commands, so we can test user-admin commands have "lsid". cls.sensitive_commands = monitoring._SENSITIVE_COMMANDS.copy() @@ -105,7 +103,7 @@ async def _tearDown_class(cls): async def asyncSetUp(self): self.listener = SessionTestListener() self.session_checker_listener = SessionTestListener() - self.client = await async_rs_or_single_client( + self.client = await self.async_rs_or_single_client( event_listeners=[self.listener, self.session_checker_listener] ) self.addAsyncCleanup(self.client.close) @@ -202,7 +200,7 @@ def test_implicit_sessions_checkout(self): failures = 0 for _ in range(5): listener = EventListener() - client = async_rs_or_single_client(event_listeners=[listener], maxPoolSize=1) + client = self.async_rs_or_single_client(event_listeners=[listener], maxPoolSize=1) cursor = client.db.test.find({}) ops: List[Tuple[Callable, List[Any]]] = [ (client.db.test.find_one, [{"_id": 1}]), @@ -285,7 +283,7 @@ async def test_end_session(self): async def test_end_sessions(self): # Use a new client so that the tearDown hook does not error. listener = SessionTestListener() - client = await async_rs_or_single_client(event_listeners=[listener]) + client = await self.async_rs_or_single_client(event_listeners=[listener]) # Start many sessions. sessions = [client.start_session() for _ in range(_MAX_END_SESSIONS + 1)] for s in sessions: @@ -789,7 +787,7 @@ async def _test_unacknowledged_ops(self, client, *ops): async def test_unacknowledged_writes(self): # Ensure the collection exists. await self.client.pymongo_test.test_unacked_writes.insert_one({}) - client = await async_rs_or_single_client(w=0, event_listeners=[self.listener]) + client = await self.async_rs_or_single_client(w=0, event_listeners=[self.listener]) self.addAsyncCleanup(client.close) db = client.pymongo_test coll = db.test_unacked_writes @@ -838,7 +836,9 @@ class TestCausalConsistency(AsyncUnitTest): @classmethod async def _setup_class(cls): cls.listener = SessionTestListener() - cls.client = await async_rs_or_single_client(event_listeners=[cls.listener]) + cls.client = await AsyncPyMongoTestCase.unmanaged_async_rs_or_single_client( + event_listeners=[cls.listener] + ) @classmethod async def _tearDown_class(cls): @@ -1153,7 +1153,7 @@ async def asyncSetUp(self): async def test_cluster_time(self): listener = SessionTestListener() # Prevent heartbeats from updating $clusterTime between operations. - client = await async_rs_or_single_client( + client = await self.async_rs_or_single_client( event_listeners=[listener], heartbeatFrequencyMS=999999 ) self.addAsyncCleanup(client.close) diff --git a/test/asynchronous/test_transactions.py b/test/asynchronous/test_transactions.py index 8fa1e70d01..3a05290e10 100644 --- a/test/asynchronous/test_transactions.py +++ b/test/asynchronous/test_transactions.py @@ -27,8 +27,6 @@ from test.asynchronous.utils_spec_runner import AsyncSpecRunner from test.utils import ( OvertCommandListener, - async_rs_client, - async_single_client, wait_until, ) from typing import List @@ -69,13 +67,7 @@ async def _setup_class(cls): await super()._setup_class() if async_client_context.supports_transactions(): for address in async_client_context.mongoses: - cls.mongos_clients.append(await async_single_client("{}:{}".format(*address))) - - @classmethod - async def _tearDown_class(cls): - for client in cls.mongos_clients: - await client.close() - await super()._tearDown_class() + cls.mongos_clients.append(await cls.async_single_client("{}:{}".format(*address))) def maybe_skip_scenario(self, test): super().maybe_skip_scenario(test) @@ -120,7 +112,7 @@ def test_transaction_options_validation(self): @async_client_context.require_transactions async def test_transaction_write_concern_override(self): """Test txn overrides Client/Database/Collection write_concern.""" - client = await async_rs_client(w=0) + client = await self.async_rs_client(w=0) self.addAsyncCleanup(client.close) db = client.test coll = db.test @@ -178,7 +170,9 @@ async def test_transaction_write_concern_override(self): async def test_unpin_for_next_transaction(self): # Increase localThresholdMS and wait until both nodes are discovered # to avoid false positives. - client = await async_rs_client(async_client_context.mongos_seeds(), localThresholdMS=1000) + client = await self.async_rs_client( + async_client_context.mongos_seeds(), localThresholdMS=1000 + ) wait_until(lambda: len(client.nodes) > 1, "discover both mongoses") coll = client.test.test # Create the collection. @@ -206,7 +200,9 @@ async def test_unpin_for_next_transaction(self): async def test_unpin_for_non_transaction_operation(self): # Increase localThresholdMS and wait until both nodes are discovered # to avoid false positives. - client = await async_rs_client(async_client_context.mongos_seeds(), localThresholdMS=1000) + client = await self.async_rs_client( + async_client_context.mongos_seeds(), localThresholdMS=1000 + ) wait_until(lambda: len(client.nodes) > 1, "discover both mongoses") coll = client.test.test # Create the collection. @@ -335,7 +331,7 @@ async def test_transaction_starts_with_batched_write(self): # Start a transaction with a batch of operations that needs to be # split. listener = OvertCommandListener() - client = await async_rs_client(event_listeners=[listener]) + client = await self.async_rs_client(event_listeners=[listener]) coll = client[self.db.name].test await coll.delete_many({}) listener.reset() @@ -364,7 +360,7 @@ async def test_transaction_starts_with_batched_write(self): @async_client_context.require_transactions async def test_transaction_direct_connection(self): - client = await async_single_client() + client = await self.async_single_client() self.addAsyncCleanup(client.close) coll = client.pymongo_test.test @@ -454,7 +450,7 @@ async def callback2(session): @async_client_context.require_transactions async def test_callback_not_retried_after_timeout(self): listener = OvertCommandListener() - client = await async_rs_client(event_listeners=[listener]) + client = await self.async_rs_client(event_listeners=[listener]) self.addAsyncCleanup(client.close) coll = client[self.db.name].test @@ -483,7 +479,7 @@ async def callback(session): @async_client_context.require_transactions async def test_callback_not_retried_after_commit_timeout(self): listener = OvertCommandListener() - client = await async_rs_client(event_listeners=[listener]) + client = await self.async_rs_client(event_listeners=[listener]) self.addAsyncCleanup(client.close) coll = client[self.db.name].test @@ -518,7 +514,7 @@ async def callback(session): @async_client_context.require_transactions async def test_commit_not_retried_after_timeout(self): listener = OvertCommandListener() - client = await async_rs_client(event_listeners=[listener]) + client = await self.async_rs_client(event_listeners=[listener]) self.addAsyncCleanup(client.close) coll = client[self.db.name].test diff --git a/test/asynchronous/utils_spec_runner.py b/test/asynchronous/utils_spec_runner.py index 71044d1530..12cb13c2cd 100644 --- a/test/asynchronous/utils_spec_runner.py +++ b/test/asynchronous/utils_spec_runner.py @@ -25,7 +25,6 @@ EventListener, OvertCommandListener, ServerAndTopologyEventListener, - async_rs_client, camel_to_snake, camel_to_snake_args, parse_spec_options, @@ -101,6 +100,8 @@ async def _setup_class(cls): @classmethod async def _tearDown_class(cls): cls.knobs.disable() + for client in cls.mongos_clients: + await client.close() await super()._tearDown_class() def setUp(self): @@ -527,7 +528,7 @@ async def run_scenario(self, scenario_def, test): host = async_client_context.MULTI_MONGOS_LB_URI elif async_client_context.is_mongos: host = async_client_context.mongos_seeds() - client = await async_rs_client( + client = await self.async_rs_client( h=host, event_listeners=[listener, pool_listener, server_listener], **client_options ) self.scenario_client = client diff --git a/test/test_bulk.py b/test/test_bulk.py index 9069109cfa..751600804e 100644 --- a/test/test_bulk.py +++ b/test/test_bulk.py @@ -25,8 +25,6 @@ from test import IntegrationTest, client_context, remove_all_users, unittest from test.utils import ( - rs_or_single_client_noauth, - single_client, wait_until, ) @@ -37,7 +35,6 @@ from pymongo.errors import ( BulkWriteError, ConfigurationError, - InvalidOperation, OperationFailure, ) from pymongo.operations import * @@ -913,7 +910,7 @@ class TestBulkAuthorization(BulkAuthorizationTestBase): def test_readonly(self): # We test that an authorization failure aborts the batch and is raised # as OperationFailure. - cli = rs_or_single_client_noauth( + cli = self.rs_or_single_client_noauth( username="readonly", password="pw", authSource="pymongo_test" ) coll = cli.pymongo_test.test @@ -924,7 +921,7 @@ def test_readonly(self): def test_no_remove(self): # We test that an authorization failure aborts the batch and is raised # as OperationFailure. - cli = rs_or_single_client_noauth( + cli = self.rs_or_single_client_noauth( username="noremove", password="pw", authSource="pymongo_test" ) coll = cli.pymongo_test.test @@ -952,7 +949,9 @@ def _setup_class(cls): if cls.w is not None and cls.w > 1: for member in (client_context.hello)["hosts"]: if member != (client_context.hello)["primary"]: - cls.secondary = single_client(*partition_node(member)) + cls.secondary = TestBulkWriteConcern.unmanaged_single_client( + *partition_node(member) + ) break @classmethod diff --git a/test/test_client.py b/test/test_client.py index fcfd8926f2..464341d6e4 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -67,10 +67,6 @@ is_greenthread_patched, lazy_client_trial, one, - rs_client, - rs_or_single_client, - rs_or_single_client_noauth, - single_client, wait_until, ) @@ -131,7 +127,9 @@ class ClientUnitTest(UnitTest): @classmethod def _setup_class(cls): - cls.client = rs_or_single_client(connect=False, serverSelectionTimeoutMS=100) + cls.client = ClientUnitTest.unmanaged_rs_or_single_client( + connect=False, serverSelectionTimeoutMS=100 + ) @classmethod def _tearDown_class(cls): @@ -261,7 +259,7 @@ def test_iteration(self): self.assertNotIsInstance(client, Iterable) def test_get_default_database(self): - with rs_or_single_client( + with self.rs_or_single_client( "mongodb://%s:%d/foo" % (client_context.host, client_context.port), connect=False, ) as c: @@ -279,7 +277,7 @@ def test_get_default_database(self): self.assertEqual(ReadPreference.SECONDARY, db.read_preference) self.assertEqual(write_concern, db.write_concern) - with rs_or_single_client( + with self.rs_or_single_client( "mongodb://%s:%d/" % (client_context.host, client_context.port), connect=False, ) as c: @@ -287,7 +285,7 @@ def test_get_default_database(self): def test_get_default_database_error(self): # URI with no database. - with rs_or_single_client( + with self.rs_or_single_client( "mongodb://%s:%d/" % (client_context.host, client_context.port), connect=False, ) as c: @@ -299,11 +297,11 @@ def test_get_default_database_with_authsource(self): client_context.host, client_context.port, ) - with rs_or_single_client(uri, connect=False) as c: + with self.rs_or_single_client(uri, connect=False) as c: self.assertEqual(Database(c, "foo"), c.get_default_database()) def test_get_database_default(self): - with rs_or_single_client( + with self.rs_or_single_client( "mongodb://%s:%d/foo" % (client_context.host, client_context.port), connect=False, ) as c: @@ -311,7 +309,7 @@ def test_get_database_default(self): def test_get_database_default_error(self): # URI with no database. - c = rs_or_single_client( + c = self.rs_or_single_client( "mongodb://%s:%d/" % (client_context.host, client_context.port), connect=False, ) @@ -324,7 +322,7 @@ def test_get_database_default_with_authsource(self): client_context.host, client_context.port, ) - c = rs_or_single_client(uri, connect=False) + c = self.rs_or_single_client(uri, connect=False) self.assertEqual(Database(c, "foo"), c.get_database()) c.close() @@ -337,7 +335,7 @@ def test_primary_read_pref_with_tags(self): MongoClient("mongodb://host/?readpreference=primary&readpreferencetags=dc:east") def test_read_preference(self): - with rs_or_single_client( + with self.rs_or_single_client( "mongodb://host", connect=False, readpreference=ReadPreference.NEAREST.mongos_mode ) as c: self.assertEqual(c.read_preference, ReadPreference.NEAREST) @@ -684,7 +682,7 @@ def test_multiple_uris(self): def test_max_idle_time_reaper_default(self): with client_knobs(kill_cursor_frequency=0.1): # Assert reaper doesn't remove connections when maxIdleTimeMS not set - client = rs_or_single_client() + client = self.rs_or_single_client() server = (client._get_topology()).select_server(readable_server_selector, _Op.TEST) with server._pool.checkout() as conn: pass @@ -695,7 +693,7 @@ def test_max_idle_time_reaper_default(self): def test_max_idle_time_reaper_removes_stale_minPoolSize(self): with client_knobs(kill_cursor_frequency=0.1): # Assert reaper removes idle socket and replaces it with a new one - client = rs_or_single_client(maxIdleTimeMS=500, minPoolSize=1) + client = self.rs_or_single_client(maxIdleTimeMS=500, minPoolSize=1) server = (client._get_topology()).select_server(readable_server_selector, _Op.TEST) with server._pool.checkout() as conn: pass @@ -709,7 +707,7 @@ def test_max_idle_time_reaper_removes_stale_minPoolSize(self): def test_max_idle_time_reaper_does_not_exceed_maxPoolSize(self): with client_knobs(kill_cursor_frequency=0.1): # Assert reaper respects maxPoolSize when adding new connections. - client = rs_or_single_client(maxIdleTimeMS=500, minPoolSize=1, maxPoolSize=1) + client = self.rs_or_single_client(maxIdleTimeMS=500, minPoolSize=1, maxPoolSize=1) server = (client._get_topology()).select_server(readable_server_selector, _Op.TEST) with server._pool.checkout() as conn: pass @@ -723,7 +721,7 @@ def test_max_idle_time_reaper_does_not_exceed_maxPoolSize(self): def test_max_idle_time_reaper_removes_stale(self): with client_knobs(kill_cursor_frequency=0.1): # Assert reaper has removed idle socket and NOT replaced it - client = rs_or_single_client(maxIdleTimeMS=500) + client = self.rs_or_single_client(maxIdleTimeMS=500) server = (client._get_topology()).select_server(readable_server_selector, _Op.TEST) with server._pool.checkout() as conn_one: pass @@ -740,13 +738,13 @@ def test_max_idle_time_reaper_removes_stale(self): def test_min_pool_size(self): with client_knobs(kill_cursor_frequency=0.1): - client = rs_or_single_client() + client = self.rs_or_single_client() self.addCleanup(client.close) server = (client._get_topology()).select_server(readable_server_selector, _Op.TEST) self.assertEqual(0, len(server._pool.conns)) # Assert that pool started up at minPoolSize - client = rs_or_single_client(minPoolSize=10) + client = self.rs_or_single_client(minPoolSize=10) self.addCleanup(client.close) server = (client._get_topology()).select_server(readable_server_selector, _Op.TEST) wait_until( @@ -766,7 +764,7 @@ def test_min_pool_size(self): def test_max_idle_time_checkout(self): # Use high frequency to test _get_socket_no_auth. with client_knobs(kill_cursor_frequency=99999999): - client = rs_or_single_client(maxIdleTimeMS=500) + client = self.rs_or_single_client(maxIdleTimeMS=500) self.addCleanup(client.close) server = (client._get_topology()).select_server(readable_server_selector, _Op.TEST) with server._pool.checkout() as conn: @@ -781,7 +779,7 @@ def test_max_idle_time_checkout(self): self.assertTrue(new_con in server._pool.conns) # Test that connections are reused if maxIdleTimeMS is not set. - client = rs_or_single_client() + client = self.rs_or_single_client() self.addCleanup(client.close) server = (client._get_topology()).select_server(readable_server_selector, _Op.TEST) with server._pool.checkout() as conn: @@ -823,27 +821,27 @@ def test_constants(self): def test_init_disconnected(self): host, port = client_context.host, client_context.port - c = rs_or_single_client(connect=False) + c = self.rs_or_single_client(connect=False) self.addCleanup(c.close) # is_primary causes client to block until connected self.assertIsInstance(c.is_primary, bool) - c = rs_or_single_client(connect=False) + c = self.rs_or_single_client(connect=False) self.addCleanup(c.close) self.assertIsInstance(c.is_mongos, bool) - c = rs_or_single_client(connect=False) + c = self.rs_or_single_client(connect=False) self.addCleanup(c.close) self.assertIsInstance(c.options.pool_options.max_pool_size, int) self.assertIsInstance(c.nodes, frozenset) - c = rs_or_single_client(connect=False) + c = self.rs_or_single_client(connect=False) self.addCleanup(c.close) self.assertEqual(c.codec_options, CodecOptions()) - c = rs_or_single_client(connect=False) + c = self.rs_or_single_client(connect=False) self.addCleanup(c.close) self.assertFalse(c.primary) self.assertFalse(c.secondaries) - c = rs_or_single_client(connect=False) + c = self.rs_or_single_client(connect=False) self.addCleanup(c.close) self.assertIsInstance(c.topology_description, TopologyDescription) self.assertEqual(c.topology_description, c._topology._description) @@ -870,12 +868,12 @@ def test_init_disconnected_with_auth(self): def test_equality(self): seed = "{}:{}".format(*list(self.client._topology_settings.seeds)[0]) - with rs_or_single_client(seed, connect=False) as c: + with self.rs_or_single_client(seed, connect=False) as c: self.assertEqual(client_context.client, c) # Explicitly test inequality self.assertFalse(client_context.client != c) - with rs_or_single_client("invalid.com", connect=False) as c: + with self.rs_or_single_client("invalid.com", connect=False) as c: self.assertNotEqual(client_context.client, c) self.assertTrue(client_context.client != c) @@ -897,10 +895,10 @@ def test_equality(self): def test_hashable(self): seed = "{}:{}".format(*list(self.client._topology_settings.seeds)[0]) - c = rs_or_single_client(seed, connect=False) + c = self.rs_or_single_client(seed, connect=False) self.addCleanup(c.close) self.assertIn(c, {client_context.client}) - c = rs_or_single_client("invalid.com", connect=False) + c = self.rs_or_single_client("invalid.com", connect=False) self.addCleanup(c.close) self.assertNotIn(c, {client_context.client}) @@ -974,7 +972,7 @@ def test_list_databases(self): for helper_doc, cmd_doc in zip(helper_docs, cmd_docs): self.assertIs(type(helper_doc), dict) self.assertEqual(helper_doc.keys(), cmd_doc.keys()) - client = rs_or_single_client(document_class=SON) + client = self.rs_or_single_client(document_class=SON) self.addCleanup(client.close) for doc in client.list_databases(): self.assertIs(type(doc), dict) @@ -1014,7 +1012,7 @@ def test_drop_database(self): self.client.drop_database("pymongo_test") if client_context.is_rs: - wc_client = rs_or_single_client(w=len(client_context.nodes) + 1) + wc_client = self.rs_or_single_client(w=len(client_context.nodes) + 1) with self.assertRaises(WriteConcernError): wc_client.drop_database("pymongo_test2") @@ -1024,7 +1022,7 @@ def test_drop_database(self): self.assertNotIn("pymongo_test2", dbs) def test_close(self): - test_client = rs_or_single_client() + test_client = self.rs_or_single_client() coll = test_client.pymongo_test.bar test_client.close() with self.assertRaises(InvalidOperation): @@ -1034,7 +1032,7 @@ def test_close_kills_cursors(self): if sys.platform.startswith("java"): # We can't figure out how to make this test reliable with Jython. raise SkipTest("Can't test with Jython") - test_client = rs_or_single_client() + test_client = self.rs_or_single_client() # Kill any cursors possibly queued up by previous tests. gc.collect() test_client._process_periodic_tasks() @@ -1061,14 +1059,14 @@ def test_close_kills_cursors(self): self.assertTrue(test_client._topology._opened) test_client.close() self.assertFalse(test_client._topology._opened) - test_client = rs_or_single_client() + test_client = self.rs_or_single_client() # The killCursors task should not need to re-open the topology. test_client._process_periodic_tasks() self.assertTrue(test_client._topology._opened) test_client.close() def test_close_stops_kill_cursors_thread(self): - client = rs_client() + client = self.rs_client() client.test.test.find_one() self.assertFalse(client._kill_cursors_executor._stopped) @@ -1084,7 +1082,7 @@ def test_close_stops_kill_cursors_thread(self): def test_uri_connect_option(self): # Ensure that topology is not opened if connect=False. - client = rs_client(connect=False) + client = self.rs_client(connect=False) self.assertFalse(client._topology._opened) # Ensure kill cursors thread has not been started. @@ -1101,14 +1099,14 @@ def test_uri_connect_option(self): client.close() def test_close_does_not_open_servers(self): - client = rs_client(connect=False) + client = self.rs_client(connect=False) topology = client._topology self.assertEqual(topology._servers, {}) client.close() self.assertEqual(topology._servers, {}) def test_close_closes_sockets(self): - client = rs_client() + client = self.rs_client() self.addCleanup(client.close) client.test.test.find_one() topology = client._topology @@ -1135,30 +1133,30 @@ def test_auth_from_uri(self): client_context.create_user("pymongo_test", "user", "pass", roles=["userAdmin", "readWrite"]) with self.assertRaises(OperationFailure): - connected(rs_or_single_client_noauth("mongodb://a:b@%s:%d" % (host, port))) + connected(self.rs_or_single_client_noauth("mongodb://a:b@%s:%d" % (host, port))) # No error. - connected(rs_or_single_client_noauth("mongodb://admin:pass@%s:%d" % (host, port))) + connected(self.rs_or_single_client_noauth("mongodb://admin:pass@%s:%d" % (host, port))) # Wrong database. uri = "mongodb://admin:pass@%s:%d/pymongo_test" % (host, port) with self.assertRaises(OperationFailure): - connected(rs_or_single_client_noauth(uri)) + connected(self.rs_or_single_client_noauth(uri)) # No error. connected( - rs_or_single_client_noauth("mongodb://user:pass@%s:%d/pymongo_test" % (host, port)) + self.rs_or_single_client_noauth("mongodb://user:pass@%s:%d/pymongo_test" % (host, port)) ) # Auth with lazy connection. ( - rs_or_single_client_noauth( + self.rs_or_single_client_noauth( "mongodb://user:pass@%s:%d/pymongo_test" % (host, port), connect=False ) ).pymongo_test.test.find_one() # Wrong password. - bad_client = rs_or_single_client_noauth( + bad_client = self.rs_or_single_client_noauth( "mongodb://user:wrong@%s:%d/pymongo_test" % (host, port), connect=False ) @@ -1170,7 +1168,7 @@ def test_username_and_password(self): client_context.create_user("admin", "ad min", "pa/ss") self.addCleanup(client_context.drop_user, "admin", "ad min") - c = rs_or_single_client_noauth(username="ad min", password="pa/ss") + c = self.rs_or_single_client_noauth(username="ad min", password="pa/ss") # Username and password aren't in strings that will likely be logged. self.assertNotIn("ad min", repr(c)) @@ -1182,13 +1180,13 @@ def test_username_and_password(self): c.server_info() with self.assertRaises(OperationFailure): - (rs_or_single_client_noauth(username="ad min", password="foo")).server_info() + (self.rs_or_single_client_noauth(username="ad min", password="foo")).server_info() @client_context.require_auth @client_context.require_no_fips def test_lazy_auth_raises_operation_failure(self): host = client_context.host - lazy_client = rs_or_single_client_noauth( + lazy_client = self.rs_or_single_client_noauth( f"mongodb://user:wrong@{host}/pymongo_test", connect=False ) @@ -1206,7 +1204,7 @@ def test_unix_socket(self): uri = "mongodb://%s" % encoded_socket # Confirm we can do operations via the socket. - client = rs_or_single_client(uri) + client = self.rs_or_single_client(uri) self.addCleanup(client.close) client.pymongo_test.test.insert_one({"dummy": "object"}) dbs = client.list_database_names() @@ -1230,7 +1228,7 @@ def test_document_class(self): self.assertTrue(isinstance(db.test.find_one(), dict)) self.assertFalse(isinstance(db.test.find_one(), SON)) - c = rs_or_single_client(document_class=SON) + c = self.rs_or_single_client(document_class=SON) self.addCleanup(c.close) db = c.pymongo_test @@ -1238,7 +1236,7 @@ def test_document_class(self): self.assertTrue(isinstance(db.test.find_one(), SON)) def test_timeouts(self): - client = rs_or_single_client( + client = self.rs_or_single_client( connectTimeoutMS=10500, socketTimeoutMS=10500, maxIdleTimeMS=10500, @@ -1252,34 +1250,34 @@ def test_timeouts(self): self.assertEqual(10.5, client.options.server_selection_timeout) def test_socket_timeout_ms_validation(self): - c = rs_or_single_client(socketTimeoutMS=10 * 1000) + c = self.rs_or_single_client(socketTimeoutMS=10 * 1000) self.addCleanup(c.close) self.assertEqual(10, (get_pool(c)).opts.socket_timeout) - c = connected(rs_or_single_client(socketTimeoutMS=None)) + c = connected(self.rs_or_single_client(socketTimeoutMS=None)) self.addCleanup(c.close) self.assertEqual(None, (get_pool(c)).opts.socket_timeout) - c = connected(rs_or_single_client(socketTimeoutMS=0)) + c = connected(self.rs_or_single_client(socketTimeoutMS=0)) self.addCleanup(c.close) self.assertEqual(None, (get_pool(c)).opts.socket_timeout) with self.assertRaises(ValueError): - with rs_or_single_client(socketTimeoutMS=-1): + with self.rs_or_single_client(socketTimeoutMS=-1): pass with self.assertRaises(ValueError): - with rs_or_single_client(socketTimeoutMS=1e10): + with self.rs_or_single_client(socketTimeoutMS=1e10): pass with self.assertRaises(ValueError): - with rs_or_single_client(socketTimeoutMS="foo"): + with self.rs_or_single_client(socketTimeoutMS="foo"): pass def test_socket_timeout(self): no_timeout = self.client timeout_sec = 1 - timeout = rs_or_single_client(socketTimeoutMS=1000 * timeout_sec) + timeout = self.rs_or_single_client(socketTimeoutMS=1000 * timeout_sec) self.addCleanup(timeout.close) no_timeout.pymongo_test.drop_collection("test") @@ -1330,7 +1328,7 @@ def test_server_selection_timeout(self): self.assertAlmostEqual(30, client.options.server_selection_timeout) def test_waitQueueTimeoutMS(self): - client = rs_or_single_client(waitQueueTimeoutMS=2000) + client = self.rs_or_single_client(waitQueueTimeoutMS=2000) self.addCleanup(client.close) self.assertEqual((get_pool(client)).opts.wait_queue_timeout, 2) @@ -1344,7 +1342,7 @@ def test_socketKeepAlive(self): def test_tz_aware(self): self.assertRaises(ValueError, MongoClient, tz_aware="foo") - aware = rs_or_single_client(tz_aware=True) + aware = self.rs_or_single_client(tz_aware=True) self.addCleanup(aware.close) naive = self.client aware.pymongo_test.drop_collection("test") @@ -1374,7 +1372,7 @@ def test_ipv6(self): if client_context.is_rs: uri += "/?replicaSet=" + (client_context.replica_set_name or "") - client = rs_or_single_client_noauth(uri) + client = self.rs_or_single_client_noauth(uri) self.addCleanup(client.close) client.pymongo_test.test.insert_one({"dummy": "object"}) client.pymongo_test_bernie.test.insert_one({"dummy": "object"}) @@ -1384,7 +1382,7 @@ def test_ipv6(self): self.assertTrue("pymongo_test_bernie" in dbs) def test_contextlib(self): - client = rs_or_single_client() + client = self.rs_or_single_client() client.pymongo_test.drop_collection("test") client.pymongo_test.test.insert_one({"foo": "bar"}) @@ -1398,7 +1396,7 @@ def test_contextlib(self): self.assertEqual("bar", (client.pymongo_test.test.find_one())["foo"]) with self.assertRaises(InvalidOperation): client.pymongo_test.test.find_one() - client = rs_or_single_client() + client = self.rs_or_single_client() with client as client: self.assertEqual("bar", (client.pymongo_test.test.find_one())["foo"]) with self.assertRaises(InvalidOperation): @@ -1476,7 +1474,7 @@ def test_operation_failure(self): # response to getLastError. PYTHON-395. We need a new client here # to avoid race conditions caused by replica set failover or idle # socket reaping. - client = single_client() + client = self.single_client() self.addCleanup(client.close) client.pymongo_test.test.find_one() pool = get_pool(client) @@ -1501,7 +1499,7 @@ def test_lazy_connect_w0(self): client_context.client.drop_database("test_lazy_connect_w0") self.addCleanup(client_context.client.drop_database, "test_lazy_connect_w0") - client = rs_or_single_client(connect=False, w=0) + client = self.rs_or_single_client(connect=False, w=0) self.addCleanup(client.close) client.test_lazy_connect_w0.test.insert_one({}) @@ -1510,7 +1508,7 @@ def predicate(): wait_until(predicate, "find one document") - client = rs_or_single_client(connect=False, w=0) + client = self.rs_or_single_client(connect=False, w=0) self.addCleanup(client.close) client.test_lazy_connect_w0.test.update_one({}, {"$set": {"x": 1}}) @@ -1519,7 +1517,7 @@ def predicate(): wait_until(predicate, "update one document") - client = rs_or_single_client(connect=False, w=0) + client = self.rs_or_single_client(connect=False, w=0) self.addCleanup(client.close) client.test_lazy_connect_w0.test.delete_one({}) @@ -1532,7 +1530,7 @@ def predicate(): def test_exhaust_network_error(self): # When doing an exhaust query, the socket stays checked out on success # but must be checked in on error to avoid semaphore leaks. - client = rs_or_single_client(maxPoolSize=1, retryReads=False) + client = self.rs_or_single_client(maxPoolSize=1, retryReads=False) self.addCleanup(client.close) collection = client.pymongo_test.test pool = get_pool(client) @@ -1559,7 +1557,9 @@ def test_auth_network_error(self): # when authenticating a new socket with cached credentials. # Get a client with one socket so we detect if it's leaked. - c = connected(rs_or_single_client(maxPoolSize=1, waitQueueTimeoutMS=1, retryReads=False)) + c = connected( + self.rs_or_single_client(maxPoolSize=1, waitQueueTimeoutMS=1, retryReads=False) + ) # Cause a network error on the actual socket. pool = get_pool(c) @@ -1576,7 +1576,7 @@ def test_auth_network_error(self): @client_context.require_no_replica_set def test_connect_to_standalone_using_replica_set_name(self): - with single_client(replicaSet="anything", serverSelectionTimeoutMS=100) as client: + with self.single_client(replicaSet="anything", serverSelectionTimeoutMS=100) as client: with self.assertRaises(AutoReconnect): client.test.test.find_one() @@ -1586,7 +1586,7 @@ def test_stale_getmore(self): # the topology before the getMore message is sent. Test that # MongoClient._run_operation_with_response handles the error. with self.assertRaises(AutoReconnect): - client = rs_client(connect=False, serverSelectionTimeoutMS=100) + client = self.rs_client(connect=False, serverSelectionTimeoutMS=100) client._run_operation( operation=message._GetMore( "pymongo_test", @@ -1634,7 +1634,7 @@ def init(self, *args): client_context.host, client_context.port, ) - client = single_client(uri, event_listeners=[listener]) + client = self.single_client(uri, event_listeners=[listener]) wait_until( lambda: len(listener.results) >= 2, "record two ServerHeartbeatStartedEvents" ) @@ -1735,12 +1735,12 @@ def compression_settings(client): options = client_context.default_client_options if "compressors" in options and "zlib" in options["compressors"]: for level in range(-1, 10): - with single_client(zlibcompressionlevel=level) as client: + with self.single_client(zlibcompressionlevel=level) as client: # No error client.pymongo_test.test.find_one() def test_reset_during_update_pool(self): - client = rs_or_single_client(minPoolSize=10) + client = self.rs_or_single_client(minPoolSize=10) self.addCleanup(client.close) client.admin.command("ping") pool = get_pool(client) @@ -1787,7 +1787,7 @@ def run(self): def test_background_connections_do_not_hold_locks(self): min_pool_size = 10 - client = rs_or_single_client( + client = self.rs_or_single_client( serverSelectionTimeoutMS=3000, minPoolSize=min_pool_size, connect=False ) self.addCleanup(client.close) @@ -1821,14 +1821,14 @@ def stall_connect(*args, **kwargs): @client_context.require_replica_set def test_direct_connection(self): # direct_connection=True should result in Single topology. - client = rs_or_single_client(directConnection=True) + client = self.rs_or_single_client(directConnection=True) client.admin.command("ping") self.assertEqual(len(client.nodes), 1) self.assertEqual(client._topology_settings.get_topology_type(), TOPOLOGY_TYPE.Single) client.close() # direct_connection=False should result in RS topology. - client = rs_or_single_client(directConnection=False) + client = self.rs_or_single_client(directConnection=False) client.admin.command("ping") self.assertGreaterEqual(len(client.nodes), 1) self.assertIn( @@ -1872,7 +1872,7 @@ def server_description_count(): @client_context.require_failCommand_fail_point def test_network_error_message(self): - client = single_client(retryReads=False) + client = self.single_client(retryReads=False) self.addCleanup(client.close) client.admin.command("ping") # connect with self.fail_point( @@ -1885,7 +1885,7 @@ def test_network_error_message(self): @unittest.skipIf("PyPy" in sys.version, "PYTHON-2938 could fail on PyPy") def test_process_periodic_tasks(self): - client = rs_or_single_client() + client = self.rs_or_single_client() coll = client.db.collection coll.insert_many([{} for _ in range(5)]) cursor = coll.find(batch_size=2) @@ -1982,7 +1982,7 @@ def _test_handshake(self, env_vars, expected_env): if "AWS_REGION" not in env_vars: os.environ["AWS_REGION"] = "" - with rs_or_single_client(serverSelectionTimeoutMS=10000) as client: + with self.rs_or_single_client(serverSelectionTimeoutMS=10000) as client: client.admin.command("ping") options = client.options self.assertEqual(options.pool_options.metadata, metadata) @@ -2081,7 +2081,7 @@ def setUp(self): def test_exhaust_query_server_error(self): # When doing an exhaust query, the socket stays checked out on success # but must be checked in on error to avoid semaphore leaks. - client = connected(rs_or_single_client(maxPoolSize=1)) + client = connected(self.rs_or_single_client(maxPoolSize=1)) self.addCleanup(client.close) collection = client.pymongo_test.test @@ -2105,7 +2105,7 @@ def test_exhaust_query_server_error(self): def test_exhaust_getmore_server_error(self): # When doing a getmore on an exhaust cursor, the socket stays checked # out on success but it's checked in on error to avoid semaphore leaks. - client = rs_or_single_client(maxPoolSize=1) + client = self.rs_or_single_client(maxPoolSize=1) self.addCleanup(client.close) collection = client.pymongo_test.test collection.drop() @@ -2145,7 +2145,7 @@ def receive_message(request_id): def test_exhaust_query_network_error(self): # When doing an exhaust query, the socket stays checked out on success # but must be checked in on error to avoid semaphore leaks. - client = connected(rs_or_single_client(maxPoolSize=1, retryReads=False)) + client = connected(self.rs_or_single_client(maxPoolSize=1, retryReads=False)) self.addCleanup(client.close) collection = client.pymongo_test.test pool = get_pool(client) @@ -2167,7 +2167,7 @@ def test_exhaust_query_network_error(self): def test_exhaust_getmore_network_error(self): # When doing a getmore on an exhaust cursor, the socket stays checked # out on success but it's checked in on error to avoid semaphore leaks. - client = rs_or_single_client(maxPoolSize=1) + client = self.rs_or_single_client(maxPoolSize=1) self.addCleanup(client.close) collection = client.pymongo_test.test collection.drop() @@ -2217,7 +2217,7 @@ def test_gevent_timeout(self): raise SkipTest("Must be running monkey patched by gevent") from gevent import Timeout, spawn - client = rs_or_single_client(maxPoolSize=1) + client = self.rs_or_single_client(maxPoolSize=1) coll = client.pymongo_test.test coll.insert_one({}) @@ -2249,7 +2249,7 @@ def test_gevent_timeout_when_creating_connection(self): raise SkipTest("Must be running monkey patched by gevent") from gevent import Timeout, spawn - client = rs_or_single_client() + client = self.rs_or_single_client() self.addCleanup(client.close) coll = client.pymongo_test.test pool = get_pool(client) @@ -2286,7 +2286,7 @@ class TestClientLazyConnect(IntegrationTest): """Test concurrent operations on a lazily-connecting MongoClient.""" def _get_client(self): - return rs_or_single_client(connect=False) + return self.rs_or_single_client(connect=False) @client_context.require_sync def test_insert_one(self): diff --git a/test/test_client_bulk_write.py b/test/test_client_bulk_write.py index ee19a04176..ebc74ef98e 100644 --- a/test/test_client_bulk_write.py +++ b/test/test_client_bulk_write.py @@ -27,7 +27,6 @@ ) from test.utils import ( OvertCommandListener, - rs_or_single_client, ) from unittest.mock import patch @@ -35,10 +34,8 @@ from pymongo.errors import ( ClientBulkWriteException, DocumentTooLarge, - InvalidOperation, NetworkTimeout, ) -from pymongo.monitoring import * from pymongo.operations import * from pymongo.synchronous.client_bulk import _ClientBulk from pymongo.write_concern import WriteConcern @@ -97,7 +94,7 @@ def setUp(self): @client_context.require_no_serverless def test_batch_splits_if_num_operations_too_large(self): listener = OvertCommandListener() - client = rs_or_single_client(event_listeners=[listener]) + client = self.rs_or_single_client(event_listeners=[listener]) self.addCleanup(client.close) models = [] @@ -123,7 +120,7 @@ def test_batch_splits_if_num_operations_too_large(self): @client_context.require_no_serverless def test_batch_splits_if_ops_payload_too_large(self): listener = OvertCommandListener() - client = rs_or_single_client(event_listeners=[listener]) + client = self.rs_or_single_client(event_listeners=[listener]) self.addCleanup(client.close) models = [] @@ -157,7 +154,7 @@ def test_batch_splits_if_ops_payload_too_large(self): @client_context.require_failCommand_fail_point def test_collects_write_concern_errors_across_batches(self): listener = OvertCommandListener() - client = rs_or_single_client( + client = self.rs_or_single_client( event_listeners=[listener], retryWrites=False, ) @@ -200,7 +197,7 @@ def test_collects_write_concern_errors_across_batches(self): @client_context.require_no_serverless def test_collects_write_errors_across_batches_unordered(self): listener = OvertCommandListener() - client = rs_or_single_client(event_listeners=[listener]) + client = self.rs_or_single_client(event_listeners=[listener]) self.addCleanup(client.close) collection = client.db["coll"] @@ -231,7 +228,7 @@ def test_collects_write_errors_across_batches_unordered(self): @client_context.require_no_serverless def test_collects_write_errors_across_batches_ordered(self): listener = OvertCommandListener() - client = rs_or_single_client(event_listeners=[listener]) + client = self.rs_or_single_client(event_listeners=[listener]) self.addCleanup(client.close) collection = client.db["coll"] @@ -262,7 +259,7 @@ def test_collects_write_errors_across_batches_ordered(self): @client_context.require_no_serverless def test_handles_cursor_requiring_getMore(self): listener = OvertCommandListener() - client = rs_or_single_client(event_listeners=[listener]) + client = self.rs_or_single_client(event_listeners=[listener]) self.addCleanup(client.close) collection = client.db["coll"] @@ -304,7 +301,7 @@ def test_handles_cursor_requiring_getMore(self): @client_context.require_no_standalone def test_handles_cursor_requiring_getMore_within_transaction(self): listener = OvertCommandListener() - client = rs_or_single_client(event_listeners=[listener]) + client = self.rs_or_single_client(event_listeners=[listener]) self.addCleanup(client.close) collection = client.db["coll"] @@ -348,7 +345,7 @@ def test_handles_cursor_requiring_getMore_within_transaction(self): @client_context.require_failCommand_fail_point def test_handles_getMore_error(self): listener = OvertCommandListener() - client = rs_or_single_client(event_listeners=[listener]) + client = self.rs_or_single_client(event_listeners=[listener]) self.addCleanup(client.close) collection = client.db["coll"] @@ -403,7 +400,7 @@ def test_handles_getMore_error(self): @client_context.require_no_serverless def test_returns_error_if_unacknowledged_too_large_insert(self): listener = OvertCommandListener() - client = rs_or_single_client(event_listeners=[listener]) + client = self.rs_or_single_client(event_listeners=[listener]) self.addCleanup(client.close) b_repeated = "b" * self.max_bson_object_size @@ -460,7 +457,7 @@ def _setup_namespace_test_models(self): @client_context.require_no_serverless def test_no_batch_splits_if_new_namespace_is_not_too_large(self): listener = OvertCommandListener() - client = rs_or_single_client(event_listeners=[listener]) + client = self.rs_or_single_client(event_listeners=[listener]) self.addCleanup(client.close) num_models, models = self._setup_namespace_test_models() @@ -492,7 +489,7 @@ def test_no_batch_splits_if_new_namespace_is_not_too_large(self): @client_context.require_no_serverless def test_batch_splits_if_new_namespace_is_too_large(self): listener = OvertCommandListener() - client = rs_or_single_client(event_listeners=[listener]) + client = self.rs_or_single_client(event_listeners=[listener]) self.addCleanup(client.close) num_models, models = self._setup_namespace_test_models() @@ -530,7 +527,7 @@ def test_batch_splits_if_new_namespace_is_too_large(self): @client_context.require_version_min(8, 0, 0, -24) @client_context.require_no_serverless def test_returns_error_if_no_writes_can_be_added_to_ops(self): - client = rs_or_single_client() + client = self.rs_or_single_client() self.addCleanup(client.close) # Document too large. @@ -554,7 +551,7 @@ def test_returns_error_if_auto_encryption_configured(self): key_vault_namespace="db.coll", kms_providers={"aws": {"accessKeyId": "foo", "secretAccessKey": "bar"}}, ) - client = rs_or_single_client(auto_encryption_opts=opts) + client = self.rs_or_single_client(auto_encryption_opts=opts) self.addCleanup(client.close) models = [InsertOne(namespace="db.coll", document={"a": "b"})] @@ -580,7 +577,7 @@ def setUp(self): def test_timeout_in_multi_batch_bulk_write(self): _OVERHEAD = 500 - internal_client = rs_or_single_client(timeoutMS=None) + internal_client = self.rs_or_single_client(timeoutMS=None) self.addCleanup(internal_client.close) collection = internal_client.db["coll"] @@ -605,7 +602,7 @@ def test_timeout_in_multi_batch_bulk_write(self): ) listener = OvertCommandListener() - client = rs_or_single_client( + client = self.rs_or_single_client( event_listeners=[listener], readConcernLevel="majority", readPreference="primary", diff --git a/test/test_collection.py b/test/test_collection.py index b68aa74f73..dab59cf1b2 100644 --- a/test/test_collection.py +++ b/test/test_collection.py @@ -29,6 +29,7 @@ from test import ( # TODO: fix sync imports in PYTHON-4528 IntegrationTest, + UnitTest, client_context, unittest, ) @@ -37,8 +38,6 @@ EventListener, get_pool, is_mongos, - rs_or_single_client, - single_client, wait_until, ) @@ -81,14 +80,20 @@ _IS_SYNC = True -class TestCollectionNoConnect(unittest.TestCase): +class TestCollectionNoConnect(UnitTest): """Test Collection features on a client that does not connect.""" db: Database + client: MongoClient @classmethod - def setUpClass(cls): - cls.db = MongoClient(connect=False).pymongo_test + def _setup_class(cls): + cls.client = MongoClient(connect=False) + cls.db = cls.client.pymongo_test + + @classmethod + def _tearDown_class(cls): + cls.client.close() def test_collection(self): self.assertRaises(TypeError, Collection, self.db, 5) @@ -1800,8 +1805,7 @@ def test_exhaust(self): # Insert enough documents to require more than one batch self.db.test.insert_many([{"i": i} for i in range(150)]) - client = rs_or_single_client(maxPoolSize=1) - self.addCleanup(client.close) + client = self.rs_or_single_client(maxPoolSize=1) pool = get_pool(client) # Make sure the socket is returned after exhaustion. @@ -2077,7 +2081,7 @@ def test_find_one_and(self): def test_find_one_and_write_concern(self): listener = EventListener() - db = (single_client(event_listeners=[listener]))[self.db.name] + db = (self.single_client(event_listeners=[listener]))[self.db.name] # non-default WriteConcern. c_w0 = db.get_collection("test", write_concern=WriteConcern(w=0)) # default WriteConcern. diff --git a/test/test_cursor.py b/test/test_cursor.py index 8e6fade1ec..520229902b 100644 --- a/test/test_cursor.py +++ b/test/test_cursor.py @@ -35,7 +35,6 @@ EventListener, OvertCommandListener, ignore_deprecations, - rs_or_single_client, wait_until, ) @@ -230,7 +229,7 @@ def test_max_await_time_ms(self): self.assertEqual(90, cursor._max_await_time_ms) listener = AllowListEventListener("find", "getMore") - coll = (rs_or_single_client(event_listeners=[listener]))[self.db.name].pymongo_test + coll = (self.rs_or_single_client(event_listeners=[listener]))[self.db.name].pymongo_test # Tailable_defaults. coll.find(cursor_type=CursorType.TAILABLE_AWAIT).to_list() @@ -345,7 +344,7 @@ def test_explain(self): def test_explain_with_read_concern(self): # Do not add readConcern level to explain. listener = AllowListEventListener("explain") - client = rs_or_single_client(event_listeners=[listener]) + client = self.rs_or_single_client(event_listeners=[listener]) self.addCleanup(client.close) coll = client.pymongo_test.test.with_options(read_concern=ReadConcern(level="local")) self.assertTrue(coll.find().explain()) @@ -1252,7 +1251,7 @@ def test_close_kills_cursor_synchronously(self): self.client._process_periodic_tasks() listener = AllowListEventListener("killCursors") - client = rs_or_single_client(event_listeners=[listener]) + client = self.rs_or_single_client(event_listeners=[listener]) self.addCleanup(client.close) coll = client[self.db.name].test_close_kills_cursors @@ -1291,7 +1290,7 @@ def assertCursorKilled(): @client_context.require_failCommand_appName def test_timeout_kills_cursor_synchronously(self): listener = AllowListEventListener("killCursors") - client = rs_or_single_client(event_listeners=[listener]) + client = self.rs_or_single_client(event_listeners=[listener]) self.addCleanup(client.close) coll = client[self.db.name].test_timeout_kills_cursor @@ -1349,7 +1348,7 @@ def test_delete_not_initialized(self): def test_getMore_does_not_send_readPreference(self): listener = AllowListEventListener("find", "getMore") - client = rs_or_single_client(event_listeners=[listener]) + client = self.rs_or_single_client(event_listeners=[listener]) self.addCleanup(client.close) # We never send primary read preference so override the default. coll = client[self.db.name].get_collection( @@ -1454,7 +1453,7 @@ def test_find_raw_transaction(self): c.insert_many(docs) listener = OvertCommandListener() - client = rs_or_single_client(event_listeners=[listener]) + client = self.rs_or_single_client(event_listeners=[listener]) with client.start_session() as session: with session.start_transaction(): batches = ( @@ -1484,7 +1483,7 @@ def test_find_raw_retryable_reads(self): c.insert_many(docs) listener = OvertCommandListener() - client = rs_or_single_client(event_listeners=[listener], retryReads=True) + client = self.rs_or_single_client(event_listeners=[listener], retryReads=True) with self.fail_point( {"mode": {"times": 1}, "data": {"failCommands": ["find"], "closeConnection": True}} ): @@ -1505,7 +1504,7 @@ def test_find_raw_snapshot_reads(self): c.insert_many(docs) listener = OvertCommandListener() - client = rs_or_single_client(event_listeners=[listener], retryReads=True) + client = self.rs_or_single_client(event_listeners=[listener], retryReads=True) db = client[self.db.name] with client.start_session(snapshot=True) as session: db.test.distinct("x", {}, session=session) @@ -1566,7 +1565,7 @@ def test_read_concern(self): def test_monitoring(self): listener = EventListener() - client = rs_or_single_client(event_listeners=[listener]) + client = self.rs_or_single_client(event_listeners=[listener]) c = client.pymongo_test.test c.drop() c.insert_many([{"_id": i} for i in range(10)]) @@ -1632,7 +1631,7 @@ def test_aggregate_raw_transaction(self): c.insert_many(docs) listener = OvertCommandListener() - client = rs_or_single_client(event_listeners=[listener]) + client = self.rs_or_single_client(event_listeners=[listener]) with client.start_session() as session: with session.start_transaction(): batches = ( @@ -1663,7 +1662,7 @@ def test_aggregate_raw_retryable_reads(self): c.insert_many(docs) listener = OvertCommandListener() - client = rs_or_single_client(event_listeners=[listener], retryReads=True) + client = self.rs_or_single_client(event_listeners=[listener], retryReads=True) with self.fail_point( {"mode": {"times": 1}, "data": {"failCommands": ["aggregate"], "closeConnection": True}} ): @@ -1687,7 +1686,7 @@ def test_aggregate_raw_snapshot_reads(self): c.insert_many(docs) listener = OvertCommandListener() - client = rs_or_single_client(event_listeners=[listener], retryReads=True) + client = self.rs_or_single_client(event_listeners=[listener], retryReads=True) db = client[self.db.name] with client.start_session(snapshot=True) as session: db.test.distinct("x", {}, session=session) @@ -1733,7 +1732,7 @@ def test_collation(self): def test_monitoring(self): listener = EventListener() - client = rs_or_single_client(event_listeners=[listener]) + client = self.rs_or_single_client(event_listeners=[listener]) c = client.pymongo_test.test c.drop() c.insert_many([{"_id": i} for i in range(10)]) @@ -1777,7 +1776,7 @@ def test_monitoring(self): @client_context.require_no_mongos def test_exhaust_cursor_db_set(self): listener = OvertCommandListener() - client = rs_or_single_client(event_listeners=[listener]) + client = self.rs_or_single_client(event_listeners=[listener]) self.addCleanup(client.close) c = client.pymongo_test.test c.delete_many({}) diff --git a/test/test_custom_types.py b/test/test_custom_types.py index c30c62b1b1..abaa820cb7 100644 --- a/test/test_custom_types.py +++ b/test/test_custom_types.py @@ -27,7 +27,6 @@ from test import client_context, unittest from test.test_client import IntegrationTest -from test.utils import rs_client from bson import ( _BUILT_IN_TYPES, @@ -971,7 +970,7 @@ def create_targets(self, *args, **kwargs): if codec_options: kwargs["type_registry"] = codec_options.type_registry kwargs["document_class"] = codec_options.document_class - self.watched_target = rs_client(*args, **kwargs) + self.watched_target = self.rs_client(*args, **kwargs) self.addCleanup(self.watched_target.close) self.input_target = self.watched_target[self.db.name].test # Insert a record to ensure db, coll are created. diff --git a/test/test_database.py b/test/test_database.py index 12d4eb666a..144c357c52 100644 --- a/test/test_database.py +++ b/test/test_database.py @@ -28,7 +28,6 @@ from test.utils import ( IMPOSSIBLE_WRITE_CONCERN, OvertCommandListener, - rs_or_single_client, wait_until, ) @@ -207,7 +206,7 @@ def test_list_collection_names(self): def test_list_collection_names_filter(self): listener = OvertCommandListener() - client = rs_or_single_client(event_listeners=[listener]) + client = self.rs_or_single_client(event_listeners=[listener]) db = client[self.db.name] db.capped.drop() db.create_collection("capped", capped=True, size=4096) @@ -234,7 +233,7 @@ def test_list_collection_names_filter(self): def test_check_exists(self): listener = OvertCommandListener() - client = rs_or_single_client(event_listeners=[listener]) + client = self.rs_or_single_client(event_listeners=[listener]) self.addCleanup(client.close) db = client[self.db.name] db.drop_collection("unique") @@ -323,7 +322,7 @@ def test_list_collections(self): self.client.drop_database("pymongo_test") def test_list_collection_names_single_socket(self): - client = rs_or_single_client(maxPoolSize=1) + client = self.rs_or_single_client(maxPoolSize=1) client.drop_database("test_collection_names_single_socket") db = client.test_collection_names_single_socket for i in range(200): diff --git a/test/test_encryption.py b/test/test_encryption.py index 568ebffc9e..75fe9a4e34 100644 --- a/test/test_encryption.py +++ b/test/test_encryption.py @@ -61,7 +61,6 @@ TopologyEventListener, camel_to_snake_args, is_greenthread_patched, - rs_or_single_client, wait_until, ) from test.utils_spec_runner import SpecRunner @@ -260,7 +259,7 @@ def bson_data(*paths): class TestClientSimple(EncryptionIntegrationTest): def _test_auto_encrypt(self, opts): - client = rs_or_single_client(auto_encryption_opts=opts) + client = self.rs_or_single_client(auto_encryption_opts=opts) self.addCleanup(client.close) # Create the encrypted field's data key. @@ -342,7 +341,7 @@ def test_auto_encrypt_local_schema_map(self): def test_use_after_close(self): opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys") - client = rs_or_single_client(auto_encryption_opts=opts) + client = self.rs_or_single_client(auto_encryption_opts=opts) self.addCleanup(client.close) client.admin.command("ping") @@ -360,7 +359,7 @@ def test_use_after_close(self): ) def test_fork(self): opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys") - client = rs_or_single_client(auto_encryption_opts=opts) + client = self.rs_or_single_client(auto_encryption_opts=opts) self.addCleanup(client.close) def target(): @@ -375,7 +374,7 @@ def target(): class TestEncryptedBulkWrite(BulkTestBase, EncryptionIntegrationTest): def test_upsert_uuid_standard_encrypt(self): opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys") - client = rs_or_single_client(auto_encryption_opts=opts) + client = self.rs_or_single_client(auto_encryption_opts=opts) self.addCleanup(client.close) options = CodecOptions(uuid_representation=UuidRepresentation.STANDARD) @@ -416,7 +415,7 @@ def _setup_class(cls): @client_context.require_version_max(4, 0, 99) def test_raise_max_wire_version_error(self): opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys") - client = rs_or_single_client(auto_encryption_opts=opts) + client = self.rs_or_single_client(auto_encryption_opts=opts) self.addCleanup(client.close) msg = "Auto-encryption requires a minimum MongoDB version of 4.2" with self.assertRaisesRegex(ConfigurationError, msg): @@ -430,7 +429,7 @@ def test_raise_max_wire_version_error(self): def test_raise_unsupported_error(self): opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys") - client = rs_or_single_client(auto_encryption_opts=opts) + client = self.rs_or_single_client(auto_encryption_opts=opts) self.addCleanup(client.close) msg = "find_raw_batches does not support auto encryption" with self.assertRaisesRegex(InvalidOperation, msg): @@ -807,7 +806,7 @@ class TestDataKeyDoubleEncryption(EncryptionIntegrationTest): def _setup_class(cls): super()._setup_class() cls.listener = OvertCommandListener() - cls.client = rs_or_single_client(event_listeners=[cls.listener]) + cls.client = PyMongoTestCase.unmanaged_rs_or_single_client(event_listeners=[cls.listener]) cls.client.db.coll.drop() cls.vault = create_key_vault(cls.client.keyvault.datakeys) @@ -829,7 +828,7 @@ def _setup_class(cls): opts = AutoEncryptionOpts( cls.KMS_PROVIDERS, "keyvault.datakeys", schema_map=schemas, kms_tls_options=KMS_TLS_OPTS ) - cls.client_encrypted = rs_or_single_client( + cls.client_encrypted = PyMongoTestCase.unmanaged_rs_or_single_client( auto_encryption_opts=opts, uuidRepresentation="standard" ) cls.client_encryption = ClientEncryption( @@ -919,7 +918,7 @@ def _test_external_key_vault(self, with_external_key_vault): # Configure the encrypted field via the local schema_map option. schemas = {"db.coll": json_data("external", "external-schema.json")} if with_external_key_vault: - key_vault_client = rs_or_single_client(username="fake-user", password="fake-pwd") + key_vault_client = self.rs_or_single_client(username="fake-user", password="fake-pwd") self.addCleanup(key_vault_client.close) else: key_vault_client = client_context.client @@ -930,7 +929,7 @@ def _test_external_key_vault(self, with_external_key_vault): key_vault_client=key_vault_client, ) - client_encrypted = rs_or_single_client( + client_encrypted = self.rs_or_single_client( auto_encryption_opts=opts, uuidRepresentation="standard" ) self.addCleanup(client_encrypted.close) @@ -984,7 +983,7 @@ def test_views_are_prohibited(self): self.addCleanup(self.client.db.view.drop) opts = AutoEncryptionOpts(self.kms_providers(), "keyvault.datakeys") - client_encrypted = rs_or_single_client( + client_encrypted = self.rs_or_single_client( auto_encryption_opts=opts, uuidRepresentation="standard" ) self.addCleanup(client_encrypted.close) @@ -1044,7 +1043,7 @@ def _test_corpus(self, opts): ) self.addCleanup(vault.drop) - client_encrypted = rs_or_single_client(auto_encryption_opts=opts) + client_encrypted = self.rs_or_single_client(auto_encryption_opts=opts) self.addCleanup(client_encrypted.close) client_encryption = ClientEncryption( @@ -1197,7 +1196,7 @@ def _setup_class(cls): opts = AutoEncryptionOpts({"local": {"key": LOCAL_MASTER_KEY}}, "keyvault.datakeys") cls.listener = OvertCommandListener() - cls.client_encrypted = rs_or_single_client( + cls.client_encrypted = PyMongoTestCase.unmanaged_rs_or_single_client( auto_encryption_opts=opts, event_listeners=[cls.listener] ) cls.coll_encrypted = cls.client_encrypted.db.coll @@ -1517,7 +1516,7 @@ def _test_automatic(self, expectation_extjson, payload): ) insert_listener = AllowListEventListener("insert") - client = rs_or_single_client( + client = self.rs_or_single_client( auto_encryption_opts=encryption_opts, event_listeners=[insert_listener] ) self.addCleanup(client.close) @@ -1596,13 +1595,13 @@ def test_automatic(self): # https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.rst#deadlock-tests class TestDeadlockProse(EncryptionIntegrationTest): def setUp(self): - self.client_test = rs_or_single_client( + self.client_test = self.rs_or_single_client( maxPoolSize=1, readConcernLevel="majority", w="majority", uuidRepresentation="standard" ) self.addCleanup(self.client_test.close) self.client_keyvault_listener = OvertCommandListener() - self.client_keyvault = rs_or_single_client( + self.client_keyvault = self.rs_or_single_client( maxPoolSize=1, readConcernLevel="majority", w="majority", @@ -1635,7 +1634,7 @@ def setUp(self): self.optargs = ({"local": {"key": LOCAL_MASTER_KEY}}, "keyvault.datakeys") def _run_test(self, max_pool_size, auto_encryption_opts): - client_encrypted = rs_or_single_client( + client_encrypted = self.rs_or_single_client( readConcernLevel="majority", w="majority", maxPoolSize=max_pool_size, @@ -1845,7 +1844,7 @@ def setUp(self): key_vault_namespace="keyvault.datakeys", kms_providers=kms_providers_map ) self.listener = AllowListEventListener("aggregate") - self.encrypted_client = rs_or_single_client( + self.encrypted_client = self.rs_or_single_client( auto_encryption_opts=opts, retryReads=False, event_listeners=[self.listener] ) self.addCleanup(self.encrypted_client.close) @@ -1925,7 +1924,7 @@ def reset_timeout(): "--port=27027", ], ) - client_encrypted = rs_or_single_client(auto_encryption_opts=opts) + client_encrypted = self.rs_or_single_client(auto_encryption_opts=opts) self.addCleanup(client_encrypted.close) with self.assertRaisesRegex(EncryptionError, "Timeout"): client_encrypted.db.coll.insert_one({"encrypted": "test"}) @@ -1940,11 +1939,12 @@ def test_bypassAutoEncryption(self): "--port=27027", ], ) - client_encrypted = rs_or_single_client(auto_encryption_opts=opts) + client_encrypted = self.rs_or_single_client(auto_encryption_opts=opts) self.addCleanup(client_encrypted.close) client_encrypted.db.coll.insert_one({"unencrypted": "test"}) # Validate that mongocryptd was not spawned: mongocryptd_client = MongoClient("mongodb://localhost:27027/?serverSelectionTimeoutMS=500") + self.addCleanup(mongocryptd_client.close) with self.assertRaises(ServerSelectionTimeoutError): mongocryptd_client.admin.command("ping") @@ -1966,7 +1966,7 @@ def test_via_loading_shared_library(self): ], crypt_shared_lib_required=True, ) - client_encrypted = rs_or_single_client(auto_encryption_opts=opts) + client_encrypted = self.rs_or_single_client(auto_encryption_opts=opts) self.addCleanup(client_encrypted.close) client_encrypted.db.coll.drop() client_encrypted.db.coll.insert_one({"encrypted": "test"}) @@ -2008,7 +2008,7 @@ def listener(): mongocryptd_uri="mongodb://localhost:47021", crypt_shared_lib_required=False, ) - client_encrypted = rs_or_single_client(auto_encryption_opts=opts) + client_encrypted = self.rs_or_single_client(auto_encryption_opts=opts) self.addCleanup(client_encrypted.close) client_encrypted.db.coll.drop() client_encrypted.db.coll.insert_one({"encrypted": "test"}) @@ -2320,7 +2320,7 @@ def setUp(self): key_vault.full_name, bypass_query_analysis=True, ) - self.encrypted_client = rs_or_single_client(auto_encryption_opts=opts) + self.encrypted_client = self.rs_or_single_client(auto_encryption_opts=opts) self.addCleanup(self.encrypted_client.close) def test_01_insert_encrypted_indexed_and_find(self): @@ -2464,7 +2464,7 @@ def run_test(self, src_provider, dst_provider): ) # Step 5. Create a ``ClientEncryption`` object named ``client_encryption2`` - client2 = rs_or_single_client() + client2 = self.rs_or_single_client() self.addCleanup(client2.close) client_encryption2 = ClientEncryption( key_vault_client=client2, @@ -2539,7 +2539,7 @@ def test_queryable_encryption(self): # MongoClient to use in testing that handles auth/tls/etc, # and cleanup. def MongoClient(**kwargs): - c = rs_or_single_client(**kwargs) + c = self.rs_or_single_client(**kwargs) self.addCleanup(c.close) return c @@ -2641,7 +2641,7 @@ def setUp(self): key_vault.full_name, bypass_query_analysis=True, ) - self.encrypted_client = rs_or_single_client(auto_encryption_opts=opts) + self.encrypted_client = self.rs_or_single_client(auto_encryption_opts=opts) self.db = self.encrypted_client.db self.addCleanup(self.encrypted_client.close) diff --git a/test/test_grid_file.py b/test/test_grid_file.py index 0e806eb5cb..5b0daf8d7f 100644 --- a/test/test_grid_file.py +++ b/test/test_grid_file.py @@ -28,7 +28,7 @@ sys.path[0:0] = [""] from test import IntegrationTest, qcheck, unittest -from test.utils import EventListener, rs_or_single_client +from test.utils import EventListener from bson.objectid import ObjectId from gridfs import GridFS @@ -789,6 +789,7 @@ def test_grid_out_lazy_connect(self): def test_grid_in_lazy_connect(self): client = MongoClient("badhost", connect=False, serverSelectionTimeoutMS=10) + self.addCleanup(client.close) fs = client.db.fs infile = GridIn(fs, file_id=-1, chunk_size=1) with self.assertRaises(ServerSelectionTimeoutError): @@ -799,7 +800,7 @@ def test_grid_in_lazy_connect(self): def test_unacknowledged(self): # w=0 is prohibited. with self.assertRaises(ConfigurationError): - GridIn((rs_or_single_client(w=0)).pymongo_test.fs) + GridIn((self.rs_or_single_client(w=0)).pymongo_test.fs) def test_survive_cursor_not_found(self): # By default the find command returns 101 documents in the first batch. @@ -807,7 +808,7 @@ def test_survive_cursor_not_found(self): chunk_size = 1024 data = b"d" * (102 * chunk_size) listener = EventListener() - client = rs_or_single_client(event_listeners=[listener]) + client = self.rs_or_single_client(event_listeners=[listener]) db = client.pymongo_test with GridIn(db.fs, chunk_size=chunk_size) as infile: infile.write(data) diff --git a/test/test_logger.py b/test/test_logger.py index d6c30b68a8..71ada17c42 100644 --- a/test/test_logger.py +++ b/test/test_logger.py @@ -15,7 +15,6 @@ import os from test import IntegrationTest, unittest -from test.utils import single_client from unittest.mock import patch from bson import json_util @@ -85,7 +84,7 @@ def test_truncation_multi_byte_codepoints(self): self.assertEqual(last_3_bytes, str_to_repeat) def test_logging_without_listeners(self): - c = single_client() + c = self.single_client() self.assertEqual(len(c._event_listeners.event_listeners()), 0) with self.assertLogs("pymongo.connection", level="DEBUG") as cm: c.db.test.insert_one({"x": "1"}) diff --git a/test/test_session.py b/test/test_session.py index 563b33c70e..6988ef8667 100644 --- a/test/test_session.py +++ b/test/test_session.py @@ -36,7 +36,6 @@ from test.utils import ( EventListener, ExceptionCatchingThread, - rs_or_single_client, wait_until, ) @@ -88,7 +87,7 @@ def _setup_class(cls): super()._setup_class() # Create a second client so we can make sure clients cannot share # sessions. - cls.client2 = rs_or_single_client() + cls.client2 = PyMongoTestCase.unmanaged_rs_or_single_client() # Redact no commands, so we can test user-admin commands have "lsid". cls.sensitive_commands = monitoring._SENSITIVE_COMMANDS.copy() @@ -103,7 +102,7 @@ def _tearDown_class(cls): def setUp(self): self.listener = SessionTestListener() self.session_checker_listener = SessionTestListener() - self.client = rs_or_single_client( + self.client = self.rs_or_single_client( event_listeners=[self.listener, self.session_checker_listener] ) self.addCleanup(self.client.close) @@ -200,7 +199,7 @@ def test_implicit_sessions_checkout(self): failures = 0 for _ in range(5): listener = EventListener() - client = rs_or_single_client(event_listeners=[listener], maxPoolSize=1) + client = self.rs_or_single_client(event_listeners=[listener], maxPoolSize=1) cursor = client.db.test.find({}) ops: List[Tuple[Callable, List[Any]]] = [ (client.db.test.find_one, [{"_id": 1}]), @@ -283,7 +282,7 @@ def test_end_session(self): def test_end_sessions(self): # Use a new client so that the tearDown hook does not error. listener = SessionTestListener() - client = rs_or_single_client(event_listeners=[listener]) + client = self.rs_or_single_client(event_listeners=[listener]) # Start many sessions. sessions = [client.start_session() for _ in range(_MAX_END_SESSIONS + 1)] for s in sessions: @@ -787,7 +786,7 @@ def _test_unacknowledged_ops(self, client, *ops): def test_unacknowledged_writes(self): # Ensure the collection exists. self.client.pymongo_test.test_unacked_writes.insert_one({}) - client = rs_or_single_client(w=0, event_listeners=[self.listener]) + client = self.rs_or_single_client(w=0, event_listeners=[self.listener]) self.addCleanup(client.close) db = client.pymongo_test coll = db.test_unacked_writes @@ -836,7 +835,7 @@ class TestCausalConsistency(UnitTest): @classmethod def _setup_class(cls): cls.listener = SessionTestListener() - cls.client = rs_or_single_client(event_listeners=[cls.listener]) + cls.client = PyMongoTestCase.unmanaged_rs_or_single_client(event_listeners=[cls.listener]) @classmethod def _tearDown_class(cls): @@ -1137,7 +1136,7 @@ def setUp(self): def test_cluster_time(self): listener = SessionTestListener() # Prevent heartbeats from updating $clusterTime between operations. - client = rs_or_single_client(event_listeners=[listener], heartbeatFrequencyMS=999999) + client = self.rs_or_single_client(event_listeners=[listener], heartbeatFrequencyMS=999999) self.addCleanup(client.close) collection = client.pymongo_test.collection # Prepare for tests of find() and aggregate(). diff --git a/test/test_transactions.py b/test/test_transactions.py index b1869bec79..069f2fb29b 100644 --- a/test/test_transactions.py +++ b/test/test_transactions.py @@ -26,8 +26,6 @@ from test import client_context, unittest from test.utils import ( OvertCommandListener, - rs_client, - single_client, wait_until, ) from test.utils_spec_runner import SpecRunner @@ -69,13 +67,7 @@ def _setup_class(cls): super()._setup_class() if client_context.supports_transactions(): for address in client_context.mongoses: - cls.mongos_clients.append(single_client("{}:{}".format(*address))) - - @classmethod - def _tearDown_class(cls): - for client in cls.mongos_clients: - client.close() - super()._tearDown_class() + cls.mongos_clients.append(cls.single_client("{}:{}".format(*address))) def maybe_skip_scenario(self, test): super().maybe_skip_scenario(test) @@ -120,7 +112,7 @@ def test_transaction_options_validation(self): @client_context.require_transactions def test_transaction_write_concern_override(self): """Test txn overrides Client/Database/Collection write_concern.""" - client = rs_client(w=0) + client = self.rs_client(w=0) self.addCleanup(client.close) db = client.test coll = db.test @@ -174,7 +166,7 @@ def test_transaction_write_concern_override(self): def test_unpin_for_next_transaction(self): # Increase localThresholdMS and wait until both nodes are discovered # to avoid false positives. - client = rs_client(client_context.mongos_seeds(), localThresholdMS=1000) + client = self.rs_client(client_context.mongos_seeds(), localThresholdMS=1000) wait_until(lambda: len(client.nodes) > 1, "discover both mongoses") coll = client.test.test # Create the collection. @@ -202,7 +194,7 @@ def test_unpin_for_next_transaction(self): def test_unpin_for_non_transaction_operation(self): # Increase localThresholdMS and wait until both nodes are discovered # to avoid false positives. - client = rs_client(client_context.mongos_seeds(), localThresholdMS=1000) + client = self.rs_client(client_context.mongos_seeds(), localThresholdMS=1000) wait_until(lambda: len(client.nodes) > 1, "discover both mongoses") coll = client.test.test # Create the collection. @@ -331,7 +323,7 @@ def test_transaction_starts_with_batched_write(self): # Start a transaction with a batch of operations that needs to be # split. listener = OvertCommandListener() - client = rs_client(event_listeners=[listener]) + client = self.rs_client(event_listeners=[listener]) coll = client[self.db.name].test coll.delete_many({}) listener.reset() @@ -360,7 +352,7 @@ def test_transaction_starts_with_batched_write(self): @client_context.require_transactions def test_transaction_direct_connection(self): - client = single_client() + client = self.single_client() self.addCleanup(client.close) coll = client.pymongo_test.test @@ -450,7 +442,7 @@ def callback2(session): @client_context.require_transactions def test_callback_not_retried_after_timeout(self): listener = OvertCommandListener() - client = rs_client(event_listeners=[listener]) + client = self.rs_client(event_listeners=[listener]) self.addCleanup(client.close) coll = client[self.db.name].test @@ -479,7 +471,7 @@ def callback(session): @client_context.require_transactions def test_callback_not_retried_after_commit_timeout(self): listener = OvertCommandListener() - client = rs_client(event_listeners=[listener]) + client = self.rs_client(event_listeners=[listener]) self.addCleanup(client.close) coll = client[self.db.name].test @@ -512,7 +504,7 @@ def callback(session): @client_context.require_transactions def test_commit_not_retried_after_timeout(self): listener = OvertCommandListener() - client = rs_client(event_listeners=[listener]) + client = self.rs_client(event_listeners=[listener]) self.addCleanup(client.close) coll = client[self.db.name].test diff --git a/test/unified_format.py b/test/unified_format.py index e4ebf677e2..7140b83e1f 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -55,8 +55,6 @@ parse_collection_options, parse_spec_options, prepare_spec_arguments, - rs_or_single_client, - single_client, snake_to_camel, wait_until, ) @@ -574,7 +572,7 @@ def _create_entity(self, entity_spec, uri=None): ) if uri: kwargs["h"] = uri - client = rs_or_single_client(**kwargs) + client = self.test.rs_or_single_client(**kwargs) self[spec["id"]] = client self.test.addCleanup(client.close) return @@ -1108,7 +1106,7 @@ def setUpClass(cls): and not client_context.serverless ): for address in client_context.mongoses: - cls.mongos_clients.append(single_client("{}:{}".format(*address))) + cls.mongos_clients.append(cls.single_client("{}:{}".format(*address))) # add any special-casing for skipping tests here if client_context.storage_engine == "mmapv1": @@ -1647,7 +1645,7 @@ def _testOperation_targetedFailPoint(self, spec): ) ) - client = single_client("{}:{}".format(*session._pinned_address)) + client = self.single_client("{}:{}".format(*session._pinned_address)) self.addCleanup(client.close) self.__set_fail_point(client=client, command_args=spec["failPoint"]) diff --git a/test/utils.py b/test/utils.py index fa198b1c64..6eefd1c7ea 100644 --- a/test/utils.py +++ b/test/utils.py @@ -565,151 +565,6 @@ def create_tests(self): setattr(self._test_class, new_test.__name__, new_test) -def _connection_string(h): - if h.startswith(("mongodb://", "mongodb+srv://")): - return h - return f"mongodb://{h!s}" - - -def _mongo_client(host, port, authenticate=True, directConnection=None, **kwargs): - """Create a new client over SSL/TLS if necessary.""" - host = host or client_context.host - port = port or client_context.port - client_options: dict = client_context.default_client_options.copy() - if client_context.replica_set_name and not directConnection: - client_options["replicaSet"] = client_context.replica_set_name - if directConnection is not None: - client_options["directConnection"] = directConnection - client_options.update(kwargs) - - uri = _connection_string(host) - auth_mech = kwargs.get("authMechanism", "") - if client_context.auth_enabled and authenticate and auth_mech != "MONGODB-OIDC": - # Only add the default username or password if one is not provided. - res = parse_uri(uri) - if ( - not res["username"] - and not res["password"] - and "username" not in client_options - and "password" not in client_options - ): - client_options["username"] = db_user - client_options["password"] = db_pwd - return MongoClient(uri, port, **client_options) - - -async def _async_mongo_client(host, port, authenticate=True, directConnection=None, **kwargs): - """Create a new client over SSL/TLS if necessary.""" - host = host or await async_client_context.host - port = port or await async_client_context.port - client_options: dict = async_client_context.default_client_options.copy() - if async_client_context.replica_set_name and not directConnection: - client_options["replicaSet"] = async_client_context.replica_set_name - if directConnection is not None: - client_options["directConnection"] = directConnection - client_options.update(kwargs) - - uri = _connection_string(host) - auth_mech = kwargs.get("authMechanism", "") - if async_client_context.auth_enabled and authenticate and auth_mech != "MONGODB-OIDC": - # Only add the default username or password if one is not provided. - res = parse_uri(uri) - if ( - not res["username"] - and not res["password"] - and "username" not in client_options - and "password" not in client_options - ): - client_options["username"] = db_user - client_options["password"] = db_pwd - client = AsyncMongoClient(uri, port, **client_options) - if client._options.connect: - await client.aconnect() - return client - - -def single_client_noauth(h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]: - """Make a direct connection. Don't authenticate.""" - return _mongo_client(h, p, authenticate=False, directConnection=True, **kwargs) - - -def single_client(h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]: - """Make a direct connection, and authenticate if necessary.""" - return _mongo_client(h, p, directConnection=True, **kwargs) - - -def rs_client_noauth(h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]: - """Connect to the replica set. Don't authenticate.""" - return _mongo_client(h, p, authenticate=False, **kwargs) - - -def rs_client(h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]: - """Connect to the replica set and authenticate if necessary.""" - return _mongo_client(h, p, **kwargs) - - -def rs_or_single_client_noauth(h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]: - """Connect to the replica set if there is one, otherwise the standalone. - - Like rs_or_single_client, but does not authenticate. - """ - return _mongo_client(h, p, authenticate=False, **kwargs) - - -def rs_or_single_client(h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[Any]: - """Connect to the replica set if there is one, otherwise the standalone. - - Authenticates if necessary. - """ - return _mongo_client(h, p, **kwargs) - - -async def async_single_client_noauth( - h: Any = None, p: Any = None, **kwargs: Any -) -> AsyncMongoClient[dict]: - """Make a direct connection. Don't authenticate.""" - return await _async_mongo_client(h, p, authenticate=False, directConnection=True, **kwargs) - - -async def async_single_client( - h: Any = None, p: Any = None, **kwargs: Any -) -> AsyncMongoClient[dict]: - """Make a direct connection, and authenticate if necessary.""" - return await _async_mongo_client(h, p, directConnection=True, **kwargs) - - -async def async_rs_client_noauth( - h: Any = None, p: Any = None, **kwargs: Any -) -> AsyncMongoClient[dict]: - """Connect to the replica set. Don't authenticate.""" - return await _async_mongo_client(h, p, authenticate=False, **kwargs) - - -async def async_rs_client(h: Any = None, p: Any = None, **kwargs: Any) -> AsyncMongoClient[dict]: - """Connect to the replica set and authenticate if necessary.""" - return await _async_mongo_client(h, p, **kwargs) - - -async def async_rs_or_single_client_noauth( - h: Any = None, p: Any = None, **kwargs: Any -) -> AsyncMongoClient[dict]: - """Connect to the replica set if there is one, otherwise the standalone. - - Like rs_or_single_client, but does not authenticate. - """ - return await _async_mongo_client(h, p, authenticate=False, **kwargs) - - -async def async_rs_or_single_client( - h: Any = None, p: Any = None, **kwargs: Any -) -> AsyncMongoClient[Any]: - """Connect to the replica set if there is one, otherwise the standalone. - - Authenticates if necessary. - """ - return await _async_mongo_client(h, p, **kwargs) - - def ensure_all_connected(client: MongoClient) -> None: """Ensure that the client's connection pool has socket connections to all members of a replica set. Raises ConfigurationError when called with a @@ -1108,20 +963,6 @@ def is_greenthread_patched(): return gevent_monkey_patched() or eventlet_monkey_patched() -def disable_replication(client): - """Disable replication on all secondaries.""" - for host, port in client.secondaries: - secondary = single_client(host, port) - secondary.admin.command("configureFailPoint", "stopReplProducer", mode="alwaysOn") - - -def enable_replication(client): - """Enable replication on all secondaries.""" - for host, port in client.secondaries: - secondary = single_client(host, port) - secondary.admin.command("configureFailPoint", "stopReplProducer", mode="off") - - class ExceptionCatchingThread(threading.Thread): """A thread that stores any exception encountered from run().""" diff --git a/test/utils_spec_runner.py b/test/utils_spec_runner.py index 0b882a8bc3..06a40351cd 100644 --- a/test/utils_spec_runner.py +++ b/test/utils_spec_runner.py @@ -29,7 +29,6 @@ camel_to_snake_args, parse_spec_options, prepare_spec_arguments, - rs_client, ) from typing import List @@ -101,6 +100,8 @@ def _setup_class(cls): @classmethod def _tearDown_class(cls): cls.knobs.disable() + for client in cls.mongos_clients: + client.close() super()._tearDown_class() def setUp(self): @@ -524,7 +525,7 @@ def run_scenario(self, scenario_def, test): host = client_context.MULTI_MONGOS_LB_URI elif client_context.is_mongos: host = client_context.mongos_seeds() - client = rs_client( + client = self.rs_client( h=host, event_listeners=[listener, pool_listener, server_listener], **client_options ) self.scenario_client = client From 4ddb75e17e48a86e0b15fa853c48eac70f224595 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Wed, 11 Sep 2024 17:32:22 -0400 Subject: [PATCH 07/29] WIP all --- test/__init__.py | 12 +-- test/asynchronous/__init__.py | 18 ++--- test/asynchronous/test_bulk.py | 4 +- test/asynchronous/test_client.py | 2 +- test/asynchronous/test_encryption.py | 8 +- test/asynchronous/test_session.py | 6 +- test/test_auth.py | 74 ++++++++++--------- test/test_bulk.py | 4 +- test/test_change_stream.py | 11 ++- test/test_client.py | 4 +- test/test_collation.py | 4 +- test/test_comment.py | 8 +- test/test_common.py | 19 ++--- test/test_connection_monitoring.py | 19 ++--- ...nnections_survive_primary_stepdown_spec.py | 3 +- test/test_data_lake.py | 8 +- test/test_discovery_and_monitoring.py | 12 ++- test/test_encryption.py | 6 +- test/test_examples.py | 8 +- test/test_gridfs.py | 10 +-- test/test_gridfs_bucket.py | 10 +-- test/test_heartbeat_monitoring.py | 4 +- test/test_load_balancer.py | 8 +- test/test_max_staleness.py | 11 ++- test/test_monitor.py | 18 ++--- test/test_monitoring.py | 10 ++- test/test_pooling.py | 18 ++--- test/test_read_concern.py | 6 +- test/test_read_preferences.py | 46 ++++++------ test/test_read_write_concern_spec.py | 19 ++--- test/test_retryable_reads.py | 10 +-- test/test_retryable_writes.py | 25 ++++--- test/test_sdam_monitoring_spec.py | 3 +- test/test_server_selection.py | 7 +- test/test_server_selection_in_window.py | 3 +- test/test_session.py | 4 +- test/test_streaming_protocol.py | 10 +-- test/test_typing.py | 5 +- test/test_versioned_api.py | 6 +- 39 files changed, 220 insertions(+), 243 deletions(-) diff --git a/test/__init__.py b/test/__init__.py index 5dc2b132fd..355f06426d 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -1001,7 +1001,7 @@ def unmanaged_single_client_noauth( cls, h: Any = None, p: Any = None, **kwargs: Any ) -> MongoClient[dict]: """Make a direct connection. Don't authenticate.""" - return PyMongoTestCase._unmanaged_async_mongo_client( + return cls._unmanaged_async_mongo_client( h, p, authenticate=False, directConnection=True, **kwargs ) @@ -1010,33 +1010,33 @@ def unmanaged_single_client( cls, h: Any = None, p: Any = None, **kwargs: Any ) -> MongoClient[dict]: """Make a direct connection. Don't authenticate.""" - return PyMongoTestCase._unmanaged_async_mongo_client(h, p, directConnection=True, **kwargs) + return cls._unmanaged_async_mongo_client(h, p, directConnection=True, **kwargs) @classmethod def unmanaged_rs_client(cls, h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]: """Connect to the replica set and authenticate if necessary.""" - return PyMongoTestCase._unmanaged_async_mongo_client(h, p, **kwargs) + return cls._unmanaged_async_mongo_client(h, p, **kwargs) @classmethod def unmanaged_rs_client_noauth( cls, h: Any = None, p: Any = None, **kwargs: Any ) -> MongoClient[dict]: """Make a direct connection. Don't authenticate.""" - return PyMongoTestCase._unmanaged_async_mongo_client(h, p, authenticate=False, **kwargs) + return cls._unmanaged_async_mongo_client(h, p, authenticate=False, **kwargs) @classmethod def unmanaged_rs_or_single_client_noauth( cls, h: Any = None, p: Any = None, **kwargs: Any ) -> MongoClient[dict]: """Make a direct connection. Don't authenticate.""" - return PyMongoTestCase._unmanaged_async_mongo_client(h, p, authenticate=False, **kwargs) + return cls._unmanaged_async_mongo_client(h, p, authenticate=False, **kwargs) @classmethod def unmanaged_rs_or_single_client( cls, h: Any = None, p: Any = None, **kwargs: Any ) -> MongoClient[dict]: """Make a direct connection. Don't authenticate.""" - return PyMongoTestCase._unmanaged_async_mongo_client(h, p, **kwargs) + return cls._unmanaged_async_mongo_client(h, p, **kwargs) def single_client_noauth( self, h: Any = None, p: Any = None, **kwargs: Any diff --git a/test/asynchronous/__init__.py b/test/asynchronous/__init__.py index 3da10c5ec6..a3b7e18507 100644 --- a/test/asynchronous/__init__.py +++ b/test/asynchronous/__init__.py @@ -1005,7 +1005,7 @@ async def unmanaged_async_single_client_noauth( cls, h: Any = None, p: Any = None, **kwargs: Any ) -> AsyncMongoClient[dict]: """Make a direct connection. Don't authenticate.""" - return await AsyncPyMongoTestCase._unmanaged_async_mongo_client( + return await cls._unmanaged_async_mongo_client( h, p, authenticate=False, directConnection=True, **kwargs ) @@ -1014,41 +1014,35 @@ async def unmanaged_async_single_client( cls, h: Any = None, p: Any = None, **kwargs: Any ) -> AsyncMongoClient[dict]: """Make a direct connection. Don't authenticate.""" - return await AsyncPyMongoTestCase._unmanaged_async_mongo_client( - h, p, directConnection=True, **kwargs - ) + return await cls._unmanaged_async_mongo_client(h, p, directConnection=True, **kwargs) @classmethod async def unmanaged_async_rs_client( cls, h: Any = None, p: Any = None, **kwargs: Any ) -> AsyncMongoClient[dict]: """Connect to the replica set and authenticate if necessary.""" - return await AsyncPyMongoTestCase._unmanaged_async_mongo_client(h, p, **kwargs) + return await cls._unmanaged_async_mongo_client(h, p, **kwargs) @classmethod async def unmanaged_async_rs_client_noauth( cls, h: Any = None, p: Any = None, **kwargs: Any ) -> AsyncMongoClient[dict]: """Make a direct connection. Don't authenticate.""" - return await AsyncPyMongoTestCase._unmanaged_async_mongo_client( - h, p, authenticate=False, **kwargs - ) + return await cls._unmanaged_async_mongo_client(h, p, authenticate=False, **kwargs) @classmethod async def unmanaged_async_rs_or_single_client_noauth( cls, h: Any = None, p: Any = None, **kwargs: Any ) -> AsyncMongoClient[dict]: """Make a direct connection. Don't authenticate.""" - return await AsyncPyMongoTestCase._unmanaged_async_mongo_client( - h, p, authenticate=False, **kwargs - ) + return await cls._unmanaged_async_mongo_client(h, p, authenticate=False, **kwargs) @classmethod async def unmanaged_async_rs_or_single_client( cls, h: Any = None, p: Any = None, **kwargs: Any ) -> AsyncMongoClient[dict]: """Make a direct connection. Don't authenticate.""" - return await AsyncPyMongoTestCase._unmanaged_async_mongo_client(h, p, **kwargs) + return await cls._unmanaged_async_mongo_client(h, p, **kwargs) async def async_single_client_noauth( self, h: Any = None, p: Any = None, **kwargs: Any diff --git a/test/asynchronous/test_bulk.py b/test/asynchronous/test_bulk.py index a90c237890..09d27f5e66 100644 --- a/test/asynchronous/test_bulk.py +++ b/test/asynchronous/test_bulk.py @@ -951,9 +951,7 @@ async def _setup_class(cls): if cls.w is not None and cls.w > 1: for member in (await async_client_context.hello)["hosts"]: if member != (await async_client_context.hello)["primary"]: - cls.secondary = await AsyncTestBulkWriteConcern.unmanaged_async_single_client( - *partition_node(member) - ) + cls.secondary = await cls.unmanaged_async_single_client(*partition_node(member)) break @classmethod diff --git a/test/asynchronous/test_client.py b/test/asynchronous/test_client.py index f3a354fee4..ad6efca837 100644 --- a/test/asynchronous/test_client.py +++ b/test/asynchronous/test_client.py @@ -129,7 +129,7 @@ class AsyncClientUnitTest(AsyncUnitTest): @classmethod async def _setup_class(cls): - cls.client = await AsyncClientUnitTest.unmanaged_async_rs_or_single_client( + cls.client = await cls.unmanaged_async_rs_or_single_client( connect=False, serverSelectionTimeoutMS=100 ) diff --git a/test/asynchronous/test_encryption.py b/test/asynchronous/test_encryption.py index dea3571aad..7fda6890b5 100644 --- a/test/asynchronous/test_encryption.py +++ b/test/asynchronous/test_encryption.py @@ -810,9 +810,7 @@ class TestDataKeyDoubleEncryption(AsyncEncryptionIntegrationTest): async def _setup_class(cls): await super()._setup_class() cls.listener = OvertCommandListener() - cls.client = await AsyncPyMongoTestCase.unmanaged_async_rs_or_single_client( - event_listeners=[cls.listener] - ) + cls.client = await cls.unmanaged_async_rs_or_single_client(event_listeners=[cls.listener]) await cls.client.db.coll.drop() cls.vault = await create_key_vault(cls.client.keyvault.datakeys) @@ -834,7 +832,7 @@ async def _setup_class(cls): opts = AutoEncryptionOpts( cls.KMS_PROVIDERS, "keyvault.datakeys", schema_map=schemas, kms_tls_options=KMS_TLS_OPTS ) - cls.client_encrypted = await AsyncPyMongoTestCase.unmanaged_async_rs_or_single_client( + cls.client_encrypted = await cls.unmanaged_async_rs_or_single_client( auto_encryption_opts=opts, uuidRepresentation="standard" ) cls.client_encryption = AsyncClientEncryption( @@ -1204,7 +1202,7 @@ async def _setup_class(cls): opts = AutoEncryptionOpts({"local": {"key": LOCAL_MASTER_KEY}}, "keyvault.datakeys") cls.listener = OvertCommandListener() - cls.client_encrypted = await AsyncPyMongoTestCase.unmanaged_async_rs_or_single_client( + cls.client_encrypted = await cls.unmanaged_async_rs_or_single_client( auto_encryption_opts=opts, event_listeners=[cls.listener] ) cls.coll_encrypted = cls.client_encrypted.db.coll diff --git a/test/asynchronous/test_session.py b/test/asynchronous/test_session.py index 073864ef45..046d091f3b 100644 --- a/test/asynchronous/test_session.py +++ b/test/asynchronous/test_session.py @@ -88,7 +88,7 @@ async def _setup_class(cls): await super()._setup_class() # Create a second client so we can make sure clients cannot share # sessions. - cls.client2 = await AsyncPyMongoTestCase.unmanaged_async_rs_or_single_client() + cls.client2 = await cls.unmanaged_async_rs_or_single_client() # Redact no commands, so we can test user-admin commands have "lsid". cls.sensitive_commands = monitoring._SENSITIVE_COMMANDS.copy() @@ -836,9 +836,7 @@ class TestCausalConsistency(AsyncUnitTest): @classmethod async def _setup_class(cls): cls.listener = SessionTestListener() - cls.client = await AsyncPyMongoTestCase.unmanaged_async_rs_or_single_client( - event_listeners=[cls.listener] - ) + cls.client = await cls.unmanaged_async_rs_or_single_client(event_listeners=[cls.listener]) @classmethod async def _tearDown_class(cls): diff --git a/test/test_auth.py b/test/test_auth.py index 2ae0eae129..ec17a71bfe 100644 --- a/test/test_auth.py +++ b/test/test_auth.py @@ -27,10 +27,6 @@ AllowListEventListener, delay, ignore_deprecations, - rs_or_single_client, - rs_or_single_client_noauth, - single_client, - single_client_noauth, ) from pymongo import MongoClient, monitoring @@ -348,7 +344,7 @@ def tearDown(self): def test_scram_sha1(self): host, port = client_context.host, client_context.port - client = rs_or_single_client_noauth( + client = self.rs_or_single_client_noauth( "mongodb://user:pass@%s:%d/pymongo_test?authMechanism=SCRAM-SHA-1" % (host, port) ) client.pymongo_test.command("dbstats") @@ -359,7 +355,7 @@ def test_scram_sha1(self): "@%s:%d/pymongo_test?authMechanism=SCRAM-SHA-1" "&replicaSet=%s" % (host, port, client_context.replica_set_name) ) - client = single_client_noauth(uri) + client = self.single_client_noauth(uri) client.pymongo_test.command("dbstats") db = client.get_database("pymongo_test", read_preference=ReadPreference.SECONDARY) db.command("dbstats") @@ -387,7 +383,7 @@ def test_scram_skip_empty_exchange(self): "testscram", "sha256", "pwd", roles=["dbOwner"], mechanisms=["SCRAM-SHA-256"] ) - client = rs_or_single_client_noauth( + client = self.rs_or_single_client_noauth( username="sha256", password="pwd", authSource="testscram", event_listeners=[listener] ) client.testscram.command("dbstats") @@ -424,36 +420,38 @@ def test_scram(self): ) # Step 2: verify auth success cases - client = rs_or_single_client_noauth(username="sha1", password="pwd", authSource="testscram") + client = self.rs_or_single_client_noauth( + username="sha1", password="pwd", authSource="testscram" + ) client.testscram.command("dbstats") - client = rs_or_single_client_noauth( + client = self.rs_or_single_client_noauth( username="sha1", password="pwd", authSource="testscram", authMechanism="SCRAM-SHA-1" ) client.testscram.command("dbstats") - client = rs_or_single_client_noauth( + client = self.rs_or_single_client_noauth( username="sha256", password="pwd", authSource="testscram" ) client.testscram.command("dbstats") - client = rs_or_single_client_noauth( + client = self.rs_or_single_client_noauth( username="sha256", password="pwd", authSource="testscram", authMechanism="SCRAM-SHA-256" ) client.testscram.command("dbstats") # Step 2: SCRAM-SHA-1 and SCRAM-SHA-256 - client = rs_or_single_client_noauth( + client = self.rs_or_single_client_noauth( username="both", password="pwd", authSource="testscram", authMechanism="SCRAM-SHA-1" ) client.testscram.command("dbstats") - client = rs_or_single_client_noauth( + client = self.rs_or_single_client_noauth( username="both", password="pwd", authSource="testscram", authMechanism="SCRAM-SHA-256" ) client.testscram.command("dbstats") self.listener.reset() - client = rs_or_single_client_noauth( + client = self.rs_or_single_client_noauth( username="both", password="pwd", authSource="testscram", event_listeners=[self.listener] ) client.testscram.command("dbstats") @@ -466,19 +464,19 @@ def test_scram(self): self.assertEqual(started.command.get("mechanism"), "SCRAM-SHA-256") # Step 3: verify auth failure conditions - client = rs_or_single_client_noauth( + client = self.rs_or_single_client_noauth( username="sha1", password="pwd", authSource="testscram", authMechanism="SCRAM-SHA-256" ) with self.assertRaises(OperationFailure): client.testscram.command("dbstats") - client = rs_or_single_client_noauth( + client = self.rs_or_single_client_noauth( username="sha256", password="pwd", authSource="testscram", authMechanism="SCRAM-SHA-1" ) with self.assertRaises(OperationFailure): client.testscram.command("dbstats") - client = rs_or_single_client_noauth( + client = self.rs_or_single_client_noauth( username="not-a-user", password="pwd", authSource="testscram" ) with self.assertRaises(OperationFailure): @@ -491,7 +489,7 @@ def test_scram(self): port, client_context.replica_set_name, ) - client = single_client_noauth(uri) + client = self.single_client_noauth(uri) client.testscram.command("dbstats") db = client.get_database("testscram", read_preference=ReadPreference.SECONDARY) db.command("dbstats") @@ -511,12 +509,12 @@ def test_scram_saslprep(self): "testscram", "IX", "IX", roles=["dbOwner"], mechanisms=["SCRAM-SHA-256"] ) - client = rs_or_single_client_noauth( + client = self.rs_or_single_client_noauth( username="\u2168", password="\u2163", authSource="testscram" ) client.testscram.command("dbstats") - client = rs_or_single_client_noauth( + client = self.rs_or_single_client_noauth( username="\u2168", password="\u2163", authSource="testscram", @@ -524,17 +522,17 @@ def test_scram_saslprep(self): ) client.testscram.command("dbstats") - client = rs_or_single_client_noauth( + client = self.rs_or_single_client_noauth( username="\u2168", password="IV", authSource="testscram" ) client.testscram.command("dbstats") - client = rs_or_single_client_noauth( + client = self.rs_or_single_client_noauth( username="IX", password="I\u00ADX", authSource="testscram" ) client.testscram.command("dbstats") - client = rs_or_single_client_noauth( + client = self.rs_or_single_client_noauth( username="IX", password="I\u00ADX", authSource="testscram", @@ -542,25 +540,29 @@ def test_scram_saslprep(self): ) client.testscram.command("dbstats") - client = rs_or_single_client_noauth( + client = self.rs_or_single_client_noauth( username="IX", password="IX", authSource="testscram", authMechanism="SCRAM-SHA-256" ) client.testscram.command("dbstats") - client = rs_or_single_client_noauth( + client = self.rs_or_single_client_noauth( "mongodb://\u2168:\u2163@%s:%d/testscram" % (host, port) ) client.testscram.command("dbstats") - client = rs_or_single_client_noauth("mongodb://\u2168:IV@%s:%d/testscram" % (host, port)) + client = self.rs_or_single_client_noauth( + "mongodb://\u2168:IV@%s:%d/testscram" % (host, port) + ) client.testscram.command("dbstats") - client = rs_or_single_client_noauth("mongodb://IX:I\u00ADX@%s:%d/testscram" % (host, port)) + client = self.rs_or_single_client_noauth( + "mongodb://IX:I\u00ADX@%s:%d/testscram" % (host, port) + ) client.testscram.command("dbstats") - client = rs_or_single_client_noauth("mongodb://IX:IX@%s:%d/testscram" % (host, port)) + client = self.rs_or_single_client_noauth("mongodb://IX:IX@%s:%d/testscram" % (host, port)) client.testscram.command("dbstats") def test_cache(self): - client = single_client() + client = self.single_client() credentials = client.options.pool_options._credentials cache = credentials.cache self.assertIsNotNone(cache) @@ -584,7 +586,7 @@ def test_scram_threaded(self): coll.insert_one({"_id": 1}) # The first thread to call find() will authenticate - client = rs_or_single_client() + client = self.rs_or_single_client() self.addCleanup(client.close) coll = client.db.test threads = [] @@ -612,7 +614,7 @@ def tearDown(self): def test_uri_options(self): # Test default to admin host, port = client_context.host, client_context.port - client = rs_or_single_client_noauth("mongodb://admin:pass@%s:%d" % (host, port)) + client = self.rs_or_single_client_noauth("mongodb://admin:pass@%s:%d" % (host, port)) self.assertTrue(client.admin.command("dbstats")) if client_context.is_rs: @@ -621,14 +623,14 @@ def test_uri_options(self): port, client_context.replica_set_name, ) - client = single_client_noauth(uri) + client = self.single_client_noauth(uri) self.assertTrue(client.admin.command("dbstats")) db = client.get_database("admin", read_preference=ReadPreference.SECONDARY) self.assertTrue(db.command("dbstats")) # Test explicit database uri = "mongodb://user:pass@%s:%d/pymongo_test" % (host, port) - client = rs_or_single_client_noauth(uri) + client = self.rs_or_single_client_noauth(uri) self.assertRaises(OperationFailure, client.admin.command, "dbstats") self.assertTrue(client.pymongo_test.command("dbstats")) @@ -638,7 +640,7 @@ def test_uri_options(self): port, client_context.replica_set_name, ) - client = single_client_noauth(uri) + client = self.single_client_noauth(uri) self.assertRaises(OperationFailure, client.admin.command, "dbstats") self.assertTrue(client.pymongo_test.command("dbstats")) db = client.get_database("pymongo_test", read_preference=ReadPreference.SECONDARY) @@ -646,7 +648,7 @@ def test_uri_options(self): # Test authSource uri = "mongodb://user:pass@%s:%d/pymongo_test2?authSource=pymongo_test" % (host, port) - client = rs_or_single_client_noauth(uri) + client = self.rs_or_single_client_noauth(uri) self.assertRaises(OperationFailure, client.pymongo_test2.command, "dbstats") self.assertTrue(client.pymongo_test.command("dbstats")) @@ -655,7 +657,7 @@ def test_uri_options(self): "mongodb://user:pass@%s:%d/pymongo_test2?replicaSet=" "%s;authSource=pymongo_test" % (host, port, client_context.replica_set_name) ) - client = single_client_noauth(uri) + client = self.single_client_noauth(uri) self.assertRaises(OperationFailure, client.pymongo_test2.command, "dbstats") self.assertTrue(client.pymongo_test.command("dbstats")) db = client.get_database("pymongo_test", read_preference=ReadPreference.SECONDARY) diff --git a/test/test_bulk.py b/test/test_bulk.py index 751600804e..d6b4bd26c3 100644 --- a/test/test_bulk.py +++ b/test/test_bulk.py @@ -949,9 +949,7 @@ def _setup_class(cls): if cls.w is not None and cls.w > 1: for member in (client_context.hello)["hosts"]: if member != (client_context.hello)["primary"]: - cls.secondary = TestBulkWriteConcern.unmanaged_single_client( - *partition_node(member) - ) + cls.secondary = cls.unmanaged_single_client(*partition_node(member)) break @classmethod diff --git a/test/test_change_stream.py b/test/test_change_stream.py index b71f5613d8..ee98069c4f 100644 --- a/test/test_change_stream.py +++ b/test/test_change_stream.py @@ -32,7 +32,6 @@ from test.utils import ( AllowListEventListener, EventListener, - rs_or_single_client, wait_until, ) @@ -65,7 +64,7 @@ def change_stream(self, *args, **kwargs): def client_with_listener(self, *commands): """Return a client with a AllowListEventListener.""" listener = AllowListEventListener(*commands) - client = rs_or_single_client(event_listeners=[listener]) + client = self.rs_or_single_client(event_listeners=[listener]) self.addCleanup(client.close) return client, listener @@ -166,7 +165,7 @@ def test_try_next(self): @no_type_check def test_try_next_runs_one_getmore(self): listener = EventListener() - client = rs_or_single_client(event_listeners=[listener]) + client = self.rs_or_single_client(event_listeners=[listener]) # Connect to the cluster. client.admin.command("ping") listener.reset() @@ -216,7 +215,7 @@ def test_try_next_runs_one_getmore(self): @no_type_check def test_batch_size_is_honored(self): listener = EventListener() - client = rs_or_single_client(event_listeners=[listener]) + client = self.rs_or_single_client(event_listeners=[listener]) # Connect to the cluster. client.admin.command("ping") listener.reset() @@ -453,7 +452,7 @@ class ProseSpecTestsMixin: @no_type_check def _client_with_listener(self, *commands): listener = AllowListEventListener(*commands) - client = rs_or_single_client(event_listeners=[listener]) + client = self.rs_or_single_client(event_listeners=[listener]) self.addCleanup(client.close) return client, listener @@ -1090,7 +1089,7 @@ class TestAllLegacyScenarios(IntegrationTest): def setUpClass(cls): super().setUpClass() cls.listener = AllowListEventListener("aggregate", "getMore") - cls.client = rs_or_single_client(event_listeners=[cls.listener]) + cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener]) @classmethod def tearDownClass(cls): diff --git a/test/test_client.py b/test/test_client.py index 464341d6e4..c2fb3c0e35 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -127,9 +127,7 @@ class ClientUnitTest(UnitTest): @classmethod def _setup_class(cls): - cls.client = ClientUnitTest.unmanaged_rs_or_single_client( - connect=False, serverSelectionTimeoutMS=100 - ) + cls.client = cls.unmanaged_rs_or_single_client(connect=False, serverSelectionTimeoutMS=100) @classmethod def _tearDown_class(cls): diff --git a/test/test_collation.py b/test/test_collation.py index bedf0a2eaa..19df25c1c0 100644 --- a/test/test_collation.py +++ b/test/test_collation.py @@ -18,7 +18,7 @@ import functools import warnings from test import IntegrationTest, client_context, unittest -from test.utils import EventListener, rs_or_single_client +from test.utils import EventListener from typing import Any from pymongo.collation import ( @@ -99,7 +99,7 @@ class TestCollation(IntegrationTest): def setUpClass(cls): super().setUpClass() cls.listener = EventListener() - cls.client = rs_or_single_client(event_listeners=[cls.listener]) + cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener]) cls.db = cls.client.pymongo_test cls.collation = Collation("en_US") cls.warn_context = warnings.catch_warnings() diff --git a/test/test_comment.py b/test/test_comment.py index 931446ef3a..c0f037ea44 100644 --- a/test/test_comment.py +++ b/test/test_comment.py @@ -22,7 +22,7 @@ sys.path[0:0] = [""] from test import IntegrationTest, client_context, unittest -from test.utils import EventListener, rs_or_single_client +from test.utils import EventListener from bson.dbref import DBRef from pymongo.operations import IndexModel @@ -109,7 +109,7 @@ def _test_ops( @client_context.require_replica_set def test_database_helpers(self): listener = EventListener() - db = rs_or_single_client(event_listeners=[listener]).db + db = self.rs_or_single_client(event_listeners=[listener]).db helpers = [ (db.watch, []), (db.command, ["hello"]), @@ -126,7 +126,7 @@ def test_database_helpers(self): @client_context.require_replica_set def test_client_helpers(self): listener = EventListener() - cli = rs_or_single_client(event_listeners=[listener]) + cli = self.rs_or_single_client(event_listeners=[listener]) helpers = [ (cli.watch, []), (cli.list_databases, []), @@ -141,7 +141,7 @@ def test_client_helpers(self): @client_context.require_version_min(4, 7, -1) def test_collection_helpers(self): listener = EventListener() - db = rs_or_single_client(event_listeners=[listener])[self.db.name] + db = self.rs_or_single_client(event_listeners=[listener])[self.db.name] coll = db.get_collection("test") helpers = [ diff --git a/test/test_common.py b/test/test_common.py index 358cd29b81..3228dc97fb 100644 --- a/test/test_common.py +++ b/test/test_common.py @@ -21,7 +21,6 @@ sys.path[0:0] = [""] from test import IntegrationTest, client_context, connected, unittest -from test.utils import rs_or_single_client, single_client from bson.binary import PYTHON_LEGACY, STANDARD, Binary, UuidRepresentation from bson.codec_options import CodecOptions @@ -111,10 +110,10 @@ def test_uuid_representation(self): ) def test_write_concern(self): - c = rs_or_single_client(connect=False) + c = self.rs_or_single_client(connect=False) self.assertEqual(WriteConcern(), c.write_concern) - c = rs_or_single_client(connect=False, w=2, wTimeoutMS=1000) + c = self.rs_or_single_client(connect=False, w=2, wTimeoutMS=1000) wc = WriteConcern(w=2, wtimeout=1000) self.assertEqual(wc, c.write_concern) @@ -134,7 +133,7 @@ def test_write_concern(self): def test_mongo_client(self): pair = client_context.pair - m = rs_or_single_client(w=0) + m = self.rs_or_single_client(w=0) coll = m.pymongo_test.write_concern_test coll.drop() doc = {"_id": ObjectId()} @@ -143,17 +142,19 @@ def test_mongo_client(self): coll = coll.with_options(write_concern=WriteConcern(w=1)) self.assertRaises(OperationFailure, coll.insert_one, doc) - m = rs_or_single_client() + m = self.rs_or_single_client() coll = m.pymongo_test.write_concern_test new_coll = coll.with_options(write_concern=WriteConcern(w=0)) self.assertTrue(new_coll.insert_one(doc)) self.assertRaises(OperationFailure, coll.insert_one, doc) - m = rs_or_single_client(f"mongodb://{pair}/", replicaSet=client_context.replica_set_name) + m = self.rs_or_single_client( + f"mongodb://{pair}/", replicaSet=client_context.replica_set_name + ) coll = m.pymongo_test.write_concern_test self.assertRaises(OperationFailure, coll.insert_one, doc) - m = rs_or_single_client( + m = self.rs_or_single_client( f"mongodb://{pair}/?w=0", replicaSet=client_context.replica_set_name ) @@ -161,8 +162,8 @@ def test_mongo_client(self): coll.insert_one(doc) # Equality tests - direct = connected(single_client(w=0)) - direct2 = connected(single_client(f"mongodb://{pair}/?w=0", **self.credentials)) + direct = connected(self.single_client(w=0)) + direct2 = connected(self.single_client(f"mongodb://{pair}/?w=0", **self.credentials)) self.assertEqual(direct, direct2) self.assertFalse(direct != direct2) diff --git a/test/test_connection_monitoring.py b/test/test_connection_monitoring.py index 9ee3202e13..142af0f9a7 100644 --- a/test/test_connection_monitoring.py +++ b/test/test_connection_monitoring.py @@ -30,9 +30,6 @@ client_context, get_pool, get_pools, - rs_or_single_client, - single_client, - single_client_noauth, wait_until, ) from test.utils_spec_runner import SpecRunnerThread @@ -250,7 +247,7 @@ def run_scenario(self, scenario_def, test): else: kill_cursor_frequency = interval / 1000.0 with client_knobs(kill_cursor_frequency=kill_cursor_frequency, min_heartbeat_interval=0.05): - client = single_client(**opts) + client = self.single_client(**opts) # Update the SD to a known type because the DummyMonitor will not. # Note we cannot simply call topology.on_change because that would # internally call pool.ready() which introduces unexpected @@ -323,13 +320,13 @@ def cleanup(): # Prose tests. Numbers correspond to the prose test number in the spec. # def test_1_client_connection_pool_options(self): - client = rs_or_single_client(**self.POOL_OPTIONS) + client = self.rs_or_single_client(**self.POOL_OPTIONS) self.addCleanup(client.close) pool_opts = get_pool(client).opts self.assertEqual(pool_opts.non_default_options, self.POOL_OPTIONS) def test_2_all_client_pools_have_same_options(self): - client = rs_or_single_client(**self.POOL_OPTIONS) + client = self.rs_or_single_client(**self.POOL_OPTIONS) self.addCleanup(client.close) client.admin.command("ping") # Discover at least one secondary. @@ -345,14 +342,14 @@ def test_2_all_client_pools_have_same_options(self): def test_3_uri_connection_pool_options(self): opts = "&".join([f"{k}={v}" for k, v in self.POOL_OPTIONS.items()]) uri = f"mongodb://{client_context.pair}/?{opts}" - client = rs_or_single_client(uri) + client = self.rs_or_single_client(uri) self.addCleanup(client.close) pool_opts = get_pool(client).opts self.assertEqual(pool_opts.non_default_options, self.POOL_OPTIONS) def test_4_subscribe_to_events(self): listener = CMAPListener() - client = single_client(event_listeners=[listener]) + client = self.single_client(event_listeners=[listener]) self.addCleanup(client.close) self.assertEqual(listener.event_count(PoolCreatedEvent), 1) @@ -376,7 +373,7 @@ def test_4_subscribe_to_events(self): def test_5_check_out_fails_connection_error(self): listener = CMAPListener() - client = single_client(event_listeners=[listener]) + client = self.single_client(event_listeners=[listener]) self.addCleanup(client.close) pool = get_pool(client) @@ -403,7 +400,7 @@ def mock_connect(*args, **kwargs): @client_context.require_no_fips def test_5_check_out_fails_auth_error(self): listener = CMAPListener() - client = single_client_noauth( + client = self.single_client_noauth( username="notauser", password="fail", event_listeners=[listener] ) self.addCleanup(client.close) @@ -449,7 +446,7 @@ def test_events_repr(self): def test_close_leaves_pool_unpaused(self): listener = CMAPListener() - client = single_client(event_listeners=[listener]) + client = self.single_client(event_listeners=[listener]) client.admin.command("ping") pool = get_pool(client) client.close() diff --git a/test/test_connections_survive_primary_stepdown_spec.py b/test/test_connections_survive_primary_stepdown_spec.py index 674612693c..fba7675743 100644 --- a/test/test_connections_survive_primary_stepdown_spec.py +++ b/test/test_connections_survive_primary_stepdown_spec.py @@ -24,7 +24,6 @@ CMAPListener, ensure_all_connected, repl_set_step_down, - rs_or_single_client, ) from bson import SON @@ -43,7 +42,7 @@ class TestConnectionsSurvivePrimaryStepDown(IntegrationTest): def setUpClass(cls): super().setUpClass() cls.listener = CMAPListener() - cls.client = rs_or_single_client( + cls.client = cls.unmanaged_rs_or_single_client( event_listeners=[cls.listener], retryWrites=False, heartbeatFrequencyMS=500 ) diff --git a/test/test_data_lake.py b/test/test_data_lake.py index 8ba83ab190..a374db550e 100644 --- a/test/test_data_lake.py +++ b/test/test_data_lake.py @@ -27,8 +27,6 @@ from test.unified_format import generate_test_classes from test.utils import ( OvertCommandListener, - rs_client_noauth, - rs_or_single_client, ) pytestmark = pytest.mark.data_lake @@ -65,7 +63,7 @@ def setUpClass(cls): # Test killCursors def test_1(self): listener = OvertCommandListener() - client = rs_or_single_client(event_listeners=[listener]) + client = self.rs_or_single_client(event_listeners=[listener]) cursor = client[self.TEST_DB][self.TEST_COLLECTION].find({}, batch_size=2) next(cursor) @@ -90,13 +88,13 @@ def test_1(self): # Test no auth def test_2(self): - client = rs_client_noauth() + client = self.rs_client_noauth() client.admin.command("ping") # Test with auth def test_3(self): for mechanism in ["SCRAM-SHA-1", "SCRAM-SHA-256"]: - client = rs_or_single_client(authMechanism=mechanism) + client = self.rs_or_single_client(authMechanism=mechanism) client[self.TEST_DB][self.TEST_COLLECTION].find_one() diff --git a/test/test_discovery_and_monitoring.py b/test/test_discovery_and_monitoring.py index ef32afbcd4..131a4c3531 100644 --- a/test/test_discovery_and_monitoring.py +++ b/test/test_discovery_and_monitoring.py @@ -32,9 +32,7 @@ assertion_context, client_context, get_pool, - rs_or_single_client, server_name_to_type, - single_client, wait_until, ) from unittest.mock import patch @@ -272,7 +270,7 @@ class TestIgnoreStaleErrors(IntegrationTest): def test_ignore_stale_connection_errors(self): N_THREADS = 5 barrier = threading.Barrier(N_THREADS, timeout=30) - client = rs_or_single_client(minPoolSize=N_THREADS) + client = self.rs_or_single_client(minPoolSize=N_THREADS) self.addCleanup(client.close) # Wait for initial discovery. @@ -319,7 +317,7 @@ class TestPoolManagement(IntegrationTest): def test_pool_unpause(self): # This test implements the prose test "Connection Pool Management" listener = CMAPHeartbeatListener() - client = single_client( + client = self.single_client( appName="SDAMPoolManagementTest", heartbeatFrequencyMS=500, event_listeners=[listener] ) self.addCleanup(client.close) @@ -353,7 +351,7 @@ def setUp(self): super().setUp() def test_rtt_connection_is_enabled_stream(self): - client = rs_or_single_client(serverMonitoringMode="stream") + client = self.rs_or_single_client(serverMonitoringMode="stream") self.addCleanup(client.close) client.admin.command("ping") @@ -373,7 +371,7 @@ def predicate(): wait_until(predicate, "find all RTT monitors") def test_rtt_connection_is_disabled_poll(self): - client = rs_or_single_client(serverMonitoringMode="poll") + client = self.rs_or_single_client(serverMonitoringMode="poll") self.addCleanup(client.close) self.assert_rtt_connection_is_disabled(client) @@ -387,7 +385,7 @@ def test_rtt_connection_is_disabled_auto(self): ] for env in envs: with patch.dict("os.environ", env): - client = rs_or_single_client(serverMonitoringMode="auto") + client = self.rs_or_single_client(serverMonitoringMode="auto") self.addCleanup(client.close) self.assert_rtt_connection_is_disabled(client) diff --git a/test/test_encryption.py b/test/test_encryption.py index 75fe9a4e34..21fcc78fb8 100644 --- a/test/test_encryption.py +++ b/test/test_encryption.py @@ -806,7 +806,7 @@ class TestDataKeyDoubleEncryption(EncryptionIntegrationTest): def _setup_class(cls): super()._setup_class() cls.listener = OvertCommandListener() - cls.client = PyMongoTestCase.unmanaged_rs_or_single_client(event_listeners=[cls.listener]) + cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener]) cls.client.db.coll.drop() cls.vault = create_key_vault(cls.client.keyvault.datakeys) @@ -828,7 +828,7 @@ def _setup_class(cls): opts = AutoEncryptionOpts( cls.KMS_PROVIDERS, "keyvault.datakeys", schema_map=schemas, kms_tls_options=KMS_TLS_OPTS ) - cls.client_encrypted = PyMongoTestCase.unmanaged_rs_or_single_client( + cls.client_encrypted = cls.unmanaged_rs_or_single_client( auto_encryption_opts=opts, uuidRepresentation="standard" ) cls.client_encryption = ClientEncryption( @@ -1196,7 +1196,7 @@ def _setup_class(cls): opts = AutoEncryptionOpts({"local": {"key": LOCAL_MASTER_KEY}}, "keyvault.datakeys") cls.listener = OvertCommandListener() - cls.client_encrypted = PyMongoTestCase.unmanaged_rs_or_single_client( + cls.client_encrypted = cls.unmanaged_rs_or_single_client( auto_encryption_opts=opts, event_listeners=[cls.listener] ) cls.coll_encrypted = cls.client_encrypted.db.coll diff --git a/test/test_examples.py b/test/test_examples.py index 296283db28..ebf1d784a3 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -22,7 +22,7 @@ sys.path[0:0] = [""] from test import IntegrationTest, client_context, unittest -from test.utils import rs_client, wait_until +from test.utils import wait_until import pymongo from pymongo.errors import ConnectionFailure, OperationFailure @@ -1128,7 +1128,7 @@ def update_employee_info(session): self.assertEqual(employee["status"], "Inactive") def MongoClient(_): - return rs_client() + return self.rs_client() uriString = None @@ -1220,7 +1220,7 @@ class TestVersionedApiExamples(IntegrationTest): def test_versioned_api(self): # Versioned API examples def MongoClient(_, server_api): - return rs_client(server_api=server_api, connect=False) + return self.rs_client(server_api=server_api, connect=False) uri = None @@ -1251,7 +1251,7 @@ def test_versioned_api_migration(self): ): self.skipTest("This test needs MongoDB 5.0.2 or newer") - client = rs_client(server_api=ServerApi("1", strict=True)) + client = self.rs_client(server_api=ServerApi("1", strict=True)) client.db.sales.drop() # Start Versioned API Example 5 diff --git a/test/test_gridfs.py b/test/test_gridfs.py index 19ec152bd1..c8ecf8b560 100644 --- a/test/test_gridfs.py +++ b/test/test_gridfs.py @@ -26,7 +26,7 @@ sys.path[0:0] = [""] from test import IntegrationTest, client_context, unittest -from test.utils import joinall, one, rs_client, rs_or_single_client, single_client +from test.utils import joinall, one import gridfs from bson.binary import Binary @@ -492,7 +492,7 @@ def test_grid_in_non_int_chunksize(self): def test_unacknowledged(self): # w=0 is prohibited. with self.assertRaises(ConfigurationError): - gridfs.GridFS(rs_or_single_client(w=0).pymongo_test) + gridfs.GridFS(self.rs_or_single_client(w=0).pymongo_test) def test_md5(self): gin = self.fs.new_file() @@ -519,7 +519,7 @@ def tearDownClass(cls): client_context.client.drop_database("gfsreplica") def test_gridfs_replica_set(self): - rsc = rs_client(w=client_context.w, read_preference=ReadPreference.SECONDARY) + rsc = self.rs_client(w=client_context.w, read_preference=ReadPreference.SECONDARY) fs = gridfs.GridFS(rsc.gfsreplica, "gfsreplicatest") @@ -532,7 +532,7 @@ def test_gridfs_replica_set(self): def test_gridfs_secondary(self): secondary_host, secondary_port = one(self.client.secondaries) - secondary_connection = single_client( + secondary_connection = self.single_client( secondary_host, secondary_port, read_preference=ReadPreference.SECONDARY ) @@ -547,7 +547,7 @@ def test_gridfs_secondary_lazy(self): # Should detect it's connected to secondary and not attempt to # create index. secondary_host, secondary_port = one(self.client.secondaries) - client = single_client( + client = self.single_client( secondary_host, secondary_port, read_preference=ReadPreference.SECONDARY, connect=False ) diff --git a/test/test_gridfs_bucket.py b/test/test_gridfs_bucket.py index c3945d1053..3ff63f1fa2 100644 --- a/test/test_gridfs_bucket.py +++ b/test/test_gridfs_bucket.py @@ -27,7 +27,7 @@ sys.path[0:0] = [""] from test import IntegrationTest, client_context, unittest -from test.utils import joinall, one, rs_client, rs_or_single_client, single_client +from test.utils import joinall, one import gridfs from bson.binary import Binary @@ -391,7 +391,7 @@ def test_grid_in_non_int_chunksize(self): def test_unacknowledged(self): # w=0 is prohibited. with self.assertRaises(ConfigurationError): - gridfs.GridFSBucket(rs_or_single_client(w=0).pymongo_test) + gridfs.GridFSBucket(self.rs_or_single_client(w=0).pymongo_test) def test_rename(self): _id = self.fs.upload_from_stream("first_name", b"testing") @@ -489,7 +489,7 @@ def tearDownClass(cls): client_context.client.drop_database("gfsbucketreplica") def test_gridfs_replica_set(self): - rsc = rs_client(w=client_context.w, read_preference=ReadPreference.SECONDARY) + rsc = self.rs_client(w=client_context.w, read_preference=ReadPreference.SECONDARY) gfs = gridfs.GridFSBucket(rsc.gfsbucketreplica, "gfsbucketreplicatest") oid = gfs.upload_from_stream("test_filename", b"foo") @@ -498,7 +498,7 @@ def test_gridfs_replica_set(self): def test_gridfs_secondary(self): secondary_host, secondary_port = one(self.client.secondaries) - secondary_connection = single_client( + secondary_connection = self.single_client( secondary_host, secondary_port, read_preference=ReadPreference.SECONDARY ) @@ -513,7 +513,7 @@ def test_gridfs_secondary_lazy(self): # Should detect it's connected to secondary and not attempt to # create index. secondary_host, secondary_port = one(self.client.secondaries) - client = single_client( + client = self.single_client( secondary_host, secondary_port, read_preference=ReadPreference.SECONDARY, connect=False ) diff --git a/test/test_heartbeat_monitoring.py b/test/test_heartbeat_monitoring.py index 1302df8fde..5e203a33b3 100644 --- a/test/test_heartbeat_monitoring.py +++ b/test/test_heartbeat_monitoring.py @@ -20,7 +20,7 @@ sys.path[0:0] = [""] from test import IntegrationTest, client_knobs, unittest -from test.utils import HeartbeatEventListener, MockPool, single_client, wait_until +from test.utils import HeartbeatEventListener, MockPool, wait_until from pymongo.errors import ConnectionFailure from pymongo.hello import Hello, HelloCompat @@ -40,7 +40,7 @@ def _check_with_socket(self, *args, **kwargs): raise responses[1] return Hello(responses[1]), 99 - m = single_client( + m = self.single_client( h=uri, event_listeners=(listener,), _monitor_class=MockMonitor, _pool_class=MockPool ) diff --git a/test/test_load_balancer.py b/test/test_load_balancer.py index a4db7395f1..23bea4d984 100644 --- a/test/test_load_balancer.py +++ b/test/test_load_balancer.py @@ -26,7 +26,7 @@ from test import IntegrationTest, client_context, unittest from test.unified_format import generate_test_classes -from test.utils import ExceptionCatchingThread, get_pool, rs_client, wait_until +from test.utils import ExceptionCatchingThread, get_pool, wait_until pytestmark = pytest.mark.load_balancer @@ -54,7 +54,7 @@ def test_connections_are_only_returned_once(self): @client_context.require_load_balancer def test_unpin_committed_transaction(self): - client = rs_client() + client = self.rs_client() self.addCleanup(client.close) pool = get_pool(client) coll = client[self.db.name].test @@ -85,7 +85,7 @@ def create_resource(coll): self._test_no_gc_deadlock(create_resource) def _test_no_gc_deadlock(self, create_resource): - client = rs_client() + client = self.rs_client() self.addCleanup(client.close) pool = get_pool(client) coll = client[self.db.name].test @@ -124,7 +124,7 @@ def _test_no_gc_deadlock(self, create_resource): @client_context.require_transactions def test_session_gc(self): - client = rs_client() + client = self.rs_client() self.addCleanup(client.close) pool = get_pool(client) session = client.start_session() diff --git a/test/test_max_staleness.py b/test/test_max_staleness.py index 1b0130f7d8..38dc499a9c 100644 --- a/test/test_max_staleness.py +++ b/test/test_max_staleness.py @@ -24,8 +24,7 @@ sys.path[0:0] = [""] -from test import client_context, unittest -from test.utils import rs_or_single_client +from test import PyMongoTestCase, client_context, unittest from test.utils_selection_tests import create_selection_tests from pymongo import MongoClient @@ -40,7 +39,7 @@ class TestAllScenarios(create_selection_tests(_TEST_PATH)): # type: ignore pass -class TestMaxStaleness(unittest.TestCase): +class TestMaxStaleness(PyMongoTestCase): def test_max_staleness(self): client = MongoClient() self.assertEqual(-1, client.read_preference.max_staleness) @@ -81,7 +80,7 @@ def test_max_staleness(self): def test_max_staleness_float(self): with self.assertRaises(TypeError) as ctx: - rs_or_single_client(maxStalenessSeconds=1.5, readPreference="nearest") + self.rs_or_single_client(maxStalenessSeconds=1.5, readPreference="nearest") self.assertIn("must be an integer", str(ctx.exception)) @@ -96,7 +95,7 @@ def test_max_staleness_float(self): def test_max_staleness_zero(self): # Zero is too small. with self.assertRaises(ValueError) as ctx: - rs_or_single_client(maxStalenessSeconds=0, readPreference="nearest") + self.rs_or_single_client(maxStalenessSeconds=0, readPreference="nearest") self.assertIn("must be a positive integer", str(ctx.exception)) @@ -111,7 +110,7 @@ def test_max_staleness_zero(self): @client_context.require_replica_set def test_last_write_date(self): # From max-staleness-tests.rst, "Parse lastWriteDate". - client = rs_or_single_client(heartbeatFrequencyMS=500) + client = self.rs_or_single_client(heartbeatFrequencyMS=500) client.pymongo_test.test.insert_one({}) # Wait for the server description to be updated. time.sleep(1) diff --git a/test/test_monitor.py b/test/test_monitor.py index fd82fc1ca4..5fb9f3f267 100644 --- a/test/test_monitor.py +++ b/test/test_monitor.py @@ -25,7 +25,6 @@ from test import IntegrationTest, connected, unittest from test.utils import ( ServerAndTopologyEventListener, - single_client, wait_until, ) @@ -47,16 +46,15 @@ def get_executors(client): return [e for e in executors if e is not None] -def create_client(): - listener = ServerAndTopologyEventListener() - client = single_client(event_listeners=[listener]) - connected(client) - return client - - class TestMonitor(IntegrationTest): + def create_client(self): + listener = ServerAndTopologyEventListener() + client = self.single_client(event_listeners=[listener]) + connected(client) + return client + def test_cleanup_executors_on_client_del(self): - client = create_client() + client = self.create_client() executors = get_executors(client) self.assertEqual(len(executors), 4) @@ -70,7 +68,7 @@ def test_cleanup_executors_on_client_del(self): wait_until(partial(unregistered, ref), f"unregister executor: {name}", timeout=5) def test_cleanup_executors_on_client_close(self): - client = create_client() + client = self.create_client() executors = get_executors(client) self.assertEqual(len(executors), 4) diff --git a/test/test_monitoring.py b/test/test_monitoring.py index ed6a3d0bc2..90c4957186 100644 --- a/test/test_monitoring.py +++ b/test/test_monitoring.py @@ -22,7 +22,7 @@ sys.path[0:0] = [""] from test import IntegrationTest, client_context, client_knobs, sanitize_cmd, unittest -from test.utils import EventListener, rs_or_single_client, single_client, wait_until +from test.utils import EventListener, wait_until from bson.int64 import Int64 from bson.objectid import ObjectId @@ -42,7 +42,9 @@ class TestCommandMonitoring(IntegrationTest): def setUpClass(cls): super().setUpClass() cls.listener = EventListener() - cls.client = rs_or_single_client(event_listeners=[cls.listener], retryWrites=False) + cls.client = cls.unmanaged_rs_or_single_client( + event_listeners=[cls.listener], retryWrites=False + ) @classmethod def tearDownClass(cls): @@ -390,7 +392,7 @@ def test_get_more_failure(self): @client_context.require_secondaries_count(1) def test_not_primary_error(self): address = next(iter(client_context.client.secondaries)) - client = single_client(*address, event_listeners=[self.listener]) + client = self.single_client(*address, event_listeners=[self.listener]) # Clear authentication command results from the listener. client.admin.command("ping") self.listener.reset() @@ -1125,7 +1127,7 @@ def setUpClass(cls): # We plan to call register(), which internally modifies _LISTENERS. cls.saved_listeners = copy.deepcopy(monitoring._LISTENERS) monitoring.register(cls.listener) - cls.client = single_client() + cls.client = cls.unmanaged_single_client() # Get one (authenticated) socket in the pool. cls.client.pymongo_test.command("ping") diff --git a/test/test_pooling.py b/test/test_pooling.py index 31259d7b3a..3b867965bd 100644 --- a/test/test_pooling.py +++ b/test/test_pooling.py @@ -31,7 +31,7 @@ sys.path[0:0] = [""] from test import IntegrationTest, client_context, unittest -from test.utils import delay, get_pool, joinall, rs_or_single_client +from test.utils import delay, get_pool, joinall from pymongo.socket_checker import SocketChecker from pymongo.synchronous.pool import Pool, PoolOptions @@ -151,7 +151,7 @@ class _TestPoolingBase(IntegrationTest): def setUp(self): super().setUp() - self.c = rs_or_single_client() + self.c = self.rs_or_single_client() db = self.c[DB] db.unique.drop() db.test.drop() @@ -378,7 +378,7 @@ def test_checkout_more_than_max_pool_size(self): socket_info.close_conn(None) def test_maxConnecting(self): - client = rs_or_single_client() + client = self.rs_or_single_client() self.addCleanup(client.close) self.client.test.test.insert_one({}) self.addCleanup(self.client.test.test.delete_many, {}) @@ -415,7 +415,7 @@ def find_one(): @client_context.require_failCommand_appName def test_csot_timeout_message(self): - client = rs_or_single_client(appName="connectionTimeoutApp") + client = self.rs_or_single_client(appName="connectionTimeoutApp") self.addCleanup(client.close) # Mock an operation failing due to pymongo.timeout(). mock_connection_timeout = { @@ -440,7 +440,7 @@ def test_csot_timeout_message(self): @client_context.require_failCommand_appName def test_socket_timeout_message(self): - client = rs_or_single_client(socketTimeoutMS=500, appName="connectionTimeoutApp") + client = self.rs_or_single_client(socketTimeoutMS=500, appName="connectionTimeoutApp") self.addCleanup(client.close) # Mock an operation failing due to socketTimeoutMS. mock_connection_timeout = { @@ -479,7 +479,7 @@ def test_connection_timeout_message(self): }, } - client = rs_or_single_client( + client = self.rs_or_single_client( connectTimeoutMS=500, socketTimeoutMS=500, appName="connectionTimeoutApp", @@ -502,7 +502,7 @@ def test_connection_timeout_message(self): class TestPoolMaxSize(_TestPoolingBase): def test_max_pool_size(self): max_pool_size = 4 - c = rs_or_single_client(maxPoolSize=max_pool_size) + c = self.rs_or_single_client(maxPoolSize=max_pool_size) self.addCleanup(c.close) collection = c[DB].test @@ -538,7 +538,7 @@ def f(): self.assertEqual(0, cx_pool.requests) def test_max_pool_size_none(self): - c = rs_or_single_client(maxPoolSize=None) + c = self.rs_or_single_client(maxPoolSize=None) self.addCleanup(c.close) collection = c[DB].test @@ -570,7 +570,7 @@ def f(): self.assertEqual(cx_pool.max_pool_size, float("inf")) def test_max_pool_size_zero(self): - c = rs_or_single_client(maxPoolSize=0) + c = self.rs_or_single_client(maxPoolSize=0) self.addCleanup(c.close) pool = get_pool(c) self.assertEqual(pool.max_pool_size, float("inf")) diff --git a/test/test_read_concern.py b/test/test_read_concern.py index 97855872cf..ea9ce49a30 100644 --- a/test/test_read_concern.py +++ b/test/test_read_concern.py @@ -21,7 +21,7 @@ sys.path[0:0] = [""] from test import IntegrationTest, client_context -from test.utils import OvertCommandListener, rs_or_single_client +from test.utils import OvertCommandListener from bson.son import SON from pymongo.errors import OperationFailure @@ -36,7 +36,7 @@ class TestReadConcern(IntegrationTest): def setUpClass(cls): super().setUpClass() cls.listener = OvertCommandListener() - cls.client = rs_or_single_client(event_listeners=[cls.listener]) + cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener]) cls.db = cls.client.pymongo_test client_context.client.pymongo_test.create_collection("coll") @@ -67,7 +67,7 @@ def test_read_concern(self): def test_read_concern_uri(self): uri = f"mongodb://{client_context.pair}/?readConcernLevel=majority" - client = rs_or_single_client(uri, connect=False) + client = self.rs_or_single_client(uri, connect=False) self.assertEqual(ReadConcern("majority"), client.read_concern) def test_invalid_read_concern(self): diff --git a/test/test_read_preferences.py b/test/test_read_preferences.py index 2cd3195f40..32883399e1 100644 --- a/test/test_read_preferences.py +++ b/test/test_read_preferences.py @@ -30,8 +30,6 @@ from test.utils import ( OvertCommandListener, one, - rs_client, - single_client, wait_until, ) from test.version import Version @@ -58,7 +56,7 @@ class TestSelections(IntegrationTest): @client_context.require_connection def test_bool(self): - client = single_client() + client = self.single_client() wait_until(lambda: client.address, "discover primary") selection = Selection.from_topology_description(client._topology.description) @@ -128,7 +126,7 @@ def read_from_which_kind(self, client): return None def assertReadsFrom(self, expected, **kwargs): - c = rs_client(**kwargs) + c = self.rs_client(**kwargs) wait_until(lambda: len(c.nodes - c.arbiters) == client_context.w, "discovered all nodes") used = self.read_from_which_kind(c) @@ -139,7 +137,7 @@ class TestSingleSecondaryOk(TestReadPreferencesBase): def test_reads_from_secondary(self): host, port = next(iter(self.client.secondaries)) # Direct connection to a secondary. - client = single_client(host, port) + client = self.single_client(host, port) self.assertFalse(client.is_primary) # Regardless of read preference, we should be able to do @@ -175,19 +173,21 @@ def test_mode_validation(self): ReadPreference.SECONDARY_PREFERRED, ReadPreference.NEAREST, ): - self.assertEqual(mode, rs_client(read_preference=mode).read_preference) + self.assertEqual(mode, self.rs_client(read_preference=mode).read_preference) - self.assertRaises(TypeError, rs_client, read_preference="foo") + self.assertRaises(TypeError, self.rs_client, read_preference="foo") def test_tag_sets_validation(self): S = Secondary(tag_sets=[{}]) - self.assertEqual([{}], rs_client(read_preference=S).read_preference.tag_sets) + self.assertEqual([{}], self.rs_client(read_preference=S).read_preference.tag_sets) S = Secondary(tag_sets=[{"k": "v"}]) - self.assertEqual([{"k": "v"}], rs_client(read_preference=S).read_preference.tag_sets) + self.assertEqual([{"k": "v"}], self.rs_client(read_preference=S).read_preference.tag_sets) S = Secondary(tag_sets=[{"k": "v"}, {}]) - self.assertEqual([{"k": "v"}, {}], rs_client(read_preference=S).read_preference.tag_sets) + self.assertEqual( + [{"k": "v"}, {}], self.rs_client(read_preference=S).read_preference.tag_sets + ) self.assertRaises(ValueError, Secondary, tag_sets=[]) @@ -200,20 +200,22 @@ def test_tag_sets_validation(self): def test_threshold_validation(self): self.assertEqual( - 17, rs_client(localThresholdMS=17, connect=False).options.local_threshold_ms + 17, self.rs_client(localThresholdMS=17, connect=False).options.local_threshold_ms ) self.assertEqual( - 42, rs_client(localThresholdMS=42, connect=False).options.local_threshold_ms + 42, self.rs_client(localThresholdMS=42, connect=False).options.local_threshold_ms ) self.assertEqual( - 666, rs_client(localThresholdMS=666, connect=False).options.local_threshold_ms + 666, self.rs_client(localThresholdMS=666, connect=False).options.local_threshold_ms ) - self.assertEqual(0, rs_client(localThresholdMS=0, connect=False).options.local_threshold_ms) + self.assertEqual( + 0, self.rs_client(localThresholdMS=0, connect=False).options.local_threshold_ms + ) - self.assertRaises(ValueError, rs_client, localthresholdms=-1) + self.assertRaises(ValueError, self.rs_client, localthresholdms=-1) def test_zero_latency(self): ping_times: set = set() @@ -223,7 +225,7 @@ def test_zero_latency(self): for ping_time, host in zip(ping_times, self.client.nodes): ServerDescription._host_to_round_trip_time[host] = ping_time try: - client = connected(rs_client(readPreference="nearest", localThresholdMS=0)) + client = connected(self.rs_client(readPreference="nearest", localThresholdMS=0)) wait_until(lambda: client.nodes == self.client.nodes, "discovered all nodes") host = self.read_from_which_host(client) for _ in range(5): @@ -236,7 +238,7 @@ def test_primary(self): def test_primary_with_tags(self): # Tags not allowed with PRIMARY - self.assertRaises(ConfigurationError, rs_client, tag_sets=[{"dc": "ny"}]) + self.assertRaises(ConfigurationError, self.rs_client, tag_sets=[{"dc": "ny"}]) def test_primary_preferred(self): self.assertReadsFrom("primary", read_preference=ReadPreference.PRIMARY_PREFERRED) @@ -250,7 +252,9 @@ def test_secondary_preferred(self): def test_nearest(self): # With high localThresholdMS, expect to read from any # member - c = rs_client(read_preference=ReadPreference.NEAREST, localThresholdMS=10000) # 10 seconds + c = self.rs_client( + read_preference=ReadPreference.NEAREST, localThresholdMS=10000 + ) # 10 seconds data_members = {self.client.primary} | self.client.secondaries @@ -540,7 +544,7 @@ def test_send_hedge(self): if client_context.supports_secondary_read_pref: cases["secondary"] = Secondary listener = OvertCommandListener() - client = rs_client(event_listeners=[listener]) + client = self.rs_client(event_listeners=[listener]) self.addCleanup(client.close) client.admin.command("ping") for _mode, cls in cases.items(): @@ -667,13 +671,13 @@ def test_mongos_max_staleness(self): else: self.fail("mongos accepted invalid staleness") - coll = single_client( + coll = self.single_client( readPreference="secondaryPreferred", maxStalenessSeconds=120 ).pymongo_test.test # No error coll.find_one() - coll = single_client( + coll = self.single_client( readPreference="secondaryPreferred", maxStalenessSeconds=10 ).pymongo_test.test try: diff --git a/test/test_read_write_concern_spec.py b/test/test_read_write_concern_spec.py index 3e37e8f9a5..67943d495d 100644 --- a/test/test_read_write_concern_spec.py +++ b/test/test_read_write_concern_spec.py @@ -24,12 +24,7 @@ from test import IntegrationTest, client_context, unittest from test.unified_format import generate_test_classes -from test.utils import ( - EventListener, - disable_replication, - enable_replication, - rs_or_single_client, -) +from test.utils import EventListener from pymongo import DESCENDING from pymongo.errors import ( @@ -51,7 +46,7 @@ class TestReadWriteConcernSpec(IntegrationTest): def test_omit_default_read_write_concern(self): listener = EventListener() # Client with default readConcern and writeConcern - client = rs_or_single_client(event_listeners=[listener]) + client = self.rs_or_single_client(event_listeners=[listener]) self.addCleanup(client.close) collection = client.pymongo_test.collection # Prepare for tests of find() and aggregate(). @@ -104,7 +99,9 @@ def insert_command_default_write_concern(): def assertWriteOpsRaise(self, write_concern, expected_exception): wc = write_concern.document # Set socket timeout to avoid indefinite stalls - client = rs_or_single_client(w=wc["w"], wTimeoutMS=wc["wtimeout"], socketTimeoutMS=30000) + client = self.rs_or_single_client( + w=wc["w"], wTimeoutMS=wc["wtimeout"], socketTimeoutMS=30000 + ) db = client.get_database("pymongo_test") coll = db.test @@ -167,9 +164,9 @@ def test_raise_write_concern_error(self): @client_context.require_test_commands def test_raise_wtimeout(self): self.addCleanup(client_context.client.drop_database, "pymongo_test") - self.addCleanup(enable_replication, client_context.client) + self.addCleanup(self.enable_replication, client_context.client) # Disable replication to guarantee a wtimeout error. - disable_replication(client_context.client) + self.disable_replication(client_context.client) self.assertWriteOpsRaise(WriteConcern(w=client_context.w, wtimeout=1), WTimeoutError) @client_context.require_failCommand_fail_point @@ -209,7 +206,7 @@ def test_error_includes_errInfo(self): @client_context.require_version_min(4, 9) def test_write_error_details_exposes_errinfo(self): listener = EventListener() - client = rs_or_single_client(event_listeners=[listener]) + client = self.rs_or_single_client(event_listeners=[listener]) self.addCleanup(client.close) db = client.errinfotest self.addCleanup(client.drop_database, "errinfotest") diff --git a/test/test_retryable_reads.py b/test/test_retryable_reads.py index 9ea546ba9b..571384eb1d 100644 --- a/test/test_retryable_reads.py +++ b/test/test_retryable_reads.py @@ -37,8 +37,6 @@ EventListener, OvertCommandListener, SpecTestCreator, - rs_client, - rs_or_single_client, set_fail_point, ) from test.utils_spec_runner import SpecRunner @@ -174,7 +172,9 @@ def test_pool_paused_error_is_retryable(self): self.skipTest("Test is flakey on PyPy") cmap_listener = CMAPListener() cmd_listener = OvertCommandListener() - client = rs_or_single_client(maxPoolSize=1, event_listeners=[cmap_listener, cmd_listener]) + client = self.rs_or_single_client( + maxPoolSize=1, event_listeners=[cmap_listener, cmd_listener] + ) self.addCleanup(client.close) for _ in range(10): cmap_listener.reset() @@ -244,13 +244,13 @@ def test_retryable_reads_in_sharded_cluster_multiple_available(self): mongos_clients = [] for mongos in client_context.mongos_seeds().split(","): - client = rs_or_single_client(mongos) + client = self.rs_or_single_client(mongos) set_fail_point(client, fail_command) self.addCleanup(client.close) mongos_clients.append(client) listener = OvertCommandListener() - client = rs_or_single_client( + client = self.rs_or_single_client( client_context.mongos_seeds(), appName="retryableReadTest", event_listeners=[listener], diff --git a/test/test_retryable_writes.py b/test/test_retryable_writes.py index 45a740e844..e687372d79 100644 --- a/test/test_retryable_writes.py +++ b/test/test_retryable_writes.py @@ -30,7 +30,6 @@ EventListener, OvertCommandListener, SpecTestCreator, - rs_or_single_client, set_fail_point, ) from test.utils_spec_runner import SpecRunner @@ -189,7 +188,7 @@ def setUpClass(cls): # Speed up the tests by decreasing the heartbeat frequency. cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) cls.knobs.enable() - cls.client = rs_or_single_client(retryWrites=True) + cls.client = cls.unmanaged_rs_or_single_client(retryWrites=True) cls.db = cls.client.pymongo_test @classmethod @@ -225,7 +224,9 @@ def setUpClass(cls): cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) cls.knobs.enable() cls.listener = OvertCommandListener() - cls.client = rs_or_single_client(retryWrites=True, event_listeners=[cls.listener]) + cls.client = cls.unmanaged_rs_or_single_client( + retryWrites=True, event_listeners=[cls.listener] + ) cls.db = cls.client.pymongo_test @classmethod @@ -248,7 +249,7 @@ def tearDown(self): def test_supported_single_statement_no_retry(self): listener = OvertCommandListener() - client = rs_or_single_client(retryWrites=False, event_listeners=[listener]) + client = self.rs_or_single_client(retryWrites=False, event_listeners=[listener]) self.addCleanup(client.close) for method, args, kwargs in retryable_single_statement_ops(client.db.retryable_write_test): msg = f"{method.__name__}(*{args!r}, **{kwargs!r})" @@ -361,7 +362,7 @@ def test_retry_timeout_raises_original_error(self): original error. """ listener = OvertCommandListener() - client = rs_or_single_client(retryWrites=True, event_listeners=[listener]) + client = self.rs_or_single_client(retryWrites=True, event_listeners=[listener]) self.addCleanup(client.close) topology = client._topology select_server = topology.select_server @@ -487,13 +488,13 @@ def test_retryable_writes_in_sharded_cluster_multiple_available(self): mongos_clients = [] for mongos in client_context.mongos_seeds().split(","): - client = rs_or_single_client(mongos) + client = self.rs_or_single_client(mongos) set_fail_point(client, fail_command) self.addCleanup(client.close) mongos_clients.append(client) listener = OvertCommandListener() - client = rs_or_single_client( + client = self.rs_or_single_client( client_context.mongos_seeds(), appName="retryableWriteTest", event_listeners=[listener], @@ -536,7 +537,7 @@ def setUpClass(cls): @client_knobs(heartbeat_frequency=0.05, min_heartbeat_interval=0.05) def test_RetryableWriteError_error_label(self): listener = OvertCommandListener() - client = rs_or_single_client(retryWrites=True, event_listeners=[listener]) + client = self.rs_or_single_client(retryWrites=True, event_listeners=[listener]) self.addCleanup(client.close) # Ensure collection exists. @@ -595,7 +596,9 @@ class TestPoolPausedError(IntegrationTest): def test_pool_paused_error_is_retryable(self): cmap_listener = CMAPListener() cmd_listener = OvertCommandListener() - client = rs_or_single_client(maxPoolSize=1, event_listeners=[cmap_listener, cmd_listener]) + client = self.rs_or_single_client( + maxPoolSize=1, event_listeners=[cmap_listener, cmd_listener] + ) self.addCleanup(client.close) for _ in range(10): cmap_listener.reset() @@ -657,7 +660,7 @@ def test_returns_original_error_code( self, ): cmd_listener = InsertEventListener() - client = rs_or_single_client(retryWrites=True, event_listeners=[cmd_listener]) + client = self.rs_or_single_client(retryWrites=True, event_listeners=[cmd_listener]) client.test.test.drop() self.addCleanup(client.close) cmd_listener.reset() @@ -694,7 +697,7 @@ def test_increment_transaction_id_without_sending_command(self): the first attempt fails before sending the command. """ listener = OvertCommandListener() - client = rs_or_single_client(retryWrites=True, event_listeners=[listener]) + client = self.rs_or_single_client(retryWrites=True, event_listeners=[listener]) self.addCleanup(client.close) topology = client._topology select_server = topology.select_server diff --git a/test/test_sdam_monitoring_spec.py b/test/test_sdam_monitoring_spec.py index 8e0a3cbbb4..81b208d511 100644 --- a/test/test_sdam_monitoring_spec.py +++ b/test/test_sdam_monitoring_spec.py @@ -25,7 +25,6 @@ from test import IntegrationTest, client_context, client_knobs, unittest from test.utils import ( ServerAndTopologyEventListener, - rs_or_single_client, server_name_to_type, wait_until, ) @@ -279,7 +278,7 @@ def setUpClass(cls): cls.knobs.enable() cls.listener = ServerAndTopologyEventListener() retry_writes = client_context.supports_transactions() - cls.test_client = rs_or_single_client( + cls.test_client = cls.unmanaged_rs_or_single_client( event_listeners=[cls.listener], retryWrites=retry_writes ) cls.coll = cls.test_client[cls.client.db.name].test diff --git a/test/test_server_selection.py b/test/test_server_selection.py index d3526617f6..67e9716bf4 100644 --- a/test/test_server_selection.py +++ b/test/test_server_selection.py @@ -33,7 +33,6 @@ from test.utils import ( EventListener, FunctionCallRecorder, - rs_or_single_client, wait_until, ) from test.utils_selection_tests import ( @@ -76,7 +75,9 @@ def custom_selector(servers): # Initialize client with appropriate listeners. listener = EventListener() - client = rs_or_single_client(server_selector=custom_selector, event_listeners=[listener]) + client = self.rs_or_single_client( + server_selector=custom_selector, event_listeners=[listener] + ) self.addCleanup(client.close) coll = client.get_database("testdb", read_preference=ReadPreference.NEAREST).coll self.addCleanup(client.drop_database, "testdb") @@ -117,7 +118,7 @@ def test_selector_called(self): selector = FunctionCallRecorder(lambda x: x) # Client setup. - mongo_client = rs_or_single_client(server_selector=selector) + mongo_client = self.rs_or_single_client(server_selector=selector) test_collection = mongo_client.testdb.test_collection self.addCleanup(mongo_client.close) self.addCleanup(mongo_client.drop_database, "testdb") diff --git a/test/test_server_selection_in_window.py b/test/test_server_selection_in_window.py index 9dced595c9..8e030f61e8 100644 --- a/test/test_server_selection_in_window.py +++ b/test/test_server_selection_in_window.py @@ -22,7 +22,6 @@ OvertCommandListener, SpecTestCreator, get_pool, - rs_client, wait_until, ) from test.utils_selection_tests import create_topology @@ -134,7 +133,7 @@ def test_load_balancing(self): listener = OvertCommandListener() # PYTHON-2584: Use a large localThresholdMS to avoid the impact of # varying RTTs. - client = rs_client( + client = self.rs_client( client_context.mongos_seeds(), appName="loadBalancingTest", event_listeners=[listener], diff --git a/test/test_session.py b/test/test_session.py index 6988ef8667..d0f5c6e6d9 100644 --- a/test/test_session.py +++ b/test/test_session.py @@ -87,7 +87,7 @@ def _setup_class(cls): super()._setup_class() # Create a second client so we can make sure clients cannot share # sessions. - cls.client2 = PyMongoTestCase.unmanaged_rs_or_single_client() + cls.client2 = cls.unmanaged_rs_or_single_client() # Redact no commands, so we can test user-admin commands have "lsid". cls.sensitive_commands = monitoring._SENSITIVE_COMMANDS.copy() @@ -835,7 +835,7 @@ class TestCausalConsistency(UnitTest): @classmethod def _setup_class(cls): cls.listener = SessionTestListener() - cls.client = PyMongoTestCase.unmanaged_rs_or_single_client(event_listeners=[cls.listener]) + cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener]) @classmethod def _tearDown_class(cls): diff --git a/test/test_streaming_protocol.py b/test/test_streaming_protocol.py index 9bca899a48..b3b68703a4 100644 --- a/test/test_streaming_protocol.py +++ b/test/test_streaming_protocol.py @@ -24,8 +24,6 @@ from test.utils import ( HeartbeatEventListener, ServerEventListener, - rs_or_single_client, - single_client, wait_until, ) @@ -38,7 +36,7 @@ class TestStreamingProtocol(IntegrationTest): def test_failCommand_streaming(self): listener = ServerEventListener() hb_listener = HeartbeatEventListener() - client = rs_or_single_client( + client = self.rs_or_single_client( event_listeners=[listener, hb_listener], heartbeatFrequencyMS=500, appName="failingHeartbeatTest", @@ -107,7 +105,7 @@ def test_streaming_rtt(self): }, } with self.fail_point(delay_hello): - client = rs_or_single_client( + client = self.rs_or_single_client( event_listeners=[listener, hb_listener], heartbeatFrequencyMS=500, appName=name ) self.addCleanup(client.close) @@ -155,7 +153,7 @@ def test_monitor_waits_after_server_check_error(self): } with self.fail_point(fail_hello): start = time.time() - client = single_client( + client = self.single_client( appName="SDAMMinHeartbeatFrequencyTest", serverSelectionTimeoutMS=5000 ) self.addCleanup(client.close) @@ -180,7 +178,7 @@ def test_monitor_waits_after_server_check_error(self): @client_context.require_failCommand_appName def test_heartbeat_awaited_flag(self): hb_listener = HeartbeatEventListener() - client = single_client( + client = self.single_client( event_listeners=[hb_listener], heartbeatFrequencyMS=500, appName="heartbeatEventAwaitedFlag", diff --git a/test/test_typing.py b/test/test_typing.py index f423b70a3e..7eb6f80460 100644 --- a/test/test_typing.py +++ b/test/test_typing.py @@ -69,7 +69,6 @@ class ImplicitMovie(TypedDict): sys.path[0:0] = [""] from test import IntegrationTest, client_context -from test.utils import rs_or_single_client from bson import CodecOptions, decode, decode_all, decode_file_iter, decode_iter, encode from bson.raw_bson import RawBSONDocument @@ -194,7 +193,7 @@ def test_list_databases(self) -> None: value.items() def test_default_document_type(self) -> None: - client = rs_or_single_client() + client = self.rs_or_single_client() self.addCleanup(client.close) coll = client.test.test doc = {"my": "doc"} @@ -480,7 +479,7 @@ def test_typeddict_empty_document_type(self) -> None: def test_typeddict_find_notrequired(self): if NotRequired is None or ImplicitMovie is None: raise unittest.SkipTest("Python 3.11+ is required to use NotRequired.") - client: MongoClient[ImplicitMovie] = rs_or_single_client() + client: MongoClient[ImplicitMovie] = self.rs_or_single_client() coll = client.test.test coll.insert_one(ImplicitMovie(name="THX-1138", year=1971)) out = coll.find_one({}) diff --git a/test/test_versioned_api.py b/test/test_versioned_api.py index 7fe8ebd76f..7a25a507dc 100644 --- a/test/test_versioned_api.py +++ b/test/test_versioned_api.py @@ -20,7 +20,7 @@ from test import IntegrationTest, client_context, unittest from test.unified_format import generate_test_classes -from test.utils import OvertCommandListener, rs_or_single_client +from test.utils import OvertCommandListener from pymongo.server_api import ServerApi, ServerApiVersion from pymongo.synchronous.mongo_client import MongoClient @@ -77,7 +77,7 @@ def assertServerApiInAllCommands(self, events): @client_context.require_version_min(4, 7) def test_command_options(self): listener = OvertCommandListener() - client = rs_or_single_client(server_api=ServerApi("1"), event_listeners=[listener]) + client = self.rs_or_single_client(server_api=ServerApi("1"), event_listeners=[listener]) self.addCleanup(client.close) coll = client.test.test coll.insert_many([{} for _ in range(100)]) @@ -90,7 +90,7 @@ def test_command_options(self): @client_context.require_transactions def test_command_options_txn(self): listener = OvertCommandListener() - client = rs_or_single_client(server_api=ServerApi("1"), event_listeners=[listener]) + client = self.rs_or_single_client(server_api=ServerApi("1"), event_listeners=[listener]) self.addCleanup(client.close) coll = client.test.test coll.insert_many([{} for _ in range(100)]) From 96519b8a8789eea4c0d066eac9a28b05dceee766 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Thu, 12 Sep 2024 11:45:58 -0400 Subject: [PATCH 08/29] Reduce test_client diff --- test/__init__.py | 4 +- test/asynchronous/__init__.py | 4 +- test/asynchronous/test_client.py | 490 ++++++++++++++----------------- test/test_client.py | 480 ++++++++++++++---------------- 4 files changed, 449 insertions(+), 529 deletions(-) diff --git a/test/__init__.py b/test/__init__.py index 355f06426d..38d705ce46 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -1072,8 +1072,8 @@ def rs_or_single_client(self, h: Any = None, p: Any = None, **kwargs: Any) -> Mo """ return self._async_mongo_client(h, p, **kwargs) - def simple_client(self, **kwargs: Any) -> MongoClient: - client = MongoClient(**kwargs) + def simple_client(self, h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient: + client = MongoClient(h, p, **kwargs) self.addCleanup(client.close) return client diff --git a/test/asynchronous/__init__.py b/test/asynchronous/__init__.py index a3b7e18507..e807720c05 100644 --- a/test/asynchronous/__init__.py +++ b/test/asynchronous/__init__.py @@ -1088,8 +1088,8 @@ async def async_rs_or_single_client( """ return await self._async_mongo_client(h, p, **kwargs) - def simple_client(self, **kwargs: Any) -> AsyncMongoClient: - client = AsyncMongoClient(**kwargs) + def simple_client(self, h: Any = None, p: Any = None, **kwargs: Any) -> AsyncMongoClient: + client = AsyncMongoClient(h, p, **kwargs) self.addAsyncCleanup(client.close) return client diff --git a/test/asynchronous/test_client.py b/test/asynchronous/test_client.py index b434bd8ac2..29864d2ca6 100644 --- a/test/asynchronous/test_client.py +++ b/test/asynchronous/test_client.py @@ -142,7 +142,7 @@ def inject_fixtures(self, caplog): self._caplog = caplog async def test_keyword_arg_defaults(self): - async with AsyncMongoClient( + client = self.simple_client( socketTimeoutMS=None, connectTimeoutMS=20000, waitQueueTimeoutMS=None, @@ -154,36 +154,36 @@ async def test_keyword_arg_defaults(self): tlsCAFile=None, connect=False, serverSelectionTimeoutMS=12000, - ) as client: - options = client.options - pool_opts = options.pool_options - self.assertEqual(None, pool_opts.socket_timeout) - # socket.Socket.settimeout takes a float in seconds - self.assertEqual(20.0, pool_opts.connect_timeout) - self.assertEqual(None, pool_opts.wait_queue_timeout) - self.assertEqual(None, pool_opts._ssl_context) - self.assertEqual(None, options.replica_set_name) - self.assertEqual(ReadPreference.PRIMARY, client.read_preference) - self.assertAlmostEqual(12, client.options.server_selection_timeout) + ) + + options = client.options + pool_opts = options.pool_options + self.assertEqual(None, pool_opts.socket_timeout) + # socket.Socket.settimeout takes a float in seconds + self.assertEqual(20.0, pool_opts.connect_timeout) + self.assertEqual(None, pool_opts.wait_queue_timeout) + self.assertEqual(None, pool_opts._ssl_context) + self.assertEqual(None, options.replica_set_name) + self.assertEqual(ReadPreference.PRIMARY, client.read_preference) + self.assertAlmostEqual(12, client.options.server_selection_timeout) async def test_connect_timeout(self): - client = AsyncMongoClient(connect=False, connectTimeoutMS=None, socketTimeoutMS=None) + client = self.simple_client(connect=False, connectTimeoutMS=None, socketTimeoutMS=None) pool_opts = client.options.pool_options self.assertEqual(None, pool_opts.socket_timeout) self.assertEqual(None, pool_opts.connect_timeout) - await client.close() - client = AsyncMongoClient(connect=False, connectTimeoutMS=0, socketTimeoutMS=0) + + client = self.simple_client(connect=False, connectTimeoutMS=0, socketTimeoutMS=0) pool_opts = client.options.pool_options self.assertEqual(None, pool_opts.socket_timeout) self.assertEqual(None, pool_opts.connect_timeout) - await client.close() - client = AsyncMongoClient( + + client = await self.async_single_client( "mongodb://localhost/?connectTimeoutMS=0&socketTimeoutMS=0", connect=False ) pool_opts = client.options.pool_options self.assertEqual(None, pool_opts.socket_timeout) self.assertEqual(None, pool_opts.connect_timeout) - await client.close() def test_types(self): self.assertRaises(TypeError, AsyncMongoClient, 1) @@ -195,8 +195,7 @@ def test_types(self): self.assertRaises(ConfigurationError, AsyncMongoClient, []) async def test_max_pool_size_zero(self): - async with AsyncMongoClient(maxPoolSize=0): - pass + self.simple_client(maxPoolSize=0) def test_uri_detection(self): self.assertRaises(ConfigurationError, AsyncMongoClient, "/foo") @@ -261,38 +260,36 @@ def test_iteration(self): self.assertNotIsInstance(client, Iterable) async def test_get_default_database(self): - async with await self.async_rs_or_single_client( + c = await self.async_rs_or_single_client( "mongodb://%s:%d/foo" % (await async_client_context.host, await async_client_context.port), connect=False, - ) as c: - self.assertEqual(AsyncDatabase(c, "foo"), c.get_default_database()) - # Test that default doesn't override the URI value. - self.assertEqual(AsyncDatabase(c, "foo"), c.get_default_database("bar")) - - codec_options = CodecOptions(tz_aware=True) - write_concern = WriteConcern(w=2, j=True) - db = c.get_default_database( - None, codec_options, ReadPreference.SECONDARY, write_concern - ) - self.assertEqual("foo", db.name) - self.assertEqual(codec_options, db.codec_options) - self.assertEqual(ReadPreference.SECONDARY, db.read_preference) - self.assertEqual(write_concern, db.write_concern) + ) + self.assertEqual(AsyncDatabase(c, "foo"), c.get_default_database()) + # Test that default doesn't override the URI value. + self.assertEqual(AsyncDatabase(c, "foo"), c.get_default_database("bar")) - async with await self.async_rs_or_single_client( + codec_options = CodecOptions(tz_aware=True) + write_concern = WriteConcern(w=2, j=True) + db = c.get_default_database(None, codec_options, ReadPreference.SECONDARY, write_concern) + self.assertEqual("foo", db.name) + self.assertEqual(codec_options, db.codec_options) + self.assertEqual(ReadPreference.SECONDARY, db.read_preference) + self.assertEqual(write_concern, db.write_concern) + + c = await self.async_rs_or_single_client( "mongodb://%s:%d/" % (await async_client_context.host, await async_client_context.port), connect=False, - ) as c: - self.assertEqual(AsyncDatabase(c, "foo"), c.get_default_database("foo")) + ) + self.assertEqual(AsyncDatabase(c, "foo"), c.get_default_database("foo")) async def test_get_default_database_error(self): # URI with no database. - async with await self.async_rs_or_single_client( + c = await self.async_rs_or_single_client( "mongodb://%s:%d/" % (await async_client_context.host, await async_client_context.port), connect=False, - ) as c: - self.assertRaises(ConfigurationError, c.get_default_database) + ) + self.assertRaises(ConfigurationError, c.get_default_database) async def test_get_default_database_with_authsource(self): # Ensure we distinguish database name from authSource. @@ -300,16 +297,16 @@ async def test_get_default_database_with_authsource(self): await async_client_context.host, await async_client_context.port, ) - async with await self.async_rs_or_single_client(uri, connect=False) as c: - self.assertEqual(AsyncDatabase(c, "foo"), c.get_default_database()) + c = await self.async_rs_or_single_client(uri, connect=False) + self.assertEqual(AsyncDatabase(c, "foo"), c.get_default_database()) async def test_get_database_default(self): - async with await self.async_rs_or_single_client( + c = await self.async_rs_or_single_client( "mongodb://%s:%d/foo" % (await async_client_context.host, await async_client_context.port), connect=False, - ) as c: - self.assertEqual(AsyncDatabase(c, "foo"), c.get_database()) + ) + self.assertEqual(AsyncDatabase(c, "foo"), c.get_database()) async def test_get_database_default_error(self): # URI with no database. @@ -318,7 +315,6 @@ async def test_get_database_default_error(self): connect=False, ) self.assertRaises(ConfigurationError, c.get_database) - await c.close() async def test_get_database_default_with_authsource(self): # Ensure we distinguish database name from authSource. @@ -328,92 +324,88 @@ async def test_get_database_default_with_authsource(self): ) c = await self.async_rs_or_single_client(uri, connect=False) self.assertEqual(AsyncDatabase(c, "foo"), c.get_database()) - await c.close() - def test_primary_read_pref_with_tags(self): + async def test_primary_read_pref_with_tags(self): # No tags allowed with "primary". with self.assertRaises(ConfigurationError): - AsyncMongoClient("mongodb://host/?readpreferencetags=dc:east") + await self.async_single_client("mongodb://host/?readpreferencetags=dc:east") with self.assertRaises(ConfigurationError): - AsyncMongoClient("mongodb://host/?readpreference=primary&readpreferencetags=dc:east") + await self.async_single_client( + "mongodb://host/?readpreference=primary&readpreferencetags=dc:east" + ) async def test_read_preference(self): - async with await self.async_rs_or_single_client( + c = await self.async_rs_or_single_client( "mongodb://host", connect=False, readpreference=ReadPreference.NEAREST.mongos_mode - ) as c: - self.assertEqual(c.read_preference, ReadPreference.NEAREST) + ) + self.assertEqual(c.read_preference, ReadPreference.NEAREST) async def test_metadata(self): metadata = copy.deepcopy(_METADATA) metadata["driver"]["name"] = "PyMongo|async" metadata["application"] = {"name": "foobar"} - async with AsyncMongoClient("mongodb://foo:27017/?appname=foobar&connect=false") as client: - options = client.options - self.assertEqual(options.pool_options.metadata, metadata) - async with AsyncMongoClient("foo", 27017, appname="foobar", connect=False) as client: - options = client.options - self.assertEqual(options.pool_options.metadata, metadata) + client = await self.async_single_client("mongodb://foo:27017/?appname=foobar&connect=false") + options = client.options + self.assertEqual(options.pool_options.metadata, metadata) + client = await self.async_single_client("foo", 27017, appname="foobar", connect=False) + options = client.options + self.assertEqual(options.pool_options.metadata, metadata) # No error - async with AsyncMongoClient(appname="x" * 128): - pass + self.simple_client(appname="x" * 128) with self.assertRaises(ValueError): - async with AsyncMongoClient(appname="x" * 129): - pass + self.simple_client(appname="x" * 129) # Bad "driver" options. self.assertRaises(TypeError, DriverInfo, "Foo", 1, "a") self.assertRaises(TypeError, DriverInfo, version="1", platform="a") self.assertRaises(TypeError, DriverInfo) with self.assertRaises(TypeError): - async with AsyncMongoClient(driver=1): - pass + self.simple_client(driver=1) with self.assertRaises(TypeError): - async with AsyncMongoClient(driver="abc"): - pass + self.simple_client(driver="abc") with self.assertRaises(TypeError): - async with AsyncMongoClient(driver=("Foo", "1", "a")): - pass + self.simple_client(driver=("Foo", "1", "a")) # Test appending to driver info. metadata["driver"]["name"] = "PyMongo|async|FooDriver" metadata["driver"]["version"] = "{}|1.2.3".format(_METADATA["driver"]["version"]) - async with AsyncMongoClient( + client = await self.async_single_client( "foo", 27017, appname="foobar", driver=DriverInfo("FooDriver", "1.2.3", None), connect=False, - ) as client: - options = client.options - self.assertEqual(options.pool_options.metadata, metadata) + ) + options = client.options + self.assertEqual(options.pool_options.metadata, metadata) metadata["platform"] = "{}|FooPlatform".format(_METADATA["platform"]) - async with AsyncMongoClient( + client = await self.async_single_client( "foo", 27017, appname="foobar", driver=DriverInfo("FooDriver", "1.2.3", "FooPlatform"), connect=False, - ) as client: - options = client.options - self.assertEqual(options.pool_options.metadata, metadata) + ) + options = client.options + self.assertEqual(options.pool_options.metadata, metadata) # Test truncating driver info metadata. - async with AsyncMongoClient( + client = await self.async_single_client( driver=DriverInfo(name="s" * _MAX_METADATA_SIZE), connect=False, - ) as client: - options = client.options - self.assertLessEqual( - len(bson.encode(options.pool_options.metadata)), - _MAX_METADATA_SIZE, - ) - async with AsyncMongoClient( + ) + options = client.options + self.assertLessEqual( + len(bson.encode(options.pool_options.metadata)), + _MAX_METADATA_SIZE, + ) + client = await self.async_single_client( driver=DriverInfo(name="s" * _MAX_METADATA_SIZE, version="s" * _MAX_METADATA_SIZE), connect=False, - ) as client: - options = client.options - self.assertLessEqual( - len(bson.encode(options.pool_options.metadata)), - _MAX_METADATA_SIZE, - ) + ) + options = client.options + self.assertLessEqual( + len(bson.encode(options.pool_options.metadata)), + _MAX_METADATA_SIZE, + ) @mock.patch.dict("os.environ", {ENV_VAR_K8S: "1"}) async def test_container_metadata(self): @@ -421,10 +413,9 @@ async def test_container_metadata(self): metadata["driver"]["name"] = "PyMongo|async" metadata["env"] = {} metadata["env"]["container"] = {"orchestrator": "kubernetes"} - client = AsyncMongoClient("mongodb://foo:27017/?appname=foobar&connect=false") + client = self.simple_client("mongodb://foo:27017/?appname=foobar&connect=false") options = client.options self.assertEqual(options.pool_options.metadata["env"], metadata["env"]) - await client.close() async def test_kwargs_codec_options(self): class MyFloatType: @@ -448,7 +439,7 @@ def transform_python(self, value): uuid_representation_label = "javaLegacy" unicode_decode_error_handler = "ignore" tzinfo = utc - async with AsyncMongoClient( + c = self.simple_client( document_class=document_class, type_registry=type_registry, tz_aware=tz_aware, @@ -456,18 +447,16 @@ def transform_python(self, value): unicode_decode_error_handler=unicode_decode_error_handler, tzinfo=tzinfo, connect=False, - ) as c: - self.assertEqual(c.codec_options.document_class, document_class) - self.assertEqual(c.codec_options.type_registry, type_registry) - self.assertEqual(c.codec_options.tz_aware, tz_aware) - self.assertEqual( - c.codec_options.uuid_representation, - _UUID_REPRESENTATIONS[uuid_representation_label], - ) - self.assertEqual( - c.codec_options.unicode_decode_error_handler, unicode_decode_error_handler - ) - self.assertEqual(c.codec_options.tzinfo, tzinfo) + ) + self.assertEqual(c.codec_options.document_class, document_class) + self.assertEqual(c.codec_options.type_registry, type_registry) + self.assertEqual(c.codec_options.tz_aware, tz_aware) + self.assertEqual( + c.codec_options.uuid_representation, + _UUID_REPRESENTATIONS[uuid_representation_label], + ) + self.assertEqual(c.codec_options.unicode_decode_error_handler, unicode_decode_error_handler) + self.assertEqual(c.codec_options.tzinfo, tzinfo) async def test_uri_codec_options(self): # Ensure codec options are passed in correctly @@ -486,38 +475,36 @@ async def test_uri_codec_options(self): datetime_conversion, ) ) - async with AsyncMongoClient(uri, connect=False) as c: - self.assertEqual(c.codec_options.tz_aware, True) - self.assertEqual( - c.codec_options.uuid_representation, - _UUID_REPRESENTATIONS[uuid_representation_label], - ) - self.assertEqual( - c.codec_options.unicode_decode_error_handler, unicode_decode_error_handler - ) - self.assertEqual( - c.codec_options.datetime_conversion, DatetimeConversion[datetime_conversion] - ) + c = self.simple_client(uri, connect=False) + self.assertEqual(c.codec_options.tz_aware, True) + self.assertEqual( + c.codec_options.uuid_representation, + _UUID_REPRESENTATIONS[uuid_representation_label], + ) + self.assertEqual(c.codec_options.unicode_decode_error_handler, unicode_decode_error_handler) + self.assertEqual( + c.codec_options.datetime_conversion, DatetimeConversion[datetime_conversion] + ) # Change the passed datetime_conversion to a number and re-assert. uri = uri.replace(datetime_conversion, f"{int(DatetimeConversion[datetime_conversion])}") - async with AsyncMongoClient(uri, connect=False) as c: - self.assertEqual( - c.codec_options.datetime_conversion, DatetimeConversion[datetime_conversion] - ) + c = self.simple_client(uri, connect=False) + self.assertEqual( + c.codec_options.datetime_conversion, DatetimeConversion[datetime_conversion] + ) async def test_uri_option_precedence(self): # Ensure kwarg options override connection string options. uri = "mongodb://localhost/?ssl=true&replicaSet=name&readPreference=primary" - async with AsyncMongoClient( + c = self.simple_client( uri, ssl=False, replicaSet="newname", readPreference="secondaryPreferred" - ) as c: - clopts = c.options - opts = clopts._options + ) + clopts = c.options + opts = clopts._options - self.assertEqual(opts["tls"], False) - self.assertEqual(clopts.replica_set_name, "newname") - self.assertEqual(clopts.read_preference, ReadPreference.SECONDARY_PREFERRED) + self.assertEqual(opts["tls"], False) + self.assertEqual(clopts.replica_set_name, "newname") + self.assertEqual(clopts.read_preference, ReadPreference.SECONDARY_PREFERRED) async def test_connection_timeout_ms_propagates_to_DNS_resolver(self): # Patch the resolver. @@ -540,9 +527,9 @@ def reset_resolver(): async def test_scenario(args, kwargs, expected_value): patched_resolver.reset() - async with AsyncMongoClient(*args, **kwargs): - for _, kw in patched_resolver.call_list(): - self.assertAlmostEqual(kw["lifetime"], expected_value) + self.simple_client(*args, **kwargs) + for _, kw in patched_resolver.call_list(): + self.assertAlmostEqual(kw["lifetime"], expected_value) # No timeout specified. await test_scenario((base_uri,), {}, CONNECT_TIMEOUT) @@ -560,60 +547,53 @@ async def test_scenario(args, kwargs, expected_value): async def test_uri_security_options(self): # Ensure that we don't silently override security-related options. with self.assertRaises(InvalidURI): - async with AsyncMongoClient("mongodb://localhost/?ssl=true", tls=False, connect=False): - pass + self.simple_client("mongodb://localhost/?ssl=true", tls=False, connect=False) # Matching SSL and TLS options should not cause errors. - async with AsyncMongoClient( - "mongodb://localhost/?ssl=false", tls=False, connect=False - ) as c: - self.assertEqual(c.options._options["tls"], False) + c = self.simple_client("mongodb://localhost/?ssl=false", tls=False, connect=False) + self.assertEqual(c.options._options["tls"], False) # Conflicting tlsInsecure options should raise an error. with self.assertRaises(InvalidURI): - async with AsyncMongoClient( + self.simple_client( "mongodb://localhost/?tlsInsecure=true", connect=False, tlsAllowInvalidHostnames=True, - ): - pass + ) # Conflicting legacy tlsInsecure options should also raise an error. with self.assertRaises(InvalidURI): - async with AsyncMongoClient( + self.simple_client( "mongodb://localhost/?tlsInsecure=true", connect=False, tlsAllowInvalidCertificates=False, - ): - pass + ) # Conflicting kwargs should raise InvalidURI with self.assertRaises(InvalidURI): - async with AsyncMongoClient(ssl=True, tls=False): - pass + self.simple_client(ssl=True, tls=False) async def test_event_listeners(self): - async with AsyncMongoClient(event_listeners=[], connect=False) as c: - self.assertEqual(c.options.event_listeners, []) - listeners = [ - event_loggers.CommandLogger(), - event_loggers.HeartbeatLogger(), - event_loggers.ServerLogger(), - event_loggers.TopologyLogger(), - event_loggers.ConnectionPoolLogger(), - ] - async with AsyncMongoClient(event_listeners=listeners, connect=False) as c: - self.assertEqual(c.options.event_listeners, listeners) + c = self.simple_client(event_listeners=[], connect=False) + self.assertEqual(c.options.event_listeners, []) + listeners = [ + event_loggers.CommandLogger(), + event_loggers.HeartbeatLogger(), + event_loggers.ServerLogger(), + event_loggers.TopologyLogger(), + event_loggers.ConnectionPoolLogger(), + ] + c = self.simple_client(event_listeners=listeners, connect=False) + self.assertEqual(c.options.event_listeners, listeners) async def test_client_options(self): - c = AsyncMongoClient(connect=False) + c = self.simple_client(connect=False) self.assertIsInstance(c.options, ClientOptions) self.assertIsInstance(c.options.pool_options, PoolOptions) self.assertEqual(c.options.server_selection_timeout, 30) self.assertEqual(c.options.pool_options.max_idle_time_seconds, None) self.assertIsInstance(c.options.retry_writes, bool) self.assertIsInstance(c.options.retry_reads, bool) - await c.close() def test_validate_suggestion(self): """Validate kwargs in constructor.""" @@ -659,16 +639,13 @@ async def test_detected_environment_warning(self, mock_get_hosts): ) for host in normal_hosts: with self.assertWarns(UserWarning): - async with AsyncMongoClient(host): - pass + self.simple_client(host) for host in srv_hosts: mock_get_hosts.return_value = [(host, 1)] with self.assertWarns(UserWarning): - async with AsyncMongoClient(host): - pass + self.simple_client(host) with self.assertWarns(UserWarning): - async with AsyncMongoClient(multi_host): - pass + self.simple_client(multi_host) class TestClient(AsyncIntegrationTest): @@ -693,7 +670,6 @@ async def test_max_idle_time_reaper_default(self): pass self.assertEqual(1, len(server._pool.conns)) self.assertTrue(conn in server._pool.conns) - await client.close() async def test_max_idle_time_reaper_removes_stale_minPoolSize(self): with client_knobs(kill_cursor_frequency=0.1): @@ -709,7 +685,6 @@ async def test_max_idle_time_reaper_removes_stale_minPoolSize(self): self.assertGreaterEqual(len(server._pool.conns), 1) wait_until(lambda: conn not in server._pool.conns, "remove stale socket") wait_until(lambda: len(server._pool.conns) >= 1, "replace stale socket") - await client.close() async def test_max_idle_time_reaper_does_not_exceed_maxPoolSize(self): with client_knobs(kill_cursor_frequency=0.1): @@ -727,7 +702,6 @@ async def test_max_idle_time_reaper_does_not_exceed_maxPoolSize(self): self.assertEqual(1, len(server._pool.conns)) wait_until(lambda: conn not in server._pool.conns, "remove stale socket") wait_until(lambda: len(server._pool.conns) == 1, "replace stale socket") - await client.close() async def test_max_idle_time_reaper_removes_stale(self): with client_knobs(kill_cursor_frequency=0.1): @@ -747,7 +721,6 @@ async def test_max_idle_time_reaper_removes_stale(self): lambda: len(server._pool.conns) == 0, "stale socket reaped and new one NOT added to the pool", ) - await client.close() async def test_min_pool_size(self): with client_knobs(kill_cursor_frequency=0.1): @@ -825,20 +798,20 @@ async def test_constants(self): AsyncMongoClient.HOST = "somedomainthatdoesntexist.org" AsyncMongoClient.PORT = 123456789 with self.assertRaises(AutoReconnect): - async with AsyncMongoClient(serverSelectionTimeoutMS=10, **kwargs) as c: - await connected(c) - - async with AsyncMongoClient(host, port, **kwargs) as c: - # Override the defaults. No error. + c = self.simple_client(serverSelectionTimeoutMS=10, **kwargs) await connected(c) + c = self.simple_client(host, port, **kwargs) + # Override the defaults. No error. + await connected(c) + # Set good defaults. AsyncMongoClient.HOST = host AsyncMongoClient.PORT = port # No error. - async with AsyncMongoClient(**kwargs) as c: - await connected(c) + c = self.simple_client(**kwargs) + await connected(c) async def test_init_disconnected(self): host, port = await async_client_context.host, await async_client_context.port @@ -889,27 +862,23 @@ async def test_init_disconnected_with_auth(self): async def test_equality(self): seed = "{}:{}".format(*list(self.client._topology_settings.seeds)[0]) - async with await self.async_rs_or_single_client(seed, connect=False) as c: - self.assertEqual(async_client_context.client, c) - # Explicitly test inequality - self.assertFalse(async_client_context.client != c) + c = await self.async_rs_or_single_client(seed, connect=False) + self.assertEqual(async_client_context.client, c) + # Explicitly test inequality + self.assertFalse(async_client_context.client != c) - async with await self.async_rs_or_single_client("invalid.com", connect=False) as c: - self.assertNotEqual(async_client_context.client, c) - self.assertTrue(async_client_context.client != c) + c = await self.async_rs_or_single_client("invalid.com", connect=False) + self.assertNotEqual(async_client_context.client, c) + self.assertTrue(async_client_context.client != c) - c1 = AsyncMongoClient("a", connect=False) - c2 = AsyncMongoClient("b", connect=False) - self.addAsyncCleanup(c1.close) - self.addAsyncCleanup(c2.close) + c1 = self.simple_client("a", connect=False) + c2 = self.simple_client("b", connect=False) # Seeds differ: self.assertNotEqual(c1, c2) - c1 = AsyncMongoClient(["a", "b", "c"], connect=False) - c2 = AsyncMongoClient(["c", "a", "b"], connect=False) - self.addAsyncCleanup(c1.close) - self.addAsyncCleanup(c2.close) + c1 = self.simple_client(["a", "b", "c"], connect=False) + c2 = self.simple_client(["c", "a", "b"], connect=False) # Same seeds but out of order still compares equal: self.assertEqual(c1, c2) @@ -1084,7 +1053,6 @@ async def test_close_kills_cursors(self): # The killCursors task should not need to re-open the topology. await test_client._process_periodic_tasks() self.assertTrue(test_client._topology._opened) - await test_client.close() async def test_close_stops_kill_cursors_thread(self): client = await self.async_rs_client() @@ -1116,9 +1084,6 @@ async def test_uri_connect_option(self): kc_thread = client._kill_cursors_executor._thread self.assertTrue(kc_thread and kc_thread.is_alive()) - # Tear down. - await client.close() - async def test_close_does_not_open_servers(self): client = await self.async_rs_client(connect=False) topology = client._topology @@ -1245,10 +1210,10 @@ async def test_unix_socket(self): # Confirm it fails with a missing socket. with self.assertRaises(ConnectionFailure): - async with AsyncMongoClient( + c = self.simple_client( "mongodb://%2Ftmp%2Fnon-existent.sock", serverSelectionTimeoutMS=100 - ) as c: - await connected(c) + ) + await connected(c) async def test_document_class(self): c = self.client @@ -1613,11 +1578,9 @@ async def test_auth_network_error(self): @async_client_context.require_no_replica_set async def test_connect_to_standalone_using_replica_set_name(self): - async with await self.async_single_client( - replicaSet="anything", serverSelectionTimeoutMS=100 - ) as client: - with self.assertRaises(AutoReconnect): - await client.test.test.find_one() + client = await self.async_single_client(replicaSet="anything", serverSelectionTimeoutMS=100) + with self.assertRaises(AutoReconnect): + await client.test.test.find_one() @async_client_context.require_replica_set async def test_stale_getmore(self): @@ -1673,7 +1636,7 @@ def init(self, *args): await async_client_context.host, await async_client_context.port, ) - client = await self.async_single_client(uri, event_listeners=[listener]) + await self.async_single_client(uri, event_listeners=[listener]) wait_until( lambda: len(listener.results) >= 2, "record two ServerHeartbeatStartedEvents" ) @@ -1682,7 +1645,6 @@ def init(self, *args): # closer to 0.5 sec with heartbeatFrequencyMS configured. self.assertAlmostEqual(heartbeat_times[1] - heartbeat_times[0], 0.5, delta=2) - await client.close() finally: ServerHeartbeatStartedEvent.__init__ = old_init # type: ignore @@ -1699,84 +1661,84 @@ def compression_settings(client): return pool_options._compression_settings uri = "mongodb://localhost:27017/?compressors=zlib" - async with AsyncMongoClient(uri, connect=False) as client: - opts = compression_settings(client) - self.assertEqual(opts.compressors, ["zlib"]) + client = self.simple_client(uri, connect=False) + opts = compression_settings(client) + self.assertEqual(opts.compressors, ["zlib"]) uri = "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=4" - async with AsyncMongoClient(uri, connect=False) as client: - opts = compression_settings(client) - self.assertEqual(opts.compressors, ["zlib"]) - self.assertEqual(opts.zlib_compression_level, 4) + client = self.simple_client(uri, connect=False) + opts = compression_settings(client) + self.assertEqual(opts.compressors, ["zlib"]) + self.assertEqual(opts.zlib_compression_level, 4) uri = "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=-1" - async with AsyncMongoClient(uri, connect=False) as client: - opts = compression_settings(client) - self.assertEqual(opts.compressors, ["zlib"]) - self.assertEqual(opts.zlib_compression_level, -1) + client = self.simple_client(uri, connect=False) + opts = compression_settings(client) + self.assertEqual(opts.compressors, ["zlib"]) + self.assertEqual(opts.zlib_compression_level, -1) uri = "mongodb://localhost:27017" - async with AsyncMongoClient(uri, connect=False) as client: - opts = compression_settings(client) - self.assertEqual(opts.compressors, []) - self.assertEqual(opts.zlib_compression_level, -1) + client = self.simple_client(uri, connect=False) + opts = compression_settings(client) + self.assertEqual(opts.compressors, []) + self.assertEqual(opts.zlib_compression_level, -1) uri = "mongodb://localhost:27017/?compressors=foobar" - async with AsyncMongoClient(uri, connect=False) as client: - opts = compression_settings(client) - self.assertEqual(opts.compressors, []) - self.assertEqual(opts.zlib_compression_level, -1) + client = self.simple_client(uri, connect=False) + opts = compression_settings(client) + self.assertEqual(opts.compressors, []) + self.assertEqual(opts.zlib_compression_level, -1) uri = "mongodb://localhost:27017/?compressors=foobar,zlib" - async with AsyncMongoClient(uri, connect=False) as client: - opts = compression_settings(client) - self.assertEqual(opts.compressors, ["zlib"]) - self.assertEqual(opts.zlib_compression_level, -1) + client = self.simple_client(uri, connect=False) + opts = compression_settings(client) + self.assertEqual(opts.compressors, ["zlib"]) + self.assertEqual(opts.zlib_compression_level, -1) # According to the connection string spec, unsupported values # just raise a warning and are ignored. uri = "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=10" - async with AsyncMongoClient(uri, connect=False) as client: - opts = compression_settings(client) - self.assertEqual(opts.compressors, ["zlib"]) - self.assertEqual(opts.zlib_compression_level, -1) + client = self.simple_client(uri, connect=False) + opts = compression_settings(client) + self.assertEqual(opts.compressors, ["zlib"]) + self.assertEqual(opts.zlib_compression_level, -1) uri = "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=-2" - async with AsyncMongoClient(uri, connect=False) as client: - opts = compression_settings(client) - self.assertEqual(opts.compressors, ["zlib"]) - self.assertEqual(opts.zlib_compression_level, -1) + client = self.simple_client(uri, connect=False) + opts = compression_settings(client) + self.assertEqual(opts.compressors, ["zlib"]) + self.assertEqual(opts.zlib_compression_level, -1) if not _have_snappy(): uri = "mongodb://localhost:27017/?compressors=snappy" - async with AsyncMongoClient(uri, connect=False) as client: - opts = compression_settings(client) - self.assertEqual(opts.compressors, []) + client = self.simple_client(uri, connect=False) + opts = compression_settings(client) + self.assertEqual(opts.compressors, []) else: uri = "mongodb://localhost:27017/?compressors=snappy" - async with AsyncMongoClient(uri, connect=False) as client: - opts = compression_settings(client) - self.assertEqual(opts.compressors, ["snappy"]) + client = self.simple_client(uri, connect=False) + opts = compression_settings(client) + self.assertEqual(opts.compressors, ["snappy"]) uri = "mongodb://localhost:27017/?compressors=snappy,zlib" - async with AsyncMongoClient(uri, connect=False) as client: - opts = compression_settings(client) - self.assertEqual(opts.compressors, ["snappy", "zlib"]) + client = self.simple_client(uri, connect=False) + opts = compression_settings(client) + self.assertEqual(opts.compressors, ["snappy", "zlib"]) if not _have_zstd(): uri = "mongodb://localhost:27017/?compressors=zstd" - async with AsyncMongoClient(uri, connect=False) as client: - opts = compression_settings(client) - self.assertEqual(opts.compressors, []) + client = self.simple_client(uri, connect=False) + opts = compression_settings(client) + self.assertEqual(opts.compressors, []) else: uri = "mongodb://localhost:27017/?compressors=zstd" - async with AsyncMongoClient(uri, connect=False) as client: - opts = compression_settings(client) - self.assertEqual(opts.compressors, ["zstd"]) + client = self.simple_client(uri, connect=False) + opts = compression_settings(client) + self.assertEqual(opts.compressors, ["zstd"]) uri = "mongodb://localhost:27017/?compressors=zstd,zlib" - async with AsyncMongoClient(uri, connect=False) as client: - opts = compression_settings(client) - self.assertEqual(opts.compressors, ["zstd", "zlib"]) + client = self.simple_client(uri, connect=False) + opts = compression_settings(client) + self.assertEqual(opts.compressors, ["zstd", "zlib"]) options = async_client_context.default_client_options if "compressors" in options and "zlib" in options["compressors"]: for level in range(-1, 10): - async with await self.async_single_client(zlibcompressionlevel=level) as client: - # No error - await client.pymongo_test.test.find_one() + client = await self.async_single_client(zlibcompressionlevel=level) + # No error + await client.pymongo_test.test.find_one() async def test_reset_during_update_pool(self): client = await self.async_rs_or_single_client(minPoolSize=10) @@ -1864,7 +1826,6 @@ async def test_direct_connection(self): await client.admin.command("ping") self.assertEqual(len(client.nodes), 1) self.assertEqual(client._topology_settings.get_topology_type(), TOPOLOGY_TYPE.Single) - await client.close() # direct_connection=False should result in RS topology. client = await self.async_rs_or_single_client(directConnection=False) @@ -1874,7 +1835,6 @@ async def test_direct_connection(self): client._topology_settings.get_topology_type(), [TOPOLOGY_TYPE.ReplicaSetNoPrimary, TOPOLOGY_TYPE.ReplicaSetWithPrimary], ) - await client.close() # directConnection=True, should error with multiple hosts as a list. with self.assertRaises(ConfigurationError): @@ -2021,12 +1981,10 @@ async def _test_handshake(self, env_vars, expected_env): if "AWS_REGION" not in env_vars: os.environ["AWS_REGION"] = "" - async with await self.async_rs_or_single_client( - serverSelectionTimeoutMS=10000 - ) as client: - await client.admin.command("ping") - options = client.options - self.assertEqual(options.pool_options.metadata, metadata) + client = await self.async_rs_or_single_client(serverSelectionTimeoutMS=10000) + await client.admin.command("ping") + options = client.options + self.assertEqual(options.pool_options.metadata, metadata) async def test_handshake_01_aws(self): await self._test_handshake( diff --git a/test/test_client.py b/test/test_client.py index 872018ee41..84e919ff6d 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -138,7 +138,7 @@ def inject_fixtures(self, caplog): self._caplog = caplog def test_keyword_arg_defaults(self): - with MongoClient( + client = self.simple_client( socketTimeoutMS=None, connectTimeoutMS=20000, waitQueueTimeoutMS=None, @@ -150,36 +150,36 @@ def test_keyword_arg_defaults(self): tlsCAFile=None, connect=False, serverSelectionTimeoutMS=12000, - ) as client: - options = client.options - pool_opts = options.pool_options - self.assertEqual(None, pool_opts.socket_timeout) - # socket.Socket.settimeout takes a float in seconds - self.assertEqual(20.0, pool_opts.connect_timeout) - self.assertEqual(None, pool_opts.wait_queue_timeout) - self.assertEqual(None, pool_opts._ssl_context) - self.assertEqual(None, options.replica_set_name) - self.assertEqual(ReadPreference.PRIMARY, client.read_preference) - self.assertAlmostEqual(12, client.options.server_selection_timeout) + ) + + options = client.options + pool_opts = options.pool_options + self.assertEqual(None, pool_opts.socket_timeout) + # socket.Socket.settimeout takes a float in seconds + self.assertEqual(20.0, pool_opts.connect_timeout) + self.assertEqual(None, pool_opts.wait_queue_timeout) + self.assertEqual(None, pool_opts._ssl_context) + self.assertEqual(None, options.replica_set_name) + self.assertEqual(ReadPreference.PRIMARY, client.read_preference) + self.assertAlmostEqual(12, client.options.server_selection_timeout) def test_connect_timeout(self): - client = MongoClient(connect=False, connectTimeoutMS=None, socketTimeoutMS=None) + client = self.simple_client(connect=False, connectTimeoutMS=None, socketTimeoutMS=None) pool_opts = client.options.pool_options self.assertEqual(None, pool_opts.socket_timeout) self.assertEqual(None, pool_opts.connect_timeout) - client.close() - client = MongoClient(connect=False, connectTimeoutMS=0, socketTimeoutMS=0) + + client = self.simple_client(connect=False, connectTimeoutMS=0, socketTimeoutMS=0) pool_opts = client.options.pool_options self.assertEqual(None, pool_opts.socket_timeout) self.assertEqual(None, pool_opts.connect_timeout) - client.close() - client = MongoClient( + + client = self.single_client( "mongodb://localhost/?connectTimeoutMS=0&socketTimeoutMS=0", connect=False ) pool_opts = client.options.pool_options self.assertEqual(None, pool_opts.socket_timeout) self.assertEqual(None, pool_opts.connect_timeout) - client.close() def test_types(self): self.assertRaises(TypeError, MongoClient, 1) @@ -191,8 +191,7 @@ def test_types(self): self.assertRaises(ConfigurationError, MongoClient, []) def test_max_pool_size_zero(self): - with MongoClient(maxPoolSize=0): - pass + self.simple_client(maxPoolSize=0) def test_uri_detection(self): self.assertRaises(ConfigurationError, MongoClient, "/foo") @@ -257,37 +256,35 @@ def test_iteration(self): self.assertNotIsInstance(client, Iterable) def test_get_default_database(self): - with self.rs_or_single_client( + c = self.rs_or_single_client( "mongodb://%s:%d/foo" % (client_context.host, client_context.port), connect=False, - ) as c: - self.assertEqual(Database(c, "foo"), c.get_default_database()) - # Test that default doesn't override the URI value. - self.assertEqual(Database(c, "foo"), c.get_default_database("bar")) - - codec_options = CodecOptions(tz_aware=True) - write_concern = WriteConcern(w=2, j=True) - db = c.get_default_database( - None, codec_options, ReadPreference.SECONDARY, write_concern - ) - self.assertEqual("foo", db.name) - self.assertEqual(codec_options, db.codec_options) - self.assertEqual(ReadPreference.SECONDARY, db.read_preference) - self.assertEqual(write_concern, db.write_concern) + ) + self.assertEqual(Database(c, "foo"), c.get_default_database()) + # Test that default doesn't override the URI value. + self.assertEqual(Database(c, "foo"), c.get_default_database("bar")) - with self.rs_or_single_client( + codec_options = CodecOptions(tz_aware=True) + write_concern = WriteConcern(w=2, j=True) + db = c.get_default_database(None, codec_options, ReadPreference.SECONDARY, write_concern) + self.assertEqual("foo", db.name) + self.assertEqual(codec_options, db.codec_options) + self.assertEqual(ReadPreference.SECONDARY, db.read_preference) + self.assertEqual(write_concern, db.write_concern) + + c = self.rs_or_single_client( "mongodb://%s:%d/" % (client_context.host, client_context.port), connect=False, - ) as c: - self.assertEqual(Database(c, "foo"), c.get_default_database("foo")) + ) + self.assertEqual(Database(c, "foo"), c.get_default_database("foo")) def test_get_default_database_error(self): # URI with no database. - with self.rs_or_single_client( + c = self.rs_or_single_client( "mongodb://%s:%d/" % (client_context.host, client_context.port), connect=False, - ) as c: - self.assertRaises(ConfigurationError, c.get_default_database) + ) + self.assertRaises(ConfigurationError, c.get_default_database) def test_get_default_database_with_authsource(self): # Ensure we distinguish database name from authSource. @@ -295,15 +292,15 @@ def test_get_default_database_with_authsource(self): client_context.host, client_context.port, ) - with self.rs_or_single_client(uri, connect=False) as c: - self.assertEqual(Database(c, "foo"), c.get_default_database()) + c = self.rs_or_single_client(uri, connect=False) + self.assertEqual(Database(c, "foo"), c.get_default_database()) def test_get_database_default(self): - with self.rs_or_single_client( + c = self.rs_or_single_client( "mongodb://%s:%d/foo" % (client_context.host, client_context.port), connect=False, - ) as c: - self.assertEqual(Database(c, "foo"), c.get_database()) + ) + self.assertEqual(Database(c, "foo"), c.get_database()) def test_get_database_default_error(self): # URI with no database. @@ -312,7 +309,6 @@ def test_get_database_default_error(self): connect=False, ) self.assertRaises(ConfigurationError, c.get_database) - c.close() def test_get_database_default_with_authsource(self): # Ensure we distinguish database name from authSource. @@ -322,92 +318,86 @@ def test_get_database_default_with_authsource(self): ) c = self.rs_or_single_client(uri, connect=False) self.assertEqual(Database(c, "foo"), c.get_database()) - c.close() def test_primary_read_pref_with_tags(self): # No tags allowed with "primary". with self.assertRaises(ConfigurationError): - MongoClient("mongodb://host/?readpreferencetags=dc:east") + self.single_client("mongodb://host/?readpreferencetags=dc:east") with self.assertRaises(ConfigurationError): - MongoClient("mongodb://host/?readpreference=primary&readpreferencetags=dc:east") + self.single_client("mongodb://host/?readpreference=primary&readpreferencetags=dc:east") def test_read_preference(self): - with self.rs_or_single_client( + c = self.rs_or_single_client( "mongodb://host", connect=False, readpreference=ReadPreference.NEAREST.mongos_mode - ) as c: - self.assertEqual(c.read_preference, ReadPreference.NEAREST) + ) + self.assertEqual(c.read_preference, ReadPreference.NEAREST) def test_metadata(self): metadata = copy.deepcopy(_METADATA) metadata["driver"]["name"] = "PyMongo" metadata["application"] = {"name": "foobar"} - with MongoClient("mongodb://foo:27017/?appname=foobar&connect=false") as client: - options = client.options - self.assertEqual(options.pool_options.metadata, metadata) - with MongoClient("foo", 27017, appname="foobar", connect=False) as client: - options = client.options - self.assertEqual(options.pool_options.metadata, metadata) + client = self.single_client("mongodb://foo:27017/?appname=foobar&connect=false") + options = client.options + self.assertEqual(options.pool_options.metadata, metadata) + client = self.single_client("foo", 27017, appname="foobar", connect=False) + options = client.options + self.assertEqual(options.pool_options.metadata, metadata) # No error - with MongoClient(appname="x" * 128): - pass + self.simple_client(appname="x" * 128) with self.assertRaises(ValueError): - with MongoClient(appname="x" * 129): - pass + self.simple_client(appname="x" * 129) # Bad "driver" options. self.assertRaises(TypeError, DriverInfo, "Foo", 1, "a") self.assertRaises(TypeError, DriverInfo, version="1", platform="a") self.assertRaises(TypeError, DriverInfo) with self.assertRaises(TypeError): - with MongoClient(driver=1): - pass + self.simple_client(driver=1) with self.assertRaises(TypeError): - with MongoClient(driver="abc"): - pass + self.simple_client(driver="abc") with self.assertRaises(TypeError): - with MongoClient(driver=("Foo", "1", "a")): - pass + self.simple_client(driver=("Foo", "1", "a")) # Test appending to driver info. metadata["driver"]["name"] = "PyMongo|FooDriver" metadata["driver"]["version"] = "{}|1.2.3".format(_METADATA["driver"]["version"]) - with MongoClient( + client = self.single_client( "foo", 27017, appname="foobar", driver=DriverInfo("FooDriver", "1.2.3", None), connect=False, - ) as client: - options = client.options - self.assertEqual(options.pool_options.metadata, metadata) + ) + options = client.options + self.assertEqual(options.pool_options.metadata, metadata) metadata["platform"] = "{}|FooPlatform".format(_METADATA["platform"]) - with MongoClient( + client = self.single_client( "foo", 27017, appname="foobar", driver=DriverInfo("FooDriver", "1.2.3", "FooPlatform"), connect=False, - ) as client: - options = client.options - self.assertEqual(options.pool_options.metadata, metadata) + ) + options = client.options + self.assertEqual(options.pool_options.metadata, metadata) # Test truncating driver info metadata. - with MongoClient( + client = self.single_client( driver=DriverInfo(name="s" * _MAX_METADATA_SIZE), connect=False, - ) as client: - options = client.options - self.assertLessEqual( - len(bson.encode(options.pool_options.metadata)), - _MAX_METADATA_SIZE, - ) - with MongoClient( + ) + options = client.options + self.assertLessEqual( + len(bson.encode(options.pool_options.metadata)), + _MAX_METADATA_SIZE, + ) + client = self.single_client( driver=DriverInfo(name="s" * _MAX_METADATA_SIZE, version="s" * _MAX_METADATA_SIZE), connect=False, - ) as client: - options = client.options - self.assertLessEqual( - len(bson.encode(options.pool_options.metadata)), - _MAX_METADATA_SIZE, - ) + ) + options = client.options + self.assertLessEqual( + len(bson.encode(options.pool_options.metadata)), + _MAX_METADATA_SIZE, + ) @mock.patch.dict("os.environ", {ENV_VAR_K8S: "1"}) def test_container_metadata(self): @@ -415,10 +405,9 @@ def test_container_metadata(self): metadata["driver"]["name"] = "PyMongo" metadata["env"] = {} metadata["env"]["container"] = {"orchestrator": "kubernetes"} - client = MongoClient("mongodb://foo:27017/?appname=foobar&connect=false") + client = self.simple_client("mongodb://foo:27017/?appname=foobar&connect=false") options = client.options self.assertEqual(options.pool_options.metadata["env"], metadata["env"]) - client.close() def test_kwargs_codec_options(self): class MyFloatType: @@ -442,7 +431,7 @@ def transform_python(self, value): uuid_representation_label = "javaLegacy" unicode_decode_error_handler = "ignore" tzinfo = utc - with MongoClient( + c = self.simple_client( document_class=document_class, type_registry=type_registry, tz_aware=tz_aware, @@ -450,18 +439,16 @@ def transform_python(self, value): unicode_decode_error_handler=unicode_decode_error_handler, tzinfo=tzinfo, connect=False, - ) as c: - self.assertEqual(c.codec_options.document_class, document_class) - self.assertEqual(c.codec_options.type_registry, type_registry) - self.assertEqual(c.codec_options.tz_aware, tz_aware) - self.assertEqual( - c.codec_options.uuid_representation, - _UUID_REPRESENTATIONS[uuid_representation_label], - ) - self.assertEqual( - c.codec_options.unicode_decode_error_handler, unicode_decode_error_handler - ) - self.assertEqual(c.codec_options.tzinfo, tzinfo) + ) + self.assertEqual(c.codec_options.document_class, document_class) + self.assertEqual(c.codec_options.type_registry, type_registry) + self.assertEqual(c.codec_options.tz_aware, tz_aware) + self.assertEqual( + c.codec_options.uuid_representation, + _UUID_REPRESENTATIONS[uuid_representation_label], + ) + self.assertEqual(c.codec_options.unicode_decode_error_handler, unicode_decode_error_handler) + self.assertEqual(c.codec_options.tzinfo, tzinfo) def test_uri_codec_options(self): # Ensure codec options are passed in correctly @@ -480,38 +467,36 @@ def test_uri_codec_options(self): datetime_conversion, ) ) - with MongoClient(uri, connect=False) as c: - self.assertEqual(c.codec_options.tz_aware, True) - self.assertEqual( - c.codec_options.uuid_representation, - _UUID_REPRESENTATIONS[uuid_representation_label], - ) - self.assertEqual( - c.codec_options.unicode_decode_error_handler, unicode_decode_error_handler - ) - self.assertEqual( - c.codec_options.datetime_conversion, DatetimeConversion[datetime_conversion] - ) + c = self.simple_client(uri, connect=False) + self.assertEqual(c.codec_options.tz_aware, True) + self.assertEqual( + c.codec_options.uuid_representation, + _UUID_REPRESENTATIONS[uuid_representation_label], + ) + self.assertEqual(c.codec_options.unicode_decode_error_handler, unicode_decode_error_handler) + self.assertEqual( + c.codec_options.datetime_conversion, DatetimeConversion[datetime_conversion] + ) # Change the passed datetime_conversion to a number and re-assert. uri = uri.replace(datetime_conversion, f"{int(DatetimeConversion[datetime_conversion])}") - with MongoClient(uri, connect=False) as c: - self.assertEqual( - c.codec_options.datetime_conversion, DatetimeConversion[datetime_conversion] - ) + c = self.simple_client(uri, connect=False) + self.assertEqual( + c.codec_options.datetime_conversion, DatetimeConversion[datetime_conversion] + ) def test_uri_option_precedence(self): # Ensure kwarg options override connection string options. uri = "mongodb://localhost/?ssl=true&replicaSet=name&readPreference=primary" - with MongoClient( + c = self.simple_client( uri, ssl=False, replicaSet="newname", readPreference="secondaryPreferred" - ) as c: - clopts = c.options - opts = clopts._options + ) + clopts = c.options + opts = clopts._options - self.assertEqual(opts["tls"], False) - self.assertEqual(clopts.replica_set_name, "newname") - self.assertEqual(clopts.read_preference, ReadPreference.SECONDARY_PREFERRED) + self.assertEqual(opts["tls"], False) + self.assertEqual(clopts.replica_set_name, "newname") + self.assertEqual(clopts.read_preference, ReadPreference.SECONDARY_PREFERRED) def test_connection_timeout_ms_propagates_to_DNS_resolver(self): # Patch the resolver. @@ -534,9 +519,9 @@ def reset_resolver(): def test_scenario(args, kwargs, expected_value): patched_resolver.reset() - with MongoClient(*args, **kwargs): - for _, kw in patched_resolver.call_list(): - self.assertAlmostEqual(kw["lifetime"], expected_value) + self.simple_client(*args, **kwargs) + for _, kw in patched_resolver.call_list(): + self.assertAlmostEqual(kw["lifetime"], expected_value) # No timeout specified. test_scenario((base_uri,), {}, CONNECT_TIMEOUT) @@ -554,58 +539,53 @@ def test_scenario(args, kwargs, expected_value): def test_uri_security_options(self): # Ensure that we don't silently override security-related options. with self.assertRaises(InvalidURI): - with MongoClient("mongodb://localhost/?ssl=true", tls=False, connect=False): - pass + self.simple_client("mongodb://localhost/?ssl=true", tls=False, connect=False) # Matching SSL and TLS options should not cause errors. - with MongoClient("mongodb://localhost/?ssl=false", tls=False, connect=False) as c: - self.assertEqual(c.options._options["tls"], False) + c = self.simple_client("mongodb://localhost/?ssl=false", tls=False, connect=False) + self.assertEqual(c.options._options["tls"], False) # Conflicting tlsInsecure options should raise an error. with self.assertRaises(InvalidURI): - with MongoClient( + self.simple_client( "mongodb://localhost/?tlsInsecure=true", connect=False, tlsAllowInvalidHostnames=True, - ): - pass + ) # Conflicting legacy tlsInsecure options should also raise an error. with self.assertRaises(InvalidURI): - with MongoClient( + self.simple_client( "mongodb://localhost/?tlsInsecure=true", connect=False, tlsAllowInvalidCertificates=False, - ): - pass + ) # Conflicting kwargs should raise InvalidURI with self.assertRaises(InvalidURI): - with MongoClient(ssl=True, tls=False): - pass + self.simple_client(ssl=True, tls=False) def test_event_listeners(self): - with MongoClient(event_listeners=[], connect=False) as c: - self.assertEqual(c.options.event_listeners, []) - listeners = [ - event_loggers.CommandLogger(), - event_loggers.HeartbeatLogger(), - event_loggers.ServerLogger(), - event_loggers.TopologyLogger(), - event_loggers.ConnectionPoolLogger(), - ] - with MongoClient(event_listeners=listeners, connect=False) as c: - self.assertEqual(c.options.event_listeners, listeners) + c = self.simple_client(event_listeners=[], connect=False) + self.assertEqual(c.options.event_listeners, []) + listeners = [ + event_loggers.CommandLogger(), + event_loggers.HeartbeatLogger(), + event_loggers.ServerLogger(), + event_loggers.TopologyLogger(), + event_loggers.ConnectionPoolLogger(), + ] + c = self.simple_client(event_listeners=listeners, connect=False) + self.assertEqual(c.options.event_listeners, listeners) def test_client_options(self): - c = MongoClient(connect=False) + c = self.simple_client(connect=False) self.assertIsInstance(c.options, ClientOptions) self.assertIsInstance(c.options.pool_options, PoolOptions) self.assertEqual(c.options.server_selection_timeout, 30) self.assertEqual(c.options.pool_options.max_idle_time_seconds, None) self.assertIsInstance(c.options.retry_writes, bool) self.assertIsInstance(c.options.retry_reads, bool) - c.close() def test_validate_suggestion(self): """Validate kwargs in constructor.""" @@ -651,16 +631,13 @@ def test_detected_environment_warning(self, mock_get_hosts): ) for host in normal_hosts: with self.assertWarns(UserWarning): - with MongoClient(host): - pass + self.simple_client(host) for host in srv_hosts: mock_get_hosts.return_value = [(host, 1)] with self.assertWarns(UserWarning): - with MongoClient(host): - pass + self.simple_client(host) with self.assertWarns(UserWarning): - with MongoClient(multi_host): - pass + self.simple_client(multi_host) class TestClient(IntegrationTest): @@ -683,7 +660,6 @@ def test_max_idle_time_reaper_default(self): pass self.assertEqual(1, len(server._pool.conns)) self.assertTrue(conn in server._pool.conns) - client.close() def test_max_idle_time_reaper_removes_stale_minPoolSize(self): with client_knobs(kill_cursor_frequency=0.1): @@ -697,7 +673,6 @@ def test_max_idle_time_reaper_removes_stale_minPoolSize(self): self.assertGreaterEqual(len(server._pool.conns), 1) wait_until(lambda: conn not in server._pool.conns, "remove stale socket") wait_until(lambda: len(server._pool.conns) >= 1, "replace stale socket") - client.close() def test_max_idle_time_reaper_does_not_exceed_maxPoolSize(self): with client_knobs(kill_cursor_frequency=0.1): @@ -711,7 +686,6 @@ def test_max_idle_time_reaper_does_not_exceed_maxPoolSize(self): self.assertEqual(1, len(server._pool.conns)) wait_until(lambda: conn not in server._pool.conns, "remove stale socket") wait_until(lambda: len(server._pool.conns) == 1, "replace stale socket") - client.close() def test_max_idle_time_reaper_removes_stale(self): with client_knobs(kill_cursor_frequency=0.1): @@ -729,7 +703,6 @@ def test_max_idle_time_reaper_removes_stale(self): lambda: len(server._pool.conns) == 0, "stale socket reaped and new one NOT added to the pool", ) - client.close() def test_min_pool_size(self): with client_knobs(kill_cursor_frequency=0.1): @@ -799,20 +772,20 @@ def test_constants(self): MongoClient.HOST = "somedomainthatdoesntexist.org" MongoClient.PORT = 123456789 with self.assertRaises(AutoReconnect): - with MongoClient(serverSelectionTimeoutMS=10, **kwargs) as c: - connected(c) - - with MongoClient(host, port, **kwargs) as c: - # Override the defaults. No error. + c = self.simple_client(serverSelectionTimeoutMS=10, **kwargs) connected(c) + c = self.simple_client(host, port, **kwargs) + # Override the defaults. No error. + connected(c) + # Set good defaults. MongoClient.HOST = host MongoClient.PORT = port # No error. - with MongoClient(**kwargs) as c: - connected(c) + c = self.simple_client(**kwargs) + connected(c) def test_init_disconnected(self): host, port = client_context.host, client_context.port @@ -863,27 +836,23 @@ def test_init_disconnected_with_auth(self): def test_equality(self): seed = "{}:{}".format(*list(self.client._topology_settings.seeds)[0]) - with self.rs_or_single_client(seed, connect=False) as c: - self.assertEqual(client_context.client, c) - # Explicitly test inequality - self.assertFalse(client_context.client != c) + c = self.rs_or_single_client(seed, connect=False) + self.assertEqual(client_context.client, c) + # Explicitly test inequality + self.assertFalse(client_context.client != c) - with self.rs_or_single_client("invalid.com", connect=False) as c: - self.assertNotEqual(client_context.client, c) - self.assertTrue(client_context.client != c) + c = self.rs_or_single_client("invalid.com", connect=False) + self.assertNotEqual(client_context.client, c) + self.assertTrue(client_context.client != c) - c1 = MongoClient("a", connect=False) - c2 = MongoClient("b", connect=False) - self.addCleanup(c1.close) - self.addCleanup(c2.close) + c1 = self.simple_client("a", connect=False) + c2 = self.simple_client("b", connect=False) # Seeds differ: self.assertNotEqual(c1, c2) - c1 = MongoClient(["a", "b", "c"], connect=False) - c2 = MongoClient(["c", "a", "b"], connect=False) - self.addCleanup(c1.close) - self.addCleanup(c2.close) + c1 = self.simple_client(["a", "b", "c"], connect=False) + c2 = self.simple_client(["c", "a", "b"], connect=False) # Same seeds but out of order still compares equal: self.assertEqual(c1, c2) @@ -1058,7 +1027,6 @@ def test_close_kills_cursors(self): # The killCursors task should not need to re-open the topology. test_client._process_periodic_tasks() self.assertTrue(test_client._topology._opened) - test_client.close() def test_close_stops_kill_cursors_thread(self): client = self.rs_client() @@ -1090,9 +1058,6 @@ def test_uri_connect_option(self): kc_thread = client._kill_cursors_executor._thread self.assertTrue(kc_thread and kc_thread.is_alive()) - # Tear down. - client.close() - def test_close_does_not_open_servers(self): client = self.rs_client(connect=False) topology = client._topology @@ -1209,10 +1174,10 @@ def test_unix_socket(self): # Confirm it fails with a missing socket. with self.assertRaises(ConnectionFailure): - with MongoClient( + c = self.simple_client( "mongodb://%2Ftmp%2Fnon-existent.sock", serverSelectionTimeoutMS=100 - ) as c: - connected(c) + ) + connected(c) def test_document_class(self): c = self.client @@ -1571,9 +1536,9 @@ def test_auth_network_error(self): @client_context.require_no_replica_set def test_connect_to_standalone_using_replica_set_name(self): - with self.single_client(replicaSet="anything", serverSelectionTimeoutMS=100) as client: - with self.assertRaises(AutoReconnect): - client.test.test.find_one() + client = self.single_client(replicaSet="anything", serverSelectionTimeoutMS=100) + with self.assertRaises(AutoReconnect): + client.test.test.find_one() @client_context.require_replica_set def test_stale_getmore(self): @@ -1629,7 +1594,7 @@ def init(self, *args): client_context.host, client_context.port, ) - client = self.single_client(uri, event_listeners=[listener]) + self.single_client(uri, event_listeners=[listener]) wait_until( lambda: len(listener.results) >= 2, "record two ServerHeartbeatStartedEvents" ) @@ -1638,7 +1603,6 @@ def init(self, *args): # closer to 0.5 sec with heartbeatFrequencyMS configured. self.assertAlmostEqual(heartbeat_times[1] - heartbeat_times[0], 0.5, delta=2) - client.close() finally: ServerHeartbeatStartedEvent.__init__ = old_init # type: ignore @@ -1655,84 +1619,84 @@ def compression_settings(client): return pool_options._compression_settings uri = "mongodb://localhost:27017/?compressors=zlib" - with MongoClient(uri, connect=False) as client: - opts = compression_settings(client) - self.assertEqual(opts.compressors, ["zlib"]) + client = self.simple_client(uri, connect=False) + opts = compression_settings(client) + self.assertEqual(opts.compressors, ["zlib"]) uri = "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=4" - with MongoClient(uri, connect=False) as client: - opts = compression_settings(client) - self.assertEqual(opts.compressors, ["zlib"]) - self.assertEqual(opts.zlib_compression_level, 4) + client = self.simple_client(uri, connect=False) + opts = compression_settings(client) + self.assertEqual(opts.compressors, ["zlib"]) + self.assertEqual(opts.zlib_compression_level, 4) uri = "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=-1" - with MongoClient(uri, connect=False) as client: - opts = compression_settings(client) - self.assertEqual(opts.compressors, ["zlib"]) - self.assertEqual(opts.zlib_compression_level, -1) + client = self.simple_client(uri, connect=False) + opts = compression_settings(client) + self.assertEqual(opts.compressors, ["zlib"]) + self.assertEqual(opts.zlib_compression_level, -1) uri = "mongodb://localhost:27017" - with MongoClient(uri, connect=False) as client: - opts = compression_settings(client) - self.assertEqual(opts.compressors, []) - self.assertEqual(opts.zlib_compression_level, -1) + client = self.simple_client(uri, connect=False) + opts = compression_settings(client) + self.assertEqual(opts.compressors, []) + self.assertEqual(opts.zlib_compression_level, -1) uri = "mongodb://localhost:27017/?compressors=foobar" - with MongoClient(uri, connect=False) as client: - opts = compression_settings(client) - self.assertEqual(opts.compressors, []) - self.assertEqual(opts.zlib_compression_level, -1) + client = self.simple_client(uri, connect=False) + opts = compression_settings(client) + self.assertEqual(opts.compressors, []) + self.assertEqual(opts.zlib_compression_level, -1) uri = "mongodb://localhost:27017/?compressors=foobar,zlib" - with MongoClient(uri, connect=False) as client: - opts = compression_settings(client) - self.assertEqual(opts.compressors, ["zlib"]) - self.assertEqual(opts.zlib_compression_level, -1) + client = self.simple_client(uri, connect=False) + opts = compression_settings(client) + self.assertEqual(opts.compressors, ["zlib"]) + self.assertEqual(opts.zlib_compression_level, -1) # According to the connection string spec, unsupported values # just raise a warning and are ignored. uri = "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=10" - with MongoClient(uri, connect=False) as client: - opts = compression_settings(client) - self.assertEqual(opts.compressors, ["zlib"]) - self.assertEqual(opts.zlib_compression_level, -1) + client = self.simple_client(uri, connect=False) + opts = compression_settings(client) + self.assertEqual(opts.compressors, ["zlib"]) + self.assertEqual(opts.zlib_compression_level, -1) uri = "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=-2" - with MongoClient(uri, connect=False) as client: - opts = compression_settings(client) - self.assertEqual(opts.compressors, ["zlib"]) - self.assertEqual(opts.zlib_compression_level, -1) + client = self.simple_client(uri, connect=False) + opts = compression_settings(client) + self.assertEqual(opts.compressors, ["zlib"]) + self.assertEqual(opts.zlib_compression_level, -1) if not _have_snappy(): uri = "mongodb://localhost:27017/?compressors=snappy" - with MongoClient(uri, connect=False) as client: - opts = compression_settings(client) - self.assertEqual(opts.compressors, []) + client = self.simple_client(uri, connect=False) + opts = compression_settings(client) + self.assertEqual(opts.compressors, []) else: uri = "mongodb://localhost:27017/?compressors=snappy" - with MongoClient(uri, connect=False) as client: - opts = compression_settings(client) - self.assertEqual(opts.compressors, ["snappy"]) + client = self.simple_client(uri, connect=False) + opts = compression_settings(client) + self.assertEqual(opts.compressors, ["snappy"]) uri = "mongodb://localhost:27017/?compressors=snappy,zlib" - with MongoClient(uri, connect=False) as client: - opts = compression_settings(client) - self.assertEqual(opts.compressors, ["snappy", "zlib"]) + client = self.simple_client(uri, connect=False) + opts = compression_settings(client) + self.assertEqual(opts.compressors, ["snappy", "zlib"]) if not _have_zstd(): uri = "mongodb://localhost:27017/?compressors=zstd" - with MongoClient(uri, connect=False) as client: - opts = compression_settings(client) - self.assertEqual(opts.compressors, []) + client = self.simple_client(uri, connect=False) + opts = compression_settings(client) + self.assertEqual(opts.compressors, []) else: uri = "mongodb://localhost:27017/?compressors=zstd" - with MongoClient(uri, connect=False) as client: - opts = compression_settings(client) - self.assertEqual(opts.compressors, ["zstd"]) + client = self.simple_client(uri, connect=False) + opts = compression_settings(client) + self.assertEqual(opts.compressors, ["zstd"]) uri = "mongodb://localhost:27017/?compressors=zstd,zlib" - with MongoClient(uri, connect=False) as client: - opts = compression_settings(client) - self.assertEqual(opts.compressors, ["zstd", "zlib"]) + client = self.simple_client(uri, connect=False) + opts = compression_settings(client) + self.assertEqual(opts.compressors, ["zstd", "zlib"]) options = client_context.default_client_options if "compressors" in options and "zlib" in options["compressors"]: for level in range(-1, 10): - with self.single_client(zlibcompressionlevel=level) as client: - # No error - client.pymongo_test.test.find_one() + client = self.single_client(zlibcompressionlevel=level) + # No error + client.pymongo_test.test.find_one() def test_reset_during_update_pool(self): client = self.rs_or_single_client(minPoolSize=10) @@ -1820,7 +1784,6 @@ def test_direct_connection(self): client.admin.command("ping") self.assertEqual(len(client.nodes), 1) self.assertEqual(client._topology_settings.get_topology_type(), TOPOLOGY_TYPE.Single) - client.close() # direct_connection=False should result in RS topology. client = self.rs_or_single_client(directConnection=False) @@ -1830,7 +1793,6 @@ def test_direct_connection(self): client._topology_settings.get_topology_type(), [TOPOLOGY_TYPE.ReplicaSetNoPrimary, TOPOLOGY_TYPE.ReplicaSetWithPrimary], ) - client.close() # directConnection=True, should error with multiple hosts as a list. with self.assertRaises(ConfigurationError): @@ -1977,10 +1939,10 @@ def _test_handshake(self, env_vars, expected_env): if "AWS_REGION" not in env_vars: os.environ["AWS_REGION"] = "" - with self.rs_or_single_client(serverSelectionTimeoutMS=10000) as client: - client.admin.command("ping") - options = client.options - self.assertEqual(options.pool_options.metadata, metadata) + client = self.rs_or_single_client(serverSelectionTimeoutMS=10000) + client.admin.command("ping") + options = client.options + self.assertEqual(options.pool_options.metadata, metadata) def test_handshake_01_aws(self): self._test_handshake( From 2c9da4593f2dbf1c9bd7a559b35f88676c555c22 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Thu, 12 Sep 2024 11:51:28 -0400 Subject: [PATCH 09/29] WIP --- test/asynchronous/test_client.py | 21 +-------------------- test/asynchronous/test_transactions.py | 8 -------- test/test_client.py | 21 +-------------------- test/test_transactions.py | 8 -------- 4 files changed, 2 insertions(+), 56 deletions(-) diff --git a/test/asynchronous/test_client.py b/test/asynchronous/test_client.py index 29864d2ca6..18b4103edb 100644 --- a/test/asynchronous/test_client.py +++ b/test/asynchronous/test_client.py @@ -816,27 +816,20 @@ async def test_constants(self): async def test_init_disconnected(self): host, port = await async_client_context.host, await async_client_context.port c = await self.async_rs_or_single_client(connect=False) - self.addAsyncCleanup(c.close) # is_primary causes client to block until connected self.assertIsInstance(await c.is_primary, bool) - c = await self.async_rs_or_single_client(connect=False) - self.addAsyncCleanup(c.close) self.assertIsInstance(await c.is_mongos, bool) c = await self.async_rs_or_single_client(connect=False) - self.addAsyncCleanup(c.close) self.assertIsInstance(c.options.pool_options.max_pool_size, int) self.assertIsInstance(c.nodes, frozenset) c = await self.async_rs_or_single_client(connect=False) - self.addAsyncCleanup(c.close) self.assertEqual(c.codec_options, CodecOptions()) c = await self.async_rs_or_single_client(connect=False) - self.addAsyncCleanup(c.close) self.assertFalse(await c.primary) self.assertFalse(await c.secondaries) c = await self.async_rs_or_single_client(connect=False) - self.addAsyncCleanup(c.close) self.assertIsInstance(c.topology_description, TopologyDescription) self.assertEqual(c.topology_description, c._topology._description) self.assertIsNone(await c.address) # PYTHON-2981 @@ -849,14 +842,12 @@ async def test_init_disconnected(self): bad_host = "somedomainthatdoesntexist.org" c = AsyncMongoClient(bad_host, port, connectTimeoutMS=1, serverSelectionTimeoutMS=10) - self.addAsyncCleanup(c.close) with self.assertRaises(ConnectionFailure): await c.pymongo_test.test.find_one() async def test_init_disconnected_with_auth(self): uri = "mongodb://user:pass@somedomainthatdoesntexist" c = AsyncMongoClient(uri, connectTimeoutMS=1, serverSelectionTimeoutMS=10) - self.addAsyncCleanup(c.close) with self.assertRaises(ConnectionFailure): await c.pymongo_test.test.find_one() @@ -886,10 +877,8 @@ async def test_equality(self): async def test_hashable(self): seed = "{}:{}".format(*list(self.client._topology_settings.seeds)[0]) c = await self.async_rs_or_single_client(seed, connect=False) - self.addAsyncCleanup(c.close) self.assertIn(c, {async_client_context.client}) c = await self.async_rs_or_single_client("invalid.com", connect=False) - self.addAsyncCleanup(c.close) self.assertNotIn(c, {async_client_context.client}) async def test_host_w_port(self): @@ -1225,7 +1214,7 @@ async def test_document_class(self): self.assertFalse(isinstance(await db.test.find_one(), SON)) c = await self.async_rs_or_single_client(document_class=SON) - self.addAsyncCleanup(c.close) + db = c.pymongo_test self.assertEqual(SON, c.codec_options.document_class) @@ -1247,15 +1236,12 @@ async def test_timeouts(self): async def test_socket_timeout_ms_validation(self): c = await self.async_rs_or_single_client(socketTimeoutMS=10 * 1000) - self.addAsyncCleanup(c.close) self.assertEqual(10, (await async_get_pool(c)).opts.socket_timeout) c = await connected(await self.async_rs_or_single_client(socketTimeoutMS=None)) - self.addAsyncCleanup(c.close) self.assertEqual(None, (await async_get_pool(c)).opts.socket_timeout) c = await connected(await self.async_rs_or_single_client(socketTimeoutMS=0)) - self.addAsyncCleanup(c.close) self.assertEqual(None, (await async_get_pool(c)).opts.socket_timeout) with self.assertRaises(ValueError): @@ -2359,7 +2345,6 @@ async def test_discover_primary(self): replicaSet="rs", heartbeatFrequencyMS=500, ) - self.addAsyncCleanup(c.close) wait_until(lambda: len(c.nodes) == 3, "connect") @@ -2386,7 +2371,6 @@ async def test_reconnect(self): retryReads=False, serverSelectionTimeoutMS=1000, ) - self.addAsyncCleanup(c.close) wait_until(lambda: len(c.nodes) == 3, "connect") @@ -2421,7 +2405,6 @@ async def _test_network_error(self, operation_callback): retryReads=False, serverSelectionTimeoutMS=1000, ) - self.addAsyncCleanup(c.close) # Set host-specific information so we can test whether it is reset. c.set_wire_version_range("a:1", 2, 6) @@ -2497,7 +2480,6 @@ async def test_rs_client_does_not_maintain_pool_to_arbiters(self): minPoolSize=1, # minPoolSize event_listeners=[listener], ) - self.addAsyncCleanup(c.close) wait_until(lambda: len(c.nodes) == 3, "connect") self.assertEqual(await c.address, ("a", 1)) @@ -2527,7 +2509,6 @@ async def test_direct_client_maintains_pool_to_arbiter(self): minPoolSize=1, # minPoolSize event_listeners=[listener], ) - self.addAsyncCleanup(c.close) wait_until(lambda: len(c.nodes) == 1, "connect") self.assertEqual(await c.address, ("c", 3)) diff --git a/test/asynchronous/test_transactions.py b/test/asynchronous/test_transactions.py index 72d5e025b4..309ff5d6ae 100644 --- a/test/asynchronous/test_transactions.py +++ b/test/asynchronous/test_transactions.py @@ -112,7 +112,6 @@ def test_transaction_options_validation(self): async def test_transaction_write_concern_override(self): """Test txn overrides Client/Database/Collection write_concern.""" client = await self.async_rs_client(w=0) - self.addAsyncCleanup(client.close) db = client.test coll = db.test await coll.insert_one({}) @@ -176,7 +175,6 @@ async def test_unpin_for_next_transaction(self): coll = client.test.test # Create the collection. await coll.insert_one({}) - self.addAsyncCleanup(client.close) async with client.start_session() as s: # Session is pinned to Mongos. async with await s.start_transaction(): @@ -206,7 +204,6 @@ async def test_unpin_for_non_transaction_operation(self): coll = client.test.test # Create the collection. await coll.insert_one({}) - self.addAsyncCleanup(client.close) async with client.start_session() as s: # Session is pinned to Mongos. async with await s.start_transaction(): @@ -334,7 +331,6 @@ async def test_transaction_starts_with_batched_write(self): coll = client[self.db.name].test await coll.delete_many({}) listener.reset() - self.addAsyncCleanup(client.close) self.addAsyncCleanup(coll.drop) large_str = "\0" * (1 * 1024 * 1024) ops: List[InsertOne[RawBSONDocument]] = [ @@ -360,7 +356,6 @@ async def test_transaction_starts_with_batched_write(self): @async_client_context.require_transactions async def test_transaction_direct_connection(self): client = await self.async_single_client() - self.addAsyncCleanup(client.close) coll = client.pymongo_test.test # Make sure the collection exists. @@ -476,7 +471,6 @@ async def callback2(session): async def test_callback_not_retried_after_timeout(self): listener = OvertCommandListener() client = await self.async_rs_client(event_listeners=[listener]) - self.addAsyncCleanup(client.close) coll = client[self.db.name].test async def callback(session): @@ -505,7 +499,6 @@ async def callback(session): async def test_callback_not_retried_after_commit_timeout(self): listener = OvertCommandListener() client = await self.async_rs_client(event_listeners=[listener]) - self.addAsyncCleanup(client.close) coll = client[self.db.name].test async def callback(session): @@ -540,7 +533,6 @@ async def callback(session): async def test_commit_not_retried_after_timeout(self): listener = OvertCommandListener() client = await self.async_rs_client(event_listeners=[listener]) - self.addAsyncCleanup(client.close) coll = client[self.db.name].test async def callback(session): diff --git a/test/test_client.py b/test/test_client.py index 84e919ff6d..2971c169ad 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -790,27 +790,20 @@ def test_constants(self): def test_init_disconnected(self): host, port = client_context.host, client_context.port c = self.rs_or_single_client(connect=False) - self.addCleanup(c.close) # is_primary causes client to block until connected self.assertIsInstance(c.is_primary, bool) - c = self.rs_or_single_client(connect=False) - self.addCleanup(c.close) self.assertIsInstance(c.is_mongos, bool) c = self.rs_or_single_client(connect=False) - self.addCleanup(c.close) self.assertIsInstance(c.options.pool_options.max_pool_size, int) self.assertIsInstance(c.nodes, frozenset) c = self.rs_or_single_client(connect=False) - self.addCleanup(c.close) self.assertEqual(c.codec_options, CodecOptions()) c = self.rs_or_single_client(connect=False) - self.addCleanup(c.close) self.assertFalse(c.primary) self.assertFalse(c.secondaries) c = self.rs_or_single_client(connect=False) - self.addCleanup(c.close) self.assertIsInstance(c.topology_description, TopologyDescription) self.assertEqual(c.topology_description, c._topology._description) self.assertIsNone(c.address) # PYTHON-2981 @@ -823,14 +816,12 @@ def test_init_disconnected(self): bad_host = "somedomainthatdoesntexist.org" c = MongoClient(bad_host, port, connectTimeoutMS=1, serverSelectionTimeoutMS=10) - self.addCleanup(c.close) with self.assertRaises(ConnectionFailure): c.pymongo_test.test.find_one() def test_init_disconnected_with_auth(self): uri = "mongodb://user:pass@somedomainthatdoesntexist" c = MongoClient(uri, connectTimeoutMS=1, serverSelectionTimeoutMS=10) - self.addCleanup(c.close) with self.assertRaises(ConnectionFailure): c.pymongo_test.test.find_one() @@ -860,10 +851,8 @@ def test_equality(self): def test_hashable(self): seed = "{}:{}".format(*list(self.client._topology_settings.seeds)[0]) c = self.rs_or_single_client(seed, connect=False) - self.addCleanup(c.close) self.assertIn(c, {client_context.client}) c = self.rs_or_single_client("invalid.com", connect=False) - self.addCleanup(c.close) self.assertNotIn(c, {client_context.client}) def test_host_w_port(self): @@ -1189,7 +1178,7 @@ def test_document_class(self): self.assertFalse(isinstance(db.test.find_one(), SON)) c = self.rs_or_single_client(document_class=SON) - self.addCleanup(c.close) + db = c.pymongo_test self.assertEqual(SON, c.codec_options.document_class) @@ -1211,15 +1200,12 @@ def test_timeouts(self): def test_socket_timeout_ms_validation(self): c = self.rs_or_single_client(socketTimeoutMS=10 * 1000) - self.addCleanup(c.close) self.assertEqual(10, (get_pool(c)).opts.socket_timeout) c = connected(self.rs_or_single_client(socketTimeoutMS=None)) - self.addCleanup(c.close) self.assertEqual(None, (get_pool(c)).opts.socket_timeout) c = connected(self.rs_or_single_client(socketTimeoutMS=0)) - self.addCleanup(c.close) self.assertEqual(None, (get_pool(c)).opts.socket_timeout) with self.assertRaises(ValueError): @@ -2315,7 +2301,6 @@ def test_discover_primary(self): replicaSet="rs", heartbeatFrequencyMS=500, ) - self.addCleanup(c.close) wait_until(lambda: len(c.nodes) == 3, "connect") @@ -2342,7 +2327,6 @@ def test_reconnect(self): retryReads=False, serverSelectionTimeoutMS=1000, ) - self.addCleanup(c.close) wait_until(lambda: len(c.nodes) == 3, "connect") @@ -2377,7 +2361,6 @@ def _test_network_error(self, operation_callback): retryReads=False, serverSelectionTimeoutMS=1000, ) - self.addCleanup(c.close) # Set host-specific information so we can test whether it is reset. c.set_wire_version_range("a:1", 2, 6) @@ -2453,7 +2436,6 @@ def test_rs_client_does_not_maintain_pool_to_arbiters(self): minPoolSize=1, # minPoolSize event_listeners=[listener], ) - self.addCleanup(c.close) wait_until(lambda: len(c.nodes) == 3, "connect") self.assertEqual(c.address, ("a", 1)) @@ -2483,7 +2465,6 @@ def test_direct_client_maintains_pool_to_arbiter(self): minPoolSize=1, # minPoolSize event_listeners=[listener], ) - self.addCleanup(c.close) wait_until(lambda: len(c.nodes) == 1, "connect") self.assertEqual(c.address, ("c", 3)) diff --git a/test/test_transactions.py b/test/test_transactions.py index f7a3afc694..8110705600 100644 --- a/test/test_transactions.py +++ b/test/test_transactions.py @@ -110,7 +110,6 @@ def test_transaction_options_validation(self): def test_transaction_write_concern_override(self): """Test txn overrides Client/Database/Collection write_concern.""" client = self.rs_client(w=0) - self.addCleanup(client.close) db = client.test coll = db.test coll.insert_one({}) @@ -168,7 +167,6 @@ def test_unpin_for_next_transaction(self): coll = client.test.test # Create the collection. coll.insert_one({}) - self.addCleanup(client.close) with client.start_session() as s: # Session is pinned to Mongos. with s.start_transaction(): @@ -196,7 +194,6 @@ def test_unpin_for_non_transaction_operation(self): coll = client.test.test # Create the collection. coll.insert_one({}) - self.addCleanup(client.close) with client.start_session() as s: # Session is pinned to Mongos. with s.start_transaction(): @@ -324,7 +321,6 @@ def test_transaction_starts_with_batched_write(self): coll = client[self.db.name].test coll.delete_many({}) listener.reset() - self.addCleanup(client.close) self.addCleanup(coll.drop) large_str = "\0" * (1 * 1024 * 1024) ops: List[InsertOne[RawBSONDocument]] = [ @@ -350,7 +346,6 @@ def test_transaction_starts_with_batched_write(self): @client_context.require_transactions def test_transaction_direct_connection(self): client = self.single_client() - self.addCleanup(client.close) coll = client.pymongo_test.test # Make sure the collection exists. @@ -464,7 +459,6 @@ def callback2(session): def test_callback_not_retried_after_timeout(self): listener = OvertCommandListener() client = self.rs_client(event_listeners=[listener]) - self.addCleanup(client.close) coll = client[self.db.name].test def callback(session): @@ -493,7 +487,6 @@ def callback(session): def test_callback_not_retried_after_commit_timeout(self): listener = OvertCommandListener() client = self.rs_client(event_listeners=[listener]) - self.addCleanup(client.close) coll = client[self.db.name].test def callback(session): @@ -526,7 +519,6 @@ def callback(session): def test_commit_not_retried_after_timeout(self): listener = OvertCommandListener() client = self.rs_client(event_listeners=[listener]) - self.addCleanup(client.close) coll = client[self.db.name].test def callback(session): From 3f7dce519e46fb81f4e3b1bd33110a263114c0b5 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Thu, 12 Sep 2024 16:51:21 -0400 Subject: [PATCH 10/29] Cleanup --- test/asynchronous/test_auth.py | 1 - test/asynchronous/test_client.py | 52 ++----- test/asynchronous/test_client_bulk_write.py | 14 -- test/asynchronous/test_cursor.py | 5 - test/asynchronous/test_database.py | 1 - test/asynchronous/test_encryption.py | 161 +++++++++---------- test/asynchronous/test_grid_file.py | 3 +- test/asynchronous/test_session.py | 2 - test/test_auth.py | 1 - test/test_client.py | 52 ++----- test/test_client_bulk_write.py | 14 -- test/test_cursor.py | 5 - test/test_database.py | 1 - test/test_encryption.py | 163 +++++++++----------- test/test_grid_file.py | 3 +- test/test_session.py | 2 - 16 files changed, 164 insertions(+), 316 deletions(-) diff --git a/test/asynchronous/test_auth.py b/test/asynchronous/test_auth.py index 0f9c2a7886..c34af16710 100644 --- a/test/asynchronous/test_auth.py +++ b/test/asynchronous/test_auth.py @@ -594,7 +594,6 @@ async def test_scram_threaded(self): # The first thread to call find() will authenticate client = await self.async_rs_or_single_client() - self.addAsyncCleanup(client.close) coll = client.db.test threads = [] for _ in range(4): diff --git a/test/asynchronous/test_client.py b/test/asynchronous/test_client.py index 18b4103edb..a25ecccf63 100644 --- a/test/asynchronous/test_client.py +++ b/test/asynchronous/test_client.py @@ -725,7 +725,6 @@ async def test_max_idle_time_reaper_removes_stale(self): async def test_min_pool_size(self): with client_knobs(kill_cursor_frequency=0.1): client = await self.async_rs_or_single_client() - self.addAsyncCleanup(client.close) server = await (await client._get_topology()).select_server( readable_server_selector, _Op.TEST ) @@ -733,7 +732,6 @@ async def test_min_pool_size(self): # Assert that pool started up at minPoolSize client = await self.async_rs_or_single_client(minPoolSize=10) - self.addAsyncCleanup(client.close) server = await (await client._get_topology()).select_server( readable_server_selector, _Op.TEST ) @@ -755,7 +753,6 @@ async def test_max_idle_time_checkout(self): # Use high frequency to test _get_socket_no_auth. with client_knobs(kill_cursor_frequency=99999999): client = await self.async_rs_or_single_client(maxIdleTimeMS=500) - self.addAsyncCleanup(client.close) server = await (await client._get_topology()).select_server( readable_server_selector, _Op.TEST ) @@ -772,7 +769,6 @@ async def test_max_idle_time_checkout(self): # Test that connections are reused if maxIdleTimeMS is not set. client = await self.async_rs_or_single_client() - self.addAsyncCleanup(client.close) server = await (await client._get_topology()).select_server( readable_server_selector, _Op.TEST ) @@ -841,13 +837,13 @@ async def test_init_disconnected(self): self.assertEqual(await c.address, (host, port)) bad_host = "somedomainthatdoesntexist.org" - c = AsyncMongoClient(bad_host, port, connectTimeoutMS=1, serverSelectionTimeoutMS=10) + c = self.simple_client(bad_host, port, connectTimeoutMS=1, serverSelectionTimeoutMS=10) with self.assertRaises(ConnectionFailure): await c.pymongo_test.test.find_one() async def test_init_disconnected_with_auth(self): uri = "mongodb://user:pass@somedomainthatdoesntexist" - c = AsyncMongoClient(uri, connectTimeoutMS=1, serverSelectionTimeoutMS=10) + c = self.simple_client(uri, connectTimeoutMS=1, serverSelectionTimeoutMS=10) with self.assertRaises(ConnectionFailure): await c.pymongo_test.test.find_one() @@ -902,7 +898,6 @@ async def test_repr(self): connect=False, document_class=SON, ) - self.addAsyncCleanup(client.close) the_repr = repr(client) self.assertIn("AsyncMongoClient(host=", the_repr) @@ -915,7 +910,7 @@ async def test_repr(self): async with eval(the_repr) as client_two: self.assertEqual(client_two, client) - client = AsyncMongoClient( + client = self.simple_client( "localhost:27017,localhost:27018", replicaSet="replset", connectTimeoutMS=12345, @@ -924,7 +919,6 @@ async def test_repr(self): wtimeoutms=100, connect=False, ) - self.addAsyncCleanup(client.close) the_repr = repr(client) self.assertIn("AsyncMongoClient(host=", the_repr) self.assertIn("document_class=dict, tz_aware=False, connect=False, ", the_repr) @@ -952,7 +946,6 @@ async def test_list_databases(self): self.assertIs(type(helper_doc), dict) self.assertEqual(helper_doc.keys(), cmd_doc.keys()) client = await self.async_rs_or_single_client(document_class=SON) - self.addAsyncCleanup(client.close) async for doc in await client.list_databases(): self.assertIs(type(doc), dict) @@ -1082,7 +1075,6 @@ async def test_close_does_not_open_servers(self): async def test_close_closes_sockets(self): client = await self.async_rs_client() - self.addAsyncCleanup(client.close) await client.test.test.find_one() topology = client._topology await client.close() @@ -1190,7 +1182,6 @@ async def test_unix_socket(self): uri = "mongodb://%s" % encoded_socket # Confirm we can do operations via the socket. client = await self.async_rs_or_single_client(uri) - self.addAsyncCleanup(client.close) await client.pymongo_test.test.insert_one({"dummy": "object"}) dbs = await client.list_database_names() self.assertTrue("pymongo_test" in dbs) @@ -1227,7 +1218,6 @@ async def test_timeouts(self): maxIdleTimeMS=10500, serverSelectionTimeoutMS=10500, ) - self.addAsyncCleanup(client.close) self.assertEqual(10.5, (await async_get_pool(client)).opts.connect_timeout) self.assertEqual(10.5, (await async_get_pool(client)).opts.socket_timeout) self.assertEqual(10.5, (await async_get_pool(client)).opts.max_idle_time_seconds) @@ -1278,11 +1268,9 @@ async def get_x(db): async def test_server_selection_timeout(self): client = AsyncMongoClient(serverSelectionTimeoutMS=100, connect=False) - self.addAsyncCleanup(client.close) self.assertAlmostEqual(0.1, client.options.server_selection_timeout) client = AsyncMongoClient(serverSelectionTimeoutMS=0, connect=False) - self.addAsyncCleanup(client.close) self.assertAlmostEqual(0, client.options.server_selection_timeout) @@ -1297,25 +1285,20 @@ async def test_server_selection_timeout(self): client = AsyncMongoClient( "mongodb://localhost/?serverSelectionTimeoutMS=100", connect=False ) - self.addAsyncCleanup(client.close) self.assertAlmostEqual(0.1, client.options.server_selection_timeout) client = AsyncMongoClient("mongodb://localhost/?serverSelectionTimeoutMS=0", connect=False) - self.addAsyncCleanup(client.close) self.assertAlmostEqual(0, client.options.server_selection_timeout) # Test invalid timeout in URI ignored and set to default. client = AsyncMongoClient("mongodb://localhost/?serverSelectionTimeoutMS=-1", connect=False) - self.addAsyncCleanup(client.close) self.assertAlmostEqual(30, client.options.server_selection_timeout) client = AsyncMongoClient("mongodb://localhost/?serverSelectionTimeoutMS=", connect=False) - self.addAsyncCleanup(client.close) self.assertAlmostEqual(30, client.options.server_selection_timeout) async def test_waitQueueTimeoutMS(self): client = await self.async_rs_or_single_client(waitQueueTimeoutMS=2000) - self.addAsyncCleanup(client.close) self.assertEqual((await async_get_pool(client)).opts.wait_queue_timeout, 2) async def test_socketKeepAlive(self): @@ -1359,7 +1342,6 @@ async def test_ipv6(self): uri += "/?replicaSet=" + (async_client_context.replica_set_name or "") client = await self.async_rs_or_single_client_noauth(uri) - self.addAsyncCleanup(client.close) await client.pymongo_test.test.insert_one({"dummy": "object"}) await client.pymongo_test_bernie.test.insert_one({"dummy": "object"}) @@ -1461,7 +1443,6 @@ async def test_operation_failure(self): # to avoid race conditions caused by replica set failover or idle # socket reaping. client = await self.async_single_client() - self.addAsyncCleanup(client.close) await client.pymongo_test.test.find_one() pool = await async_get_pool(client) socket_count = len(pool.conns) @@ -1486,7 +1467,6 @@ async def test_lazy_connect_w0(self): self.addAsyncCleanup(async_client_context.client.drop_database, "test_lazy_connect_w0") client = await self.async_rs_or_single_client(connect=False, w=0) - self.addAsyncCleanup(client.close) await client.test_lazy_connect_w0.test.insert_one({}) async def predicate(): @@ -1495,7 +1475,6 @@ async def predicate(): await async_wait_until(predicate, "find one document") client = await self.async_rs_or_single_client(connect=False, w=0) - self.addAsyncCleanup(client.close) await client.test_lazy_connect_w0.test.update_one({}, {"$set": {"x": 1}}) async def predicate(): @@ -1504,7 +1483,6 @@ async def predicate(): await async_wait_until(predicate, "update one document") client = await self.async_rs_or_single_client(connect=False, w=0) - self.addAsyncCleanup(client.close) await client.test_lazy_connect_w0.test.delete_one({}) async def predicate(): @@ -1517,7 +1495,6 @@ async def test_exhaust_network_error(self): # When doing an exhaust query, the socket stays checked out on success # but must be checked in on error to avoid semaphore leaks. client = await self.async_rs_or_single_client(maxPoolSize=1, retryReads=False) - self.addAsyncCleanup(client.close) collection = client.pymongo_test.test pool = await async_get_pool(client) pool._check_interval_seconds = None # Never check. @@ -1728,7 +1705,6 @@ def compression_settings(client): async def test_reset_during_update_pool(self): client = await self.async_rs_or_single_client(minPoolSize=10) - self.addAsyncCleanup(client.close) await client.admin.command("ping") pool = await async_get_pool(client) generation = pool.gen.get_overall() @@ -1777,8 +1753,6 @@ async def test_background_connections_do_not_hold_locks(self): client = await self.async_rs_or_single_client( serverSelectionTimeoutMS=3000, minPoolSize=min_pool_size, connect=False ) - self.addAsyncCleanup(client.close) - # Create a single connection in the pool. await client.admin.command("ping") @@ -1840,11 +1814,10 @@ def server_description_count(): gc.collect() with client_knobs(min_heartbeat_interval=0.003): - client = AsyncMongoClient( + client = self.simple_client( "invalid:27017", heartbeatFrequencyMS=3, serverSelectionTimeoutMS=150 ) initial_count = server_description_count() - self.addAsyncCleanup(client.close) with self.assertRaises(ServerSelectionTimeoutError): await client.test.test.find_one() gc.collect() @@ -1858,7 +1831,6 @@ def server_description_count(): @async_client_context.require_failCommand_fail_point async def test_network_error_message(self): client = await self.async_single_client(retryReads=False) - self.addAsyncCleanup(client.close) await client.admin.command("ping") # connect async with self.fail_point( {"mode": {"times": 1}, "data": {"closeConnection": True, "failCommands": ["find"]}} @@ -1894,7 +1866,6 @@ async def test_service_name_from_kwargs(self): srvServiceName="customname", connect=False, ) - self.addAsyncCleanup(client.close) self.assertEqual(client._topology_settings.srv_service_name, "customname") client = AsyncMongoClient( "mongodb+srv://user:password@test22.test.build.10gen.cc" @@ -1902,26 +1873,21 @@ async def test_service_name_from_kwargs(self): srvServiceName="customname", connect=False, ) - self.addAsyncCleanup(client.close) self.assertEqual(client._topology_settings.srv_service_name, "customname") client = AsyncMongoClient( "mongodb+srv://user:password@test22.test.build.10gen.cc/?srvServiceName=customname", connect=False, ) - self.addAsyncCleanup(client.close) self.assertEqual(client._topology_settings.srv_service_name, "customname") async def test_srv_max_hosts_kwarg(self): client = AsyncMongoClient("mongodb+srv://test1.test.build.10gen.cc/") - self.addAsyncCleanup(client.close) self.assertGreater(len(client.topology_description.server_descriptions()), 1) client = AsyncMongoClient("mongodb+srv://test1.test.build.10gen.cc/", srvmaxhosts=1) - self.addAsyncCleanup(client.close) self.assertEqual(len(client.topology_description.server_descriptions()), 1) client = AsyncMongoClient( "mongodb+srv://test1.test.build.10gen.cc/?srvMaxHosts=1", srvmaxhosts=2 ) - self.addAsyncCleanup(client.close) self.assertEqual(len(client.topology_description.server_descriptions()), 2) @unittest.skipIf( @@ -2067,7 +2033,6 @@ async def test_exhaust_query_server_error(self): # When doing an exhaust query, the socket stays checked out on success # but must be checked in on error to avoid semaphore leaks. client = await connected(await self.async_rs_or_single_client(maxPoolSize=1)) - self.addAsyncCleanup(client.close) collection = client.pymongo_test.test pool = await async_get_pool(client) @@ -2091,7 +2056,6 @@ async def test_exhaust_getmore_server_error(self): # When doing a getmore on an exhaust cursor, the socket stays checked # out on success but it's checked in on error to avoid semaphore leaks. client = await self.async_rs_or_single_client(maxPoolSize=1) - self.addAsyncCleanup(client.close) collection = client.pymongo_test.test await collection.drop() @@ -2133,7 +2097,6 @@ async def test_exhaust_query_network_error(self): client = await connected( await self.async_rs_or_single_client(maxPoolSize=1, retryReads=False) ) - self.addAsyncCleanup(client.close) collection = client.pymongo_test.test pool = await async_get_pool(client) pool._check_interval_seconds = None # Never check. @@ -2155,7 +2118,6 @@ async def test_exhaust_getmore_network_error(self): # When doing a getmore on an exhaust cursor, the socket stays checked # out on success but it's checked in on error to avoid semaphore leaks. client = await self.async_rs_or_single_client(maxPoolSize=1) - self.addAsyncCleanup(client.close) collection = client.pymongo_test.test await collection.drop() await collection.insert_many([{} for _ in range(200)]) # More than one batch. @@ -2345,6 +2307,7 @@ async def test_discover_primary(self): replicaSet="rs", heartbeatFrequencyMS=500, ) + self.addAsyncCleanup(c.close) wait_until(lambda: len(c.nodes) == 3, "connect") @@ -2371,6 +2334,7 @@ async def test_reconnect(self): retryReads=False, serverSelectionTimeoutMS=1000, ) + self.addAsyncCleanup(c.close) wait_until(lambda: len(c.nodes) == 3, "connect") @@ -2406,6 +2370,8 @@ async def _test_network_error(self, operation_callback): serverSelectionTimeoutMS=1000, ) + self.addAsyncCleanup(c.close) + # Set host-specific information so we can test whether it is reset. c.set_wire_version_range("a:1", 2, 6) c.set_wire_version_range("b:2", 2, 7) @@ -2480,6 +2446,7 @@ async def test_rs_client_does_not_maintain_pool_to_arbiters(self): minPoolSize=1, # minPoolSize event_listeners=[listener], ) + self.addAsyncCleanup(c.close) wait_until(lambda: len(c.nodes) == 3, "connect") self.assertEqual(await c.address, ("a", 1)) @@ -2509,6 +2476,7 @@ async def test_direct_client_maintains_pool_to_arbiter(self): minPoolSize=1, # minPoolSize event_listeners=[listener], ) + self.addAsyncCleanup(c.close) wait_until(lambda: len(c.nodes) == 1, "connect") self.assertEqual(await c.address, ("c", 3)) diff --git a/test/asynchronous/test_client_bulk_write.py b/test/asynchronous/test_client_bulk_write.py index 997db94d55..3a17299453 100644 --- a/test/asynchronous/test_client_bulk_write.py +++ b/test/asynchronous/test_client_bulk_write.py @@ -96,7 +96,6 @@ async def asyncSetUp(self): async def test_batch_splits_if_num_operations_too_large(self): listener = OvertCommandListener() client = await self.async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.close) models = [] for _ in range(self.max_write_batch_size + 1): @@ -122,7 +121,6 @@ async def test_batch_splits_if_num_operations_too_large(self): async def test_batch_splits_if_ops_payload_too_large(self): listener = OvertCommandListener() client = await self.async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.close) models = [] num_models = int(self.max_message_size_bytes / self.max_bson_object_size + 1) @@ -159,7 +157,6 @@ async def test_collects_write_concern_errors_across_batches(self): event_listeners=[listener], retryWrites=False, ) - self.addAsyncCleanup(client.close) fail_command = { "configureFailPoint": "failCommand", @@ -199,7 +196,6 @@ async def test_collects_write_concern_errors_across_batches(self): async def test_collects_write_errors_across_batches_unordered(self): listener = OvertCommandListener() client = await self.async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.close) collection = client.db["coll"] self.addAsyncCleanup(collection.drop) @@ -230,7 +226,6 @@ async def test_collects_write_errors_across_batches_unordered(self): async def test_collects_write_errors_across_batches_ordered(self): listener = OvertCommandListener() client = await self.async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.close) collection = client.db["coll"] self.addAsyncCleanup(collection.drop) @@ -261,7 +256,6 @@ async def test_collects_write_errors_across_batches_ordered(self): async def test_handles_cursor_requiring_getMore(self): listener = OvertCommandListener() client = await self.async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.close) collection = client.db["coll"] self.addAsyncCleanup(collection.drop) @@ -303,7 +297,6 @@ async def test_handles_cursor_requiring_getMore(self): async def test_handles_cursor_requiring_getMore_within_transaction(self): listener = OvertCommandListener() client = await self.async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.close) collection = client.db["coll"] self.addAsyncCleanup(collection.drop) @@ -347,7 +340,6 @@ async def test_handles_cursor_requiring_getMore_within_transaction(self): async def test_handles_getMore_error(self): listener = OvertCommandListener() client = await self.async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.close) collection = client.db["coll"] self.addAsyncCleanup(collection.drop) @@ -402,7 +394,6 @@ async def test_handles_getMore_error(self): async def test_returns_error_if_unacknowledged_too_large_insert(self): listener = OvertCommandListener() client = await self.async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.close) b_repeated = "b" * self.max_bson_object_size @@ -459,7 +450,6 @@ async def _setup_namespace_test_models(self): async def test_no_batch_splits_if_new_namespace_is_not_too_large(self): listener = OvertCommandListener() client = await self.async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.close) num_models, models = await self._setup_namespace_test_models() models.append( @@ -491,7 +481,6 @@ async def test_no_batch_splits_if_new_namespace_is_not_too_large(self): async def test_batch_splits_if_new_namespace_is_too_large(self): listener = OvertCommandListener() client = await self.async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.close) num_models, models = await self._setup_namespace_test_models() c_repeated = "c" * 200 @@ -529,7 +518,6 @@ async def test_batch_splits_if_new_namespace_is_too_large(self): @async_client_context.require_no_serverless async def test_returns_error_if_no_writes_can_be_added_to_ops(self): client = await self.async_rs_or_single_client() - self.addAsyncCleanup(client.close) # Document too large. b_repeated = "b" * self.max_message_size_bytes @@ -553,7 +541,6 @@ async def test_returns_error_if_auto_encryption_configured(self): kms_providers={"aws": {"accessKeyId": "foo", "secretAccessKey": "bar"}}, ) client = await self.async_rs_or_single_client(auto_encryption_opts=opts) - self.addAsyncCleanup(client.close) models = [InsertOne(namespace="db.coll", document={"a": "b"})] with self.assertRaises(InvalidOperation) as context: @@ -610,7 +597,6 @@ async def test_timeout_in_multi_batch_bulk_write(self): timeoutMS=2000, w="majority", ) - self.addAsyncCleanup(client.close) await client.admin.command("ping") # Init the client first. with self.assertRaises(ClientBulkWriteException) as context: await client.bulk_write(models=models) diff --git a/test/asynchronous/test_cursor.py b/test/asynchronous/test_cursor.py index 8431f4369b..d6773d832e 100644 --- a/test/asynchronous/test_cursor.py +++ b/test/asynchronous/test_cursor.py @@ -353,7 +353,6 @@ async def test_explain_with_read_concern(self): # Do not add readConcern level to explain. listener = AllowListEventListener("explain") client = await self.async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.close) coll = client.pymongo_test.test.with_options(read_concern=ReadConcern(level="local")) self.assertTrue(await coll.find().explain()) started = listener.started_events @@ -1261,7 +1260,6 @@ async def test_close_kills_cursor_synchronously(self): listener = AllowListEventListener("killCursors") client = await self.async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.close) coll = client[self.db.name].test_close_kills_cursors # Add some test data. @@ -1300,7 +1298,6 @@ def assertCursorKilled(): async def test_timeout_kills_cursor_asynchronously(self): listener = AllowListEventListener("killCursors") client = await self.async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.close) coll = client[self.db.name].test_timeout_kills_cursor # Add some test data. @@ -1358,7 +1355,6 @@ def test_delete_not_initialized(self): async def test_getMore_does_not_send_readPreference(self): listener = AllowListEventListener("find", "getMore") client = await self.async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.close) # We never send primary read preference so override the default. coll = client[self.db.name].get_collection( "test", read_preference=ReadPreference.PRIMARY_PREFERRED @@ -1788,7 +1784,6 @@ async def test_monitoring(self): async def test_exhaust_cursor_db_set(self): listener = OvertCommandListener() client = await self.async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.close) c = client.pymongo_test.test await c.delete_many({}) await c.insert_many([{"_id": i} for i in range(3)]) diff --git a/test/asynchronous/test_database.py b/test/asynchronous/test_database.py index 2e1f8e0450..c5d62323df 100644 --- a/test/asynchronous/test_database.py +++ b/test/asynchronous/test_database.py @@ -235,7 +235,6 @@ async def test_list_collection_names_filter(self): async def test_check_exists(self): listener = OvertCommandListener() client = await self.async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.close) db = client[self.db.name] await db.drop_collection("unique") await db.create_collection("unique", check_exists=True) diff --git a/test/asynchronous/test_encryption.py b/test/asynchronous/test_encryption.py index 5fda81e35c..3000d2361d 100644 --- a/test/asynchronous/test_encryption.py +++ b/test/asynchronous/test_encryption.py @@ -31,7 +31,7 @@ from test.asynchronous import AsyncIntegrationTest, AsyncPyMongoTestCase, async_client_context from test.asynchronous.test_bulk import AsyncBulkTestBase from threading import Thread -from typing import Any, Dict, Mapping +from typing import Any, Dict, Mapping, Optional import pytest @@ -109,13 +109,12 @@ class TestAutoEncryptionOpts(AsyncPyMongoTestCase): @unittest.skipUnless(os.environ.get("TEST_CRYPT_SHARED"), "crypt_shared lib is not installed") async def test_crypt_shared(self): # Test that we can pick up crypt_shared lib automatically - client = AsyncMongoClient( + self.simple_client( auto_encryption_opts=AutoEncryptionOpts( KMS_PROVIDERS, "keyvault.datakeys", crypt_shared_lib_required=True ), connect=False, ) - self.addAsyncCleanup(client.aclose) @unittest.skipIf(_HAVE_PYMONGOCRYPT, "pymongocrypt is installed") def test_init_requires_pymongocrypt(self): @@ -196,19 +195,16 @@ def test_init_kms_tls_options(self): class TestClientOptions(AsyncPyMongoTestCase): async def test_default(self): - client = AsyncMongoClient(connect=False) - self.addAsyncCleanup(client.aclose) + client = self.simple_client(connect=False) self.assertEqual(get_client_opts(client).auto_encryption_opts, None) - client = AsyncMongoClient(auto_encryption_opts=None, connect=False) - self.addAsyncCleanup(client.aclose) + client = self.simple_client(auto_encryption_opts=None, connect=False) self.assertEqual(get_client_opts(client).auto_encryption_opts, None) @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") async def test_kwargs(self): opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys") - client = AsyncMongoClient(auto_encryption_opts=opts, connect=False) - self.addAsyncCleanup(client.aclose) + client = self.simple_client(auto_encryption_opts=opts, connect=False) self.assertEqual(get_client_opts(client).auto_encryption_opts, opts) @@ -229,6 +225,34 @@ def assertBinaryUUID(self, val): self.assertIsInstance(val, Binary) self.assertEqual(val.subtype, UUID_SUBTYPE) + def create_client_encryption( + self, + kms_providers: Mapping[str, Any], + key_vault_namespace: str, + key_vault_client: AsyncMongoClient, + codec_options: Mapping[str, Any], + kms_tls_options: Optional[Mapping[str, Any]] = None, + ): + client_encryption = AsyncClientEncryption( + kms_providers, key_vault_namespace, key_vault_client, codec_options, kms_tls_options + ) + self.addAsyncCleanup(client_encryption.close) + return client_encryption + + @classmethod + def unmanaged_create_client_encryption( + cls, + kms_providers: Mapping[str, Any], + key_vault_namespace: str, + key_vault_client: AsyncMongoClient, + codec_options: Mapping[str, Any], + kms_tls_options: Optional[Mapping[str, Any]] = None, + ): + client_encryption = AsyncClientEncryption( + kms_providers, key_vault_namespace, key_vault_client, codec_options, kms_tls_options + ) + return client_encryption + # Location of JSON test files. if _IS_SYNC: @@ -261,7 +285,6 @@ def bson_data(*paths): class TestClientSimple(AsyncEncryptionIntegrationTest): async def _test_auto_encrypt(self, opts): client = await self.async_rs_or_single_client(auto_encryption_opts=opts) - self.addAsyncCleanup(client.aclose) # Create the encrypted field's data key. key_vault = await create_key_vault( @@ -343,7 +366,6 @@ async def test_auto_encrypt_local_schema_map(self): async def test_use_after_close(self): opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys") client = await self.async_rs_or_single_client(auto_encryption_opts=opts) - self.addAsyncCleanup(client.aclose) await client.admin.command("ping") await client.aclose() @@ -361,7 +383,6 @@ async def test_use_after_close(self): async def test_fork(self): opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys") client = await self.async_rs_or_single_client(auto_encryption_opts=opts) - self.addAsyncCleanup(client.aclose) async def target(): with warnings.catch_warnings(): @@ -376,7 +397,6 @@ class TestEncryptedBulkWrite(AsyncBulkTestBase, AsyncEncryptionIntegrationTest): async def test_upsert_uuid_standard_encrypt(self): opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys") client = await self.async_rs_or_single_client(auto_encryption_opts=opts) - self.addAsyncCleanup(client.aclose) options = CodecOptions(uuid_representation=UuidRepresentation.STANDARD) encrypted_coll = client.pymongo_test.test @@ -417,7 +437,6 @@ async def _setup_class(cls): async def test_raise_max_wire_version_error(self): opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys") client = await self.async_rs_or_single_client(auto_encryption_opts=opts) - self.addAsyncCleanup(client.aclose) msg = "Auto-encryption requires a minimum MongoDB version of 4.2" with self.assertRaisesRegex(ConfigurationError, msg): await client.test.test.insert_one({}) @@ -431,7 +450,6 @@ async def test_raise_max_wire_version_error(self): async def test_raise_unsupported_error(self): opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys") client = await self.async_rs_or_single_client(auto_encryption_opts=opts) - self.addAsyncCleanup(client.aclose) msg = "find_raw_batches does not support auto encryption" with self.assertRaisesRegex(InvalidOperation, msg): await client.test.test.find_raw_batches({}) @@ -450,10 +468,9 @@ async def test_raise_unsupported_error(self): class TestExplicitSimple(AsyncEncryptionIntegrationTest): async def test_encrypt_decrypt(self): - client_encryption = AsyncClientEncryption( + client_encryption = self.create_client_encryption( KMS_PROVIDERS, "keyvault.datakeys", async_client_context.client, OPTS ) - self.addAsyncCleanup(client_encryption.close) # Use standard UUID representation. key_vault = async_client_context.client.keyvault.get_collection( "datakeys", codec_options=OPTS @@ -495,10 +512,9 @@ async def test_encrypt_decrypt(self): self.assertEqual(decrypted_ssn, doc["ssn"]) async def test_validation(self): - client_encryption = AsyncClientEncryption( + client_encryption = self.create_client_encryption( KMS_PROVIDERS, "keyvault.datakeys", async_client_context.client, OPTS ) - self.addAsyncCleanup(client_encryption.close) msg = "value to decrypt must be a bson.binary.Binary with subtype 6" with self.assertRaisesRegex(TypeError, msg): @@ -512,10 +528,9 @@ async def test_validation(self): await client_encryption.encrypt("str", algo, key_id=Binary(b"123")) async def test_bson_errors(self): - client_encryption = AsyncClientEncryption( + client_encryption = self.create_client_encryption( KMS_PROVIDERS, "keyvault.datakeys", async_client_context.client, OPTS ) - self.addAsyncCleanup(client_encryption.close) # Attempt to encrypt an unencodable object. unencodable_value = object() @@ -528,7 +543,7 @@ async def test_bson_errors(self): async def test_codec_options(self): with self.assertRaisesRegex(TypeError, "codec_options must be"): - AsyncClientEncryption( + self.create_client_encryption( KMS_PROVIDERS, "keyvault.datakeys", async_client_context.client, @@ -536,10 +551,9 @@ async def test_codec_options(self): ) opts = CodecOptions(uuid_representation=UuidRepresentation.JAVA_LEGACY) - client_encryption_legacy = AsyncClientEncryption( + client_encryption_legacy = self.create_client_encryption( KMS_PROVIDERS, "keyvault.datakeys", async_client_context.client, opts ) - self.addAsyncCleanup(client_encryption_legacy.close) # Create the encrypted field's data key. key_id = await client_encryption_legacy.create_data_key("local") @@ -554,10 +568,9 @@ async def test_codec_options(self): # Encrypt the same UUID with STANDARD codec options. opts = CodecOptions(uuid_representation=UuidRepresentation.STANDARD) - client_encryption = AsyncClientEncryption( + client_encryption = self.create_client_encryption( KMS_PROVIDERS, "keyvault.datakeys", async_client_context.client, opts ) - self.addAsyncCleanup(client_encryption.close) encrypted_standard = await client_encryption.encrypt( value, Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_id=key_id ) @@ -573,7 +586,7 @@ async def test_codec_options(self): self.assertNotEqual(await client_encryption.decrypt(encrypted_legacy), value) async def test_close(self): - client_encryption = AsyncClientEncryption( + client_encryption = self.create_client_encryption( KMS_PROVIDERS, "keyvault.datakeys", async_client_context.client, OPTS ) await client_encryption.close() @@ -589,7 +602,7 @@ async def test_close(self): await client_encryption.decrypt(Binary(b"", 6)) async def test_with_statement(self): - async with AsyncClientEncryption( + async with self.create_client_encryption( KMS_PROVIDERS, "keyvault.datakeys", async_client_context.client, OPTS ) as client_encryption: pass @@ -836,7 +849,7 @@ async def _setup_class(cls): cls.client_encrypted = await cls.unmanaged_async_rs_or_single_client( auto_encryption_opts=opts, uuidRepresentation="standard" ) - cls.client_encryption = AsyncClientEncryption( + cls.client_encryption = cls.unmanaged_create_client_encryption( cls.KMS_PROVIDERS, "keyvault.datakeys", cls.client, OPTS, kms_tls_options=KMS_TLS_OPTS ) @@ -926,7 +939,6 @@ async def _test_external_key_vault(self, with_external_key_vault): key_vault_client = await self.async_rs_or_single_client( username="fake-user", password="fake-pwd" ) - self.addAsyncCleanup(key_vault_client.close) else: key_vault_client = async_client_context.client opts = AutoEncryptionOpts( @@ -939,12 +951,10 @@ async def _test_external_key_vault(self, with_external_key_vault): client_encrypted = await self.async_rs_or_single_client( auto_encryption_opts=opts, uuidRepresentation="standard" ) - self.addAsyncCleanup(client_encrypted.close) - client_encryption = AsyncClientEncryption( + client_encryption = self.create_client_encryption( self.kms_providers(), "keyvault.datakeys", key_vault_client, OPTS ) - self.addAsyncCleanup(client_encryption.close) if with_external_key_vault: # Authentication error. @@ -993,7 +1003,6 @@ async def test_views_are_prohibited(self): client_encrypted = await self.async_rs_or_single_client( auto_encryption_opts=opts, uuidRepresentation="standard" ) - self.addAsyncCleanup(client_encrypted.aclose) with self.assertRaisesRegex(EncryptionError, "cannot auto encrypt a view"): await client_encrypted.db.view.insert_one({}) @@ -1051,16 +1060,14 @@ async def _test_corpus(self, opts): self.addAsyncCleanup(vault.drop) client_encrypted = await self.async_rs_or_single_client(auto_encryption_opts=opts) - self.addAsyncCleanup(client_encrypted.close) - client_encryption = AsyncClientEncryption( + client_encryption = self.create_client_encryption( self.kms_providers(), "keyvault.datakeys", async_client_context.client, OPTS, kms_tls_options=KMS_TLS_OPTS, ) - self.addAsyncCleanup(client_encryption.close) corpus = self.fix_up_curpus(json_data("corpus", "corpus.json")) corpus_copied: SON = SON() @@ -1291,7 +1298,7 @@ def setUp(self): "gcp": GCP_CREDS, "kmip": KMIP_CREDS, } - self.client_encryption = AsyncClientEncryption( + self.client_encryption = self.create_client_encryption( kms_providers=kms_providers, key_vault_namespace="keyvault.datakeys", key_vault_client=async_client_context.client, @@ -1303,7 +1310,7 @@ def setUp(self): kms_providers_invalid["azure"]["identityPlatformEndpoint"] = "doesnotexist.invalid:443" kms_providers_invalid["gcp"]["endpoint"] = "doesnotexist.invalid:443" kms_providers_invalid["kmip"]["endpoint"] = "doesnotexist.local:5698" - self.client_encryption_invalid = AsyncClientEncryption( + self.client_encryption_invalid = self.create_client_encryption( kms_providers=kms_providers_invalid, key_vault_namespace="keyvault.datakeys", key_vault_client=async_client_context.client, @@ -1484,7 +1491,7 @@ async def test_12_kmip_master_key_invalid_endpoint(self): await self.client_encryption.create_data_key("kmip", key) -class AzureGCPEncryptionTestMixin: +class AzureGCPEncryptionTestMixin(AsyncEncryptionIntegrationTest): DEK = None KMS_PROVIDER_MAP = None KEYVAULT_DB = "keyvault" @@ -1496,7 +1503,7 @@ async def asyncSetUp(self): await create_key_vault(keyvault, self.DEK) async def _test_explicit(self, expectation): - client_encryption = AsyncClientEncryption( + client_encryption = self.create_client_encryption( self.KMS_PROVIDER_MAP, # type: ignore[arg-type] ".".join([self.KEYVAULT_DB, self.KEYVAULT_COLL]), async_client_context.client, @@ -1607,7 +1614,6 @@ async def asyncSetUp(self): self.client_test = await self.async_rs_or_single_client( maxPoolSize=1, readConcernLevel="majority", w="majority", uuidRepresentation="standard" ) - self.addAsyncCleanup(self.client_test.aclose) self.client_keyvault_listener = OvertCommandListener() self.client_keyvault = await self.async_rs_or_single_client( @@ -1616,7 +1622,6 @@ async def asyncSetUp(self): w="majority", event_listeners=[self.client_keyvault_listener], ) - self.addAsyncCleanup(self.client_keyvault.aclose) await self.client_test.keyvault.datakeys.drop() await self.client_test.db.coll.drop() @@ -1629,7 +1634,7 @@ async def asyncSetUp(self): codec_options=OPTS, ) - client_encryption = AsyncClientEncryption( + client_encryption = self.create_client_encryption( kms_providers={"local": {"key": LOCAL_MASTER_KEY}}, key_vault_namespace="keyvault.datakeys", key_vault_client=self.client_test, @@ -1663,8 +1668,6 @@ async def _run_test(self, max_pool_size, auto_encryption_opts): result = await client_encrypted.db.coll.find_one({"_id": 0}) self.assertEqual(result, {"_id": 0, "encrypted": "string0"}) - self.addAsyncCleanup(client_encrypted.close) - async def test_case_1(self): await self._run_test( max_pool_size=1, @@ -1840,7 +1843,7 @@ async def asyncSetUp(self): await create_key_vault(self.client.keyvault.datakeys) kms_providers_map = {"local": {"key": LOCAL_MASTER_KEY}} - self.client_encryption = AsyncClientEncryption( + self.client_encryption = self.create_client_encryption( kms_providers_map, "keyvault.datakeys", self.client, CodecOptions() ) keyID = await self.client_encryption.create_data_key("local") @@ -1858,7 +1861,6 @@ async def asyncSetUp(self): self.encrypted_client = await self.async_rs_or_single_client( auto_encryption_opts=opts, retryReads=False, event_listeners=[self.listener] ) - self.addAsyncCleanup(self.encrypted_client.close) async def test_01_command_error(self): async with self.fail_point( @@ -1936,7 +1938,6 @@ def reset_timeout(): ], ) client_encrypted = await self.async_rs_or_single_client(auto_encryption_opts=opts) - self.addAsyncCleanup(client_encrypted.close) with self.assertRaisesRegex(EncryptionError, "Timeout"): await client_encrypted.db.coll.insert_one({"encrypted": "test"}) @@ -1951,13 +1952,11 @@ async def test_bypassAutoEncryption(self): ], ) client_encrypted = await self.async_rs_or_single_client(auto_encryption_opts=opts) - self.addAsyncCleanup(client_encrypted.close) await client_encrypted.db.coll.insert_one({"unencrypted": "test"}) # Validate that mongocryptd was not spawned: - mongocryptd_client = AsyncMongoClient( + mongocryptd_client = self.simple_client( "mongodb://localhost:27027/?serverSelectionTimeoutMS=500" ) - self.addAsyncCleanup(mongocryptd_client.close) with self.assertRaises(ServerSelectionTimeoutError): await mongocryptd_client.admin.command("ping") @@ -1980,14 +1979,12 @@ async def test_via_loading_shared_library(self): crypt_shared_lib_required=True, ) client_encrypted = await self.async_rs_or_single_client(auto_encryption_opts=opts) - self.addAsyncCleanup(client_encrypted.aclose) await client_encrypted.db.coll.drop() await client_encrypted.db.coll.insert_one({"encrypted": "test"}) self.assertEncrypted((await async_client_context.client.db.coll.find_one({}))["encrypted"]) - no_mongocryptd_client = AsyncMongoClient( + no_mongocryptd_client = self.simple_client( host="mongodb://localhost:47021/db?serverSelectionTimeoutMS=1000" ) - self.addAsyncCleanup(no_mongocryptd_client.aclose) with self.assertRaises(ServerSelectionTimeoutError): await no_mongocryptd_client.db.command("ping") @@ -2022,7 +2019,6 @@ def listener(): crypt_shared_lib_required=False, ) client_encrypted = await self.async_rs_or_single_client(auto_encryption_opts=opts) - self.addAsyncCleanup(client_encrypted.aclose) await client_encrypted.db.coll.drop() await client_encrypted.db.coll.insert_one({"encrypted": "test"}) server.shutdown() @@ -2036,10 +2032,9 @@ class TestKmsTLSProse(AsyncEncryptionIntegrationTest): async def asyncSetUp(self): await super().asyncSetUp() self.patch_system_certs(CA_PEM) - self.client_encrypted = AsyncClientEncryption( + self.client_encrypted = self.create_client_encryption( {"aws": AWS_CREDS}, "keyvault.datakeys", self.client, OPTS ) - self.addAsyncCleanup(self.client_encrypted.close) async def test_invalid_kms_certificate_expired(self): key = { @@ -2084,36 +2079,32 @@ async def asyncSetUp(self): "gcp": {"tlsCAFile": CA_PEM}, "kmip": {"tlsCAFile": CA_PEM}, } - self.client_encryption_no_client_cert = AsyncClientEncryption( + self.client_encryption_no_client_cert = self.create_client_encryption( providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=kms_tls_opts_ca_only ) - self.addAsyncCleanup(self.client_encryption_no_client_cert.close) # 2, same providers as above but with tlsCertificateKeyFile. kms_tls_opts = copy.deepcopy(kms_tls_opts_ca_only) for p in kms_tls_opts: kms_tls_opts[p]["tlsCertificateKeyFile"] = CLIENT_PEM - self.client_encryption_with_tls = AsyncClientEncryption( + self.client_encryption_with_tls = self.create_client_encryption( providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=kms_tls_opts ) - self.addAsyncCleanup(self.client_encryption_with_tls.close) # 3, update endpoints to expired host. providers: dict = copy.deepcopy(providers) providers["azure"]["identityPlatformEndpoint"] = "127.0.0.1:9000" providers["gcp"]["endpoint"] = "127.0.0.1:9000" providers["kmip"]["endpoint"] = "127.0.0.1:9000" - self.client_encryption_expired = AsyncClientEncryption( + self.client_encryption_expired = self.create_client_encryption( providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=kms_tls_opts_ca_only ) - self.addAsyncCleanup(self.client_encryption_expired.close) # 3, update endpoints to invalid host. providers: dict = copy.deepcopy(providers) providers["azure"]["identityPlatformEndpoint"] = "127.0.0.1:9001" providers["gcp"]["endpoint"] = "127.0.0.1:9001" providers["kmip"]["endpoint"] = "127.0.0.1:9001" - self.client_encryption_invalid_hostname = AsyncClientEncryption( + self.client_encryption_invalid_hostname = self.create_client_encryption( providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=kms_tls_opts_ca_only ) - self.addAsyncCleanup(self.client_encryption_invalid_hostname.close) # Errors when client has no cert, some examples: # [SSL: TLSV13_ALERT_CERTIFICATE_REQUIRED] tlsv13 alert certificate required (_ssl.c:2623) self.cert_error = ( @@ -2151,7 +2142,7 @@ async def asyncSetUp(self): "gcp:with_tls": with_cert, "kmip:with_tls": with_cert, } - self.client_encryption_with_names = AsyncClientEncryption( + self.client_encryption_with_names = self.create_client_encryption( providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=kms_tls_opts_4 ) @@ -2233,10 +2224,9 @@ async def test_04_kmip(self): async def test_05_tlsDisableOCSPEndpointCheck_is_permitted(self): providers = {"aws": {"accessKeyId": "foo", "secretAccessKey": "bar"}} options = {"aws": {"tlsDisableOCSPEndpointCheck": True}} - encryption = AsyncClientEncryption( + encryption = self.create_client_encryption( providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=options ) - self.addAsyncCleanup(encryption.close) ctx = encryption._io_callbacks.opts._kms_ssl_contexts["aws"] if not hasattr(ctx, "check_ocsp_endpoint"): raise self.skipTest("OCSP not enabled") @@ -2286,7 +2276,7 @@ async def asyncSetUp(self): self.client = async_client_context.client await create_key_vault(self.client.keyvault.datakeys) kms_providers_map = {"local": {"key": LOCAL_MASTER_KEY}} - self.client_encryption = AsyncClientEncryption( + self.client_encryption = self.create_client_encryption( kms_providers_map, "keyvault.datakeys", self.client, CodecOptions() ) self.def_key_id = await self.client_encryption.create_data_key( @@ -2328,17 +2318,15 @@ async def asyncSetUp(self): key_vault = await create_key_vault(self.client.keyvault.datakeys, self.key1_document) self.addCleanup(key_vault.drop) self.key_vault_client = self.client - self.client_encryption = AsyncClientEncryption( + self.client_encryption = self.create_client_encryption( {"local": {"key": LOCAL_MASTER_KEY}}, key_vault.full_name, self.key_vault_client, OPTS ) - self.addAsyncCleanup(self.client_encryption.close) opts = AutoEncryptionOpts( {"local": {"key": LOCAL_MASTER_KEY}}, key_vault.full_name, bypass_query_analysis=True, ) self.encrypted_client = await self.async_rs_or_single_client(auto_encryption_opts=opts) - self.addAsyncCleanup(self.encrypted_client.aclose) async def test_01_insert_encrypted_indexed_and_find(self): val = "encrypted indexed value" @@ -2465,14 +2453,13 @@ async def run_test(self, src_provider, dst_provider): await self.client.keyvault.drop_collection("datakeys") # Step 2. Create a ``AsyncClientEncryption`` object named ``client_encryption1`` - client_encryption1 = AsyncClientEncryption( + client_encryption1 = self.create_client_encryption( key_vault_client=self.client, key_vault_namespace="keyvault.datakeys", kms_providers=ALL_KMS_PROVIDERS, kms_tls_options=KMS_TLS_OPTS, codec_options=OPTS, ) - self.addAsyncCleanup(client_encryption1.close) # Step 3. Call ``client_encryption1.create_data_key`` with ``src_provider``. key_id = await client_encryption1.create_data_key( @@ -2486,15 +2473,13 @@ async def run_test(self, src_provider, dst_provider): # Step 5. Create a ``AsyncClientEncryption`` object named ``client_encryption2`` client2 = await self.async_rs_or_single_client() - self.addAsyncCleanup(client2.aclose) - client_encryption2 = AsyncClientEncryption( + client_encryption2 = self.create_client_encryption( key_vault_client=client2, key_vault_namespace="keyvault.datakeys", kms_providers=ALL_KMS_PROVIDERS, kms_tls_options=KMS_TLS_OPTS, codec_options=OPTS, ) - self.addAsyncCleanup(client_encryption2.close) # Step 6. Call ``client_encryption2.rewrap_many_data_key`` with an empty ``filter``. rewrap_many_data_key_result = await client_encryption2.rewrap_many_data_key( @@ -2529,7 +2514,7 @@ async def asyncSetUp(self): @unittest.skipIf(any(AWS_CREDS.values()), "AWS environment credentials are set") async def test_01_failure(self): - self.client_encryption = AsyncClientEncryption( + self.client_encryption = self.create_client_encryption( kms_providers={"aws": {}}, key_vault_namespace="keyvault.datakeys", key_vault_client=async_client_context.client, @@ -2540,7 +2525,7 @@ async def test_01_failure(self): @unittest.skipUnless(any(AWS_CREDS.values()), "AWS environment credentials are not set") async def test_02_success(self): - self.client_encryption = AsyncClientEncryption( + self.client_encryption = self.create_client_encryption( kms_providers={"aws": {}}, key_vault_namespace="keyvault.datakeys", key_vault_client=async_client_context.client, @@ -2561,7 +2546,6 @@ async def test_queryable_encryption(self): # and cleanup. async def AsyncMongoClient(**kwargs): c = await self.async_rs_or_single_client(**kwargs) - self.addAsyncCleanup(c.aclose) return c # Drop data from prior test runs. @@ -2572,7 +2556,7 @@ async def AsyncMongoClient(**kwargs): # Create two data keys. key_vault_client = await AsyncMongoClient() - client_encryption = AsyncClientEncryption( + client_encryption = self.create_client_encryption( kms_providers_map, "keyvault.datakeys", key_vault_client, CodecOptions() ) key1_id = await client_encryption.create_data_key("local") @@ -2653,10 +2637,9 @@ async def asyncSetUp(self): key_vault = await create_key_vault(self.client.keyvault.datakeys, self.key1_document) self.addCleanup(key_vault.drop) self.key_vault_client = self.client - self.client_encryption = AsyncClientEncryption( + self.client_encryption = self.create_client_encryption( {"local": {"key": LOCAL_MASTER_KEY}}, key_vault.full_name, self.key_vault_client, OPTS ) - self.addAsyncCleanup(self.client_encryption.close) opts = AutoEncryptionOpts( {"local": {"key": LOCAL_MASTER_KEY}}, key_vault.full_name, @@ -2664,7 +2647,6 @@ async def asyncSetUp(self): ) self.encrypted_client = await self.async_rs_or_single_client(auto_encryption_opts=opts) self.db = self.encrypted_client.db - self.addAsyncCleanup(self.encrypted_client.aclose) async def run_expression_find( self, name, expression, expected_elems, range_opts, use_expr=False, key_id=None @@ -2861,10 +2843,9 @@ async def asyncSetUp(self): await super().asyncSetUp() await self.client.drop_database(self.db) self.key_vault_client = self.client - self.client_encryption = AsyncClientEncryption( + self.client_encryption = self.create_client_encryption( {"local": {"key": LOCAL_MASTER_KEY}}, "keyvault.datakeys", self.key_vault_client, OPTS ) - self.addAsyncCleanup(self.client_encryption.close) self.key_id = await self.client_encryption.create_data_key("local") opts = RangeOpts(min=0, max=1000) self.payload_defaults = await self.client_encryption.encrypt( @@ -2897,13 +2878,12 @@ async def asyncSetUp(self): await self.client.drop_database(self.db) self.key_vault = await create_key_vault(self.client.keyvault.datakeys, self.key1_document) self.addAsyncCleanup(self.key_vault.drop) - self.client_encryption = AsyncClientEncryption( + self.client_encryption = self.create_client_encryption( {"local": {"key": LOCAL_MASTER_KEY}}, self.key_vault.full_name, self.client, OPTS, ) - self.addAsyncCleanup(self.client_encryption.close) async def test_01_simple_create(self): coll, _ = await self.client_encryption.create_encrypted_collection( @@ -3119,10 +3099,9 @@ async def _tearDown_class(cls): async def asyncSetUp(self) -> None: self.listener = OvertCommandListener() - self.mongocryptd_client = AsyncMongoClient( + self.mongocryptd_client = self.simple_client( f"mongodb://localhost:{self.MONGOCRYPTD_PORT}", event_listeners=[self.listener] ) - self.addAsyncCleanup(self.mongocryptd_client.aclose) hello = await self.mongocryptd_client.db.command("hello") self.assertNotIn("logicalSessionTimeoutMinutes", hello) diff --git a/test/asynchronous/test_grid_file.py b/test/asynchronous/test_grid_file.py index f8b0c60dad..9c57c15c5a 100644 --- a/test/asynchronous/test_grid_file.py +++ b/test/asynchronous/test_grid_file.py @@ -792,8 +792,7 @@ async def test_grid_out_lazy_connect(self): await outfile.readchunk() async def test_grid_in_lazy_connect(self): - client = AsyncMongoClient("badhost", connect=False, serverSelectionTimeoutMS=10) - self.addAsyncCleanup(client.close) + client = self.simple_client("badhost", connect=False, serverSelectionTimeoutMS=10) fs = client.db.fs infile = AsyncGridIn(fs, file_id=-1, chunk_size=1) with self.assertRaises(ServerSelectionTimeoutError): diff --git a/test/asynchronous/test_session.py b/test/asynchronous/test_session.py index 046d091f3b..d264b5ecb0 100644 --- a/test/asynchronous/test_session.py +++ b/test/asynchronous/test_session.py @@ -788,7 +788,6 @@ async def test_unacknowledged_writes(self): # Ensure the collection exists. await self.client.pymongo_test.test_unacked_writes.insert_one({}) client = await self.async_rs_or_single_client(w=0, event_listeners=[self.listener]) - self.addAsyncCleanup(client.close) db = client.pymongo_test coll = db.test_unacked_writes ops: list = [ @@ -1154,7 +1153,6 @@ async def test_cluster_time(self): client = await self.async_rs_or_single_client( event_listeners=[listener], heartbeatFrequencyMS=999999 ) - self.addAsyncCleanup(client.close) collection = client.pymongo_test.collection # Prepare for tests of find() and aggregate(). await collection.insert_many([{} for _ in range(10)]) diff --git a/test/test_auth.py b/test/test_auth.py index 8094f86428..2140f2d6f3 100644 --- a/test/test_auth.py +++ b/test/test_auth.py @@ -590,7 +590,6 @@ def test_scram_threaded(self): # The first thread to call find() will authenticate client = self.rs_or_single_client() - self.addCleanup(client.close) coll = client.db.test threads = [] for _ in range(4): diff --git a/test/test_client.py b/test/test_client.py index 2971c169ad..4a8b7dfcb9 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -707,13 +707,11 @@ def test_max_idle_time_reaper_removes_stale(self): def test_min_pool_size(self): with client_knobs(kill_cursor_frequency=0.1): client = self.rs_or_single_client() - self.addCleanup(client.close) server = (client._get_topology()).select_server(readable_server_selector, _Op.TEST) self.assertEqual(0, len(server._pool.conns)) # Assert that pool started up at minPoolSize client = self.rs_or_single_client(minPoolSize=10) - self.addCleanup(client.close) server = (client._get_topology()).select_server(readable_server_selector, _Op.TEST) wait_until( lambda: len(server._pool.conns) == 10, @@ -733,7 +731,6 @@ def test_max_idle_time_checkout(self): # Use high frequency to test _get_socket_no_auth. with client_knobs(kill_cursor_frequency=99999999): client = self.rs_or_single_client(maxIdleTimeMS=500) - self.addCleanup(client.close) server = (client._get_topology()).select_server(readable_server_selector, _Op.TEST) with server._pool.checkout() as conn: pass @@ -748,7 +745,6 @@ def test_max_idle_time_checkout(self): # Test that connections are reused if maxIdleTimeMS is not set. client = self.rs_or_single_client() - self.addCleanup(client.close) server = (client._get_topology()).select_server(readable_server_selector, _Op.TEST) with server._pool.checkout() as conn: pass @@ -815,13 +811,13 @@ def test_init_disconnected(self): self.assertEqual(c.address, (host, port)) bad_host = "somedomainthatdoesntexist.org" - c = MongoClient(bad_host, port, connectTimeoutMS=1, serverSelectionTimeoutMS=10) + c = self.simple_client(bad_host, port, connectTimeoutMS=1, serverSelectionTimeoutMS=10) with self.assertRaises(ConnectionFailure): c.pymongo_test.test.find_one() def test_init_disconnected_with_auth(self): uri = "mongodb://user:pass@somedomainthatdoesntexist" - c = MongoClient(uri, connectTimeoutMS=1, serverSelectionTimeoutMS=10) + c = self.simple_client(uri, connectTimeoutMS=1, serverSelectionTimeoutMS=10) with self.assertRaises(ConnectionFailure): c.pymongo_test.test.find_one() @@ -876,7 +872,6 @@ def test_repr(self): connect=False, document_class=SON, ) - self.addCleanup(client.close) the_repr = repr(client) self.assertIn("MongoClient(host=", the_repr) @@ -889,7 +884,7 @@ def test_repr(self): with eval(the_repr) as client_two: self.assertEqual(client_two, client) - client = MongoClient( + client = self.simple_client( "localhost:27017,localhost:27018", replicaSet="replset", connectTimeoutMS=12345, @@ -898,7 +893,6 @@ def test_repr(self): wtimeoutms=100, connect=False, ) - self.addCleanup(client.close) the_repr = repr(client) self.assertIn("MongoClient(host=", the_repr) self.assertIn("document_class=dict, tz_aware=False, connect=False, ", the_repr) @@ -926,7 +920,6 @@ def test_list_databases(self): self.assertIs(type(helper_doc), dict) self.assertEqual(helper_doc.keys(), cmd_doc.keys()) client = self.rs_or_single_client(document_class=SON) - self.addCleanup(client.close) for doc in client.list_databases(): self.assertIs(type(doc), dict) @@ -1056,7 +1049,6 @@ def test_close_does_not_open_servers(self): def test_close_closes_sockets(self): client = self.rs_client() - self.addCleanup(client.close) client.test.test.find_one() topology = client._topology client.close() @@ -1154,7 +1146,6 @@ def test_unix_socket(self): uri = "mongodb://%s" % encoded_socket # Confirm we can do operations via the socket. client = self.rs_or_single_client(uri) - self.addCleanup(client.close) client.pymongo_test.test.insert_one({"dummy": "object"}) dbs = client.list_database_names() self.assertTrue("pymongo_test" in dbs) @@ -1191,7 +1182,6 @@ def test_timeouts(self): maxIdleTimeMS=10500, serverSelectionTimeoutMS=10500, ) - self.addCleanup(client.close) self.assertEqual(10.5, (get_pool(client)).opts.connect_timeout) self.assertEqual(10.5, (get_pool(client)).opts.socket_timeout) self.assertEqual(10.5, (get_pool(client)).opts.max_idle_time_seconds) @@ -1242,11 +1232,9 @@ def get_x(db): def test_server_selection_timeout(self): client = MongoClient(serverSelectionTimeoutMS=100, connect=False) - self.addCleanup(client.close) self.assertAlmostEqual(0.1, client.options.server_selection_timeout) client = MongoClient(serverSelectionTimeoutMS=0, connect=False) - self.addCleanup(client.close) self.assertAlmostEqual(0, client.options.server_selection_timeout) @@ -1257,25 +1245,20 @@ def test_server_selection_timeout(self): ) client = MongoClient("mongodb://localhost/?serverSelectionTimeoutMS=100", connect=False) - self.addCleanup(client.close) self.assertAlmostEqual(0.1, client.options.server_selection_timeout) client = MongoClient("mongodb://localhost/?serverSelectionTimeoutMS=0", connect=False) - self.addCleanup(client.close) self.assertAlmostEqual(0, client.options.server_selection_timeout) # Test invalid timeout in URI ignored and set to default. client = MongoClient("mongodb://localhost/?serverSelectionTimeoutMS=-1", connect=False) - self.addCleanup(client.close) self.assertAlmostEqual(30, client.options.server_selection_timeout) client = MongoClient("mongodb://localhost/?serverSelectionTimeoutMS=", connect=False) - self.addCleanup(client.close) self.assertAlmostEqual(30, client.options.server_selection_timeout) def test_waitQueueTimeoutMS(self): client = self.rs_or_single_client(waitQueueTimeoutMS=2000) - self.addCleanup(client.close) self.assertEqual((get_pool(client)).opts.wait_queue_timeout, 2) def test_socketKeepAlive(self): @@ -1319,7 +1302,6 @@ def test_ipv6(self): uri += "/?replicaSet=" + (client_context.replica_set_name or "") client = self.rs_or_single_client_noauth(uri) - self.addCleanup(client.close) client.pymongo_test.test.insert_one({"dummy": "object"}) client.pymongo_test_bernie.test.insert_one({"dummy": "object"}) @@ -1421,7 +1403,6 @@ def test_operation_failure(self): # to avoid race conditions caused by replica set failover or idle # socket reaping. client = self.single_client() - self.addCleanup(client.close) client.pymongo_test.test.find_one() pool = get_pool(client) socket_count = len(pool.conns) @@ -1446,7 +1427,6 @@ def test_lazy_connect_w0(self): self.addCleanup(client_context.client.drop_database, "test_lazy_connect_w0") client = self.rs_or_single_client(connect=False, w=0) - self.addCleanup(client.close) client.test_lazy_connect_w0.test.insert_one({}) def predicate(): @@ -1455,7 +1435,6 @@ def predicate(): wait_until(predicate, "find one document") client = self.rs_or_single_client(connect=False, w=0) - self.addCleanup(client.close) client.test_lazy_connect_w0.test.update_one({}, {"$set": {"x": 1}}) def predicate(): @@ -1464,7 +1443,6 @@ def predicate(): wait_until(predicate, "update one document") client = self.rs_or_single_client(connect=False, w=0) - self.addCleanup(client.close) client.test_lazy_connect_w0.test.delete_one({}) def predicate(): @@ -1477,7 +1455,6 @@ def test_exhaust_network_error(self): # When doing an exhaust query, the socket stays checked out on success # but must be checked in on error to avoid semaphore leaks. client = self.rs_or_single_client(maxPoolSize=1, retryReads=False) - self.addCleanup(client.close) collection = client.pymongo_test.test pool = get_pool(client) pool._check_interval_seconds = None # Never check. @@ -1686,7 +1663,6 @@ def compression_settings(client): def test_reset_during_update_pool(self): client = self.rs_or_single_client(minPoolSize=10) - self.addCleanup(client.close) client.admin.command("ping") pool = get_pool(client) generation = pool.gen.get_overall() @@ -1735,8 +1711,6 @@ def test_background_connections_do_not_hold_locks(self): client = self.rs_or_single_client( serverSelectionTimeoutMS=3000, minPoolSize=min_pool_size, connect=False ) - self.addCleanup(client.close) - # Create a single connection in the pool. client.admin.command("ping") @@ -1798,11 +1772,10 @@ def server_description_count(): gc.collect() with client_knobs(min_heartbeat_interval=0.003): - client = MongoClient( + client = self.simple_client( "invalid:27017", heartbeatFrequencyMS=3, serverSelectionTimeoutMS=150 ) initial_count = server_description_count() - self.addCleanup(client.close) with self.assertRaises(ServerSelectionTimeoutError): client.test.test.find_one() gc.collect() @@ -1816,7 +1789,6 @@ def server_description_count(): @client_context.require_failCommand_fail_point def test_network_error_message(self): client = self.single_client(retryReads=False) - self.addCleanup(client.close) client.admin.command("ping") # connect with self.fail_point( {"mode": {"times": 1}, "data": {"closeConnection": True, "failCommands": ["find"]}} @@ -1852,7 +1824,6 @@ def test_service_name_from_kwargs(self): srvServiceName="customname", connect=False, ) - self.addCleanup(client.close) self.assertEqual(client._topology_settings.srv_service_name, "customname") client = MongoClient( "mongodb+srv://user:password@test22.test.build.10gen.cc" @@ -1860,26 +1831,21 @@ def test_service_name_from_kwargs(self): srvServiceName="customname", connect=False, ) - self.addCleanup(client.close) self.assertEqual(client._topology_settings.srv_service_name, "customname") client = MongoClient( "mongodb+srv://user:password@test22.test.build.10gen.cc/?srvServiceName=customname", connect=False, ) - self.addCleanup(client.close) self.assertEqual(client._topology_settings.srv_service_name, "customname") def test_srv_max_hosts_kwarg(self): client = MongoClient("mongodb+srv://test1.test.build.10gen.cc/") - self.addCleanup(client.close) self.assertGreater(len(client.topology_description.server_descriptions()), 1) client = MongoClient("mongodb+srv://test1.test.build.10gen.cc/", srvmaxhosts=1) - self.addCleanup(client.close) self.assertEqual(len(client.topology_description.server_descriptions()), 1) client = MongoClient( "mongodb+srv://test1.test.build.10gen.cc/?srvMaxHosts=1", srvmaxhosts=2 ) - self.addCleanup(client.close) self.assertEqual(len(client.topology_description.server_descriptions()), 2) @unittest.skipIf( @@ -2025,7 +1991,6 @@ def test_exhaust_query_server_error(self): # When doing an exhaust query, the socket stays checked out on success # but must be checked in on error to avoid semaphore leaks. client = connected(self.rs_or_single_client(maxPoolSize=1)) - self.addCleanup(client.close) collection = client.pymongo_test.test pool = get_pool(client) @@ -2049,7 +2014,6 @@ def test_exhaust_getmore_server_error(self): # When doing a getmore on an exhaust cursor, the socket stays checked # out on success but it's checked in on error to avoid semaphore leaks. client = self.rs_or_single_client(maxPoolSize=1) - self.addCleanup(client.close) collection = client.pymongo_test.test collection.drop() @@ -2089,7 +2053,6 @@ def test_exhaust_query_network_error(self): # When doing an exhaust query, the socket stays checked out on success # but must be checked in on error to avoid semaphore leaks. client = connected(self.rs_or_single_client(maxPoolSize=1, retryReads=False)) - self.addCleanup(client.close) collection = client.pymongo_test.test pool = get_pool(client) pool._check_interval_seconds = None # Never check. @@ -2111,7 +2074,6 @@ def test_exhaust_getmore_network_error(self): # When doing a getmore on an exhaust cursor, the socket stays checked # out on success but it's checked in on error to avoid semaphore leaks. client = self.rs_or_single_client(maxPoolSize=1) - self.addCleanup(client.close) collection = client.pymongo_test.test collection.drop() collection.insert_many([{} for _ in range(200)]) # More than one batch. @@ -2301,6 +2263,7 @@ def test_discover_primary(self): replicaSet="rs", heartbeatFrequencyMS=500, ) + self.addCleanup(c.close) wait_until(lambda: len(c.nodes) == 3, "connect") @@ -2327,6 +2290,7 @@ def test_reconnect(self): retryReads=False, serverSelectionTimeoutMS=1000, ) + self.addCleanup(c.close) wait_until(lambda: len(c.nodes) == 3, "connect") @@ -2362,6 +2326,8 @@ def _test_network_error(self, operation_callback): serverSelectionTimeoutMS=1000, ) + self.addCleanup(c.close) + # Set host-specific information so we can test whether it is reset. c.set_wire_version_range("a:1", 2, 6) c.set_wire_version_range("b:2", 2, 7) @@ -2436,6 +2402,7 @@ def test_rs_client_does_not_maintain_pool_to_arbiters(self): minPoolSize=1, # minPoolSize event_listeners=[listener], ) + self.addCleanup(c.close) wait_until(lambda: len(c.nodes) == 3, "connect") self.assertEqual(c.address, ("a", 1)) @@ -2465,6 +2432,7 @@ def test_direct_client_maintains_pool_to_arbiter(self): minPoolSize=1, # minPoolSize event_listeners=[listener], ) + self.addCleanup(c.close) wait_until(lambda: len(c.nodes) == 1, "connect") self.assertEqual(c.address, ("c", 3)) diff --git a/test/test_client_bulk_write.py b/test/test_client_bulk_write.py index 4b3b5a2a97..ebbdc74c1c 100644 --- a/test/test_client_bulk_write.py +++ b/test/test_client_bulk_write.py @@ -96,7 +96,6 @@ def setUp(self): def test_batch_splits_if_num_operations_too_large(self): listener = OvertCommandListener() client = self.rs_or_single_client(event_listeners=[listener]) - self.addCleanup(client.close) models = [] for _ in range(self.max_write_batch_size + 1): @@ -122,7 +121,6 @@ def test_batch_splits_if_num_operations_too_large(self): def test_batch_splits_if_ops_payload_too_large(self): listener = OvertCommandListener() client = self.rs_or_single_client(event_listeners=[listener]) - self.addCleanup(client.close) models = [] num_models = int(self.max_message_size_bytes / self.max_bson_object_size + 1) @@ -159,7 +157,6 @@ def test_collects_write_concern_errors_across_batches(self): event_listeners=[listener], retryWrites=False, ) - self.addCleanup(client.close) fail_command = { "configureFailPoint": "failCommand", @@ -199,7 +196,6 @@ def test_collects_write_concern_errors_across_batches(self): def test_collects_write_errors_across_batches_unordered(self): listener = OvertCommandListener() client = self.rs_or_single_client(event_listeners=[listener]) - self.addCleanup(client.close) collection = client.db["coll"] self.addCleanup(collection.drop) @@ -230,7 +226,6 @@ def test_collects_write_errors_across_batches_unordered(self): def test_collects_write_errors_across_batches_ordered(self): listener = OvertCommandListener() client = self.rs_or_single_client(event_listeners=[listener]) - self.addCleanup(client.close) collection = client.db["coll"] self.addCleanup(collection.drop) @@ -261,7 +256,6 @@ def test_collects_write_errors_across_batches_ordered(self): def test_handles_cursor_requiring_getMore(self): listener = OvertCommandListener() client = self.rs_or_single_client(event_listeners=[listener]) - self.addCleanup(client.close) collection = client.db["coll"] self.addCleanup(collection.drop) @@ -303,7 +297,6 @@ def test_handles_cursor_requiring_getMore(self): def test_handles_cursor_requiring_getMore_within_transaction(self): listener = OvertCommandListener() client = self.rs_or_single_client(event_listeners=[listener]) - self.addCleanup(client.close) collection = client.db["coll"] self.addCleanup(collection.drop) @@ -347,7 +340,6 @@ def test_handles_cursor_requiring_getMore_within_transaction(self): def test_handles_getMore_error(self): listener = OvertCommandListener() client = self.rs_or_single_client(event_listeners=[listener]) - self.addCleanup(client.close) collection = client.db["coll"] self.addCleanup(collection.drop) @@ -402,7 +394,6 @@ def test_handles_getMore_error(self): def test_returns_error_if_unacknowledged_too_large_insert(self): listener = OvertCommandListener() client = self.rs_or_single_client(event_listeners=[listener]) - self.addCleanup(client.close) b_repeated = "b" * self.max_bson_object_size @@ -459,7 +450,6 @@ def _setup_namespace_test_models(self): def test_no_batch_splits_if_new_namespace_is_not_too_large(self): listener = OvertCommandListener() client = self.rs_or_single_client(event_listeners=[listener]) - self.addCleanup(client.close) num_models, models = self._setup_namespace_test_models() models.append( @@ -491,7 +481,6 @@ def test_no_batch_splits_if_new_namespace_is_not_too_large(self): def test_batch_splits_if_new_namespace_is_too_large(self): listener = OvertCommandListener() client = self.rs_or_single_client(event_listeners=[listener]) - self.addCleanup(client.close) num_models, models = self._setup_namespace_test_models() c_repeated = "c" * 200 @@ -529,7 +518,6 @@ def test_batch_splits_if_new_namespace_is_too_large(self): @client_context.require_no_serverless def test_returns_error_if_no_writes_can_be_added_to_ops(self): client = self.rs_or_single_client() - self.addCleanup(client.close) # Document too large. b_repeated = "b" * self.max_message_size_bytes @@ -553,7 +541,6 @@ def test_returns_error_if_auto_encryption_configured(self): kms_providers={"aws": {"accessKeyId": "foo", "secretAccessKey": "bar"}}, ) client = self.rs_or_single_client(auto_encryption_opts=opts) - self.addCleanup(client.close) models = [InsertOne(namespace="db.coll", document={"a": "b"})] with self.assertRaises(InvalidOperation) as context: @@ -610,7 +597,6 @@ def test_timeout_in_multi_batch_bulk_write(self): timeoutMS=2000, w="majority", ) - self.addCleanup(client.close) client.admin.command("ping") # Init the client first. with self.assertRaises(ClientBulkWriteException) as context: client.bulk_write(models=models) diff --git a/test/test_cursor.py b/test/test_cursor.py index 520229902b..9bc22aca3c 100644 --- a/test/test_cursor.py +++ b/test/test_cursor.py @@ -345,7 +345,6 @@ def test_explain_with_read_concern(self): # Do not add readConcern level to explain. listener = AllowListEventListener("explain") client = self.rs_or_single_client(event_listeners=[listener]) - self.addCleanup(client.close) coll = client.pymongo_test.test.with_options(read_concern=ReadConcern(level="local")) self.assertTrue(coll.find().explain()) started = listener.started_events @@ -1252,7 +1251,6 @@ def test_close_kills_cursor_synchronously(self): listener = AllowListEventListener("killCursors") client = self.rs_or_single_client(event_listeners=[listener]) - self.addCleanup(client.close) coll = client[self.db.name].test_close_kills_cursors # Add some test data. @@ -1291,7 +1289,6 @@ def assertCursorKilled(): def test_timeout_kills_cursor_synchronously(self): listener = AllowListEventListener("killCursors") client = self.rs_or_single_client(event_listeners=[listener]) - self.addCleanup(client.close) coll = client[self.db.name].test_timeout_kills_cursor # Add some test data. @@ -1349,7 +1346,6 @@ def test_delete_not_initialized(self): def test_getMore_does_not_send_readPreference(self): listener = AllowListEventListener("find", "getMore") client = self.rs_or_single_client(event_listeners=[listener]) - self.addCleanup(client.close) # We never send primary read preference so override the default. coll = client[self.db.name].get_collection( "test", read_preference=ReadPreference.PRIMARY_PREFERRED @@ -1777,7 +1773,6 @@ def test_monitoring(self): def test_exhaust_cursor_db_set(self): listener = OvertCommandListener() client = self.rs_or_single_client(event_listeners=[listener]) - self.addCleanup(client.close) c = client.pymongo_test.test c.delete_many({}) c.insert_many([{"_id": i} for i in range(3)]) diff --git a/test/test_database.py b/test/test_database.py index 144c357c52..fe07f343c5 100644 --- a/test/test_database.py +++ b/test/test_database.py @@ -234,7 +234,6 @@ def test_list_collection_names_filter(self): def test_check_exists(self): listener = OvertCommandListener() client = self.rs_or_single_client(event_listeners=[listener]) - self.addCleanup(client.close) db = client[self.db.name] db.drop_collection("unique") db.create_collection("unique", check_exists=True) diff --git a/test/test_encryption.py b/test/test_encryption.py index ec146e7f2b..2f9b99fd09 100644 --- a/test/test_encryption.py +++ b/test/test_encryption.py @@ -31,7 +31,7 @@ from test import IntegrationTest, PyMongoTestCase, client_context from test.test_bulk import BulkTestBase from threading import Thread -from typing import Any, Dict, Mapping +from typing import Any, Dict, Mapping, Optional import pytest @@ -109,13 +109,12 @@ class TestAutoEncryptionOpts(PyMongoTestCase): @unittest.skipUnless(os.environ.get("TEST_CRYPT_SHARED"), "crypt_shared lib is not installed") def test_crypt_shared(self): # Test that we can pick up crypt_shared lib automatically - client = MongoClient( + self.simple_client( auto_encryption_opts=AutoEncryptionOpts( KMS_PROVIDERS, "keyvault.datakeys", crypt_shared_lib_required=True ), connect=False, ) - self.addCleanup(client.close) @unittest.skipIf(_HAVE_PYMONGOCRYPT, "pymongocrypt is installed") def test_init_requires_pymongocrypt(self): @@ -196,19 +195,16 @@ def test_init_kms_tls_options(self): class TestClientOptions(PyMongoTestCase): def test_default(self): - client = MongoClient(connect=False) - self.addCleanup(client.close) + client = self.simple_client(connect=False) self.assertEqual(get_client_opts(client).auto_encryption_opts, None) - client = MongoClient(auto_encryption_opts=None, connect=False) - self.addCleanup(client.close) + client = self.simple_client(auto_encryption_opts=None, connect=False) self.assertEqual(get_client_opts(client).auto_encryption_opts, None) @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") def test_kwargs(self): opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys") - client = MongoClient(auto_encryption_opts=opts, connect=False) - self.addCleanup(client.close) + client = self.simple_client(auto_encryption_opts=opts, connect=False) self.assertEqual(get_client_opts(client).auto_encryption_opts, opts) @@ -229,6 +225,34 @@ def assertBinaryUUID(self, val): self.assertIsInstance(val, Binary) self.assertEqual(val.subtype, UUID_SUBTYPE) + def create_client_encryption( + self, + kms_providers: Mapping[str, Any], + key_vault_namespace: str, + key_vault_client: MongoClient, + codec_options: Mapping[str, Any], + kms_tls_options: Optional[Mapping[str, Any]] = None, + ): + client_encryption = ClientEncryption( + kms_providers, key_vault_namespace, key_vault_client, codec_options, kms_tls_options + ) + self.addCleanup(client_encryption.close) + return client_encryption + + @classmethod + def unmanaged_create_client_encryption( + cls, + kms_providers: Mapping[str, Any], + key_vault_namespace: str, + key_vault_client: MongoClient, + codec_options: Mapping[str, Any], + kms_tls_options: Optional[Mapping[str, Any]] = None, + ): + client_encryption = ClientEncryption( + kms_providers, key_vault_namespace, key_vault_client, codec_options, kms_tls_options + ) + return client_encryption + # Location of JSON test files. if _IS_SYNC: @@ -261,7 +285,6 @@ def bson_data(*paths): class TestClientSimple(EncryptionIntegrationTest): def _test_auto_encrypt(self, opts): client = self.rs_or_single_client(auto_encryption_opts=opts) - self.addCleanup(client.close) # Create the encrypted field's data key. key_vault = create_key_vault( @@ -343,7 +366,6 @@ def test_auto_encrypt_local_schema_map(self): def test_use_after_close(self): opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys") client = self.rs_or_single_client(auto_encryption_opts=opts) - self.addCleanup(client.close) client.admin.command("ping") client.close() @@ -361,7 +383,6 @@ def test_use_after_close(self): def test_fork(self): opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys") client = self.rs_or_single_client(auto_encryption_opts=opts) - self.addCleanup(client.close) def target(): with warnings.catch_warnings(): @@ -376,7 +397,6 @@ class TestEncryptedBulkWrite(BulkTestBase, EncryptionIntegrationTest): def test_upsert_uuid_standard_encrypt(self): opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys") client = self.rs_or_single_client(auto_encryption_opts=opts) - self.addCleanup(client.close) options = CodecOptions(uuid_representation=UuidRepresentation.STANDARD) encrypted_coll = client.pymongo_test.test @@ -417,7 +437,6 @@ def _setup_class(cls): def test_raise_max_wire_version_error(self): opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys") client = self.rs_or_single_client(auto_encryption_opts=opts) - self.addCleanup(client.close) msg = "Auto-encryption requires a minimum MongoDB version of 4.2" with self.assertRaisesRegex(ConfigurationError, msg): client.test.test.insert_one({}) @@ -431,7 +450,6 @@ def test_raise_max_wire_version_error(self): def test_raise_unsupported_error(self): opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys") client = self.rs_or_single_client(auto_encryption_opts=opts) - self.addCleanup(client.close) msg = "find_raw_batches does not support auto encryption" with self.assertRaisesRegex(InvalidOperation, msg): client.test.test.find_raw_batches({}) @@ -450,10 +468,9 @@ def test_raise_unsupported_error(self): class TestExplicitSimple(EncryptionIntegrationTest): def test_encrypt_decrypt(self): - client_encryption = ClientEncryption( + client_encryption = self.create_client_encryption( KMS_PROVIDERS, "keyvault.datakeys", client_context.client, OPTS ) - self.addCleanup(client_encryption.close) # Use standard UUID representation. key_vault = client_context.client.keyvault.get_collection("datakeys", codec_options=OPTS) self.addCleanup(key_vault.drop) @@ -493,10 +510,9 @@ def test_encrypt_decrypt(self): self.assertEqual(decrypted_ssn, doc["ssn"]) def test_validation(self): - client_encryption = ClientEncryption( + client_encryption = self.create_client_encryption( KMS_PROVIDERS, "keyvault.datakeys", client_context.client, OPTS ) - self.addCleanup(client_encryption.close) msg = "value to decrypt must be a bson.binary.Binary with subtype 6" with self.assertRaisesRegex(TypeError, msg): @@ -510,10 +526,9 @@ def test_validation(self): client_encryption.encrypt("str", algo, key_id=Binary(b"123")) def test_bson_errors(self): - client_encryption = ClientEncryption( + client_encryption = self.create_client_encryption( KMS_PROVIDERS, "keyvault.datakeys", client_context.client, OPTS ) - self.addCleanup(client_encryption.close) # Attempt to encrypt an unencodable object. unencodable_value = object() @@ -526,7 +541,7 @@ def test_bson_errors(self): def test_codec_options(self): with self.assertRaisesRegex(TypeError, "codec_options must be"): - ClientEncryption( + self.create_client_encryption( KMS_PROVIDERS, "keyvault.datakeys", client_context.client, @@ -534,10 +549,9 @@ def test_codec_options(self): ) opts = CodecOptions(uuid_representation=UuidRepresentation.JAVA_LEGACY) - client_encryption_legacy = ClientEncryption( + client_encryption_legacy = self.create_client_encryption( KMS_PROVIDERS, "keyvault.datakeys", client_context.client, opts ) - self.addCleanup(client_encryption_legacy.close) # Create the encrypted field's data key. key_id = client_encryption_legacy.create_data_key("local") @@ -552,10 +566,9 @@ def test_codec_options(self): # Encrypt the same UUID with STANDARD codec options. opts = CodecOptions(uuid_representation=UuidRepresentation.STANDARD) - client_encryption = ClientEncryption( + client_encryption = self.create_client_encryption( KMS_PROVIDERS, "keyvault.datakeys", client_context.client, opts ) - self.addCleanup(client_encryption.close) encrypted_standard = client_encryption.encrypt( value, Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_id=key_id ) @@ -571,7 +584,7 @@ def test_codec_options(self): self.assertNotEqual(client_encryption.decrypt(encrypted_legacy), value) def test_close(self): - client_encryption = ClientEncryption( + client_encryption = self.create_client_encryption( KMS_PROVIDERS, "keyvault.datakeys", client_context.client, OPTS ) client_encryption.close() @@ -587,7 +600,7 @@ def test_close(self): client_encryption.decrypt(Binary(b"", 6)) def test_with_statement(self): - with ClientEncryption( + with self.create_client_encryption( KMS_PROVIDERS, "keyvault.datakeys", client_context.client, OPTS ) as client_encryption: pass @@ -832,7 +845,7 @@ def _setup_class(cls): cls.client_encrypted = cls.unmanaged_rs_or_single_client( auto_encryption_opts=opts, uuidRepresentation="standard" ) - cls.client_encryption = ClientEncryption( + cls.client_encryption = cls.unmanaged_create_client_encryption( cls.KMS_PROVIDERS, "keyvault.datakeys", cls.client, OPTS, kms_tls_options=KMS_TLS_OPTS ) @@ -920,7 +933,6 @@ def _test_external_key_vault(self, with_external_key_vault): schemas = {"db.coll": json_data("external", "external-schema.json")} if with_external_key_vault: key_vault_client = self.rs_or_single_client(username="fake-user", password="fake-pwd") - self.addCleanup(key_vault_client.close) else: key_vault_client = client_context.client opts = AutoEncryptionOpts( @@ -933,12 +945,10 @@ def _test_external_key_vault(self, with_external_key_vault): client_encrypted = self.rs_or_single_client( auto_encryption_opts=opts, uuidRepresentation="standard" ) - self.addCleanup(client_encrypted.close) - client_encryption = ClientEncryption( + client_encryption = self.create_client_encryption( self.kms_providers(), "keyvault.datakeys", key_vault_client, OPTS ) - self.addCleanup(client_encryption.close) if with_external_key_vault: # Authentication error. @@ -987,7 +997,6 @@ def test_views_are_prohibited(self): client_encrypted = self.rs_or_single_client( auto_encryption_opts=opts, uuidRepresentation="standard" ) - self.addCleanup(client_encrypted.close) with self.assertRaisesRegex(EncryptionError, "cannot auto encrypt a view"): client_encrypted.db.view.insert_one({}) @@ -1045,16 +1054,14 @@ def _test_corpus(self, opts): self.addCleanup(vault.drop) client_encrypted = self.rs_or_single_client(auto_encryption_opts=opts) - self.addCleanup(client_encrypted.close) - client_encryption = ClientEncryption( + client_encryption = self.create_client_encryption( self.kms_providers(), "keyvault.datakeys", client_context.client, OPTS, kms_tls_options=KMS_TLS_OPTS, ) - self.addCleanup(client_encryption.close) corpus = self.fix_up_curpus(json_data("corpus", "corpus.json")) corpus_copied: SON = SON() @@ -1285,7 +1292,7 @@ def setUp(self): "gcp": GCP_CREDS, "kmip": KMIP_CREDS, } - self.client_encryption = ClientEncryption( + self.client_encryption = self.create_client_encryption( kms_providers=kms_providers, key_vault_namespace="keyvault.datakeys", key_vault_client=client_context.client, @@ -1297,7 +1304,7 @@ def setUp(self): kms_providers_invalid["azure"]["identityPlatformEndpoint"] = "doesnotexist.invalid:443" kms_providers_invalid["gcp"]["endpoint"] = "doesnotexist.invalid:443" kms_providers_invalid["kmip"]["endpoint"] = "doesnotexist.local:5698" - self.client_encryption_invalid = ClientEncryption( + self.client_encryption_invalid = self.create_client_encryption( kms_providers=kms_providers_invalid, key_vault_namespace="keyvault.datakeys", key_vault_client=client_context.client, @@ -1476,7 +1483,7 @@ def test_12_kmip_master_key_invalid_endpoint(self): self.client_encryption.create_data_key("kmip", key) -class AzureGCPEncryptionTestMixin: +class AzureGCPEncryptionTestMixin(EncryptionIntegrationTest): DEK = None KMS_PROVIDER_MAP = None KEYVAULT_DB = "keyvault" @@ -1488,7 +1495,7 @@ def setUp(self): create_key_vault(keyvault, self.DEK) def _test_explicit(self, expectation): - client_encryption = ClientEncryption( + client_encryption = self.create_client_encryption( self.KMS_PROVIDER_MAP, # type: ignore[arg-type] ".".join([self.KEYVAULT_DB, self.KEYVAULT_COLL]), client_context.client, @@ -1599,7 +1606,6 @@ def setUp(self): self.client_test = self.rs_or_single_client( maxPoolSize=1, readConcernLevel="majority", w="majority", uuidRepresentation="standard" ) - self.addCleanup(self.client_test.close) self.client_keyvault_listener = OvertCommandListener() self.client_keyvault = self.rs_or_single_client( @@ -1608,7 +1614,6 @@ def setUp(self): w="majority", event_listeners=[self.client_keyvault_listener], ) - self.addCleanup(self.client_keyvault.close) self.client_test.keyvault.datakeys.drop() self.client_test.db.coll.drop() @@ -1619,7 +1624,7 @@ def setUp(self): codec_options=OPTS, ) - client_encryption = ClientEncryption( + client_encryption = self.create_client_encryption( kms_providers={"local": {"key": LOCAL_MASTER_KEY}}, key_vault_namespace="keyvault.datakeys", key_vault_client=self.client_test, @@ -1653,8 +1658,6 @@ def _run_test(self, max_pool_size, auto_encryption_opts): result = client_encrypted.db.coll.find_one({"_id": 0}) self.assertEqual(result, {"_id": 0, "encrypted": "string0"}) - self.addCleanup(client_encrypted.close) - def test_case_1(self): self._run_test( max_pool_size=1, @@ -1830,7 +1833,7 @@ def setUp(self): create_key_vault(self.client.keyvault.datakeys) kms_providers_map = {"local": {"key": LOCAL_MASTER_KEY}} - self.client_encryption = ClientEncryption( + self.client_encryption = self.create_client_encryption( kms_providers_map, "keyvault.datakeys", self.client, CodecOptions() ) keyID = self.client_encryption.create_data_key("local") @@ -1848,7 +1851,6 @@ def setUp(self): self.encrypted_client = self.rs_or_single_client( auto_encryption_opts=opts, retryReads=False, event_listeners=[self.listener] ) - self.addCleanup(self.encrypted_client.close) def test_01_command_error(self): with self.fail_point( @@ -1926,7 +1928,6 @@ def reset_timeout(): ], ) client_encrypted = self.rs_or_single_client(auto_encryption_opts=opts) - self.addCleanup(client_encrypted.close) with self.assertRaisesRegex(EncryptionError, "Timeout"): client_encrypted.db.coll.insert_one({"encrypted": "test"}) @@ -1941,11 +1942,11 @@ def test_bypassAutoEncryption(self): ], ) client_encrypted = self.rs_or_single_client(auto_encryption_opts=opts) - self.addCleanup(client_encrypted.close) client_encrypted.db.coll.insert_one({"unencrypted": "test"}) # Validate that mongocryptd was not spawned: - mongocryptd_client = MongoClient("mongodb://localhost:27027/?serverSelectionTimeoutMS=500") - self.addCleanup(mongocryptd_client.close) + mongocryptd_client = self.simple_client( + "mongodb://localhost:27027/?serverSelectionTimeoutMS=500" + ) with self.assertRaises(ServerSelectionTimeoutError): mongocryptd_client.admin.command("ping") @@ -1968,14 +1969,12 @@ def test_via_loading_shared_library(self): crypt_shared_lib_required=True, ) client_encrypted = self.rs_or_single_client(auto_encryption_opts=opts) - self.addCleanup(client_encrypted.close) client_encrypted.db.coll.drop() client_encrypted.db.coll.insert_one({"encrypted": "test"}) self.assertEncrypted((client_context.client.db.coll.find_one({}))["encrypted"]) - no_mongocryptd_client = MongoClient( + no_mongocryptd_client = self.simple_client( host="mongodb://localhost:47021/db?serverSelectionTimeoutMS=1000" ) - self.addCleanup(no_mongocryptd_client.close) with self.assertRaises(ServerSelectionTimeoutError): no_mongocryptd_client.db.command("ping") @@ -2010,7 +2009,6 @@ def listener(): crypt_shared_lib_required=False, ) client_encrypted = self.rs_or_single_client(auto_encryption_opts=opts) - self.addCleanup(client_encrypted.close) client_encrypted.db.coll.drop() client_encrypted.db.coll.insert_one({"encrypted": "test"}) server.shutdown() @@ -2024,10 +2022,9 @@ class TestKmsTLSProse(EncryptionIntegrationTest): def setUp(self): super().setUp() self.patch_system_certs(CA_PEM) - self.client_encrypted = ClientEncryption( + self.client_encrypted = self.create_client_encryption( {"aws": AWS_CREDS}, "keyvault.datakeys", self.client, OPTS ) - self.addCleanup(self.client_encrypted.close) def test_invalid_kms_certificate_expired(self): key = { @@ -2072,36 +2069,32 @@ def setUp(self): "gcp": {"tlsCAFile": CA_PEM}, "kmip": {"tlsCAFile": CA_PEM}, } - self.client_encryption_no_client_cert = ClientEncryption( + self.client_encryption_no_client_cert = self.create_client_encryption( providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=kms_tls_opts_ca_only ) - self.addCleanup(self.client_encryption_no_client_cert.close) # 2, same providers as above but with tlsCertificateKeyFile. kms_tls_opts = copy.deepcopy(kms_tls_opts_ca_only) for p in kms_tls_opts: kms_tls_opts[p]["tlsCertificateKeyFile"] = CLIENT_PEM - self.client_encryption_with_tls = ClientEncryption( + self.client_encryption_with_tls = self.create_client_encryption( providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=kms_tls_opts ) - self.addCleanup(self.client_encryption_with_tls.close) # 3, update endpoints to expired host. providers: dict = copy.deepcopy(providers) providers["azure"]["identityPlatformEndpoint"] = "127.0.0.1:9000" providers["gcp"]["endpoint"] = "127.0.0.1:9000" providers["kmip"]["endpoint"] = "127.0.0.1:9000" - self.client_encryption_expired = ClientEncryption( + self.client_encryption_expired = self.create_client_encryption( providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=kms_tls_opts_ca_only ) - self.addCleanup(self.client_encryption_expired.close) # 3, update endpoints to invalid host. providers: dict = copy.deepcopy(providers) providers["azure"]["identityPlatformEndpoint"] = "127.0.0.1:9001" providers["gcp"]["endpoint"] = "127.0.0.1:9001" providers["kmip"]["endpoint"] = "127.0.0.1:9001" - self.client_encryption_invalid_hostname = ClientEncryption( + self.client_encryption_invalid_hostname = self.create_client_encryption( providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=kms_tls_opts_ca_only ) - self.addCleanup(self.client_encryption_invalid_hostname.close) # Errors when client has no cert, some examples: # [SSL: TLSV13_ALERT_CERTIFICATE_REQUIRED] tlsv13 alert certificate required (_ssl.c:2623) self.cert_error = ( @@ -2139,7 +2132,7 @@ def setUp(self): "gcp:with_tls": with_cert, "kmip:with_tls": with_cert, } - self.client_encryption_with_names = ClientEncryption( + self.client_encryption_with_names = self.create_client_encryption( providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=kms_tls_opts_4 ) @@ -2221,10 +2214,9 @@ def test_04_kmip(self): def test_05_tlsDisableOCSPEndpointCheck_is_permitted(self): providers = {"aws": {"accessKeyId": "foo", "secretAccessKey": "bar"}} options = {"aws": {"tlsDisableOCSPEndpointCheck": True}} - encryption = ClientEncryption( + encryption = self.create_client_encryption( providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=options ) - self.addCleanup(encryption.close) ctx = encryption._io_callbacks.opts._kms_ssl_contexts["aws"] if not hasattr(ctx, "check_ocsp_endpoint"): raise self.skipTest("OCSP not enabled") @@ -2274,7 +2266,7 @@ def setUp(self): self.client = client_context.client create_key_vault(self.client.keyvault.datakeys) kms_providers_map = {"local": {"key": LOCAL_MASTER_KEY}} - self.client_encryption = ClientEncryption( + self.client_encryption = self.create_client_encryption( kms_providers_map, "keyvault.datakeys", self.client, CodecOptions() ) self.def_key_id = self.client_encryption.create_data_key("local", key_alt_names=["def"]) @@ -2312,17 +2304,15 @@ def setUp(self): key_vault = create_key_vault(self.client.keyvault.datakeys, self.key1_document) self.addCleanup(key_vault.drop) self.key_vault_client = self.client - self.client_encryption = ClientEncryption( + self.client_encryption = self.create_client_encryption( {"local": {"key": LOCAL_MASTER_KEY}}, key_vault.full_name, self.key_vault_client, OPTS ) - self.addCleanup(self.client_encryption.close) opts = AutoEncryptionOpts( {"local": {"key": LOCAL_MASTER_KEY}}, key_vault.full_name, bypass_query_analysis=True, ) self.encrypted_client = self.rs_or_single_client(auto_encryption_opts=opts) - self.addCleanup(self.encrypted_client.close) def test_01_insert_encrypted_indexed_and_find(self): val = "encrypted indexed value" @@ -2445,14 +2435,13 @@ def run_test(self, src_provider, dst_provider): self.client.keyvault.drop_collection("datakeys") # Step 2. Create a ``ClientEncryption`` object named ``client_encryption1`` - client_encryption1 = ClientEncryption( + client_encryption1 = self.create_client_encryption( key_vault_client=self.client, key_vault_namespace="keyvault.datakeys", kms_providers=ALL_KMS_PROVIDERS, kms_tls_options=KMS_TLS_OPTS, codec_options=OPTS, ) - self.addCleanup(client_encryption1.close) # Step 3. Call ``client_encryption1.create_data_key`` with ``src_provider``. key_id = client_encryption1.create_data_key( @@ -2466,15 +2455,13 @@ def run_test(self, src_provider, dst_provider): # Step 5. Create a ``ClientEncryption`` object named ``client_encryption2`` client2 = self.rs_or_single_client() - self.addCleanup(client2.close) - client_encryption2 = ClientEncryption( + client_encryption2 = self.create_client_encryption( key_vault_client=client2, key_vault_namespace="keyvault.datakeys", kms_providers=ALL_KMS_PROVIDERS, kms_tls_options=KMS_TLS_OPTS, codec_options=OPTS, ) - self.addCleanup(client_encryption2.close) # Step 6. Call ``client_encryption2.rewrap_many_data_key`` with an empty ``filter``. rewrap_many_data_key_result = client_encryption2.rewrap_many_data_key( @@ -2509,7 +2496,7 @@ def setUp(self): @unittest.skipIf(any(AWS_CREDS.values()), "AWS environment credentials are set") def test_01_failure(self): - self.client_encryption = ClientEncryption( + self.client_encryption = self.create_client_encryption( kms_providers={"aws": {}}, key_vault_namespace="keyvault.datakeys", key_vault_client=client_context.client, @@ -2520,7 +2507,7 @@ def test_01_failure(self): @unittest.skipUnless(any(AWS_CREDS.values()), "AWS environment credentials are not set") def test_02_success(self): - self.client_encryption = ClientEncryption( + self.client_encryption = self.create_client_encryption( kms_providers={"aws": {}}, key_vault_namespace="keyvault.datakeys", key_vault_client=client_context.client, @@ -2541,7 +2528,6 @@ def test_queryable_encryption(self): # and cleanup. def MongoClient(**kwargs): c = self.rs_or_single_client(**kwargs) - self.addCleanup(c.close) return c # Drop data from prior test runs. @@ -2552,7 +2538,7 @@ def MongoClient(**kwargs): # Create two data keys. key_vault_client = MongoClient() - client_encryption = ClientEncryption( + client_encryption = self.create_client_encryption( kms_providers_map, "keyvault.datakeys", key_vault_client, CodecOptions() ) key1_id = client_encryption.create_data_key("local") @@ -2633,10 +2619,9 @@ def setUp(self): key_vault = create_key_vault(self.client.keyvault.datakeys, self.key1_document) self.addCleanup(key_vault.drop) self.key_vault_client = self.client - self.client_encryption = ClientEncryption( + self.client_encryption = self.create_client_encryption( {"local": {"key": LOCAL_MASTER_KEY}}, key_vault.full_name, self.key_vault_client, OPTS ) - self.addCleanup(self.client_encryption.close) opts = AutoEncryptionOpts( {"local": {"key": LOCAL_MASTER_KEY}}, key_vault.full_name, @@ -2644,7 +2629,6 @@ def setUp(self): ) self.encrypted_client = self.rs_or_single_client(auto_encryption_opts=opts) self.db = self.encrypted_client.db - self.addCleanup(self.encrypted_client.close) def run_expression_find( self, name, expression, expected_elems, range_opts, use_expr=False, key_id=None @@ -2839,10 +2823,9 @@ def setUp(self): super().setUp() self.client.drop_database(self.db) self.key_vault_client = self.client - self.client_encryption = ClientEncryption( + self.client_encryption = self.create_client_encryption( {"local": {"key": LOCAL_MASTER_KEY}}, "keyvault.datakeys", self.key_vault_client, OPTS ) - self.addCleanup(self.client_encryption.close) self.key_id = self.client_encryption.create_data_key("local") opts = RangeOpts(min=0, max=1000) self.payload_defaults = self.client_encryption.encrypt( @@ -2875,13 +2858,12 @@ def setUp(self): self.client.drop_database(self.db) self.key_vault = create_key_vault(self.client.keyvault.datakeys, self.key1_document) self.addCleanup(self.key_vault.drop) - self.client_encryption = ClientEncryption( + self.client_encryption = self.create_client_encryption( {"local": {"key": LOCAL_MASTER_KEY}}, self.key_vault.full_name, self.client, OPTS, ) - self.addCleanup(self.client_encryption.close) def test_01_simple_create(self): coll, _ = self.client_encryption.create_encrypted_collection( @@ -3097,10 +3079,9 @@ def _tearDown_class(cls): def setUp(self) -> None: self.listener = OvertCommandListener() - self.mongocryptd_client = MongoClient( + self.mongocryptd_client = self.simple_client( f"mongodb://localhost:{self.MONGOCRYPTD_PORT}", event_listeners=[self.listener] ) - self.addCleanup(self.mongocryptd_client.close) hello = self.mongocryptd_client.db.command("hello") self.assertNotIn("logicalSessionTimeoutMinutes", hello) diff --git a/test/test_grid_file.py b/test/test_grid_file.py index 815d95cf03..fe88aec5ff 100644 --- a/test/test_grid_file.py +++ b/test/test_grid_file.py @@ -790,8 +790,7 @@ def test_grid_out_lazy_connect(self): outfile.readchunk() def test_grid_in_lazy_connect(self): - client = MongoClient("badhost", connect=False, serverSelectionTimeoutMS=10) - self.addCleanup(client.close) + client = self.simple_client("badhost", connect=False, serverSelectionTimeoutMS=10) fs = client.db.fs infile = GridIn(fs, file_id=-1, chunk_size=1) with self.assertRaises(ServerSelectionTimeoutError): diff --git a/test/test_session.py b/test/test_session.py index d0f5c6e6d9..9f94ded927 100644 --- a/test/test_session.py +++ b/test/test_session.py @@ -787,7 +787,6 @@ def test_unacknowledged_writes(self): # Ensure the collection exists. self.client.pymongo_test.test_unacked_writes.insert_one({}) client = self.rs_or_single_client(w=0, event_listeners=[self.listener]) - self.addCleanup(client.close) db = client.pymongo_test coll = db.test_unacked_writes ops: list = [ @@ -1137,7 +1136,6 @@ def test_cluster_time(self): listener = SessionTestListener() # Prevent heartbeats from updating $clusterTime between operations. client = self.rs_or_single_client(event_listeners=[listener], heartbeatFrequencyMS=999999) - self.addCleanup(client.close) collection = client.pymongo_test.collection # Prepare for tests of find() and aggregate(). collection.insert_many([{} for _ in range(10)]) From 4df385c76f1d8625cf9750bcca0970de0869cc0f Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Thu, 12 Sep 2024 17:15:53 -0400 Subject: [PATCH 11/29] Cleanup part 2 --- test/asynchronous/test_client.py | 6 +++--- test/asynchronous/test_encryption.py | 4 ++-- test/test_client.py | 6 +++--- test/test_encryption.py | 4 ++-- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/test/asynchronous/test_client.py b/test/asynchronous/test_client.py index a25ecccf63..4369bf1481 100644 --- a/test/asynchronous/test_client.py +++ b/test/asynchronous/test_client.py @@ -1881,11 +1881,11 @@ async def test_service_name_from_kwargs(self): self.assertEqual(client._topology_settings.srv_service_name, "customname") async def test_srv_max_hosts_kwarg(self): - client = AsyncMongoClient("mongodb+srv://test1.test.build.10gen.cc/") + client = self.simple_client("mongodb+srv://test1.test.build.10gen.cc/") self.assertGreater(len(client.topology_description.server_descriptions()), 1) - client = AsyncMongoClient("mongodb+srv://test1.test.build.10gen.cc/", srvmaxhosts=1) + client = self.simple_client("mongodb+srv://test1.test.build.10gen.cc/", srvmaxhosts=1) self.assertEqual(len(client.topology_description.server_descriptions()), 1) - client = AsyncMongoClient( + client = self.simple_client( "mongodb+srv://test1.test.build.10gen.cc/?srvMaxHosts=1", srvmaxhosts=2 ) self.assertEqual(len(client.topology_description.server_descriptions()), 2) diff --git a/test/asynchronous/test_encryption.py b/test/asynchronous/test_encryption.py index 3000d2361d..3f3714eeb4 100644 --- a/test/asynchronous/test_encryption.py +++ b/test/asynchronous/test_encryption.py @@ -230,7 +230,7 @@ def create_client_encryption( kms_providers: Mapping[str, Any], key_vault_namespace: str, key_vault_client: AsyncMongoClient, - codec_options: Mapping[str, Any], + codec_options: CodecOptions, kms_tls_options: Optional[Mapping[str, Any]] = None, ): client_encryption = AsyncClientEncryption( @@ -245,7 +245,7 @@ def unmanaged_create_client_encryption( kms_providers: Mapping[str, Any], key_vault_namespace: str, key_vault_client: AsyncMongoClient, - codec_options: Mapping[str, Any], + codec_options: CodecOptions, kms_tls_options: Optional[Mapping[str, Any]] = None, ): client_encryption = AsyncClientEncryption( diff --git a/test/test_client.py b/test/test_client.py index 4a8b7dfcb9..b293158fc0 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -1839,11 +1839,11 @@ def test_service_name_from_kwargs(self): self.assertEqual(client._topology_settings.srv_service_name, "customname") def test_srv_max_hosts_kwarg(self): - client = MongoClient("mongodb+srv://test1.test.build.10gen.cc/") + client = self.simple_client("mongodb+srv://test1.test.build.10gen.cc/") self.assertGreater(len(client.topology_description.server_descriptions()), 1) - client = MongoClient("mongodb+srv://test1.test.build.10gen.cc/", srvmaxhosts=1) + client = self.simple_client("mongodb+srv://test1.test.build.10gen.cc/", srvmaxhosts=1) self.assertEqual(len(client.topology_description.server_descriptions()), 1) - client = MongoClient( + client = self.simple_client( "mongodb+srv://test1.test.build.10gen.cc/?srvMaxHosts=1", srvmaxhosts=2 ) self.assertEqual(len(client.topology_description.server_descriptions()), 2) diff --git a/test/test_encryption.py b/test/test_encryption.py index 2f9b99fd09..96d40c4a34 100644 --- a/test/test_encryption.py +++ b/test/test_encryption.py @@ -230,7 +230,7 @@ def create_client_encryption( kms_providers: Mapping[str, Any], key_vault_namespace: str, key_vault_client: MongoClient, - codec_options: Mapping[str, Any], + codec_options: CodecOptions, kms_tls_options: Optional[Mapping[str, Any]] = None, ): client_encryption = ClientEncryption( @@ -245,7 +245,7 @@ def unmanaged_create_client_encryption( kms_providers: Mapping[str, Any], key_vault_namespace: str, key_vault_client: MongoClient, - codec_options: Mapping[str, Any], + codec_options: CodecOptions, kms_tls_options: Optional[Mapping[str, Any]] = None, ): client_encryption = ClientEncryption( From 772dcf0ff992f1b37efe5aedecd659156a311561 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Fri, 13 Sep 2024 10:03:10 -0400 Subject: [PATCH 12/29] Fix failures --- test/asynchronous/test_client.py | 2 +- test/test_dns.py | 24 ++++---- test/test_max_staleness.py | 1 - test/test_ssl.py | 101 ++++++++++++++++--------------- 4 files changed, 66 insertions(+), 62 deletions(-) diff --git a/test/asynchronous/test_client.py b/test/asynchronous/test_client.py index 4369bf1481..d67fb2aa4a 100644 --- a/test/asynchronous/test_client.py +++ b/test/asynchronous/test_client.py @@ -408,7 +408,7 @@ async def test_metadata(self): ) @mock.patch.dict("os.environ", {ENV_VAR_K8S: "1"}) - async def test_container_metadata(self): + def test_container_metadata(self): metadata = copy.deepcopy(_METADATA) metadata["driver"]["name"] = "PyMongo|async" metadata["env"] = {} diff --git a/test/test_dns.py b/test/test_dns.py index b4c5e3684c..ac2059fa99 100644 --- a/test/test_dns.py +++ b/test/test_dns.py @@ -22,16 +22,15 @@ sys.path[0:0] = [""] -from test import IntegrationTest, client_context, unittest +from test import IntegrationTest, PyMongoTestCase, client_context, unittest from test.utils import wait_until from pymongo.common import validate_read_preference_tags from pymongo.errors import ConfigurationError -from pymongo.synchronous.mongo_client import MongoClient from pymongo.uri_parser import parse_uri, split_hosts -class TestDNSRepl(unittest.TestCase): +class TestDNSRepl(PyMongoTestCase): TEST_PATH = os.path.join( os.path.dirname(os.path.realpath(__file__)), "srv_seedlist", "replica-set" ) @@ -42,7 +41,7 @@ def setUp(self): pass -class TestDNSLoadBalanced(unittest.TestCase): +class TestDNSLoadBalanced(PyMongoTestCase): TEST_PATH = os.path.join( os.path.dirname(os.path.realpath(__file__)), "srv_seedlist", "load-balanced" ) @@ -53,7 +52,7 @@ def setUp(self): pass -class TestDNSSharded(unittest.TestCase): +class TestDNSSharded(PyMongoTestCase): TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "srv_seedlist", "sharded") load_balanced = False @@ -120,7 +119,7 @@ def run_test(self): # tests. copts["tlsAllowInvalidHostnames"] = True - client = MongoClient(uri, **copts) + client = PyMongoTestCase.unmanaged_single_client(uri, **copts) if num_seeds is not None: self.assertEqual(len(client._topology_settings.seeds), num_seeds) if hosts is not None: @@ -133,6 +132,7 @@ def run_test(self): client.admin.command("ping") # XXX: we should block until SRV poller runs at least once # and re-run these assertions. + client.close() else: try: parse_uri(uri) @@ -157,37 +157,37 @@ def create_tests(cls): create_tests(TestDNSSharded) -class TestParsingErrors(unittest.TestCase): +class TestParsingErrors(PyMongoTestCase): def test_invalid_host(self): self.assertRaisesRegex( ConfigurationError, "Invalid URI host: mongodb is not", - MongoClient, + self.simple_client, "mongodb+srv://mongodb", ) self.assertRaisesRegex( ConfigurationError, "Invalid URI host: mongodb.com is not", - MongoClient, + self.simple_client, "mongodb+srv://mongodb.com", ) self.assertRaisesRegex( ConfigurationError, "Invalid URI host: an IP address is not", - MongoClient, + self.simple_client, "mongodb+srv://127.0.0.1", ) self.assertRaisesRegex( ConfigurationError, "Invalid URI host: an IP address is not", - MongoClient, + self.simple_client, "mongodb+srv://[::1]", ) class TestCaseInsensitive(IntegrationTest): def test_connect_case_insensitive(self): - client = MongoClient("mongodb+srv://TEST1.TEST.BUILD.10GEN.cc/") + client = self.simple_client("mongodb+srv://TEST1.TEST.BUILD.10GEN.cc/") self.addCleanup(client.close) self.assertGreater(len(client.topology_description.server_descriptions()), 1) diff --git a/test/test_max_staleness.py b/test/test_max_staleness.py index fb7313d1f2..283f5ae5d0 100644 --- a/test/test_max_staleness.py +++ b/test/test_max_staleness.py @@ -27,7 +27,6 @@ from test import PyMongoTestCase, client_context, unittest from test.utils_selection_tests import create_selection_tests -from pymongo import MongoClient from pymongo.errors import ConfigurationError from pymongo.server_selectors import writable_server_selector diff --git a/test/test_ssl.py b/test/test_ssl.py index 5b3855a82a..36d7ba12b6 100644 --- a/test/test_ssl.py +++ b/test/test_ssl.py @@ -24,6 +24,7 @@ from test import ( HAVE_IPADDRESS, IntegrationTest, + PyMongoTestCase, SkipTest, client_context, connected, @@ -82,45 +83,45 @@ # use 'localhost' for the hostname of all hosts. -class TestClientSSL(unittest.TestCase): +class TestClientSSL(PyMongoTestCase): @unittest.skipIf(HAVE_SSL, "The ssl module is available, can't test what happens without it.") def test_no_ssl_module(self): # Explicit - self.assertRaises(ConfigurationError, MongoClient, ssl=True) + self.assertRaises(ConfigurationError, self.simple_client, ssl=True) # Implied - self.assertRaises(ConfigurationError, MongoClient, tlsCertificateKeyFile=CLIENT_PEM) + self.assertRaises(ConfigurationError, self.simple_client, tlsCertificateKeyFile=CLIENT_PEM) @unittest.skipUnless(HAVE_SSL, "The ssl module is not available.") @ignore_deprecations def test_config_ssl(self): # Tests various ssl configurations - self.assertRaises(ValueError, MongoClient, ssl="foo") + self.assertRaises(ValueError, self.simple_client, ssl="foo") self.assertRaises( - ConfigurationError, MongoClient, tls=False, tlsCertificateKeyFile=CLIENT_PEM + ConfigurationError, self.simple_client, tls=False, tlsCertificateKeyFile=CLIENT_PEM ) - self.assertRaises(TypeError, MongoClient, ssl=0) - self.assertRaises(TypeError, MongoClient, ssl=5.5) - self.assertRaises(TypeError, MongoClient, ssl=[]) + self.assertRaises(TypeError, self.simple_client, ssl=0) + self.assertRaises(TypeError, self.simple_client, ssl=5.5) + self.assertRaises(TypeError, self.simple_client, ssl=[]) - self.assertRaises(IOError, MongoClient, tlsCertificateKeyFile="NoSuchFile") - self.assertRaises(TypeError, MongoClient, tlsCertificateKeyFile=True) - self.assertRaises(TypeError, MongoClient, tlsCertificateKeyFile=[]) + self.assertRaises(IOError, self.simple_client, tlsCertificateKeyFile="NoSuchFile") + self.assertRaises(TypeError, self.simple_client, tlsCertificateKeyFile=True) + self.assertRaises(TypeError, self.simple_client, tlsCertificateKeyFile=[]) # Test invalid combinations self.assertRaises( - ConfigurationError, MongoClient, tls=False, tlsCertificateKeyFile=CLIENT_PEM + ConfigurationError, self.simple_client, tls=False, tlsCertificateKeyFile=CLIENT_PEM ) - self.assertRaises(ConfigurationError, MongoClient, tls=False, tlsCAFile=CA_PEM) - self.assertRaises(ConfigurationError, MongoClient, tls=False, tlsCRLFile=CRL_PEM) + self.assertRaises(ConfigurationError, self.simple_client, tls=False, tlsCAFile=CA_PEM) + self.assertRaises(ConfigurationError, self.simple_client, tls=False, tlsCRLFile=CRL_PEM) self.assertRaises( - ConfigurationError, MongoClient, tls=False, tlsAllowInvalidCertificates=False + ConfigurationError, self.simple_client, tls=False, tlsAllowInvalidCertificates=False ) self.assertRaises( - ConfigurationError, MongoClient, tls=False, tlsAllowInvalidHostnames=False + ConfigurationError, self.simple_client, tls=False, tlsAllowInvalidHostnames=False ) self.assertRaises( - ConfigurationError, MongoClient, tls=False, tlsDisableOCSPEndpointCheck=False + ConfigurationError, self.simple_client, tls=False, tlsDisableOCSPEndpointCheck=False ) @unittest.skipUnless(_HAVE_PYOPENSSL, "PyOpenSSL is not available.") @@ -174,7 +175,7 @@ def test_tlsCertificateKeyFilePassword(self): if not hasattr(ssl, "SSLContext") and not _ssl.IS_PYOPENSSL: self.assertRaises( ConfigurationError, - MongoClient, + self.simple_client, "localhost", ssl=True, tlsCertificateKeyFile=CLIENT_ENCRYPTED_PEM, @@ -184,7 +185,7 @@ def test_tlsCertificateKeyFilePassword(self): ) else: connected( - MongoClient( + self.simple_client( "localhost", ssl=True, tlsCertificateKeyFile=CLIENT_ENCRYPTED_PEM, @@ -201,7 +202,7 @@ def test_tlsCertificateKeyFilePassword(self): "&tlsCAFile=%s&serverSelectionTimeoutMS=5000" ) connected( - MongoClient(uri_fmt % (CLIENT_ENCRYPTED_PEM, CA_PEM), **self.credentials) # type: ignore[arg-type] + self.simple_client(uri_fmt % (CLIENT_ENCRYPTED_PEM, CA_PEM), **self.credentials) # type: ignore[arg-type] ) @client_context.require_tlsCertificateKeyFile @@ -215,7 +216,7 @@ def test_cert_ssl_implicitly_set(self): # # test that setting tlsCertificateKeyFile causes ssl to be set to True - client = MongoClient( + client = self.simple_client( client_context.host, client_context.port, tlsAllowInvalidCertificates=True, @@ -223,7 +224,7 @@ def test_cert_ssl_implicitly_set(self): ) response = client.admin.command(HelloCompat.LEGACY_CMD) if "setName" in response: - client = MongoClient( + client = self.simple_client( client_context.pair, replicaSet=response["setName"], w=len(response["hosts"]), @@ -242,7 +243,7 @@ def test_cert_ssl_validation(self): # --sslPEMKeyFile=/path/to/pymongo/test/certificates/server.pem # --sslCAFile=/path/to/pymongo/test/certificates/ca.pem # - client = MongoClient( + client = self.simple_client( "localhost", ssl=True, tlsCertificateKeyFile=CLIENT_PEM, @@ -257,7 +258,7 @@ def test_cert_ssl_validation(self): "Cannot validate hostname in the certificate" ) - client = MongoClient( + client = self.simple_client( "localhost", replicaSet=response["setName"], w=len(response["hosts"]), @@ -270,7 +271,7 @@ def test_cert_ssl_validation(self): self.assertClientWorks(client) if HAVE_IPADDRESS: - client = MongoClient( + client = self.simple_client( "127.0.0.1", ssl=True, tlsCertificateKeyFile=CLIENT_PEM, @@ -292,7 +293,7 @@ def test_cert_ssl_uri_support(self): "mongodb://localhost/?ssl=true&tlsCertificateKeyFile=%s&tlsAllowInvalidCertificates" "=%s&tlsCAFile=%s&tlsAllowInvalidHostnames=false" ) - client = MongoClient(uri_fmt % (CLIENT_PEM, "true", CA_PEM)) + client = self.simple_client(uri_fmt % (CLIENT_PEM, "true", CA_PEM)) self.assertClientWorks(client) @client_context.require_tlsCertificateKeyFile @@ -316,7 +317,7 @@ def test_cert_ssl_validation_hostname_matching(self): with self.assertRaises(ConnectionFailure): connected( - MongoClient( + self.simple_client( "server", ssl=True, tlsCertificateKeyFile=CLIENT_PEM, @@ -328,7 +329,7 @@ def test_cert_ssl_validation_hostname_matching(self): ) connected( - MongoClient( + self.simple_client( "server", ssl=True, tlsCertificateKeyFile=CLIENT_PEM, @@ -343,7 +344,7 @@ def test_cert_ssl_validation_hostname_matching(self): if "setName" in response: with self.assertRaises(ConnectionFailure): connected( - MongoClient( + self.simple_client( "server", replicaSet=response["setName"], ssl=True, @@ -356,7 +357,7 @@ def test_cert_ssl_validation_hostname_matching(self): ) connected( - MongoClient( + self.simple_client( "server", replicaSet=response["setName"], ssl=True, @@ -375,7 +376,7 @@ def test_tlsCRLFile_support(self): if not hasattr(ssl, "VERIFY_CRL_CHECK_LEAF") or _ssl.IS_PYOPENSSL: self.assertRaises( ConfigurationError, - MongoClient, + self.simple_client, "localhost", ssl=True, tlsCAFile=CA_PEM, @@ -384,7 +385,7 @@ def test_tlsCRLFile_support(self): ) else: connected( - MongoClient( + self.simple_client( "localhost", ssl=True, tlsCAFile=CA_PEM, @@ -395,7 +396,7 @@ def test_tlsCRLFile_support(self): with self.assertRaises(ConnectionFailure): connected( - MongoClient( + self.simple_client( "localhost", ssl=True, tlsCAFile=CA_PEM, @@ -406,7 +407,7 @@ def test_tlsCRLFile_support(self): ) uri_fmt = "mongodb://localhost/?ssl=true&tlsCAFile=%s&serverSelectionTimeoutMS=1000" - connected(MongoClient(uri_fmt % (CA_PEM,), **self.credentials)) # type: ignore + connected(self.simple_client(uri_fmt % (CA_PEM,), **self.credentials)) # type: ignore uri_fmt = ( "mongodb://localhost/?ssl=true&tlsCRLFile=%s" @@ -414,7 +415,7 @@ def test_tlsCRLFile_support(self): ) with self.assertRaises(ConnectionFailure): connected( - MongoClient(uri_fmt % (CRL_PEM, CA_PEM), **self.credentials) # type: ignore[arg-type] + self.simple_client(uri_fmt % (CRL_PEM, CA_PEM), **self.credentials) # type: ignore[arg-type] ) @client_context.require_tlsCertificateKeyFile @@ -431,12 +432,14 @@ def test_validation_with_system_ca_certs(self): with self.assertRaises(ConnectionFailure): # Server cert is verified but hostname matching fails connected( - MongoClient("server", ssl=True, serverSelectionTimeoutMS=1000, **self.credentials) # type: ignore[arg-type] + self.simple_client( + "server", ssl=True, serverSelectionTimeoutMS=1000, **self.credentials + ) # type: ignore[arg-type] ) # Server cert is verified. Disable hostname matching. connected( - MongoClient( + self.simple_client( "server", ssl=True, tlsAllowInvalidHostnames=True, @@ -447,12 +450,14 @@ def test_validation_with_system_ca_certs(self): # Server cert and hostname are verified. connected( - MongoClient("localhost", ssl=True, serverSelectionTimeoutMS=1000, **self.credentials) # type: ignore[arg-type] + self.simple_client( + "localhost", ssl=True, serverSelectionTimeoutMS=1000, **self.credentials + ) # type: ignore[arg-type] ) # Server cert and hostname are verified. connected( - MongoClient( + self.simple_client( "mongodb://localhost/?ssl=true&serverSelectionTimeoutMS=1000", **self.credentials, # type: ignore[arg-type] ) @@ -472,7 +477,7 @@ def test_system_certs_config_error(self): ssl_support.HAVE_WINCERTSTORE = False try: with self.assertRaises(ConfigurationError): - MongoClient("mongodb://localhost/?ssl=true") + self.simple_client("mongodb://localhost/?ssl=true") finally: ssl_support.HAVE_CERTIFI = have_certifi ssl_support.HAVE_WINCERTSTORE = have_wincertstore @@ -536,7 +541,7 @@ def test_mongodb_x509_auth(self): ], ) - noauth = MongoClient( + noauth = self.simple_client( client_context.pair, ssl=True, tlsAllowInvalidCertificates=True, @@ -548,7 +553,7 @@ def test_mongodb_x509_auth(self): noauth.pymongo_test.test.find_one() listener = EventListener() - auth = MongoClient( + auth = self.simple_client( client_context.pair, authMechanism="MONGODB-X509", ssl=True, @@ -572,7 +577,7 @@ def test_mongodb_x509_auth(self): host, port, ) - client = MongoClient( + client = self.simple_client( uri, ssl=True, tlsAllowInvalidCertificates=True, tlsCertificateKeyFile=CLIENT_PEM ) self.addCleanup(client.close) @@ -580,7 +585,7 @@ def test_mongodb_x509_auth(self): client.pymongo_test.test.find_one() uri = "mongodb://%s:%d/?authMechanism=MONGODB-X509" % (host, port) - client = MongoClient( + client = self.simple_client( uri, ssl=True, tlsAllowInvalidCertificates=True, tlsCertificateKeyFile=CLIENT_PEM ) self.addCleanup(client.close) @@ -593,7 +598,7 @@ def test_mongodb_x509_auth(self): port, ) - bad_client = MongoClient( + bad_client = self.simple_client( uri, ssl=True, tlsAllowInvalidCertificates=True, tlsCertificateKeyFile=CLIENT_PEM ) self.addCleanup(bad_client.close) @@ -601,7 +606,7 @@ def test_mongodb_x509_auth(self): with self.assertRaises(OperationFailure): bad_client.pymongo_test.test.find_one() - bad_client = MongoClient( + bad_client = self.simple_client( client_context.pair, username="not the username", authMechanism="MONGODB-X509", @@ -622,7 +627,7 @@ def test_mongodb_x509_auth(self): ) try: connected( - MongoClient( + self.simple_client( uri, ssl=True, tlsAllowInvalidCertificates=True, @@ -648,7 +653,7 @@ def remove(path): self.addCleanup(remove, temp_ca_bundle) # Add the CA cert file to the bundle. cat_files(temp_ca_bundle, CA_BUNDLE_PEM, CA_PEM) - with MongoClient( + with self.simple_client( "localhost", tls=True, tlsCertificateKeyFile=CLIENT_PEM, tlsCAFile=temp_ca_bundle ) as client: self.assertTrue(client.admin.command("ping")) From 1ea15cba837f78a25dea75045098d3cd5021ddb4 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Fri, 13 Sep 2024 10:48:24 -0400 Subject: [PATCH 13/29] More fixes --- test/__init__.py | 5 ++++- test/asynchronous/__init__.py | 5 ++++- test/test_dns.py | 2 +- 3 files changed, 9 insertions(+), 3 deletions(-) diff --git a/test/__init__.py b/test/__init__.py index 38d705ce46..32b3c96017 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -1073,7 +1073,10 @@ def rs_or_single_client(self, h: Any = None, p: Any = None, **kwargs: Any) -> Mo return self._async_mongo_client(h, p, **kwargs) def simple_client(self, h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient: - client = MongoClient(h, p, **kwargs) + if not h and not p: + client = MongoClient(**kwargs) + else: + client = MongoClient(h, p, **kwargs) self.addCleanup(client.close) return client diff --git a/test/asynchronous/__init__.py b/test/asynchronous/__init__.py index e807720c05..959edbd368 100644 --- a/test/asynchronous/__init__.py +++ b/test/asynchronous/__init__.py @@ -1089,7 +1089,10 @@ async def async_rs_or_single_client( return await self._async_mongo_client(h, p, **kwargs) def simple_client(self, h: Any = None, p: Any = None, **kwargs: Any) -> AsyncMongoClient: - client = AsyncMongoClient(h, p, **kwargs) + if not h and not p: + client = AsyncMongoClient(**kwargs) + else: + client = AsyncMongoClient(h, p, **kwargs) self.addAsyncCleanup(client.close) return client diff --git a/test/test_dns.py b/test/test_dns.py index ac2059fa99..9123439073 100644 --- a/test/test_dns.py +++ b/test/test_dns.py @@ -119,7 +119,7 @@ def run_test(self): # tests. copts["tlsAllowInvalidHostnames"] = True - client = PyMongoTestCase.unmanaged_single_client(uri, **copts) + client = PyMongoTestCase.unmanaged_rs_or_single_client(uri, **copts) if num_seeds is not None: self.assertEqual(len(client._topology_settings.seeds), num_seeds) if hosts is not None: From 0512fa08663555fa6eeee198a2d7c5cf547ef0d9 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Fri, 13 Sep 2024 11:29:53 -0400 Subject: [PATCH 14/29] Use simple_client instead of raw clients --- test/test_max_staleness.py | 5 +++-- test/test_monitor.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/test/test_max_staleness.py b/test/test_max_staleness.py index 283f5ae5d0..101a8745eb 100644 --- a/test/test_max_staleness.py +++ b/test/test_max_staleness.py @@ -20,6 +20,7 @@ import time import warnings +from pymongo import MongoClient from pymongo.operations import _Op sys.path[0:0] = [""] @@ -87,7 +88,7 @@ def test_max_staleness_float(self): with warnings.catch_warnings(record=True) as ctx: warnings.simplefilter("always") - client = self.single_client( + client = self.simple_client( "mongodb://host/?maxStalenessSeconds=1.5&readPreference=nearest" ) @@ -104,7 +105,7 @@ def test_max_staleness_zero(self): with warnings.catch_warnings(record=True) as ctx: warnings.simplefilter("always") - client = self.single_client( + client = self.simple_client( "mongodb://host/?maxStalenessSeconds=0&readPreference=nearest" ) diff --git a/test/test_monitor.py b/test/test_monitor.py index 5fb9f3f267..e66866bb89 100644 --- a/test/test_monitor.py +++ b/test/test_monitor.py @@ -49,7 +49,7 @@ def get_executors(client): class TestMonitor(IntegrationTest): def create_client(self): listener = ServerAndTopologyEventListener() - client = self.single_client(event_listeners=[listener]) + client = self.simple_client(event_listeners=[listener]) connected(client) return client From 55e12cbe33807d6941e5f8861da4dd50a4d5441b Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Fri, 13 Sep 2024 12:40:31 -0400 Subject: [PATCH 15/29] More DNS fixes --- test/__init__.py | 8 ++++++++ test/asynchronous/__init__.py | 10 ++++++++++ test/test_dns.py | 2 +- test/test_monitor.py | 2 +- 4 files changed, 20 insertions(+), 2 deletions(-) diff --git a/test/__init__.py b/test/__init__.py index 32b3c96017..c69b807a7a 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -1080,6 +1080,14 @@ def simple_client(self, h: Any = None, p: Any = None, **kwargs: Any) -> MongoCli self.addCleanup(client.close) return client + @classmethod + def unmanaged_simple_client(cls, h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient: + if not h and not p: + client = MongoClient(**kwargs) + else: + client = MongoClient(h, p, **kwargs) + return client + def disable_replication(self, client): """Disable replication on all secondaries.""" for h, p in client.secondaries: diff --git a/test/asynchronous/__init__.py b/test/asynchronous/__init__.py index 959edbd368..92b9c83245 100644 --- a/test/asynchronous/__init__.py +++ b/test/asynchronous/__init__.py @@ -1096,6 +1096,16 @@ def simple_client(self, h: Any = None, p: Any = None, **kwargs: Any) -> AsyncMon self.addAsyncCleanup(client.close) return client + @classmethod + def unmanaged_simple_client( + cls, h: Any = None, p: Any = None, **kwargs: Any + ) -> AsyncMongoClient: + if not h and not p: + client = AsyncMongoClient(**kwargs) + else: + client = AsyncMongoClient(h, p, **kwargs) + return client + async def disable_replication(self, client): """Disable replication on all secondaries.""" for h, p in client.secondaries: diff --git a/test/test_dns.py b/test/test_dns.py index 9123439073..f2185efb1b 100644 --- a/test/test_dns.py +++ b/test/test_dns.py @@ -119,7 +119,7 @@ def run_test(self): # tests. copts["tlsAllowInvalidHostnames"] = True - client = PyMongoTestCase.unmanaged_rs_or_single_client(uri, **copts) + client = PyMongoTestCase.unmanaged_simple_client(uri, **copts) if num_seeds is not None: self.assertEqual(len(client._topology_settings.seeds), num_seeds) if hosts is not None: diff --git a/test/test_monitor.py b/test/test_monitor.py index e66866bb89..5fb9f3f267 100644 --- a/test/test_monitor.py +++ b/test/test_monitor.py @@ -49,7 +49,7 @@ def get_executors(client): class TestMonitor(IntegrationTest): def create_client(self): listener = ServerAndTopologyEventListener() - client = self.simple_client(event_listeners=[listener]) + client = self.single_client(event_listeners=[listener]) connected(client) return client From a27c852123f85851c1e8350cd30dca026a19a2c5 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Mon, 16 Sep 2024 09:01:54 -0400 Subject: [PATCH 16/29] Fix test_cleanup_executors_on_client_del --- test/test_monitor.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/test/test_monitor.py b/test/test_monitor.py index 5fb9f3f267..f8e9443fae 100644 --- a/test/test_monitor.py +++ b/test/test_monitor.py @@ -18,6 +18,7 @@ import gc import subprocess import sys +import warnings from functools import partial sys.path[0:0] = [""] @@ -49,23 +50,25 @@ def get_executors(client): class TestMonitor(IntegrationTest): def create_client(self): listener = ServerAndTopologyEventListener() - client = self.single_client(event_listeners=[listener]) + client = self.unmanaged_single_client(event_listeners=[listener]) connected(client) return client def test_cleanup_executors_on_client_del(self): - client = self.create_client() - executors = get_executors(client) - self.assertEqual(len(executors), 4) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + client = self.create_client() + executors = get_executors(client) + self.assertEqual(len(executors), 4) - # Each executor stores a weakref to itself in _EXECUTORS. - executor_refs = [(r, r()._name) for r in _EXECUTORS.copy() if r() in executors] + # Each executor stores a weakref to itself in _EXECUTORS. + executor_refs = [(r, r()._name) for r in _EXECUTORS.copy() if r() in executors] - del executors - del client + del executors + del client - for ref, name in executor_refs: - wait_until(partial(unregistered, ref), f"unregister executor: {name}", timeout=5) + for ref, name in executor_refs: + wait_until(partial(unregistered, ref), f"unregister executor: {name}", timeout=5) def test_cleanup_executors_on_client_close(self): client = self.create_client() From a54d4e3bad7581112300ef58a32bec72c5759399 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Mon, 16 Sep 2024 09:15:08 -0400 Subject: [PATCH 17/29] Remove source reference from MongoClient.__del__ --- pymongo/asynchronous/mongo_client.py | 1 - pymongo/synchronous/mongo_client.py | 1 - 2 files changed, 2 deletions(-) diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index f7fc8e5e81..6d0e5d5280 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -1193,7 +1193,6 @@ def __del__(self) -> None: ), ResourceWarning, stacklevel=2, - source=self, ) except AttributeError: pass diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 5786bbf5a9..b2dff5b4ab 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -1193,7 +1193,6 @@ def __del__(self) -> None: ), ResourceWarning, stacklevel=2, - source=self, ) except AttributeError: pass From 9cf059bf205d25c83c0b99aa94f9c137502b38c4 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Mon, 16 Sep 2024 09:58:59 -0400 Subject: [PATCH 18/29] Fix test_srv_polling --- test/test_srv_polling.py | 20 ++++++++------------ 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/test/test_srv_polling.py b/test/test_srv_polling.py index 0b79867182..e01552bf7d 100644 --- a/test/test_srv_polling.py +++ b/test/test_srv_polling.py @@ -167,7 +167,7 @@ def dns_resolver_response(): # Patch timeouts to ensure short test running times. with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME): - client = self.rs_or_single_client(self.CONNECTION_STRING) + client = self.simple_client(self.CONNECTION_STRING) self.assert_nodelist_change(self.BASE_SRV_RESPONSE, client) # Patch list of hosts returned by DNS query. with SrvPollingKnobs( @@ -231,7 +231,7 @@ def final_callback(): count_resolver_calls=True, ): # Client uses unpatched method to get initial nodelist - client = self.rs_or_single_client(self.CONNECTION_STRING) + client = self.simple_client(self.CONNECTION_STRING) # Invalid DNS resolver response should not change nodelist. self.assert_nodelist_nochange(self.BASE_SRV_RESPONSE, client) @@ -264,8 +264,7 @@ def nodelist_callback(): return response with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME): - client = MongoClient(self.CONNECTION_STRING, srvMaxHosts=0) - self.addCleanup(client.close) + client = self.simple_client(self.CONNECTION_STRING, srvMaxHosts=0) with SrvPollingKnobs(nodelist_callback=nodelist_callback): self.assert_nodelist_change(response, client) @@ -279,8 +278,7 @@ def nodelist_callback(): return response with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME): - client = MongoClient(self.CONNECTION_STRING, srvMaxHosts=2) - self.addCleanup(client.close) + client = self.simple_client(self.CONNECTION_STRING, srvMaxHosts=2) with SrvPollingKnobs(nodelist_callback=nodelist_callback): self.assert_nodelist_change(response, client) @@ -295,8 +293,7 @@ def nodelist_callback(): return response with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME): - client = MongoClient(self.CONNECTION_STRING, srvMaxHosts=2) - self.addCleanup(client.close) + client = self.simple_client(self.CONNECTION_STRING, srvMaxHosts=2) with SrvPollingKnobs(nodelist_callback=nodelist_callback): sleep(2 * common.MIN_SRV_RESCAN_INTERVAL) final_topology = set(client.topology_description.server_descriptions()) @@ -305,8 +302,7 @@ def nodelist_callback(): def test_does_not_flipflop(self): with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME): - client = MongoClient(self.CONNECTION_STRING, srvMaxHosts=1) - self.addCleanup(client.close) + client = self.simple_client(self.CONNECTION_STRING, srvMaxHosts=1) old = set(client.topology_description.server_descriptions()) sleep(4 * WAIT_TIME) new = set(client.topology_description.server_descriptions()) @@ -323,7 +319,7 @@ def nodelist_callback(): return response with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME): - client = self.rs_or_single_client( + client = self.simple_client( "mongodb+srv://test22.test.build.10gen.cc/?srvServiceName=customname" ) with SrvPollingKnobs(nodelist_callback=nodelist_callback): @@ -340,7 +336,7 @@ def resolver_response(): min_srv_rescan_interval=WAIT_TIME, nodelist_callback=resolver_response, ): - client = self.rs_or_single_client(self.CONNECTION_STRING) + client = self.simple_client(self.CONNECTION_STRING) self.assertRaises( AssertionError, self.assert_nodelist_change, modified, client, timeout=WAIT_TIME / 2 ) From ab018d7b8dd203ebccc37354d1db90290e168bec Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Mon, 16 Sep 2024 10:47:30 -0400 Subject: [PATCH 19/29] More test fixes --- test/asynchronous/test_auth_spec.py | 2 +- test/asynchronous/test_transactions.py | 7 +++++++ test/test_auth_spec.py | 2 +- test/test_transactions.py | 7 +++++++ 4 files changed, 16 insertions(+), 2 deletions(-) diff --git a/test/asynchronous/test_auth_spec.py b/test/asynchronous/test_auth_spec.py index 329b3eec62..476dc5b5c0 100644 --- a/test/asynchronous/test_auth_spec.py +++ b/test/asynchronous/test_auth_spec.py @@ -54,7 +54,7 @@ def run_test(self): warnings.simplefilter("default") self.assertRaises(Exception, AsyncMongoClient, uri, connect=False) else: - client = AsyncMongoClient(uri, connect=False) + client = self.simple_client(uri, connect=False) credentials = client.options.pool_options._credentials if credential is None: self.assertIsNone(credentials) diff --git a/test/asynchronous/test_transactions.py b/test/asynchronous/test_transactions.py index 309ff5d6ae..1aa36c5eac 100644 --- a/test/asynchronous/test_transactions.py +++ b/test/asynchronous/test_transactions.py @@ -68,6 +68,13 @@ async def _setup_class(cls): await cls.unmanaged_async_single_client("{}:{}".format(*address)) ) + @classmethod + async def _tearDown_class(cls): + await super()._tearDown_class() + if cls.mongos_clients: + for client in cls.mongos_clients: + await client.close() + def maybe_skip_scenario(self, test): super().maybe_skip_scenario(test) if ( diff --git a/test/test_auth_spec.py b/test/test_auth_spec.py index 38e5f19bf8..1e6bf6bc33 100644 --- a/test/test_auth_spec.py +++ b/test/test_auth_spec.py @@ -54,7 +54,7 @@ def run_test(self): warnings.simplefilter("default") self.assertRaises(Exception, MongoClient, uri, connect=False) else: - client = MongoClient(uri, connect=False) + client = self.simple_client(uri, connect=False) credentials = client.options.pool_options._credentials if credential is None: self.assertIsNone(credentials) diff --git a/test/test_transactions.py b/test/test_transactions.py index 8110705600..c29a24c64d 100644 --- a/test/test_transactions.py +++ b/test/test_transactions.py @@ -66,6 +66,13 @@ def _setup_class(cls): for address in client_context.mongoses: cls.mongos_clients.append(cls.unmanaged_single_client("{}:{}".format(*address))) + @classmethod + def _tearDown_class(cls): + super()._tearDown_class() + if cls.mongos_clients: + for client in cls.mongos_clients: + client.close() + def maybe_skip_scenario(self, test): super().maybe_skip_scenario(test) if ( From a41e554c11417f622363375d1d8b308dd470c288 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Mon, 16 Sep 2024 11:05:46 -0400 Subject: [PATCH 20/29] More fixes --- test/asynchronous/test_auth_spec.py | 3 ++- test/asynchronous/test_transactions.py | 5 ++--- test/test_auth_spec.py | 3 ++- test/test_transactions.py | 5 ++--- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/test/asynchronous/test_auth_spec.py b/test/asynchronous/test_auth_spec.py index 476dc5b5c0..a6ab1cb331 100644 --- a/test/asynchronous/test_auth_spec.py +++ b/test/asynchronous/test_auth_spec.py @@ -20,6 +20,7 @@ import os import sys import warnings +from test.asynchronous import AsyncPyMongoTestCase sys.path[0:0] = [""] @@ -34,7 +35,7 @@ _TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "auth") -class TestAuthSpec(unittest.IsolatedAsyncioTestCase): +class TestAuthSpec(AsyncPyMongoTestCase): pass diff --git a/test/asynchronous/test_transactions.py b/test/asynchronous/test_transactions.py index 1aa36c5eac..c47e37ba9b 100644 --- a/test/asynchronous/test_transactions.py +++ b/test/asynchronous/test_transactions.py @@ -70,10 +70,9 @@ async def _setup_class(cls): @classmethod async def _tearDown_class(cls): + for client in cls.mongos_clients: + await client.close() await super()._tearDown_class() - if cls.mongos_clients: - for client in cls.mongos_clients: - await client.close() def maybe_skip_scenario(self, test): super().maybe_skip_scenario(test) diff --git a/test/test_auth_spec.py b/test/test_auth_spec.py index 1e6bf6bc33..3c3a1a67ae 100644 --- a/test/test_auth_spec.py +++ b/test/test_auth_spec.py @@ -20,6 +20,7 @@ import os import sys import warnings +from test import PyMongoTestCase sys.path[0:0] = [""] @@ -34,7 +35,7 @@ _TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "auth") -class TestAuthSpec(unittest.TestCase): +class TestAuthSpec(PyMongoTestCase): pass diff --git a/test/test_transactions.py b/test/test_transactions.py index c29a24c64d..4eb7ecc08a 100644 --- a/test/test_transactions.py +++ b/test/test_transactions.py @@ -68,10 +68,9 @@ def _setup_class(cls): @classmethod def _tearDown_class(cls): + for client in cls.mongos_clients: + client.close() super()._tearDown_class() - if cls.mongos_clients: - for client in cls.mongos_clients: - client.close() def maybe_skip_scenario(self, test): super().maybe_skip_scenario(test) From b899e641240ec78864ef9717fde1b270c76c005a Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Mon, 16 Sep 2024 13:06:53 -0400 Subject: [PATCH 21/29] More fixes --- test/__init__.py | 9 +++++---- test/asynchronous/__init__.py | 9 +++++---- test/asynchronous/test_transactions.py | 15 --------------- test/test_transactions.py | 13 ------------- 4 files changed, 10 insertions(+), 36 deletions(-) diff --git a/test/__init__.py b/test/__init__.py index c69b807a7a..1a17ff14c5 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -235,6 +235,9 @@ def _init_client(self): if not self._check_user_provided(): _create_user(self.client.admin, db_user, db_pwd) + if self.client: + self.client.close() + self.client = self._connect( host, port, @@ -261,9 +264,9 @@ def _init_client(self): if "setName" in hello: self.replica_set_name = str(hello["setName"]) self.is_rs = True + if self.client: + self.client.close() if self.auth_enabled: - if self.client: - self.client.close() # It doesn't matter which member we use as the seed here. self.client = pymongo.MongoClient( host, @@ -274,8 +277,6 @@ def _init_client(self): **self.default_client_options, ) else: - if self.client: - self.client.close() self.client = pymongo.MongoClient( host, port, replicaSet=self.replica_set_name, **self.default_client_options ) diff --git a/test/asynchronous/__init__.py b/test/asynchronous/__init__.py index 92b9c83245..0d94331587 100644 --- a/test/asynchronous/__init__.py +++ b/test/asynchronous/__init__.py @@ -235,6 +235,9 @@ async def _init_client(self): if not await self._check_user_provided(): await _create_user(self.client.admin, db_user, db_pwd) + if self.client: + await self.client.close() + self.client = await self._connect( host, port, @@ -261,9 +264,9 @@ async def _init_client(self): if "setName" in hello: self.replica_set_name = str(hello["setName"]) self.is_rs = True + if self.client: + await self.client.close() if self.auth_enabled: - if self.client: - await self.client.close() # It doesn't matter which member we use as the seed here. self.client = pymongo.AsyncMongoClient( host, @@ -274,8 +277,6 @@ async def _init_client(self): **self.default_client_options, ) else: - if self.client: - await self.client.close() self.client = pymongo.AsyncMongoClient( host, port, replicaSet=self.replica_set_name, **self.default_client_options ) diff --git a/test/asynchronous/test_transactions.py b/test/asynchronous/test_transactions.py index c47e37ba9b..b5d0686417 100644 --- a/test/asynchronous/test_transactions.py +++ b/test/asynchronous/test_transactions.py @@ -59,21 +59,6 @@ class AsyncTransactionsBase(AsyncSpecRunner): - @classmethod - async def _setup_class(cls): - await super()._setup_class() - if async_client_context.supports_transactions(): - for address in async_client_context.mongoses: - cls.mongos_clients.append( - await cls.unmanaged_async_single_client("{}:{}".format(*address)) - ) - - @classmethod - async def _tearDown_class(cls): - for client in cls.mongos_clients: - await client.close() - await super()._tearDown_class() - def maybe_skip_scenario(self, test): super().maybe_skip_scenario(test) if ( diff --git a/test/test_transactions.py b/test/test_transactions.py index 4eb7ecc08a..3cecbe9d38 100644 --- a/test/test_transactions.py +++ b/test/test_transactions.py @@ -59,19 +59,6 @@ class TransactionsBase(SpecRunner): - @classmethod - def _setup_class(cls): - super()._setup_class() - if client_context.supports_transactions(): - for address in client_context.mongoses: - cls.mongos_clients.append(cls.unmanaged_single_client("{}:{}".format(*address))) - - @classmethod - def _tearDown_class(cls): - for client in cls.mongos_clients: - client.close() - super()._tearDown_class() - def maybe_skip_scenario(self, test): super().maybe_skip_scenario(test) if ( From 58aeb162ea04229470efd7f9a441703aea86c89f Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Mon, 16 Sep 2024 14:17:54 -0400 Subject: [PATCH 22/29] Fix change stream tests --- test/asynchronous/test_change_stream.py | 22 ++++++++++++++-------- test/test_change_stream.py | 11 ++++++++--- 2 files changed, 22 insertions(+), 11 deletions(-) diff --git a/test/asynchronous/test_change_stream.py b/test/asynchronous/test_change_stream.py index 1b89c43bb7..883ed72c4c 100644 --- a/test/asynchronous/test_change_stream.py +++ b/test/asynchronous/test_change_stream.py @@ -28,12 +28,17 @@ sys.path[0:0] = [""] -from test.asynchronous import AsyncIntegrationTest, Version, async_client_context, unittest +from test.asynchronous import ( + AsyncIntegrationTest, + AsyncPyMongoTestCase, + Version, + async_client_context, + unittest, +) from test.unified_format import generate_test_classes from test.utils import ( AllowListEventListener, EventListener, - async_rs_or_single_client, async_wait_until, ) @@ -69,8 +74,7 @@ async def change_stream(self, *args, **kwargs): async def client_with_listener(self, *commands): """Return a client with a AllowListEventListener.""" listener = AllowListEventListener(*commands) - client = await async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.close) + client = await self.async_rs_or_single_client(event_listeners=[listener]) return client, listener def watched_collection(self, *args, **kwargs): @@ -176,7 +180,7 @@ async def _wait_until(): @no_type_check async def test_try_next_runs_one_getmore(self): listener = EventListener() - client = await async_rs_or_single_client(event_listeners=[listener]) + client = await self.async_rs_or_single_client(event_listeners=[listener]) # Connect to the cluster. await client.admin.command("ping") listener.reset() @@ -234,7 +238,7 @@ async def _wait_until(): @no_type_check async def test_batch_size_is_honored(self): listener = EventListener() - client = await async_rs_or_single_client(event_listeners=[listener]) + client = await self.async_rs_or_single_client(event_listeners=[listener]) # Connect to the cluster. await client.admin.command("ping") listener.reset() @@ -481,7 +485,9 @@ class ProseSpecTestsMixin: @no_type_check async def _client_with_listener(self, *commands): listener = AllowListEventListener(*commands) - client = await async_rs_or_single_client(event_listeners=[listener]) + client = await AsyncPyMongoTestCase.unmanaged_async_rs_or_single_client( + event_listeners=[listener] + ) self.addAsyncCleanup(client.close) return client, listener @@ -1131,7 +1137,7 @@ class TestAllLegacyScenarios(AsyncIntegrationTest): async def _setup_class(cls): await super()._setup_class() cls.listener = AllowListEventListener("aggregate", "getMore") - cls.client = await async_rs_or_single_client(event_listeners=[cls.listener]) + cls.client = await cls.unmanaged_async_rs_or_single_client(event_listeners=[cls.listener]) @classmethod async def _tearDown_class(cls): diff --git a/test/test_change_stream.py b/test/test_change_stream.py index fb4542946e..dae224c5e0 100644 --- a/test/test_change_stream.py +++ b/test/test_change_stream.py @@ -28,7 +28,13 @@ sys.path[0:0] = [""] -from test import IntegrationTest, Version, client_context, unittest +from test import ( + IntegrationTest, + PyMongoTestCase, + Version, + client_context, + unittest, +) from test.unified_format import generate_test_classes from test.utils import ( AllowListEventListener, @@ -69,7 +75,6 @@ def client_with_listener(self, *commands): """Return a client with a AllowListEventListener.""" listener = AllowListEventListener(*commands) client = self.rs_or_single_client(event_listeners=[listener]) - self.addCleanup(client.close) return client, listener def watched_collection(self, *args, **kwargs): @@ -472,7 +477,7 @@ class ProseSpecTestsMixin: @no_type_check def _client_with_listener(self, *commands): listener = AllowListEventListener(*commands) - client = self.rs_or_single_client(event_listeners=[listener]) + client = PyMongoTestCase.unmanaged_rs_or_single_client(event_listeners=[listener]) self.addCleanup(client.close) return client, listener From 62d27162fbfcf1ec7f0345458fa88ced44b80262 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Mon, 16 Sep 2024 14:47:55 -0400 Subject: [PATCH 23/29] Fix oidc and ocsp tests --- test/auth_oidc/test_auth_oidc.py | 1 + test/ocsp/test_ocsp.py | 1 + 2 files changed, 2 insertions(+) diff --git a/test/auth_oidc/test_auth_oidc.py b/test/auth_oidc/test_auth_oidc.py index fa4b7d6697..a739ab4ac7 100644 --- a/test/auth_oidc/test_auth_oidc.py +++ b/test/auth_oidc/test_auth_oidc.py @@ -94,6 +94,7 @@ def fail_point(self, command_args): yield finally: client.admin.command("configureFailPoint", cmd_on["configureFailPoint"], mode="off") + client.close() @pytest.mark.auth_oidc diff --git a/test/ocsp/test_ocsp.py b/test/ocsp/test_ocsp.py index fe7f21160e..cecf4ab37c 100644 --- a/test/ocsp/test_ocsp.py +++ b/test/ocsp/test_ocsp.py @@ -50,6 +50,7 @@ def _connect(options): print(uri) client = pymongo.MongoClient(uri) client.admin.command("ping") + client.close() class TestOCSP(unittest.TestCase): From 0f67b0d45d6af5796b6784c463eef1348dfe65ef Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Mon, 16 Sep 2024 14:54:36 -0400 Subject: [PATCH 24/29] More fixes --- test/auth_aws/test_auth_aws.py | 5 +++-- test/auth_oidc/test_auth_oidc.py | 7 +++++-- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/test/auth_aws/test_auth_aws.py b/test/auth_aws/test_auth_aws.py index 10416ae5fe..9798b99868 100644 --- a/test/auth_aws/test_auth_aws.py +++ b/test/auth_aws/test_auth_aws.py @@ -18,6 +18,7 @@ import os import sys import unittest +from test import PyMongoTestCase from unittest.mock import patch import pytest @@ -185,14 +186,14 @@ def test_no_cache_environment_variables(self): client2.get_database().test.find_one() -class TestAWSLambdaExamples(unittest.TestCase): +class TestAWSLambdaExamples(PyMongoTestCase): def test_shared_client(self): # Start AWS Lambda Example 1 import os from pymongo import MongoClient - client = MongoClient(host=os.environ["MONGODB_URI"]) + client = self.simple_client(host=os.environ["MONGODB_URI"]) def lambda_handler(event, context): return client.db.command("ping") diff --git a/test/auth_oidc/test_auth_oidc.py b/test/auth_oidc/test_auth_oidc.py index a739ab4ac7..6d31f3db4e 100644 --- a/test/auth_oidc/test_auth_oidc.py +++ b/test/auth_oidc/test_auth_oidc.py @@ -23,6 +23,7 @@ import warnings from contextlib import contextmanager from pathlib import Path +from test import PyMongoTestCase from typing import Dict import pytest @@ -56,7 +57,7 @@ pytestmark = pytest.mark.auth_oidc -class OIDCTestBase(unittest.TestCase): +class OIDCTestBase(PyMongoTestCase): @classmethod def setUpClass(cls): cls.uri_single = os.environ["MONGODB_URI_SINGLE"] @@ -150,7 +151,9 @@ def create_client(self, *args, **kwargs): if not len(args): args = [self.uri_single] - return MongoClient(*args, authmechanismproperties=props, **kwargs) + client = self.simple_client(*args, authmechanismproperties=props, **kwargs) + + return client def test_1_1_single_principal_implicit_username(self): # Create default OIDC client with authMechanism=MONGODB-OIDC. From 06b5d0eacaca03ca2e1db54edeac5ec2e1615dae Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Mon, 16 Sep 2024 15:35:24 -0400 Subject: [PATCH 25/29] More fixes --- test/auth_aws/test_auth_aws.py | 30 ++++++++++-------------------- test/ocsp/test_ocsp.py | 8 +++++--- 2 files changed, 15 insertions(+), 23 deletions(-) diff --git a/test/auth_aws/test_auth_aws.py b/test/auth_aws/test_auth_aws.py index 9798b99868..a7660f2f67 100644 --- a/test/auth_aws/test_auth_aws.py +++ b/test/auth_aws/test_auth_aws.py @@ -37,7 +37,7 @@ pytestmark = pytest.mark.auth_aws -class TestAuthAWS(unittest.TestCase): +class TestAuthAWS(PyMongoTestCase): uri: str @classmethod @@ -70,7 +70,7 @@ def setup_cache(self): self.skipTest("Not testing cached credentials") # Make a connection to ensure that we enable caching. - client = MongoClient(self.uri) + client = self.simple_client(self.uri) client.get_database().test.find_one() client.close() @@ -80,7 +80,7 @@ def setup_cache(self): auth.set_cached_credentials(None) self.assertEqual(auth.get_cached_credentials(), None) - client = MongoClient(self.uri) + client = self.simple_client(self.uri) client.get_database().test.find_one() client.close() return auth.get_cached_credentials() @@ -91,8 +91,7 @@ def test_cache_credentials(self): def test_cache_about_to_expire(self): creds = self.setup_cache() - client = MongoClient(self.uri) - self.addCleanup(client.close) + client = self.simple_client(self.uri) # Make the creds about to expire. creds = auth.get_cached_credentials() @@ -108,8 +107,7 @@ def test_cache_about_to_expire(self): def test_poisoned_cache(self): creds = self.setup_cache() - client = MongoClient(self.uri) - self.addCleanup(client.close) + client = self.simple_client(self.uri) # Poison the creds with invalid password. assert creds is not None @@ -131,8 +129,7 @@ def test_environment_variables_ignored(self): self.assertIsNotNone(creds) os.environ.copy() - client = MongoClient(self.uri) - self.addCleanup(client.close) + client = self.simple_client(self.uri) client.get_database().test.find_one() @@ -150,8 +147,7 @@ def test_environment_variables_ignored(self): auth.set_cached_credentials(None) - client2 = MongoClient(self.uri) - self.addCleanup(client2.close) + client2 = self.simple_client(self.uri) with patch.dict("os.environ", mock_env): self.assertEqual(os.environ["AWS_ACCESS_KEY_ID"], "foo") @@ -167,8 +163,7 @@ def test_no_cache_environment_variables(self): if creds.token: mock_env["AWS_SESSION_TOKEN"] = creds.token - client = MongoClient(self.uri) - self.addCleanup(client.close) + client = self.simple_client(self.uri) with patch.dict(os.environ, mock_env): self.assertEqual(os.environ["AWS_ACCESS_KEY_ID"], creds.username) @@ -178,8 +173,7 @@ def test_no_cache_environment_variables(self): mock_env["AWS_ACCESS_KEY_ID"] = "foo" - client2 = MongoClient(self.uri) - self.addCleanup(client2.close) + client2 = self.simple_client(self.uri) with patch.dict("os.environ", mock_env), self.assertRaises(OperationFailure): self.assertEqual(os.environ["AWS_ACCESS_KEY_ID"], "foo") @@ -191,8 +185,6 @@ def test_shared_client(self): # Start AWS Lambda Example 1 import os - from pymongo import MongoClient - client = self.simple_client(host=os.environ["MONGODB_URI"]) def lambda_handler(event, context): @@ -204,9 +196,7 @@ def test_IAM_auth(self): # Start AWS Lambda Example 2 import os - from pymongo import MongoClient - - client = MongoClient( + client = self.simple_client( host=os.environ["MONGODB_URI"], authSource="$external", authMechanism="MONGODB-AWS", diff --git a/test/ocsp/test_ocsp.py b/test/ocsp/test_ocsp.py index cecf4ab37c..a42b3a34ee 100644 --- a/test/ocsp/test_ocsp.py +++ b/test/ocsp/test_ocsp.py @@ -48,9 +48,11 @@ def _connect(options): uri = f"mongodb://localhost:27017/?serverSelectionTimeoutMS={TIMEOUT_MS}&tlsCAFile={CA_FILE}&{options}" print(uri) - client = pymongo.MongoClient(uri) - client.admin.command("ping") - client.close() + try: + client = pymongo.MongoClient(uri) + client.admin.command("ping") + finally: + client.close() class TestOCSP(unittest.TestCase): From 12732a6618d483c4a5caf379223687379c09693e Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Mon, 16 Sep 2024 15:55:45 -0400 Subject: [PATCH 26/29] Fixes woo --- test/asynchronous/test_auth.py | 44 +++++++++++++++------------ test/asynchronous/test_client.py | 6 ++-- test/mockupdb/test_cursor.py | 7 ++--- test/test_auth.py | 44 +++++++++++++++------------ test/test_client.py | 6 ++-- test/test_discovery_and_monitoring.py | 2 +- test/test_max_staleness.py | 22 +++++++------- 7 files changed, 71 insertions(+), 60 deletions(-) diff --git a/test/asynchronous/test_auth.py b/test/asynchronous/test_auth.py index ff0c3e46a9..fbaca41f09 100644 --- a/test/asynchronous/test_auth.py +++ b/test/asynchronous/test_auth.py @@ -23,7 +23,13 @@ sys.path[0:0] = [""] -from test.asynchronous import AsyncIntegrationTest, SkipTest, async_client_context, unittest +from test.asynchronous import ( + AsyncIntegrationTest, + AsyncPyMongoTestCase, + SkipTest, + async_client_context, + unittest, +) from test.utils import AllowListEventListener, delay, ignore_deprecations from pymongo import AsyncMongoClient, monitoring @@ -73,7 +79,7 @@ def run(self): self.success = True -class TestGSSAPI(unittest.IsolatedAsyncioTestCase): +class TestGSSAPI(AsyncPyMongoTestCase): mech_properties: str service_realm_required: bool @@ -130,7 +136,7 @@ async def test_gssapi_simple(self): if not self.service_realm_required: # Without authMechanismProperties. - client = AsyncMongoClient( + client = self.simple_client( GSSAPI_HOST, GSSAPI_PORT, username=GSSAPI_PRINCIPAL, @@ -141,11 +147,11 @@ async def test_gssapi_simple(self): await client[GSSAPI_DB].collection.find_one() # Log in using URI, without authMechanismProperties. - client = AsyncMongoClient(uri) + client = self.simple_client(uri) await client[GSSAPI_DB].collection.find_one() # Authenticate with authMechanismProperties. - client = AsyncMongoClient( + client = self.simple_client( GSSAPI_HOST, GSSAPI_PORT, username=GSSAPI_PRINCIPAL, @@ -158,14 +164,14 @@ async def test_gssapi_simple(self): # Log in using URI, with authMechanismProperties. mech_uri = uri + f"&authMechanismProperties={self.mech_properties}" - client = AsyncMongoClient(mech_uri) + client = self.simple_client(mech_uri) await client[GSSAPI_DB].collection.find_one() set_name = async_client_context.replica_set_name if set_name: if not self.service_realm_required: # Without authMechanismProperties - client = AsyncMongoClient( + client = self.simple_client( GSSAPI_HOST, GSSAPI_PORT, username=GSSAPI_PRINCIPAL, @@ -177,11 +183,11 @@ async def test_gssapi_simple(self): await client[GSSAPI_DB].list_collection_names() uri = uri + f"&replicaSet={set_name!s}" - client = AsyncMongoClient(uri) + client = self.simple_client(uri) await client[GSSAPI_DB].list_collection_names() # With authMechanismProperties - client = AsyncMongoClient( + client = self.simple_client( GSSAPI_HOST, GSSAPI_PORT, username=GSSAPI_PRINCIPAL, @@ -194,13 +200,13 @@ async def test_gssapi_simple(self): await client[GSSAPI_DB].list_collection_names() mech_uri = mech_uri + f"&replicaSet={set_name!s}" - client = AsyncMongoClient(mech_uri) + client = self.simple_client(mech_uri) await client[GSSAPI_DB].list_collection_names() @ignore_deprecations @async_client_context.require_sync async def test_gssapi_threaded(self): - client = AsyncMongoClient( + client = self.simple_client( GSSAPI_HOST, GSSAPI_PORT, username=GSSAPI_PRINCIPAL, @@ -236,7 +242,7 @@ async def test_gssapi_threaded(self): set_name = async_client_context.replica_set_name if set_name: - client = AsyncMongoClient( + client = self.simple_client( GSSAPI_HOST, GSSAPI_PORT, username=GSSAPI_PRINCIPAL, @@ -259,14 +265,14 @@ async def test_gssapi_threaded(self): self.assertTrue(thread.success) -class TestSASLPlain(unittest.IsolatedAsyncioTestCase): +class TestSASLPlain(AsyncPyMongoTestCase): @classmethod def setUpClass(cls): if not SASL_HOST or not SASL_USER or not SASL_PASS: raise SkipTest("Must set SASL_HOST, SASL_USER, and SASL_PASS to test SASL") async def test_sasl_plain(self): - client = AsyncMongoClient( + client = self.simple_client( SASL_HOST, SASL_PORT, username=SASL_USER, @@ -285,12 +291,12 @@ async def test_sasl_plain(self): SASL_PORT, SASL_DB, ) - client = AsyncMongoClient(uri) + client = self.simple_client(uri) await client.ldap.test.find_one() set_name = async_client_context.replica_set_name if set_name: - client = AsyncMongoClient( + client = self.simple_client( SASL_HOST, SASL_PORT, replicaSet=set_name, @@ -309,7 +315,7 @@ async def test_sasl_plain(self): SASL_DB, str(set_name), ) - client = AsyncMongoClient(uri) + client = self.simple_client(uri) await client.ldap.test.find_one() async def test_sasl_plain_bad_credentials(self): @@ -323,8 +329,8 @@ def auth_string(user, password): ) return uri - bad_user = AsyncMongoClient(auth_string("not-user", SASL_PASS)) - bad_pwd = AsyncMongoClient(auth_string(SASL_USER, "not-pwd")) + bad_user = self.simple_client(auth_string("not-user", SASL_PASS)) + bad_pwd = self.simple_client(auth_string(SASL_USER, "not-pwd")) # OperationFailure raised upon connecting. with self.assertRaises(OperationFailure): await bad_user.admin.command("ping") diff --git a/test/asynchronous/test_client.py b/test/asynchronous/test_client.py index d67fb2aa4a..d6be685085 100644 --- a/test/asynchronous/test_client.py +++ b/test/asynchronous/test_client.py @@ -178,7 +178,7 @@ async def test_connect_timeout(self): self.assertEqual(None, pool_opts.socket_timeout) self.assertEqual(None, pool_opts.connect_timeout) - client = await self.async_single_client( + client = await self.simple_client( "mongodb://localhost/?connectTimeoutMS=0&socketTimeoutMS=0", connect=False ) pool_opts = client.options.pool_options @@ -345,10 +345,10 @@ async def test_metadata(self): metadata = copy.deepcopy(_METADATA) metadata["driver"]["name"] = "PyMongo|async" metadata["application"] = {"name": "foobar"} - client = await self.async_single_client("mongodb://foo:27017/?appname=foobar&connect=false") + client = await self.simple_client("mongodb://foo:27017/?appname=foobar&connect=false") options = client.options self.assertEqual(options.pool_options.metadata, metadata) - client = await self.async_single_client("foo", 27017, appname="foobar", connect=False) + client = await self.simple_client("foo", 27017, appname="foobar", connect=False) options = client.options self.assertEqual(options.pool_options.metadata, metadata) # No error diff --git a/test/mockupdb/test_cursor.py b/test/mockupdb/test_cursor.py index 46af39c7b9..2300297218 100644 --- a/test/mockupdb/test_cursor.py +++ b/test/mockupdb/test_cursor.py @@ -29,13 +29,12 @@ from bson.objectid import ObjectId -from pymongo import MongoClient from pymongo.errors import OperationFailure pytestmark = pytest.mark.mockupdb -class TestCursor(unittest.TestCase): +class TestCursor(PyMongoTestCase): def test_getmore_load_balanced(self): server = MockupDB() server.autoresponds( @@ -50,7 +49,7 @@ def test_getmore_load_balanced(self): server.run() self.addCleanup(server.stop) - client = MongoClient(server.uri, loadBalanced=True) + client = self.simple_client(server.uri, loadBalanced=True) self.addCleanup(client.close) collection = client.db.coll cursor = collection.find() @@ -77,7 +76,7 @@ def _test_fail_on_operation_failure_with_code(self, code): self.addCleanup(server.stop) server.autoresponds("ismaster", maxWireVersion=6) - client = MongoClient(server.uri) + client = self.simple_client(server.uri) with going(lambda: server.receives(OpMsg({"find": "collection"})).command_err(code=code)): cursor = client.db.collection.find() diff --git a/test/test_auth.py b/test/test_auth.py index 9ce8fc2a1b..b311d330bc 100644 --- a/test/test_auth.py +++ b/test/test_auth.py @@ -23,7 +23,13 @@ sys.path[0:0] = [""] -from test import IntegrationTest, SkipTest, client_context, unittest +from test import ( + IntegrationTest, + PyMongoTestCase, + SkipTest, + client_context, + unittest, +) from test.utils import AllowListEventListener, delay, ignore_deprecations from pymongo import MongoClient, monitoring @@ -73,7 +79,7 @@ def run(self): self.success = True -class TestGSSAPI(unittest.TestCase): +class TestGSSAPI(PyMongoTestCase): mech_properties: str service_realm_required: bool @@ -130,7 +136,7 @@ def test_gssapi_simple(self): if not self.service_realm_required: # Without authMechanismProperties. - client = MongoClient( + client = self.simple_client( GSSAPI_HOST, GSSAPI_PORT, username=GSSAPI_PRINCIPAL, @@ -141,11 +147,11 @@ def test_gssapi_simple(self): client[GSSAPI_DB].collection.find_one() # Log in using URI, without authMechanismProperties. - client = MongoClient(uri) + client = self.simple_client(uri) client[GSSAPI_DB].collection.find_one() # Authenticate with authMechanismProperties. - client = MongoClient( + client = self.simple_client( GSSAPI_HOST, GSSAPI_PORT, username=GSSAPI_PRINCIPAL, @@ -158,14 +164,14 @@ def test_gssapi_simple(self): # Log in using URI, with authMechanismProperties. mech_uri = uri + f"&authMechanismProperties={self.mech_properties}" - client = MongoClient(mech_uri) + client = self.simple_client(mech_uri) client[GSSAPI_DB].collection.find_one() set_name = client_context.replica_set_name if set_name: if not self.service_realm_required: # Without authMechanismProperties - client = MongoClient( + client = self.simple_client( GSSAPI_HOST, GSSAPI_PORT, username=GSSAPI_PRINCIPAL, @@ -177,11 +183,11 @@ def test_gssapi_simple(self): client[GSSAPI_DB].list_collection_names() uri = uri + f"&replicaSet={set_name!s}" - client = MongoClient(uri) + client = self.simple_client(uri) client[GSSAPI_DB].list_collection_names() # With authMechanismProperties - client = MongoClient( + client = self.simple_client( GSSAPI_HOST, GSSAPI_PORT, username=GSSAPI_PRINCIPAL, @@ -194,13 +200,13 @@ def test_gssapi_simple(self): client[GSSAPI_DB].list_collection_names() mech_uri = mech_uri + f"&replicaSet={set_name!s}" - client = MongoClient(mech_uri) + client = self.simple_client(mech_uri) client[GSSAPI_DB].list_collection_names() @ignore_deprecations @client_context.require_sync def test_gssapi_threaded(self): - client = MongoClient( + client = self.simple_client( GSSAPI_HOST, GSSAPI_PORT, username=GSSAPI_PRINCIPAL, @@ -236,7 +242,7 @@ def test_gssapi_threaded(self): set_name = client_context.replica_set_name if set_name: - client = MongoClient( + client = self.simple_client( GSSAPI_HOST, GSSAPI_PORT, username=GSSAPI_PRINCIPAL, @@ -259,14 +265,14 @@ def test_gssapi_threaded(self): self.assertTrue(thread.success) -class TestSASLPlain(unittest.TestCase): +class TestSASLPlain(PyMongoTestCase): @classmethod def setUpClass(cls): if not SASL_HOST or not SASL_USER or not SASL_PASS: raise SkipTest("Must set SASL_HOST, SASL_USER, and SASL_PASS to test SASL") def test_sasl_plain(self): - client = MongoClient( + client = self.simple_client( SASL_HOST, SASL_PORT, username=SASL_USER, @@ -285,12 +291,12 @@ def test_sasl_plain(self): SASL_PORT, SASL_DB, ) - client = MongoClient(uri) + client = self.simple_client(uri) client.ldap.test.find_one() set_name = client_context.replica_set_name if set_name: - client = MongoClient( + client = self.simple_client( SASL_HOST, SASL_PORT, replicaSet=set_name, @@ -309,7 +315,7 @@ def test_sasl_plain(self): SASL_DB, str(set_name), ) - client = MongoClient(uri) + client = self.simple_client(uri) client.ldap.test.find_one() def test_sasl_plain_bad_credentials(self): @@ -323,8 +329,8 @@ def auth_string(user, password): ) return uri - bad_user = MongoClient(auth_string("not-user", SASL_PASS)) - bad_pwd = MongoClient(auth_string(SASL_USER, "not-pwd")) + bad_user = self.simple_client(auth_string("not-user", SASL_PASS)) + bad_pwd = self.simple_client(auth_string(SASL_USER, "not-pwd")) # OperationFailure raised upon connecting. with self.assertRaises(OperationFailure): bad_user.admin.command("ping") diff --git a/test/test_client.py b/test/test_client.py index b293158fc0..20df2796d4 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -174,7 +174,7 @@ def test_connect_timeout(self): self.assertEqual(None, pool_opts.socket_timeout) self.assertEqual(None, pool_opts.connect_timeout) - client = self.single_client( + client = self.simple_client( "mongodb://localhost/?connectTimeoutMS=0&socketTimeoutMS=0", connect=False ) pool_opts = client.options.pool_options @@ -337,10 +337,10 @@ def test_metadata(self): metadata = copy.deepcopy(_METADATA) metadata["driver"]["name"] = "PyMongo" metadata["application"] = {"name": "foobar"} - client = self.single_client("mongodb://foo:27017/?appname=foobar&connect=false") + client = self.simple_client("mongodb://foo:27017/?appname=foobar&connect=false") options = client.options self.assertEqual(options.pool_options.metadata, metadata) - client = self.single_client("foo", 27017, appname="foobar", connect=False) + client = self.simple_client("foo", 27017, appname="foobar", connect=False) options = client.options self.assertEqual(options.pool_options.metadata, metadata) # No error diff --git a/test/test_discovery_and_monitoring.py b/test/test_discovery_and_monitoring.py index 860be89e80..3554619f12 100644 --- a/test/test_discovery_and_monitoring.py +++ b/test/test_discovery_and_monitoring.py @@ -421,7 +421,7 @@ def test_heartbeat_start_ordering(self): server.events = events server_thread = threading.Thread(target=server.handle_request_and_shutdown) server_thread.start() - _c = self.single_client( + _c = self.simple_client( "mongodb://localhost:9999", serverSelectionTimeoutMS=500, event_listeners=(listener,) ) server_thread.join() diff --git a/test/test_max_staleness.py b/test/test_max_staleness.py index 101a8745eb..32d09ada9a 100644 --- a/test/test_max_staleness.py +++ b/test/test_max_staleness.py @@ -41,44 +41,44 @@ class TestAllScenarios(create_selection_tests(_TEST_PATH)): # type: ignore class TestMaxStaleness(PyMongoTestCase): def test_max_staleness(self): - client = self.single_client() + client = self.simple_client() self.assertEqual(-1, client.read_preference.max_staleness) - client = self.single_client("mongodb://a/?readPreference=secondary") + client = self.simple_client("mongodb://a/?readPreference=secondary") self.assertEqual(-1, client.read_preference.max_staleness) # These tests are specified in max-staleness-tests.rst. with self.assertRaises(ConfigurationError): # Default read pref "primary" can't be used with max staleness. - self.single_client("mongodb://a/?maxStalenessSeconds=120") + self.simple_client("mongodb://a/?maxStalenessSeconds=120") with self.assertRaises(ConfigurationError): # Read pref "primary" can't be used with max staleness. - self.single_client("mongodb://a/?readPreference=primary&maxStalenessSeconds=120") + self.simple_client("mongodb://a/?readPreference=primary&maxStalenessSeconds=120") - client = self.single_client("mongodb://host/?maxStalenessSeconds=-1") + client = self.simple_client("mongodb://host/?maxStalenessSeconds=-1") self.assertEqual(-1, client.read_preference.max_staleness) - client = self.single_client("mongodb://host/?readPreference=primary&maxStalenessSeconds=-1") + client = self.simple_client("mongodb://host/?readPreference=primary&maxStalenessSeconds=-1") self.assertEqual(-1, client.read_preference.max_staleness) - client = self.single_client( + client = self.simple_client( "mongodb://host/?readPreference=secondary&maxStalenessSeconds=120" ) self.assertEqual(120, client.read_preference.max_staleness) - client = self.single_client("mongodb://a/?readPreference=secondary&maxStalenessSeconds=1") + client = self.simple_client("mongodb://a/?readPreference=secondary&maxStalenessSeconds=1") self.assertEqual(1, client.read_preference.max_staleness) - client = self.single_client("mongodb://a/?readPreference=secondary&maxStalenessSeconds=-1") + client = self.simple_client("mongodb://a/?readPreference=secondary&maxStalenessSeconds=-1") self.assertEqual(-1, client.read_preference.max_staleness) - client = self.single_client(maxStalenessSeconds=-1, readPreference="nearest") + client = self.simple_client(maxStalenessSeconds=-1, readPreference="nearest") self.assertEqual(-1, client.read_preference.max_staleness) with self.assertRaises(TypeError): # Prohibit None. - self.single_client(maxStalenessSeconds=None, readPreference="nearest") + self.simple_client(maxStalenessSeconds=None, readPreference="nearest") def test_max_staleness_float(self): with self.assertRaises(TypeError) as ctx: From a770fd1dc7cd86a4dd57750a3bd0ac0e161edd8f Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Mon, 16 Sep 2024 16:17:11 -0400 Subject: [PATCH 27/29] Fixes --- test/asynchronous/test_client.py | 6 +++--- test/test_retryable_writes.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/test/asynchronous/test_client.py b/test/asynchronous/test_client.py index d6be685085..089e172d5f 100644 --- a/test/asynchronous/test_client.py +++ b/test/asynchronous/test_client.py @@ -178,7 +178,7 @@ async def test_connect_timeout(self): self.assertEqual(None, pool_opts.socket_timeout) self.assertEqual(None, pool_opts.connect_timeout) - client = await self.simple_client( + client = self.simple_client( "mongodb://localhost/?connectTimeoutMS=0&socketTimeoutMS=0", connect=False ) pool_opts = client.options.pool_options @@ -345,10 +345,10 @@ async def test_metadata(self): metadata = copy.deepcopy(_METADATA) metadata["driver"]["name"] = "PyMongo|async" metadata["application"] = {"name": "foobar"} - client = await self.simple_client("mongodb://foo:27017/?appname=foobar&connect=false") + client = self.simple_client("mongodb://foo:27017/?appname=foobar&connect=false") options = client.options self.assertEqual(options.pool_options.metadata, metadata) - client = await self.simple_client("foo", 27017, appname="foobar", connect=False) + client = self.simple_client("foo", 27017, appname="foobar", connect=False) options = client.options self.assertEqual(options.pool_options.metadata, metadata) # No error diff --git a/test/test_retryable_writes.py b/test/test_retryable_writes.py index f84b85816b..89454ad236 100644 --- a/test/test_retryable_writes.py +++ b/test/test_retryable_writes.py @@ -298,7 +298,7 @@ def test_unsupported_single_statement(self): def test_server_selection_timeout_not_retried(self): """A ServerSelectionTimeoutError is not retried.""" listener = OvertCommandListener() - client = self.single_client( + client = self.simple_client( "somedomainthatdoesntexist.org", serverSelectionTimeoutMS=1, retryWrites=True, From 4adcffd9e94f873a6b44d8ba941f8ee7b2ed7a62 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Mon, 16 Sep 2024 16:35:18 -0400 Subject: [PATCH 28/29] Fix serverless --- test/asynchronous/test_client.py | 8 ++++---- test/test_client.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/test/asynchronous/test_client.py b/test/asynchronous/test_client.py index 089e172d5f..fc4e7075e6 100644 --- a/test/asynchronous/test_client.py +++ b/test/asynchronous/test_client.py @@ -368,7 +368,7 @@ async def test_metadata(self): # Test appending to driver info. metadata["driver"]["name"] = "PyMongo|async|FooDriver" metadata["driver"]["version"] = "{}|1.2.3".format(_METADATA["driver"]["version"]) - client = await self.async_single_client( + client = await self.simple_client( "foo", 27017, appname="foobar", @@ -378,7 +378,7 @@ async def test_metadata(self): options = client.options self.assertEqual(options.pool_options.metadata, metadata) metadata["platform"] = "{}|FooPlatform".format(_METADATA["platform"]) - client = await self.async_single_client( + client = await self.simple_client( "foo", 27017, appname="foobar", @@ -388,7 +388,7 @@ async def test_metadata(self): options = client.options self.assertEqual(options.pool_options.metadata, metadata) # Test truncating driver info metadata. - client = await self.async_single_client( + client = await self.simple_client( driver=DriverInfo(name="s" * _MAX_METADATA_SIZE), connect=False, ) @@ -397,7 +397,7 @@ async def test_metadata(self): len(bson.encode(options.pool_options.metadata)), _MAX_METADATA_SIZE, ) - client = await self.async_single_client( + client = await self.simple_client( driver=DriverInfo(name="s" * _MAX_METADATA_SIZE, version="s" * _MAX_METADATA_SIZE), connect=False, ) diff --git a/test/test_client.py b/test/test_client.py index 20df2796d4..bc45325f0b 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -360,7 +360,7 @@ def test_metadata(self): # Test appending to driver info. metadata["driver"]["name"] = "PyMongo|FooDriver" metadata["driver"]["version"] = "{}|1.2.3".format(_METADATA["driver"]["version"]) - client = self.single_client( + client = self.simple_client( "foo", 27017, appname="foobar", @@ -370,7 +370,7 @@ def test_metadata(self): options = client.options self.assertEqual(options.pool_options.metadata, metadata) metadata["platform"] = "{}|FooPlatform".format(_METADATA["platform"]) - client = self.single_client( + client = self.simple_client( "foo", 27017, appname="foobar", @@ -380,7 +380,7 @@ def test_metadata(self): options = client.options self.assertEqual(options.pool_options.metadata, metadata) # Test truncating driver info metadata. - client = self.single_client( + client = self.simple_client( driver=DriverInfo(name="s" * _MAX_METADATA_SIZE), connect=False, ) @@ -389,7 +389,7 @@ def test_metadata(self): len(bson.encode(options.pool_options.metadata)), _MAX_METADATA_SIZE, ) - client = self.single_client( + client = self.simple_client( driver=DriverInfo(name="s" * _MAX_METADATA_SIZE, version="s" * _MAX_METADATA_SIZE), connect=False, ) From 0678a89561d98049563aa6d61fb0203a83b3eded Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Mon, 16 Sep 2024 16:54:43 -0400 Subject: [PATCH 29/29] fixes --- test/asynchronous/test_client.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/asynchronous/test_client.py b/test/asynchronous/test_client.py index fc4e7075e6..f610f32779 100644 --- a/test/asynchronous/test_client.py +++ b/test/asynchronous/test_client.py @@ -368,7 +368,7 @@ async def test_metadata(self): # Test appending to driver info. metadata["driver"]["name"] = "PyMongo|async|FooDriver" metadata["driver"]["version"] = "{}|1.2.3".format(_METADATA["driver"]["version"]) - client = await self.simple_client( + client = self.simple_client( "foo", 27017, appname="foobar", @@ -378,7 +378,7 @@ async def test_metadata(self): options = client.options self.assertEqual(options.pool_options.metadata, metadata) metadata["platform"] = "{}|FooPlatform".format(_METADATA["platform"]) - client = await self.simple_client( + client = self.simple_client( "foo", 27017, appname="foobar", @@ -388,7 +388,7 @@ async def test_metadata(self): options = client.options self.assertEqual(options.pool_options.metadata, metadata) # Test truncating driver info metadata. - client = await self.simple_client( + client = self.simple_client( driver=DriverInfo(name="s" * _MAX_METADATA_SIZE), connect=False, ) @@ -397,7 +397,7 @@ async def test_metadata(self): len(bson.encode(options.pool_options.metadata)), _MAX_METADATA_SIZE, ) - client = await self.simple_client( + client = self.simple_client( driver=DriverInfo(name="s" * _MAX_METADATA_SIZE, version="s" * _MAX_METADATA_SIZE), connect=False, )