From 0dc245137c387769d6b4430b063efb0142b7dd34 Mon Sep 17 00:00:00 2001 From: Steven Silvester Date: Thu, 30 May 2024 16:40:23 -0500 Subject: [PATCH] PYTHON-4441 Use deferred imports instead of lazy module loading (#1648) (cherry picked from commit 49987e6a8a5217de19a425838e16d5c9674a8841) --- doc/changelog.rst | 16 ++++++++ pymongo/_gcp_helpers.py | 3 +- pymongo/_lazy_import.py | 43 ---------------------- pymongo/auth_aws.py | 15 ++------ pymongo/client_options.py | 4 +- pymongo/common.py | 6 ++- pymongo/compression_support.py | 67 +++++++++++++++++++++------------- pymongo/helpers.py | 15 ++++++++ pymongo/monitoring.py | 18 +-------- pymongo/pool.py | 6 ++- pymongo/pyopenssl_context.py | 25 ++++++------- pymongo/srv_resolver.py | 19 +++++++--- pymongo/uri_parser.py | 4 +- test/test_client.py | 6 +-- test/test_srv_polling.py | 4 +- test/test_uri_spec.py | 4 +- test/utils.py | 2 +- 17 files changed, 128 insertions(+), 129 deletions(-) delete mode 100644 pymongo/_lazy_import.py diff --git a/doc/changelog.rst b/doc/changelog.rst index 4df0ba6321..1654acf862 100644 --- a/doc/changelog.rst +++ b/doc/changelog.rst @@ -1,6 +1,22 @@ Changelog ========= +Changes in Version 4.7.3 +------------------------- + +Version 4.7.3 has further fixes for lazily loading modules. + +- Use deferred imports instead of importlib lazy module loading. +- Improve import time on Windows. + +Issues Resolved +............... + +See the `PyMongo 4.7.3 release notes in JIRA`_ for the list of resolved issues +in this release. + +.. _PyMongo 4.7.3 release notes in JIRA: https://jira.mongodb.org/secure/ReleaseNote.jspa?projectId=10004&version=39865 + Changes in Version 4.7.2 ------------------------- diff --git a/pymongo/_gcp_helpers.py b/pymongo/_gcp_helpers.py index 46f02ba1e5..d90f3cc217 100644 --- a/pymongo/_gcp_helpers.py +++ b/pymongo/_gcp_helpers.py @@ -16,10 +16,11 @@ from __future__ import annotations from typing import Any -from urllib.request import Request, urlopen def _get_gcp_response(resource: str, timeout: float = 5) -> dict[str, Any]: + from urllib.request import Request, urlopen + url = "http://metadata/computeMetadata/v1/instance/service-accounts/default/identity" url += f"?audience={resource}" headers = {"Metadata-Flavor": "Google"} diff --git a/pymongo/_lazy_import.py b/pymongo/_lazy_import.py deleted file mode 100644 index 888339d034..0000000000 --- a/pymongo/_lazy_import.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright 2024-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you -# may not use this file except in compliance with the License. You -# may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or -# implied. See the License for the specific language governing -# permissions and limitations under the License. -from __future__ import annotations - -import importlib.util -import sys -from types import ModuleType - - -def lazy_import(name: str) -> ModuleType: - """Lazily import a module by name - - From https://docs.python.org/3/library/importlib.html#implementing-lazy-imports - """ - # Workaround for PYTHON-4424. - if "__compiled__" in globals(): - return importlib.import_module(name) - try: - spec = importlib.util.find_spec(name) - except ValueError: - # Note: this cannot be ModuleNotFoundError, see PYTHON-4424. - raise ImportError(name=name) from None - if spec is None: - # Note: this cannot be ModuleNotFoundError, see PYTHON-4424. - raise ImportError(name=name) - assert spec is not None - loader = importlib.util.LazyLoader(spec.loader) # type:ignore[arg-type] - spec.loader = loader - module = importlib.util.module_from_spec(spec) - sys.modules[name] = module - loader.exec_module(module) - return module diff --git a/pymongo/auth_aws.py b/pymongo/auth_aws.py index 0d253cea13..042eee5a73 100644 --- a/pymongo/auth_aws.py +++ b/pymongo/auth_aws.py @@ -15,15 +15,6 @@ """MONGODB-AWS Authentication helpers.""" from __future__ import annotations -from pymongo._lazy_import import lazy_import - -try: - pymongo_auth_aws = lazy_import("pymongo_auth_aws") - _HAVE_MONGODB_AWS = True -except ImportError: - _HAVE_MONGODB_AWS = False - - from typing import TYPE_CHECKING, Any, Mapping, Type import bson @@ -38,11 +29,13 @@ def _authenticate_aws(credentials: MongoCredential, conn: Connection) -> None: """Authenticate using MONGODB-AWS.""" - if not _HAVE_MONGODB_AWS: + try: + import pymongo_auth_aws # type:ignore[import] + except ImportError as e: raise ConfigurationError( "MONGODB-AWS authentication requires pymongo-auth-aws: " "install with: python -m pip install 'pymongo[aws]'" - ) + ) from e # Delayed import. from pymongo_auth_aws.auth import ( # type:ignore[import] diff --git a/pymongo/client_options.py b/pymongo/client_options.py index 60332605a3..9c745b11ef 100644 --- a/pymongo/client_options.py +++ b/pymongo/client_options.py @@ -19,7 +19,6 @@ from bson.codec_options import _parse_codec_options from pymongo import common -from pymongo.auth import MongoCredential, _build_credentials_tuple from pymongo.compression_support import CompressionSettings from pymongo.errors import ConfigurationError from pymongo.monitoring import _EventListener, _EventListeners @@ -36,6 +35,7 @@ if TYPE_CHECKING: from bson.codec_options import CodecOptions + from pymongo.auth import MongoCredential from pymongo.encryption_options import AutoEncryptionOpts from pymongo.pyopenssl_context import SSLContext from pymongo.topology_description import _ServerSelector @@ -48,6 +48,8 @@ def _parse_credentials( mechanism = options.get("authmechanism", "DEFAULT" if username else None) source = options.get("authsource") if username or mechanism: + from pymongo.auth import _build_credentials_tuple + return _build_credentials_tuple(mechanism, source, username, password, options, database) return None diff --git a/pymongo/common.py b/pymongo/common.py index 7f1245b7d3..217bb3465e 100644 --- a/pymongo/common.py +++ b/pymongo/common.py @@ -40,8 +40,6 @@ from bson.binary import UuidRepresentation from bson.codec_options import CodecOptions, DatetimeConversion, TypeRegistry from bson.raw_bson import RawBSONDocument -from pymongo.auth import MECHANISMS -from pymongo.auth_oidc import OIDCCallback from pymongo.compression_support import ( validate_compressors, validate_zlib_compression_level, @@ -380,6 +378,8 @@ def validate_read_preference_mode(dummy: Any, value: Any) -> _ServerMode: def validate_auth_mechanism(option: str, value: Any) -> str: """Validate the authMechanism URI option.""" + from pymongo.auth import MECHANISMS + if value not in MECHANISMS: raise ValueError(f"{option} must be in {tuple(MECHANISMS)}") return value @@ -444,6 +444,8 @@ def validate_auth_mechanism_properties(option: str, value: Any) -> dict[str, Uni elif key in ["ALLOWED_HOSTS"] and isinstance(value, list): props[key] = value elif key in ["OIDC_CALLBACK", "OIDC_HUMAN_CALLBACK"]: + from pymongo.auth_oidc import OIDCCallback + if not isinstance(value, OIDCCallback): raise ValueError("callback must be an OIDCCallback object") props[key] = value diff --git a/pymongo/compression_support.py b/pymongo/compression_support.py index 7daad21046..2f155352d2 100644 --- a/pymongo/compression_support.py +++ b/pymongo/compression_support.py @@ -16,34 +16,39 @@ import warnings from typing import Any, Iterable, Optional, Union -from pymongo._lazy_import import lazy_import from pymongo.hello import HelloCompat -from pymongo.monitoring import _SENSITIVE_COMMANDS +from pymongo.helpers import _SENSITIVE_COMMANDS -try: - snappy = lazy_import("snappy") - _HAVE_SNAPPY = True -except ImportError: - # python-snappy isn't available. - _HAVE_SNAPPY = False +_SUPPORTED_COMPRESSORS = {"snappy", "zlib", "zstd"} +_NO_COMPRESSION = {HelloCompat.CMD, HelloCompat.LEGACY_CMD} +_NO_COMPRESSION.update(_SENSITIVE_COMMANDS) -try: - zlib = lazy_import("zlib") - _HAVE_ZLIB = True -except ImportError: - # Python built without zlib support. - _HAVE_ZLIB = False +def _have_snappy() -> bool: + try: + import snappy # type:ignore[import] # noqa: F401 -try: - zstandard = lazy_import("zstandard") - _HAVE_ZSTD = True -except ImportError: - _HAVE_ZSTD = False + return True + except ImportError: + return False -_SUPPORTED_COMPRESSORS = {"snappy", "zlib", "zstd"} -_NO_COMPRESSION = {HelloCompat.CMD, HelloCompat.LEGACY_CMD} -_NO_COMPRESSION.update(_SENSITIVE_COMMANDS) + +def _have_zlib() -> bool: + try: + import zlib # noqa: F401 + + return True + except ImportError: + return False + + +def _have_zstd() -> bool: + try: + import zstandard # noqa: F401 + + return True + except ImportError: + return False def validate_compressors(dummy: Any, value: Union[str, Iterable[str]]) -> list[str]: @@ -58,21 +63,21 @@ def validate_compressors(dummy: Any, value: Union[str, Iterable[str]]) -> list[s if compressor not in _SUPPORTED_COMPRESSORS: compressors.remove(compressor) warnings.warn(f"Unsupported compressor: {compressor}", stacklevel=2) - elif compressor == "snappy" and not _HAVE_SNAPPY: + elif compressor == "snappy" and not _have_snappy(): compressors.remove(compressor) warnings.warn( "Wire protocol compression with snappy is not available. " "You must install the python-snappy module for snappy support.", stacklevel=2, ) - elif compressor == "zlib" and not _HAVE_ZLIB: + elif compressor == "zlib" and not _have_zlib(): compressors.remove(compressor) warnings.warn( "Wire protocol compression with zlib is not available. " "The zlib module is not available.", stacklevel=2, ) - elif compressor == "zstd" and not _HAVE_ZSTD: + elif compressor == "zstd" and not _have_zstd(): compressors.remove(compressor) warnings.warn( "Wire protocol compression with zstandard is not available. " @@ -117,6 +122,8 @@ class SnappyContext: @staticmethod def compress(data: bytes) -> bytes: + import snappy + return snappy.compress(data) @@ -127,6 +134,8 @@ def __init__(self, level: int): self.level = level def compress(self, data: bytes) -> bytes: + import zlib + return zlib.compress(data, self.level) @@ -137,6 +146,8 @@ class ZstdContext: def compress(data: bytes) -> bytes: # ZstdCompressor is not thread safe. # TODO: Use a pool? + import zstandard + return zstandard.ZstdCompressor().compress(data) @@ -146,12 +157,18 @@ def decompress(data: bytes, compressor_id: int) -> bytes: # https://github.com/andrix/python-snappy/issues/65 # This only matters when data is a memoryview since # id(bytes(data)) == id(data) when data is a bytes. + import snappy + return snappy.uncompress(bytes(data)) elif compressor_id == ZlibContext.compressor_id: + import zlib + return zlib.decompress(data) elif compressor_id == ZstdContext.compressor_id: # ZstdDecompressor is not thread safe. # TODO: Use a pool? + import zstandard + return zstandard.ZstdDecompressor().decompress(data) else: raise ValueError("Unknown compressorId %d" % (compressor_id,)) diff --git a/pymongo/helpers.py b/pymongo/helpers.py index 916d78a33b..080c3204a4 100644 --- a/pymongo/helpers.py +++ b/pymongo/helpers.py @@ -93,6 +93,21 @@ # Server code raised when authentication fails. _AUTHENTICATION_FAILURE_CODE: int = 18 +# Note - to avoid bugs from forgetting which if these is all lowercase and +# which are camelCase, and at the same time avoid having to add a test for +# every command, use all lowercase here and test against command_name.lower(). +_SENSITIVE_COMMANDS: set = { + "authenticate", + "saslstart", + "saslcontinue", + "getnonce", + "createuser", + "updateuser", + "copydbgetnonce", + "copydbsaslstart", + "copydb", +} + def _gen_index_name(keys: _IndexList) -> str: """Generate an index name from the set of fields it is over.""" diff --git a/pymongo/monitoring.py b/pymongo/monitoring.py index aff11a9f42..896a747e72 100644 --- a/pymongo/monitoring.py +++ b/pymongo/monitoring.py @@ -191,7 +191,7 @@ def connection_checked_in(self, event): from bson.objectid import ObjectId from pymongo.hello import Hello, HelloCompat -from pymongo.helpers import _handle_exception +from pymongo.helpers import _SENSITIVE_COMMANDS, _handle_exception from pymongo.typings import _Address, _DocumentOut if TYPE_CHECKING: @@ -507,22 +507,6 @@ def register(listener: _EventListener) -> None: _LISTENERS.cmap_listeners.append(listener) -# Note - to avoid bugs from forgetting which if these is all lowercase and -# which are camelCase, and at the same time avoid having to add a test for -# every command, use all lowercase here and test against command_name.lower(). -_SENSITIVE_COMMANDS: set = { - "authenticate", - "saslstart", - "saslcontinue", - "getnonce", - "createuser", - "updateuser", - "copydbgetnonce", - "copydbsaslstart", - "copydb", -} - - # The "hello" command is also deemed sensitive when attempting speculative # authentication. def _is_speculative_authenticate(command_name: str, doc: Mapping[str, Any]) -> bool: diff --git a/pymongo/pool.py b/pymongo/pool.py index 50c2a6d3e1..a57ae1a10b 100644 --- a/pymongo/pool.py +++ b/pymongo/pool.py @@ -41,7 +41,7 @@ import bson from bson import DEFAULT_CODEC_OPTIONS -from pymongo import __version__, _csot, auth, helpers +from pymongo import __version__, _csot, helpers from pymongo.client_session import _validate_session_write_concern from pymongo.common import ( MAX_BSON_SIZE, @@ -860,6 +860,8 @@ def _hello( if creds: if creds.mechanism == "DEFAULT" and creds.username: cmd["saslSupportedMechs"] = creds.source + "." + creds.username + from pymongo import auth + auth_ctx = auth._AuthContext.from_credentials(creds, self.address) if auth_ctx: speculative_authenticate = auth_ctx.speculate_command() @@ -1091,6 +1093,8 @@ def authenticate(self, reauthenticate: bool = False) -> None: if not self.ready: creds = self.opts._credentials if creds: + from pymongo import auth + auth.authenticate(creds, self, reauthenticate=reauthenticate) self.ready = True if self.enabled_for_cmap: diff --git a/pymongo/pyopenssl_context.py b/pymongo/pyopenssl_context.py index fb00713553..b08588daff 100644 --- a/pymongo/pyopenssl_context.py +++ b/pymongo/pyopenssl_context.py @@ -25,10 +25,11 @@ from ipaddress import ip_address as _ip_address from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, Union +import cryptography.x509 as x509 +import service_identity from OpenSSL import SSL as _SSL from OpenSSL import crypto as _crypto -from pymongo._lazy_import import lazy_import from pymongo.errors import ConfigurationError as _ConfigurationError from pymongo.errors import _CertificateError # type:ignore[attr-defined] from pymongo.ocsp_cache import _OCSPCache @@ -37,14 +38,9 @@ from pymongo.socket_checker import _errno_from_exception from pymongo.write_concern import validate_boolean -_x509 = lazy_import("cryptography.x509") -_service_identity = lazy_import("service_identity") -_service_identity_pyopenssl = lazy_import("service_identity.pyopenssl") - if TYPE_CHECKING: from ssl import VerifyMode - from cryptography.x509 import Certificate _T = TypeVar("_T") @@ -184,7 +180,7 @@ class _CallbackData: """Data class which is passed to the OCSP callback.""" def __init__(self) -> None: - self.trusted_ca_certs: Optional[list[Certificate]] = None + self.trusted_ca_certs: Optional[list[x509.Certificate]] = None self.check_ocsp_endpoint: Optional[bool] = None self.ocsp_response_cache = _OCSPCache() @@ -336,11 +332,12 @@ def _load_wincerts(self, store: str) -> None: """Attempt to load CA certs from Windows trust store.""" cert_store = self._ctx.get_cert_store() oid = _stdlibssl.Purpose.SERVER_AUTH.oid + for cert, encoding, trust in _stdlibssl.enum_certificates(store): # type: ignore if encoding == "x509_asn": if trust is True or oid in trust: cert_store.add_cert( - _crypto.X509.from_cryptography(_x509.load_der_x509_certificate(cert)) + _crypto.X509.from_cryptography(x509.load_der_x509_certificate(cert)) ) def load_default_certs(self) -> None: @@ -404,14 +401,16 @@ def wrap_socket( # XXX: Do this in a callback registered with # SSLContext.set_info_callback? See Twisted for an example. if self.check_hostname and server_hostname is not None: + from service_identity import pyopenssl + try: if _is_ip_address(server_hostname): - _service_identity_pyopenssl.verify_ip_address(ssl_conn, server_hostname) + pyopenssl.verify_ip_address(ssl_conn, server_hostname) else: - _service_identity_pyopenssl.verify_hostname(ssl_conn, server_hostname) - except ( - _service_identity.SICertificateError, - _service_identity.SIVerificationError, + pyopenssl.verify_hostname(ssl_conn, server_hostname) + except ( # type:ignore[misc] + service_identity.SICertificateError, + service_identity.SIVerificationError, ) as exc: raise _CertificateError(str(exc)) from None return ssl_conn diff --git a/pymongo/srv_resolver.py b/pymongo/srv_resolver.py index 4ee1b1f5b6..6f6cc285fa 100644 --- a/pymongo/srv_resolver.py +++ b/pymongo/srv_resolver.py @@ -17,17 +17,22 @@ import ipaddress import random -from typing import Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union from pymongo.common import CONNECT_TIMEOUT from pymongo.errors import ConfigurationError -try: +if TYPE_CHECKING: from dns import resolver - _HAVE_DNSPYTHON = True -except ImportError: - _HAVE_DNSPYTHON = False + +def _have_dnspython() -> bool: + try: + import dns # noqa: F401 + + return True + except ImportError: + return False # dnspython can return bytes or str from various parts @@ -40,6 +45,8 @@ def maybe_decode(text: Union[str, bytes]) -> str: # PYTHON-2667 Lazily call dns.resolver methods for compatibility with eventlet. def _resolve(*args: Any, **kwargs: Any) -> resolver.Answer: + from dns import resolver + if hasattr(resolver, "resolve"): # dnspython >= 2 return resolver.resolve(*args, **kwargs) @@ -81,6 +88,8 @@ def __init__( raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,)) def get_options(self) -> Optional[str]: + from dns import resolver + try: results = _resolve(self.__fqdn, "TXT", lifetime=self.__connect_timeout) except (resolver.NoAnswer, resolver.NXDOMAIN): diff --git a/pymongo/uri_parser.py b/pymongo/uri_parser.py index 7f4ef57f9c..8e6e8298d8 100644 --- a/pymongo/uri_parser.py +++ b/pymongo/uri_parser.py @@ -40,7 +40,7 @@ get_validated_options, ) from pymongo.errors import ConfigurationError, InvalidURI -from pymongo.srv_resolver import _HAVE_DNSPYTHON, _SrvResolver +from pymongo.srv_resolver import _have_dnspython, _SrvResolver from pymongo.typings import _Address if TYPE_CHECKING: @@ -472,7 +472,7 @@ def parse_uri( is_srv = False scheme_free = uri[SCHEME_LEN:] elif uri.startswith(SRV_SCHEME): - if not _HAVE_DNSPYTHON: + if not _have_dnspython(): python_path = sys.executable or "python" raise ConfigurationError( 'The "dnspython" module must be ' diff --git a/test/test_client.py b/test/test_client.py index 4679e563b9..4377d410a9 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -86,7 +86,7 @@ from pymongo.client_options import ClientOptions from pymongo.command_cursor import CommandCursor from pymongo.common import _UUID_REPRESENTATIONS, CONNECT_TIMEOUT -from pymongo.compression_support import _HAVE_SNAPPY, _HAVE_ZSTD +from pymongo.compression_support import _have_snappy, _have_zstd from pymongo.cursor import Cursor, CursorType from pymongo.database import Database from pymongo.driver_info import DriverInfo @@ -1558,7 +1558,7 @@ def compression_settings(client): self.assertEqual(opts.compressors, ["zlib"]) self.assertEqual(opts.zlib_compression_level, -1) - if not _HAVE_SNAPPY: + if not _have_snappy(): uri = "mongodb://localhost:27017/?compressors=snappy" client = MongoClient(uri, connect=False) opts = compression_settings(client) @@ -1573,7 +1573,7 @@ def compression_settings(client): opts = compression_settings(client) self.assertEqual(opts.compressors, ["snappy", "zlib"]) - if not _HAVE_ZSTD: + if not _have_zstd(): uri = "mongodb://localhost:27017/?compressors=zstd" client = MongoClient(uri, connect=False) opts = compression_settings(client) diff --git a/test/test_srv_polling.py b/test/test_srv_polling.py index 32646b9946..29283f0ff2 100644 --- a/test/test_srv_polling.py +++ b/test/test_srv_polling.py @@ -28,7 +28,7 @@ from pymongo import common from pymongo.errors import ConfigurationError from pymongo.mongo_client import MongoClient -from pymongo.srv_resolver import _HAVE_DNSPYTHON +from pymongo.srv_resolver import _have_dnspython WAIT_TIME = 0.1 @@ -148,7 +148,7 @@ def predicate(): return True def run_scenario(self, dns_response, expect_change): - self.assertEqual(_HAVE_DNSPYTHON, True) + self.assertEqual(_have_dnspython(), True) if callable(dns_response): dns_resolver_response = dns_response else: diff --git a/test/test_uri_spec.py b/test/test_uri_spec.py index 33a22330fa..f483a03842 100644 --- a/test/test_uri_spec.py +++ b/test/test_uri_spec.py @@ -27,7 +27,7 @@ from test import clear_warning_registry, unittest from pymongo.common import INTERNAL_URI_OPTION_NAME_MAP, validate -from pymongo.compression_support import _HAVE_SNAPPY +from pymongo.compression_support import _have_snappy from pymongo.uri_parser import SRV_SCHEME, parse_uri CONN_STRING_TEST_PATH = os.path.join( @@ -95,7 +95,7 @@ def modified_test_scenario(*args, **kwargs): def create_test(test, test_workdir): def run_scenario(self): compressors = (test.get("options") or {}).get("compressors", []) - if "snappy" in compressors and not _HAVE_SNAPPY: + if "snappy" in compressors and not _have_snappy(): self.skipTest("This test needs the snappy module.") valid = True warning = False diff --git a/test/utils.py b/test/utils.py index 2caaa6fd99..15480dc440 100644 --- a/test/utils.py +++ b/test/utils.py @@ -39,9 +39,9 @@ from pymongo.cursor import CursorType from pymongo.errors import ConfigurationError, OperationFailure from pymongo.hello import HelloCompat +from pymongo.helpers import _SENSITIVE_COMMANDS from pymongo.lock import _create_lock from pymongo.monitoring import ( - _SENSITIVE_COMMANDS, ConnectionCheckedInEvent, ConnectionCheckedOutEvent, ConnectionCheckOutFailedEvent,