diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index f7fc8e5e81..6d0e5d5280 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -1193,7 +1193,6 @@ def __del__(self) -> None: ), ResourceWarning, stacklevel=2, - source=self, ) except AttributeError: pass diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 5786bbf5a9..b2dff5b4ab 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -1193,7 +1193,6 @@ def __del__(self) -> None: ), ResourceWarning, stacklevel=2, - source=self, ) except AttributeError: pass diff --git a/pyproject.toml b/pyproject.toml index b64c7d6031..1702f6d16d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -96,9 +96,6 @@ filterwarnings = [ "module:please use dns.resolver.Resolver.resolve:DeprecationWarning", # https://github.com/dateutil/dateutil/issues/1314 "module:datetime.datetime.utc:DeprecationWarning:dateutil", - # TODO: Remove both of these in https://jira.mongodb.org/browse/PYTHON-4731 - "ignore:Unclosed AsyncMongoClient*", - "ignore:Unclosed MongoClient*", ] markers = [ "auth_aws: tests that rely on pymongo-auth-aws", diff --git a/test/__init__.py b/test/__init__.py index 41af81f979..1a17ff14c5 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -16,8 +16,6 @@ from __future__ import annotations import asyncio -import base64 -import contextlib import gc import multiprocessing import os @@ -27,7 +25,6 @@ import sys import threading import time -import traceback import unittest import warnings from asyncio import iscoroutinefunction @@ -54,6 +51,8 @@ sanitize_reply, ) +from pymongo.uri_parser import parse_uri + try: import ipaddress @@ -80,6 +79,12 @@ _IS_SYNC = True +def _connection_string(h): + if h.startswith(("mongodb://", "mongodb+srv://")): + return h + return f"mongodb://{h!s}" + + class ClientContext: client: MongoClient @@ -230,6 +235,9 @@ def _init_client(self): if not self._check_user_provided(): _create_user(self.client.admin, db_user, db_pwd) + if self.client: + self.client.close() + self.client = self._connect( host, port, @@ -256,6 +264,8 @@ def _init_client(self): if "setName" in hello: self.replica_set_name = str(hello["setName"]) self.is_rs = True + if self.client: + self.client.close() if self.auth_enabled: # It doesn't matter which member we use as the seed here. self.client = pymongo.MongoClient( @@ -318,6 +328,7 @@ def _init_client(self): hello = mongos_client.admin.command(HelloCompat.LEGACY_CMD) if hello.get("msg") == "isdbgrid": self.mongoses.append(next_address) + mongos_client.close() def init(self): with self.conn_lock: @@ -537,12 +548,6 @@ def require_auth(self, func): lambda: self.auth_enabled, "Authentication is not enabled on the server", func=func ) - def require_no_fips(self, func): - """Run a test only if the host does not have FIPS enabled.""" - return self._require( - lambda: not self.fips_enabled, "Test cannot run on a FIPS-enabled host", func=func - ) - def require_no_auth(self, func): """Run a test only if the server is running without auth enabled.""" return self._require( @@ -930,6 +935,172 @@ def _target() -> None: self.fail(f"child timed out after {timeout}s (see traceback in logs): deadlock?") self.assertEqual(proc.exitcode, 0) + @classmethod + def _unmanaged_async_mongo_client( + cls, host, port, authenticate=True, directConnection=None, **kwargs + ): + """Create a new client over SSL/TLS if necessary.""" + host = host or client_context.host + port = port or client_context.port + client_options: dict = client_context.default_client_options.copy() + if client_context.replica_set_name and not directConnection: + client_options["replicaSet"] = client_context.replica_set_name + if directConnection is not None: + client_options["directConnection"] = directConnection + client_options.update(kwargs) + + uri = _connection_string(host) + auth_mech = kwargs.get("authMechanism", "") + if client_context.auth_enabled and authenticate and auth_mech != "MONGODB-OIDC": + # Only add the default username or password if one is not provided. + res = parse_uri(uri) + if ( + not res["username"] + and not res["password"] + and "username" not in client_options + and "password" not in client_options + ): + client_options["username"] = db_user + client_options["password"] = db_pwd + client = MongoClient(uri, port, **client_options) + if client._options.connect: + client._connect() + return client + + def _async_mongo_client(self, host, port, authenticate=True, directConnection=None, **kwargs): + """Create a new client over SSL/TLS if necessary.""" + host = host or client_context.host + port = port or client_context.port + client_options: dict = client_context.default_client_options.copy() + if client_context.replica_set_name and not directConnection: + client_options["replicaSet"] = client_context.replica_set_name + if directConnection is not None: + client_options["directConnection"] = directConnection + client_options.update(kwargs) + + uri = _connection_string(host) + auth_mech = kwargs.get("authMechanism", "") + if client_context.auth_enabled and authenticate and auth_mech != "MONGODB-OIDC": + # Only add the default username or password if one is not provided. + res = parse_uri(uri) + if ( + not res["username"] + and not res["password"] + and "username" not in client_options + and "password" not in client_options + ): + client_options["username"] = db_user + client_options["password"] = db_pwd + client = MongoClient(uri, port, **client_options) + if client._options.connect: + client._connect() + self.addCleanup(client.close) + return client + + @classmethod + def unmanaged_single_client_noauth( + cls, h: Any = None, p: Any = None, **kwargs: Any + ) -> MongoClient[dict]: + """Make a direct connection. Don't authenticate.""" + return cls._unmanaged_async_mongo_client( + h, p, authenticate=False, directConnection=True, **kwargs + ) + + @classmethod + def unmanaged_single_client( + cls, h: Any = None, p: Any = None, **kwargs: Any + ) -> MongoClient[dict]: + """Make a direct connection. Don't authenticate.""" + return cls._unmanaged_async_mongo_client(h, p, directConnection=True, **kwargs) + + @classmethod + def unmanaged_rs_client(cls, h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]: + """Connect to the replica set and authenticate if necessary.""" + return cls._unmanaged_async_mongo_client(h, p, **kwargs) + + @classmethod + def unmanaged_rs_client_noauth( + cls, h: Any = None, p: Any = None, **kwargs: Any + ) -> MongoClient[dict]: + """Make a direct connection. Don't authenticate.""" + return cls._unmanaged_async_mongo_client(h, p, authenticate=False, **kwargs) + + @classmethod + def unmanaged_rs_or_single_client_noauth( + cls, h: Any = None, p: Any = None, **kwargs: Any + ) -> MongoClient[dict]: + """Make a direct connection. Don't authenticate.""" + return cls._unmanaged_async_mongo_client(h, p, authenticate=False, **kwargs) + + @classmethod + def unmanaged_rs_or_single_client( + cls, h: Any = None, p: Any = None, **kwargs: Any + ) -> MongoClient[dict]: + """Make a direct connection. Don't authenticate.""" + return cls._unmanaged_async_mongo_client(h, p, **kwargs) + + def single_client_noauth( + self, h: Any = None, p: Any = None, **kwargs: Any + ) -> MongoClient[dict]: + """Make a direct connection. Don't authenticate.""" + return self._async_mongo_client(h, p, authenticate=False, directConnection=True, **kwargs) + + def single_client(self, h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]: + """Make a direct connection, and authenticate if necessary.""" + return self._async_mongo_client(h, p, directConnection=True, **kwargs) + + def rs_client_noauth(self, h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]: + """Connect to the replica set. Don't authenticate.""" + return self._async_mongo_client(h, p, authenticate=False, **kwargs) + + def rs_client(self, h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]: + """Connect to the replica set and authenticate if necessary.""" + return self._async_mongo_client(h, p, **kwargs) + + def rs_or_single_client_noauth( + self, h: Any = None, p: Any = None, **kwargs: Any + ) -> MongoClient[dict]: + """Connect to the replica set if there is one, otherwise the standalone. + + Like rs_or_single_client, but does not authenticate. + """ + return self._async_mongo_client(h, p, authenticate=False, **kwargs) + + def rs_or_single_client(self, h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[Any]: + """Connect to the replica set if there is one, otherwise the standalone. + + Authenticates if necessary. + """ + return self._async_mongo_client(h, p, **kwargs) + + def simple_client(self, h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient: + if not h and not p: + client = MongoClient(**kwargs) + else: + client = MongoClient(h, p, **kwargs) + self.addCleanup(client.close) + return client + + @classmethod + def unmanaged_simple_client(cls, h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient: + if not h and not p: + client = MongoClient(**kwargs) + else: + client = MongoClient(h, p, **kwargs) + return client + + def disable_replication(self, client): + """Disable replication on all secondaries.""" + for h, p in client.secondaries: + secondary = self.single_client(h, p) + secondary.admin.command("configureFailPoint", "stopReplProducer", mode="alwaysOn") + + def enable_replication(self, client): + """Enable replication on all secondaries.""" + for h, p in client.secondaries: + secondary = self.single_client(h, p) + secondary.admin.command("configureFailPoint", "stopReplProducer", mode="off") + class UnitTest(PyMongoTestCase): """Async base class for TestCases that don't require a connection to MongoDB.""" diff --git a/test/asynchronous/__init__.py b/test/asynchronous/__init__.py index d1af89c184..0d94331587 100644 --- a/test/asynchronous/__init__.py +++ b/test/asynchronous/__init__.py @@ -16,8 +16,6 @@ from __future__ import annotations import asyncio -import base64 -import contextlib import gc import multiprocessing import os @@ -27,7 +25,6 @@ import sys import threading import time -import traceback import unittest import warnings from asyncio import iscoroutinefunction @@ -54,6 +51,8 @@ sanitize_reply, ) +from pymongo.uri_parser import parse_uri + try: import ipaddress @@ -80,6 +79,12 @@ _IS_SYNC = False +def _connection_string(h): + if h.startswith(("mongodb://", "mongodb+srv://")): + return h + return f"mongodb://{h!s}" + + class AsyncClientContext: client: AsyncMongoClient @@ -230,6 +235,9 @@ async def _init_client(self): if not await self._check_user_provided(): await _create_user(self.client.admin, db_user, db_pwd) + if self.client: + await self.client.close() + self.client = await self._connect( host, port, @@ -256,6 +264,8 @@ async def _init_client(self): if "setName" in hello: self.replica_set_name = str(hello["setName"]) self.is_rs = True + if self.client: + await self.client.close() if self.auth_enabled: # It doesn't matter which member we use as the seed here. self.client = pymongo.AsyncMongoClient( @@ -320,6 +330,7 @@ async def _init_client(self): hello = await mongos_client.admin.command(HelloCompat.LEGACY_CMD) if hello.get("msg") == "isdbgrid": self.mongoses.append(next_address) + await mongos_client.close() async def init(self): with self.conn_lock: @@ -539,12 +550,6 @@ def require_auth(self, func): lambda: self.auth_enabled, "Authentication is not enabled on the server", func=func ) - def require_no_fips(self, func): - """Run a test only if the host does not have FIPS enabled.""" - return self._require( - lambda: not self.fips_enabled, "Test cannot run on a FIPS-enabled host", func=func - ) - def require_no_auth(self, func): """Run a test only if the server is running without auth enabled.""" return self._require( @@ -932,6 +937,188 @@ def _target() -> None: self.fail(f"child timed out after {timeout}s (see traceback in logs): deadlock?") self.assertEqual(proc.exitcode, 0) + @classmethod + async def _unmanaged_async_mongo_client( + cls, host, port, authenticate=True, directConnection=None, **kwargs + ): + """Create a new client over SSL/TLS if necessary.""" + host = host or await async_client_context.host + port = port or await async_client_context.port + client_options: dict = async_client_context.default_client_options.copy() + if async_client_context.replica_set_name and not directConnection: + client_options["replicaSet"] = async_client_context.replica_set_name + if directConnection is not None: + client_options["directConnection"] = directConnection + client_options.update(kwargs) + + uri = _connection_string(host) + auth_mech = kwargs.get("authMechanism", "") + if async_client_context.auth_enabled and authenticate and auth_mech != "MONGODB-OIDC": + # Only add the default username or password if one is not provided. + res = parse_uri(uri) + if ( + not res["username"] + and not res["password"] + and "username" not in client_options + and "password" not in client_options + ): + client_options["username"] = db_user + client_options["password"] = db_pwd + client = AsyncMongoClient(uri, port, **client_options) + if client._options.connect: + await client.aconnect() + return client + + async def _async_mongo_client( + self, host, port, authenticate=True, directConnection=None, **kwargs + ): + """Create a new client over SSL/TLS if necessary.""" + host = host or await async_client_context.host + port = port or await async_client_context.port + client_options: dict = async_client_context.default_client_options.copy() + if async_client_context.replica_set_name and not directConnection: + client_options["replicaSet"] = async_client_context.replica_set_name + if directConnection is not None: + client_options["directConnection"] = directConnection + client_options.update(kwargs) + + uri = _connection_string(host) + auth_mech = kwargs.get("authMechanism", "") + if async_client_context.auth_enabled and authenticate and auth_mech != "MONGODB-OIDC": + # Only add the default username or password if one is not provided. + res = parse_uri(uri) + if ( + not res["username"] + and not res["password"] + and "username" not in client_options + and "password" not in client_options + ): + client_options["username"] = db_user + client_options["password"] = db_pwd + client = AsyncMongoClient(uri, port, **client_options) + if client._options.connect: + await client.aconnect() + self.addAsyncCleanup(client.close) + return client + + @classmethod + async def unmanaged_async_single_client_noauth( + cls, h: Any = None, p: Any = None, **kwargs: Any + ) -> AsyncMongoClient[dict]: + """Make a direct connection. Don't authenticate.""" + return await cls._unmanaged_async_mongo_client( + h, p, authenticate=False, directConnection=True, **kwargs + ) + + @classmethod + async def unmanaged_async_single_client( + cls, h: Any = None, p: Any = None, **kwargs: Any + ) -> AsyncMongoClient[dict]: + """Make a direct connection. Don't authenticate.""" + return await cls._unmanaged_async_mongo_client(h, p, directConnection=True, **kwargs) + + @classmethod + async def unmanaged_async_rs_client( + cls, h: Any = None, p: Any = None, **kwargs: Any + ) -> AsyncMongoClient[dict]: + """Connect to the replica set and authenticate if necessary.""" + return await cls._unmanaged_async_mongo_client(h, p, **kwargs) + + @classmethod + async def unmanaged_async_rs_client_noauth( + cls, h: Any = None, p: Any = None, **kwargs: Any + ) -> AsyncMongoClient[dict]: + """Make a direct connection. Don't authenticate.""" + return await cls._unmanaged_async_mongo_client(h, p, authenticate=False, **kwargs) + + @classmethod + async def unmanaged_async_rs_or_single_client_noauth( + cls, h: Any = None, p: Any = None, **kwargs: Any + ) -> AsyncMongoClient[dict]: + """Make a direct connection. Don't authenticate.""" + return await cls._unmanaged_async_mongo_client(h, p, authenticate=False, **kwargs) + + @classmethod + async def unmanaged_async_rs_or_single_client( + cls, h: Any = None, p: Any = None, **kwargs: Any + ) -> AsyncMongoClient[dict]: + """Make a direct connection. Don't authenticate.""" + return await cls._unmanaged_async_mongo_client(h, p, **kwargs) + + async def async_single_client_noauth( + self, h: Any = None, p: Any = None, **kwargs: Any + ) -> AsyncMongoClient[dict]: + """Make a direct connection. Don't authenticate.""" + return await self._async_mongo_client( + h, p, authenticate=False, directConnection=True, **kwargs + ) + + async def async_single_client( + self, h: Any = None, p: Any = None, **kwargs: Any + ) -> AsyncMongoClient[dict]: + """Make a direct connection, and authenticate if necessary.""" + return await self._async_mongo_client(h, p, directConnection=True, **kwargs) + + async def async_rs_client_noauth( + self, h: Any = None, p: Any = None, **kwargs: Any + ) -> AsyncMongoClient[dict]: + """Connect to the replica set. Don't authenticate.""" + return await self._async_mongo_client(h, p, authenticate=False, **kwargs) + + async def async_rs_client( + self, h: Any = None, p: Any = None, **kwargs: Any + ) -> AsyncMongoClient[dict]: + """Connect to the replica set and authenticate if necessary.""" + return await self._async_mongo_client(h, p, **kwargs) + + async def async_rs_or_single_client_noauth( + self, h: Any = None, p: Any = None, **kwargs: Any + ) -> AsyncMongoClient[dict]: + """Connect to the replica set if there is one, otherwise the standalone. + + Like rs_or_single_client, but does not authenticate. + """ + return await self._async_mongo_client(h, p, authenticate=False, **kwargs) + + async def async_rs_or_single_client( + self, h: Any = None, p: Any = None, **kwargs: Any + ) -> AsyncMongoClient[Any]: + """Connect to the replica set if there is one, otherwise the standalone. + + Authenticates if necessary. + """ + return await self._async_mongo_client(h, p, **kwargs) + + def simple_client(self, h: Any = None, p: Any = None, **kwargs: Any) -> AsyncMongoClient: + if not h and not p: + client = AsyncMongoClient(**kwargs) + else: + client = AsyncMongoClient(h, p, **kwargs) + self.addAsyncCleanup(client.close) + return client + + @classmethod + def unmanaged_simple_client( + cls, h: Any = None, p: Any = None, **kwargs: Any + ) -> AsyncMongoClient: + if not h and not p: + client = AsyncMongoClient(**kwargs) + else: + client = AsyncMongoClient(h, p, **kwargs) + return client + + async def disable_replication(self, client): + """Disable replication on all secondaries.""" + for h, p in client.secondaries: + secondary = await self.async_single_client(h, p) + secondary.admin.command("configureFailPoint", "stopReplProducer", mode="alwaysOn") + + async def enable_replication(self, client): + """Enable replication on all secondaries.""" + for h, p in client.secondaries: + secondary = await self.async_single_client(h, p) + secondary.admin.command("configureFailPoint", "stopReplProducer", mode="off") + class AsyncUnitTest(AsyncPyMongoTestCase): """Async base class for TestCases that don't require a connection to MongoDB.""" diff --git a/test/asynchronous/test_auth.py b/test/asynchronous/test_auth.py index 06f7fb9ca8..fbaca41f09 100644 --- a/test/asynchronous/test_auth.py +++ b/test/asynchronous/test_auth.py @@ -23,16 +23,14 @@ sys.path[0:0] = [""] -from test.asynchronous import AsyncIntegrationTest, SkipTest, async_client_context, unittest -from test.utils import ( - AllowListEventListener, - async_rs_or_single_client, - async_rs_or_single_client_noauth, - async_single_client, - async_single_client_noauth, - delay, - ignore_deprecations, +from test.asynchronous import ( + AsyncIntegrationTest, + AsyncPyMongoTestCase, + SkipTest, + async_client_context, + unittest, ) +from test.utils import AllowListEventListener, delay, ignore_deprecations from pymongo import AsyncMongoClient, monitoring from pymongo.asynchronous.auth import HAVE_KERBEROS @@ -81,7 +79,7 @@ def run(self): self.success = True -class TestGSSAPI(unittest.IsolatedAsyncioTestCase): +class TestGSSAPI(AsyncPyMongoTestCase): mech_properties: str service_realm_required: bool @@ -138,7 +136,7 @@ async def test_gssapi_simple(self): if not self.service_realm_required: # Without authMechanismProperties. - client = AsyncMongoClient( + client = self.simple_client( GSSAPI_HOST, GSSAPI_PORT, username=GSSAPI_PRINCIPAL, @@ -149,11 +147,11 @@ async def test_gssapi_simple(self): await client[GSSAPI_DB].collection.find_one() # Log in using URI, without authMechanismProperties. - client = AsyncMongoClient(uri) + client = self.simple_client(uri) await client[GSSAPI_DB].collection.find_one() # Authenticate with authMechanismProperties. - client = AsyncMongoClient( + client = self.simple_client( GSSAPI_HOST, GSSAPI_PORT, username=GSSAPI_PRINCIPAL, @@ -166,14 +164,14 @@ async def test_gssapi_simple(self): # Log in using URI, with authMechanismProperties. mech_uri = uri + f"&authMechanismProperties={self.mech_properties}" - client = AsyncMongoClient(mech_uri) + client = self.simple_client(mech_uri) await client[GSSAPI_DB].collection.find_one() set_name = async_client_context.replica_set_name if set_name: if not self.service_realm_required: # Without authMechanismProperties - client = AsyncMongoClient( + client = self.simple_client( GSSAPI_HOST, GSSAPI_PORT, username=GSSAPI_PRINCIPAL, @@ -185,11 +183,11 @@ async def test_gssapi_simple(self): await client[GSSAPI_DB].list_collection_names() uri = uri + f"&replicaSet={set_name!s}" - client = AsyncMongoClient(uri) + client = self.simple_client(uri) await client[GSSAPI_DB].list_collection_names() # With authMechanismProperties - client = AsyncMongoClient( + client = self.simple_client( GSSAPI_HOST, GSSAPI_PORT, username=GSSAPI_PRINCIPAL, @@ -202,13 +200,13 @@ async def test_gssapi_simple(self): await client[GSSAPI_DB].list_collection_names() mech_uri = mech_uri + f"&replicaSet={set_name!s}" - client = AsyncMongoClient(mech_uri) + client = self.simple_client(mech_uri) await client[GSSAPI_DB].list_collection_names() @ignore_deprecations @async_client_context.require_sync async def test_gssapi_threaded(self): - client = AsyncMongoClient( + client = self.simple_client( GSSAPI_HOST, GSSAPI_PORT, username=GSSAPI_PRINCIPAL, @@ -244,7 +242,7 @@ async def test_gssapi_threaded(self): set_name = async_client_context.replica_set_name if set_name: - client = AsyncMongoClient( + client = self.simple_client( GSSAPI_HOST, GSSAPI_PORT, username=GSSAPI_PRINCIPAL, @@ -267,14 +265,14 @@ async def test_gssapi_threaded(self): self.assertTrue(thread.success) -class TestSASLPlain(unittest.IsolatedAsyncioTestCase): +class TestSASLPlain(AsyncPyMongoTestCase): @classmethod def setUpClass(cls): if not SASL_HOST or not SASL_USER or not SASL_PASS: raise SkipTest("Must set SASL_HOST, SASL_USER, and SASL_PASS to test SASL") async def test_sasl_plain(self): - client = AsyncMongoClient( + client = self.simple_client( SASL_HOST, SASL_PORT, username=SASL_USER, @@ -293,12 +291,12 @@ async def test_sasl_plain(self): SASL_PORT, SASL_DB, ) - client = AsyncMongoClient(uri) + client = self.simple_client(uri) await client.ldap.test.find_one() set_name = async_client_context.replica_set_name if set_name: - client = AsyncMongoClient( + client = self.simple_client( SASL_HOST, SASL_PORT, replicaSet=set_name, @@ -317,7 +315,7 @@ async def test_sasl_plain(self): SASL_DB, str(set_name), ) - client = AsyncMongoClient(uri) + client = self.simple_client(uri) await client.ldap.test.find_one() async def test_sasl_plain_bad_credentials(self): @@ -331,8 +329,8 @@ def auth_string(user, password): ) return uri - bad_user = AsyncMongoClient(auth_string("not-user", SASL_PASS)) - bad_pwd = AsyncMongoClient(auth_string(SASL_USER, "not-pwd")) + bad_user = self.simple_client(auth_string("not-user", SASL_PASS)) + bad_pwd = self.simple_client(auth_string(SASL_USER, "not-pwd")) # OperationFailure raised upon connecting. with self.assertRaises(OperationFailure): await bad_user.admin.command("ping") @@ -356,7 +354,7 @@ async def asyncTearDown(self): async def test_scram_sha1(self): host, port = await async_client_context.host, await async_client_context.port - client = await async_rs_or_single_client_noauth( + client = await self.async_rs_or_single_client_noauth( "mongodb://user:pass@%s:%d/pymongo_test?authMechanism=SCRAM-SHA-1" % (host, port) ) await client.pymongo_test.command("dbstats") @@ -367,7 +365,7 @@ async def test_scram_sha1(self): "@%s:%d/pymongo_test?authMechanism=SCRAM-SHA-1" "&replicaSet=%s" % (host, port, async_client_context.replica_set_name) ) - client = await async_single_client_noauth(uri) + client = await self.async_single_client_noauth(uri) await client.pymongo_test.command("dbstats") db = client.get_database("pymongo_test", read_preference=ReadPreference.SECONDARY) await db.command("dbstats") @@ -395,7 +393,7 @@ async def test_scram_skip_empty_exchange(self): "testscram", "sha256", "pwd", roles=["dbOwner"], mechanisms=["SCRAM-SHA-256"] ) - client = await async_rs_or_single_client_noauth( + client = await self.async_rs_or_single_client_noauth( username="sha256", password="pwd", authSource="testscram", event_listeners=[listener] ) await client.testscram.command("dbstats") @@ -432,38 +430,38 @@ async def test_scram(self): ) # Step 2: verify auth success cases - client = await async_rs_or_single_client_noauth( + client = await self.async_rs_or_single_client_noauth( username="sha1", password="pwd", authSource="testscram" ) await client.testscram.command("dbstats") - client = await async_rs_or_single_client_noauth( + client = await self.async_rs_or_single_client_noauth( username="sha1", password="pwd", authSource="testscram", authMechanism="SCRAM-SHA-1" ) await client.testscram.command("dbstats") - client = await async_rs_or_single_client_noauth( + client = await self.async_rs_or_single_client_noauth( username="sha256", password="pwd", authSource="testscram" ) await client.testscram.command("dbstats") - client = await async_rs_or_single_client_noauth( + client = await self.async_rs_or_single_client_noauth( username="sha256", password="pwd", authSource="testscram", authMechanism="SCRAM-SHA-256" ) await client.testscram.command("dbstats") # Step 2: SCRAM-SHA-1 and SCRAM-SHA-256 - client = await async_rs_or_single_client_noauth( + client = await self.async_rs_or_single_client_noauth( username="both", password="pwd", authSource="testscram", authMechanism="SCRAM-SHA-1" ) await client.testscram.command("dbstats") - client = await async_rs_or_single_client_noauth( + client = await self.async_rs_or_single_client_noauth( username="both", password="pwd", authSource="testscram", authMechanism="SCRAM-SHA-256" ) await client.testscram.command("dbstats") self.listener.reset() - client = await async_rs_or_single_client_noauth( + client = await self.async_rs_or_single_client_noauth( username="both", password="pwd", authSource="testscram", event_listeners=[self.listener] ) await client.testscram.command("dbstats") @@ -476,19 +474,19 @@ async def test_scram(self): self.assertEqual(started.command.get("mechanism"), "SCRAM-SHA-256") # Step 3: verify auth failure conditions - client = await async_rs_or_single_client_noauth( + client = await self.async_rs_or_single_client_noauth( username="sha1", password="pwd", authSource="testscram", authMechanism="SCRAM-SHA-256" ) with self.assertRaises(OperationFailure): await client.testscram.command("dbstats") - client = await async_rs_or_single_client_noauth( + client = await self.async_rs_or_single_client_noauth( username="sha256", password="pwd", authSource="testscram", authMechanism="SCRAM-SHA-1" ) with self.assertRaises(OperationFailure): await client.testscram.command("dbstats") - client = await async_rs_or_single_client_noauth( + client = await self.async_rs_or_single_client_noauth( username="not-a-user", password="pwd", authSource="testscram" ) with self.assertRaises(OperationFailure): @@ -501,7 +499,7 @@ async def test_scram(self): port, async_client_context.replica_set_name, ) - client = await async_single_client_noauth(uri) + client = await self.async_single_client_noauth(uri) await client.testscram.command("dbstats") db = client.get_database("testscram", read_preference=ReadPreference.SECONDARY) await db.command("dbstats") @@ -521,12 +519,12 @@ async def test_scram_saslprep(self): "testscram", "IX", "IX", roles=["dbOwner"], mechanisms=["SCRAM-SHA-256"] ) - client = await async_rs_or_single_client_noauth( + client = await self.async_rs_or_single_client_noauth( username="\u2168", password="\u2163", authSource="testscram" ) await client.testscram.command("dbstats") - client = await async_rs_or_single_client_noauth( + client = await self.async_rs_or_single_client_noauth( username="\u2168", password="\u2163", authSource="testscram", @@ -534,17 +532,17 @@ async def test_scram_saslprep(self): ) await client.testscram.command("dbstats") - client = await async_rs_or_single_client_noauth( + client = await self.async_rs_or_single_client_noauth( username="\u2168", password="IV", authSource="testscram" ) await client.testscram.command("dbstats") - client = await async_rs_or_single_client_noauth( + client = await self.async_rs_or_single_client_noauth( username="IX", password="I\u00ADX", authSource="testscram" ) await client.testscram.command("dbstats") - client = await async_rs_or_single_client_noauth( + client = await self.async_rs_or_single_client_noauth( username="IX", password="I\u00ADX", authSource="testscram", @@ -552,31 +550,31 @@ async def test_scram_saslprep(self): ) await client.testscram.command("dbstats") - client = await async_rs_or_single_client_noauth( + client = await self.async_rs_or_single_client_noauth( username="IX", password="IX", authSource="testscram", authMechanism="SCRAM-SHA-256" ) await client.testscram.command("dbstats") - client = await async_rs_or_single_client_noauth( + client = await self.async_rs_or_single_client_noauth( "mongodb://\u2168:\u2163@%s:%d/testscram" % (host, port) ) await client.testscram.command("dbstats") - client = await async_rs_or_single_client_noauth( + client = await self.async_rs_or_single_client_noauth( "mongodb://\u2168:IV@%s:%d/testscram" % (host, port) ) await client.testscram.command("dbstats") - client = await async_rs_or_single_client_noauth( + client = await self.async_rs_or_single_client_noauth( "mongodb://IX:I\u00ADX@%s:%d/testscram" % (host, port) ) await client.testscram.command("dbstats") - client = await async_rs_or_single_client_noauth( + client = await self.async_rs_or_single_client_noauth( "mongodb://IX:IX@%s:%d/testscram" % (host, port) ) await client.testscram.command("dbstats") async def test_cache(self): - client = await async_single_client() + client = await self.async_single_client() credentials = client.options.pool_options._credentials cache = credentials.cache self.assertIsNotNone(cache) @@ -601,8 +599,7 @@ async def test_scram_threaded(self): await coll.insert_one({"_id": 1}) # The first thread to call find() will authenticate - client = await async_rs_or_single_client() - self.addAsyncCleanup(client.close) + client = await self.async_rs_or_single_client() coll = client.db.test threads = [] for _ in range(4): @@ -631,7 +628,9 @@ async def asyncTearDown(self): async def test_uri_options(self): # Test default to admin host, port = await async_client_context.host, await async_client_context.port - client = await async_rs_or_single_client_noauth("mongodb://admin:pass@%s:%d" % (host, port)) + client = await self.async_rs_or_single_client_noauth( + "mongodb://admin:pass@%s:%d" % (host, port) + ) self.assertTrue(await client.admin.command("dbstats")) if async_client_context.is_rs: @@ -640,14 +639,14 @@ async def test_uri_options(self): port, async_client_context.replica_set_name, ) - client = await async_single_client_noauth(uri) + client = await self.async_single_client_noauth(uri) self.assertTrue(await client.admin.command("dbstats")) db = client.get_database("admin", read_preference=ReadPreference.SECONDARY) self.assertTrue(await db.command("dbstats")) # Test explicit database uri = "mongodb://user:pass@%s:%d/pymongo_test" % (host, port) - client = await async_rs_or_single_client_noauth(uri) + client = await self.async_rs_or_single_client_noauth(uri) with self.assertRaises(OperationFailure): await client.admin.command("dbstats") self.assertTrue(await client.pymongo_test.command("dbstats")) @@ -658,7 +657,7 @@ async def test_uri_options(self): port, async_client_context.replica_set_name, ) - client = await async_single_client_noauth(uri) + client = await self.async_single_client_noauth(uri) with self.assertRaises(OperationFailure): await client.admin.command("dbstats") self.assertTrue(await client.pymongo_test.command("dbstats")) @@ -667,7 +666,7 @@ async def test_uri_options(self): # Test authSource uri = "mongodb://user:pass@%s:%d/pymongo_test2?authSource=pymongo_test" % (host, port) - client = await async_rs_or_single_client_noauth(uri) + client = await self.async_rs_or_single_client_noauth(uri) with self.assertRaises(OperationFailure): await client.pymongo_test2.command("dbstats") self.assertTrue(await client.pymongo_test.command("dbstats")) @@ -677,7 +676,7 @@ async def test_uri_options(self): "mongodb://user:pass@%s:%d/pymongo_test2?replicaSet=" "%s;authSource=pymongo_test" % (host, port, async_client_context.replica_set_name) ) - client = await async_single_client_noauth(uri) + client = await self.async_single_client_noauth(uri) with self.assertRaises(OperationFailure): await client.pymongo_test2.command("dbstats") self.assertTrue(await client.pymongo_test.command("dbstats")) diff --git a/test/asynchronous/test_auth_spec.py b/test/asynchronous/test_auth_spec.py index 329b3eec62..a6ab1cb331 100644 --- a/test/asynchronous/test_auth_spec.py +++ b/test/asynchronous/test_auth_spec.py @@ -20,6 +20,7 @@ import os import sys import warnings +from test.asynchronous import AsyncPyMongoTestCase sys.path[0:0] = [""] @@ -34,7 +35,7 @@ _TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "auth") -class TestAuthSpec(unittest.IsolatedAsyncioTestCase): +class TestAuthSpec(AsyncPyMongoTestCase): pass @@ -54,7 +55,7 @@ def run_test(self): warnings.simplefilter("default") self.assertRaises(Exception, AsyncMongoClient, uri, connect=False) else: - client = AsyncMongoClient(uri, connect=False) + client = self.simple_client(uri, connect=False) credentials = client.options.pool_options._credentials if credential is None: self.assertIsNone(credentials) diff --git a/test/asynchronous/test_bulk.py b/test/asynchronous/test_bulk.py index 79d8e1a0f1..42a3311072 100644 --- a/test/asynchronous/test_bulk.py +++ b/test/asynchronous/test_bulk.py @@ -24,23 +24,14 @@ sys.path[0:0] = [""] from test.asynchronous import AsyncIntegrationTest, async_client_context, remove_all_users, unittest -from test.utils import ( - async_rs_or_single_client_noauth, - async_single_client, - async_wait_until, -) +from test.utils import async_wait_until from bson.binary import Binary, UuidRepresentation from bson.codec_options import CodecOptions from bson.objectid import ObjectId from pymongo.asynchronous.collection import AsyncCollection from pymongo.common import partition_node -from pymongo.errors import ( - BulkWriteError, - ConfigurationError, - InvalidOperation, - OperationFailure, -) +from pymongo.errors import BulkWriteError, ConfigurationError, InvalidOperation, OperationFailure from pymongo.operations import * from pymongo.write_concern import WriteConcern @@ -915,7 +906,7 @@ class AsyncTestBulkAuthorization(AsyncBulkAuthorizationTestBase): async def test_readonly(self): # We test that an authorization failure aborts the batch and is raised # as OperationFailure. - cli = await async_rs_or_single_client_noauth( + cli = await self.async_rs_or_single_client_noauth( username="readonly", password="pw", authSource="pymongo_test" ) coll = cli.pymongo_test.test @@ -926,7 +917,7 @@ async def test_readonly(self): async def test_no_remove(self): # We test that an authorization failure aborts the batch and is raised # as OperationFailure. - cli = await async_rs_or_single_client_noauth( + cli = await self.async_rs_or_single_client_noauth( username="noremove", password="pw", authSource="pymongo_test" ) coll = cli.pymongo_test.test @@ -954,7 +945,7 @@ async def _setup_class(cls): if cls.w is not None and cls.w > 1: for member in (await async_client_context.hello)["hosts"]: if member != (await async_client_context.hello)["primary"]: - cls.secondary = await async_single_client(*partition_node(member)) + cls.secondary = await cls.unmanaged_async_single_client(*partition_node(member)) break @classmethod diff --git a/test/asynchronous/test_change_stream.py b/test/asynchronous/test_change_stream.py index 1b89c43bb7..883ed72c4c 100644 --- a/test/asynchronous/test_change_stream.py +++ b/test/asynchronous/test_change_stream.py @@ -28,12 +28,17 @@ sys.path[0:0] = [""] -from test.asynchronous import AsyncIntegrationTest, Version, async_client_context, unittest +from test.asynchronous import ( + AsyncIntegrationTest, + AsyncPyMongoTestCase, + Version, + async_client_context, + unittest, +) from test.unified_format import generate_test_classes from test.utils import ( AllowListEventListener, EventListener, - async_rs_or_single_client, async_wait_until, ) @@ -69,8 +74,7 @@ async def change_stream(self, *args, **kwargs): async def client_with_listener(self, *commands): """Return a client with a AllowListEventListener.""" listener = AllowListEventListener(*commands) - client = await async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.close) + client = await self.async_rs_or_single_client(event_listeners=[listener]) return client, listener def watched_collection(self, *args, **kwargs): @@ -176,7 +180,7 @@ async def _wait_until(): @no_type_check async def test_try_next_runs_one_getmore(self): listener = EventListener() - client = await async_rs_or_single_client(event_listeners=[listener]) + client = await self.async_rs_or_single_client(event_listeners=[listener]) # Connect to the cluster. await client.admin.command("ping") listener.reset() @@ -234,7 +238,7 @@ async def _wait_until(): @no_type_check async def test_batch_size_is_honored(self): listener = EventListener() - client = await async_rs_or_single_client(event_listeners=[listener]) + client = await self.async_rs_or_single_client(event_listeners=[listener]) # Connect to the cluster. await client.admin.command("ping") listener.reset() @@ -481,7 +485,9 @@ class ProseSpecTestsMixin: @no_type_check async def _client_with_listener(self, *commands): listener = AllowListEventListener(*commands) - client = await async_rs_or_single_client(event_listeners=[listener]) + client = await AsyncPyMongoTestCase.unmanaged_async_rs_or_single_client( + event_listeners=[listener] + ) self.addAsyncCleanup(client.close) return client, listener @@ -1131,7 +1137,7 @@ class TestAllLegacyScenarios(AsyncIntegrationTest): async def _setup_class(cls): await super()._setup_class() cls.listener = AllowListEventListener("aggregate", "getMore") - cls.client = await async_rs_or_single_client(event_listeners=[cls.listener]) + cls.client = await cls.unmanaged_async_rs_or_single_client(event_listeners=[cls.listener]) @classmethod async def _tearDown_class(cls): diff --git a/test/asynchronous/test_client.py b/test/asynchronous/test_client.py index 97cbdf6dbd..f610f32779 100644 --- a/test/asynchronous/test_client.py +++ b/test/asynchronous/test_client.py @@ -36,6 +36,7 @@ from unittest.mock import patch import pytest +import pytest_asyncio from pymongo.operations import _Op @@ -61,10 +62,6 @@ CMAPListener, FunctionCallRecorder, async_get_pool, - async_rs_client, - async_rs_or_single_client, - async_rs_or_single_client_noauth, - async_single_client, async_wait_until, asyncAssertRaisesExactly, delay, @@ -72,7 +69,6 @@ is_greenthread_patched, lazy_client_trial, one, - rs_or_single_client, wait_until, ) @@ -133,7 +129,9 @@ class AsyncClientUnitTest(AsyncUnitTest): @classmethod async def _setup_class(cls): - cls.client = await async_rs_or_single_client(connect=False, serverSelectionTimeoutMS=100) + cls.client = await cls.unmanaged_async_rs_or_single_client( + connect=False, serverSelectionTimeoutMS=100 + ) @classmethod async def _tearDown_class(cls): @@ -143,8 +141,8 @@ async def _tearDown_class(cls): def inject_fixtures(self, caplog): self._caplog = caplog - def test_keyword_arg_defaults(self): - client = AsyncMongoClient( + async def test_keyword_arg_defaults(self): + client = self.simple_client( socketTimeoutMS=None, connectTimeoutMS=20000, waitQueueTimeoutMS=None, @@ -169,16 +167,18 @@ def test_keyword_arg_defaults(self): self.assertEqual(ReadPreference.PRIMARY, client.read_preference) self.assertAlmostEqual(12, client.options.server_selection_timeout) - def test_connect_timeout(self): - client = AsyncMongoClient(connect=False, connectTimeoutMS=None, socketTimeoutMS=None) + async def test_connect_timeout(self): + client = self.simple_client(connect=False, connectTimeoutMS=None, socketTimeoutMS=None) pool_opts = client.options.pool_options self.assertEqual(None, pool_opts.socket_timeout) self.assertEqual(None, pool_opts.connect_timeout) - client = AsyncMongoClient(connect=False, connectTimeoutMS=0, socketTimeoutMS=0) + + client = self.simple_client(connect=False, connectTimeoutMS=0, socketTimeoutMS=0) pool_opts = client.options.pool_options self.assertEqual(None, pool_opts.socket_timeout) self.assertEqual(None, pool_opts.connect_timeout) - client = AsyncMongoClient( + + client = self.simple_client( "mongodb://localhost/?connectTimeoutMS=0&socketTimeoutMS=0", connect=False ) pool_opts = client.options.pool_options @@ -194,8 +194,8 @@ def test_types(self): self.assertRaises(ConfigurationError, AsyncMongoClient, []) - def test_max_pool_size_zero(self): - AsyncMongoClient(maxPoolSize=0) + async def test_max_pool_size_zero(self): + self.simple_client(maxPoolSize=0) def test_uri_detection(self): self.assertRaises(ConfigurationError, AsyncMongoClient, "/foo") @@ -260,7 +260,7 @@ def test_iteration(self): self.assertNotIsInstance(client, Iterable) async def test_get_default_database(self): - c = await async_rs_or_single_client( + c = await self.async_rs_or_single_client( "mongodb://%s:%d/foo" % (await async_client_context.host, await async_client_context.port), connect=False, @@ -277,7 +277,7 @@ async def test_get_default_database(self): self.assertEqual(ReadPreference.SECONDARY, db.read_preference) self.assertEqual(write_concern, db.write_concern) - c = await async_rs_or_single_client( + c = await self.async_rs_or_single_client( "mongodb://%s:%d/" % (await async_client_context.host, await async_client_context.port), connect=False, ) @@ -285,7 +285,7 @@ async def test_get_default_database(self): async def test_get_default_database_error(self): # URI with no database. - c = await async_rs_or_single_client( + c = await self.async_rs_or_single_client( "mongodb://%s:%d/" % (await async_client_context.host, await async_client_context.port), connect=False, ) @@ -297,11 +297,11 @@ async def test_get_default_database_with_authsource(self): await async_client_context.host, await async_client_context.port, ) - c = await async_rs_or_single_client(uri, connect=False) + c = await self.async_rs_or_single_client(uri, connect=False) self.assertEqual(AsyncDatabase(c, "foo"), c.get_default_database()) async def test_get_database_default(self): - c = await async_rs_or_single_client( + c = await self.async_rs_or_single_client( "mongodb://%s:%d/foo" % (await async_client_context.host, await async_client_context.port), connect=False, @@ -310,7 +310,7 @@ async def test_get_database_default(self): async def test_get_database_default_error(self): # URI with no database. - c = await async_rs_or_single_client( + c = await self.async_rs_or_single_client( "mongodb://%s:%d/" % (await async_client_context.host, await async_client_context.port), connect=False, ) @@ -322,47 +322,53 @@ async def test_get_database_default_with_authsource(self): await async_client_context.host, await async_client_context.port, ) - c = await async_rs_or_single_client(uri, connect=False) + c = await self.async_rs_or_single_client(uri, connect=False) self.assertEqual(AsyncDatabase(c, "foo"), c.get_database()) - def test_primary_read_pref_with_tags(self): + async def test_primary_read_pref_with_tags(self): # No tags allowed with "primary". with self.assertRaises(ConfigurationError): - AsyncMongoClient("mongodb://host/?readpreferencetags=dc:east") + await self.async_single_client("mongodb://host/?readpreferencetags=dc:east") with self.assertRaises(ConfigurationError): - AsyncMongoClient("mongodb://host/?readpreference=primary&readpreferencetags=dc:east") + await self.async_single_client( + "mongodb://host/?readpreference=primary&readpreferencetags=dc:east" + ) async def test_read_preference(self): - c = await async_rs_or_single_client( + c = await self.async_rs_or_single_client( "mongodb://host", connect=False, readpreference=ReadPreference.NEAREST.mongos_mode ) self.assertEqual(c.read_preference, ReadPreference.NEAREST) - def test_metadata(self): + async def test_metadata(self): metadata = copy.deepcopy(_METADATA) metadata["driver"]["name"] = "PyMongo|async" metadata["application"] = {"name": "foobar"} - client = AsyncMongoClient("mongodb://foo:27017/?appname=foobar&connect=false") + client = self.simple_client("mongodb://foo:27017/?appname=foobar&connect=false") options = client.options self.assertEqual(options.pool_options.metadata, metadata) - client = AsyncMongoClient("foo", 27017, appname="foobar", connect=False) + client = self.simple_client("foo", 27017, appname="foobar", connect=False) options = client.options self.assertEqual(options.pool_options.metadata, metadata) # No error - AsyncMongoClient(appname="x" * 128) - self.assertRaises(ValueError, AsyncMongoClient, appname="x" * 129) + self.simple_client(appname="x" * 128) + with self.assertRaises(ValueError): + self.simple_client(appname="x" * 129) # Bad "driver" options. self.assertRaises(TypeError, DriverInfo, "Foo", 1, "a") self.assertRaises(TypeError, DriverInfo, version="1", platform="a") self.assertRaises(TypeError, DriverInfo) - self.assertRaises(TypeError, AsyncMongoClient, driver=1) - self.assertRaises(TypeError, AsyncMongoClient, driver="abc") - self.assertRaises(TypeError, AsyncMongoClient, driver=("Foo", "1", "a")) + with self.assertRaises(TypeError): + self.simple_client(driver=1) + with self.assertRaises(TypeError): + self.simple_client(driver="abc") + with self.assertRaises(TypeError): + self.simple_client(driver=("Foo", "1", "a")) # Test appending to driver info. metadata["driver"]["name"] = "PyMongo|async|FooDriver" metadata["driver"]["version"] = "{}|1.2.3".format(_METADATA["driver"]["version"]) - client = AsyncMongoClient( + client = self.simple_client( "foo", 27017, appname="foobar", @@ -372,7 +378,7 @@ def test_metadata(self): options = client.options self.assertEqual(options.pool_options.metadata, metadata) metadata["platform"] = "{}|FooPlatform".format(_METADATA["platform"]) - client = AsyncMongoClient( + client = self.simple_client( "foo", 27017, appname="foobar", @@ -382,7 +388,7 @@ def test_metadata(self): options = client.options self.assertEqual(options.pool_options.metadata, metadata) # Test truncating driver info metadata. - client = AsyncMongoClient( + client = self.simple_client( driver=DriverInfo(name="s" * _MAX_METADATA_SIZE), connect=False, ) @@ -391,7 +397,7 @@ def test_metadata(self): len(bson.encode(options.pool_options.metadata)), _MAX_METADATA_SIZE, ) - client = AsyncMongoClient( + client = self.simple_client( driver=DriverInfo(name="s" * _MAX_METADATA_SIZE, version="s" * _MAX_METADATA_SIZE), connect=False, ) @@ -407,11 +413,11 @@ def test_container_metadata(self): metadata["driver"]["name"] = "PyMongo|async" metadata["env"] = {} metadata["env"]["container"] = {"orchestrator": "kubernetes"} - client = AsyncMongoClient("mongodb://foo:27017/?appname=foobar&connect=false") + client = self.simple_client("mongodb://foo:27017/?appname=foobar&connect=false") options = client.options self.assertEqual(options.pool_options.metadata["env"], metadata["env"]) - def test_kwargs_codec_options(self): + async def test_kwargs_codec_options(self): class MyFloatType: def __init__(self, x): self.__x = x @@ -433,7 +439,7 @@ def transform_python(self, value): uuid_representation_label = "javaLegacy" unicode_decode_error_handler = "ignore" tzinfo = utc - c = AsyncMongoClient( + c = self.simple_client( document_class=document_class, type_registry=type_registry, tz_aware=tz_aware, @@ -442,12 +448,12 @@ def transform_python(self, value): tzinfo=tzinfo, connect=False, ) - self.assertEqual(c.codec_options.document_class, document_class) self.assertEqual(c.codec_options.type_registry, type_registry) self.assertEqual(c.codec_options.tz_aware, tz_aware) self.assertEqual( - c.codec_options.uuid_representation, _UUID_REPRESENTATIONS[uuid_representation_label] + c.codec_options.uuid_representation, + _UUID_REPRESENTATIONS[uuid_representation_label], ) self.assertEqual(c.codec_options.unicode_decode_error_handler, unicode_decode_error_handler) self.assertEqual(c.codec_options.tzinfo, tzinfo) @@ -469,11 +475,11 @@ async def test_uri_codec_options(self): datetime_conversion, ) ) - c = AsyncMongoClient(uri, connect=False) - + c = self.simple_client(uri, connect=False) self.assertEqual(c.codec_options.tz_aware, True) self.assertEqual( - c.codec_options.uuid_representation, _UUID_REPRESENTATIONS[uuid_representation_label] + c.codec_options.uuid_representation, + _UUID_REPRESENTATIONS[uuid_representation_label], ) self.assertEqual(c.codec_options.unicode_decode_error_handler, unicode_decode_error_handler) self.assertEqual( @@ -482,16 +488,15 @@ async def test_uri_codec_options(self): # Change the passed datetime_conversion to a number and re-assert. uri = uri.replace(datetime_conversion, f"{int(DatetimeConversion[datetime_conversion])}") - c = AsyncMongoClient(uri, connect=False) - + c = self.simple_client(uri, connect=False) self.assertEqual( c.codec_options.datetime_conversion, DatetimeConversion[datetime_conversion] ) - def test_uri_option_precedence(self): + async def test_uri_option_precedence(self): # Ensure kwarg options override connection string options. uri = "mongodb://localhost/?ssl=true&replicaSet=name&readPreference=primary" - c = AsyncMongoClient( + c = self.simple_client( uri, ssl=False, replicaSet="newname", readPreference="secondaryPreferred" ) clopts = c.options @@ -501,7 +506,7 @@ def test_uri_option_precedence(self): self.assertEqual(clopts.replica_set_name, "newname") self.assertEqual(clopts.read_preference, ReadPreference.SECONDARY_PREFERRED) - def test_connection_timeout_ms_propagates_to_DNS_resolver(self): + async def test_connection_timeout_ms_propagates_to_DNS_resolver(self): # Patch the resolver. from pymongo.srv_resolver import _resolve @@ -520,37 +525,37 @@ def reset_resolver(): uri_with_timeout = base_uri + "/?connectTimeoutMS=6000" expected_uri_value = 6.0 - def test_scenario(args, kwargs, expected_value): + async def test_scenario(args, kwargs, expected_value): patched_resolver.reset() - AsyncMongoClient(*args, **kwargs) + self.simple_client(*args, **kwargs) for _, kw in patched_resolver.call_list(): self.assertAlmostEqual(kw["lifetime"], expected_value) # No timeout specified. - test_scenario((base_uri,), {}, CONNECT_TIMEOUT) + await test_scenario((base_uri,), {}, CONNECT_TIMEOUT) # Timeout only specified in connection string. - test_scenario((uri_with_timeout,), {}, expected_uri_value) + await test_scenario((uri_with_timeout,), {}, expected_uri_value) # Timeout only specified in keyword arguments. kwarg = {"connectTimeoutMS": connectTimeoutMS} - test_scenario((base_uri,), kwarg, expected_kw_value) + await test_scenario((base_uri,), kwarg, expected_kw_value) # Timeout specified in both kwargs and connection string. - test_scenario((uri_with_timeout,), kwarg, expected_kw_value) + await test_scenario((uri_with_timeout,), kwarg, expected_kw_value) - def test_uri_security_options(self): + async def test_uri_security_options(self): # Ensure that we don't silently override security-related options. with self.assertRaises(InvalidURI): - AsyncMongoClient("mongodb://localhost/?ssl=true", tls=False, connect=False) + self.simple_client("mongodb://localhost/?ssl=true", tls=False, connect=False) # Matching SSL and TLS options should not cause errors. - c = AsyncMongoClient("mongodb://localhost/?ssl=false", tls=False, connect=False) + c = self.simple_client("mongodb://localhost/?ssl=false", tls=False, connect=False) self.assertEqual(c.options._options["tls"], False) # Conflicting tlsInsecure options should raise an error. with self.assertRaises(InvalidURI): - AsyncMongoClient( + self.simple_client( "mongodb://localhost/?tlsInsecure=true", connect=False, tlsAllowInvalidHostnames=True, @@ -558,7 +563,7 @@ def test_uri_security_options(self): # Conflicting legacy tlsInsecure options should also raise an error. with self.assertRaises(InvalidURI): - AsyncMongoClient( + self.simple_client( "mongodb://localhost/?tlsInsecure=true", connect=False, tlsAllowInvalidCertificates=False, @@ -566,10 +571,10 @@ def test_uri_security_options(self): # Conflicting kwargs should raise InvalidURI with self.assertRaises(InvalidURI): - AsyncMongoClient(ssl=True, tls=False) + self.simple_client(ssl=True, tls=False) - def test_event_listeners(self): - c = AsyncMongoClient(event_listeners=[], connect=False) + async def test_event_listeners(self): + c = self.simple_client(event_listeners=[], connect=False) self.assertEqual(c.options.event_listeners, []) listeners = [ event_loggers.CommandLogger(), @@ -578,11 +583,11 @@ def test_event_listeners(self): event_loggers.TopologyLogger(), event_loggers.ConnectionPoolLogger(), ] - c = AsyncMongoClient(event_listeners=listeners, connect=False) + c = self.simple_client(event_listeners=listeners, connect=False) self.assertEqual(c.options.event_listeners, listeners) - def test_client_options(self): - c = AsyncMongoClient(connect=False) + async def test_client_options(self): + c = self.simple_client(connect=False) self.assertIsInstance(c.options, ClientOptions) self.assertIsInstance(c.options.pool_options, PoolOptions) self.assertEqual(c.options.server_selection_timeout, 30) @@ -612,16 +617,16 @@ def test_detected_environment_logging(self, mock_get_hosts): ) with self.assertLogs("pymongo", level="INFO") as cm: for host in normal_hosts: - AsyncMongoClient(host) + AsyncMongoClient(host, connect=False) for host in srv_hosts: mock_get_hosts.return_value = [(host, 1)] - AsyncMongoClient(host) - AsyncMongoClient(multi_host) + AsyncMongoClient(host, connect=False) + AsyncMongoClient(multi_host, connect=False) logs = [record.getMessage() for record in cm.records if record.name == "pymongo.client"] self.assertEqual(len(logs), 7) @patch("pymongo.srv_resolver._SrvResolver.get_hosts") - def test_detected_environment_warning(self, mock_get_hosts): + async def test_detected_environment_warning(self, mock_get_hosts): with self._caplog.at_level(logging.WARN): normal_hosts = [ "host.cosmos.azure.com", @@ -634,13 +639,13 @@ def test_detected_environment_warning(self, mock_get_hosts): ) for host in normal_hosts: with self.assertWarns(UserWarning): - AsyncMongoClient(host) + self.simple_client(host) for host in srv_hosts: mock_get_hosts.return_value = [(host, 1)] with self.assertWarns(UserWarning): - AsyncMongoClient(host) + self.simple_client(host) with self.assertWarns(UserWarning): - AsyncMongoClient(multi_host) + self.simple_client(multi_host) class TestClient(AsyncIntegrationTest): @@ -657,7 +662,7 @@ def test_multiple_uris(self): async def test_max_idle_time_reaper_default(self): with client_knobs(kill_cursor_frequency=0.1): # Assert reaper doesn't remove connections when maxIdleTimeMS not set - client = await async_rs_or_single_client() + client = await self.async_rs_or_single_client() server = await (await client._get_topology()).select_server( readable_server_selector, _Op.TEST ) @@ -665,12 +670,11 @@ async def test_max_idle_time_reaper_default(self): pass self.assertEqual(1, len(server._pool.conns)) self.assertTrue(conn in server._pool.conns) - await client.close() async def test_max_idle_time_reaper_removes_stale_minPoolSize(self): with client_knobs(kill_cursor_frequency=0.1): # Assert reaper removes idle socket and replaces it with a new one - client = await async_rs_or_single_client(maxIdleTimeMS=500, minPoolSize=1) + client = await self.async_rs_or_single_client(maxIdleTimeMS=500, minPoolSize=1) server = await (await client._get_topology()).select_server( readable_server_selector, _Op.TEST ) @@ -681,12 +685,11 @@ async def test_max_idle_time_reaper_removes_stale_minPoolSize(self): self.assertGreaterEqual(len(server._pool.conns), 1) wait_until(lambda: conn not in server._pool.conns, "remove stale socket") wait_until(lambda: len(server._pool.conns) >= 1, "replace stale socket") - await client.close() async def test_max_idle_time_reaper_does_not_exceed_maxPoolSize(self): with client_knobs(kill_cursor_frequency=0.1): # Assert reaper respects maxPoolSize when adding new connections. - client = await async_rs_or_single_client( + client = await self.async_rs_or_single_client( maxIdleTimeMS=500, minPoolSize=1, maxPoolSize=1 ) server = await (await client._get_topology()).select_server( @@ -699,12 +702,11 @@ async def test_max_idle_time_reaper_does_not_exceed_maxPoolSize(self): self.assertEqual(1, len(server._pool.conns)) wait_until(lambda: conn not in server._pool.conns, "remove stale socket") wait_until(lambda: len(server._pool.conns) == 1, "replace stale socket") - await client.close() async def test_max_idle_time_reaper_removes_stale(self): with client_knobs(kill_cursor_frequency=0.1): # Assert reaper has removed idle socket and NOT replaced it - client = await async_rs_or_single_client(maxIdleTimeMS=500) + client = await self.async_rs_or_single_client(maxIdleTimeMS=500) server = await (await client._get_topology()).select_server( readable_server_selector, _Op.TEST ) @@ -719,18 +721,17 @@ async def test_max_idle_time_reaper_removes_stale(self): lambda: len(server._pool.conns) == 0, "stale socket reaped and new one NOT added to the pool", ) - await client.close() async def test_min_pool_size(self): with client_knobs(kill_cursor_frequency=0.1): - client = await async_rs_or_single_client() + client = await self.async_rs_or_single_client() server = await (await client._get_topology()).select_server( readable_server_selector, _Op.TEST ) self.assertEqual(0, len(server._pool.conns)) # Assert that pool started up at minPoolSize - client = await async_rs_or_single_client(minPoolSize=10) + client = await self.async_rs_or_single_client(minPoolSize=10) server = await (await client._get_topology()).select_server( readable_server_selector, _Op.TEST ) @@ -751,7 +752,7 @@ async def test_min_pool_size(self): async def test_max_idle_time_checkout(self): # Use high frequency to test _get_socket_no_auth. with client_knobs(kill_cursor_frequency=99999999): - client = await async_rs_or_single_client(maxIdleTimeMS=500) + client = await self.async_rs_or_single_client(maxIdleTimeMS=500) server = await (await client._get_topology()).select_server( readable_server_selector, _Op.TEST ) @@ -767,7 +768,7 @@ async def test_max_idle_time_checkout(self): self.assertTrue(new_con in server._pool.conns) # Test that connections are reused if maxIdleTimeMS is not set. - client = await async_rs_or_single_client() + client = await self.async_rs_or_single_client() server = await (await client._get_topology()).select_server( readable_server_selector, _Op.TEST ) @@ -793,36 +794,38 @@ async def test_constants(self): AsyncMongoClient.HOST = "somedomainthatdoesntexist.org" AsyncMongoClient.PORT = 123456789 with self.assertRaises(AutoReconnect): - await connected(AsyncMongoClient(serverSelectionTimeoutMS=10, **kwargs)) + c = self.simple_client(serverSelectionTimeoutMS=10, **kwargs) + await connected(c) + c = self.simple_client(host, port, **kwargs) # Override the defaults. No error. - await connected(AsyncMongoClient(host, port, **kwargs)) + await connected(c) # Set good defaults. AsyncMongoClient.HOST = host AsyncMongoClient.PORT = port # No error. - await connected(AsyncMongoClient(**kwargs)) + c = self.simple_client(**kwargs) + await connected(c) async def test_init_disconnected(self): host, port = await async_client_context.host, await async_client_context.port - c = await async_rs_or_single_client(connect=False) + c = await self.async_rs_or_single_client(connect=False) # is_primary causes client to block until connected self.assertIsInstance(await c.is_primary, bool) - - c = await async_rs_or_single_client(connect=False) + c = await self.async_rs_or_single_client(connect=False) self.assertIsInstance(await c.is_mongos, bool) - c = await async_rs_or_single_client(connect=False) + c = await self.async_rs_or_single_client(connect=False) self.assertIsInstance(c.options.pool_options.max_pool_size, int) self.assertIsInstance(c.nodes, frozenset) - c = await async_rs_or_single_client(connect=False) + c = await self.async_rs_or_single_client(connect=False) self.assertEqual(c.codec_options, CodecOptions()) - c = await async_rs_or_single_client(connect=False) + c = await self.async_rs_or_single_client(connect=False) self.assertFalse(await c.primary) self.assertFalse(await c.secondaries) - c = await async_rs_or_single_client(connect=False) + c = await self.async_rs_or_single_client(connect=False) self.assertIsInstance(c.topology_description, TopologyDescription) self.assertEqual(c.topology_description, c._topology._description) self.assertIsNone(await c.address) # PYTHON-2981 @@ -834,45 +837,44 @@ async def test_init_disconnected(self): self.assertEqual(await c.address, (host, port)) bad_host = "somedomainthatdoesntexist.org" - c = AsyncMongoClient(bad_host, port, connectTimeoutMS=1, serverSelectionTimeoutMS=10) + c = self.simple_client(bad_host, port, connectTimeoutMS=1, serverSelectionTimeoutMS=10) with self.assertRaises(ConnectionFailure): await c.pymongo_test.test.find_one() async def test_init_disconnected_with_auth(self): uri = "mongodb://user:pass@somedomainthatdoesntexist" - c = AsyncMongoClient(uri, connectTimeoutMS=1, serverSelectionTimeoutMS=10) + c = self.simple_client(uri, connectTimeoutMS=1, serverSelectionTimeoutMS=10) with self.assertRaises(ConnectionFailure): await c.pymongo_test.test.find_one() async def test_equality(self): seed = "{}:{}".format(*list(self.client._topology_settings.seeds)[0]) - c = await async_rs_or_single_client(seed, connect=False) - self.addAsyncCleanup(c.close) + c = await self.async_rs_or_single_client(seed, connect=False) self.assertEqual(async_client_context.client, c) # Explicitly test inequality self.assertFalse(async_client_context.client != c) - c = await async_rs_or_single_client("invalid.com", connect=False) - self.addAsyncCleanup(c.close) + c = await self.async_rs_or_single_client("invalid.com", connect=False) self.assertNotEqual(async_client_context.client, c) self.assertTrue(async_client_context.client != c) + + c1 = self.simple_client("a", connect=False) + c2 = self.simple_client("b", connect=False) + # Seeds differ: - self.assertNotEqual( - AsyncMongoClient("a", connect=False), AsyncMongoClient("b", connect=False) - ) + self.assertNotEqual(c1, c2) + + c1 = self.simple_client(["a", "b", "c"], connect=False) + c2 = self.simple_client(["c", "a", "b"], connect=False) + # Same seeds but out of order still compares equal: - self.assertEqual( - AsyncMongoClient(["a", "b", "c"], connect=False), - AsyncMongoClient(["c", "a", "b"], connect=False), - ) + self.assertEqual(c1, c2) async def test_hashable(self): seed = "{}:{}".format(*list(self.client._topology_settings.seeds)[0]) - c = await async_rs_or_single_client(seed, connect=False) - self.addAsyncCleanup(c.close) + c = await self.async_rs_or_single_client(seed, connect=False) self.assertIn(c, {async_client_context.client}) - c = await async_rs_or_single_client("invalid.com", connect=False) - self.addAsyncCleanup(c.close) + c = await self.async_rs_or_single_client("invalid.com", connect=False) self.assertNotIn(c, {async_client_context.client}) async def test_host_w_port(self): @@ -886,7 +888,7 @@ async def test_host_w_port(self): ) ) - def test_repr(self): + async def test_repr(self): # Used to test 'eval' below. import bson @@ -905,9 +907,10 @@ def test_repr(self): self.assertIn("w=1", the_repr) self.assertIn("wtimeoutms=100", the_repr) - self.assertEqual(eval(the_repr), client) + async with eval(the_repr) as client_two: + self.assertEqual(client_two, client) - client = AsyncMongoClient( + client = self.simple_client( "localhost:27017,localhost:27018", replicaSet="replset", connectTimeoutMS=12345, @@ -925,7 +928,8 @@ def test_repr(self): self.assertIn("w=1", the_repr) self.assertIn("wtimeoutms=100", the_repr) - self.assertEqual(eval(the_repr), client) + async with eval(the_repr) as client_two: + self.assertEqual(client_two, client) def test_getters(self): wait_until(lambda: async_client_context.nodes == self.client.nodes, "find all nodes") @@ -941,8 +945,7 @@ async def test_list_databases(self): for helper_doc, cmd_doc in zip(helper_docs, cmd_docs): self.assertIs(type(helper_doc), dict) self.assertEqual(helper_doc.keys(), cmd_doc.keys()) - client = await async_rs_or_single_client(document_class=SON) - self.addAsyncCleanup(client.close) + client = await self.async_rs_or_single_client(document_class=SON) async for doc in await client.list_databases(): self.assertIs(type(doc), dict) @@ -981,7 +984,7 @@ async def test_drop_database(self): await self.client.drop_database("pymongo_test") if async_client_context.is_rs: - wc_client = await async_rs_or_single_client(w=len(async_client_context.nodes) + 1) + wc_client = await self.async_rs_or_single_client(w=len(async_client_context.nodes) + 1) with self.assertRaises(WriteConcernError): await wc_client.drop_database("pymongo_test2") @@ -991,7 +994,7 @@ async def test_drop_database(self): self.assertNotIn("pymongo_test2", dbs) async def test_close(self): - test_client = await async_rs_or_single_client() + test_client = await self.async_rs_or_single_client() coll = test_client.pymongo_test.bar await test_client.close() with self.assertRaises(InvalidOperation): @@ -1001,7 +1004,7 @@ async def test_close_kills_cursors(self): if sys.platform.startswith("java"): # We can't figure out how to make this test reliable with Jython. raise SkipTest("Can't test with Jython") - test_client = await async_rs_or_single_client() + test_client = await self.async_rs_or_single_client() # Kill any cursors possibly queued up by previous tests. gc.collect() await test_client._process_periodic_tasks() @@ -1028,13 +1031,13 @@ async def test_close_kills_cursors(self): self.assertTrue(test_client._topology._opened) await test_client.close() self.assertFalse(test_client._topology._opened) - test_client = await async_rs_or_single_client() + test_client = await self.async_rs_or_single_client() # The killCursors task should not need to re-open the topology. await test_client._process_periodic_tasks() self.assertTrue(test_client._topology._opened) async def test_close_stops_kill_cursors_thread(self): - client = await async_rs_client() + client = await self.async_rs_client() await client.test.test.find_one() self.assertFalse(client._kill_cursors_executor._stopped) @@ -1050,7 +1053,7 @@ async def test_close_stops_kill_cursors_thread(self): async def test_uri_connect_option(self): # Ensure that topology is not opened if connect=False. - client = await async_rs_client(connect=False) + client = await self.async_rs_client(connect=False) self.assertFalse(client._topology._opened) # Ensure kill cursors thread has not been started. @@ -1063,19 +1066,15 @@ async def test_uri_connect_option(self): kc_thread = client._kill_cursors_executor._thread self.assertTrue(kc_thread and kc_thread.is_alive()) - # Tear down. - await client.close() - async def test_close_does_not_open_servers(self): - client = await async_rs_client(connect=False) + client = await self.async_rs_client(connect=False) topology = client._topology self.assertEqual(topology._servers, {}) await client.close() self.assertEqual(topology._servers, {}) async def test_close_closes_sockets(self): - client = await async_rs_client() - self.addAsyncCleanup(client.close) + client = await self.async_rs_client() await client.test.test.find_one() topology = client._topology await client.close() @@ -1104,35 +1103,35 @@ async def test_auth_from_uri(self): with self.assertRaises(OperationFailure): await connected( - await async_rs_or_single_client_noauth("mongodb://a:b@%s:%d" % (host, port)) + await self.async_rs_or_single_client_noauth("mongodb://a:b@%s:%d" % (host, port)) ) # No error. await connected( - await async_rs_or_single_client_noauth("mongodb://admin:pass@%s:%d" % (host, port)) + await self.async_rs_or_single_client_noauth("mongodb://admin:pass@%s:%d" % (host, port)) ) # Wrong database. uri = "mongodb://admin:pass@%s:%d/pymongo_test" % (host, port) with self.assertRaises(OperationFailure): - await connected(await async_rs_or_single_client_noauth(uri)) + await connected(await self.async_rs_or_single_client_noauth(uri)) # No error. await connected( - await async_rs_or_single_client_noauth( + await self.async_rs_or_single_client_noauth( "mongodb://user:pass@%s:%d/pymongo_test" % (host, port) ) ) # Auth with lazy connection. await ( - await async_rs_or_single_client_noauth( + await self.async_rs_or_single_client_noauth( "mongodb://user:pass@%s:%d/pymongo_test" % (host, port), connect=False ) ).pymongo_test.test.find_one() # Wrong password. - bad_client = await async_rs_or_single_client_noauth( + bad_client = await self.async_rs_or_single_client_noauth( "mongodb://user:wrong@%s:%d/pymongo_test" % (host, port), connect=False ) @@ -1144,7 +1143,7 @@ async def test_username_and_password(self): await async_client_context.create_user("admin", "ad min", "pa/ss") self.addAsyncCleanup(async_client_context.drop_user, "admin", "ad min") - c = await async_rs_or_single_client_noauth(username="ad min", password="pa/ss") + c = await self.async_rs_or_single_client_noauth(username="ad min", password="pa/ss") # Username and password aren't in strings that will likely be logged. self.assertNotIn("ad min", repr(c)) @@ -1157,14 +1156,14 @@ async def test_username_and_password(self): with self.assertRaises(OperationFailure): await ( - await async_rs_or_single_client_noauth(username="ad min", password="foo") + await self.async_rs_or_single_client_noauth(username="ad min", password="foo") ).server_info() @async_client_context.require_auth @async_client_context.require_no_fips async def test_lazy_auth_raises_operation_failure(self): host = await async_client_context.host - lazy_client = await async_rs_or_single_client_noauth( + lazy_client = await self.async_rs_or_single_client_noauth( f"mongodb://user:wrong@{host}/pymongo_test", connect=False ) @@ -1182,8 +1181,7 @@ async def test_unix_socket(self): uri = "mongodb://%s" % encoded_socket # Confirm we can do operations via the socket. - client = await async_rs_or_single_client(uri) - self.addAsyncCleanup(client.close) + client = await self.async_rs_or_single_client(uri) await client.pymongo_test.test.insert_one({"dummy": "object"}) dbs = await client.list_database_names() self.assertTrue("pymongo_test" in dbs) @@ -1192,11 +1190,10 @@ async def test_unix_socket(self): # Confirm it fails with a missing socket. with self.assertRaises(ConnectionFailure): - await connected( - AsyncMongoClient( - "mongodb://%2Ftmp%2Fnon-existent.sock", serverSelectionTimeoutMS=100 - ), + c = self.simple_client( + "mongodb://%2Ftmp%2Fnon-existent.sock", serverSelectionTimeoutMS=100 ) + await connected(c) async def test_document_class(self): c = self.client @@ -1207,15 +1204,15 @@ async def test_document_class(self): self.assertTrue(isinstance(await db.test.find_one(), dict)) self.assertFalse(isinstance(await db.test.find_one(), SON)) - c = await async_rs_or_single_client(document_class=SON) - self.addAsyncCleanup(c.close) + c = await self.async_rs_or_single_client(document_class=SON) + db = c.pymongo_test self.assertEqual(SON, c.codec_options.document_class) self.assertTrue(isinstance(await db.test.find_one(), SON)) async def test_timeouts(self): - client = await async_rs_or_single_client( + client = await self.async_rs_or_single_client( connectTimeoutMS=10500, socketTimeoutMS=10500, maxIdleTimeMS=10500, @@ -1228,28 +1225,31 @@ async def test_timeouts(self): self.assertEqual(10.5, client.options.server_selection_timeout) async def test_socket_timeout_ms_validation(self): - c = await async_rs_or_single_client(socketTimeoutMS=10 * 1000) + c = await self.async_rs_or_single_client(socketTimeoutMS=10 * 1000) self.assertEqual(10, (await async_get_pool(c)).opts.socket_timeout) - c = await connected(await async_rs_or_single_client(socketTimeoutMS=None)) + c = await connected(await self.async_rs_or_single_client(socketTimeoutMS=None)) self.assertEqual(None, (await async_get_pool(c)).opts.socket_timeout) - c = await connected(await async_rs_or_single_client(socketTimeoutMS=0)) + c = await connected(await self.async_rs_or_single_client(socketTimeoutMS=0)) self.assertEqual(None, (await async_get_pool(c)).opts.socket_timeout) with self.assertRaises(ValueError): - await async_rs_or_single_client(socketTimeoutMS=-1) + async with await self.async_rs_or_single_client(socketTimeoutMS=-1): + pass with self.assertRaises(ValueError): - await async_rs_or_single_client(socketTimeoutMS=1e10) + async with await self.async_rs_or_single_client(socketTimeoutMS=1e10): + pass with self.assertRaises(ValueError): - await async_rs_or_single_client(socketTimeoutMS="foo") + async with await self.async_rs_or_single_client(socketTimeoutMS="foo"): + pass async def test_socket_timeout(self): no_timeout = self.client timeout_sec = 1 - timeout = await async_rs_or_single_client(socketTimeoutMS=1000 * timeout_sec) + timeout = await self.async_rs_or_single_client(socketTimeoutMS=1000 * timeout_sec) self.addAsyncCleanup(timeout.close) await no_timeout.pymongo_test.drop_collection("test") @@ -1266,7 +1266,7 @@ async def get_x(db): with self.assertRaises(NetworkTimeout): await get_x(timeout.pymongo_test) - def test_server_selection_timeout(self): + async def test_server_selection_timeout(self): client = AsyncMongoClient(serverSelectionTimeoutMS=100, connect=False) self.assertAlmostEqual(0.1, client.options.server_selection_timeout) @@ -1298,7 +1298,7 @@ def test_server_selection_timeout(self): self.assertAlmostEqual(30, client.options.server_selection_timeout) async def test_waitQueueTimeoutMS(self): - client = await async_rs_or_single_client(waitQueueTimeoutMS=2000) + client = await self.async_rs_or_single_client(waitQueueTimeoutMS=2000) self.assertEqual((await async_get_pool(client)).opts.wait_queue_timeout, 2) async def test_socketKeepAlive(self): @@ -1311,7 +1311,7 @@ async def test_socketKeepAlive(self): async def test_tz_aware(self): self.assertRaises(ValueError, AsyncMongoClient, tz_aware="foo") - aware = await async_rs_or_single_client(tz_aware=True) + aware = await self.async_rs_or_single_client(tz_aware=True) self.addAsyncCleanup(aware.close) naive = self.client await aware.pymongo_test.drop_collection("test") @@ -1341,8 +1341,7 @@ async def test_ipv6(self): if async_client_context.is_rs: uri += "/?replicaSet=" + (async_client_context.replica_set_name or "") - client = await async_rs_or_single_client_noauth(uri) - self.addAsyncCleanup(client.close) + client = await self.async_rs_or_single_client_noauth(uri) await client.pymongo_test.test.insert_one({"dummy": "object"}) await client.pymongo_test_bernie.test.insert_one({"dummy": "object"}) @@ -1351,7 +1350,7 @@ async def test_ipv6(self): self.assertTrue("pymongo_test_bernie" in dbs) async def test_contextlib(self): - client = await async_rs_or_single_client() + client = await self.async_rs_or_single_client() await client.pymongo_test.drop_collection("test") await client.pymongo_test.test.insert_one({"foo": "bar"}) @@ -1365,7 +1364,7 @@ async def test_contextlib(self): self.assertEqual("bar", (await client.pymongo_test.test.find_one())["foo"]) with self.assertRaises(InvalidOperation): await client.pymongo_test.test.find_one() - client = await async_rs_or_single_client() + client = await self.async_rs_or_single_client() async with client as client: self.assertEqual("bar", (await client.pymongo_test.test.find_one())["foo"]) with self.assertRaises(InvalidOperation): @@ -1443,8 +1442,7 @@ async def test_operation_failure(self): # response to getLastError. PYTHON-395. We need a new client here # to avoid race conditions caused by replica set failover or idle # socket reaping. - client = await async_single_client() - self.addAsyncCleanup(client.close) + client = await self.async_single_client() await client.pymongo_test.test.find_one() pool = await async_get_pool(client) socket_count = len(pool.conns) @@ -1468,8 +1466,7 @@ async def test_lazy_connect_w0(self): await async_client_context.client.drop_database("test_lazy_connect_w0") self.addAsyncCleanup(async_client_context.client.drop_database, "test_lazy_connect_w0") - client = await async_rs_or_single_client(connect=False, w=0) - self.addAsyncCleanup(client.close) + client = await self.async_rs_or_single_client(connect=False, w=0) await client.test_lazy_connect_w0.test.insert_one({}) async def predicate(): @@ -1477,8 +1474,7 @@ async def predicate(): await async_wait_until(predicate, "find one document") - client = await async_rs_or_single_client(connect=False, w=0) - self.addAsyncCleanup(client.close) + client = await self.async_rs_or_single_client(connect=False, w=0) await client.test_lazy_connect_w0.test.update_one({}, {"$set": {"x": 1}}) async def predicate(): @@ -1486,8 +1482,7 @@ async def predicate(): await async_wait_until(predicate, "update one document") - client = await async_rs_or_single_client(connect=False, w=0) - self.addAsyncCleanup(client.close) + client = await self.async_rs_or_single_client(connect=False, w=0) await client.test_lazy_connect_w0.test.delete_one({}) async def predicate(): @@ -1499,8 +1494,7 @@ async def predicate(): async def test_exhaust_network_error(self): # When doing an exhaust query, the socket stays checked out on success # but must be checked in on error to avoid semaphore leaks. - client = await async_rs_or_single_client(maxPoolSize=1, retryReads=False) - self.addAsyncCleanup(client.close) + client = await self.async_rs_or_single_client(maxPoolSize=1, retryReads=False) collection = client.pymongo_test.test pool = await async_get_pool(client) pool._check_interval_seconds = None # Never check. @@ -1527,7 +1521,9 @@ async def test_auth_network_error(self): # Get a client with one socket so we detect if it's leaked. c = await connected( - await async_rs_or_single_client(maxPoolSize=1, waitQueueTimeoutMS=1, retryReads=False) + await self.async_rs_or_single_client( + maxPoolSize=1, waitQueueTimeoutMS=1, retryReads=False + ) ) # Cause a network error on the actual socket. @@ -1545,8 +1541,7 @@ async def test_auth_network_error(self): @async_client_context.require_no_replica_set async def test_connect_to_standalone_using_replica_set_name(self): - client = await async_single_client(replicaSet="anything", serverSelectionTimeoutMS=100) - + client = await self.async_single_client(replicaSet="anything", serverSelectionTimeoutMS=100) with self.assertRaises(AutoReconnect): await client.test.test.find_one() @@ -1556,7 +1551,7 @@ async def test_stale_getmore(self): # the topology before the getMore message is sent. Test that # AsyncMongoClient._run_operation_with_response handles the error. with self.assertRaises(AutoReconnect): - client = await async_rs_client(connect=False, serverSelectionTimeoutMS=100) + client = await self.async_rs_client(connect=False, serverSelectionTimeoutMS=100) await client._run_operation( operation=message._GetMore( "pymongo_test", @@ -1604,7 +1599,7 @@ def init(self, *args): await async_client_context.host, await async_client_context.port, ) - client = await async_single_client(uri, event_listeners=[listener]) + await self.async_single_client(uri, event_listeners=[listener]) wait_until( lambda: len(listener.results) >= 2, "record two ServerHeartbeatStartedEvents" ) @@ -1613,7 +1608,6 @@ def init(self, *args): # closer to 0.5 sec with heartbeatFrequencyMS configured. self.assertAlmostEqual(heartbeat_times[1] - heartbeat_times[0], 0.5, delta=2) - await client.close() finally: ServerHeartbeatStartedEvent.__init__ = old_init # type: ignore @@ -1630,31 +1624,31 @@ def compression_settings(client): return pool_options._compression_settings uri = "mongodb://localhost:27017/?compressors=zlib" - client = AsyncMongoClient(uri, connect=False) + client = self.simple_client(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, ["zlib"]) uri = "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=4" - client = AsyncMongoClient(uri, connect=False) + client = self.simple_client(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, ["zlib"]) self.assertEqual(opts.zlib_compression_level, 4) uri = "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=-1" - client = AsyncMongoClient(uri, connect=False) + client = self.simple_client(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, ["zlib"]) self.assertEqual(opts.zlib_compression_level, -1) uri = "mongodb://localhost:27017" - client = AsyncMongoClient(uri, connect=False) + client = self.simple_client(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, []) self.assertEqual(opts.zlib_compression_level, -1) uri = "mongodb://localhost:27017/?compressors=foobar" - client = AsyncMongoClient(uri, connect=False) + client = self.simple_client(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, []) self.assertEqual(opts.zlib_compression_level, -1) uri = "mongodb://localhost:27017/?compressors=foobar,zlib" - client = AsyncMongoClient(uri, connect=False) + client = self.simple_client(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, ["zlib"]) self.assertEqual(opts.zlib_compression_level, -1) @@ -1662,56 +1656,55 @@ def compression_settings(client): # According to the connection string spec, unsupported values # just raise a warning and are ignored. uri = "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=10" - client = AsyncMongoClient(uri, connect=False) + client = self.simple_client(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, ["zlib"]) self.assertEqual(opts.zlib_compression_level, -1) uri = "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=-2" - client = AsyncMongoClient(uri, connect=False) + client = self.simple_client(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, ["zlib"]) self.assertEqual(opts.zlib_compression_level, -1) if not _have_snappy(): uri = "mongodb://localhost:27017/?compressors=snappy" - client = AsyncMongoClient(uri, connect=False) + client = self.simple_client(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, []) else: uri = "mongodb://localhost:27017/?compressors=snappy" - client = AsyncMongoClient(uri, connect=False) + client = self.simple_client(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, ["snappy"]) uri = "mongodb://localhost:27017/?compressors=snappy,zlib" - client = AsyncMongoClient(uri, connect=False) + client = self.simple_client(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, ["snappy", "zlib"]) if not _have_zstd(): uri = "mongodb://localhost:27017/?compressors=zstd" - client = AsyncMongoClient(uri, connect=False) + client = self.simple_client(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, []) else: uri = "mongodb://localhost:27017/?compressors=zstd" - client = AsyncMongoClient(uri, connect=False) + client = self.simple_client(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, ["zstd"]) uri = "mongodb://localhost:27017/?compressors=zstd,zlib" - client = AsyncMongoClient(uri, connect=False) + client = self.simple_client(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, ["zstd", "zlib"]) options = async_client_context.default_client_options if "compressors" in options and "zlib" in options["compressors"]: for level in range(-1, 10): - client = await async_single_client(zlibcompressionlevel=level) + client = await self.async_single_client(zlibcompressionlevel=level) # No error await client.pymongo_test.test.find_one() async def test_reset_during_update_pool(self): - client = await async_rs_or_single_client(minPoolSize=10) - self.addAsyncCleanup(client.close) + client = await self.async_rs_or_single_client(minPoolSize=10) await client.admin.command("ping") pool = await async_get_pool(client) generation = pool.gen.get_overall() @@ -1757,11 +1750,9 @@ def run(self): async def test_background_connections_do_not_hold_locks(self): min_pool_size = 10 - client = await async_rs_or_single_client( + client = await self.async_rs_or_single_client( serverSelectionTimeoutMS=3000, minPoolSize=min_pool_size, connect=False ) - self.addAsyncCleanup(client.close) - # Create a single connection in the pool. await client.admin.command("ping") @@ -1791,21 +1782,19 @@ def stall_connect(*args, **kwargs): @async_client_context.require_replica_set async def test_direct_connection(self): # direct_connection=True should result in Single topology. - client = await async_rs_or_single_client(directConnection=True) + client = await self.async_rs_or_single_client(directConnection=True) await client.admin.command("ping") self.assertEqual(len(client.nodes), 1) self.assertEqual(client._topology_settings.get_topology_type(), TOPOLOGY_TYPE.Single) - await client.close() # direct_connection=False should result in RS topology. - client = await async_rs_or_single_client(directConnection=False) + client = await self.async_rs_or_single_client(directConnection=False) await client.admin.command("ping") self.assertGreaterEqual(len(client.nodes), 1) self.assertIn( client._topology_settings.get_topology_type(), [TOPOLOGY_TYPE.ReplicaSetNoPrimary, TOPOLOGY_TYPE.ReplicaSetWithPrimary], ) - await client.close() # directConnection=True, should error with multiple hosts as a list. with self.assertRaises(ConfigurationError): @@ -1825,11 +1814,10 @@ def server_description_count(): gc.collect() with client_knobs(min_heartbeat_interval=0.003): - client = AsyncMongoClient( + client = self.simple_client( "invalid:27017", heartbeatFrequencyMS=3, serverSelectionTimeoutMS=150 ) initial_count = server_description_count() - self.addAsyncCleanup(client.close) with self.assertRaises(ServerSelectionTimeoutError): await client.test.test.find_one() gc.collect() @@ -1842,8 +1830,7 @@ def server_description_count(): @async_client_context.require_failCommand_fail_point async def test_network_error_message(self): - client = await async_single_client(retryReads=False) - self.addAsyncCleanup(client.close) + client = await self.async_single_client(retryReads=False) await client.admin.command("ping") # connect async with self.fail_point( {"mode": {"times": 1}, "data": {"closeConnection": True, "failCommands": ["find"]}} @@ -1855,7 +1842,7 @@ async def test_network_error_message(self): @unittest.skipIf("PyPy" in sys.version, "PYTHON-2938 could fail on PyPy") async def test_process_periodic_tasks(self): - client = await async_rs_or_single_client() + client = await self.async_rs_or_single_client() coll = client.db.collection await coll.insert_many([{} for _ in range(5)]) cursor = coll.find(batch_size=2) @@ -1873,7 +1860,7 @@ async def test_process_periodic_tasks(self): with self.assertRaises(InvalidOperation): await coll.insert_many([{} for _ in range(5)]) - def test_service_name_from_kwargs(self): + async def test_service_name_from_kwargs(self): client = AsyncMongoClient( "mongodb+srv://user:password@test22.test.build.10gen.cc", srvServiceName="customname", @@ -1893,12 +1880,12 @@ def test_service_name_from_kwargs(self): ) self.assertEqual(client._topology_settings.srv_service_name, "customname") - def test_srv_max_hosts_kwarg(self): - client = AsyncMongoClient("mongodb+srv://test1.test.build.10gen.cc/") + async def test_srv_max_hosts_kwarg(self): + client = self.simple_client("mongodb+srv://test1.test.build.10gen.cc/") self.assertGreater(len(client.topology_description.server_descriptions()), 1) - client = AsyncMongoClient("mongodb+srv://test1.test.build.10gen.cc/", srvmaxhosts=1) + client = self.simple_client("mongodb+srv://test1.test.build.10gen.cc/", srvmaxhosts=1) self.assertEqual(len(client.topology_description.server_descriptions()), 1) - client = AsyncMongoClient( + client = self.simple_client( "mongodb+srv://test1.test.build.10gen.cc/?srvMaxHosts=1", srvmaxhosts=2 ) self.assertEqual(len(client.topology_description.server_descriptions()), 2) @@ -1946,10 +1933,10 @@ async def _test_handshake(self, env_vars, expected_env): if "AWS_REGION" not in env_vars: os.environ["AWS_REGION"] = "" - async with await async_rs_or_single_client(serverSelectionTimeoutMS=10000) as client: - await client.admin.command("ping") - options = client.options - self.assertEqual(options.pool_options.metadata, metadata) + client = await self.async_rs_or_single_client(serverSelectionTimeoutMS=10000) + await client.admin.command("ping") + options = client.options + self.assertEqual(options.pool_options.metadata, metadata) async def test_handshake_01_aws(self): await self._test_handshake( @@ -2045,7 +2032,7 @@ def setUp(self): async def test_exhaust_query_server_error(self): # When doing an exhaust query, the socket stays checked out on success # but must be checked in on error to avoid semaphore leaks. - client = await connected(await async_rs_or_single_client(maxPoolSize=1)) + client = await connected(await self.async_rs_or_single_client(maxPoolSize=1)) collection = client.pymongo_test.test pool = await async_get_pool(client) @@ -2068,7 +2055,7 @@ async def test_exhaust_query_server_error(self): async def test_exhaust_getmore_server_error(self): # When doing a getmore on an exhaust cursor, the socket stays checked # out on success but it's checked in on error to avoid semaphore leaks. - client = await async_rs_or_single_client(maxPoolSize=1) + client = await self.async_rs_or_single_client(maxPoolSize=1) collection = client.pymongo_test.test await collection.drop() @@ -2107,7 +2094,9 @@ async def receive_message(request_id): async def test_exhaust_query_network_error(self): # When doing an exhaust query, the socket stays checked out on success # but must be checked in on error to avoid semaphore leaks. - client = await connected(await async_rs_or_single_client(maxPoolSize=1, retryReads=False)) + client = await connected( + await self.async_rs_or_single_client(maxPoolSize=1, retryReads=False) + ) collection = client.pymongo_test.test pool = await async_get_pool(client) pool._check_interval_seconds = None # Never check. @@ -2128,7 +2117,7 @@ async def test_exhaust_query_network_error(self): async def test_exhaust_getmore_network_error(self): # When doing a getmore on an exhaust cursor, the socket stays checked # out on success but it's checked in on error to avoid semaphore leaks. - client = await async_rs_or_single_client(maxPoolSize=1) + client = await self.async_rs_or_single_client(maxPoolSize=1) collection = client.pymongo_test.test await collection.drop() await collection.insert_many([{} for _ in range(200)]) # More than one batch. @@ -2177,7 +2166,7 @@ def test_gevent_timeout(self): raise SkipTest("Must be running monkey patched by gevent") from gevent import Timeout, spawn - client = rs_or_single_client(maxPoolSize=1) + client = self.async_rs_or_single_client(maxPoolSize=1) coll = client.pymongo_test.test coll.insert_one({}) @@ -2209,7 +2198,7 @@ def test_gevent_timeout_when_creating_connection(self): raise SkipTest("Must be running monkey patched by gevent") from gevent import Timeout, spawn - client = rs_or_single_client() + client = self.async_rs_or_single_client() self.addCleanup(client.close) coll = client.pymongo_test.test pool = async_get_pool(client) @@ -2246,7 +2235,7 @@ class TestClientLazyConnect(AsyncIntegrationTest): """Test concurrent operations on a lazily-connecting MongoClient.""" def _get_client(self): - return rs_or_single_client(connect=False) + return self.async_rs_or_single_client(connect=False) @async_client_context.require_sync def test_insert_one(self): @@ -2380,6 +2369,7 @@ async def _test_network_error(self, operation_callback): retryReads=False, serverSelectionTimeoutMS=1000, ) + self.addAsyncCleanup(c.close) # Set host-specific information so we can test whether it is reset. diff --git a/test/asynchronous/test_client_bulk_write.py b/test/asynchronous/test_client_bulk_write.py index c35e823d03..3a17299453 100644 --- a/test/asynchronous/test_client_bulk_write.py +++ b/test/asynchronous/test_client_bulk_write.py @@ -27,7 +27,6 @@ ) from test.utils import ( OvertCommandListener, - async_rs_or_single_client, ) from unittest.mock import patch @@ -39,7 +38,6 @@ InvalidOperation, NetworkTimeout, ) -from pymongo.monitoring import * from pymongo.operations import * from pymongo.write_concern import WriteConcern @@ -97,8 +95,7 @@ async def asyncSetUp(self): @async_client_context.require_no_serverless async def test_batch_splits_if_num_operations_too_large(self): listener = OvertCommandListener() - client = await async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.close) + client = await self.async_rs_or_single_client(event_listeners=[listener]) models = [] for _ in range(self.max_write_batch_size + 1): @@ -123,8 +120,7 @@ async def test_batch_splits_if_num_operations_too_large(self): @async_client_context.require_no_serverless async def test_batch_splits_if_ops_payload_too_large(self): listener = OvertCommandListener() - client = await async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.close) + client = await self.async_rs_or_single_client(event_listeners=[listener]) models = [] num_models = int(self.max_message_size_bytes / self.max_bson_object_size + 1) @@ -157,11 +153,10 @@ async def test_batch_splits_if_ops_payload_too_large(self): @async_client_context.require_failCommand_fail_point async def test_collects_write_concern_errors_across_batches(self): listener = OvertCommandListener() - client = await async_rs_or_single_client( + client = await self.async_rs_or_single_client( event_listeners=[listener], retryWrites=False, ) - self.addAsyncCleanup(client.close) fail_command = { "configureFailPoint": "failCommand", @@ -200,8 +195,7 @@ async def test_collects_write_concern_errors_across_batches(self): @async_client_context.require_no_serverless async def test_collects_write_errors_across_batches_unordered(self): listener = OvertCommandListener() - client = await async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.close) + client = await self.async_rs_or_single_client(event_listeners=[listener]) collection = client.db["coll"] self.addAsyncCleanup(collection.drop) @@ -231,8 +225,7 @@ async def test_collects_write_errors_across_batches_unordered(self): @async_client_context.require_no_serverless async def test_collects_write_errors_across_batches_ordered(self): listener = OvertCommandListener() - client = await async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.close) + client = await self.async_rs_or_single_client(event_listeners=[listener]) collection = client.db["coll"] self.addAsyncCleanup(collection.drop) @@ -262,8 +255,7 @@ async def test_collects_write_errors_across_batches_ordered(self): @async_client_context.require_no_serverless async def test_handles_cursor_requiring_getMore(self): listener = OvertCommandListener() - client = await async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.close) + client = await self.async_rs_or_single_client(event_listeners=[listener]) collection = client.db["coll"] self.addAsyncCleanup(collection.drop) @@ -304,8 +296,7 @@ async def test_handles_cursor_requiring_getMore(self): @async_client_context.require_no_standalone async def test_handles_cursor_requiring_getMore_within_transaction(self): listener = OvertCommandListener() - client = await async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.close) + client = await self.async_rs_or_single_client(event_listeners=[listener]) collection = client.db["coll"] self.addAsyncCleanup(collection.drop) @@ -348,8 +339,7 @@ async def test_handles_cursor_requiring_getMore_within_transaction(self): @async_client_context.require_failCommand_fail_point async def test_handles_getMore_error(self): listener = OvertCommandListener() - client = await async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.close) + client = await self.async_rs_or_single_client(event_listeners=[listener]) collection = client.db["coll"] self.addAsyncCleanup(collection.drop) @@ -403,8 +393,7 @@ async def test_handles_getMore_error(self): @async_client_context.require_no_serverless async def test_returns_error_if_unacknowledged_too_large_insert(self): listener = OvertCommandListener() - client = await async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.close) + client = await self.async_rs_or_single_client(event_listeners=[listener]) b_repeated = "b" * self.max_bson_object_size @@ -460,8 +449,7 @@ async def _setup_namespace_test_models(self): @async_client_context.require_no_serverless async def test_no_batch_splits_if_new_namespace_is_not_too_large(self): listener = OvertCommandListener() - client = await async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.close) + client = await self.async_rs_or_single_client(event_listeners=[listener]) num_models, models = await self._setup_namespace_test_models() models.append( @@ -492,8 +480,7 @@ async def test_no_batch_splits_if_new_namespace_is_not_too_large(self): @async_client_context.require_no_serverless async def test_batch_splits_if_new_namespace_is_too_large(self): listener = OvertCommandListener() - client = await async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.close) + client = await self.async_rs_or_single_client(event_listeners=[listener]) num_models, models = await self._setup_namespace_test_models() c_repeated = "c" * 200 @@ -530,8 +517,7 @@ async def test_batch_splits_if_new_namespace_is_too_large(self): @async_client_context.require_version_min(8, 0, 0, -24) @async_client_context.require_no_serverless async def test_returns_error_if_no_writes_can_be_added_to_ops(self): - client = await async_rs_or_single_client() - self.addAsyncCleanup(client.close) + client = await self.async_rs_or_single_client() # Document too large. b_repeated = "b" * self.max_message_size_bytes @@ -554,8 +540,7 @@ async def test_returns_error_if_auto_encryption_configured(self): key_vault_namespace="db.coll", kms_providers={"aws": {"accessKeyId": "foo", "secretAccessKey": "bar"}}, ) - client = await async_rs_or_single_client(auto_encryption_opts=opts) - self.addAsyncCleanup(client.close) + client = await self.async_rs_or_single_client(auto_encryption_opts=opts) models = [InsertOne(namespace="db.coll", document={"a": "b"})] with self.assertRaises(InvalidOperation) as context: @@ -580,7 +565,7 @@ async def asyncSetUp(self): async def test_timeout_in_multi_batch_bulk_write(self): _OVERHEAD = 500 - internal_client = await async_rs_or_single_client(timeoutMS=None) + internal_client = await self.async_rs_or_single_client(timeoutMS=None) self.addAsyncCleanup(internal_client.close) collection = internal_client.db["coll"] @@ -605,14 +590,13 @@ async def test_timeout_in_multi_batch_bulk_write(self): ) listener = OvertCommandListener() - client = await async_rs_or_single_client( + client = await self.async_rs_or_single_client( event_listeners=[listener], readConcernLevel="majority", readPreference="primary", timeoutMS=2000, w="majority", ) - self.addAsyncCleanup(client.close) await client.admin.command("ping") # Init the client first. with self.assertRaises(ClientBulkWriteException) as context: await client.bulk_write(models=models) diff --git a/test/asynchronous/test_collection.py b/test/asynchronous/test_collection.py index 10d64a525c..74a4a5151d 100644 --- a/test/asynchronous/test_collection.py +++ b/test/asynchronous/test_collection.py @@ -30,6 +30,7 @@ from test import unittest from test.asynchronous import ( # TODO: fix sync imports in PYTHON-4528 AsyncIntegrationTest, + AsyncUnitTest, async_client_context, ) from test.utils import ( @@ -37,8 +38,6 @@ EventListener, async_get_pool, async_is_mongos, - async_rs_or_single_client, - async_single_client, async_wait_until, wait_until, ) @@ -82,14 +81,20 @@ _IS_SYNC = False -class TestCollectionNoConnect(unittest.TestCase): +class TestCollectionNoConnect(AsyncUnitTest): """Test Collection features on a client that does not connect.""" db: AsyncDatabase + client: AsyncMongoClient @classmethod - def setUpClass(cls): - cls.db = AsyncMongoClient(connect=False).pymongo_test + async def _setup_class(cls): + cls.client = AsyncMongoClient(connect=False) + cls.db = cls.client.pymongo_test + + @classmethod + async def _tearDown_class(cls): + await cls.client.close() def test_collection(self): self.assertRaises(TypeError, AsyncCollection, self.db, 5) @@ -1819,8 +1824,7 @@ async def test_exhaust(self): # Insert enough documents to require more than one batch await self.db.test.insert_many([{"i": i} for i in range(150)]) - client = await async_rs_or_single_client(maxPoolSize=1) - self.addAsyncCleanup(client.close) + client = await self.async_rs_or_single_client(maxPoolSize=1) pool = await async_get_pool(client) # Make sure the socket is returned after exhaustion. @@ -2100,7 +2104,7 @@ async def test_find_one_and(self): async def test_find_one_and_write_concern(self): listener = EventListener() - db = (await async_single_client(event_listeners=[listener]))[self.db.name] + db = (await self.async_single_client(event_listeners=[listener]))[self.db.name] # non-default WriteConcern. c_w0 = db.get_collection("test", write_concern=WriteConcern(w=0)) # default WriteConcern. diff --git a/test/asynchronous/test_cursor.py b/test/asynchronous/test_cursor.py index 6967205fe3..d6773d832e 100644 --- a/test/asynchronous/test_cursor.py +++ b/test/asynchronous/test_cursor.py @@ -34,7 +34,6 @@ AllowListEventListener, EventListener, OvertCommandListener, - async_rs_or_single_client, ignore_deprecations, wait_until, ) @@ -232,7 +231,7 @@ async def test_max_await_time_ms(self): self.assertEqual(90, cursor._max_await_time_ms) listener = AllowListEventListener("find", "getMore") - coll = (await async_rs_or_single_client(event_listeners=[listener]))[ + coll = (await self.async_rs_or_single_client(event_listeners=[listener]))[ self.db.name ].pymongo_test @@ -353,8 +352,7 @@ async def test_explain(self): async def test_explain_with_read_concern(self): # Do not add readConcern level to explain. listener = AllowListEventListener("explain") - client = await async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.close) + client = await self.async_rs_or_single_client(event_listeners=[listener]) coll = client.pymongo_test.test.with_options(read_concern=ReadConcern(level="local")) self.assertTrue(await coll.find().explain()) started = listener.started_events @@ -1261,8 +1259,7 @@ async def test_close_kills_cursor_synchronously(self): await self.client._process_periodic_tasks() listener = AllowListEventListener("killCursors") - client = await async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.close) + client = await self.async_rs_or_single_client(event_listeners=[listener]) coll = client[self.db.name].test_close_kills_cursors # Add some test data. @@ -1300,8 +1297,7 @@ def assertCursorKilled(): @async_client_context.require_failCommand_appName async def test_timeout_kills_cursor_asynchronously(self): listener = AllowListEventListener("killCursors") - client = await async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.close) + client = await self.async_rs_or_single_client(event_listeners=[listener]) coll = client[self.db.name].test_timeout_kills_cursor # Add some test data. @@ -1358,8 +1354,7 @@ def test_delete_not_initialized(self): async def test_getMore_does_not_send_readPreference(self): listener = AllowListEventListener("find", "getMore") - client = await async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.close) + client = await self.async_rs_or_single_client(event_listeners=[listener]) # We never send primary read preference so override the default. coll = client[self.db.name].get_collection( "test", read_preference=ReadPreference.PRIMARY_PREFERRED @@ -1463,7 +1458,7 @@ async def test_find_raw_transaction(self): await c.insert_many(docs) listener = OvertCommandListener() - client = await async_rs_or_single_client(event_listeners=[listener]) + client = await self.async_rs_or_single_client(event_listeners=[listener]) async with client.start_session() as session: async with await session.start_transaction(): batches = await ( @@ -1493,7 +1488,7 @@ async def test_find_raw_retryable_reads(self): await c.insert_many(docs) listener = OvertCommandListener() - client = await async_rs_or_single_client(event_listeners=[listener], retryReads=True) + client = await self.async_rs_or_single_client(event_listeners=[listener], retryReads=True) async with self.fail_point( {"mode": {"times": 1}, "data": {"failCommands": ["find"], "closeConnection": True}} ): @@ -1514,7 +1509,7 @@ async def test_find_raw_snapshot_reads(self): await c.insert_many(docs) listener = OvertCommandListener() - client = await async_rs_or_single_client(event_listeners=[listener], retryReads=True) + client = await self.async_rs_or_single_client(event_listeners=[listener], retryReads=True) db = client[self.db.name] async with client.start_session(snapshot=True) as session: await db.test.distinct("x", {}, session=session) @@ -1577,7 +1572,7 @@ async def test_read_concern(self): async def test_monitoring(self): listener = EventListener() - client = await async_rs_or_single_client(event_listeners=[listener]) + client = await self.async_rs_or_single_client(event_listeners=[listener]) c = client.pymongo_test.test await c.drop() await c.insert_many([{"_id": i} for i in range(10)]) @@ -1643,7 +1638,7 @@ async def test_aggregate_raw_transaction(self): await c.insert_many(docs) listener = OvertCommandListener() - client = await async_rs_or_single_client(event_listeners=[listener]) + client = await self.async_rs_or_single_client(event_listeners=[listener]) async with client.start_session() as session: async with await session.start_transaction(): batches = await ( @@ -1674,7 +1669,7 @@ async def test_aggregate_raw_retryable_reads(self): await c.insert_many(docs) listener = OvertCommandListener() - client = await async_rs_or_single_client(event_listeners=[listener], retryReads=True) + client = await self.async_rs_or_single_client(event_listeners=[listener], retryReads=True) async with self.fail_point( {"mode": {"times": 1}, "data": {"failCommands": ["aggregate"], "closeConnection": True}} ): @@ -1698,7 +1693,7 @@ async def test_aggregate_raw_snapshot_reads(self): await c.insert_many(docs) listener = OvertCommandListener() - client = await async_rs_or_single_client(event_listeners=[listener], retryReads=True) + client = await self.async_rs_or_single_client(event_listeners=[listener], retryReads=True) db = client[self.db.name] async with client.start_session(snapshot=True) as session: await db.test.distinct("x", {}, session=session) @@ -1744,7 +1739,7 @@ async def test_collation(self): async def test_monitoring(self): listener = EventListener() - client = await async_rs_or_single_client(event_listeners=[listener]) + client = await self.async_rs_or_single_client(event_listeners=[listener]) c = client.pymongo_test.test await c.drop() await c.insert_many([{"_id": i} for i in range(10)]) @@ -1788,8 +1783,7 @@ async def test_monitoring(self): @async_client_context.require_no_mongos async def test_exhaust_cursor_db_set(self): listener = OvertCommandListener() - client = await async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.close) + client = await self.async_rs_or_single_client(event_listeners=[listener]) c = client.pymongo_test.test await c.delete_many({}) await c.insert_many([{"_id": i} for i in range(3)]) diff --git a/test/asynchronous/test_database.py b/test/asynchronous/test_database.py index 8f6886a2a7..c5d62323df 100644 --- a/test/asynchronous/test_database.py +++ b/test/asynchronous/test_database.py @@ -29,7 +29,6 @@ from test.utils import ( IMPOSSIBLE_WRITE_CONCERN, OvertCommandListener, - async_rs_or_single_client, async_wait_until, ) @@ -208,7 +207,7 @@ async def test_list_collection_names(self): async def test_list_collection_names_filter(self): listener = OvertCommandListener() - client = await async_rs_or_single_client(event_listeners=[listener]) + client = await self.async_rs_or_single_client(event_listeners=[listener]) db = client[self.db.name] await db.capped.drop() await db.create_collection("capped", capped=True, size=4096) @@ -235,8 +234,7 @@ async def test_list_collection_names_filter(self): async def test_check_exists(self): listener = OvertCommandListener() - client = await async_rs_or_single_client(event_listeners=[listener]) - self.addAsyncCleanup(client.close) + client = await self.async_rs_or_single_client(event_listeners=[listener]) db = client[self.db.name] await db.drop_collection("unique") await db.create_collection("unique", check_exists=True) @@ -326,7 +324,7 @@ async def test_list_collections(self): await self.client.drop_database("pymongo_test") async def test_list_collection_names_single_socket(self): - client = await async_rs_or_single_client(maxPoolSize=1) + client = await self.async_rs_or_single_client(maxPoolSize=1) await client.drop_database("test_collection_names_single_socket") db = client.test_collection_names_single_socket for i in range(200): diff --git a/test/asynchronous/test_encryption.py b/test/asynchronous/test_encryption.py index 030f468db2..3f3714eeb4 100644 --- a/test/asynchronous/test_encryption.py +++ b/test/asynchronous/test_encryption.py @@ -31,7 +31,7 @@ from test.asynchronous import AsyncIntegrationTest, AsyncPyMongoTestCase, async_client_context from test.asynchronous.test_bulk import AsyncBulkTestBase from threading import Thread -from typing import Any, Dict, Mapping +from typing import Any, Dict, Mapping, Optional import pytest @@ -44,6 +44,8 @@ from test import ( unittest, ) +from test.asynchronous.test_bulk import AsyncBulkTestBase +from test.asynchronous.utils_spec_runner import AsyncSpecRunner from test.helpers import ( AWS_CREDS, AZURE_CREDS, @@ -59,12 +61,10 @@ OvertCommandListener, SpecTestCreator, TopologyEventListener, - async_rs_or_single_client, async_wait_until, camel_to_snake_args, is_greenthread_patched, ) -from test.utils_spec_runner import SpecRunner from bson import DatetimeMS, Decimal128, encode, json_util from bson.binary import UUID_SUBTYPE, Binary, UuidRepresentation @@ -109,13 +109,12 @@ class TestAutoEncryptionOpts(AsyncPyMongoTestCase): @unittest.skipUnless(os.environ.get("TEST_CRYPT_SHARED"), "crypt_shared lib is not installed") async def test_crypt_shared(self): # Test that we can pick up crypt_shared lib automatically - client = AsyncMongoClient( + self.simple_client( auto_encryption_opts=AutoEncryptionOpts( KMS_PROVIDERS, "keyvault.datakeys", crypt_shared_lib_required=True ), connect=False, ) - self.addAsyncCleanup(client.aclose) @unittest.skipIf(_HAVE_PYMONGOCRYPT, "pymongocrypt is installed") def test_init_requires_pymongocrypt(self): @@ -196,19 +195,16 @@ def test_init_kms_tls_options(self): class TestClientOptions(AsyncPyMongoTestCase): async def test_default(self): - client = AsyncMongoClient(connect=False) - self.addAsyncCleanup(client.aclose) + client = self.simple_client(connect=False) self.assertEqual(get_client_opts(client).auto_encryption_opts, None) - client = AsyncMongoClient(auto_encryption_opts=None, connect=False) - self.addAsyncCleanup(client.aclose) + client = self.simple_client(auto_encryption_opts=None, connect=False) self.assertEqual(get_client_opts(client).auto_encryption_opts, None) @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") async def test_kwargs(self): opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys") - client = AsyncMongoClient(auto_encryption_opts=opts, connect=False) - self.addAsyncCleanup(client.aclose) + client = self.simple_client(auto_encryption_opts=opts, connect=False) self.assertEqual(get_client_opts(client).auto_encryption_opts, opts) @@ -229,6 +225,34 @@ def assertBinaryUUID(self, val): self.assertIsInstance(val, Binary) self.assertEqual(val.subtype, UUID_SUBTYPE) + def create_client_encryption( + self, + kms_providers: Mapping[str, Any], + key_vault_namespace: str, + key_vault_client: AsyncMongoClient, + codec_options: CodecOptions, + kms_tls_options: Optional[Mapping[str, Any]] = None, + ): + client_encryption = AsyncClientEncryption( + kms_providers, key_vault_namespace, key_vault_client, codec_options, kms_tls_options + ) + self.addAsyncCleanup(client_encryption.close) + return client_encryption + + @classmethod + def unmanaged_create_client_encryption( + cls, + kms_providers: Mapping[str, Any], + key_vault_namespace: str, + key_vault_client: AsyncMongoClient, + codec_options: CodecOptions, + kms_tls_options: Optional[Mapping[str, Any]] = None, + ): + client_encryption = AsyncClientEncryption( + kms_providers, key_vault_namespace, key_vault_client, codec_options, kms_tls_options + ) + return client_encryption + # Location of JSON test files. if _IS_SYNC: @@ -260,8 +284,7 @@ def bson_data(*paths): class TestClientSimple(AsyncEncryptionIntegrationTest): async def _test_auto_encrypt(self, opts): - client = await async_rs_or_single_client(auto_encryption_opts=opts) - self.addAsyncCleanup(client.aclose) + client = await self.async_rs_or_single_client(auto_encryption_opts=opts) # Create the encrypted field's data key. key_vault = await create_key_vault( @@ -342,8 +365,7 @@ async def test_auto_encrypt_local_schema_map(self): async def test_use_after_close(self): opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys") - client = await async_rs_or_single_client(auto_encryption_opts=opts) - self.addAsyncCleanup(client.aclose) + client = await self.async_rs_or_single_client(auto_encryption_opts=opts) await client.admin.command("ping") await client.aclose() @@ -360,8 +382,7 @@ async def test_use_after_close(self): ) async def test_fork(self): opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys") - client = await async_rs_or_single_client(auto_encryption_opts=opts) - self.addAsyncCleanup(client.aclose) + client = await self.async_rs_or_single_client(auto_encryption_opts=opts) async def target(): with warnings.catch_warnings(): @@ -375,8 +396,7 @@ async def target(): class TestEncryptedBulkWrite(AsyncBulkTestBase, AsyncEncryptionIntegrationTest): async def test_upsert_uuid_standard_encrypt(self): opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys") - client = await async_rs_or_single_client(auto_encryption_opts=opts) - self.addAsyncCleanup(client.aclose) + client = await self.async_rs_or_single_client(auto_encryption_opts=opts) options = CodecOptions(uuid_representation=UuidRepresentation.STANDARD) encrypted_coll = client.pymongo_test.test @@ -416,8 +436,7 @@ async def _setup_class(cls): @async_client_context.require_version_max(4, 0, 99) async def test_raise_max_wire_version_error(self): opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys") - client = await async_rs_or_single_client(auto_encryption_opts=opts) - self.addAsyncCleanup(client.aclose) + client = await self.async_rs_or_single_client(auto_encryption_opts=opts) msg = "Auto-encryption requires a minimum MongoDB version of 4.2" with self.assertRaisesRegex(ConfigurationError, msg): await client.test.test.insert_one({}) @@ -430,8 +449,7 @@ async def test_raise_max_wire_version_error(self): async def test_raise_unsupported_error(self): opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys") - client = await async_rs_or_single_client(auto_encryption_opts=opts) - self.addAsyncCleanup(client.aclose) + client = await self.async_rs_or_single_client(auto_encryption_opts=opts) msg = "find_raw_batches does not support auto encryption" with self.assertRaisesRegex(InvalidOperation, msg): await client.test.test.find_raw_batches({}) @@ -450,10 +468,9 @@ async def test_raise_unsupported_error(self): class TestExplicitSimple(AsyncEncryptionIntegrationTest): async def test_encrypt_decrypt(self): - client_encryption = AsyncClientEncryption( + client_encryption = self.create_client_encryption( KMS_PROVIDERS, "keyvault.datakeys", async_client_context.client, OPTS ) - self.addAsyncCleanup(client_encryption.close) # Use standard UUID representation. key_vault = async_client_context.client.keyvault.get_collection( "datakeys", codec_options=OPTS @@ -495,10 +512,9 @@ async def test_encrypt_decrypt(self): self.assertEqual(decrypted_ssn, doc["ssn"]) async def test_validation(self): - client_encryption = AsyncClientEncryption( + client_encryption = self.create_client_encryption( KMS_PROVIDERS, "keyvault.datakeys", async_client_context.client, OPTS ) - self.addAsyncCleanup(client_encryption.close) msg = "value to decrypt must be a bson.binary.Binary with subtype 6" with self.assertRaisesRegex(TypeError, msg): @@ -512,10 +528,9 @@ async def test_validation(self): await client_encryption.encrypt("str", algo, key_id=Binary(b"123")) async def test_bson_errors(self): - client_encryption = AsyncClientEncryption( + client_encryption = self.create_client_encryption( KMS_PROVIDERS, "keyvault.datakeys", async_client_context.client, OPTS ) - self.addAsyncCleanup(client_encryption.close) # Attempt to encrypt an unencodable object. unencodable_value = object() @@ -528,7 +543,7 @@ async def test_bson_errors(self): async def test_codec_options(self): with self.assertRaisesRegex(TypeError, "codec_options must be"): - AsyncClientEncryption( + self.create_client_encryption( KMS_PROVIDERS, "keyvault.datakeys", async_client_context.client, @@ -536,10 +551,9 @@ async def test_codec_options(self): ) opts = CodecOptions(uuid_representation=UuidRepresentation.JAVA_LEGACY) - client_encryption_legacy = AsyncClientEncryption( + client_encryption_legacy = self.create_client_encryption( KMS_PROVIDERS, "keyvault.datakeys", async_client_context.client, opts ) - self.addAsyncCleanup(client_encryption_legacy.close) # Create the encrypted field's data key. key_id = await client_encryption_legacy.create_data_key("local") @@ -554,10 +568,9 @@ async def test_codec_options(self): # Encrypt the same UUID with STANDARD codec options. opts = CodecOptions(uuid_representation=UuidRepresentation.STANDARD) - client_encryption = AsyncClientEncryption( + client_encryption = self.create_client_encryption( KMS_PROVIDERS, "keyvault.datakeys", async_client_context.client, opts ) - self.addAsyncCleanup(client_encryption.close) encrypted_standard = await client_encryption.encrypt( value, Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_id=key_id ) @@ -573,7 +586,7 @@ async def test_codec_options(self): self.assertNotEqual(await client_encryption.decrypt(encrypted_legacy), value) async def test_close(self): - client_encryption = AsyncClientEncryption( + client_encryption = self.create_client_encryption( KMS_PROVIDERS, "keyvault.datakeys", async_client_context.client, OPTS ) await client_encryption.close() @@ -589,7 +602,7 @@ async def test_close(self): await client_encryption.decrypt(Binary(b"", 6)) async def test_with_statement(self): - async with AsyncClientEncryption( + async with self.create_client_encryption( KMS_PROVIDERS, "keyvault.datakeys", async_client_context.client, OPTS ) as client_encryption: pass @@ -613,7 +626,7 @@ async def test_with_statement(self): if _IS_SYNC: # TODO: Add asynchronous SpecRunner (https://jira.mongodb.org/browse/PYTHON-4700) - class TestSpec(SpecRunner): + class TestSpec(AsyncSpecRunner): @classmethod @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") def setUpClass(cls): @@ -811,7 +824,7 @@ class TestDataKeyDoubleEncryption(AsyncEncryptionIntegrationTest): async def _setup_class(cls): await super()._setup_class() cls.listener = OvertCommandListener() - cls.client = await async_rs_or_single_client(event_listeners=[cls.listener]) + cls.client = await cls.unmanaged_async_rs_or_single_client(event_listeners=[cls.listener]) await cls.client.db.coll.drop() cls.vault = await create_key_vault(cls.client.keyvault.datakeys) @@ -833,10 +846,10 @@ async def _setup_class(cls): opts = AutoEncryptionOpts( cls.KMS_PROVIDERS, "keyvault.datakeys", schema_map=schemas, kms_tls_options=KMS_TLS_OPTS ) - cls.client_encrypted = await async_rs_or_single_client( + cls.client_encrypted = await cls.unmanaged_async_rs_or_single_client( auto_encryption_opts=opts, uuidRepresentation="standard" ) - cls.client_encryption = AsyncClientEncryption( + cls.client_encryption = cls.unmanaged_create_client_encryption( cls.KMS_PROVIDERS, "keyvault.datakeys", cls.client, OPTS, kms_tls_options=KMS_TLS_OPTS ) @@ -923,10 +936,9 @@ async def _test_external_key_vault(self, with_external_key_vault): # Configure the encrypted field via the local schema_map option. schemas = {"db.coll": json_data("external", "external-schema.json")} if with_external_key_vault: - key_vault_client = await async_rs_or_single_client( + key_vault_client = await self.async_rs_or_single_client( username="fake-user", password="fake-pwd" ) - self.addAsyncCleanup(key_vault_client.close) else: key_vault_client = async_client_context.client opts = AutoEncryptionOpts( @@ -936,15 +948,13 @@ async def _test_external_key_vault(self, with_external_key_vault): key_vault_client=key_vault_client, ) - client_encrypted = await async_rs_or_single_client( + client_encrypted = await self.async_rs_or_single_client( auto_encryption_opts=opts, uuidRepresentation="standard" ) - self.addAsyncCleanup(client_encrypted.close) - client_encryption = AsyncClientEncryption( + client_encryption = self.create_client_encryption( self.kms_providers(), "keyvault.datakeys", key_vault_client, OPTS ) - self.addAsyncCleanup(client_encryption.close) if with_external_key_vault: # Authentication error. @@ -990,10 +1000,9 @@ async def test_views_are_prohibited(self): self.addAsyncCleanup(self.client.db.view.drop) opts = AutoEncryptionOpts(self.kms_providers(), "keyvault.datakeys") - client_encrypted = await async_rs_or_single_client( + client_encrypted = await self.async_rs_or_single_client( auto_encryption_opts=opts, uuidRepresentation="standard" ) - self.addAsyncCleanup(client_encrypted.aclose) with self.assertRaisesRegex(EncryptionError, "cannot auto encrypt a view"): await client_encrypted.db.view.insert_one({}) @@ -1050,17 +1059,15 @@ async def _test_corpus(self, opts): ) self.addAsyncCleanup(vault.drop) - client_encrypted = await async_rs_or_single_client(auto_encryption_opts=opts) - self.addAsyncCleanup(client_encrypted.close) + client_encrypted = await self.async_rs_or_single_client(auto_encryption_opts=opts) - client_encryption = AsyncClientEncryption( + client_encryption = self.create_client_encryption( self.kms_providers(), "keyvault.datakeys", async_client_context.client, OPTS, kms_tls_options=KMS_TLS_OPTS, ) - self.addAsyncCleanup(client_encryption.close) corpus = self.fix_up_curpus(json_data("corpus", "corpus.json")) corpus_copied: SON = SON() @@ -1203,7 +1210,7 @@ async def _setup_class(cls): opts = AutoEncryptionOpts({"local": {"key": LOCAL_MASTER_KEY}}, "keyvault.datakeys") cls.listener = OvertCommandListener() - cls.client_encrypted = await async_rs_or_single_client( + cls.client_encrypted = await cls.unmanaged_async_rs_or_single_client( auto_encryption_opts=opts, event_listeners=[cls.listener] ) cls.coll_encrypted = cls.client_encrypted.db.coll @@ -1291,7 +1298,7 @@ def setUp(self): "gcp": GCP_CREDS, "kmip": KMIP_CREDS, } - self.client_encryption = AsyncClientEncryption( + self.client_encryption = self.create_client_encryption( kms_providers=kms_providers, key_vault_namespace="keyvault.datakeys", key_vault_client=async_client_context.client, @@ -1303,7 +1310,7 @@ def setUp(self): kms_providers_invalid["azure"]["identityPlatformEndpoint"] = "doesnotexist.invalid:443" kms_providers_invalid["gcp"]["endpoint"] = "doesnotexist.invalid:443" kms_providers_invalid["kmip"]["endpoint"] = "doesnotexist.local:5698" - self.client_encryption_invalid = AsyncClientEncryption( + self.client_encryption_invalid = self.create_client_encryption( kms_providers=kms_providers_invalid, key_vault_namespace="keyvault.datakeys", key_vault_client=async_client_context.client, @@ -1484,7 +1491,7 @@ async def test_12_kmip_master_key_invalid_endpoint(self): await self.client_encryption.create_data_key("kmip", key) -class AzureGCPEncryptionTestMixin: +class AzureGCPEncryptionTestMixin(AsyncEncryptionIntegrationTest): DEK = None KMS_PROVIDER_MAP = None KEYVAULT_DB = "keyvault" @@ -1496,7 +1503,7 @@ async def asyncSetUp(self): await create_key_vault(keyvault, self.DEK) async def _test_explicit(self, expectation): - client_encryption = AsyncClientEncryption( + client_encryption = self.create_client_encryption( self.KMS_PROVIDER_MAP, # type: ignore[arg-type] ".".join([self.KEYVAULT_DB, self.KEYVAULT_COLL]), async_client_context.client, @@ -1525,7 +1532,7 @@ async def _test_automatic(self, expectation_extjson, payload): ) insert_listener = AllowListEventListener("insert") - client = await async_rs_or_single_client( + client = await self.async_rs_or_single_client( auto_encryption_opts=encryption_opts, event_listeners=[insert_listener] ) self.addAsyncCleanup(client.aclose) @@ -1604,19 +1611,17 @@ async def test_automatic(self): # https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.rst#deadlock-tests class TestDeadlockProse(AsyncEncryptionIntegrationTest): async def asyncSetUp(self): - self.client_test = await async_rs_or_single_client( + self.client_test = await self.async_rs_or_single_client( maxPoolSize=1, readConcernLevel="majority", w="majority", uuidRepresentation="standard" ) - self.addAsyncCleanup(self.client_test.aclose) self.client_keyvault_listener = OvertCommandListener() - self.client_keyvault = await async_rs_or_single_client( + self.client_keyvault = await self.async_rs_or_single_client( maxPoolSize=1, readConcernLevel="majority", w="majority", event_listeners=[self.client_keyvault_listener], ) - self.addAsyncCleanup(self.client_keyvault.aclose) await self.client_test.keyvault.datakeys.drop() await self.client_test.db.coll.drop() @@ -1629,7 +1634,7 @@ async def asyncSetUp(self): codec_options=OPTS, ) - client_encryption = AsyncClientEncryption( + client_encryption = self.create_client_encryption( kms_providers={"local": {"key": LOCAL_MASTER_KEY}}, key_vault_namespace="keyvault.datakeys", key_vault_client=self.client_test, @@ -1645,7 +1650,7 @@ async def asyncSetUp(self): self.optargs = ({"local": {"key": LOCAL_MASTER_KEY}}, "keyvault.datakeys") async def _run_test(self, max_pool_size, auto_encryption_opts): - client_encrypted = await async_rs_or_single_client( + client_encrypted = await self.async_rs_or_single_client( readConcernLevel="majority", w="majority", maxPoolSize=max_pool_size, @@ -1663,8 +1668,6 @@ async def _run_test(self, max_pool_size, auto_encryption_opts): result = await client_encrypted.db.coll.find_one({"_id": 0}) self.assertEqual(result, {"_id": 0, "encrypted": "string0"}) - self.addAsyncCleanup(client_encrypted.close) - async def test_case_1(self): await self._run_test( max_pool_size=1, @@ -1840,7 +1843,7 @@ async def asyncSetUp(self): await create_key_vault(self.client.keyvault.datakeys) kms_providers_map = {"local": {"key": LOCAL_MASTER_KEY}} - self.client_encryption = AsyncClientEncryption( + self.client_encryption = self.create_client_encryption( kms_providers_map, "keyvault.datakeys", self.client, CodecOptions() ) keyID = await self.client_encryption.create_data_key("local") @@ -1855,10 +1858,9 @@ async def asyncSetUp(self): key_vault_namespace="keyvault.datakeys", kms_providers=kms_providers_map ) self.listener = AllowListEventListener("aggregate") - self.encrypted_client = await async_rs_or_single_client( + self.encrypted_client = await self.async_rs_or_single_client( auto_encryption_opts=opts, retryReads=False, event_listeners=[self.listener] ) - self.addAsyncCleanup(self.encrypted_client.close) async def test_01_command_error(self): async with self.fail_point( @@ -1935,8 +1937,7 @@ def reset_timeout(): "--port=27027", ], ) - client_encrypted = await async_rs_or_single_client(auto_encryption_opts=opts) - self.addAsyncCleanup(client_encrypted.close) + client_encrypted = await self.async_rs_or_single_client(auto_encryption_opts=opts) with self.assertRaisesRegex(EncryptionError, "Timeout"): await client_encrypted.db.coll.insert_one({"encrypted": "test"}) @@ -1950,11 +1951,10 @@ async def test_bypassAutoEncryption(self): "--port=27027", ], ) - client_encrypted = await async_rs_or_single_client(auto_encryption_opts=opts) - self.addAsyncCleanup(client_encrypted.aclose) + client_encrypted = await self.async_rs_or_single_client(auto_encryption_opts=opts) await client_encrypted.db.coll.insert_one({"unencrypted": "test"}) # Validate that mongocryptd was not spawned: - mongocryptd_client = AsyncMongoClient( + mongocryptd_client = self.simple_client( "mongodb://localhost:27027/?serverSelectionTimeoutMS=500" ) with self.assertRaises(ServerSelectionTimeoutError): @@ -1978,15 +1978,13 @@ async def test_via_loading_shared_library(self): ], crypt_shared_lib_required=True, ) - client_encrypted = await async_rs_or_single_client(auto_encryption_opts=opts) - self.addAsyncCleanup(client_encrypted.aclose) + client_encrypted = await self.async_rs_or_single_client(auto_encryption_opts=opts) await client_encrypted.db.coll.drop() await client_encrypted.db.coll.insert_one({"encrypted": "test"}) self.assertEncrypted((await async_client_context.client.db.coll.find_one({}))["encrypted"]) - no_mongocryptd_client = AsyncMongoClient( + no_mongocryptd_client = self.simple_client( host="mongodb://localhost:47021/db?serverSelectionTimeoutMS=1000" ) - self.addAsyncCleanup(no_mongocryptd_client.aclose) with self.assertRaises(ServerSelectionTimeoutError): await no_mongocryptd_client.db.command("ping") @@ -2020,8 +2018,7 @@ def listener(): mongocryptd_uri="mongodb://localhost:47021", crypt_shared_lib_required=False, ) - client_encrypted = await async_rs_or_single_client(auto_encryption_opts=opts) - self.addAsyncCleanup(client_encrypted.aclose) + client_encrypted = await self.async_rs_or_single_client(auto_encryption_opts=opts) await client_encrypted.db.coll.drop() await client_encrypted.db.coll.insert_one({"encrypted": "test"}) server.shutdown() @@ -2035,10 +2032,9 @@ class TestKmsTLSProse(AsyncEncryptionIntegrationTest): async def asyncSetUp(self): await super().asyncSetUp() self.patch_system_certs(CA_PEM) - self.client_encrypted = AsyncClientEncryption( + self.client_encrypted = self.create_client_encryption( {"aws": AWS_CREDS}, "keyvault.datakeys", self.client, OPTS ) - self.addAsyncCleanup(self.client_encrypted.close) async def test_invalid_kms_certificate_expired(self): key = { @@ -2083,36 +2079,32 @@ async def asyncSetUp(self): "gcp": {"tlsCAFile": CA_PEM}, "kmip": {"tlsCAFile": CA_PEM}, } - self.client_encryption_no_client_cert = AsyncClientEncryption( + self.client_encryption_no_client_cert = self.create_client_encryption( providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=kms_tls_opts_ca_only ) - self.addAsyncCleanup(self.client_encryption_no_client_cert.close) # 2, same providers as above but with tlsCertificateKeyFile. kms_tls_opts = copy.deepcopy(kms_tls_opts_ca_only) for p in kms_tls_opts: kms_tls_opts[p]["tlsCertificateKeyFile"] = CLIENT_PEM - self.client_encryption_with_tls = AsyncClientEncryption( + self.client_encryption_with_tls = self.create_client_encryption( providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=kms_tls_opts ) - self.addAsyncCleanup(self.client_encryption_with_tls.close) # 3, update endpoints to expired host. providers: dict = copy.deepcopy(providers) providers["azure"]["identityPlatformEndpoint"] = "127.0.0.1:9000" providers["gcp"]["endpoint"] = "127.0.0.1:9000" providers["kmip"]["endpoint"] = "127.0.0.1:9000" - self.client_encryption_expired = AsyncClientEncryption( + self.client_encryption_expired = self.create_client_encryption( providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=kms_tls_opts_ca_only ) - self.addAsyncCleanup(self.client_encryption_expired.close) # 3, update endpoints to invalid host. providers: dict = copy.deepcopy(providers) providers["azure"]["identityPlatformEndpoint"] = "127.0.0.1:9001" providers["gcp"]["endpoint"] = "127.0.0.1:9001" providers["kmip"]["endpoint"] = "127.0.0.1:9001" - self.client_encryption_invalid_hostname = AsyncClientEncryption( + self.client_encryption_invalid_hostname = self.create_client_encryption( providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=kms_tls_opts_ca_only ) - self.addAsyncCleanup(self.client_encryption_invalid_hostname.close) # Errors when client has no cert, some examples: # [SSL: TLSV13_ALERT_CERTIFICATE_REQUIRED] tlsv13 alert certificate required (_ssl.c:2623) self.cert_error = ( @@ -2150,7 +2142,7 @@ async def asyncSetUp(self): "gcp:with_tls": with_cert, "kmip:with_tls": with_cert, } - self.client_encryption_with_names = AsyncClientEncryption( + self.client_encryption_with_names = self.create_client_encryption( providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=kms_tls_opts_4 ) @@ -2232,10 +2224,9 @@ async def test_04_kmip(self): async def test_05_tlsDisableOCSPEndpointCheck_is_permitted(self): providers = {"aws": {"accessKeyId": "foo", "secretAccessKey": "bar"}} options = {"aws": {"tlsDisableOCSPEndpointCheck": True}} - encryption = AsyncClientEncryption( + encryption = self.create_client_encryption( providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=options ) - self.addAsyncCleanup(encryption.close) ctx = encryption._io_callbacks.opts._kms_ssl_contexts["aws"] if not hasattr(ctx, "check_ocsp_endpoint"): raise self.skipTest("OCSP not enabled") @@ -2285,7 +2276,7 @@ async def asyncSetUp(self): self.client = async_client_context.client await create_key_vault(self.client.keyvault.datakeys) kms_providers_map = {"local": {"key": LOCAL_MASTER_KEY}} - self.client_encryption = AsyncClientEncryption( + self.client_encryption = self.create_client_encryption( kms_providers_map, "keyvault.datakeys", self.client, CodecOptions() ) self.def_key_id = await self.client_encryption.create_data_key( @@ -2327,17 +2318,15 @@ async def asyncSetUp(self): key_vault = await create_key_vault(self.client.keyvault.datakeys, self.key1_document) self.addCleanup(key_vault.drop) self.key_vault_client = self.client - self.client_encryption = AsyncClientEncryption( + self.client_encryption = self.create_client_encryption( {"local": {"key": LOCAL_MASTER_KEY}}, key_vault.full_name, self.key_vault_client, OPTS ) - self.addAsyncCleanup(self.client_encryption.close) opts = AutoEncryptionOpts( {"local": {"key": LOCAL_MASTER_KEY}}, key_vault.full_name, bypass_query_analysis=True, ) - self.encrypted_client = await async_rs_or_single_client(auto_encryption_opts=opts) - self.addAsyncCleanup(self.encrypted_client.aclose) + self.encrypted_client = await self.async_rs_or_single_client(auto_encryption_opts=opts) async def test_01_insert_encrypted_indexed_and_find(self): val = "encrypted indexed value" @@ -2464,14 +2453,13 @@ async def run_test(self, src_provider, dst_provider): await self.client.keyvault.drop_collection("datakeys") # Step 2. Create a ``AsyncClientEncryption`` object named ``client_encryption1`` - client_encryption1 = AsyncClientEncryption( + client_encryption1 = self.create_client_encryption( key_vault_client=self.client, key_vault_namespace="keyvault.datakeys", kms_providers=ALL_KMS_PROVIDERS, kms_tls_options=KMS_TLS_OPTS, codec_options=OPTS, ) - self.addAsyncCleanup(client_encryption1.close) # Step 3. Call ``client_encryption1.create_data_key`` with ``src_provider``. key_id = await client_encryption1.create_data_key( @@ -2484,16 +2472,14 @@ async def run_test(self, src_provider, dst_provider): ) # Step 5. Create a ``AsyncClientEncryption`` object named ``client_encryption2`` - client2 = await async_rs_or_single_client() - self.addAsyncCleanup(client2.aclose) - client_encryption2 = AsyncClientEncryption( + client2 = await self.async_rs_or_single_client() + client_encryption2 = self.create_client_encryption( key_vault_client=client2, key_vault_namespace="keyvault.datakeys", kms_providers=ALL_KMS_PROVIDERS, kms_tls_options=KMS_TLS_OPTS, codec_options=OPTS, ) - self.addAsyncCleanup(client_encryption2.close) # Step 6. Call ``client_encryption2.rewrap_many_data_key`` with an empty ``filter``. rewrap_many_data_key_result = await client_encryption2.rewrap_many_data_key( @@ -2528,7 +2514,7 @@ async def asyncSetUp(self): @unittest.skipIf(any(AWS_CREDS.values()), "AWS environment credentials are set") async def test_01_failure(self): - self.client_encryption = AsyncClientEncryption( + self.client_encryption = self.create_client_encryption( kms_providers={"aws": {}}, key_vault_namespace="keyvault.datakeys", key_vault_client=async_client_context.client, @@ -2539,7 +2525,7 @@ async def test_01_failure(self): @unittest.skipUnless(any(AWS_CREDS.values()), "AWS environment credentials are not set") async def test_02_success(self): - self.client_encryption = AsyncClientEncryption( + self.client_encryption = self.create_client_encryption( kms_providers={"aws": {}}, key_vault_namespace="keyvault.datakeys", key_vault_client=async_client_context.client, @@ -2559,8 +2545,7 @@ async def test_queryable_encryption(self): # AsyncMongoClient to use in testing that handles auth/tls/etc, # and cleanup. async def AsyncMongoClient(**kwargs): - c = await async_rs_or_single_client(**kwargs) - self.addAsyncCleanup(c.aclose) + c = await self.async_rs_or_single_client(**kwargs) return c # Drop data from prior test runs. @@ -2571,7 +2556,7 @@ async def AsyncMongoClient(**kwargs): # Create two data keys. key_vault_client = await AsyncMongoClient() - client_encryption = AsyncClientEncryption( + client_encryption = self.create_client_encryption( kms_providers_map, "keyvault.datakeys", key_vault_client, CodecOptions() ) key1_id = await client_encryption.create_data_key("local") @@ -2652,18 +2637,16 @@ async def asyncSetUp(self): key_vault = await create_key_vault(self.client.keyvault.datakeys, self.key1_document) self.addCleanup(key_vault.drop) self.key_vault_client = self.client - self.client_encryption = AsyncClientEncryption( + self.client_encryption = self.create_client_encryption( {"local": {"key": LOCAL_MASTER_KEY}}, key_vault.full_name, self.key_vault_client, OPTS ) - self.addAsyncCleanup(self.client_encryption.close) opts = AutoEncryptionOpts( {"local": {"key": LOCAL_MASTER_KEY}}, key_vault.full_name, bypass_query_analysis=True, ) - self.encrypted_client = await async_rs_or_single_client(auto_encryption_opts=opts) + self.encrypted_client = await self.async_rs_or_single_client(auto_encryption_opts=opts) self.db = self.encrypted_client.db - self.addAsyncCleanup(self.encrypted_client.aclose) async def run_expression_find( self, name, expression, expected_elems, range_opts, use_expr=False, key_id=None @@ -2860,10 +2843,9 @@ async def asyncSetUp(self): await super().asyncSetUp() await self.client.drop_database(self.db) self.key_vault_client = self.client - self.client_encryption = AsyncClientEncryption( + self.client_encryption = self.create_client_encryption( {"local": {"key": LOCAL_MASTER_KEY}}, "keyvault.datakeys", self.key_vault_client, OPTS ) - self.addAsyncCleanup(self.client_encryption.close) self.key_id = await self.client_encryption.create_data_key("local") opts = RangeOpts(min=0, max=1000) self.payload_defaults = await self.client_encryption.encrypt( @@ -2896,13 +2878,12 @@ async def asyncSetUp(self): await self.client.drop_database(self.db) self.key_vault = await create_key_vault(self.client.keyvault.datakeys, self.key1_document) self.addAsyncCleanup(self.key_vault.drop) - self.client_encryption = AsyncClientEncryption( + self.client_encryption = self.create_client_encryption( {"local": {"key": LOCAL_MASTER_KEY}}, self.key_vault.full_name, self.client, OPTS, ) - self.addAsyncCleanup(self.client_encryption.close) async def test_01_simple_create(self): coll, _ = await self.client_encryption.create_encrypted_collection( @@ -3118,10 +3099,9 @@ async def _tearDown_class(cls): async def asyncSetUp(self) -> None: self.listener = OvertCommandListener() - self.mongocryptd_client = AsyncMongoClient( + self.mongocryptd_client = self.simple_client( f"mongodb://localhost:{self.MONGOCRYPTD_PORT}", event_listeners=[self.listener] ) - self.addAsyncCleanup(self.mongocryptd_client.aclose) hello = await self.mongocryptd_client.db.command("hello") self.assertNotIn("logicalSessionTimeoutMinutes", hello) diff --git a/test/asynchronous/test_grid_file.py b/test/asynchronous/test_grid_file.py index 6d589dc01c..9c57c15c5a 100644 --- a/test/asynchronous/test_grid_file.py +++ b/test/asynchronous/test_grid_file.py @@ -33,7 +33,7 @@ sys.path[0:0] = [""] -from test.utils import EventListener, async_rs_or_single_client +from test.utils import EventListener from bson.objectid import ObjectId from gridfs.asynchronous.grid_file import ( @@ -792,7 +792,7 @@ async def test_grid_out_lazy_connect(self): await outfile.readchunk() async def test_grid_in_lazy_connect(self): - client = AsyncMongoClient("badhost", connect=False, serverSelectionTimeoutMS=10) + client = self.simple_client("badhost", connect=False, serverSelectionTimeoutMS=10) fs = client.db.fs infile = AsyncGridIn(fs, file_id=-1, chunk_size=1) with self.assertRaises(ServerSelectionTimeoutError): @@ -803,7 +803,7 @@ async def test_grid_in_lazy_connect(self): async def test_unacknowledged(self): # w=0 is prohibited. with self.assertRaises(ConfigurationError): - AsyncGridIn((await async_rs_or_single_client(w=0)).pymongo_test.fs) + AsyncGridIn((await self.async_rs_or_single_client(w=0)).pymongo_test.fs) async def test_survive_cursor_not_found(self): # By default the find command returns 101 documents in the first batch. @@ -811,7 +811,7 @@ async def test_survive_cursor_not_found(self): chunk_size = 1024 data = b"d" * (102 * chunk_size) listener = EventListener() - client = await async_rs_or_single_client(event_listeners=[listener]) + client = await self.async_rs_or_single_client(event_listeners=[listener]) db = client.pymongo_test async with AsyncGridIn(db.fs, chunk_size=chunk_size) as infile: await infile.write(data) diff --git a/test/asynchronous/test_logger.py b/test/asynchronous/test_logger.py index b219d530e7..a2e8b35c5f 100644 --- a/test/asynchronous/test_logger.py +++ b/test/asynchronous/test_logger.py @@ -16,7 +16,6 @@ import os from test import unittest from test.asynchronous import AsyncIntegrationTest -from test.utils import async_single_client from unittest.mock import patch from bson import json_util @@ -86,7 +85,7 @@ async def test_truncation_multi_byte_codepoints(self): self.assertEqual(last_3_bytes, str_to_repeat) async def test_logging_without_listeners(self): - c = await async_single_client() + c = await self.async_single_client() self.assertEqual(len(c._event_listeners.event_listeners()), 0) with self.assertLogs("pymongo.connection", level="DEBUG") as cm: await c.db.test.insert_one({"x": "1"}) diff --git a/test/asynchronous/test_monitoring.py b/test/asynchronous/test_monitoring.py index 3f6563ee56..b5d8708dc3 100644 --- a/test/asynchronous/test_monitoring.py +++ b/test/asynchronous/test_monitoring.py @@ -31,8 +31,6 @@ ) from test.utils import ( EventListener, - async_rs_or_single_client, - async_single_client, async_wait_until, ) @@ -57,7 +55,7 @@ class AsyncTestCommandMonitoring(AsyncIntegrationTest): async def _setup_class(cls): await super()._setup_class() cls.listener = EventListener() - cls.client = await async_rs_or_single_client( + cls.client = await cls.unmanaged_async_rs_or_single_client( event_listeners=[cls.listener], retryWrites=False ) @@ -407,7 +405,7 @@ async def test_get_more_failure(self): @async_client_context.require_secondaries_count(1) async def test_not_primary_error(self): address = next(iter(await async_client_context.client.secondaries)) - client = await async_single_client(*address, event_listeners=[self.listener]) + client = await self.async_single_client(*address, event_listeners=[self.listener]) # Clear authentication command results from the listener. await client.admin.command("ping") self.listener.reset() @@ -1146,7 +1144,7 @@ async def _setup_class(cls): # We plan to call register(), which internally modifies _LISTENERS. cls.saved_listeners = copy.deepcopy(monitoring._LISTENERS) monitoring.register(cls.listener) - cls.client = await async_single_client() + cls.client = await cls.unmanaged_async_single_client() # Get one (authenticated) socket in the pool. await cls.client.pymongo_test.command("ping") diff --git a/test/asynchronous/test_session.py b/test/asynchronous/test_session.py index 1e1f5659ba..d264b5ecb0 100644 --- a/test/asynchronous/test_session.py +++ b/test/asynchronous/test_session.py @@ -36,9 +36,7 @@ from test.utils import ( EventListener, ExceptionCatchingThread, - async_rs_or_single_client, async_wait_until, - rs_or_single_client, wait_until, ) @@ -90,7 +88,7 @@ async def _setup_class(cls): await super()._setup_class() # Create a second client so we can make sure clients cannot share # sessions. - cls.client2 = await async_rs_or_single_client() + cls.client2 = await cls.unmanaged_async_rs_or_single_client() # Redact no commands, so we can test user-admin commands have "lsid". cls.sensitive_commands = monitoring._SENSITIVE_COMMANDS.copy() @@ -105,7 +103,7 @@ async def _tearDown_class(cls): async def asyncSetUp(self): self.listener = SessionTestListener() self.session_checker_listener = SessionTestListener() - self.client = await async_rs_or_single_client( + self.client = await self.async_rs_or_single_client( event_listeners=[self.listener, self.session_checker_listener] ) self.addAsyncCleanup(self.client.close) @@ -202,7 +200,7 @@ def test_implicit_sessions_checkout(self): failures = 0 for _ in range(5): listener = EventListener() - client = async_rs_or_single_client(event_listeners=[listener], maxPoolSize=1) + client = self.async_rs_or_single_client(event_listeners=[listener], maxPoolSize=1) cursor = client.db.test.find({}) ops: List[Tuple[Callable, List[Any]]] = [ (client.db.test.find_one, [{"_id": 1}]), @@ -285,7 +283,7 @@ async def test_end_session(self): async def test_end_sessions(self): # Use a new client so that the tearDown hook does not error. listener = SessionTestListener() - client = await async_rs_or_single_client(event_listeners=[listener]) + client = await self.async_rs_or_single_client(event_listeners=[listener]) # Start many sessions. sessions = [client.start_session() for _ in range(_MAX_END_SESSIONS + 1)] for s in sessions: @@ -789,8 +787,7 @@ async def _test_unacknowledged_ops(self, client, *ops): async def test_unacknowledged_writes(self): # Ensure the collection exists. await self.client.pymongo_test.test_unacked_writes.insert_one({}) - client = await async_rs_or_single_client(w=0, event_listeners=[self.listener]) - self.addAsyncCleanup(client.close) + client = await self.async_rs_or_single_client(w=0, event_listeners=[self.listener]) db = client.pymongo_test coll = db.test_unacked_writes ops: list = [ @@ -838,7 +835,7 @@ class TestCausalConsistency(AsyncUnitTest): @classmethod async def _setup_class(cls): cls.listener = SessionTestListener() - cls.client = await async_rs_or_single_client(event_listeners=[cls.listener]) + cls.client = await cls.unmanaged_async_rs_or_single_client(event_listeners=[cls.listener]) @classmethod async def _tearDown_class(cls): @@ -1153,10 +1150,9 @@ async def asyncSetUp(self): async def test_cluster_time(self): listener = SessionTestListener() # Prevent heartbeats from updating $clusterTime between operations. - client = await async_rs_or_single_client( + client = await self.async_rs_or_single_client( event_listeners=[listener], heartbeatFrequencyMS=999999 ) - self.addAsyncCleanup(client.close) collection = client.pymongo_test.collection # Prepare for tests of find() and aggregate(). await collection.insert_many([{} for _ in range(10)]) diff --git a/test/asynchronous/test_transactions.py b/test/asynchronous/test_transactions.py index 4034c8e2c4..b5d0686417 100644 --- a/test/asynchronous/test_transactions.py +++ b/test/asynchronous/test_transactions.py @@ -17,6 +17,7 @@ import sys from io import BytesIO +from test.asynchronous.utils_spec_runner import AsyncSpecRunner from gridfs.asynchronous.grid_file import AsyncGridFS, AsyncGridFSBucket @@ -25,8 +26,6 @@ from test.asynchronous import AsyncIntegrationTest, async_client_context, unittest from test.utils import ( OvertCommandListener, - async_rs_client, - async_single_client, wait_until, ) from typing import List @@ -59,7 +58,18 @@ UNPIN_TEST_MAX_ATTEMPTS = 50 -class TestTransactions(AsyncIntegrationTest): +class AsyncTransactionsBase(AsyncSpecRunner): + def maybe_skip_scenario(self, test): + super().maybe_skip_scenario(test) + if ( + "secondary" in self.id() + and not async_client_context.is_mongos + and not async_client_context.has_secondaries + ): + raise unittest.SkipTest("No secondaries") + + +class TestTransactions(AsyncTransactionsBase): RUN_ON_SERVERLESS = True @async_client_context.require_transactions @@ -92,8 +102,7 @@ def test_transaction_options_validation(self): @async_client_context.require_transactions async def test_transaction_write_concern_override(self): """Test txn overrides Client/Database/Collection write_concern.""" - client = await async_rs_client(w=0) - self.addAsyncCleanup(client.close) + client = await self.async_rs_client(w=0) db = client.test coll = db.test await coll.insert_one({}) @@ -150,12 +159,13 @@ async def test_transaction_write_concern_override(self): async def test_unpin_for_next_transaction(self): # Increase localThresholdMS and wait until both nodes are discovered # to avoid false positives. - client = await async_rs_client(async_client_context.mongos_seeds(), localThresholdMS=1000) + client = await self.async_rs_client( + async_client_context.mongos_seeds(), localThresholdMS=1000 + ) wait_until(lambda: len(client.nodes) > 1, "discover both mongoses") coll = client.test.test # Create the collection. await coll.insert_one({}) - self.addAsyncCleanup(client.close) async with client.start_session() as s: # Session is pinned to Mongos. async with await s.start_transaction(): @@ -178,12 +188,13 @@ async def test_unpin_for_next_transaction(self): async def test_unpin_for_non_transaction_operation(self): # Increase localThresholdMS and wait until both nodes are discovered # to avoid false positives. - client = await async_rs_client(async_client_context.mongos_seeds(), localThresholdMS=1000) + client = await self.async_rs_client( + async_client_context.mongos_seeds(), localThresholdMS=1000 + ) wait_until(lambda: len(client.nodes) > 1, "discover both mongoses") coll = client.test.test # Create the collection. await coll.insert_one({}) - self.addAsyncCleanup(client.close) async with client.start_session() as s: # Session is pinned to Mongos. async with await s.start_transaction(): @@ -307,11 +318,10 @@ async def test_transaction_starts_with_batched_write(self): # Start a transaction with a batch of operations that needs to be # split. listener = OvertCommandListener() - client = await async_rs_client(event_listeners=[listener]) + client = await self.async_rs_client(event_listeners=[listener]) coll = client[self.db.name].test await coll.delete_many({}) listener.reset() - self.addAsyncCleanup(client.close) self.addAsyncCleanup(coll.drop) large_str = "\0" * (1 * 1024 * 1024) ops: List[InsertOne[RawBSONDocument]] = [ @@ -336,8 +346,7 @@ async def test_transaction_starts_with_batched_write(self): @async_client_context.require_transactions async def test_transaction_direct_connection(self): - client = await async_single_client() - self.addAsyncCleanup(client.close) + client = await self.async_single_client() coll = client.pymongo_test.test # Make sure the collection exists. @@ -393,14 +402,16 @@ def __exit__(self, exc_type, exc_val, exc_tb): client_session._WITH_TRANSACTION_RETRY_TIME_LIMIT = self.real_timeout -class TestTransactionsConvenientAPI(AsyncIntegrationTest): +class TestTransactionsConvenientAPI(AsyncTransactionsBase): @classmethod async def _setup_class(cls): await super()._setup_class() cls.mongos_clients = [] if async_client_context.supports_transactions(): for address in async_client_context.mongoses: - cls.mongos_clients.append(await async_single_client("{}:{}".format(*address))) + cls.mongos_clients.append( + await cls.unmanaged_async_single_client("{}:{}".format(*address)) + ) @classmethod async def _tearDown_class(cls): @@ -450,8 +461,7 @@ async def callback2(session): @async_client_context.require_transactions async def test_callback_not_retried_after_timeout(self): listener = OvertCommandListener() - client = await async_rs_client(event_listeners=[listener]) - self.addAsyncCleanup(client.close) + client = await self.async_rs_client(event_listeners=[listener]) coll = client[self.db.name].test async def callback(session): @@ -479,8 +489,7 @@ async def callback(session): @async_client_context.require_transactions async def test_callback_not_retried_after_commit_timeout(self): listener = OvertCommandListener() - client = await async_rs_client(event_listeners=[listener]) - self.addAsyncCleanup(client.close) + client = await self.async_rs_client(event_listeners=[listener]) coll = client[self.db.name].test async def callback(session): @@ -514,8 +523,7 @@ async def callback(session): @async_client_context.require_transactions async def test_commit_not_retried_after_timeout(self): listener = OvertCommandListener() - client = await async_rs_client(event_listeners=[listener]) - self.addAsyncCleanup(client.close) + client = await self.async_rs_client(event_listeners=[listener]) coll = client[self.db.name].test async def callback(session): diff --git a/test/asynchronous/utils_spec_runner.py b/test/asynchronous/utils_spec_runner.py index 71044d1530..12cb13c2cd 100644 --- a/test/asynchronous/utils_spec_runner.py +++ b/test/asynchronous/utils_spec_runner.py @@ -25,7 +25,6 @@ EventListener, OvertCommandListener, ServerAndTopologyEventListener, - async_rs_client, camel_to_snake, camel_to_snake_args, parse_spec_options, @@ -101,6 +100,8 @@ async def _setup_class(cls): @classmethod async def _tearDown_class(cls): cls.knobs.disable() + for client in cls.mongos_clients: + await client.close() await super()._tearDown_class() def setUp(self): @@ -527,7 +528,7 @@ async def run_scenario(self, scenario_def, test): host = async_client_context.MULTI_MONGOS_LB_URI elif async_client_context.is_mongos: host = async_client_context.mongos_seeds() - client = await async_rs_client( + client = await self.async_rs_client( h=host, event_listeners=[listener, pool_listener, server_listener], **client_options ) self.scenario_client = client diff --git a/test/auth_aws/test_auth_aws.py b/test/auth_aws/test_auth_aws.py index 10416ae5fe..a7660f2f67 100644 --- a/test/auth_aws/test_auth_aws.py +++ b/test/auth_aws/test_auth_aws.py @@ -18,6 +18,7 @@ import os import sys import unittest +from test import PyMongoTestCase from unittest.mock import patch import pytest @@ -36,7 +37,7 @@ pytestmark = pytest.mark.auth_aws -class TestAuthAWS(unittest.TestCase): +class TestAuthAWS(PyMongoTestCase): uri: str @classmethod @@ -69,7 +70,7 @@ def setup_cache(self): self.skipTest("Not testing cached credentials") # Make a connection to ensure that we enable caching. - client = MongoClient(self.uri) + client = self.simple_client(self.uri) client.get_database().test.find_one() client.close() @@ -79,7 +80,7 @@ def setup_cache(self): auth.set_cached_credentials(None) self.assertEqual(auth.get_cached_credentials(), None) - client = MongoClient(self.uri) + client = self.simple_client(self.uri) client.get_database().test.find_one() client.close() return auth.get_cached_credentials() @@ -90,8 +91,7 @@ def test_cache_credentials(self): def test_cache_about_to_expire(self): creds = self.setup_cache() - client = MongoClient(self.uri) - self.addCleanup(client.close) + client = self.simple_client(self.uri) # Make the creds about to expire. creds = auth.get_cached_credentials() @@ -107,8 +107,7 @@ def test_cache_about_to_expire(self): def test_poisoned_cache(self): creds = self.setup_cache() - client = MongoClient(self.uri) - self.addCleanup(client.close) + client = self.simple_client(self.uri) # Poison the creds with invalid password. assert creds is not None @@ -130,8 +129,7 @@ def test_environment_variables_ignored(self): self.assertIsNotNone(creds) os.environ.copy() - client = MongoClient(self.uri) - self.addCleanup(client.close) + client = self.simple_client(self.uri) client.get_database().test.find_one() @@ -149,8 +147,7 @@ def test_environment_variables_ignored(self): auth.set_cached_credentials(None) - client2 = MongoClient(self.uri) - self.addCleanup(client2.close) + client2 = self.simple_client(self.uri) with patch.dict("os.environ", mock_env): self.assertEqual(os.environ["AWS_ACCESS_KEY_ID"], "foo") @@ -166,8 +163,7 @@ def test_no_cache_environment_variables(self): if creds.token: mock_env["AWS_SESSION_TOKEN"] = creds.token - client = MongoClient(self.uri) - self.addCleanup(client.close) + client = self.simple_client(self.uri) with patch.dict(os.environ, mock_env): self.assertEqual(os.environ["AWS_ACCESS_KEY_ID"], creds.username) @@ -177,22 +173,19 @@ def test_no_cache_environment_variables(self): mock_env["AWS_ACCESS_KEY_ID"] = "foo" - client2 = MongoClient(self.uri) - self.addCleanup(client2.close) + client2 = self.simple_client(self.uri) with patch.dict("os.environ", mock_env), self.assertRaises(OperationFailure): self.assertEqual(os.environ["AWS_ACCESS_KEY_ID"], "foo") client2.get_database().test.find_one() -class TestAWSLambdaExamples(unittest.TestCase): +class TestAWSLambdaExamples(PyMongoTestCase): def test_shared_client(self): # Start AWS Lambda Example 1 import os - from pymongo import MongoClient - - client = MongoClient(host=os.environ["MONGODB_URI"]) + client = self.simple_client(host=os.environ["MONGODB_URI"]) def lambda_handler(event, context): return client.db.command("ping") @@ -203,9 +196,7 @@ def test_IAM_auth(self): # Start AWS Lambda Example 2 import os - from pymongo import MongoClient - - client = MongoClient( + client = self.simple_client( host=os.environ["MONGODB_URI"], authSource="$external", authMechanism="MONGODB-AWS", diff --git a/test/auth_oidc/test_auth_oidc.py b/test/auth_oidc/test_auth_oidc.py index fa4b7d6697..6d31f3db4e 100644 --- a/test/auth_oidc/test_auth_oidc.py +++ b/test/auth_oidc/test_auth_oidc.py @@ -23,6 +23,7 @@ import warnings from contextlib import contextmanager from pathlib import Path +from test import PyMongoTestCase from typing import Dict import pytest @@ -56,7 +57,7 @@ pytestmark = pytest.mark.auth_oidc -class OIDCTestBase(unittest.TestCase): +class OIDCTestBase(PyMongoTestCase): @classmethod def setUpClass(cls): cls.uri_single = os.environ["MONGODB_URI_SINGLE"] @@ -94,6 +95,7 @@ def fail_point(self, command_args): yield finally: client.admin.command("configureFailPoint", cmd_on["configureFailPoint"], mode="off") + client.close() @pytest.mark.auth_oidc @@ -149,7 +151,9 @@ def create_client(self, *args, **kwargs): if not len(args): args = [self.uri_single] - return MongoClient(*args, authmechanismproperties=props, **kwargs) + client = self.simple_client(*args, authmechanismproperties=props, **kwargs) + + return client def test_1_1_single_principal_implicit_username(self): # Create default OIDC client with authMechanism=MONGODB-OIDC. diff --git a/test/mockupdb/test_cursor.py b/test/mockupdb/test_cursor.py index 46af39c7b9..2300297218 100644 --- a/test/mockupdb/test_cursor.py +++ b/test/mockupdb/test_cursor.py @@ -29,13 +29,12 @@ from bson.objectid import ObjectId -from pymongo import MongoClient from pymongo.errors import OperationFailure pytestmark = pytest.mark.mockupdb -class TestCursor(unittest.TestCase): +class TestCursor(PyMongoTestCase): def test_getmore_load_balanced(self): server = MockupDB() server.autoresponds( @@ -50,7 +49,7 @@ def test_getmore_load_balanced(self): server.run() self.addCleanup(server.stop) - client = MongoClient(server.uri, loadBalanced=True) + client = self.simple_client(server.uri, loadBalanced=True) self.addCleanup(client.close) collection = client.db.coll cursor = collection.find() @@ -77,7 +76,7 @@ def _test_fail_on_operation_failure_with_code(self, code): self.addCleanup(server.stop) server.autoresponds("ismaster", maxWireVersion=6) - client = MongoClient(server.uri) + client = self.simple_client(server.uri) with going(lambda: server.receives(OpMsg({"find": "collection"})).command_err(code=code)): cursor = client.db.collection.find() diff --git a/test/ocsp/test_ocsp.py b/test/ocsp/test_ocsp.py index fe7f21160e..a42b3a34ee 100644 --- a/test/ocsp/test_ocsp.py +++ b/test/ocsp/test_ocsp.py @@ -48,8 +48,11 @@ def _connect(options): uri = f"mongodb://localhost:27017/?serverSelectionTimeoutMS={TIMEOUT_MS}&tlsCAFile={CA_FILE}&{options}" print(uri) - client = pymongo.MongoClient(uri) - client.admin.command("ping") + try: + client = pymongo.MongoClient(uri) + client.admin.command("ping") + finally: + client.close() class TestOCSP(unittest.TestCase): diff --git a/test/test_auth.py b/test/test_auth.py index fa3d0905bb..b311d330bc 100644 --- a/test/test_auth.py +++ b/test/test_auth.py @@ -23,16 +23,14 @@ sys.path[0:0] = [""] -from test import IntegrationTest, SkipTest, client_context, unittest -from test.utils import ( - AllowListEventListener, - delay, - ignore_deprecations, - rs_or_single_client, - rs_or_single_client_noauth, - single_client, - single_client_noauth, +from test import ( + IntegrationTest, + PyMongoTestCase, + SkipTest, + client_context, + unittest, ) +from test.utils import AllowListEventListener, delay, ignore_deprecations from pymongo import MongoClient, monitoring from pymongo.auth_shared import _build_credentials_tuple @@ -81,7 +79,7 @@ def run(self): self.success = True -class TestGSSAPI(unittest.TestCase): +class TestGSSAPI(PyMongoTestCase): mech_properties: str service_realm_required: bool @@ -138,7 +136,7 @@ def test_gssapi_simple(self): if not self.service_realm_required: # Without authMechanismProperties. - client = MongoClient( + client = self.simple_client( GSSAPI_HOST, GSSAPI_PORT, username=GSSAPI_PRINCIPAL, @@ -149,11 +147,11 @@ def test_gssapi_simple(self): client[GSSAPI_DB].collection.find_one() # Log in using URI, without authMechanismProperties. - client = MongoClient(uri) + client = self.simple_client(uri) client[GSSAPI_DB].collection.find_one() # Authenticate with authMechanismProperties. - client = MongoClient( + client = self.simple_client( GSSAPI_HOST, GSSAPI_PORT, username=GSSAPI_PRINCIPAL, @@ -166,14 +164,14 @@ def test_gssapi_simple(self): # Log in using URI, with authMechanismProperties. mech_uri = uri + f"&authMechanismProperties={self.mech_properties}" - client = MongoClient(mech_uri) + client = self.simple_client(mech_uri) client[GSSAPI_DB].collection.find_one() set_name = client_context.replica_set_name if set_name: if not self.service_realm_required: # Without authMechanismProperties - client = MongoClient( + client = self.simple_client( GSSAPI_HOST, GSSAPI_PORT, username=GSSAPI_PRINCIPAL, @@ -185,11 +183,11 @@ def test_gssapi_simple(self): client[GSSAPI_DB].list_collection_names() uri = uri + f"&replicaSet={set_name!s}" - client = MongoClient(uri) + client = self.simple_client(uri) client[GSSAPI_DB].list_collection_names() # With authMechanismProperties - client = MongoClient( + client = self.simple_client( GSSAPI_HOST, GSSAPI_PORT, username=GSSAPI_PRINCIPAL, @@ -202,13 +200,13 @@ def test_gssapi_simple(self): client[GSSAPI_DB].list_collection_names() mech_uri = mech_uri + f"&replicaSet={set_name!s}" - client = MongoClient(mech_uri) + client = self.simple_client(mech_uri) client[GSSAPI_DB].list_collection_names() @ignore_deprecations @client_context.require_sync def test_gssapi_threaded(self): - client = MongoClient( + client = self.simple_client( GSSAPI_HOST, GSSAPI_PORT, username=GSSAPI_PRINCIPAL, @@ -244,7 +242,7 @@ def test_gssapi_threaded(self): set_name = client_context.replica_set_name if set_name: - client = MongoClient( + client = self.simple_client( GSSAPI_HOST, GSSAPI_PORT, username=GSSAPI_PRINCIPAL, @@ -267,14 +265,14 @@ def test_gssapi_threaded(self): self.assertTrue(thread.success) -class TestSASLPlain(unittest.TestCase): +class TestSASLPlain(PyMongoTestCase): @classmethod def setUpClass(cls): if not SASL_HOST or not SASL_USER or not SASL_PASS: raise SkipTest("Must set SASL_HOST, SASL_USER, and SASL_PASS to test SASL") def test_sasl_plain(self): - client = MongoClient( + client = self.simple_client( SASL_HOST, SASL_PORT, username=SASL_USER, @@ -293,12 +291,12 @@ def test_sasl_plain(self): SASL_PORT, SASL_DB, ) - client = MongoClient(uri) + client = self.simple_client(uri) client.ldap.test.find_one() set_name = client_context.replica_set_name if set_name: - client = MongoClient( + client = self.simple_client( SASL_HOST, SASL_PORT, replicaSet=set_name, @@ -317,7 +315,7 @@ def test_sasl_plain(self): SASL_DB, str(set_name), ) - client = MongoClient(uri) + client = self.simple_client(uri) client.ldap.test.find_one() def test_sasl_plain_bad_credentials(self): @@ -331,8 +329,8 @@ def auth_string(user, password): ) return uri - bad_user = MongoClient(auth_string("not-user", SASL_PASS)) - bad_pwd = MongoClient(auth_string(SASL_USER, "not-pwd")) + bad_user = self.simple_client(auth_string("not-user", SASL_PASS)) + bad_pwd = self.simple_client(auth_string(SASL_USER, "not-pwd")) # OperationFailure raised upon connecting. with self.assertRaises(OperationFailure): bad_user.admin.command("ping") @@ -354,7 +352,7 @@ def tearDown(self): def test_scram_sha1(self): host, port = client_context.host, client_context.port - client = rs_or_single_client_noauth( + client = self.rs_or_single_client_noauth( "mongodb://user:pass@%s:%d/pymongo_test?authMechanism=SCRAM-SHA-1" % (host, port) ) client.pymongo_test.command("dbstats") @@ -365,7 +363,7 @@ def test_scram_sha1(self): "@%s:%d/pymongo_test?authMechanism=SCRAM-SHA-1" "&replicaSet=%s" % (host, port, client_context.replica_set_name) ) - client = single_client_noauth(uri) + client = self.single_client_noauth(uri) client.pymongo_test.command("dbstats") db = client.get_database("pymongo_test", read_preference=ReadPreference.SECONDARY) db.command("dbstats") @@ -393,7 +391,7 @@ def test_scram_skip_empty_exchange(self): "testscram", "sha256", "pwd", roles=["dbOwner"], mechanisms=["SCRAM-SHA-256"] ) - client = rs_or_single_client_noauth( + client = self.rs_or_single_client_noauth( username="sha256", password="pwd", authSource="testscram", event_listeners=[listener] ) client.testscram.command("dbstats") @@ -430,36 +428,38 @@ def test_scram(self): ) # Step 2: verify auth success cases - client = rs_or_single_client_noauth(username="sha1", password="pwd", authSource="testscram") + client = self.rs_or_single_client_noauth( + username="sha1", password="pwd", authSource="testscram" + ) client.testscram.command("dbstats") - client = rs_or_single_client_noauth( + client = self.rs_or_single_client_noauth( username="sha1", password="pwd", authSource="testscram", authMechanism="SCRAM-SHA-1" ) client.testscram.command("dbstats") - client = rs_or_single_client_noauth( + client = self.rs_or_single_client_noauth( username="sha256", password="pwd", authSource="testscram" ) client.testscram.command("dbstats") - client = rs_or_single_client_noauth( + client = self.rs_or_single_client_noauth( username="sha256", password="pwd", authSource="testscram", authMechanism="SCRAM-SHA-256" ) client.testscram.command("dbstats") # Step 2: SCRAM-SHA-1 and SCRAM-SHA-256 - client = rs_or_single_client_noauth( + client = self.rs_or_single_client_noauth( username="both", password="pwd", authSource="testscram", authMechanism="SCRAM-SHA-1" ) client.testscram.command("dbstats") - client = rs_or_single_client_noauth( + client = self.rs_or_single_client_noauth( username="both", password="pwd", authSource="testscram", authMechanism="SCRAM-SHA-256" ) client.testscram.command("dbstats") self.listener.reset() - client = rs_or_single_client_noauth( + client = self.rs_or_single_client_noauth( username="both", password="pwd", authSource="testscram", event_listeners=[self.listener] ) client.testscram.command("dbstats") @@ -472,19 +472,19 @@ def test_scram(self): self.assertEqual(started.command.get("mechanism"), "SCRAM-SHA-256") # Step 3: verify auth failure conditions - client = rs_or_single_client_noauth( + client = self.rs_or_single_client_noauth( username="sha1", password="pwd", authSource="testscram", authMechanism="SCRAM-SHA-256" ) with self.assertRaises(OperationFailure): client.testscram.command("dbstats") - client = rs_or_single_client_noauth( + client = self.rs_or_single_client_noauth( username="sha256", password="pwd", authSource="testscram", authMechanism="SCRAM-SHA-1" ) with self.assertRaises(OperationFailure): client.testscram.command("dbstats") - client = rs_or_single_client_noauth( + client = self.rs_or_single_client_noauth( username="not-a-user", password="pwd", authSource="testscram" ) with self.assertRaises(OperationFailure): @@ -497,7 +497,7 @@ def test_scram(self): port, client_context.replica_set_name, ) - client = single_client_noauth(uri) + client = self.single_client_noauth(uri) client.testscram.command("dbstats") db = client.get_database("testscram", read_preference=ReadPreference.SECONDARY) db.command("dbstats") @@ -517,12 +517,12 @@ def test_scram_saslprep(self): "testscram", "IX", "IX", roles=["dbOwner"], mechanisms=["SCRAM-SHA-256"] ) - client = rs_or_single_client_noauth( + client = self.rs_or_single_client_noauth( username="\u2168", password="\u2163", authSource="testscram" ) client.testscram.command("dbstats") - client = rs_or_single_client_noauth( + client = self.rs_or_single_client_noauth( username="\u2168", password="\u2163", authSource="testscram", @@ -530,17 +530,17 @@ def test_scram_saslprep(self): ) client.testscram.command("dbstats") - client = rs_or_single_client_noauth( + client = self.rs_or_single_client_noauth( username="\u2168", password="IV", authSource="testscram" ) client.testscram.command("dbstats") - client = rs_or_single_client_noauth( + client = self.rs_or_single_client_noauth( username="IX", password="I\u00ADX", authSource="testscram" ) client.testscram.command("dbstats") - client = rs_or_single_client_noauth( + client = self.rs_or_single_client_noauth( username="IX", password="I\u00ADX", authSource="testscram", @@ -548,25 +548,29 @@ def test_scram_saslprep(self): ) client.testscram.command("dbstats") - client = rs_or_single_client_noauth( + client = self.rs_or_single_client_noauth( username="IX", password="IX", authSource="testscram", authMechanism="SCRAM-SHA-256" ) client.testscram.command("dbstats") - client = rs_or_single_client_noauth( + client = self.rs_or_single_client_noauth( "mongodb://\u2168:\u2163@%s:%d/testscram" % (host, port) ) client.testscram.command("dbstats") - client = rs_or_single_client_noauth("mongodb://\u2168:IV@%s:%d/testscram" % (host, port)) + client = self.rs_or_single_client_noauth( + "mongodb://\u2168:IV@%s:%d/testscram" % (host, port) + ) client.testscram.command("dbstats") - client = rs_or_single_client_noauth("mongodb://IX:I\u00ADX@%s:%d/testscram" % (host, port)) + client = self.rs_or_single_client_noauth( + "mongodb://IX:I\u00ADX@%s:%d/testscram" % (host, port) + ) client.testscram.command("dbstats") - client = rs_or_single_client_noauth("mongodb://IX:IX@%s:%d/testscram" % (host, port)) + client = self.rs_or_single_client_noauth("mongodb://IX:IX@%s:%d/testscram" % (host, port)) client.testscram.command("dbstats") def test_cache(self): - client = single_client() + client = self.single_client() credentials = client.options.pool_options._credentials cache = credentials.cache self.assertIsNotNone(cache) @@ -591,8 +595,7 @@ def test_scram_threaded(self): coll.insert_one({"_id": 1}) # The first thread to call find() will authenticate - client = rs_or_single_client() - self.addCleanup(client.close) + client = self.rs_or_single_client() coll = client.db.test threads = [] for _ in range(4): @@ -619,7 +622,7 @@ def tearDown(self): def test_uri_options(self): # Test default to admin host, port = client_context.host, client_context.port - client = rs_or_single_client_noauth("mongodb://admin:pass@%s:%d" % (host, port)) + client = self.rs_or_single_client_noauth("mongodb://admin:pass@%s:%d" % (host, port)) self.assertTrue(client.admin.command("dbstats")) if client_context.is_rs: @@ -628,14 +631,14 @@ def test_uri_options(self): port, client_context.replica_set_name, ) - client = single_client_noauth(uri) + client = self.single_client_noauth(uri) self.assertTrue(client.admin.command("dbstats")) db = client.get_database("admin", read_preference=ReadPreference.SECONDARY) self.assertTrue(db.command("dbstats")) # Test explicit database uri = "mongodb://user:pass@%s:%d/pymongo_test" % (host, port) - client = rs_or_single_client_noauth(uri) + client = self.rs_or_single_client_noauth(uri) with self.assertRaises(OperationFailure): client.admin.command("dbstats") self.assertTrue(client.pymongo_test.command("dbstats")) @@ -646,7 +649,7 @@ def test_uri_options(self): port, client_context.replica_set_name, ) - client = single_client_noauth(uri) + client = self.single_client_noauth(uri) with self.assertRaises(OperationFailure): client.admin.command("dbstats") self.assertTrue(client.pymongo_test.command("dbstats")) @@ -655,7 +658,7 @@ def test_uri_options(self): # Test authSource uri = "mongodb://user:pass@%s:%d/pymongo_test2?authSource=pymongo_test" % (host, port) - client = rs_or_single_client_noauth(uri) + client = self.rs_or_single_client_noauth(uri) with self.assertRaises(OperationFailure): client.pymongo_test2.command("dbstats") self.assertTrue(client.pymongo_test.command("dbstats")) @@ -665,7 +668,7 @@ def test_uri_options(self): "mongodb://user:pass@%s:%d/pymongo_test2?replicaSet=" "%s;authSource=pymongo_test" % (host, port, client_context.replica_set_name) ) - client = single_client_noauth(uri) + client = self.single_client_noauth(uri) with self.assertRaises(OperationFailure): client.pymongo_test2.command("dbstats") self.assertTrue(client.pymongo_test.command("dbstats")) diff --git a/test/test_auth_spec.py b/test/test_auth_spec.py index 38e5f19bf8..3c3a1a67ae 100644 --- a/test/test_auth_spec.py +++ b/test/test_auth_spec.py @@ -20,6 +20,7 @@ import os import sys import warnings +from test import PyMongoTestCase sys.path[0:0] = [""] @@ -34,7 +35,7 @@ _TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "auth") -class TestAuthSpec(unittest.TestCase): +class TestAuthSpec(PyMongoTestCase): pass @@ -54,7 +55,7 @@ def run_test(self): warnings.simplefilter("default") self.assertRaises(Exception, MongoClient, uri, connect=False) else: - client = MongoClient(uri, connect=False) + client = self.simple_client(uri, connect=False) credentials = client.options.pool_options._credentials if credential is None: self.assertIsNone(credentials) diff --git a/test/test_bulk.py b/test/test_bulk.py index 63b8c7790a..64fd48e8cd 100644 --- a/test/test_bulk.py +++ b/test/test_bulk.py @@ -24,22 +24,13 @@ sys.path[0:0] = [""] from test import IntegrationTest, client_context, remove_all_users, unittest -from test.utils import ( - rs_or_single_client_noauth, - single_client, - wait_until, -) +from test.utils import wait_until from bson.binary import Binary, UuidRepresentation from bson.codec_options import CodecOptions from bson.objectid import ObjectId from pymongo.common import partition_node -from pymongo.errors import ( - BulkWriteError, - ConfigurationError, - InvalidOperation, - OperationFailure, -) +from pymongo.errors import BulkWriteError, ConfigurationError, InvalidOperation, OperationFailure from pymongo.operations import * from pymongo.synchronous.collection import Collection from pymongo.write_concern import WriteConcern @@ -913,7 +904,7 @@ class TestBulkAuthorization(BulkAuthorizationTestBase): def test_readonly(self): # We test that an authorization failure aborts the batch and is raised # as OperationFailure. - cli = rs_or_single_client_noauth( + cli = self.rs_or_single_client_noauth( username="readonly", password="pw", authSource="pymongo_test" ) coll = cli.pymongo_test.test @@ -924,7 +915,7 @@ def test_readonly(self): def test_no_remove(self): # We test that an authorization failure aborts the batch and is raised # as OperationFailure. - cli = rs_or_single_client_noauth( + cli = self.rs_or_single_client_noauth( username="noremove", password="pw", authSource="pymongo_test" ) coll = cli.pymongo_test.test @@ -952,7 +943,7 @@ def _setup_class(cls): if cls.w is not None and cls.w > 1: for member in (client_context.hello)["hosts"]: if member != (client_context.hello)["primary"]: - cls.secondary = single_client(*partition_node(member)) + cls.secondary = cls.unmanaged_single_client(*partition_node(member)) break @classmethod diff --git a/test/test_change_stream.py b/test/test_change_stream.py index cb19452aec..dae224c5e0 100644 --- a/test/test_change_stream.py +++ b/test/test_change_stream.py @@ -28,12 +28,17 @@ sys.path[0:0] = [""] -from test import IntegrationTest, Version, client_context, unittest +from test import ( + IntegrationTest, + PyMongoTestCase, + Version, + client_context, + unittest, +) from test.unified_format import generate_test_classes from test.utils import ( AllowListEventListener, EventListener, - rs_or_single_client, wait_until, ) @@ -69,8 +74,7 @@ def change_stream(self, *args, **kwargs): def client_with_listener(self, *commands): """Return a client with a AllowListEventListener.""" listener = AllowListEventListener(*commands) - client = rs_or_single_client(event_listeners=[listener]) - self.addCleanup(client.close) + client = self.rs_or_single_client(event_listeners=[listener]) return client, listener def watched_collection(self, *args, **kwargs): @@ -174,7 +178,7 @@ def _wait_until(): @no_type_check def test_try_next_runs_one_getmore(self): listener = EventListener() - client = rs_or_single_client(event_listeners=[listener]) + client = self.rs_or_single_client(event_listeners=[listener]) # Connect to the cluster. client.admin.command("ping") listener.reset() @@ -232,7 +236,7 @@ def _wait_until(): @no_type_check def test_batch_size_is_honored(self): listener = EventListener() - client = rs_or_single_client(event_listeners=[listener]) + client = self.rs_or_single_client(event_listeners=[listener]) # Connect to the cluster. client.admin.command("ping") listener.reset() @@ -473,7 +477,7 @@ class ProseSpecTestsMixin: @no_type_check def _client_with_listener(self, *commands): listener = AllowListEventListener(*commands) - client = rs_or_single_client(event_listeners=[listener]) + client = PyMongoTestCase.unmanaged_rs_or_single_client(event_listeners=[listener]) self.addCleanup(client.close) return client, listener @@ -1111,7 +1115,7 @@ class TestAllLegacyScenarios(IntegrationTest): def _setup_class(cls): super()._setup_class() cls.listener = AllowListEventListener("aggregate", "getMore") - cls.client = rs_or_single_client(event_listeners=[cls.listener]) + cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener]) @classmethod def _tearDown_class(cls): diff --git a/test/test_client.py b/test/test_client.py index 785139d6a8..bc45325f0b 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -67,10 +67,6 @@ is_greenthread_patched, lazy_client_trial, one, - rs_client, - rs_or_single_client, - rs_or_single_client_noauth, - single_client, wait_until, ) @@ -131,7 +127,7 @@ class ClientUnitTest(UnitTest): @classmethod def _setup_class(cls): - cls.client = rs_or_single_client(connect=False, serverSelectionTimeoutMS=100) + cls.client = cls.unmanaged_rs_or_single_client(connect=False, serverSelectionTimeoutMS=100) @classmethod def _tearDown_class(cls): @@ -142,7 +138,7 @@ def inject_fixtures(self, caplog): self._caplog = caplog def test_keyword_arg_defaults(self): - client = MongoClient( + client = self.simple_client( socketTimeoutMS=None, connectTimeoutMS=20000, waitQueueTimeoutMS=None, @@ -168,15 +164,17 @@ def test_keyword_arg_defaults(self): self.assertAlmostEqual(12, client.options.server_selection_timeout) def test_connect_timeout(self): - client = MongoClient(connect=False, connectTimeoutMS=None, socketTimeoutMS=None) + client = self.simple_client(connect=False, connectTimeoutMS=None, socketTimeoutMS=None) pool_opts = client.options.pool_options self.assertEqual(None, pool_opts.socket_timeout) self.assertEqual(None, pool_opts.connect_timeout) - client = MongoClient(connect=False, connectTimeoutMS=0, socketTimeoutMS=0) + + client = self.simple_client(connect=False, connectTimeoutMS=0, socketTimeoutMS=0) pool_opts = client.options.pool_options self.assertEqual(None, pool_opts.socket_timeout) self.assertEqual(None, pool_opts.connect_timeout) - client = MongoClient( + + client = self.simple_client( "mongodb://localhost/?connectTimeoutMS=0&socketTimeoutMS=0", connect=False ) pool_opts = client.options.pool_options @@ -193,7 +191,7 @@ def test_types(self): self.assertRaises(ConfigurationError, MongoClient, []) def test_max_pool_size_zero(self): - MongoClient(maxPoolSize=0) + self.simple_client(maxPoolSize=0) def test_uri_detection(self): self.assertRaises(ConfigurationError, MongoClient, "/foo") @@ -258,7 +256,7 @@ def test_iteration(self): self.assertNotIsInstance(client, Iterable) def test_get_default_database(self): - c = rs_or_single_client( + c = self.rs_or_single_client( "mongodb://%s:%d/foo" % (client_context.host, client_context.port), connect=False, ) @@ -274,7 +272,7 @@ def test_get_default_database(self): self.assertEqual(ReadPreference.SECONDARY, db.read_preference) self.assertEqual(write_concern, db.write_concern) - c = rs_or_single_client( + c = self.rs_or_single_client( "mongodb://%s:%d/" % (client_context.host, client_context.port), connect=False, ) @@ -282,7 +280,7 @@ def test_get_default_database(self): def test_get_default_database_error(self): # URI with no database. - c = rs_or_single_client( + c = self.rs_or_single_client( "mongodb://%s:%d/" % (client_context.host, client_context.port), connect=False, ) @@ -294,11 +292,11 @@ def test_get_default_database_with_authsource(self): client_context.host, client_context.port, ) - c = rs_or_single_client(uri, connect=False) + c = self.rs_or_single_client(uri, connect=False) self.assertEqual(Database(c, "foo"), c.get_default_database()) def test_get_database_default(self): - c = rs_or_single_client( + c = self.rs_or_single_client( "mongodb://%s:%d/foo" % (client_context.host, client_context.port), connect=False, ) @@ -306,7 +304,7 @@ def test_get_database_default(self): def test_get_database_default_error(self): # URI with no database. - c = rs_or_single_client( + c = self.rs_or_single_client( "mongodb://%s:%d/" % (client_context.host, client_context.port), connect=False, ) @@ -318,19 +316,19 @@ def test_get_database_default_with_authsource(self): client_context.host, client_context.port, ) - c = rs_or_single_client(uri, connect=False) + c = self.rs_or_single_client(uri, connect=False) self.assertEqual(Database(c, "foo"), c.get_database()) def test_primary_read_pref_with_tags(self): # No tags allowed with "primary". with self.assertRaises(ConfigurationError): - MongoClient("mongodb://host/?readpreferencetags=dc:east") + self.single_client("mongodb://host/?readpreferencetags=dc:east") with self.assertRaises(ConfigurationError): - MongoClient("mongodb://host/?readpreference=primary&readpreferencetags=dc:east") + self.single_client("mongodb://host/?readpreference=primary&readpreferencetags=dc:east") def test_read_preference(self): - c = rs_or_single_client( + c = self.rs_or_single_client( "mongodb://host", connect=False, readpreference=ReadPreference.NEAREST.mongos_mode ) self.assertEqual(c.read_preference, ReadPreference.NEAREST) @@ -339,26 +337,30 @@ def test_metadata(self): metadata = copy.deepcopy(_METADATA) metadata["driver"]["name"] = "PyMongo" metadata["application"] = {"name": "foobar"} - client = MongoClient("mongodb://foo:27017/?appname=foobar&connect=false") + client = self.simple_client("mongodb://foo:27017/?appname=foobar&connect=false") options = client.options self.assertEqual(options.pool_options.metadata, metadata) - client = MongoClient("foo", 27017, appname="foobar", connect=False) + client = self.simple_client("foo", 27017, appname="foobar", connect=False) options = client.options self.assertEqual(options.pool_options.metadata, metadata) # No error - MongoClient(appname="x" * 128) - self.assertRaises(ValueError, MongoClient, appname="x" * 129) + self.simple_client(appname="x" * 128) + with self.assertRaises(ValueError): + self.simple_client(appname="x" * 129) # Bad "driver" options. self.assertRaises(TypeError, DriverInfo, "Foo", 1, "a") self.assertRaises(TypeError, DriverInfo, version="1", platform="a") self.assertRaises(TypeError, DriverInfo) - self.assertRaises(TypeError, MongoClient, driver=1) - self.assertRaises(TypeError, MongoClient, driver="abc") - self.assertRaises(TypeError, MongoClient, driver=("Foo", "1", "a")) + with self.assertRaises(TypeError): + self.simple_client(driver=1) + with self.assertRaises(TypeError): + self.simple_client(driver="abc") + with self.assertRaises(TypeError): + self.simple_client(driver=("Foo", "1", "a")) # Test appending to driver info. metadata["driver"]["name"] = "PyMongo|FooDriver" metadata["driver"]["version"] = "{}|1.2.3".format(_METADATA["driver"]["version"]) - client = MongoClient( + client = self.simple_client( "foo", 27017, appname="foobar", @@ -368,7 +370,7 @@ def test_metadata(self): options = client.options self.assertEqual(options.pool_options.metadata, metadata) metadata["platform"] = "{}|FooPlatform".format(_METADATA["platform"]) - client = MongoClient( + client = self.simple_client( "foo", 27017, appname="foobar", @@ -378,7 +380,7 @@ def test_metadata(self): options = client.options self.assertEqual(options.pool_options.metadata, metadata) # Test truncating driver info metadata. - client = MongoClient( + client = self.simple_client( driver=DriverInfo(name="s" * _MAX_METADATA_SIZE), connect=False, ) @@ -387,7 +389,7 @@ def test_metadata(self): len(bson.encode(options.pool_options.metadata)), _MAX_METADATA_SIZE, ) - client = MongoClient( + client = self.simple_client( driver=DriverInfo(name="s" * _MAX_METADATA_SIZE, version="s" * _MAX_METADATA_SIZE), connect=False, ) @@ -403,7 +405,7 @@ def test_container_metadata(self): metadata["driver"]["name"] = "PyMongo" metadata["env"] = {} metadata["env"]["container"] = {"orchestrator": "kubernetes"} - client = MongoClient("mongodb://foo:27017/?appname=foobar&connect=false") + client = self.simple_client("mongodb://foo:27017/?appname=foobar&connect=false") options = client.options self.assertEqual(options.pool_options.metadata["env"], metadata["env"]) @@ -429,7 +431,7 @@ def transform_python(self, value): uuid_representation_label = "javaLegacy" unicode_decode_error_handler = "ignore" tzinfo = utc - c = MongoClient( + c = self.simple_client( document_class=document_class, type_registry=type_registry, tz_aware=tz_aware, @@ -438,12 +440,12 @@ def transform_python(self, value): tzinfo=tzinfo, connect=False, ) - self.assertEqual(c.codec_options.document_class, document_class) self.assertEqual(c.codec_options.type_registry, type_registry) self.assertEqual(c.codec_options.tz_aware, tz_aware) self.assertEqual( - c.codec_options.uuid_representation, _UUID_REPRESENTATIONS[uuid_representation_label] + c.codec_options.uuid_representation, + _UUID_REPRESENTATIONS[uuid_representation_label], ) self.assertEqual(c.codec_options.unicode_decode_error_handler, unicode_decode_error_handler) self.assertEqual(c.codec_options.tzinfo, tzinfo) @@ -465,11 +467,11 @@ def test_uri_codec_options(self): datetime_conversion, ) ) - c = MongoClient(uri, connect=False) - + c = self.simple_client(uri, connect=False) self.assertEqual(c.codec_options.tz_aware, True) self.assertEqual( - c.codec_options.uuid_representation, _UUID_REPRESENTATIONS[uuid_representation_label] + c.codec_options.uuid_representation, + _UUID_REPRESENTATIONS[uuid_representation_label], ) self.assertEqual(c.codec_options.unicode_decode_error_handler, unicode_decode_error_handler) self.assertEqual( @@ -478,8 +480,7 @@ def test_uri_codec_options(self): # Change the passed datetime_conversion to a number and re-assert. uri = uri.replace(datetime_conversion, f"{int(DatetimeConversion[datetime_conversion])}") - c = MongoClient(uri, connect=False) - + c = self.simple_client(uri, connect=False) self.assertEqual( c.codec_options.datetime_conversion, DatetimeConversion[datetime_conversion] ) @@ -487,7 +488,9 @@ def test_uri_codec_options(self): def test_uri_option_precedence(self): # Ensure kwarg options override connection string options. uri = "mongodb://localhost/?ssl=true&replicaSet=name&readPreference=primary" - c = MongoClient(uri, ssl=False, replicaSet="newname", readPreference="secondaryPreferred") + c = self.simple_client( + uri, ssl=False, replicaSet="newname", readPreference="secondaryPreferred" + ) clopts = c.options opts = clopts._options @@ -516,7 +519,7 @@ def reset_resolver(): def test_scenario(args, kwargs, expected_value): patched_resolver.reset() - MongoClient(*args, **kwargs) + self.simple_client(*args, **kwargs) for _, kw in patched_resolver.call_list(): self.assertAlmostEqual(kw["lifetime"], expected_value) @@ -536,15 +539,15 @@ def test_scenario(args, kwargs, expected_value): def test_uri_security_options(self): # Ensure that we don't silently override security-related options. with self.assertRaises(InvalidURI): - MongoClient("mongodb://localhost/?ssl=true", tls=False, connect=False) + self.simple_client("mongodb://localhost/?ssl=true", tls=False, connect=False) # Matching SSL and TLS options should not cause errors. - c = MongoClient("mongodb://localhost/?ssl=false", tls=False, connect=False) + c = self.simple_client("mongodb://localhost/?ssl=false", tls=False, connect=False) self.assertEqual(c.options._options["tls"], False) # Conflicting tlsInsecure options should raise an error. with self.assertRaises(InvalidURI): - MongoClient( + self.simple_client( "mongodb://localhost/?tlsInsecure=true", connect=False, tlsAllowInvalidHostnames=True, @@ -552,7 +555,7 @@ def test_uri_security_options(self): # Conflicting legacy tlsInsecure options should also raise an error. with self.assertRaises(InvalidURI): - MongoClient( + self.simple_client( "mongodb://localhost/?tlsInsecure=true", connect=False, tlsAllowInvalidCertificates=False, @@ -560,10 +563,10 @@ def test_uri_security_options(self): # Conflicting kwargs should raise InvalidURI with self.assertRaises(InvalidURI): - MongoClient(ssl=True, tls=False) + self.simple_client(ssl=True, tls=False) def test_event_listeners(self): - c = MongoClient(event_listeners=[], connect=False) + c = self.simple_client(event_listeners=[], connect=False) self.assertEqual(c.options.event_listeners, []) listeners = [ event_loggers.CommandLogger(), @@ -572,11 +575,11 @@ def test_event_listeners(self): event_loggers.TopologyLogger(), event_loggers.ConnectionPoolLogger(), ] - c = MongoClient(event_listeners=listeners, connect=False) + c = self.simple_client(event_listeners=listeners, connect=False) self.assertEqual(c.options.event_listeners, listeners) def test_client_options(self): - c = MongoClient(connect=False) + c = self.simple_client(connect=False) self.assertIsInstance(c.options, ClientOptions) self.assertIsInstance(c.options.pool_options, PoolOptions) self.assertEqual(c.options.server_selection_timeout, 30) @@ -606,11 +609,11 @@ def test_detected_environment_logging(self, mock_get_hosts): ) with self.assertLogs("pymongo", level="INFO") as cm: for host in normal_hosts: - MongoClient(host) + MongoClient(host, connect=False) for host in srv_hosts: mock_get_hosts.return_value = [(host, 1)] - MongoClient(host) - MongoClient(multi_host) + MongoClient(host, connect=False) + MongoClient(multi_host, connect=False) logs = [record.getMessage() for record in cm.records if record.name == "pymongo.client"] self.assertEqual(len(logs), 7) @@ -628,13 +631,13 @@ def test_detected_environment_warning(self, mock_get_hosts): ) for host in normal_hosts: with self.assertWarns(UserWarning): - MongoClient(host) + self.simple_client(host) for host in srv_hosts: mock_get_hosts.return_value = [(host, 1)] with self.assertWarns(UserWarning): - MongoClient(host) + self.simple_client(host) with self.assertWarns(UserWarning): - MongoClient(multi_host) + self.simple_client(multi_host) class TestClient(IntegrationTest): @@ -651,18 +654,17 @@ def test_multiple_uris(self): def test_max_idle_time_reaper_default(self): with client_knobs(kill_cursor_frequency=0.1): # Assert reaper doesn't remove connections when maxIdleTimeMS not set - client = rs_or_single_client() + client = self.rs_or_single_client() server = (client._get_topology()).select_server(readable_server_selector, _Op.TEST) with server._pool.checkout() as conn: pass self.assertEqual(1, len(server._pool.conns)) self.assertTrue(conn in server._pool.conns) - client.close() def test_max_idle_time_reaper_removes_stale_minPoolSize(self): with client_knobs(kill_cursor_frequency=0.1): # Assert reaper removes idle socket and replaces it with a new one - client = rs_or_single_client(maxIdleTimeMS=500, minPoolSize=1) + client = self.rs_or_single_client(maxIdleTimeMS=500, minPoolSize=1) server = (client._get_topology()).select_server(readable_server_selector, _Op.TEST) with server._pool.checkout() as conn: pass @@ -671,12 +673,11 @@ def test_max_idle_time_reaper_removes_stale_minPoolSize(self): self.assertGreaterEqual(len(server._pool.conns), 1) wait_until(lambda: conn not in server._pool.conns, "remove stale socket") wait_until(lambda: len(server._pool.conns) >= 1, "replace stale socket") - client.close() def test_max_idle_time_reaper_does_not_exceed_maxPoolSize(self): with client_knobs(kill_cursor_frequency=0.1): # Assert reaper respects maxPoolSize when adding new connections. - client = rs_or_single_client(maxIdleTimeMS=500, minPoolSize=1, maxPoolSize=1) + client = self.rs_or_single_client(maxIdleTimeMS=500, minPoolSize=1, maxPoolSize=1) server = (client._get_topology()).select_server(readable_server_selector, _Op.TEST) with server._pool.checkout() as conn: pass @@ -685,12 +686,11 @@ def test_max_idle_time_reaper_does_not_exceed_maxPoolSize(self): self.assertEqual(1, len(server._pool.conns)) wait_until(lambda: conn not in server._pool.conns, "remove stale socket") wait_until(lambda: len(server._pool.conns) == 1, "replace stale socket") - client.close() def test_max_idle_time_reaper_removes_stale(self): with client_knobs(kill_cursor_frequency=0.1): # Assert reaper has removed idle socket and NOT replaced it - client = rs_or_single_client(maxIdleTimeMS=500) + client = self.rs_or_single_client(maxIdleTimeMS=500) server = (client._get_topology()).select_server(readable_server_selector, _Op.TEST) with server._pool.checkout() as conn_one: pass @@ -703,16 +703,15 @@ def test_max_idle_time_reaper_removes_stale(self): lambda: len(server._pool.conns) == 0, "stale socket reaped and new one NOT added to the pool", ) - client.close() def test_min_pool_size(self): with client_knobs(kill_cursor_frequency=0.1): - client = rs_or_single_client() + client = self.rs_or_single_client() server = (client._get_topology()).select_server(readable_server_selector, _Op.TEST) self.assertEqual(0, len(server._pool.conns)) # Assert that pool started up at minPoolSize - client = rs_or_single_client(minPoolSize=10) + client = self.rs_or_single_client(minPoolSize=10) server = (client._get_topology()).select_server(readable_server_selector, _Op.TEST) wait_until( lambda: len(server._pool.conns) == 10, @@ -731,7 +730,7 @@ def test_min_pool_size(self): def test_max_idle_time_checkout(self): # Use high frequency to test _get_socket_no_auth. with client_knobs(kill_cursor_frequency=99999999): - client = rs_or_single_client(maxIdleTimeMS=500) + client = self.rs_or_single_client(maxIdleTimeMS=500) server = (client._get_topology()).select_server(readable_server_selector, _Op.TEST) with server._pool.checkout() as conn: pass @@ -745,7 +744,7 @@ def test_max_idle_time_checkout(self): self.assertTrue(new_con in server._pool.conns) # Test that connections are reused if maxIdleTimeMS is not set. - client = rs_or_single_client() + client = self.rs_or_single_client() server = (client._get_topology()).select_server(readable_server_selector, _Op.TEST) with server._pool.checkout() as conn: pass @@ -769,36 +768,38 @@ def test_constants(self): MongoClient.HOST = "somedomainthatdoesntexist.org" MongoClient.PORT = 123456789 with self.assertRaises(AutoReconnect): - connected(MongoClient(serverSelectionTimeoutMS=10, **kwargs)) + c = self.simple_client(serverSelectionTimeoutMS=10, **kwargs) + connected(c) + c = self.simple_client(host, port, **kwargs) # Override the defaults. No error. - connected(MongoClient(host, port, **kwargs)) + connected(c) # Set good defaults. MongoClient.HOST = host MongoClient.PORT = port # No error. - connected(MongoClient(**kwargs)) + c = self.simple_client(**kwargs) + connected(c) def test_init_disconnected(self): host, port = client_context.host, client_context.port - c = rs_or_single_client(connect=False) + c = self.rs_or_single_client(connect=False) # is_primary causes client to block until connected self.assertIsInstance(c.is_primary, bool) - - c = rs_or_single_client(connect=False) + c = self.rs_or_single_client(connect=False) self.assertIsInstance(c.is_mongos, bool) - c = rs_or_single_client(connect=False) + c = self.rs_or_single_client(connect=False) self.assertIsInstance(c.options.pool_options.max_pool_size, int) self.assertIsInstance(c.nodes, frozenset) - c = rs_or_single_client(connect=False) + c = self.rs_or_single_client(connect=False) self.assertEqual(c.codec_options, CodecOptions()) - c = rs_or_single_client(connect=False) + c = self.rs_or_single_client(connect=False) self.assertFalse(c.primary) self.assertFalse(c.secondaries) - c = rs_or_single_client(connect=False) + c = self.rs_or_single_client(connect=False) self.assertIsInstance(c.topology_description, TopologyDescription) self.assertEqual(c.topology_description, c._topology._description) self.assertIsNone(c.address) # PYTHON-2981 @@ -810,43 +811,44 @@ def test_init_disconnected(self): self.assertEqual(c.address, (host, port)) bad_host = "somedomainthatdoesntexist.org" - c = MongoClient(bad_host, port, connectTimeoutMS=1, serverSelectionTimeoutMS=10) + c = self.simple_client(bad_host, port, connectTimeoutMS=1, serverSelectionTimeoutMS=10) with self.assertRaises(ConnectionFailure): c.pymongo_test.test.find_one() def test_init_disconnected_with_auth(self): uri = "mongodb://user:pass@somedomainthatdoesntexist" - c = MongoClient(uri, connectTimeoutMS=1, serverSelectionTimeoutMS=10) + c = self.simple_client(uri, connectTimeoutMS=1, serverSelectionTimeoutMS=10) with self.assertRaises(ConnectionFailure): c.pymongo_test.test.find_one() def test_equality(self): seed = "{}:{}".format(*list(self.client._topology_settings.seeds)[0]) - c = rs_or_single_client(seed, connect=False) - self.addCleanup(c.close) + c = self.rs_or_single_client(seed, connect=False) self.assertEqual(client_context.client, c) # Explicitly test inequality self.assertFalse(client_context.client != c) - c = rs_or_single_client("invalid.com", connect=False) - self.addCleanup(c.close) + c = self.rs_or_single_client("invalid.com", connect=False) self.assertNotEqual(client_context.client, c) self.assertTrue(client_context.client != c) + + c1 = self.simple_client("a", connect=False) + c2 = self.simple_client("b", connect=False) + # Seeds differ: - self.assertNotEqual(MongoClient("a", connect=False), MongoClient("b", connect=False)) + self.assertNotEqual(c1, c2) + + c1 = self.simple_client(["a", "b", "c"], connect=False) + c2 = self.simple_client(["c", "a", "b"], connect=False) + # Same seeds but out of order still compares equal: - self.assertEqual( - MongoClient(["a", "b", "c"], connect=False), - MongoClient(["c", "a", "b"], connect=False), - ) + self.assertEqual(c1, c2) def test_hashable(self): seed = "{}:{}".format(*list(self.client._topology_settings.seeds)[0]) - c = rs_or_single_client(seed, connect=False) - self.addCleanup(c.close) + c = self.rs_or_single_client(seed, connect=False) self.assertIn(c, {client_context.client}) - c = rs_or_single_client("invalid.com", connect=False) - self.addCleanup(c.close) + c = self.rs_or_single_client("invalid.com", connect=False) self.assertNotIn(c, {client_context.client}) def test_host_w_port(self): @@ -879,9 +881,10 @@ def test_repr(self): self.assertIn("w=1", the_repr) self.assertIn("wtimeoutms=100", the_repr) - self.assertEqual(eval(the_repr), client) + with eval(the_repr) as client_two: + self.assertEqual(client_two, client) - client = MongoClient( + client = self.simple_client( "localhost:27017,localhost:27018", replicaSet="replset", connectTimeoutMS=12345, @@ -899,7 +902,8 @@ def test_repr(self): self.assertIn("w=1", the_repr) self.assertIn("wtimeoutms=100", the_repr) - self.assertEqual(eval(the_repr), client) + with eval(the_repr) as client_two: + self.assertEqual(client_two, client) def test_getters(self): wait_until(lambda: client_context.nodes == self.client.nodes, "find all nodes") @@ -915,8 +919,7 @@ def test_list_databases(self): for helper_doc, cmd_doc in zip(helper_docs, cmd_docs): self.assertIs(type(helper_doc), dict) self.assertEqual(helper_doc.keys(), cmd_doc.keys()) - client = rs_or_single_client(document_class=SON) - self.addCleanup(client.close) + client = self.rs_or_single_client(document_class=SON) for doc in client.list_databases(): self.assertIs(type(doc), dict) @@ -955,7 +958,7 @@ def test_drop_database(self): self.client.drop_database("pymongo_test") if client_context.is_rs: - wc_client = rs_or_single_client(w=len(client_context.nodes) + 1) + wc_client = self.rs_or_single_client(w=len(client_context.nodes) + 1) with self.assertRaises(WriteConcernError): wc_client.drop_database("pymongo_test2") @@ -965,7 +968,7 @@ def test_drop_database(self): self.assertNotIn("pymongo_test2", dbs) def test_close(self): - test_client = rs_or_single_client() + test_client = self.rs_or_single_client() coll = test_client.pymongo_test.bar test_client.close() with self.assertRaises(InvalidOperation): @@ -975,7 +978,7 @@ def test_close_kills_cursors(self): if sys.platform.startswith("java"): # We can't figure out how to make this test reliable with Jython. raise SkipTest("Can't test with Jython") - test_client = rs_or_single_client() + test_client = self.rs_or_single_client() # Kill any cursors possibly queued up by previous tests. gc.collect() test_client._process_periodic_tasks() @@ -1002,13 +1005,13 @@ def test_close_kills_cursors(self): self.assertTrue(test_client._topology._opened) test_client.close() self.assertFalse(test_client._topology._opened) - test_client = rs_or_single_client() + test_client = self.rs_or_single_client() # The killCursors task should not need to re-open the topology. test_client._process_periodic_tasks() self.assertTrue(test_client._topology._opened) def test_close_stops_kill_cursors_thread(self): - client = rs_client() + client = self.rs_client() client.test.test.find_one() self.assertFalse(client._kill_cursors_executor._stopped) @@ -1024,7 +1027,7 @@ def test_close_stops_kill_cursors_thread(self): def test_uri_connect_option(self): # Ensure that topology is not opened if connect=False. - client = rs_client(connect=False) + client = self.rs_client(connect=False) self.assertFalse(client._topology._opened) # Ensure kill cursors thread has not been started. @@ -1037,19 +1040,15 @@ def test_uri_connect_option(self): kc_thread = client._kill_cursors_executor._thread self.assertTrue(kc_thread and kc_thread.is_alive()) - # Tear down. - client.close() - def test_close_does_not_open_servers(self): - client = rs_client(connect=False) + client = self.rs_client(connect=False) topology = client._topology self.assertEqual(topology._servers, {}) client.close() self.assertEqual(topology._servers, {}) def test_close_closes_sockets(self): - client = rs_client() - self.addCleanup(client.close) + client = self.rs_client() client.test.test.find_one() topology = client._topology client.close() @@ -1075,30 +1074,30 @@ def test_auth_from_uri(self): client_context.create_user("pymongo_test", "user", "pass", roles=["userAdmin", "readWrite"]) with self.assertRaises(OperationFailure): - connected(rs_or_single_client_noauth("mongodb://a:b@%s:%d" % (host, port))) + connected(self.rs_or_single_client_noauth("mongodb://a:b@%s:%d" % (host, port))) # No error. - connected(rs_or_single_client_noauth("mongodb://admin:pass@%s:%d" % (host, port))) + connected(self.rs_or_single_client_noauth("mongodb://admin:pass@%s:%d" % (host, port))) # Wrong database. uri = "mongodb://admin:pass@%s:%d/pymongo_test" % (host, port) with self.assertRaises(OperationFailure): - connected(rs_or_single_client_noauth(uri)) + connected(self.rs_or_single_client_noauth(uri)) # No error. connected( - rs_or_single_client_noauth("mongodb://user:pass@%s:%d/pymongo_test" % (host, port)) + self.rs_or_single_client_noauth("mongodb://user:pass@%s:%d/pymongo_test" % (host, port)) ) # Auth with lazy connection. ( - rs_or_single_client_noauth( + self.rs_or_single_client_noauth( "mongodb://user:pass@%s:%d/pymongo_test" % (host, port), connect=False ) ).pymongo_test.test.find_one() # Wrong password. - bad_client = rs_or_single_client_noauth( + bad_client = self.rs_or_single_client_noauth( "mongodb://user:wrong@%s:%d/pymongo_test" % (host, port), connect=False ) @@ -1110,7 +1109,7 @@ def test_username_and_password(self): client_context.create_user("admin", "ad min", "pa/ss") self.addCleanup(client_context.drop_user, "admin", "ad min") - c = rs_or_single_client_noauth(username="ad min", password="pa/ss") + c = self.rs_or_single_client_noauth(username="ad min", password="pa/ss") # Username and password aren't in strings that will likely be logged. self.assertNotIn("ad min", repr(c)) @@ -1122,13 +1121,13 @@ def test_username_and_password(self): c.server_info() with self.assertRaises(OperationFailure): - (rs_or_single_client_noauth(username="ad min", password="foo")).server_info() + (self.rs_or_single_client_noauth(username="ad min", password="foo")).server_info() @client_context.require_auth @client_context.require_no_fips def test_lazy_auth_raises_operation_failure(self): host = client_context.host - lazy_client = rs_or_single_client_noauth( + lazy_client = self.rs_or_single_client_noauth( f"mongodb://user:wrong@{host}/pymongo_test", connect=False ) @@ -1146,8 +1145,7 @@ def test_unix_socket(self): uri = "mongodb://%s" % encoded_socket # Confirm we can do operations via the socket. - client = rs_or_single_client(uri) - self.addCleanup(client.close) + client = self.rs_or_single_client(uri) client.pymongo_test.test.insert_one({"dummy": "object"}) dbs = client.list_database_names() self.assertTrue("pymongo_test" in dbs) @@ -1156,9 +1154,10 @@ def test_unix_socket(self): # Confirm it fails with a missing socket. with self.assertRaises(ConnectionFailure): - connected( - MongoClient("mongodb://%2Ftmp%2Fnon-existent.sock", serverSelectionTimeoutMS=100), + c = self.simple_client( + "mongodb://%2Ftmp%2Fnon-existent.sock", serverSelectionTimeoutMS=100 ) + connected(c) def test_document_class(self): c = self.client @@ -1169,15 +1168,15 @@ def test_document_class(self): self.assertTrue(isinstance(db.test.find_one(), dict)) self.assertFalse(isinstance(db.test.find_one(), SON)) - c = rs_or_single_client(document_class=SON) - self.addCleanup(c.close) + c = self.rs_or_single_client(document_class=SON) + db = c.pymongo_test self.assertEqual(SON, c.codec_options.document_class) self.assertTrue(isinstance(db.test.find_one(), SON)) def test_timeouts(self): - client = rs_or_single_client( + client = self.rs_or_single_client( connectTimeoutMS=10500, socketTimeoutMS=10500, maxIdleTimeMS=10500, @@ -1190,28 +1189,31 @@ def test_timeouts(self): self.assertEqual(10.5, client.options.server_selection_timeout) def test_socket_timeout_ms_validation(self): - c = rs_or_single_client(socketTimeoutMS=10 * 1000) + c = self.rs_or_single_client(socketTimeoutMS=10 * 1000) self.assertEqual(10, (get_pool(c)).opts.socket_timeout) - c = connected(rs_or_single_client(socketTimeoutMS=None)) + c = connected(self.rs_or_single_client(socketTimeoutMS=None)) self.assertEqual(None, (get_pool(c)).opts.socket_timeout) - c = connected(rs_or_single_client(socketTimeoutMS=0)) + c = connected(self.rs_or_single_client(socketTimeoutMS=0)) self.assertEqual(None, (get_pool(c)).opts.socket_timeout) with self.assertRaises(ValueError): - rs_or_single_client(socketTimeoutMS=-1) + with self.rs_or_single_client(socketTimeoutMS=-1): + pass with self.assertRaises(ValueError): - rs_or_single_client(socketTimeoutMS=1e10) + with self.rs_or_single_client(socketTimeoutMS=1e10): + pass with self.assertRaises(ValueError): - rs_or_single_client(socketTimeoutMS="foo") + with self.rs_or_single_client(socketTimeoutMS="foo"): + pass def test_socket_timeout(self): no_timeout = self.client timeout_sec = 1 - timeout = rs_or_single_client(socketTimeoutMS=1000 * timeout_sec) + timeout = self.rs_or_single_client(socketTimeoutMS=1000 * timeout_sec) self.addCleanup(timeout.close) no_timeout.pymongo_test.drop_collection("test") @@ -1256,7 +1258,7 @@ def test_server_selection_timeout(self): self.assertAlmostEqual(30, client.options.server_selection_timeout) def test_waitQueueTimeoutMS(self): - client = rs_or_single_client(waitQueueTimeoutMS=2000) + client = self.rs_or_single_client(waitQueueTimeoutMS=2000) self.assertEqual((get_pool(client)).opts.wait_queue_timeout, 2) def test_socketKeepAlive(self): @@ -1269,7 +1271,7 @@ def test_socketKeepAlive(self): def test_tz_aware(self): self.assertRaises(ValueError, MongoClient, tz_aware="foo") - aware = rs_or_single_client(tz_aware=True) + aware = self.rs_or_single_client(tz_aware=True) self.addCleanup(aware.close) naive = self.client aware.pymongo_test.drop_collection("test") @@ -1299,8 +1301,7 @@ def test_ipv6(self): if client_context.is_rs: uri += "/?replicaSet=" + (client_context.replica_set_name or "") - client = rs_or_single_client_noauth(uri) - self.addCleanup(client.close) + client = self.rs_or_single_client_noauth(uri) client.pymongo_test.test.insert_one({"dummy": "object"}) client.pymongo_test_bernie.test.insert_one({"dummy": "object"}) @@ -1309,7 +1310,7 @@ def test_ipv6(self): self.assertTrue("pymongo_test_bernie" in dbs) def test_contextlib(self): - client = rs_or_single_client() + client = self.rs_or_single_client() client.pymongo_test.drop_collection("test") client.pymongo_test.test.insert_one({"foo": "bar"}) @@ -1323,7 +1324,7 @@ def test_contextlib(self): self.assertEqual("bar", (client.pymongo_test.test.find_one())["foo"]) with self.assertRaises(InvalidOperation): client.pymongo_test.test.find_one() - client = rs_or_single_client() + client = self.rs_or_single_client() with client as client: self.assertEqual("bar", (client.pymongo_test.test.find_one())["foo"]) with self.assertRaises(InvalidOperation): @@ -1401,8 +1402,7 @@ def test_operation_failure(self): # response to getLastError. PYTHON-395. We need a new client here # to avoid race conditions caused by replica set failover or idle # socket reaping. - client = single_client() - self.addCleanup(client.close) + client = self.single_client() client.pymongo_test.test.find_one() pool = get_pool(client) socket_count = len(pool.conns) @@ -1426,8 +1426,7 @@ def test_lazy_connect_w0(self): client_context.client.drop_database("test_lazy_connect_w0") self.addCleanup(client_context.client.drop_database, "test_lazy_connect_w0") - client = rs_or_single_client(connect=False, w=0) - self.addCleanup(client.close) + client = self.rs_or_single_client(connect=False, w=0) client.test_lazy_connect_w0.test.insert_one({}) def predicate(): @@ -1435,8 +1434,7 @@ def predicate(): wait_until(predicate, "find one document") - client = rs_or_single_client(connect=False, w=0) - self.addCleanup(client.close) + client = self.rs_or_single_client(connect=False, w=0) client.test_lazy_connect_w0.test.update_one({}, {"$set": {"x": 1}}) def predicate(): @@ -1444,8 +1442,7 @@ def predicate(): wait_until(predicate, "update one document") - client = rs_or_single_client(connect=False, w=0) - self.addCleanup(client.close) + client = self.rs_or_single_client(connect=False, w=0) client.test_lazy_connect_w0.test.delete_one({}) def predicate(): @@ -1457,8 +1454,7 @@ def predicate(): def test_exhaust_network_error(self): # When doing an exhaust query, the socket stays checked out on success # but must be checked in on error to avoid semaphore leaks. - client = rs_or_single_client(maxPoolSize=1, retryReads=False) - self.addCleanup(client.close) + client = self.rs_or_single_client(maxPoolSize=1, retryReads=False) collection = client.pymongo_test.test pool = get_pool(client) pool._check_interval_seconds = None # Never check. @@ -1484,7 +1480,9 @@ def test_auth_network_error(self): # when authenticating a new socket with cached credentials. # Get a client with one socket so we detect if it's leaked. - c = connected(rs_or_single_client(maxPoolSize=1, waitQueueTimeoutMS=1, retryReads=False)) + c = connected( + self.rs_or_single_client(maxPoolSize=1, waitQueueTimeoutMS=1, retryReads=False) + ) # Cause a network error on the actual socket. pool = get_pool(c) @@ -1501,8 +1499,7 @@ def test_auth_network_error(self): @client_context.require_no_replica_set def test_connect_to_standalone_using_replica_set_name(self): - client = single_client(replicaSet="anything", serverSelectionTimeoutMS=100) - + client = self.single_client(replicaSet="anything", serverSelectionTimeoutMS=100) with self.assertRaises(AutoReconnect): client.test.test.find_one() @@ -1512,7 +1509,7 @@ def test_stale_getmore(self): # the topology before the getMore message is sent. Test that # MongoClient._run_operation_with_response handles the error. with self.assertRaises(AutoReconnect): - client = rs_client(connect=False, serverSelectionTimeoutMS=100) + client = self.rs_client(connect=False, serverSelectionTimeoutMS=100) client._run_operation( operation=message._GetMore( "pymongo_test", @@ -1560,7 +1557,7 @@ def init(self, *args): client_context.host, client_context.port, ) - client = single_client(uri, event_listeners=[listener]) + self.single_client(uri, event_listeners=[listener]) wait_until( lambda: len(listener.results) >= 2, "record two ServerHeartbeatStartedEvents" ) @@ -1569,7 +1566,6 @@ def init(self, *args): # closer to 0.5 sec with heartbeatFrequencyMS configured. self.assertAlmostEqual(heartbeat_times[1] - heartbeat_times[0], 0.5, delta=2) - client.close() finally: ServerHeartbeatStartedEvent.__init__ = old_init # type: ignore @@ -1586,31 +1582,31 @@ def compression_settings(client): return pool_options._compression_settings uri = "mongodb://localhost:27017/?compressors=zlib" - client = MongoClient(uri, connect=False) + client = self.simple_client(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, ["zlib"]) uri = "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=4" - client = MongoClient(uri, connect=False) + client = self.simple_client(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, ["zlib"]) self.assertEqual(opts.zlib_compression_level, 4) uri = "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=-1" - client = MongoClient(uri, connect=False) + client = self.simple_client(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, ["zlib"]) self.assertEqual(opts.zlib_compression_level, -1) uri = "mongodb://localhost:27017" - client = MongoClient(uri, connect=False) + client = self.simple_client(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, []) self.assertEqual(opts.zlib_compression_level, -1) uri = "mongodb://localhost:27017/?compressors=foobar" - client = MongoClient(uri, connect=False) + client = self.simple_client(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, []) self.assertEqual(opts.zlib_compression_level, -1) uri = "mongodb://localhost:27017/?compressors=foobar,zlib" - client = MongoClient(uri, connect=False) + client = self.simple_client(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, ["zlib"]) self.assertEqual(opts.zlib_compression_level, -1) @@ -1618,56 +1614,55 @@ def compression_settings(client): # According to the connection string spec, unsupported values # just raise a warning and are ignored. uri = "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=10" - client = MongoClient(uri, connect=False) + client = self.simple_client(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, ["zlib"]) self.assertEqual(opts.zlib_compression_level, -1) uri = "mongodb://localhost:27017/?compressors=zlib&zlibCompressionLevel=-2" - client = MongoClient(uri, connect=False) + client = self.simple_client(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, ["zlib"]) self.assertEqual(opts.zlib_compression_level, -1) if not _have_snappy(): uri = "mongodb://localhost:27017/?compressors=snappy" - client = MongoClient(uri, connect=False) + client = self.simple_client(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, []) else: uri = "mongodb://localhost:27017/?compressors=snappy" - client = MongoClient(uri, connect=False) + client = self.simple_client(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, ["snappy"]) uri = "mongodb://localhost:27017/?compressors=snappy,zlib" - client = MongoClient(uri, connect=False) + client = self.simple_client(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, ["snappy", "zlib"]) if not _have_zstd(): uri = "mongodb://localhost:27017/?compressors=zstd" - client = MongoClient(uri, connect=False) + client = self.simple_client(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, []) else: uri = "mongodb://localhost:27017/?compressors=zstd" - client = MongoClient(uri, connect=False) + client = self.simple_client(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, ["zstd"]) uri = "mongodb://localhost:27017/?compressors=zstd,zlib" - client = MongoClient(uri, connect=False) + client = self.simple_client(uri, connect=False) opts = compression_settings(client) self.assertEqual(opts.compressors, ["zstd", "zlib"]) options = client_context.default_client_options if "compressors" in options and "zlib" in options["compressors"]: for level in range(-1, 10): - client = single_client(zlibcompressionlevel=level) + client = self.single_client(zlibcompressionlevel=level) # No error client.pymongo_test.test.find_one() def test_reset_during_update_pool(self): - client = rs_or_single_client(minPoolSize=10) - self.addCleanup(client.close) + client = self.rs_or_single_client(minPoolSize=10) client.admin.command("ping") pool = get_pool(client) generation = pool.gen.get_overall() @@ -1713,11 +1708,9 @@ def run(self): def test_background_connections_do_not_hold_locks(self): min_pool_size = 10 - client = rs_or_single_client( + client = self.rs_or_single_client( serverSelectionTimeoutMS=3000, minPoolSize=min_pool_size, connect=False ) - self.addCleanup(client.close) - # Create a single connection in the pool. client.admin.command("ping") @@ -1747,21 +1740,19 @@ def stall_connect(*args, **kwargs): @client_context.require_replica_set def test_direct_connection(self): # direct_connection=True should result in Single topology. - client = rs_or_single_client(directConnection=True) + client = self.rs_or_single_client(directConnection=True) client.admin.command("ping") self.assertEqual(len(client.nodes), 1) self.assertEqual(client._topology_settings.get_topology_type(), TOPOLOGY_TYPE.Single) - client.close() # direct_connection=False should result in RS topology. - client = rs_or_single_client(directConnection=False) + client = self.rs_or_single_client(directConnection=False) client.admin.command("ping") self.assertGreaterEqual(len(client.nodes), 1) self.assertIn( client._topology_settings.get_topology_type(), [TOPOLOGY_TYPE.ReplicaSetNoPrimary, TOPOLOGY_TYPE.ReplicaSetWithPrimary], ) - client.close() # directConnection=True, should error with multiple hosts as a list. with self.assertRaises(ConfigurationError): @@ -1781,11 +1772,10 @@ def server_description_count(): gc.collect() with client_knobs(min_heartbeat_interval=0.003): - client = MongoClient( + client = self.simple_client( "invalid:27017", heartbeatFrequencyMS=3, serverSelectionTimeoutMS=150 ) initial_count = server_description_count() - self.addCleanup(client.close) with self.assertRaises(ServerSelectionTimeoutError): client.test.test.find_one() gc.collect() @@ -1798,8 +1788,7 @@ def server_description_count(): @client_context.require_failCommand_fail_point def test_network_error_message(self): - client = single_client(retryReads=False) - self.addCleanup(client.close) + client = self.single_client(retryReads=False) client.admin.command("ping") # connect with self.fail_point( {"mode": {"times": 1}, "data": {"closeConnection": True, "failCommands": ["find"]}} @@ -1811,7 +1800,7 @@ def test_network_error_message(self): @unittest.skipIf("PyPy" in sys.version, "PYTHON-2938 could fail on PyPy") def test_process_periodic_tasks(self): - client = rs_or_single_client() + client = self.rs_or_single_client() coll = client.db.collection coll.insert_many([{} for _ in range(5)]) cursor = coll.find(batch_size=2) @@ -1850,11 +1839,11 @@ def test_service_name_from_kwargs(self): self.assertEqual(client._topology_settings.srv_service_name, "customname") def test_srv_max_hosts_kwarg(self): - client = MongoClient("mongodb+srv://test1.test.build.10gen.cc/") + client = self.simple_client("mongodb+srv://test1.test.build.10gen.cc/") self.assertGreater(len(client.topology_description.server_descriptions()), 1) - client = MongoClient("mongodb+srv://test1.test.build.10gen.cc/", srvmaxhosts=1) + client = self.simple_client("mongodb+srv://test1.test.build.10gen.cc/", srvmaxhosts=1) self.assertEqual(len(client.topology_description.server_descriptions()), 1) - client = MongoClient( + client = self.simple_client( "mongodb+srv://test1.test.build.10gen.cc/?srvMaxHosts=1", srvmaxhosts=2 ) self.assertEqual(len(client.topology_description.server_descriptions()), 2) @@ -1902,10 +1891,10 @@ def _test_handshake(self, env_vars, expected_env): if "AWS_REGION" not in env_vars: os.environ["AWS_REGION"] = "" - with rs_or_single_client(serverSelectionTimeoutMS=10000) as client: - client.admin.command("ping") - options = client.options - self.assertEqual(options.pool_options.metadata, metadata) + client = self.rs_or_single_client(serverSelectionTimeoutMS=10000) + client.admin.command("ping") + options = client.options + self.assertEqual(options.pool_options.metadata, metadata) def test_handshake_01_aws(self): self._test_handshake( @@ -2001,7 +1990,7 @@ def setUp(self): def test_exhaust_query_server_error(self): # When doing an exhaust query, the socket stays checked out on success # but must be checked in on error to avoid semaphore leaks. - client = connected(rs_or_single_client(maxPoolSize=1)) + client = connected(self.rs_or_single_client(maxPoolSize=1)) collection = client.pymongo_test.test pool = get_pool(client) @@ -2024,7 +2013,7 @@ def test_exhaust_query_server_error(self): def test_exhaust_getmore_server_error(self): # When doing a getmore on an exhaust cursor, the socket stays checked # out on success but it's checked in on error to avoid semaphore leaks. - client = rs_or_single_client(maxPoolSize=1) + client = self.rs_or_single_client(maxPoolSize=1) collection = client.pymongo_test.test collection.drop() @@ -2063,7 +2052,7 @@ def receive_message(request_id): def test_exhaust_query_network_error(self): # When doing an exhaust query, the socket stays checked out on success # but must be checked in on error to avoid semaphore leaks. - client = connected(rs_or_single_client(maxPoolSize=1, retryReads=False)) + client = connected(self.rs_or_single_client(maxPoolSize=1, retryReads=False)) collection = client.pymongo_test.test pool = get_pool(client) pool._check_interval_seconds = None # Never check. @@ -2084,7 +2073,7 @@ def test_exhaust_query_network_error(self): def test_exhaust_getmore_network_error(self): # When doing a getmore on an exhaust cursor, the socket stays checked # out on success but it's checked in on error to avoid semaphore leaks. - client = rs_or_single_client(maxPoolSize=1) + client = self.rs_or_single_client(maxPoolSize=1) collection = client.pymongo_test.test collection.drop() collection.insert_many([{} for _ in range(200)]) # More than one batch. @@ -2133,7 +2122,7 @@ def test_gevent_timeout(self): raise SkipTest("Must be running monkey patched by gevent") from gevent import Timeout, spawn - client = rs_or_single_client(maxPoolSize=1) + client = self.rs_or_single_client(maxPoolSize=1) coll = client.pymongo_test.test coll.insert_one({}) @@ -2165,7 +2154,7 @@ def test_gevent_timeout_when_creating_connection(self): raise SkipTest("Must be running monkey patched by gevent") from gevent import Timeout, spawn - client = rs_or_single_client() + client = self.rs_or_single_client() self.addCleanup(client.close) coll = client.pymongo_test.test pool = get_pool(client) @@ -2202,7 +2191,7 @@ class TestClientLazyConnect(IntegrationTest): """Test concurrent operations on a lazily-connecting MongoClient.""" def _get_client(self): - return rs_or_single_client(connect=False) + return self.rs_or_single_client(connect=False) @client_context.require_sync def test_insert_one(self): @@ -2336,6 +2325,7 @@ def _test_network_error(self, operation_callback): retryReads=False, serverSelectionTimeoutMS=1000, ) + self.addCleanup(c.close) # Set host-specific information so we can test whether it is reset. diff --git a/test/test_client_bulk_write.py b/test/test_client_bulk_write.py index ee19a04176..ebbdc74c1c 100644 --- a/test/test_client_bulk_write.py +++ b/test/test_client_bulk_write.py @@ -27,7 +27,6 @@ ) from test.utils import ( OvertCommandListener, - rs_or_single_client, ) from unittest.mock import patch @@ -38,7 +37,6 @@ InvalidOperation, NetworkTimeout, ) -from pymongo.monitoring import * from pymongo.operations import * from pymongo.synchronous.client_bulk import _ClientBulk from pymongo.write_concern import WriteConcern @@ -97,8 +95,7 @@ def setUp(self): @client_context.require_no_serverless def test_batch_splits_if_num_operations_too_large(self): listener = OvertCommandListener() - client = rs_or_single_client(event_listeners=[listener]) - self.addCleanup(client.close) + client = self.rs_or_single_client(event_listeners=[listener]) models = [] for _ in range(self.max_write_batch_size + 1): @@ -123,8 +120,7 @@ def test_batch_splits_if_num_operations_too_large(self): @client_context.require_no_serverless def test_batch_splits_if_ops_payload_too_large(self): listener = OvertCommandListener() - client = rs_or_single_client(event_listeners=[listener]) - self.addCleanup(client.close) + client = self.rs_or_single_client(event_listeners=[listener]) models = [] num_models = int(self.max_message_size_bytes / self.max_bson_object_size + 1) @@ -157,11 +153,10 @@ def test_batch_splits_if_ops_payload_too_large(self): @client_context.require_failCommand_fail_point def test_collects_write_concern_errors_across_batches(self): listener = OvertCommandListener() - client = rs_or_single_client( + client = self.rs_or_single_client( event_listeners=[listener], retryWrites=False, ) - self.addCleanup(client.close) fail_command = { "configureFailPoint": "failCommand", @@ -200,8 +195,7 @@ def test_collects_write_concern_errors_across_batches(self): @client_context.require_no_serverless def test_collects_write_errors_across_batches_unordered(self): listener = OvertCommandListener() - client = rs_or_single_client(event_listeners=[listener]) - self.addCleanup(client.close) + client = self.rs_or_single_client(event_listeners=[listener]) collection = client.db["coll"] self.addCleanup(collection.drop) @@ -231,8 +225,7 @@ def test_collects_write_errors_across_batches_unordered(self): @client_context.require_no_serverless def test_collects_write_errors_across_batches_ordered(self): listener = OvertCommandListener() - client = rs_or_single_client(event_listeners=[listener]) - self.addCleanup(client.close) + client = self.rs_or_single_client(event_listeners=[listener]) collection = client.db["coll"] self.addCleanup(collection.drop) @@ -262,8 +255,7 @@ def test_collects_write_errors_across_batches_ordered(self): @client_context.require_no_serverless def test_handles_cursor_requiring_getMore(self): listener = OvertCommandListener() - client = rs_or_single_client(event_listeners=[listener]) - self.addCleanup(client.close) + client = self.rs_or_single_client(event_listeners=[listener]) collection = client.db["coll"] self.addCleanup(collection.drop) @@ -304,8 +296,7 @@ def test_handles_cursor_requiring_getMore(self): @client_context.require_no_standalone def test_handles_cursor_requiring_getMore_within_transaction(self): listener = OvertCommandListener() - client = rs_or_single_client(event_listeners=[listener]) - self.addCleanup(client.close) + client = self.rs_or_single_client(event_listeners=[listener]) collection = client.db["coll"] self.addCleanup(collection.drop) @@ -348,8 +339,7 @@ def test_handles_cursor_requiring_getMore_within_transaction(self): @client_context.require_failCommand_fail_point def test_handles_getMore_error(self): listener = OvertCommandListener() - client = rs_or_single_client(event_listeners=[listener]) - self.addCleanup(client.close) + client = self.rs_or_single_client(event_listeners=[listener]) collection = client.db["coll"] self.addCleanup(collection.drop) @@ -403,8 +393,7 @@ def test_handles_getMore_error(self): @client_context.require_no_serverless def test_returns_error_if_unacknowledged_too_large_insert(self): listener = OvertCommandListener() - client = rs_or_single_client(event_listeners=[listener]) - self.addCleanup(client.close) + client = self.rs_or_single_client(event_listeners=[listener]) b_repeated = "b" * self.max_bson_object_size @@ -460,8 +449,7 @@ def _setup_namespace_test_models(self): @client_context.require_no_serverless def test_no_batch_splits_if_new_namespace_is_not_too_large(self): listener = OvertCommandListener() - client = rs_or_single_client(event_listeners=[listener]) - self.addCleanup(client.close) + client = self.rs_or_single_client(event_listeners=[listener]) num_models, models = self._setup_namespace_test_models() models.append( @@ -492,8 +480,7 @@ def test_no_batch_splits_if_new_namespace_is_not_too_large(self): @client_context.require_no_serverless def test_batch_splits_if_new_namespace_is_too_large(self): listener = OvertCommandListener() - client = rs_or_single_client(event_listeners=[listener]) - self.addCleanup(client.close) + client = self.rs_or_single_client(event_listeners=[listener]) num_models, models = self._setup_namespace_test_models() c_repeated = "c" * 200 @@ -530,8 +517,7 @@ def test_batch_splits_if_new_namespace_is_too_large(self): @client_context.require_version_min(8, 0, 0, -24) @client_context.require_no_serverless def test_returns_error_if_no_writes_can_be_added_to_ops(self): - client = rs_or_single_client() - self.addCleanup(client.close) + client = self.rs_or_single_client() # Document too large. b_repeated = "b" * self.max_message_size_bytes @@ -554,8 +540,7 @@ def test_returns_error_if_auto_encryption_configured(self): key_vault_namespace="db.coll", kms_providers={"aws": {"accessKeyId": "foo", "secretAccessKey": "bar"}}, ) - client = rs_or_single_client(auto_encryption_opts=opts) - self.addCleanup(client.close) + client = self.rs_or_single_client(auto_encryption_opts=opts) models = [InsertOne(namespace="db.coll", document={"a": "b"})] with self.assertRaises(InvalidOperation) as context: @@ -580,7 +565,7 @@ def setUp(self): def test_timeout_in_multi_batch_bulk_write(self): _OVERHEAD = 500 - internal_client = rs_or_single_client(timeoutMS=None) + internal_client = self.rs_or_single_client(timeoutMS=None) self.addCleanup(internal_client.close) collection = internal_client.db["coll"] @@ -605,14 +590,13 @@ def test_timeout_in_multi_batch_bulk_write(self): ) listener = OvertCommandListener() - client = rs_or_single_client( + client = self.rs_or_single_client( event_listeners=[listener], readConcernLevel="majority", readPreference="primary", timeoutMS=2000, w="majority", ) - self.addCleanup(client.close) client.admin.command("ping") # Init the client first. with self.assertRaises(ClientBulkWriteException) as context: client.bulk_write(models=models) diff --git a/test/test_collation.py b/test/test_collation.py index bedf0a2eaa..19df25c1c0 100644 --- a/test/test_collation.py +++ b/test/test_collation.py @@ -18,7 +18,7 @@ import functools import warnings from test import IntegrationTest, client_context, unittest -from test.utils import EventListener, rs_or_single_client +from test.utils import EventListener from typing import Any from pymongo.collation import ( @@ -99,7 +99,7 @@ class TestCollation(IntegrationTest): def setUpClass(cls): super().setUpClass() cls.listener = EventListener() - cls.client = rs_or_single_client(event_listeners=[cls.listener]) + cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener]) cls.db = cls.client.pymongo_test cls.collation = Collation("en_US") cls.warn_context = warnings.catch_warnings() diff --git a/test/test_collection.py b/test/test_collection.py index b68aa74f73..dab59cf1b2 100644 --- a/test/test_collection.py +++ b/test/test_collection.py @@ -29,6 +29,7 @@ from test import ( # TODO: fix sync imports in PYTHON-4528 IntegrationTest, + UnitTest, client_context, unittest, ) @@ -37,8 +38,6 @@ EventListener, get_pool, is_mongos, - rs_or_single_client, - single_client, wait_until, ) @@ -81,14 +80,20 @@ _IS_SYNC = True -class TestCollectionNoConnect(unittest.TestCase): +class TestCollectionNoConnect(UnitTest): """Test Collection features on a client that does not connect.""" db: Database + client: MongoClient @classmethod - def setUpClass(cls): - cls.db = MongoClient(connect=False).pymongo_test + def _setup_class(cls): + cls.client = MongoClient(connect=False) + cls.db = cls.client.pymongo_test + + @classmethod + def _tearDown_class(cls): + cls.client.close() def test_collection(self): self.assertRaises(TypeError, Collection, self.db, 5) @@ -1800,8 +1805,7 @@ def test_exhaust(self): # Insert enough documents to require more than one batch self.db.test.insert_many([{"i": i} for i in range(150)]) - client = rs_or_single_client(maxPoolSize=1) - self.addCleanup(client.close) + client = self.rs_or_single_client(maxPoolSize=1) pool = get_pool(client) # Make sure the socket is returned after exhaustion. @@ -2077,7 +2081,7 @@ def test_find_one_and(self): def test_find_one_and_write_concern(self): listener = EventListener() - db = (single_client(event_listeners=[listener]))[self.db.name] + db = (self.single_client(event_listeners=[listener]))[self.db.name] # non-default WriteConcern. c_w0 = db.get_collection("test", write_concern=WriteConcern(w=0)) # default WriteConcern. diff --git a/test/test_comment.py b/test/test_comment.py index 931446ef3a..c0f037ea44 100644 --- a/test/test_comment.py +++ b/test/test_comment.py @@ -22,7 +22,7 @@ sys.path[0:0] = [""] from test import IntegrationTest, client_context, unittest -from test.utils import EventListener, rs_or_single_client +from test.utils import EventListener from bson.dbref import DBRef from pymongo.operations import IndexModel @@ -109,7 +109,7 @@ def _test_ops( @client_context.require_replica_set def test_database_helpers(self): listener = EventListener() - db = rs_or_single_client(event_listeners=[listener]).db + db = self.rs_or_single_client(event_listeners=[listener]).db helpers = [ (db.watch, []), (db.command, ["hello"]), @@ -126,7 +126,7 @@ def test_database_helpers(self): @client_context.require_replica_set def test_client_helpers(self): listener = EventListener() - cli = rs_or_single_client(event_listeners=[listener]) + cli = self.rs_or_single_client(event_listeners=[listener]) helpers = [ (cli.watch, []), (cli.list_databases, []), @@ -141,7 +141,7 @@ def test_client_helpers(self): @client_context.require_version_min(4, 7, -1) def test_collection_helpers(self): listener = EventListener() - db = rs_or_single_client(event_listeners=[listener])[self.db.name] + db = self.rs_or_single_client(event_listeners=[listener])[self.db.name] coll = db.get_collection("test") helpers = [ diff --git a/test/test_common.py b/test/test_common.py index 358cd29b81..3228dc97fb 100644 --- a/test/test_common.py +++ b/test/test_common.py @@ -21,7 +21,6 @@ sys.path[0:0] = [""] from test import IntegrationTest, client_context, connected, unittest -from test.utils import rs_or_single_client, single_client from bson.binary import PYTHON_LEGACY, STANDARD, Binary, UuidRepresentation from bson.codec_options import CodecOptions @@ -111,10 +110,10 @@ def test_uuid_representation(self): ) def test_write_concern(self): - c = rs_or_single_client(connect=False) + c = self.rs_or_single_client(connect=False) self.assertEqual(WriteConcern(), c.write_concern) - c = rs_or_single_client(connect=False, w=2, wTimeoutMS=1000) + c = self.rs_or_single_client(connect=False, w=2, wTimeoutMS=1000) wc = WriteConcern(w=2, wtimeout=1000) self.assertEqual(wc, c.write_concern) @@ -134,7 +133,7 @@ def test_write_concern(self): def test_mongo_client(self): pair = client_context.pair - m = rs_or_single_client(w=0) + m = self.rs_or_single_client(w=0) coll = m.pymongo_test.write_concern_test coll.drop() doc = {"_id": ObjectId()} @@ -143,17 +142,19 @@ def test_mongo_client(self): coll = coll.with_options(write_concern=WriteConcern(w=1)) self.assertRaises(OperationFailure, coll.insert_one, doc) - m = rs_or_single_client() + m = self.rs_or_single_client() coll = m.pymongo_test.write_concern_test new_coll = coll.with_options(write_concern=WriteConcern(w=0)) self.assertTrue(new_coll.insert_one(doc)) self.assertRaises(OperationFailure, coll.insert_one, doc) - m = rs_or_single_client(f"mongodb://{pair}/", replicaSet=client_context.replica_set_name) + m = self.rs_or_single_client( + f"mongodb://{pair}/", replicaSet=client_context.replica_set_name + ) coll = m.pymongo_test.write_concern_test self.assertRaises(OperationFailure, coll.insert_one, doc) - m = rs_or_single_client( + m = self.rs_or_single_client( f"mongodb://{pair}/?w=0", replicaSet=client_context.replica_set_name ) @@ -161,8 +162,8 @@ def test_mongo_client(self): coll.insert_one(doc) # Equality tests - direct = connected(single_client(w=0)) - direct2 = connected(single_client(f"mongodb://{pair}/?w=0", **self.credentials)) + direct = connected(self.single_client(w=0)) + direct2 = connected(self.single_client(f"mongodb://{pair}/?w=0", **self.credentials)) self.assertEqual(direct, direct2) self.assertFalse(direct != direct2) diff --git a/test/test_connection_monitoring.py b/test/test_connection_monitoring.py index 9ee3202e13..142af0f9a7 100644 --- a/test/test_connection_monitoring.py +++ b/test/test_connection_monitoring.py @@ -30,9 +30,6 @@ client_context, get_pool, get_pools, - rs_or_single_client, - single_client, - single_client_noauth, wait_until, ) from test.utils_spec_runner import SpecRunnerThread @@ -250,7 +247,7 @@ def run_scenario(self, scenario_def, test): else: kill_cursor_frequency = interval / 1000.0 with client_knobs(kill_cursor_frequency=kill_cursor_frequency, min_heartbeat_interval=0.05): - client = single_client(**opts) + client = self.single_client(**opts) # Update the SD to a known type because the DummyMonitor will not. # Note we cannot simply call topology.on_change because that would # internally call pool.ready() which introduces unexpected @@ -323,13 +320,13 @@ def cleanup(): # Prose tests. Numbers correspond to the prose test number in the spec. # def test_1_client_connection_pool_options(self): - client = rs_or_single_client(**self.POOL_OPTIONS) + client = self.rs_or_single_client(**self.POOL_OPTIONS) self.addCleanup(client.close) pool_opts = get_pool(client).opts self.assertEqual(pool_opts.non_default_options, self.POOL_OPTIONS) def test_2_all_client_pools_have_same_options(self): - client = rs_or_single_client(**self.POOL_OPTIONS) + client = self.rs_or_single_client(**self.POOL_OPTIONS) self.addCleanup(client.close) client.admin.command("ping") # Discover at least one secondary. @@ -345,14 +342,14 @@ def test_2_all_client_pools_have_same_options(self): def test_3_uri_connection_pool_options(self): opts = "&".join([f"{k}={v}" for k, v in self.POOL_OPTIONS.items()]) uri = f"mongodb://{client_context.pair}/?{opts}" - client = rs_or_single_client(uri) + client = self.rs_or_single_client(uri) self.addCleanup(client.close) pool_opts = get_pool(client).opts self.assertEqual(pool_opts.non_default_options, self.POOL_OPTIONS) def test_4_subscribe_to_events(self): listener = CMAPListener() - client = single_client(event_listeners=[listener]) + client = self.single_client(event_listeners=[listener]) self.addCleanup(client.close) self.assertEqual(listener.event_count(PoolCreatedEvent), 1) @@ -376,7 +373,7 @@ def test_4_subscribe_to_events(self): def test_5_check_out_fails_connection_error(self): listener = CMAPListener() - client = single_client(event_listeners=[listener]) + client = self.single_client(event_listeners=[listener]) self.addCleanup(client.close) pool = get_pool(client) @@ -403,7 +400,7 @@ def mock_connect(*args, **kwargs): @client_context.require_no_fips def test_5_check_out_fails_auth_error(self): listener = CMAPListener() - client = single_client_noauth( + client = self.single_client_noauth( username="notauser", password="fail", event_listeners=[listener] ) self.addCleanup(client.close) @@ -449,7 +446,7 @@ def test_events_repr(self): def test_close_leaves_pool_unpaused(self): listener = CMAPListener() - client = single_client(event_listeners=[listener]) + client = self.single_client(event_listeners=[listener]) client.admin.command("ping") pool = get_pool(client) client.close() diff --git a/test/test_connections_survive_primary_stepdown_spec.py b/test/test_connections_survive_primary_stepdown_spec.py index 674612693c..fba7675743 100644 --- a/test/test_connections_survive_primary_stepdown_spec.py +++ b/test/test_connections_survive_primary_stepdown_spec.py @@ -24,7 +24,6 @@ CMAPListener, ensure_all_connected, repl_set_step_down, - rs_or_single_client, ) from bson import SON @@ -43,7 +42,7 @@ class TestConnectionsSurvivePrimaryStepDown(IntegrationTest): def setUpClass(cls): super().setUpClass() cls.listener = CMAPListener() - cls.client = rs_or_single_client( + cls.client = cls.unmanaged_rs_or_single_client( event_listeners=[cls.listener], retryWrites=False, heartbeatFrequencyMS=500 ) diff --git a/test/test_cursor.py b/test/test_cursor.py index 8e6fade1ec..9bc22aca3c 100644 --- a/test/test_cursor.py +++ b/test/test_cursor.py @@ -35,7 +35,6 @@ EventListener, OvertCommandListener, ignore_deprecations, - rs_or_single_client, wait_until, ) @@ -230,7 +229,7 @@ def test_max_await_time_ms(self): self.assertEqual(90, cursor._max_await_time_ms) listener = AllowListEventListener("find", "getMore") - coll = (rs_or_single_client(event_listeners=[listener]))[self.db.name].pymongo_test + coll = (self.rs_or_single_client(event_listeners=[listener]))[self.db.name].pymongo_test # Tailable_defaults. coll.find(cursor_type=CursorType.TAILABLE_AWAIT).to_list() @@ -345,8 +344,7 @@ def test_explain(self): def test_explain_with_read_concern(self): # Do not add readConcern level to explain. listener = AllowListEventListener("explain") - client = rs_or_single_client(event_listeners=[listener]) - self.addCleanup(client.close) + client = self.rs_or_single_client(event_listeners=[listener]) coll = client.pymongo_test.test.with_options(read_concern=ReadConcern(level="local")) self.assertTrue(coll.find().explain()) started = listener.started_events @@ -1252,8 +1250,7 @@ def test_close_kills_cursor_synchronously(self): self.client._process_periodic_tasks() listener = AllowListEventListener("killCursors") - client = rs_or_single_client(event_listeners=[listener]) - self.addCleanup(client.close) + client = self.rs_or_single_client(event_listeners=[listener]) coll = client[self.db.name].test_close_kills_cursors # Add some test data. @@ -1291,8 +1288,7 @@ def assertCursorKilled(): @client_context.require_failCommand_appName def test_timeout_kills_cursor_synchronously(self): listener = AllowListEventListener("killCursors") - client = rs_or_single_client(event_listeners=[listener]) - self.addCleanup(client.close) + client = self.rs_or_single_client(event_listeners=[listener]) coll = client[self.db.name].test_timeout_kills_cursor # Add some test data. @@ -1349,8 +1345,7 @@ def test_delete_not_initialized(self): def test_getMore_does_not_send_readPreference(self): listener = AllowListEventListener("find", "getMore") - client = rs_or_single_client(event_listeners=[listener]) - self.addCleanup(client.close) + client = self.rs_or_single_client(event_listeners=[listener]) # We never send primary read preference so override the default. coll = client[self.db.name].get_collection( "test", read_preference=ReadPreference.PRIMARY_PREFERRED @@ -1454,7 +1449,7 @@ def test_find_raw_transaction(self): c.insert_many(docs) listener = OvertCommandListener() - client = rs_or_single_client(event_listeners=[listener]) + client = self.rs_or_single_client(event_listeners=[listener]) with client.start_session() as session: with session.start_transaction(): batches = ( @@ -1484,7 +1479,7 @@ def test_find_raw_retryable_reads(self): c.insert_many(docs) listener = OvertCommandListener() - client = rs_or_single_client(event_listeners=[listener], retryReads=True) + client = self.rs_or_single_client(event_listeners=[listener], retryReads=True) with self.fail_point( {"mode": {"times": 1}, "data": {"failCommands": ["find"], "closeConnection": True}} ): @@ -1505,7 +1500,7 @@ def test_find_raw_snapshot_reads(self): c.insert_many(docs) listener = OvertCommandListener() - client = rs_or_single_client(event_listeners=[listener], retryReads=True) + client = self.rs_or_single_client(event_listeners=[listener], retryReads=True) db = client[self.db.name] with client.start_session(snapshot=True) as session: db.test.distinct("x", {}, session=session) @@ -1566,7 +1561,7 @@ def test_read_concern(self): def test_monitoring(self): listener = EventListener() - client = rs_or_single_client(event_listeners=[listener]) + client = self.rs_or_single_client(event_listeners=[listener]) c = client.pymongo_test.test c.drop() c.insert_many([{"_id": i} for i in range(10)]) @@ -1632,7 +1627,7 @@ def test_aggregate_raw_transaction(self): c.insert_many(docs) listener = OvertCommandListener() - client = rs_or_single_client(event_listeners=[listener]) + client = self.rs_or_single_client(event_listeners=[listener]) with client.start_session() as session: with session.start_transaction(): batches = ( @@ -1663,7 +1658,7 @@ def test_aggregate_raw_retryable_reads(self): c.insert_many(docs) listener = OvertCommandListener() - client = rs_or_single_client(event_listeners=[listener], retryReads=True) + client = self.rs_or_single_client(event_listeners=[listener], retryReads=True) with self.fail_point( {"mode": {"times": 1}, "data": {"failCommands": ["aggregate"], "closeConnection": True}} ): @@ -1687,7 +1682,7 @@ def test_aggregate_raw_snapshot_reads(self): c.insert_many(docs) listener = OvertCommandListener() - client = rs_or_single_client(event_listeners=[listener], retryReads=True) + client = self.rs_or_single_client(event_listeners=[listener], retryReads=True) db = client[self.db.name] with client.start_session(snapshot=True) as session: db.test.distinct("x", {}, session=session) @@ -1733,7 +1728,7 @@ def test_collation(self): def test_monitoring(self): listener = EventListener() - client = rs_or_single_client(event_listeners=[listener]) + client = self.rs_or_single_client(event_listeners=[listener]) c = client.pymongo_test.test c.drop() c.insert_many([{"_id": i} for i in range(10)]) @@ -1777,8 +1772,7 @@ def test_monitoring(self): @client_context.require_no_mongos def test_exhaust_cursor_db_set(self): listener = OvertCommandListener() - client = rs_or_single_client(event_listeners=[listener]) - self.addCleanup(client.close) + client = self.rs_or_single_client(event_listeners=[listener]) c = client.pymongo_test.test c.delete_many({}) c.insert_many([{"_id": i} for i in range(3)]) diff --git a/test/test_custom_types.py b/test/test_custom_types.py index c30c62b1b1..abaa820cb7 100644 --- a/test/test_custom_types.py +++ b/test/test_custom_types.py @@ -27,7 +27,6 @@ from test import client_context, unittest from test.test_client import IntegrationTest -from test.utils import rs_client from bson import ( _BUILT_IN_TYPES, @@ -971,7 +970,7 @@ def create_targets(self, *args, **kwargs): if codec_options: kwargs["type_registry"] = codec_options.type_registry kwargs["document_class"] = codec_options.document_class - self.watched_target = rs_client(*args, **kwargs) + self.watched_target = self.rs_client(*args, **kwargs) self.addCleanup(self.watched_target.close) self.input_target = self.watched_target[self.db.name].test # Insert a record to ensure db, coll are created. diff --git a/test/test_data_lake.py b/test/test_data_lake.py index 8ba83ab190..a374db550e 100644 --- a/test/test_data_lake.py +++ b/test/test_data_lake.py @@ -27,8 +27,6 @@ from test.unified_format import generate_test_classes from test.utils import ( OvertCommandListener, - rs_client_noauth, - rs_or_single_client, ) pytestmark = pytest.mark.data_lake @@ -65,7 +63,7 @@ def setUpClass(cls): # Test killCursors def test_1(self): listener = OvertCommandListener() - client = rs_or_single_client(event_listeners=[listener]) + client = self.rs_or_single_client(event_listeners=[listener]) cursor = client[self.TEST_DB][self.TEST_COLLECTION].find({}, batch_size=2) next(cursor) @@ -90,13 +88,13 @@ def test_1(self): # Test no auth def test_2(self): - client = rs_client_noauth() + client = self.rs_client_noauth() client.admin.command("ping") # Test with auth def test_3(self): for mechanism in ["SCRAM-SHA-1", "SCRAM-SHA-256"]: - client = rs_or_single_client(authMechanism=mechanism) + client = self.rs_or_single_client(authMechanism=mechanism) client[self.TEST_DB][self.TEST_COLLECTION].find_one() diff --git a/test/test_database.py b/test/test_database.py index 12d4eb666a..fe07f343c5 100644 --- a/test/test_database.py +++ b/test/test_database.py @@ -28,7 +28,6 @@ from test.utils import ( IMPOSSIBLE_WRITE_CONCERN, OvertCommandListener, - rs_or_single_client, wait_until, ) @@ -207,7 +206,7 @@ def test_list_collection_names(self): def test_list_collection_names_filter(self): listener = OvertCommandListener() - client = rs_or_single_client(event_listeners=[listener]) + client = self.rs_or_single_client(event_listeners=[listener]) db = client[self.db.name] db.capped.drop() db.create_collection("capped", capped=True, size=4096) @@ -234,8 +233,7 @@ def test_list_collection_names_filter(self): def test_check_exists(self): listener = OvertCommandListener() - client = rs_or_single_client(event_listeners=[listener]) - self.addCleanup(client.close) + client = self.rs_or_single_client(event_listeners=[listener]) db = client[self.db.name] db.drop_collection("unique") db.create_collection("unique", check_exists=True) @@ -323,7 +321,7 @@ def test_list_collections(self): self.client.drop_database("pymongo_test") def test_list_collection_names_single_socket(self): - client = rs_or_single_client(maxPoolSize=1) + client = self.rs_or_single_client(maxPoolSize=1) client.drop_database("test_collection_names_single_socket") db = client.test_collection_names_single_socket for i in range(200): diff --git a/test/test_discovery_and_monitoring.py b/test/test_discovery_and_monitoring.py index ef32afbcd4..3554619f12 100644 --- a/test/test_discovery_and_monitoring.py +++ b/test/test_discovery_and_monitoring.py @@ -22,7 +22,7 @@ sys.path[0:0] = [""] -from test import IntegrationTest, unittest +from test import IntegrationTest, PyMongoTestCase, unittest from test.pymongo_mocks import DummyMonitor from test.unified_format import generate_test_classes from test.utils import ( @@ -32,9 +32,7 @@ assertion_context, client_context, get_pool, - rs_or_single_client, server_name_to_type, - single_client, wait_until, ) from unittest.mock import patch @@ -272,7 +270,7 @@ class TestIgnoreStaleErrors(IntegrationTest): def test_ignore_stale_connection_errors(self): N_THREADS = 5 barrier = threading.Barrier(N_THREADS, timeout=30) - client = rs_or_single_client(minPoolSize=N_THREADS) + client = self.rs_or_single_client(minPoolSize=N_THREADS) self.addCleanup(client.close) # Wait for initial discovery. @@ -319,7 +317,7 @@ class TestPoolManagement(IntegrationTest): def test_pool_unpause(self): # This test implements the prose test "Connection Pool Management" listener = CMAPHeartbeatListener() - client = single_client( + client = self.single_client( appName="SDAMPoolManagementTest", heartbeatFrequencyMS=500, event_listeners=[listener] ) self.addCleanup(client.close) @@ -353,7 +351,7 @@ def setUp(self): super().setUp() def test_rtt_connection_is_enabled_stream(self): - client = rs_or_single_client(serverMonitoringMode="stream") + client = self.rs_or_single_client(serverMonitoringMode="stream") self.addCleanup(client.close) client.admin.command("ping") @@ -373,7 +371,7 @@ def predicate(): wait_until(predicate, "find all RTT monitors") def test_rtt_connection_is_disabled_poll(self): - client = rs_or_single_client(serverMonitoringMode="poll") + client = self.rs_or_single_client(serverMonitoringMode="poll") self.addCleanup(client.close) self.assert_rtt_connection_is_disabled(client) @@ -387,7 +385,7 @@ def test_rtt_connection_is_disabled_auto(self): ] for env in envs: with patch.dict("os.environ", env): - client = rs_or_single_client(serverMonitoringMode="auto") + client = self.rs_or_single_client(serverMonitoringMode="auto") self.addCleanup(client.close) self.assert_rtt_connection_is_disabled(client) @@ -415,7 +413,7 @@ def handle_request_and_shutdown(self): self.server_close() -class TestHeartbeatStartOrdering(unittest.TestCase): +class TestHeartbeatStartOrdering(PyMongoTestCase): def test_heartbeat_start_ordering(self): events = [] listener = HeartbeatEventsListListener(events) @@ -423,7 +421,7 @@ def test_heartbeat_start_ordering(self): server.events = events server_thread = threading.Thread(target=server.handle_request_and_shutdown) server_thread.start() - _c = MongoClient( + _c = self.simple_client( "mongodb://localhost:9999", serverSelectionTimeoutMS=500, event_listeners=(listener,) ) server_thread.join() diff --git a/test/test_dns.py b/test/test_dns.py index b4c5e3684c..f2185efb1b 100644 --- a/test/test_dns.py +++ b/test/test_dns.py @@ -22,16 +22,15 @@ sys.path[0:0] = [""] -from test import IntegrationTest, client_context, unittest +from test import IntegrationTest, PyMongoTestCase, client_context, unittest from test.utils import wait_until from pymongo.common import validate_read_preference_tags from pymongo.errors import ConfigurationError -from pymongo.synchronous.mongo_client import MongoClient from pymongo.uri_parser import parse_uri, split_hosts -class TestDNSRepl(unittest.TestCase): +class TestDNSRepl(PyMongoTestCase): TEST_PATH = os.path.join( os.path.dirname(os.path.realpath(__file__)), "srv_seedlist", "replica-set" ) @@ -42,7 +41,7 @@ def setUp(self): pass -class TestDNSLoadBalanced(unittest.TestCase): +class TestDNSLoadBalanced(PyMongoTestCase): TEST_PATH = os.path.join( os.path.dirname(os.path.realpath(__file__)), "srv_seedlist", "load-balanced" ) @@ -53,7 +52,7 @@ def setUp(self): pass -class TestDNSSharded(unittest.TestCase): +class TestDNSSharded(PyMongoTestCase): TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "srv_seedlist", "sharded") load_balanced = False @@ -120,7 +119,7 @@ def run_test(self): # tests. copts["tlsAllowInvalidHostnames"] = True - client = MongoClient(uri, **copts) + client = PyMongoTestCase.unmanaged_simple_client(uri, **copts) if num_seeds is not None: self.assertEqual(len(client._topology_settings.seeds), num_seeds) if hosts is not None: @@ -133,6 +132,7 @@ def run_test(self): client.admin.command("ping") # XXX: we should block until SRV poller runs at least once # and re-run these assertions. + client.close() else: try: parse_uri(uri) @@ -157,37 +157,37 @@ def create_tests(cls): create_tests(TestDNSSharded) -class TestParsingErrors(unittest.TestCase): +class TestParsingErrors(PyMongoTestCase): def test_invalid_host(self): self.assertRaisesRegex( ConfigurationError, "Invalid URI host: mongodb is not", - MongoClient, + self.simple_client, "mongodb+srv://mongodb", ) self.assertRaisesRegex( ConfigurationError, "Invalid URI host: mongodb.com is not", - MongoClient, + self.simple_client, "mongodb+srv://mongodb.com", ) self.assertRaisesRegex( ConfigurationError, "Invalid URI host: an IP address is not", - MongoClient, + self.simple_client, "mongodb+srv://127.0.0.1", ) self.assertRaisesRegex( ConfigurationError, "Invalid URI host: an IP address is not", - MongoClient, + self.simple_client, "mongodb+srv://[::1]", ) class TestCaseInsensitive(IntegrationTest): def test_connect_case_insensitive(self): - client = MongoClient("mongodb+srv://TEST1.TEST.BUILD.10GEN.cc/") + client = self.simple_client("mongodb+srv://TEST1.TEST.BUILD.10GEN.cc/") self.addCleanup(client.close) self.assertGreater(len(client.topology_description.server_descriptions()), 1) diff --git a/test/test_encryption.py b/test/test_encryption.py index 5e02e4d628..96d40c4a34 100644 --- a/test/test_encryption.py +++ b/test/test_encryption.py @@ -31,7 +31,7 @@ from test import IntegrationTest, PyMongoTestCase, client_context from test.test_bulk import BulkTestBase from threading import Thread -from typing import Any, Dict, Mapping +from typing import Any, Dict, Mapping, Optional import pytest @@ -53,6 +53,7 @@ KMIP_CREDS, LOCAL_MASTER_KEY, ) +from test.test_bulk import BulkTestBase from test.unified_format import generate_test_classes from test.utils import ( AllowListEventListener, @@ -61,7 +62,6 @@ TopologyEventListener, camel_to_snake_args, is_greenthread_patched, - rs_or_single_client, wait_until, ) from test.utils_spec_runner import SpecRunner @@ -109,13 +109,12 @@ class TestAutoEncryptionOpts(PyMongoTestCase): @unittest.skipUnless(os.environ.get("TEST_CRYPT_SHARED"), "crypt_shared lib is not installed") def test_crypt_shared(self): # Test that we can pick up crypt_shared lib automatically - client = MongoClient( + self.simple_client( auto_encryption_opts=AutoEncryptionOpts( KMS_PROVIDERS, "keyvault.datakeys", crypt_shared_lib_required=True ), connect=False, ) - self.addCleanup(client.close) @unittest.skipIf(_HAVE_PYMONGOCRYPT, "pymongocrypt is installed") def test_init_requires_pymongocrypt(self): @@ -196,19 +195,16 @@ def test_init_kms_tls_options(self): class TestClientOptions(PyMongoTestCase): def test_default(self): - client = MongoClient(connect=False) - self.addCleanup(client.close) + client = self.simple_client(connect=False) self.assertEqual(get_client_opts(client).auto_encryption_opts, None) - client = MongoClient(auto_encryption_opts=None, connect=False) - self.addCleanup(client.close) + client = self.simple_client(auto_encryption_opts=None, connect=False) self.assertEqual(get_client_opts(client).auto_encryption_opts, None) @unittest.skipUnless(_HAVE_PYMONGOCRYPT, "pymongocrypt is not installed") def test_kwargs(self): opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys") - client = MongoClient(auto_encryption_opts=opts, connect=False) - self.addCleanup(client.close) + client = self.simple_client(auto_encryption_opts=opts, connect=False) self.assertEqual(get_client_opts(client).auto_encryption_opts, opts) @@ -229,6 +225,34 @@ def assertBinaryUUID(self, val): self.assertIsInstance(val, Binary) self.assertEqual(val.subtype, UUID_SUBTYPE) + def create_client_encryption( + self, + kms_providers: Mapping[str, Any], + key_vault_namespace: str, + key_vault_client: MongoClient, + codec_options: CodecOptions, + kms_tls_options: Optional[Mapping[str, Any]] = None, + ): + client_encryption = ClientEncryption( + kms_providers, key_vault_namespace, key_vault_client, codec_options, kms_tls_options + ) + self.addCleanup(client_encryption.close) + return client_encryption + + @classmethod + def unmanaged_create_client_encryption( + cls, + kms_providers: Mapping[str, Any], + key_vault_namespace: str, + key_vault_client: MongoClient, + codec_options: CodecOptions, + kms_tls_options: Optional[Mapping[str, Any]] = None, + ): + client_encryption = ClientEncryption( + kms_providers, key_vault_namespace, key_vault_client, codec_options, kms_tls_options + ) + return client_encryption + # Location of JSON test files. if _IS_SYNC: @@ -260,8 +284,7 @@ def bson_data(*paths): class TestClientSimple(EncryptionIntegrationTest): def _test_auto_encrypt(self, opts): - client = rs_or_single_client(auto_encryption_opts=opts) - self.addCleanup(client.close) + client = self.rs_or_single_client(auto_encryption_opts=opts) # Create the encrypted field's data key. key_vault = create_key_vault( @@ -342,8 +365,7 @@ def test_auto_encrypt_local_schema_map(self): def test_use_after_close(self): opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys") - client = rs_or_single_client(auto_encryption_opts=opts) - self.addCleanup(client.close) + client = self.rs_or_single_client(auto_encryption_opts=opts) client.admin.command("ping") client.close() @@ -360,8 +382,7 @@ def test_use_after_close(self): ) def test_fork(self): opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys") - client = rs_or_single_client(auto_encryption_opts=opts) - self.addCleanup(client.close) + client = self.rs_or_single_client(auto_encryption_opts=opts) def target(): with warnings.catch_warnings(): @@ -375,8 +396,7 @@ def target(): class TestEncryptedBulkWrite(BulkTestBase, EncryptionIntegrationTest): def test_upsert_uuid_standard_encrypt(self): opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys") - client = rs_or_single_client(auto_encryption_opts=opts) - self.addCleanup(client.close) + client = self.rs_or_single_client(auto_encryption_opts=opts) options = CodecOptions(uuid_representation=UuidRepresentation.STANDARD) encrypted_coll = client.pymongo_test.test @@ -416,8 +436,7 @@ def _setup_class(cls): @client_context.require_version_max(4, 0, 99) def test_raise_max_wire_version_error(self): opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys") - client = rs_or_single_client(auto_encryption_opts=opts) - self.addCleanup(client.close) + client = self.rs_or_single_client(auto_encryption_opts=opts) msg = "Auto-encryption requires a minimum MongoDB version of 4.2" with self.assertRaisesRegex(ConfigurationError, msg): client.test.test.insert_one({}) @@ -430,8 +449,7 @@ def test_raise_max_wire_version_error(self): def test_raise_unsupported_error(self): opts = AutoEncryptionOpts(KMS_PROVIDERS, "keyvault.datakeys") - client = rs_or_single_client(auto_encryption_opts=opts) - self.addCleanup(client.close) + client = self.rs_or_single_client(auto_encryption_opts=opts) msg = "find_raw_batches does not support auto encryption" with self.assertRaisesRegex(InvalidOperation, msg): client.test.test.find_raw_batches({}) @@ -450,10 +468,9 @@ def test_raise_unsupported_error(self): class TestExplicitSimple(EncryptionIntegrationTest): def test_encrypt_decrypt(self): - client_encryption = ClientEncryption( + client_encryption = self.create_client_encryption( KMS_PROVIDERS, "keyvault.datakeys", client_context.client, OPTS ) - self.addCleanup(client_encryption.close) # Use standard UUID representation. key_vault = client_context.client.keyvault.get_collection("datakeys", codec_options=OPTS) self.addCleanup(key_vault.drop) @@ -493,10 +510,9 @@ def test_encrypt_decrypt(self): self.assertEqual(decrypted_ssn, doc["ssn"]) def test_validation(self): - client_encryption = ClientEncryption( + client_encryption = self.create_client_encryption( KMS_PROVIDERS, "keyvault.datakeys", client_context.client, OPTS ) - self.addCleanup(client_encryption.close) msg = "value to decrypt must be a bson.binary.Binary with subtype 6" with self.assertRaisesRegex(TypeError, msg): @@ -510,10 +526,9 @@ def test_validation(self): client_encryption.encrypt("str", algo, key_id=Binary(b"123")) def test_bson_errors(self): - client_encryption = ClientEncryption( + client_encryption = self.create_client_encryption( KMS_PROVIDERS, "keyvault.datakeys", client_context.client, OPTS ) - self.addCleanup(client_encryption.close) # Attempt to encrypt an unencodable object. unencodable_value = object() @@ -526,7 +541,7 @@ def test_bson_errors(self): def test_codec_options(self): with self.assertRaisesRegex(TypeError, "codec_options must be"): - ClientEncryption( + self.create_client_encryption( KMS_PROVIDERS, "keyvault.datakeys", client_context.client, @@ -534,10 +549,9 @@ def test_codec_options(self): ) opts = CodecOptions(uuid_representation=UuidRepresentation.JAVA_LEGACY) - client_encryption_legacy = ClientEncryption( + client_encryption_legacy = self.create_client_encryption( KMS_PROVIDERS, "keyvault.datakeys", client_context.client, opts ) - self.addCleanup(client_encryption_legacy.close) # Create the encrypted field's data key. key_id = client_encryption_legacy.create_data_key("local") @@ -552,10 +566,9 @@ def test_codec_options(self): # Encrypt the same UUID with STANDARD codec options. opts = CodecOptions(uuid_representation=UuidRepresentation.STANDARD) - client_encryption = ClientEncryption( + client_encryption = self.create_client_encryption( KMS_PROVIDERS, "keyvault.datakeys", client_context.client, opts ) - self.addCleanup(client_encryption.close) encrypted_standard = client_encryption.encrypt( value, Algorithm.AEAD_AES_256_CBC_HMAC_SHA_512_Deterministic, key_id=key_id ) @@ -571,7 +584,7 @@ def test_codec_options(self): self.assertNotEqual(client_encryption.decrypt(encrypted_legacy), value) def test_close(self): - client_encryption = ClientEncryption( + client_encryption = self.create_client_encryption( KMS_PROVIDERS, "keyvault.datakeys", client_context.client, OPTS ) client_encryption.close() @@ -587,7 +600,7 @@ def test_close(self): client_encryption.decrypt(Binary(b"", 6)) def test_with_statement(self): - with ClientEncryption( + with self.create_client_encryption( KMS_PROVIDERS, "keyvault.datakeys", client_context.client, OPTS ) as client_encryption: pass @@ -807,7 +820,7 @@ class TestDataKeyDoubleEncryption(EncryptionIntegrationTest): def _setup_class(cls): super()._setup_class() cls.listener = OvertCommandListener() - cls.client = rs_or_single_client(event_listeners=[cls.listener]) + cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener]) cls.client.db.coll.drop() cls.vault = create_key_vault(cls.client.keyvault.datakeys) @@ -829,10 +842,10 @@ def _setup_class(cls): opts = AutoEncryptionOpts( cls.KMS_PROVIDERS, "keyvault.datakeys", schema_map=schemas, kms_tls_options=KMS_TLS_OPTS ) - cls.client_encrypted = rs_or_single_client( + cls.client_encrypted = cls.unmanaged_rs_or_single_client( auto_encryption_opts=opts, uuidRepresentation="standard" ) - cls.client_encryption = ClientEncryption( + cls.client_encryption = cls.unmanaged_create_client_encryption( cls.KMS_PROVIDERS, "keyvault.datakeys", cls.client, OPTS, kms_tls_options=KMS_TLS_OPTS ) @@ -919,8 +932,7 @@ def _test_external_key_vault(self, with_external_key_vault): # Configure the encrypted field via the local schema_map option. schemas = {"db.coll": json_data("external", "external-schema.json")} if with_external_key_vault: - key_vault_client = rs_or_single_client(username="fake-user", password="fake-pwd") - self.addCleanup(key_vault_client.close) + key_vault_client = self.rs_or_single_client(username="fake-user", password="fake-pwd") else: key_vault_client = client_context.client opts = AutoEncryptionOpts( @@ -930,15 +942,13 @@ def _test_external_key_vault(self, with_external_key_vault): key_vault_client=key_vault_client, ) - client_encrypted = rs_or_single_client( + client_encrypted = self.rs_or_single_client( auto_encryption_opts=opts, uuidRepresentation="standard" ) - self.addCleanup(client_encrypted.close) - client_encryption = ClientEncryption( + client_encryption = self.create_client_encryption( self.kms_providers(), "keyvault.datakeys", key_vault_client, OPTS ) - self.addCleanup(client_encryption.close) if with_external_key_vault: # Authentication error. @@ -984,10 +994,9 @@ def test_views_are_prohibited(self): self.addCleanup(self.client.db.view.drop) opts = AutoEncryptionOpts(self.kms_providers(), "keyvault.datakeys") - client_encrypted = rs_or_single_client( + client_encrypted = self.rs_or_single_client( auto_encryption_opts=opts, uuidRepresentation="standard" ) - self.addCleanup(client_encrypted.close) with self.assertRaisesRegex(EncryptionError, "cannot auto encrypt a view"): client_encrypted.db.view.insert_one({}) @@ -1044,17 +1053,15 @@ def _test_corpus(self, opts): ) self.addCleanup(vault.drop) - client_encrypted = rs_or_single_client(auto_encryption_opts=opts) - self.addCleanup(client_encrypted.close) + client_encrypted = self.rs_or_single_client(auto_encryption_opts=opts) - client_encryption = ClientEncryption( + client_encryption = self.create_client_encryption( self.kms_providers(), "keyvault.datakeys", client_context.client, OPTS, kms_tls_options=KMS_TLS_OPTS, ) - self.addCleanup(client_encryption.close) corpus = self.fix_up_curpus(json_data("corpus", "corpus.json")) corpus_copied: SON = SON() @@ -1197,7 +1204,7 @@ def _setup_class(cls): opts = AutoEncryptionOpts({"local": {"key": LOCAL_MASTER_KEY}}, "keyvault.datakeys") cls.listener = OvertCommandListener() - cls.client_encrypted = rs_or_single_client( + cls.client_encrypted = cls.unmanaged_rs_or_single_client( auto_encryption_opts=opts, event_listeners=[cls.listener] ) cls.coll_encrypted = cls.client_encrypted.db.coll @@ -1285,7 +1292,7 @@ def setUp(self): "gcp": GCP_CREDS, "kmip": KMIP_CREDS, } - self.client_encryption = ClientEncryption( + self.client_encryption = self.create_client_encryption( kms_providers=kms_providers, key_vault_namespace="keyvault.datakeys", key_vault_client=client_context.client, @@ -1297,7 +1304,7 @@ def setUp(self): kms_providers_invalid["azure"]["identityPlatformEndpoint"] = "doesnotexist.invalid:443" kms_providers_invalid["gcp"]["endpoint"] = "doesnotexist.invalid:443" kms_providers_invalid["kmip"]["endpoint"] = "doesnotexist.local:5698" - self.client_encryption_invalid = ClientEncryption( + self.client_encryption_invalid = self.create_client_encryption( kms_providers=kms_providers_invalid, key_vault_namespace="keyvault.datakeys", key_vault_client=client_context.client, @@ -1476,7 +1483,7 @@ def test_12_kmip_master_key_invalid_endpoint(self): self.client_encryption.create_data_key("kmip", key) -class AzureGCPEncryptionTestMixin: +class AzureGCPEncryptionTestMixin(EncryptionIntegrationTest): DEK = None KMS_PROVIDER_MAP = None KEYVAULT_DB = "keyvault" @@ -1488,7 +1495,7 @@ def setUp(self): create_key_vault(keyvault, self.DEK) def _test_explicit(self, expectation): - client_encryption = ClientEncryption( + client_encryption = self.create_client_encryption( self.KMS_PROVIDER_MAP, # type: ignore[arg-type] ".".join([self.KEYVAULT_DB, self.KEYVAULT_COLL]), client_context.client, @@ -1517,7 +1524,7 @@ def _test_automatic(self, expectation_extjson, payload): ) insert_listener = AllowListEventListener("insert") - client = rs_or_single_client( + client = self.rs_or_single_client( auto_encryption_opts=encryption_opts, event_listeners=[insert_listener] ) self.addCleanup(client.close) @@ -1596,19 +1603,17 @@ def test_automatic(self): # https://github.com/mongodb/specifications/blob/master/source/client-side-encryption/tests/README.rst#deadlock-tests class TestDeadlockProse(EncryptionIntegrationTest): def setUp(self): - self.client_test = rs_or_single_client( + self.client_test = self.rs_or_single_client( maxPoolSize=1, readConcernLevel="majority", w="majority", uuidRepresentation="standard" ) - self.addCleanup(self.client_test.close) self.client_keyvault_listener = OvertCommandListener() - self.client_keyvault = rs_or_single_client( + self.client_keyvault = self.rs_or_single_client( maxPoolSize=1, readConcernLevel="majority", w="majority", event_listeners=[self.client_keyvault_listener], ) - self.addCleanup(self.client_keyvault.close) self.client_test.keyvault.datakeys.drop() self.client_test.db.coll.drop() @@ -1619,7 +1624,7 @@ def setUp(self): codec_options=OPTS, ) - client_encryption = ClientEncryption( + client_encryption = self.create_client_encryption( kms_providers={"local": {"key": LOCAL_MASTER_KEY}}, key_vault_namespace="keyvault.datakeys", key_vault_client=self.client_test, @@ -1635,7 +1640,7 @@ def setUp(self): self.optargs = ({"local": {"key": LOCAL_MASTER_KEY}}, "keyvault.datakeys") def _run_test(self, max_pool_size, auto_encryption_opts): - client_encrypted = rs_or_single_client( + client_encrypted = self.rs_or_single_client( readConcernLevel="majority", w="majority", maxPoolSize=max_pool_size, @@ -1653,8 +1658,6 @@ def _run_test(self, max_pool_size, auto_encryption_opts): result = client_encrypted.db.coll.find_one({"_id": 0}) self.assertEqual(result, {"_id": 0, "encrypted": "string0"}) - self.addCleanup(client_encrypted.close) - def test_case_1(self): self._run_test( max_pool_size=1, @@ -1830,7 +1833,7 @@ def setUp(self): create_key_vault(self.client.keyvault.datakeys) kms_providers_map = {"local": {"key": LOCAL_MASTER_KEY}} - self.client_encryption = ClientEncryption( + self.client_encryption = self.create_client_encryption( kms_providers_map, "keyvault.datakeys", self.client, CodecOptions() ) keyID = self.client_encryption.create_data_key("local") @@ -1845,10 +1848,9 @@ def setUp(self): key_vault_namespace="keyvault.datakeys", kms_providers=kms_providers_map ) self.listener = AllowListEventListener("aggregate") - self.encrypted_client = rs_or_single_client( + self.encrypted_client = self.rs_or_single_client( auto_encryption_opts=opts, retryReads=False, event_listeners=[self.listener] ) - self.addCleanup(self.encrypted_client.close) def test_01_command_error(self): with self.fail_point( @@ -1925,8 +1927,7 @@ def reset_timeout(): "--port=27027", ], ) - client_encrypted = rs_or_single_client(auto_encryption_opts=opts) - self.addCleanup(client_encrypted.close) + client_encrypted = self.rs_or_single_client(auto_encryption_opts=opts) with self.assertRaisesRegex(EncryptionError, "Timeout"): client_encrypted.db.coll.insert_one({"encrypted": "test"}) @@ -1940,11 +1941,12 @@ def test_bypassAutoEncryption(self): "--port=27027", ], ) - client_encrypted = rs_or_single_client(auto_encryption_opts=opts) - self.addCleanup(client_encrypted.close) + client_encrypted = self.rs_or_single_client(auto_encryption_opts=opts) client_encrypted.db.coll.insert_one({"unencrypted": "test"}) # Validate that mongocryptd was not spawned: - mongocryptd_client = MongoClient("mongodb://localhost:27027/?serverSelectionTimeoutMS=500") + mongocryptd_client = self.simple_client( + "mongodb://localhost:27027/?serverSelectionTimeoutMS=500" + ) with self.assertRaises(ServerSelectionTimeoutError): mongocryptd_client.admin.command("ping") @@ -1966,15 +1968,13 @@ def test_via_loading_shared_library(self): ], crypt_shared_lib_required=True, ) - client_encrypted = rs_or_single_client(auto_encryption_opts=opts) - self.addCleanup(client_encrypted.close) + client_encrypted = self.rs_or_single_client(auto_encryption_opts=opts) client_encrypted.db.coll.drop() client_encrypted.db.coll.insert_one({"encrypted": "test"}) self.assertEncrypted((client_context.client.db.coll.find_one({}))["encrypted"]) - no_mongocryptd_client = MongoClient( + no_mongocryptd_client = self.simple_client( host="mongodb://localhost:47021/db?serverSelectionTimeoutMS=1000" ) - self.addCleanup(no_mongocryptd_client.close) with self.assertRaises(ServerSelectionTimeoutError): no_mongocryptd_client.db.command("ping") @@ -2008,8 +2008,7 @@ def listener(): mongocryptd_uri="mongodb://localhost:47021", crypt_shared_lib_required=False, ) - client_encrypted = rs_or_single_client(auto_encryption_opts=opts) - self.addCleanup(client_encrypted.close) + client_encrypted = self.rs_or_single_client(auto_encryption_opts=opts) client_encrypted.db.coll.drop() client_encrypted.db.coll.insert_one({"encrypted": "test"}) server.shutdown() @@ -2023,10 +2022,9 @@ class TestKmsTLSProse(EncryptionIntegrationTest): def setUp(self): super().setUp() self.patch_system_certs(CA_PEM) - self.client_encrypted = ClientEncryption( + self.client_encrypted = self.create_client_encryption( {"aws": AWS_CREDS}, "keyvault.datakeys", self.client, OPTS ) - self.addCleanup(self.client_encrypted.close) def test_invalid_kms_certificate_expired(self): key = { @@ -2071,36 +2069,32 @@ def setUp(self): "gcp": {"tlsCAFile": CA_PEM}, "kmip": {"tlsCAFile": CA_PEM}, } - self.client_encryption_no_client_cert = ClientEncryption( + self.client_encryption_no_client_cert = self.create_client_encryption( providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=kms_tls_opts_ca_only ) - self.addCleanup(self.client_encryption_no_client_cert.close) # 2, same providers as above but with tlsCertificateKeyFile. kms_tls_opts = copy.deepcopy(kms_tls_opts_ca_only) for p in kms_tls_opts: kms_tls_opts[p]["tlsCertificateKeyFile"] = CLIENT_PEM - self.client_encryption_with_tls = ClientEncryption( + self.client_encryption_with_tls = self.create_client_encryption( providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=kms_tls_opts ) - self.addCleanup(self.client_encryption_with_tls.close) # 3, update endpoints to expired host. providers: dict = copy.deepcopy(providers) providers["azure"]["identityPlatformEndpoint"] = "127.0.0.1:9000" providers["gcp"]["endpoint"] = "127.0.0.1:9000" providers["kmip"]["endpoint"] = "127.0.0.1:9000" - self.client_encryption_expired = ClientEncryption( + self.client_encryption_expired = self.create_client_encryption( providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=kms_tls_opts_ca_only ) - self.addCleanup(self.client_encryption_expired.close) # 3, update endpoints to invalid host. providers: dict = copy.deepcopy(providers) providers["azure"]["identityPlatformEndpoint"] = "127.0.0.1:9001" providers["gcp"]["endpoint"] = "127.0.0.1:9001" providers["kmip"]["endpoint"] = "127.0.0.1:9001" - self.client_encryption_invalid_hostname = ClientEncryption( + self.client_encryption_invalid_hostname = self.create_client_encryption( providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=kms_tls_opts_ca_only ) - self.addCleanup(self.client_encryption_invalid_hostname.close) # Errors when client has no cert, some examples: # [SSL: TLSV13_ALERT_CERTIFICATE_REQUIRED] tlsv13 alert certificate required (_ssl.c:2623) self.cert_error = ( @@ -2138,7 +2132,7 @@ def setUp(self): "gcp:with_tls": with_cert, "kmip:with_tls": with_cert, } - self.client_encryption_with_names = ClientEncryption( + self.client_encryption_with_names = self.create_client_encryption( providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=kms_tls_opts_4 ) @@ -2220,10 +2214,9 @@ def test_04_kmip(self): def test_05_tlsDisableOCSPEndpointCheck_is_permitted(self): providers = {"aws": {"accessKeyId": "foo", "secretAccessKey": "bar"}} options = {"aws": {"tlsDisableOCSPEndpointCheck": True}} - encryption = ClientEncryption( + encryption = self.create_client_encryption( providers, "keyvault.datakeys", self.client, OPTS, kms_tls_options=options ) - self.addCleanup(encryption.close) ctx = encryption._io_callbacks.opts._kms_ssl_contexts["aws"] if not hasattr(ctx, "check_ocsp_endpoint"): raise self.skipTest("OCSP not enabled") @@ -2273,7 +2266,7 @@ def setUp(self): self.client = client_context.client create_key_vault(self.client.keyvault.datakeys) kms_providers_map = {"local": {"key": LOCAL_MASTER_KEY}} - self.client_encryption = ClientEncryption( + self.client_encryption = self.create_client_encryption( kms_providers_map, "keyvault.datakeys", self.client, CodecOptions() ) self.def_key_id = self.client_encryption.create_data_key("local", key_alt_names=["def"]) @@ -2311,17 +2304,15 @@ def setUp(self): key_vault = create_key_vault(self.client.keyvault.datakeys, self.key1_document) self.addCleanup(key_vault.drop) self.key_vault_client = self.client - self.client_encryption = ClientEncryption( + self.client_encryption = self.create_client_encryption( {"local": {"key": LOCAL_MASTER_KEY}}, key_vault.full_name, self.key_vault_client, OPTS ) - self.addCleanup(self.client_encryption.close) opts = AutoEncryptionOpts( {"local": {"key": LOCAL_MASTER_KEY}}, key_vault.full_name, bypass_query_analysis=True, ) - self.encrypted_client = rs_or_single_client(auto_encryption_opts=opts) - self.addCleanup(self.encrypted_client.close) + self.encrypted_client = self.rs_or_single_client(auto_encryption_opts=opts) def test_01_insert_encrypted_indexed_and_find(self): val = "encrypted indexed value" @@ -2444,14 +2435,13 @@ def run_test(self, src_provider, dst_provider): self.client.keyvault.drop_collection("datakeys") # Step 2. Create a ``ClientEncryption`` object named ``client_encryption1`` - client_encryption1 = ClientEncryption( + client_encryption1 = self.create_client_encryption( key_vault_client=self.client, key_vault_namespace="keyvault.datakeys", kms_providers=ALL_KMS_PROVIDERS, kms_tls_options=KMS_TLS_OPTS, codec_options=OPTS, ) - self.addCleanup(client_encryption1.close) # Step 3. Call ``client_encryption1.create_data_key`` with ``src_provider``. key_id = client_encryption1.create_data_key( @@ -2464,16 +2454,14 @@ def run_test(self, src_provider, dst_provider): ) # Step 5. Create a ``ClientEncryption`` object named ``client_encryption2`` - client2 = rs_or_single_client() - self.addCleanup(client2.close) - client_encryption2 = ClientEncryption( + client2 = self.rs_or_single_client() + client_encryption2 = self.create_client_encryption( key_vault_client=client2, key_vault_namespace="keyvault.datakeys", kms_providers=ALL_KMS_PROVIDERS, kms_tls_options=KMS_TLS_OPTS, codec_options=OPTS, ) - self.addCleanup(client_encryption2.close) # Step 6. Call ``client_encryption2.rewrap_many_data_key`` with an empty ``filter``. rewrap_many_data_key_result = client_encryption2.rewrap_many_data_key( @@ -2508,7 +2496,7 @@ def setUp(self): @unittest.skipIf(any(AWS_CREDS.values()), "AWS environment credentials are set") def test_01_failure(self): - self.client_encryption = ClientEncryption( + self.client_encryption = self.create_client_encryption( kms_providers={"aws": {}}, key_vault_namespace="keyvault.datakeys", key_vault_client=client_context.client, @@ -2519,7 +2507,7 @@ def test_01_failure(self): @unittest.skipUnless(any(AWS_CREDS.values()), "AWS environment credentials are not set") def test_02_success(self): - self.client_encryption = ClientEncryption( + self.client_encryption = self.create_client_encryption( kms_providers={"aws": {}}, key_vault_namespace="keyvault.datakeys", key_vault_client=client_context.client, @@ -2539,8 +2527,7 @@ def test_queryable_encryption(self): # MongoClient to use in testing that handles auth/tls/etc, # and cleanup. def MongoClient(**kwargs): - c = rs_or_single_client(**kwargs) - self.addCleanup(c.close) + c = self.rs_or_single_client(**kwargs) return c # Drop data from prior test runs. @@ -2551,7 +2538,7 @@ def MongoClient(**kwargs): # Create two data keys. key_vault_client = MongoClient() - client_encryption = ClientEncryption( + client_encryption = self.create_client_encryption( kms_providers_map, "keyvault.datakeys", key_vault_client, CodecOptions() ) key1_id = client_encryption.create_data_key("local") @@ -2632,18 +2619,16 @@ def setUp(self): key_vault = create_key_vault(self.client.keyvault.datakeys, self.key1_document) self.addCleanup(key_vault.drop) self.key_vault_client = self.client - self.client_encryption = ClientEncryption( + self.client_encryption = self.create_client_encryption( {"local": {"key": LOCAL_MASTER_KEY}}, key_vault.full_name, self.key_vault_client, OPTS ) - self.addCleanup(self.client_encryption.close) opts = AutoEncryptionOpts( {"local": {"key": LOCAL_MASTER_KEY}}, key_vault.full_name, bypass_query_analysis=True, ) - self.encrypted_client = rs_or_single_client(auto_encryption_opts=opts) + self.encrypted_client = self.rs_or_single_client(auto_encryption_opts=opts) self.db = self.encrypted_client.db - self.addCleanup(self.encrypted_client.close) def run_expression_find( self, name, expression, expected_elems, range_opts, use_expr=False, key_id=None @@ -2838,10 +2823,9 @@ def setUp(self): super().setUp() self.client.drop_database(self.db) self.key_vault_client = self.client - self.client_encryption = ClientEncryption( + self.client_encryption = self.create_client_encryption( {"local": {"key": LOCAL_MASTER_KEY}}, "keyvault.datakeys", self.key_vault_client, OPTS ) - self.addCleanup(self.client_encryption.close) self.key_id = self.client_encryption.create_data_key("local") opts = RangeOpts(min=0, max=1000) self.payload_defaults = self.client_encryption.encrypt( @@ -2874,13 +2858,12 @@ def setUp(self): self.client.drop_database(self.db) self.key_vault = create_key_vault(self.client.keyvault.datakeys, self.key1_document) self.addCleanup(self.key_vault.drop) - self.client_encryption = ClientEncryption( + self.client_encryption = self.create_client_encryption( {"local": {"key": LOCAL_MASTER_KEY}}, self.key_vault.full_name, self.client, OPTS, ) - self.addCleanup(self.client_encryption.close) def test_01_simple_create(self): coll, _ = self.client_encryption.create_encrypted_collection( @@ -3096,10 +3079,9 @@ def _tearDown_class(cls): def setUp(self) -> None: self.listener = OvertCommandListener() - self.mongocryptd_client = MongoClient( + self.mongocryptd_client = self.simple_client( f"mongodb://localhost:{self.MONGOCRYPTD_PORT}", event_listeners=[self.listener] ) - self.addCleanup(self.mongocryptd_client.close) hello = self.mongocryptd_client.db.command("hello") self.assertNotIn("logicalSessionTimeoutMinutes", hello) diff --git a/test/test_examples.py b/test/test_examples.py index 296283db28..ebf1d784a3 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -22,7 +22,7 @@ sys.path[0:0] = [""] from test import IntegrationTest, client_context, unittest -from test.utils import rs_client, wait_until +from test.utils import wait_until import pymongo from pymongo.errors import ConnectionFailure, OperationFailure @@ -1128,7 +1128,7 @@ def update_employee_info(session): self.assertEqual(employee["status"], "Inactive") def MongoClient(_): - return rs_client() + return self.rs_client() uriString = None @@ -1220,7 +1220,7 @@ class TestVersionedApiExamples(IntegrationTest): def test_versioned_api(self): # Versioned API examples def MongoClient(_, server_api): - return rs_client(server_api=server_api, connect=False) + return self.rs_client(server_api=server_api, connect=False) uri = None @@ -1251,7 +1251,7 @@ def test_versioned_api_migration(self): ): self.skipTest("This test needs MongoDB 5.0.2 or newer") - client = rs_client(server_api=ServerApi("1", strict=True)) + client = self.rs_client(server_api=ServerApi("1", strict=True)) client.db.sales.drop() # Start Versioned API Example 5 diff --git a/test/test_grid_file.py b/test/test_grid_file.py index bd89235b73..fe88aec5ff 100644 --- a/test/test_grid_file.py +++ b/test/test_grid_file.py @@ -33,7 +33,7 @@ sys.path[0:0] = [""] -from test.utils import EventListener, rs_or_single_client +from test.utils import EventListener from bson.objectid import ObjectId from gridfs.errors import NoFile @@ -790,7 +790,7 @@ def test_grid_out_lazy_connect(self): outfile.readchunk() def test_grid_in_lazy_connect(self): - client = MongoClient("badhost", connect=False, serverSelectionTimeoutMS=10) + client = self.simple_client("badhost", connect=False, serverSelectionTimeoutMS=10) fs = client.db.fs infile = GridIn(fs, file_id=-1, chunk_size=1) with self.assertRaises(ServerSelectionTimeoutError): @@ -801,7 +801,7 @@ def test_grid_in_lazy_connect(self): def test_unacknowledged(self): # w=0 is prohibited. with self.assertRaises(ConfigurationError): - GridIn((rs_or_single_client(w=0)).pymongo_test.fs) + GridIn((self.rs_or_single_client(w=0)).pymongo_test.fs) def test_survive_cursor_not_found(self): # By default the find command returns 101 documents in the first batch. @@ -809,7 +809,7 @@ def test_survive_cursor_not_found(self): chunk_size = 1024 data = b"d" * (102 * chunk_size) listener = EventListener() - client = rs_or_single_client(event_listeners=[listener]) + client = self.rs_or_single_client(event_listeners=[listener]) db = client.pymongo_test with GridIn(db.fs, chunk_size=chunk_size) as infile: infile.write(data) diff --git a/test/test_gridfs.py b/test/test_gridfs.py index 19ec152bd1..549dc0b204 100644 --- a/test/test_gridfs.py +++ b/test/test_gridfs.py @@ -26,7 +26,7 @@ sys.path[0:0] = [""] from test import IntegrationTest, client_context, unittest -from test.utils import joinall, one, rs_client, rs_or_single_client, single_client +from test.utils import joinall, one import gridfs from bson.binary import Binary @@ -411,7 +411,7 @@ def iterate_file(grid_file): self.assertTrue(iterate_file(f)) def test_gridfs_lazy_connect(self): - client = MongoClient("badhost", connect=False, serverSelectionTimeoutMS=10) + client = self.single_client("badhost", connect=False, serverSelectionTimeoutMS=10) db = client.db gfs = gridfs.GridFS(db) self.assertRaises(ServerSelectionTimeoutError, gfs.list) @@ -492,7 +492,7 @@ def test_grid_in_non_int_chunksize(self): def test_unacknowledged(self): # w=0 is prohibited. with self.assertRaises(ConfigurationError): - gridfs.GridFS(rs_or_single_client(w=0).pymongo_test) + gridfs.GridFS(self.rs_or_single_client(w=0).pymongo_test) def test_md5(self): gin = self.fs.new_file() @@ -519,7 +519,7 @@ def tearDownClass(cls): client_context.client.drop_database("gfsreplica") def test_gridfs_replica_set(self): - rsc = rs_client(w=client_context.w, read_preference=ReadPreference.SECONDARY) + rsc = self.rs_client(w=client_context.w, read_preference=ReadPreference.SECONDARY) fs = gridfs.GridFS(rsc.gfsreplica, "gfsreplicatest") @@ -532,7 +532,7 @@ def test_gridfs_replica_set(self): def test_gridfs_secondary(self): secondary_host, secondary_port = one(self.client.secondaries) - secondary_connection = single_client( + secondary_connection = self.single_client( secondary_host, secondary_port, read_preference=ReadPreference.SECONDARY ) @@ -547,7 +547,7 @@ def test_gridfs_secondary_lazy(self): # Should detect it's connected to secondary and not attempt to # create index. secondary_host, secondary_port = one(self.client.secondaries) - client = single_client( + client = self.single_client( secondary_host, secondary_port, read_preference=ReadPreference.SECONDARY, connect=False ) diff --git a/test/test_gridfs_bucket.py b/test/test_gridfs_bucket.py index c3945d1053..28adb7051a 100644 --- a/test/test_gridfs_bucket.py +++ b/test/test_gridfs_bucket.py @@ -27,7 +27,7 @@ sys.path[0:0] = [""] from test import IntegrationTest, client_context, unittest -from test.utils import joinall, one, rs_client, rs_or_single_client, single_client +from test.utils import joinall, one import gridfs from bson.binary import Binary @@ -345,7 +345,7 @@ def iterate_file(grid_file): self.assertTrue(iterate_file(fstr)) def test_gridfs_lazy_connect(self): - client = MongoClient("badhost", connect=False, serverSelectionTimeoutMS=0) + client = self.single_client("badhost", connect=False, serverSelectionTimeoutMS=0) cdb = client.db gfs = gridfs.GridFSBucket(cdb) self.assertRaises(ServerSelectionTimeoutError, gfs.delete, 0) @@ -391,7 +391,7 @@ def test_grid_in_non_int_chunksize(self): def test_unacknowledged(self): # w=0 is prohibited. with self.assertRaises(ConfigurationError): - gridfs.GridFSBucket(rs_or_single_client(w=0).pymongo_test) + gridfs.GridFSBucket(self.rs_or_single_client(w=0).pymongo_test) def test_rename(self): _id = self.fs.upload_from_stream("first_name", b"testing") @@ -489,7 +489,7 @@ def tearDownClass(cls): client_context.client.drop_database("gfsbucketreplica") def test_gridfs_replica_set(self): - rsc = rs_client(w=client_context.w, read_preference=ReadPreference.SECONDARY) + rsc = self.rs_client(w=client_context.w, read_preference=ReadPreference.SECONDARY) gfs = gridfs.GridFSBucket(rsc.gfsbucketreplica, "gfsbucketreplicatest") oid = gfs.upload_from_stream("test_filename", b"foo") @@ -498,7 +498,7 @@ def test_gridfs_replica_set(self): def test_gridfs_secondary(self): secondary_host, secondary_port = one(self.client.secondaries) - secondary_connection = single_client( + secondary_connection = self.single_client( secondary_host, secondary_port, read_preference=ReadPreference.SECONDARY ) @@ -513,7 +513,7 @@ def test_gridfs_secondary_lazy(self): # Should detect it's connected to secondary and not attempt to # create index. secondary_host, secondary_port = one(self.client.secondaries) - client = single_client( + client = self.single_client( secondary_host, secondary_port, read_preference=ReadPreference.SECONDARY, connect=False ) diff --git a/test/test_heartbeat_monitoring.py b/test/test_heartbeat_monitoring.py index 1302df8fde..5e203a33b3 100644 --- a/test/test_heartbeat_monitoring.py +++ b/test/test_heartbeat_monitoring.py @@ -20,7 +20,7 @@ sys.path[0:0] = [""] from test import IntegrationTest, client_knobs, unittest -from test.utils import HeartbeatEventListener, MockPool, single_client, wait_until +from test.utils import HeartbeatEventListener, MockPool, wait_until from pymongo.errors import ConnectionFailure from pymongo.hello import Hello, HelloCompat @@ -40,7 +40,7 @@ def _check_with_socket(self, *args, **kwargs): raise responses[1] return Hello(responses[1]), 99 - m = single_client( + m = self.single_client( h=uri, event_listeners=(listener,), _monitor_class=MockMonitor, _pool_class=MockPool ) diff --git a/test/test_load_balancer.py b/test/test_load_balancer.py index a4db7395f1..23bea4d984 100644 --- a/test/test_load_balancer.py +++ b/test/test_load_balancer.py @@ -26,7 +26,7 @@ from test import IntegrationTest, client_context, unittest from test.unified_format import generate_test_classes -from test.utils import ExceptionCatchingThread, get_pool, rs_client, wait_until +from test.utils import ExceptionCatchingThread, get_pool, wait_until pytestmark = pytest.mark.load_balancer @@ -54,7 +54,7 @@ def test_connections_are_only_returned_once(self): @client_context.require_load_balancer def test_unpin_committed_transaction(self): - client = rs_client() + client = self.rs_client() self.addCleanup(client.close) pool = get_pool(client) coll = client[self.db.name].test @@ -85,7 +85,7 @@ def create_resource(coll): self._test_no_gc_deadlock(create_resource) def _test_no_gc_deadlock(self, create_resource): - client = rs_client() + client = self.rs_client() self.addCleanup(client.close) pool = get_pool(client) coll = client[self.db.name].test @@ -124,7 +124,7 @@ def _test_no_gc_deadlock(self, create_resource): @client_context.require_transactions def test_session_gc(self): - client = rs_client() + client = self.rs_client() self.addCleanup(client.close) pool = get_pool(client) session = client.start_session() diff --git a/test/test_logger.py b/test/test_logger.py index c0011ec3a5..b3c8e6d176 100644 --- a/test/test_logger.py +++ b/test/test_logger.py @@ -15,7 +15,6 @@ import os from test import IntegrationTest, unittest -from test.utils import single_client from unittest.mock import patch from bson import json_util @@ -85,7 +84,7 @@ def test_truncation_multi_byte_codepoints(self): self.assertEqual(last_3_bytes, str_to_repeat) def test_logging_without_listeners(self): - c = single_client() + c = self.single_client() self.assertEqual(len(c._event_listeners.event_listeners()), 0) with self.assertLogs("pymongo.connection", level="DEBUG") as cm: c.db.test.insert_one({"x": "1"}) diff --git a/test/test_max_staleness.py b/test/test_max_staleness.py index 1b0130f7d8..32d09ada9a 100644 --- a/test/test_max_staleness.py +++ b/test/test_max_staleness.py @@ -20,15 +20,14 @@ import time import warnings +from pymongo import MongoClient from pymongo.operations import _Op sys.path[0:0] = [""] -from test import client_context, unittest -from test.utils import rs_or_single_client +from test import PyMongoTestCase, client_context, unittest from test.utils_selection_tests import create_selection_tests -from pymongo import MongoClient from pymongo.errors import ConfigurationError from pymongo.server_selectors import writable_server_selector @@ -40,54 +39,58 @@ class TestAllScenarios(create_selection_tests(_TEST_PATH)): # type: ignore pass -class TestMaxStaleness(unittest.TestCase): +class TestMaxStaleness(PyMongoTestCase): def test_max_staleness(self): - client = MongoClient() + client = self.simple_client() self.assertEqual(-1, client.read_preference.max_staleness) - client = MongoClient("mongodb://a/?readPreference=secondary") + client = self.simple_client("mongodb://a/?readPreference=secondary") self.assertEqual(-1, client.read_preference.max_staleness) # These tests are specified in max-staleness-tests.rst. with self.assertRaises(ConfigurationError): # Default read pref "primary" can't be used with max staleness. - MongoClient("mongodb://a/?maxStalenessSeconds=120") + self.simple_client("mongodb://a/?maxStalenessSeconds=120") with self.assertRaises(ConfigurationError): # Read pref "primary" can't be used with max staleness. - MongoClient("mongodb://a/?readPreference=primary&maxStalenessSeconds=120") + self.simple_client("mongodb://a/?readPreference=primary&maxStalenessSeconds=120") - client = MongoClient("mongodb://host/?maxStalenessSeconds=-1") + client = self.simple_client("mongodb://host/?maxStalenessSeconds=-1") self.assertEqual(-1, client.read_preference.max_staleness) - client = MongoClient("mongodb://host/?readPreference=primary&maxStalenessSeconds=-1") + client = self.simple_client("mongodb://host/?readPreference=primary&maxStalenessSeconds=-1") self.assertEqual(-1, client.read_preference.max_staleness) - client = MongoClient("mongodb://host/?readPreference=secondary&maxStalenessSeconds=120") + client = self.simple_client( + "mongodb://host/?readPreference=secondary&maxStalenessSeconds=120" + ) self.assertEqual(120, client.read_preference.max_staleness) - client = MongoClient("mongodb://a/?readPreference=secondary&maxStalenessSeconds=1") + client = self.simple_client("mongodb://a/?readPreference=secondary&maxStalenessSeconds=1") self.assertEqual(1, client.read_preference.max_staleness) - client = MongoClient("mongodb://a/?readPreference=secondary&maxStalenessSeconds=-1") + client = self.simple_client("mongodb://a/?readPreference=secondary&maxStalenessSeconds=-1") self.assertEqual(-1, client.read_preference.max_staleness) - client = MongoClient(maxStalenessSeconds=-1, readPreference="nearest") + client = self.simple_client(maxStalenessSeconds=-1, readPreference="nearest") self.assertEqual(-1, client.read_preference.max_staleness) with self.assertRaises(TypeError): # Prohibit None. - MongoClient(maxStalenessSeconds=None, readPreference="nearest") + self.simple_client(maxStalenessSeconds=None, readPreference="nearest") def test_max_staleness_float(self): with self.assertRaises(TypeError) as ctx: - rs_or_single_client(maxStalenessSeconds=1.5, readPreference="nearest") + self.rs_or_single_client(maxStalenessSeconds=1.5, readPreference="nearest") self.assertIn("must be an integer", str(ctx.exception)) with warnings.catch_warnings(record=True) as ctx: warnings.simplefilter("always") - client = MongoClient("mongodb://host/?maxStalenessSeconds=1.5&readPreference=nearest") + client = self.simple_client( + "mongodb://host/?maxStalenessSeconds=1.5&readPreference=nearest" + ) # Option was ignored. self.assertEqual(-1, client.read_preference.max_staleness) @@ -96,13 +99,15 @@ def test_max_staleness_float(self): def test_max_staleness_zero(self): # Zero is too small. with self.assertRaises(ValueError) as ctx: - rs_or_single_client(maxStalenessSeconds=0, readPreference="nearest") + self.rs_or_single_client(maxStalenessSeconds=0, readPreference="nearest") self.assertIn("must be a positive integer", str(ctx.exception)) with warnings.catch_warnings(record=True) as ctx: warnings.simplefilter("always") - client = MongoClient("mongodb://host/?maxStalenessSeconds=0&readPreference=nearest") + client = self.simple_client( + "mongodb://host/?maxStalenessSeconds=0&readPreference=nearest" + ) # Option was ignored. self.assertEqual(-1, client.read_preference.max_staleness) @@ -111,7 +116,7 @@ def test_max_staleness_zero(self): @client_context.require_replica_set def test_last_write_date(self): # From max-staleness-tests.rst, "Parse lastWriteDate". - client = rs_or_single_client(heartbeatFrequencyMS=500) + client = self.rs_or_single_client(heartbeatFrequencyMS=500) client.pymongo_test.test.insert_one({}) # Wait for the server description to be updated. time.sleep(1) diff --git a/test/test_monitor.py b/test/test_monitor.py index fd82fc1ca4..f8e9443fae 100644 --- a/test/test_monitor.py +++ b/test/test_monitor.py @@ -18,6 +18,7 @@ import gc import subprocess import sys +import warnings from functools import partial sys.path[0:0] = [""] @@ -25,7 +26,6 @@ from test import IntegrationTest, connected, unittest from test.utils import ( ServerAndTopologyEventListener, - single_client, wait_until, ) @@ -47,30 +47,31 @@ def get_executors(client): return [e for e in executors if e is not None] -def create_client(): - listener = ServerAndTopologyEventListener() - client = single_client(event_listeners=[listener]) - connected(client) - return client - - class TestMonitor(IntegrationTest): + def create_client(self): + listener = ServerAndTopologyEventListener() + client = self.unmanaged_single_client(event_listeners=[listener]) + connected(client) + return client + def test_cleanup_executors_on_client_del(self): - client = create_client() - executors = get_executors(client) - self.assertEqual(len(executors), 4) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + client = self.create_client() + executors = get_executors(client) + self.assertEqual(len(executors), 4) - # Each executor stores a weakref to itself in _EXECUTORS. - executor_refs = [(r, r()._name) for r in _EXECUTORS.copy() if r() in executors] + # Each executor stores a weakref to itself in _EXECUTORS. + executor_refs = [(r, r()._name) for r in _EXECUTORS.copy() if r() in executors] - del executors - del client + del executors + del client - for ref, name in executor_refs: - wait_until(partial(unregistered, ref), f"unregister executor: {name}", timeout=5) + for ref, name in executor_refs: + wait_until(partial(unregistered, ref), f"unregister executor: {name}", timeout=5) def test_cleanup_executors_on_client_close(self): - client = create_client() + client = self.create_client() executors = get_executors(client) self.assertEqual(len(executors), 4) diff --git a/test/test_monitoring.py b/test/test_monitoring.py index 8322e29918..a0c520ed27 100644 --- a/test/test_monitoring.py +++ b/test/test_monitoring.py @@ -31,8 +31,6 @@ ) from test.utils import ( EventListener, - rs_or_single_client, - single_client, wait_until, ) @@ -57,7 +55,9 @@ class TestCommandMonitoring(IntegrationTest): def _setup_class(cls): super()._setup_class() cls.listener = EventListener() - cls.client = rs_or_single_client(event_listeners=[cls.listener], retryWrites=False) + cls.client = cls.unmanaged_rs_or_single_client( + event_listeners=[cls.listener], retryWrites=False + ) @classmethod def _tearDown_class(cls): @@ -405,7 +405,7 @@ def test_get_more_failure(self): @client_context.require_secondaries_count(1) def test_not_primary_error(self): address = next(iter(client_context.client.secondaries)) - client = single_client(*address, event_listeners=[self.listener]) + client = self.single_client(*address, event_listeners=[self.listener]) # Clear authentication command results from the listener. client.admin.command("ping") self.listener.reset() @@ -1144,7 +1144,7 @@ def _setup_class(cls): # We plan to call register(), which internally modifies _LISTENERS. cls.saved_listeners = copy.deepcopy(monitoring._LISTENERS) monitoring.register(cls.listener) - cls.client = single_client() + cls.client = cls.unmanaged_single_client() # Get one (authenticated) socket in the pool. cls.client.pymongo_test.command("ping") diff --git a/test/test_pooling.py b/test/test_pooling.py index 31259d7b3a..3b867965bd 100644 --- a/test/test_pooling.py +++ b/test/test_pooling.py @@ -31,7 +31,7 @@ sys.path[0:0] = [""] from test import IntegrationTest, client_context, unittest -from test.utils import delay, get_pool, joinall, rs_or_single_client +from test.utils import delay, get_pool, joinall from pymongo.socket_checker import SocketChecker from pymongo.synchronous.pool import Pool, PoolOptions @@ -151,7 +151,7 @@ class _TestPoolingBase(IntegrationTest): def setUp(self): super().setUp() - self.c = rs_or_single_client() + self.c = self.rs_or_single_client() db = self.c[DB] db.unique.drop() db.test.drop() @@ -378,7 +378,7 @@ def test_checkout_more_than_max_pool_size(self): socket_info.close_conn(None) def test_maxConnecting(self): - client = rs_or_single_client() + client = self.rs_or_single_client() self.addCleanup(client.close) self.client.test.test.insert_one({}) self.addCleanup(self.client.test.test.delete_many, {}) @@ -415,7 +415,7 @@ def find_one(): @client_context.require_failCommand_appName def test_csot_timeout_message(self): - client = rs_or_single_client(appName="connectionTimeoutApp") + client = self.rs_or_single_client(appName="connectionTimeoutApp") self.addCleanup(client.close) # Mock an operation failing due to pymongo.timeout(). mock_connection_timeout = { @@ -440,7 +440,7 @@ def test_csot_timeout_message(self): @client_context.require_failCommand_appName def test_socket_timeout_message(self): - client = rs_or_single_client(socketTimeoutMS=500, appName="connectionTimeoutApp") + client = self.rs_or_single_client(socketTimeoutMS=500, appName="connectionTimeoutApp") self.addCleanup(client.close) # Mock an operation failing due to socketTimeoutMS. mock_connection_timeout = { @@ -479,7 +479,7 @@ def test_connection_timeout_message(self): }, } - client = rs_or_single_client( + client = self.rs_or_single_client( connectTimeoutMS=500, socketTimeoutMS=500, appName="connectionTimeoutApp", @@ -502,7 +502,7 @@ def test_connection_timeout_message(self): class TestPoolMaxSize(_TestPoolingBase): def test_max_pool_size(self): max_pool_size = 4 - c = rs_or_single_client(maxPoolSize=max_pool_size) + c = self.rs_or_single_client(maxPoolSize=max_pool_size) self.addCleanup(c.close) collection = c[DB].test @@ -538,7 +538,7 @@ def f(): self.assertEqual(0, cx_pool.requests) def test_max_pool_size_none(self): - c = rs_or_single_client(maxPoolSize=None) + c = self.rs_or_single_client(maxPoolSize=None) self.addCleanup(c.close) collection = c[DB].test @@ -570,7 +570,7 @@ def f(): self.assertEqual(cx_pool.max_pool_size, float("inf")) def test_max_pool_size_zero(self): - c = rs_or_single_client(maxPoolSize=0) + c = self.rs_or_single_client(maxPoolSize=0) self.addCleanup(c.close) pool = get_pool(c) self.assertEqual(pool.max_pool_size, float("inf")) diff --git a/test/test_read_concern.py b/test/test_read_concern.py index 97855872cf..ea9ce49a30 100644 --- a/test/test_read_concern.py +++ b/test/test_read_concern.py @@ -21,7 +21,7 @@ sys.path[0:0] = [""] from test import IntegrationTest, client_context -from test.utils import OvertCommandListener, rs_or_single_client +from test.utils import OvertCommandListener from bson.son import SON from pymongo.errors import OperationFailure @@ -36,7 +36,7 @@ class TestReadConcern(IntegrationTest): def setUpClass(cls): super().setUpClass() cls.listener = OvertCommandListener() - cls.client = rs_or_single_client(event_listeners=[cls.listener]) + cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener]) cls.db = cls.client.pymongo_test client_context.client.pymongo_test.create_collection("coll") @@ -67,7 +67,7 @@ def test_read_concern(self): def test_read_concern_uri(self): uri = f"mongodb://{client_context.pair}/?readConcernLevel=majority" - client = rs_or_single_client(uri, connect=False) + client = self.rs_or_single_client(uri, connect=False) self.assertEqual(ReadConcern("majority"), client.read_concern) def test_invalid_read_concern(self): diff --git a/test/test_read_preferences.py b/test/test_read_preferences.py index 2cd3195f40..32883399e1 100644 --- a/test/test_read_preferences.py +++ b/test/test_read_preferences.py @@ -30,8 +30,6 @@ from test.utils import ( OvertCommandListener, one, - rs_client, - single_client, wait_until, ) from test.version import Version @@ -58,7 +56,7 @@ class TestSelections(IntegrationTest): @client_context.require_connection def test_bool(self): - client = single_client() + client = self.single_client() wait_until(lambda: client.address, "discover primary") selection = Selection.from_topology_description(client._topology.description) @@ -128,7 +126,7 @@ def read_from_which_kind(self, client): return None def assertReadsFrom(self, expected, **kwargs): - c = rs_client(**kwargs) + c = self.rs_client(**kwargs) wait_until(lambda: len(c.nodes - c.arbiters) == client_context.w, "discovered all nodes") used = self.read_from_which_kind(c) @@ -139,7 +137,7 @@ class TestSingleSecondaryOk(TestReadPreferencesBase): def test_reads_from_secondary(self): host, port = next(iter(self.client.secondaries)) # Direct connection to a secondary. - client = single_client(host, port) + client = self.single_client(host, port) self.assertFalse(client.is_primary) # Regardless of read preference, we should be able to do @@ -175,19 +173,21 @@ def test_mode_validation(self): ReadPreference.SECONDARY_PREFERRED, ReadPreference.NEAREST, ): - self.assertEqual(mode, rs_client(read_preference=mode).read_preference) + self.assertEqual(mode, self.rs_client(read_preference=mode).read_preference) - self.assertRaises(TypeError, rs_client, read_preference="foo") + self.assertRaises(TypeError, self.rs_client, read_preference="foo") def test_tag_sets_validation(self): S = Secondary(tag_sets=[{}]) - self.assertEqual([{}], rs_client(read_preference=S).read_preference.tag_sets) + self.assertEqual([{}], self.rs_client(read_preference=S).read_preference.tag_sets) S = Secondary(tag_sets=[{"k": "v"}]) - self.assertEqual([{"k": "v"}], rs_client(read_preference=S).read_preference.tag_sets) + self.assertEqual([{"k": "v"}], self.rs_client(read_preference=S).read_preference.tag_sets) S = Secondary(tag_sets=[{"k": "v"}, {}]) - self.assertEqual([{"k": "v"}, {}], rs_client(read_preference=S).read_preference.tag_sets) + self.assertEqual( + [{"k": "v"}, {}], self.rs_client(read_preference=S).read_preference.tag_sets + ) self.assertRaises(ValueError, Secondary, tag_sets=[]) @@ -200,20 +200,22 @@ def test_tag_sets_validation(self): def test_threshold_validation(self): self.assertEqual( - 17, rs_client(localThresholdMS=17, connect=False).options.local_threshold_ms + 17, self.rs_client(localThresholdMS=17, connect=False).options.local_threshold_ms ) self.assertEqual( - 42, rs_client(localThresholdMS=42, connect=False).options.local_threshold_ms + 42, self.rs_client(localThresholdMS=42, connect=False).options.local_threshold_ms ) self.assertEqual( - 666, rs_client(localThresholdMS=666, connect=False).options.local_threshold_ms + 666, self.rs_client(localThresholdMS=666, connect=False).options.local_threshold_ms ) - self.assertEqual(0, rs_client(localThresholdMS=0, connect=False).options.local_threshold_ms) + self.assertEqual( + 0, self.rs_client(localThresholdMS=0, connect=False).options.local_threshold_ms + ) - self.assertRaises(ValueError, rs_client, localthresholdms=-1) + self.assertRaises(ValueError, self.rs_client, localthresholdms=-1) def test_zero_latency(self): ping_times: set = set() @@ -223,7 +225,7 @@ def test_zero_latency(self): for ping_time, host in zip(ping_times, self.client.nodes): ServerDescription._host_to_round_trip_time[host] = ping_time try: - client = connected(rs_client(readPreference="nearest", localThresholdMS=0)) + client = connected(self.rs_client(readPreference="nearest", localThresholdMS=0)) wait_until(lambda: client.nodes == self.client.nodes, "discovered all nodes") host = self.read_from_which_host(client) for _ in range(5): @@ -236,7 +238,7 @@ def test_primary(self): def test_primary_with_tags(self): # Tags not allowed with PRIMARY - self.assertRaises(ConfigurationError, rs_client, tag_sets=[{"dc": "ny"}]) + self.assertRaises(ConfigurationError, self.rs_client, tag_sets=[{"dc": "ny"}]) def test_primary_preferred(self): self.assertReadsFrom("primary", read_preference=ReadPreference.PRIMARY_PREFERRED) @@ -250,7 +252,9 @@ def test_secondary_preferred(self): def test_nearest(self): # With high localThresholdMS, expect to read from any # member - c = rs_client(read_preference=ReadPreference.NEAREST, localThresholdMS=10000) # 10 seconds + c = self.rs_client( + read_preference=ReadPreference.NEAREST, localThresholdMS=10000 + ) # 10 seconds data_members = {self.client.primary} | self.client.secondaries @@ -540,7 +544,7 @@ def test_send_hedge(self): if client_context.supports_secondary_read_pref: cases["secondary"] = Secondary listener = OvertCommandListener() - client = rs_client(event_listeners=[listener]) + client = self.rs_client(event_listeners=[listener]) self.addCleanup(client.close) client.admin.command("ping") for _mode, cls in cases.items(): @@ -667,13 +671,13 @@ def test_mongos_max_staleness(self): else: self.fail("mongos accepted invalid staleness") - coll = single_client( + coll = self.single_client( readPreference="secondaryPreferred", maxStalenessSeconds=120 ).pymongo_test.test # No error coll.find_one() - coll = single_client( + coll = self.single_client( readPreference="secondaryPreferred", maxStalenessSeconds=10 ).pymongo_test.test try: diff --git a/test/test_read_write_concern_spec.py b/test/test_read_write_concern_spec.py index 3e37e8f9a5..67943d495d 100644 --- a/test/test_read_write_concern_spec.py +++ b/test/test_read_write_concern_spec.py @@ -24,12 +24,7 @@ from test import IntegrationTest, client_context, unittest from test.unified_format import generate_test_classes -from test.utils import ( - EventListener, - disable_replication, - enable_replication, - rs_or_single_client, -) +from test.utils import EventListener from pymongo import DESCENDING from pymongo.errors import ( @@ -51,7 +46,7 @@ class TestReadWriteConcernSpec(IntegrationTest): def test_omit_default_read_write_concern(self): listener = EventListener() # Client with default readConcern and writeConcern - client = rs_or_single_client(event_listeners=[listener]) + client = self.rs_or_single_client(event_listeners=[listener]) self.addCleanup(client.close) collection = client.pymongo_test.collection # Prepare for tests of find() and aggregate(). @@ -104,7 +99,9 @@ def insert_command_default_write_concern(): def assertWriteOpsRaise(self, write_concern, expected_exception): wc = write_concern.document # Set socket timeout to avoid indefinite stalls - client = rs_or_single_client(w=wc["w"], wTimeoutMS=wc["wtimeout"], socketTimeoutMS=30000) + client = self.rs_or_single_client( + w=wc["w"], wTimeoutMS=wc["wtimeout"], socketTimeoutMS=30000 + ) db = client.get_database("pymongo_test") coll = db.test @@ -167,9 +164,9 @@ def test_raise_write_concern_error(self): @client_context.require_test_commands def test_raise_wtimeout(self): self.addCleanup(client_context.client.drop_database, "pymongo_test") - self.addCleanup(enable_replication, client_context.client) + self.addCleanup(self.enable_replication, client_context.client) # Disable replication to guarantee a wtimeout error. - disable_replication(client_context.client) + self.disable_replication(client_context.client) self.assertWriteOpsRaise(WriteConcern(w=client_context.w, wtimeout=1), WTimeoutError) @client_context.require_failCommand_fail_point @@ -209,7 +206,7 @@ def test_error_includes_errInfo(self): @client_context.require_version_min(4, 9) def test_write_error_details_exposes_errinfo(self): listener = EventListener() - client = rs_or_single_client(event_listeners=[listener]) + client = self.rs_or_single_client(event_listeners=[listener]) self.addCleanup(client.close) db = client.errinfotest self.addCleanup(client.drop_database, "errinfotest") diff --git a/test/test_retryable_reads.py b/test/test_retryable_reads.py index b0fa42a0c9..a1c72bb7b6 100644 --- a/test/test_retryable_reads.py +++ b/test/test_retryable_reads.py @@ -34,7 +34,6 @@ from test.utils import ( CMAPListener, OvertCommandListener, - rs_or_single_client, set_fail_point, ) @@ -93,7 +92,9 @@ def test_pool_paused_error_is_retryable(self): self.skipTest("Test is flakey on PyPy") cmap_listener = CMAPListener() cmd_listener = OvertCommandListener() - client = rs_or_single_client(maxPoolSize=1, event_listeners=[cmap_listener, cmd_listener]) + client = self.rs_or_single_client( + maxPoolSize=1, event_listeners=[cmap_listener, cmd_listener] + ) self.addCleanup(client.close) for _ in range(10): cmap_listener.reset() @@ -163,13 +164,13 @@ def test_retryable_reads_in_sharded_cluster_multiple_available(self): mongos_clients = [] for mongos in client_context.mongos_seeds().split(","): - client = rs_or_single_client(mongos) + client = self.rs_or_single_client(mongos) set_fail_point(client, fail_command) self.addCleanup(client.close) mongos_clients.append(client) listener = OvertCommandListener() - client = rs_or_single_client( + client = self.rs_or_single_client( client_context.mongos_seeds(), appName="retryableReadTest", event_listeners=[listener], diff --git a/test/test_retryable_writes.py b/test/test_retryable_writes.py index 2938b7efaf..89454ad236 100644 --- a/test/test_retryable_writes.py +++ b/test/test_retryable_writes.py @@ -28,7 +28,6 @@ DeprecationFilter, EventListener, OvertCommandListener, - rs_or_single_client, set_fail_point, ) from test.version import Version @@ -145,7 +144,7 @@ def setUpClass(cls): # Speed up the tests by decreasing the heartbeat frequency. cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) cls.knobs.enable() - cls.client = rs_or_single_client(retryWrites=True) + cls.client = cls.unmanaged_rs_or_single_client(retryWrites=True) cls.db = cls.client.pymongo_test @classmethod @@ -181,7 +180,9 @@ def setUpClass(cls): cls.knobs = client_knobs(heartbeat_frequency=0.1, min_heartbeat_interval=0.1) cls.knobs.enable() cls.listener = OvertCommandListener() - cls.client = rs_or_single_client(retryWrites=True, event_listeners=[cls.listener]) + cls.client = cls.unmanaged_rs_or_single_client( + retryWrites=True, event_listeners=[cls.listener] + ) cls.db = cls.client.pymongo_test @classmethod @@ -204,7 +205,7 @@ def tearDown(self): def test_supported_single_statement_no_retry(self): listener = OvertCommandListener() - client = rs_or_single_client(retryWrites=False, event_listeners=[listener]) + client = self.rs_or_single_client(retryWrites=False, event_listeners=[listener]) self.addCleanup(client.close) for method, args, kwargs in retryable_single_statement_ops(client.db.retryable_write_test): msg = f"{method.__name__}(*{args!r}, **{kwargs!r})" @@ -297,7 +298,7 @@ def test_unsupported_single_statement(self): def test_server_selection_timeout_not_retried(self): """A ServerSelectionTimeoutError is not retried.""" listener = OvertCommandListener() - client = MongoClient( + client = self.simple_client( "somedomainthatdoesntexist.org", serverSelectionTimeoutMS=1, retryWrites=True, @@ -317,7 +318,7 @@ def test_retry_timeout_raises_original_error(self): original error. """ listener = OvertCommandListener() - client = rs_or_single_client(retryWrites=True, event_listeners=[listener]) + client = self.rs_or_single_client(retryWrites=True, event_listeners=[listener]) self.addCleanup(client.close) topology = client._topology select_server = topology.select_server @@ -443,13 +444,13 @@ def test_retryable_writes_in_sharded_cluster_multiple_available(self): mongos_clients = [] for mongos in client_context.mongos_seeds().split(","): - client = rs_or_single_client(mongos) + client = self.rs_or_single_client(mongos) set_fail_point(client, fail_command) self.addCleanup(client.close) mongos_clients.append(client) listener = OvertCommandListener() - client = rs_or_single_client( + client = self.rs_or_single_client( client_context.mongos_seeds(), appName="retryableWriteTest", event_listeners=[listener], @@ -492,7 +493,7 @@ def setUpClass(cls): @client_knobs(heartbeat_frequency=0.05, min_heartbeat_interval=0.05) def test_RetryableWriteError_error_label(self): listener = OvertCommandListener() - client = rs_or_single_client(retryWrites=True, event_listeners=[listener]) + client = self.rs_or_single_client(retryWrites=True, event_listeners=[listener]) self.addCleanup(client.close) # Ensure collection exists. @@ -551,7 +552,9 @@ class TestPoolPausedError(IntegrationTest): def test_pool_paused_error_is_retryable(self): cmap_listener = CMAPListener() cmd_listener = OvertCommandListener() - client = rs_or_single_client(maxPoolSize=1, event_listeners=[cmap_listener, cmd_listener]) + client = self.rs_or_single_client( + maxPoolSize=1, event_listeners=[cmap_listener, cmd_listener] + ) self.addCleanup(client.close) for _ in range(10): cmap_listener.reset() @@ -613,7 +616,7 @@ def test_returns_original_error_code( self, ): cmd_listener = InsertEventListener() - client = rs_or_single_client(retryWrites=True, event_listeners=[cmd_listener]) + client = self.rs_or_single_client(retryWrites=True, event_listeners=[cmd_listener]) client.test.test.drop() self.addCleanup(client.close) cmd_listener.reset() @@ -650,7 +653,7 @@ def test_increment_transaction_id_without_sending_command(self): the first attempt fails before sending the command. """ listener = OvertCommandListener() - client = rs_or_single_client(retryWrites=True, event_listeners=[listener]) + client = self.rs_or_single_client(retryWrites=True, event_listeners=[listener]) self.addCleanup(client.close) topology = client._topology select_server = topology.select_server diff --git a/test/test_sdam_monitoring_spec.py b/test/test_sdam_monitoring_spec.py index 8e0a3cbbb4..81b208d511 100644 --- a/test/test_sdam_monitoring_spec.py +++ b/test/test_sdam_monitoring_spec.py @@ -25,7 +25,6 @@ from test import IntegrationTest, client_context, client_knobs, unittest from test.utils import ( ServerAndTopologyEventListener, - rs_or_single_client, server_name_to_type, wait_until, ) @@ -279,7 +278,7 @@ def setUpClass(cls): cls.knobs.enable() cls.listener = ServerAndTopologyEventListener() retry_writes = client_context.supports_transactions() - cls.test_client = rs_or_single_client( + cls.test_client = cls.unmanaged_rs_or_single_client( event_listeners=[cls.listener], retryWrites=retry_writes ) cls.coll = cls.test_client[cls.client.db.name].test diff --git a/test/test_server_selection.py b/test/test_server_selection.py index d3526617f6..67e9716bf4 100644 --- a/test/test_server_selection.py +++ b/test/test_server_selection.py @@ -33,7 +33,6 @@ from test.utils import ( EventListener, FunctionCallRecorder, - rs_or_single_client, wait_until, ) from test.utils_selection_tests import ( @@ -76,7 +75,9 @@ def custom_selector(servers): # Initialize client with appropriate listeners. listener = EventListener() - client = rs_or_single_client(server_selector=custom_selector, event_listeners=[listener]) + client = self.rs_or_single_client( + server_selector=custom_selector, event_listeners=[listener] + ) self.addCleanup(client.close) coll = client.get_database("testdb", read_preference=ReadPreference.NEAREST).coll self.addCleanup(client.drop_database, "testdb") @@ -117,7 +118,7 @@ def test_selector_called(self): selector = FunctionCallRecorder(lambda x: x) # Client setup. - mongo_client = rs_or_single_client(server_selector=selector) + mongo_client = self.rs_or_single_client(server_selector=selector) test_collection = mongo_client.testdb.test_collection self.addCleanup(mongo_client.close) self.addCleanup(mongo_client.drop_database, "testdb") diff --git a/test/test_server_selection_in_window.py b/test/test_server_selection_in_window.py index 9dced595c9..8e030f61e8 100644 --- a/test/test_server_selection_in_window.py +++ b/test/test_server_selection_in_window.py @@ -22,7 +22,6 @@ OvertCommandListener, SpecTestCreator, get_pool, - rs_client, wait_until, ) from test.utils_selection_tests import create_topology @@ -134,7 +133,7 @@ def test_load_balancing(self): listener = OvertCommandListener() # PYTHON-2584: Use a large localThresholdMS to avoid the impact of # varying RTTs. - client = rs_client( + client = self.rs_client( client_context.mongos_seeds(), appName="loadBalancingTest", event_listeners=[listener], diff --git a/test/test_session.py b/test/test_session.py index 563b33c70e..9f94ded927 100644 --- a/test/test_session.py +++ b/test/test_session.py @@ -36,7 +36,6 @@ from test.utils import ( EventListener, ExceptionCatchingThread, - rs_or_single_client, wait_until, ) @@ -88,7 +87,7 @@ def _setup_class(cls): super()._setup_class() # Create a second client so we can make sure clients cannot share # sessions. - cls.client2 = rs_or_single_client() + cls.client2 = cls.unmanaged_rs_or_single_client() # Redact no commands, so we can test user-admin commands have "lsid". cls.sensitive_commands = monitoring._SENSITIVE_COMMANDS.copy() @@ -103,7 +102,7 @@ def _tearDown_class(cls): def setUp(self): self.listener = SessionTestListener() self.session_checker_listener = SessionTestListener() - self.client = rs_or_single_client( + self.client = self.rs_or_single_client( event_listeners=[self.listener, self.session_checker_listener] ) self.addCleanup(self.client.close) @@ -200,7 +199,7 @@ def test_implicit_sessions_checkout(self): failures = 0 for _ in range(5): listener = EventListener() - client = rs_or_single_client(event_listeners=[listener], maxPoolSize=1) + client = self.rs_or_single_client(event_listeners=[listener], maxPoolSize=1) cursor = client.db.test.find({}) ops: List[Tuple[Callable, List[Any]]] = [ (client.db.test.find_one, [{"_id": 1}]), @@ -283,7 +282,7 @@ def test_end_session(self): def test_end_sessions(self): # Use a new client so that the tearDown hook does not error. listener = SessionTestListener() - client = rs_or_single_client(event_listeners=[listener]) + client = self.rs_or_single_client(event_listeners=[listener]) # Start many sessions. sessions = [client.start_session() for _ in range(_MAX_END_SESSIONS + 1)] for s in sessions: @@ -787,8 +786,7 @@ def _test_unacknowledged_ops(self, client, *ops): def test_unacknowledged_writes(self): # Ensure the collection exists. self.client.pymongo_test.test_unacked_writes.insert_one({}) - client = rs_or_single_client(w=0, event_listeners=[self.listener]) - self.addCleanup(client.close) + client = self.rs_or_single_client(w=0, event_listeners=[self.listener]) db = client.pymongo_test coll = db.test_unacked_writes ops: list = [ @@ -836,7 +834,7 @@ class TestCausalConsistency(UnitTest): @classmethod def _setup_class(cls): cls.listener = SessionTestListener() - cls.client = rs_or_single_client(event_listeners=[cls.listener]) + cls.client = cls.unmanaged_rs_or_single_client(event_listeners=[cls.listener]) @classmethod def _tearDown_class(cls): @@ -1137,8 +1135,7 @@ def setUp(self): def test_cluster_time(self): listener = SessionTestListener() # Prevent heartbeats from updating $clusterTime between operations. - client = rs_or_single_client(event_listeners=[listener], heartbeatFrequencyMS=999999) - self.addCleanup(client.close) + client = self.rs_or_single_client(event_listeners=[listener], heartbeatFrequencyMS=999999) collection = client.pymongo_test.collection # Prepare for tests of find() and aggregate(). collection.insert_many([{} for _ in range(10)]) diff --git a/test/test_srv_polling.py b/test/test_srv_polling.py index 405db14ac6..e01552bf7d 100644 --- a/test/test_srv_polling.py +++ b/test/test_srv_polling.py @@ -21,7 +21,7 @@ sys.path[0:0] = [""] -from test import client_knobs, unittest +from test import PyMongoTestCase, client_knobs, unittest from test.utils import FunctionCallRecorder, wait_until import pymongo @@ -86,7 +86,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): self.disable() -class TestSrvPolling(unittest.TestCase): +class TestSrvPolling(PyMongoTestCase): BASE_SRV_RESPONSE = [ ("localhost.test.build.10gen.cc", 27017), ("localhost.test.build.10gen.cc", 27018), @@ -167,7 +167,7 @@ def dns_resolver_response(): # Patch timeouts to ensure short test running times. with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME): - client = MongoClient(self.CONNECTION_STRING) + client = self.simple_client(self.CONNECTION_STRING) self.assert_nodelist_change(self.BASE_SRV_RESPONSE, client) # Patch list of hosts returned by DNS query. with SrvPollingKnobs( @@ -231,7 +231,7 @@ def final_callback(): count_resolver_calls=True, ): # Client uses unpatched method to get initial nodelist - client = MongoClient(self.CONNECTION_STRING) + client = self.simple_client(self.CONNECTION_STRING) # Invalid DNS resolver response should not change nodelist. self.assert_nodelist_nochange(self.BASE_SRV_RESPONSE, client) @@ -264,8 +264,7 @@ def nodelist_callback(): return response with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME): - client = MongoClient(self.CONNECTION_STRING, srvMaxHosts=0) - self.addCleanup(client.close) + client = self.simple_client(self.CONNECTION_STRING, srvMaxHosts=0) with SrvPollingKnobs(nodelist_callback=nodelist_callback): self.assert_nodelist_change(response, client) @@ -279,8 +278,7 @@ def nodelist_callback(): return response with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME): - client = MongoClient(self.CONNECTION_STRING, srvMaxHosts=2) - self.addCleanup(client.close) + client = self.simple_client(self.CONNECTION_STRING, srvMaxHosts=2) with SrvPollingKnobs(nodelist_callback=nodelist_callback): self.assert_nodelist_change(response, client) @@ -295,8 +293,7 @@ def nodelist_callback(): return response with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME): - client = MongoClient(self.CONNECTION_STRING, srvMaxHosts=2) - self.addCleanup(client.close) + client = self.simple_client(self.CONNECTION_STRING, srvMaxHosts=2) with SrvPollingKnobs(nodelist_callback=nodelist_callback): sleep(2 * common.MIN_SRV_RESCAN_INTERVAL) final_topology = set(client.topology_description.server_descriptions()) @@ -305,8 +302,7 @@ def nodelist_callback(): def test_does_not_flipflop(self): with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME): - client = MongoClient(self.CONNECTION_STRING, srvMaxHosts=1) - self.addCleanup(client.close) + client = self.simple_client(self.CONNECTION_STRING, srvMaxHosts=1) old = set(client.topology_description.server_descriptions()) sleep(4 * WAIT_TIME) new = set(client.topology_description.server_descriptions()) @@ -323,7 +319,7 @@ def nodelist_callback(): return response with SrvPollingKnobs(ttl_time=WAIT_TIME, min_srv_rescan_interval=WAIT_TIME): - client = MongoClient( + client = self.simple_client( "mongodb+srv://test22.test.build.10gen.cc/?srvServiceName=customname" ) with SrvPollingKnobs(nodelist_callback=nodelist_callback): @@ -340,7 +336,7 @@ def resolver_response(): min_srv_rescan_interval=WAIT_TIME, nodelist_callback=resolver_response, ): - client = MongoClient(self.CONNECTION_STRING) + client = self.simple_client(self.CONNECTION_STRING) self.assertRaises( AssertionError, self.assert_nodelist_change, modified, client, timeout=WAIT_TIME / 2 ) diff --git a/test/test_ssl.py b/test/test_ssl.py index 5b3855a82a..36d7ba12b6 100644 --- a/test/test_ssl.py +++ b/test/test_ssl.py @@ -24,6 +24,7 @@ from test import ( HAVE_IPADDRESS, IntegrationTest, + PyMongoTestCase, SkipTest, client_context, connected, @@ -82,45 +83,45 @@ # use 'localhost' for the hostname of all hosts. -class TestClientSSL(unittest.TestCase): +class TestClientSSL(PyMongoTestCase): @unittest.skipIf(HAVE_SSL, "The ssl module is available, can't test what happens without it.") def test_no_ssl_module(self): # Explicit - self.assertRaises(ConfigurationError, MongoClient, ssl=True) + self.assertRaises(ConfigurationError, self.simple_client, ssl=True) # Implied - self.assertRaises(ConfigurationError, MongoClient, tlsCertificateKeyFile=CLIENT_PEM) + self.assertRaises(ConfigurationError, self.simple_client, tlsCertificateKeyFile=CLIENT_PEM) @unittest.skipUnless(HAVE_SSL, "The ssl module is not available.") @ignore_deprecations def test_config_ssl(self): # Tests various ssl configurations - self.assertRaises(ValueError, MongoClient, ssl="foo") + self.assertRaises(ValueError, self.simple_client, ssl="foo") self.assertRaises( - ConfigurationError, MongoClient, tls=False, tlsCertificateKeyFile=CLIENT_PEM + ConfigurationError, self.simple_client, tls=False, tlsCertificateKeyFile=CLIENT_PEM ) - self.assertRaises(TypeError, MongoClient, ssl=0) - self.assertRaises(TypeError, MongoClient, ssl=5.5) - self.assertRaises(TypeError, MongoClient, ssl=[]) + self.assertRaises(TypeError, self.simple_client, ssl=0) + self.assertRaises(TypeError, self.simple_client, ssl=5.5) + self.assertRaises(TypeError, self.simple_client, ssl=[]) - self.assertRaises(IOError, MongoClient, tlsCertificateKeyFile="NoSuchFile") - self.assertRaises(TypeError, MongoClient, tlsCertificateKeyFile=True) - self.assertRaises(TypeError, MongoClient, tlsCertificateKeyFile=[]) + self.assertRaises(IOError, self.simple_client, tlsCertificateKeyFile="NoSuchFile") + self.assertRaises(TypeError, self.simple_client, tlsCertificateKeyFile=True) + self.assertRaises(TypeError, self.simple_client, tlsCertificateKeyFile=[]) # Test invalid combinations self.assertRaises( - ConfigurationError, MongoClient, tls=False, tlsCertificateKeyFile=CLIENT_PEM + ConfigurationError, self.simple_client, tls=False, tlsCertificateKeyFile=CLIENT_PEM ) - self.assertRaises(ConfigurationError, MongoClient, tls=False, tlsCAFile=CA_PEM) - self.assertRaises(ConfigurationError, MongoClient, tls=False, tlsCRLFile=CRL_PEM) + self.assertRaises(ConfigurationError, self.simple_client, tls=False, tlsCAFile=CA_PEM) + self.assertRaises(ConfigurationError, self.simple_client, tls=False, tlsCRLFile=CRL_PEM) self.assertRaises( - ConfigurationError, MongoClient, tls=False, tlsAllowInvalidCertificates=False + ConfigurationError, self.simple_client, tls=False, tlsAllowInvalidCertificates=False ) self.assertRaises( - ConfigurationError, MongoClient, tls=False, tlsAllowInvalidHostnames=False + ConfigurationError, self.simple_client, tls=False, tlsAllowInvalidHostnames=False ) self.assertRaises( - ConfigurationError, MongoClient, tls=False, tlsDisableOCSPEndpointCheck=False + ConfigurationError, self.simple_client, tls=False, tlsDisableOCSPEndpointCheck=False ) @unittest.skipUnless(_HAVE_PYOPENSSL, "PyOpenSSL is not available.") @@ -174,7 +175,7 @@ def test_tlsCertificateKeyFilePassword(self): if not hasattr(ssl, "SSLContext") and not _ssl.IS_PYOPENSSL: self.assertRaises( ConfigurationError, - MongoClient, + self.simple_client, "localhost", ssl=True, tlsCertificateKeyFile=CLIENT_ENCRYPTED_PEM, @@ -184,7 +185,7 @@ def test_tlsCertificateKeyFilePassword(self): ) else: connected( - MongoClient( + self.simple_client( "localhost", ssl=True, tlsCertificateKeyFile=CLIENT_ENCRYPTED_PEM, @@ -201,7 +202,7 @@ def test_tlsCertificateKeyFilePassword(self): "&tlsCAFile=%s&serverSelectionTimeoutMS=5000" ) connected( - MongoClient(uri_fmt % (CLIENT_ENCRYPTED_PEM, CA_PEM), **self.credentials) # type: ignore[arg-type] + self.simple_client(uri_fmt % (CLIENT_ENCRYPTED_PEM, CA_PEM), **self.credentials) # type: ignore[arg-type] ) @client_context.require_tlsCertificateKeyFile @@ -215,7 +216,7 @@ def test_cert_ssl_implicitly_set(self): # # test that setting tlsCertificateKeyFile causes ssl to be set to True - client = MongoClient( + client = self.simple_client( client_context.host, client_context.port, tlsAllowInvalidCertificates=True, @@ -223,7 +224,7 @@ def test_cert_ssl_implicitly_set(self): ) response = client.admin.command(HelloCompat.LEGACY_CMD) if "setName" in response: - client = MongoClient( + client = self.simple_client( client_context.pair, replicaSet=response["setName"], w=len(response["hosts"]), @@ -242,7 +243,7 @@ def test_cert_ssl_validation(self): # --sslPEMKeyFile=/path/to/pymongo/test/certificates/server.pem # --sslCAFile=/path/to/pymongo/test/certificates/ca.pem # - client = MongoClient( + client = self.simple_client( "localhost", ssl=True, tlsCertificateKeyFile=CLIENT_PEM, @@ -257,7 +258,7 @@ def test_cert_ssl_validation(self): "Cannot validate hostname in the certificate" ) - client = MongoClient( + client = self.simple_client( "localhost", replicaSet=response["setName"], w=len(response["hosts"]), @@ -270,7 +271,7 @@ def test_cert_ssl_validation(self): self.assertClientWorks(client) if HAVE_IPADDRESS: - client = MongoClient( + client = self.simple_client( "127.0.0.1", ssl=True, tlsCertificateKeyFile=CLIENT_PEM, @@ -292,7 +293,7 @@ def test_cert_ssl_uri_support(self): "mongodb://localhost/?ssl=true&tlsCertificateKeyFile=%s&tlsAllowInvalidCertificates" "=%s&tlsCAFile=%s&tlsAllowInvalidHostnames=false" ) - client = MongoClient(uri_fmt % (CLIENT_PEM, "true", CA_PEM)) + client = self.simple_client(uri_fmt % (CLIENT_PEM, "true", CA_PEM)) self.assertClientWorks(client) @client_context.require_tlsCertificateKeyFile @@ -316,7 +317,7 @@ def test_cert_ssl_validation_hostname_matching(self): with self.assertRaises(ConnectionFailure): connected( - MongoClient( + self.simple_client( "server", ssl=True, tlsCertificateKeyFile=CLIENT_PEM, @@ -328,7 +329,7 @@ def test_cert_ssl_validation_hostname_matching(self): ) connected( - MongoClient( + self.simple_client( "server", ssl=True, tlsCertificateKeyFile=CLIENT_PEM, @@ -343,7 +344,7 @@ def test_cert_ssl_validation_hostname_matching(self): if "setName" in response: with self.assertRaises(ConnectionFailure): connected( - MongoClient( + self.simple_client( "server", replicaSet=response["setName"], ssl=True, @@ -356,7 +357,7 @@ def test_cert_ssl_validation_hostname_matching(self): ) connected( - MongoClient( + self.simple_client( "server", replicaSet=response["setName"], ssl=True, @@ -375,7 +376,7 @@ def test_tlsCRLFile_support(self): if not hasattr(ssl, "VERIFY_CRL_CHECK_LEAF") or _ssl.IS_PYOPENSSL: self.assertRaises( ConfigurationError, - MongoClient, + self.simple_client, "localhost", ssl=True, tlsCAFile=CA_PEM, @@ -384,7 +385,7 @@ def test_tlsCRLFile_support(self): ) else: connected( - MongoClient( + self.simple_client( "localhost", ssl=True, tlsCAFile=CA_PEM, @@ -395,7 +396,7 @@ def test_tlsCRLFile_support(self): with self.assertRaises(ConnectionFailure): connected( - MongoClient( + self.simple_client( "localhost", ssl=True, tlsCAFile=CA_PEM, @@ -406,7 +407,7 @@ def test_tlsCRLFile_support(self): ) uri_fmt = "mongodb://localhost/?ssl=true&tlsCAFile=%s&serverSelectionTimeoutMS=1000" - connected(MongoClient(uri_fmt % (CA_PEM,), **self.credentials)) # type: ignore + connected(self.simple_client(uri_fmt % (CA_PEM,), **self.credentials)) # type: ignore uri_fmt = ( "mongodb://localhost/?ssl=true&tlsCRLFile=%s" @@ -414,7 +415,7 @@ def test_tlsCRLFile_support(self): ) with self.assertRaises(ConnectionFailure): connected( - MongoClient(uri_fmt % (CRL_PEM, CA_PEM), **self.credentials) # type: ignore[arg-type] + self.simple_client(uri_fmt % (CRL_PEM, CA_PEM), **self.credentials) # type: ignore[arg-type] ) @client_context.require_tlsCertificateKeyFile @@ -431,12 +432,14 @@ def test_validation_with_system_ca_certs(self): with self.assertRaises(ConnectionFailure): # Server cert is verified but hostname matching fails connected( - MongoClient("server", ssl=True, serverSelectionTimeoutMS=1000, **self.credentials) # type: ignore[arg-type] + self.simple_client( + "server", ssl=True, serverSelectionTimeoutMS=1000, **self.credentials + ) # type: ignore[arg-type] ) # Server cert is verified. Disable hostname matching. connected( - MongoClient( + self.simple_client( "server", ssl=True, tlsAllowInvalidHostnames=True, @@ -447,12 +450,14 @@ def test_validation_with_system_ca_certs(self): # Server cert and hostname are verified. connected( - MongoClient("localhost", ssl=True, serverSelectionTimeoutMS=1000, **self.credentials) # type: ignore[arg-type] + self.simple_client( + "localhost", ssl=True, serverSelectionTimeoutMS=1000, **self.credentials + ) # type: ignore[arg-type] ) # Server cert and hostname are verified. connected( - MongoClient( + self.simple_client( "mongodb://localhost/?ssl=true&serverSelectionTimeoutMS=1000", **self.credentials, # type: ignore[arg-type] ) @@ -472,7 +477,7 @@ def test_system_certs_config_error(self): ssl_support.HAVE_WINCERTSTORE = False try: with self.assertRaises(ConfigurationError): - MongoClient("mongodb://localhost/?ssl=true") + self.simple_client("mongodb://localhost/?ssl=true") finally: ssl_support.HAVE_CERTIFI = have_certifi ssl_support.HAVE_WINCERTSTORE = have_wincertstore @@ -536,7 +541,7 @@ def test_mongodb_x509_auth(self): ], ) - noauth = MongoClient( + noauth = self.simple_client( client_context.pair, ssl=True, tlsAllowInvalidCertificates=True, @@ -548,7 +553,7 @@ def test_mongodb_x509_auth(self): noauth.pymongo_test.test.find_one() listener = EventListener() - auth = MongoClient( + auth = self.simple_client( client_context.pair, authMechanism="MONGODB-X509", ssl=True, @@ -572,7 +577,7 @@ def test_mongodb_x509_auth(self): host, port, ) - client = MongoClient( + client = self.simple_client( uri, ssl=True, tlsAllowInvalidCertificates=True, tlsCertificateKeyFile=CLIENT_PEM ) self.addCleanup(client.close) @@ -580,7 +585,7 @@ def test_mongodb_x509_auth(self): client.pymongo_test.test.find_one() uri = "mongodb://%s:%d/?authMechanism=MONGODB-X509" % (host, port) - client = MongoClient( + client = self.simple_client( uri, ssl=True, tlsAllowInvalidCertificates=True, tlsCertificateKeyFile=CLIENT_PEM ) self.addCleanup(client.close) @@ -593,7 +598,7 @@ def test_mongodb_x509_auth(self): port, ) - bad_client = MongoClient( + bad_client = self.simple_client( uri, ssl=True, tlsAllowInvalidCertificates=True, tlsCertificateKeyFile=CLIENT_PEM ) self.addCleanup(bad_client.close) @@ -601,7 +606,7 @@ def test_mongodb_x509_auth(self): with self.assertRaises(OperationFailure): bad_client.pymongo_test.test.find_one() - bad_client = MongoClient( + bad_client = self.simple_client( client_context.pair, username="not the username", authMechanism="MONGODB-X509", @@ -622,7 +627,7 @@ def test_mongodb_x509_auth(self): ) try: connected( - MongoClient( + self.simple_client( uri, ssl=True, tlsAllowInvalidCertificates=True, @@ -648,7 +653,7 @@ def remove(path): self.addCleanup(remove, temp_ca_bundle) # Add the CA cert file to the bundle. cat_files(temp_ca_bundle, CA_BUNDLE_PEM, CA_PEM) - with MongoClient( + with self.simple_client( "localhost", tls=True, tlsCertificateKeyFile=CLIENT_PEM, tlsCAFile=temp_ca_bundle ) as client: self.assertTrue(client.admin.command("ping")) diff --git a/test/test_streaming_protocol.py b/test/test_streaming_protocol.py index 9bca899a48..b3b68703a4 100644 --- a/test/test_streaming_protocol.py +++ b/test/test_streaming_protocol.py @@ -24,8 +24,6 @@ from test.utils import ( HeartbeatEventListener, ServerEventListener, - rs_or_single_client, - single_client, wait_until, ) @@ -38,7 +36,7 @@ class TestStreamingProtocol(IntegrationTest): def test_failCommand_streaming(self): listener = ServerEventListener() hb_listener = HeartbeatEventListener() - client = rs_or_single_client( + client = self.rs_or_single_client( event_listeners=[listener, hb_listener], heartbeatFrequencyMS=500, appName="failingHeartbeatTest", @@ -107,7 +105,7 @@ def test_streaming_rtt(self): }, } with self.fail_point(delay_hello): - client = rs_or_single_client( + client = self.rs_or_single_client( event_listeners=[listener, hb_listener], heartbeatFrequencyMS=500, appName=name ) self.addCleanup(client.close) @@ -155,7 +153,7 @@ def test_monitor_waits_after_server_check_error(self): } with self.fail_point(fail_hello): start = time.time() - client = single_client( + client = self.single_client( appName="SDAMMinHeartbeatFrequencyTest", serverSelectionTimeoutMS=5000 ) self.addCleanup(client.close) @@ -180,7 +178,7 @@ def test_monitor_waits_after_server_check_error(self): @client_context.require_failCommand_appName def test_heartbeat_awaited_flag(self): hb_listener = HeartbeatEventListener() - client = single_client( + client = self.single_client( event_listeners=[hb_listener], heartbeatFrequencyMS=500, appName="heartbeatEventAwaitedFlag", diff --git a/test/test_transactions.py b/test/test_transactions.py index c8c3c32d5b..3cecbe9d38 100644 --- a/test/test_transactions.py +++ b/test/test_transactions.py @@ -17,6 +17,7 @@ import sys from io import BytesIO +from test.utils_spec_runner import SpecRunner from gridfs.synchronous.grid_file import GridFS, GridFSBucket @@ -25,8 +26,6 @@ from test import IntegrationTest, client_context, unittest from test.utils import ( OvertCommandListener, - rs_client, - single_client, wait_until, ) from typing import List @@ -59,7 +58,18 @@ UNPIN_TEST_MAX_ATTEMPTS = 50 -class TestTransactions(IntegrationTest): +class TransactionsBase(SpecRunner): + def maybe_skip_scenario(self, test): + super().maybe_skip_scenario(test) + if ( + "secondary" in self.id() + and not client_context.is_mongos + and not client_context.has_secondaries + ): + raise unittest.SkipTest("No secondaries") + + +class TestTransactions(TransactionsBase): RUN_ON_SERVERLESS = True @client_context.require_transactions @@ -92,8 +102,7 @@ def test_transaction_options_validation(self): @client_context.require_transactions def test_transaction_write_concern_override(self): """Test txn overrides Client/Database/Collection write_concern.""" - client = rs_client(w=0) - self.addCleanup(client.close) + client = self.rs_client(w=0) db = client.test coll = db.test coll.insert_one({}) @@ -146,12 +155,11 @@ def test_transaction_write_concern_override(self): def test_unpin_for_next_transaction(self): # Increase localThresholdMS and wait until both nodes are discovered # to avoid false positives. - client = rs_client(client_context.mongos_seeds(), localThresholdMS=1000) + client = self.rs_client(client_context.mongos_seeds(), localThresholdMS=1000) wait_until(lambda: len(client.nodes) > 1, "discover both mongoses") coll = client.test.test # Create the collection. coll.insert_one({}) - self.addCleanup(client.close) with client.start_session() as s: # Session is pinned to Mongos. with s.start_transaction(): @@ -174,12 +182,11 @@ def test_unpin_for_next_transaction(self): def test_unpin_for_non_transaction_operation(self): # Increase localThresholdMS and wait until both nodes are discovered # to avoid false positives. - client = rs_client(client_context.mongos_seeds(), localThresholdMS=1000) + client = self.rs_client(client_context.mongos_seeds(), localThresholdMS=1000) wait_until(lambda: len(client.nodes) > 1, "discover both mongoses") coll = client.test.test # Create the collection. coll.insert_one({}) - self.addCleanup(client.close) with client.start_session() as s: # Session is pinned to Mongos. with s.start_transaction(): @@ -303,11 +310,10 @@ def test_transaction_starts_with_batched_write(self): # Start a transaction with a batch of operations that needs to be # split. listener = OvertCommandListener() - client = rs_client(event_listeners=[listener]) + client = self.rs_client(event_listeners=[listener]) coll = client[self.db.name].test coll.delete_many({}) listener.reset() - self.addCleanup(client.close) self.addCleanup(coll.drop) large_str = "\0" * (1 * 1024 * 1024) ops: List[InsertOne[RawBSONDocument]] = [ @@ -332,8 +338,7 @@ def test_transaction_starts_with_batched_write(self): @client_context.require_transactions def test_transaction_direct_connection(self): - client = single_client() - self.addCleanup(client.close) + client = self.single_client() coll = client.pymongo_test.test # Make sure the collection exists. @@ -389,14 +394,14 @@ def __exit__(self, exc_type, exc_val, exc_tb): client_session._WITH_TRANSACTION_RETRY_TIME_LIMIT = self.real_timeout -class TestTransactionsConvenientAPI(IntegrationTest): +class TestTransactionsConvenientAPI(TransactionsBase): @classmethod def _setup_class(cls): super()._setup_class() cls.mongos_clients = [] if client_context.supports_transactions(): for address in client_context.mongoses: - cls.mongos_clients.append(single_client("{}:{}".format(*address))) + cls.mongos_clients.append(cls.unmanaged_single_client("{}:{}".format(*address))) @classmethod def _tearDown_class(cls): @@ -446,8 +451,7 @@ def callback2(session): @client_context.require_transactions def test_callback_not_retried_after_timeout(self): listener = OvertCommandListener() - client = rs_client(event_listeners=[listener]) - self.addCleanup(client.close) + client = self.rs_client(event_listeners=[listener]) coll = client[self.db.name].test def callback(session): @@ -475,8 +479,7 @@ def callback(session): @client_context.require_transactions def test_callback_not_retried_after_commit_timeout(self): listener = OvertCommandListener() - client = rs_client(event_listeners=[listener]) - self.addCleanup(client.close) + client = self.rs_client(event_listeners=[listener]) coll = client[self.db.name].test def callback(session): @@ -508,8 +511,7 @@ def callback(session): @client_context.require_transactions def test_commit_not_retried_after_timeout(self): listener = OvertCommandListener() - client = rs_client(event_listeners=[listener]) - self.addCleanup(client.close) + client = self.rs_client(event_listeners=[listener]) coll = client[self.db.name].test def callback(session): diff --git a/test/test_typing.py b/test/test_typing.py index f423b70a3e..6cfe40537b 100644 --- a/test/test_typing.py +++ b/test/test_typing.py @@ -68,8 +68,7 @@ class ImplicitMovie(TypedDict): sys.path[0:0] = [""] -from test import IntegrationTest, client_context -from test.utils import rs_or_single_client +from test import IntegrationTest, PyMongoTestCase, client_context from bson import CodecOptions, decode, decode_all, decode_file_iter, decode_iter, encode from bson.raw_bson import RawBSONDocument @@ -194,7 +193,7 @@ def test_list_databases(self) -> None: value.items() def test_default_document_type(self) -> None: - client = rs_or_single_client() + client = self.rs_or_single_client() self.addCleanup(client.close) coll = client.test.test doc = {"my": "doc"} @@ -366,7 +365,7 @@ def test_bson_decode_file_iter_none_codec_option(self) -> None: doc["a"] = 2 -class TestDocumentType(unittest.TestCase): +class TestDocumentType(PyMongoTestCase): @only_type_check def test_default(self) -> None: client: MongoClient = MongoClient() @@ -480,7 +479,7 @@ def test_typeddict_empty_document_type(self) -> None: def test_typeddict_find_notrequired(self): if NotRequired is None or ImplicitMovie is None: raise unittest.SkipTest("Python 3.11+ is required to use NotRequired.") - client: MongoClient[ImplicitMovie] = rs_or_single_client() + client: MongoClient[ImplicitMovie] = self.rs_or_single_client() coll = client.test.test coll.insert_one(ImplicitMovie(name="THX-1138", year=1971)) out = coll.find_one({}) diff --git a/test/test_versioned_api.py b/test/test_versioned_api.py index 7fe8ebd76f..7a25a507dc 100644 --- a/test/test_versioned_api.py +++ b/test/test_versioned_api.py @@ -20,7 +20,7 @@ from test import IntegrationTest, client_context, unittest from test.unified_format import generate_test_classes -from test.utils import OvertCommandListener, rs_or_single_client +from test.utils import OvertCommandListener from pymongo.server_api import ServerApi, ServerApiVersion from pymongo.synchronous.mongo_client import MongoClient @@ -77,7 +77,7 @@ def assertServerApiInAllCommands(self, events): @client_context.require_version_min(4, 7) def test_command_options(self): listener = OvertCommandListener() - client = rs_or_single_client(server_api=ServerApi("1"), event_listeners=[listener]) + client = self.rs_or_single_client(server_api=ServerApi("1"), event_listeners=[listener]) self.addCleanup(client.close) coll = client.test.test coll.insert_many([{} for _ in range(100)]) @@ -90,7 +90,7 @@ def test_command_options(self): @client_context.require_transactions def test_command_options_txn(self): listener = OvertCommandListener() - client = rs_or_single_client(server_api=ServerApi("1"), event_listeners=[listener]) + client = self.rs_or_single_client(server_api=ServerApi("1"), event_listeners=[listener]) self.addCleanup(client.close) coll = client.test.test coll.insert_many([{} for _ in range(100)]) diff --git a/test/unified_format.py b/test/unified_format.py index 78fc638787..62211d3d25 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -55,8 +55,6 @@ parse_collection_options, parse_spec_options, prepare_spec_arguments, - rs_or_single_client, - single_client, snake_to_camel, wait_until, ) @@ -574,7 +572,7 @@ def _create_entity(self, entity_spec, uri=None): ) if uri: kwargs["h"] = uri - client = rs_or_single_client(**kwargs) + client = self.test.rs_or_single_client(**kwargs) self[spec["id"]] = client self.test.addCleanup(client.close) return @@ -1115,7 +1113,7 @@ def setUpClass(cls): and not client_context.serverless ): for address in client_context.mongoses: - cls.mongos_clients.append(single_client("{}:{}".format(*address))) + cls.mongos_clients.append(cls.unmanaged_single_client("{}:{}".format(*address))) # Speed up the tests by decreasing the heartbeat frequency. cls.knobs = client_knobs( @@ -1646,7 +1644,7 @@ def _testOperation_targetedFailPoint(self, spec): ) ) - client = single_client("{}:{}".format(*session._pinned_address)) + client = self.single_client("{}:{}".format(*session._pinned_address)) self.addCleanup(client.close) self.__set_fail_point(client=client, command_args=spec["failPoint"]) diff --git a/test/utils.py b/test/utils.py index fa198b1c64..6eefd1c7ea 100644 --- a/test/utils.py +++ b/test/utils.py @@ -565,151 +565,6 @@ def create_tests(self): setattr(self._test_class, new_test.__name__, new_test) -def _connection_string(h): - if h.startswith(("mongodb://", "mongodb+srv://")): - return h - return f"mongodb://{h!s}" - - -def _mongo_client(host, port, authenticate=True, directConnection=None, **kwargs): - """Create a new client over SSL/TLS if necessary.""" - host = host or client_context.host - port = port or client_context.port - client_options: dict = client_context.default_client_options.copy() - if client_context.replica_set_name and not directConnection: - client_options["replicaSet"] = client_context.replica_set_name - if directConnection is not None: - client_options["directConnection"] = directConnection - client_options.update(kwargs) - - uri = _connection_string(host) - auth_mech = kwargs.get("authMechanism", "") - if client_context.auth_enabled and authenticate and auth_mech != "MONGODB-OIDC": - # Only add the default username or password if one is not provided. - res = parse_uri(uri) - if ( - not res["username"] - and not res["password"] - and "username" not in client_options - and "password" not in client_options - ): - client_options["username"] = db_user - client_options["password"] = db_pwd - return MongoClient(uri, port, **client_options) - - -async def _async_mongo_client(host, port, authenticate=True, directConnection=None, **kwargs): - """Create a new client over SSL/TLS if necessary.""" - host = host or await async_client_context.host - port = port or await async_client_context.port - client_options: dict = async_client_context.default_client_options.copy() - if async_client_context.replica_set_name and not directConnection: - client_options["replicaSet"] = async_client_context.replica_set_name - if directConnection is not None: - client_options["directConnection"] = directConnection - client_options.update(kwargs) - - uri = _connection_string(host) - auth_mech = kwargs.get("authMechanism", "") - if async_client_context.auth_enabled and authenticate and auth_mech != "MONGODB-OIDC": - # Only add the default username or password if one is not provided. - res = parse_uri(uri) - if ( - not res["username"] - and not res["password"] - and "username" not in client_options - and "password" not in client_options - ): - client_options["username"] = db_user - client_options["password"] = db_pwd - client = AsyncMongoClient(uri, port, **client_options) - if client._options.connect: - await client.aconnect() - return client - - -def single_client_noauth(h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]: - """Make a direct connection. Don't authenticate.""" - return _mongo_client(h, p, authenticate=False, directConnection=True, **kwargs) - - -def single_client(h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]: - """Make a direct connection, and authenticate if necessary.""" - return _mongo_client(h, p, directConnection=True, **kwargs) - - -def rs_client_noauth(h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]: - """Connect to the replica set. Don't authenticate.""" - return _mongo_client(h, p, authenticate=False, **kwargs) - - -def rs_client(h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]: - """Connect to the replica set and authenticate if necessary.""" - return _mongo_client(h, p, **kwargs) - - -def rs_or_single_client_noauth(h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[dict]: - """Connect to the replica set if there is one, otherwise the standalone. - - Like rs_or_single_client, but does not authenticate. - """ - return _mongo_client(h, p, authenticate=False, **kwargs) - - -def rs_or_single_client(h: Any = None, p: Any = None, **kwargs: Any) -> MongoClient[Any]: - """Connect to the replica set if there is one, otherwise the standalone. - - Authenticates if necessary. - """ - return _mongo_client(h, p, **kwargs) - - -async def async_single_client_noauth( - h: Any = None, p: Any = None, **kwargs: Any -) -> AsyncMongoClient[dict]: - """Make a direct connection. Don't authenticate.""" - return await _async_mongo_client(h, p, authenticate=False, directConnection=True, **kwargs) - - -async def async_single_client( - h: Any = None, p: Any = None, **kwargs: Any -) -> AsyncMongoClient[dict]: - """Make a direct connection, and authenticate if necessary.""" - return await _async_mongo_client(h, p, directConnection=True, **kwargs) - - -async def async_rs_client_noauth( - h: Any = None, p: Any = None, **kwargs: Any -) -> AsyncMongoClient[dict]: - """Connect to the replica set. Don't authenticate.""" - return await _async_mongo_client(h, p, authenticate=False, **kwargs) - - -async def async_rs_client(h: Any = None, p: Any = None, **kwargs: Any) -> AsyncMongoClient[dict]: - """Connect to the replica set and authenticate if necessary.""" - return await _async_mongo_client(h, p, **kwargs) - - -async def async_rs_or_single_client_noauth( - h: Any = None, p: Any = None, **kwargs: Any -) -> AsyncMongoClient[dict]: - """Connect to the replica set if there is one, otherwise the standalone. - - Like rs_or_single_client, but does not authenticate. - """ - return await _async_mongo_client(h, p, authenticate=False, **kwargs) - - -async def async_rs_or_single_client( - h: Any = None, p: Any = None, **kwargs: Any -) -> AsyncMongoClient[Any]: - """Connect to the replica set if there is one, otherwise the standalone. - - Authenticates if necessary. - """ - return await _async_mongo_client(h, p, **kwargs) - - def ensure_all_connected(client: MongoClient) -> None: """Ensure that the client's connection pool has socket connections to all members of a replica set. Raises ConfigurationError when called with a @@ -1108,20 +963,6 @@ def is_greenthread_patched(): return gevent_monkey_patched() or eventlet_monkey_patched() -def disable_replication(client): - """Disable replication on all secondaries.""" - for host, port in client.secondaries: - secondary = single_client(host, port) - secondary.admin.command("configureFailPoint", "stopReplProducer", mode="alwaysOn") - - -def enable_replication(client): - """Enable replication on all secondaries.""" - for host, port in client.secondaries: - secondary = single_client(host, port) - secondary.admin.command("configureFailPoint", "stopReplProducer", mode="off") - - class ExceptionCatchingThread(threading.Thread): """A thread that stores any exception encountered from run().""" diff --git a/test/utils_spec_runner.py b/test/utils_spec_runner.py index 0b882a8bc3..06a40351cd 100644 --- a/test/utils_spec_runner.py +++ b/test/utils_spec_runner.py @@ -29,7 +29,6 @@ camel_to_snake_args, parse_spec_options, prepare_spec_arguments, - rs_client, ) from typing import List @@ -101,6 +100,8 @@ def _setup_class(cls): @classmethod def _tearDown_class(cls): cls.knobs.disable() + for client in cls.mongos_clients: + client.close() super()._tearDown_class() def setUp(self): @@ -524,7 +525,7 @@ def run_scenario(self, scenario_def, test): host = client_context.MULTI_MONGOS_LB_URI elif client_context.is_mongos: host = client_context.mongos_seeds() - client = rs_client( + client = self.rs_client( h=host, event_listeners=[listener, pool_listener, server_listener], **client_options ) self.scenario_client = client