From cf425181e2ce0fbbd21533ff24e8fa7f1e9d5292 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Thu, 6 Jun 2024 16:58:48 -0700 Subject: [PATCH 01/11] PYTHON-4476 Separate data and IO classes more effectively --- gridfs/asynchronous/grid_file.py | 104 +- gridfs/grid_file_shared.py | 2 +- gridfs/synchronous/grid_file.py | 54 +- pymongo/__init__.py | 10 +- pymongo/asynchronous/aggregation.py | 22 +- pymongo/asynchronous/auth.py | 236 +- pymongo/asynchronous/auth_aws.py | 6 +- pymongo/asynchronous/auth_oidc.py | 132 +- pymongo/asynchronous/bulk.py | 50 +- pymongo/asynchronous/change_stream.py | 19 +- pymongo/asynchronous/client_options.py | 334 --- pymongo/asynchronous/client_session.py | 76 +- pymongo/asynchronous/collation.py | 226 -- pymongo/asynchronous/collection.py | 258 +-- pymongo/asynchronous/command_cursor.py | 18 +- pymongo/asynchronous/cursor.py | 46 +- pymongo/asynchronous/database.py | 95 +- pymongo/asynchronous/encryption.py | 18 +- pymongo/asynchronous/encryption_options.py | 270 --- pymongo/asynchronous/event_loggers.py | 225 -- pymongo/asynchronous/helpers.py | 257 +-- pymongo/asynchronous/message.py | 58 +- pymongo/asynchronous/mongo_client.py | 197 +- pymongo/asynchronous/monitor.py | 17 +- pymongo/asynchronous/monitoring.py | 1903 ----------------- pymongo/asynchronous/network.py | 57 +- pymongo/asynchronous/operations.py | 625 ------ pymongo/asynchronous/pool.py | 548 +---- pymongo/asynchronous/read_preferences.py | 624 ------ pymongo/asynchronous/server.py | 23 +- pymongo/asynchronous/server_description.py | 301 --- pymongo/asynchronous/settings.py | 12 +- pymongo/asynchronous/topology.py | 49 +- pymongo/asynchronous/topology_description.py | 678 ------ pymongo/asynchronous/uri_parser.py | 624 ------ pymongo/auth.py | 1 + pymongo/auth_oidc_shared.py | 118 + pymongo/auth_shared.py | 236 ++ pymongo/client_options.py | 335 ++- pymongo/collation.py | 215 +- pymongo/{asynchronous => }/common.py | 18 +- .../{asynchronous => }/compression_support.py | 4 +- pymongo/encryption_options.py | 258 ++- pymongo/errors.py | 2 +- pymongo/event_loggers.py | 214 +- pymongo/{asynchronous => }/hello.py | 6 +- pymongo/{asynchronous => }/hello_compat.py | 0 pymongo/helpers_constants.py | 72 - pymongo/helpers_shared.py | 328 +++ pymongo/{asynchronous => }/logger.py | 2 +- .../max_staleness_selectors.py | 2 +- pymongo/monitoring.py | 1903 ++++++++++++++++- pymongo/operations.py | 614 +++++- pymongo/pool_options.py | 484 +++++ pymongo/read_preferences.py | 615 +++++- pymongo/{asynchronous => }/response.py | 11 +- pymongo/server_description.py | 290 ++- .../{asynchronous => }/server_selectors.py | 4 +- pymongo/{asynchronous => }/srv_resolver.py | 2 +- pymongo/synchronous/aggregation.py | 10 +- pymongo/synchronous/auth.py | 220 +- pymongo/synchronous/auth_aws.py | 2 +- pymongo/synchronous/auth_oidc.py | 112 +- pymongo/synchronous/bulk.py | 26 +- pymongo/synchronous/change_stream.py | 9 +- pymongo/synchronous/client_options.py | 334 --- pymongo/synchronous/client_session.py | 8 +- pymongo/synchronous/collation.py | 226 -- pymongo/synchronous/collection.py | 56 +- pymongo/synchronous/command_cursor.py | 4 +- pymongo/synchronous/common.py | 1062 --------- pymongo/synchronous/compression_support.py | 178 -- pymongo/synchronous/cursor.py | 28 +- pymongo/synchronous/database.py | 13 +- pymongo/synchronous/encryption.py | 18 +- pymongo/synchronous/encryption_options.py | 270 --- pymongo/synchronous/event_loggers.py | 225 -- pymongo/synchronous/hello.py | 219 -- pymongo/synchronous/hello_compat.py | 26 - pymongo/synchronous/helpers.py | 253 +-- pymongo/synchronous/logger.py | 171 -- .../synchronous/max_staleness_selectors.py | 125 -- pymongo/synchronous/message.py | 20 +- pymongo/synchronous/mongo_client.py | 49 +- pymongo/synchronous/monitor.py | 13 +- pymongo/synchronous/monitoring.py | 1903 ----------------- pymongo/synchronous/network.py | 39 +- pymongo/synchronous/operations.py | 625 ------ pymongo/synchronous/pool.py | 516 +---- pymongo/synchronous/read_preferences.py | 624 ------ pymongo/synchronous/response.py | 6 +- pymongo/synchronous/server.py | 17 +- pymongo/synchronous/server_description.py | 301 --- pymongo/synchronous/server_selectors.py | 175 -- pymongo/synchronous/settings.py | 12 +- pymongo/synchronous/srv_resolver.py | 149 -- pymongo/synchronous/topology.py | 29 +- pymongo/synchronous/topology_description.py | 678 ------ pymongo/synchronous/typings.py | 61 - pymongo/synchronous/uri_parser.py | 624 ------ pymongo/topology_description.py | 679 +++++- pymongo/{asynchronous => }/typings.py | 19 +- pymongo/uri_parser.py | 625 +++++- test/__init__.py | 9 +- test/asynchronous/__init__.py | 11 +- test/asynchronous/test_collection.py | 9 +- test/auth_oidc/test_auth_oidc.py | 2 +- test/lambda/mongodb/app.py | 2 +- .../mockupdb/test_mongos_command_read_mode.py | 2 +- .../test_network_disconnect_primary.py | 2 +- test/mockupdb/test_op_msg.py | 2 +- test/mockupdb/test_op_msg_read_preference.py | 2 +- test/mockupdb/test_query_read_pref_sharded.py | 2 +- test/mockupdb/test_reset_and_request_check.py | 2 +- test/mockupdb/test_slave_okay_sharded.py | 2 +- test/mockupdb/test_slave_okay_single.py | 4 +- test/pymongo_mocks.py | 7 +- test/sigstop_sigcont.py | 2 +- test/synchronous/__init__.py | 11 +- test/synchronous/test_collection.py | 9 +- test/test_auth.py | 10 +- test/test_binary.py | 2 +- test/test_bulk.py | 4 +- test/test_client.py | 35 +- test/test_collation.py | 6 +- test/test_collection.py | 4 +- test/test_comment.py | 2 +- test/test_connection_monitoring.py | 6 +- ...nnections_survive_primary_stepdown_spec.py | 2 +- test/test_crud_v1.py | 13 +- test/test_cursor.py | 7 +- test/test_database.py | 14 +- test/test_default_exports.py | 2 +- test/test_discovery_and_monitoring.py | 16 +- test/test_dns.py | 4 +- test/test_encryption.py | 4 +- test/test_examples.py | 2 +- test/test_gridfs.py | 2 +- test/test_gridfs_bucket.py | 2 +- test/test_heartbeat_monitoring.py | 2 +- test/test_index_management.py | 2 +- test/test_logger.py | 2 +- test/test_max_staleness.py | 4 +- test/test_mongos_load_balancing.py | 6 +- test/test_monitoring.py | 5 +- test/test_pooling.py | 2 +- test/test_read_preferences.py | 14 +- test/test_read_write_concern_spec.py | 2 +- test/test_retryable_reads.py | 4 +- test/test_retryable_writes.py | 6 +- test/test_sdam_monitoring_spec.py | 11 +- test/test_server.py | 4 +- test/test_server_description.py | 4 +- test/test_server_selection.py | 8 +- test/test_server_selection_in_window.py | 6 +- test/test_server_selection_rtt.py | 2 +- test/test_session.py | 7 +- test/test_srv_polling.py | 19 +- test/test_ssl.py | 2 +- test/test_streaming_protocol.py | 4 +- test/test_topology.py | 14 +- test/test_transactions.py | 4 +- test/test_typing.py | 4 +- test/test_uri_parser.py | 2 +- test/test_uri_spec.py | 6 +- test/unified_format.py | 38 +- test/utils.py | 25 +- test/utils_selection_tests.py | 10 +- test/utils_spec_runner.py | 2 +- tools/synchro.py | 3 + 170 files changed, 8017 insertions(+), 17091 deletions(-) delete mode 100644 pymongo/asynchronous/client_options.py delete mode 100644 pymongo/asynchronous/collation.py delete mode 100644 pymongo/asynchronous/encryption_options.py delete mode 100644 pymongo/asynchronous/event_loggers.py delete mode 100644 pymongo/asynchronous/monitoring.py delete mode 100644 pymongo/asynchronous/operations.py delete mode 100644 pymongo/asynchronous/read_preferences.py delete mode 100644 pymongo/asynchronous/server_description.py delete mode 100644 pymongo/asynchronous/topology_description.py delete mode 100644 pymongo/asynchronous/uri_parser.py create mode 100644 pymongo/auth_oidc_shared.py create mode 100644 pymongo/auth_shared.py rename pymongo/{asynchronous => }/common.py (98%) rename pymongo/{asynchronous => }/compression_support.py (97%) rename pymongo/{asynchronous => }/hello.py (97%) rename pymongo/{asynchronous => }/hello_compat.py (100%) delete mode 100644 pymongo/helpers_constants.py create mode 100644 pymongo/helpers_shared.py rename pymongo/{asynchronous => }/logger.py (98%) rename pymongo/{asynchronous => }/max_staleness_selectors.py (98%) create mode 100644 pymongo/pool_options.py rename pymongo/{asynchronous => }/response.py (93%) rename pymongo/{asynchronous => }/server_selectors.py (97%) rename pymongo/{asynchronous => }/srv_resolver.py (98%) delete mode 100644 pymongo/synchronous/client_options.py delete mode 100644 pymongo/synchronous/collation.py delete mode 100644 pymongo/synchronous/common.py delete mode 100644 pymongo/synchronous/compression_support.py delete mode 100644 pymongo/synchronous/encryption_options.py delete mode 100644 pymongo/synchronous/event_loggers.py delete mode 100644 pymongo/synchronous/hello.py delete mode 100644 pymongo/synchronous/hello_compat.py delete mode 100644 pymongo/synchronous/logger.py delete mode 100644 pymongo/synchronous/max_staleness_selectors.py delete mode 100644 pymongo/synchronous/monitoring.py delete mode 100644 pymongo/synchronous/operations.py delete mode 100644 pymongo/synchronous/read_preferences.py delete mode 100644 pymongo/synchronous/server_description.py delete mode 100644 pymongo/synchronous/server_selectors.py delete mode 100644 pymongo/synchronous/srv_resolver.py delete mode 100644 pymongo/synchronous/topology_description.py delete mode 100644 pymongo/synchronous/typings.py delete mode 100644 pymongo/synchronous/uri_parser.py rename pymongo/{asynchronous => }/typings.py (60%) diff --git a/gridfs/asynchronous/grid_file.py b/gridfs/asynchronous/grid_file.py index 08174fd9d4..9546429a39 100644 --- a/gridfs/asynchronous/grid_file.py +++ b/gridfs/asynchronous/grid_file.py @@ -42,13 +42,12 @@ _clear_entity_type_registry, ) from pymongo import ASCENDING, DESCENDING, WriteConcern, _csot -from pymongo.asynchronous.client_session import ClientSession +from pymongo.asynchronous.client_session import AsyncClientSession from pymongo.asynchronous.collection import AsyncCollection -from pymongo.asynchronous.common import validate_string from pymongo.asynchronous.cursor import AsyncCursor from pymongo.asynchronous.database import AsyncDatabase -from pymongo.asynchronous.helpers import _check_write_command_response, anext -from pymongo.asynchronous.read_preferences import ReadPreference, _ServerMode +from pymongo.asynchronous.helpers import anext +from pymongo.common import validate_string from pymongo.errors import ( BulkWriteError, ConfigurationError, @@ -57,11 +56,13 @@ InvalidOperation, OperationFailure, ) +from pymongo.helpers_shared import _check_write_command_response +from pymongo.read_preferences import ReadPreference, _ServerMode _IS_SYNC = False -def _disallow_transactions(session: Optional[ClientSession]) -> None: +def _disallow_transactions(session: Optional[AsyncClientSession]) -> None: if session and session.in_transaction: raise InvalidOperation("GridFS does not support multi-document transactions") @@ -155,7 +156,7 @@ async def put(self, data: Any, **kwargs: Any) -> Any: await grid_file.write(data) return await grid_file._id - async def get(self, file_id: Any, session: Optional[ClientSession] = None) -> AsyncGridOut: + async def get(self, file_id: Any, session: Optional[AsyncClientSession] = None) -> AsyncGridOut: """Get a file from GridFS by ``"_id"``. Returns an instance of :class:`~gridfs.grid_file.GridOut`, @@ -163,7 +164,7 @@ async def get(self, file_id: Any, session: Optional[ClientSession] = None) -> As :param file_id: ``"_id"`` of the file to get :param session: a - :class:`~pymongo.client_session.ClientSession` + :class:`~pymongo.client_session.AsyncClientSession` .. versionchanged:: 3.6 Added ``session`` parameter. @@ -178,7 +179,7 @@ async def get_version( self, filename: Optional[str] = None, version: Optional[int] = -1, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, **kwargs: Any, ) -> AsyncGridOut: """Get a file from GridFS by ``"filename"`` or metadata fields. @@ -205,7 +206,7 @@ async def get_version( :param version: version of the file to get (defaults to -1, the most recent version uploaded) :param session: a - :class:`~pymongo.client_session.ClientSession` + :class:`~pymongo.client_session.AsyncClientSession` :param kwargs: find files by custom metadata. .. versionchanged:: 3.6 @@ -234,7 +235,10 @@ async def get_version( raise NoFile("no version %d for filename %r" % (version, filename)) from None async def get_last_version( - self, filename: Optional[str] = None, session: Optional[ClientSession] = None, **kwargs: Any + self, + filename: Optional[str] = None, + session: Optional[AsyncClientSession] = None, + **kwargs: Any, ) -> AsyncGridOut: """Get the most recent version of a file in GridFS by ``"filename"`` or metadata fields. @@ -244,7 +248,7 @@ async def get_last_version( :param filename: ``"filename"`` of the file to get, or `None` :param session: a - :class:`~pymongo.client_session.ClientSession` + :class:`~pymongo.client_session.AsyncClientSession` :param kwargs: find files by custom metadata. .. versionchanged:: 3.6 @@ -253,7 +257,7 @@ async def get_last_version( return await self.get_version(filename=filename, session=session, **kwargs) # TODO add optional safe mode for chunk removal? - async def delete(self, file_id: Any, session: Optional[ClientSession] = None) -> None: + async def delete(self, file_id: Any, session: Optional[AsyncClientSession] = None) -> None: """Delete a file from GridFS by ``"_id"``. Deletes all data belonging to the file with ``"_id"``: @@ -269,7 +273,7 @@ async def delete(self, file_id: Any, session: Optional[ClientSession] = None) -> :param file_id: ``"_id"`` of the file to delete :param session: a - :class:`~pymongo.client_session.ClientSession` + :class:`~pymongo.client_session.AsyncClientSession` .. versionchanged:: 3.6 Added ``session`` parameter. @@ -281,12 +285,12 @@ async def delete(self, file_id: Any, session: Optional[ClientSession] = None) -> await self._files.delete_one({"_id": file_id}, session=session) await self._chunks.delete_many({"files_id": file_id}, session=session) - async def list(self, session: Optional[ClientSession] = None) -> list[str]: + async def list(self, session: Optional[AsyncClientSession] = None) -> list[str]: """List the names of all files stored in this instance of :class:`GridFS`. :param session: a - :class:`~pymongo.client_session.ClientSession` + :class:`~pymongo.client_session.AsyncClientSession` .. versionchanged:: 3.6 Added ``session`` parameter. @@ -306,7 +310,7 @@ async def list(self, session: Optional[ClientSession] = None) -> list[str]: async def find_one( self, filter: Optional[Any] = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, *args: Any, **kwargs: Any, ) -> Optional[AsyncGridOut]: @@ -327,7 +331,7 @@ async def find_one( :param args: any additional positional arguments are the same as the arguments to :meth:`find`. :param session: a - :class:`~pymongo.client_session.ClientSession` + :class:`~pymongo.client_session.AsyncClientSession` :param kwargs: any additional keyword arguments are the same as the arguments to :meth:`find`. @@ -370,7 +374,7 @@ def find(self, *args: Any, **kwargs: Any) -> AsyncGridOutCursor: :meth:`~pymongo.collection.Collection.find` in :class:`~pymongo.collection.Collection`. - If a :class:`~pymongo.client_session.ClientSession` is passed to + If a :class:`~pymongo.client_session.AsyncClientSession` is passed to :meth:`find`, all returned :class:`~gridfs.grid_file.GridOut` instances are associated with that session. @@ -406,7 +410,7 @@ def find(self, *args: Any, **kwargs: Any) -> AsyncGridOutCursor: async def exists( self, document_or_id: Optional[Any] = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, **kwargs: Any, ) -> bool: """Check if a file exists in this instance of :class:`GridFS`. @@ -438,7 +442,7 @@ async def exists( :param document_or_id: query document, or _id of the document to check for :param session: a - :class:`~pymongo.client_session.ClientSession` + :class:`~pymongo.client_session.AsyncClientSession` :param kwargs: keyword arguments are used as a query document, if they're present. @@ -525,7 +529,7 @@ def open_upload_stream( filename: str, chunk_size_bytes: Optional[int] = None, metadata: Optional[Mapping[str, Any]] = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, ) -> AsyncGridIn: """Opens a Stream that the application can write the contents of the file to. @@ -556,7 +560,7 @@ def open_upload_stream( files collection document. If not provided the metadata field will be omitted from the files collection document. :param session: a - :class:`~pymongo.client_session.ClientSession` + :class:`~pymongo.client_session.AsyncClientSession` .. versionchanged:: 3.6 Added ``session`` parameter. @@ -580,7 +584,7 @@ def open_upload_stream_with_id( filename: str, chunk_size_bytes: Optional[int] = None, metadata: Optional[Mapping[str, Any]] = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, ) -> AsyncGridIn: """Opens a Stream that the application can write the contents of the file to. @@ -615,7 +619,7 @@ def open_upload_stream_with_id( files collection document. If not provided the metadata field will be omitted from the files collection document. :param session: a - :class:`~pymongo.client_session.ClientSession` + :class:`~pymongo.client_session.AsyncClientSession` .. versionchanged:: 3.6 Added ``session`` parameter. @@ -641,7 +645,7 @@ async def upload_from_stream( source: Any, chunk_size_bytes: Optional[int] = None, metadata: Optional[Mapping[str, Any]] = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, ) -> ObjectId: """Uploads a user file to a GridFS bucket. @@ -672,7 +676,7 @@ async def upload_from_stream( files collection document. If not provided the metadata field will be omitted from the files collection document. :param session: a - :class:`~pymongo.client_session.ClientSession` + :class:`~pymongo.client_session.AsyncClientSession` .. versionchanged:: 3.6 Added ``session`` parameter. @@ -692,7 +696,7 @@ async def upload_from_stream_with_id( source: Any, chunk_size_bytes: Optional[int] = None, metadata: Optional[Mapping[str, Any]] = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, ) -> None: """Uploads a user file to a GridFS bucket with a custom file id. @@ -724,7 +728,7 @@ async def upload_from_stream_with_id( files collection document. If not provided the metadata field will be omitted from the files collection document. :param session: a - :class:`~pymongo.client_session.ClientSession` + :class:`~pymongo.client_session.AsyncClientSession` .. versionchanged:: 3.6 Added ``session`` parameter. @@ -735,7 +739,7 @@ async def upload_from_stream_with_id( await gin.write(source) async def open_download_stream( - self, file_id: Any, session: Optional[ClientSession] = None + self, file_id: Any, session: Optional[AsyncClientSession] = None ) -> AsyncGridOut: """Opens a Stream from which the application can read the contents of the stored file specified by file_id. @@ -755,7 +759,7 @@ async def open_download_stream( :param file_id: The _id of the file to be downloaded. :param session: a - :class:`~pymongo.client_session.ClientSession` + :class:`~pymongo.client_session.AsyncClientSession` .. versionchanged:: 3.6 Added ``session`` parameter. @@ -768,7 +772,7 @@ async def open_download_stream( @_csot.apply async def download_to_stream( - self, file_id: Any, destination: Any, session: Optional[ClientSession] = None + self, file_id: Any, destination: Any, session: Optional[AsyncClientSession] = None ) -> None: """Downloads the contents of the stored file specified by file_id and writes the contents to `destination`. @@ -790,7 +794,7 @@ async def download_to_stream( :param file_id: The _id of the file to be downloaded. :param destination: a file-like object implementing :meth:`write`. :param session: a - :class:`~pymongo.client_session.ClientSession` + :class:`~pymongo.client_session.AsyncClientSession` .. versionchanged:: 3.6 Added ``session`` parameter. @@ -803,7 +807,7 @@ async def download_to_stream( destination.write(chunk) @_csot.apply - async def delete(self, file_id: Any, session: Optional[ClientSession] = None) -> None: + async def delete(self, file_id: Any, session: Optional[AsyncClientSession] = None) -> None: """Given an file_id, delete this stored file's files collection document and associated chunks from a GridFS bucket. @@ -819,7 +823,7 @@ async def delete(self, file_id: Any, session: Optional[ClientSession] = None) -> :param file_id: The _id of the file to be deleted. :param session: a - :class:`~pymongo.client_session.ClientSession` + :class:`~pymongo.client_session.AsyncClientSession` .. versionchanged:: 3.6 Added ``session`` parameter. @@ -859,7 +863,7 @@ def find(self, *args: Any, **kwargs: Any) -> AsyncGridOutCursor: :meth:`~pymongo.collection.Collection.find` in :class:`~pymongo.collection.Collection`. - If a :class:`~pymongo.client_session.ClientSession` is passed to + If a :class:`~pymongo.client_session.AsyncClientSession` is passed to :meth:`find`, all returned :class:`~gridfs.grid_file.GridOut` instances are associated with that session. @@ -878,7 +882,7 @@ def find(self, *args: Any, **kwargs: Any) -> AsyncGridOutCursor: return AsyncGridOutCursor(self._collection, *args, **kwargs) async def open_download_stream_by_name( - self, filename: str, revision: int = -1, session: Optional[ClientSession] = None + self, filename: str, revision: int = -1, session: Optional[AsyncClientSession] = None ) -> AsyncGridOut: """Opens a Stream from which the application can read the contents of `filename` and optional `revision`. @@ -902,7 +906,7 @@ async def open_download_stream_by_name( filename and different uploadDate) of the file to retrieve. Defaults to -1 (the most recent revision). :param session: a - :class:`~pymongo.client_session.ClientSession` + :class:`~pymongo.client_session.AsyncClientSession` :Note: Revision numbers are defined as follows: @@ -937,7 +941,7 @@ async def download_to_stream_by_name( filename: str, destination: Any, revision: int = -1, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, ) -> None: """Write the contents of `filename` (with optional `revision`) to `destination`. @@ -961,7 +965,7 @@ async def download_to_stream_by_name( filename and different uploadDate) of the file to retrieve. Defaults to -1 (the most recent revision). :param session: a - :class:`~pymongo.client_session.ClientSession` + :class:`~pymongo.client_session.AsyncClientSession` :Note: Revision numbers are defined as follows: @@ -985,7 +989,7 @@ async def download_to_stream_by_name( destination.write(chunk) async def rename( - self, file_id: Any, new_filename: str, session: Optional[ClientSession] = None + self, file_id: Any, new_filename: str, session: Optional[AsyncClientSession] = None ) -> None: """Renames the stored file with the specified file_id. @@ -1002,7 +1006,7 @@ async def rename( :param file_id: The _id of the file to be renamed. :param new_filename: The new name of the file. :param session: a - :class:`~pymongo.client_session.ClientSession` + :class:`~pymongo.client_session.AsyncClientSession` .. versionchanged:: 3.6 Added ``session`` parameter. @@ -1024,7 +1028,7 @@ class AsyncGridIn: def __init__( self, root_collection: AsyncCollection, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, **kwargs: Any, ) -> None: """Write a file to GridFS @@ -1059,7 +1063,7 @@ def __init__( :param root_collection: root collection to write to :param session: a - :class:`~pymongo.client_session.ClientSession` to use for all + :class:`~pymongo.client_session.AsyncClientSession` to use for all commands :param kwargs: Any: file level options (see above) @@ -1402,7 +1406,7 @@ def __init__( root_collection: AsyncCollection, file_id: Optional[int] = None, file_document: Optional[Any] = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, ) -> None: """Read a file from GridFS @@ -1420,7 +1424,7 @@ def __init__( :param file_document: file document from `root_collection.files` :param session: a - :class:`~pymongo.client_session.ClientSession` to use for all + :class:`~pymongo.client_session.AsyncClientSession` to use for all commands .. versionchanged:: 3.8 @@ -1734,7 +1738,7 @@ def __init__( self, grid_out: AsyncGridOut, chunks: AsyncCollection, - session: Optional[ClientSession], + session: Optional[AsyncClientSession], next_chunk: Any, ) -> None: self._id = grid_out._id @@ -1824,7 +1828,9 @@ async def close(self) -> None: class AsyncGridOutIterator: - def __init__(self, grid_out: AsyncGridOut, chunks: AsyncCollection, session: ClientSession): + def __init__( + self, grid_out: AsyncGridOut, chunks: AsyncCollection, session: AsyncClientSession + ): self._chunk_iter = _AsyncGridOutChunkIterator(grid_out, chunks, session, 0) def __aiter__(self) -> AsyncGridOutIterator: @@ -1851,7 +1857,7 @@ def __init__( no_cursor_timeout: bool = False, sort: Optional[Any] = None, batch_size: int = 0, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, ) -> None: """Create a new cursor, similar to the normal :class:`~pymongo.cursor.Cursor`. @@ -1894,6 +1900,6 @@ def add_option(self, *args: Any, **kwargs: Any) -> NoReturn: def remove_option(self, *args: Any, **kwargs: Any) -> NoReturn: raise NotImplementedError("Method does not exist for GridOutCursor") - def _clone_base(self, session: Optional[ClientSession]) -> AsyncGridOutCursor: + def _clone_base(self, session: Optional[AsyncClientSession]) -> AsyncGridOutCursor: """Creates an empty GridOutCursor for information to be copied into.""" return AsyncGridOutCursor(self._root_collection, session=session) diff --git a/gridfs/grid_file_shared.py b/gridfs/grid_file_shared.py index f6c37b9f33..b6f02a53df 100644 --- a/gridfs/grid_file_shared.py +++ b/gridfs/grid_file_shared.py @@ -5,7 +5,7 @@ from typing import Any, Optional from pymongo import ASCENDING -from pymongo.asynchronous.common import MAX_MESSAGE_SIZE +from pymongo.common import MAX_MESSAGE_SIZE from pymongo.errors import InvalidOperation _SEEK_SET = os.SEEK_SET diff --git a/gridfs/synchronous/grid_file.py b/gridfs/synchronous/grid_file.py index 0e98429920..ee43f01897 100644 --- a/gridfs/synchronous/grid_file.py +++ b/gridfs/synchronous/grid_file.py @@ -42,6 +42,7 @@ _grid_out_property, ) from pymongo import ASCENDING, DESCENDING, WriteConcern, _csot +from pymongo.common import validate_string from pymongo.errors import ( BulkWriteError, ConfigurationError, @@ -50,13 +51,13 @@ InvalidOperation, OperationFailure, ) +from pymongo.helpers_shared import _check_write_command_response +from pymongo.read_preferences import ReadPreference, _ServerMode from pymongo.synchronous.client_session import ClientSession from pymongo.synchronous.collection import Collection -from pymongo.synchronous.common import validate_string from pymongo.synchronous.cursor import Cursor from pymongo.synchronous.database import Database -from pymongo.synchronous.helpers import _check_write_command_response, next -from pymongo.synchronous.read_preferences import ReadPreference, _ServerMode +from pymongo.synchronous.helpers import next _IS_SYNC = True @@ -163,7 +164,7 @@ def get(self, file_id: Any, session: Optional[ClientSession] = None) -> GridOut: :param file_id: ``"_id"`` of the file to get :param session: a - :class:`~pymongo.client_session.ClientSession` + :class:`~pymongo.client_session.AsyncClientSession` .. versionchanged:: 3.6 Added ``session`` parameter. @@ -205,7 +206,7 @@ def get_version( :param version: version of the file to get (defaults to -1, the most recent version uploaded) :param session: a - :class:`~pymongo.client_session.ClientSession` + :class:`~pymongo.client_session.AsyncClientSession` :param kwargs: find files by custom metadata. .. versionchanged:: 3.6 @@ -234,7 +235,10 @@ def get_version( raise NoFile("no version %d for filename %r" % (version, filename)) from None def get_last_version( - self, filename: Optional[str] = None, session: Optional[ClientSession] = None, **kwargs: Any + self, + filename: Optional[str] = None, + session: Optional[ClientSession] = None, + **kwargs: Any, ) -> GridOut: """Get the most recent version of a file in GridFS by ``"filename"`` or metadata fields. @@ -244,7 +248,7 @@ def get_last_version( :param filename: ``"filename"`` of the file to get, or `None` :param session: a - :class:`~pymongo.client_session.ClientSession` + :class:`~pymongo.client_session.AsyncClientSession` :param kwargs: find files by custom metadata. .. versionchanged:: 3.6 @@ -269,7 +273,7 @@ def delete(self, file_id: Any, session: Optional[ClientSession] = None) -> None: :param file_id: ``"_id"`` of the file to delete :param session: a - :class:`~pymongo.client_session.ClientSession` + :class:`~pymongo.client_session.AsyncClientSession` .. versionchanged:: 3.6 Added ``session`` parameter. @@ -286,7 +290,7 @@ def list(self, session: Optional[ClientSession] = None) -> list[str]: :class:`GridFS`. :param session: a - :class:`~pymongo.client_session.ClientSession` + :class:`~pymongo.client_session.AsyncClientSession` .. versionchanged:: 3.6 Added ``session`` parameter. @@ -325,7 +329,7 @@ def find_one( :param args: any additional positional arguments are the same as the arguments to :meth:`find`. :param session: a - :class:`~pymongo.client_session.ClientSession` + :class:`~pymongo.client_session.AsyncClientSession` :param kwargs: any additional keyword arguments are the same as the arguments to :meth:`find`. @@ -368,7 +372,7 @@ def find(self, *args: Any, **kwargs: Any) -> GridOutCursor: :meth:`~pymongo.collection.Collection.find` in :class:`~pymongo.collection.Collection`. - If a :class:`~pymongo.client_session.ClientSession` is passed to + If a :class:`~pymongo.client_session.AsyncClientSession` is passed to :meth:`find`, all returned :class:`~gridfs.grid_file.GridOut` instances are associated with that session. @@ -436,7 +440,7 @@ def exists( :param document_or_id: query document, or _id of the document to check for :param session: a - :class:`~pymongo.client_session.ClientSession` + :class:`~pymongo.client_session.AsyncClientSession` :param kwargs: keyword arguments are used as a query document, if they're present. @@ -554,7 +558,7 @@ def open_upload_stream( files collection document. If not provided the metadata field will be omitted from the files collection document. :param session: a - :class:`~pymongo.client_session.ClientSession` + :class:`~pymongo.client_session.AsyncClientSession` .. versionchanged:: 3.6 Added ``session`` parameter. @@ -613,7 +617,7 @@ def open_upload_stream_with_id( files collection document. If not provided the metadata field will be omitted from the files collection document. :param session: a - :class:`~pymongo.client_session.ClientSession` + :class:`~pymongo.client_session.AsyncClientSession` .. versionchanged:: 3.6 Added ``session`` parameter. @@ -670,7 +674,7 @@ def upload_from_stream( files collection document. If not provided the metadata field will be omitted from the files collection document. :param session: a - :class:`~pymongo.client_session.ClientSession` + :class:`~pymongo.client_session.AsyncClientSession` .. versionchanged:: 3.6 Added ``session`` parameter. @@ -720,7 +724,7 @@ def upload_from_stream_with_id( files collection document. If not provided the metadata field will be omitted from the files collection document. :param session: a - :class:`~pymongo.client_session.ClientSession` + :class:`~pymongo.client_session.AsyncClientSession` .. versionchanged:: 3.6 Added ``session`` parameter. @@ -751,7 +755,7 @@ def open_download_stream( :param file_id: The _id of the file to be downloaded. :param session: a - :class:`~pymongo.client_session.ClientSession` + :class:`~pymongo.client_session.AsyncClientSession` .. versionchanged:: 3.6 Added ``session`` parameter. @@ -786,7 +790,7 @@ def download_to_stream( :param file_id: The _id of the file to be downloaded. :param destination: a file-like object implementing :meth:`write`. :param session: a - :class:`~pymongo.client_session.ClientSession` + :class:`~pymongo.client_session.AsyncClientSession` .. versionchanged:: 3.6 Added ``session`` parameter. @@ -815,7 +819,7 @@ def delete(self, file_id: Any, session: Optional[ClientSession] = None) -> None: :param file_id: The _id of the file to be deleted. :param session: a - :class:`~pymongo.client_session.ClientSession` + :class:`~pymongo.client_session.AsyncClientSession` .. versionchanged:: 3.6 Added ``session`` parameter. @@ -855,7 +859,7 @@ def find(self, *args: Any, **kwargs: Any) -> GridOutCursor: :meth:`~pymongo.collection.Collection.find` in :class:`~pymongo.collection.Collection`. - If a :class:`~pymongo.client_session.ClientSession` is passed to + If a :class:`~pymongo.client_session.AsyncClientSession` is passed to :meth:`find`, all returned :class:`~gridfs.grid_file.GridOut` instances are associated with that session. @@ -898,7 +902,7 @@ def open_download_stream_by_name( filename and different uploadDate) of the file to retrieve. Defaults to -1 (the most recent revision). :param session: a - :class:`~pymongo.client_session.ClientSession` + :class:`~pymongo.client_session.AsyncClientSession` :Note: Revision numbers are defined as follows: @@ -957,7 +961,7 @@ def download_to_stream_by_name( filename and different uploadDate) of the file to retrieve. Defaults to -1 (the most recent revision). :param session: a - :class:`~pymongo.client_session.ClientSession` + :class:`~pymongo.client_session.AsyncClientSession` :Note: Revision numbers are defined as follows: @@ -996,7 +1000,7 @@ def rename( :param file_id: The _id of the file to be renamed. :param new_filename: The new name of the file. :param session: a - :class:`~pymongo.client_session.ClientSession` + :class:`~pymongo.client_session.AsyncClientSession` .. versionchanged:: 3.6 Added ``session`` parameter. @@ -1053,7 +1057,7 @@ def __init__( :param root_collection: root collection to write to :param session: a - :class:`~pymongo.client_session.ClientSession` to use for all + :class:`~pymongo.client_session.AsyncClientSession` to use for all commands :param kwargs: Any: file level options (see above) @@ -1408,7 +1412,7 @@ def __init__( :param file_document: file document from `root_collection.files` :param session: a - :class:`~pymongo.client_session.ClientSession` to use for all + :class:`~pymongo.client_session.AsyncClientSession` to use for all commands .. versionchanged:: 3.8 diff --git a/pymongo/__init__.py b/pymongo/__init__.py index 8992281db8..7ee177bdae 100644 --- a/pymongo/__init__.py +++ b/pymongo/__init__.py @@ -89,11 +89,9 @@ from pymongo import _csot from pymongo._version import __version__, get_version_string, version_tuple from pymongo.asynchronous.mongo_client import AsyncMongoClient +from pymongo.common import MAX_SUPPORTED_WIRE_VERSION, MIN_SUPPORTED_WIRE_VERSION from pymongo.cursor import CursorType -from pymongo.synchronous.collection import ReturnDocument -from pymongo.synchronous.common import MAX_SUPPORTED_WIRE_VERSION, MIN_SUPPORTED_WIRE_VERSION -from pymongo.synchronous.mongo_client import MongoClient -from pymongo.synchronous.operations import ( +from pymongo.operations import ( DeleteMany, DeleteOne, IndexModel, @@ -102,7 +100,9 @@ UpdateMany, UpdateOne, ) -from pymongo.synchronous.read_preferences import ReadPreference +from pymongo.read_preferences import ReadPreference +from pymongo.synchronous.collection import ReturnDocument +from pymongo.synchronous.mongo_client import MongoClient from pymongo.write_concern import WriteConcern version = __version__ diff --git a/pymongo/asynchronous/aggregation.py b/pymongo/asynchronous/aggregation.py index 9fc2dae3c4..fa6cefd53a 100644 --- a/pymongo/asynchronous/aggregation.py +++ b/pymongo/asynchronous/aggregation.py @@ -18,20 +18,20 @@ from collections.abc import Callable, Mapping, MutableMapping from typing import TYPE_CHECKING, Any, Optional, Union -from pymongo.asynchronous import common -from pymongo.asynchronous.collation import validate_collation_or_none -from pymongo.asynchronous.read_preferences import ReadPreference, _AggWritePref +from pymongo import common +from pymongo.collation import validate_collation_or_none from pymongo.errors import ConfigurationError +from pymongo.read_preferences import ReadPreference, _AggWritePref if TYPE_CHECKING: - from pymongo.asynchronous.client_session import ClientSession + from pymongo.asynchronous.client_session import AsyncClientSession from pymongo.asynchronous.collection import AsyncCollection from pymongo.asynchronous.command_cursor import AsyncCommandCursor from pymongo.asynchronous.database import AsyncDatabase - from pymongo.asynchronous.pool import Connection - from pymongo.asynchronous.read_preferences import _ServerMode + from pymongo.asynchronous.pool import AsyncConnection from pymongo.asynchronous.server import Server - from pymongo.asynchronous.typings import _DocumentType, _Pipeline + from pymongo.read_preferences import _ServerMode + from pymongo.typings import _DocumentType, _Pipeline _IS_SYNC = False @@ -53,7 +53,7 @@ def __init__( explicit_session: bool, let: Optional[Mapping[str, Any]] = None, user_fields: Optional[MutableMapping[str, Any]] = None, - result_processor: Optional[Callable[[Mapping[str, Any], Connection], None]] = None, + result_processor: Optional[Callable[[Mapping[str, Any], AsyncConnection], None]] = None, comment: Any = None, ) -> None: if "explain" in options: @@ -121,7 +121,7 @@ def _database(self) -> AsyncDatabase: raise NotImplementedError def get_read_preference( - self, session: Optional[ClientSession] + self, session: Optional[AsyncClientSession] ) -> Union[_AggWritePref, _ServerMode]: if self._write_preference: return self._write_preference @@ -132,9 +132,9 @@ def get_read_preference( async def get_cursor( self, - session: Optional[ClientSession], + session: Optional[AsyncClientSession], server: Server, - conn: Connection, + conn: AsyncConnection, read_preference: _ServerMode, ) -> AsyncCommandCursor[_DocumentType]: # Serialize command. diff --git a/pymongo/asynchronous/auth.py b/pymongo/asynchronous/auth.py index 41e012022f..1fb28f6c49 100644 --- a/pymongo/asynchronous/auth.py +++ b/pymongo/asynchronous/auth.py @@ -18,17 +18,13 @@ import functools import hashlib import hmac -import os import socket -import typing from base64 import standard_b64decode, standard_b64encode -from collections import namedtuple from typing import ( TYPE_CHECKING, Any, Callable, Coroutine, - Dict, Mapping, MutableMapping, Optional, @@ -41,17 +37,19 @@ from pymongo.asynchronous.auth_oidc import ( _authenticate_oidc, _get_authenticator, - _OIDCAzureCallback, - _OIDCGCPCallback, - _OIDCProperties, - _OIDCTestCallback, +) +from pymongo.auth_shared import ( + MongoCredential, + _authenticate_scram_start, + _parse_scram_response, + _xor, ) from pymongo.errors import ConfigurationError, OperationFailure from pymongo.saslprep import saslprep if TYPE_CHECKING: - from pymongo.asynchronous.hello import Hello - from pymongo.asynchronous.pool import Connection + from pymongo.asynchronous.pool import AsyncConnection + from pymongo.hello import Hello HAVE_KERBEROS = True _USE_PRINCIPAL = False @@ -69,213 +67,9 @@ _IS_SYNC = False -MECHANISMS = frozenset( - [ - "GSSAPI", - "MONGODB-CR", - "MONGODB-OIDC", - "MONGODB-X509", - "MONGODB-AWS", - "PLAIN", - "SCRAM-SHA-1", - "SCRAM-SHA-256", - "DEFAULT", - ] -) -"""The authentication mechanisms supported by PyMongo.""" - - -class _Cache: - __slots__ = ("data",) - - _hash_val = hash("_Cache") - - def __init__(self) -> None: - self.data = None - - def __eq__(self, other: object) -> bool: - # Two instances must always compare equal. - if isinstance(other, _Cache): - return True - return NotImplemented - - def __ne__(self, other: object) -> bool: - if isinstance(other, _Cache): - return False - return NotImplemented - - def __hash__(self) -> int: - return self._hash_val - - -MongoCredential = namedtuple( - "MongoCredential", - ["mechanism", "source", "username", "password", "mechanism_properties", "cache"], -) -"""A hashable namedtuple of values used for authentication.""" - - -GSSAPIProperties = namedtuple( - "GSSAPIProperties", ["service_name", "canonicalize_host_name", "service_realm"] -) -"""Mechanism properties for GSSAPI authentication.""" - - -_AWSProperties = namedtuple("_AWSProperties", ["aws_session_token"]) -"""Mechanism properties for MONGODB-AWS authentication.""" - - -def _build_credentials_tuple( - mech: str, - source: Optional[str], - user: str, - passwd: str, - extra: Mapping[str, Any], - database: Optional[str], -) -> MongoCredential: - """Build and return a mechanism specific credentials tuple.""" - if mech not in ("MONGODB-X509", "MONGODB-AWS", "MONGODB-OIDC") and user is None: - raise ConfigurationError(f"{mech} requires a username.") - if mech == "GSSAPI": - if source is not None and source != "$external": - raise ValueError("authentication source must be $external or None for GSSAPI") - properties = extra.get("authmechanismproperties", {}) - service_name = properties.get("SERVICE_NAME", "mongodb") - canonicalize = bool(properties.get("CANONICALIZE_HOST_NAME", False)) - service_realm = properties.get("SERVICE_REALM") - props = GSSAPIProperties( - service_name=service_name, - canonicalize_host_name=canonicalize, - service_realm=service_realm, - ) - # Source is always $external. - return MongoCredential(mech, "$external", user, passwd, props, None) - elif mech == "MONGODB-X509": - if passwd is not None: - raise ConfigurationError("Passwords are not supported by MONGODB-X509") - if source is not None and source != "$external": - raise ValueError("authentication source must be $external or None for MONGODB-X509") - # Source is always $external, user can be None. - return MongoCredential(mech, "$external", user, None, None, None) - elif mech == "MONGODB-AWS": - if user is not None and passwd is None: - raise ConfigurationError("username without a password is not supported by MONGODB-AWS") - if source is not None and source != "$external": - raise ConfigurationError( - "authentication source must be $external or None for MONGODB-AWS" - ) - - properties = extra.get("authmechanismproperties", {}) - aws_session_token = properties.get("AWS_SESSION_TOKEN") - aws_props = _AWSProperties(aws_session_token=aws_session_token) - # user can be None for temporary link-local EC2 credentials. - return MongoCredential(mech, "$external", user, passwd, aws_props, None) - elif mech == "MONGODB-OIDC": - properties = extra.get("authmechanismproperties", {}) - callback = properties.get("OIDC_CALLBACK") - human_callback = properties.get("OIDC_HUMAN_CALLBACK") - environ = properties.get("ENVIRONMENT") - token_resource = properties.get("TOKEN_RESOURCE", "") - default_allowed = [ - "*.mongodb.net", - "*.mongodb-dev.net", - "*.mongodb-qa.net", - "*.mongodbgov.net", - "localhost", - "127.0.0.1", - "::1", - ] - allowed_hosts = properties.get("ALLOWED_HOSTS", default_allowed) - msg = ( - "authentication with MONGODB-OIDC requires providing either a callback or a environment" - ) - if passwd is not None: - msg = "password is not supported by MONGODB-OIDC" - raise ConfigurationError(msg) - if callback or human_callback: - if environ is not None: - raise ConfigurationError(msg) - if callback and human_callback: - msg = "cannot set both OIDC_CALLBACK and OIDC_HUMAN_CALLBACK" - raise ConfigurationError(msg) - elif environ is not None: - if environ == "test": - if user is not None: - msg = "test environment for MONGODB-OIDC does not support username" - raise ConfigurationError(msg) - callback = _OIDCTestCallback() - elif environ == "azure": - passwd = None - if not token_resource: - raise ConfigurationError( - "Azure environment for MONGODB-OIDC requires a TOKEN_RESOURCE auth mechanism property" - ) - callback = _OIDCAzureCallback(token_resource) - elif environ == "gcp": - passwd = None - if not token_resource: - raise ConfigurationError( - "GCP provider for MONGODB-OIDC requires a TOKEN_RESOURCE auth mechanism property" - ) - callback = _OIDCGCPCallback(token_resource) - else: - raise ConfigurationError(f"unrecognized ENVIRONMENT for MONGODB-OIDC: {environ}") - else: - raise ConfigurationError(msg) - - oidc_props = _OIDCProperties( - callback=callback, - human_callback=human_callback, - environment=environ, - allowed_hosts=allowed_hosts, - token_resource=token_resource, - username=user, - ) - return MongoCredential(mech, "$external", user, passwd, oidc_props, _Cache()) - - elif mech == "PLAIN": - source_database = source or database or "$external" - return MongoCredential(mech, source_database, user, passwd, None, None) - else: - source_database = source or database or "admin" - if passwd is None: - raise ConfigurationError("A password is required.") - return MongoCredential(mech, source_database, user, passwd, None, _Cache()) - - -def _xor(fir: bytes, sec: bytes) -> bytes: - """XOR two byte strings together.""" - return b"".join([bytes([x ^ y]) for x, y in zip(fir, sec)]) - - -def _parse_scram_response(response: bytes) -> Dict[bytes, bytes]: - """Split a scram response into key, value pairs.""" - return dict( - typing.cast(typing.Tuple[bytes, bytes], item.split(b"=", 1)) - for item in response.split(b",") - ) - - -def _authenticate_scram_start( - credentials: MongoCredential, mechanism: str -) -> tuple[bytes, bytes, MutableMapping[str, Any]]: - username = credentials.username - user = username.encode("utf-8").replace(b"=", b"=3D").replace(b",", b"=2C") - nonce = standard_b64encode(os.urandom(32)) - first_bare = b"n=" + user + b",r=" + nonce - - cmd = { - "saslStart": 1, - "mechanism": mechanism, - "payload": Binary(b"n,," + first_bare), - "autoAuthorize": 1, - "options": {"skipEmptyExchange": True}, - } - return nonce, first_bare, cmd - async def _authenticate_scram( - credentials: MongoCredential, conn: Connection, mechanism: str + credentials: MongoCredential, conn: AsyncConnection, mechanism: str ) -> None: """Authenticate using SCRAM.""" username = credentials.username @@ -398,7 +192,7 @@ def _canonicalize_hostname(hostname: str) -> str: return name[0].lower() -async def _authenticate_gssapi(credentials: MongoCredential, conn: Connection) -> None: +async def _authenticate_gssapi(credentials: MongoCredential, conn: AsyncConnection) -> None: """Authenticate using GSSAPI.""" if not HAVE_KERBEROS: raise ConfigurationError( @@ -509,7 +303,7 @@ async def _authenticate_gssapi(credentials: MongoCredential, conn: Connection) - raise OperationFailure(str(exc)) from None -async def _authenticate_plain(credentials: MongoCredential, conn: Connection) -> None: +async def _authenticate_plain(credentials: MongoCredential, conn: AsyncConnection) -> None: """Authenticate using SASL PLAIN (RFC 4616)""" source = credentials.source username = credentials.username @@ -524,7 +318,7 @@ async def _authenticate_plain(credentials: MongoCredential, conn: Connection) -> await conn.command(source, cmd) -async def _authenticate_x509(credentials: MongoCredential, conn: Connection) -> None: +async def _authenticate_x509(credentials: MongoCredential, conn: AsyncConnection) -> None: """Authenticate using MONGODB-X509.""" ctx = conn.auth_ctx if ctx and ctx.speculate_succeeded(): @@ -535,7 +329,7 @@ async def _authenticate_x509(credentials: MongoCredential, conn: Connection) -> await conn.command("$external", cmd) -async def _authenticate_mongo_cr(credentials: MongoCredential, conn: Connection) -> None: +async def _authenticate_mongo_cr(credentials: MongoCredential, conn: AsyncConnection) -> None: """Authenticate using MONGODB-CR.""" source = credentials.source username = credentials.username @@ -550,7 +344,7 @@ async def _authenticate_mongo_cr(credentials: MongoCredential, conn: Connection) await conn.command(source, query) -async def _authenticate_default(credentials: MongoCredential, conn: Connection) -> None: +async def _authenticate_default(credentials: MongoCredential, conn: AsyncConnection) -> None: if conn.max_wire_version >= 7: if conn.negotiated_mechs: mechs = conn.negotiated_mechs @@ -652,7 +446,7 @@ def speculate_command(self) -> Optional[MutableMapping[str, Any]]: async def authenticate( - credentials: MongoCredential, conn: Connection, reauthenticate: bool = False + credentials: MongoCredential, conn: AsyncConnection, reauthenticate: bool = False ) -> None: """Authenticate connection.""" mechanism = credentials.mechanism diff --git a/pymongo/asynchronous/auth_aws.py b/pymongo/asynchronous/auth_aws.py index 7cab111b30..9dcc625d19 100644 --- a/pymongo/asynchronous/auth_aws.py +++ b/pymongo/asynchronous/auth_aws.py @@ -23,13 +23,13 @@ if TYPE_CHECKING: from bson.typings import _ReadableBuffer - from pymongo.asynchronous.auth import MongoCredential - from pymongo.asynchronous.pool import Connection + from pymongo.asynchronous.pool import AsyncConnection + from pymongo.auth_shared import MongoCredential _IS_SYNC = False -async def _authenticate_aws(credentials: MongoCredential, conn: Connection) -> None: +async def _authenticate_aws(credentials: MongoCredential, conn: AsyncConnection) -> None: """Authenticate using MONGODB-AWS.""" try: import pymongo_auth_aws # type:ignore[import] diff --git a/pymongo/asynchronous/auth_oidc.py b/pymongo/asynchronous/auth_oidc.py index 022a173dc0..f5801b85d4 100644 --- a/pymongo/asynchronous/auth_oidc.py +++ b/pymongo/asynchronous/auth_oidc.py @@ -15,79 +15,35 @@ """MONGODB-OIDC Authentication helpers.""" from __future__ import annotations -import abc -import os import threading import time from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Mapping, MutableMapping, Optional, Union -from urllib.parse import quote import bson from bson.binary import Binary -from pymongo._azure_helpers import _get_azure_response from pymongo._csot import remaining -from pymongo._gcp_helpers import _get_gcp_response +from pymongo.auth_oidc_shared import ( + CALLBACK_VERSION, + HUMAN_CALLBACK_TIMEOUT_SECONDS, + MACHINE_CALLBACK_TIMEOUT_SECONDS, + TIME_BETWEEN_CALLS_SECONDS, + OIDCCallback, + OIDCCallbackContext, + OIDCCallbackResult, + OIDCIdPInfo, + _OIDCProperties, +) from pymongo.errors import ConfigurationError, OperationFailure -from pymongo.helpers_constants import _AUTHENTICATION_FAILURE_CODE +from pymongo.helpers_shared import _AUTHENTICATION_FAILURE_CODE if TYPE_CHECKING: - from pymongo.asynchronous.auth import MongoCredential - from pymongo.asynchronous.pool import Connection + from pymongo.asynchronous.pool import AsyncConnection + from pymongo.auth_shared import MongoCredential _IS_SYNC = False -@dataclass -class OIDCIdPInfo: - issuer: str - clientId: Optional[str] = field(default=None) - requestScopes: Optional[list[str]] = field(default=None) - - -@dataclass -class OIDCCallbackContext: - timeout_seconds: float - username: str - version: int - refresh_token: Optional[str] = field(default=None) - idp_info: Optional[OIDCIdPInfo] = field(default=None) - - -@dataclass -class OIDCCallbackResult: - access_token: str - expires_in_seconds: Optional[float] = field(default=None) - refresh_token: Optional[str] = field(default=None) - - -class OIDCCallback(abc.ABC): - """A base class for defining OIDC callbacks.""" - - @abc.abstractmethod - def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult: - """Convert the given BSON value into our own type.""" - - -@dataclass -class _OIDCProperties: - callback: Optional[OIDCCallback] = field(default=None) - human_callback: Optional[OIDCCallback] = field(default=None) - environment: Optional[str] = field(default=None) - allowed_hosts: list[str] = field(default_factory=list) - token_resource: Optional[str] = field(default=None) - username: str = "" - - -"""Mechanism properties for MONGODB-OIDC authentication.""" - -TOKEN_BUFFER_MINUTES = 5 -HUMAN_CALLBACK_TIMEOUT_SECONDS = 5 * 60 -CALLBACK_VERSION = 1 -MACHINE_CALLBACK_TIMEOUT_SECONDS = 60 -TIME_BETWEEN_CALLS_SECONDS = 0.1 - - def _get_authenticator( credentials: MongoCredential, address: tuple[str, int] ) -> _OIDCAuthenticator: @@ -117,48 +73,6 @@ def _get_authenticator( return credentials.cache.data -class _OIDCTestCallback(OIDCCallback): - def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult: - token_file = os.environ.get("OIDC_TOKEN_FILE") - if not token_file: - raise RuntimeError( - 'MONGODB-OIDC with an "test" provider requires "OIDC_TOKEN_FILE" to be set' - ) - with open(token_file) as fid: - return OIDCCallbackResult(access_token=fid.read().strip()) - - -class _OIDCAWSCallback(OIDCCallback): - def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult: - token_file = os.environ.get("AWS_WEB_IDENTITY_TOKEN_FILE") - if not token_file: - raise RuntimeError( - 'MONGODB-OIDC with an "aws" provider requires "AWS_WEB_IDENTITY_TOKEN_FILE" to be set' - ) - with open(token_file) as fid: - return OIDCCallbackResult(access_token=fid.read().strip()) - - -class _OIDCAzureCallback(OIDCCallback): - def __init__(self, token_resource: str) -> None: - self.token_resource = quote(token_resource) - - def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult: - resp = _get_azure_response(self.token_resource, context.username, context.timeout_seconds) - return OIDCCallbackResult( - access_token=resp["access_token"], expires_in_seconds=resp["expires_in"] - ) - - -class _OIDCGCPCallback(OIDCCallback): - def __init__(self, token_resource: str) -> None: - self.token_resource = quote(token_resource) - - def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult: - resp = _get_gcp_response(self.token_resource, context.timeout_seconds) - return OIDCCallbackResult(access_token=resp["access_token"]) - - @dataclass class _OIDCAuthenticator: username: str @@ -170,7 +84,7 @@ class _OIDCAuthenticator: lock: threading.Lock = field(default_factory=threading.Lock) last_call_time: float = field(default=0) - async def reauthenticate(self, conn: Connection) -> Optional[Mapping[str, Any]]: + async def reauthenticate(self, conn: AsyncConnection) -> Optional[Mapping[str, Any]]: """Handle a reauthenticate from the server.""" # Invalidate the token for the connection. self._invalidate(conn) @@ -179,7 +93,7 @@ async def reauthenticate(self, conn: Connection) -> Optional[Mapping[str, Any]]: return await self._authenticate_machine(conn) return await self._authenticate_human(conn) - async def authenticate(self, conn: Connection) -> Optional[Mapping[str, Any]]: + async def authenticate(self, conn: AsyncConnection) -> Optional[Mapping[str, Any]]: """Handle an initial authenticate request.""" # First handle speculative auth. # If it succeeded, we are done. @@ -203,7 +117,7 @@ def get_spec_auth_cmd(self) -> Optional[MutableMapping[str, Any]]: return None return self._get_start_command({"jwt": self.access_token}) - async def _authenticate_machine(self, conn: Connection) -> Mapping[str, Any]: + async def _authenticate_machine(self, conn: AsyncConnection) -> Mapping[str, Any]: # If there is a cached access token, try to authenticate with it. If # authentication fails with error code 18, invalidate the access token, # fetch a new access token, and try to authenticate again. If authentication @@ -217,7 +131,7 @@ async def _authenticate_machine(self, conn: Connection) -> Mapping[str, Any]: raise return await self._sasl_start_jwt(conn) - async def _authenticate_human(self, conn: Connection) -> Optional[Mapping[str, Any]]: + async def _authenticate_human(self, conn: AsyncConnection) -> Optional[Mapping[str, Any]]: # If we have a cached access token, try a JwtStepRequest. # authentication fails with error code 18, invalidate the access token, # and try to authenticate again. If authentication fails for any other @@ -307,7 +221,7 @@ def _get_access_token(self) -> Optional[str]: return self.access_token async def _run_command( - self, conn: Connection, cmd: MutableMapping[str, Any] + self, conn: AsyncConnection, cmd: MutableMapping[str, Any] ) -> Mapping[str, Any]: try: return await conn.command("$external", cmd, no_reauth=True) # type: ignore[call-arg] @@ -321,7 +235,7 @@ def _is_auth_error(self, err: Exception) -> bool: return False return err.code == _AUTHENTICATION_FAILURE_CODE - def _invalidate(self, conn: Connection) -> None: + def _invalidate(self, conn: AsyncConnection) -> None: # Ignore the invalidation if a token gen id is given and is less than our # current token gen id. token_gen_id = conn.oidc_token_gen_id or 0 @@ -330,7 +244,7 @@ def _invalidate(self, conn: Connection) -> None: self.access_token = None async def _sasl_continue_jwt( - self, conn: Connection, start_resp: Mapping[str, Any] + self, conn: AsyncConnection, start_resp: Mapping[str, Any] ) -> Mapping[str, Any]: self.access_token = None self.refresh_token = None @@ -342,7 +256,7 @@ async def _sasl_continue_jwt( cmd = self._get_continue_command({"jwt": access_token}, start_resp) return await self._run_command(conn, cmd) - async def _sasl_start_jwt(self, conn: Connection) -> Mapping[str, Any]: + async def _sasl_start_jwt(self, conn: AsyncConnection) -> Mapping[str, Any]: access_token = self._get_access_token() conn.oidc_token_gen_id = self.token_gen_id cmd = self._get_start_command({"jwt": access_token}) @@ -370,7 +284,7 @@ def _get_continue_command( async def _authenticate_oidc( - credentials: MongoCredential, conn: Connection, reauthenticate: bool + credentials: MongoCredential, conn: AsyncConnection, reauthenticate: bool ) -> Optional[Mapping[str, Any]]: """Authenticate using MONGODB-OIDC.""" authenticator = _get_authenticator(credentials, conn.address) diff --git a/pymongo/asynchronous/bulk.py b/pymongo/asynchronous/bulk.py index 4205fceac9..66ed994142 100644 --- a/pymongo/asynchronous/bulk.py +++ b/pymongo/asynchronous/bulk.py @@ -34,15 +34,8 @@ from bson.objectid import ObjectId from bson.raw_bson import RawBSONDocument -from pymongo import _csot -from pymongo.asynchronous import common -from pymongo.asynchronous.client_session import ClientSession, _validate_session_write_concern -from pymongo.asynchronous.common import ( - validate_is_document_type, - validate_ok_for_replace, - validate_ok_for_update, -) -from pymongo.asynchronous.helpers import _get_wce_doc +from pymongo import _csot, common +from pymongo.asynchronous.client_session import AsyncClientSession, _validate_session_write_concern from pymongo.asynchronous.message import ( _DELETE, _INSERT, @@ -51,20 +44,25 @@ _EncryptedBulkWriteContext, _randint, ) -from pymongo.asynchronous.read_preferences import ReadPreference +from pymongo.common import ( + validate_is_document_type, + validate_ok_for_replace, + validate_ok_for_update, +) from pymongo.errors import ( BulkWriteError, ConfigurationError, InvalidOperation, OperationFailure, ) -from pymongo.helpers_constants import _RETRYABLE_ERROR_CODES +from pymongo.helpers_shared import _RETRYABLE_ERROR_CODES, _get_wce_doc +from pymongo.read_preferences import ReadPreference from pymongo.write_concern import WriteConcern if TYPE_CHECKING: from pymongo.asynchronous.collection import AsyncCollection - from pymongo.asynchronous.pool import Connection - from pymongo.asynchronous.typings import _DocumentOut, _DocumentType, _Pipeline + from pymongo.asynchronous.pool import AsyncConnection + from pymongo.typings import _DocumentOut, _DocumentType, _Pipeline _IS_SYNC = False @@ -169,7 +167,7 @@ def _raise_bulk_write_error(full_result: _DocumentOut) -> NoReturn: raise BulkWriteError(full_result) -class _Bulk: +class _AsyncBulk: """The private guts of the bulk write API.""" def __init__( @@ -180,7 +178,7 @@ def __init__( comment: Optional[str] = None, let: Optional[Any] = None, ) -> None: - """Initialize a _Bulk instance.""" + """Initialize a _AsyncBulk instance.""" self.collection = collection.with_options( codec_options=collection.codec_options._replace( unicode_decode_error_handler="replace", document_class=dict @@ -319,8 +317,8 @@ async def _execute_command( self, generator: Iterator[Any], write_concern: WriteConcern, - session: Optional[ClientSession], - conn: Connection, + session: Optional[AsyncClientSession], + conn: AsyncConnection, op_id: int, retryable: bool, full_result: MutableMapping[str, Any], @@ -335,8 +333,8 @@ async def _execute_command( self.next_run = None run = self.current_run - # Connection.command validates the session, but we use - # Connection.write_command + # AsyncConnection.command validates the session, but we use + # AsyncConnection.write_command conn.validate_session(client, session) last_run = False @@ -422,7 +420,7 @@ async def execute_command( self, generator: Iterator[Any], write_concern: WriteConcern, - session: Optional[ClientSession], + session: Optional[AsyncClientSession], operation: str, ) -> dict[str, Any]: """Execute using write commands.""" @@ -440,7 +438,7 @@ async def execute_command( op_id = _randint() async def retryable_bulk( - session: Optional[ClientSession], conn: Connection, retryable: bool + session: Optional[AsyncClientSession], conn: AsyncConnection, retryable: bool ) -> None: await self._execute_command( generator, @@ -466,7 +464,9 @@ async def retryable_bulk( _raise_bulk_write_error(full_result) return full_result - async def execute_op_msg_no_results(self, conn: Connection, generator: Iterator[Any]) -> None: + async def execute_op_msg_no_results( + self, conn: AsyncConnection, generator: Iterator[Any] + ) -> None: """Execute write commands with OP_MSG and w=0 writeConcern, unordered.""" db_name = self.collection.database.name client = self.collection.database.client @@ -505,7 +505,7 @@ async def execute_op_msg_no_results(self, conn: Connection, generator: Iterator[ async def execute_command_no_results( self, - conn: Connection, + conn: AsyncConnection, generator: Iterator[Any], write_concern: WriteConcern, ) -> None: @@ -541,7 +541,7 @@ async def execute_command_no_results( async def execute_no_results( self, - conn: Connection, + conn: AsyncConnection, generator: Iterator[Any], write_concern: WriteConcern, ) -> None: @@ -573,7 +573,7 @@ async def execute_no_results( async def execute( self, write_concern: WriteConcern, - session: Optional[ClientSession], + session: Optional[AsyncClientSession], operation: str, ) -> Any: """Execute operations.""" diff --git a/pymongo/asynchronous/change_stream.py b/pymongo/asynchronous/change_stream.py index b910767c5f..e298df43ad 100644 --- a/pymongo/asynchronous/change_stream.py +++ b/pymongo/asynchronous/change_stream.py @@ -21,17 +21,14 @@ from bson import CodecOptions, _bson_to_dict from bson.raw_bson import RawBSONDocument from bson.timestamp import Timestamp -from pymongo import _csot -from pymongo.asynchronous import common +from pymongo import _csot, common from pymongo.asynchronous.aggregation import ( _AggregationCommand, _CollectionAggregationCommand, _DatabaseAggregationCommand, ) -from pymongo.asynchronous.collation import validate_collation_or_none from pymongo.asynchronous.command_cursor import AsyncCommandCursor -from pymongo.asynchronous.operations import _Op -from pymongo.asynchronous.typings import _CollationIn, _DocumentType, _Pipeline +from pymongo.collation import validate_collation_or_none from pymongo.errors import ( ConnectionFailure, CursorNotFound, @@ -39,6 +36,8 @@ OperationFailure, PyMongoError, ) +from pymongo.operations import _Op +from pymongo.typings import _CollationIn, _DocumentType, _Pipeline _IS_SYNC = False @@ -68,11 +67,11 @@ if TYPE_CHECKING: - from pymongo.asynchronous.client_session import ClientSession + from pymongo.asynchronous.client_session import AsyncClientSession from pymongo.asynchronous.collection import AsyncCollection from pymongo.asynchronous.database import AsyncDatabase from pymongo.asynchronous.mongo_client import AsyncMongoClient - from pymongo.asynchronous.pool import Connection + from pymongo.asynchronous.pool import AsyncConnection def _resumable(exc: PyMongoError) -> bool: @@ -114,7 +113,7 @@ def __init__( batch_size: Optional[int], collation: Optional[_CollationIn], start_at_operation_time: Optional[Timestamp], - session: Optional[ClientSession], + session: Optional[AsyncClientSession], start_after: Optional[Mapping[str, Any]], comment: Optional[Any] = None, full_document_before_change: Optional[str] = None, @@ -211,7 +210,7 @@ def _aggregation_pipeline(self) -> list[dict[str, Any]]: full_pipeline.extend(self._pipeline) return full_pipeline - def _process_result(self, result: Mapping[str, Any], conn: Connection) -> None: + def _process_result(self, result: Mapping[str, Any], conn: AsyncConnection) -> None: """Callback that caches the postBatchResumeToken or startAtOperationTime from a changeStream aggregate command response containing an empty batch of change documents. @@ -237,7 +236,7 @@ def _process_result(self, result: Mapping[str, Any], conn: Connection) -> None: ) async def _run_aggregation_cmd( - self, session: Optional[ClientSession], explicit_session: bool + self, session: Optional[AsyncClientSession], explicit_session: bool ) -> AsyncCommandCursor: """Run the full aggregation pipeline for this ChangeStream and return the corresponding AsyncCommandCursor. diff --git a/pymongo/asynchronous/client_options.py b/pymongo/asynchronous/client_options.py deleted file mode 100644 index 834b61ceb9..0000000000 --- a/pymongo/asynchronous/client_options.py +++ /dev/null @@ -1,334 +0,0 @@ -# Copyright 2014-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you -# may not use this file except in compliance with the License. You -# may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or -# implied. See the License for the specific language governing -# permissions and limitations under the License. - -"""Tools to parse mongo client options.""" -from __future__ import annotations - -from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence, cast - -from bson.codec_options import _parse_codec_options -from pymongo.asynchronous import common -from pymongo.asynchronous.compression_support import CompressionSettings -from pymongo.asynchronous.monitoring import _EventListener, _EventListeners -from pymongo.asynchronous.pool import PoolOptions -from pymongo.asynchronous.read_preferences import ( - _ServerMode, - make_read_preference, - read_pref_mode_from_name, -) -from pymongo.asynchronous.server_selectors import any_server_selector -from pymongo.errors import ConfigurationError -from pymongo.read_concern import ReadConcern -from pymongo.ssl_support import get_ssl_context -from pymongo.write_concern import WriteConcern, validate_boolean - -if TYPE_CHECKING: - from bson.codec_options import CodecOptions - from pymongo.asynchronous.auth import MongoCredential - from pymongo.asynchronous.encryption_options import AutoEncryptionOpts - from pymongo.asynchronous.topology_description import _ServerSelector - from pymongo.pyopenssl_context import SSLContext - -_IS_SYNC = False - - -def _parse_credentials( - username: str, password: str, database: Optional[str], options: Mapping[str, Any] -) -> Optional[MongoCredential]: - """Parse authentication credentials.""" - mechanism = options.get("authmechanism", "DEFAULT" if username else None) - source = options.get("authsource") - if username or mechanism: - from pymongo.asynchronous.auth import _build_credentials_tuple - - return _build_credentials_tuple(mechanism, source, username, password, options, database) - return None - - -def _parse_read_preference(options: Mapping[str, Any]) -> _ServerMode: - """Parse read preference options.""" - if "read_preference" in options: - return options["read_preference"] - - name = options.get("readpreference", "primary") - mode = read_pref_mode_from_name(name) - tags = options.get("readpreferencetags") - max_staleness = options.get("maxstalenessseconds", -1) - return make_read_preference(mode, tags, max_staleness) - - -def _parse_write_concern(options: Mapping[str, Any]) -> WriteConcern: - """Parse write concern options.""" - concern = options.get("w") - wtimeout = options.get("wtimeoutms") - j = options.get("journal") - fsync = options.get("fsync") - return WriteConcern(concern, wtimeout, j, fsync) - - -def _parse_read_concern(options: Mapping[str, Any]) -> ReadConcern: - """Parse read concern options.""" - concern = options.get("readconcernlevel") - return ReadConcern(concern) - - -def _parse_ssl_options(options: Mapping[str, Any]) -> tuple[Optional[SSLContext], bool]: - """Parse ssl options.""" - use_tls = options.get("tls") - if use_tls is not None: - validate_boolean("tls", use_tls) - - certfile = options.get("tlscertificatekeyfile") - passphrase = options.get("tlscertificatekeyfilepassword") - ca_certs = options.get("tlscafile") - crlfile = options.get("tlscrlfile") - allow_invalid_certificates = options.get("tlsallowinvalidcertificates", False) - allow_invalid_hostnames = options.get("tlsallowinvalidhostnames", False) - disable_ocsp_endpoint_check = options.get("tlsdisableocspendpointcheck", False) - - enabled_tls_opts = [] - for opt in ( - "tlscertificatekeyfile", - "tlscertificatekeyfilepassword", - "tlscafile", - "tlscrlfile", - ): - # Any non-null value of these options implies tls=True. - if opt in options and options[opt]: - enabled_tls_opts.append(opt) - for opt in ( - "tlsallowinvalidcertificates", - "tlsallowinvalidhostnames", - "tlsdisableocspendpointcheck", - ): - # A value of False for these options implies tls=True. - if opt in options and not options[opt]: - enabled_tls_opts.append(opt) - - if enabled_tls_opts: - if use_tls is None: - # Implicitly enable TLS when one of the tls* options is set. - use_tls = True - elif not use_tls: - # Error since tls is explicitly disabled but a tls option is set. - raise ConfigurationError( - "TLS has not been enabled but the " - "following tls parameters have been set: " - "%s. Please set `tls=True` or remove." % ", ".join(enabled_tls_opts) - ) - - if use_tls: - ctx = get_ssl_context( - certfile, - passphrase, - ca_certs, - crlfile, - allow_invalid_certificates, - allow_invalid_hostnames, - disable_ocsp_endpoint_check, - ) - return ctx, allow_invalid_hostnames - return None, allow_invalid_hostnames - - -def _parse_pool_options( - username: str, password: str, database: Optional[str], options: Mapping[str, Any] -) -> PoolOptions: - """Parse connection pool options.""" - credentials = _parse_credentials(username, password, database, options) - max_pool_size = options.get("maxpoolsize", common.MAX_POOL_SIZE) - min_pool_size = options.get("minpoolsize", common.MIN_POOL_SIZE) - max_idle_time_seconds = options.get("maxidletimems", common.MAX_IDLE_TIME_SEC) - if max_pool_size is not None and min_pool_size > max_pool_size: - raise ValueError("minPoolSize must be smaller or equal to maxPoolSize") - connect_timeout = options.get("connecttimeoutms", common.CONNECT_TIMEOUT) - socket_timeout = options.get("sockettimeoutms") - wait_queue_timeout = options.get("waitqueuetimeoutms", common.WAIT_QUEUE_TIMEOUT) - event_listeners = cast(Optional[Sequence[_EventListener]], options.get("event_listeners")) - appname = options.get("appname") - driver = options.get("driver") - server_api = options.get("server_api") - compression_settings = CompressionSettings( - options.get("compressors", []), options.get("zlibcompressionlevel", -1) - ) - ssl_context, tls_allow_invalid_hostnames = _parse_ssl_options(options) - load_balanced = options.get("loadbalanced") - max_connecting = options.get("maxconnecting", common.MAX_CONNECTING) - return PoolOptions( - max_pool_size, - min_pool_size, - max_idle_time_seconds, - connect_timeout, - socket_timeout, - wait_queue_timeout, - ssl_context, - tls_allow_invalid_hostnames, - _EventListeners(event_listeners), - appname, - driver, - compression_settings, - max_connecting=max_connecting, - server_api=server_api, - load_balanced=load_balanced, - credentials=credentials, - ) - - -class ClientOptions: - """Read only configuration options for an AsyncMongoClient. - - Should not be instantiated directly by application developers. Access - a client's options via :attr:`pymongo.mongo_client.AsyncMongoClient.options` - instead. - """ - - def __init__( - self, username: str, password: str, database: Optional[str], options: Mapping[str, Any] - ): - self.__options = options - self.__codec_options = _parse_codec_options(options) - self.__direct_connection = options.get("directconnection") - self.__local_threshold_ms = options.get("localthresholdms", common.LOCAL_THRESHOLD_MS) - # self.__server_selection_timeout is in seconds. Must use full name for - # common.SERVER_SELECTION_TIMEOUT because it is set directly by tests. - self.__server_selection_timeout = options.get( - "serverselectiontimeoutms", common.SERVER_SELECTION_TIMEOUT - ) - self.__pool_options = _parse_pool_options(username, password, database, options) - self.__read_preference = _parse_read_preference(options) - self.__replica_set_name = options.get("replicaset") - self.__write_concern = _parse_write_concern(options) - self.__read_concern = _parse_read_concern(options) - self.__connect = options.get("connect") - self.__heartbeat_frequency = options.get("heartbeatfrequencyms", common.HEARTBEAT_FREQUENCY) - self.__retry_writes = options.get("retrywrites", common.RETRY_WRITES) - self.__retry_reads = options.get("retryreads", common.RETRY_READS) - self.__server_selector = options.get("server_selector", any_server_selector) - self.__auto_encryption_opts = options.get("auto_encryption_opts") - self.__load_balanced = options.get("loadbalanced") - self.__timeout = options.get("timeoutms") - self.__server_monitoring_mode = options.get( - "servermonitoringmode", common.SERVER_MONITORING_MODE - ) - - @property - def _options(self) -> Mapping[str, Any]: - """The original options used to create this ClientOptions.""" - return self.__options - - @property - def connect(self) -> Optional[bool]: - """Whether to begin discovering a MongoDB topology automatically.""" - return self.__connect - - @property - def codec_options(self) -> CodecOptions: - """A :class:`~bson.codec_options.CodecOptions` instance.""" - return self.__codec_options - - @property - def direct_connection(self) -> Optional[bool]: - """Whether to connect to the deployment in 'Single' topology.""" - return self.__direct_connection - - @property - def local_threshold_ms(self) -> int: - """The local threshold for this instance.""" - return self.__local_threshold_ms - - @property - def server_selection_timeout(self) -> int: - """The server selection timeout for this instance in seconds.""" - return self.__server_selection_timeout - - @property - def server_selector(self) -> _ServerSelector: - return self.__server_selector - - @property - def heartbeat_frequency(self) -> int: - """The monitoring frequency in seconds.""" - return self.__heartbeat_frequency - - @property - def pool_options(self) -> PoolOptions: - """A :class:`~pymongo.pool.PoolOptions` instance.""" - return self.__pool_options - - @property - def read_preference(self) -> _ServerMode: - """A read preference instance.""" - return self.__read_preference - - @property - def replica_set_name(self) -> Optional[str]: - """Replica set name or None.""" - return self.__replica_set_name - - @property - def write_concern(self) -> WriteConcern: - """A :class:`~pymongo.write_concern.WriteConcern` instance.""" - return self.__write_concern - - @property - def read_concern(self) -> ReadConcern: - """A :class:`~pymongo.read_concern.ReadConcern` instance.""" - return self.__read_concern - - @property - def timeout(self) -> Optional[float]: - """The configured timeoutMS converted to seconds, or None. - - .. versionadded:: 4.2 - """ - return self.__timeout - - @property - def retry_writes(self) -> bool: - """If this instance should retry supported write operations.""" - return self.__retry_writes - - @property - def retry_reads(self) -> bool: - """If this instance should retry supported read operations.""" - return self.__retry_reads - - @property - def auto_encryption_opts(self) -> Optional[AutoEncryptionOpts]: - """A :class:`~pymongo.encryption.AutoEncryptionOpts` or None.""" - return self.__auto_encryption_opts - - @property - def load_balanced(self) -> Optional[bool]: - """True if the client was configured to connect to a load balancer.""" - return self.__load_balanced - - @property - def event_listeners(self) -> list[_EventListeners]: - """The event listeners registered for this client. - - See :mod:`~pymongo.monitoring` for details. - - .. versionadded:: 4.0 - """ - assert self.__pool_options._event_listeners is not None - return self.__pool_options._event_listeners.event_listeners() - - @property - def server_monitoring_mode(self) -> str: - """The configured serverMonitoringMode option. - - .. versionadded:: 4.5 - """ - return self.__server_monitoring_mode diff --git a/pymongo/asynchronous/client_session.py b/pymongo/asynchronous/client_session.py index fcaf26a872..f5d1eaea95 100644 --- a/pymongo/asynchronous/client_session.py +++ b/pymongo/asynchronous/client_session.py @@ -44,8 +44,8 @@ .. versionadded:: 3.7 MongoDB 4.0 adds support for transactions on replica set primaries. A -transaction is associated with a :class:`ClientSession`. To start a transaction -on a session, use :meth:`ClientSession.start_transaction` in a with-statement. +transaction is associated with a :class:`AsyncClientSession`. To start a transaction +on a session, use :meth:`AsyncClientSession.start_transaction` in a with-statement. Then, execute an operation within the transaction by passing the session to the operation: @@ -63,9 +63,9 @@ ) Upon normal completion of ``async with session.start_transaction()`` block, the -transaction automatically calls :meth:`ClientSession.commit_transaction`. +transaction automatically calls :meth:`AsyncClientSession.commit_transaction`. If the block exits with an exception, the transaction automatically calls -:meth:`ClientSession.abort_transaction`. +:meth:`AsyncClientSession.abort_transaction`. In general, multi-document transactions only support read/write (CRUD) operations on existing collections. However, MongoDB 4.4 adds support for @@ -157,8 +157,6 @@ from bson.timestamp import Timestamp from pymongo import _csot from pymongo.asynchronous.cursor import _ConnectionManager -from pymongo.asynchronous.operations import _Op -from pymongo.asynchronous.read_preferences import ReadPreference, _ServerMode from pymongo.errors import ( ConfigurationError, ConnectionFailure, @@ -167,23 +165,25 @@ PyMongoError, WTimeoutError, ) -from pymongo.helpers_constants import _RETRYABLE_ERROR_CODES +from pymongo.helpers_shared import _RETRYABLE_ERROR_CODES +from pymongo.operations import _Op from pymongo.read_concern import ReadConcern +from pymongo.read_preferences import ReadPreference, _ServerMode from pymongo.server_type import SERVER_TYPE from pymongo.write_concern import WriteConcern if TYPE_CHECKING: from types import TracebackType - from pymongo.asynchronous.pool import Connection + from pymongo.asynchronous.pool import AsyncConnection from pymongo.asynchronous.server import Server - from pymongo.asynchronous.typings import ClusterTime, _Address + from pymongo.typings import ClusterTime, _Address _IS_SYNC = False class SessionOptions: - """Options for a new :class:`ClientSession`. + """Options for a new :class:`AsyncClientSession`. :param causal_consistency: If True, read operations are causally ordered within the session. Defaults to True when the ``snapshot`` @@ -246,7 +246,7 @@ def snapshot(self) -> Optional[bool]: class TransactionOptions: - """Options for :meth:`ClientSession.start_transaction`. + """Options for :meth:`AsyncClientSession.start_transaction`. :param read_concern: The :class:`~pymongo.read_concern.ReadConcern` to use for this transaction. @@ -336,8 +336,8 @@ def max_commit_time_ms(self) -> Optional[int]: def _validate_session_write_concern( - session: Optional[ClientSession], write_concern: Optional[WriteConcern] -) -> Optional[ClientSession]: + session: Optional[AsyncClientSession], write_concern: Optional[WriteConcern] +) -> Optional[AsyncClientSession]: """Validate that an explicit session is not used with an unack'ed write. Returns the session to use for the next operation. @@ -362,7 +362,7 @@ def _validate_session_write_concern( class _TransactionContext: """Internal transaction context manager for start_transaction.""" - def __init__(self, session: ClientSession): + def __init__(self, session: AsyncClientSession): self.__session = session async def __aenter__(self) -> _TransactionContext: @@ -391,7 +391,7 @@ class _TxnState: class _Transaction: - """Internal class to hold transaction information in a ClientSession.""" + """Internal class to hold transaction information in a AsyncClientSession.""" def __init__(self, opts: Optional[TransactionOptions], client: AsyncMongoClient): self.opts = opts @@ -410,12 +410,12 @@ def starting(self) -> bool: return self.state == _TxnState.STARTING @property - def pinned_conn(self) -> Optional[Connection]: + def pinned_conn(self) -> Optional[AsyncConnection]: if self.active() and self.conn_mgr: return self.conn_mgr.conn return None - def pin(self, server: Server, conn: Connection) -> None: + def pin(self, server: Server, conn: AsyncConnection) -> None: self.sharded = True self.pinned_address = server.description.address if server.description.server_type == SERVER_TYPE.LoadBalancer: @@ -481,16 +481,16 @@ def _within_time_limit(start_time: float) -> bool: from pymongo.asynchronous.mongo_client import AsyncMongoClient -class ClientSession: +class AsyncClientSession: """A session for ordering sequential operations. - :class:`ClientSession` instances are **not thread-safe or fork-safe**. + :class:`AsyncClientSession` instances are **not thread-safe or fork-safe**. They can only be used by one thread or process at a time. A single - :class:`ClientSession` cannot be used to run multiple operations + :class:`AsyncClientSession` cannot be used to run multiple operations concurrently. Should not be initialized directly by application developers - to create a - :class:`ClientSession`, call + :class:`AsyncClientSession`, call :meth:`~pymongo.mongo_client.AsyncMongoClient.start_session`. """ @@ -535,7 +535,7 @@ def _check_ended(self) -> None: if self._server_session is None: raise InvalidOperation("Cannot use ended session") - async def __aenter__(self) -> ClientSession: + async def __aenter__(self) -> AsyncClientSession: return self async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: @@ -592,7 +592,7 @@ def _inherit_option(self, name: str, val: _T) -> _T: async def with_transaction( self, - callback: Callable[[ClientSession], _T], + callback: Callable[[AsyncClientSession], _T], read_concern: Optional[ReadConcern] = None, write_concern: Optional[WriteConcern] = None, read_preference: Optional[_ServerMode] = None, @@ -640,25 +640,25 @@ async def callback(session, custom_arg, custom_kwarg=None): however, ``with_transaction`` will return without taking further action. - :class:`ClientSession` instances are **not thread-safe or fork-safe**. + :class:`AsyncClientSession` instances are **not thread-safe or fork-safe**. Consequently, the ``callback`` must not attempt to execute multiple operations concurrently. When ``callback`` raises an exception, ``with_transaction`` automatically aborts the current transaction. When ``callback`` or - :meth:`~ClientSession.commit_transaction` raises an exception that + :meth:`~AsyncClientSession.commit_transaction` raises an exception that includes the ``"TransientTransactionError"`` error label, ``with_transaction`` starts a new transaction and re-executes the ``callback``. - When :meth:`~ClientSession.commit_transaction` raises an exception with + When :meth:`~AsyncClientSession.commit_transaction` raises an exception with the ``"UnknownTransactionCommitResult"`` error label, ``with_transaction`` retries the commit until the result of the transaction is known. This method will cease retrying after 120 seconds has elapsed. This timeout is not configurable and any exception raised by the - ``callback`` or by :meth:`ClientSession.commit_transaction` after the + ``callback`` or by :meth:`AsyncClientSession.commit_transaction` after the timeout is reached will be re-raised. Applications that desire a different timeout duration should not use this method. @@ -844,7 +844,7 @@ async def _finish_transaction_with_retry(self, command_name: str) -> dict[str, A """ async def func( - _session: Optional[ClientSession], conn: Connection, _retryable: bool + _session: Optional[AsyncClientSession], conn: AsyncConnection, _retryable: bool ) -> dict[str, Any]: return await self._finish_transaction(conn, command_name) @@ -852,7 +852,7 @@ async def func( func, self, None, retryable=True, operation=_Op.ABORT ) - async def _finish_transaction(self, conn: Connection, command_name: str) -> dict[str, Any]: + async def _finish_transaction(self, conn: AsyncConnection, command_name: str) -> dict[str, Any]: self._transaction.attempt += 1 opts = self._transaction.opts assert opts @@ -891,8 +891,8 @@ def advance_cluster_time(self, cluster_time: Mapping[str, Any]) -> None: """Update the cluster time for this session. :param cluster_time: The - :data:`~pymongo.client_session.ClientSession.cluster_time` from - another `ClientSession` instance. + :data:`~pymongo.client_session.AsyncClientSession.cluster_time` from + another `AsyncClientSession` instance. """ if not isinstance(cluster_time, _Mapping): raise TypeError("cluster_time must be a subclass of collections.Mapping") @@ -912,8 +912,8 @@ def advance_operation_time(self, operation_time: Timestamp) -> None: """Update the operation time for this session. :param operation_time: The - :data:`~pymongo.client_session.ClientSession.operation_time` from - another `ClientSession` instance. + :data:`~pymongo.client_session.AsyncClientSession.operation_time` from + another `AsyncClientSession` instance. """ if not isinstance(operation_time, Timestamp): raise TypeError("operation_time must be an instance of bson.timestamp.Timestamp") @@ -960,11 +960,11 @@ def _pinned_address(self) -> Optional[_Address]: return None @property - def _pinned_connection(self) -> Optional[Connection]: + def _pinned_connection(self) -> Optional[AsyncConnection]: """The connection this transaction was started on.""" return self._transaction.pinned_conn - def _pin(self, server: Server, conn: Connection) -> None: + def _pin(self, server: Server, conn: AsyncConnection) -> None: """Pin this session to the given Server or to the given connection.""" self._transaction.pin(server, conn) @@ -993,7 +993,7 @@ async def _apply_to( command: MutableMapping[str, Any], is_retryable: bool, read_preference: _ServerMode, - conn: Connection, + conn: AsyncConnection, ) -> None: if not conn.supports_sessions: if not self._implicit: @@ -1036,7 +1036,7 @@ def _start_retryable_write(self) -> None: self._check_ended() self._server_session.inc_transaction_id() - def _update_read_concern(self, cmd: MutableMapping[str, Any], conn: Connection) -> None: + def _update_read_concern(self, cmd: MutableMapping[str, Any], conn: AsyncConnection) -> None: if self.options.causal_consistency and self.operation_time is not None: cmd.setdefault("readConcern", {})["afterClusterTime"] = self.operation_time if self.options.snapshot: @@ -1048,7 +1048,7 @@ def _update_read_concern(self, cmd: MutableMapping[str, Any], conn: Connection) rc["atClusterTime"] = self._snapshot_time def __copy__(self) -> NoReturn: - raise TypeError("A ClientSession cannot be copied, create a new session instead") + raise TypeError("A AsyncClientSession cannot be copied, create a new session instead") class _EmptyServerSession: diff --git a/pymongo/asynchronous/collation.py b/pymongo/asynchronous/collation.py deleted file mode 100644 index 26d5a68d7d..0000000000 --- a/pymongo/asynchronous/collation.py +++ /dev/null @@ -1,226 +0,0 @@ -# Copyright 2016 MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tools for working with `collations`_. - -.. _collations: https://www.mongodb.com/docs/manual/reference/collation/ -""" -from __future__ import annotations - -from typing import Any, Mapping, Optional, Union - -from pymongo.asynchronous import common -from pymongo.write_concern import validate_boolean - -_IS_SYNC = False - - -class CollationStrength: - """ - An enum that defines values for `strength` on a - :class:`~pymongo.collation.Collation`. - """ - - PRIMARY = 1 - """Differentiate base (unadorned) characters.""" - - SECONDARY = 2 - """Differentiate character accents.""" - - TERTIARY = 3 - """Differentiate character case.""" - - QUATERNARY = 4 - """Differentiate words with and without punctuation.""" - - IDENTICAL = 5 - """Differentiate unicode code point (characters are exactly identical).""" - - -class CollationAlternate: - """ - An enum that defines values for `alternate` on a - :class:`~pymongo.collation.Collation`. - """ - - NON_IGNORABLE = "non-ignorable" - """Spaces and punctuation are treated as base characters.""" - - SHIFTED = "shifted" - """Spaces and punctuation are *not* considered base characters. - - Spaces and punctuation are distinguished regardless when the - :class:`~pymongo.collation.Collation` strength is at least - :data:`~pymongo.collation.CollationStrength.QUATERNARY`. - - """ - - -class CollationMaxVariable: - """ - An enum that defines values for `max_variable` on a - :class:`~pymongo.collation.Collation`. - """ - - PUNCT = "punct" - """Both punctuation and spaces are ignored.""" - - SPACE = "space" - """Spaces alone are ignored.""" - - -class CollationCaseFirst: - """ - An enum that defines values for `case_first` on a - :class:`~pymongo.collation.Collation`. - """ - - UPPER = "upper" - """Sort uppercase characters first.""" - - LOWER = "lower" - """Sort lowercase characters first.""" - - OFF = "off" - """Default for locale or collation strength.""" - - -class Collation: - """Collation - - :param locale: (string) The locale of the collation. This should be a string - that identifies an `ICU locale ID` exactly. For example, ``en_US`` is - valid, but ``en_us`` and ``en-US`` are not. Consult the MongoDB - documentation for a list of supported locales. - :param caseLevel: (optional) If ``True``, turn on case sensitivity if - `strength` is 1 or 2 (case sensitivity is implied if `strength` is - greater than 2). Defaults to ``False``. - :param caseFirst: (optional) Specify that either uppercase or lowercase - characters take precedence. Must be one of the following values: - - * :data:`~CollationCaseFirst.UPPER` - * :data:`~CollationCaseFirst.LOWER` - * :data:`~CollationCaseFirst.OFF` (the default) - - :param strength: Specify the comparison strength. This is also - known as the ICU comparison level. This must be one of the following - values: - - * :data:`~CollationStrength.PRIMARY` - * :data:`~CollationStrength.SECONDARY` - * :data:`~CollationStrength.TERTIARY` (the default) - * :data:`~CollationStrength.QUATERNARY` - * :data:`~CollationStrength.IDENTICAL` - - Each successive level builds upon the previous. For example, a - `strength` of :data:`~CollationStrength.SECONDARY` differentiates - characters based both on the unadorned base character and its accents. - - :param numericOrdering: If ``True``, order numbers numerically - instead of in collation order (defaults to ``False``). - :param alternate: Specify whether spaces and punctuation are - considered base characters. This must be one of the following values: - - * :data:`~CollationAlternate.NON_IGNORABLE` (the default) - * :data:`~CollationAlternate.SHIFTED` - - :param maxVariable: When `alternate` is - :data:`~CollationAlternate.SHIFTED`, this option specifies what - characters may be ignored. This must be one of the following values: - - * :data:`~CollationMaxVariable.PUNCT` (the default) - * :data:`~CollationMaxVariable.SPACE` - - :param normalization: If ``True``, normalizes text into Unicode - NFD. Defaults to ``False``. - :param backwards: If ``True``, accents on characters are - considered from the back of the word to the front, as it is done in some - French dictionary ordering traditions. Defaults to ``False``. - :param kwargs: Keyword arguments supplying any additional options - to be sent with this Collation object. - - .. versionadded: 3.4 - - """ - - __slots__ = ("__document",) - - def __init__( - self, - locale: str, - caseLevel: Optional[bool] = None, - caseFirst: Optional[str] = None, - strength: Optional[int] = None, - numericOrdering: Optional[bool] = None, - alternate: Optional[str] = None, - maxVariable: Optional[str] = None, - normalization: Optional[bool] = None, - backwards: Optional[bool] = None, - **kwargs: Any, - ) -> None: - locale = common.validate_string("locale", locale) - self.__document: dict[str, Any] = {"locale": locale} - if caseLevel is not None: - self.__document["caseLevel"] = validate_boolean("caseLevel", caseLevel) - if caseFirst is not None: - self.__document["caseFirst"] = common.validate_string("caseFirst", caseFirst) - if strength is not None: - self.__document["strength"] = common.validate_integer("strength", strength) - if numericOrdering is not None: - self.__document["numericOrdering"] = validate_boolean( - "numericOrdering", numericOrdering - ) - if alternate is not None: - self.__document["alternate"] = common.validate_string("alternate", alternate) - if maxVariable is not None: - self.__document["maxVariable"] = common.validate_string("maxVariable", maxVariable) - if normalization is not None: - self.__document["normalization"] = validate_boolean("normalization", normalization) - if backwards is not None: - self.__document["backwards"] = validate_boolean("backwards", backwards) - self.__document.update(kwargs) - - @property - def document(self) -> dict[str, Any]: - """The document representation of this collation. - - .. note:: - :class:`Collation` is immutable. Mutating the value of - :attr:`document` does not mutate this :class:`Collation`. - """ - return self.__document.copy() - - def __repr__(self) -> str: - document = self.document - return "Collation({})".format(", ".join(f"{key}={document[key]!r}" for key in document)) - - def __eq__(self, other: Any) -> bool: - if isinstance(other, Collation): - return self.document == other.document - return NotImplemented - - def __ne__(self, other: Any) -> bool: - return not self == other - - -def validate_collation_or_none( - value: Optional[Union[Mapping[str, Any], Collation]] -) -> Optional[dict[str, Any]]: - if value is None: - return None - if isinstance(value, Collation): - return value.document - if isinstance(value, dict): - return value - raise TypeError("collation must be a dict, an instance of collation.Collation, or None.") diff --git a/pymongo/asynchronous/collection.py b/pymongo/asynchronous/collection.py index ed396fb9ce..836d4c61e3 100644 --- a/pymongo/asynchronous/collection.py +++ b/pymongo/asynchronous/collection.py @@ -41,27 +41,33 @@ from bson.raw_bson import RawBSONDocument from bson.son import SON from bson.timestamp import Timestamp -from pymongo import ASCENDING, _csot -from pymongo.asynchronous import common, helpers, message +from pymongo import ASCENDING, _csot, common, helpers_shared +from pymongo.asynchronous import message from pymongo.asynchronous.aggregation import ( _CollectionAggregationCommand, _CollectionRawAggregationCommand, ) -from pymongo.asynchronous.bulk import _Bulk +from pymongo.asynchronous.bulk import _AsyncBulk from pymongo.asynchronous.change_stream import CollectionChangeStream -from pymongo.asynchronous.collation import validate_collation_or_none from pymongo.asynchronous.command_cursor import ( AsyncCommandCursor, AsyncRawBatchCommandCursor, ) -from pymongo.asynchronous.common import _ecoc_coll_name, _esc_coll_name from pymongo.asynchronous.cursor import ( AsyncCursor, AsyncRawBatchCursor, ) -from pymongo.asynchronous.helpers import _check_write_command_response from pymongo.asynchronous.message import _UNICODE_REPLACE_CODEC_OPTIONS -from pymongo.asynchronous.operations import ( +from pymongo.collation import validate_collation_or_none +from pymongo.common import _ecoc_coll_name, _esc_coll_name +from pymongo.errors import ( + ConfigurationError, + InvalidName, + InvalidOperation, + OperationFailure, +) +from pymongo.helpers_shared import _check_write_command_response +from pymongo.operations import ( DeleteMany, DeleteOne, IndexModel, @@ -74,15 +80,8 @@ _IndexList, _Op, ) -from pymongo.asynchronous.read_preferences import ReadPreference, _ServerMode -from pymongo.asynchronous.typings import _CollationIn, _DocumentType, _DocumentTypeArg, _Pipeline -from pymongo.errors import ( - ConfigurationError, - InvalidName, - InvalidOperation, - OperationFailure, -) from pymongo.read_concern import DEFAULT_READ_CONCERN +from pymongo.read_preferences import ReadPreference, _ServerMode from pymongo.results import ( BulkWriteResult, DeleteResult, @@ -90,6 +89,7 @@ InsertOneResult, UpdateResult, ) +from pymongo.typings import _CollationIn, _DocumentType, _DocumentTypeArg, _Pipeline from pymongo.write_concern import DEFAULT_WRITE_CONCERN, WriteConcern, validate_boolean _IS_SYNC = False @@ -126,11 +126,11 @@ class ReturnDocument: if TYPE_CHECKING: import bson from pymongo.asynchronous.aggregation import _AggregationCommand - from pymongo.asynchronous.client_session import ClientSession - from pymongo.asynchronous.collation import Collation + from pymongo.asynchronous.client_session import AsyncClientSession from pymongo.asynchronous.database import AsyncDatabase - from pymongo.asynchronous.pool import Connection + from pymongo.asynchronous.pool import AsyncConnection from pymongo.asynchronous.server import Server + from pymongo.collation import Collation from pymongo.read_concern import ReadConcern @@ -146,7 +146,7 @@ def __init__( read_preference: Optional[_ServerMode] = None, write_concern: Optional[WriteConcern] = None, read_concern: Optional[ReadConcern] = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, **kwargs: Any, ) -> None: """Get / create an asynchronous Mongo collection. @@ -367,7 +367,7 @@ def with_options( ) def _write_concern_for_cmd( - self, cmd: Mapping[str, Any], session: Optional[ClientSession] + self, cmd: Mapping[str, Any], session: Optional[AsyncClientSession] ) -> WriteConcern: raw_wc = cmd.get("writeConcern") if raw_wc is not None: @@ -407,7 +407,7 @@ async def watch( batch_size: Optional[int] = None, collation: Optional[_CollationIn] = None, start_at_operation_time: Optional[Timestamp] = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, start_after: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, full_document_before_change: Optional[str] = None, @@ -488,7 +488,7 @@ async def watch( the specified :class:`~bson.timestamp.Timestamp`. Requires MongoDB >= 4.0. :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param start_after: The same as `resume_after` except that `start_after` can resume notifications after an invalidate event. This option and `resume_after` are mutually exclusive. @@ -540,13 +540,13 @@ async def watch( return change_stream async def _conn_for_writes( - self, session: Optional[ClientSession], operation: str - ) -> AsyncContextManager[Connection]: + self, session: Optional[AsyncClientSession], operation: str + ) -> AsyncContextManager[AsyncConnection]: return await self._database.client._conn_for_writes(session, operation) async def _command( self, - conn: Connection, + conn: AsyncConnection, command: MutableMapping[str, Any], read_preference: Optional[_ServerMode] = None, codec_options: Optional[CodecOptions] = None, @@ -555,13 +555,13 @@ async def _command( read_concern: Optional[ReadConcern] = None, write_concern: Optional[WriteConcern] = None, collation: Optional[_CollationIn] = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, retryable_write: bool = False, user_fields: Optional[Any] = None, ) -> Mapping[str, Any]: """Internal command helper. - :param conn` - A Connection instance. + :param conn` - A AsyncConnection instance. :param command` - The command itself, as a :class:`~bson.son.SON` instance. :param read_preference` (optional) - The read preference to use. :param codec_options` (optional) - An instance of @@ -575,7 +575,7 @@ async def _command( :param collation` (optional) - An instance of :class:`~pymongo.collation.Collation`. :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param retryable_write: True if this command is a retryable write. :param user_fields: Response fields that should be decoded @@ -607,7 +607,7 @@ async def _create_helper( name: str, options: MutableMapping[str, Any], collation: Optional[_CollationIn], - session: Optional[ClientSession], + session: Optional[AsyncClientSession], encrypted_fields: Optional[Mapping[str, Any]] = None, qev2_required: bool = False, ) -> None: @@ -640,7 +640,7 @@ async def _create_helper( async def _create( self, options: MutableMapping[str, Any], - session: Optional[ClientSession], + session: Optional[AsyncClientSession], ) -> None: collation = validate_collation_or_none(options.pop("collation", None)) encrypted_fields = options.pop("encryptedFields", None) @@ -670,7 +670,7 @@ async def bulk_write( requests: Sequence[_WriteOp[_DocumentType]], ordered: bool = True, bypass_document_validation: bool = False, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, comment: Optional[Any] = None, let: Optional[Mapping] = None, ) -> BulkWriteResult: @@ -720,7 +720,7 @@ async def bulk_write( write to opt-out of document level validation. Default is ``False``. :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param comment: A user-provided comment to attach to this command. :param let: Map of parameter names and values. Values must be @@ -749,7 +749,7 @@ async def bulk_write( """ common.validate_list("requests", requests) - blk = _Bulk(self, ordered, bypass_document_validation, comment=comment, let=let) + blk = _AsyncBulk(self, ordered, bypass_document_validation, comment=comment, let=let) for request in requests: try: request._add_to_bulk(blk) @@ -769,7 +769,7 @@ async def _insert_one( write_concern: WriteConcern, op_id: Optional[int], bypass_doc_val: bool, - session: Optional[ClientSession], + session: Optional[AsyncClientSession], comment: Optional[Any] = None, ) -> Any: """Internal helper for inserting a single document.""" @@ -780,7 +780,7 @@ async def _insert_one( command["comment"] = comment async def _insert_command( - session: Optional[ClientSession], conn: Connection, retryable_write: bool + session: Optional[AsyncClientSession], conn: AsyncConnection, retryable_write: bool ) -> None: if bypass_doc_val: command["bypassDocumentValidation"] = True @@ -809,7 +809,7 @@ async def insert_one( self, document: Union[_DocumentType, RawBSONDocument], bypass_document_validation: bool = False, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, comment: Optional[Any] = None, ) -> InsertOneResult: """Insert a single document. @@ -829,7 +829,7 @@ async def insert_one( write to opt-out of document level validation. Default is ``False``. :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param comment: A user-provided comment to attach to this command. @@ -875,7 +875,7 @@ async def insert_many( documents: Iterable[Union[_DocumentType, RawBSONDocument]], ordered: bool = True, bypass_document_validation: bool = False, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, comment: Optional[Any] = None, ) -> InsertManyResult: """Insert an iterable of documents. @@ -898,7 +898,7 @@ async def insert_many( write to opt-out of document level validation. Default is ``False``. :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param comment: A user-provided comment to attach to this command. @@ -939,14 +939,14 @@ def gen() -> Iterator[tuple[int, Mapping[str, Any]]]: yield (message._INSERT, document) write_concern = self._write_concern_for(session) - blk = _Bulk(self, ordered, bypass_document_validation, comment=comment) + blk = _AsyncBulk(self, ordered, bypass_document_validation, comment=comment) blk.ops = list(gen()) await blk.execute(write_concern, session, _Op.INSERT) return InsertManyResult(inserted_ids, write_concern.acknowledged) async def _update( self, - conn: Connection, + conn: AsyncConnection, criteria: Mapping[str, Any], document: Union[Mapping[str, Any], _Pipeline], upsert: bool = False, @@ -958,7 +958,7 @@ async def _update( collation: Optional[_CollationIn] = None, array_filters: Optional[Sequence[Mapping[str, Any]]] = None, hint: Optional[_IndexKeyHint] = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, retryable_write: bool = False, let: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, @@ -990,7 +990,7 @@ async def _update( "Must be connected to MongoDB 4.2+ to use hint on unacknowledged update commands." ) if not isinstance(hint, str): - hint = helpers._index_document(hint) + hint = helpers_shared._index_document(hint) update_doc["hint"] = hint command = {"update": self.name, "ordered": ordered, "updates": [update_doc]} if let is not None: @@ -1045,14 +1045,14 @@ async def _update_retryable( collation: Optional[_CollationIn] = None, array_filters: Optional[Sequence[Mapping[str, Any]]] = None, hint: Optional[_IndexKeyHint] = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, let: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, ) -> Optional[Mapping[str, Any]]: """Internal update / replace helper.""" async def _update( - session: Optional[ClientSession], conn: Connection, retryable_write: bool + session: Optional[AsyncClientSession], conn: AsyncConnection, retryable_write: bool ) -> Optional[Mapping[str, Any]]: return await self._update( conn, @@ -1088,7 +1088,7 @@ async def replace_one( bypass_document_validation: bool = False, collation: Optional[_CollationIn] = None, hint: Optional[_IndexKeyHint] = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, let: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, ) -> UpdateResult: @@ -1137,7 +1137,7 @@ async def replace_one( ``[('field', ASCENDING)]``). This option is only supported on MongoDB 4.2 and above. :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param let: Map of parameter names and values. Values must be constant or closed expressions that do not reference document fields. Parameters can then be accessed as variables in an @@ -1191,7 +1191,7 @@ async def update_one( collation: Optional[_CollationIn] = None, array_filters: Optional[Sequence[Mapping[str, Any]]] = None, hint: Optional[_IndexKeyHint] = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, let: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, ) -> UpdateResult: @@ -1246,7 +1246,7 @@ async def update_one( ``[('field', ASCENDING)]``). This option is only supported on MongoDB 4.2 and above. :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param let: Map of parameter names and values. Values must be constant or closed expressions that do not reference document fields. Parameters can then be accessed as variables in an @@ -1304,7 +1304,7 @@ async def update_many( bypass_document_validation: Optional[bool] = None, collation: Optional[_CollationIn] = None, hint: Optional[_IndexKeyHint] = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, let: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, ) -> UpdateResult: @@ -1346,7 +1346,7 @@ async def update_many( ``[('field', ASCENDING)]``). This option is only supported on MongoDB 4.2 and above. :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param let: Map of parameter names and values. Values must be constant or closed expressions that do not reference document fields. Parameters can then be accessed as variables in an @@ -1398,14 +1398,14 @@ async def update_many( async def drop( self, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, comment: Optional[Any] = None, encrypted_fields: Optional[Mapping[str, Any]] = None, ) -> None: """Alias for :meth:`~pymongo.database.AsyncDatabase.drop_collection`. :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param comment: A user-provided comment to attach to this command. :param encrypted_fields: **(BETA)** Document that describes the encrypted fields for @@ -1441,7 +1441,7 @@ async def drop( async def _delete( self, - conn: Connection, + conn: AsyncConnection, criteria: Mapping[str, Any], multi: bool, write_concern: Optional[WriteConcern] = None, @@ -1449,7 +1449,7 @@ async def _delete( ordered: bool = True, collation: Optional[_CollationIn] = None, hint: Optional[_IndexKeyHint] = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, retryable_write: bool = False, let: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, @@ -1471,7 +1471,7 @@ async def _delete( "Must be connected to MongoDB 4.4+ to use hint on unacknowledged delete commands." ) if not isinstance(hint, str): - hint = helpers._index_document(hint) + hint = helpers_shared._index_document(hint) delete_doc["hint"] = hint command = {"delete": self.name, "ordered": ordered, "deletes": [delete_doc]} @@ -1504,14 +1504,14 @@ async def _delete_retryable( ordered: bool = True, collation: Optional[_CollationIn] = None, hint: Optional[_IndexKeyHint] = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, let: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, ) -> Mapping[str, Any]: """Internal delete helper.""" async def _delete( - session: Optional[ClientSession], conn: Connection, retryable_write: bool + session: Optional[AsyncClientSession], conn: AsyncConnection, retryable_write: bool ) -> Mapping[str, Any]: return await self._delete( conn, @@ -1540,7 +1540,7 @@ async def delete_one( filter: Mapping[str, Any], collation: Optional[_CollationIn] = None, hint: Optional[_IndexKeyHint] = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, let: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, ) -> DeleteResult: @@ -1564,7 +1564,7 @@ async def delete_one( ``[('field', ASCENDING)]``). This option is only supported on MongoDB 4.4 and above. :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param let: Map of parameter names and values. Values must be constant or closed expressions that do not reference document fields. Parameters can then be accessed as variables in an @@ -1605,7 +1605,7 @@ async def delete_many( filter: Mapping[str, Any], collation: Optional[_CollationIn] = None, hint: Optional[_IndexKeyHint] = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, let: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, ) -> DeleteResult: @@ -1629,7 +1629,7 @@ async def delete_many( ``[('field', ASCENDING)]``). This option is only supported on MongoDB 4.4 and above. :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param let: Map of parameter names and values. Values must be constant or closed expressions that do not reference document fields. Parameters can then be accessed as variables in an @@ -1731,7 +1731,7 @@ async def find(self, *args: Any, **kwargs: Any) -> AsyncCursor[_DocumentType]: always be returned. Use a dict to exclude fields from the result (e.g. projection={'_id': False}). :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param skip: the number of documents to omit (from the start of the result set) when returning the results :param limit: the maximum number of results to @@ -1925,8 +1925,8 @@ async def find_raw_batches( async def _count_cmd( self, - session: Optional[ClientSession], - conn: Connection, + session: Optional[AsyncClientSession], + conn: AsyncConnection, read_preference: Optional[_ServerMode], cmd: dict[str, Any], collation: Optional[Collation], @@ -1950,11 +1950,11 @@ async def _count_cmd( async def _aggregate_one_result( self, - conn: Connection, + conn: AsyncConnection, read_preference: Optional[_ServerMode], cmd: dict[str, Any], collation: Optional[_CollationIn], - session: Optional[ClientSession], + session: Optional[AsyncClientSession], ) -> Optional[Mapping[str, Any]]: """Internal helper to run an aggregate that returns a single result.""" result = await self._command( @@ -2006,9 +2006,9 @@ async def estimated_document_count(self, comment: Optional[Any] = None, **kwargs kwargs["comment"] = comment async def _cmd( - session: Optional[ClientSession], + session: Optional[AsyncClientSession], _server: Server, - conn: Connection, + conn: AsyncConnection, read_preference: Optional[_ServerMode], ) -> int: cmd: dict[str, Any] = {"count": self._name} @@ -2020,7 +2020,7 @@ async def _cmd( async def count_documents( self, filter: Mapping[str, Any], - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, comment: Optional[Any] = None, **kwargs: Any, ) -> int: @@ -2066,7 +2066,7 @@ async def count_documents( to count in the collection. Can be an empty document to count all documents. :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param comment: A user-provided comment to attach to this command. :param kwargs: See list of options above. @@ -2089,14 +2089,14 @@ async def count_documents( pipeline.append({"$group": {"_id": 1, "n": {"$sum": 1}}}) cmd = {"aggregate": self._name, "pipeline": pipeline, "cursor": {}} if "hint" in kwargs and not isinstance(kwargs["hint"], str): - kwargs["hint"] = helpers._index_document(kwargs["hint"]) + kwargs["hint"] = helpers_shared._index_document(kwargs["hint"]) collation = validate_collation_or_none(kwargs.pop("collation", None)) cmd.update(kwargs) async def _cmd( - session: Optional[ClientSession], + session: Optional[AsyncClientSession], _server: Server, - conn: Connection, + conn: AsyncConnection, read_preference: Optional[_ServerMode], ) -> int: result = await self._aggregate_one_result( @@ -2111,10 +2111,10 @@ async def _cmd( async def _retryable_non_cursor_read( self, func: Callable[ - [Optional[ClientSession], Server, Connection, Optional[_ServerMode]], + [Optional[AsyncClientSession], Server, AsyncConnection, Optional[_ServerMode]], Coroutine[Any, Any, T], ], - session: Optional[ClientSession], + session: Optional[AsyncClientSession], operation: str, ) -> T: """Non-cursor read helper to handle implicit session creation.""" @@ -2125,7 +2125,7 @@ async def _retryable_non_cursor_read( async def create_indexes( self, indexes: Sequence[IndexModel], - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, comment: Optional[Any] = None, **kwargs: Any, ) -> list[str]: @@ -2141,7 +2141,7 @@ async def create_indexes( :param indexes: A list of :class:`~pymongo.operations.IndexModel` instances. :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param comment: A user-provided comment to attach to this command. :param kwargs: optional arguments to the createIndexes @@ -2171,14 +2171,14 @@ async def create_indexes( @_csot.apply async def _create_indexes( - self, indexes: Sequence[IndexModel], session: Optional[ClientSession], **kwargs: Any + self, indexes: Sequence[IndexModel], session: Optional[AsyncClientSession], **kwargs: Any ) -> list[str]: """Internal createIndexes helper. :param indexes: A list of :class:`~pymongo.operations.IndexModel` instances. :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param kwargs: optional arguments to the createIndexes command (like maxTimeMS) can be passed as keyword arguments. """ @@ -2217,7 +2217,7 @@ def gen_indexes() -> Iterator[Mapping[str, Any]]: async def create_index( self, keys: _IndexKeyHint, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, comment: Optional[Any] = None, **kwargs: Any, ) -> str: @@ -2293,7 +2293,7 @@ async def create_index( :param keys: a single key or a list of (key, direction) pairs specifying the index to create :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param comment: A user-provided comment to attach to this command. :param kwargs: any additional index creation @@ -2334,7 +2334,7 @@ async def create_index( async def drop_indexes( self, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, comment: Optional[Any] = None, **kwargs: Any, ) -> None: @@ -2344,7 +2344,7 @@ async def drop_indexes( Raises OperationFailure on an error. :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param comment: A user-provided comment to attach to this command. :param kwargs: optional arguments to the createIndexes @@ -2369,7 +2369,7 @@ async def drop_indexes( async def drop_index( self, index_or_name: _IndexKeyHint, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, comment: Optional[Any] = None, **kwargs: Any, ) -> None: @@ -2391,7 +2391,7 @@ async def drop_index( :param index_or_name: index (or name of index) to drop :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param comment: A user-provided comment to attach to this command. :param kwargs: optional arguments to the createIndexes @@ -2418,13 +2418,13 @@ async def drop_index( async def _drop_index( self, index_or_name: _IndexKeyHint, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, comment: Optional[Any] = None, **kwargs: Any, ) -> None: name = index_or_name if isinstance(index_or_name, list): - name = helpers._gen_index_name(index_or_name) + name = helpers_shared._gen_index_name(index_or_name) if not isinstance(name, str): raise TypeError("index_or_name must be an instance of str or list") @@ -2445,7 +2445,7 @@ async def _drop_index( async def list_indexes( self, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, comment: Optional[Any] = None, ) -> AsyncCommandCursor[MutableMapping[str, Any]]: """Get a cursor over the index documents for this collection. @@ -2456,7 +2456,7 @@ async def list_indexes( SON([('v', 2), ('key', SON([('_id', 1)])), ('name', '_id_')]) :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param comment: A user-provided comment to attach to this command. @@ -2474,7 +2474,7 @@ async def list_indexes( async def _list_indexes( self, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, comment: Optional[Any] = None, ) -> AsyncCommandCursor[MutableMapping[str, Any]]: codec_options: CodecOptions = CodecOptions(SON) @@ -2486,9 +2486,9 @@ async def _list_indexes( explicit_session = session is not None async def _cmd( - session: Optional[ClientSession], + session: Optional[AsyncClientSession], _server: Server, - conn: Connection, + conn: AsyncConnection, read_preference: _ServerMode, ) -> AsyncCommandCursor[MutableMapping[str, Any]]: cmd = {"listIndexes": self._name, "cursor": {}} @@ -2523,7 +2523,7 @@ async def _cmd( async def index_information( self, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, comment: Optional[Any] = None, ) -> MutableMapping[str, Any]: """Get information on this collection's indexes. @@ -2545,7 +2545,7 @@ async def index_information( 'x_1': {'unique': True, 'key': [('x', 1)]}} :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param comment: A user-provided comment to attach to this command. @@ -2566,7 +2566,7 @@ async def index_information( async def list_search_indexes( self, name: Optional[str] = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, comment: Optional[Any] = None, **kwargs: Any, ) -> AsyncCommandCursor[Mapping[str, Any]]: @@ -2576,7 +2576,7 @@ async def list_search_indexes( for. Only indexes with matching index names will be returned. If not given, all search indexes for the current collection will be returned. - :param session: a :class:`~pymongo.client_session.ClientSession`. + :param session: a :class:`~pymongo.client_session.AsyncClientSession`. :param comment: A user-provided comment to attach to this command. @@ -2619,7 +2619,7 @@ async def list_search_indexes( async def create_search_index( self, model: Union[Mapping[str, Any], SearchIndexModel], - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, comment: Any = None, **kwargs: Any, ) -> str: @@ -2630,7 +2630,7 @@ async def create_search_index( instance or a dictionary with a model "definition" and optional "name". :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param comment: A user-provided comment to attach to this command. :param kwargs: optional arguments to the createSearchIndexes @@ -2649,14 +2649,14 @@ async def create_search_index( async def create_search_indexes( self, models: list[SearchIndexModel], - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, comment: Optional[Any] = None, **kwargs: Any, ) -> list[str]: """Create multiple search indexes for the current collection. :param models: A list of :class:`~pymongo.operations.SearchIndexModel` instances. - :param session: a :class:`~pymongo.client_session.ClientSession`. + :param session: a :class:`~pymongo.client_session.AsyncClientSession`. :param comment: A user-provided comment to attach to this command. :param kwargs: optional arguments to the createSearchIndexes @@ -2673,7 +2673,7 @@ async def create_search_indexes( async def _create_search_indexes( self, models: list[SearchIndexModel], - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, comment: Optional[Any] = None, **kwargs: Any, ) -> list[str]: @@ -2705,7 +2705,7 @@ def gen_indexes() -> Iterator[Mapping[str, Any]]: async def drop_search_index( self, name: str, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, comment: Optional[Any] = None, **kwargs: Any, ) -> None: @@ -2713,7 +2713,7 @@ async def drop_search_index( :param name: The name of the search index to be deleted. :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param comment: A user-provided comment to attach to this command. :param kwargs: optional arguments to the dropSearchIndexes @@ -2740,7 +2740,7 @@ async def update_search_index( self, name: str, definition: Mapping[str, Any], - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, comment: Optional[Any] = None, **kwargs: Any, ) -> None: @@ -2749,7 +2749,7 @@ async def update_search_index( :param name: The name of the search index to be updated. :param definition: The new search index definition. :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param comment: A user-provided comment to attach to this command. :param kwargs: optional arguments to the updateSearchIndexes @@ -2774,7 +2774,7 @@ async def update_search_index( async def options( self, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, comment: Optional[Any] = None, ) -> MutableMapping[str, Any]: """Get the options set on this collection. @@ -2785,7 +2785,7 @@ async def options( dictionary if the collection has not been created yet. :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param comment: A user-provided comment to attach to this command. @@ -2824,7 +2824,7 @@ async def _aggregate( aggregation_command: Type[_AggregationCommand], pipeline: _Pipeline, cursor_class: Type[AsyncCommandCursor], - session: Optional[ClientSession], + session: Optional[AsyncClientSession], explicit_session: bool, let: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, @@ -2853,7 +2853,7 @@ async def _aggregate( async def aggregate( self, pipeline: _Pipeline, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, let: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, **kwargs: Any, @@ -2876,7 +2876,7 @@ async def aggregate( :param pipeline: a list of aggregation pipeline stages :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param let: A dict of parameter names and values. Values must be constant or closed expressions that do not reference document fields. Parameters can then be accessed as variables in an @@ -2948,7 +2948,7 @@ async def aggregate( async def aggregate_raw_batches( self, pipeline: _Pipeline, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, comment: Optional[Any] = None, **kwargs: Any, ) -> AsyncRawBatchCursor[_DocumentType]: @@ -2997,7 +2997,7 @@ async def aggregate_raw_batches( async def rename( self, new_name: str, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, comment: Optional[Any] = None, **kwargs: Any, ) -> MutableMapping[str, Any]: @@ -3011,7 +3011,7 @@ async def rename( :param new_name: new name for this collection :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param comment: A user-provided comment to attach to this command. :param kwargs: additional arguments to the rename command @@ -3061,7 +3061,7 @@ async def distinct( self, key: str, filter: Optional[Mapping[str, Any]] = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, comment: Optional[Any] = None, **kwargs: Any, ) -> list: @@ -3087,7 +3087,7 @@ async def distinct( :param filter: A query document that specifies the documents from which to retrieve the distinct values. :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param comment: A user-provided comment to attach to this command. :param kwargs: See list of options above. @@ -3112,9 +3112,9 @@ async def distinct( cmd["comment"] = comment async def _cmd( - session: Optional[ClientSession], + session: Optional[AsyncClientSession], _server: Server, - conn: Connection, + conn: AsyncConnection, read_preference: Optional[_ServerMode], ) -> list: return ( @@ -3140,7 +3140,7 @@ async def _find_and_modify( return_document: bool = ReturnDocument.BEFORE, array_filters: Optional[Sequence[Mapping[str, Any]]] = None, hint: Optional[_IndexKeyHint] = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, let: Optional[Mapping] = None, **kwargs: Any, ) -> Any: @@ -3157,20 +3157,20 @@ async def _find_and_modify( cmd["let"] = let cmd.update(kwargs) if projection is not None: - cmd["fields"] = helpers._fields_list_to_dict(projection, "projection") + cmd["fields"] = helpers_shared._fields_list_to_dict(projection, "projection") if sort is not None: - cmd["sort"] = helpers._index_document(sort) + cmd["sort"] = helpers_shared._index_document(sort) if upsert is not None: validate_boolean("upsert", upsert) cmd["upsert"] = upsert if hint is not None: if not isinstance(hint, str): - hint = helpers._index_document(hint) + hint = helpers_shared._index_document(hint) write_concern = self._write_concern_for_cmd(cmd, session) async def _find_and_modify_helper( - session: Optional[ClientSession], conn: Connection, retryable_write: bool + session: Optional[AsyncClientSession], conn: AsyncConnection, retryable_write: bool ) -> Any: acknowledged = write_concern.acknowledged if array_filters is not None: @@ -3216,7 +3216,7 @@ async def find_one_and_delete( projection: Optional[Union[Mapping[str, Any], Iterable[str]]] = None, sort: Optional[_IndexList] = None, hint: Optional[_IndexKeyHint] = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, let: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, **kwargs: Any, @@ -3262,7 +3262,7 @@ async def find_one_and_delete( (e.g. ``[('field', ASCENDING)]``). This option is only supported on MongoDB 4.4 and above. :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param let: Map of parameter names and values. Values must be constant or closed expressions that do not reference document fields. Parameters can then be accessed as variables in an @@ -3308,7 +3308,7 @@ async def find_one_and_replace( upsert: bool = False, return_document: bool = ReturnDocument.BEFORE, hint: Optional[_IndexKeyHint] = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, let: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, **kwargs: Any, @@ -3360,7 +3360,7 @@ async def find_one_and_replace( ``[('field', ASCENDING)]``). This option is only supported on MongoDB 4.4 and above. :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param let: Map of parameter names and values. Values must be constant or closed expressions that do not reference document fields. Parameters can then be accessed as variables in an @@ -3416,7 +3416,7 @@ async def find_one_and_update( return_document: bool = ReturnDocument.BEFORE, array_filters: Optional[Sequence[Mapping[str, Any]]] = None, hint: Optional[_IndexKeyHint] = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, let: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, **kwargs: Any, @@ -3507,7 +3507,7 @@ async def find_one_and_update( ``[('field', ASCENDING)]``). This option is only supported on MongoDB 4.4 and above. :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param let: Map of parameter names and values. Values must be constant or closed expressions that do not reference document fields. Parameters can then be accessed as variables in an diff --git a/pymongo/asynchronous/command_cursor.py b/pymongo/asynchronous/command_cursor.py index 0412264e20..4dbc52802a 100644 --- a/pymongo/asynchronous/command_cursor.py +++ b/pymongo/asynchronous/command_cursor.py @@ -37,15 +37,15 @@ _OpReply, _RawBatchGetMore, ) -from pymongo.asynchronous.response import PinnedResponse -from pymongo.asynchronous.typings import _Address, _DocumentOut, _DocumentType from pymongo.cursor_shared import _CURSOR_CLOSED_ERRORS from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure +from pymongo.response import PinnedResponse +from pymongo.typings import _Address, _DocumentOut, _DocumentType if TYPE_CHECKING: - from pymongo.asynchronous.client_session import ClientSession + from pymongo.asynchronous.client_session import AsyncClientSession from pymongo.asynchronous.collection import AsyncCollection - from pymongo.asynchronous.pool import Connection + from pymongo.asynchronous.pool import AsyncConnection _IS_SYNC = False @@ -62,7 +62,7 @@ def __init__( address: Optional[_Address], batch_size: int = 0, max_await_time_ms: Optional[int] = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, explicit_session: bool = False, comment: Any = None, ) -> None: @@ -134,7 +134,7 @@ def _post_batch_resume_token(self) -> Optional[Mapping[str, Any]]: """ return self._postbatchresumetoken - async def _maybe_pin_connection(self, conn: Connection) -> None: + async def _maybe_pin_connection(self, conn: AsyncConnection) -> None: client = self._collection.database.client if not client._should_pin_cursor(self._session): return @@ -189,8 +189,8 @@ def address(self) -> Optional[_Address]: return self._address @property - def session(self) -> Optional[ClientSession]: - """The cursor's :class:`~pymongo.client_session.ClientSession`, or None. + def session(self) -> Optional[AsyncClientSession]: + """The cursor's :class:`~pymongo.client_session.AsyncClientSession`, or None. .. versionadded:: 3.6 """ @@ -372,7 +372,7 @@ def __init__( address: Optional[_Address], batch_size: int = 0, max_await_time_ms: Optional[int] = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, explicit_session: bool = False, comment: Any = None, ) -> None: diff --git a/pymongo/asynchronous/cursor.py b/pymongo/asynchronous/cursor.py index 4edd2103fd..8213e9e64e 100644 --- a/pymongo/asynchronous/cursor.py +++ b/pymongo/asynchronous/cursor.py @@ -36,12 +36,7 @@ from bson import RE_TYPE, _convert_raw_document_lists_to_streams from bson.code import Code from bson.son import SON -from pymongo.asynchronous import helpers -from pymongo.asynchronous.collation import validate_collation_or_none -from pymongo.asynchronous.common import ( - validate_is_document_type, - validate_is_mapping, -) +from pymongo import helpers_shared from pymongo.asynchronous.helpers import anext from pymongo.asynchronous.message import ( _CursorAddress, @@ -52,21 +47,26 @@ _RawBatchGetMore, _RawBatchQuery, ) -from pymongo.asynchronous.response import PinnedResponse -from pymongo.asynchronous.typings import _Address, _CollationIn, _DocumentOut, _DocumentType +from pymongo.collation import validate_collation_or_none +from pymongo.common import ( + validate_is_document_type, + validate_is_mapping, +) from pymongo.cursor_shared import _CURSOR_CLOSED_ERRORS, _QUERY_OPTIONS, CursorType, _Hint, _Sort from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure from pymongo.lock import _ALock, _create_lock +from pymongo.response import PinnedResponse +from pymongo.typings import _Address, _CollationIn, _DocumentOut, _DocumentType from pymongo.write_concern import validate_boolean if TYPE_CHECKING: from _typeshed import SupportsItems from bson.codec_options import CodecOptions - from pymongo.asynchronous.client_session import ClientSession + from pymongo.asynchronous.client_session import AsyncClientSession from pymongo.asynchronous.collection import AsyncCollection - from pymongo.asynchronous.pool import Connection - from pymongo.asynchronous.read_preferences import _ServerMode + from pymongo.asynchronous.pool import AsyncConnection + from pymongo.read_preferences import _ServerMode _IS_SYNC = False @@ -74,8 +74,8 @@ class _ConnectionManager: """Used with exhaust cursors to ensure the connection is returned.""" - def __init__(self, conn: Connection, more_to_come: bool): - self.conn: Optional[Connection] = conn + def __init__(self, conn: AsyncConnection, more_to_come: bool): + self.conn: Optional[AsyncConnection] = conn self.more_to_come = more_to_come self._alock = _ALock(_create_lock()) @@ -116,7 +116,7 @@ def __init__( show_record_id: Optional[bool] = None, snapshot: Optional[bool] = None, comment: Optional[Any] = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, allow_disk_use: Optional[bool] = None, let: Optional[bool] = None, ) -> None: @@ -134,7 +134,7 @@ def __init__( self._exhaust = False self._sock_mgr: Any = None self._killed = False - self._session: Optional[ClientSession] + self._session: Optional[AsyncClientSession] if session: self._session = session @@ -179,7 +179,7 @@ def __init__( allow_disk_use = validate_boolean("allow_disk_use", allow_disk_use) if projection is not None: - projection = helpers._fields_list_to_dict(projection, "projection") + projection = helpers_shared._fields_list_to_dict(projection, "projection") if let is not None: validate_is_document_type("let", let) @@ -191,7 +191,7 @@ def __init__( self._skip = skip self._limit = limit self._batch_size = batch_size - self._ordering = sort and helpers._index_document(sort) or None + self._ordering = sort and helpers_shared._index_document(sort) or None self._max_scan = max_scan self._explain = False self._comment = comment @@ -314,7 +314,7 @@ def _clone(self, deepcopy: bool = True, base: Optional[AsyncCursor] = None) -> A base.__dict__.update(data) return base - def _clone_base(self, session: Optional[ClientSession]) -> AsyncCursor: + def _clone_base(self, session: Optional[AsyncClientSession]) -> AsyncCursor: """Creates an empty Cursor object for information to be copied into.""" return self.__class__(self._collection, session=session) @@ -743,8 +743,8 @@ def sort( key, if not given :data:`~pymongo.ASCENDING` is assumed """ self._check_okay_to_chain() - keys = helpers._index_list(key_or_list, direction) - self._ordering = helpers._index_document(keys) + keys = helpers_shared._index_list(key_or_list, direction) + self._ordering = helpers_shared._index_document(keys) return self async def explain(self) -> _DocumentType: @@ -775,7 +775,7 @@ def _set_hint(self, index: Optional[_Hint]) -> None: if isinstance(index, str): self._hint = index else: - self._hint = helpers._index_document(index) + self._hint = helpers_shared._index_document(index) def hint(self, index: Optional[_Hint]) -> AsyncCursor[_DocumentType]: """Adds a 'hint', telling Mongo the proper index to use for the query. @@ -928,8 +928,8 @@ def address(self) -> Optional[tuple[str, Any]]: return self._address @property - def session(self) -> Optional[ClientSession]: - """The cursor's :class:`~pymongo.client_session.ClientSession`, or None. + def session(self) -> Optional[AsyncClientSession]: + """The cursor's :class:`~pymongo.client_session.AsyncClientSession`, or None. .. versionadded:: 3.6 """ diff --git a/pymongo/asynchronous/database.py b/pymongo/asynchronous/database.py index 57ad71ece3..f7a07027c8 100644 --- a/pymongo/asynchronous/database.py +++ b/pymongo/asynchronous/database.py @@ -33,25 +33,24 @@ from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions from bson.dbref import DBRef from bson.timestamp import Timestamp -from pymongo import _csot -from pymongo.asynchronous import common +from pymongo import _csot, common from pymongo.asynchronous.aggregation import _DatabaseAggregationCommand from pymongo.asynchronous.change_stream import DatabaseChangeStream from pymongo.asynchronous.collection import AsyncCollection from pymongo.asynchronous.command_cursor import AsyncCommandCursor -from pymongo.asynchronous.common import _ecoc_coll_name, _esc_coll_name -from pymongo.asynchronous.operations import _Op -from pymongo.asynchronous.read_preferences import ReadPreference, _ServerMode -from pymongo.asynchronous.typings import _CollationIn, _DocumentType, _DocumentTypeArg, _Pipeline +from pymongo.common import _ecoc_coll_name, _esc_coll_name from pymongo.database_shared import _check_name, _CodecDocumentType from pymongo.errors import CollectionInvalid, InvalidOperation +from pymongo.operations import _Op +from pymongo.read_preferences import ReadPreference, _ServerMode +from pymongo.typings import _CollationIn, _DocumentType, _DocumentTypeArg, _Pipeline if TYPE_CHECKING: import bson import bson.codec_options - from pymongo.asynchronous.client_session import ClientSession + from pymongo.asynchronous.client_session import AsyncClientSession from pymongo.asynchronous.mongo_client import AsyncMongoClient - from pymongo.asynchronous.pool import Connection + from pymongo.asynchronous.pool import AsyncConnection from pymongo.asynchronous.server import Server from pymongo.read_concern import ReadConcern from pymongo.write_concern import WriteConcern @@ -151,7 +150,7 @@ def with_options( >>> db1.read_preference Primary() - >>> from pymongo.asynchronous.read_preferences import Secondary + >>> from pymongo.read_preferences import Secondary >>> db2 = db1.with_options(read_preference=Secondary([{'node': 'analytics'}])) >>> db1.read_preference Primary() @@ -328,7 +327,7 @@ async def watch( batch_size: Optional[int] = None, collation: Optional[_CollationIn] = None, start_at_operation_time: Optional[Timestamp] = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, start_after: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, full_document_before_change: Optional[str] = None, @@ -402,7 +401,7 @@ async def watch( the specified :class:`~bson.timestamp.Timestamp`. Requires MongoDB >= 4.0. :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param start_after: The same as `resume_after` except that `start_after` can resume notifications after an invalidate event. This option and `resume_after` are mutually exclusive. @@ -458,7 +457,7 @@ async def create_collection( read_preference: Optional[_ServerMode] = None, write_concern: Optional[WriteConcern] = None, read_concern: Optional[ReadConcern] = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, check_exists: Optional[bool] = True, **kwargs: Any, ) -> AsyncCollection[_DocumentType]: @@ -489,7 +488,7 @@ async def create_collection( :param collation: An instance of :class:`~pymongo.collation.Collation`. :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param `check_exists`: if True (the default), send a listCollections command to check if the collection already exists before creation. :param kwargs: additional keyword arguments will @@ -607,7 +606,7 @@ async def create_collection( return coll async def aggregate( - self, pipeline: _Pipeline, session: Optional[ClientSession] = None, **kwargs: Any + self, pipeline: _Pipeline, session: Optional[AsyncClientSession] = None, **kwargs: Any ) -> AsyncCommandCursor[_DocumentType]: """Perform a database-level aggregation. @@ -634,7 +633,7 @@ async def aggregate( :param pipeline: a list of aggregation pipeline stages :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param kwargs: extra `aggregate command`_ parameters. All optional `aggregate command`_ parameters should be passed as @@ -688,7 +687,7 @@ async def aggregate( @overload async def _command( self, - conn: Connection, + conn: AsyncConnection, command: Union[str, MutableMapping[str, Any]], value: int = 1, check: bool = True, @@ -697,7 +696,7 @@ async def _command( codec_options: CodecOptions[dict[str, Any]] = DEFAULT_CODEC_OPTIONS, write_concern: Optional[WriteConcern] = None, parse_write_concern_error: bool = False, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, **kwargs: Any, ) -> dict[str, Any]: ... @@ -705,7 +704,7 @@ async def _command( @overload async def _command( self, - conn: Connection, + conn: AsyncConnection, command: Union[str, MutableMapping[str, Any]], value: int = 1, check: bool = True, @@ -714,14 +713,14 @@ async def _command( codec_options: CodecOptions[_CodecDocumentType] = ..., write_concern: Optional[WriteConcern] = None, parse_write_concern_error: bool = False, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, **kwargs: Any, ) -> _CodecDocumentType: ... async def _command( self, - conn: Connection, + conn: AsyncConnection, command: Union[str, MutableMapping[str, Any]], value: int = 1, check: bool = True, @@ -732,7 +731,7 @@ async def _command( ] = DEFAULT_CODEC_OPTIONS, write_concern: Optional[WriteConcern] = None, parse_write_concern_error: bool = False, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, **kwargs: Any, ) -> Union[dict[str, Any], _CodecDocumentType]: """Internal command helper.""" @@ -763,7 +762,7 @@ async def command( allowable_errors: Optional[Sequence[Union[str, int]]] = None, read_preference: Optional[_ServerMode] = None, codec_options: None = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, comment: Optional[Any] = None, **kwargs: Any, ) -> dict[str, Any]: @@ -778,7 +777,7 @@ async def command( allowable_errors: Optional[Sequence[Union[str, int]]] = None, read_preference: Optional[_ServerMode] = None, codec_options: CodecOptions[_CodecDocumentType] = ..., - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, comment: Optional[Any] = None, **kwargs: Any, ) -> _CodecDocumentType: @@ -793,7 +792,7 @@ async def command( allowable_errors: Optional[Sequence[Union[str, int]]] = None, read_preference: Optional[_ServerMode] = None, codec_options: Optional[bson.codec_options.CodecOptions[_CodecDocumentType]] = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, comment: Optional[Any] = None, **kwargs: Any, ) -> Union[dict[str, Any], _CodecDocumentType]: @@ -852,7 +851,7 @@ async def command( :param codec_options: A :class:`~bson.codec_options.CodecOptions` instance. :param session: A - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param comment: A user-provided comment to attach to this command. :param kwargs: additional keyword arguments will @@ -922,7 +921,7 @@ async def cursor_command( value: Any = 1, read_preference: Optional[_ServerMode] = None, codec_options: Optional[CodecOptions[_CodecDocumentType]] = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, comment: Optional[Any] = None, max_await_time_ms: Optional[int] = None, **kwargs: Any, @@ -953,7 +952,7 @@ async def cursor_command( :param codec_options`: A :class:`~bson.codec_options.CodecOptions` instance. :param session: A - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param comment: A user-provided comment to attach to future getMores for this command. :param max_await_time_ms: The number of ms to wait for more data on future getMores for this command. @@ -1024,15 +1023,15 @@ async def _retryable_read_command( self, command: Union[str, MutableMapping[str, Any]], operation: str, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, ) -> dict[str, Any]: """Same as command but used for retryable read commands.""" read_preference = (session and session._txn_read_preference()) or ReadPreference.PRIMARY async def _cmd( - session: Optional[ClientSession], + session: Optional[AsyncClientSession], _server: Server, - conn: Connection, + conn: AsyncConnection, read_preference: _ServerMode, ) -> dict[str, Any]: return await self._command( @@ -1046,8 +1045,8 @@ async def _cmd( async def _list_collections( self, - conn: Connection, - session: Optional[ClientSession], + conn: AsyncConnection, + session: Optional[AsyncClientSession], read_preference: _ServerMode, **kwargs: Any, ) -> AsyncCommandCursor[MutableMapping[str, Any]]: @@ -1075,7 +1074,7 @@ async def _list_collections( async def _list_collections_helper( self, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, filter: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, **kwargs: Any, @@ -1083,7 +1082,7 @@ async def _list_collections_helper( """Get a cursor over the collections of this database. :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param filter: A query document to filter the list of collections returned from the listCollections command. :param comment: A user-provided comment to attach to this @@ -1106,9 +1105,9 @@ async def _list_collections_helper( kwargs["comment"] = comment async def _cmd( - session: Optional[ClientSession], + session: Optional[AsyncClientSession], _server: Server, - conn: Connection, + conn: AsyncConnection, read_preference: _ServerMode, ) -> AsyncCommandCursor[MutableMapping[str, Any]]: return await self._list_collections( @@ -1121,7 +1120,7 @@ async def _cmd( async def list_collections( self, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, filter: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, **kwargs: Any, @@ -1129,7 +1128,7 @@ async def list_collections( """Get a cursor over the collections of this database. :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param filter: A query document to filter the list of collections returned from the listCollections command. :param comment: A user-provided comment to attach to this @@ -1149,7 +1148,7 @@ async def list_collections( async def _list_collection_names( self, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, filter: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, **kwargs: Any, @@ -1174,7 +1173,7 @@ async def _list_collection_names( async def list_collection_names( self, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, filter: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, **kwargs: Any, @@ -1187,7 +1186,7 @@ async def list_collection_names( db.list_collection_names(filter=filter) :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param filter: A query document to filter the list of collections returned from the listCollections command. :param comment: A user-provided comment to attach to this @@ -1207,7 +1206,7 @@ async def list_collection_names( return await self._list_collection_names(session, filter, comment, **kwargs) async def _drop_helper( - self, name: str, session: Optional[ClientSession] = None, comment: Optional[Any] = None + self, name: str, session: Optional[AsyncClientSession] = None, comment: Optional[Any] = None ) -> dict[str, Any]: command = {"drop": name} if comment is not None: @@ -1227,7 +1226,7 @@ async def _drop_helper( async def drop_collection( self, name_or_collection: Union[str, AsyncCollection[_DocumentTypeArg]], - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, comment: Optional[Any] = None, encrypted_fields: Optional[Mapping[str, Any]] = None, ) -> dict[str, Any]: @@ -1236,7 +1235,7 @@ async def drop_collection( :param name_or_collection: the name of a collection to drop or the collection object itself :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param comment: A user-provided comment to attach to this command. :param encrypted_fields: **(BETA)** Document that describes the encrypted fields for @@ -1306,7 +1305,7 @@ async def validate_collection( name_or_collection: Union[str, AsyncCollection[_DocumentTypeArg]], scandata: bool = False, full: bool = False, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, background: Optional[bool] = None, comment: Optional[Any] = None, ) -> dict[str, Any]: @@ -1326,7 +1325,7 @@ async def validate_collection( of the structure of the collection and the individual documents. :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param background: A boolean flag that determines whether the command runs in the background. Requires MongoDB 4.4+. :param comment: A user-provided comment to attach to this @@ -1386,7 +1385,7 @@ async def validate_collection( async def dereference( self, dbref: DBRef, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, comment: Optional[Any] = None, **kwargs: Any, ) -> Optional[_DocumentType]: @@ -1401,7 +1400,7 @@ async def dereference( :param dbref: the reference :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param comment: A user-provided comment to attach to this command. :param kwargs: any additional keyword arguments diff --git a/pymongo/asynchronous/encryption.py b/pymongo/asynchronous/encryption.py index cc9c30f988..fbd0e719e8 100644 --- a/pymongo/asynchronous/encryption.py +++ b/pymongo/asynchronous/encryption.py @@ -59,16 +59,13 @@ from bson.raw_bson import DEFAULT_RAW_BSON_OPTIONS, RawBSONDocument, _inflate_bson from pymongo import _csot from pymongo.asynchronous.collection import AsyncCollection -from pymongo.asynchronous.common import CONNECT_TIMEOUT from pymongo.asynchronous.cursor import AsyncCursor from pymongo.asynchronous.database import AsyncDatabase -from pymongo.asynchronous.encryption_options import AutoEncryptionOpts, RangeOpts from pymongo.asynchronous.mongo_client import AsyncMongoClient -from pymongo.asynchronous.operations import UpdateOne -from pymongo.asynchronous.pool import PoolOptions, _configured_socket, _raise_connection_failure -from pymongo.asynchronous.typings import _DocumentType, _DocumentTypeArg -from pymongo.asynchronous.uri_parser import parse_host +from pymongo.asynchronous.pool import _configured_socket, _raise_connection_failure +from pymongo.common import CONNECT_TIMEOUT from pymongo.daemon import _spawn_daemon +from pymongo.encryption_options import AutoEncryptionOpts, RangeOpts from pymongo.errors import ( ConfigurationError, EncryptedCollectionError, @@ -78,9 +75,13 @@ ServerSelectionTimeoutError, ) from pymongo.network_layer import BLOCKING_IO_ERRORS, async_sendall +from pymongo.operations import UpdateOne +from pymongo.pool_options import PoolOptions from pymongo.read_concern import ReadConcern from pymongo.results import BulkWriteResult, DeleteResult from pymongo.ssl_support import get_ssl_context +from pymongo.typings import _DocumentType, _DocumentTypeArg +from pymongo.uri_parser import parse_host from pymongo.write_concern import WriteConcern if TYPE_CHECKING: @@ -381,7 +382,10 @@ def _get_internal_client( ) io_callbacks = _EncryptionIO( # type:ignore[misc] - metadata_client, key_vault_coll, mongocryptd_client, opts + metadata_client, + key_vault_coll, # type:ignore[arg-type] + mongocryptd_client, + opts, ) self._auto_encrypter = AsyncAutoEncrypter( io_callbacks, diff --git a/pymongo/asynchronous/encryption_options.py b/pymongo/asynchronous/encryption_options.py deleted file mode 100644 index 73d1932c6a..0000000000 --- a/pymongo/asynchronous/encryption_options.py +++ /dev/null @@ -1,270 +0,0 @@ -# Copyright 2019-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Support for automatic client-side field level encryption.""" -from __future__ import annotations - -from typing import TYPE_CHECKING, Any, Mapping, Optional - -try: - import pymongocrypt # type:ignore[import] # noqa: F401 - - _HAVE_PYMONGOCRYPT = True -except ImportError: - _HAVE_PYMONGOCRYPT = False -from bson import int64 -from pymongo.asynchronous.common import validate_is_mapping -from pymongo.asynchronous.uri_parser import _parse_kms_tls_options -from pymongo.errors import ConfigurationError - -if TYPE_CHECKING: - from pymongo.asynchronous.mongo_client import AsyncMongoClient - from pymongo.asynchronous.typings import _DocumentTypeArg - -_IS_SYNC = False - - -class AutoEncryptionOpts: - """Options to configure automatic client-side field level encryption.""" - - def __init__( - self, - kms_providers: Mapping[str, Any], - key_vault_namespace: str, - key_vault_client: Optional[AsyncMongoClient[_DocumentTypeArg]] = None, - schema_map: Optional[Mapping[str, Any]] = None, - bypass_auto_encryption: bool = False, - mongocryptd_uri: str = "mongodb://localhost:27020", - mongocryptd_bypass_spawn: bool = False, - mongocryptd_spawn_path: str = "mongocryptd", - mongocryptd_spawn_args: Optional[list[str]] = None, - kms_tls_options: Optional[Mapping[str, Any]] = None, - crypt_shared_lib_path: Optional[str] = None, - crypt_shared_lib_required: bool = False, - bypass_query_analysis: bool = False, - encrypted_fields_map: Optional[Mapping[str, Any]] = None, - ) -> None: - """Options to configure automatic client-side field level encryption. - - Automatic client-side field level encryption requires MongoDB >=4.2 - enterprise or a MongoDB >=4.2 Atlas cluster. Automatic encryption is not - supported for operations on a database or view and will result in - error. - - Although automatic encryption requires MongoDB >=4.2 enterprise or a - MongoDB >=4.2 Atlas cluster, automatic *decryption* is supported for all - users. To configure automatic *decryption* without automatic - *encryption* set ``bypass_auto_encryption=True``. Explicit - encryption and explicit decryption is also supported for all users - with the :class:`~pymongo.encryption.ClientEncryption` class. - - See :ref:`automatic-client-side-encryption` for an example. - - :param kms_providers: Map of KMS provider options. The `kms_providers` - map values differ by provider: - - - `aws`: Map with "accessKeyId" and "secretAccessKey" as strings. - These are the AWS access key ID and AWS secret access key used - to generate KMS messages. An optional "sessionToken" may be - included to support temporary AWS credentials. - - `azure`: Map with "tenantId", "clientId", and "clientSecret" as - strings. Additionally, "identityPlatformEndpoint" may also be - specified as a string (defaults to 'login.microsoftonline.com'). - These are the Azure Active Directory credentials used to - generate Azure Key Vault messages. - - `gcp`: Map with "email" as a string and "privateKey" - as `bytes` or a base64 encoded string. - Additionally, "endpoint" may also be specified as a string - (defaults to 'oauth2.googleapis.com'). These are the - credentials used to generate Google Cloud KMS messages. - - `kmip`: Map with "endpoint" as a host with required port. - For example: ``{"endpoint": "example.com:443"}``. - - `local`: Map with "key" as `bytes` (96 bytes in length) or - a base64 encoded string which decodes - to 96 bytes. "key" is the master key used to encrypt/decrypt - data keys. This key should be generated and stored as securely - as possible. - - KMS providers may be specified with an optional name suffix - separated by a colon, for example "kmip:name" or "aws:name". - Named KMS providers do not support :ref:`CSFLE on-demand credentials`. - Named KMS providers enables more than one of each KMS provider type to be configured. - For example, to configure multiple local KMS providers:: - - kms_providers = { - "local": {"key": local_kek1}, # Unnamed KMS provider. - "local:myname": {"key": local_kek2}, # Named KMS provider with name "myname". - } - - :param key_vault_namespace: The namespace for the key vault collection. - The key vault collection contains all data keys used for encryption - and decryption. Data keys are stored as documents in this MongoDB - collection. Data keys are protected with encryption by a KMS - provider. - :param key_vault_client: By default, the key vault collection - is assumed to reside in the same MongoDB cluster as the encrypted - AsyncMongoClient. Use this option to route data key queries to a - separate MongoDB cluster. - :param schema_map: Map of collection namespace ("db.coll") to - JSON Schema. By default, a collection's JSONSchema is periodically - polled with the listCollections command. But a JSONSchema may be - specified locally with the schemaMap option. - - **Supplying a `schema_map` provides more security than relying on - JSON Schemas obtained from the server. It protects against a - malicious server advertising a false JSON Schema, which could trick - the client into sending unencrypted data that should be - encrypted.** - - Schemas supplied in the schemaMap only apply to configuring - automatic encryption for client side encryption. Other validation - rules in the JSON schema will not be enforced by the driver and - will result in an error. - :param bypass_auto_encryption: If ``True``, automatic - encryption will be disabled but automatic decryption will still be - enabled. Defaults to ``False``. - :param mongocryptd_uri: The MongoDB URI used to connect - to the *local* mongocryptd process. Defaults to - ``'mongodb://localhost:27020'``. - :param mongocryptd_bypass_spawn: If ``True``, the encrypted - AsyncMongoClient will not attempt to spawn the mongocryptd process. - Defaults to ``False``. - :param mongocryptd_spawn_path: Used for spawning the - mongocryptd process. Defaults to ``'mongocryptd'`` and spawns - mongocryptd from the system path. - :param mongocryptd_spawn_args: A list of string arguments to - use when spawning the mongocryptd process. Defaults to - ``['--idleShutdownTimeoutSecs=60']``. If the list does not include - the ``idleShutdownTimeoutSecs`` option then - ``'--idleShutdownTimeoutSecs=60'`` will be added. - :param kms_tls_options: A map of KMS provider names to TLS - options to use when creating secure connections to KMS providers. - Accepts the same TLS options as - :class:`pymongo.mongo_client.AsyncMongoClient`. For example, to - override the system default CA file:: - - kms_tls_options={'kmip': {'tlsCAFile': certifi.where()}} - - Or to supply a client certificate:: - - kms_tls_options={'kmip': {'tlsCertificateKeyFile': 'client.pem'}} - :param crypt_shared_lib_path: Override the path to load the crypt_shared library. - :param crypt_shared_lib_required: If True, raise an error if libmongocrypt is - unable to load the crypt_shared library. - :param bypass_query_analysis: If ``True``, disable automatic analysis - of outgoing commands. Set `bypass_query_analysis` to use explicit - encryption on indexed fields without the MongoDB Enterprise Advanced - licensed crypt_shared library. - :param encrypted_fields_map: Map of collection namespace ("db.coll") to documents - that described the encrypted fields for Queryable Encryption. For example:: - - { - "db.encryptedCollection": { - "escCollection": "enxcol_.encryptedCollection.esc", - "ecocCollection": "enxcol_.encryptedCollection.ecoc", - "fields": [ - { - "path": "firstName", - "keyId": Binary.from_uuid(UUID('00000000-0000-0000-0000-000000000000')), - "bsonType": "string", - "queries": {"queryType": "equality"} - }, - { - "path": "ssn", - "keyId": Binary.from_uuid(UUID('04104104-1041-0410-4104-104104104104')), - "bsonType": "string" - } - ] - } - } - - .. versionchanged:: 4.2 - Added `encrypted_fields_map` `crypt_shared_lib_path`, `crypt_shared_lib_required`, - and `bypass_query_analysis` parameters. - - .. versionchanged:: 4.0 - Added the `kms_tls_options` parameter and the "kmip" KMS provider. - - .. versionadded:: 3.9 - """ - if not _HAVE_PYMONGOCRYPT: - raise ConfigurationError( - "client side encryption requires the pymongocrypt library: " - "install a compatible version with: " - "python -m pip install 'pymongo[encryption]'" - ) - if encrypted_fields_map: - validate_is_mapping("encrypted_fields_map", encrypted_fields_map) - self._encrypted_fields_map = encrypted_fields_map - self._bypass_query_analysis = bypass_query_analysis - self._crypt_shared_lib_path = crypt_shared_lib_path - self._crypt_shared_lib_required = crypt_shared_lib_required - self._kms_providers = kms_providers - self._key_vault_namespace = key_vault_namespace - self._key_vault_client = key_vault_client - self._schema_map = schema_map - self._bypass_auto_encryption = bypass_auto_encryption - self._mongocryptd_uri = mongocryptd_uri - self._mongocryptd_bypass_spawn = mongocryptd_bypass_spawn - self._mongocryptd_spawn_path = mongocryptd_spawn_path - if mongocryptd_spawn_args is None: - mongocryptd_spawn_args = ["--idleShutdownTimeoutSecs=60"] - self._mongocryptd_spawn_args = mongocryptd_spawn_args - if not isinstance(self._mongocryptd_spawn_args, list): - raise TypeError("mongocryptd_spawn_args must be a list") - if not any("idleShutdownTimeoutSecs" in s for s in self._mongocryptd_spawn_args): - self._mongocryptd_spawn_args.append("--idleShutdownTimeoutSecs=60") - # Maps KMS provider name to a SSLContext. - self._kms_ssl_contexts = _parse_kms_tls_options(kms_tls_options) - self._bypass_query_analysis = bypass_query_analysis - - -class RangeOpts: - """Options to configure encrypted queries using the rangePreview algorithm.""" - - def __init__( - self, - sparsity: int, - min: Optional[Any] = None, - max: Optional[Any] = None, - precision: Optional[int] = None, - ) -> None: - """Options to configure encrypted queries using the rangePreview algorithm. - - .. note:: This feature is experimental only, and not intended for public use. - - :param sparsity: An integer. - :param min: A BSON scalar value corresponding to the type being queried. - :param max: A BSON scalar value corresponding to the type being queried. - :param precision: An integer, may only be set for double or decimal128 types. - - .. versionadded:: 4.4 - """ - self.min = min - self.max = max - self.sparsity = sparsity - self.precision = precision - - @property - def document(self) -> dict[str, Any]: - doc = {} - for k, v in [ - ("sparsity", int64.Int64(self.sparsity)), - ("precision", self.precision), - ("min", self.min), - ("max", self.max), - ]: - if v is not None: - doc[k] = v - return doc diff --git a/pymongo/asynchronous/event_loggers.py b/pymongo/asynchronous/event_loggers.py deleted file mode 100644 index 9bb8bb36bc..0000000000 --- a/pymongo/asynchronous/event_loggers.py +++ /dev/null @@ -1,225 +0,0 @@ -# Copyright 2020-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -"""Example event logger classes. - -.. versionadded:: 3.11 - -These loggers can be registered using :func:`register` or -:class:`~pymongo.mongo_client.MongoClient`. - -``monitoring.register(CommandLogger())`` - -or - -``MongoClient(event_listeners=[CommandLogger()])`` -""" -from __future__ import annotations - -import logging - -from pymongo.asynchronous import monitoring - -_IS_SYNC = False - - -class CommandLogger(monitoring.CommandListener): - """A simple listener that logs command events. - - Listens for :class:`~pymongo.monitoring.CommandStartedEvent`, - :class:`~pymongo.monitoring.CommandSucceededEvent` and - :class:`~pymongo.monitoring.CommandFailedEvent` events and - logs them at the `INFO` severity level using :mod:`logging`. - .. versionadded:: 3.11 - """ - - def started(self, event: monitoring.CommandStartedEvent) -> None: - logging.info( - f"Command {event.command_name} with request id " - f"{event.request_id} started on server " - f"{event.connection_id}" - ) - - def succeeded(self, event: monitoring.CommandSucceededEvent) -> None: - logging.info( - f"Command {event.command_name} with request id " - f"{event.request_id} on server {event.connection_id} " - f"succeeded in {event.duration_micros} " - "microseconds" - ) - - def failed(self, event: monitoring.CommandFailedEvent) -> None: - logging.info( - f"Command {event.command_name} with request id " - f"{event.request_id} on server {event.connection_id} " - f"failed in {event.duration_micros} " - "microseconds" - ) - - -class ServerLogger(monitoring.ServerListener): - """A simple listener that logs server discovery events. - - Listens for :class:`~pymongo.monitoring.ServerOpeningEvent`, - :class:`~pymongo.monitoring.ServerDescriptionChangedEvent`, - and :class:`~pymongo.monitoring.ServerClosedEvent` - events and logs them at the `INFO` severity level using :mod:`logging`. - - .. versionadded:: 3.11 - """ - - def opened(self, event: monitoring.ServerOpeningEvent) -> None: - logging.info(f"Server {event.server_address} added to topology {event.topology_id}") - - def description_changed(self, event: monitoring.ServerDescriptionChangedEvent) -> None: - previous_server_type = event.previous_description.server_type - new_server_type = event.new_description.server_type - if new_server_type != previous_server_type: - # server_type_name was added in PyMongo 3.4 - logging.info( - f"Server {event.server_address} changed type from " - f"{event.previous_description.server_type_name} to " - f"{event.new_description.server_type_name}" - ) - - def closed(self, event: monitoring.ServerClosedEvent) -> None: - logging.warning(f"Server {event.server_address} removed from topology {event.topology_id}") - - -class HeartbeatLogger(monitoring.ServerHeartbeatListener): - """A simple listener that logs server heartbeat events. - - Listens for :class:`~pymongo.monitoring.ServerHeartbeatStartedEvent`, - :class:`~pymongo.monitoring.ServerHeartbeatSucceededEvent`, - and :class:`~pymongo.monitoring.ServerHeartbeatFailedEvent` - events and logs them at the `INFO` severity level using :mod:`logging`. - - .. versionadded:: 3.11 - """ - - def started(self, event: monitoring.ServerHeartbeatStartedEvent) -> None: - logging.info(f"Heartbeat sent to server {event.connection_id}") - - def succeeded(self, event: monitoring.ServerHeartbeatSucceededEvent) -> None: - # The reply.document attribute was added in PyMongo 3.4. - logging.info( - f"Heartbeat to server {event.connection_id} " - "succeeded with reply " - f"{event.reply.document}" - ) - - def failed(self, event: monitoring.ServerHeartbeatFailedEvent) -> None: - logging.warning( - f"Heartbeat to server {event.connection_id} failed with error {event.reply}" - ) - - -class TopologyLogger(monitoring.TopologyListener): - """A simple listener that logs server topology events. - - Listens for :class:`~pymongo.monitoring.TopologyOpenedEvent`, - :class:`~pymongo.monitoring.TopologyDescriptionChangedEvent`, - and :class:`~pymongo.monitoring.TopologyClosedEvent` - events and logs them at the `INFO` severity level using :mod:`logging`. - - .. versionadded:: 3.11 - """ - - def opened(self, event: monitoring.TopologyOpenedEvent) -> None: - logging.info(f"Topology with id {event.topology_id} opened") - - def description_changed(self, event: monitoring.TopologyDescriptionChangedEvent) -> None: - logging.info(f"Topology description updated for topology id {event.topology_id}") - previous_topology_type = event.previous_description.topology_type - new_topology_type = event.new_description.topology_type - if new_topology_type != previous_topology_type: - # topology_type_name was added in PyMongo 3.4 - logging.info( - f"Topology {event.topology_id} changed type from " - f"{event.previous_description.topology_type_name} to " - f"{event.new_description.topology_type_name}" - ) - # The has_writable_server and has_readable_server methods - # were added in PyMongo 3.4. - if not event.new_description.has_writable_server(): - logging.warning("No writable servers available.") - if not event.new_description.has_readable_server(): - logging.warning("No readable servers available.") - - def closed(self, event: monitoring.TopologyClosedEvent) -> None: - logging.info(f"Topology with id {event.topology_id} closed") - - -class ConnectionPoolLogger(monitoring.ConnectionPoolListener): - """A simple listener that logs server connection pool events. - - Listens for :class:`~pymongo.monitoring.PoolCreatedEvent`, - :class:`~pymongo.monitoring.PoolClearedEvent`, - :class:`~pymongo.monitoring.PoolClosedEvent`, - :~pymongo.monitoring.class:`ConnectionCreatedEvent`, - :class:`~pymongo.monitoring.ConnectionReadyEvent`, - :class:`~pymongo.monitoring.ConnectionClosedEvent`, - :class:`~pymongo.monitoring.ConnectionCheckOutStartedEvent`, - :class:`~pymongo.monitoring.ConnectionCheckOutFailedEvent`, - :class:`~pymongo.monitoring.ConnectionCheckedOutEvent`, - and :class:`~pymongo.monitoring.ConnectionCheckedInEvent` - events and logs them at the `INFO` severity level using :mod:`logging`. - - .. versionadded:: 3.11 - """ - - def pool_created(self, event: monitoring.PoolCreatedEvent) -> None: - logging.info(f"[pool {event.address}] pool created") - - def pool_ready(self, event: monitoring.PoolReadyEvent) -> None: - logging.info(f"[pool {event.address}] pool ready") - - def pool_cleared(self, event: monitoring.PoolClearedEvent) -> None: - logging.info(f"[pool {event.address}] pool cleared") - - def pool_closed(self, event: monitoring.PoolClosedEvent) -> None: - logging.info(f"[pool {event.address}] pool closed") - - def connection_created(self, event: monitoring.ConnectionCreatedEvent) -> None: - logging.info(f"[pool {event.address}][conn #{event.connection_id}] connection created") - - def connection_ready(self, event: monitoring.ConnectionReadyEvent) -> None: - logging.info( - f"[pool {event.address}][conn #{event.connection_id}] connection setup succeeded" - ) - - def connection_closed(self, event: monitoring.ConnectionClosedEvent) -> None: - logging.info( - f"[pool {event.address}][conn #{event.connection_id}] " - f'connection closed, reason: "{event.reason}"' - ) - - def connection_check_out_started( - self, event: monitoring.ConnectionCheckOutStartedEvent - ) -> None: - logging.info(f"[pool {event.address}] connection check out started") - - def connection_check_out_failed(self, event: monitoring.ConnectionCheckOutFailedEvent) -> None: - logging.info(f"[pool {event.address}] connection check out failed, reason: {event.reason}") - - def connection_checked_out(self, event: monitoring.ConnectionCheckedOutEvent) -> None: - logging.info( - f"[pool {event.address}][conn #{event.connection_id}] connection checked out of pool" - ) - - def connection_checked_in(self, event: monitoring.ConnectionCheckedInEvent) -> None: - logging.info( - f"[pool {event.address}][conn #{event.connection_id}] connection checked into pool" - ) diff --git a/pymongo/asynchronous/helpers.py b/pymongo/asynchronous/helpers.py index 2b7420bbce..c939bfabe1 100644 --- a/pymongo/asynchronous/helpers.py +++ b/pymongo/asynchronous/helpers.py @@ -1,4 +1,4 @@ -# Copyright 2009-present MongoDB, Inc. +# Copyright 2024-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,270 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Bits and pieces used by the driver that don't really fit elsewhere.""" +"""Miscellaneous pieces that need to be synchronized.""" from __future__ import annotations import builtins import sys -import traceback -from collections import abc from typing import ( - TYPE_CHECKING, Any, Callable, - Container, - Iterable, - Mapping, - NoReturn, - Optional, - Sequence, TypeVar, - Union, cast, ) -from pymongo import ASCENDING -from pymongo.asynchronous.hello_compat import HelloCompat from pymongo.errors import ( - CursorNotFound, - DuplicateKeyError, - ExecutionTimeout, - NotPrimaryError, OperationFailure, - WriteConcernError, - WriteError, - WTimeoutError, - _wtimeout_error, ) -from pymongo.helpers_constants import _NOT_PRIMARY_CODES, _REAUTHENTICATION_REQUIRED_CODE - -if TYPE_CHECKING: - from pymongo.asynchronous.operations import _IndexList - from pymongo.asynchronous.typings import _DocumentOut - from pymongo.cursor_shared import _Hint +from pymongo.helpers_shared import _REAUTHENTICATION_REQUIRED_CODE _IS_SYNC = False - -def _gen_index_name(keys: _IndexList) -> str: - """Generate an index name from the set of fields it is over.""" - return "_".join(["{}_{}".format(*item) for item in keys]) - - -def _index_list( - key_or_list: _Hint, direction: Optional[Union[int, str]] = None -) -> Sequence[tuple[str, Union[int, str, Mapping[str, Any]]]]: - """Helper to generate a list of (key, direction) pairs. - - Takes such a list, or a single key, or a single key and direction. - """ - if direction is not None: - if not isinstance(key_or_list, str): - raise TypeError("Expected a string and a direction") - return [(key_or_list, direction)] - else: - if isinstance(key_or_list, str): - return [(key_or_list, ASCENDING)] - elif isinstance(key_or_list, abc.ItemsView): - return list(key_or_list) # type: ignore[arg-type] - elif isinstance(key_or_list, abc.Mapping): - return list(key_or_list.items()) - elif not isinstance(key_or_list, (list, tuple)): - raise TypeError("if no direction is specified, key_or_list must be an instance of list") - values: list[tuple[str, int]] = [] - for item in key_or_list: - if isinstance(item, str): - item = (item, ASCENDING) # noqa: PLW2901 - values.append(item) - return values - - -def _index_document(index_list: _IndexList) -> dict[str, Any]: - """Helper to generate an index specifying document. - - Takes a list of (key, direction) pairs. - """ - if not isinstance(index_list, (list, tuple, abc.Mapping)): - raise TypeError( - "must use a dictionary or a list of (key, direction) pairs, not: " + repr(index_list) - ) - if not len(index_list): - raise ValueError("key_or_list must not be empty") - - index: dict[str, Any] = {} - - if isinstance(index_list, abc.Mapping): - for key in index_list: - value = index_list[key] - _validate_index_key_pair(key, value) - index[key] = value - else: - for item in index_list: - if isinstance(item, str): - item = (item, ASCENDING) # noqa: PLW2901 - key, value = item - _validate_index_key_pair(key, value) - index[key] = value - return index - - -def _validate_index_key_pair(key: Any, value: Any) -> None: - if not isinstance(key, str): - raise TypeError("first item in each key pair must be an instance of str") - if not isinstance(value, (str, int, abc.Mapping)): - raise TypeError( - "second item in each key pair must be 1, -1, " - "'2d', or another valid MongoDB index specifier." - ) - - -def _check_command_response( - response: _DocumentOut, - max_wire_version: Optional[int], - allowable_errors: Optional[Container[Union[int, str]]] = None, - parse_write_concern_error: bool = False, -) -> None: - """Check the response to a command for errors.""" - if "ok" not in response: - # Server didn't recognize our message as a command. - raise OperationFailure( - response.get("$err"), # type: ignore[arg-type] - response.get("code"), - response, - max_wire_version, - ) - - if parse_write_concern_error and "writeConcernError" in response: - _error = response["writeConcernError"] - _labels = response.get("errorLabels") - if _labels: - _error.update({"errorLabels": _labels}) - _raise_write_concern_error(_error) - - if response["ok"]: - return - - details = response - # Mongos returns the error details in a 'raw' object - # for some errors. - if "raw" in response: - for shard in response["raw"].values(): - # Grab the first non-empty raw error from a shard. - if shard.get("errmsg") and not shard.get("ok"): - details = shard - break - - errmsg = details["errmsg"] - code = details.get("code") - - # For allowable errors, only check for error messages when the code is not - # included. - if allowable_errors: - if code is not None: - if code in allowable_errors: - return - elif errmsg in allowable_errors: - return - - # Server is "not primary" or "recovering" - if code is not None: - if code in _NOT_PRIMARY_CODES: - raise NotPrimaryError(errmsg, response) - elif HelloCompat.LEGACY_ERROR in errmsg or "node is recovering" in errmsg: - raise NotPrimaryError(errmsg, response) - - # Other errors - # findAndModify with upsert can raise duplicate key error - if code in (11000, 11001, 12582): - raise DuplicateKeyError(errmsg, code, response, max_wire_version) - elif code == 50: - raise ExecutionTimeout(errmsg, code, response, max_wire_version) - elif code == 43: - raise CursorNotFound(errmsg, code, response, max_wire_version) - - raise OperationFailure(errmsg, code, response, max_wire_version) - - -def _raise_last_write_error(write_errors: list[Any]) -> NoReturn: - # If the last batch had multiple errors only report - # the last error to emulate continue_on_error. - error = write_errors[-1] - if error.get("code") == 11000: - raise DuplicateKeyError(error.get("errmsg"), 11000, error) - raise WriteError(error.get("errmsg"), error.get("code"), error) - - -def _raise_write_concern_error(error: Any) -> NoReturn: - if _wtimeout_error(error): - # Make sure we raise WTimeoutError - raise WTimeoutError(error.get("errmsg"), error.get("code"), error) - raise WriteConcernError(error.get("errmsg"), error.get("code"), error) - - -def _get_wce_doc(result: Mapping[str, Any]) -> Optional[Mapping[str, Any]]: - """Return the writeConcernError or None.""" - wce = result.get("writeConcernError") - if wce: - # The server reports errorLabels at the top level but it's more - # convenient to attach it to the writeConcernError doc itself. - error_labels = result.get("errorLabels") - if error_labels: - # Copy to avoid changing the original document. - wce = wce.copy() - wce["errorLabels"] = error_labels - return wce - - -def _check_write_command_response(result: Mapping[str, Any]) -> None: - """Backward compatibility helper for write command error handling.""" - # Prefer write errors over write concern errors - write_errors = result.get("writeErrors") - if write_errors: - _raise_last_write_error(write_errors) - - wce = _get_wce_doc(result) - if wce: - _raise_write_concern_error(wce) - - -def _fields_list_to_dict( - fields: Union[Mapping[str, Any], Iterable[str]], option_name: str -) -> Mapping[str, Any]: - """Takes a sequence of field names and returns a matching dictionary. - - ["a", "b"] becomes {"a": 1, "b": 1} - - and - - ["a.b.c", "d", "a.c"] becomes {"a.b.c": 1, "d": 1, "a.c": 1} - """ - if isinstance(fields, abc.Mapping): - return fields - - if isinstance(fields, (abc.Sequence, abc.Set)): - if not all(isinstance(field, str) for field in fields): - raise TypeError(f"{option_name} must be a list of key names, each an instance of str") - return dict.fromkeys(fields, 1) - - raise TypeError(f"{option_name} must be a mapping or list of key names") - - -def _handle_exception() -> None: - """Print exceptions raised by subscribers to stderr.""" - # Heavily influenced by logging.Handler.handleError. - - # See note here: - # https://docs.python.org/3.4/library/sys.html#sys.__stderr__ - if sys.stderr: - einfo = sys.exc_info() - try: - traceback.print_exception(einfo[0], einfo[1], einfo[2], None, sys.stderr) - except OSError: - pass - finally: - del einfo - - # See https://mypy.readthedocs.io/en/stable/generics.html?#decorator-factories F = TypeVar("F", bound=Callable[..., Any]) @@ -284,7 +39,7 @@ def _handle_reauth(func: F) -> F: async def inner(*args: Any, **kwargs: Any) -> Any: no_reauth = kwargs.pop("no_reauth", False) from pymongo.asynchronous.message import _BulkWriteContext - from pymongo.asynchronous.pool import Connection + from pymongo.asynchronous.pool import AsyncConnection try: return await func(*args, **kwargs) @@ -292,12 +47,12 @@ async def inner(*args: Any, **kwargs: Any) -> Any: if no_reauth: raise if exc.code == _REAUTHENTICATION_REQUIRED_CODE: - # Look for an argument that either is a Connection + # Look for an argument that either is a AsyncConnection # or has a connection attribute, so we can trigger # a reauth. conn = None for arg in args: - if isinstance(arg, Connection): + if isinstance(arg, AsyncConnection): conn = arg break if isinstance(arg, _BulkWriteContext): diff --git a/pymongo/asynchronous/message.py b/pymongo/asynchronous/message.py index 0815d33536..0973677d3a 100644 --- a/pymongo/asynchronous/message.py +++ b/pymongo/asynchronous/message.py @@ -54,14 +54,7 @@ _use_c = True except ImportError: _use_c = False -from pymongo.asynchronous.hello_compat import HelloCompat from pymongo.asynchronous.helpers import _handle_reauth -from pymongo.asynchronous.logger import ( - _COMMAND_LOGGER, - _CommandStatusMessage, - _debug_log, -) -from pymongo.asynchronous.read_preferences import ReadPreference from pymongo.errors import ( ConfigurationError, CursorNotFound, @@ -72,19 +65,26 @@ OperationFailure, ProtocolError, ) +from pymongo.hello_compat import HelloCompat +from pymongo.logger import ( + _COMMAND_LOGGER, + _CommandStatusMessage, + _debug_log, +) +from pymongo.read_preferences import ReadPreference from pymongo.write_concern import WriteConcern if TYPE_CHECKING: from datetime import timedelta - from pymongo.asynchronous.client_session import ClientSession - from pymongo.asynchronous.compression_support import SnappyContext, ZlibContext, ZstdContext + from pymongo.asynchronous.client_session import AsyncClientSession from pymongo.asynchronous.mongo_client import AsyncMongoClient - from pymongo.asynchronous.monitoring import _EventListeners - from pymongo.asynchronous.pool import Connection - from pymongo.asynchronous.read_preferences import _ServerMode - from pymongo.asynchronous.typings import _Address, _DocumentOut + from pymongo.asynchronous.pool import AsyncConnection + from pymongo.compression_support import SnappyContext, ZlibContext, ZstdContext + from pymongo.monitoring import _EventListeners from pymongo.read_concern import ReadConcern + from pymongo.read_preferences import _ServerMode + from pymongo.typings import _Address, _DocumentOut _IS_SYNC = False @@ -217,7 +217,7 @@ def _gen_find_command( options: Optional[int], read_concern: ReadConcern, collation: Optional[Mapping[str, Any]] = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, allow_disk_use: Optional[bool] = None, ) -> dict[str, Any]: """Generate a find command document.""" @@ -264,7 +264,7 @@ def _gen_get_more_command( batch_size: Optional[int], max_await_time_ms: Optional[int], comment: Optional[Any], - conn: Connection, + conn: AsyncConnection, ) -> dict[str, Any]: """Generate a getMore command document.""" cmd: dict[str, Any] = {"getMore": cursor_id, "collection": coll} @@ -319,7 +319,7 @@ def __init__( batch_size: int, read_concern: ReadConcern, collation: Optional[Mapping[str, Any]], - session: Optional[ClientSession], + session: Optional[AsyncClientSession], client: AsyncMongoClient, allow_disk_use: Optional[bool], exhaust: bool, @@ -349,7 +349,7 @@ def reset(self) -> None: def namespace(self) -> str: return f"{self.db}.{self.coll}" - def use_command(self, conn: Connection) -> bool: + def use_command(self, conn: AsyncConnection) -> bool: use_find_cmd = False if not self.exhaust: use_find_cmd = True @@ -366,7 +366,7 @@ def use_command(self, conn: Connection) -> bool: return use_find_cmd async def as_command( - self, conn: Connection, apply_timeout: bool = False + self, conn: AsyncConnection, apply_timeout: bool = False ) -> tuple[dict[str, Any], str]: """Return a find command document for this query.""" # We use the command twice: on the wire and for command monitoring. @@ -410,7 +410,7 @@ async def as_command( return self._as_command async def get_message( - self, read_preference: _ServerMode, conn: Connection, use_cmd: bool = False + self, read_preference: _ServerMode, conn: AsyncConnection, use_cmd: bool = False ) -> tuple[int, bytes, int]: """Get a query message, possibly setting the secondaryOk bit.""" # Use the read_preference decided by _socket_from_server. @@ -491,7 +491,7 @@ def __init__( cursor_id: int, codec_options: CodecOptions, read_preference: _ServerMode, - session: Optional[ClientSession], + session: Optional[AsyncClientSession], client: AsyncMongoClient, max_await_time_ms: Optional[int], conn_mgr: Any, @@ -518,7 +518,7 @@ def reset(self) -> None: def namespace(self) -> str: return f"{self.db}.{self.coll}" - def use_command(self, conn: Connection) -> bool: + def use_command(self, conn: AsyncConnection) -> bool: use_cmd = False if not self.exhaust: use_cmd = True @@ -530,7 +530,7 @@ def use_command(self, conn: Connection) -> bool: return use_cmd async def as_command( - self, conn: Connection, apply_timeout: bool = False + self, conn: AsyncConnection, apply_timeout: bool = False ) -> tuple[dict[str, Any], str]: """Return a getMore command document for this query.""" # See _Query.as_command for an explanation of this caching. @@ -560,7 +560,7 @@ async def as_command( return self._as_command async def get_message( - self, dummy0: Any, conn: Connection, use_cmd: bool = False + self, dummy0: Any, conn: AsyncConnection, use_cmd: bool = False ) -> Union[tuple[int, bytes, int], tuple[int, bytes]]: """Get a getmore message.""" ns = self.namespace() @@ -581,7 +581,7 @@ async def get_message( class _RawBatchQuery(_Query): - def use_command(self, conn: Connection) -> bool: + def use_command(self, conn: AsyncConnection) -> bool: # Compatibility checks. super().use_command(conn) if conn.max_wire_version >= 8: @@ -593,7 +593,7 @@ def use_command(self, conn: Connection) -> bool: class _RawBatchGetMore(_GetMore): - def use_command(self, conn: Connection) -> bool: + def use_command(self, conn: AsyncConnection) -> bool: # Compatibility checks. super().use_command(conn) if conn.max_wire_version >= 8: @@ -908,7 +908,7 @@ def _get_more( class _BulkWriteContext: - """A wrapper around Connection for use with write splitting functions.""" + """A wrapper around AsyncConnection for use with write splitting functions.""" __slots__ = ( "db_name", @@ -929,10 +929,10 @@ def __init__( self, database_name: str, cmd_name: str, - conn: Connection, + conn: AsyncConnection, operation_id: int, listeners: _EventListeners, - session: ClientSession, + session: AsyncClientSession, op_type: int, codec: CodecOptions, ): @@ -1012,7 +1012,7 @@ async def unack_write( docs: list[Mapping[str, Any]], client: AsyncMongoClient, ) -> Optional[Mapping[str, Any]]: - """A proxy for Connection.unack_write that handles event publishing.""" + """A proxy for AsyncConnection.unack_write that handles event publishing.""" if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _COMMAND_LOGGER, diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index 5eedd5ba07..acba1c1e32 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -58,42 +58,14 @@ from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions, TypeRegistry from bson.timestamp import Timestamp -from pymongo import _csot, helpers_constants -from pymongo.asynchronous import ( - client_session, - common, - database, - helpers, - message, - periodic_executor, - uri_parser, -) +from pymongo import _csot, common, helpers_shared, uri_parser +from pymongo.asynchronous import client_session, database, message, periodic_executor from pymongo.asynchronous.change_stream import ChangeStream, ClusterChangeStream -from pymongo.asynchronous.client_options import ClientOptions from pymongo.asynchronous.client_session import _EmptyServerSession from pymongo.asynchronous.command_cursor import AsyncCommandCursor -from pymongo.asynchronous.logger import _CLIENT_LOGGER, _log_or_warn -from pymongo.asynchronous.monitoring import ConnectionClosedReason -from pymongo.asynchronous.operations import _Op -from pymongo.asynchronous.read_preferences import ReadPreference, _ServerMode -from pymongo.asynchronous.server_selectors import writable_server_selector from pymongo.asynchronous.settings import TopologySettings from pymongo.asynchronous.topology import Topology, _ErrorContext -from pymongo.asynchronous.topology_description import TOPOLOGY_TYPE, TopologyDescription -from pymongo.asynchronous.typings import ( - ClusterTime, - _Address, - _CollationIn, - _DocumentType, - _DocumentTypeArg, - _Pipeline, -) -from pymongo.asynchronous.uri_parser import ( - _check_options, - _handle_option_deprecations, - _handle_security_options, - _normalize_options, -) +from pymongo.client_options import ClientOptions from pymongo.errors import ( AutoReconnect, BulkWriteError, @@ -108,7 +80,27 @@ WriteConcernError, ) from pymongo.lock import _HAS_REGISTER_AT_FORK, _ALock, _create_lock, _release_locks +from pymongo.logger import _CLIENT_LOGGER, _log_or_warn +from pymongo.monitoring import ConnectionClosedReason +from pymongo.operations import _Op +from pymongo.read_preferences import ReadPreference, _ServerMode +from pymongo.server_selectors import writable_server_selector from pymongo.server_type import SERVER_TYPE +from pymongo.topology_description import TOPOLOGY_TYPE, TopologyDescription +from pymongo.typings import ( + ClusterTime, + _Address, + _CollationIn, + _DocumentType, + _DocumentTypeArg, + _Pipeline, +) +from pymongo.uri_parser import ( + _check_options, + _handle_option_deprecations, + _handle_security_options, + _normalize_options, +) from pymongo.write_concern import DEFAULT_WRITE_CONCERN, WriteConcern if TYPE_CHECKING: @@ -116,15 +108,15 @@ from types import TracebackType from bson.objectid import ObjectId - from pymongo.asynchronous.bulk import _Bulk - from pymongo.asynchronous.client_session import ClientSession, _ServerSession + from pymongo.asynchronous.bulk import _AsyncBulk + from pymongo.asynchronous.client_session import AsyncClientSession, _ServerSession from pymongo.asynchronous.cursor import _ConnectionManager from pymongo.asynchronous.message import _CursorAddress, _GetMore, _Query - from pymongo.asynchronous.pool import Connection - from pymongo.asynchronous.response import Response + from pymongo.asynchronous.pool import AsyncConnection from pymongo.asynchronous.server import Server - from pymongo.asynchronous.server_selectors import Selection from pymongo.read_concern import ReadConcern + from pymongo.response import Response + from pymongo.server_selectors import Selection if sys.version_info[:2] >= (3, 9): pass @@ -134,9 +126,12 @@ T = TypeVar("T") -_WriteCall = Callable[[Optional["ClientSession"], "Connection", bool], Coroutine[Any, Any, T]] +_WriteCall = Callable[ + [Optional["AsyncClientSession"], "AsyncConnection", bool], Coroutine[Any, Any, T] +] _ReadCall = Callable[ - [Optional["ClientSession"], "Server", "Connection", _ServerMode], Coroutine[Any, Any, T] + [Optional["AsyncClientSession"], "Server", "AsyncConnection", _ServerMode], + Coroutine[Any, Any, T], ] _IS_SYNC = False @@ -896,7 +891,7 @@ async def target() -> bool: self_ref: Any = weakref.ref(self, executor.close) self._kill_cursors_executor = executor - def _should_pin_cursor(self, session: Optional[ClientSession]) -> Optional[bool]: + def _should_pin_cursor(self, session: Optional[AsyncClientSession]) -> Optional[bool]: return self._options.load_balanced and not (session and session.in_transaction) def _after_fork(self) -> None: @@ -917,7 +912,7 @@ async def watch( batch_size: Optional[int] = None, collation: Optional[_CollationIn] = None, start_at_operation_time: Optional[Timestamp] = None, - session: Optional[client_session.ClientSession] = None, + session: Optional[client_session.AsyncClientSession] = None, start_after: Optional[Mapping[str, Any]] = None, comment: Optional[Any] = None, full_document_before_change: Optional[str] = None, @@ -991,7 +986,7 @@ async def watch( the specified :class:`~bson.timestamp.Timestamp`. Requires MongoDB >= 4.0. :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param start_after: The same as `resume_after` except that `start_after` can resume notifications after an invalidate event. This option and `resume_after` are mutually exclusive. @@ -1165,30 +1160,30 @@ def _close_cursor_soon( """Request that a cursor and/or connection be cleaned up soon.""" self._kill_cursors_queue.append((address, cursor_id, conn_mgr)) - def _start_session(self, implicit: bool, **kwargs: Any) -> ClientSession: + def _start_session(self, implicit: bool, **kwargs: Any) -> AsyncClientSession: server_session = _EmptyServerSession() opts = client_session.SessionOptions(**kwargs) - return client_session.ClientSession(self, server_session, opts, implicit) + return client_session.AsyncClientSession(self, server_session, opts, implicit) def start_session( self, causal_consistency: Optional[bool] = None, default_transaction_options: Optional[client_session.TransactionOptions] = None, snapshot: Optional[bool] = False, - ) -> client_session.ClientSession: + ) -> client_session.AsyncClientSession: """Start a logical session. This method takes the same parameters as :class:`~pymongo.client_session.SessionOptions`. See the :mod:`~pymongo.client_session` module for details and examples. - A :class:`~pymongo.client_session.ClientSession` may only be used with - the MongoClient that started it. :class:`ClientSession` instances are + A :class:`~pymongo.client_session.AsyncClientSession` may only be used with + the MongoClient that started it. :class:`AsyncClientSession` instances are **not thread-safe or fork-safe**. They can only be used by one thread - or process at a time. A single :class:`ClientSession` cannot be used + or process at a time. A single :class:`AsyncClientSession` cannot be used to run multiple operations concurrently. - :return: An instance of :class:`~pymongo.client_session.ClientSession`. + :return: An instance of :class:`~pymongo.client_session.AsyncClientSession`. .. versionadded:: 3.6 """ @@ -1199,7 +1194,9 @@ def start_session( snapshot=snapshot, ) - def _ensure_session(self, session: Optional[ClientSession] = None) -> Optional[ClientSession]: + def _ensure_session( + self, session: Optional[AsyncClientSession] = None + ) -> Optional[AsyncClientSession]: """If provided session is None, lend a temporary session.""" if session: return session @@ -1213,7 +1210,7 @@ def _ensure_session(self, session: Optional[ClientSession] = None) -> Optional[C return None def _send_cluster_time( - self, command: MutableMapping[str, Any], session: Optional[ClientSession] + self, command: MutableMapping[str, Any], session: Optional[AsyncClientSession] ) -> None: topology_time = self._topology.max_cluster_time() session_time = session.cluster_time if session else None @@ -1474,7 +1471,7 @@ async def is_mongos(self) -> bool: async def _end_sessions(self, session_ids: list[_ServerSession]) -> None: """Send endSessions command(s) with the given session ids.""" try: - # Use Connection.command directly to avoid implicitly creating + # Use AsyncConnection.command directly to avoid implicitly creating # another session. async with await self._conn_for_reads( ReadPreference.PRIMARY_PREFERRED, None, operation=_Op.END_SESSIONS @@ -1533,8 +1530,8 @@ async def _get_topology(self) -> Topology: @contextlib.asynccontextmanager async def _checkout( - self, server: Server, session: Optional[ClientSession] - ) -> AsyncGenerator[Connection, None]: + self, server: Server, session: Optional[AsyncClientSession] + ) -> AsyncGenerator[AsyncConnection, None]: in_txn = session and session.in_transaction async with _MongoClientErrorHandler(self, server, session) as err_handler: # Reuse the pinned connection, if it exists. @@ -1568,7 +1565,7 @@ async def _checkout( async def _select_server( self, server_selector: Callable[[Selection], Selection], - session: Optional[ClientSession], + session: Optional[AsyncClientSession], operation: str, address: Optional[_Address] = None, deprioritized_servers: Optional[list[Server]] = None, @@ -1579,7 +1576,7 @@ async def _select_server( :Parameters: - `server_selector`: The server selector to use if the session is not pinned and no address is given. - - `session`: The ClientSession for the next operation, or None. May + - `session`: The AsyncClientSession for the next operation, or None. May be pinned to a mongos server address. - `address` (optional): Address when sending a message to a specific server, used for getMore. @@ -1613,15 +1610,15 @@ async def _select_server( raise async def _conn_for_writes( - self, session: Optional[ClientSession], operation: str - ) -> AsyncContextManager[Connection]: + self, session: Optional[AsyncClientSession], operation: str + ) -> AsyncContextManager[AsyncConnection]: server = await self._select_server(writable_server_selector, session, operation) return self._checkout(server, session) @contextlib.asynccontextmanager async def _conn_from_server( - self, read_preference: _ServerMode, server: Server, session: Optional[ClientSession] - ) -> AsyncGenerator[tuple[Connection, _ServerMode], None]: + self, read_preference: _ServerMode, server: Server, session: Optional[AsyncClientSession] + ) -> AsyncGenerator[tuple[AsyncConnection, _ServerMode], None]: assert read_preference is not None, "read_preference must not be None" # Get a connection for a server matching the read preference, and yield # conn with the effective read preference. The Server Selection @@ -1646,9 +1643,9 @@ async def _conn_from_server( async def _conn_for_reads( self, read_preference: _ServerMode, - session: Optional[ClientSession], + session: Optional[AsyncClientSession], operation: str, - ) -> AsyncContextManager[tuple[Connection, _ServerMode]]: + ) -> AsyncContextManager[tuple[AsyncConnection, _ServerMode]]: assert read_preference is not None, "read_preference must not be None" _ = await self._get_topology() server = await self._select_server(read_preference, session, operation) @@ -1689,9 +1686,9 @@ async def _run_operation( ) async def _cmd( - _session: Optional[ClientSession], + _session: Optional[AsyncClientSession], server: Server, - conn: Connection, + conn: AsyncConnection, read_preference: _ServerMode, ) -> Response: operation.reset() # Reset op in case of retry. @@ -1717,8 +1714,8 @@ async def _retry_with_session( self, retryable: bool, func: _WriteCall[T], - session: Optional[ClientSession], - bulk: Optional[_Bulk], + session: Optional[AsyncClientSession], + bulk: Optional[_AsyncBulk], operation: str, operation_id: Optional[int] = None, ) -> T: @@ -1747,8 +1744,8 @@ async def _retry_with_session( async def _retry_internal( self, func: _WriteCall[T] | _ReadCall[T], - session: Optional[ClientSession], - bulk: Optional[_Bulk], + session: Optional[AsyncClientSession], + bulk: Optional[_AsyncBulk], operation: str, is_read: bool = False, address: Optional[_Address] = None, @@ -1786,7 +1783,7 @@ async def _retryable_read( self, func: _ReadCall[T], read_pref: _ServerMode, - session: Optional[ClientSession], + session: Optional[AsyncClientSession], operation: str, address: Optional[_Address] = None, retryable: bool = True, @@ -1829,9 +1826,9 @@ async def _retryable_write( self, retryable: bool, func: _WriteCall[T], - session: Optional[ClientSession], + session: Optional[AsyncClientSession], operation: str, - bulk: Optional[_Bulk] = None, + bulk: Optional[_AsyncBulk] = None, operation_id: Optional[int] = None, ) -> T: """Execute an operation with consecutive retries if possible @@ -1856,7 +1853,7 @@ async def _cleanup_cursor( cursor_id: int, address: Optional[_CursorAddress], conn_mgr: _ConnectionManager, - session: Optional[ClientSession], + session: Optional[AsyncClientSession], explicit_session: bool, ) -> None: """Cleanup a cursor from cursor.close() or __del__. @@ -1897,7 +1894,7 @@ async def _close_cursor_now( self, cursor_id: int, address: Optional[_CursorAddress], - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, conn_mgr: Optional[_ConnectionManager] = None, ) -> None: """Send a kill cursors message with the given id. @@ -1925,7 +1922,7 @@ async def _kill_cursors( cursor_ids: Sequence[int], address: Optional[_CursorAddress], topology: Topology, - session: Optional[ClientSession], + session: Optional[AsyncClientSession], ) -> None: """Send a kill cursors message with the given ids.""" if address: @@ -1944,8 +1941,8 @@ async def _kill_cursor_impl( self, cursor_ids: Sequence[int], address: _CursorAddress, - session: Optional[ClientSession], - conn: Connection, + session: Optional[AsyncClientSession], + conn: AsyncConnection, ) -> None: namespace = address.namespace db, coll = namespace.split(".", 1) @@ -1978,7 +1975,7 @@ async def _process_kill_cursors(self) -> None: # can be caught in _process_periodic_tasks raise else: - helpers._handle_exception() + helpers_shared._handle_exception() # Don't re-open topology if it's closed and there's no pending cursors. if address_to_cursor_ids: @@ -1990,7 +1987,7 @@ async def _process_kill_cursors(self) -> None: if isinstance(exc, InvalidOperation) and self._topology._closed: raise else: - helpers._handle_exception() + helpers_shared._handle_exception() # This method is run periodically by a background thread. async def _process_periodic_tasks(self) -> None: @@ -2004,7 +2001,7 @@ async def _process_periodic_tasks(self) -> None: if isinstance(exc, InvalidOperation) and self._topology._closed: return else: - helpers._handle_exception() + helpers_shared._handle_exception() async def _return_server_session( self, server_session: Union[_ServerSession, _EmptyServerSession], lock: bool @@ -2016,12 +2013,12 @@ async def _return_server_session( @contextlib.asynccontextmanager async def _tmp_session( - self, session: Optional[client_session.ClientSession], close: bool = True - ) -> AsyncGenerator[Optional[client_session.ClientSession], None, None]: + self, session: Optional[client_session.AsyncClientSession], close: bool = True + ) -> AsyncGenerator[Optional[client_session.AsyncClientSession], None, None]: """If provided session is None, lend a temporary session.""" if session is not None: - if not isinstance(session, client_session.ClientSession): - raise ValueError("'session' argument must be a ClientSession or None.") + if not isinstance(session, client_session.AsyncClientSession): + raise ValueError("'session' argument must be a AsyncClientSession or None.") # Don't call end_session. yield session return @@ -2045,19 +2042,19 @@ async def _tmp_session( yield None async def _process_response( - self, reply: Mapping[str, Any], session: Optional[ClientSession] + self, reply: Mapping[str, Any], session: Optional[AsyncClientSession] ) -> None: await self._topology.receive_cluster_time(reply.get("$clusterTime")) if session is not None: session._process_response(reply) async def server_info( - self, session: Optional[client_session.ClientSession] = None + self, session: Optional[client_session.AsyncClientSession] = None ) -> dict[str, Any]: """Get information about the MongoDB server we're connected to. :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. .. versionchanged:: 3.6 Added ``session`` parameter. @@ -2071,7 +2068,7 @@ async def server_info( async def _list_databases( self, - session: Optional[client_session.ClientSession] = None, + session: Optional[client_session.AsyncClientSession] = None, comment: Optional[Any] = None, **kwargs: Any, ) -> AsyncCommandCursor[dict[str, Any]]: @@ -2093,14 +2090,14 @@ async def _list_databases( async def list_databases( self, - session: Optional[client_session.ClientSession] = None, + session: Optional[client_session.AsyncClientSession] = None, comment: Optional[Any] = None, **kwargs: Any, ) -> AsyncCommandCursor[dict[str, Any]]: """Get a cursor over the databases of the connected server. :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param comment: A user-provided comment to attach to this command. :param kwargs: Optional parameters of the @@ -2118,13 +2115,13 @@ async def list_databases( async def list_database_names( self, - session: Optional[client_session.ClientSession] = None, + session: Optional[client_session.AsyncClientSession] = None, comment: Optional[Any] = None, ) -> list[str]: """Get a list of the names of all databases on the connected server. :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param comment: A user-provided comment to attach to this command. @@ -2140,7 +2137,7 @@ async def list_database_names( async def drop_database( self, name_or_database: Union[str, database.AsyncDatabase[_DocumentTypeArg]], - session: Optional[client_session.ClientSession] = None, + session: Optional[client_session.AsyncClientSession] = None, comment: Optional[Any] = None, ) -> None: """Drop a database. @@ -2152,7 +2149,7 @@ async def drop_database( :class:`~pymongo.database.Database` instance representing the database to drop :param session: a - :class:`~pymongo.client_session.ClientSession`. + :class:`~pymongo.client_session.AsyncClientSession`. :param comment: A user-provided comment to attach to this command. @@ -2220,10 +2217,10 @@ def _add_retryable_write_error(exc: PyMongoError, max_wire_version: int, is_mong # Do not consult writeConcernError for pre-4.4 mongos. if isinstance(exc, WriteConcernError) and is_mongos: pass - elif code in helpers_constants._RETRYABLE_ERROR_CODES: + elif code in helpers_shared._RETRYABLE_ERROR_CODES: exc._add_error_label("RetryableWriteError") - # Connection errors are always retryable except NotPrimaryError and WaitQueueTimeoutError which is + # AsyncConnection errors are always retryable except NotPrimaryError and WaitQueueTimeoutError which is # handled above. if isinstance(exc, ConnectionFailure) and not isinstance( exc, (NotPrimaryError, WaitQueueTimeoutError) @@ -2245,7 +2242,9 @@ class _MongoClientErrorHandler: "handled", ) - def __init__(self, client: AsyncMongoClient, server: Server, session: Optional[ClientSession]): + def __init__( + self, client: AsyncMongoClient, server: Server, session: Optional[AsyncClientSession] + ): self.client = client self.server_address = server.description.address self.session = session @@ -2259,7 +2258,7 @@ def __init__(self, client: AsyncMongoClient, server: Server, session: Optional[C self.service_id: Optional[ObjectId] = None self.handled = False - def contribute_socket(self, conn: Connection, completed_handshake: bool = True) -> None: + def contribute_socket(self, conn: AsyncConnection, completed_handshake: bool = True) -> None: """Provide socket information to the error handler.""" self.max_wire_version = conn.max_wire_version self.sock_generation = conn.generation @@ -2311,10 +2310,10 @@ def __init__( self, mongo_client: AsyncMongoClient, func: _WriteCall[T] | _ReadCall[T], - bulk: Optional[_Bulk], + bulk: Optional[_AsyncBulk], operation: str, is_read: bool = False, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, read_pref: Optional[_ServerMode] = None, address: Optional[_Address] = None, retryable: bool = False, @@ -2376,7 +2375,7 @@ async def run(self) -> T: exc_code = getattr(exc, "code", None) if self._is_not_eligible_for_retry() or ( isinstance(exc, OperationFailure) - and exc_code not in helpers_constants._RETRYABLE_ERROR_CODES + and exc_code not in helpers_shared._RETRYABLE_ERROR_CODES ): raise self._retrying = True diff --git a/pymongo/asynchronous/monitor.py b/pymongo/asynchronous/monitor.py index 6bd8061081..a5f7435128 100644 --- a/pymongo/asynchronous/monitor.py +++ b/pymongo/asynchronous/monitor.py @@ -21,19 +21,20 @@ import weakref from typing import TYPE_CHECKING, Any, Mapping, Optional, cast +from pymongo import common from pymongo._csot import MovingMinimum -from pymongo.asynchronous import common, periodic_executor -from pymongo.asynchronous.hello import Hello +from pymongo.asynchronous import periodic_executor from pymongo.asynchronous.periodic_executor import _shutdown_executors -from pymongo.asynchronous.pool import _is_faas -from pymongo.asynchronous.read_preferences import MovingAverage -from pymongo.asynchronous.server_description import ServerDescription -from pymongo.asynchronous.srv_resolver import _SrvResolver from pymongo.errors import NetworkTimeout, NotPrimaryError, OperationFailure, _OperationCancelled +from pymongo.hello import Hello from pymongo.lock import _create_lock +from pymongo.pool_options import _is_faas +from pymongo.read_preferences import MovingAverage +from pymongo.server_description import ServerDescription +from pymongo.srv_resolver import _SrvResolver if TYPE_CHECKING: - from pymongo.asynchronous.pool import Connection, Pool, _CancellationContext + from pymongo.asynchronous.pool import AsyncConnection, Pool, _CancellationContext from pymongo.asynchronous.settings import TopologySettings from pymongo.asynchronous.topology import Topology @@ -294,7 +295,7 @@ async def _check_once(self) -> ServerDescription: ) return sd - async def _check_with_socket(self, conn: Connection) -> tuple[Hello, float]: + async def _check_with_socket(self, conn: AsyncConnection) -> tuple[Hello, float]: """Return (Hello, round_trip_time). Can raise ConnectionFailure or OperationFailure. diff --git a/pymongo/asynchronous/monitoring.py b/pymongo/asynchronous/monitoring.py deleted file mode 100644 index 36d015fe29..0000000000 --- a/pymongo/asynchronous/monitoring.py +++ /dev/null @@ -1,1903 +0,0 @@ -# Copyright 2015-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you -# may not use this file except in compliance with the License. You -# may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or -# implied. See the License for the specific language governing -# permissions and limitations under the License. - -"""Tools to monitor driver events. - -.. versionadded:: 3.1 - -.. attention:: Starting in PyMongo 3.11, the monitoring classes outlined below - are included in the PyMongo distribution under the - :mod:`~pymongo.event_loggers` submodule. - -Use :func:`register` to register global listeners for specific events. -Listeners must inherit from one of the abstract classes below and implement -the correct functions for that class. - -For example, a simple command logger might be implemented like this:: - - import logging - - from pymongo import monitoring - - class CommandLogger(monitoring.CommandListener): - - def started(self, event): - logging.info("Command {0.command_name} with request id " - "{0.request_id} started on server " - "{0.connection_id}".format(event)) - - def succeeded(self, event): - logging.info("Command {0.command_name} with request id " - "{0.request_id} on server {0.connection_id} " - "succeeded in {0.duration_micros} " - "microseconds".format(event)) - - def failed(self, event): - logging.info("Command {0.command_name} with request id " - "{0.request_id} on server {0.connection_id} " - "failed in {0.duration_micros} " - "microseconds".format(event)) - - monitoring.register(CommandLogger()) - -Server discovery and monitoring events are also available. For example:: - - class ServerLogger(monitoring.ServerListener): - - def opened(self, event): - logging.info("Server {0.server_address} added to topology " - "{0.topology_id}".format(event)) - - def description_changed(self, event): - previous_server_type = event.previous_description.server_type - new_server_type = event.new_description.server_type - if new_server_type != previous_server_type: - # server_type_name was added in PyMongo 3.4 - logging.info( - "Server {0.server_address} changed type from " - "{0.previous_description.server_type_name} to " - "{0.new_description.server_type_name}".format(event)) - - def closed(self, event): - logging.warning("Server {0.server_address} removed from topology " - "{0.topology_id}".format(event)) - - - class HeartbeatLogger(monitoring.ServerHeartbeatListener): - - def started(self, event): - logging.info("Heartbeat sent to server " - "{0.connection_id}".format(event)) - - def succeeded(self, event): - # The reply.document attribute was added in PyMongo 3.4. - logging.info("Heartbeat to server {0.connection_id} " - "succeeded with reply " - "{0.reply.document}".format(event)) - - def failed(self, event): - logging.warning("Heartbeat to server {0.connection_id} " - "failed with error {0.reply}".format(event)) - - class TopologyLogger(monitoring.TopologyListener): - - def opened(self, event): - logging.info("Topology with id {0.topology_id} " - "opened".format(event)) - - def description_changed(self, event): - logging.info("Topology description updated for " - "topology id {0.topology_id}".format(event)) - previous_topology_type = event.previous_description.topology_type - new_topology_type = event.new_description.topology_type - if new_topology_type != previous_topology_type: - # topology_type_name was added in PyMongo 3.4 - logging.info( - "Topology {0.topology_id} changed type from " - "{0.previous_description.topology_type_name} to " - "{0.new_description.topology_type_name}".format(event)) - # The has_writable_server and has_readable_server methods - # were added in PyMongo 3.4. - if not event.new_description.has_writable_server(): - logging.warning("No writable servers available.") - if not event.new_description.has_readable_server(): - logging.warning("No readable servers available.") - - def closed(self, event): - logging.info("Topology with id {0.topology_id} " - "closed".format(event)) - -Connection monitoring and pooling events are also available. For example:: - - class ConnectionPoolLogger(ConnectionPoolListener): - - def pool_created(self, event): - logging.info("[pool {0.address}] pool created".format(event)) - - def pool_ready(self, event): - logging.info("[pool {0.address}] pool is ready".format(event)) - - def pool_cleared(self, event): - logging.info("[pool {0.address}] pool cleared".format(event)) - - def pool_closed(self, event): - logging.info("[pool {0.address}] pool closed".format(event)) - - def connection_created(self, event): - logging.info("[pool {0.address}][connection #{0.connection_id}] " - "connection created".format(event)) - - def connection_ready(self, event): - logging.info("[pool {0.address}][connection #{0.connection_id}] " - "connection setup succeeded".format(event)) - - def connection_closed(self, event): - logging.info("[pool {0.address}][connection #{0.connection_id}] " - "connection closed, reason: " - "{0.reason}".format(event)) - - def connection_check_out_started(self, event): - logging.info("[pool {0.address}] connection check out " - "started".format(event)) - - def connection_check_out_failed(self, event): - logging.info("[pool {0.address}] connection check out " - "failed, reason: {0.reason}".format(event)) - - def connection_checked_out(self, event): - logging.info("[pool {0.address}][connection #{0.connection_id}] " - "connection checked out of pool".format(event)) - - def connection_checked_in(self, event): - logging.info("[pool {0.address}][connection #{0.connection_id}] " - "connection checked into pool".format(event)) - - -Event listeners can also be registered per instance of -:class:`~pymongo.mongo_client.MongoClient`:: - - client = MongoClient(event_listeners=[CommandLogger()]) - -Note that previously registered global listeners are automatically included -when configuring per client event listeners. Registering a new global listener -will not add that listener to existing client instances. - -.. note:: Events are delivered **synchronously**. Application threads block - waiting for event handlers (e.g. :meth:`~CommandListener.started`) to - return. Care must be taken to ensure that your event handlers are efficient - enough to not adversely affect overall application performance. - -.. warning:: The command documents published through this API are *not* copies. - If you intend to modify them in any way you must copy them in your event - handler first. -""" - -from __future__ import annotations - -import datetime -from collections import abc, namedtuple -from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence - -from bson.objectid import ObjectId -from pymongo.asynchronous.hello import Hello -from pymongo.asynchronous.hello_compat import HelloCompat -from pymongo.asynchronous.helpers import _handle_exception -from pymongo.asynchronous.typings import _Address, _DocumentOut -from pymongo.helpers_constants import _SENSITIVE_COMMANDS - -if TYPE_CHECKING: - from datetime import timedelta - - from pymongo.asynchronous.server_description import ServerDescription - from pymongo.asynchronous.topology_description import TopologyDescription - -_IS_SYNC = False - -_Listeners = namedtuple( - "_Listeners", - ( - "command_listeners", - "server_listeners", - "server_heartbeat_listeners", - "topology_listeners", - "cmap_listeners", - ), -) - -_LISTENERS = _Listeners([], [], [], [], []) - - -class _EventListener: - """Abstract base class for all event listeners.""" - - -class CommandListener(_EventListener): - """Abstract base class for command listeners. - - Handles `CommandStartedEvent`, `CommandSucceededEvent`, - and `CommandFailedEvent`. - """ - - def started(self, event: CommandStartedEvent) -> None: - """Abstract method to handle a `CommandStartedEvent`. - - :param event: An instance of :class:`CommandStartedEvent`. - """ - raise NotImplementedError - - def succeeded(self, event: CommandSucceededEvent) -> None: - """Abstract method to handle a `CommandSucceededEvent`. - - :param event: An instance of :class:`CommandSucceededEvent`. - """ - raise NotImplementedError - - def failed(self, event: CommandFailedEvent) -> None: - """Abstract method to handle a `CommandFailedEvent`. - - :param event: An instance of :class:`CommandFailedEvent`. - """ - raise NotImplementedError - - -class ConnectionPoolListener(_EventListener): - """Abstract base class for connection pool listeners. - - Handles all of the connection pool events defined in the Connection - Monitoring and Pooling Specification: - :class:`PoolCreatedEvent`, :class:`PoolClearedEvent`, - :class:`PoolClosedEvent`, :class:`ConnectionCreatedEvent`, - :class:`ConnectionReadyEvent`, :class:`ConnectionClosedEvent`, - :class:`ConnectionCheckOutStartedEvent`, - :class:`ConnectionCheckOutFailedEvent`, - :class:`ConnectionCheckedOutEvent`, - and :class:`ConnectionCheckedInEvent`. - - .. versionadded:: 3.9 - """ - - def pool_created(self, event: PoolCreatedEvent) -> None: - """Abstract method to handle a :class:`PoolCreatedEvent`. - - Emitted when a connection Pool is created. - - :param event: An instance of :class:`PoolCreatedEvent`. - """ - raise NotImplementedError - - def pool_ready(self, event: PoolReadyEvent) -> None: - """Abstract method to handle a :class:`PoolReadyEvent`. - - Emitted when a connection Pool is marked ready. - - :param event: An instance of :class:`PoolReadyEvent`. - - .. versionadded:: 4.0 - """ - raise NotImplementedError - - def pool_cleared(self, event: PoolClearedEvent) -> None: - """Abstract method to handle a `PoolClearedEvent`. - - Emitted when a connection Pool is cleared. - - :param event: An instance of :class:`PoolClearedEvent`. - """ - raise NotImplementedError - - def pool_closed(self, event: PoolClosedEvent) -> None: - """Abstract method to handle a `PoolClosedEvent`. - - Emitted when a connection Pool is closed. - - :param event: An instance of :class:`PoolClosedEvent`. - """ - raise NotImplementedError - - def connection_created(self, event: ConnectionCreatedEvent) -> None: - """Abstract method to handle a :class:`ConnectionCreatedEvent`. - - Emitted when a connection Pool creates a Connection object. - - :param event: An instance of :class:`ConnectionCreatedEvent`. - """ - raise NotImplementedError - - def connection_ready(self, event: ConnectionReadyEvent) -> None: - """Abstract method to handle a :class:`ConnectionReadyEvent`. - - Emitted when a connection has finished its setup, and is now ready to - use. - - :param event: An instance of :class:`ConnectionReadyEvent`. - """ - raise NotImplementedError - - def connection_closed(self, event: ConnectionClosedEvent) -> None: - """Abstract method to handle a :class:`ConnectionClosedEvent`. - - Emitted when a connection Pool closes a connection. - - :param event: An instance of :class:`ConnectionClosedEvent`. - """ - raise NotImplementedError - - def connection_check_out_started(self, event: ConnectionCheckOutStartedEvent) -> None: - """Abstract method to handle a :class:`ConnectionCheckOutStartedEvent`. - - Emitted when the driver starts attempting to check out a connection. - - :param event: An instance of :class:`ConnectionCheckOutStartedEvent`. - """ - raise NotImplementedError - - def connection_check_out_failed(self, event: ConnectionCheckOutFailedEvent) -> None: - """Abstract method to handle a :class:`ConnectionCheckOutFailedEvent`. - - Emitted when the driver's attempt to check out a connection fails. - - :param event: An instance of :class:`ConnectionCheckOutFailedEvent`. - """ - raise NotImplementedError - - def connection_checked_out(self, event: ConnectionCheckedOutEvent) -> None: - """Abstract method to handle a :class:`ConnectionCheckedOutEvent`. - - Emitted when the driver successfully checks out a connection. - - :param event: An instance of :class:`ConnectionCheckedOutEvent`. - """ - raise NotImplementedError - - def connection_checked_in(self, event: ConnectionCheckedInEvent) -> None: - """Abstract method to handle a :class:`ConnectionCheckedInEvent`. - - Emitted when the driver checks in a connection back to the connection - Pool. - - :param event: An instance of :class:`ConnectionCheckedInEvent`. - """ - raise NotImplementedError - - -class ServerHeartbeatListener(_EventListener): - """Abstract base class for server heartbeat listeners. - - Handles `ServerHeartbeatStartedEvent`, `ServerHeartbeatSucceededEvent`, - and `ServerHeartbeatFailedEvent`. - - .. versionadded:: 3.3 - """ - - def started(self, event: ServerHeartbeatStartedEvent) -> None: - """Abstract method to handle a `ServerHeartbeatStartedEvent`. - - :param event: An instance of :class:`ServerHeartbeatStartedEvent`. - """ - raise NotImplementedError - - def succeeded(self, event: ServerHeartbeatSucceededEvent) -> None: - """Abstract method to handle a `ServerHeartbeatSucceededEvent`. - - :param event: An instance of :class:`ServerHeartbeatSucceededEvent`. - """ - raise NotImplementedError - - def failed(self, event: ServerHeartbeatFailedEvent) -> None: - """Abstract method to handle a `ServerHeartbeatFailedEvent`. - - :param event: An instance of :class:`ServerHeartbeatFailedEvent`. - """ - raise NotImplementedError - - -class TopologyListener(_EventListener): - """Abstract base class for topology monitoring listeners. - Handles `TopologyOpenedEvent`, `TopologyDescriptionChangedEvent`, and - `TopologyClosedEvent`. - - .. versionadded:: 3.3 - """ - - def opened(self, event: TopologyOpenedEvent) -> None: - """Abstract method to handle a `TopologyOpenedEvent`. - - :param event: An instance of :class:`TopologyOpenedEvent`. - """ - raise NotImplementedError - - def description_changed(self, event: TopologyDescriptionChangedEvent) -> None: - """Abstract method to handle a `TopologyDescriptionChangedEvent`. - - :param event: An instance of :class:`TopologyDescriptionChangedEvent`. - """ - raise NotImplementedError - - def closed(self, event: TopologyClosedEvent) -> None: - """Abstract method to handle a `TopologyClosedEvent`. - - :param event: An instance of :class:`TopologyClosedEvent`. - """ - raise NotImplementedError - - -class ServerListener(_EventListener): - """Abstract base class for server listeners. - Handles `ServerOpeningEvent`, `ServerDescriptionChangedEvent`, and - `ServerClosedEvent`. - - .. versionadded:: 3.3 - """ - - def opened(self, event: ServerOpeningEvent) -> None: - """Abstract method to handle a `ServerOpeningEvent`. - - :param event: An instance of :class:`ServerOpeningEvent`. - """ - raise NotImplementedError - - def description_changed(self, event: ServerDescriptionChangedEvent) -> None: - """Abstract method to handle a `ServerDescriptionChangedEvent`. - - :param event: An instance of :class:`ServerDescriptionChangedEvent`. - """ - raise NotImplementedError - - def closed(self, event: ServerClosedEvent) -> None: - """Abstract method to handle a `ServerClosedEvent`. - - :param event: An instance of :class:`ServerClosedEvent`. - """ - raise NotImplementedError - - -def _to_micros(dur: timedelta) -> int: - """Convert duration 'dur' to microseconds.""" - return int(dur.total_seconds() * 10e5) - - -def _validate_event_listeners( - option: str, listeners: Sequence[_EventListeners] -) -> Sequence[_EventListeners]: - """Validate event listeners""" - if not isinstance(listeners, abc.Sequence): - raise TypeError(f"{option} must be a list or tuple") - for listener in listeners: - if not isinstance(listener, _EventListener): - raise TypeError( - f"Listeners for {option} must be either a " - "CommandListener, ServerHeartbeatListener, " - "ServerListener, TopologyListener, or " - "ConnectionPoolListener." - ) - return listeners - - -def register(listener: _EventListener) -> None: - """Register a global event listener. - - :param listener: A subclasses of :class:`CommandListener`, - :class:`ServerHeartbeatListener`, :class:`ServerListener`, - :class:`TopologyListener`, or :class:`ConnectionPoolListener`. - """ - if not isinstance(listener, _EventListener): - raise TypeError( - f"Listeners for {listener} must be either a " - "CommandListener, ServerHeartbeatListener, " - "ServerListener, TopologyListener, or " - "ConnectionPoolListener." - ) - if isinstance(listener, CommandListener): - _LISTENERS.command_listeners.append(listener) - if isinstance(listener, ServerHeartbeatListener): - _LISTENERS.server_heartbeat_listeners.append(listener) - if isinstance(listener, ServerListener): - _LISTENERS.server_listeners.append(listener) - if isinstance(listener, TopologyListener): - _LISTENERS.topology_listeners.append(listener) - if isinstance(listener, ConnectionPoolListener): - _LISTENERS.cmap_listeners.append(listener) - - -# The "hello" command is also deemed sensitive when attempting speculative -# authentication. -def _is_speculative_authenticate(command_name: str, doc: Mapping[str, Any]) -> bool: - if ( - command_name.lower() in ("hello", HelloCompat.LEGACY_CMD) - and "speculativeAuthenticate" in doc - ): - return True - return False - - -class _CommandEvent: - """Base class for command events.""" - - __slots__ = ( - "__cmd_name", - "__rqst_id", - "__conn_id", - "__op_id", - "__service_id", - "__db", - "__server_conn_id", - ) - - def __init__( - self, - command_name: str, - request_id: int, - connection_id: _Address, - operation_id: Optional[int], - service_id: Optional[ObjectId] = None, - database_name: str = "", - server_connection_id: Optional[int] = None, - ) -> None: - self.__cmd_name = command_name - self.__rqst_id = request_id - self.__conn_id = connection_id - self.__op_id = operation_id - self.__service_id = service_id - self.__db = database_name - self.__server_conn_id = server_connection_id - - @property - def command_name(self) -> str: - """The command name.""" - return self.__cmd_name - - @property - def request_id(self) -> int: - """The request id for this operation.""" - return self.__rqst_id - - @property - def connection_id(self) -> _Address: - """The address (host, port) of the server this command was sent to.""" - return self.__conn_id - - @property - def service_id(self) -> Optional[ObjectId]: - """The service_id this command was sent to, or ``None``. - - .. versionadded:: 3.12 - """ - return self.__service_id - - @property - def operation_id(self) -> Optional[int]: - """An id for this series of events or None.""" - return self.__op_id - - @property - def database_name(self) -> str: - """The database_name this command was sent to, or ``""``. - - .. versionadded:: 4.6 - """ - return self.__db - - @property - def server_connection_id(self) -> Optional[int]: - """The server-side connection id for the connection this command was sent on, or ``None``. - - .. versionadded:: 4.7 - """ - return self.__server_conn_id - - -class CommandStartedEvent(_CommandEvent): - """Event published when a command starts. - - :param command: The command document. - :param database_name: The name of the database this command was run against. - :param request_id: The request id for this operation. - :param connection_id: The address (host, port) of the server this command - was sent to. - :param operation_id: An optional identifier for a series of related events. - :param service_id: The service_id this command was sent to, or ``None``. - """ - - __slots__ = ("__cmd",) - - def __init__( - self, - command: _DocumentOut, - database_name: str, - request_id: int, - connection_id: _Address, - operation_id: Optional[int], - service_id: Optional[ObjectId] = None, - server_connection_id: Optional[int] = None, - ) -> None: - if not command: - raise ValueError(f"{command!r} is not a valid command") - # Command name must be first key. - command_name = next(iter(command)) - super().__init__( - command_name, - request_id, - connection_id, - operation_id, - service_id=service_id, - database_name=database_name, - server_connection_id=server_connection_id, - ) - cmd_name = command_name.lower() - if cmd_name in _SENSITIVE_COMMANDS or _is_speculative_authenticate(cmd_name, command): - self.__cmd: _DocumentOut = {} - else: - self.__cmd = command - - @property - def command(self) -> _DocumentOut: - """The command document.""" - return self.__cmd - - @property - def database_name(self) -> str: - """The name of the database this command was run against.""" - return super().database_name - - def __repr__(self) -> str: - return ( - "<{} {} db: {!r}, command: {!r}, operation_id: {}, service_id: {}, server_connection_id: {}>" - ).format( - self.__class__.__name__, - self.connection_id, - self.database_name, - self.command_name, - self.operation_id, - self.service_id, - self.server_connection_id, - ) - - -class CommandSucceededEvent(_CommandEvent): - """Event published when a command succeeds. - - :param duration: The command duration as a datetime.timedelta. - :param reply: The server reply document. - :param command_name: The command name. - :param request_id: The request id for this operation. - :param connection_id: The address (host, port) of the server this command - was sent to. - :param operation_id: An optional identifier for a series of related events. - :param service_id: The service_id this command was sent to, or ``None``. - :param database_name: The database this command was sent to, or ``""``. - """ - - __slots__ = ("__duration_micros", "__reply") - - def __init__( - self, - duration: datetime.timedelta, - reply: _DocumentOut, - command_name: str, - request_id: int, - connection_id: _Address, - operation_id: Optional[int], - service_id: Optional[ObjectId] = None, - database_name: str = "", - server_connection_id: Optional[int] = None, - ) -> None: - super().__init__( - command_name, - request_id, - connection_id, - operation_id, - service_id=service_id, - database_name=database_name, - server_connection_id=server_connection_id, - ) - self.__duration_micros = _to_micros(duration) - cmd_name = command_name.lower() - if cmd_name in _SENSITIVE_COMMANDS or _is_speculative_authenticate(cmd_name, reply): - self.__reply: _DocumentOut = {} - else: - self.__reply = reply - - @property - def duration_micros(self) -> int: - """The duration of this operation in microseconds.""" - return self.__duration_micros - - @property - def reply(self) -> _DocumentOut: - """The server failure document for this operation.""" - return self.__reply - - def __repr__(self) -> str: - return ( - "<{} {} db: {!r}, command: {!r}, operation_id: {}, duration_micros: {}, service_id: {}, server_connection_id: {}>" - ).format( - self.__class__.__name__, - self.connection_id, - self.database_name, - self.command_name, - self.operation_id, - self.duration_micros, - self.service_id, - self.server_connection_id, - ) - - -class CommandFailedEvent(_CommandEvent): - """Event published when a command fails. - - :param duration: The command duration as a datetime.timedelta. - :param failure: The server reply document. - :param command_name: The command name. - :param request_id: The request id for this operation. - :param connection_id: The address (host, port) of the server this command - was sent to. - :param operation_id: An optional identifier for a series of related events. - :param service_id: The service_id this command was sent to, or ``None``. - :param database_name: The database this command was sent to, or ``""``. - """ - - __slots__ = ("__duration_micros", "__failure") - - def __init__( - self, - duration: datetime.timedelta, - failure: _DocumentOut, - command_name: str, - request_id: int, - connection_id: _Address, - operation_id: Optional[int], - service_id: Optional[ObjectId] = None, - database_name: str = "", - server_connection_id: Optional[int] = None, - ) -> None: - super().__init__( - command_name, - request_id, - connection_id, - operation_id, - service_id=service_id, - database_name=database_name, - server_connection_id=server_connection_id, - ) - self.__duration_micros = _to_micros(duration) - self.__failure = failure - - @property - def duration_micros(self) -> int: - """The duration of this operation in microseconds.""" - return self.__duration_micros - - @property - def failure(self) -> _DocumentOut: - """The server failure document for this operation.""" - return self.__failure - - def __repr__(self) -> str: - return ( - "<{} {} db: {!r}, command: {!r}, operation_id: {}, duration_micros: {}, " - "failure: {!r}, service_id: {}, server_connection_id: {}>" - ).format( - self.__class__.__name__, - self.connection_id, - self.database_name, - self.command_name, - self.operation_id, - self.duration_micros, - self.failure, - self.service_id, - self.server_connection_id, - ) - - -class _PoolEvent: - """Base class for pool events.""" - - __slots__ = ("__address",) - - def __init__(self, address: _Address) -> None: - self.__address = address - - @property - def address(self) -> _Address: - """The address (host, port) pair of the server the pool is attempting - to connect to. - """ - return self.__address - - def __repr__(self) -> str: - return f"{self.__class__.__name__}({self.__address!r})" - - -class PoolCreatedEvent(_PoolEvent): - """Published when a Connection Pool is created. - - :param address: The address (host, port) pair of the server this Pool is - attempting to connect to. - - .. versionadded:: 3.9 - """ - - __slots__ = ("__options",) - - def __init__(self, address: _Address, options: dict[str, Any]) -> None: - super().__init__(address) - self.__options = options - - @property - def options(self) -> dict[str, Any]: - """Any non-default pool options that were set on this Connection Pool.""" - return self.__options - - def __repr__(self) -> str: - return f"{self.__class__.__name__}({self.address!r}, {self.__options!r})" - - -class PoolReadyEvent(_PoolEvent): - """Published when a Connection Pool is marked ready. - - :param address: The address (host, port) pair of the server this Pool is - attempting to connect to. - - .. versionadded:: 4.0 - """ - - __slots__ = () - - -class PoolClearedEvent(_PoolEvent): - """Published when a Connection Pool is cleared. - - :param address: The address (host, port) pair of the server this Pool is - attempting to connect to. - :param service_id: The service_id this command was sent to, or ``None``. - :param interrupt_connections: True if all active connections were interrupted by the Pool during clearing. - - .. versionadded:: 3.9 - """ - - __slots__ = ("__service_id", "__interrupt_connections") - - def __init__( - self, - address: _Address, - service_id: Optional[ObjectId] = None, - interrupt_connections: bool = False, - ) -> None: - super().__init__(address) - self.__service_id = service_id - self.__interrupt_connections = interrupt_connections - - @property - def service_id(self) -> Optional[ObjectId]: - """Connections with this service_id are cleared. - - When service_id is ``None``, all connections in the pool are cleared. - - .. versionadded:: 3.12 - """ - return self.__service_id - - @property - def interrupt_connections(self) -> bool: - """If True, active connections are interrupted during clearing. - - .. versionadded:: 4.7 - """ - return self.__interrupt_connections - - def __repr__(self) -> str: - return f"{self.__class__.__name__}({self.address!r}, {self.__service_id!r}, {self.__interrupt_connections!r})" - - -class PoolClosedEvent(_PoolEvent): - """Published when a Connection Pool is closed. - - :param address: The address (host, port) pair of the server this Pool is - attempting to connect to. - - .. versionadded:: 3.9 - """ - - __slots__ = () - - -class ConnectionClosedReason: - """An enum that defines values for `reason` on a - :class:`ConnectionClosedEvent`. - - .. versionadded:: 3.9 - """ - - STALE = "stale" - """The pool was cleared, making the connection no longer valid.""" - - IDLE = "idle" - """The connection became stale by being idle for too long (maxIdleTimeMS). - """ - - ERROR = "error" - """The connection experienced an error, making it no longer valid.""" - - POOL_CLOSED = "poolClosed" - """The pool was closed, making the connection no longer valid.""" - - -class ConnectionCheckOutFailedReason: - """An enum that defines values for `reason` on a - :class:`ConnectionCheckOutFailedEvent`. - - .. versionadded:: 3.9 - """ - - TIMEOUT = "timeout" - """The connection check out attempt exceeded the specified timeout.""" - - POOL_CLOSED = "poolClosed" - """The pool was previously closed, and cannot provide new connections.""" - - CONN_ERROR = "connectionError" - """The connection check out attempt experienced an error while setting up - a new connection. - """ - - -class _ConnectionEvent: - """Private base class for connection events.""" - - __slots__ = ("__address",) - - def __init__(self, address: _Address) -> None: - self.__address = address - - @property - def address(self) -> _Address: - """The address (host, port) pair of the server this connection is - attempting to connect to. - """ - return self.__address - - def __repr__(self) -> str: - return f"{self.__class__.__name__}({self.__address!r})" - - -class _ConnectionIdEvent(_ConnectionEvent): - """Private base class for connection events with an id.""" - - __slots__ = ("__connection_id",) - - def __init__(self, address: _Address, connection_id: int) -> None: - super().__init__(address) - self.__connection_id = connection_id - - @property - def connection_id(self) -> int: - """The ID of the connection.""" - return self.__connection_id - - def __repr__(self) -> str: - return f"{self.__class__.__name__}({self.address!r}, {self.__connection_id!r})" - - -class _ConnectionDurationEvent(_ConnectionIdEvent): - """Private base class for connection events with a duration.""" - - __slots__ = ("__duration",) - - def __init__(self, address: _Address, connection_id: int, duration: Optional[float]) -> None: - super().__init__(address, connection_id) - self.__duration = duration - - @property - def duration(self) -> Optional[float]: - """The duration of the connection event. - - .. versionadded:: 4.7 - """ - return self.__duration - - def __repr__(self) -> str: - return f"{self.__class__.__name__}({self.address!r}, {self.connection_id!r}, {self.__duration!r})" - - -class ConnectionCreatedEvent(_ConnectionIdEvent): - """Published when a Connection Pool creates a Connection object. - - NOTE: This connection is not ready for use until the - :class:`ConnectionReadyEvent` is published. - - :param address: The address (host, port) pair of the server this - Connection is attempting to connect to. - :param connection_id: The integer ID of the Connection in this Pool. - - .. versionadded:: 3.9 - """ - - __slots__ = () - - -class ConnectionReadyEvent(_ConnectionDurationEvent): - """Published when a Connection has finished its setup, and is ready to use. - - :param address: The address (host, port) pair of the server this - Connection is attempting to connect to. - :param connection_id: The integer ID of the Connection in this Pool. - - .. versionadded:: 3.9 - """ - - __slots__ = () - - -class ConnectionClosedEvent(_ConnectionIdEvent): - """Published when a Connection is closed. - - :param address: The address (host, port) pair of the server this - Connection is attempting to connect to. - :param connection_id: The integer ID of the Connection in this Pool. - :param reason: A reason explaining why this connection was closed. - - .. versionadded:: 3.9 - """ - - __slots__ = ("__reason",) - - def __init__(self, address: _Address, connection_id: int, reason: str): - super().__init__(address, connection_id) - self.__reason = reason - - @property - def reason(self) -> str: - """A reason explaining why this connection was closed. - - The reason must be one of the strings from the - :class:`ConnectionClosedReason` enum. - """ - return self.__reason - - def __repr__(self) -> str: - return "{}({!r}, {!r}, {!r})".format( - self.__class__.__name__, - self.address, - self.connection_id, - self.__reason, - ) - - -class ConnectionCheckOutStartedEvent(_ConnectionEvent): - """Published when the driver starts attempting to check out a connection. - - :param address: The address (host, port) pair of the server this - Connection is attempting to connect to. - - .. versionadded:: 3.9 - """ - - __slots__ = () - - -class ConnectionCheckOutFailedEvent(_ConnectionDurationEvent): - """Published when the driver's attempt to check out a connection fails. - - :param address: The address (host, port) pair of the server this - Connection is attempting to connect to. - :param reason: A reason explaining why connection check out failed. - - .. versionadded:: 3.9 - """ - - __slots__ = ("__reason",) - - def __init__(self, address: _Address, reason: str, duration: Optional[float]) -> None: - super().__init__(address=address, connection_id=0, duration=duration) - self.__reason = reason - - @property - def reason(self) -> str: - """A reason explaining why connection check out failed. - - The reason must be one of the strings from the - :class:`ConnectionCheckOutFailedReason` enum. - """ - return self.__reason - - def __repr__(self) -> str: - return f"{self.__class__.__name__}({self.address!r}, {self.__reason!r}, {self.duration!r})" - - -class ConnectionCheckedOutEvent(_ConnectionDurationEvent): - """Published when the driver successfully checks out a connection. - - :param address: The address (host, port) pair of the server this - Connection is attempting to connect to. - :param connection_id: The integer ID of the Connection in this Pool. - - .. versionadded:: 3.9 - """ - - __slots__ = () - - -class ConnectionCheckedInEvent(_ConnectionIdEvent): - """Published when the driver checks in a Connection into the Pool. - - :param address: The address (host, port) pair of the server this - Connection is attempting to connect to. - :param connection_id: The integer ID of the Connection in this Pool. - - .. versionadded:: 3.9 - """ - - __slots__ = () - - -class _ServerEvent: - """Base class for server events.""" - - __slots__ = ("__server_address", "__topology_id") - - def __init__(self, server_address: _Address, topology_id: ObjectId) -> None: - self.__server_address = server_address - self.__topology_id = topology_id - - @property - def server_address(self) -> _Address: - """The address (host, port) pair of the server""" - return self.__server_address - - @property - def topology_id(self) -> ObjectId: - """A unique identifier for the topology this server is a part of.""" - return self.__topology_id - - def __repr__(self) -> str: - return f"<{self.__class__.__name__} {self.server_address} topology_id: {self.topology_id}>" - - -class ServerDescriptionChangedEvent(_ServerEvent): - """Published when server description changes. - - .. versionadded:: 3.3 - """ - - __slots__ = ("__previous_description", "__new_description") - - def __init__( - self, - previous_description: ServerDescription, - new_description: ServerDescription, - *args: Any, - ) -> None: - super().__init__(*args) - self.__previous_description = previous_description - self.__new_description = new_description - - @property - def previous_description(self) -> ServerDescription: - """The previous - :class:`~pymongo.server_description.ServerDescription`. - """ - return self.__previous_description - - @property - def new_description(self) -> ServerDescription: - """The new - :class:`~pymongo.server_description.ServerDescription`. - """ - return self.__new_description - - def __repr__(self) -> str: - return "<{} {} changed from: {}, to: {}>".format( - self.__class__.__name__, - self.server_address, - self.previous_description, - self.new_description, - ) - - -class ServerOpeningEvent(_ServerEvent): - """Published when server is initialized. - - .. versionadded:: 3.3 - """ - - __slots__ = () - - -class ServerClosedEvent(_ServerEvent): - """Published when server is closed. - - .. versionadded:: 3.3 - """ - - __slots__ = () - - -class TopologyEvent: - """Base class for topology description events.""" - - __slots__ = ("__topology_id",) - - def __init__(self, topology_id: ObjectId) -> None: - self.__topology_id = topology_id - - @property - def topology_id(self) -> ObjectId: - """A unique identifier for the topology this server is a part of.""" - return self.__topology_id - - def __repr__(self) -> str: - return f"<{self.__class__.__name__} topology_id: {self.topology_id}>" - - -class TopologyDescriptionChangedEvent(TopologyEvent): - """Published when the topology description changes. - - .. versionadded:: 3.3 - """ - - __slots__ = ("__previous_description", "__new_description") - - def __init__( - self, - previous_description: TopologyDescription, - new_description: TopologyDescription, - *args: Any, - ) -> None: - super().__init__(*args) - self.__previous_description = previous_description - self.__new_description = new_description - - @property - def previous_description(self) -> TopologyDescription: - """The previous - :class:`~pymongo.topology_description.TopologyDescription`. - """ - return self.__previous_description - - @property - def new_description(self) -> TopologyDescription: - """The new - :class:`~pymongo.topology_description.TopologyDescription`. - """ - return self.__new_description - - def __repr__(self) -> str: - return "<{} topology_id: {} changed from: {}, to: {}>".format( - self.__class__.__name__, - self.topology_id, - self.previous_description, - self.new_description, - ) - - -class TopologyOpenedEvent(TopologyEvent): - """Published when the topology is initialized. - - .. versionadded:: 3.3 - """ - - __slots__ = () - - -class TopologyClosedEvent(TopologyEvent): - """Published when the topology is closed. - - .. versionadded:: 3.3 - """ - - __slots__ = () - - -class _ServerHeartbeatEvent: - """Base class for server heartbeat events.""" - - __slots__ = ("__connection_id", "__awaited") - - def __init__(self, connection_id: _Address, awaited: bool = False) -> None: - self.__connection_id = connection_id - self.__awaited = awaited - - @property - def connection_id(self) -> _Address: - """The address (host, port) of the server this heartbeat was sent - to. - """ - return self.__connection_id - - @property - def awaited(self) -> bool: - """Whether the heartbeat was issued as an awaitable hello command. - - .. versionadded:: 4.6 - """ - return self.__awaited - - def __repr__(self) -> str: - return f"<{self.__class__.__name__} {self.connection_id} awaited: {self.awaited}>" - - -class ServerHeartbeatStartedEvent(_ServerHeartbeatEvent): - """Published when a heartbeat is started. - - .. versionadded:: 3.3 - """ - - __slots__ = () - - -class ServerHeartbeatSucceededEvent(_ServerHeartbeatEvent): - """Fired when the server heartbeat succeeds. - - .. versionadded:: 3.3 - """ - - __slots__ = ("__duration", "__reply") - - def __init__( - self, duration: float, reply: Hello, connection_id: _Address, awaited: bool = False - ) -> None: - super().__init__(connection_id, awaited) - self.__duration = duration - self.__reply = reply - - @property - def duration(self) -> float: - """The duration of this heartbeat in microseconds.""" - return self.__duration - - @property - def reply(self) -> Hello: - """An instance of :class:`~pymongo.hello.Hello`.""" - return self.__reply - - @property - def awaited(self) -> bool: - """Whether the heartbeat was awaited. - - If true, then :meth:`duration` reflects the sum of the round trip time - to the server and the time that the server waited before sending a - response. - - .. versionadded:: 3.11 - """ - return super().awaited - - def __repr__(self) -> str: - return "<{} {} duration: {}, awaited: {}, reply: {}>".format( - self.__class__.__name__, - self.connection_id, - self.duration, - self.awaited, - self.reply, - ) - - -class ServerHeartbeatFailedEvent(_ServerHeartbeatEvent): - """Fired when the server heartbeat fails, either with an "ok: 0" - or a socket exception. - - .. versionadded:: 3.3 - """ - - __slots__ = ("__duration", "__reply") - - def __init__( - self, duration: float, reply: Exception, connection_id: _Address, awaited: bool = False - ) -> None: - super().__init__(connection_id, awaited) - self.__duration = duration - self.__reply = reply - - @property - def duration(self) -> float: - """The duration of this heartbeat in microseconds.""" - return self.__duration - - @property - def reply(self) -> Exception: - """A subclass of :exc:`Exception`.""" - return self.__reply - - @property - def awaited(self) -> bool: - """Whether the heartbeat was awaited. - - If true, then :meth:`duration` reflects the sum of the round trip time - to the server and the time that the server waited before sending a - response. - - .. versionadded:: 3.11 - """ - return super().awaited - - def __repr__(self) -> str: - return "<{} {} duration: {}, awaited: {}, reply: {!r}>".format( - self.__class__.__name__, - self.connection_id, - self.duration, - self.awaited, - self.reply, - ) - - -class _EventListeners: - """Configure event listeners for a client instance. - - Any event listeners registered globally are included by default. - - :param listeners: A list of event listeners. - """ - - def __init__(self, listeners: Optional[Sequence[_EventListener]]): - self.__command_listeners = _LISTENERS.command_listeners[:] - self.__server_listeners = _LISTENERS.server_listeners[:] - lst = _LISTENERS.server_heartbeat_listeners - self.__server_heartbeat_listeners = lst[:] - self.__topology_listeners = _LISTENERS.topology_listeners[:] - self.__cmap_listeners = _LISTENERS.cmap_listeners[:] - if listeners is not None: - for lst in listeners: - if isinstance(lst, CommandListener): - self.__command_listeners.append(lst) - if isinstance(lst, ServerListener): - self.__server_listeners.append(lst) - if isinstance(lst, ServerHeartbeatListener): - self.__server_heartbeat_listeners.append(lst) - if isinstance(lst, TopologyListener): - self.__topology_listeners.append(lst) - if isinstance(lst, ConnectionPoolListener): - self.__cmap_listeners.append(lst) - self.__enabled_for_commands = bool(self.__command_listeners) - self.__enabled_for_server = bool(self.__server_listeners) - self.__enabled_for_server_heartbeat = bool(self.__server_heartbeat_listeners) - self.__enabled_for_topology = bool(self.__topology_listeners) - self.__enabled_for_cmap = bool(self.__cmap_listeners) - - @property - def enabled_for_commands(self) -> bool: - """Are any CommandListener instances registered?""" - return self.__enabled_for_commands - - @property - def enabled_for_server(self) -> bool: - """Are any ServerListener instances registered?""" - return self.__enabled_for_server - - @property - def enabled_for_server_heartbeat(self) -> bool: - """Are any ServerHeartbeatListener instances registered?""" - return self.__enabled_for_server_heartbeat - - @property - def enabled_for_topology(self) -> bool: - """Are any TopologyListener instances registered?""" - return self.__enabled_for_topology - - @property - def enabled_for_cmap(self) -> bool: - """Are any ConnectionPoolListener instances registered?""" - return self.__enabled_for_cmap - - def event_listeners(self) -> list[_EventListeners]: - """List of registered event listeners.""" - return ( - self.__command_listeners - + self.__server_heartbeat_listeners - + self.__server_listeners - + self.__topology_listeners - + self.__cmap_listeners - ) - - def publish_command_start( - self, - command: _DocumentOut, - database_name: str, - request_id: int, - connection_id: _Address, - server_connection_id: Optional[int], - op_id: Optional[int] = None, - service_id: Optional[ObjectId] = None, - ) -> None: - """Publish a CommandStartedEvent to all command listeners. - - :param command: The command document. - :param database_name: The name of the database this command was run - against. - :param request_id: The request id for this operation. - :param connection_id: The address (host, port) of the server this - command was sent to. - :param op_id: The (optional) operation id for this operation. - :param service_id: The service_id this command was sent to, or ``None``. - """ - if op_id is None: - op_id = request_id - event = CommandStartedEvent( - command, - database_name, - request_id, - connection_id, - op_id, - service_id=service_id, - server_connection_id=server_connection_id, - ) - for subscriber in self.__command_listeners: - try: - subscriber.started(event) - except Exception: - _handle_exception() - - def publish_command_success( - self, - duration: timedelta, - reply: _DocumentOut, - command_name: str, - request_id: int, - connection_id: _Address, - server_connection_id: Optional[int], - op_id: Optional[int] = None, - service_id: Optional[ObjectId] = None, - speculative_hello: bool = False, - database_name: str = "", - ) -> None: - """Publish a CommandSucceededEvent to all command listeners. - - :param duration: The command duration as a datetime.timedelta. - :param reply: The server reply document. - :param command_name: The command name. - :param request_id: The request id for this operation. - :param connection_id: The address (host, port) of the server this - command was sent to. - :param op_id: The (optional) operation id for this operation. - :param service_id: The service_id this command was sent to, or ``None``. - :param speculative_hello: Was the command sent with speculative auth? - :param database_name: The database this command was sent to, or ``""``. - """ - if op_id is None: - op_id = request_id - if speculative_hello: - # Redact entire response when the command started contained - # speculativeAuthenticate. - reply = {} - event = CommandSucceededEvent( - duration, - reply, - command_name, - request_id, - connection_id, - op_id, - service_id, - database_name=database_name, - server_connection_id=server_connection_id, - ) - for subscriber in self.__command_listeners: - try: - subscriber.succeeded(event) - except Exception: - _handle_exception() - - def publish_command_failure( - self, - duration: timedelta, - failure: _DocumentOut, - command_name: str, - request_id: int, - connection_id: _Address, - server_connection_id: Optional[int], - op_id: Optional[int] = None, - service_id: Optional[ObjectId] = None, - database_name: str = "", - ) -> None: - """Publish a CommandFailedEvent to all command listeners. - - :param duration: The command duration as a datetime.timedelta. - :param failure: The server reply document or failure description - document. - :param command_name: The command name. - :param request_id: The request id for this operation. - :param connection_id: The address (host, port) of the server this - command was sent to. - :param op_id: The (optional) operation id for this operation. - :param service_id: The service_id this command was sent to, or ``None``. - :param database_name: The database this command was sent to, or ``""``. - """ - if op_id is None: - op_id = request_id - event = CommandFailedEvent( - duration, - failure, - command_name, - request_id, - connection_id, - op_id, - service_id=service_id, - database_name=database_name, - server_connection_id=server_connection_id, - ) - for subscriber in self.__command_listeners: - try: - subscriber.failed(event) - except Exception: - _handle_exception() - - def publish_server_heartbeat_started(self, connection_id: _Address, awaited: bool) -> None: - """Publish a ServerHeartbeatStartedEvent to all server heartbeat - listeners. - - :param connection_id: The address (host, port) pair of the connection. - :param awaited: True if this heartbeat is part of an awaitable hello command. - """ - event = ServerHeartbeatStartedEvent(connection_id, awaited) - for subscriber in self.__server_heartbeat_listeners: - try: - subscriber.started(event) - except Exception: - _handle_exception() - - def publish_server_heartbeat_succeeded( - self, connection_id: _Address, duration: float, reply: Hello, awaited: bool - ) -> None: - """Publish a ServerHeartbeatSucceededEvent to all server heartbeat - listeners. - - :param connection_id: The address (host, port) pair of the connection. - :param duration: The execution time of the event in the highest possible - resolution for the platform. - :param reply: The command reply. - :param awaited: True if the response was awaited. - """ - event = ServerHeartbeatSucceededEvent(duration, reply, connection_id, awaited) - for subscriber in self.__server_heartbeat_listeners: - try: - subscriber.succeeded(event) - except Exception: - _handle_exception() - - def publish_server_heartbeat_failed( - self, connection_id: _Address, duration: float, reply: Exception, awaited: bool - ) -> None: - """Publish a ServerHeartbeatFailedEvent to all server heartbeat - listeners. - - :param connection_id: The address (host, port) pair of the connection. - :param duration: The execution time of the event in the highest possible - resolution for the platform. - :param reply: The command reply. - :param awaited: True if the response was awaited. - """ - event = ServerHeartbeatFailedEvent(duration, reply, connection_id, awaited) - for subscriber in self.__server_heartbeat_listeners: - try: - subscriber.failed(event) - except Exception: - _handle_exception() - - def publish_server_opened(self, server_address: _Address, topology_id: ObjectId) -> None: - """Publish a ServerOpeningEvent to all server listeners. - - :param server_address: The address (host, port) pair of the server. - :param topology_id: A unique identifier for the topology this server - is a part of. - """ - event = ServerOpeningEvent(server_address, topology_id) - for subscriber in self.__server_listeners: - try: - subscriber.opened(event) - except Exception: - _handle_exception() - - def publish_server_closed(self, server_address: _Address, topology_id: ObjectId) -> None: - """Publish a ServerClosedEvent to all server listeners. - - :param server_address: The address (host, port) pair of the server. - :param topology_id: A unique identifier for the topology this server - is a part of. - """ - event = ServerClosedEvent(server_address, topology_id) - for subscriber in self.__server_listeners: - try: - subscriber.closed(event) - except Exception: - _handle_exception() - - def publish_server_description_changed( - self, - previous_description: ServerDescription, - new_description: ServerDescription, - server_address: _Address, - topology_id: ObjectId, - ) -> None: - """Publish a ServerDescriptionChangedEvent to all server listeners. - - :param previous_description: The previous server description. - :param server_address: The address (host, port) pair of the server. - :param new_description: The new server description. - :param topology_id: A unique identifier for the topology this server - is a part of. - """ - event = ServerDescriptionChangedEvent( - previous_description, new_description, server_address, topology_id - ) - for subscriber in self.__server_listeners: - try: - subscriber.description_changed(event) - except Exception: - _handle_exception() - - def publish_topology_opened(self, topology_id: ObjectId) -> None: - """Publish a TopologyOpenedEvent to all topology listeners. - - :param topology_id: A unique identifier for the topology this server - is a part of. - """ - event = TopologyOpenedEvent(topology_id) - for subscriber in self.__topology_listeners: - try: - subscriber.opened(event) - except Exception: - _handle_exception() - - def publish_topology_closed(self, topology_id: ObjectId) -> None: - """Publish a TopologyClosedEvent to all topology listeners. - - :param topology_id: A unique identifier for the topology this server - is a part of. - """ - event = TopologyClosedEvent(topology_id) - for subscriber in self.__topology_listeners: - try: - subscriber.closed(event) - except Exception: - _handle_exception() - - def publish_topology_description_changed( - self, - previous_description: TopologyDescription, - new_description: TopologyDescription, - topology_id: ObjectId, - ) -> None: - """Publish a TopologyDescriptionChangedEvent to all topology listeners. - - :param previous_description: The previous topology description. - :param new_description: The new topology description. - :param topology_id: A unique identifier for the topology this server - is a part of. - """ - event = TopologyDescriptionChangedEvent(previous_description, new_description, topology_id) - for subscriber in self.__topology_listeners: - try: - subscriber.description_changed(event) - except Exception: - _handle_exception() - - def publish_pool_created(self, address: _Address, options: dict[str, Any]) -> None: - """Publish a :class:`PoolCreatedEvent` to all pool listeners.""" - event = PoolCreatedEvent(address, options) - for subscriber in self.__cmap_listeners: - try: - subscriber.pool_created(event) - except Exception: - _handle_exception() - - def publish_pool_ready(self, address: _Address) -> None: - """Publish a :class:`PoolReadyEvent` to all pool listeners.""" - event = PoolReadyEvent(address) - for subscriber in self.__cmap_listeners: - try: - subscriber.pool_ready(event) - except Exception: - _handle_exception() - - def publish_pool_cleared( - self, - address: _Address, - service_id: Optional[ObjectId], - interrupt_connections: bool = False, - ) -> None: - """Publish a :class:`PoolClearedEvent` to all pool listeners.""" - event = PoolClearedEvent(address, service_id, interrupt_connections) - for subscriber in self.__cmap_listeners: - try: - subscriber.pool_cleared(event) - except Exception: - _handle_exception() - - def publish_pool_closed(self, address: _Address) -> None: - """Publish a :class:`PoolClosedEvent` to all pool listeners.""" - event = PoolClosedEvent(address) - for subscriber in self.__cmap_listeners: - try: - subscriber.pool_closed(event) - except Exception: - _handle_exception() - - def publish_connection_created(self, address: _Address, connection_id: int) -> None: - """Publish a :class:`ConnectionCreatedEvent` to all connection - listeners. - """ - event = ConnectionCreatedEvent(address, connection_id) - for subscriber in self.__cmap_listeners: - try: - subscriber.connection_created(event) - except Exception: - _handle_exception() - - def publish_connection_ready( - self, address: _Address, connection_id: int, duration: float - ) -> None: - """Publish a :class:`ConnectionReadyEvent` to all connection listeners.""" - event = ConnectionReadyEvent(address, connection_id, duration) - for subscriber in self.__cmap_listeners: - try: - subscriber.connection_ready(event) - except Exception: - _handle_exception() - - def publish_connection_closed(self, address: _Address, connection_id: int, reason: str) -> None: - """Publish a :class:`ConnectionClosedEvent` to all connection - listeners. - """ - event = ConnectionClosedEvent(address, connection_id, reason) - for subscriber in self.__cmap_listeners: - try: - subscriber.connection_closed(event) - except Exception: - _handle_exception() - - def publish_connection_check_out_started(self, address: _Address) -> None: - """Publish a :class:`ConnectionCheckOutStartedEvent` to all connection - listeners. - """ - event = ConnectionCheckOutStartedEvent(address) - for subscriber in self.__cmap_listeners: - try: - subscriber.connection_check_out_started(event) - except Exception: - _handle_exception() - - def publish_connection_check_out_failed( - self, address: _Address, reason: str, duration: float - ) -> None: - """Publish a :class:`ConnectionCheckOutFailedEvent` to all connection - listeners. - """ - event = ConnectionCheckOutFailedEvent(address, reason, duration) - for subscriber in self.__cmap_listeners: - try: - subscriber.connection_check_out_failed(event) - except Exception: - _handle_exception() - - def publish_connection_checked_out( - self, address: _Address, connection_id: int, duration: float - ) -> None: - """Publish a :class:`ConnectionCheckedOutEvent` to all connection - listeners. - """ - event = ConnectionCheckedOutEvent(address, connection_id, duration) - for subscriber in self.__cmap_listeners: - try: - subscriber.connection_checked_out(event) - except Exception: - _handle_exception() - - def publish_connection_checked_in(self, address: _Address, connection_id: int) -> None: - """Publish a :class:`ConnectionCheckedInEvent` to all connection - listeners. - """ - event = ConnectionCheckedInEvent(address, connection_id) - for subscriber in self.__cmap_listeners: - try: - subscriber.connection_checked_in(event) - except Exception: - _handle_exception() diff --git a/pymongo/asynchronous/network.py b/pymongo/asynchronous/network.py index 25fffaca19..7c3444e071 100644 --- a/pymongo/asynchronous/network.py +++ b/pymongo/asynchronous/network.py @@ -33,20 +33,19 @@ ) from bson import _decode_all_selective -from pymongo import _csot -from pymongo.asynchronous import helpers as _async_helpers -from pymongo.asynchronous import message as _async_message -from pymongo.asynchronous.common import MAX_MESSAGE_SIZE -from pymongo.asynchronous.compression_support import _NO_COMPRESSION, decompress -from pymongo.asynchronous.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log +from pymongo import _csot, helpers_shared +from pymongo.asynchronous import message from pymongo.asynchronous.message import _UNPACK_REPLY, _OpMsg, _OpReply -from pymongo.asynchronous.monitoring import _is_speculative_authenticate +from pymongo.common import MAX_MESSAGE_SIZE +from pymongo.compression_support import _NO_COMPRESSION, decompress from pymongo.errors import ( NotPrimaryError, OperationFailure, ProtocolError, _OperationCancelled, ) +from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log +from pymongo.monitoring import _is_speculative_authenticate from pymongo.network_layer import ( _POLL_TIMEOUT, _UNPACK_COMPRESSION_HEADER, @@ -58,27 +57,27 @@ if TYPE_CHECKING: from bson import CodecOptions - from pymongo.asynchronous.client_session import ClientSession - from pymongo.asynchronous.compression_support import SnappyContext, ZlibContext, ZstdContext + from pymongo.asynchronous.client_session import AsyncClientSession from pymongo.asynchronous.mongo_client import AsyncMongoClient - from pymongo.asynchronous.monitoring import _EventListeners - from pymongo.asynchronous.pool import Connection - from pymongo.asynchronous.read_preferences import _ServerMode - from pymongo.asynchronous.typings import _Address, _CollationIn, _DocumentOut, _DocumentType + from pymongo.asynchronous.pool import AsyncConnection + from pymongo.compression_support import SnappyContext, ZlibContext, ZstdContext + from pymongo.monitoring import _EventListeners from pymongo.read_concern import ReadConcern + from pymongo.read_preferences import _ServerMode + from pymongo.typings import _Address, _CollationIn, _DocumentOut, _DocumentType from pymongo.write_concern import WriteConcern _IS_SYNC = False async def command( - conn: Connection, + conn: AsyncConnection, dbname: str, spec: MutableMapping[str, Any], is_mongos: bool, read_preference: Optional[_ServerMode], codec_options: CodecOptions[_DocumentType], - session: Optional[ClientSession], + session: Optional[AsyncClientSession], client: Optional[AsyncMongoClient], check: bool = True, allowable_errors: Optional[Sequence[Union[str, int]]] = None, @@ -97,13 +96,13 @@ async def command( ) -> _DocumentType: """Execute a command over the socket, or raise socket.error. - :param conn: a Connection instance + :param conn: a AsyncConnection instance :param dbname: name of the database on which to run the command :param spec: a command document as an ordered dict type, eg SON. :param is_mongos: are we connected to a mongos? :param read_preference: a read preference :param codec_options: a CodecOptions instance - :param session: optional ClientSession instance. + :param session: optional AsyncClientSession instance. :param client: optional AsyncMongoClient instance for updating $clusterTime. :param check: raise OperationFailure if there are errors :param allowable_errors: errors to ignore if `check` is True @@ -130,7 +129,7 @@ async def command( orig = spec if is_mongos and not use_op_msg: assert read_preference is not None - spec = _async_message._maybe_add_read_preference(spec, read_preference) + spec = message._maybe_add_read_preference(spec, read_preference) if read_concern and not (session and session.in_transaction): if read_concern.level: spec["readConcern"] = read_concern.document @@ -158,22 +157,20 @@ async def command( if use_op_msg: flags = _OpMsg.MORE_TO_COME if unacknowledged else 0 flags |= _OpMsg.EXHAUST_ALLOWED if exhaust_allowed else 0 - request_id, msg, size, max_doc_size = _async_message._op_msg( + request_id, msg, size, max_doc_size = message._op_msg( flags, spec, dbname, read_preference, codec_options, ctx=compression_ctx ) # If this is an unacknowledged write then make sure the encoded doc(s) # are small enough, otherwise rely on the server to return an error. if unacknowledged and max_bson_size is not None and max_doc_size > max_bson_size: - _async_message._raise_document_too_large(name, size, max_bson_size) + message._raise_document_too_large(name, size, max_bson_size) else: - request_id, msg, size = _async_message._query( + request_id, msg, size = message._query( 0, ns, 0, -1, spec, None, codec_options, compression_ctx ) - if max_bson_size is not None and size > max_bson_size + _async_message._COMMAND_OVERHEAD: - _async_message._raise_document_too_large( - name, size, max_bson_size + _async_message._COMMAND_OVERHEAD - ) + if max_bson_size is not None and size > max_bson_size + message._COMMAND_OVERHEAD: + message._raise_document_too_large(name, size, max_bson_size + message._COMMAND_OVERHEAD) if client is not None: if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( @@ -220,7 +217,7 @@ async def command( if client: await client._process_response(response_doc, session) if check: - _async_helpers._check_command_response( + helpers_shared._check_command_response( response_doc, conn.max_wire_version, allowable_errors, @@ -231,7 +228,7 @@ async def command( if isinstance(exc, (NotPrimaryError, OperationFailure)): failure: _DocumentOut = exc.details # type: ignore[assignment] else: - failure = _async_message._convert_exception(exc) + failure = message._convert_exception(exc) if client is not None: if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( @@ -310,7 +307,7 @@ async def command( async def receive_message( - conn: Connection, request_id: Optional[int], max_message_size: int = MAX_MESSAGE_SIZE + conn: AsyncConnection, request_id: Optional[int], max_message_size: int = MAX_MESSAGE_SIZE ) -> Union[_OpReply, _OpMsg]: """Receive a raw BSON message or raise socket.error.""" if _csot.get_timeout(): @@ -355,7 +352,7 @@ async def receive_message( return unpack_reply(data) -async def wait_for_read(conn: Connection, deadline: Optional[float]) -> None: +async def wait_for_read(conn: AsyncConnection, deadline: Optional[float]) -> None: """Block until at least one byte is read, or a timeout, or a cancel.""" sock = conn.conn timed_out = False @@ -390,7 +387,7 @@ async def wait_for_read(conn: Connection, deadline: Optional[float]) -> None: async def _receive_data_on_socket( - conn: Connection, length: int, deadline: Optional[float] + conn: AsyncConnection, length: int, deadline: Optional[float] ) -> memoryview: buf = bytearray(length) mv = memoryview(buf) diff --git a/pymongo/asynchronous/operations.py b/pymongo/asynchronous/operations.py deleted file mode 100644 index d4beff759d..0000000000 --- a/pymongo/asynchronous/operations.py +++ /dev/null @@ -1,625 +0,0 @@ -# Copyright 2015-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Operation class definitions.""" -from __future__ import annotations - -import enum -from typing import ( - TYPE_CHECKING, - Any, - Generic, - Mapping, - Optional, - Sequence, - Tuple, - Union, -) - -from bson.raw_bson import RawBSONDocument -from pymongo.asynchronous import helpers -from pymongo.asynchronous.collation import validate_collation_or_none -from pymongo.asynchronous.common import validate_is_mapping, validate_list -from pymongo.asynchronous.helpers import _gen_index_name, _index_document, _index_list -from pymongo.asynchronous.typings import _CollationIn, _DocumentType, _Pipeline -from pymongo.write_concern import validate_boolean - -if TYPE_CHECKING: - from pymongo.asynchronous.bulk import _Bulk - -_IS_SYNC = False - -# Hint supports index name, "myIndex", a list of either strings or index pairs: [('x', 1), ('y', -1), 'z''], or a dictionary -_IndexList = Union[ - Sequence[Union[str, Tuple[str, Union[int, str, Mapping[str, Any]]]]], Mapping[str, Any] -] -_IndexKeyHint = Union[str, _IndexList] - - -class _Op(str, enum.Enum): - ABORT = "abortTransaction" - AGGREGATE = "aggregate" - COMMIT = "commitTransaction" - COUNT = "count" - CREATE = "create" - CREATE_INDEXES = "createIndexes" - CREATE_SEARCH_INDEXES = "createSearchIndexes" - DELETE = "delete" - DISTINCT = "distinct" - DROP = "drop" - DROP_DATABASE = "dropDatabase" - DROP_INDEXES = "dropIndexes" - DROP_SEARCH_INDEXES = "dropSearchIndexes" - END_SESSIONS = "endSessions" - FIND_AND_MODIFY = "findAndModify" - FIND = "find" - INSERT = "insert" - LIST_COLLECTIONS = "listCollections" - LIST_INDEXES = "listIndexes" - LIST_SEARCH_INDEX = "listSearchIndexes" - LIST_DATABASES = "listDatabases" - UPDATE = "update" - UPDATE_INDEX = "updateIndex" - UPDATE_SEARCH_INDEX = "updateSearchIndex" - RENAME = "rename" - GETMORE = "getMore" - KILL_CURSORS = "killCursors" - TEST = "testOperation" - - -class InsertOne(Generic[_DocumentType]): - """Represents an insert_one operation.""" - - __slots__ = ("_doc",) - - def __init__(self, document: _DocumentType) -> None: - """Create an InsertOne instance. - - For use with :meth:`~pymongo.collection.AsyncCollection.bulk_write`. - - :param document: The document to insert. If the document is missing an - _id field one will be added. - """ - self._doc = document - - def _add_to_bulk(self, bulkobj: _Bulk) -> None: - """Add this operation to the _Bulk instance `bulkobj`.""" - bulkobj.add_insert(self._doc) # type: ignore[arg-type] - - def __repr__(self) -> str: - return f"InsertOne({self._doc!r})" - - def __eq__(self, other: Any) -> bool: - if type(other) == type(self): - return other._doc == self._doc - return NotImplemented - - def __ne__(self, other: Any) -> bool: - return not self == other - - -class DeleteOne: - """Represents a delete_one operation.""" - - __slots__ = ("_filter", "_collation", "_hint") - - def __init__( - self, - filter: Mapping[str, Any], - collation: Optional[_CollationIn] = None, - hint: Optional[_IndexKeyHint] = None, - ) -> None: - """Create a DeleteOne instance. - - For use with :meth:`~pymongo.collection.AsyncCollection.bulk_write`. - - :param filter: A query that matches the document to delete. - :param collation: An instance of - :class:`~pymongo.collation.Collation`. - :param hint: An index to use to support the query - predicate specified either by its string name, or in the same - format as passed to - :meth:`~pymongo.collection.AsyncCollection.create_index` (e.g. - ``[('field', ASCENDING)]``). This option is only supported on - MongoDB 4.4 and above. - - .. versionchanged:: 3.11 - Added the ``hint`` option. - .. versionchanged:: 3.5 - Added the `collation` option. - """ - if filter is not None: - validate_is_mapping("filter", filter) - if hint is not None and not isinstance(hint, str): - self._hint: Union[str, dict[str, Any], None] = helpers._index_document(hint) - else: - self._hint = hint - self._filter = filter - self._collation = collation - - def _add_to_bulk(self, bulkobj: _Bulk) -> None: - """Add this operation to the _Bulk instance `bulkobj`.""" - bulkobj.add_delete( - self._filter, - 1, - collation=validate_collation_or_none(self._collation), - hint=self._hint, - ) - - def __repr__(self) -> str: - return f"DeleteOne({self._filter!r}, {self._collation!r}, {self._hint!r})" - - def __eq__(self, other: Any) -> bool: - if type(other) == type(self): - return (other._filter, other._collation, other._hint) == ( - self._filter, - self._collation, - self._hint, - ) - return NotImplemented - - def __ne__(self, other: Any) -> bool: - return not self == other - - -class DeleteMany: - """Represents a delete_many operation.""" - - __slots__ = ("_filter", "_collation", "_hint") - - def __init__( - self, - filter: Mapping[str, Any], - collation: Optional[_CollationIn] = None, - hint: Optional[_IndexKeyHint] = None, - ) -> None: - """Create a DeleteMany instance. - - For use with :meth:`~pymongo.collection.AsyncCollection.bulk_write`. - - :param filter: A query that matches the documents to delete. - :param collation: An instance of - :class:`~pymongo.collation.Collation`. - :param hint: An index to use to support the query - predicate specified either by its string name, or in the same - format as passed to - :meth:`~pymongo.collection.AsyncCollection.create_index` (e.g. - ``[('field', ASCENDING)]``). This option is only supported on - MongoDB 4.4 and above. - - .. versionchanged:: 3.11 - Added the ``hint`` option. - .. versionchanged:: 3.5 - Added the `collation` option. - """ - if filter is not None: - validate_is_mapping("filter", filter) - if hint is not None and not isinstance(hint, str): - self._hint: Union[str, dict[str, Any], None] = helpers._index_document(hint) - else: - self._hint = hint - self._filter = filter - self._collation = collation - - def _add_to_bulk(self, bulkobj: _Bulk) -> None: - """Add this operation to the _Bulk instance `bulkobj`.""" - bulkobj.add_delete( - self._filter, - 0, - collation=validate_collation_or_none(self._collation), - hint=self._hint, - ) - - def __repr__(self) -> str: - return f"DeleteMany({self._filter!r}, {self._collation!r}, {self._hint!r})" - - def __eq__(self, other: Any) -> bool: - if type(other) == type(self): - return (other._filter, other._collation, other._hint) == ( - self._filter, - self._collation, - self._hint, - ) - return NotImplemented - - def __ne__(self, other: Any) -> bool: - return not self == other - - -class ReplaceOne(Generic[_DocumentType]): - """Represents a replace_one operation.""" - - __slots__ = ("_filter", "_doc", "_upsert", "_collation", "_hint") - - def __init__( - self, - filter: Mapping[str, Any], - replacement: Union[_DocumentType, RawBSONDocument], - upsert: bool = False, - collation: Optional[_CollationIn] = None, - hint: Optional[_IndexKeyHint] = None, - ) -> None: - """Create a ReplaceOne instance. - - For use with :meth:`~pymongo.collection.AsyncCollection.bulk_write`. - - :param filter: A query that matches the document to replace. - :param replacement: The new document. - :param upsert: If ``True``, perform an insert if no documents - match the filter. - :param collation: An instance of - :class:`~pymongo.collation.Collation`. - :param hint: An index to use to support the query - predicate specified either by its string name, or in the same - format as passed to - :meth:`~pymongo.collection.AsyncCollection.create_index` (e.g. - ``[('field', ASCENDING)]``). This option is only supported on - MongoDB 4.2 and above. - - .. versionchanged:: 3.11 - Added the ``hint`` option. - .. versionchanged:: 3.5 - Added the ``collation`` option. - """ - if filter is not None: - validate_is_mapping("filter", filter) - if upsert is not None: - validate_boolean("upsert", upsert) - if hint is not None and not isinstance(hint, str): - self._hint: Union[str, dict[str, Any], None] = helpers._index_document(hint) - else: - self._hint = hint - self._filter = filter - self._doc = replacement - self._upsert = upsert - self._collation = collation - - def _add_to_bulk(self, bulkobj: _Bulk) -> None: - """Add this operation to the _Bulk instance `bulkobj`.""" - bulkobj.add_replace( - self._filter, - self._doc, - self._upsert, - collation=validate_collation_or_none(self._collation), - hint=self._hint, - ) - - def __eq__(self, other: Any) -> bool: - if type(other) == type(self): - return ( - other._filter, - other._doc, - other._upsert, - other._collation, - other._hint, - ) == ( - self._filter, - self._doc, - self._upsert, - self._collation, - other._hint, - ) - return NotImplemented - - def __ne__(self, other: Any) -> bool: - return not self == other - - def __repr__(self) -> str: - return "{}({!r}, {!r}, {!r}, {!r}, {!r})".format( - self.__class__.__name__, - self._filter, - self._doc, - self._upsert, - self._collation, - self._hint, - ) - - -class _UpdateOp: - """Private base class for update operations.""" - - __slots__ = ("_filter", "_doc", "_upsert", "_collation", "_array_filters", "_hint") - - def __init__( - self, - filter: Mapping[str, Any], - doc: Union[Mapping[str, Any], _Pipeline], - upsert: bool, - collation: Optional[_CollationIn], - array_filters: Optional[list[Mapping[str, Any]]], - hint: Optional[_IndexKeyHint], - ): - if filter is not None: - validate_is_mapping("filter", filter) - if upsert is not None: - validate_boolean("upsert", upsert) - if array_filters is not None: - validate_list("array_filters", array_filters) - if hint is not None and not isinstance(hint, str): - self._hint: Union[str, dict[str, Any], None] = helpers._index_document(hint) - else: - self._hint = hint - - self._filter = filter - self._doc = doc - self._upsert = upsert - self._collation = collation - self._array_filters = array_filters - - def __eq__(self, other: object) -> bool: - if isinstance(other, type(self)): - return ( - other._filter, - other._doc, - other._upsert, - other._collation, - other._array_filters, - other._hint, - ) == ( - self._filter, - self._doc, - self._upsert, - self._collation, - self._array_filters, - self._hint, - ) - return NotImplemented - - def __repr__(self) -> str: - return "{}({!r}, {!r}, {!r}, {!r}, {!r}, {!r})".format( - self.__class__.__name__, - self._filter, - self._doc, - self._upsert, - self._collation, - self._array_filters, - self._hint, - ) - - -class UpdateOne(_UpdateOp): - """Represents an update_one operation.""" - - __slots__ = () - - def __init__( - self, - filter: Mapping[str, Any], - update: Union[Mapping[str, Any], _Pipeline], - upsert: bool = False, - collation: Optional[_CollationIn] = None, - array_filters: Optional[list[Mapping[str, Any]]] = None, - hint: Optional[_IndexKeyHint] = None, - ) -> None: - """Represents an update_one operation. - - For use with :meth:`~pymongo.collection.AsyncCollection.bulk_write`. - - :param filter: A query that matches the document to update. - :param update: The modifications to apply. - :param upsert: If ``True``, perform an insert if no documents - match the filter. - :param collation: An instance of - :class:`~pymongo.collation.Collation`. - :param array_filters: A list of filters specifying which - array elements an update should apply. - :param hint: An index to use to support the query - predicate specified either by its string name, or in the same - format as passed to - :meth:`~pymongo.collection.AsyncCollection.create_index` (e.g. - ``[('field', ASCENDING)]``). This option is only supported on - MongoDB 4.2 and above. - - .. versionchanged:: 3.11 - Added the `hint` option. - .. versionchanged:: 3.9 - Added the ability to accept a pipeline as the `update`. - .. versionchanged:: 3.6 - Added the `array_filters` option. - .. versionchanged:: 3.5 - Added the `collation` option. - """ - super().__init__(filter, update, upsert, collation, array_filters, hint) - - def _add_to_bulk(self, bulkobj: _Bulk) -> None: - """Add this operation to the _Bulk instance `bulkobj`.""" - bulkobj.add_update( - self._filter, - self._doc, - False, - self._upsert, - collation=validate_collation_or_none(self._collation), - array_filters=self._array_filters, - hint=self._hint, - ) - - -class UpdateMany(_UpdateOp): - """Represents an update_many operation.""" - - __slots__ = () - - def __init__( - self, - filter: Mapping[str, Any], - update: Union[Mapping[str, Any], _Pipeline], - upsert: bool = False, - collation: Optional[_CollationIn] = None, - array_filters: Optional[list[Mapping[str, Any]]] = None, - hint: Optional[_IndexKeyHint] = None, - ) -> None: - """Create an UpdateMany instance. - - For use with :meth:`~pymongo.collection.AsyncCollection.bulk_write`. - - :param filter: A query that matches the documents to update. - :param update: The modifications to apply. - :param upsert: If ``True``, perform an insert if no documents - match the filter. - :param collation: An instance of - :class:`~pymongo.collation.Collation`. - :param array_filters: A list of filters specifying which - array elements an update should apply. - :param hint: An index to use to support the query - predicate specified either by its string name, or in the same - format as passed to - :meth:`~pymongo.collection.AsyncCollection.create_index` (e.g. - ``[('field', ASCENDING)]``). This option is only supported on - MongoDB 4.2 and above. - - .. versionchanged:: 3.11 - Added the `hint` option. - .. versionchanged:: 3.9 - Added the ability to accept a pipeline as the `update`. - .. versionchanged:: 3.6 - Added the `array_filters` option. - .. versionchanged:: 3.5 - Added the `collation` option. - """ - super().__init__(filter, update, upsert, collation, array_filters, hint) - - def _add_to_bulk(self, bulkobj: _Bulk) -> None: - """Add this operation to the _Bulk instance `bulkobj`.""" - bulkobj.add_update( - self._filter, - self._doc, - True, - self._upsert, - collation=validate_collation_or_none(self._collation), - array_filters=self._array_filters, - hint=self._hint, - ) - - -class IndexModel: - """Represents an index to create.""" - - __slots__ = ("__document",) - - def __init__(self, keys: _IndexKeyHint, **kwargs: Any) -> None: - """Create an Index instance. - - For use with :meth:`~pymongo.collection.AsyncCollection.create_indexes`. - - Takes either a single key or a list containing (key, direction) pairs - or keys. If no direction is given, :data:`~pymongo.ASCENDING` will - be assumed. - The key(s) must be an instance of :class:`str`, and the direction(s) must - be one of (:data:`~pymongo.ASCENDING`, :data:`~pymongo.DESCENDING`, - :data:`~pymongo.GEO2D`, :data:`~pymongo.GEOSPHERE`, - :data:`~pymongo.HASHED`, :data:`~pymongo.TEXT`). - - Valid options include, but are not limited to: - - - `name`: custom name to use for this index - if none is - given, a name will be generated. - - `unique`: if ``True``, creates a uniqueness constraint on the index. - - `background`: if ``True``, this index should be created in the - background. - - `sparse`: if ``True``, omit from the index any documents that lack - the indexed field. - - `bucketSize`: for use with geoHaystack indexes. - Number of documents to group together within a certain proximity - to a given longitude and latitude. - - `min`: minimum value for keys in a :data:`~pymongo.GEO2D` - index. - - `max`: maximum value for keys in a :data:`~pymongo.GEO2D` - index. - - `expireAfterSeconds`: Used to create an expiring (TTL) - collection. MongoDB will automatically delete documents from - this collection after seconds. The indexed field must - be a UTC datetime or the data will not expire. - - `partialFilterExpression`: A document that specifies a filter for - a partial index. - - `collation`: An instance of :class:`~pymongo.collation.Collation` - that specifies the collation to use. - - `wildcardProjection`: Allows users to include or exclude specific - field paths from a `wildcard index`_ using the { "$**" : 1} key - pattern. Requires MongoDB >= 4.2. - - `hidden`: if ``True``, this index will be hidden from the query - planner and will not be evaluated as part of query plan - selection. Requires MongoDB >= 4.4. - - See the MongoDB documentation for a full list of supported options by - server version. - - :param keys: a single key or a list containing (key, direction) pairs - or keys specifying the index to create. - :param kwargs: any additional index creation - options (see the above list) should be passed as keyword - arguments. - - .. versionchanged:: 3.11 - Added the ``hidden`` option. - .. versionchanged:: 3.2 - Added the ``partialFilterExpression`` option to support partial - indexes. - - .. _wildcard index: https://mongodb.com/docs/master/core/index-wildcard/ - """ - keys = _index_list(keys) - if kwargs.get("name") is None: - kwargs["name"] = _gen_index_name(keys) - kwargs["key"] = _index_document(keys) - collation = validate_collation_or_none(kwargs.pop("collation", None)) - self.__document = kwargs - if collation is not None: - self.__document["collation"] = collation - - @property - def document(self) -> dict[str, Any]: - """An index document suitable for passing to the createIndexes - command. - """ - return self.__document - - -class SearchIndexModel: - """Represents a search index to create.""" - - __slots__ = ("__document",) - - def __init__( - self, - definition: Mapping[str, Any], - name: Optional[str] = None, - type: Optional[str] = None, - **kwargs: Any, - ) -> None: - """Create a Search Index instance. - - For use with :meth:`~pymongo.collection.AsyncCollection.create_search_index` and :meth:`~pymongo.collection.AsyncCollection.create_search_indexes`. - - :param definition: The definition for this index. - :param name: The name for this index, if present. - :param type: The type for this index which defaults to "search". Alternative values include "vectorSearch". - :param kwargs: Keyword arguments supplying any additional options. - - .. note:: Search indexes require a MongoDB server version 7.0+ Atlas cluster. - .. versionadded:: 4.5 - .. versionchanged:: 4.7 - Added the type and kwargs arguments. - """ - self.__document: dict[str, Any] = {} - if name is not None: - self.__document["name"] = name - self.__document["definition"] = definition - if type is not None: - self.__document["type"] = type - self.__document.update(kwargs) - - @property - def document(self) -> Mapping[str, Any]: - """The document for this index.""" - return self.__document diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index a4d3c50645..265da13187 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -16,17 +16,14 @@ import collections import contextlib -import copy import logging import os -import platform import socket import ssl import sys import threading import time import weakref -from pathlib import Path from typing import ( TYPE_CHECKING, Any, @@ -39,39 +36,18 @@ Union, ) -import bson from bson import DEFAULT_CODEC_OPTIONS -from pymongo import __version__, _csot -from pymongo.asynchronous import helpers +from pymongo import _csot, helpers_shared from pymongo.asynchronous.client_session import _validate_session_write_concern -from pymongo.asynchronous.common import ( +from pymongo.asynchronous.helpers import _handle_reauth +from pymongo.asynchronous.network import command, receive_message +from pymongo.common import ( MAX_BSON_SIZE, - MAX_CONNECTING, - MAX_IDLE_TIME_SEC, MAX_MESSAGE_SIZE, - MAX_POOL_SIZE, MAX_WIRE_VERSION, MAX_WRITE_BATCH_SIZE, - MIN_POOL_SIZE, ORDERED_TYPES, - WAIT_QUEUE_TIMEOUT, -) -from pymongo.asynchronous.hello import Hello -from pymongo.asynchronous.hello_compat import HelloCompat -from pymongo.asynchronous.helpers import _handle_reauth -from pymongo.asynchronous.logger import ( - _CONNECTION_LOGGER, - _ConnectionStatusMessage, - _debug_log, - _verbose_connection_error_reason, ) -from pymongo.asynchronous.monitoring import ( - ConnectionCheckOutFailedReason, - ConnectionClosedReason, - _EventListeners, -) -from pymongo.asynchronous.network import command, receive_message -from pymongo.asynchronous.read_preferences import ReadPreference from pymongo.errors import ( # type:ignore[attr-defined] AutoReconnect, ConfigurationError, @@ -86,8 +62,22 @@ WaitQueueTimeoutError, _CertificateError, ) +from pymongo.hello import Hello +from pymongo.hello_compat import HelloCompat from pymongo.lock import _ACondition, _ALock, _create_lock +from pymongo.logger import ( + _CONNECTION_LOGGER, + _ConnectionStatusMessage, + _debug_log, + _verbose_connection_error_reason, +) +from pymongo.monitoring import ( + ConnectionCheckOutFailedReason, + ConnectionClosedReason, +) from pymongo.network_layer import async_sendall +from pymongo.pool_options import PoolOptions +from pymongo.read_preferences import ReadPreference from pymongo.server_api import _add_to_command from pymongo.server_type import SERVER_TYPE from pymongo.socket_checker import SocketChecker @@ -96,22 +86,19 @@ if TYPE_CHECKING: from bson import CodecOptions from bson.objectid import ObjectId - from pymongo.asynchronous.auth import MongoCredential, _AuthContext - from pymongo.asynchronous.client_session import ClientSession - from pymongo.asynchronous.compression_support import ( - CompressionSettings, + from pymongo.asynchronous.auth import _AuthContext + from pymongo.asynchronous.client_session import AsyncClientSession + from pymongo.asynchronous.message import _OpMsg, _OpReply + from pymongo.asynchronous.mongo_client import AsyncMongoClient, _MongoClientErrorHandler + from pymongo.compression_support import ( SnappyContext, ZlibContext, ZstdContext, ) - from pymongo.asynchronous.message import _OpMsg, _OpReply - from pymongo.asynchronous.mongo_client import AsyncMongoClient, _MongoClientErrorHandler - from pymongo.asynchronous.read_preferences import _ServerMode - from pymongo.asynchronous.typings import ClusterTime, _Address, _CollationIn - from pymongo.driver_info import DriverInfo - from pymongo.pyopenssl_context import SSLContext, _sslConn + from pymongo.pyopenssl_context import _sslConn from pymongo.read_concern import ReadConcern - from pymongo.server_api import ServerApi + from pymongo.read_preferences import _ServerMode + from pymongo.typings import ClusterTime, _Address, _CollationIn from pymongo.write_concern import WriteConcern try: @@ -191,217 +178,6 @@ def _set_keepalive_times(sock: socket.socket) -> None: _set_tcp_option(sock, "TCP_KEEPCNT", _MAX_TCP_KEEPCNT) -_METADATA: dict[str, Any] = {"driver": {"name": "PyMongo", "version": __version__}} - -if sys.platform.startswith("linux"): - # platform.linux_distribution was deprecated in Python 3.5 - # and removed in Python 3.8. Starting in Python 3.5 it - # raises DeprecationWarning - # DeprecationWarning: dist() and linux_distribution() functions are deprecated in Python 3.5 - _name = platform.system() - _METADATA["os"] = { - "type": _name, - "name": _name, - "architecture": platform.machine(), - # Kernel version (e.g. 4.4.0-17-generic). - "version": platform.release(), - } -elif sys.platform == "darwin": - _METADATA["os"] = { - "type": platform.system(), - "name": platform.system(), - "architecture": platform.machine(), - # (mac|i|tv)OS(X) version (e.g. 10.11.6) instead of darwin - # kernel version. - "version": platform.mac_ver()[0], - } -elif sys.platform == "win32": - _METADATA["os"] = { - "type": platform.system(), - # "Windows XP", "Windows 7", "Windows 10", etc. - "name": " ".join((platform.system(), platform.release())), - "architecture": platform.machine(), - # Windows patch level (e.g. 5.1.2600-SP3) - "version": "-".join(platform.win32_ver()[1:3]), - } -elif sys.platform.startswith("java"): - _name, _ver, _arch = platform.java_ver()[-1] - _METADATA["os"] = { - # Linux, Windows 7, Mac OS X, etc. - "type": _name, - "name": _name, - # x86, x86_64, AMD64, etc. - "architecture": _arch, - # Linux kernel version, OSX version, etc. - "version": _ver, - } -else: - # Get potential alias (e.g. SunOS 5.11 becomes Solaris 2.11) - _aliased = platform.system_alias(platform.system(), platform.release(), platform.version()) - _METADATA["os"] = { - "type": platform.system(), - "name": " ".join([part for part in _aliased[:2] if part]), - "architecture": platform.machine(), - "version": _aliased[2], - } - -if platform.python_implementation().startswith("PyPy"): - _METADATA["platform"] = " ".join( - ( - platform.python_implementation(), - ".".join(map(str, sys.pypy_version_info)), # type: ignore - "(Python %s)" % ".".join(map(str, sys.version_info)), - ) - ) -elif sys.platform.startswith("java"): - _METADATA["platform"] = " ".join( - ( - platform.python_implementation(), - ".".join(map(str, sys.version_info)), - "(%s)" % " ".join((platform.system(), platform.release())), - ) - ) -else: - _METADATA["platform"] = " ".join( - (platform.python_implementation(), ".".join(map(str, sys.version_info))) - ) - -DOCKER_ENV_PATH = "/.dockerenv" -ENV_VAR_K8S = "KUBERNETES_SERVICE_HOST" - -RUNTIME_NAME_DOCKER = "docker" -ORCHESTRATOR_NAME_K8S = "kubernetes" - - -def get_container_env_info() -> dict[str, str]: - """Returns the runtime and orchestrator of a container. - If neither value is present, the metadata client.env.container field will be omitted.""" - container = {} - - if Path(DOCKER_ENV_PATH).exists(): - container["runtime"] = RUNTIME_NAME_DOCKER - if os.getenv(ENV_VAR_K8S): - container["orchestrator"] = ORCHESTRATOR_NAME_K8S - - return container - - -def _is_lambda() -> bool: - if os.getenv("AWS_LAMBDA_RUNTIME_API"): - return True - env = os.getenv("AWS_EXECUTION_ENV") - if env: - return env.startswith("AWS_Lambda_") - return False - - -def _is_azure_func() -> bool: - return bool(os.getenv("FUNCTIONS_WORKER_RUNTIME")) - - -def _is_gcp_func() -> bool: - return bool(os.getenv("K_SERVICE") or os.getenv("FUNCTION_NAME")) - - -def _is_vercel() -> bool: - return bool(os.getenv("VERCEL")) - - -def _is_faas() -> bool: - return _is_lambda() or _is_azure_func() or _is_gcp_func() or _is_vercel() - - -def _getenv_int(key: str) -> Optional[int]: - """Like os.getenv but returns an int, or None if the value is missing/malformed.""" - val = os.getenv(key) - if not val: - return None - try: - return int(val) - except ValueError: - return None - - -def _metadata_env() -> dict[str, Any]: - env: dict[str, Any] = {} - container = get_container_env_info() - if container: - env["container"] = container - # Skip if multiple (or no) envs are matched. - if (_is_lambda(), _is_azure_func(), _is_gcp_func(), _is_vercel()).count(True) != 1: - return env - if _is_lambda(): - env["name"] = "aws.lambda" - region = os.getenv("AWS_REGION") - if region: - env["region"] = region - memory_mb = _getenv_int("AWS_LAMBDA_FUNCTION_MEMORY_SIZE") - if memory_mb is not None: - env["memory_mb"] = memory_mb - elif _is_azure_func(): - env["name"] = "azure.func" - elif _is_gcp_func(): - env["name"] = "gcp.func" - region = os.getenv("FUNCTION_REGION") - if region: - env["region"] = region - memory_mb = _getenv_int("FUNCTION_MEMORY_MB") - if memory_mb is not None: - env["memory_mb"] = memory_mb - timeout_sec = _getenv_int("FUNCTION_TIMEOUT_SEC") - if timeout_sec is not None: - env["timeout_sec"] = timeout_sec - elif _is_vercel(): - env["name"] = "vercel" - region = os.getenv("VERCEL_REGION") - if region: - env["region"] = region - return env - - -_MAX_METADATA_SIZE = 512 - - -# See: https://github.com/mongodb/specifications/blob/5112bcc/source/mongodb-handshake/handshake.rst#limitations -def _truncate_metadata(metadata: MutableMapping[str, Any]) -> None: - """Perform metadata truncation.""" - if len(bson.encode(metadata)) <= _MAX_METADATA_SIZE: - return - # 1. Omit fields from env except env.name. - env_name = metadata.get("env", {}).get("name") - if env_name: - metadata["env"] = {"name": env_name} - if len(bson.encode(metadata)) <= _MAX_METADATA_SIZE: - return - # 2. Omit fields from os except os.type. - os_type = metadata.get("os", {}).get("type") - if os_type: - metadata["os"] = {"type": os_type} - if len(bson.encode(metadata)) <= _MAX_METADATA_SIZE: - return - # 3. Omit the env document entirely. - metadata.pop("env", None) - encoded_size = len(bson.encode(metadata)) - if encoded_size <= _MAX_METADATA_SIZE: - return - # 4. Truncate platform. - overflow = encoded_size - _MAX_METADATA_SIZE - plat = metadata.get("platform", "") - if plat: - plat = plat[:-overflow] - if plat: - metadata["platform"] = plat - else: - metadata.pop("platform", None) - - -# If the first getaddrinfo call of this interpreter's life is on a thread, -# while the main thread holds the import lock, getaddrinfo deadlocks trying -# to import the IDNA codec. Import it here, where presumably we're on the -# main thread, to avoid the deadlock. See PYTHON-607. -"foo".encode("idna") - - def _raise_connection_failure( address: Any, error: Exception, @@ -462,238 +238,6 @@ def format_timeout_details(details: Optional[dict[str, float]]) -> str: return result -class PoolOptions: - """Read only connection pool options for an AsyncMongoClient. - - Should not be instantiated directly by application developers. Access - a client's pool options via - :attr:`~pymongo.client_options.ClientOptions.pool_options` instead:: - - pool_opts = client.options.pool_options - pool_opts.max_pool_size - pool_opts.min_pool_size - - """ - - __slots__ = ( - "__max_pool_size", - "__min_pool_size", - "__max_idle_time_seconds", - "__connect_timeout", - "__socket_timeout", - "__wait_queue_timeout", - "__ssl_context", - "__tls_allow_invalid_hostnames", - "__event_listeners", - "__appname", - "__driver", - "__metadata", - "__compression_settings", - "__max_connecting", - "__pause_enabled", - "__server_api", - "__load_balanced", - "__credentials", - ) - - def __init__( - self, - max_pool_size: int = MAX_POOL_SIZE, - min_pool_size: int = MIN_POOL_SIZE, - max_idle_time_seconds: Optional[int] = MAX_IDLE_TIME_SEC, - connect_timeout: Optional[float] = None, - socket_timeout: Optional[float] = None, - wait_queue_timeout: Optional[int] = WAIT_QUEUE_TIMEOUT, - ssl_context: Optional[SSLContext] = None, - tls_allow_invalid_hostnames: bool = False, - event_listeners: Optional[_EventListeners] = None, - appname: Optional[str] = None, - driver: Optional[DriverInfo] = None, - compression_settings: Optional[CompressionSettings] = None, - max_connecting: int = MAX_CONNECTING, - pause_enabled: bool = True, - server_api: Optional[ServerApi] = None, - load_balanced: Optional[bool] = None, - credentials: Optional[MongoCredential] = None, - ): - self.__max_pool_size = max_pool_size - self.__min_pool_size = min_pool_size - self.__max_idle_time_seconds = max_idle_time_seconds - self.__connect_timeout = connect_timeout - self.__socket_timeout = socket_timeout - self.__wait_queue_timeout = wait_queue_timeout - self.__ssl_context = ssl_context - self.__tls_allow_invalid_hostnames = tls_allow_invalid_hostnames - self.__event_listeners = event_listeners - self.__appname = appname - self.__driver = driver - self.__compression_settings = compression_settings - self.__max_connecting = max_connecting - self.__pause_enabled = pause_enabled - self.__server_api = server_api - self.__load_balanced = load_balanced - self.__credentials = credentials - self.__metadata = copy.deepcopy(_METADATA) - if appname: - self.__metadata["application"] = {"name": appname} - - # Combine the "driver" AsyncMongoClient option with PyMongo's info, like: - # { - # 'driver': { - # 'name': 'PyMongo|MyDriver', - # 'version': '4.2.0|1.2.3', - # }, - # 'platform': 'CPython 3.8.0|MyPlatform' - # } - if driver: - if driver.name: - self.__metadata["driver"]["name"] = "{}|{}".format( - _METADATA["driver"]["name"], - driver.name, - ) - if driver.version: - self.__metadata["driver"]["version"] = "{}|{}".format( - _METADATA["driver"]["version"], - driver.version, - ) - if driver.platform: - self.__metadata["platform"] = "{}|{}".format(_METADATA["platform"], driver.platform) - - env = _metadata_env() - if env: - self.__metadata["env"] = env - - _truncate_metadata(self.__metadata) - - @property - def _credentials(self) -> Optional[MongoCredential]: - """A :class:`~pymongo.auth.MongoCredentials` instance or None.""" - return self.__credentials - - @property - def non_default_options(self) -> dict[str, Any]: - """The non-default options this pool was created with. - - Added for CMAP's :class:`PoolCreatedEvent`. - """ - opts = {} - if self.__max_pool_size != MAX_POOL_SIZE: - opts["maxPoolSize"] = self.__max_pool_size - if self.__min_pool_size != MIN_POOL_SIZE: - opts["minPoolSize"] = self.__min_pool_size - if self.__max_idle_time_seconds != MAX_IDLE_TIME_SEC: - assert self.__max_idle_time_seconds is not None - opts["maxIdleTimeMS"] = self.__max_idle_time_seconds * 1000 - if self.__wait_queue_timeout != WAIT_QUEUE_TIMEOUT: - assert self.__wait_queue_timeout is not None - opts["waitQueueTimeoutMS"] = self.__wait_queue_timeout * 1000 - if self.__max_connecting != MAX_CONNECTING: - opts["maxConnecting"] = self.__max_connecting - return opts - - @property - def max_pool_size(self) -> float: - """The maximum allowable number of concurrent connections to each - connected server. Requests to a server will block if there are - `maxPoolSize` outstanding connections to the requested server. - Defaults to 100. Cannot be 0. - - When a server's pool has reached `max_pool_size`, operations for that - server block waiting for a socket to be returned to the pool. If - ``waitQueueTimeoutMS`` is set, a blocked operation will raise - :exc:`~pymongo.errors.ConnectionFailure` after a timeout. - By default ``waitQueueTimeoutMS`` is not set. - """ - return self.__max_pool_size - - @property - def min_pool_size(self) -> int: - """The minimum required number of concurrent connections that the pool - will maintain to each connected server. Default is 0. - """ - return self.__min_pool_size - - @property - def max_connecting(self) -> int: - """The maximum number of concurrent connection creation attempts per - pool. Defaults to 2. - """ - return self.__max_connecting - - @property - def pause_enabled(self) -> bool: - return self.__pause_enabled - - @property - def max_idle_time_seconds(self) -> Optional[int]: - """The maximum number of seconds that a connection can remain - idle in the pool before being removed and replaced. Defaults to - `None` (no limit). - """ - return self.__max_idle_time_seconds - - @property - def connect_timeout(self) -> Optional[float]: - """How long a connection can take to be opened before timing out.""" - return self.__connect_timeout - - @property - def socket_timeout(self) -> Optional[float]: - """How long a send or receive on a socket can take before timing out.""" - return self.__socket_timeout - - @property - def wait_queue_timeout(self) -> Optional[int]: - """How long a thread will wait for a socket from the pool if the pool - has no free sockets. - """ - return self.__wait_queue_timeout - - @property - def _ssl_context(self) -> Optional[SSLContext]: - """An SSLContext instance or None.""" - return self.__ssl_context - - @property - def tls_allow_invalid_hostnames(self) -> bool: - """If True skip ssl.match_hostname.""" - return self.__tls_allow_invalid_hostnames - - @property - def _event_listeners(self) -> Optional[_EventListeners]: - """An instance of pymongo.monitoring._EventListeners.""" - return self.__event_listeners - - @property - def appname(self) -> Optional[str]: - """The application name, for sending with hello in server handshake.""" - return self.__appname - - @property - def driver(self) -> Optional[DriverInfo]: - """Driver name and version, for sending with hello in handshake.""" - return self.__driver - - @property - def _compression_settings(self) -> Optional[CompressionSettings]: - return self.__compression_settings - - @property - def metadata(self) -> dict[str, Any]: - """A dict of metadata about the application, driver, os, and platform.""" - return self.__metadata.copy() - - @property - def server_api(self) -> Optional[ServerApi]: - """A pymongo.server_api.ServerApi or None.""" - return self.__server_api - - @property - def load_balanced(self) -> Optional[bool]: - """True if this Pool is configured in load balanced mode.""" - return self.__load_balanced - - class _CancellationContext: def __init__(self) -> None: self._cancelled = False @@ -708,7 +252,7 @@ def cancelled(self) -> bool: return self._cancelled -class Connection: +class AsyncConnection: """Store a connection with some metadata. :param conn: a raw connection object @@ -926,7 +470,7 @@ async def _next_reply(self) -> dict[str, Any]: self.more_to_come = reply.more_to_come unpacked_docs = reply.unpack_response() response_doc = unpacked_docs[0] - helpers._check_command_response(response_doc, self.max_wire_version) + helpers_shared._check_command_response(response_doc, self.max_wire_version) return response_doc @_handle_reauth @@ -942,7 +486,7 @@ async def command( write_concern: Optional[WriteConcern] = None, parse_write_concern_error: bool = False, collation: Optional[_CollationIn] = None, - session: Optional[ClientSession] = None, + session: Optional[AsyncClientSession] = None, client: Optional[AsyncMongoClient] = None, retryable_write: bool = False, publish_events: bool = True, @@ -962,7 +506,7 @@ async def command( :param parse_write_concern_error: Whether to parse the ``writeConcernError`` field in the command response. :param collation: The collation for this command. - :param session: optional ClientSession instance. + :param session: optional AsyncClientSession instance. :param client: optional AsyncMongoClient for gossipping $clusterTime. :param retryable_write: True if this command is a retryable write. :param publish_events: Should we publish events for this command? @@ -1079,7 +623,7 @@ async def write_command( result = reply.command_response(codec_options) # Raises NotPrimaryError or OperationFailure. - helpers._check_command_response(result, self.max_wire_version) + helpers_shared._check_command_response(result, self.max_wire_version) return result async def authenticate(self, reauthenticate: bool = False) -> None: @@ -1117,7 +661,7 @@ async def authenticate(self, reauthenticate: bool = False) -> None: ) def validate_session( - self, client: Optional[AsyncMongoClient], session: Optional[ClientSession] + self, client: Optional[AsyncMongoClient], session: Optional[AsyncClientSession] ) -> None: """Validate this session before use with client. @@ -1169,7 +713,7 @@ def conn_closed(self) -> bool: def send_cluster_time( self, command: MutableMapping[str, Any], - session: Optional[ClientSession], + session: Optional[AsyncClientSession], client: Optional[AsyncMongoClient], ) -> None: """Add $clusterTime.""" @@ -1229,7 +773,7 @@ def __hash__(self) -> int: return hash(self.conn) def __repr__(self) -> str: - return "Connection({}){} at {}".format( + return "AsyncConnection({}){} at {}".format( repr(self.conn), self.closed and " CLOSED" or "", id(self), @@ -1420,7 +964,7 @@ def __init__( """ :param address: a (hostname, port) tuple :param options: a PoolOptions instance - :param handshake: whether to call hello for each new Connection + :param handshake: whether to call hello for each new AsyncConnection """ if options.pause_enabled: self.state = PoolState.PAUSED @@ -1490,7 +1034,7 @@ def __init__( # Retain references to pinned connections to prevent the CPython GC # from thinking that a cursor's pinned connection can be GC'd when the # cursor is GC'd (see PYTHON-2751). - self.__pinned_sockets: set[Connection] = set() + self.__pinned_sockets: set[AsyncConnection] = set() self.ncursors = 0 self.ntxns = 0 @@ -1677,8 +1221,8 @@ async def remove_stale_sockets(self, reference_generation: int) -> None: self.requests -= 1 self.size_cond.notify() - async def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> Connection: - """Connect to Mongo and return a new Connection. + async def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> AsyncConnection: + """Connect to Mongo and return a new AsyncConnection. Can raise ConnectionFailure. @@ -1728,7 +1272,7 @@ async def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> C raise - conn = Connection(sock, self, self.address, conn_id) # type: ignore[arg-type] + conn = AsyncConnection(sock, self, self.address, conn_id) # type: ignore[arg-type] async with self.lock: self.active_contexts.add(conn.cancel_context) try: @@ -1748,10 +1292,10 @@ async def connect(self, handler: Optional[_MongoClientErrorHandler] = None) -> C @contextlib.asynccontextmanager async def checkout( self, handler: Optional[_MongoClientErrorHandler] = None - ) -> AsyncGenerator[Connection, None]: + ) -> AsyncGenerator[AsyncConnection, None]: """Get a connection from the pool. Use with a "with" statement. - Returns a :class:`Connection` object wrapping a connected + Returns a :class:`AsyncConnection` object wrapping a connected :class:`socket.socket`. This method should always be used in a with-statement:: @@ -1850,8 +1394,8 @@ def _raise_if_not_ready(self, checkout_started_time: float, emit_event: bool) -> async def _get_conn( self, checkout_started_time: float, handler: Optional[_MongoClientErrorHandler] = None - ) -> Connection: - """Get or create a Connection. Can raise ConnectionFailure.""" + ) -> AsyncConnection: + """Get or create a AsyncConnection. Can raise ConnectionFailure.""" # We use the pid here to avoid issues with fork / multiprocessing. # See test.test_client:TestClient.test_fork for an example of # what could go wrong otherwise @@ -1872,7 +1416,7 @@ async def _get_conn( message=_ConnectionStatusMessage.CHECKOUT_FAILED, serverHost=self.address[0], serverPort=self.address[1], - reason="Connection pool was closed", + reason="AsyncConnection pool was closed", error=ConnectionCheckOutFailedReason.POOL_CLOSED, durationMS=duration, ) @@ -1973,7 +1517,7 @@ async def _get_conn( conn.active = True return conn - async def checkin(self, conn: Connection) -> None: + async def checkin(self, conn: AsyncConnection) -> None: """Return the connection to the pool, or if it's closed discard it. :param conn: The connection to check into the pool. @@ -2045,7 +1589,7 @@ async def checkin(self, conn: Connection) -> None: self.operation_count -= 1 self.size_cond.notify() - def _perished(self, conn: Connection) -> bool: + def _perished(self, conn: AsyncConnection) -> bool: """Return True and close the connection if it is "perished". This side-effecty function checks if this socket has been idle for diff --git a/pymongo/asynchronous/read_preferences.py b/pymongo/asynchronous/read_preferences.py deleted file mode 100644 index 8b6fb60753..0000000000 --- a/pymongo/asynchronous/read_preferences.py +++ /dev/null @@ -1,624 +0,0 @@ -# Copyright 2012-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License", -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Utilities for choosing which member of a replica set to read from.""" - -from __future__ import annotations - -from collections import abc -from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence - -from pymongo.asynchronous import max_staleness_selectors -from pymongo.asynchronous.server_selectors import ( - member_with_tags_server_selector, - secondary_with_tags_server_selector, -) -from pymongo.errors import ConfigurationError - -if TYPE_CHECKING: - from pymongo.asynchronous.server_selectors import Selection - from pymongo.asynchronous.topology_description import TopologyDescription - -_IS_SYNC = False - -_PRIMARY = 0 -_PRIMARY_PREFERRED = 1 -_SECONDARY = 2 -_SECONDARY_PREFERRED = 3 -_NEAREST = 4 - - -_MONGOS_MODES = ( - "primary", - "primaryPreferred", - "secondary", - "secondaryPreferred", - "nearest", -) - -_Hedge = Mapping[str, Any] -_TagSets = Sequence[Mapping[str, Any]] - - -def _validate_tag_sets(tag_sets: Optional[_TagSets]) -> Optional[_TagSets]: - """Validate tag sets for a MongoClient.""" - if tag_sets is None: - return tag_sets - - if not isinstance(tag_sets, (list, tuple)): - raise TypeError(f"Tag sets {tag_sets!r} invalid, must be a sequence") - if len(tag_sets) == 0: - raise ValueError( - f"Tag sets {tag_sets!r} invalid, must be None or contain at least one set of tags" - ) - - for tags in tag_sets: - if not isinstance(tags, abc.Mapping): - raise TypeError( - f"Tag set {tags!r} invalid, must be an instance of dict, " - "bson.son.SON or other type that inherits from " - "collection.Mapping" - ) - - return list(tag_sets) - - -def _invalid_max_staleness_msg(max_staleness: Any) -> str: - return "maxStalenessSeconds must be a positive integer, not %s" % max_staleness - - -# Some duplication with common.py to avoid import cycle. -def _validate_max_staleness(max_staleness: Any) -> int: - """Validate max_staleness.""" - if max_staleness == -1: - return -1 - - if not isinstance(max_staleness, int): - raise TypeError(_invalid_max_staleness_msg(max_staleness)) - - if max_staleness <= 0: - raise ValueError(_invalid_max_staleness_msg(max_staleness)) - - return max_staleness - - -def _validate_hedge(hedge: Optional[_Hedge]) -> Optional[_Hedge]: - """Validate hedge.""" - if hedge is None: - return None - - if not isinstance(hedge, dict): - raise TypeError(f"hedge must be a dictionary, not {hedge!r}") - - return hedge - - -class _ServerMode: - """Base class for all read preferences.""" - - __slots__ = ("__mongos_mode", "__mode", "__tag_sets", "__max_staleness", "__hedge") - - def __init__( - self, - mode: int, - tag_sets: Optional[_TagSets] = None, - max_staleness: int = -1, - hedge: Optional[_Hedge] = None, - ) -> None: - self.__mongos_mode = _MONGOS_MODES[mode] - self.__mode = mode - self.__tag_sets = _validate_tag_sets(tag_sets) - self.__max_staleness = _validate_max_staleness(max_staleness) - self.__hedge = _validate_hedge(hedge) - - @property - def name(self) -> str: - """The name of this read preference.""" - return self.__class__.__name__ - - @property - def mongos_mode(self) -> str: - """The mongos mode of this read preference.""" - return self.__mongos_mode - - @property - def document(self) -> dict[str, Any]: - """Read preference as a document.""" - doc: dict[str, Any] = {"mode": self.__mongos_mode} - if self.__tag_sets not in (None, [{}]): - doc["tags"] = self.__tag_sets - if self.__max_staleness != -1: - doc["maxStalenessSeconds"] = self.__max_staleness - if self.__hedge not in (None, {}): - doc["hedge"] = self.__hedge - return doc - - @property - def mode(self) -> int: - """The mode of this read preference instance.""" - return self.__mode - - @property - def tag_sets(self) -> _TagSets: - """Set ``tag_sets`` to a list of dictionaries like [{'dc': 'ny'}] to - read only from members whose ``dc`` tag has the value ``"ny"``. - To specify a priority-order for tag sets, provide a list of - tag sets: ``[{'dc': 'ny'}, {'dc': 'la'}, {}]``. A final, empty tag - set, ``{}``, means "read from any member that matches the mode, - ignoring tags." MongoClient tries each set of tags in turn - until it finds a set of tags with at least one matching member. - For example, to only send a query to an analytic node:: - - Nearest(tag_sets=[{"node":"analytics"}]) - - Or using :class:`SecondaryPreferred`:: - - SecondaryPreferred(tag_sets=[{"node":"analytics"}]) - - .. seealso:: `Data-Center Awareness - `_ - """ - return list(self.__tag_sets) if self.__tag_sets else [{}] - - @property - def max_staleness(self) -> int: - """The maximum estimated length of time (in seconds) a replica set - secondary can fall behind the primary in replication before it will - no longer be selected for operations, or -1 for no maximum. - """ - return self.__max_staleness - - @property - def hedge(self) -> Optional[_Hedge]: - """The read preference ``hedge`` parameter. - - A dictionary that configures how the server will perform hedged reads. - It consists of the following keys: - - - ``enabled``: Enables or disables hedged reads in sharded clusters. - - Hedged reads are automatically enabled in MongoDB 4.4+ when using a - ``nearest`` read preference. To explicitly enable hedged reads, set - the ``enabled`` key to ``true``:: - - >>> Nearest(hedge={'enabled': True}) - - To explicitly disable hedged reads, set the ``enabled`` key to - ``False``:: - - >>> Nearest(hedge={'enabled': False}) - - .. versionadded:: 3.11 - """ - return self.__hedge - - @property - def min_wire_version(self) -> int: - """The wire protocol version the server must support. - - Some read preferences impose version requirements on all servers (e.g. - maxStalenessSeconds requires MongoDB 3.4 / maxWireVersion 5). - - All servers' maxWireVersion must be at least this read preference's - `min_wire_version`, or the driver raises - :exc:`~pymongo.errors.ConfigurationError`. - """ - return 0 if self.__max_staleness == -1 else 5 - - def __repr__(self) -> str: - return "{}(tag_sets={!r}, max_staleness={!r}, hedge={!r})".format( - self.name, - self.__tag_sets, - self.__max_staleness, - self.__hedge, - ) - - def __eq__(self, other: Any) -> bool: - if isinstance(other, _ServerMode): - return ( - self.mode == other.mode - and self.tag_sets == other.tag_sets - and self.max_staleness == other.max_staleness - and self.hedge == other.hedge - ) - return NotImplemented - - def __ne__(self, other: Any) -> bool: - return not self == other - - def __getstate__(self) -> dict[str, Any]: - """Return value of object for pickling. - - Needed explicitly because __slots__() defined. - """ - return { - "mode": self.__mode, - "tag_sets": self.__tag_sets, - "max_staleness": self.__max_staleness, - "hedge": self.__hedge, - } - - def __setstate__(self, value: Mapping[str, Any]) -> None: - """Restore from pickling.""" - self.__mode = value["mode"] - self.__mongos_mode = _MONGOS_MODES[self.__mode] - self.__tag_sets = _validate_tag_sets(value["tag_sets"]) - self.__max_staleness = _validate_max_staleness(value["max_staleness"]) - self.__hedge = _validate_hedge(value["hedge"]) - - def __call__(self, selection: Selection) -> Selection: - return selection - - -class Primary(_ServerMode): - """Primary read preference. - - * When directly connected to one mongod queries are allowed if the server - is standalone or a replica set primary. - * When connected to a mongos queries are sent to the primary of a shard. - * When connected to a replica set queries are sent to the primary of - the replica set. - """ - - __slots__ = () - - def __init__(self) -> None: - super().__init__(_PRIMARY) - - def __call__(self, selection: Selection) -> Selection: - """Apply this read preference to a Selection.""" - return selection.primary_selection - - def __repr__(self) -> str: - return "Primary()" - - def __eq__(self, other: Any) -> bool: - if isinstance(other, _ServerMode): - return other.mode == _PRIMARY - return NotImplemented - - -class PrimaryPreferred(_ServerMode): - """PrimaryPreferred read preference. - - * When directly connected to one mongod queries are allowed to standalone - servers, to a replica set primary, or to replica set secondaries. - * When connected to a mongos queries are sent to the primary of a shard if - available, otherwise a shard secondary. - * When connected to a replica set queries are sent to the primary if - available, otherwise a secondary. - - .. note:: When a :class:`~pymongo.mongo_client.MongoClient` is first - created reads will be routed to an available secondary until the - primary of the replica set is discovered. - - :param tag_sets: The :attr:`~tag_sets` to use if the primary is not - available. - :param max_staleness: (integer, in seconds) The maximum estimated - length of time a replica set secondary can fall behind the primary in - replication before it will no longer be selected for operations. - Default -1, meaning no maximum. If it is set, it must be at least - 90 seconds. - :param hedge: The :attr:`~hedge` to use if the primary is not available. - - .. versionchanged:: 3.11 - Added ``hedge`` parameter. - """ - - __slots__ = () - - def __init__( - self, - tag_sets: Optional[_TagSets] = None, - max_staleness: int = -1, - hedge: Optional[_Hedge] = None, - ) -> None: - super().__init__(_PRIMARY_PREFERRED, tag_sets, max_staleness, hedge) - - def __call__(self, selection: Selection) -> Selection: - """Apply this read preference to Selection.""" - if selection.primary: - return selection.primary_selection - else: - return secondary_with_tags_server_selector( - self.tag_sets, max_staleness_selectors.select(self.max_staleness, selection) - ) - - -class Secondary(_ServerMode): - """Secondary read preference. - - * When directly connected to one mongod queries are allowed to standalone - servers, to a replica set primary, or to replica set secondaries. - * When connected to a mongos queries are distributed among shard - secondaries. An error is raised if no secondaries are available. - * When connected to a replica set queries are distributed among - secondaries. An error is raised if no secondaries are available. - - :param tag_sets: The :attr:`~tag_sets` for this read preference. - :param max_staleness: (integer, in seconds) The maximum estimated - length of time a replica set secondary can fall behind the primary in - replication before it will no longer be selected for operations. - Default -1, meaning no maximum. If it is set, it must be at least - 90 seconds. - :param hedge: The :attr:`~hedge` for this read preference. - - .. versionchanged:: 3.11 - Added ``hedge`` parameter. - """ - - __slots__ = () - - def __init__( - self, - tag_sets: Optional[_TagSets] = None, - max_staleness: int = -1, - hedge: Optional[_Hedge] = None, - ) -> None: - super().__init__(_SECONDARY, tag_sets, max_staleness, hedge) - - def __call__(self, selection: Selection) -> Selection: - """Apply this read preference to Selection.""" - return secondary_with_tags_server_selector( - self.tag_sets, max_staleness_selectors.select(self.max_staleness, selection) - ) - - -class SecondaryPreferred(_ServerMode): - """SecondaryPreferred read preference. - - * When directly connected to one mongod queries are allowed to standalone - servers, to a replica set primary, or to replica set secondaries. - * When connected to a mongos queries are distributed among shard - secondaries, or the shard primary if no secondary is available. - * When connected to a replica set queries are distributed among - secondaries, or the primary if no secondary is available. - - .. note:: When a :class:`~pymongo.mongo_client.MongoClient` is first - created reads will be routed to the primary of the replica set until - an available secondary is discovered. - - :param tag_sets: The :attr:`~tag_sets` for this read preference. - :param max_staleness: (integer, in seconds) The maximum estimated - length of time a replica set secondary can fall behind the primary in - replication before it will no longer be selected for operations. - Default -1, meaning no maximum. If it is set, it must be at least - 90 seconds. - :param hedge: The :attr:`~hedge` for this read preference. - - .. versionchanged:: 3.11 - Added ``hedge`` parameter. - """ - - __slots__ = () - - def __init__( - self, - tag_sets: Optional[_TagSets] = None, - max_staleness: int = -1, - hedge: Optional[_Hedge] = None, - ) -> None: - super().__init__(_SECONDARY_PREFERRED, tag_sets, max_staleness, hedge) - - def __call__(self, selection: Selection) -> Selection: - """Apply this read preference to Selection.""" - secondaries = secondary_with_tags_server_selector( - self.tag_sets, max_staleness_selectors.select(self.max_staleness, selection) - ) - - if secondaries: - return secondaries - else: - return selection.primary_selection - - -class Nearest(_ServerMode): - """Nearest read preference. - - * When directly connected to one mongod queries are allowed to standalone - servers, to a replica set primary, or to replica set secondaries. - * When connected to a mongos queries are distributed among all members of - a shard. - * When connected to a replica set queries are distributed among all - members. - - :param tag_sets: The :attr:`~tag_sets` for this read preference. - :param max_staleness: (integer, in seconds) The maximum estimated - length of time a replica set secondary can fall behind the primary in - replication before it will no longer be selected for operations. - Default -1, meaning no maximum. If it is set, it must be at least - 90 seconds. - :param hedge: The :attr:`~hedge` for this read preference. - - .. versionchanged:: 3.11 - Added ``hedge`` parameter. - """ - - __slots__ = () - - def __init__( - self, - tag_sets: Optional[_TagSets] = None, - max_staleness: int = -1, - hedge: Optional[_Hedge] = None, - ) -> None: - super().__init__(_NEAREST, tag_sets, max_staleness, hedge) - - def __call__(self, selection: Selection) -> Selection: - """Apply this read preference to Selection.""" - return member_with_tags_server_selector( - self.tag_sets, max_staleness_selectors.select(self.max_staleness, selection) - ) - - -class _AggWritePref: - """Agg $out/$merge write preference. - - * If there are readable servers and there is any pre-5.0 server, use - primary read preference. - * Otherwise use `pref` read preference. - - :param pref: The read preference to use on MongoDB 5.0+. - """ - - __slots__ = ("pref", "effective_pref") - - def __init__(self, pref: _ServerMode): - self.pref = pref - self.effective_pref: _ServerMode = ReadPreference.PRIMARY - - def selection_hook(self, topology_description: TopologyDescription) -> None: - common_wv = topology_description.common_wire_version - if ( - topology_description.has_readable_server(ReadPreference.PRIMARY_PREFERRED) - and common_wv - and common_wv < 13 - ): - self.effective_pref = ReadPreference.PRIMARY - else: - self.effective_pref = self.pref - - def __call__(self, selection: Selection) -> Selection: - """Apply this read preference to a Selection.""" - return self.effective_pref(selection) - - def __repr__(self) -> str: - return f"_AggWritePref(pref={self.pref!r})" - - # Proxy other calls to the effective_pref so that _AggWritePref can be - # used in place of an actual read preference. - def __getattr__(self, name: str) -> Any: - return getattr(self.effective_pref, name) - - -_ALL_READ_PREFERENCES = (Primary, PrimaryPreferred, Secondary, SecondaryPreferred, Nearest) - - -def make_read_preference( - mode: int, tag_sets: Optional[_TagSets], max_staleness: int = -1 -) -> _ServerMode: - if mode == _PRIMARY: - if tag_sets not in (None, [{}]): - raise ConfigurationError("Read preference primary cannot be combined with tags") - if max_staleness != -1: - raise ConfigurationError( - "Read preference primary cannot be combined with maxStalenessSeconds" - ) - return Primary() - return _ALL_READ_PREFERENCES[mode](tag_sets, max_staleness) # type: ignore - - -_MODES = ( - "PRIMARY", - "PRIMARY_PREFERRED", - "SECONDARY", - "SECONDARY_PREFERRED", - "NEAREST", -) - - -class ReadPreference: - """An enum that defines some commonly used read preference modes. - - Apps can also create a custom read preference, for example:: - - Nearest(tag_sets=[{"node":"analytics"}]) - - See :doc:`/examples/high_availability` for code examples. - - A read preference is used in three cases: - - :class:`~pymongo.mongo_client.MongoClient` connected to a single mongod: - - - ``PRIMARY``: Queries are allowed if the server is standalone or a replica - set primary. - - All other modes allow queries to standalone servers, to a replica set - primary, or to replica set secondaries. - - :class:`~pymongo.mongo_client.MongoClient` initialized with the - ``replicaSet`` option: - - - ``PRIMARY``: Read from the primary. This is the default, and provides the - strongest consistency. If no primary is available, raise - :class:`~pymongo.errors.AutoReconnect`. - - - ``PRIMARY_PREFERRED``: Read from the primary if available, or if there is - none, read from a secondary. - - - ``SECONDARY``: Read from a secondary. If no secondary is available, - raise :class:`~pymongo.errors.AutoReconnect`. - - - ``SECONDARY_PREFERRED``: Read from a secondary if available, otherwise - from the primary. - - - ``NEAREST``: Read from any member. - - :class:`~pymongo.mongo_client.MongoClient` connected to a mongos, with a - sharded cluster of replica sets: - - - ``PRIMARY``: Read from the primary of the shard, or raise - :class:`~pymongo.errors.OperationFailure` if there is none. - This is the default. - - - ``PRIMARY_PREFERRED``: Read from the primary of the shard, or if there is - none, read from a secondary of the shard. - - - ``SECONDARY``: Read from a secondary of the shard, or raise - :class:`~pymongo.errors.OperationFailure` if there is none. - - - ``SECONDARY_PREFERRED``: Read from a secondary of the shard if available, - otherwise from the shard primary. - - - ``NEAREST``: Read from any shard member. - """ - - PRIMARY = Primary() - PRIMARY_PREFERRED = PrimaryPreferred() - SECONDARY = Secondary() - SECONDARY_PREFERRED = SecondaryPreferred() - NEAREST = Nearest() - - -def read_pref_mode_from_name(name: str) -> int: - """Get the read preference mode from mongos/uri name.""" - return _MONGOS_MODES.index(name) - - -class MovingAverage: - """Tracks an exponentially-weighted moving average.""" - - average: Optional[float] - - def __init__(self) -> None: - self.average = None - - def add_sample(self, sample: float) -> None: - if sample < 0: - # Likely system time change while waiting for hello response - # and not using time.monotonic. Ignore it, the next one will - # probably be valid. - return - if self.average is None: - self.average = sample - else: - # The Server Selection Spec requires an exponentially weighted - # average with alpha = 0.2. - self.average = 0.8 * self.average + 0.2 * sample - - def get(self) -> Optional[float]: - """Get the calculated average, or None if no samples yet.""" - return self.average - - def reset(self) -> None: - self.average = None diff --git a/pymongo/asynchronous/server.py b/pymongo/asynchronous/server.py index cf812d05c7..0e6ae1574d 100644 --- a/pymongo/asynchronous/server.py +++ b/pymongo/asynchronous/server.py @@ -27,11 +27,12 @@ ) from bson import _decode_all_selective -from pymongo.asynchronous.helpers import _check_command_response, _handle_reauth -from pymongo.asynchronous.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log +from pymongo.asynchronous.helpers import _handle_reauth from pymongo.asynchronous.message import _convert_exception, _GetMore, _OpMsg, _Query -from pymongo.asynchronous.response import PinnedResponse, Response from pymongo.errors import NotPrimaryError, OperationFailure +from pymongo.helpers_shared import _check_command_response +from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log +from pymongo.response import PinnedResponse, Response if TYPE_CHECKING: from queue import Queue @@ -40,11 +41,11 @@ from bson.objectid import ObjectId from pymongo.asynchronous.mongo_client import AsyncMongoClient, _MongoClientErrorHandler from pymongo.asynchronous.monitor import Monitor - from pymongo.asynchronous.monitoring import _EventListeners - from pymongo.asynchronous.pool import Connection, Pool - from pymongo.asynchronous.read_preferences import _ServerMode - from pymongo.asynchronous.server_description import ServerDescription - from pymongo.asynchronous.typings import _DocumentOut + from pymongo.asynchronous.pool import AsyncConnection, Pool + from pymongo.monitoring import _EventListeners + from pymongo.read_preferences import _ServerMode + from pymongo.server_description import ServerDescription + from pymongo.typings import _DocumentOut _IS_SYNC = False @@ -108,7 +109,7 @@ def request_check(self) -> None: @_handle_reauth async def run_operation( self, - conn: Connection, + conn: AsyncConnection, operation: Union[_Query, _GetMore], read_preference: _ServerMode, listeners: Optional[_EventListeners], @@ -121,7 +122,7 @@ async def run_operation( cursors. Can raise ConnectionFailure, OperationFailure, etc. - :param conn: A Connection instance. + :param conn: A AsyncConnection instance. :param operation: A _Query or _GetMore object. :param read_preference: The read preference to use. :param listeners: Instance of _EventListeners or None. @@ -321,7 +322,7 @@ async def run_operation( async def checkout( self, handler: Optional[_MongoClientErrorHandler] = None - ) -> AsyncContextManager[Connection]: + ) -> AsyncContextManager[AsyncConnection]: return self.pool.checkout(handler) @property diff --git a/pymongo/asynchronous/server_description.py b/pymongo/asynchronous/server_description.py deleted file mode 100644 index 8e15c34006..0000000000 --- a/pymongo/asynchronous/server_description.py +++ /dev/null @@ -1,301 +0,0 @@ -# Copyright 2014-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Represent one server the driver is connected to.""" -from __future__ import annotations - -import time -import warnings -from typing import Any, Mapping, Optional - -from bson import EPOCH_NAIVE -from bson.objectid import ObjectId -from pymongo.asynchronous.hello import Hello -from pymongo.asynchronous.typings import ClusterTime, _Address -from pymongo.server_type import SERVER_TYPE - -_IS_SYNC = False - - -class ServerDescription: - """Immutable representation of one server. - - :param address: A (host, port) pair - :param hello: Optional Hello instance - :param round_trip_time: Optional float - :param error: Optional, the last error attempting to connect to the server - :param round_trip_time: Optional float, the min latency from the most recent samples - """ - - __slots__ = ( - "_address", - "_server_type", - "_all_hosts", - "_tags", - "_replica_set_name", - "_primary", - "_max_bson_size", - "_max_message_size", - "_max_write_batch_size", - "_min_wire_version", - "_max_wire_version", - "_round_trip_time", - "_min_round_trip_time", - "_me", - "_is_writable", - "_is_readable", - "_ls_timeout_minutes", - "_error", - "_set_version", - "_election_id", - "_cluster_time", - "_last_write_date", - "_last_update_time", - "_topology_version", - ) - - def __init__( - self, - address: _Address, - hello: Optional[Hello] = None, - round_trip_time: Optional[float] = None, - error: Optional[Exception] = None, - min_round_trip_time: float = 0.0, - ) -> None: - self._address = address - if not hello: - hello = Hello({}) - - self._server_type = hello.server_type - self._all_hosts = hello.all_hosts - self._tags = hello.tags - self._replica_set_name = hello.replica_set_name - self._primary = hello.primary - self._max_bson_size = hello.max_bson_size - self._max_message_size = hello.max_message_size - self._max_write_batch_size = hello.max_write_batch_size - self._min_wire_version = hello.min_wire_version - self._max_wire_version = hello.max_wire_version - self._set_version = hello.set_version - self._election_id = hello.election_id - self._cluster_time = hello.cluster_time - self._is_writable = hello.is_writable - self._is_readable = hello.is_readable - self._ls_timeout_minutes = hello.logical_session_timeout_minutes - self._round_trip_time = round_trip_time - self._min_round_trip_time = min_round_trip_time - self._me = hello.me - self._last_update_time = time.monotonic() - self._error = error - self._topology_version = hello.topology_version - if error: - details = getattr(error, "details", None) - if isinstance(details, dict): - self._topology_version = details.get("topologyVersion") - - self._last_write_date: Optional[float] - if hello.last_write_date: - # Convert from datetime to seconds. - delta = hello.last_write_date - EPOCH_NAIVE - self._last_write_date = delta.total_seconds() - else: - self._last_write_date = None - - @property - def address(self) -> _Address: - """The address (host, port) of this server.""" - return self._address - - @property - def server_type(self) -> int: - """The type of this server.""" - return self._server_type - - @property - def server_type_name(self) -> str: - """The server type as a human readable string. - - .. versionadded:: 3.4 - """ - return SERVER_TYPE._fields[self._server_type] - - @property - def all_hosts(self) -> set[tuple[str, int]]: - """List of hosts, passives, and arbiters known to this server.""" - return self._all_hosts - - @property - def tags(self) -> Mapping[str, Any]: - return self._tags - - @property - def replica_set_name(self) -> Optional[str]: - """Replica set name or None.""" - return self._replica_set_name - - @property - def primary(self) -> Optional[tuple[str, int]]: - """This server's opinion about who the primary is, or None.""" - return self._primary - - @property - def max_bson_size(self) -> int: - return self._max_bson_size - - @property - def max_message_size(self) -> int: - return self._max_message_size - - @property - def max_write_batch_size(self) -> int: - return self._max_write_batch_size - - @property - def min_wire_version(self) -> int: - return self._min_wire_version - - @property - def max_wire_version(self) -> int: - return self._max_wire_version - - @property - def set_version(self) -> Optional[int]: - return self._set_version - - @property - def election_id(self) -> Optional[ObjectId]: - return self._election_id - - @property - def cluster_time(self) -> Optional[ClusterTime]: - return self._cluster_time - - @property - def election_tuple(self) -> tuple[Optional[int], Optional[ObjectId]]: - warnings.warn( - "'election_tuple' is deprecated, use 'set_version' and 'election_id' instead", - DeprecationWarning, - stacklevel=2, - ) - return self._set_version, self._election_id - - @property - def me(self) -> Optional[tuple[str, int]]: - return self._me - - @property - def logical_session_timeout_minutes(self) -> Optional[int]: - return self._ls_timeout_minutes - - @property - def last_write_date(self) -> Optional[float]: - return self._last_write_date - - @property - def last_update_time(self) -> float: - return self._last_update_time - - @property - def round_trip_time(self) -> Optional[float]: - """The current average latency or None.""" - # This override is for unittesting only! - if self._address in self._host_to_round_trip_time: - return self._host_to_round_trip_time[self._address] - - return self._round_trip_time - - @property - def min_round_trip_time(self) -> float: - """The min latency from the most recent samples.""" - return self._min_round_trip_time - - @property - def error(self) -> Optional[Exception]: - """The last error attempting to connect to the server, or None.""" - return self._error - - @property - def is_writable(self) -> bool: - return self._is_writable - - @property - def is_readable(self) -> bool: - return self._is_readable - - @property - def mongos(self) -> bool: - return self._server_type == SERVER_TYPE.Mongos - - @property - def is_server_type_known(self) -> bool: - return self.server_type != SERVER_TYPE.Unknown - - @property - def retryable_writes_supported(self) -> bool: - """Checks if this server supports retryable writes.""" - return ( - self._server_type in (SERVER_TYPE.Mongos, SERVER_TYPE.RSPrimary) - ) or self._server_type == SERVER_TYPE.LoadBalancer - - @property - def retryable_reads_supported(self) -> bool: - """Checks if this server supports retryable writes.""" - return self._max_wire_version >= 6 - - @property - def topology_version(self) -> Optional[Mapping[str, Any]]: - return self._topology_version - - def to_unknown(self, error: Optional[Exception] = None) -> ServerDescription: - unknown = ServerDescription(self.address, error=error) - unknown._topology_version = self.topology_version - return unknown - - def __eq__(self, other: Any) -> bool: - if isinstance(other, ServerDescription): - return ( - (self._address == other.address) - and (self._server_type == other.server_type) - and (self._min_wire_version == other.min_wire_version) - and (self._max_wire_version == other.max_wire_version) - and (self._me == other.me) - and (self._all_hosts == other.all_hosts) - and (self._tags == other.tags) - and (self._replica_set_name == other.replica_set_name) - and (self._set_version == other.set_version) - and (self._election_id == other.election_id) - and (self._primary == other.primary) - and (self._ls_timeout_minutes == other.logical_session_timeout_minutes) - and (self._error == other.error) - ) - - return NotImplemented - - def __ne__(self, other: Any) -> bool: - return not self == other - - def __repr__(self) -> str: - errmsg = "" - if self.error: - errmsg = f", error={self.error!r}" - return "<{} {} server_type: {}, rtt: {}{}>".format( - self.__class__.__name__, - self.address, - self.server_type_name, - self.round_trip_time, - errmsg, - ) - - # For unittesting only. Use under no circumstances! - _host_to_round_trip_time: dict = {} diff --git a/pymongo/asynchronous/settings.py b/pymongo/asynchronous/settings.py index f88235cf59..c41c638e6c 100644 --- a/pymongo/asynchronous/settings.py +++ b/pymongo/asynchronous/settings.py @@ -20,12 +20,14 @@ from typing import Any, Collection, Optional, Type, Union from bson.objectid import ObjectId -from pymongo.asynchronous import common, monitor, pool -from pymongo.asynchronous.common import LOCAL_THRESHOLD_MS, SERVER_SELECTION_TIMEOUT -from pymongo.asynchronous.pool import Pool, PoolOptions -from pymongo.asynchronous.server_description import ServerDescription -from pymongo.asynchronous.topology_description import TOPOLOGY_TYPE, _ServerSelector +from pymongo import common +from pymongo.asynchronous import monitor, pool +from pymongo.asynchronous.pool import Pool +from pymongo.common import LOCAL_THRESHOLD_MS, SERVER_SELECTION_TIMEOUT from pymongo.errors import ConfigurationError +from pymongo.pool_options import PoolOptions +from pymongo.server_description import ServerDescription +from pymongo.topology_description import TOPOLOGY_TYPE, _ServerSelector _IS_SYNC = False diff --git a/pymongo/asynchronous/topology.py b/pymongo/asynchronous/topology.py index df6dd903a7..487f8de116 100644 --- a/pymongo/asynchronous/topology.py +++ b/pymongo/asynchronous/topology.py @@ -27,49 +27,50 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Mapping, Optional, cast -from pymongo import _csot, helpers_constants -from pymongo.asynchronous import common, periodic_executor +from pymongo import _csot, common, helpers_shared +from pymongo.asynchronous import periodic_executor from pymongo.asynchronous.client_session import _ServerSession, _ServerSessionPool -from pymongo.asynchronous.hello import Hello -from pymongo.asynchronous.logger import ( +from pymongo.asynchronous.monitor import SrvMonitor +from pymongo.asynchronous.pool import Pool +from pymongo.asynchronous.server import Server +from pymongo.errors import ( + ConnectionFailure, + InvalidOperation, + NetworkTimeout, + NotPrimaryError, + OperationFailure, + PyMongoError, + ServerSelectionTimeoutError, + WriteError, +) +from pymongo.hello import Hello +from pymongo.lock import _ACondition, _ALock, _create_lock +from pymongo.logger import ( _SERVER_SELECTION_LOGGER, _debug_log, _ServerSelectionStatusMessage, ) -from pymongo.asynchronous.monitor import SrvMonitor -from pymongo.asynchronous.pool import Pool, PoolOptions -from pymongo.asynchronous.server import Server -from pymongo.asynchronous.server_description import ServerDescription -from pymongo.asynchronous.server_selectors import ( +from pymongo.pool_options import PoolOptions +from pymongo.server_description import ServerDescription +from pymongo.server_selectors import ( Selection, any_server_selector, arbiter_server_selector, secondary_server_selector, writable_server_selector, ) -from pymongo.asynchronous.topology_description import ( +from pymongo.topology_description import ( SRV_POLLING_TOPOLOGIES, TOPOLOGY_TYPE, TopologyDescription, _updated_topology_description_srv_polling, updated_topology_description, ) -from pymongo.errors import ( - ConnectionFailure, - InvalidOperation, - NetworkTimeout, - NotPrimaryError, - OperationFailure, - PyMongoError, - ServerSelectionTimeoutError, - WriteError, -) -from pymongo.lock import _ACondition, _ALock, _create_lock if TYPE_CHECKING: from bson import ObjectId from pymongo.asynchronous.settings import TopologySettings - from pymongo.asynchronous.typings import ClusterTime, _Address + from pymongo.typings import ClusterTime, _Address _IS_SYNC = False @@ -790,8 +791,8 @@ async def _handle_error(self, address: _Address, err_ctx: _ErrorContext) -> None # Default error code if one does not exist. default = 10107 if isinstance(error, NotPrimaryError) else None err_code = error.details.get("code", default) # type: ignore[union-attr] - if err_code in helpers_constants._NOT_PRIMARY_CODES: - is_shutting_down = err_code in helpers_constants._SHUTDOWN_CODES + if err_code in helpers_shared._NOT_PRIMARY_CODES: + is_shutting_down = err_code in helpers_shared._SHUTDOWN_CODES # Mark server Unknown, clear the pool, and request check. if not self._settings.load_balanced: await self._process_change(ServerDescription(address, error=error)) diff --git a/pymongo/asynchronous/topology_description.py b/pymongo/asynchronous/topology_description.py deleted file mode 100644 index ce7aff7f51..0000000000 --- a/pymongo/asynchronous/topology_description.py +++ /dev/null @@ -1,678 +0,0 @@ -# Copyright 2014-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you -# may not use this file except in compliance with the License. You -# may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or -# implied. See the License for the specific language governing -# permissions and limitations under the License. - -"""Represent a deployment of MongoDB servers.""" -from __future__ import annotations - -from random import sample -from typing import ( - Any, - Callable, - List, - Mapping, - MutableMapping, - NamedTuple, - Optional, - cast, -) - -from bson.min_key import MinKey -from bson.objectid import ObjectId -from pymongo.asynchronous import common -from pymongo.asynchronous.read_preferences import ReadPreference, _AggWritePref, _ServerMode -from pymongo.asynchronous.server_description import ServerDescription -from pymongo.asynchronous.server_selectors import Selection -from pymongo.asynchronous.typings import _Address -from pymongo.errors import ConfigurationError -from pymongo.server_type import SERVER_TYPE - -_IS_SYNC = False - - -# Enumeration for various kinds of MongoDB cluster topologies. -class _TopologyType(NamedTuple): - Single: int - ReplicaSetNoPrimary: int - ReplicaSetWithPrimary: int - Sharded: int - Unknown: int - LoadBalanced: int - - -TOPOLOGY_TYPE = _TopologyType(*range(6)) - -# Topologies compatible with SRV record polling. -SRV_POLLING_TOPOLOGIES: tuple[int, int] = (TOPOLOGY_TYPE.Unknown, TOPOLOGY_TYPE.Sharded) - - -_ServerSelector = Callable[[List[ServerDescription]], List[ServerDescription]] - - -class TopologyDescription: - def __init__( - self, - topology_type: int, - server_descriptions: dict[_Address, ServerDescription], - replica_set_name: Optional[str], - max_set_version: Optional[int], - max_election_id: Optional[ObjectId], - topology_settings: Any, - ) -> None: - """Representation of a deployment of MongoDB servers. - - :param topology_type: initial type - :param server_descriptions: dict of (address, ServerDescription) for - all seeds - :param replica_set_name: replica set name or None - :param max_set_version: greatest setVersion seen from a primary, or None - :param max_election_id: greatest electionId seen from a primary, or None - :param topology_settings: a TopologySettings - """ - self._topology_type = topology_type - self._replica_set_name = replica_set_name - self._server_descriptions = server_descriptions - self._max_set_version = max_set_version - self._max_election_id = max_election_id - - # The heartbeat_frequency is used in staleness estimates. - self._topology_settings = topology_settings - - # Is PyMongo compatible with all servers' wire protocols? - self._incompatible_err = None - if self._topology_type != TOPOLOGY_TYPE.LoadBalanced: - self._init_incompatible_err() - - # Server Discovery And Monitoring Spec: Whenever a client updates the - # TopologyDescription from an hello response, it MUST set - # TopologyDescription.logicalSessionTimeoutMinutes to the smallest - # logicalSessionTimeoutMinutes value among ServerDescriptions of all - # data-bearing server types. If any have a null - # logicalSessionTimeoutMinutes, then - # TopologyDescription.logicalSessionTimeoutMinutes MUST be set to null. - readable_servers = self.readable_servers - if not readable_servers: - self._ls_timeout_minutes = None - elif any(s.logical_session_timeout_minutes is None for s in readable_servers): - self._ls_timeout_minutes = None - else: - self._ls_timeout_minutes = min( # type: ignore[type-var] - s.logical_session_timeout_minutes for s in readable_servers - ) - - def _init_incompatible_err(self) -> None: - """Internal compatibility check for non-load balanced topologies.""" - for s in self._server_descriptions.values(): - if not s.is_server_type_known: - continue - - # s.min/max_wire_version is the server's wire protocol. - # MIN/MAX_SUPPORTED_WIRE_VERSION is what PyMongo supports. - server_too_new = ( - # Server too new. - s.min_wire_version is not None - and s.min_wire_version > common.MAX_SUPPORTED_WIRE_VERSION - ) - - server_too_old = ( - # Server too old. - s.max_wire_version is not None - and s.max_wire_version < common.MIN_SUPPORTED_WIRE_VERSION - ) - - if server_too_new: - self._incompatible_err = ( - "Server at %s:%d requires wire version %d, but this " # type: ignore - "version of PyMongo only supports up to %d." - % ( - s.address[0], - s.address[1] or 0, - s.min_wire_version, - common.MAX_SUPPORTED_WIRE_VERSION, - ) - ) - - elif server_too_old: - self._incompatible_err = ( - "Server at %s:%d reports wire version %d, but this " # type: ignore - "version of PyMongo requires at least %d (MongoDB %s)." - % ( - s.address[0], - s.address[1] or 0, - s.max_wire_version, - common.MIN_SUPPORTED_WIRE_VERSION, - common.MIN_SUPPORTED_SERVER_VERSION, - ) - ) - - break - - def check_compatible(self) -> None: - """Raise ConfigurationError if any server is incompatible. - - A server is incompatible if its wire protocol version range does not - overlap with PyMongo's. - """ - if self._incompatible_err: - raise ConfigurationError(self._incompatible_err) - - def has_server(self, address: _Address) -> bool: - return address in self._server_descriptions - - def reset_server(self, address: _Address) -> TopologyDescription: - """A copy of this description, with one server marked Unknown.""" - unknown_sd = self._server_descriptions[address].to_unknown() - return updated_topology_description(self, unknown_sd) - - def reset(self) -> TopologyDescription: - """A copy of this description, with all servers marked Unknown.""" - if self._topology_type == TOPOLOGY_TYPE.ReplicaSetWithPrimary: - topology_type = TOPOLOGY_TYPE.ReplicaSetNoPrimary - else: - topology_type = self._topology_type - - # The default ServerDescription's type is Unknown. - sds = {address: ServerDescription(address) for address in self._server_descriptions} - - return TopologyDescription( - topology_type, - sds, - self._replica_set_name, - self._max_set_version, - self._max_election_id, - self._topology_settings, - ) - - def server_descriptions(self) -> dict[_Address, ServerDescription]: - """dict of (address, - :class:`~pymongo.server_description.ServerDescription`). - """ - return self._server_descriptions.copy() - - @property - def topology_type(self) -> int: - """The type of this topology.""" - return self._topology_type - - @property - def topology_type_name(self) -> str: - """The topology type as a human readable string. - - .. versionadded:: 3.4 - """ - return TOPOLOGY_TYPE._fields[self._topology_type] - - @property - def replica_set_name(self) -> Optional[str]: - """The replica set name.""" - return self._replica_set_name - - @property - def max_set_version(self) -> Optional[int]: - """Greatest setVersion seen from a primary, or None.""" - return self._max_set_version - - @property - def max_election_id(self) -> Optional[ObjectId]: - """Greatest electionId seen from a primary, or None.""" - return self._max_election_id - - @property - def logical_session_timeout_minutes(self) -> Optional[int]: - """Minimum logical session timeout, or None.""" - return self._ls_timeout_minutes - - @property - def known_servers(self) -> list[ServerDescription]: - """List of Servers of types besides Unknown.""" - return [s for s in self._server_descriptions.values() if s.is_server_type_known] - - @property - def has_known_servers(self) -> bool: - """Whether there are any Servers of types besides Unknown.""" - return any(s for s in self._server_descriptions.values() if s.is_server_type_known) - - @property - def readable_servers(self) -> list[ServerDescription]: - """List of readable Servers.""" - return [s for s in self._server_descriptions.values() if s.is_readable] - - @property - def common_wire_version(self) -> Optional[int]: - """Minimum of all servers' max wire versions, or None.""" - servers = self.known_servers - if servers: - return min(s.max_wire_version for s in self.known_servers) - - return None - - @property - def heartbeat_frequency(self) -> int: - return self._topology_settings.heartbeat_frequency - - @property - def srv_max_hosts(self) -> int: - return self._topology_settings._srv_max_hosts - - def _apply_local_threshold(self, selection: Optional[Selection]) -> list[ServerDescription]: - if not selection: - return [] - round_trip_times: list[float] = [] - for server in selection.server_descriptions: - if server.round_trip_time is None: - config_err_msg = f"round_trip_time for server {server.address} is unexpectedly None: {self}, servers: {selection.server_descriptions}" - raise ConfigurationError(config_err_msg) - round_trip_times.append(server.round_trip_time) - # Round trip time in seconds. - fastest = min(round_trip_times) - threshold = self._topology_settings.local_threshold_ms / 1000.0 - return [ - s - for s in selection.server_descriptions - if (cast(float, s.round_trip_time) - fastest) <= threshold - ] - - def apply_selector( - self, - selector: Any, - address: Optional[_Address] = None, - custom_selector: Optional[_ServerSelector] = None, - ) -> list[ServerDescription]: - """List of servers matching the provided selector(s). - - :param selector: a callable that takes a Selection as input and returns - a Selection as output. For example, an instance of a read - preference from :mod:`~pymongo.read_preferences`. - :param address: A server address to select. - :param custom_selector: A callable that augments server - selection rules. Accepts a list of - :class:`~pymongo.server_description.ServerDescription` objects and - return a list of server descriptions that should be considered - suitable for the desired operation. - - .. versionadded:: 3.4 - """ - if getattr(selector, "min_wire_version", 0): - common_wv = self.common_wire_version - if common_wv and common_wv < selector.min_wire_version: - raise ConfigurationError( - "%s requires min wire version %d, but topology's min" - " wire version is %d" % (selector, selector.min_wire_version, common_wv) - ) - - if isinstance(selector, _AggWritePref): - selector.selection_hook(self) - - if self.topology_type == TOPOLOGY_TYPE.Unknown: - return [] - elif self.topology_type in (TOPOLOGY_TYPE.Single, TOPOLOGY_TYPE.LoadBalanced): - # Ignore selectors for standalone and load balancer mode. - return self.known_servers - if address: - # Ignore selectors when explicit address is requested. - description = self.server_descriptions().get(address) - return [description] if description else [] - - selection = Selection.from_topology_description(self) - # Ignore read preference for sharded clusters. - if self.topology_type != TOPOLOGY_TYPE.Sharded: - selection = selector(selection) - - # Apply custom selector followed by localThresholdMS. - if custom_selector is not None and selection: - selection = selection.with_server_descriptions( - custom_selector(selection.server_descriptions) - ) - return self._apply_local_threshold(selection) - - def has_readable_server(self, read_preference: _ServerMode = ReadPreference.PRIMARY) -> bool: - """Does this topology have any readable servers available matching the - given read preference? - - :param read_preference: an instance of a read preference from - :mod:`~pymongo.read_preferences`. Defaults to - :attr:`~pymongo.read_preferences.ReadPreference.PRIMARY`. - - .. note:: When connected directly to a single server this method - always returns ``True``. - - .. versionadded:: 3.4 - """ - common.validate_read_preference("read_preference", read_preference) - return any(self.apply_selector(read_preference)) - - def has_writable_server(self) -> bool: - """Does this topology have a writable server available? - - .. note:: When connected directly to a single server this method - always returns ``True``. - - .. versionadded:: 3.4 - """ - return self.has_readable_server(ReadPreference.PRIMARY) - - def __repr__(self) -> str: - # Sort the servers by address. - servers = sorted(self._server_descriptions.values(), key=lambda sd: sd.address) - return "<{} id: {}, topology_type: {}, servers: {!r}>".format( - self.__class__.__name__, - self._topology_settings._topology_id, - self.topology_type_name, - servers, - ) - - -# If topology type is Unknown and we receive a hello response, what should -# the new topology type be? -_SERVER_TYPE_TO_TOPOLOGY_TYPE = { - SERVER_TYPE.Mongos: TOPOLOGY_TYPE.Sharded, - SERVER_TYPE.RSPrimary: TOPOLOGY_TYPE.ReplicaSetWithPrimary, - SERVER_TYPE.RSSecondary: TOPOLOGY_TYPE.ReplicaSetNoPrimary, - SERVER_TYPE.RSArbiter: TOPOLOGY_TYPE.ReplicaSetNoPrimary, - SERVER_TYPE.RSOther: TOPOLOGY_TYPE.ReplicaSetNoPrimary, - # Note: SERVER_TYPE.LoadBalancer and Unknown are intentionally left out. -} - - -def updated_topology_description( - topology_description: TopologyDescription, server_description: ServerDescription -) -> TopologyDescription: - """Return an updated copy of a TopologyDescription. - - :param topology_description: the current TopologyDescription - :param server_description: a new ServerDescription that resulted from - a hello call - - Called after attempting (successfully or not) to call hello on the - server at server_description.address. Does not modify topology_description. - """ - address = server_description.address - - # These values will be updated, if necessary, to form the new - # TopologyDescription. - topology_type = topology_description.topology_type - set_name = topology_description.replica_set_name - max_set_version = topology_description.max_set_version - max_election_id = topology_description.max_election_id - server_type = server_description.server_type - - # Don't mutate the original dict of server descriptions; copy it. - sds = topology_description.server_descriptions() - - # Replace this server's description with the new one. - sds[address] = server_description - - if topology_type == TOPOLOGY_TYPE.Single: - # Set server type to Unknown if replica set name does not match. - if set_name is not None and set_name != server_description.replica_set_name: - error = ConfigurationError( - "client is configured to connect to a replica set named " - "'{}' but this node belongs to a set named '{}'".format( - set_name, server_description.replica_set_name - ) - ) - sds[address] = server_description.to_unknown(error=error) - # Single type never changes. - return TopologyDescription( - TOPOLOGY_TYPE.Single, - sds, - set_name, - max_set_version, - max_election_id, - topology_description._topology_settings, - ) - - if topology_type == TOPOLOGY_TYPE.Unknown: - if server_type in (SERVER_TYPE.Standalone, SERVER_TYPE.LoadBalancer): - if len(topology_description._topology_settings.seeds) == 1: - topology_type = TOPOLOGY_TYPE.Single - else: - # Remove standalone from Topology when given multiple seeds. - sds.pop(address) - elif server_type not in (SERVER_TYPE.Unknown, SERVER_TYPE.RSGhost): - topology_type = _SERVER_TYPE_TO_TOPOLOGY_TYPE[server_type] - - if topology_type == TOPOLOGY_TYPE.Sharded: - if server_type not in (SERVER_TYPE.Mongos, SERVER_TYPE.Unknown): - sds.pop(address) - - elif topology_type == TOPOLOGY_TYPE.ReplicaSetNoPrimary: - if server_type in (SERVER_TYPE.Standalone, SERVER_TYPE.Mongos): - sds.pop(address) - - elif server_type == SERVER_TYPE.RSPrimary: - (topology_type, set_name, max_set_version, max_election_id) = _update_rs_from_primary( - sds, set_name, server_description, max_set_version, max_election_id - ) - - elif server_type in (SERVER_TYPE.RSSecondary, SERVER_TYPE.RSArbiter, SERVER_TYPE.RSOther): - topology_type, set_name = _update_rs_no_primary_from_member( - sds, set_name, server_description - ) - - elif topology_type == TOPOLOGY_TYPE.ReplicaSetWithPrimary: - if server_type in (SERVER_TYPE.Standalone, SERVER_TYPE.Mongos): - sds.pop(address) - topology_type = _check_has_primary(sds) - - elif server_type == SERVER_TYPE.RSPrimary: - (topology_type, set_name, max_set_version, max_election_id) = _update_rs_from_primary( - sds, set_name, server_description, max_set_version, max_election_id - ) - - elif server_type in (SERVER_TYPE.RSSecondary, SERVER_TYPE.RSArbiter, SERVER_TYPE.RSOther): - topology_type = _update_rs_with_primary_from_member(sds, set_name, server_description) - - else: - # Server type is Unknown or RSGhost: did we just lose the primary? - topology_type = _check_has_primary(sds) - - # Return updated copy. - return TopologyDescription( - topology_type, - sds, - set_name, - max_set_version, - max_election_id, - topology_description._topology_settings, - ) - - -def _updated_topology_description_srv_polling( - topology_description: TopologyDescription, seedlist: list[tuple[str, Any]] -) -> TopologyDescription: - """Return an updated copy of a TopologyDescription. - - :param topology_description: the current TopologyDescription - :param seedlist: a list of new seeds new ServerDescription that resulted from - a hello call - """ - assert topology_description.topology_type in SRV_POLLING_TOPOLOGIES - # Create a copy of the server descriptions. - sds = topology_description.server_descriptions() - - # If seeds haven't changed, don't do anything. - if set(sds.keys()) == set(seedlist): - return topology_description - - # Remove SDs corresponding to servers no longer part of the SRV record. - for address in list(sds.keys()): - if address not in seedlist: - sds.pop(address) - - if topology_description.srv_max_hosts != 0: - new_hosts = set(seedlist) - set(sds.keys()) - n_to_add = topology_description.srv_max_hosts - len(sds) - if n_to_add > 0: - seedlist = sample(sorted(new_hosts), min(n_to_add, len(new_hosts))) - else: - seedlist = [] - # Add SDs corresponding to servers recently added to the SRV record. - for address in seedlist: - if address not in sds: - sds[address] = ServerDescription(address) - return TopologyDescription( - topology_description.topology_type, - sds, - topology_description.replica_set_name, - topology_description.max_set_version, - topology_description.max_election_id, - topology_description._topology_settings, - ) - - -def _update_rs_from_primary( - sds: MutableMapping[_Address, ServerDescription], - replica_set_name: Optional[str], - server_description: ServerDescription, - max_set_version: Optional[int], - max_election_id: Optional[ObjectId], -) -> tuple[int, Optional[str], Optional[int], Optional[ObjectId]]: - """Update topology description from a primary's hello response. - - Pass in a dict of ServerDescriptions, current replica set name, the - ServerDescription we are processing, and the TopologyDescription's - max_set_version and max_election_id if any. - - Returns (new topology type, new replica_set_name, new max_set_version, - new max_election_id). - """ - if replica_set_name is None: - replica_set_name = server_description.replica_set_name - - elif replica_set_name != server_description.replica_set_name: - # We found a primary but it doesn't have the replica_set_name - # provided by the user. - sds.pop(server_description.address) - return _check_has_primary(sds), replica_set_name, max_set_version, max_election_id - - if server_description.max_wire_version is None or server_description.max_wire_version < 17: - new_election_tuple: tuple = (server_description.set_version, server_description.election_id) - max_election_tuple: tuple = (max_set_version, max_election_id) - if None not in new_election_tuple: - if None not in max_election_tuple and new_election_tuple < max_election_tuple: - # Stale primary, set to type Unknown. - sds[server_description.address] = server_description.to_unknown() - return _check_has_primary(sds), replica_set_name, max_set_version, max_election_id - max_election_id = server_description.election_id - - if server_description.set_version is not None and ( - max_set_version is None or server_description.set_version > max_set_version - ): - max_set_version = server_description.set_version - else: - new_election_tuple = server_description.election_id, server_description.set_version - max_election_tuple = max_election_id, max_set_version - new_election_safe = tuple(MinKey() if i is None else i for i in new_election_tuple) - max_election_safe = tuple(MinKey() if i is None else i for i in max_election_tuple) - if new_election_safe < max_election_safe: - # Stale primary, set to type Unknown. - sds[server_description.address] = server_description.to_unknown() - return _check_has_primary(sds), replica_set_name, max_set_version, max_election_id - else: - max_election_id = server_description.election_id - max_set_version = server_description.set_version - - # We've heard from the primary. Is it the same primary as before? - for server in sds.values(): - if ( - server.server_type is SERVER_TYPE.RSPrimary - and server.address != server_description.address - ): - # Reset old primary's type to Unknown. - sds[server.address] = server.to_unknown() - - # There can be only one prior primary. - break - - # Discover new hosts from this primary's response. - for new_address in server_description.all_hosts: - if new_address not in sds: - sds[new_address] = ServerDescription(new_address) - - # Remove hosts not in the response. - for addr in set(sds) - server_description.all_hosts: - sds.pop(addr) - - # If the host list differs from the seed list, we may not have a primary - # after all. - return (_check_has_primary(sds), replica_set_name, max_set_version, max_election_id) - - -def _update_rs_with_primary_from_member( - sds: MutableMapping[_Address, ServerDescription], - replica_set_name: Optional[str], - server_description: ServerDescription, -) -> int: - """RS with known primary. Process a response from a non-primary. - - Pass in a dict of ServerDescriptions, current replica set name, and the - ServerDescription we are processing. - - Returns new topology type. - """ - assert replica_set_name is not None - - if replica_set_name != server_description.replica_set_name: - sds.pop(server_description.address) - elif server_description.me and server_description.address != server_description.me: - sds.pop(server_description.address) - - # Had this member been the primary? - return _check_has_primary(sds) - - -def _update_rs_no_primary_from_member( - sds: MutableMapping[_Address, ServerDescription], - replica_set_name: Optional[str], - server_description: ServerDescription, -) -> tuple[int, Optional[str]]: - """RS without known primary. Update from a non-primary's response. - - Pass in a dict of ServerDescriptions, current replica set name, and the - ServerDescription we are processing. - - Returns (new topology type, new replica_set_name). - """ - topology_type = TOPOLOGY_TYPE.ReplicaSetNoPrimary - if replica_set_name is None: - replica_set_name = server_description.replica_set_name - - elif replica_set_name != server_description.replica_set_name: - sds.pop(server_description.address) - return topology_type, replica_set_name - - # This isn't the primary's response, so don't remove any servers - # it doesn't report. Only add new servers. - for address in server_description.all_hosts: - if address not in sds: - sds[address] = ServerDescription(address) - - if server_description.me and server_description.address != server_description.me: - sds.pop(server_description.address) - - return topology_type, replica_set_name - - -def _check_has_primary(sds: Mapping[_Address, ServerDescription]) -> int: - """Current topology type is ReplicaSetWithPrimary. Is primary still known? - - Pass in a dict of ServerDescriptions. - - Returns new topology type. - """ - for s in sds.values(): - if s.server_type == SERVER_TYPE.RSPrimary: - return TOPOLOGY_TYPE.ReplicaSetWithPrimary - else: # noqa: PLW0120 - return TOPOLOGY_TYPE.ReplicaSetNoPrimary diff --git a/pymongo/asynchronous/uri_parser.py b/pymongo/asynchronous/uri_parser.py deleted file mode 100644 index b5fde6c30c..0000000000 --- a/pymongo/asynchronous/uri_parser.py +++ /dev/null @@ -1,624 +0,0 @@ -# Copyright 2011-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you -# may not use this file except in compliance with the License. You -# may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or -# implied. See the License for the specific language governing -# permissions and limitations under the License. - - -"""Tools to parse and validate a MongoDB URI.""" -from __future__ import annotations - -import re -import sys -import warnings -from typing import ( - TYPE_CHECKING, - Any, - Mapping, - MutableMapping, - Optional, - Sized, - Union, - cast, -) -from urllib.parse import unquote_plus - -from pymongo.asynchronous.client_options import _parse_ssl_options -from pymongo.asynchronous.common import ( - INTERNAL_URI_OPTION_NAME_MAP, - SRV_SERVICE_NAME, - URI_OPTIONS_DEPRECATION_MAP, - _CaseInsensitiveDictionary, - get_validated_options, -) -from pymongo.asynchronous.srv_resolver import _have_dnspython, _SrvResolver -from pymongo.asynchronous.typings import _Address -from pymongo.errors import ConfigurationError, InvalidURI - -if TYPE_CHECKING: - from pymongo.pyopenssl_context import SSLContext - -_IS_SYNC = False -SCHEME = "mongodb://" -SCHEME_LEN = len(SCHEME) -SRV_SCHEME = "mongodb+srv://" -SRV_SCHEME_LEN = len(SRV_SCHEME) -DEFAULT_PORT = 27017 - - -def _unquoted_percent(s: str) -> bool: - """Check for unescaped percent signs. - - :param s: A string. `s` can have things like '%25', '%2525', - and '%E2%85%A8' but cannot have unquoted percent like '%foo'. - """ - for i in range(len(s)): - if s[i] == "%": - sub = s[i : i + 3] - # If unquoting yields the same string this means there was an - # unquoted %. - if unquote_plus(sub) == sub: - return True - return False - - -def parse_userinfo(userinfo: str) -> tuple[str, str]: - """Validates the format of user information in a MongoDB URI. - Reserved characters that are gen-delimiters (":", "/", "?", "#", "[", - "]", "@") as per RFC 3986 must be escaped. - - Returns a 2-tuple containing the unescaped username followed - by the unescaped password. - - :param userinfo: A string of the form : - """ - if "@" in userinfo or userinfo.count(":") > 1 or _unquoted_percent(userinfo): - raise InvalidURI( - "Username and password must be escaped according to " - "RFC 3986, use urllib.parse.quote_plus" - ) - - user, _, passwd = userinfo.partition(":") - # No password is expected with GSSAPI authentication. - if not user: - raise InvalidURI("The empty string is not valid username.") - - return unquote_plus(user), unquote_plus(passwd) - - -def parse_ipv6_literal_host( - entity: str, default_port: Optional[int] -) -> tuple[str, Optional[Union[str, int]]]: - """Validates an IPv6 literal host:port string. - - Returns a 2-tuple of IPv6 literal followed by port where - port is default_port if it wasn't specified in entity. - - :param entity: A string that represents an IPv6 literal enclosed - in braces (e.g. '[::1]' or '[::1]:27017'). - :param default_port: The port number to use when one wasn't - specified in entity. - """ - if entity.find("]") == -1: - raise ValueError( - "an IPv6 address literal must be enclosed in '[' and ']' according to RFC 2732." - ) - i = entity.find("]:") - if i == -1: - return entity[1:-1], default_port - return entity[1:i], entity[i + 2 :] - - -def parse_host(entity: str, default_port: Optional[int] = DEFAULT_PORT) -> _Address: - """Validates a host string - - Returns a 2-tuple of host followed by port where port is default_port - if it wasn't specified in the string. - - :param entity: A host or host:port string where host could be a - hostname or IP address. - :param default_port: The port number to use when one wasn't - specified in entity. - """ - host = entity - port: Optional[Union[str, int]] = default_port - if entity[0] == "[": - host, port = parse_ipv6_literal_host(entity, default_port) - elif entity.endswith(".sock"): - return entity, default_port - elif entity.find(":") != -1: - if entity.count(":") > 1: - raise ValueError( - "Reserved characters such as ':' must be " - "escaped according RFC 2396. An IPv6 " - "address literal must be enclosed in '[' " - "and ']' according to RFC 2732." - ) - host, port = host.split(":", 1) - if isinstance(port, str): - if not port.isdigit() or int(port) > 65535 or int(port) <= 0: - raise ValueError(f"Port must be an integer between 0 and 65535: {port!r}") - port = int(port) - - # Normalize hostname to lowercase, since DNS is case-insensitive: - # http://tools.ietf.org/html/rfc4343 - # This prevents useless rediscovery if "foo.com" is in the seed list but - # "FOO.com" is in the hello response. - return host.lower(), port - - -# Options whose values are implicitly determined by tlsInsecure. -_IMPLICIT_TLSINSECURE_OPTS = { - "tlsallowinvalidcertificates", - "tlsallowinvalidhostnames", - "tlsdisableocspendpointcheck", -} - - -def _parse_options(opts: str, delim: Optional[str]) -> _CaseInsensitiveDictionary: - """Helper method for split_options which creates the options dict. - Also handles the creation of a list for the URI tag_sets/ - readpreferencetags portion, and the use of a unicode options string. - """ - options = _CaseInsensitiveDictionary() - for uriopt in opts.split(delim): - key, value = uriopt.split("=") - if key.lower() == "readpreferencetags": - options.setdefault(key, []).append(value) - else: - if key in options: - warnings.warn(f"Duplicate URI option '{key}'.", stacklevel=2) - if key.lower() == "authmechanismproperties": - val = value - else: - val = unquote_plus(value) - options[key] = val - - return options - - -def _handle_security_options(options: _CaseInsensitiveDictionary) -> _CaseInsensitiveDictionary: - """Raise appropriate errors when conflicting TLS options are present in - the options dictionary. - - :param options: Instance of _CaseInsensitiveDictionary containing - MongoDB URI options. - """ - # Implicitly defined options must not be explicitly specified. - tlsinsecure = options.get("tlsinsecure") - if tlsinsecure is not None: - for opt in _IMPLICIT_TLSINSECURE_OPTS: - if opt in options: - err_msg = "URI options %s and %s cannot be specified simultaneously." - raise InvalidURI( - err_msg % (options.cased_key("tlsinsecure"), options.cased_key(opt)) - ) - - # Handle co-occurence of OCSP & tlsAllowInvalidCertificates options. - tlsallowinvalidcerts = options.get("tlsallowinvalidcertificates") - if tlsallowinvalidcerts is not None: - if "tlsdisableocspendpointcheck" in options: - err_msg = "URI options %s and %s cannot be specified simultaneously." - raise InvalidURI( - err_msg - % ("tlsallowinvalidcertificates", options.cased_key("tlsdisableocspendpointcheck")) - ) - if tlsallowinvalidcerts is True: - options["tlsdisableocspendpointcheck"] = True - - # Handle co-occurence of CRL and OCSP-related options. - tlscrlfile = options.get("tlscrlfile") - if tlscrlfile is not None: - for opt in ("tlsinsecure", "tlsallowinvalidcertificates", "tlsdisableocspendpointcheck"): - if options.get(opt) is True: - err_msg = "URI option %s=True cannot be specified when CRL checking is enabled." - raise InvalidURI(err_msg % (opt,)) - - if "ssl" in options and "tls" in options: - - def truth_value(val: Any) -> Any: - if val in ("true", "false"): - return val == "true" - if isinstance(val, bool): - return val - return val - - if truth_value(options.get("ssl")) != truth_value(options.get("tls")): - err_msg = "Can not specify conflicting values for URI options %s and %s." - raise InvalidURI(err_msg % (options.cased_key("ssl"), options.cased_key("tls"))) - - return options - - -def _handle_option_deprecations(options: _CaseInsensitiveDictionary) -> _CaseInsensitiveDictionary: - """Issue appropriate warnings when deprecated options are present in the - options dictionary. Removes deprecated option key, value pairs if the - options dictionary is found to also have the renamed option. - - :param options: Instance of _CaseInsensitiveDictionary containing - MongoDB URI options. - """ - for optname in list(options): - if optname in URI_OPTIONS_DEPRECATION_MAP: - mode, message = URI_OPTIONS_DEPRECATION_MAP[optname] - if mode == "renamed": - newoptname = message - if newoptname in options: - warn_msg = "Deprecated option '%s' ignored in favor of '%s'." - warnings.warn( - warn_msg % (options.cased_key(optname), options.cased_key(newoptname)), - DeprecationWarning, - stacklevel=2, - ) - options.pop(optname) - continue - warn_msg = "Option '%s' is deprecated, use '%s' instead." - warnings.warn( - warn_msg % (options.cased_key(optname), newoptname), - DeprecationWarning, - stacklevel=2, - ) - elif mode == "removed": - warn_msg = "Option '%s' is deprecated. %s." - warnings.warn( - warn_msg % (options.cased_key(optname), message), - DeprecationWarning, - stacklevel=2, - ) - - return options - - -def _normalize_options(options: _CaseInsensitiveDictionary) -> _CaseInsensitiveDictionary: - """Normalizes option names in the options dictionary by converting them to - their internally-used names. - - :param options: Instance of _CaseInsensitiveDictionary containing - MongoDB URI options. - """ - # Expand the tlsInsecure option. - tlsinsecure = options.get("tlsinsecure") - if tlsinsecure is not None: - for opt in _IMPLICIT_TLSINSECURE_OPTS: - # Implicit options are logically the same as tlsInsecure. - options[opt] = tlsinsecure - - for optname in list(options): - intname = INTERNAL_URI_OPTION_NAME_MAP.get(optname, None) - if intname is not None: - options[intname] = options.pop(optname) - - return options - - -def validate_options(opts: Mapping[str, Any], warn: bool = False) -> MutableMapping[str, Any]: - """Validates and normalizes options passed in a MongoDB URI. - - Returns a new dictionary of validated and normalized options. If warn is - False then errors will be thrown for invalid options, otherwise they will - be ignored and a warning will be issued. - - :param opts: A dict of MongoDB URI options. - :param warn: If ``True`` then warnings will be logged and - invalid options will be ignored. Otherwise invalid options will - cause errors. - """ - return get_validated_options(opts, warn) - - -def split_options( - opts: str, validate: bool = True, warn: bool = False, normalize: bool = True -) -> MutableMapping[str, Any]: - """Takes the options portion of a MongoDB URI, validates each option - and returns the options in a dictionary. - - :param opt: A string representing MongoDB URI options. - :param validate: If ``True`` (the default), validate and normalize all - options. - :param warn: If ``False`` (the default), suppress all warnings raised - during validation of options. - :param normalize: If ``True`` (the default), renames all options to their - internally-used names. - """ - and_idx = opts.find("&") - semi_idx = opts.find(";") - try: - if and_idx >= 0 and semi_idx >= 0: - raise InvalidURI("Can not mix '&' and ';' for option separators.") - elif and_idx >= 0: - options = _parse_options(opts, "&") - elif semi_idx >= 0: - options = _parse_options(opts, ";") - elif opts.find("=") != -1: - options = _parse_options(opts, None) - else: - raise ValueError - except ValueError: - raise InvalidURI("MongoDB URI options are key=value pairs.") from None - - options = _handle_security_options(options) - - options = _handle_option_deprecations(options) - - if normalize: - options = _normalize_options(options) - - if validate: - options = cast(_CaseInsensitiveDictionary, validate_options(options, warn)) - if options.get("authsource") == "": - raise InvalidURI("the authSource database cannot be an empty string") - - return options - - -def split_hosts(hosts: str, default_port: Optional[int] = DEFAULT_PORT) -> list[_Address]: - """Takes a string of the form host1[:port],host2[:port]... and - splits it into (host, port) tuples. If [:port] isn't present the - default_port is used. - - Returns a set of 2-tuples containing the host name (or IP) followed by - port number. - - :param hosts: A string of the form host1[:port],host2[:port],... - :param default_port: The port number to use when one wasn't specified - for a host. - """ - nodes = [] - for entity in hosts.split(","): - if not entity: - raise ConfigurationError("Empty host (or extra comma in host list).") - port = default_port - # Unix socket entities don't have ports - if entity.endswith(".sock"): - port = None - nodes.append(parse_host(entity, port)) - return nodes - - -# Prohibited characters in database name. DB names also can't have ".", but for -# backward-compat we allow "db.collection" in URI. -_BAD_DB_CHARS = re.compile("[" + re.escape(r'/ "$') + "]") - -_ALLOWED_TXT_OPTS = frozenset( - ["authsource", "authSource", "replicaset", "replicaSet", "loadbalanced", "loadBalanced"] -) - - -def _check_options(nodes: Sized, options: Mapping[str, Any]) -> None: - # Ensure directConnection was not True if there are multiple seeds. - if len(nodes) > 1 and options.get("directconnection"): - raise ConfigurationError("Cannot specify multiple hosts with directConnection=true") - - if options.get("loadbalanced"): - if len(nodes) > 1: - raise ConfigurationError("Cannot specify multiple hosts with loadBalanced=true") - if options.get("directconnection"): - raise ConfigurationError("Cannot specify directConnection=true with loadBalanced=true") - if options.get("replicaset"): - raise ConfigurationError("Cannot specify replicaSet with loadBalanced=true") - - -def parse_uri( - uri: str, - default_port: Optional[int] = DEFAULT_PORT, - validate: bool = True, - warn: bool = False, - normalize: bool = True, - connect_timeout: Optional[float] = None, - srv_service_name: Optional[str] = None, - srv_max_hosts: Optional[int] = None, -) -> dict[str, Any]: - """Parse and validate a MongoDB URI. - - Returns a dict of the form:: - - { - 'nodelist': , - 'username': or None, - 'password': or None, - 'database': or None, - 'collection': or None, - 'options': , - 'fqdn': or None - } - - If the URI scheme is "mongodb+srv://" DNS SRV and TXT lookups will be done - to build nodelist and options. - - :param uri: The MongoDB URI to parse. - :param default_port: The port number to use when one wasn't specified - for a host in the URI. - :param validate: If ``True`` (the default), validate and - normalize all options. Default: ``True``. - :param warn: When validating, if ``True`` then will warn - the user then ignore any invalid options or values. If ``False``, - validation will error when options are unsupported or values are - invalid. Default: ``False``. - :param normalize: If ``True``, convert names of URI options - to their internally-used names. Default: ``True``. - :param connect_timeout: The maximum time in milliseconds to - wait for a response from the DNS server. - :param srv_service_name: A custom SRV service name - - .. versionchanged:: 4.6 - The delimiting slash (``/``) between hosts and connection options is now optional. - For example, "mongodb://example.com?tls=true" is now a valid URI. - - .. versionchanged:: 4.0 - To better follow RFC 3986, unquoted percent signs ("%") are no longer - supported. - - .. versionchanged:: 3.9 - Added the ``normalize`` parameter. - - .. versionchanged:: 3.6 - Added support for mongodb+srv:// URIs. - - .. versionchanged:: 3.5 - Return the original value of the ``readPreference`` MongoDB URI option - instead of the validated read preference mode. - - .. versionchanged:: 3.1 - ``warn`` added so invalid options can be ignored. - """ - if uri.startswith(SCHEME): - is_srv = False - scheme_free = uri[SCHEME_LEN:] - elif uri.startswith(SRV_SCHEME): - if not _have_dnspython(): - python_path = sys.executable or "python" - raise ConfigurationError( - 'The "dnspython" module must be ' - "installed to use mongodb+srv:// URIs. " - "To fix this error install pymongo again:\n " - "%s -m pip install pymongo>=4.3" % (python_path) - ) - is_srv = True - scheme_free = uri[SRV_SCHEME_LEN:] - else: - raise InvalidURI(f"Invalid URI scheme: URI must begin with '{SCHEME}' or '{SRV_SCHEME}'") - - if not scheme_free: - raise InvalidURI("Must provide at least one hostname or IP.") - - user = None - passwd = None - dbase = None - collection = None - options = _CaseInsensitiveDictionary() - - host_plus_db_part, _, opts = scheme_free.partition("?") - if "/" in host_plus_db_part: - host_part, _, dbase = host_plus_db_part.partition("/") - else: - host_part = host_plus_db_part - - if dbase: - dbase = unquote_plus(dbase) - if "." in dbase: - dbase, collection = dbase.split(".", 1) - if _BAD_DB_CHARS.search(dbase): - raise InvalidURI('Bad database name "%s"' % dbase) - else: - dbase = None - - if opts: - options.update(split_options(opts, validate, warn, normalize)) - if srv_service_name is None: - srv_service_name = options.get("srvServiceName", SRV_SERVICE_NAME) - if "@" in host_part: - userinfo, _, hosts = host_part.rpartition("@") - user, passwd = parse_userinfo(userinfo) - else: - hosts = host_part - - if "/" in hosts: - raise InvalidURI("Any '/' in a unix domain socket must be percent-encoded: %s" % host_part) - - hosts = unquote_plus(hosts) - fqdn = None - srv_max_hosts = srv_max_hosts or options.get("srvMaxHosts") - if is_srv: - if options.get("directConnection"): - raise ConfigurationError(f"Cannot specify directConnection=true with {SRV_SCHEME} URIs") - nodes = split_hosts(hosts, default_port=None) - if len(nodes) != 1: - raise InvalidURI(f"{SRV_SCHEME} URIs must include one, and only one, hostname") - fqdn, port = nodes[0] - if port is not None: - raise InvalidURI(f"{SRV_SCHEME} URIs must not include a port number") - - # Use the connection timeout. connectTimeoutMS passed as a keyword - # argument overrides the same option passed in the connection string. - connect_timeout = connect_timeout or options.get("connectTimeoutMS") - dns_resolver = _SrvResolver(fqdn, connect_timeout, srv_service_name, srv_max_hosts) - nodes = dns_resolver.get_hosts() - dns_options = dns_resolver.get_options() - if dns_options: - parsed_dns_options = split_options(dns_options, validate, warn, normalize) - if set(parsed_dns_options) - _ALLOWED_TXT_OPTS: - raise ConfigurationError( - "Only authSource, replicaSet, and loadBalanced are supported from DNS" - ) - for opt, val in parsed_dns_options.items(): - if opt not in options: - options[opt] = val - if options.get("loadBalanced") and srv_max_hosts: - raise InvalidURI("You cannot specify loadBalanced with srvMaxHosts") - if options.get("replicaSet") and srv_max_hosts: - raise InvalidURI("You cannot specify replicaSet with srvMaxHosts") - if "tls" not in options and "ssl" not in options: - options["tls"] = True if validate else "true" - elif not is_srv and options.get("srvServiceName") is not None: - raise ConfigurationError( - "The srvServiceName option is only allowed with 'mongodb+srv://' URIs" - ) - elif not is_srv and srv_max_hosts: - raise ConfigurationError( - "The srvMaxHosts option is only allowed with 'mongodb+srv://' URIs" - ) - else: - nodes = split_hosts(hosts, default_port=default_port) - - _check_options(nodes, options) - - return { - "nodelist": nodes, - "username": user, - "password": passwd, - "database": dbase, - "collection": collection, - "options": options, - "fqdn": fqdn, - } - - -def _parse_kms_tls_options(kms_tls_options: Optional[Mapping[str, Any]]) -> dict[str, SSLContext]: - """Parse KMS TLS connection options.""" - if not kms_tls_options: - return {} - if not isinstance(kms_tls_options, dict): - raise TypeError("kms_tls_options must be a dict") - contexts = {} - for provider, options in kms_tls_options.items(): - if not isinstance(options, dict): - raise TypeError(f'kms_tls_options["{provider}"] must be a dict') - options.setdefault("tls", True) - opts = _CaseInsensitiveDictionary(options) - opts = _handle_security_options(opts) - opts = _normalize_options(opts) - opts = cast(_CaseInsensitiveDictionary, validate_options(opts)) - ssl_context, allow_invalid_hostnames = _parse_ssl_options(opts) - if ssl_context is None: - raise ConfigurationError("TLS is required for KMS providers") - if allow_invalid_hostnames: - raise ConfigurationError("Insecure TLS options prohibited") - - for n in [ - "tlsInsecure", - "tlsAllowInvalidCertificates", - "tlsAllowInvalidHostnames", - "tlsDisableCertificateRevocationCheck", - ]: - if n in opts: - raise ConfigurationError(f"Insecure TLS options prohibited: {n}") - contexts[provider] = ssl_context - return contexts - - -if __name__ == "__main__": - import pprint - - try: - pprint.pprint(parse_uri(sys.argv[1])) # noqa: T203 - except InvalidURI as exc: - print(exc) # noqa: T201 - sys.exit(0) diff --git a/pymongo/auth.py b/pymongo/auth.py index 13302ae5db..a65113841d 100644 --- a/pymongo/auth.py +++ b/pymongo/auth.py @@ -15,6 +15,7 @@ """Re-import of synchronous Auth API for compatibility.""" from __future__ import annotations +from pymongo.auth_shared import * # noqa: F403 from pymongo.synchronous.auth import * # noqa: F403 from pymongo.synchronous.auth import __doc__ as original_doc diff --git a/pymongo/auth_oidc_shared.py b/pymongo/auth_oidc_shared.py new file mode 100644 index 0000000000..5e3603fa31 --- /dev/null +++ b/pymongo/auth_oidc_shared.py @@ -0,0 +1,118 @@ +# Copyright 2024-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you +# may not use this file except in compliance with the License. You +# may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + + +"""Constants, types, and classes shared across OIDC auth implementations.""" +from __future__ import annotations + +import abc +import os +from dataclasses import dataclass, field +from typing import Optional +from urllib.parse import quote + +from pymongo._azure_helpers import _get_azure_response +from pymongo._gcp_helpers import _get_gcp_response + + +@dataclass +class OIDCIdPInfo: + issuer: str + clientId: Optional[str] = field(default=None) + requestScopes: Optional[list[str]] = field(default=None) + + +@dataclass +class OIDCCallbackContext: + timeout_seconds: float + username: str + version: int + refresh_token: Optional[str] = field(default=None) + idp_info: Optional[OIDCIdPInfo] = field(default=None) + + +@dataclass +class OIDCCallbackResult: + access_token: str + expires_in_seconds: Optional[float] = field(default=None) + refresh_token: Optional[str] = field(default=None) + + +class OIDCCallback(abc.ABC): + """A base class for defining OIDC callbacks.""" + + @abc.abstractmethod + def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult: + """Convert the given BSON value into our own type.""" + + +@dataclass +class _OIDCProperties: + callback: Optional[OIDCCallback] = field(default=None) + human_callback: Optional[OIDCCallback] = field(default=None) + environment: Optional[str] = field(default=None) + allowed_hosts: list[str] = field(default_factory=list) + token_resource: Optional[str] = field(default=None) + username: str = "" + + +"""Mechanism properties for MONGODB-OIDC authentication.""" + +TOKEN_BUFFER_MINUTES = 5 +HUMAN_CALLBACK_TIMEOUT_SECONDS = 5 * 60 +CALLBACK_VERSION = 1 +MACHINE_CALLBACK_TIMEOUT_SECONDS = 60 +TIME_BETWEEN_CALLS_SECONDS = 0.1 + + +class _OIDCTestCallback(OIDCCallback): + def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult: + token_file = os.environ.get("OIDC_TOKEN_FILE") + if not token_file: + raise RuntimeError( + 'MONGODB-OIDC with an "test" provider requires "OIDC_TOKEN_FILE" to be set' + ) + with open(token_file) as fid: + return OIDCCallbackResult(access_token=fid.read().strip()) + + +class _OIDCAWSCallback(OIDCCallback): + def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult: + token_file = os.environ.get("AWS_WEB_IDENTITY_TOKEN_FILE") + if not token_file: + raise RuntimeError( + 'MONGODB-OIDC with an "aws" provider requires "AWS_WEB_IDENTITY_TOKEN_FILE" to be set' + ) + with open(token_file) as fid: + return OIDCCallbackResult(access_token=fid.read().strip()) + + +class _OIDCAzureCallback(OIDCCallback): + def __init__(self, token_resource: str) -> None: + self.token_resource = quote(token_resource) + + def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult: + resp = _get_azure_response(self.token_resource, context.username, context.timeout_seconds) + return OIDCCallbackResult( + access_token=resp["access_token"], expires_in_seconds=resp["expires_in"] + ) + + +class _OIDCGCPCallback(OIDCCallback): + def __init__(self, token_resource: str) -> None: + self.token_resource = quote(token_resource) + + def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult: + resp = _get_gcp_response(self.token_resource, context.timeout_seconds) + return OIDCCallbackResult(access_token=resp["access_token"]) diff --git a/pymongo/auth_shared.py b/pymongo/auth_shared.py new file mode 100644 index 0000000000..7e3acd9dfb --- /dev/null +++ b/pymongo/auth_shared.py @@ -0,0 +1,236 @@ +# Copyright 2024-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you +# may not use this file except in compliance with the License. You +# may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + + +"""Constants and types shared across multiple auth types.""" +from __future__ import annotations + +import os +import typing +from base64 import standard_b64encode +from collections import namedtuple +from typing import Any, Dict, Mapping, Optional + +from bson import Binary +from pymongo.auth_oidc_shared import ( + _OIDCAzureCallback, + _OIDCGCPCallback, + _OIDCProperties, + _OIDCTestCallback, +) +from pymongo.errors import ConfigurationError + +MECHANISMS = frozenset( + [ + "GSSAPI", + "MONGODB-CR", + "MONGODB-OIDC", + "MONGODB-X509", + "MONGODB-AWS", + "PLAIN", + "SCRAM-SHA-1", + "SCRAM-SHA-256", + "DEFAULT", + ] +) +"""The authentication mechanisms supported by PyMongo.""" + + +class _Cache: + __slots__ = ("data",) + + _hash_val = hash("_Cache") + + def __init__(self) -> None: + self.data = None + + def __eq__(self, other: object) -> bool: + # Two instances must always compare equal. + if isinstance(other, _Cache): + return True + return NotImplemented + + def __ne__(self, other: object) -> bool: + if isinstance(other, _Cache): + return False + return NotImplemented + + def __hash__(self) -> int: + return self._hash_val + + +MongoCredential = namedtuple( + "MongoCredential", + ["mechanism", "source", "username", "password", "mechanism_properties", "cache"], +) +"""A hashable namedtuple of values used for authentication.""" + + +GSSAPIProperties = namedtuple( + "GSSAPIProperties", ["service_name", "canonicalize_host_name", "service_realm"] +) +"""Mechanism properties for GSSAPI authentication.""" + + +_AWSProperties = namedtuple("_AWSProperties", ["aws_session_token"]) +"""Mechanism properties for MONGODB-AWS authentication.""" + + +def _build_credentials_tuple( + mech: str, + source: Optional[str], + user: str, + passwd: str, + extra: Mapping[str, Any], + database: Optional[str], +) -> MongoCredential: + """Build and return a mechanism specific credentials tuple.""" + if mech not in ("MONGODB-X509", "MONGODB-AWS", "MONGODB-OIDC") and user is None: + raise ConfigurationError(f"{mech} requires a username.") + if mech == "GSSAPI": + if source is not None and source != "$external": + raise ValueError("authentication source must be $external or None for GSSAPI") + properties = extra.get("authmechanismproperties", {}) + service_name = properties.get("SERVICE_NAME", "mongodb") + canonicalize = bool(properties.get("CANONICALIZE_HOST_NAME", False)) + service_realm = properties.get("SERVICE_REALM") + props = GSSAPIProperties( + service_name=service_name, + canonicalize_host_name=canonicalize, + service_realm=service_realm, + ) + # Source is always $external. + return MongoCredential(mech, "$external", user, passwd, props, None) + elif mech == "MONGODB-X509": + if passwd is not None: + raise ConfigurationError("Passwords are not supported by MONGODB-X509") + if source is not None and source != "$external": + raise ValueError("authentication source must be $external or None for MONGODB-X509") + # Source is always $external, user can be None. + return MongoCredential(mech, "$external", user, None, None, None) + elif mech == "MONGODB-AWS": + if user is not None and passwd is None: + raise ConfigurationError("username without a password is not supported by MONGODB-AWS") + if source is not None and source != "$external": + raise ConfigurationError( + "authentication source must be $external or None for MONGODB-AWS" + ) + + properties = extra.get("authmechanismproperties", {}) + aws_session_token = properties.get("AWS_SESSION_TOKEN") + aws_props = _AWSProperties(aws_session_token=aws_session_token) + # user can be None for temporary link-local EC2 credentials. + return MongoCredential(mech, "$external", user, passwd, aws_props, None) + elif mech == "MONGODB-OIDC": + properties = extra.get("authmechanismproperties", {}) + callback = properties.get("OIDC_CALLBACK") + human_callback = properties.get("OIDC_HUMAN_CALLBACK") + environ = properties.get("ENVIRONMENT") + token_resource = properties.get("TOKEN_RESOURCE", "") + default_allowed = [ + "*.mongodb.net", + "*.mongodb-dev.net", + "*.mongodb-qa.net", + "*.mongodbgov.net", + "localhost", + "127.0.0.1", + "::1", + ] + allowed_hosts = properties.get("ALLOWED_HOSTS", default_allowed) + msg = ( + "authentication with MONGODB-OIDC requires providing either a callback or a environment" + ) + if passwd is not None: + msg = "password is not supported by MONGODB-OIDC" + raise ConfigurationError(msg) + if callback or human_callback: + if environ is not None: + raise ConfigurationError(msg) + if callback and human_callback: + msg = "cannot set both OIDC_CALLBACK and OIDC_HUMAN_CALLBACK" + raise ConfigurationError(msg) + elif environ is not None: + if environ == "test": + if user is not None: + msg = "test environment for MONGODB-OIDC does not support username" + raise ConfigurationError(msg) + callback = _OIDCTestCallback() + elif environ == "azure": + passwd = None + if not token_resource: + raise ConfigurationError( + "Azure environment for MONGODB-OIDC requires a TOKEN_RESOURCE auth mechanism property" + ) + callback = _OIDCAzureCallback(token_resource) + elif environ == "gcp": + passwd = None + if not token_resource: + raise ConfigurationError( + "GCP provider for MONGODB-OIDC requires a TOKEN_RESOURCE auth mechanism property" + ) + callback = _OIDCGCPCallback(token_resource) + else: + raise ConfigurationError(f"unrecognized ENVIRONMENT for MONGODB-OIDC: {environ}") + else: + raise ConfigurationError(msg) + + oidc_props = _OIDCProperties( + callback=callback, + human_callback=human_callback, + environment=environ, + allowed_hosts=allowed_hosts, + token_resource=token_resource, + username=user, + ) + return MongoCredential(mech, "$external", user, passwd, oidc_props, _Cache()) + + elif mech == "PLAIN": + source_database = source or database or "$external" + return MongoCredential(mech, source_database, user, passwd, None, None) + else: + source_database = source or database or "admin" + if passwd is None: + raise ConfigurationError("A password is required.") + return MongoCredential(mech, source_database, user, passwd, None, _Cache()) + + +def _xor(fir: bytes, sec: bytes) -> bytes: + """XOR two byte strings together.""" + return b"".join([bytes([x ^ y]) for x, y in zip(fir, sec)]) + + +def _parse_scram_response(response: bytes) -> Dict[bytes, bytes]: + """Split a scram response into key, value pairs.""" + return dict( + typing.cast(typing.Tuple[bytes, bytes], item.split(b"=", 1)) + for item in response.split(b",") + ) + + +def _authenticate_scram_start( + credentials: MongoCredential, mechanism: str +) -> tuple[bytes, bytes, typing.MutableMapping[str, Any]]: + username = credentials.username + user = username.encode("utf-8").replace(b"=", b"=3D").replace(b",", b"=2C") + nonce = standard_b64encode(os.urandom(32)) + first_bare = b"n=" + user + b",r=" + nonce + + cmd = { + "saslStart": 1, + "mechanism": mechanism, + "payload": Binary(b"n,," + first_bare), + "autoAuthorize": 1, + "options": {"skipEmptyExchange": True}, + } + return nonce, first_bare, cmd diff --git a/pymongo/client_options.py b/pymongo/client_options.py index 7a4e04453d..ddc22c3aff 100644 --- a/pymongo/client_options.py +++ b/pymongo/client_options.py @@ -1,21 +1,334 @@ -# Copyright 2024-present MongoDB, Inc. +# Copyright 2014-present MongoDB, Inc. # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); you +# may not use this file except in compliance with the License. You +# may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. -"""Re-import of synchronous ClientOptions API for compatibility.""" +"""Tools to parse mongo client options.""" from __future__ import annotations -from pymongo.synchronous.client_options import * # noqa: F403 -from pymongo.synchronous.client_options import __doc__ as original_doc +from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence, cast -__doc__ = original_doc +from bson.codec_options import _parse_codec_options +from pymongo import common +from pymongo.compression_support import CompressionSettings +from pymongo.errors import ConfigurationError +from pymongo.monitoring import _EventListener, _EventListeners +from pymongo.pool_options import PoolOptions +from pymongo.read_concern import ReadConcern +from pymongo.read_preferences import ( + _ServerMode, + make_read_preference, + read_pref_mode_from_name, +) +from pymongo.server_selectors import any_server_selector +from pymongo.ssl_support import get_ssl_context +from pymongo.write_concern import WriteConcern, validate_boolean + +if TYPE_CHECKING: + from bson.codec_options import CodecOptions + from pymongo.auth_shared import MongoCredential + from pymongo.encryption_options import AutoEncryptionOpts + from pymongo.pyopenssl_context import SSLContext + from pymongo.topology_description import _ServerSelector + +_IS_SYNC = False + + +def _parse_credentials( + username: str, password: str, database: Optional[str], options: Mapping[str, Any] +) -> Optional[MongoCredential]: + """Parse authentication credentials.""" + mechanism = options.get("authmechanism", "DEFAULT" if username else None) + source = options.get("authsource") + if username or mechanism: + from pymongo.auth_shared import _build_credentials_tuple + + return _build_credentials_tuple(mechanism, source, username, password, options, database) + return None + + +def _parse_read_preference(options: Mapping[str, Any]) -> _ServerMode: + """Parse read preference options.""" + if "read_preference" in options: + return options["read_preference"] + + name = options.get("readpreference", "primary") + mode = read_pref_mode_from_name(name) + tags = options.get("readpreferencetags") + max_staleness = options.get("maxstalenessseconds", -1) + return make_read_preference(mode, tags, max_staleness) + + +def _parse_write_concern(options: Mapping[str, Any]) -> WriteConcern: + """Parse write concern options.""" + concern = options.get("w") + wtimeout = options.get("wtimeoutms") + j = options.get("journal") + fsync = options.get("fsync") + return WriteConcern(concern, wtimeout, j, fsync) + + +def _parse_read_concern(options: Mapping[str, Any]) -> ReadConcern: + """Parse read concern options.""" + concern = options.get("readconcernlevel") + return ReadConcern(concern) + + +def _parse_ssl_options(options: Mapping[str, Any]) -> tuple[Optional[SSLContext], bool]: + """Parse ssl options.""" + use_tls = options.get("tls") + if use_tls is not None: + validate_boolean("tls", use_tls) + + certfile = options.get("tlscertificatekeyfile") + passphrase = options.get("tlscertificatekeyfilepassword") + ca_certs = options.get("tlscafile") + crlfile = options.get("tlscrlfile") + allow_invalid_certificates = options.get("tlsallowinvalidcertificates", False) + allow_invalid_hostnames = options.get("tlsallowinvalidhostnames", False) + disable_ocsp_endpoint_check = options.get("tlsdisableocspendpointcheck", False) + + enabled_tls_opts = [] + for opt in ( + "tlscertificatekeyfile", + "tlscertificatekeyfilepassword", + "tlscafile", + "tlscrlfile", + ): + # Any non-null value of these options implies tls=True. + if opt in options and options[opt]: + enabled_tls_opts.append(opt) + for opt in ( + "tlsallowinvalidcertificates", + "tlsallowinvalidhostnames", + "tlsdisableocspendpointcheck", + ): + # A value of False for these options implies tls=True. + if opt in options and not options[opt]: + enabled_tls_opts.append(opt) + + if enabled_tls_opts: + if use_tls is None: + # Implicitly enable TLS when one of the tls* options is set. + use_tls = True + elif not use_tls: + # Error since tls is explicitly disabled but a tls option is set. + raise ConfigurationError( + "TLS has not been enabled but the " + "following tls parameters have been set: " + "%s. Please set `tls=True` or remove." % ", ".join(enabled_tls_opts) + ) + + if use_tls: + ctx = get_ssl_context( + certfile, + passphrase, + ca_certs, + crlfile, + allow_invalid_certificates, + allow_invalid_hostnames, + disable_ocsp_endpoint_check, + ) + return ctx, allow_invalid_hostnames + return None, allow_invalid_hostnames + + +def _parse_pool_options( + username: str, password: str, database: Optional[str], options: Mapping[str, Any] +) -> PoolOptions: + """Parse connection pool options.""" + credentials = _parse_credentials(username, password, database, options) + max_pool_size = options.get("maxpoolsize", common.MAX_POOL_SIZE) + min_pool_size = options.get("minpoolsize", common.MIN_POOL_SIZE) + max_idle_time_seconds = options.get("maxidletimems", common.MAX_IDLE_TIME_SEC) + if max_pool_size is not None and min_pool_size > max_pool_size: + raise ValueError("minPoolSize must be smaller or equal to maxPoolSize") + connect_timeout = options.get("connecttimeoutms", common.CONNECT_TIMEOUT) + socket_timeout = options.get("sockettimeoutms") + wait_queue_timeout = options.get("waitqueuetimeoutms", common.WAIT_QUEUE_TIMEOUT) + event_listeners = cast(Optional[Sequence[_EventListener]], options.get("event_listeners")) + appname = options.get("appname") + driver = options.get("driver") + server_api = options.get("server_api") + compression_settings = CompressionSettings( + options.get("compressors", []), options.get("zlibcompressionlevel", -1) + ) + ssl_context, tls_allow_invalid_hostnames = _parse_ssl_options(options) + load_balanced = options.get("loadbalanced") + max_connecting = options.get("maxconnecting", common.MAX_CONNECTING) + return PoolOptions( + max_pool_size, + min_pool_size, + max_idle_time_seconds, + connect_timeout, + socket_timeout, + wait_queue_timeout, + ssl_context, + tls_allow_invalid_hostnames, + _EventListeners(event_listeners), + appname, + driver, + compression_settings, + max_connecting=max_connecting, + server_api=server_api, + load_balanced=load_balanced, + credentials=credentials, + ) + + +class ClientOptions: + """Read only configuration options for an AsyncMongoClient/MongoClient. + + Should not be instantiated directly by application developers. Access + a client's options via :attr:`pymongo.mongo_client.AsyncMongoClient.options` or :attr:`pymongo.mongo_client.MongoClient.options` + instead. + """ + + def __init__( + self, username: str, password: str, database: Optional[str], options: Mapping[str, Any] + ): + self.__options = options + self.__codec_options = _parse_codec_options(options) + self.__direct_connection = options.get("directconnection") + self.__local_threshold_ms = options.get("localthresholdms", common.LOCAL_THRESHOLD_MS) + # self.__server_selection_timeout is in seconds. Must use full name for + # common.SERVER_SELECTION_TIMEOUT because it is set directly by tests. + self.__server_selection_timeout = options.get( + "serverselectiontimeoutms", common.SERVER_SELECTION_TIMEOUT + ) + self.__pool_options = _parse_pool_options(username, password, database, options) + self.__read_preference = _parse_read_preference(options) + self.__replica_set_name = options.get("replicaset") + self.__write_concern = _parse_write_concern(options) + self.__read_concern = _parse_read_concern(options) + self.__connect = options.get("connect") + self.__heartbeat_frequency = options.get("heartbeatfrequencyms", common.HEARTBEAT_FREQUENCY) + self.__retry_writes = options.get("retrywrites", common.RETRY_WRITES) + self.__retry_reads = options.get("retryreads", common.RETRY_READS) + self.__server_selector = options.get("server_selector", any_server_selector) + self.__auto_encryption_opts = options.get("auto_encryption_opts") + self.__load_balanced = options.get("loadbalanced") + self.__timeout = options.get("timeoutms") + self.__server_monitoring_mode = options.get( + "servermonitoringmode", common.SERVER_MONITORING_MODE + ) + + @property + def _options(self) -> Mapping[str, Any]: + """The original options used to create this ClientOptions.""" + return self.__options + + @property + def connect(self) -> Optional[bool]: + """Whether to begin discovering a MongoDB topology automatically.""" + return self.__connect + + @property + def codec_options(self) -> CodecOptions: + """A :class:`~bson.codec_options.CodecOptions` instance.""" + return self.__codec_options + + @property + def direct_connection(self) -> Optional[bool]: + """Whether to connect to the deployment in 'Single' topology.""" + return self.__direct_connection + + @property + def local_threshold_ms(self) -> int: + """The local threshold for this instance.""" + return self.__local_threshold_ms + + @property + def server_selection_timeout(self) -> int: + """The server selection timeout for this instance in seconds.""" + return self.__server_selection_timeout + + @property + def server_selector(self) -> _ServerSelector: + return self.__server_selector + + @property + def heartbeat_frequency(self) -> int: + """The monitoring frequency in seconds.""" + return self.__heartbeat_frequency + + @property + def pool_options(self) -> PoolOptions: + """A :class:`~pymongo.pool.PoolOptions` instance.""" + return self.__pool_options + + @property + def read_preference(self) -> _ServerMode: + """A read preference instance.""" + return self.__read_preference + + @property + def replica_set_name(self) -> Optional[str]: + """Replica set name or None.""" + return self.__replica_set_name + + @property + def write_concern(self) -> WriteConcern: + """A :class:`~pymongo.write_concern.WriteConcern` instance.""" + return self.__write_concern + + @property + def read_concern(self) -> ReadConcern: + """A :class:`~pymongo.read_concern.ReadConcern` instance.""" + return self.__read_concern + + @property + def timeout(self) -> Optional[float]: + """The configured timeoutMS converted to seconds, or None. + + .. versionadded:: 4.2 + """ + return self.__timeout + + @property + def retry_writes(self) -> bool: + """If this instance should retry supported write operations.""" + return self.__retry_writes + + @property + def retry_reads(self) -> bool: + """If this instance should retry supported read operations.""" + return self.__retry_reads + + @property + def auto_encryption_opts(self) -> Optional[AutoEncryptionOpts]: + """A :class:`~pymongo.encryption.AutoEncryptionOpts` or None.""" + return self.__auto_encryption_opts + + @property + def load_balanced(self) -> Optional[bool]: + """True if the client was configured to connect to a load balancer.""" + return self.__load_balanced + + @property + def event_listeners(self) -> list[_EventListeners]: + """The event listeners registered for this client. + + See :mod:`~pymongo.monitoring` for details. + + .. versionadded:: 4.0 + """ + assert self.__pool_options._event_listeners is not None + return self.__pool_options._event_listeners.event_listeners() + + @property + def server_monitoring_mode(self) -> str: + """The configured serverMonitoringMode option. + + .. versionadded:: 4.5 + """ + return self.__server_monitoring_mode diff --git a/pymongo/collation.py b/pymongo/collation.py index b129a04512..115c8c7e88 100644 --- a/pymongo/collation.py +++ b/pymongo/collation.py @@ -1,4 +1,4 @@ -# Copyright 2024-present MongoDB, Inc. +# Copyright 2016-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,10 +12,215 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Re-import of synchronous Collation API for compatibility.""" +"""Tools for working with `collations`_. + +.. _collations: https://www.mongodb.com/docs/manual/reference/collation/ +""" from __future__ import annotations -from pymongo.synchronous.collation import * # noqa: F403 -from pymongo.synchronous.collation import __doc__ as original_doc +from typing import Any, Mapping, Optional, Union + +from pymongo import common +from pymongo.write_concern import validate_boolean + +_IS_SYNC = False + + +class CollationStrength: + """ + An enum that defines values for `strength` on a + :class:`~pymongo.collation.Collation`. + """ + + PRIMARY = 1 + """Differentiate base (unadorned) characters.""" + + SECONDARY = 2 + """Differentiate character accents.""" + + TERTIARY = 3 + """Differentiate character case.""" + + QUATERNARY = 4 + """Differentiate words with and without punctuation.""" + + IDENTICAL = 5 + """Differentiate unicode code point (characters are exactly identical).""" + + +class CollationAlternate: + """ + An enum that defines values for `alternate` on a + :class:`~pymongo.collation.Collation`. + """ + + NON_IGNORABLE = "non-ignorable" + """Spaces and punctuation are treated as base characters.""" + + SHIFTED = "shifted" + """Spaces and punctuation are *not* considered base characters. + + Spaces and punctuation are distinguished regardless when the + :class:`~pymongo.collation.Collation` strength is at least + :data:`~pymongo.collation.CollationStrength.QUATERNARY`. + + """ + + +class CollationMaxVariable: + """ + An enum that defines values for `max_variable` on a + :class:`~pymongo.collation.Collation`. + """ + + PUNCT = "punct" + """Both punctuation and spaces are ignored.""" + + SPACE = "space" + """Spaces alone are ignored.""" + + +class CollationCaseFirst: + """ + An enum that defines values for `case_first` on a + :class:`~pymongo.collation.Collation`. + """ + + UPPER = "upper" + """Sort uppercase characters first.""" + + LOWER = "lower" + """Sort lowercase characters first.""" + + OFF = "off" + """Default for locale or collation strength.""" + + +class Collation: + """Collation + + :param locale: (string) The locale of the collation. This should be a string + that identifies an `ICU locale ID` exactly. For example, ``en_US`` is + valid, but ``en_us`` and ``en-US`` are not. Consult the MongoDB + documentation for a list of supported locales. + :param caseLevel: (optional) If ``True``, turn on case sensitivity if + `strength` is 1 or 2 (case sensitivity is implied if `strength` is + greater than 2). Defaults to ``False``. + :param caseFirst: (optional) Specify that either uppercase or lowercase + characters take precedence. Must be one of the following values: + + * :data:`~CollationCaseFirst.UPPER` + * :data:`~CollationCaseFirst.LOWER` + * :data:`~CollationCaseFirst.OFF` (the default) + + :param strength: Specify the comparison strength. This is also + known as the ICU comparison level. This must be one of the following + values: + + * :data:`~CollationStrength.PRIMARY` + * :data:`~CollationStrength.SECONDARY` + * :data:`~CollationStrength.TERTIARY` (the default) + * :data:`~CollationStrength.QUATERNARY` + * :data:`~CollationStrength.IDENTICAL` + + Each successive level builds upon the previous. For example, a + `strength` of :data:`~CollationStrength.SECONDARY` differentiates + characters based both on the unadorned base character and its accents. + + :param numericOrdering: If ``True``, order numbers numerically + instead of in collation order (defaults to ``False``). + :param alternate: Specify whether spaces and punctuation are + considered base characters. This must be one of the following values: + + * :data:`~CollationAlternate.NON_IGNORABLE` (the default) + * :data:`~CollationAlternate.SHIFTED` + + :param maxVariable: When `alternate` is + :data:`~CollationAlternate.SHIFTED`, this option specifies what + characters may be ignored. This must be one of the following values: + + * :data:`~CollationMaxVariable.PUNCT` (the default) + * :data:`~CollationMaxVariable.SPACE` + + :param normalization: If ``True``, normalizes text into Unicode + NFD. Defaults to ``False``. + :param backwards: If ``True``, accents on characters are + considered from the back of the word to the front, as it is done in some + French dictionary ordering traditions. Defaults to ``False``. + :param kwargs: Keyword arguments supplying any additional options + to be sent with this Collation object. + + .. versionadded: 3.4 + + """ + + __slots__ = ("__document",) + + def __init__( + self, + locale: str, + caseLevel: Optional[bool] = None, + caseFirst: Optional[str] = None, + strength: Optional[int] = None, + numericOrdering: Optional[bool] = None, + alternate: Optional[str] = None, + maxVariable: Optional[str] = None, + normalization: Optional[bool] = None, + backwards: Optional[bool] = None, + **kwargs: Any, + ) -> None: + locale = common.validate_string("locale", locale) + self.__document: dict[str, Any] = {"locale": locale} + if caseLevel is not None: + self.__document["caseLevel"] = validate_boolean("caseLevel", caseLevel) + if caseFirst is not None: + self.__document["caseFirst"] = common.validate_string("caseFirst", caseFirst) + if strength is not None: + self.__document["strength"] = common.validate_integer("strength", strength) + if numericOrdering is not None: + self.__document["numericOrdering"] = validate_boolean( + "numericOrdering", numericOrdering + ) + if alternate is not None: + self.__document["alternate"] = common.validate_string("alternate", alternate) + if maxVariable is not None: + self.__document["maxVariable"] = common.validate_string("maxVariable", maxVariable) + if normalization is not None: + self.__document["normalization"] = validate_boolean("normalization", normalization) + if backwards is not None: + self.__document["backwards"] = validate_boolean("backwards", backwards) + self.__document.update(kwargs) + + @property + def document(self) -> dict[str, Any]: + """The document representation of this collation. + + .. note:: + :class:`Collation` is immutable. Mutating the value of + :attr:`document` does not mutate this :class:`Collation`. + """ + return self.__document.copy() + + def __repr__(self) -> str: + document = self.document + return "Collation({})".format(", ".join(f"{key}={document[key]!r}" for key in document)) + + def __eq__(self, other: Any) -> bool: + if isinstance(other, Collation): + return self.document == other.document + return NotImplemented + + def __ne__(self, other: Any) -> bool: + return not self == other + -__doc__ = original_doc +def validate_collation_or_none( + value: Optional[Union[Mapping[str, Any], Collation]] +) -> Optional[dict[str, Any]]: + if value is None: + return None + if isinstance(value, Collation): + return value.document + if isinstance(value, dict): + return value + raise TypeError("collation must be a dict, an instance of collation.Collation, or None.") diff --git a/pymongo/asynchronous/common.py b/pymongo/common.py similarity index 98% rename from pymongo/asynchronous/common.py rename to pymongo/common.py index 7dcfa29388..16f3ff2580 100644 --- a/pymongo/asynchronous/common.py +++ b/pymongo/common.py @@ -40,20 +40,20 @@ from bson.binary import UuidRepresentation from bson.codec_options import CodecOptions, DatetimeConversion, TypeRegistry from bson.raw_bson import RawBSONDocument -from pymongo.asynchronous.compression_support import ( +from pymongo.compression_support import ( validate_compressors, validate_zlib_compression_level, ) -from pymongo.asynchronous.monitoring import _validate_event_listeners -from pymongo.asynchronous.read_preferences import _MONGOS_MODES, _ServerMode from pymongo.driver_info import DriverInfo from pymongo.errors import ConfigurationError +from pymongo.monitoring import _validate_event_listeners from pymongo.read_concern import ReadConcern +from pymongo.read_preferences import _MONGOS_MODES, _ServerMode from pymongo.server_api import ServerApi from pymongo.write_concern import DEFAULT_WRITE_CONCERN, WriteConcern, validate_boolean if TYPE_CHECKING: - from pymongo.asynchronous.client_session import ClientSession + from pymongo.typings import _AgnosticClientSession _IS_SYNC = False @@ -380,7 +380,7 @@ def validate_read_preference_mode(dummy: Any, value: Any) -> _ServerMode: def validate_auth_mechanism(option: str, value: Any) -> str: """Validate the authMechanism URI option.""" - from pymongo.asynchronous.auth import MECHANISMS + from pymongo.auth_shared import MECHANISMS if value not in MECHANISMS: raise ValueError(f"{option} must be in {tuple(MECHANISMS)}") @@ -446,7 +446,7 @@ def validate_auth_mechanism_properties(option: str, value: Any) -> dict[str, Uni elif key in ["ALLOWED_HOSTS"] and isinstance(value, list): props[key] = value elif key in ["OIDC_CALLBACK", "OIDC_HUMAN_CALLBACK"]: - from pymongo.asynchronous.auth_oidc import OIDCCallback + from pymongo.auth_oidc_shared import OIDCCallback if not isinstance(value, OIDCCallback): raise ValueError("callback must be an OIDCCallback object") @@ -642,7 +642,7 @@ def validate_auto_encryption_opts_or_none(option: Any, value: Any) -> Optional[A """Validate the driver keyword arg.""" if value is None: return value - from pymongo.asynchronous.encryption_options import AutoEncryptionOpts + from pymongo.encryption_options import AutoEncryptionOpts if not isinstance(value, AutoEncryptionOpts): raise TypeError(f"{option} must be an instance of AutoEncryptionOpts") @@ -941,7 +941,7 @@ def write_concern(self) -> WriteConcern: """ return self._write_concern - def _write_concern_for(self, session: Optional[ClientSession]) -> WriteConcern: + def _write_concern_for(self, session: Optional[_AgnosticClientSession]) -> WriteConcern: """Read only access to the write concern of this instance or session.""" # Override this operation's write concern with the transaction's. if session and session.in_transaction: @@ -957,7 +957,7 @@ def read_preference(self) -> _ServerMode: """ return self._read_preference - def _read_preference_for(self, session: Optional[ClientSession]) -> _ServerMode: + def _read_preference_for(self, session: Optional[_AgnosticClientSession]) -> _ServerMode: """Read only access to the read preference of this instance or session.""" # Override this operation's read preference with the transaction's. if session: diff --git a/pymongo/asynchronous/compression_support.py b/pymongo/compression_support.py similarity index 97% rename from pymongo/asynchronous/compression_support.py rename to pymongo/compression_support.py index 8a39bfb465..7a0f2a36dd 100644 --- a/pymongo/asynchronous/compression_support.py +++ b/pymongo/compression_support.py @@ -16,8 +16,8 @@ import warnings from typing import Any, Iterable, Optional, Union -from pymongo.asynchronous.hello_compat import HelloCompat -from pymongo.helpers_constants import _SENSITIVE_COMMANDS +from pymongo.hello_compat import HelloCompat +from pymongo.helpers_shared import _SENSITIVE_COMMANDS _IS_SYNC = False diff --git a/pymongo/encryption_options.py b/pymongo/encryption_options.py index 350344a6da..45ee0dd95d 100644 --- a/pymongo/encryption_options.py +++ b/pymongo/encryption_options.py @@ -1,4 +1,4 @@ -# Copyright 2024-present MongoDB, Inc. +# Copyright 2019-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,10 +12,258 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Re-import of synchronous EncryptionOptions API for compatibility.""" +"""Support for automatic client-side field level encryption.""" from __future__ import annotations -from pymongo.synchronous.encryption_options import * # noqa: F403 -from pymongo.synchronous.encryption_options import __doc__ as original_doc +from typing import TYPE_CHECKING, Any, Mapping, Optional -__doc__ = original_doc +try: + import pymongocrypt # type:ignore[import] # noqa: F401 + + _HAVE_PYMONGOCRYPT = True +except ImportError: + _HAVE_PYMONGOCRYPT = False +from bson import int64 +from pymongo.common import validate_is_mapping +from pymongo.errors import ConfigurationError +from pymongo.uri_parser import _parse_kms_tls_options + +if TYPE_CHECKING: + from pymongo.typings import _AgnosticMongoClient, _DocumentTypeArg + +_IS_SYNC = False + + +class AutoEncryptionOpts: + """Options to configure automatic client-side field level encryption.""" + + def __init__( + self, + kms_providers: Mapping[str, Any], + key_vault_namespace: str, + key_vault_client: Optional[_AgnosticMongoClient[_DocumentTypeArg]] = None, + schema_map: Optional[Mapping[str, Any]] = None, + bypass_auto_encryption: bool = False, + mongocryptd_uri: str = "mongodb://localhost:27020", + mongocryptd_bypass_spawn: bool = False, + mongocryptd_spawn_path: str = "mongocryptd", + mongocryptd_spawn_args: Optional[list[str]] = None, + kms_tls_options: Optional[Mapping[str, Any]] = None, + crypt_shared_lib_path: Optional[str] = None, + crypt_shared_lib_required: bool = False, + bypass_query_analysis: bool = False, + encrypted_fields_map: Optional[Mapping[str, Any]] = None, + ) -> None: + """Options to configure automatic client-side field level encryption. + + Automatic client-side field level encryption requires MongoDB >=4.2 + enterprise or a MongoDB >=4.2 Atlas cluster. Automatic encryption is not + supported for operations on a database or view and will result in + error. + + Although automatic encryption requires MongoDB >=4.2 enterprise or a + MongoDB >=4.2 Atlas cluster, automatic *decryption* is supported for all + users. To configure automatic *decryption* without automatic + *encryption* set ``bypass_auto_encryption=True``. Explicit + encryption and explicit decryption is also supported for all users + with the :class:`~pymongo.encryption.ClientEncryption` class. + + See :ref:`automatic-client-side-encryption` for an example. + + :param kms_providers: Map of KMS provider options. The `kms_providers` + map values differ by provider: + + - `aws`: Map with "accessKeyId" and "secretAccessKey" as strings. + These are the AWS access key ID and AWS secret access key used + to generate KMS messages. An optional "sessionToken" may be + included to support temporary AWS credentials. + - `azure`: Map with "tenantId", "clientId", and "clientSecret" as + strings. Additionally, "identityPlatformEndpoint" may also be + specified as a string (defaults to 'login.microsoftonline.com'). + These are the Azure Active Directory credentials used to + generate Azure Key Vault messages. + - `gcp`: Map with "email" as a string and "privateKey" + as `bytes` or a base64 encoded string. + Additionally, "endpoint" may also be specified as a string + (defaults to 'oauth2.googleapis.com'). These are the + credentials used to generate Google Cloud KMS messages. + - `kmip`: Map with "endpoint" as a host with required port. + For example: ``{"endpoint": "example.com:443"}``. + - `local`: Map with "key" as `bytes` (96 bytes in length) or + a base64 encoded string which decodes + to 96 bytes. "key" is the master key used to encrypt/decrypt + data keys. This key should be generated and stored as securely + as possible. + + KMS providers may be specified with an optional name suffix + separated by a colon, for example "kmip:name" or "aws:name". + Named KMS providers do not support :ref:`CSFLE on-demand credentials`. + Named KMS providers enables more than one of each KMS provider type to be configured. + For example, to configure multiple local KMS providers:: + + kms_providers = { + "local": {"key": local_kek1}, # Unnamed KMS provider. + "local:myname": {"key": local_kek2}, # Named KMS provider with name "myname". + } + + :param key_vault_namespace: The namespace for the key vault collection. + The key vault collection contains all data keys used for encryption + and decryption. Data keys are stored as documents in this MongoDB + collection. Data keys are protected with encryption by a KMS + provider. + :param key_vault_client: By default, the key vault collection + is assumed to reside in the same MongoDB cluster as the encrypted + AsyncMongoClient/MongoClient. Use this option to route data key queries to a + separate MongoDB cluster. + :param schema_map: Map of collection namespace ("db.coll") to + JSON Schema. By default, a collection's JSONSchema is periodically + polled with the listCollections command. But a JSONSchema may be + specified locally with the schemaMap option. + + **Supplying a `schema_map` provides more security than relying on + JSON Schemas obtained from the server. It protects against a + malicious server advertising a false JSON Schema, which could trick + the client into sending unencrypted data that should be + encrypted.** + + Schemas supplied in the schemaMap only apply to configuring + automatic encryption for client side encryption. Other validation + rules in the JSON schema will not be enforced by the driver and + will result in an error. + :param bypass_auto_encryption: If ``True``, automatic + encryption will be disabled but automatic decryption will still be + enabled. Defaults to ``False``. + :param mongocryptd_uri: The MongoDB URI used to connect + to the *local* mongocryptd process. Defaults to + ``'mongodb://localhost:27020'``. + :param mongocryptd_bypass_spawn: If ``True``, the encrypted + AsyncMongoClient/MongoClient will not attempt to spawn the mongocryptd process. + Defaults to ``False``. + :param mongocryptd_spawn_path: Used for spawning the + mongocryptd process. Defaults to ``'mongocryptd'`` and spawns + mongocryptd from the system path. + :param mongocryptd_spawn_args: A list of string arguments to + use when spawning the mongocryptd process. Defaults to + ``['--idleShutdownTimeoutSecs=60']``. If the list does not include + the ``idleShutdownTimeoutSecs`` option then + ``'--idleShutdownTimeoutSecs=60'`` will be added. + :param kms_tls_options: A map of KMS provider names to TLS + options to use when creating secure connections to KMS providers. + Accepts the same TLS options as + :class:`pymongo.mongo_client.AsyncMongoClient` and :class:`pymongo.mongo_client.MongoClient`. For example, to + override the system default CA file:: + + kms_tls_options={'kmip': {'tlsCAFile': certifi.where()}} + + Or to supply a client certificate:: + + kms_tls_options={'kmip': {'tlsCertificateKeyFile': 'client.pem'}} + :param crypt_shared_lib_path: Override the path to load the crypt_shared library. + :param crypt_shared_lib_required: If True, raise an error if libmongocrypt is + unable to load the crypt_shared library. + :param bypass_query_analysis: If ``True``, disable automatic analysis + of outgoing commands. Set `bypass_query_analysis` to use explicit + encryption on indexed fields without the MongoDB Enterprise Advanced + licensed crypt_shared library. + :param encrypted_fields_map: Map of collection namespace ("db.coll") to documents + that described the encrypted fields for Queryable Encryption. For example:: + + { + "db.encryptedCollection": { + "escCollection": "enxcol_.encryptedCollection.esc", + "ecocCollection": "enxcol_.encryptedCollection.ecoc", + "fields": [ + { + "path": "firstName", + "keyId": Binary.from_uuid(UUID('00000000-0000-0000-0000-000000000000')), + "bsonType": "string", + "queries": {"queryType": "equality"} + }, + { + "path": "ssn", + "keyId": Binary.from_uuid(UUID('04104104-1041-0410-4104-104104104104')), + "bsonType": "string" + } + ] + } + } + + .. versionchanged:: 4.2 + Added `encrypted_fields_map` `crypt_shared_lib_path`, `crypt_shared_lib_required`, + and `bypass_query_analysis` parameters. + + .. versionchanged:: 4.0 + Added the `kms_tls_options` parameter and the "kmip" KMS provider. + + .. versionadded:: 3.9 + """ + if not _HAVE_PYMONGOCRYPT: + raise ConfigurationError( + "client side encryption requires the pymongocrypt library: " + "install a compatible version with: " + "python -m pip install 'pymongo[encryption]'" + ) + if encrypted_fields_map: + validate_is_mapping("encrypted_fields_map", encrypted_fields_map) + self._encrypted_fields_map = encrypted_fields_map + self._bypass_query_analysis = bypass_query_analysis + self._crypt_shared_lib_path = crypt_shared_lib_path + self._crypt_shared_lib_required = crypt_shared_lib_required + self._kms_providers = kms_providers + self._key_vault_namespace = key_vault_namespace + self._key_vault_client = key_vault_client + self._schema_map = schema_map + self._bypass_auto_encryption = bypass_auto_encryption + self._mongocryptd_uri = mongocryptd_uri + self._mongocryptd_bypass_spawn = mongocryptd_bypass_spawn + self._mongocryptd_spawn_path = mongocryptd_spawn_path + if mongocryptd_spawn_args is None: + mongocryptd_spawn_args = ["--idleShutdownTimeoutSecs=60"] + self._mongocryptd_spawn_args = mongocryptd_spawn_args + if not isinstance(self._mongocryptd_spawn_args, list): + raise TypeError("mongocryptd_spawn_args must be a list") + if not any("idleShutdownTimeoutSecs" in s for s in self._mongocryptd_spawn_args): + self._mongocryptd_spawn_args.append("--idleShutdownTimeoutSecs=60") + # Maps KMS provider name to a SSLContext. + self._kms_ssl_contexts = _parse_kms_tls_options(kms_tls_options) + self._bypass_query_analysis = bypass_query_analysis + + +class RangeOpts: + """Options to configure encrypted queries using the rangePreview algorithm.""" + + def __init__( + self, + sparsity: int, + min: Optional[Any] = None, + max: Optional[Any] = None, + precision: Optional[int] = None, + ) -> None: + """Options to configure encrypted queries using the rangePreview algorithm. + + .. note:: This feature is experimental only, and not intended for public use. + + :param sparsity: An integer. + :param min: A BSON scalar value corresponding to the type being queried. + :param max: A BSON scalar value corresponding to the type being queried. + :param precision: An integer, may only be set for double or decimal128 types. + + .. versionadded:: 4.4 + """ + self.min = min + self.max = max + self.sparsity = sparsity + self.precision = precision + + @property + def document(self) -> dict[str, Any]: + doc = {} + for k, v in [ + ("sparsity", int64.Int64(self.sparsity)), + ("precision", self.precision), + ("min", self.min), + ("max", self.max), + ]: + if v is not None: + doc[k] = v + return doc diff --git a/pymongo/errors.py b/pymongo/errors.py index 7efbc1ff31..a781e4a016 100644 --- a/pymongo/errors.py +++ b/pymongo/errors.py @@ -21,7 +21,7 @@ from bson.errors import InvalidDocument if TYPE_CHECKING: - from pymongo.asynchronous.typings import _DocumentOut + from pymongo.typings import _DocumentOut class PyMongoError(Exception): diff --git a/pymongo/event_loggers.py b/pymongo/event_loggers.py index 756e90ba23..3a241df52b 100644 --- a/pymongo/event_loggers.py +++ b/pymongo/event_loggers.py @@ -1,4 +1,4 @@ -# Copyright 2024-present MongoDB, Inc. +# Copyright 2020-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,10 +12,214 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Re-import of synchronous EventLoggers API for compatibility.""" + +"""Example event logger classes. + +.. versionadded:: 3.11 + +These loggers can be registered using :func:`register` or +:class:`~pymongo.mongo_client.MongoClient`. + +``monitoring.register(CommandLogger())`` + +or + +``MongoClient(event_listeners=[CommandLogger()])`` +""" from __future__ import annotations -from pymongo.synchronous.event_loggers import * # noqa: F403 -from pymongo.synchronous.event_loggers import __doc__ as original_doc +import logging + +from pymongo import monitoring + +_IS_SYNC = False + + +class CommandLogger(monitoring.CommandListener): + """A simple listener that logs command events. + + Listens for :class:`~pymongo.monitoring.CommandStartedEvent`, + :class:`~pymongo.monitoring.CommandSucceededEvent` and + :class:`~pymongo.monitoring.CommandFailedEvent` events and + logs them at the `INFO` severity level using :mod:`logging`. + .. versionadded:: 3.11 + """ + + def started(self, event: monitoring.CommandStartedEvent) -> None: + logging.info( + f"Command {event.command_name} with request id " + f"{event.request_id} started on server " + f"{event.connection_id}" + ) + + def succeeded(self, event: monitoring.CommandSucceededEvent) -> None: + logging.info( + f"Command {event.command_name} with request id " + f"{event.request_id} on server {event.connection_id} " + f"succeeded in {event.duration_micros} " + "microseconds" + ) + + def failed(self, event: monitoring.CommandFailedEvent) -> None: + logging.info( + f"Command {event.command_name} with request id " + f"{event.request_id} on server {event.connection_id} " + f"failed in {event.duration_micros} " + "microseconds" + ) + + +class ServerLogger(monitoring.ServerListener): + """A simple listener that logs server discovery events. + + Listens for :class:`~pymongo.monitoring.ServerOpeningEvent`, + :class:`~pymongo.monitoring.ServerDescriptionChangedEvent`, + and :class:`~pymongo.monitoring.ServerClosedEvent` + events and logs them at the `INFO` severity level using :mod:`logging`. + + .. versionadded:: 3.11 + """ + + def opened(self, event: monitoring.ServerOpeningEvent) -> None: + logging.info(f"Server {event.server_address} added to topology {event.topology_id}") + + def description_changed(self, event: monitoring.ServerDescriptionChangedEvent) -> None: + previous_server_type = event.previous_description.server_type + new_server_type = event.new_description.server_type + if new_server_type != previous_server_type: + # server_type_name was added in PyMongo 3.4 + logging.info( + f"Server {event.server_address} changed type from " + f"{event.previous_description.server_type_name} to " + f"{event.new_description.server_type_name}" + ) + + def closed(self, event: monitoring.ServerClosedEvent) -> None: + logging.warning(f"Server {event.server_address} removed from topology {event.topology_id}") + + +class HeartbeatLogger(monitoring.ServerHeartbeatListener): + """A simple listener that logs server heartbeat events. + + Listens for :class:`~pymongo.monitoring.ServerHeartbeatStartedEvent`, + :class:`~pymongo.monitoring.ServerHeartbeatSucceededEvent`, + and :class:`~pymongo.monitoring.ServerHeartbeatFailedEvent` + events and logs them at the `INFO` severity level using :mod:`logging`. + + .. versionadded:: 3.11 + """ + + def started(self, event: monitoring.ServerHeartbeatStartedEvent) -> None: + logging.info(f"Heartbeat sent to server {event.connection_id}") + + def succeeded(self, event: monitoring.ServerHeartbeatSucceededEvent) -> None: + # The reply.document attribute was added in PyMongo 3.4. + logging.info( + f"Heartbeat to server {event.connection_id} " + "succeeded with reply " + f"{event.reply.document}" + ) + + def failed(self, event: monitoring.ServerHeartbeatFailedEvent) -> None: + logging.warning( + f"Heartbeat to server {event.connection_id} failed with error {event.reply}" + ) + + +class TopologyLogger(monitoring.TopologyListener): + """A simple listener that logs server topology events. + + Listens for :class:`~pymongo.monitoring.TopologyOpenedEvent`, + :class:`~pymongo.monitoring.TopologyDescriptionChangedEvent`, + and :class:`~pymongo.monitoring.TopologyClosedEvent` + events and logs them at the `INFO` severity level using :mod:`logging`. + + .. versionadded:: 3.11 + """ + + def opened(self, event: monitoring.TopologyOpenedEvent) -> None: + logging.info(f"Topology with id {event.topology_id} opened") + + def description_changed(self, event: monitoring.TopologyDescriptionChangedEvent) -> None: + logging.info(f"Topology description updated for topology id {event.topology_id}") + previous_topology_type = event.previous_description.topology_type + new_topology_type = event.new_description.topology_type + if new_topology_type != previous_topology_type: + # topology_type_name was added in PyMongo 3.4 + logging.info( + f"Topology {event.topology_id} changed type from " + f"{event.previous_description.topology_type_name} to " + f"{event.new_description.topology_type_name}" + ) + # The has_writable_server and has_readable_server methods + # were added in PyMongo 3.4. + if not event.new_description.has_writable_server(): + logging.warning("No writable servers available.") + if not event.new_description.has_readable_server(): + logging.warning("No readable servers available.") + + def closed(self, event: monitoring.TopologyClosedEvent) -> None: + logging.info(f"Topology with id {event.topology_id} closed") + + +class ConnectionPoolLogger(monitoring.ConnectionPoolListener): + """A simple listener that logs server connection pool events. + + Listens for :class:`~pymongo.monitoring.PoolCreatedEvent`, + :class:`~pymongo.monitoring.PoolClearedEvent`, + :class:`~pymongo.monitoring.PoolClosedEvent`, + :~pymongo.monitoring.class:`ConnectionCreatedEvent`, + :class:`~pymongo.monitoring.ConnectionReadyEvent`, + :class:`~pymongo.monitoring.ConnectionClosedEvent`, + :class:`~pymongo.monitoring.ConnectionCheckOutStartedEvent`, + :class:`~pymongo.monitoring.ConnectionCheckOutFailedEvent`, + :class:`~pymongo.monitoring.ConnectionCheckedOutEvent`, + and :class:`~pymongo.monitoring.ConnectionCheckedInEvent` + events and logs them at the `INFO` severity level using :mod:`logging`. + + .. versionadded:: 3.11 + """ + + def pool_created(self, event: monitoring.PoolCreatedEvent) -> None: + logging.info(f"[pool {event.address}] pool created") + + def pool_ready(self, event: monitoring.PoolReadyEvent) -> None: + logging.info(f"[pool {event.address}] pool ready") + + def pool_cleared(self, event: monitoring.PoolClearedEvent) -> None: + logging.info(f"[pool {event.address}] pool cleared") + + def pool_closed(self, event: monitoring.PoolClosedEvent) -> None: + logging.info(f"[pool {event.address}] pool closed") + + def connection_created(self, event: monitoring.ConnectionCreatedEvent) -> None: + logging.info(f"[pool {event.address}][conn #{event.connection_id}] connection created") + + def connection_ready(self, event: monitoring.ConnectionReadyEvent) -> None: + logging.info( + f"[pool {event.address}][conn #{event.connection_id}] connection setup succeeded" + ) + + def connection_closed(self, event: monitoring.ConnectionClosedEvent) -> None: + logging.info( + f"[pool {event.address}][conn #{event.connection_id}] " + f'connection closed, reason: "{event.reason}"' + ) + + def connection_check_out_started( + self, event: monitoring.ConnectionCheckOutStartedEvent + ) -> None: + logging.info(f"[pool {event.address}] connection check out started") + + def connection_check_out_failed(self, event: monitoring.ConnectionCheckOutFailedEvent) -> None: + logging.info(f"[pool {event.address}] connection check out failed, reason: {event.reason}") + + def connection_checked_out(self, event: monitoring.ConnectionCheckedOutEvent) -> None: + logging.info( + f"[pool {event.address}][conn #{event.connection_id}] connection checked out of pool" + ) -__doc__ = original_doc + def connection_checked_in(self, event: monitoring.ConnectionCheckedInEvent) -> None: + logging.info( + f"[pool {event.address}][conn #{event.connection_id}] connection checked into pool" + ) diff --git a/pymongo/asynchronous/hello.py b/pymongo/hello.py similarity index 97% rename from pymongo/asynchronous/hello.py rename to pymongo/hello.py index 3826e8a27f..40bd842c0a 100644 --- a/pymongo/asynchronous/hello.py +++ b/pymongo/hello.py @@ -21,10 +21,10 @@ from typing import Any, Generic, Mapping, Optional from bson.objectid import ObjectId -from pymongo.asynchronous import common -from pymongo.asynchronous.hello_compat import HelloCompat -from pymongo.asynchronous.typings import ClusterTime, _DocumentType +from pymongo import common +from pymongo.hello_compat import HelloCompat from pymongo.server_type import SERVER_TYPE +from pymongo.typings import ClusterTime, _DocumentType _IS_SYNC = False diff --git a/pymongo/asynchronous/hello_compat.py b/pymongo/hello_compat.py similarity index 100% rename from pymongo/asynchronous/hello_compat.py rename to pymongo/hello_compat.py diff --git a/pymongo/helpers_constants.py b/pymongo/helpers_constants.py deleted file mode 100644 index 00b2502701..0000000000 --- a/pymongo/helpers_constants.py +++ /dev/null @@ -1,72 +0,0 @@ -# Copyright 2024-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -"""Constants used by the driver that don't really fit elsewhere.""" - -# From the SDAM spec, the "node is shutting down" codes. -from __future__ import annotations - -_SHUTDOWN_CODES: frozenset = frozenset( - [ - 11600, # InterruptedAtShutdown - 91, # ShutdownInProgress - ] -) -# From the SDAM spec, the "not primary" error codes are combined with the -# "node is recovering" error codes (of which the "node is shutting down" -# errors are a subset). -_NOT_PRIMARY_CODES: frozenset = ( - frozenset( - [ - 10058, # LegacyNotPrimary <=3.2 "not primary" error code - 10107, # NotWritablePrimary - 13435, # NotPrimaryNoSecondaryOk - 11602, # InterruptedDueToReplStateChange - 13436, # NotPrimaryOrSecondary - 189, # PrimarySteppedDown - ] - ) - | _SHUTDOWN_CODES -) -# From the retryable writes spec. -_RETRYABLE_ERROR_CODES: frozenset = _NOT_PRIMARY_CODES | frozenset( - [ - 7, # HostNotFound - 6, # HostUnreachable - 89, # NetworkTimeout - 9001, # SocketException - 262, # ExceededTimeLimit - 134, # ReadConcernMajorityNotAvailableYet - ] -) - -# Server code raised when re-authentication is required -_REAUTHENTICATION_REQUIRED_CODE: int = 391 - -# Server code raised when authentication fails. -_AUTHENTICATION_FAILURE_CODE: int = 18 - -# Note - to avoid bugs from forgetting which if these is all lowercase and -# which are camelCase, and at the same time avoid having to add a test for -# every command, use all lowercase here and test against command_name.lower(). -_SENSITIVE_COMMANDS: set = { - "authenticate", - "saslstart", - "saslcontinue", - "getnonce", - "createuser", - "updateuser", - "copydbgetnonce", - "copydbsaslstart", - "copydb", -} diff --git a/pymongo/helpers_shared.py b/pymongo/helpers_shared.py new file mode 100644 index 0000000000..884a008385 --- /dev/null +++ b/pymongo/helpers_shared.py @@ -0,0 +1,328 @@ +# Copyright 2009-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Bits and pieces used by the driver that don't really fit elsewhere.""" +from __future__ import annotations + +import sys +import traceback +from collections import abc +from typing import ( + TYPE_CHECKING, + Any, + Container, + Iterable, + Mapping, + NoReturn, + Optional, + Sequence, + Union, +) + +from pymongo import ASCENDING +from pymongo.errors import ( + CursorNotFound, + DuplicateKeyError, + ExecutionTimeout, + NotPrimaryError, + OperationFailure, + WriteConcernError, + WriteError, + WTimeoutError, + _wtimeout_error, +) +from pymongo.hello_compat import HelloCompat + +if TYPE_CHECKING: + from pymongo.cursor_shared import _Hint + from pymongo.operations import _IndexList + from pymongo.typings import _DocumentOut + +_IS_SYNC = False + +# From the SDAM spec, the "node is shutting down" codes. + +_SHUTDOWN_CODES: frozenset = frozenset( + [ + 11600, # InterruptedAtShutdown + 91, # ShutdownInProgress + ] +) +# From the SDAM spec, the "not primary" error codes are combined with the +# "node is recovering" error codes (of which the "node is shutting down" +# errors are a subset). +_NOT_PRIMARY_CODES: frozenset = ( + frozenset( + [ + 10058, # LegacyNotPrimary <=3.2 "not primary" error code + 10107, # NotWritablePrimary + 13435, # NotPrimaryNoSecondaryOk + 11602, # InterruptedDueToReplStateChange + 13436, # NotPrimaryOrSecondary + 189, # PrimarySteppedDown + ] + ) + | _SHUTDOWN_CODES +) +# From the retryable writes spec. +_RETRYABLE_ERROR_CODES: frozenset = _NOT_PRIMARY_CODES | frozenset( + [ + 7, # HostNotFound + 6, # HostUnreachable + 89, # NetworkTimeout + 9001, # SocketException + 262, # ExceededTimeLimit + 134, # ReadConcernMajorityNotAvailableYet + ] +) + +# Server code raised when re-authentication is required +_REAUTHENTICATION_REQUIRED_CODE: int = 391 + +# Server code raised when authentication fails. +_AUTHENTICATION_FAILURE_CODE: int = 18 + +# Note - to avoid bugs from forgetting which if these is all lowercase and +# which are camelCase, and at the same time avoid having to add a test for +# every command, use all lowercase here and test against command_name.lower(). +_SENSITIVE_COMMANDS: set = { + "authenticate", + "saslstart", + "saslcontinue", + "getnonce", + "createuser", + "updateuser", + "copydbgetnonce", + "copydbsaslstart", + "copydb", +} + + +def _gen_index_name(keys: _IndexList) -> str: + """Generate an index name from the set of fields it is over.""" + return "_".join(["{}_{}".format(*item) for item in keys]) + + +def _index_list( + key_or_list: _Hint, direction: Optional[Union[int, str]] = None +) -> Sequence[tuple[str, Union[int, str, Mapping[str, Any]]]]: + """Helper to generate a list of (key, direction) pairs. + + Takes such a list, or a single key, or a single key and direction. + """ + if direction is not None: + if not isinstance(key_or_list, str): + raise TypeError("Expected a string and a direction") + return [(key_or_list, direction)] + else: + if isinstance(key_or_list, str): + return [(key_or_list, ASCENDING)] + elif isinstance(key_or_list, abc.ItemsView): + return list(key_or_list) # type: ignore[arg-type] + elif isinstance(key_or_list, abc.Mapping): + return list(key_or_list.items()) + elif not isinstance(key_or_list, (list, tuple)): + raise TypeError("if no direction is specified, key_or_list must be an instance of list") + values: list[tuple[str, int]] = [] + for item in key_or_list: + if isinstance(item, str): + item = (item, ASCENDING) # noqa: PLW2901 + values.append(item) + return values + + +def _index_document(index_list: _IndexList) -> dict[str, Any]: + """Helper to generate an index specifying document. + + Takes a list of (key, direction) pairs. + """ + if not isinstance(index_list, (list, tuple, abc.Mapping)): + raise TypeError( + "must use a dictionary or a list of (key, direction) pairs, not: " + repr(index_list) + ) + if not len(index_list): + raise ValueError("key_or_list must not be empty") + + index: dict[str, Any] = {} + + if isinstance(index_list, abc.Mapping): + for key in index_list: + value = index_list[key] + _validate_index_key_pair(key, value) + index[key] = value + else: + for item in index_list: + if isinstance(item, str): + item = (item, ASCENDING) # noqa: PLW2901 + key, value = item + _validate_index_key_pair(key, value) + index[key] = value + return index + + +def _validate_index_key_pair(key: Any, value: Any) -> None: + if not isinstance(key, str): + raise TypeError("first item in each key pair must be an instance of str") + if not isinstance(value, (str, int, abc.Mapping)): + raise TypeError( + "second item in each key pair must be 1, -1, " + "'2d', or another valid MongoDB index specifier." + ) + + +def _check_command_response( + response: _DocumentOut, + max_wire_version: Optional[int], + allowable_errors: Optional[Container[Union[int, str]]] = None, + parse_write_concern_error: bool = False, +) -> None: + """Check the response to a command for errors.""" + if "ok" not in response: + # Server didn't recognize our message as a command. + raise OperationFailure( + response.get("$err"), # type: ignore[arg-type] + response.get("code"), + response, + max_wire_version, + ) + + if parse_write_concern_error and "writeConcernError" in response: + _error = response["writeConcernError"] + _labels = response.get("errorLabels") + if _labels: + _error.update({"errorLabels": _labels}) + _raise_write_concern_error(_error) + + if response["ok"]: + return + + details = response + # Mongos returns the error details in a 'raw' object + # for some errors. + if "raw" in response: + for shard in response["raw"].values(): + # Grab the first non-empty raw error from a shard. + if shard.get("errmsg") and not shard.get("ok"): + details = shard + break + + errmsg = details["errmsg"] + code = details.get("code") + + # For allowable errors, only check for error messages when the code is not + # included. + if allowable_errors: + if code is not None: + if code in allowable_errors: + return + elif errmsg in allowable_errors: + return + + # Server is "not primary" or "recovering" + if code is not None: + if code in _NOT_PRIMARY_CODES: + raise NotPrimaryError(errmsg, response) + elif HelloCompat.LEGACY_ERROR in errmsg or "node is recovering" in errmsg: + raise NotPrimaryError(errmsg, response) + + # Other errors + # findAndModify with upsert can raise duplicate key error + if code in (11000, 11001, 12582): + raise DuplicateKeyError(errmsg, code, response, max_wire_version) + elif code == 50: + raise ExecutionTimeout(errmsg, code, response, max_wire_version) + elif code == 43: + raise CursorNotFound(errmsg, code, response, max_wire_version) + + raise OperationFailure(errmsg, code, response, max_wire_version) + + +def _raise_last_write_error(write_errors: list[Any]) -> NoReturn: + # If the last batch had multiple errors only report + # the last error to emulate continue_on_error. + error = write_errors[-1] + if error.get("code") == 11000: + raise DuplicateKeyError(error.get("errmsg"), 11000, error) + raise WriteError(error.get("errmsg"), error.get("code"), error) + + +def _raise_write_concern_error(error: Any) -> NoReturn: + if _wtimeout_error(error): + # Make sure we raise WTimeoutError + raise WTimeoutError(error.get("errmsg"), error.get("code"), error) + raise WriteConcernError(error.get("errmsg"), error.get("code"), error) + + +def _get_wce_doc(result: Mapping[str, Any]) -> Optional[Mapping[str, Any]]: + """Return the writeConcernError or None.""" + wce = result.get("writeConcernError") + if wce: + # The server reports errorLabels at the top level but it's more + # convenient to attach it to the writeConcernError doc itself. + error_labels = result.get("errorLabels") + if error_labels: + # Copy to avoid changing the original document. + wce = wce.copy() + wce["errorLabels"] = error_labels + return wce + + +def _check_write_command_response(result: Mapping[str, Any]) -> None: + """Backward compatibility helper for write command error handling.""" + # Prefer write errors over write concern errors + write_errors = result.get("writeErrors") + if write_errors: + _raise_last_write_error(write_errors) + + wce = _get_wce_doc(result) + if wce: + _raise_write_concern_error(wce) + + +def _fields_list_to_dict( + fields: Union[Mapping[str, Any], Iterable[str]], option_name: str +) -> Mapping[str, Any]: + """Takes a sequence of field names and returns a matching dictionary. + + ["a", "b"] becomes {"a": 1, "b": 1} + + and + + ["a.b.c", "d", "a.c"] becomes {"a.b.c": 1, "d": 1, "a.c": 1} + """ + if isinstance(fields, abc.Mapping): + return fields + + if isinstance(fields, (abc.Sequence, abc.Set)): + if not all(isinstance(field, str) for field in fields): + raise TypeError(f"{option_name} must be a list of key names, each an instance of str") + return dict.fromkeys(fields, 1) + + raise TypeError(f"{option_name} must be a mapping or list of key names") + + +def _handle_exception() -> None: + """Print exceptions raised by subscribers to stderr.""" + # Heavily influenced by logging.Handler.handleError. + + # See note here: + # https://docs.python.org/3.4/library/sys.html#sys.__stderr__ + if sys.stderr: + einfo = sys.exc_info() + try: + traceback.print_exception(einfo[0], einfo[1], einfo[2], None, sys.stderr) + except OSError: + pass + finally: + del einfo diff --git a/pymongo/asynchronous/logger.py b/pymongo/logger.py similarity index 98% rename from pymongo/asynchronous/logger.py rename to pymongo/logger.py index 4fe8201273..ed398c8329 100644 --- a/pymongo/asynchronous/logger.py +++ b/pymongo/logger.py @@ -21,7 +21,7 @@ from bson import UuidRepresentation, json_util from bson.json_util import JSONOptions, _truncate_documents -from pymongo.asynchronous.monitoring import ConnectionCheckOutFailedReason, ConnectionClosedReason +from pymongo.monitoring import ConnectionCheckOutFailedReason, ConnectionClosedReason _IS_SYNC = False diff --git a/pymongo/asynchronous/max_staleness_selectors.py b/pymongo/max_staleness_selectors.py similarity index 98% rename from pymongo/asynchronous/max_staleness_selectors.py rename to pymongo/max_staleness_selectors.py index fadd3b429d..d9b2396a0c 100644 --- a/pymongo/asynchronous/max_staleness_selectors.py +++ b/pymongo/max_staleness_selectors.py @@ -34,7 +34,7 @@ from pymongo.server_type import SERVER_TYPE if TYPE_CHECKING: - from pymongo.asynchronous.server_selectors import Selection + from pymongo.server_selectors import Selection _IS_SYNC = False diff --git a/pymongo/monitoring.py b/pymongo/monitoring.py index b9825b4ca3..87451d5180 100644 --- a/pymongo/monitoring.py +++ b/pymongo/monitoring.py @@ -1,21 +1,1902 @@ -# Copyright 2024-present MongoDB, Inc. +# Copyright 2015-present MongoDB, Inc. # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); you +# may not use this file except in compliance with the License. You +# may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + +"""Tools to monitor driver events. + +.. versionadded:: 3.1 + +.. attention:: Starting in PyMongo 3.11, the monitoring classes outlined below + are included in the PyMongo distribution under the + :mod:`~pymongo.event_loggers` submodule. + +Use :func:`register` to register global listeners for specific events. +Listeners must inherit from one of the abstract classes below and implement +the correct functions for that class. + +For example, a simple command logger might be implemented like this:: + + import logging + + from pymongo import monitoring + + class CommandLogger(monitoring.CommandListener): + + def started(self, event): + logging.info("Command {0.command_name} with request id " + "{0.request_id} started on server " + "{0.connection_id}".format(event)) + + def succeeded(self, event): + logging.info("Command {0.command_name} with request id " + "{0.request_id} on server {0.connection_id} " + "succeeded in {0.duration_micros} " + "microseconds".format(event)) + + def failed(self, event): + logging.info("Command {0.command_name} with request id " + "{0.request_id} on server {0.connection_id} " + "failed in {0.duration_micros} " + "microseconds".format(event)) + + monitoring.register(CommandLogger()) + +Server discovery and monitoring events are also available. For example:: + + class ServerLogger(monitoring.ServerListener): + + def opened(self, event): + logging.info("Server {0.server_address} added to topology " + "{0.topology_id}".format(event)) + + def description_changed(self, event): + previous_server_type = event.previous_description.server_type + new_server_type = event.new_description.server_type + if new_server_type != previous_server_type: + # server_type_name was added in PyMongo 3.4 + logging.info( + "Server {0.server_address} changed type from " + "{0.previous_description.server_type_name} to " + "{0.new_description.server_type_name}".format(event)) + + def closed(self, event): + logging.warning("Server {0.server_address} removed from topology " + "{0.topology_id}".format(event)) + + + class HeartbeatLogger(monitoring.ServerHeartbeatListener): + + def started(self, event): + logging.info("Heartbeat sent to server " + "{0.connection_id}".format(event)) + + def succeeded(self, event): + # The reply.document attribute was added in PyMongo 3.4. + logging.info("Heartbeat to server {0.connection_id} " + "succeeded with reply " + "{0.reply.document}".format(event)) + + def failed(self, event): + logging.warning("Heartbeat to server {0.connection_id} " + "failed with error {0.reply}".format(event)) + + class TopologyLogger(monitoring.TopologyListener): + + def opened(self, event): + logging.info("Topology with id {0.topology_id} " + "opened".format(event)) + + def description_changed(self, event): + logging.info("Topology description updated for " + "topology id {0.topology_id}".format(event)) + previous_topology_type = event.previous_description.topology_type + new_topology_type = event.new_description.topology_type + if new_topology_type != previous_topology_type: + # topology_type_name was added in PyMongo 3.4 + logging.info( + "Topology {0.topology_id} changed type from " + "{0.previous_description.topology_type_name} to " + "{0.new_description.topology_type_name}".format(event)) + # The has_writable_server and has_readable_server methods + # were added in PyMongo 3.4. + if not event.new_description.has_writable_server(): + logging.warning("No writable servers available.") + if not event.new_description.has_readable_server(): + logging.warning("No readable servers available.") + + def closed(self, event): + logging.info("Topology with id {0.topology_id} " + "closed".format(event)) + +Connection monitoring and pooling events are also available. For example:: + + class ConnectionPoolLogger(ConnectionPoolListener): + + def pool_created(self, event): + logging.info("[pool {0.address}] pool created".format(event)) + + def pool_ready(self, event): + logging.info("[pool {0.address}] pool is ready".format(event)) + + def pool_cleared(self, event): + logging.info("[pool {0.address}] pool cleared".format(event)) + + def pool_closed(self, event): + logging.info("[pool {0.address}] pool closed".format(event)) + + def connection_created(self, event): + logging.info("[pool {0.address}][connection #{0.connection_id}] " + "connection created".format(event)) + + def connection_ready(self, event): + logging.info("[pool {0.address}][connection #{0.connection_id}] " + "connection setup succeeded".format(event)) + + def connection_closed(self, event): + logging.info("[pool {0.address}][connection #{0.connection_id}] " + "connection closed, reason: " + "{0.reason}".format(event)) + + def connection_check_out_started(self, event): + logging.info("[pool {0.address}] connection check out " + "started".format(event)) + + def connection_check_out_failed(self, event): + logging.info("[pool {0.address}] connection check out " + "failed, reason: {0.reason}".format(event)) + + def connection_checked_out(self, event): + logging.info("[pool {0.address}][connection #{0.connection_id}] " + "connection checked out of pool".format(event)) + + def connection_checked_in(self, event): + logging.info("[pool {0.address}][connection #{0.connection_id}] " + "connection checked into pool".format(event)) + + +Event listeners can also be registered per instance of +:class:`~pymongo.mongo_client.MongoClient`:: + + client = MongoClient(event_listeners=[CommandLogger()]) + +Note that previously registered global listeners are automatically included +when configuring per client event listeners. Registering a new global listener +will not add that listener to existing client instances. + +.. note:: Events are delivered **synchronously**. Application threads block + waiting for event handlers (e.g. :meth:`~CommandListener.started`) to + return. Care must be taken to ensure that your event handlers are efficient + enough to not adversely affect overall application performance. + +.. warning:: The command documents published through this API are *not* copies. + If you intend to modify them in any way you must copy them in your event + handler first. +""" -"""Re-import of synchronous Monitoring API for compatibility.""" from __future__ import annotations -from pymongo.synchronous.monitoring import * # noqa: F403 -from pymongo.synchronous.monitoring import __doc__ as original_doc +import datetime +from collections import abc, namedtuple +from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence + +from bson.objectid import ObjectId +from pymongo.hello import Hello +from pymongo.hello_compat import HelloCompat +from pymongo.helpers_shared import _SENSITIVE_COMMANDS, _handle_exception +from pymongo.typings import _Address, _DocumentOut + +if TYPE_CHECKING: + from datetime import timedelta + + from pymongo.server_description import ServerDescription + from pymongo.topology_description import TopologyDescription + +_IS_SYNC = False + +_Listeners = namedtuple( + "_Listeners", + ( + "command_listeners", + "server_listeners", + "server_heartbeat_listeners", + "topology_listeners", + "cmap_listeners", + ), +) + +_LISTENERS = _Listeners([], [], [], [], []) + + +class _EventListener: + """Abstract base class for all event listeners.""" + + +class CommandListener(_EventListener): + """Abstract base class for command listeners. + + Handles `CommandStartedEvent`, `CommandSucceededEvent`, + and `CommandFailedEvent`. + """ + + def started(self, event: CommandStartedEvent) -> None: + """Abstract method to handle a `CommandStartedEvent`. + + :param event: An instance of :class:`CommandStartedEvent`. + """ + raise NotImplementedError + + def succeeded(self, event: CommandSucceededEvent) -> None: + """Abstract method to handle a `CommandSucceededEvent`. + + :param event: An instance of :class:`CommandSucceededEvent`. + """ + raise NotImplementedError + + def failed(self, event: CommandFailedEvent) -> None: + """Abstract method to handle a `CommandFailedEvent`. + + :param event: An instance of :class:`CommandFailedEvent`. + """ + raise NotImplementedError + + +class ConnectionPoolListener(_EventListener): + """Abstract base class for connection pool listeners. + + Handles all of the connection pool events defined in the Connection + Monitoring and Pooling Specification: + :class:`PoolCreatedEvent`, :class:`PoolClearedEvent`, + :class:`PoolClosedEvent`, :class:`ConnectionCreatedEvent`, + :class:`ConnectionReadyEvent`, :class:`ConnectionClosedEvent`, + :class:`ConnectionCheckOutStartedEvent`, + :class:`ConnectionCheckOutFailedEvent`, + :class:`ConnectionCheckedOutEvent`, + and :class:`ConnectionCheckedInEvent`. + + .. versionadded:: 3.9 + """ + + def pool_created(self, event: PoolCreatedEvent) -> None: + """Abstract method to handle a :class:`PoolCreatedEvent`. + + Emitted when a connection Pool is created. + + :param event: An instance of :class:`PoolCreatedEvent`. + """ + raise NotImplementedError + + def pool_ready(self, event: PoolReadyEvent) -> None: + """Abstract method to handle a :class:`PoolReadyEvent`. + + Emitted when a connection Pool is marked ready. + + :param event: An instance of :class:`PoolReadyEvent`. + + .. versionadded:: 4.0 + """ + raise NotImplementedError + + def pool_cleared(self, event: PoolClearedEvent) -> None: + """Abstract method to handle a `PoolClearedEvent`. + + Emitted when a connection Pool is cleared. + + :param event: An instance of :class:`PoolClearedEvent`. + """ + raise NotImplementedError + + def pool_closed(self, event: PoolClosedEvent) -> None: + """Abstract method to handle a `PoolClosedEvent`. + + Emitted when a connection Pool is closed. + + :param event: An instance of :class:`PoolClosedEvent`. + """ + raise NotImplementedError + + def connection_created(self, event: ConnectionCreatedEvent) -> None: + """Abstract method to handle a :class:`ConnectionCreatedEvent`. + + Emitted when a connection Pool creates a Connection object. + + :param event: An instance of :class:`ConnectionCreatedEvent`. + """ + raise NotImplementedError + + def connection_ready(self, event: ConnectionReadyEvent) -> None: + """Abstract method to handle a :class:`ConnectionReadyEvent`. + + Emitted when a connection has finished its setup, and is now ready to + use. + + :param event: An instance of :class:`ConnectionReadyEvent`. + """ + raise NotImplementedError + + def connection_closed(self, event: ConnectionClosedEvent) -> None: + """Abstract method to handle a :class:`ConnectionClosedEvent`. + + Emitted when a connection Pool closes a connection. + + :param event: An instance of :class:`ConnectionClosedEvent`. + """ + raise NotImplementedError + + def connection_check_out_started(self, event: ConnectionCheckOutStartedEvent) -> None: + """Abstract method to handle a :class:`ConnectionCheckOutStartedEvent`. + + Emitted when the driver starts attempting to check out a connection. + + :param event: An instance of :class:`ConnectionCheckOutStartedEvent`. + """ + raise NotImplementedError + + def connection_check_out_failed(self, event: ConnectionCheckOutFailedEvent) -> None: + """Abstract method to handle a :class:`ConnectionCheckOutFailedEvent`. + + Emitted when the driver's attempt to check out a connection fails. + + :param event: An instance of :class:`ConnectionCheckOutFailedEvent`. + """ + raise NotImplementedError + + def connection_checked_out(self, event: ConnectionCheckedOutEvent) -> None: + """Abstract method to handle a :class:`ConnectionCheckedOutEvent`. + + Emitted when the driver successfully checks out a connection. + + :param event: An instance of :class:`ConnectionCheckedOutEvent`. + """ + raise NotImplementedError + + def connection_checked_in(self, event: ConnectionCheckedInEvent) -> None: + """Abstract method to handle a :class:`ConnectionCheckedInEvent`. + + Emitted when the driver checks in a connection back to the connection + Pool. + + :param event: An instance of :class:`ConnectionCheckedInEvent`. + """ + raise NotImplementedError + + +class ServerHeartbeatListener(_EventListener): + """Abstract base class for server heartbeat listeners. + + Handles `ServerHeartbeatStartedEvent`, `ServerHeartbeatSucceededEvent`, + and `ServerHeartbeatFailedEvent`. + + .. versionadded:: 3.3 + """ + + def started(self, event: ServerHeartbeatStartedEvent) -> None: + """Abstract method to handle a `ServerHeartbeatStartedEvent`. + + :param event: An instance of :class:`ServerHeartbeatStartedEvent`. + """ + raise NotImplementedError + + def succeeded(self, event: ServerHeartbeatSucceededEvent) -> None: + """Abstract method to handle a `ServerHeartbeatSucceededEvent`. + + :param event: An instance of :class:`ServerHeartbeatSucceededEvent`. + """ + raise NotImplementedError + + def failed(self, event: ServerHeartbeatFailedEvent) -> None: + """Abstract method to handle a `ServerHeartbeatFailedEvent`. + + :param event: An instance of :class:`ServerHeartbeatFailedEvent`. + """ + raise NotImplementedError + + +class TopologyListener(_EventListener): + """Abstract base class for topology monitoring listeners. + Handles `TopologyOpenedEvent`, `TopologyDescriptionChangedEvent`, and + `TopologyClosedEvent`. + + .. versionadded:: 3.3 + """ + + def opened(self, event: TopologyOpenedEvent) -> None: + """Abstract method to handle a `TopologyOpenedEvent`. + + :param event: An instance of :class:`TopologyOpenedEvent`. + """ + raise NotImplementedError + + def description_changed(self, event: TopologyDescriptionChangedEvent) -> None: + """Abstract method to handle a `TopologyDescriptionChangedEvent`. + + :param event: An instance of :class:`TopologyDescriptionChangedEvent`. + """ + raise NotImplementedError + + def closed(self, event: TopologyClosedEvent) -> None: + """Abstract method to handle a `TopologyClosedEvent`. + + :param event: An instance of :class:`TopologyClosedEvent`. + """ + raise NotImplementedError + + +class ServerListener(_EventListener): + """Abstract base class for server listeners. + Handles `ServerOpeningEvent`, `ServerDescriptionChangedEvent`, and + `ServerClosedEvent`. + + .. versionadded:: 3.3 + """ + + def opened(self, event: ServerOpeningEvent) -> None: + """Abstract method to handle a `ServerOpeningEvent`. + + :param event: An instance of :class:`ServerOpeningEvent`. + """ + raise NotImplementedError + + def description_changed(self, event: ServerDescriptionChangedEvent) -> None: + """Abstract method to handle a `ServerDescriptionChangedEvent`. + + :param event: An instance of :class:`ServerDescriptionChangedEvent`. + """ + raise NotImplementedError + + def closed(self, event: ServerClosedEvent) -> None: + """Abstract method to handle a `ServerClosedEvent`. + + :param event: An instance of :class:`ServerClosedEvent`. + """ + raise NotImplementedError + + +def _to_micros(dur: timedelta) -> int: + """Convert duration 'dur' to microseconds.""" + return int(dur.total_seconds() * 10e5) + + +def _validate_event_listeners( + option: str, listeners: Sequence[_EventListeners] +) -> Sequence[_EventListeners]: + """Validate event listeners""" + if not isinstance(listeners, abc.Sequence): + raise TypeError(f"{option} must be a list or tuple") + for listener in listeners: + if not isinstance(listener, _EventListener): + raise TypeError( + f"Listeners for {option} must be either a " + "CommandListener, ServerHeartbeatListener, " + "ServerListener, TopologyListener, or " + "ConnectionPoolListener." + ) + return listeners + + +def register(listener: _EventListener) -> None: + """Register a global event listener. + + :param listener: A subclasses of :class:`CommandListener`, + :class:`ServerHeartbeatListener`, :class:`ServerListener`, + :class:`TopologyListener`, or :class:`ConnectionPoolListener`. + """ + if not isinstance(listener, _EventListener): + raise TypeError( + f"Listeners for {listener} must be either a " + "CommandListener, ServerHeartbeatListener, " + "ServerListener, TopologyListener, or " + "ConnectionPoolListener." + ) + if isinstance(listener, CommandListener): + _LISTENERS.command_listeners.append(listener) + if isinstance(listener, ServerHeartbeatListener): + _LISTENERS.server_heartbeat_listeners.append(listener) + if isinstance(listener, ServerListener): + _LISTENERS.server_listeners.append(listener) + if isinstance(listener, TopologyListener): + _LISTENERS.topology_listeners.append(listener) + if isinstance(listener, ConnectionPoolListener): + _LISTENERS.cmap_listeners.append(listener) + + +# The "hello" command is also deemed sensitive when attempting speculative +# authentication. +def _is_speculative_authenticate(command_name: str, doc: Mapping[str, Any]) -> bool: + if ( + command_name.lower() in ("hello", HelloCompat.LEGACY_CMD) + and "speculativeAuthenticate" in doc + ): + return True + return False + + +class _CommandEvent: + """Base class for command events.""" + + __slots__ = ( + "__cmd_name", + "__rqst_id", + "__conn_id", + "__op_id", + "__service_id", + "__db", + "__server_conn_id", + ) + + def __init__( + self, + command_name: str, + request_id: int, + connection_id: _Address, + operation_id: Optional[int], + service_id: Optional[ObjectId] = None, + database_name: str = "", + server_connection_id: Optional[int] = None, + ) -> None: + self.__cmd_name = command_name + self.__rqst_id = request_id + self.__conn_id = connection_id + self.__op_id = operation_id + self.__service_id = service_id + self.__db = database_name + self.__server_conn_id = server_connection_id + + @property + def command_name(self) -> str: + """The command name.""" + return self.__cmd_name + + @property + def request_id(self) -> int: + """The request id for this operation.""" + return self.__rqst_id + + @property + def connection_id(self) -> _Address: + """The address (host, port) of the server this command was sent to.""" + return self.__conn_id + + @property + def service_id(self) -> Optional[ObjectId]: + """The service_id this command was sent to, or ``None``. + + .. versionadded:: 3.12 + """ + return self.__service_id + + @property + def operation_id(self) -> Optional[int]: + """An id for this series of events or None.""" + return self.__op_id + + @property + def database_name(self) -> str: + """The database_name this command was sent to, or ``""``. + + .. versionadded:: 4.6 + """ + return self.__db + + @property + def server_connection_id(self) -> Optional[int]: + """The server-side connection id for the connection this command was sent on, or ``None``. + + .. versionadded:: 4.7 + """ + return self.__server_conn_id + + +class CommandStartedEvent(_CommandEvent): + """Event published when a command starts. + + :param command: The command document. + :param database_name: The name of the database this command was run against. + :param request_id: The request id for this operation. + :param connection_id: The address (host, port) of the server this command + was sent to. + :param operation_id: An optional identifier for a series of related events. + :param service_id: The service_id this command was sent to, or ``None``. + """ + + __slots__ = ("__cmd",) + + def __init__( + self, + command: _DocumentOut, + database_name: str, + request_id: int, + connection_id: _Address, + operation_id: Optional[int], + service_id: Optional[ObjectId] = None, + server_connection_id: Optional[int] = None, + ) -> None: + if not command: + raise ValueError(f"{command!r} is not a valid command") + # Command name must be first key. + command_name = next(iter(command)) + super().__init__( + command_name, + request_id, + connection_id, + operation_id, + service_id=service_id, + database_name=database_name, + server_connection_id=server_connection_id, + ) + cmd_name = command_name.lower() + if cmd_name in _SENSITIVE_COMMANDS or _is_speculative_authenticate(cmd_name, command): + self.__cmd: _DocumentOut = {} + else: + self.__cmd = command + + @property + def command(self) -> _DocumentOut: + """The command document.""" + return self.__cmd + + @property + def database_name(self) -> str: + """The name of the database this command was run against.""" + return super().database_name + + def __repr__(self) -> str: + return ( + "<{} {} db: {!r}, command: {!r}, operation_id: {}, service_id: {}, server_connection_id: {}>" + ).format( + self.__class__.__name__, + self.connection_id, + self.database_name, + self.command_name, + self.operation_id, + self.service_id, + self.server_connection_id, + ) + + +class CommandSucceededEvent(_CommandEvent): + """Event published when a command succeeds. + + :param duration: The command duration as a datetime.timedelta. + :param reply: The server reply document. + :param command_name: The command name. + :param request_id: The request id for this operation. + :param connection_id: The address (host, port) of the server this command + was sent to. + :param operation_id: An optional identifier for a series of related events. + :param service_id: The service_id this command was sent to, or ``None``. + :param database_name: The database this command was sent to, or ``""``. + """ + + __slots__ = ("__duration_micros", "__reply") + + def __init__( + self, + duration: datetime.timedelta, + reply: _DocumentOut, + command_name: str, + request_id: int, + connection_id: _Address, + operation_id: Optional[int], + service_id: Optional[ObjectId] = None, + database_name: str = "", + server_connection_id: Optional[int] = None, + ) -> None: + super().__init__( + command_name, + request_id, + connection_id, + operation_id, + service_id=service_id, + database_name=database_name, + server_connection_id=server_connection_id, + ) + self.__duration_micros = _to_micros(duration) + cmd_name = command_name.lower() + if cmd_name in _SENSITIVE_COMMANDS or _is_speculative_authenticate(cmd_name, reply): + self.__reply: _DocumentOut = {} + else: + self.__reply = reply + + @property + def duration_micros(self) -> int: + """The duration of this operation in microseconds.""" + return self.__duration_micros + + @property + def reply(self) -> _DocumentOut: + """The server failure document for this operation.""" + return self.__reply + + def __repr__(self) -> str: + return ( + "<{} {} db: {!r}, command: {!r}, operation_id: {}, duration_micros: {}, service_id: {}, server_connection_id: {}>" + ).format( + self.__class__.__name__, + self.connection_id, + self.database_name, + self.command_name, + self.operation_id, + self.duration_micros, + self.service_id, + self.server_connection_id, + ) + + +class CommandFailedEvent(_CommandEvent): + """Event published when a command fails. + + :param duration: The command duration as a datetime.timedelta. + :param failure: The server reply document. + :param command_name: The command name. + :param request_id: The request id for this operation. + :param connection_id: The address (host, port) of the server this command + was sent to. + :param operation_id: An optional identifier for a series of related events. + :param service_id: The service_id this command was sent to, or ``None``. + :param database_name: The database this command was sent to, or ``""``. + """ + + __slots__ = ("__duration_micros", "__failure") + + def __init__( + self, + duration: datetime.timedelta, + failure: _DocumentOut, + command_name: str, + request_id: int, + connection_id: _Address, + operation_id: Optional[int], + service_id: Optional[ObjectId] = None, + database_name: str = "", + server_connection_id: Optional[int] = None, + ) -> None: + super().__init__( + command_name, + request_id, + connection_id, + operation_id, + service_id=service_id, + database_name=database_name, + server_connection_id=server_connection_id, + ) + self.__duration_micros = _to_micros(duration) + self.__failure = failure + + @property + def duration_micros(self) -> int: + """The duration of this operation in microseconds.""" + return self.__duration_micros + + @property + def failure(self) -> _DocumentOut: + """The server failure document for this operation.""" + return self.__failure + + def __repr__(self) -> str: + return ( + "<{} {} db: {!r}, command: {!r}, operation_id: {}, duration_micros: {}, " + "failure: {!r}, service_id: {}, server_connection_id: {}>" + ).format( + self.__class__.__name__, + self.connection_id, + self.database_name, + self.command_name, + self.operation_id, + self.duration_micros, + self.failure, + self.service_id, + self.server_connection_id, + ) + + +class _PoolEvent: + """Base class for pool events.""" + + __slots__ = ("__address",) + + def __init__(self, address: _Address) -> None: + self.__address = address + + @property + def address(self) -> _Address: + """The address (host, port) pair of the server the pool is attempting + to connect to. + """ + return self.__address + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.__address!r})" + + +class PoolCreatedEvent(_PoolEvent): + """Published when a Connection Pool is created. + + :param address: The address (host, port) pair of the server this Pool is + attempting to connect to. + + .. versionadded:: 3.9 + """ + + __slots__ = ("__options",) + + def __init__(self, address: _Address, options: dict[str, Any]) -> None: + super().__init__(address) + self.__options = options + + @property + def options(self) -> dict[str, Any]: + """Any non-default pool options that were set on this Connection Pool.""" + return self.__options + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.address!r}, {self.__options!r})" + + +class PoolReadyEvent(_PoolEvent): + """Published when a Connection Pool is marked ready. + + :param address: The address (host, port) pair of the server this Pool is + attempting to connect to. + + .. versionadded:: 4.0 + """ + + __slots__ = () + + +class PoolClearedEvent(_PoolEvent): + """Published when a Connection Pool is cleared. + + :param address: The address (host, port) pair of the server this Pool is + attempting to connect to. + :param service_id: The service_id this command was sent to, or ``None``. + :param interrupt_connections: True if all active connections were interrupted by the Pool during clearing. + + .. versionadded:: 3.9 + """ + + __slots__ = ("__service_id", "__interrupt_connections") + + def __init__( + self, + address: _Address, + service_id: Optional[ObjectId] = None, + interrupt_connections: bool = False, + ) -> None: + super().__init__(address) + self.__service_id = service_id + self.__interrupt_connections = interrupt_connections + + @property + def service_id(self) -> Optional[ObjectId]: + """Connections with this service_id are cleared. + + When service_id is ``None``, all connections in the pool are cleared. + + .. versionadded:: 3.12 + """ + return self.__service_id + + @property + def interrupt_connections(self) -> bool: + """If True, active connections are interrupted during clearing. + + .. versionadded:: 4.7 + """ + return self.__interrupt_connections + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.address!r}, {self.__service_id!r}, {self.__interrupt_connections!r})" + + +class PoolClosedEvent(_PoolEvent): + """Published when a Connection Pool is closed. + + :param address: The address (host, port) pair of the server this Pool is + attempting to connect to. + + .. versionadded:: 3.9 + """ + + __slots__ = () + + +class ConnectionClosedReason: + """An enum that defines values for `reason` on a + :class:`ConnectionClosedEvent`. + + .. versionadded:: 3.9 + """ + + STALE = "stale" + """The pool was cleared, making the connection no longer valid.""" + + IDLE = "idle" + """The connection became stale by being idle for too long (maxIdleTimeMS). + """ + + ERROR = "error" + """The connection experienced an error, making it no longer valid.""" + + POOL_CLOSED = "poolClosed" + """The pool was closed, making the connection no longer valid.""" + + +class ConnectionCheckOutFailedReason: + """An enum that defines values for `reason` on a + :class:`ConnectionCheckOutFailedEvent`. + + .. versionadded:: 3.9 + """ + + TIMEOUT = "timeout" + """The connection check out attempt exceeded the specified timeout.""" + + POOL_CLOSED = "poolClosed" + """The pool was previously closed, and cannot provide new connections.""" + + CONN_ERROR = "connectionError" + """The connection check out attempt experienced an error while setting up + a new connection. + """ + + +class _ConnectionEvent: + """Private base class for connection events.""" + + __slots__ = ("__address",) + + def __init__(self, address: _Address) -> None: + self.__address = address + + @property + def address(self) -> _Address: + """The address (host, port) pair of the server this connection is + attempting to connect to. + """ + return self.__address + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.__address!r})" + + +class _ConnectionIdEvent(_ConnectionEvent): + """Private base class for connection events with an id.""" + + __slots__ = ("__connection_id",) + + def __init__(self, address: _Address, connection_id: int) -> None: + super().__init__(address) + self.__connection_id = connection_id + + @property + def connection_id(self) -> int: + """The ID of the connection.""" + return self.__connection_id + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.address!r}, {self.__connection_id!r})" + + +class _ConnectionDurationEvent(_ConnectionIdEvent): + """Private base class for connection events with a duration.""" + + __slots__ = ("__duration",) + + def __init__(self, address: _Address, connection_id: int, duration: Optional[float]) -> None: + super().__init__(address, connection_id) + self.__duration = duration + + @property + def duration(self) -> Optional[float]: + """The duration of the connection event. + + .. versionadded:: 4.7 + """ + return self.__duration + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.address!r}, {self.connection_id!r}, {self.__duration!r})" + + +class ConnectionCreatedEvent(_ConnectionIdEvent): + """Published when a Connection Pool creates a Connection object. + + NOTE: This connection is not ready for use until the + :class:`ConnectionReadyEvent` is published. + + :param address: The address (host, port) pair of the server this + Connection is attempting to connect to. + :param connection_id: The integer ID of the Connection in this Pool. + + .. versionadded:: 3.9 + """ + + __slots__ = () + + +class ConnectionReadyEvent(_ConnectionDurationEvent): + """Published when a Connection has finished its setup, and is ready to use. + + :param address: The address (host, port) pair of the server this + Connection is attempting to connect to. + :param connection_id: The integer ID of the Connection in this Pool. + + .. versionadded:: 3.9 + """ + + __slots__ = () + + +class ConnectionClosedEvent(_ConnectionIdEvent): + """Published when a Connection is closed. + + :param address: The address (host, port) pair of the server this + Connection is attempting to connect to. + :param connection_id: The integer ID of the Connection in this Pool. + :param reason: A reason explaining why this connection was closed. + + .. versionadded:: 3.9 + """ + + __slots__ = ("__reason",) + + def __init__(self, address: _Address, connection_id: int, reason: str): + super().__init__(address, connection_id) + self.__reason = reason + + @property + def reason(self) -> str: + """A reason explaining why this connection was closed. + + The reason must be one of the strings from the + :class:`ConnectionClosedReason` enum. + """ + return self.__reason + + def __repr__(self) -> str: + return "{}({!r}, {!r}, {!r})".format( + self.__class__.__name__, + self.address, + self.connection_id, + self.__reason, + ) + + +class ConnectionCheckOutStartedEvent(_ConnectionEvent): + """Published when the driver starts attempting to check out a connection. + + :param address: The address (host, port) pair of the server this + Connection is attempting to connect to. + + .. versionadded:: 3.9 + """ + + __slots__ = () + + +class ConnectionCheckOutFailedEvent(_ConnectionDurationEvent): + """Published when the driver's attempt to check out a connection fails. + + :param address: The address (host, port) pair of the server this + Connection is attempting to connect to. + :param reason: A reason explaining why connection check out failed. + + .. versionadded:: 3.9 + """ + + __slots__ = ("__reason",) + + def __init__(self, address: _Address, reason: str, duration: Optional[float]) -> None: + super().__init__(address=address, connection_id=0, duration=duration) + self.__reason = reason + + @property + def reason(self) -> str: + """A reason explaining why connection check out failed. + + The reason must be one of the strings from the + :class:`ConnectionCheckOutFailedReason` enum. + """ + return self.__reason + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.address!r}, {self.__reason!r}, {self.duration!r})" + + +class ConnectionCheckedOutEvent(_ConnectionDurationEvent): + """Published when the driver successfully checks out a connection. + + :param address: The address (host, port) pair of the server this + Connection is attempting to connect to. + :param connection_id: The integer ID of the Connection in this Pool. + + .. versionadded:: 3.9 + """ + + __slots__ = () + + +class ConnectionCheckedInEvent(_ConnectionIdEvent): + """Published when the driver checks in a Connection into the Pool. + + :param address: The address (host, port) pair of the server this + Connection is attempting to connect to. + :param connection_id: The integer ID of the Connection in this Pool. + + .. versionadded:: 3.9 + """ + + __slots__ = () + + +class _ServerEvent: + """Base class for server events.""" + + __slots__ = ("__server_address", "__topology_id") + + def __init__(self, server_address: _Address, topology_id: ObjectId) -> None: + self.__server_address = server_address + self.__topology_id = topology_id + + @property + def server_address(self) -> _Address: + """The address (host, port) pair of the server""" + return self.__server_address + + @property + def topology_id(self) -> ObjectId: + """A unique identifier for the topology this server is a part of.""" + return self.__topology_id + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} {self.server_address} topology_id: {self.topology_id}>" + + +class ServerDescriptionChangedEvent(_ServerEvent): + """Published when server description changes. + + .. versionadded:: 3.3 + """ + + __slots__ = ("__previous_description", "__new_description") + + def __init__( + self, + previous_description: ServerDescription, + new_description: ServerDescription, + *args: Any, + ) -> None: + super().__init__(*args) + self.__previous_description = previous_description + self.__new_description = new_description + + @property + def previous_description(self) -> ServerDescription: + """The previous + :class:`~pymongo.server_description.ServerDescription`. + """ + return self.__previous_description + + @property + def new_description(self) -> ServerDescription: + """The new + :class:`~pymongo.server_description.ServerDescription`. + """ + return self.__new_description + + def __repr__(self) -> str: + return "<{} {} changed from: {}, to: {}>".format( + self.__class__.__name__, + self.server_address, + self.previous_description, + self.new_description, + ) + + +class ServerOpeningEvent(_ServerEvent): + """Published when server is initialized. + + .. versionadded:: 3.3 + """ + + __slots__ = () + + +class ServerClosedEvent(_ServerEvent): + """Published when server is closed. + + .. versionadded:: 3.3 + """ + + __slots__ = () + + +class TopologyEvent: + """Base class for topology description events.""" + + __slots__ = ("__topology_id",) + + def __init__(self, topology_id: ObjectId) -> None: + self.__topology_id = topology_id + + @property + def topology_id(self) -> ObjectId: + """A unique identifier for the topology this server is a part of.""" + return self.__topology_id + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} topology_id: {self.topology_id}>" + + +class TopologyDescriptionChangedEvent(TopologyEvent): + """Published when the topology description changes. + + .. versionadded:: 3.3 + """ + + __slots__ = ("__previous_description", "__new_description") + + def __init__( + self, + previous_description: TopologyDescription, + new_description: TopologyDescription, + *args: Any, + ) -> None: + super().__init__(*args) + self.__previous_description = previous_description + self.__new_description = new_description + + @property + def previous_description(self) -> TopologyDescription: + """The previous + :class:`~pymongo.topology_description.TopologyDescription`. + """ + return self.__previous_description + + @property + def new_description(self) -> TopologyDescription: + """The new + :class:`~pymongo.topology_description.TopologyDescription`. + """ + return self.__new_description + + def __repr__(self) -> str: + return "<{} topology_id: {} changed from: {}, to: {}>".format( + self.__class__.__name__, + self.topology_id, + self.previous_description, + self.new_description, + ) + + +class TopologyOpenedEvent(TopologyEvent): + """Published when the topology is initialized. + + .. versionadded:: 3.3 + """ + + __slots__ = () + + +class TopologyClosedEvent(TopologyEvent): + """Published when the topology is closed. + + .. versionadded:: 3.3 + """ + + __slots__ = () + + +class _ServerHeartbeatEvent: + """Base class for server heartbeat events.""" + + __slots__ = ("__connection_id", "__awaited") + + def __init__(self, connection_id: _Address, awaited: bool = False) -> None: + self.__connection_id = connection_id + self.__awaited = awaited + + @property + def connection_id(self) -> _Address: + """The address (host, port) of the server this heartbeat was sent + to. + """ + return self.__connection_id + + @property + def awaited(self) -> bool: + """Whether the heartbeat was issued as an awaitable hello command. + + .. versionadded:: 4.6 + """ + return self.__awaited + + def __repr__(self) -> str: + return f"<{self.__class__.__name__} {self.connection_id} awaited: {self.awaited}>" + + +class ServerHeartbeatStartedEvent(_ServerHeartbeatEvent): + """Published when a heartbeat is started. + + .. versionadded:: 3.3 + """ + + __slots__ = () + + +class ServerHeartbeatSucceededEvent(_ServerHeartbeatEvent): + """Fired when the server heartbeat succeeds. + + .. versionadded:: 3.3 + """ + + __slots__ = ("__duration", "__reply") + + def __init__( + self, duration: float, reply: Hello, connection_id: _Address, awaited: bool = False + ) -> None: + super().__init__(connection_id, awaited) + self.__duration = duration + self.__reply = reply + + @property + def duration(self) -> float: + """The duration of this heartbeat in microseconds.""" + return self.__duration + + @property + def reply(self) -> Hello: + """An instance of :class:`~pymongo.hello.Hello`.""" + return self.__reply + + @property + def awaited(self) -> bool: + """Whether the heartbeat was awaited. + + If true, then :meth:`duration` reflects the sum of the round trip time + to the server and the time that the server waited before sending a + response. + + .. versionadded:: 3.11 + """ + return super().awaited + + def __repr__(self) -> str: + return "<{} {} duration: {}, awaited: {}, reply: {}>".format( + self.__class__.__name__, + self.connection_id, + self.duration, + self.awaited, + self.reply, + ) + + +class ServerHeartbeatFailedEvent(_ServerHeartbeatEvent): + """Fired when the server heartbeat fails, either with an "ok: 0" + or a socket exception. + + .. versionadded:: 3.3 + """ + + __slots__ = ("__duration", "__reply") + + def __init__( + self, duration: float, reply: Exception, connection_id: _Address, awaited: bool = False + ) -> None: + super().__init__(connection_id, awaited) + self.__duration = duration + self.__reply = reply + + @property + def duration(self) -> float: + """The duration of this heartbeat in microseconds.""" + return self.__duration + + @property + def reply(self) -> Exception: + """A subclass of :exc:`Exception`.""" + return self.__reply + + @property + def awaited(self) -> bool: + """Whether the heartbeat was awaited. + + If true, then :meth:`duration` reflects the sum of the round trip time + to the server and the time that the server waited before sending a + response. + + .. versionadded:: 3.11 + """ + return super().awaited + + def __repr__(self) -> str: + return "<{} {} duration: {}, awaited: {}, reply: {!r}>".format( + self.__class__.__name__, + self.connection_id, + self.duration, + self.awaited, + self.reply, + ) + + +class _EventListeners: + """Configure event listeners for a client instance. + + Any event listeners registered globally are included by default. + + :param listeners: A list of event listeners. + """ + + def __init__(self, listeners: Optional[Sequence[_EventListener]]): + self.__command_listeners = _LISTENERS.command_listeners[:] + self.__server_listeners = _LISTENERS.server_listeners[:] + lst = _LISTENERS.server_heartbeat_listeners + self.__server_heartbeat_listeners = lst[:] + self.__topology_listeners = _LISTENERS.topology_listeners[:] + self.__cmap_listeners = _LISTENERS.cmap_listeners[:] + if listeners is not None: + for lst in listeners: + if isinstance(lst, CommandListener): + self.__command_listeners.append(lst) + if isinstance(lst, ServerListener): + self.__server_listeners.append(lst) + if isinstance(lst, ServerHeartbeatListener): + self.__server_heartbeat_listeners.append(lst) + if isinstance(lst, TopologyListener): + self.__topology_listeners.append(lst) + if isinstance(lst, ConnectionPoolListener): + self.__cmap_listeners.append(lst) + self.__enabled_for_commands = bool(self.__command_listeners) + self.__enabled_for_server = bool(self.__server_listeners) + self.__enabled_for_server_heartbeat = bool(self.__server_heartbeat_listeners) + self.__enabled_for_topology = bool(self.__topology_listeners) + self.__enabled_for_cmap = bool(self.__cmap_listeners) + + @property + def enabled_for_commands(self) -> bool: + """Are any CommandListener instances registered?""" + return self.__enabled_for_commands + + @property + def enabled_for_server(self) -> bool: + """Are any ServerListener instances registered?""" + return self.__enabled_for_server + + @property + def enabled_for_server_heartbeat(self) -> bool: + """Are any ServerHeartbeatListener instances registered?""" + return self.__enabled_for_server_heartbeat + + @property + def enabled_for_topology(self) -> bool: + """Are any TopologyListener instances registered?""" + return self.__enabled_for_topology + + @property + def enabled_for_cmap(self) -> bool: + """Are any ConnectionPoolListener instances registered?""" + return self.__enabled_for_cmap + + def event_listeners(self) -> list[_EventListeners]: + """List of registered event listeners.""" + return ( + self.__command_listeners + + self.__server_heartbeat_listeners + + self.__server_listeners + + self.__topology_listeners + + self.__cmap_listeners + ) + + def publish_command_start( + self, + command: _DocumentOut, + database_name: str, + request_id: int, + connection_id: _Address, + server_connection_id: Optional[int], + op_id: Optional[int] = None, + service_id: Optional[ObjectId] = None, + ) -> None: + """Publish a CommandStartedEvent to all command listeners. + + :param command: The command document. + :param database_name: The name of the database this command was run + against. + :param request_id: The request id for this operation. + :param connection_id: The address (host, port) of the server this + command was sent to. + :param op_id: The (optional) operation id for this operation. + :param service_id: The service_id this command was sent to, or ``None``. + """ + if op_id is None: + op_id = request_id + event = CommandStartedEvent( + command, + database_name, + request_id, + connection_id, + op_id, + service_id=service_id, + server_connection_id=server_connection_id, + ) + for subscriber in self.__command_listeners: + try: + subscriber.started(event) + except Exception: + _handle_exception() + + def publish_command_success( + self, + duration: timedelta, + reply: _DocumentOut, + command_name: str, + request_id: int, + connection_id: _Address, + server_connection_id: Optional[int], + op_id: Optional[int] = None, + service_id: Optional[ObjectId] = None, + speculative_hello: bool = False, + database_name: str = "", + ) -> None: + """Publish a CommandSucceededEvent to all command listeners. + + :param duration: The command duration as a datetime.timedelta. + :param reply: The server reply document. + :param command_name: The command name. + :param request_id: The request id for this operation. + :param connection_id: The address (host, port) of the server this + command was sent to. + :param op_id: The (optional) operation id for this operation. + :param service_id: The service_id this command was sent to, or ``None``. + :param speculative_hello: Was the command sent with speculative auth? + :param database_name: The database this command was sent to, or ``""``. + """ + if op_id is None: + op_id = request_id + if speculative_hello: + # Redact entire response when the command started contained + # speculativeAuthenticate. + reply = {} + event = CommandSucceededEvent( + duration, + reply, + command_name, + request_id, + connection_id, + op_id, + service_id, + database_name=database_name, + server_connection_id=server_connection_id, + ) + for subscriber in self.__command_listeners: + try: + subscriber.succeeded(event) + except Exception: + _handle_exception() + + def publish_command_failure( + self, + duration: timedelta, + failure: _DocumentOut, + command_name: str, + request_id: int, + connection_id: _Address, + server_connection_id: Optional[int], + op_id: Optional[int] = None, + service_id: Optional[ObjectId] = None, + database_name: str = "", + ) -> None: + """Publish a CommandFailedEvent to all command listeners. + + :param duration: The command duration as a datetime.timedelta. + :param failure: The server reply document or failure description + document. + :param command_name: The command name. + :param request_id: The request id for this operation. + :param connection_id: The address (host, port) of the server this + command was sent to. + :param op_id: The (optional) operation id for this operation. + :param service_id: The service_id this command was sent to, or ``None``. + :param database_name: The database this command was sent to, or ``""``. + """ + if op_id is None: + op_id = request_id + event = CommandFailedEvent( + duration, + failure, + command_name, + request_id, + connection_id, + op_id, + service_id=service_id, + database_name=database_name, + server_connection_id=server_connection_id, + ) + for subscriber in self.__command_listeners: + try: + subscriber.failed(event) + except Exception: + _handle_exception() + + def publish_server_heartbeat_started(self, connection_id: _Address, awaited: bool) -> None: + """Publish a ServerHeartbeatStartedEvent to all server heartbeat + listeners. + + :param connection_id: The address (host, port) pair of the connection. + :param awaited: True if this heartbeat is part of an awaitable hello command. + """ + event = ServerHeartbeatStartedEvent(connection_id, awaited) + for subscriber in self.__server_heartbeat_listeners: + try: + subscriber.started(event) + except Exception: + _handle_exception() + + def publish_server_heartbeat_succeeded( + self, connection_id: _Address, duration: float, reply: Hello, awaited: bool + ) -> None: + """Publish a ServerHeartbeatSucceededEvent to all server heartbeat + listeners. + + :param connection_id: The address (host, port) pair of the connection. + :param duration: The execution time of the event in the highest possible + resolution for the platform. + :param reply: The command reply. + :param awaited: True if the response was awaited. + """ + event = ServerHeartbeatSucceededEvent(duration, reply, connection_id, awaited) + for subscriber in self.__server_heartbeat_listeners: + try: + subscriber.succeeded(event) + except Exception: + _handle_exception() + + def publish_server_heartbeat_failed( + self, connection_id: _Address, duration: float, reply: Exception, awaited: bool + ) -> None: + """Publish a ServerHeartbeatFailedEvent to all server heartbeat + listeners. + + :param connection_id: The address (host, port) pair of the connection. + :param duration: The execution time of the event in the highest possible + resolution for the platform. + :param reply: The command reply. + :param awaited: True if the response was awaited. + """ + event = ServerHeartbeatFailedEvent(duration, reply, connection_id, awaited) + for subscriber in self.__server_heartbeat_listeners: + try: + subscriber.failed(event) + except Exception: + _handle_exception() + + def publish_server_opened(self, server_address: _Address, topology_id: ObjectId) -> None: + """Publish a ServerOpeningEvent to all server listeners. + + :param server_address: The address (host, port) pair of the server. + :param topology_id: A unique identifier for the topology this server + is a part of. + """ + event = ServerOpeningEvent(server_address, topology_id) + for subscriber in self.__server_listeners: + try: + subscriber.opened(event) + except Exception: + _handle_exception() + + def publish_server_closed(self, server_address: _Address, topology_id: ObjectId) -> None: + """Publish a ServerClosedEvent to all server listeners. + + :param server_address: The address (host, port) pair of the server. + :param topology_id: A unique identifier for the topology this server + is a part of. + """ + event = ServerClosedEvent(server_address, topology_id) + for subscriber in self.__server_listeners: + try: + subscriber.closed(event) + except Exception: + _handle_exception() + + def publish_server_description_changed( + self, + previous_description: ServerDescription, + new_description: ServerDescription, + server_address: _Address, + topology_id: ObjectId, + ) -> None: + """Publish a ServerDescriptionChangedEvent to all server listeners. + + :param previous_description: The previous server description. + :param server_address: The address (host, port) pair of the server. + :param new_description: The new server description. + :param topology_id: A unique identifier for the topology this server + is a part of. + """ + event = ServerDescriptionChangedEvent( + previous_description, new_description, server_address, topology_id + ) + for subscriber in self.__server_listeners: + try: + subscriber.description_changed(event) + except Exception: + _handle_exception() + + def publish_topology_opened(self, topology_id: ObjectId) -> None: + """Publish a TopologyOpenedEvent to all topology listeners. + + :param topology_id: A unique identifier for the topology this server + is a part of. + """ + event = TopologyOpenedEvent(topology_id) + for subscriber in self.__topology_listeners: + try: + subscriber.opened(event) + except Exception: + _handle_exception() + + def publish_topology_closed(self, topology_id: ObjectId) -> None: + """Publish a TopologyClosedEvent to all topology listeners. + + :param topology_id: A unique identifier for the topology this server + is a part of. + """ + event = TopologyClosedEvent(topology_id) + for subscriber in self.__topology_listeners: + try: + subscriber.closed(event) + except Exception: + _handle_exception() + + def publish_topology_description_changed( + self, + previous_description: TopologyDescription, + new_description: TopologyDescription, + topology_id: ObjectId, + ) -> None: + """Publish a TopologyDescriptionChangedEvent to all topology listeners. + + :param previous_description: The previous topology description. + :param new_description: The new topology description. + :param topology_id: A unique identifier for the topology this server + is a part of. + """ + event = TopologyDescriptionChangedEvent(previous_description, new_description, topology_id) + for subscriber in self.__topology_listeners: + try: + subscriber.description_changed(event) + except Exception: + _handle_exception() + + def publish_pool_created(self, address: _Address, options: dict[str, Any]) -> None: + """Publish a :class:`PoolCreatedEvent` to all pool listeners.""" + event = PoolCreatedEvent(address, options) + for subscriber in self.__cmap_listeners: + try: + subscriber.pool_created(event) + except Exception: + _handle_exception() + + def publish_pool_ready(self, address: _Address) -> None: + """Publish a :class:`PoolReadyEvent` to all pool listeners.""" + event = PoolReadyEvent(address) + for subscriber in self.__cmap_listeners: + try: + subscriber.pool_ready(event) + except Exception: + _handle_exception() + + def publish_pool_cleared( + self, + address: _Address, + service_id: Optional[ObjectId], + interrupt_connections: bool = False, + ) -> None: + """Publish a :class:`PoolClearedEvent` to all pool listeners.""" + event = PoolClearedEvent(address, service_id, interrupt_connections) + for subscriber in self.__cmap_listeners: + try: + subscriber.pool_cleared(event) + except Exception: + _handle_exception() + + def publish_pool_closed(self, address: _Address) -> None: + """Publish a :class:`PoolClosedEvent` to all pool listeners.""" + event = PoolClosedEvent(address) + for subscriber in self.__cmap_listeners: + try: + subscriber.pool_closed(event) + except Exception: + _handle_exception() + + def publish_connection_created(self, address: _Address, connection_id: int) -> None: + """Publish a :class:`ConnectionCreatedEvent` to all connection + listeners. + """ + event = ConnectionCreatedEvent(address, connection_id) + for subscriber in self.__cmap_listeners: + try: + subscriber.connection_created(event) + except Exception: + _handle_exception() + + def publish_connection_ready( + self, address: _Address, connection_id: int, duration: float + ) -> None: + """Publish a :class:`ConnectionReadyEvent` to all connection listeners.""" + event = ConnectionReadyEvent(address, connection_id, duration) + for subscriber in self.__cmap_listeners: + try: + subscriber.connection_ready(event) + except Exception: + _handle_exception() + + def publish_connection_closed(self, address: _Address, connection_id: int, reason: str) -> None: + """Publish a :class:`ConnectionClosedEvent` to all connection + listeners. + """ + event = ConnectionClosedEvent(address, connection_id, reason) + for subscriber in self.__cmap_listeners: + try: + subscriber.connection_closed(event) + except Exception: + _handle_exception() + + def publish_connection_check_out_started(self, address: _Address) -> None: + """Publish a :class:`ConnectionCheckOutStartedEvent` to all connection + listeners. + """ + event = ConnectionCheckOutStartedEvent(address) + for subscriber in self.__cmap_listeners: + try: + subscriber.connection_check_out_started(event) + except Exception: + _handle_exception() + + def publish_connection_check_out_failed( + self, address: _Address, reason: str, duration: float + ) -> None: + """Publish a :class:`ConnectionCheckOutFailedEvent` to all connection + listeners. + """ + event = ConnectionCheckOutFailedEvent(address, reason, duration) + for subscriber in self.__cmap_listeners: + try: + subscriber.connection_check_out_failed(event) + except Exception: + _handle_exception() + + def publish_connection_checked_out( + self, address: _Address, connection_id: int, duration: float + ) -> None: + """Publish a :class:`ConnectionCheckedOutEvent` to all connection + listeners. + """ + event = ConnectionCheckedOutEvent(address, connection_id, duration) + for subscriber in self.__cmap_listeners: + try: + subscriber.connection_checked_out(event) + except Exception: + _handle_exception() -__doc__ = original_doc + def publish_connection_checked_in(self, address: _Address, connection_id: int) -> None: + """Publish a :class:`ConnectionCheckedInEvent` to all connection + listeners. + """ + event = ConnectionCheckedInEvent(address, connection_id) + for subscriber in self.__cmap_listeners: + try: + subscriber.connection_checked_in(event) + except Exception: + _handle_exception() diff --git a/pymongo/operations.py b/pymongo/operations.py index dbfc048a60..f43f8bdc8c 100644 --- a/pymongo/operations.py +++ b/pymongo/operations.py @@ -1,4 +1,4 @@ -# Copyright 2024-present MongoDB, Inc. +# Copyright 2015-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,10 +12,614 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Re-import of synchronous Operations API for compatibility.""" +"""Operation class definitions.""" from __future__ import annotations -from pymongo.synchronous.operations import * # noqa: F403 -from pymongo.synchronous.operations import __doc__ as original_doc +import enum +from typing import ( + TYPE_CHECKING, + Any, + Generic, + Mapping, + Optional, + Sequence, + Tuple, + Union, +) -__doc__ = original_doc +from bson.raw_bson import RawBSONDocument +from pymongo import helpers_shared +from pymongo.collation import validate_collation_or_none +from pymongo.common import validate_is_mapping, validate_list +from pymongo.helpers_shared import _gen_index_name, _index_document, _index_list +from pymongo.typings import _CollationIn, _DocumentType, _Pipeline +from pymongo.write_concern import validate_boolean + +if TYPE_CHECKING: + from pymongo.typings import _AgnosticBulk + +_IS_SYNC = False + +# Hint supports index name, "myIndex", a list of either strings or index pairs: [('x', 1), ('y', -1), 'z''], or a dictionary +_IndexList = Union[ + Sequence[Union[str, Tuple[str, Union[int, str, Mapping[str, Any]]]]], Mapping[str, Any] +] +_IndexKeyHint = Union[str, _IndexList] + + +class _Op(str, enum.Enum): + ABORT = "abortTransaction" + AGGREGATE = "aggregate" + COMMIT = "commitTransaction" + COUNT = "count" + CREATE = "create" + CREATE_INDEXES = "createIndexes" + CREATE_SEARCH_INDEXES = "createSearchIndexes" + DELETE = "delete" + DISTINCT = "distinct" + DROP = "drop" + DROP_DATABASE = "dropDatabase" + DROP_INDEXES = "dropIndexes" + DROP_SEARCH_INDEXES = "dropSearchIndexes" + END_SESSIONS = "endSessions" + FIND_AND_MODIFY = "findAndModify" + FIND = "find" + INSERT = "insert" + LIST_COLLECTIONS = "listCollections" + LIST_INDEXES = "listIndexes" + LIST_SEARCH_INDEX = "listSearchIndexes" + LIST_DATABASES = "listDatabases" + UPDATE = "update" + UPDATE_INDEX = "updateIndex" + UPDATE_SEARCH_INDEX = "updateSearchIndex" + RENAME = "rename" + GETMORE = "getMore" + KILL_CURSORS = "killCursors" + TEST = "testOperation" + + +class InsertOne(Generic[_DocumentType]): + """Represents an insert_one operation.""" + + __slots__ = ("_doc",) + + def __init__(self, document: _DocumentType) -> None: + """Create an InsertOne instance. + + For use with :meth:`~pymongo.asynchronous.collection.AsyncCollection.bulk_write` and :meth:`~pymongo.collection.Collection.bulk_write`. + + :param document: The document to insert. If the document is missing an + _id field one will be added. + """ + self._doc = document + + def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> None: + """Add this operation to the _AsyncBulk/_Bulk instance `bulkobj`.""" + bulkobj.add_insert(self._doc) # type: ignore[arg-type] + + def __repr__(self) -> str: + return f"InsertOne({self._doc!r})" + + def __eq__(self, other: Any) -> bool: + if type(other) == type(self): + return other._doc == self._doc + return NotImplemented + + def __ne__(self, other: Any) -> bool: + return not self == other + + +class DeleteOne: + """Represents a delete_one operation.""" + + __slots__ = ("_filter", "_collation", "_hint") + + def __init__( + self, + filter: Mapping[str, Any], + collation: Optional[_CollationIn] = None, + hint: Optional[_IndexKeyHint] = None, + ) -> None: + """Create a DeleteOne instance. + + For use with :meth:`~pymongo.asynchronous.collection.AsyncCollection.bulk_write` and :meth:`~pymongo.collection.Collection.bulk_write`. + + :param filter: A query that matches the document to delete. + :param collation: An instance of + :class:`~pymongo.collation.Collation`. + :param hint: An index to use to support the query + predicate specified either by its string name, or in the same + format as passed to + :meth:`~pymongo.asynchronous.collection.AsyncCollection.create_index` or :meth:`~pymongo.collection.Collection.create_index` (e.g. + ``[('field', ASCENDING)]``). This option is only supported on + MongoDB 4.4 and above. + + .. versionchanged:: 3.11 + Added the ``hint`` option. + .. versionchanged:: 3.5 + Added the `collation` option. + """ + if filter is not None: + validate_is_mapping("filter", filter) + if hint is not None and not isinstance(hint, str): + self._hint: Union[str, dict[str, Any], None] = helpers_shared._index_document(hint) + else: + self._hint = hint + self._filter = filter + self._collation = collation + + def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> None: + """Add this operation to the _AsyncBulk/_Bulk instance `bulkobj`.""" + bulkobj.add_delete( + self._filter, + 1, + collation=validate_collation_or_none(self._collation), + hint=self._hint, + ) + + def __repr__(self) -> str: + return f"DeleteOne({self._filter!r}, {self._collation!r}, {self._hint!r})" + + def __eq__(self, other: Any) -> bool: + if type(other) == type(self): + return (other._filter, other._collation, other._hint) == ( + self._filter, + self._collation, + self._hint, + ) + return NotImplemented + + def __ne__(self, other: Any) -> bool: + return not self == other + + +class DeleteMany: + """Represents a delete_many operation.""" + + __slots__ = ("_filter", "_collation", "_hint") + + def __init__( + self, + filter: Mapping[str, Any], + collation: Optional[_CollationIn] = None, + hint: Optional[_IndexKeyHint] = None, + ) -> None: + """Create a DeleteMany instance. + + For use with :meth:`~pymongo.asynchronous.collection.AsyncCollection.bulk_write` and :meth:`~pymongo.collection.Collection.bulk_write`. + + :param filter: A query that matches the documents to delete. + :param collation: An instance of + :class:`~pymongo.collation.Collation`. + :param hint: An index to use to support the query + predicate specified either by its string name, or in the same + format as passed to + :meth:`~pymongo.asynchronous.collection.AsyncCollection.create_index` or :meth:`~pymongo.collection.Collection.create_index` (e.g. + ``[('field', ASCENDING)]``). This option is only supported on + MongoDB 4.4 and above. + + .. versionchanged:: 3.11 + Added the ``hint`` option. + .. versionchanged:: 3.5 + Added the `collation` option. + """ + if filter is not None: + validate_is_mapping("filter", filter) + if hint is not None and not isinstance(hint, str): + self._hint: Union[str, dict[str, Any], None] = helpers_shared._index_document(hint) + else: + self._hint = hint + self._filter = filter + self._collation = collation + + def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> None: + """Add this operation to the _AsyncBulk/_Bulk instance `bulkobj`.""" + bulkobj.add_delete( + self._filter, + 0, + collation=validate_collation_or_none(self._collation), + hint=self._hint, + ) + + def __repr__(self) -> str: + return f"DeleteMany({self._filter!r}, {self._collation!r}, {self._hint!r})" + + def __eq__(self, other: Any) -> bool: + if type(other) == type(self): + return (other._filter, other._collation, other._hint) == ( + self._filter, + self._collation, + self._hint, + ) + return NotImplemented + + def __ne__(self, other: Any) -> bool: + return not self == other + + +class ReplaceOne(Generic[_DocumentType]): + """Represents a replace_one operation.""" + + __slots__ = ("_filter", "_doc", "_upsert", "_collation", "_hint") + + def __init__( + self, + filter: Mapping[str, Any], + replacement: Union[_DocumentType, RawBSONDocument], + upsert: bool = False, + collation: Optional[_CollationIn] = None, + hint: Optional[_IndexKeyHint] = None, + ) -> None: + """Create a ReplaceOne instance. + + For use with :meth:`~pymongo.asynchronous.collection.AsyncCollection.bulk_write` and :meth:`~pymongo.collection.Collection.bulk_write`. + + :param filter: A query that matches the document to replace. + :param replacement: The new document. + :param upsert: If ``True``, perform an insert if no documents + match the filter. + :param collation: An instance of + :class:`~pymongo.collation.Collation`. + :param hint: An index to use to support the query + predicate specified either by its string name, or in the same + format as passed to + :meth:`~pymongo.asynchronous.collection.AsyncCollection.create_index` or :meth:`~pymongo.collection.Collection.create_index` (e.g. + ``[('field', ASCENDING)]``). This option is only supported on + MongoDB 4.2 and above. + + .. versionchanged:: 3.11 + Added the ``hint`` option. + .. versionchanged:: 3.5 + Added the ``collation`` option. + """ + if filter is not None: + validate_is_mapping("filter", filter) + if upsert is not None: + validate_boolean("upsert", upsert) + if hint is not None and not isinstance(hint, str): + self._hint: Union[str, dict[str, Any], None] = helpers_shared._index_document(hint) + else: + self._hint = hint + self._filter = filter + self._doc = replacement + self._upsert = upsert + self._collation = collation + + def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> None: + """Add this operation to the _AsyncBulk/_Bulk instance `bulkobj`.""" + bulkobj.add_replace( + self._filter, + self._doc, + self._upsert, + collation=validate_collation_or_none(self._collation), + hint=self._hint, + ) + + def __eq__(self, other: Any) -> bool: + if type(other) == type(self): + return ( + other._filter, + other._doc, + other._upsert, + other._collation, + other._hint, + ) == ( + self._filter, + self._doc, + self._upsert, + self._collation, + other._hint, + ) + return NotImplemented + + def __ne__(self, other: Any) -> bool: + return not self == other + + def __repr__(self) -> str: + return "{}({!r}, {!r}, {!r}, {!r}, {!r})".format( + self.__class__.__name__, + self._filter, + self._doc, + self._upsert, + self._collation, + self._hint, + ) + + +class _UpdateOp: + """Private base class for update operations.""" + + __slots__ = ("_filter", "_doc", "_upsert", "_collation", "_array_filters", "_hint") + + def __init__( + self, + filter: Mapping[str, Any], + doc: Union[Mapping[str, Any], _Pipeline], + upsert: bool, + collation: Optional[_CollationIn], + array_filters: Optional[list[Mapping[str, Any]]], + hint: Optional[_IndexKeyHint], + ): + if filter is not None: + validate_is_mapping("filter", filter) + if upsert is not None: + validate_boolean("upsert", upsert) + if array_filters is not None: + validate_list("array_filters", array_filters) + if hint is not None and not isinstance(hint, str): + self._hint: Union[str, dict[str, Any], None] = helpers_shared._index_document(hint) + else: + self._hint = hint + + self._filter = filter + self._doc = doc + self._upsert = upsert + self._collation = collation + self._array_filters = array_filters + + def __eq__(self, other: object) -> bool: + if isinstance(other, type(self)): + return ( + other._filter, + other._doc, + other._upsert, + other._collation, + other._array_filters, + other._hint, + ) == ( + self._filter, + self._doc, + self._upsert, + self._collation, + self._array_filters, + self._hint, + ) + return NotImplemented + + def __repr__(self) -> str: + return "{}({!r}, {!r}, {!r}, {!r}, {!r}, {!r})".format( + self.__class__.__name__, + self._filter, + self._doc, + self._upsert, + self._collation, + self._array_filters, + self._hint, + ) + + +class UpdateOne(_UpdateOp): + """Represents an update_one operation.""" + + __slots__ = () + + def __init__( + self, + filter: Mapping[str, Any], + update: Union[Mapping[str, Any], _Pipeline], + upsert: bool = False, + collation: Optional[_CollationIn] = None, + array_filters: Optional[list[Mapping[str, Any]]] = None, + hint: Optional[_IndexKeyHint] = None, + ) -> None: + """Represents an update_one operation. + + For use with :meth:`~pymongo.asynchronous.collection.AsyncCollection.bulk_write` and :meth:`~pymongo.collection.Collection.bulk_write`. + + :param filter: A query that matches the document to update. + :param update: The modifications to apply. + :param upsert: If ``True``, perform an insert if no documents + match the filter. + :param collation: An instance of + :class:`~pymongo.collation.Collation`. + :param array_filters: A list of filters specifying which + array elements an update should apply. + :param hint: An index to use to support the query + predicate specified either by its string name, or in the same + format as passed to + :meth:`~pymongo.asynchronous.collection.AsyncCollection.create_index` or :meth:`~pymongo.collection.Collection.create_index` (e.g. + ``[('field', ASCENDING)]``). This option is only supported on + MongoDB 4.2 and above. + + .. versionchanged:: 3.11 + Added the `hint` option. + .. versionchanged:: 3.9 + Added the ability to accept a pipeline as the `update`. + .. versionchanged:: 3.6 + Added the `array_filters` option. + .. versionchanged:: 3.5 + Added the `collation` option. + """ + super().__init__(filter, update, upsert, collation, array_filters, hint) + + def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> None: + """Add this operation to the _AsyncBulk/_Bulk instance `bulkobj`.""" + bulkobj.add_update( + self._filter, + self._doc, + False, + self._upsert, + collation=validate_collation_or_none(self._collation), + array_filters=self._array_filters, + hint=self._hint, + ) + + +class UpdateMany(_UpdateOp): + """Represents an update_many operation.""" + + __slots__ = () + + def __init__( + self, + filter: Mapping[str, Any], + update: Union[Mapping[str, Any], _Pipeline], + upsert: bool = False, + collation: Optional[_CollationIn] = None, + array_filters: Optional[list[Mapping[str, Any]]] = None, + hint: Optional[_IndexKeyHint] = None, + ) -> None: + """Create an UpdateMany instance. + + For use with :meth:`~pymongo.asynchronous.collection.AsyncCollection.bulk_write` and :meth:`~pymongo.collection.Collection.bulk_write`. + + :param filter: A query that matches the documents to update. + :param update: The modifications to apply. + :param upsert: If ``True``, perform an insert if no documents + match the filter. + :param collation: An instance of + :class:`~pymongo.collation.Collation`. + :param array_filters: A list of filters specifying which + array elements an update should apply. + :param hint: An index to use to support the query + predicate specified either by its string name, or in the same + format as passed to + :meth:`~pymongo.asynchronous.collection.AsyncCollection.create_index` or :meth:`~pymongo.collection.Collection.create_index` (e.g. + ``[('field', ASCENDING)]``). This option is only supported on + MongoDB 4.2 and above. + + .. versionchanged:: 3.11 + Added the `hint` option. + .. versionchanged:: 3.9 + Added the ability to accept a pipeline as the `update`. + .. versionchanged:: 3.6 + Added the `array_filters` option. + .. versionchanged:: 3.5 + Added the `collation` option. + """ + super().__init__(filter, update, upsert, collation, array_filters, hint) + + def _add_to_bulk(self, bulkobj: _AgnosticBulk) -> None: + """Add this operation to the _AsyncBulk/_Bulk instance `bulkobj`.""" + bulkobj.add_update( + self._filter, + self._doc, + True, + self._upsert, + collation=validate_collation_or_none(self._collation), + array_filters=self._array_filters, + hint=self._hint, + ) + + +class IndexModel: + """Represents an index to create.""" + + __slots__ = ("__document",) + + def __init__(self, keys: _IndexKeyHint, **kwargs: Any) -> None: + """Create an Index instance. + + For use with :meth:`~pymongo.asynchronous.collection.AsyncCollection.create_indexes` and :meth:`~pymongo.collection.Collection.create_indexes`. + + Takes either a single key or a list containing (key, direction) pairs + or keys. If no direction is given, :data:`~pymongo.ASCENDING` will + be assumed. + The key(s) must be an instance of :class:`str`, and the direction(s) must + be one of (:data:`~pymongo.ASCENDING`, :data:`~pymongo.DESCENDING`, + :data:`~pymongo.GEO2D`, :data:`~pymongo.GEOSPHERE`, + :data:`~pymongo.HASHED`, :data:`~pymongo.TEXT`). + + Valid options include, but are not limited to: + + - `name`: custom name to use for this index - if none is + given, a name will be generated. + - `unique`: if ``True``, creates a uniqueness constraint on the index. + - `background`: if ``True``, this index should be created in the + background. + - `sparse`: if ``True``, omit from the index any documents that lack + the indexed field. + - `bucketSize`: for use with geoHaystack indexes. + Number of documents to group together within a certain proximity + to a given longitude and latitude. + - `min`: minimum value for keys in a :data:`~pymongo.GEO2D` + index. + - `max`: maximum value for keys in a :data:`~pymongo.GEO2D` + index. + - `expireAfterSeconds`: Used to create an expiring (TTL) + collection. MongoDB will automatically delete documents from + this collection after seconds. The indexed field must + be a UTC datetime or the data will not expire. + - `partialFilterExpression`: A document that specifies a filter for + a partial index. + - `collation`: An instance of :class:`~pymongo.collation.Collation` + that specifies the collation to use. + - `wildcardProjection`: Allows users to include or exclude specific + field paths from a `wildcard index`_ using the { "$**" : 1} key + pattern. Requires MongoDB >= 4.2. + - `hidden`: if ``True``, this index will be hidden from the query + planner and will not be evaluated as part of query plan + selection. Requires MongoDB >= 4.4. + + See the MongoDB documentation for a full list of supported options by + server version. + + :param keys: a single key or a list containing (key, direction) pairs + or keys specifying the index to create. + :param kwargs: any additional index creation + options (see the above list) should be passed as keyword + arguments. + + .. versionchanged:: 3.11 + Added the ``hidden`` option. + .. versionchanged:: 3.2 + Added the ``partialFilterExpression`` option to support partial + indexes. + + .. _wildcard index: https://mongodb.com/docs/master/core/index-wildcard/ + """ + keys = _index_list(keys) + if kwargs.get("name") is None: + kwargs["name"] = _gen_index_name(keys) + kwargs["key"] = _index_document(keys) + collation = validate_collation_or_none(kwargs.pop("collation", None)) + self.__document = kwargs + if collation is not None: + self.__document["collation"] = collation + + @property + def document(self) -> dict[str, Any]: + """An index document suitable for passing to the createIndexes + command. + """ + return self.__document + + +class SearchIndexModel: + """Represents a search index to create.""" + + __slots__ = ("__document",) + + def __init__( + self, + definition: Mapping[str, Any], + name: Optional[str] = None, + type: Optional[str] = None, + **kwargs: Any, + ) -> None: + """Create a Search Index instance. + + For use with :meth:`~pymongo.collection.AsyncCollection.create_search_index` and :meth:`~pymongo.collection.AsyncCollection.create_search_indexes`. + + :param definition: The definition for this index. + :param name: The name for this index, if present. + :param type: The type for this index which defaults to "search". Alternative values include "vectorSearch". + :param kwargs: Keyword arguments supplying any additional options. + + .. note:: Search indexes require a MongoDB server version 7.0+ Atlas cluster. + .. versionadded:: 4.5 + .. versionchanged:: 4.7 + Added the type and kwargs arguments. + """ + self.__document: dict[str, Any] = {} + if name is not None: + self.__document["name"] = name + self.__document["definition"] = definition + if type is not None: + self.__document["type"] = type + self.__document.update(kwargs) + + @property + def document(self) -> Mapping[str, Any]: + """The document for this index.""" + return self.__document diff --git a/pymongo/pool_options.py b/pymongo/pool_options.py new file mode 100644 index 0000000000..4170bb5cb6 --- /dev/null +++ b/pymongo/pool_options.py @@ -0,0 +1,484 @@ +# Copyright 2024-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you +# may not use this file except in compliance with the License. You +# may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + +"""AsyncConnection pool options for AsyncMongoClient/MongoClient.""" +from __future__ import annotations + +import copy +import os +import platform +import sys +from pathlib import Path +from typing import TYPE_CHECKING, Any, MutableMapping, Optional + +import bson +from pymongo import __version__ +from pymongo.common import ( + MAX_CONNECTING, + MAX_IDLE_TIME_SEC, + MAX_POOL_SIZE, + MIN_POOL_SIZE, + WAIT_QUEUE_TIMEOUT, +) + +if TYPE_CHECKING: + from pymongo.auth_shared import MongoCredential + from pymongo.compression_support import CompressionSettings + from pymongo.driver_info import DriverInfo + from pymongo.monitoring import _EventListeners + from pymongo.pyopenssl_context import SSLContext + from pymongo.server_api import ServerApi + + +_METADATA: dict[str, Any] = {"driver": {"name": "PyMongo", "version": __version__}} + +if sys.platform.startswith("linux"): + # platform.linux_distribution was deprecated in Python 3.5 + # and removed in Python 3.8. Starting in Python 3.5 it + # raises DeprecationWarning + # DeprecationWarning: dist() and linux_distribution() functions are deprecated in Python 3.5 + _name = platform.system() + _METADATA["os"] = { + "type": _name, + "name": _name, + "architecture": platform.machine(), + # Kernel version (e.g. 4.4.0-17-generic). + "version": platform.release(), + } +elif sys.platform == "darwin": + _METADATA["os"] = { + "type": platform.system(), + "name": platform.system(), + "architecture": platform.machine(), + # (mac|i|tv)OS(X) version (e.g. 10.11.6) instead of darwin + # kernel version. + "version": platform.mac_ver()[0], + } +elif sys.platform == "win32": + _METADATA["os"] = { + "type": platform.system(), + # "Windows XP", "Windows 7", "Windows 10", etc. + "name": " ".join((platform.system(), platform.release())), + "architecture": platform.machine(), + # Windows patch level (e.g. 5.1.2600-SP3) + "version": "-".join(platform.win32_ver()[1:3]), + } +elif sys.platform.startswith("java"): + _name, _ver, _arch = platform.java_ver()[-1] + _METADATA["os"] = { + # Linux, Windows 7, Mac OS X, etc. + "type": _name, + "name": _name, + # x86, x86_64, AMD64, etc. + "architecture": _arch, + # Linux kernel version, OSX version, etc. + "version": _ver, + } +else: + # Get potential alias (e.g. SunOS 5.11 becomes Solaris 2.11) + _aliased = platform.system_alias(platform.system(), platform.release(), platform.version()) + _METADATA["os"] = { + "type": platform.system(), + "name": " ".join([part for part in _aliased[:2] if part]), + "architecture": platform.machine(), + "version": _aliased[2], + } + +if platform.python_implementation().startswith("PyPy"): + _METADATA["platform"] = " ".join( + ( + platform.python_implementation(), + ".".join(map(str, sys.pypy_version_info)), # type: ignore + "(Python %s)" % ".".join(map(str, sys.version_info)), + ) + ) +elif sys.platform.startswith("java"): + _METADATA["platform"] = " ".join( + ( + platform.python_implementation(), + ".".join(map(str, sys.version_info)), + "(%s)" % " ".join((platform.system(), platform.release())), + ) + ) +else: + _METADATA["platform"] = " ".join( + (platform.python_implementation(), ".".join(map(str, sys.version_info))) + ) + +DOCKER_ENV_PATH = "/.dockerenv" +ENV_VAR_K8S = "KUBERNETES_SERVICE_HOST" + +RUNTIME_NAME_DOCKER = "docker" +ORCHESTRATOR_NAME_K8S = "kubernetes" + + +def get_container_env_info() -> dict[str, str]: + """Returns the runtime and orchestrator of a container. + If neither value is present, the metadata client.env.container field will be omitted.""" + container = {} + + if Path(DOCKER_ENV_PATH).exists(): + container["runtime"] = RUNTIME_NAME_DOCKER + if os.getenv(ENV_VAR_K8S): + container["orchestrator"] = ORCHESTRATOR_NAME_K8S + + return container + + +def _is_lambda() -> bool: + if os.getenv("AWS_LAMBDA_RUNTIME_API"): + return True + env = os.getenv("AWS_EXECUTION_ENV") + if env: + return env.startswith("AWS_Lambda_") + return False + + +def _is_azure_func() -> bool: + return bool(os.getenv("FUNCTIONS_WORKER_RUNTIME")) + + +def _is_gcp_func() -> bool: + return bool(os.getenv("K_SERVICE") or os.getenv("FUNCTION_NAME")) + + +def _is_vercel() -> bool: + return bool(os.getenv("VERCEL")) + + +def _is_faas() -> bool: + return _is_lambda() or _is_azure_func() or _is_gcp_func() or _is_vercel() + + +def _getenv_int(key: str) -> Optional[int]: + """Like os.getenv but returns an int, or None if the value is missing/malformed.""" + val = os.getenv(key) + if not val: + return None + try: + return int(val) + except ValueError: + return None + + +def _metadata_env() -> dict[str, Any]: + env: dict[str, Any] = {} + container = get_container_env_info() + if container: + env["container"] = container + # Skip if multiple (or no) envs are matched. + if (_is_lambda(), _is_azure_func(), _is_gcp_func(), _is_vercel()).count(True) != 1: + return env + if _is_lambda(): + env["name"] = "aws.lambda" + region = os.getenv("AWS_REGION") + if region: + env["region"] = region + memory_mb = _getenv_int("AWS_LAMBDA_FUNCTION_MEMORY_SIZE") + if memory_mb is not None: + env["memory_mb"] = memory_mb + elif _is_azure_func(): + env["name"] = "azure.func" + elif _is_gcp_func(): + env["name"] = "gcp.func" + region = os.getenv("FUNCTION_REGION") + if region: + env["region"] = region + memory_mb = _getenv_int("FUNCTION_MEMORY_MB") + if memory_mb is not None: + env["memory_mb"] = memory_mb + timeout_sec = _getenv_int("FUNCTION_TIMEOUT_SEC") + if timeout_sec is not None: + env["timeout_sec"] = timeout_sec + elif _is_vercel(): + env["name"] = "vercel" + region = os.getenv("VERCEL_REGION") + if region: + env["region"] = region + return env + + +_MAX_METADATA_SIZE = 512 + + +# See: https://github.com/mongodb/specifications/blob/5112bcc/source/mongodb-handshake/handshake.rst#limitations +def _truncate_metadata(metadata: MutableMapping[str, Any]) -> None: + """Perform metadata truncation.""" + if len(bson.encode(metadata)) <= _MAX_METADATA_SIZE: + return + # 1. Omit fields from env except env.name. + env_name = metadata.get("env", {}).get("name") + if env_name: + metadata["env"] = {"name": env_name} + if len(bson.encode(metadata)) <= _MAX_METADATA_SIZE: + return + # 2. Omit fields from os except os.type. + os_type = metadata.get("os", {}).get("type") + if os_type: + metadata["os"] = {"type": os_type} + if len(bson.encode(metadata)) <= _MAX_METADATA_SIZE: + return + # 3. Omit the env document entirely. + metadata.pop("env", None) + encoded_size = len(bson.encode(metadata)) + if encoded_size <= _MAX_METADATA_SIZE: + return + # 4. Truncate platform. + overflow = encoded_size - _MAX_METADATA_SIZE + plat = metadata.get("platform", "") + if plat: + plat = plat[:-overflow] + if plat: + metadata["platform"] = plat + else: + metadata.pop("platform", None) + + +# If the first getaddrinfo call of this interpreter's life is on a thread, +# while the main thread holds the import lock, getaddrinfo deadlocks trying +# to import the IDNA codec. Import it here, where presumably we're on the +# main thread, to avoid the deadlock. See PYTHON-607. +"foo".encode("idna") + + +class PoolOptions: + """Read only connection pool options for an AsyncMongoClient/MongoClient. + + Should not be instantiated directly by application developers. Access + a client's pool options via + :attr:`~pymongo.client_options.ClientOptions.pool_options` instead:: + + pool_opts = client.options.pool_options + pool_opts.max_pool_size + pool_opts.min_pool_size + + """ + + __slots__ = ( + "__max_pool_size", + "__min_pool_size", + "__max_idle_time_seconds", + "__connect_timeout", + "__socket_timeout", + "__wait_queue_timeout", + "__ssl_context", + "__tls_allow_invalid_hostnames", + "__event_listeners", + "__appname", + "__driver", + "__metadata", + "__compression_settings", + "__max_connecting", + "__pause_enabled", + "__server_api", + "__load_balanced", + "__credentials", + ) + + def __init__( + self, + max_pool_size: int = MAX_POOL_SIZE, + min_pool_size: int = MIN_POOL_SIZE, + max_idle_time_seconds: Optional[int] = MAX_IDLE_TIME_SEC, + connect_timeout: Optional[float] = None, + socket_timeout: Optional[float] = None, + wait_queue_timeout: Optional[int] = WAIT_QUEUE_TIMEOUT, + ssl_context: Optional[SSLContext] = None, + tls_allow_invalid_hostnames: bool = False, + event_listeners: Optional[_EventListeners] = None, + appname: Optional[str] = None, + driver: Optional[DriverInfo] = None, + compression_settings: Optional[CompressionSettings] = None, + max_connecting: int = MAX_CONNECTING, + pause_enabled: bool = True, + server_api: Optional[ServerApi] = None, + load_balanced: Optional[bool] = None, + credentials: Optional[MongoCredential] = None, + ): + self.__max_pool_size = max_pool_size + self.__min_pool_size = min_pool_size + self.__max_idle_time_seconds = max_idle_time_seconds + self.__connect_timeout = connect_timeout + self.__socket_timeout = socket_timeout + self.__wait_queue_timeout = wait_queue_timeout + self.__ssl_context = ssl_context + self.__tls_allow_invalid_hostnames = tls_allow_invalid_hostnames + self.__event_listeners = event_listeners + self.__appname = appname + self.__driver = driver + self.__compression_settings = compression_settings + self.__max_connecting = max_connecting + self.__pause_enabled = pause_enabled + self.__server_api = server_api + self.__load_balanced = load_balanced + self.__credentials = credentials + self.__metadata = copy.deepcopy(_METADATA) + if appname: + self.__metadata["application"] = {"name": appname} + + # Combine the "driver" AsyncMongoClient option with PyMongo's info, like: + # { + # 'driver': { + # 'name': 'PyMongo|MyDriver', + # 'version': '4.2.0|1.2.3', + # }, + # 'platform': 'CPython 3.8.0|MyPlatform' + # } + if driver: + if driver.name: + self.__metadata["driver"]["name"] = "{}|{}".format( + _METADATA["driver"]["name"], + driver.name, + ) + if driver.version: + self.__metadata["driver"]["version"] = "{}|{}".format( + _METADATA["driver"]["version"], + driver.version, + ) + if driver.platform: + self.__metadata["platform"] = "{}|{}".format(_METADATA["platform"], driver.platform) + + env = _metadata_env() + if env: + self.__metadata["env"] = env + + _truncate_metadata(self.__metadata) + + @property + def _credentials(self) -> Optional[MongoCredential]: + """A :class:`~pymongo.auth.MongoCredentials` instance or None.""" + return self.__credentials + + @property + def non_default_options(self) -> dict[str, Any]: + """The non-default options this pool was created with. + + Added for CMAP's :class:`PoolCreatedEvent`. + """ + opts = {} + if self.__max_pool_size != MAX_POOL_SIZE: + opts["maxPoolSize"] = self.__max_pool_size + if self.__min_pool_size != MIN_POOL_SIZE: + opts["minPoolSize"] = self.__min_pool_size + if self.__max_idle_time_seconds != MAX_IDLE_TIME_SEC: + assert self.__max_idle_time_seconds is not None + opts["maxIdleTimeMS"] = self.__max_idle_time_seconds * 1000 + if self.__wait_queue_timeout != WAIT_QUEUE_TIMEOUT: + assert self.__wait_queue_timeout is not None + opts["waitQueueTimeoutMS"] = self.__wait_queue_timeout * 1000 + if self.__max_connecting != MAX_CONNECTING: + opts["maxConnecting"] = self.__max_connecting + return opts + + @property + def max_pool_size(self) -> float: + """The maximum allowable number of concurrent connections to each + connected server. Requests to a server will block if there are + `maxPoolSize` outstanding connections to the requested server. + Defaults to 100. Cannot be 0. + + When a server's pool has reached `max_pool_size`, operations for that + server block waiting for a socket to be returned to the pool. If + ``waitQueueTimeoutMS`` is set, a blocked operation will raise + :exc:`~pymongo.errors.ConnectionFailure` after a timeout. + By default ``waitQueueTimeoutMS`` is not set. + """ + return self.__max_pool_size + + @property + def min_pool_size(self) -> int: + """The minimum required number of concurrent connections that the pool + will maintain to each connected server. Default is 0. + """ + return self.__min_pool_size + + @property + def max_connecting(self) -> int: + """The maximum number of concurrent connection creation attempts per + pool. Defaults to 2. + """ + return self.__max_connecting + + @property + def pause_enabled(self) -> bool: + return self.__pause_enabled + + @property + def max_idle_time_seconds(self) -> Optional[int]: + """The maximum number of seconds that a connection can remain + idle in the pool before being removed and replaced. Defaults to + `None` (no limit). + """ + return self.__max_idle_time_seconds + + @property + def connect_timeout(self) -> Optional[float]: + """How long a connection can take to be opened before timing out.""" + return self.__connect_timeout + + @property + def socket_timeout(self) -> Optional[float]: + """How long a send or receive on a socket can take before timing out.""" + return self.__socket_timeout + + @property + def wait_queue_timeout(self) -> Optional[int]: + """How long a thread will wait for a socket from the pool if the pool + has no free sockets. + """ + return self.__wait_queue_timeout + + @property + def _ssl_context(self) -> Optional[SSLContext]: + """An SSLContext instance or None.""" + return self.__ssl_context + + @property + def tls_allow_invalid_hostnames(self) -> bool: + """If True skip ssl.match_hostname.""" + return self.__tls_allow_invalid_hostnames + + @property + def _event_listeners(self) -> Optional[_EventListeners]: + """An instance of pymongo.monitoring._EventListeners.""" + return self.__event_listeners + + @property + def appname(self) -> Optional[str]: + """The application name, for sending with hello in server handshake.""" + return self.__appname + + @property + def driver(self) -> Optional[DriverInfo]: + """Driver name and version, for sending with hello in handshake.""" + return self.__driver + + @property + def _compression_settings(self) -> Optional[CompressionSettings]: + return self.__compression_settings + + @property + def metadata(self) -> dict[str, Any]: + """A dict of metadata about the application, driver, os, and platform.""" + return self.__metadata.copy() + + @property + def server_api(self) -> Optional[ServerApi]: + """A pymongo.server_api.ServerApi or None.""" + return self.__server_api + + @property + def load_balanced(self) -> Optional[bool]: + """True if this Pool is configured in load balanced mode.""" + return self.__load_balanced diff --git a/pymongo/read_preferences.py b/pymongo/read_preferences.py index de15cbfcaf..10deba7bbe 100644 --- a/pymongo/read_preferences.py +++ b/pymongo/read_preferences.py @@ -1,6 +1,6 @@ -# Copyright 2024-present MongoDB, Inc. +# Copyright 2012-present MongoDB, Inc. # -# Licensed under the Apache License, Version 2.0 (the "License"); +# Licensed under the Apache License, Version 2.0 (the "License", # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # @@ -12,10 +12,613 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Re-import of synchronous ReadPreferences API for compatibility.""" +"""Utilities for choosing which member of a replica set to read from.""" + from __future__ import annotations -from pymongo.synchronous.read_preferences import * # noqa: F403 -from pymongo.synchronous.read_preferences import __doc__ as original_doc +from collections import abc +from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence + +from pymongo import max_staleness_selectors +from pymongo.errors import ConfigurationError +from pymongo.server_selectors import ( + member_with_tags_server_selector, + secondary_with_tags_server_selector, +) + +if TYPE_CHECKING: + from pymongo.server_selectors import Selection + from pymongo.topology_description import TopologyDescription + +_IS_SYNC = False + +_PRIMARY = 0 +_PRIMARY_PREFERRED = 1 +_SECONDARY = 2 +_SECONDARY_PREFERRED = 3 +_NEAREST = 4 + + +_MONGOS_MODES = ( + "primary", + "primaryPreferred", + "secondary", + "secondaryPreferred", + "nearest", +) + +_Hedge = Mapping[str, Any] +_TagSets = Sequence[Mapping[str, Any]] + + +def _validate_tag_sets(tag_sets: Optional[_TagSets]) -> Optional[_TagSets]: + """Validate tag sets for a MongoClient.""" + if tag_sets is None: + return tag_sets + + if not isinstance(tag_sets, (list, tuple)): + raise TypeError(f"Tag sets {tag_sets!r} invalid, must be a sequence") + if len(tag_sets) == 0: + raise ValueError( + f"Tag sets {tag_sets!r} invalid, must be None or contain at least one set of tags" + ) + + for tags in tag_sets: + if not isinstance(tags, abc.Mapping): + raise TypeError( + f"Tag set {tags!r} invalid, must be an instance of dict, " + "bson.son.SON or other type that inherits from " + "collection.Mapping" + ) + + return list(tag_sets) + + +def _invalid_max_staleness_msg(max_staleness: Any) -> str: + return "maxStalenessSeconds must be a positive integer, not %s" % max_staleness + + +# Some duplication with common.py to avoid import cycle. +def _validate_max_staleness(max_staleness: Any) -> int: + """Validate max_staleness.""" + if max_staleness == -1: + return -1 + + if not isinstance(max_staleness, int): + raise TypeError(_invalid_max_staleness_msg(max_staleness)) + + if max_staleness <= 0: + raise ValueError(_invalid_max_staleness_msg(max_staleness)) + + return max_staleness + + +def _validate_hedge(hedge: Optional[_Hedge]) -> Optional[_Hedge]: + """Validate hedge.""" + if hedge is None: + return None + + if not isinstance(hedge, dict): + raise TypeError(f"hedge must be a dictionary, not {hedge!r}") + + return hedge + + +class _ServerMode: + """Base class for all read preferences.""" + + __slots__ = ("__mongos_mode", "__mode", "__tag_sets", "__max_staleness", "__hedge") + + def __init__( + self, + mode: int, + tag_sets: Optional[_TagSets] = None, + max_staleness: int = -1, + hedge: Optional[_Hedge] = None, + ) -> None: + self.__mongos_mode = _MONGOS_MODES[mode] + self.__mode = mode + self.__tag_sets = _validate_tag_sets(tag_sets) + self.__max_staleness = _validate_max_staleness(max_staleness) + self.__hedge = _validate_hedge(hedge) + + @property + def name(self) -> str: + """The name of this read preference.""" + return self.__class__.__name__ + + @property + def mongos_mode(self) -> str: + """The mongos mode of this read preference.""" + return self.__mongos_mode + + @property + def document(self) -> dict[str, Any]: + """Read preference as a document.""" + doc: dict[str, Any] = {"mode": self.__mongos_mode} + if self.__tag_sets not in (None, [{}]): + doc["tags"] = self.__tag_sets + if self.__max_staleness != -1: + doc["maxStalenessSeconds"] = self.__max_staleness + if self.__hedge not in (None, {}): + doc["hedge"] = self.__hedge + return doc + + @property + def mode(self) -> int: + """The mode of this read preference instance.""" + return self.__mode + + @property + def tag_sets(self) -> _TagSets: + """Set ``tag_sets`` to a list of dictionaries like [{'dc': 'ny'}] to + read only from members whose ``dc`` tag has the value ``"ny"``. + To specify a priority-order for tag sets, provide a list of + tag sets: ``[{'dc': 'ny'}, {'dc': 'la'}, {}]``. A final, empty tag + set, ``{}``, means "read from any member that matches the mode, + ignoring tags." MongoClient tries each set of tags in turn + until it finds a set of tags with at least one matching member. + For example, to only send a query to an analytic node:: + + Nearest(tag_sets=[{"node":"analytics"}]) + + Or using :class:`SecondaryPreferred`:: + + SecondaryPreferred(tag_sets=[{"node":"analytics"}]) + + .. seealso:: `Data-Center Awareness + `_ + """ + return list(self.__tag_sets) if self.__tag_sets else [{}] + + @property + def max_staleness(self) -> int: + """The maximum estimated length of time (in seconds) a replica set + secondary can fall behind the primary in replication before it will + no longer be selected for operations, or -1 for no maximum. + """ + return self.__max_staleness + + @property + def hedge(self) -> Optional[_Hedge]: + """The read preference ``hedge`` parameter. + + A dictionary that configures how the server will perform hedged reads. + It consists of the following keys: + + - ``enabled``: Enables or disables hedged reads in sharded clusters. + + Hedged reads are automatically enabled in MongoDB 4.4+ when using a + ``nearest`` read preference. To explicitly enable hedged reads, set + the ``enabled`` key to ``true``:: + + >>> Nearest(hedge={'enabled': True}) + + To explicitly disable hedged reads, set the ``enabled`` key to + ``False``:: + + >>> Nearest(hedge={'enabled': False}) + + .. versionadded:: 3.11 + """ + return self.__hedge + + @property + def min_wire_version(self) -> int: + """The wire protocol version the server must support. + + Some read preferences impose version requirements on all servers (e.g. + maxStalenessSeconds requires MongoDB 3.4 / maxWireVersion 5). + + All servers' maxWireVersion must be at least this read preference's + `min_wire_version`, or the driver raises + :exc:`~pymongo.errors.ConfigurationError`. + """ + return 0 if self.__max_staleness == -1 else 5 + + def __repr__(self) -> str: + return "{}(tag_sets={!r}, max_staleness={!r}, hedge={!r})".format( + self.name, + self.__tag_sets, + self.__max_staleness, + self.__hedge, + ) + + def __eq__(self, other: Any) -> bool: + if isinstance(other, _ServerMode): + return ( + self.mode == other.mode + and self.tag_sets == other.tag_sets + and self.max_staleness == other.max_staleness + and self.hedge == other.hedge + ) + return NotImplemented + + def __ne__(self, other: Any) -> bool: + return not self == other + + def __getstate__(self) -> dict[str, Any]: + """Return value of object for pickling. + + Needed explicitly because __slots__() defined. + """ + return { + "mode": self.__mode, + "tag_sets": self.__tag_sets, + "max_staleness": self.__max_staleness, + "hedge": self.__hedge, + } + + def __setstate__(self, value: Mapping[str, Any]) -> None: + """Restore from pickling.""" + self.__mode = value["mode"] + self.__mongos_mode = _MONGOS_MODES[self.__mode] + self.__tag_sets = _validate_tag_sets(value["tag_sets"]) + self.__max_staleness = _validate_max_staleness(value["max_staleness"]) + self.__hedge = _validate_hedge(value["hedge"]) + + def __call__(self, selection: Selection) -> Selection: + return selection + + +class Primary(_ServerMode): + """Primary read preference. + + * When directly connected to one mongod queries are allowed if the server + is standalone or a replica set primary. + * When connected to a mongos queries are sent to the primary of a shard. + * When connected to a replica set queries are sent to the primary of + the replica set. + """ + + __slots__ = () + + def __init__(self) -> None: + super().__init__(_PRIMARY) + + def __call__(self, selection: Selection) -> Selection: + """Apply this read preference to a Selection.""" + return selection.primary_selection + + def __repr__(self) -> str: + return "Primary()" + + def __eq__(self, other: Any) -> bool: + if isinstance(other, _ServerMode): + return other.mode == _PRIMARY + return NotImplemented + + +class PrimaryPreferred(_ServerMode): + """PrimaryPreferred read preference. + + * When directly connected to one mongod queries are allowed to standalone + servers, to a replica set primary, or to replica set secondaries. + * When connected to a mongos queries are sent to the primary of a shard if + available, otherwise a shard secondary. + * When connected to a replica set queries are sent to the primary if + available, otherwise a secondary. + + .. note:: When a :class:`~pymongo.mongo_client.MongoClient` is first + created reads will be routed to an available secondary until the + primary of the replica set is discovered. + + :param tag_sets: The :attr:`~tag_sets` to use if the primary is not + available. + :param max_staleness: (integer, in seconds) The maximum estimated + length of time a replica set secondary can fall behind the primary in + replication before it will no longer be selected for operations. + Default -1, meaning no maximum. If it is set, it must be at least + 90 seconds. + :param hedge: The :attr:`~hedge` to use if the primary is not available. + + .. versionchanged:: 3.11 + Added ``hedge`` parameter. + """ + + __slots__ = () + + def __init__( + self, + tag_sets: Optional[_TagSets] = None, + max_staleness: int = -1, + hedge: Optional[_Hedge] = None, + ) -> None: + super().__init__(_PRIMARY_PREFERRED, tag_sets, max_staleness, hedge) + + def __call__(self, selection: Selection) -> Selection: + """Apply this read preference to Selection.""" + if selection.primary: + return selection.primary_selection + else: + return secondary_with_tags_server_selector( + self.tag_sets, max_staleness_selectors.select(self.max_staleness, selection) + ) + + +class Secondary(_ServerMode): + """Secondary read preference. + + * When directly connected to one mongod queries are allowed to standalone + servers, to a replica set primary, or to replica set secondaries. + * When connected to a mongos queries are distributed among shard + secondaries. An error is raised if no secondaries are available. + * When connected to a replica set queries are distributed among + secondaries. An error is raised if no secondaries are available. + + :param tag_sets: The :attr:`~tag_sets` for this read preference. + :param max_staleness: (integer, in seconds) The maximum estimated + length of time a replica set secondary can fall behind the primary in + replication before it will no longer be selected for operations. + Default -1, meaning no maximum. If it is set, it must be at least + 90 seconds. + :param hedge: The :attr:`~hedge` for this read preference. + + .. versionchanged:: 3.11 + Added ``hedge`` parameter. + """ + + __slots__ = () + + def __init__( + self, + tag_sets: Optional[_TagSets] = None, + max_staleness: int = -1, + hedge: Optional[_Hedge] = None, + ) -> None: + super().__init__(_SECONDARY, tag_sets, max_staleness, hedge) + + def __call__(self, selection: Selection) -> Selection: + """Apply this read preference to Selection.""" + return secondary_with_tags_server_selector( + self.tag_sets, max_staleness_selectors.select(self.max_staleness, selection) + ) + + +class SecondaryPreferred(_ServerMode): + """SecondaryPreferred read preference. + + * When directly connected to one mongod queries are allowed to standalone + servers, to a replica set primary, or to replica set secondaries. + * When connected to a mongos queries are distributed among shard + secondaries, or the shard primary if no secondary is available. + * When connected to a replica set queries are distributed among + secondaries, or the primary if no secondary is available. + + .. note:: When a :class:`~pymongo.mongo_client.MongoClient` is first + created reads will be routed to the primary of the replica set until + an available secondary is discovered. + + :param tag_sets: The :attr:`~tag_sets` for this read preference. + :param max_staleness: (integer, in seconds) The maximum estimated + length of time a replica set secondary can fall behind the primary in + replication before it will no longer be selected for operations. + Default -1, meaning no maximum. If it is set, it must be at least + 90 seconds. + :param hedge: The :attr:`~hedge` for this read preference. + + .. versionchanged:: 3.11 + Added ``hedge`` parameter. + """ + + __slots__ = () + + def __init__( + self, + tag_sets: Optional[_TagSets] = None, + max_staleness: int = -1, + hedge: Optional[_Hedge] = None, + ) -> None: + super().__init__(_SECONDARY_PREFERRED, tag_sets, max_staleness, hedge) + + def __call__(self, selection: Selection) -> Selection: + """Apply this read preference to Selection.""" + secondaries = secondary_with_tags_server_selector( + self.tag_sets, max_staleness_selectors.select(self.max_staleness, selection) + ) + + if secondaries: + return secondaries + else: + return selection.primary_selection + + +class Nearest(_ServerMode): + """Nearest read preference. + + * When directly connected to one mongod queries are allowed to standalone + servers, to a replica set primary, or to replica set secondaries. + * When connected to a mongos queries are distributed among all members of + a shard. + * When connected to a replica set queries are distributed among all + members. + + :param tag_sets: The :attr:`~tag_sets` for this read preference. + :param max_staleness: (integer, in seconds) The maximum estimated + length of time a replica set secondary can fall behind the primary in + replication before it will no longer be selected for operations. + Default -1, meaning no maximum. If it is set, it must be at least + 90 seconds. + :param hedge: The :attr:`~hedge` for this read preference. + + .. versionchanged:: 3.11 + Added ``hedge`` parameter. + """ + + __slots__ = () + + def __init__( + self, + tag_sets: Optional[_TagSets] = None, + max_staleness: int = -1, + hedge: Optional[_Hedge] = None, + ) -> None: + super().__init__(_NEAREST, tag_sets, max_staleness, hedge) + + def __call__(self, selection: Selection) -> Selection: + """Apply this read preference to Selection.""" + return member_with_tags_server_selector( + self.tag_sets, max_staleness_selectors.select(self.max_staleness, selection) + ) + + +class _AggWritePref: + """Agg $out/$merge write preference. + + * If there are readable servers and there is any pre-5.0 server, use + primary read preference. + * Otherwise use `pref` read preference. + + :param pref: The read preference to use on MongoDB 5.0+. + """ + + __slots__ = ("pref", "effective_pref") + + def __init__(self, pref: _ServerMode): + self.pref = pref + self.effective_pref: _ServerMode = ReadPreference.PRIMARY + + def selection_hook(self, topology_description: TopologyDescription) -> None: + common_wv = topology_description.common_wire_version + if ( + topology_description.has_readable_server(ReadPreference.PRIMARY_PREFERRED) + and common_wv + and common_wv < 13 + ): + self.effective_pref = ReadPreference.PRIMARY + else: + self.effective_pref = self.pref + + def __call__(self, selection: Selection) -> Selection: + """Apply this read preference to a Selection.""" + return self.effective_pref(selection) + + def __repr__(self) -> str: + return f"_AggWritePref(pref={self.pref!r})" + + # Proxy other calls to the effective_pref so that _AggWritePref can be + # used in place of an actual read preference. + def __getattr__(self, name: str) -> Any: + return getattr(self.effective_pref, name) + + +_ALL_READ_PREFERENCES = (Primary, PrimaryPreferred, Secondary, SecondaryPreferred, Nearest) + + +def make_read_preference( + mode: int, tag_sets: Optional[_TagSets], max_staleness: int = -1 +) -> _ServerMode: + if mode == _PRIMARY: + if tag_sets not in (None, [{}]): + raise ConfigurationError("Read preference primary cannot be combined with tags") + if max_staleness != -1: + raise ConfigurationError( + "Read preference primary cannot be combined with maxStalenessSeconds" + ) + return Primary() + return _ALL_READ_PREFERENCES[mode](tag_sets, max_staleness) # type: ignore + + +_MODES = ( + "PRIMARY", + "PRIMARY_PREFERRED", + "SECONDARY", + "SECONDARY_PREFERRED", + "NEAREST", +) + + +class ReadPreference: + """An enum that defines some commonly used read preference modes. + + Apps can also create a custom read preference, for example:: + + Nearest(tag_sets=[{"node":"analytics"}]) + + See :doc:`/examples/high_availability` for code examples. + + A read preference is used in three cases: + + :class:`~pymongo.mongo_client.MongoClient` connected to a single mongod: + + - ``PRIMARY``: Queries are allowed if the server is standalone or a replica + set primary. + - All other modes allow queries to standalone servers, to a replica set + primary, or to replica set secondaries. + + :class:`~pymongo.mongo_client.MongoClient` initialized with the + ``replicaSet`` option: + + - ``PRIMARY``: Read from the primary. This is the default, and provides the + strongest consistency. If no primary is available, raise + :class:`~pymongo.errors.AutoReconnect`. + + - ``PRIMARY_PREFERRED``: Read from the primary if available, or if there is + none, read from a secondary. + + - ``SECONDARY``: Read from a secondary. If no secondary is available, + raise :class:`~pymongo.errors.AutoReconnect`. + + - ``SECONDARY_PREFERRED``: Read from a secondary if available, otherwise + from the primary. + + - ``NEAREST``: Read from any member. + + :class:`~pymongo.mongo_client.MongoClient` connected to a mongos, with a + sharded cluster of replica sets: + + - ``PRIMARY``: Read from the primary of the shard, or raise + :class:`~pymongo.errors.OperationFailure` if there is none. + This is the default. + + - ``PRIMARY_PREFERRED``: Read from the primary of the shard, or if there is + none, read from a secondary of the shard. + + - ``SECONDARY``: Read from a secondary of the shard, or raise + :class:`~pymongo.errors.OperationFailure` if there is none. + + - ``SECONDARY_PREFERRED``: Read from a secondary of the shard if available, + otherwise from the shard primary. + + - ``NEAREST``: Read from any shard member. + """ + + PRIMARY = Primary() + PRIMARY_PREFERRED = PrimaryPreferred() + SECONDARY = Secondary() + SECONDARY_PREFERRED = SecondaryPreferred() + NEAREST = Nearest() + + +def read_pref_mode_from_name(name: str) -> int: + """Get the read preference mode from mongos/uri name.""" + return _MONGOS_MODES.index(name) + + +class MovingAverage: + """Tracks an exponentially-weighted moving average.""" + + average: Optional[float] + + def __init__(self) -> None: + self.average = None + + def add_sample(self, sample: float) -> None: + if sample < 0: + # Likely system time change while waiting for hello response + # and not using time.monotonic. Ignore it, the next one will + # probably be valid. + return + if self.average is None: + self.average = sample + else: + # The Server Selection Spec requires an exponentially weighted + # average with alpha = 0.2. + self.average = 0.8 * self.average + 0.2 * sample + + def get(self) -> Optional[float]: + """Get the calculated average, or None if no samples yet.""" + return self.average -__doc__ = original_doc + def reset(self) -> None: + self.average = None diff --git a/pymongo/asynchronous/response.py b/pymongo/response.py similarity index 93% rename from pymongo/asynchronous/response.py rename to pymongo/response.py index f19328f6ee..99a154efae 100644 --- a/pymongo/asynchronous/response.py +++ b/pymongo/response.py @@ -21,8 +21,7 @@ from datetime import timedelta from pymongo.asynchronous.message import _OpMsg, _OpReply - from pymongo.asynchronous.pool import Connection - from pymongo.asynchronous.typings import _Address, _DocumentOut + from pymongo.typings import _Address, _AgnosticConnection, _DocumentOut _IS_SYNC = False @@ -92,7 +91,7 @@ def __init__( self, data: Union[_OpMsg, _OpReply], address: _Address, - conn: Connection, + conn: _AgnosticConnection, request_id: int, duration: Optional[timedelta], from_command: bool, @@ -103,7 +102,7 @@ def __init__( :param data: A network response message. :param address: (host, port) of the source server. - :param conn: The Connection used for the initial query. + :param conn: The AsyncConnection/Connection used for the initial query. :param request_id: The request id of this operation. :param duration: The duration of the operation. :param from_command: If the response is the result of a db command. @@ -116,8 +115,8 @@ def __init__( self._more_to_come = more_to_come @property - def conn(self) -> Connection: - """The Connection used for the initial query. + def conn(self) -> _AgnosticConnection: + """The AsyncConnection/Connection used for the initial query. The server will send batches on this socket, without waiting for getMores from the client, until the result set is exhausted or there diff --git a/pymongo/server_description.py b/pymongo/server_description.py index 4ee6b340d9..5a2e62837d 100644 --- a/pymongo/server_description.py +++ b/pymongo/server_description.py @@ -1,4 +1,4 @@ -# Copyright 2024-present MongoDB, Inc. +# Copyright 2014-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,10 +12,290 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Re-import of synchronous ServerDescription API for compatibility.""" +"""Represent one server the driver is connected to.""" from __future__ import annotations -from pymongo.synchronous.server_description import * # noqa: F403 -from pymongo.synchronous.server_description import __doc__ as original_doc +import time +import warnings +from typing import Any, Mapping, Optional -__doc__ = original_doc +from bson import EPOCH_NAIVE +from bson.objectid import ObjectId +from pymongo.hello import Hello +from pymongo.server_type import SERVER_TYPE +from pymongo.typings import ClusterTime, _Address + +_IS_SYNC = False + + +class ServerDescription: + """Immutable representation of one server. + + :param address: A (host, port) pair + :param hello: Optional Hello instance + :param round_trip_time: Optional float + :param error: Optional, the last error attempting to connect to the server + :param round_trip_time: Optional float, the min latency from the most recent samples + """ + + __slots__ = ( + "_address", + "_server_type", + "_all_hosts", + "_tags", + "_replica_set_name", + "_primary", + "_max_bson_size", + "_max_message_size", + "_max_write_batch_size", + "_min_wire_version", + "_max_wire_version", + "_round_trip_time", + "_min_round_trip_time", + "_me", + "_is_writable", + "_is_readable", + "_ls_timeout_minutes", + "_error", + "_set_version", + "_election_id", + "_cluster_time", + "_last_write_date", + "_last_update_time", + "_topology_version", + ) + + def __init__( + self, + address: _Address, + hello: Optional[Hello] = None, + round_trip_time: Optional[float] = None, + error: Optional[Exception] = None, + min_round_trip_time: float = 0.0, + ) -> None: + self._address = address + if not hello: + hello = Hello({}) + + self._server_type = hello.server_type + self._all_hosts = hello.all_hosts + self._tags = hello.tags + self._replica_set_name = hello.replica_set_name + self._primary = hello.primary + self._max_bson_size = hello.max_bson_size + self._max_message_size = hello.max_message_size + self._max_write_batch_size = hello.max_write_batch_size + self._min_wire_version = hello.min_wire_version + self._max_wire_version = hello.max_wire_version + self._set_version = hello.set_version + self._election_id = hello.election_id + self._cluster_time = hello.cluster_time + self._is_writable = hello.is_writable + self._is_readable = hello.is_readable + self._ls_timeout_minutes = hello.logical_session_timeout_minutes + self._round_trip_time = round_trip_time + self._min_round_trip_time = min_round_trip_time + self._me = hello.me + self._last_update_time = time.monotonic() + self._error = error + self._topology_version = hello.topology_version + if error: + details = getattr(error, "details", None) + if isinstance(details, dict): + self._topology_version = details.get("topologyVersion") + + self._last_write_date: Optional[float] + if hello.last_write_date: + # Convert from datetime to seconds. + delta = hello.last_write_date - EPOCH_NAIVE + self._last_write_date = delta.total_seconds() + else: + self._last_write_date = None + + @property + def address(self) -> _Address: + """The address (host, port) of this server.""" + return self._address + + @property + def server_type(self) -> int: + """The type of this server.""" + return self._server_type + + @property + def server_type_name(self) -> str: + """The server type as a human readable string. + + .. versionadded:: 3.4 + """ + return SERVER_TYPE._fields[self._server_type] + + @property + def all_hosts(self) -> set[tuple[str, int]]: + """List of hosts, passives, and arbiters known to this server.""" + return self._all_hosts + + @property + def tags(self) -> Mapping[str, Any]: + return self._tags + + @property + def replica_set_name(self) -> Optional[str]: + """Replica set name or None.""" + return self._replica_set_name + + @property + def primary(self) -> Optional[tuple[str, int]]: + """This server's opinion about who the primary is, or None.""" + return self._primary + + @property + def max_bson_size(self) -> int: + return self._max_bson_size + + @property + def max_message_size(self) -> int: + return self._max_message_size + + @property + def max_write_batch_size(self) -> int: + return self._max_write_batch_size + + @property + def min_wire_version(self) -> int: + return self._min_wire_version + + @property + def max_wire_version(self) -> int: + return self._max_wire_version + + @property + def set_version(self) -> Optional[int]: + return self._set_version + + @property + def election_id(self) -> Optional[ObjectId]: + return self._election_id + + @property + def cluster_time(self) -> Optional[ClusterTime]: + return self._cluster_time + + @property + def election_tuple(self) -> tuple[Optional[int], Optional[ObjectId]]: + warnings.warn( + "'election_tuple' is deprecated, use 'set_version' and 'election_id' instead", + DeprecationWarning, + stacklevel=2, + ) + return self._set_version, self._election_id + + @property + def me(self) -> Optional[tuple[str, int]]: + return self._me + + @property + def logical_session_timeout_minutes(self) -> Optional[int]: + return self._ls_timeout_minutes + + @property + def last_write_date(self) -> Optional[float]: + return self._last_write_date + + @property + def last_update_time(self) -> float: + return self._last_update_time + + @property + def round_trip_time(self) -> Optional[float]: + """The current average latency or None.""" + # This override is for unittesting only! + if self._address in self._host_to_round_trip_time: + return self._host_to_round_trip_time[self._address] + + return self._round_trip_time + + @property + def min_round_trip_time(self) -> float: + """The min latency from the most recent samples.""" + return self._min_round_trip_time + + @property + def error(self) -> Optional[Exception]: + """The last error attempting to connect to the server, or None.""" + return self._error + + @property + def is_writable(self) -> bool: + return self._is_writable + + @property + def is_readable(self) -> bool: + return self._is_readable + + @property + def mongos(self) -> bool: + return self._server_type == SERVER_TYPE.Mongos + + @property + def is_server_type_known(self) -> bool: + return self.server_type != SERVER_TYPE.Unknown + + @property + def retryable_writes_supported(self) -> bool: + """Checks if this server supports retryable writes.""" + return ( + self._server_type in (SERVER_TYPE.Mongos, SERVER_TYPE.RSPrimary) + ) or self._server_type == SERVER_TYPE.LoadBalancer + + @property + def retryable_reads_supported(self) -> bool: + """Checks if this server supports retryable writes.""" + return self._max_wire_version >= 6 + + @property + def topology_version(self) -> Optional[Mapping[str, Any]]: + return self._topology_version + + def to_unknown(self, error: Optional[Exception] = None) -> ServerDescription: + unknown = ServerDescription(self.address, error=error) + unknown._topology_version = self.topology_version + return unknown + + def __eq__(self, other: Any) -> bool: + if isinstance(other, ServerDescription): + return ( + (self._address == other.address) + and (self._server_type == other.server_type) + and (self._min_wire_version == other.min_wire_version) + and (self._max_wire_version == other.max_wire_version) + and (self._me == other.me) + and (self._all_hosts == other.all_hosts) + and (self._tags == other.tags) + and (self._replica_set_name == other.replica_set_name) + and (self._set_version == other.set_version) + and (self._election_id == other.election_id) + and (self._primary == other.primary) + and (self._ls_timeout_minutes == other.logical_session_timeout_minutes) + and (self._error == other.error) + ) + + return NotImplemented + + def __ne__(self, other: Any) -> bool: + return not self == other + + def __repr__(self) -> str: + errmsg = "" + if self.error: + errmsg = f", error={self.error!r}" + return "<{} {} server_type: {}, rtt: {}{}>".format( + self.__class__.__name__, + self.address, + self.server_type_name, + self.round_trip_time, + errmsg, + ) + + # For unittesting only. Use under no circumstances! + _host_to_round_trip_time: dict = {} diff --git a/pymongo/asynchronous/server_selectors.py b/pymongo/server_selectors.py similarity index 97% rename from pymongo/asynchronous/server_selectors.py rename to pymongo/server_selectors.py index eeaebadd6e..c0f7ad6ea6 100644 --- a/pymongo/asynchronous/server_selectors.py +++ b/pymongo/server_selectors.py @@ -20,8 +20,8 @@ from pymongo.server_type import SERVER_TYPE if TYPE_CHECKING: - from pymongo.asynchronous.server_description import ServerDescription - from pymongo.asynchronous.topology_description import TopologyDescription + from pymongo.server_description import ServerDescription + from pymongo.topology_description import TopologyDescription _IS_SYNC = False diff --git a/pymongo/asynchronous/srv_resolver.py b/pymongo/srv_resolver.py similarity index 98% rename from pymongo/asynchronous/srv_resolver.py rename to pymongo/srv_resolver.py index 1a37bad966..2d699f9c1f 100644 --- a/pymongo/asynchronous/srv_resolver.py +++ b/pymongo/srv_resolver.py @@ -19,7 +19,7 @@ import random from typing import TYPE_CHECKING, Any, Optional, Union -from pymongo.asynchronous.common import CONNECT_TIMEOUT +from pymongo.common import CONNECT_TIMEOUT from pymongo.errors import ConfigurationError if TYPE_CHECKING: diff --git a/pymongo/synchronous/aggregation.py b/pymongo/synchronous/aggregation.py index a4b5a957cb..7c7e6252f7 100644 --- a/pymongo/synchronous/aggregation.py +++ b/pymongo/synchronous/aggregation.py @@ -18,20 +18,20 @@ from collections.abc import Callable, Mapping, MutableMapping from typing import TYPE_CHECKING, Any, Optional, Union +from pymongo import common +from pymongo.collation import validate_collation_or_none from pymongo.errors import ConfigurationError -from pymongo.synchronous import common -from pymongo.synchronous.collation import validate_collation_or_none -from pymongo.synchronous.read_preferences import ReadPreference, _AggWritePref +from pymongo.read_preferences import ReadPreference, _AggWritePref if TYPE_CHECKING: + from pymongo.read_preferences import _ServerMode from pymongo.synchronous.client_session import ClientSession from pymongo.synchronous.collection import Collection from pymongo.synchronous.command_cursor import CommandCursor from pymongo.synchronous.database import Database from pymongo.synchronous.pool import Connection - from pymongo.synchronous.read_preferences import _ServerMode from pymongo.synchronous.server import Server - from pymongo.synchronous.typings import _DocumentType, _Pipeline + from pymongo.typings import _DocumentType, _Pipeline _IS_SYNC = True diff --git a/pymongo/synchronous/auth.py b/pymongo/synchronous/auth.py index cb1b23d15b..9a3477679d 100644 --- a/pymongo/synchronous/auth.py +++ b/pymongo/synchronous/auth.py @@ -18,16 +18,12 @@ import functools import hashlib import hmac -import os import socket -import typing from base64 import standard_b64decode, standard_b64encode -from collections import namedtuple from typing import ( TYPE_CHECKING, Any, Callable, - Dict, Mapping, MutableMapping, Optional, @@ -36,20 +32,22 @@ from urllib.parse import quote from bson.binary import Binary +from pymongo.auth_shared import ( + MongoCredential, + _authenticate_scram_start, + _parse_scram_response, + _xor, +) from pymongo.errors import ConfigurationError, OperationFailure from pymongo.saslprep import saslprep from pymongo.synchronous.auth_aws import _authenticate_aws from pymongo.synchronous.auth_oidc import ( _authenticate_oidc, _get_authenticator, - _OIDCAzureCallback, - _OIDCGCPCallback, - _OIDCProperties, - _OIDCTestCallback, ) if TYPE_CHECKING: - from pymongo.synchronous.hello import Hello + from pymongo.hello import Hello from pymongo.synchronous.pool import Connection HAVE_KERBEROS = True @@ -68,210 +66,6 @@ _IS_SYNC = True -MECHANISMS = frozenset( - [ - "GSSAPI", - "MONGODB-CR", - "MONGODB-OIDC", - "MONGODB-X509", - "MONGODB-AWS", - "PLAIN", - "SCRAM-SHA-1", - "SCRAM-SHA-256", - "DEFAULT", - ] -) -"""The authentication mechanisms supported by PyMongo.""" - - -class _Cache: - __slots__ = ("data",) - - _hash_val = hash("_Cache") - - def __init__(self) -> None: - self.data = None - - def __eq__(self, other: object) -> bool: - # Two instances must always compare equal. - if isinstance(other, _Cache): - return True - return NotImplemented - - def __ne__(self, other: object) -> bool: - if isinstance(other, _Cache): - return False - return NotImplemented - - def __hash__(self) -> int: - return self._hash_val - - -MongoCredential = namedtuple( - "MongoCredential", - ["mechanism", "source", "username", "password", "mechanism_properties", "cache"], -) -"""A hashable namedtuple of values used for authentication.""" - - -GSSAPIProperties = namedtuple( - "GSSAPIProperties", ["service_name", "canonicalize_host_name", "service_realm"] -) -"""Mechanism properties for GSSAPI authentication.""" - - -_AWSProperties = namedtuple("_AWSProperties", ["aws_session_token"]) -"""Mechanism properties for MONGODB-AWS authentication.""" - - -def _build_credentials_tuple( - mech: str, - source: Optional[str], - user: str, - passwd: str, - extra: Mapping[str, Any], - database: Optional[str], -) -> MongoCredential: - """Build and return a mechanism specific credentials tuple.""" - if mech not in ("MONGODB-X509", "MONGODB-AWS", "MONGODB-OIDC") and user is None: - raise ConfigurationError(f"{mech} requires a username.") - if mech == "GSSAPI": - if source is not None and source != "$external": - raise ValueError("authentication source must be $external or None for GSSAPI") - properties = extra.get("authmechanismproperties", {}) - service_name = properties.get("SERVICE_NAME", "mongodb") - canonicalize = bool(properties.get("CANONICALIZE_HOST_NAME", False)) - service_realm = properties.get("SERVICE_REALM") - props = GSSAPIProperties( - service_name=service_name, - canonicalize_host_name=canonicalize, - service_realm=service_realm, - ) - # Source is always $external. - return MongoCredential(mech, "$external", user, passwd, props, None) - elif mech == "MONGODB-X509": - if passwd is not None: - raise ConfigurationError("Passwords are not supported by MONGODB-X509") - if source is not None and source != "$external": - raise ValueError("authentication source must be $external or None for MONGODB-X509") - # Source is always $external, user can be None. - return MongoCredential(mech, "$external", user, None, None, None) - elif mech == "MONGODB-AWS": - if user is not None and passwd is None: - raise ConfigurationError("username without a password is not supported by MONGODB-AWS") - if source is not None and source != "$external": - raise ConfigurationError( - "authentication source must be $external or None for MONGODB-AWS" - ) - - properties = extra.get("authmechanismproperties", {}) - aws_session_token = properties.get("AWS_SESSION_TOKEN") - aws_props = _AWSProperties(aws_session_token=aws_session_token) - # user can be None for temporary link-local EC2 credentials. - return MongoCredential(mech, "$external", user, passwd, aws_props, None) - elif mech == "MONGODB-OIDC": - properties = extra.get("authmechanismproperties", {}) - callback = properties.get("OIDC_CALLBACK") - human_callback = properties.get("OIDC_HUMAN_CALLBACK") - environ = properties.get("ENVIRONMENT") - token_resource = properties.get("TOKEN_RESOURCE", "") - default_allowed = [ - "*.mongodb.net", - "*.mongodb-dev.net", - "*.mongodb-qa.net", - "*.mongodbgov.net", - "localhost", - "127.0.0.1", - "::1", - ] - allowed_hosts = properties.get("ALLOWED_HOSTS", default_allowed) - msg = ( - "authentication with MONGODB-OIDC requires providing either a callback or a environment" - ) - if passwd is not None: - msg = "password is not supported by MONGODB-OIDC" - raise ConfigurationError(msg) - if callback or human_callback: - if environ is not None: - raise ConfigurationError(msg) - if callback and human_callback: - msg = "cannot set both OIDC_CALLBACK and OIDC_HUMAN_CALLBACK" - raise ConfigurationError(msg) - elif environ is not None: - if environ == "test": - if user is not None: - msg = "test environment for MONGODB-OIDC does not support username" - raise ConfigurationError(msg) - callback = _OIDCTestCallback() - elif environ == "azure": - passwd = None - if not token_resource: - raise ConfigurationError( - "Azure environment for MONGODB-OIDC requires a TOKEN_RESOURCE auth mechanism property" - ) - callback = _OIDCAzureCallback(token_resource) - elif environ == "gcp": - passwd = None - if not token_resource: - raise ConfigurationError( - "GCP provider for MONGODB-OIDC requires a TOKEN_RESOURCE auth mechanism property" - ) - callback = _OIDCGCPCallback(token_resource) - else: - raise ConfigurationError(f"unrecognized ENVIRONMENT for MONGODB-OIDC: {environ}") - else: - raise ConfigurationError(msg) - - oidc_props = _OIDCProperties( - callback=callback, - human_callback=human_callback, - environment=environ, - allowed_hosts=allowed_hosts, - token_resource=token_resource, - username=user, - ) - return MongoCredential(mech, "$external", user, passwd, oidc_props, _Cache()) - - elif mech == "PLAIN": - source_database = source or database or "$external" - return MongoCredential(mech, source_database, user, passwd, None, None) - else: - source_database = source or database or "admin" - if passwd is None: - raise ConfigurationError("A password is required.") - return MongoCredential(mech, source_database, user, passwd, None, _Cache()) - - -def _xor(fir: bytes, sec: bytes) -> bytes: - """XOR two byte strings together.""" - return b"".join([bytes([x ^ y]) for x, y in zip(fir, sec)]) - - -def _parse_scram_response(response: bytes) -> Dict[bytes, bytes]: - """Split a scram response into key, value pairs.""" - return dict( - typing.cast(typing.Tuple[bytes, bytes], item.split(b"=", 1)) - for item in response.split(b",") - ) - - -def _authenticate_scram_start( - credentials: MongoCredential, mechanism: str -) -> tuple[bytes, bytes, MutableMapping[str, Any]]: - username = credentials.username - user = username.encode("utf-8").replace(b"=", b"=3D").replace(b",", b"=2C") - nonce = standard_b64encode(os.urandom(32)) - first_bare = b"n=" + user + b",r=" + nonce - - cmd = { - "saslStart": 1, - "mechanism": mechanism, - "payload": Binary(b"n,," + first_bare), - "autoAuthorize": 1, - "options": {"skipEmptyExchange": True}, - } - return nonce, first_bare, cmd - def _authenticate_scram(credentials: MongoCredential, conn: Connection, mechanism: str) -> None: """Authenticate using SCRAM.""" diff --git a/pymongo/synchronous/auth_aws.py b/pymongo/synchronous/auth_aws.py index 04ceb95b34..7c0d24f3a1 100644 --- a/pymongo/synchronous/auth_aws.py +++ b/pymongo/synchronous/auth_aws.py @@ -23,7 +23,7 @@ if TYPE_CHECKING: from bson.typings import _ReadableBuffer - from pymongo.synchronous.auth import MongoCredential + from pymongo.auth_shared import MongoCredential from pymongo.synchronous.pool import Connection _IS_SYNC = True diff --git a/pymongo/synchronous/auth_oidc.py b/pymongo/synchronous/auth_oidc.py index f59b4d54a1..6381a408ab 100644 --- a/pymongo/synchronous/auth_oidc.py +++ b/pymongo/synchronous/auth_oidc.py @@ -15,79 +15,35 @@ """MONGODB-OIDC Authentication helpers.""" from __future__ import annotations -import abc -import os import threading import time from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Mapping, MutableMapping, Optional, Union -from urllib.parse import quote import bson from bson.binary import Binary -from pymongo._azure_helpers import _get_azure_response from pymongo._csot import remaining -from pymongo._gcp_helpers import _get_gcp_response +from pymongo.auth_oidc_shared import ( + CALLBACK_VERSION, + HUMAN_CALLBACK_TIMEOUT_SECONDS, + MACHINE_CALLBACK_TIMEOUT_SECONDS, + TIME_BETWEEN_CALLS_SECONDS, + OIDCCallback, + OIDCCallbackContext, + OIDCCallbackResult, + OIDCIdPInfo, + _OIDCProperties, +) from pymongo.errors import ConfigurationError, OperationFailure -from pymongo.helpers_constants import _AUTHENTICATION_FAILURE_CODE +from pymongo.helpers_shared import _AUTHENTICATION_FAILURE_CODE if TYPE_CHECKING: - from pymongo.synchronous.auth import MongoCredential + from pymongo.auth_shared import MongoCredential from pymongo.synchronous.pool import Connection _IS_SYNC = True -@dataclass -class OIDCIdPInfo: - issuer: str - clientId: Optional[str] = field(default=None) - requestScopes: Optional[list[str]] = field(default=None) - - -@dataclass -class OIDCCallbackContext: - timeout_seconds: float - username: str - version: int - refresh_token: Optional[str] = field(default=None) - idp_info: Optional[OIDCIdPInfo] = field(default=None) - - -@dataclass -class OIDCCallbackResult: - access_token: str - expires_in_seconds: Optional[float] = field(default=None) - refresh_token: Optional[str] = field(default=None) - - -class OIDCCallback(abc.ABC): - """A base class for defining OIDC callbacks.""" - - @abc.abstractmethod - def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult: - """Convert the given BSON value into our own type.""" - - -@dataclass -class _OIDCProperties: - callback: Optional[OIDCCallback] = field(default=None) - human_callback: Optional[OIDCCallback] = field(default=None) - environment: Optional[str] = field(default=None) - allowed_hosts: list[str] = field(default_factory=list) - token_resource: Optional[str] = field(default=None) - username: str = "" - - -"""Mechanism properties for MONGODB-OIDC authentication.""" - -TOKEN_BUFFER_MINUTES = 5 -HUMAN_CALLBACK_TIMEOUT_SECONDS = 5 * 60 -CALLBACK_VERSION = 1 -MACHINE_CALLBACK_TIMEOUT_SECONDS = 60 -TIME_BETWEEN_CALLS_SECONDS = 0.1 - - def _get_authenticator( credentials: MongoCredential, address: tuple[str, int] ) -> _OIDCAuthenticator: @@ -117,48 +73,6 @@ def _get_authenticator( return credentials.cache.data -class _OIDCTestCallback(OIDCCallback): - def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult: - token_file = os.environ.get("OIDC_TOKEN_FILE") - if not token_file: - raise RuntimeError( - 'MONGODB-OIDC with an "test" provider requires "OIDC_TOKEN_FILE" to be set' - ) - with open(token_file) as fid: - return OIDCCallbackResult(access_token=fid.read().strip()) - - -class _OIDCAWSCallback(OIDCCallback): - def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult: - token_file = os.environ.get("AWS_WEB_IDENTITY_TOKEN_FILE") - if not token_file: - raise RuntimeError( - 'MONGODB-OIDC with an "aws" provider requires "AWS_WEB_IDENTITY_TOKEN_FILE" to be set' - ) - with open(token_file) as fid: - return OIDCCallbackResult(access_token=fid.read().strip()) - - -class _OIDCAzureCallback(OIDCCallback): - def __init__(self, token_resource: str) -> None: - self.token_resource = quote(token_resource) - - def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult: - resp = _get_azure_response(self.token_resource, context.username, context.timeout_seconds) - return OIDCCallbackResult( - access_token=resp["access_token"], expires_in_seconds=resp["expires_in"] - ) - - -class _OIDCGCPCallback(OIDCCallback): - def __init__(self, token_resource: str) -> None: - self.token_resource = quote(token_resource) - - def fetch(self, context: OIDCCallbackContext) -> OIDCCallbackResult: - resp = _get_gcp_response(self.token_resource, context.timeout_seconds) - return OIDCCallbackResult(access_token=resp["access_token"]) - - @dataclass class _OIDCAuthenticator: username: str diff --git a/pymongo/synchronous/bulk.py b/pymongo/synchronous/bulk.py index 781acdb4d8..8d3d0e10fd 100644 --- a/pymongo/synchronous/bulk.py +++ b/pymongo/synchronous/bulk.py @@ -34,22 +34,21 @@ from bson.objectid import ObjectId from bson.raw_bson import RawBSONDocument -from pymongo import _csot +from pymongo import _csot, common +from pymongo.common import ( + validate_is_document_type, + validate_ok_for_replace, + validate_ok_for_update, +) from pymongo.errors import ( BulkWriteError, ConfigurationError, InvalidOperation, OperationFailure, ) -from pymongo.helpers_constants import _RETRYABLE_ERROR_CODES -from pymongo.synchronous import common +from pymongo.helpers_shared import _RETRYABLE_ERROR_CODES, _get_wce_doc +from pymongo.read_preferences import ReadPreference from pymongo.synchronous.client_session import ClientSession, _validate_session_write_concern -from pymongo.synchronous.common import ( - validate_is_document_type, - validate_ok_for_replace, - validate_ok_for_update, -) -from pymongo.synchronous.helpers import _get_wce_doc from pymongo.synchronous.message import ( _DELETE, _INSERT, @@ -58,13 +57,12 @@ _EncryptedBulkWriteContext, _randint, ) -from pymongo.synchronous.read_preferences import ReadPreference from pymongo.write_concern import WriteConcern if TYPE_CHECKING: from pymongo.synchronous.collection import Collection from pymongo.synchronous.pool import Connection - from pymongo.synchronous.typings import _DocumentOut, _DocumentType, _Pipeline + from pymongo.typings import _DocumentOut, _DocumentType, _Pipeline _IS_SYNC = True @@ -180,7 +178,7 @@ def __init__( comment: Optional[str] = None, let: Optional[Any] = None, ) -> None: - """Initialize a _Bulk instance.""" + """Initialize a _AsyncBulk instance.""" self.collection = collection.with_options( codec_options=collection.codec_options._replace( unicode_decode_error_handler="replace", document_class=dict @@ -335,8 +333,8 @@ def _execute_command( self.next_run = None run = self.current_run - # Connection.command validates the session, but we use - # Connection.write_command + # AsyncConnection.command validates the session, but we use + # AsyncConnection.write_command conn.validate_session(client, session) last_run = False diff --git a/pymongo/synchronous/change_stream.py b/pymongo/synchronous/change_stream.py index 1b22ed9be1..f7489249d8 100644 --- a/pymongo/synchronous/change_stream.py +++ b/pymongo/synchronous/change_stream.py @@ -21,7 +21,8 @@ from bson import CodecOptions, _bson_to_dict from bson.raw_bson import RawBSONDocument from bson.timestamp import Timestamp -from pymongo import _csot +from pymongo import _csot, common +from pymongo.collation import validate_collation_or_none from pymongo.errors import ( ConnectionFailure, CursorNotFound, @@ -29,16 +30,14 @@ OperationFailure, PyMongoError, ) -from pymongo.synchronous import common +from pymongo.operations import _Op from pymongo.synchronous.aggregation import ( _AggregationCommand, _CollectionAggregationCommand, _DatabaseAggregationCommand, ) -from pymongo.synchronous.collation import validate_collation_or_none from pymongo.synchronous.command_cursor import CommandCursor -from pymongo.synchronous.operations import _Op -from pymongo.synchronous.typings import _CollationIn, _DocumentType, _Pipeline +from pymongo.typings import _CollationIn, _DocumentType, _Pipeline _IS_SYNC = True diff --git a/pymongo/synchronous/client_options.py b/pymongo/synchronous/client_options.py deleted file mode 100644 index 58042220fb..0000000000 --- a/pymongo/synchronous/client_options.py +++ /dev/null @@ -1,334 +0,0 @@ -# Copyright 2014-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you -# may not use this file except in compliance with the License. You -# may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or -# implied. See the License for the specific language governing -# permissions and limitations under the License. - -"""Tools to parse mongo client options.""" -from __future__ import annotations - -from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence, cast - -from bson.codec_options import _parse_codec_options -from pymongo.errors import ConfigurationError -from pymongo.read_concern import ReadConcern -from pymongo.ssl_support import get_ssl_context -from pymongo.synchronous import common -from pymongo.synchronous.compression_support import CompressionSettings -from pymongo.synchronous.monitoring import _EventListener, _EventListeners -from pymongo.synchronous.pool import PoolOptions -from pymongo.synchronous.read_preferences import ( - _ServerMode, - make_read_preference, - read_pref_mode_from_name, -) -from pymongo.synchronous.server_selectors import any_server_selector -from pymongo.write_concern import WriteConcern, validate_boolean - -if TYPE_CHECKING: - from bson.codec_options import CodecOptions - from pymongo.pyopenssl_context import SSLContext - from pymongo.synchronous.auth import MongoCredential - from pymongo.synchronous.encryption_options import AutoEncryptionOpts - from pymongo.synchronous.topology_description import _ServerSelector - -_IS_SYNC = True - - -def _parse_credentials( - username: str, password: str, database: Optional[str], options: Mapping[str, Any] -) -> Optional[MongoCredential]: - """Parse authentication credentials.""" - mechanism = options.get("authmechanism", "DEFAULT" if username else None) - source = options.get("authsource") - if username or mechanism: - from pymongo.synchronous.auth import _build_credentials_tuple - - return _build_credentials_tuple(mechanism, source, username, password, options, database) - return None - - -def _parse_read_preference(options: Mapping[str, Any]) -> _ServerMode: - """Parse read preference options.""" - if "read_preference" in options: - return options["read_preference"] - - name = options.get("readpreference", "primary") - mode = read_pref_mode_from_name(name) - tags = options.get("readpreferencetags") - max_staleness = options.get("maxstalenessseconds", -1) - return make_read_preference(mode, tags, max_staleness) - - -def _parse_write_concern(options: Mapping[str, Any]) -> WriteConcern: - """Parse write concern options.""" - concern = options.get("w") - wtimeout = options.get("wtimeoutms") - j = options.get("journal") - fsync = options.get("fsync") - return WriteConcern(concern, wtimeout, j, fsync) - - -def _parse_read_concern(options: Mapping[str, Any]) -> ReadConcern: - """Parse read concern options.""" - concern = options.get("readconcernlevel") - return ReadConcern(concern) - - -def _parse_ssl_options(options: Mapping[str, Any]) -> tuple[Optional[SSLContext], bool]: - """Parse ssl options.""" - use_tls = options.get("tls") - if use_tls is not None: - validate_boolean("tls", use_tls) - - certfile = options.get("tlscertificatekeyfile") - passphrase = options.get("tlscertificatekeyfilepassword") - ca_certs = options.get("tlscafile") - crlfile = options.get("tlscrlfile") - allow_invalid_certificates = options.get("tlsallowinvalidcertificates", False) - allow_invalid_hostnames = options.get("tlsallowinvalidhostnames", False) - disable_ocsp_endpoint_check = options.get("tlsdisableocspendpointcheck", False) - - enabled_tls_opts = [] - for opt in ( - "tlscertificatekeyfile", - "tlscertificatekeyfilepassword", - "tlscafile", - "tlscrlfile", - ): - # Any non-null value of these options implies tls=True. - if opt in options and options[opt]: - enabled_tls_opts.append(opt) - for opt in ( - "tlsallowinvalidcertificates", - "tlsallowinvalidhostnames", - "tlsdisableocspendpointcheck", - ): - # A value of False for these options implies tls=True. - if opt in options and not options[opt]: - enabled_tls_opts.append(opt) - - if enabled_tls_opts: - if use_tls is None: - # Implicitly enable TLS when one of the tls* options is set. - use_tls = True - elif not use_tls: - # Error since tls is explicitly disabled but a tls option is set. - raise ConfigurationError( - "TLS has not been enabled but the " - "following tls parameters have been set: " - "%s. Please set `tls=True` or remove." % ", ".join(enabled_tls_opts) - ) - - if use_tls: - ctx = get_ssl_context( - certfile, - passphrase, - ca_certs, - crlfile, - allow_invalid_certificates, - allow_invalid_hostnames, - disable_ocsp_endpoint_check, - ) - return ctx, allow_invalid_hostnames - return None, allow_invalid_hostnames - - -def _parse_pool_options( - username: str, password: str, database: Optional[str], options: Mapping[str, Any] -) -> PoolOptions: - """Parse connection pool options.""" - credentials = _parse_credentials(username, password, database, options) - max_pool_size = options.get("maxpoolsize", common.MAX_POOL_SIZE) - min_pool_size = options.get("minpoolsize", common.MIN_POOL_SIZE) - max_idle_time_seconds = options.get("maxidletimems", common.MAX_IDLE_TIME_SEC) - if max_pool_size is not None and min_pool_size > max_pool_size: - raise ValueError("minPoolSize must be smaller or equal to maxPoolSize") - connect_timeout = options.get("connecttimeoutms", common.CONNECT_TIMEOUT) - socket_timeout = options.get("sockettimeoutms") - wait_queue_timeout = options.get("waitqueuetimeoutms", common.WAIT_QUEUE_TIMEOUT) - event_listeners = cast(Optional[Sequence[_EventListener]], options.get("event_listeners")) - appname = options.get("appname") - driver = options.get("driver") - server_api = options.get("server_api") - compression_settings = CompressionSettings( - options.get("compressors", []), options.get("zlibcompressionlevel", -1) - ) - ssl_context, tls_allow_invalid_hostnames = _parse_ssl_options(options) - load_balanced = options.get("loadbalanced") - max_connecting = options.get("maxconnecting", common.MAX_CONNECTING) - return PoolOptions( - max_pool_size, - min_pool_size, - max_idle_time_seconds, - connect_timeout, - socket_timeout, - wait_queue_timeout, - ssl_context, - tls_allow_invalid_hostnames, - _EventListeners(event_listeners), - appname, - driver, - compression_settings, - max_connecting=max_connecting, - server_api=server_api, - load_balanced=load_balanced, - credentials=credentials, - ) - - -class ClientOptions: - """Read only configuration options for a MongoClient. - - Should not be instantiated directly by application developers. Access - a client's options via :attr:`pymongo.mongo_client.MongoClient.options` - instead. - """ - - def __init__( - self, username: str, password: str, database: Optional[str], options: Mapping[str, Any] - ): - self.__options = options - self.__codec_options = _parse_codec_options(options) - self.__direct_connection = options.get("directconnection") - self.__local_threshold_ms = options.get("localthresholdms", common.LOCAL_THRESHOLD_MS) - # self.__server_selection_timeout is in seconds. Must use full name for - # common.SERVER_SELECTION_TIMEOUT because it is set directly by tests. - self.__server_selection_timeout = options.get( - "serverselectiontimeoutms", common.SERVER_SELECTION_TIMEOUT - ) - self.__pool_options = _parse_pool_options(username, password, database, options) - self.__read_preference = _parse_read_preference(options) - self.__replica_set_name = options.get("replicaset") - self.__write_concern = _parse_write_concern(options) - self.__read_concern = _parse_read_concern(options) - self.__connect = options.get("connect") - self.__heartbeat_frequency = options.get("heartbeatfrequencyms", common.HEARTBEAT_FREQUENCY) - self.__retry_writes = options.get("retrywrites", common.RETRY_WRITES) - self.__retry_reads = options.get("retryreads", common.RETRY_READS) - self.__server_selector = options.get("server_selector", any_server_selector) - self.__auto_encryption_opts = options.get("auto_encryption_opts") - self.__load_balanced = options.get("loadbalanced") - self.__timeout = options.get("timeoutms") - self.__server_monitoring_mode = options.get( - "servermonitoringmode", common.SERVER_MONITORING_MODE - ) - - @property - def _options(self) -> Mapping[str, Any]: - """The original options used to create this ClientOptions.""" - return self.__options - - @property - def connect(self) -> Optional[bool]: - """Whether to begin discovering a MongoDB topology automatically.""" - return self.__connect - - @property - def codec_options(self) -> CodecOptions: - """A :class:`~bson.codec_options.CodecOptions` instance.""" - return self.__codec_options - - @property - def direct_connection(self) -> Optional[bool]: - """Whether to connect to the deployment in 'Single' topology.""" - return self.__direct_connection - - @property - def local_threshold_ms(self) -> int: - """The local threshold for this instance.""" - return self.__local_threshold_ms - - @property - def server_selection_timeout(self) -> int: - """The server selection timeout for this instance in seconds.""" - return self.__server_selection_timeout - - @property - def server_selector(self) -> _ServerSelector: - return self.__server_selector - - @property - def heartbeat_frequency(self) -> int: - """The monitoring frequency in seconds.""" - return self.__heartbeat_frequency - - @property - def pool_options(self) -> PoolOptions: - """A :class:`~pymongo.pool.PoolOptions` instance.""" - return self.__pool_options - - @property - def read_preference(self) -> _ServerMode: - """A read preference instance.""" - return self.__read_preference - - @property - def replica_set_name(self) -> Optional[str]: - """Replica set name or None.""" - return self.__replica_set_name - - @property - def write_concern(self) -> WriteConcern: - """A :class:`~pymongo.write_concern.WriteConcern` instance.""" - return self.__write_concern - - @property - def read_concern(self) -> ReadConcern: - """A :class:`~pymongo.read_concern.ReadConcern` instance.""" - return self.__read_concern - - @property - def timeout(self) -> Optional[float]: - """The configured timeoutMS converted to seconds, or None. - - .. versionadded:: 4.2 - """ - return self.__timeout - - @property - def retry_writes(self) -> bool: - """If this instance should retry supported write operations.""" - return self.__retry_writes - - @property - def retry_reads(self) -> bool: - """If this instance should retry supported read operations.""" - return self.__retry_reads - - @property - def auto_encryption_opts(self) -> Optional[AutoEncryptionOpts]: - """A :class:`~pymongo.encryption.AutoEncryptionOpts` or None.""" - return self.__auto_encryption_opts - - @property - def load_balanced(self) -> Optional[bool]: - """True if the client was configured to connect to a load balancer.""" - return self.__load_balanced - - @property - def event_listeners(self) -> list[_EventListeners]: - """The event listeners registered for this client. - - See :mod:`~pymongo.monitoring` for details. - - .. versionadded:: 4.0 - """ - assert self.__pool_options._event_listeners is not None - return self.__pool_options._event_listeners.event_listeners() - - @property - def server_monitoring_mode(self) -> str: - """The configured serverMonitoringMode option. - - .. versionadded:: 4.5 - """ - return self.__server_monitoring_mode diff --git a/pymongo/synchronous/client_session.py b/pymongo/synchronous/client_session.py index b4339bd122..cf1fc746be 100644 --- a/pymongo/synchronous/client_session.py +++ b/pymongo/synchronous/client_session.py @@ -164,12 +164,12 @@ PyMongoError, WTimeoutError, ) -from pymongo.helpers_constants import _RETRYABLE_ERROR_CODES +from pymongo.helpers_shared import _RETRYABLE_ERROR_CODES +from pymongo.operations import _Op from pymongo.read_concern import ReadConcern +from pymongo.read_preferences import ReadPreference, _ServerMode from pymongo.server_type import SERVER_TYPE from pymongo.synchronous.cursor import _ConnectionManager -from pymongo.synchronous.operations import _Op -from pymongo.synchronous.read_preferences import ReadPreference, _ServerMode from pymongo.write_concern import WriteConcern if TYPE_CHECKING: @@ -177,7 +177,7 @@ from pymongo.synchronous.pool import Connection from pymongo.synchronous.server import Server - from pymongo.synchronous.typings import ClusterTime, _Address + from pymongo.typings import ClusterTime, _Address _IS_SYNC = True diff --git a/pymongo/synchronous/collation.py b/pymongo/synchronous/collation.py deleted file mode 100644 index 1ce1ee00b1..0000000000 --- a/pymongo/synchronous/collation.py +++ /dev/null @@ -1,226 +0,0 @@ -# Copyright 2016 MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tools for working with `collations`_. - -.. _collations: https://www.mongodb.com/docs/manual/reference/collation/ -""" -from __future__ import annotations - -from typing import Any, Mapping, Optional, Union - -from pymongo.synchronous import common -from pymongo.write_concern import validate_boolean - -_IS_SYNC = True - - -class CollationStrength: - """ - An enum that defines values for `strength` on a - :class:`~pymongo.collation.Collation`. - """ - - PRIMARY = 1 - """Differentiate base (unadorned) characters.""" - - SECONDARY = 2 - """Differentiate character accents.""" - - TERTIARY = 3 - """Differentiate character case.""" - - QUATERNARY = 4 - """Differentiate words with and without punctuation.""" - - IDENTICAL = 5 - """Differentiate unicode code point (characters are exactly identical).""" - - -class CollationAlternate: - """ - An enum that defines values for `alternate` on a - :class:`~pymongo.collation.Collation`. - """ - - NON_IGNORABLE = "non-ignorable" - """Spaces and punctuation are treated as base characters.""" - - SHIFTED = "shifted" - """Spaces and punctuation are *not* considered base characters. - - Spaces and punctuation are distinguished regardless when the - :class:`~pymongo.collation.Collation` strength is at least - :data:`~pymongo.collation.CollationStrength.QUATERNARY`. - - """ - - -class CollationMaxVariable: - """ - An enum that defines values for `max_variable` on a - :class:`~pymongo.collation.Collation`. - """ - - PUNCT = "punct" - """Both punctuation and spaces are ignored.""" - - SPACE = "space" - """Spaces alone are ignored.""" - - -class CollationCaseFirst: - """ - An enum that defines values for `case_first` on a - :class:`~pymongo.collation.Collation`. - """ - - UPPER = "upper" - """Sort uppercase characters first.""" - - LOWER = "lower" - """Sort lowercase characters first.""" - - OFF = "off" - """Default for locale or collation strength.""" - - -class Collation: - """Collation - - :param locale: (string) The locale of the collation. This should be a string - that identifies an `ICU locale ID` exactly. For example, ``en_US`` is - valid, but ``en_us`` and ``en-US`` are not. Consult the MongoDB - documentation for a list of supported locales. - :param caseLevel: (optional) If ``True``, turn on case sensitivity if - `strength` is 1 or 2 (case sensitivity is implied if `strength` is - greater than 2). Defaults to ``False``. - :param caseFirst: (optional) Specify that either uppercase or lowercase - characters take precedence. Must be one of the following values: - - * :data:`~CollationCaseFirst.UPPER` - * :data:`~CollationCaseFirst.LOWER` - * :data:`~CollationCaseFirst.OFF` (the default) - - :param strength: Specify the comparison strength. This is also - known as the ICU comparison level. This must be one of the following - values: - - * :data:`~CollationStrength.PRIMARY` - * :data:`~CollationStrength.SECONDARY` - * :data:`~CollationStrength.TERTIARY` (the default) - * :data:`~CollationStrength.QUATERNARY` - * :data:`~CollationStrength.IDENTICAL` - - Each successive level builds upon the previous. For example, a - `strength` of :data:`~CollationStrength.SECONDARY` differentiates - characters based both on the unadorned base character and its accents. - - :param numericOrdering: If ``True``, order numbers numerically - instead of in collation order (defaults to ``False``). - :param alternate: Specify whether spaces and punctuation are - considered base characters. This must be one of the following values: - - * :data:`~CollationAlternate.NON_IGNORABLE` (the default) - * :data:`~CollationAlternate.SHIFTED` - - :param maxVariable: When `alternate` is - :data:`~CollationAlternate.SHIFTED`, this option specifies what - characters may be ignored. This must be one of the following values: - - * :data:`~CollationMaxVariable.PUNCT` (the default) - * :data:`~CollationMaxVariable.SPACE` - - :param normalization: If ``True``, normalizes text into Unicode - NFD. Defaults to ``False``. - :param backwards: If ``True``, accents on characters are - considered from the back of the word to the front, as it is done in some - French dictionary ordering traditions. Defaults to ``False``. - :param kwargs: Keyword arguments supplying any additional options - to be sent with this Collation object. - - .. versionadded: 3.4 - - """ - - __slots__ = ("__document",) - - def __init__( - self, - locale: str, - caseLevel: Optional[bool] = None, - caseFirst: Optional[str] = None, - strength: Optional[int] = None, - numericOrdering: Optional[bool] = None, - alternate: Optional[str] = None, - maxVariable: Optional[str] = None, - normalization: Optional[bool] = None, - backwards: Optional[bool] = None, - **kwargs: Any, - ) -> None: - locale = common.validate_string("locale", locale) - self.__document: dict[str, Any] = {"locale": locale} - if caseLevel is not None: - self.__document["caseLevel"] = validate_boolean("caseLevel", caseLevel) - if caseFirst is not None: - self.__document["caseFirst"] = common.validate_string("caseFirst", caseFirst) - if strength is not None: - self.__document["strength"] = common.validate_integer("strength", strength) - if numericOrdering is not None: - self.__document["numericOrdering"] = validate_boolean( - "numericOrdering", numericOrdering - ) - if alternate is not None: - self.__document["alternate"] = common.validate_string("alternate", alternate) - if maxVariable is not None: - self.__document["maxVariable"] = common.validate_string("maxVariable", maxVariable) - if normalization is not None: - self.__document["normalization"] = validate_boolean("normalization", normalization) - if backwards is not None: - self.__document["backwards"] = validate_boolean("backwards", backwards) - self.__document.update(kwargs) - - @property - def document(self) -> dict[str, Any]: - """The document representation of this collation. - - .. note:: - :class:`Collation` is immutable. Mutating the value of - :attr:`document` does not mutate this :class:`Collation`. - """ - return self.__document.copy() - - def __repr__(self) -> str: - document = self.document - return "Collation({})".format(", ".join(f"{key}={document[key]!r}" for key in document)) - - def __eq__(self, other: Any) -> bool: - if isinstance(other, Collation): - return self.document == other.document - return NotImplemented - - def __ne__(self, other: Any) -> bool: - return not self == other - - -def validate_collation_or_none( - value: Optional[Union[Mapping[str, Any], Collation]] -) -> Optional[dict[str, Any]]: - if value is None: - return None - if isinstance(value, Collation): - return value.document - if isinstance(value, dict): - return value - raise TypeError("collation must be a dict, an instance of collation.Collation, or None.") diff --git a/pymongo/synchronous/collection.py b/pymongo/synchronous/collection.py index 61bd81fd9b..b8fd39f2d6 100644 --- a/pymongo/synchronous/collection.py +++ b/pymongo/synchronous/collection.py @@ -40,14 +40,31 @@ from bson.raw_bson import RawBSONDocument from bson.son import SON from bson.timestamp import Timestamp -from pymongo import ASCENDING, _csot +from pymongo import ASCENDING, _csot, common, helpers_shared +from pymongo.collation import validate_collation_or_none +from pymongo.common import _ecoc_coll_name, _esc_coll_name from pymongo.errors import ( ConfigurationError, InvalidName, InvalidOperation, OperationFailure, ) +from pymongo.helpers_shared import _check_write_command_response +from pymongo.operations import ( + DeleteMany, + DeleteOne, + IndexModel, + InsertOne, + ReplaceOne, + SearchIndexModel, + UpdateMany, + UpdateOne, + _IndexKeyHint, + _IndexList, + _Op, +) from pymongo.read_concern import DEFAULT_READ_CONCERN +from pymongo.read_preferences import ReadPreference, _ServerMode from pymongo.results import ( BulkWriteResult, DeleteResult, @@ -55,40 +72,23 @@ InsertOneResult, UpdateResult, ) -from pymongo.synchronous import common, helpers, message +from pymongo.synchronous import message from pymongo.synchronous.aggregation import ( _CollectionAggregationCommand, _CollectionRawAggregationCommand, ) from pymongo.synchronous.bulk import _Bulk from pymongo.synchronous.change_stream import CollectionChangeStream -from pymongo.synchronous.collation import validate_collation_or_none from pymongo.synchronous.command_cursor import ( CommandCursor, RawBatchCommandCursor, ) -from pymongo.synchronous.common import _ecoc_coll_name, _esc_coll_name from pymongo.synchronous.cursor import ( Cursor, RawBatchCursor, ) -from pymongo.synchronous.helpers import _check_write_command_response from pymongo.synchronous.message import _UNICODE_REPLACE_CODEC_OPTIONS -from pymongo.synchronous.operations import ( - DeleteMany, - DeleteOne, - IndexModel, - InsertOne, - ReplaceOne, - SearchIndexModel, - UpdateMany, - UpdateOne, - _IndexKeyHint, - _IndexList, - _Op, -) -from pymongo.synchronous.read_preferences import ReadPreference, _ServerMode -from pymongo.synchronous.typings import _CollationIn, _DocumentType, _DocumentTypeArg, _Pipeline +from pymongo.typings import _CollationIn, _DocumentType, _DocumentTypeArg, _Pipeline from pymongo.write_concern import DEFAULT_WRITE_CONCERN, WriteConcern, validate_boolean _IS_SYNC = True @@ -124,10 +124,10 @@ class ReturnDocument: if TYPE_CHECKING: import bson + from pymongo.collation import Collation from pymongo.read_concern import ReadConcern from pymongo.synchronous.aggregation import _AggregationCommand from pymongo.synchronous.client_session import ClientSession - from pymongo.synchronous.collation import Collation from pymongo.synchronous.database import Database from pymongo.synchronous.pool import Connection from pymongo.synchronous.server import Server @@ -989,7 +989,7 @@ def _update( "Must be connected to MongoDB 4.2+ to use hint on unacknowledged update commands." ) if not isinstance(hint, str): - hint = helpers._index_document(hint) + hint = helpers_shared._index_document(hint) update_doc["hint"] = hint command = {"update": self.name, "ordered": ordered, "updates": [update_doc]} if let is not None: @@ -1470,7 +1470,7 @@ def _delete( "Must be connected to MongoDB 4.4+ to use hint on unacknowledged delete commands." ) if not isinstance(hint, str): - hint = helpers._index_document(hint) + hint = helpers_shared._index_document(hint) delete_doc["hint"] = hint command = {"delete": self.name, "ordered": ordered, "deletes": [delete_doc]} @@ -2086,7 +2086,7 @@ def count_documents( pipeline.append({"$group": {"_id": 1, "n": {"$sum": 1}}}) cmd = {"aggregate": self._name, "pipeline": pipeline, "cursor": {}} if "hint" in kwargs and not isinstance(kwargs["hint"], str): - kwargs["hint"] = helpers._index_document(kwargs["hint"]) + kwargs["hint"] = helpers_shared._index_document(kwargs["hint"]) collation = validate_collation_or_none(kwargs.pop("collation", None)) cmd.update(kwargs) @@ -2419,7 +2419,7 @@ def _drop_index( ) -> None: name = index_or_name if isinstance(index_or_name, list): - name = helpers._gen_index_name(index_or_name) + name = helpers_shared._gen_index_name(index_or_name) if not isinstance(name, str): raise TypeError("index_or_name must be an instance of str or list") @@ -3148,15 +3148,15 @@ def _find_and_modify( cmd["let"] = let cmd.update(kwargs) if projection is not None: - cmd["fields"] = helpers._fields_list_to_dict(projection, "projection") + cmd["fields"] = helpers_shared._fields_list_to_dict(projection, "projection") if sort is not None: - cmd["sort"] = helpers._index_document(sort) + cmd["sort"] = helpers_shared._index_document(sort) if upsert is not None: validate_boolean("upsert", upsert) cmd["upsert"] = upsert if hint is not None: if not isinstance(hint, str): - hint = helpers._index_document(hint) + hint = helpers_shared._index_document(hint) write_concern = self._write_concern_for_cmd(cmd, session) diff --git a/pymongo/synchronous/command_cursor.py b/pymongo/synchronous/command_cursor.py index a2a5d8b192..ba9bf6ef10 100644 --- a/pymongo/synchronous/command_cursor.py +++ b/pymongo/synchronous/command_cursor.py @@ -31,6 +31,7 @@ from bson import CodecOptions, _convert_raw_document_lists_to_streams from pymongo.cursor_shared import _CURSOR_CLOSED_ERRORS from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure +from pymongo.response import PinnedResponse from pymongo.synchronous.cursor import _ConnectionManager from pymongo.synchronous.message import ( _CursorAddress, @@ -39,8 +40,7 @@ _OpReply, _RawBatchGetMore, ) -from pymongo.synchronous.response import PinnedResponse -from pymongo.synchronous.typings import _Address, _DocumentOut, _DocumentType +from pymongo.typings import _Address, _DocumentOut, _DocumentType if TYPE_CHECKING: from pymongo.synchronous.client_session import ClientSession diff --git a/pymongo/synchronous/common.py b/pymongo/synchronous/common.py deleted file mode 100644 index 13e58adedd..0000000000 --- a/pymongo/synchronous/common.py +++ /dev/null @@ -1,1062 +0,0 @@ -# Copyright 2011-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you -# may not use this file except in compliance with the License. You -# may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or -# implied. See the License for the specific language governing -# permissions and limitations under the License. - - -"""Functions and classes common to multiple pymongo modules.""" -from __future__ import annotations - -import datetime -import warnings -from collections import OrderedDict, abc -from difflib import get_close_matches -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Iterator, - Mapping, - MutableMapping, - NoReturn, - Optional, - Sequence, - Type, - Union, - overload, -) -from urllib.parse import unquote_plus - -from bson import SON -from bson.binary import UuidRepresentation -from bson.codec_options import CodecOptions, DatetimeConversion, TypeRegistry -from bson.raw_bson import RawBSONDocument -from pymongo.driver_info import DriverInfo -from pymongo.errors import ConfigurationError -from pymongo.read_concern import ReadConcern -from pymongo.server_api import ServerApi -from pymongo.synchronous.compression_support import ( - validate_compressors, - validate_zlib_compression_level, -) -from pymongo.synchronous.monitoring import _validate_event_listeners -from pymongo.synchronous.read_preferences import _MONGOS_MODES, _ServerMode -from pymongo.write_concern import DEFAULT_WRITE_CONCERN, WriteConcern, validate_boolean - -if TYPE_CHECKING: - from pymongo.synchronous.client_session import ClientSession - -_IS_SYNC = True - -ORDERED_TYPES: Sequence[Type] = (SON, OrderedDict) - -# Defaults until we connect to a server and get updated limits. -MAX_BSON_SIZE = 16 * (1024**2) -MAX_MESSAGE_SIZE: int = 2 * MAX_BSON_SIZE -MIN_WIRE_VERSION = 0 -MAX_WIRE_VERSION = 0 -MAX_WRITE_BATCH_SIZE = 1000 - -# What this version of PyMongo supports. -MIN_SUPPORTED_SERVER_VERSION = "3.6" -MIN_SUPPORTED_WIRE_VERSION = 6 -MAX_SUPPORTED_WIRE_VERSION = 21 - -# Frequency to call hello on servers, in seconds. -HEARTBEAT_FREQUENCY = 10 - -# Frequency to clean up unclosed cursors, in seconds. -# See MongoClient._process_kill_cursors. -KILL_CURSOR_FREQUENCY = 1 - -# Frequency to process events queue, in seconds. -EVENTS_QUEUE_FREQUENCY = 1 - -# How long to wait, in seconds, for a suitable server to be found before -# aborting an operation. For example, if the client attempts an insert -# during a replica set election, SERVER_SELECTION_TIMEOUT governs the -# longest it is willing to wait for a new primary to be found. -SERVER_SELECTION_TIMEOUT = 30 - -# Spec requires at least 500ms between hello calls. -MIN_HEARTBEAT_INTERVAL = 0.5 - -# Spec requires at least 60s between SRV rescans. -MIN_SRV_RESCAN_INTERVAL = 60 - -# Default connectTimeout in seconds. -CONNECT_TIMEOUT = 20.0 - -# Default value for maxPoolSize. -MAX_POOL_SIZE = 100 - -# Default value for minPoolSize. -MIN_POOL_SIZE = 0 - -# The maximum number of concurrent connection creation attempts per pool. -MAX_CONNECTING = 2 - -# Default value for maxIdleTimeMS. -MAX_IDLE_TIME_MS: Optional[int] = None - -# Default value for maxIdleTimeMS in seconds. -MAX_IDLE_TIME_SEC: Optional[int] = None - -# Default value for waitQueueTimeoutMS in seconds. -WAIT_QUEUE_TIMEOUT: Optional[int] = None - -# Default value for localThresholdMS. -LOCAL_THRESHOLD_MS = 15 - -# Default value for retryWrites. -RETRY_WRITES = True - -# Default value for retryReads. -RETRY_READS = True - -# The error code returned when a command doesn't exist. -COMMAND_NOT_FOUND_CODES: Sequence[int] = (59,) - -# Error codes to ignore if GridFS calls createIndex on a secondary -UNAUTHORIZED_CODES: Sequence[int] = (13, 16547, 16548) - -# Maximum number of sessions to send in a single endSessions command. -# From the driver sessions spec. -_MAX_END_SESSIONS = 10000 - -# Default value for srvServiceName -SRV_SERVICE_NAME = "mongodb" - -# Default value for serverMonitoringMode -SERVER_MONITORING_MODE = "auto" # poll/stream/auto - - -def partition_node(node: str) -> tuple[str, int]: - """Split a host:port string into (host, int(port)) pair.""" - host = node - port = 27017 - idx = node.rfind(":") - if idx != -1: - host, port = node[:idx], int(node[idx + 1 :]) - if host.startswith("["): - host = host[1:-1] - return host, port - - -def clean_node(node: str) -> tuple[str, int]: - """Split and normalize a node name from a hello response.""" - host, port = partition_node(node) - - # Normalize hostname to lowercase, since DNS is case-insensitive: - # http://tools.ietf.org/html/rfc4343 - # This prevents useless rediscovery if "foo.com" is in the seed list but - # "FOO.com" is in the hello response. - return host.lower(), port - - -def raise_config_error(key: str, suggestions: Optional[list] = None) -> NoReturn: - """Raise ConfigurationError with the given key name.""" - msg = f"Unknown option: {key}." - if suggestions: - msg += f" Did you mean one of ({', '.join(suggestions)}) or maybe a camelCase version of one? Refer to docstring." - raise ConfigurationError(msg) - - -# Mapping of URI uuid representation options to valid subtypes. -_UUID_REPRESENTATIONS = { - "unspecified": UuidRepresentation.UNSPECIFIED, - "standard": UuidRepresentation.STANDARD, - "pythonLegacy": UuidRepresentation.PYTHON_LEGACY, - "javaLegacy": UuidRepresentation.JAVA_LEGACY, - "csharpLegacy": UuidRepresentation.CSHARP_LEGACY, -} - - -def validate_boolean_or_string(option: str, value: Any) -> bool: - """Validates that value is True, False, 'true', or 'false'.""" - if isinstance(value, str): - if value not in ("true", "false"): - raise ValueError(f"The value of {option} must be 'true' or 'false'") - return value == "true" - return validate_boolean(option, value) - - -def validate_integer(option: str, value: Any) -> int: - """Validates that 'value' is an integer (or basestring representation).""" - if isinstance(value, int): - return value - elif isinstance(value, str): - try: - return int(value) - except ValueError: - raise ValueError(f"The value of {option} must be an integer") from None - raise TypeError(f"Wrong type for {option}, value must be an integer") - - -def validate_positive_integer(option: str, value: Any) -> int: - """Validate that 'value' is a positive integer, which does not include 0.""" - val = validate_integer(option, value) - if val <= 0: - raise ValueError(f"The value of {option} must be a positive integer") - return val - - -def validate_non_negative_integer(option: str, value: Any) -> int: - """Validate that 'value' is a positive integer or 0.""" - val = validate_integer(option, value) - if val < 0: - raise ValueError(f"The value of {option} must be a non negative integer") - return val - - -def validate_readable(option: str, value: Any) -> Optional[str]: - """Validates that 'value' is file-like and readable.""" - if value is None: - return value - # First make sure its a string py3.3 open(True, 'r') succeeds - # Used in ssl cert checking due to poor ssl module error reporting - value = validate_string(option, value) - open(value).close() - return value - - -def validate_positive_integer_or_none(option: str, value: Any) -> Optional[int]: - """Validate that 'value' is a positive integer or None.""" - if value is None: - return value - return validate_positive_integer(option, value) - - -def validate_non_negative_integer_or_none(option: str, value: Any) -> Optional[int]: - """Validate that 'value' is a positive integer or 0 or None.""" - if value is None: - return value - return validate_non_negative_integer(option, value) - - -def validate_string(option: str, value: Any) -> str: - """Validates that 'value' is an instance of `str`.""" - if isinstance(value, str): - return value - raise TypeError(f"Wrong type for {option}, value must be an instance of str") - - -def validate_string_or_none(option: str, value: Any) -> Optional[str]: - """Validates that 'value' is an instance of `basestring` or `None`.""" - if value is None: - return value - return validate_string(option, value) - - -def validate_int_or_basestring(option: str, value: Any) -> Union[int, str]: - """Validates that 'value' is an integer or string.""" - if isinstance(value, int): - return value - elif isinstance(value, str): - try: - return int(value) - except ValueError: - return value - raise TypeError(f"Wrong type for {option}, value must be an integer or a string") - - -def validate_non_negative_int_or_basestring(option: Any, value: Any) -> Union[int, str]: - """Validates that 'value' is an integer or string.""" - if isinstance(value, int): - return value - elif isinstance(value, str): - try: - val = int(value) - except ValueError: - return value - return validate_non_negative_integer(option, val) - raise TypeError(f"Wrong type for {option}, value must be an non negative integer or a string") - - -def validate_positive_float(option: str, value: Any) -> float: - """Validates that 'value' is a float, or can be converted to one, and is - positive. - """ - errmsg = f"{option} must be an integer or float" - try: - value = float(value) - except ValueError: - raise ValueError(errmsg) from None - except TypeError: - raise TypeError(errmsg) from None - - # float('inf') doesn't work in 2.4 or 2.5 on Windows, so just cap floats at - # one billion - this is a reasonable approximation for infinity - if not 0 < value < 1e9: - raise ValueError(f"{option} must be greater than 0 and less than one billion") - return value - - -def validate_positive_float_or_zero(option: str, value: Any) -> float: - """Validates that 'value' is 0 or a positive float, or can be converted to - 0 or a positive float. - """ - if value == 0 or value == "0": - return 0 - return validate_positive_float(option, value) - - -def validate_timeout_or_none(option: str, value: Any) -> Optional[float]: - """Validates a timeout specified in milliseconds returning - a value in floating point seconds. - """ - if value is None: - return value - return validate_positive_float(option, value) / 1000.0 - - -def validate_timeout_or_zero(option: str, value: Any) -> float: - """Validates a timeout specified in milliseconds returning - a value in floating point seconds for the case where None is an error - and 0 is valid. Setting the timeout to nothing in the URI string is a - config error. - """ - if value is None: - raise ConfigurationError(f"{option} cannot be None") - if value == 0 or value == "0": - return 0 - return validate_positive_float(option, value) / 1000.0 - - -def validate_timeout_or_none_or_zero(option: Any, value: Any) -> Optional[float]: - """Validates a timeout specified in milliseconds returning - a value in floating point seconds. value=0 and value="0" are treated the - same as value=None which means unlimited timeout. - """ - if value is None or value == 0 or value == "0": - return None - return validate_positive_float(option, value) / 1000.0 - - -def validate_timeoutms(option: Any, value: Any) -> Optional[float]: - """Validates a timeout specified in milliseconds returning - a value in floating point seconds. - """ - if value is None: - return None - return validate_positive_float_or_zero(option, value) / 1000.0 - - -def validate_max_staleness(option: str, value: Any) -> int: - """Validates maxStalenessSeconds according to the Max Staleness Spec.""" - if value == -1 or value == "-1": - # Default: No maximum staleness. - return -1 - return validate_positive_integer(option, value) - - -def validate_read_preference(dummy: Any, value: Any) -> _ServerMode: - """Validate a read preference.""" - if not isinstance(value, _ServerMode): - raise TypeError(f"{value!r} is not a read preference.") - return value - - -def validate_read_preference_mode(dummy: Any, value: Any) -> _ServerMode: - """Validate read preference mode for a MongoClient. - - .. versionchanged:: 3.5 - Returns the original ``value`` instead of the validated read preference - mode. - """ - if value not in _MONGOS_MODES: - raise ValueError(f"{value} is not a valid read preference") - return value - - -def validate_auth_mechanism(option: str, value: Any) -> str: - """Validate the authMechanism URI option.""" - from pymongo.synchronous.auth import MECHANISMS - - if value not in MECHANISMS: - raise ValueError(f"{option} must be in {tuple(MECHANISMS)}") - return value - - -def validate_uuid_representation(dummy: Any, value: Any) -> int: - """Validate the uuid representation option selected in the URI.""" - try: - return _UUID_REPRESENTATIONS[value] - except KeyError: - raise ValueError( - f"{value} is an invalid UUID representation. " - "Must be one of " - f"{tuple(_UUID_REPRESENTATIONS)}" - ) from None - - -def validate_read_preference_tags(name: str, value: Any) -> list[dict[str, str]]: - """Parse readPreferenceTags if passed as a client kwarg.""" - if not isinstance(value, list): - value = [value] - - tag_sets: list = [] - for tag_set in value: - if tag_set == "": - tag_sets.append({}) - continue - try: - tags = {} - for tag in tag_set.split(","): - key, val = tag.split(":") - tags[unquote_plus(key)] = unquote_plus(val) - tag_sets.append(tags) - except Exception: - raise ValueError(f"{tag_set!r} not a valid value for {name}") from None - return tag_sets - - -_MECHANISM_PROPS = frozenset( - [ - "SERVICE_NAME", - "CANONICALIZE_HOST_NAME", - "SERVICE_REALM", - "AWS_SESSION_TOKEN", - "ENVIRONMENT", - "TOKEN_RESOURCE", - ] -) - - -def validate_auth_mechanism_properties(option: str, value: Any) -> dict[str, Union[bool, str]]: - """Validate authMechanismProperties.""" - props: dict[str, Any] = {} - if not isinstance(value, str): - if not isinstance(value, dict): - raise ValueError("Auth mechanism properties must be given as a string or a dictionary") - for key, value in value.items(): # noqa: B020 - if isinstance(value, str): - props[key] = value - elif isinstance(value, bool): - props[key] = str(value).lower() - elif key in ["ALLOWED_HOSTS"] and isinstance(value, list): - props[key] = value - elif key in ["OIDC_CALLBACK", "OIDC_HUMAN_CALLBACK"]: - from pymongo.synchronous.auth_oidc import OIDCCallback - - if not isinstance(value, OIDCCallback): - raise ValueError("callback must be an OIDCCallback object") - props[key] = value - else: - raise ValueError(f"Invalid type for auth mechanism property {key}, {type(value)}") - return props - - value = validate_string(option, value) - value = unquote_plus(value) - for opt in value.split(","): - key, _, val = opt.partition(":") - if not val: - raise ValueError("Malformed auth mechanism properties") - if key not in _MECHANISM_PROPS: - # Try not to leak the token. - if "AWS_SESSION_TOKEN" in key: - raise ValueError( - "auth mechanism properties must be " - "key:value pairs like AWS_SESSION_TOKEN:" - ) - - raise ValueError( - f"{key} is not a supported auth " - "mechanism property. Must be one of " - f"{tuple(_MECHANISM_PROPS)}." - ) - - if key == "CANONICALIZE_HOST_NAME": - props[key] = validate_boolean_or_string(key, val) - else: - props[key] = val - - return props - - -def validate_document_class( - option: str, value: Any -) -> Union[Type[MutableMapping], Type[RawBSONDocument]]: - """Validate the document_class option.""" - # issubclass can raise TypeError for generic aliases like SON[str, Any]. - # In that case we can use the base class for the comparison. - is_mapping = False - try: - is_mapping = issubclass(value, abc.MutableMapping) - except TypeError: - if hasattr(value, "__origin__"): - is_mapping = issubclass(value.__origin__, abc.MutableMapping) - if not is_mapping and not issubclass(value, RawBSONDocument): - raise TypeError( - f"{option} must be dict, bson.son.SON, " - "bson.raw_bson.RawBSONDocument, or a " - "subclass of collections.MutableMapping" - ) - return value - - -def validate_type_registry(option: Any, value: Any) -> Optional[TypeRegistry]: - """Validate the type_registry option.""" - if value is not None and not isinstance(value, TypeRegistry): - raise TypeError(f"{option} must be an instance of {TypeRegistry}") - return value - - -def validate_list(option: str, value: Any) -> list: - """Validates that 'value' is a list.""" - if not isinstance(value, list): - raise TypeError(f"{option} must be a list") - return value - - -def validate_list_or_none(option: Any, value: Any) -> Optional[list]: - """Validates that 'value' is a list or None.""" - if value is None: - return value - return validate_list(option, value) - - -def validate_list_or_mapping(option: Any, value: Any) -> None: - """Validates that 'value' is a list or a document.""" - if not isinstance(value, (abc.Mapping, list)): - raise TypeError( - f"{option} must either be a list or an instance of dict, " - "bson.son.SON, or any other type that inherits from " - "collections.Mapping" - ) - - -def validate_is_mapping(option: str, value: Any) -> None: - """Validate the type of method arguments that expect a document.""" - if not isinstance(value, abc.Mapping): - raise TypeError( - f"{option} must be an instance of dict, bson.son.SON, or " - "any other type that inherits from " - "collections.Mapping" - ) - - -def validate_is_document_type(option: str, value: Any) -> None: - """Validate the type of method arguments that expect a MongoDB document.""" - if not isinstance(value, (abc.MutableMapping, RawBSONDocument)): - raise TypeError( - f"{option} must be an instance of dict, bson.son.SON, " - "bson.raw_bson.RawBSONDocument, or " - "a type that inherits from " - "collections.MutableMapping" - ) - - -def validate_appname_or_none(option: str, value: Any) -> Optional[str]: - """Validate the appname option.""" - if value is None: - return value - validate_string(option, value) - # We need length in bytes, so encode utf8 first. - if len(value.encode("utf-8")) > 128: - raise ValueError(f"{option} must be <= 128 bytes") - return value - - -def validate_driver_or_none(option: Any, value: Any) -> Optional[DriverInfo]: - """Validate the driver keyword arg.""" - if value is None: - return value - if not isinstance(value, DriverInfo): - raise TypeError(f"{option} must be an instance of DriverInfo") - return value - - -def validate_server_api_or_none(option: Any, value: Any) -> Optional[ServerApi]: - """Validate the server_api keyword arg.""" - if value is None: - return value - if not isinstance(value, ServerApi): - raise TypeError(f"{option} must be an instance of ServerApi") - return value - - -def validate_is_callable_or_none(option: Any, value: Any) -> Optional[Callable]: - """Validates that 'value' is a callable.""" - if value is None: - return value - if not callable(value): - raise ValueError(f"{option} must be a callable") - return value - - -def validate_ok_for_replace(replacement: Mapping[str, Any]) -> None: - """Validate a replacement document.""" - validate_is_mapping("replacement", replacement) - # Replacement can be {} - if replacement and not isinstance(replacement, RawBSONDocument): - first = next(iter(replacement)) - if first.startswith("$"): - raise ValueError("replacement can not include $ operators") - - -def validate_ok_for_update(update: Any) -> None: - """Validate an update document.""" - validate_list_or_mapping("update", update) - # Update cannot be {}. - if not update: - raise ValueError("update cannot be empty") - - is_document = not isinstance(update, list) - first = next(iter(update)) - if is_document and not first.startswith("$"): - raise ValueError("update only works with $ operators") - - -_UNICODE_DECODE_ERROR_HANDLERS = frozenset(["strict", "replace", "ignore"]) - - -def validate_unicode_decode_error_handler(dummy: Any, value: str) -> str: - """Validate the Unicode decode error handler option of CodecOptions.""" - if value not in _UNICODE_DECODE_ERROR_HANDLERS: - raise ValueError( - f"{value} is an invalid Unicode decode error handler. " - "Must be one of " - f"{tuple(_UNICODE_DECODE_ERROR_HANDLERS)}" - ) - return value - - -def validate_tzinfo(dummy: Any, value: Any) -> Optional[datetime.tzinfo]: - """Validate the tzinfo option""" - if value is not None and not isinstance(value, datetime.tzinfo): - raise TypeError("%s must be an instance of datetime.tzinfo" % value) - return value - - -def validate_auto_encryption_opts_or_none(option: Any, value: Any) -> Optional[Any]: - """Validate the driver keyword arg.""" - if value is None: - return value - from pymongo.synchronous.encryption_options import AutoEncryptionOpts - - if not isinstance(value, AutoEncryptionOpts): - raise TypeError(f"{option} must be an instance of AutoEncryptionOpts") - - return value - - -def validate_datetime_conversion(option: Any, value: Any) -> Optional[DatetimeConversion]: - """Validate a DatetimeConversion string.""" - if value is None: - return DatetimeConversion.DATETIME - - if isinstance(value, str): - if value.isdigit(): - return DatetimeConversion(int(value)) - return DatetimeConversion[value] - elif isinstance(value, int): - return DatetimeConversion(value) - - raise TypeError(f"{option} must be a str or int representing DatetimeConversion") - - -def validate_server_monitoring_mode(option: str, value: str) -> str: - """Validate the serverMonitoringMode option.""" - if value not in {"auto", "stream", "poll"}: - raise ValueError( - f'{option}={value!r} is invalid. Must be one of "auto", "stream", or "poll"' - ) - return value - - -# Dictionary where keys are the names of public URI options, and values -# are lists of aliases for that option. -URI_OPTIONS_ALIAS_MAP: dict[str, list[str]] = { - "tls": ["ssl"], -} - -# Dictionary where keys are the names of URI options, and values -# are functions that validate user-input values for that option. If an option -# alias uses a different validator than its public counterpart, it should be -# included here as a key, value pair. -URI_OPTIONS_VALIDATOR_MAP: dict[str, Callable[[Any, Any], Any]] = { - "appname": validate_appname_or_none, - "authmechanism": validate_auth_mechanism, - "authmechanismproperties": validate_auth_mechanism_properties, - "authsource": validate_string, - "compressors": validate_compressors, - "connecttimeoutms": validate_timeout_or_none_or_zero, - "directconnection": validate_boolean_or_string, - "heartbeatfrequencyms": validate_timeout_or_none, - "journal": validate_boolean_or_string, - "localthresholdms": validate_positive_float_or_zero, - "maxidletimems": validate_timeout_or_none, - "maxconnecting": validate_positive_integer, - "maxpoolsize": validate_non_negative_integer_or_none, - "maxstalenessseconds": validate_max_staleness, - "readconcernlevel": validate_string_or_none, - "readpreference": validate_read_preference_mode, - "readpreferencetags": validate_read_preference_tags, - "replicaset": validate_string_or_none, - "retryreads": validate_boolean_or_string, - "retrywrites": validate_boolean_or_string, - "loadbalanced": validate_boolean_or_string, - "serverselectiontimeoutms": validate_timeout_or_zero, - "sockettimeoutms": validate_timeout_or_none_or_zero, - "tls": validate_boolean_or_string, - "tlsallowinvalidcertificates": validate_boolean_or_string, - "tlsallowinvalidhostnames": validate_boolean_or_string, - "tlscafile": validate_readable, - "tlscertificatekeyfile": validate_readable, - "tlscertificatekeyfilepassword": validate_string_or_none, - "tlsdisableocspendpointcheck": validate_boolean_or_string, - "tlsinsecure": validate_boolean_or_string, - "w": validate_non_negative_int_or_basestring, - "wtimeoutms": validate_non_negative_integer, - "zlibcompressionlevel": validate_zlib_compression_level, - "srvservicename": validate_string, - "srvmaxhosts": validate_non_negative_integer, - "timeoutms": validate_timeoutms, - "servermonitoringmode": validate_server_monitoring_mode, -} - -# Dictionary where keys are the names of URI options specific to pymongo, -# and values are functions that validate user-input values for those options. -NONSPEC_OPTIONS_VALIDATOR_MAP: dict[str, Callable[[Any, Any], Any]] = { - "connect": validate_boolean_or_string, - "driver": validate_driver_or_none, - "server_api": validate_server_api_or_none, - "fsync": validate_boolean_or_string, - "minpoolsize": validate_non_negative_integer, - "tlscrlfile": validate_readable, - "tz_aware": validate_boolean_or_string, - "unicode_decode_error_handler": validate_unicode_decode_error_handler, - "uuidrepresentation": validate_uuid_representation, - "waitqueuemultiple": validate_non_negative_integer_or_none, - "waitqueuetimeoutms": validate_timeout_or_none, - "datetime_conversion": validate_datetime_conversion, -} - -# Dictionary where keys are the names of keyword-only options for the -# MongoClient constructor, and values are functions that validate user-input -# values for those options. -KW_VALIDATORS: dict[str, Callable[[Any, Any], Any]] = { - "document_class": validate_document_class, - "type_registry": validate_type_registry, - "read_preference": validate_read_preference, - "event_listeners": _validate_event_listeners, - "tzinfo": validate_tzinfo, - "username": validate_string_or_none, - "password": validate_string_or_none, - "server_selector": validate_is_callable_or_none, - "auto_encryption_opts": validate_auto_encryption_opts_or_none, - "authoidcallowedhosts": validate_list, -} - -# Dictionary where keys are any URI option name, and values are the -# internally-used names of that URI option. Options with only one name -# variant need not be included here. Options whose public and internal -# names are the same need not be included here. -INTERNAL_URI_OPTION_NAME_MAP: dict[str, str] = { - "ssl": "tls", -} - -# Map from deprecated URI option names to a tuple indicating the method of -# their deprecation and any additional information that may be needed to -# construct the warning message. -URI_OPTIONS_DEPRECATION_MAP: dict[str, tuple[str, str]] = { - # format: : (, ), - # Supported values: - # - 'renamed': should be the new option name. Note that case is - # preserved for renamed options as they are part of user warnings. - # - 'removed': may suggest the rationale for deprecating the - # option and/or recommend remedial action. - # For example: - # 'wtimeout': ('renamed', 'wTimeoutMS'), -} - -# Augment the option validator map with pymongo-specific option information. -URI_OPTIONS_VALIDATOR_MAP.update(NONSPEC_OPTIONS_VALIDATOR_MAP) -for optname, aliases in URI_OPTIONS_ALIAS_MAP.items(): - for alias in aliases: - if alias not in URI_OPTIONS_VALIDATOR_MAP: - URI_OPTIONS_VALIDATOR_MAP[alias] = URI_OPTIONS_VALIDATOR_MAP[optname] - -# Map containing all URI option and keyword argument validators. -VALIDATORS: dict[str, Callable[[Any, Any], Any]] = URI_OPTIONS_VALIDATOR_MAP.copy() -VALIDATORS.update(KW_VALIDATORS) - -# List of timeout-related options. -TIMEOUT_OPTIONS: list[str] = [ - "connecttimeoutms", - "heartbeatfrequencyms", - "maxidletimems", - "maxstalenessseconds", - "serverselectiontimeoutms", - "sockettimeoutms", - "waitqueuetimeoutms", -] - -_AUTH_OPTIONS = frozenset(["authmechanismproperties"]) - - -def validate_auth_option(option: str, value: Any) -> tuple[str, Any]: - """Validate optional authentication parameters.""" - lower, value = validate(option, value) - if lower not in _AUTH_OPTIONS: - raise ConfigurationError(f"Unknown option: {option}. Must be in {_AUTH_OPTIONS}") - return option, value - - -def _get_validator( - key: str, validators: dict[str, Callable[[Any, Any], Any]], normed_key: Optional[str] = None -) -> Callable: - normed_key = normed_key or key - try: - return validators[normed_key] - except KeyError: - suggestions = get_close_matches(normed_key, validators, cutoff=0.2) - raise_config_error(key, suggestions) - - -def validate(option: str, value: Any) -> tuple[str, Any]: - """Generic validation function.""" - validator = _get_validator(option, VALIDATORS, normed_key=option.lower()) - value = validator(option, value) - return option, value - - -def get_validated_options( - options: Mapping[str, Any], warn: bool = True -) -> MutableMapping[str, Any]: - """Validate each entry in options and raise a warning if it is not valid. - Returns a copy of options with invalid entries removed. - - :param opts: A dict containing MongoDB URI options. - :param warn: If ``True`` then warnings will be logged and - invalid options will be ignored. Otherwise, invalid options will - cause errors. - """ - validated_options: MutableMapping[str, Any] - if isinstance(options, _CaseInsensitiveDictionary): - validated_options = _CaseInsensitiveDictionary() - - def get_normed_key(x: str) -> str: - return x - - def get_setter_key(x: str) -> str: - return options.cased_key(x) # type: ignore[attr-defined] - - else: - validated_options = {} - - def get_normed_key(x: str) -> str: - return x.lower() - - def get_setter_key(x: str) -> str: - return x - - for opt, value in options.items(): - normed_key = get_normed_key(opt) - try: - validator = _get_validator(opt, URI_OPTIONS_VALIDATOR_MAP, normed_key=normed_key) - validated = validator(opt, value) - except (ValueError, TypeError, ConfigurationError) as exc: - if warn: - warnings.warn(str(exc), stacklevel=2) - else: - raise - else: - validated_options[get_setter_key(normed_key)] = validated - return validated_options - - -def _esc_coll_name(encrypted_fields: Mapping[str, Any], name: str) -> Any: - return encrypted_fields.get("escCollection", f"enxcol_.{name}.esc") - - -def _ecoc_coll_name(encrypted_fields: Mapping[str, Any], name: str) -> Any: - return encrypted_fields.get("ecocCollection", f"enxcol_.{name}.ecoc") - - -# List of write-concern-related options. -WRITE_CONCERN_OPTIONS = frozenset(["w", "wtimeout", "wtimeoutms", "fsync", "j", "journal"]) - - -class BaseObject: - """A base class that provides attributes and methods common - to multiple pymongo classes. - - SHOULD NOT BE USED BY DEVELOPERS EXTERNAL TO MONGODB. - """ - - def __init__( - self, - codec_options: CodecOptions, - read_preference: _ServerMode, - write_concern: WriteConcern, - read_concern: ReadConcern, - ) -> None: - if not isinstance(codec_options, CodecOptions): - raise TypeError("codec_options must be an instance of bson.codec_options.CodecOptions") - self._codec_options = codec_options - - if not isinstance(read_preference, _ServerMode): - raise TypeError( - f"{read_preference!r} is not valid for read_preference. See " - "pymongo.read_preferences for valid " - "options." - ) - self._read_preference = read_preference - - if not isinstance(write_concern, WriteConcern): - raise TypeError( - "write_concern must be an instance of pymongo.write_concern.WriteConcern" - ) - self._write_concern = write_concern - - if not isinstance(read_concern, ReadConcern): - raise TypeError("read_concern must be an instance of pymongo.read_concern.ReadConcern") - self._read_concern = read_concern - - @property - def codec_options(self) -> CodecOptions: - """Read only access to the :class:`~bson.codec_options.CodecOptions` - of this instance. - """ - return self._codec_options - - @property - def write_concern(self) -> WriteConcern: - """Read only access to the :class:`~pymongo.write_concern.WriteConcern` - of this instance. - - .. versionchanged:: 3.0 - The :attr:`write_concern` attribute is now read only. - """ - return self._write_concern - - def _write_concern_for(self, session: Optional[ClientSession]) -> WriteConcern: - """Read only access to the write concern of this instance or session.""" - # Override this operation's write concern with the transaction's. - if session and session.in_transaction: - return DEFAULT_WRITE_CONCERN - return self.write_concern - - @property - def read_preference(self) -> _ServerMode: - """Read only access to the read preference of this instance. - - .. versionchanged:: 3.0 - The :attr:`read_preference` attribute is now read only. - """ - return self._read_preference - - def _read_preference_for(self, session: Optional[ClientSession]) -> _ServerMode: - """Read only access to the read preference of this instance or session.""" - # Override this operation's read preference with the transaction's. - if session: - return session._txn_read_preference() or self._read_preference - return self._read_preference - - @property - def read_concern(self) -> ReadConcern: - """Read only access to the :class:`~pymongo.read_concern.ReadConcern` - of this instance. - - .. versionadded:: 3.2 - """ - return self._read_concern - - -class _CaseInsensitiveDictionary(MutableMapping[str, Any]): - def __init__(self, *args: Any, **kwargs: Any): - self.__casedkeys: dict[str, Any] = {} - self.__data: dict[str, Any] = {} - self.update(dict(*args, **kwargs)) - - def __contains__(self, key: str) -> bool: # type: ignore[override] - return key.lower() in self.__data - - def __len__(self) -> int: - return len(self.__data) - - def __iter__(self) -> Iterator[str]: - return (key for key in self.__casedkeys) - - def __repr__(self) -> str: - return str({self.__casedkeys[k]: self.__data[k] for k in self}) - - def __setitem__(self, key: str, value: Any) -> None: - lc_key = key.lower() - self.__casedkeys[lc_key] = key - self.__data[lc_key] = value - - def __getitem__(self, key: str) -> Any: - return self.__data[key.lower()] - - def __delitem__(self, key: str) -> None: - lc_key = key.lower() - del self.__casedkeys[lc_key] - del self.__data[lc_key] - - def __eq__(self, other: Any) -> bool: - if not isinstance(other, abc.Mapping): - return NotImplemented - if len(self) != len(other): - return False - for key in other: # noqa: SIM110 - if self[key] != other[key]: - return False - - return True - - def get(self, key: str, default: Optional[Any] = None) -> Any: - return self.__data.get(key.lower(), default) - - def pop(self, key: str, *args: Any, **kwargs: Any) -> Any: - lc_key = key.lower() - self.__casedkeys.pop(lc_key, None) - return self.__data.pop(lc_key, *args, **kwargs) - - def popitem(self) -> tuple[str, Any]: - lc_key, cased_key = self.__casedkeys.popitem() - value = self.__data.pop(lc_key) - return cased_key, value - - def clear(self) -> None: - self.__casedkeys.clear() - self.__data.clear() - - @overload - def setdefault(self, key: str, default: None = None) -> Optional[Any]: - ... - - @overload - def setdefault(self, key: str, default: Any) -> Any: - ... - - def setdefault(self, key: str, default: Optional[Any] = None) -> Optional[Any]: - lc_key = key.lower() - if key in self: - return self.__data[lc_key] - else: - self.__casedkeys[lc_key] = key - self.__data[lc_key] = default - return default - - def update(self, other: Mapping[str, Any]) -> None: # type: ignore[override] - if isinstance(other, _CaseInsensitiveDictionary): - for key in other: - self[other.cased_key(key)] = other[key] - else: - for key in other: - self[key] = other[key] - - def cased_key(self, key: str) -> Any: - return self.__casedkeys[key.lower()] diff --git a/pymongo/synchronous/compression_support.py b/pymongo/synchronous/compression_support.py deleted file mode 100644 index e5153f8c87..0000000000 --- a/pymongo/synchronous/compression_support.py +++ /dev/null @@ -1,178 +0,0 @@ -# Copyright 2018 MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from __future__ import annotations - -import warnings -from typing import Any, Iterable, Optional, Union - -from pymongo.helpers_constants import _SENSITIVE_COMMANDS -from pymongo.synchronous.hello_compat import HelloCompat - -_IS_SYNC = True - - -_SUPPORTED_COMPRESSORS = {"snappy", "zlib", "zstd"} -_NO_COMPRESSION = {HelloCompat.CMD, HelloCompat.LEGACY_CMD} -_NO_COMPRESSION.update(_SENSITIVE_COMMANDS) - - -def _have_snappy() -> bool: - try: - import snappy # type:ignore[import] # noqa: F401 - - return True - except ImportError: - return False - - -def _have_zlib() -> bool: - try: - import zlib # noqa: F401 - - return True - except ImportError: - return False - - -def _have_zstd() -> bool: - try: - import zstandard # noqa: F401 - - return True - except ImportError: - return False - - -def validate_compressors(dummy: Any, value: Union[str, Iterable[str]]) -> list[str]: - try: - # `value` is string. - compressors = value.split(",") # type: ignore[union-attr] - except AttributeError: - # `value` is an iterable. - compressors = list(value) - - for compressor in compressors[:]: - if compressor not in _SUPPORTED_COMPRESSORS: - compressors.remove(compressor) - warnings.warn(f"Unsupported compressor: {compressor}", stacklevel=2) - elif compressor == "snappy" and not _have_snappy(): - compressors.remove(compressor) - warnings.warn( - "Wire protocol compression with snappy is not available. " - "You must install the python-snappy module for snappy support.", - stacklevel=2, - ) - elif compressor == "zlib" and not _have_zlib(): - compressors.remove(compressor) - warnings.warn( - "Wire protocol compression with zlib is not available. " - "The zlib module is not available.", - stacklevel=2, - ) - elif compressor == "zstd" and not _have_zstd(): - compressors.remove(compressor) - warnings.warn( - "Wire protocol compression with zstandard is not available. " - "You must install the zstandard module for zstandard support.", - stacklevel=2, - ) - return compressors - - -def validate_zlib_compression_level(option: str, value: Any) -> int: - try: - level = int(value) - except Exception: - raise TypeError(f"{option} must be an integer, not {value!r}.") from None - if level < -1 or level > 9: - raise ValueError("%s must be between -1 and 9, not %d." % (option, level)) - return level - - -class CompressionSettings: - def __init__(self, compressors: list[str], zlib_compression_level: int): - self.compressors = compressors - self.zlib_compression_level = zlib_compression_level - - def get_compression_context( - self, compressors: Optional[list[str]] - ) -> Union[SnappyContext, ZlibContext, ZstdContext, None]: - if compressors: - chosen = compressors[0] - if chosen == "snappy": - return SnappyContext() - elif chosen == "zlib": - return ZlibContext(self.zlib_compression_level) - elif chosen == "zstd": - return ZstdContext() - return None - return None - - -class SnappyContext: - compressor_id = 1 - - @staticmethod - def compress(data: bytes) -> bytes: - import snappy - - return snappy.compress(data) - - -class ZlibContext: - compressor_id = 2 - - def __init__(self, level: int): - self.level = level - - def compress(self, data: bytes) -> bytes: - import zlib - - return zlib.compress(data, self.level) - - -class ZstdContext: - compressor_id = 3 - - @staticmethod - def compress(data: bytes) -> bytes: - # ZstdCompressor is not thread safe. - # TODO: Use a pool? - - import zstandard - - return zstandard.ZstdCompressor().compress(data) - - -def decompress(data: bytes, compressor_id: int) -> bytes: - if compressor_id == SnappyContext.compressor_id: - # python-snappy doesn't support the buffer interface. - # https://github.com/andrix/python-snappy/issues/65 - # This only matters when data is a memoryview since - # id(bytes(data)) == id(data) when data is a bytes. - import snappy - - return snappy.uncompress(bytes(data)) - elif compressor_id == ZlibContext.compressor_id: - import zlib - - return zlib.decompress(data) - elif compressor_id == ZstdContext.compressor_id: - # ZstdDecompressor is not thread safe. - # TODO: Use a pool? - import zstandard - - return zstandard.ZstdDecompressor().decompress(data) - else: - raise ValueError("Unknown compressorId %d" % (compressor_id,)) diff --git a/pymongo/synchronous/cursor.py b/pymongo/synchronous/cursor.py index b74266a74e..cacaeb7aad 100644 --- a/pymongo/synchronous/cursor.py +++ b/pymongo/synchronous/cursor.py @@ -36,15 +36,16 @@ from bson import RE_TYPE, _convert_raw_document_lists_to_streams from bson.code import Code from bson.son import SON -from pymongo.cursor_shared import _CURSOR_CLOSED_ERRORS, _QUERY_OPTIONS, CursorType, _Hint, _Sort -from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure -from pymongo.lock import _create_lock -from pymongo.synchronous import helpers -from pymongo.synchronous.collation import validate_collation_or_none -from pymongo.synchronous.common import ( +from pymongo import helpers_shared +from pymongo.collation import validate_collation_or_none +from pymongo.common import ( validate_is_document_type, validate_is_mapping, ) +from pymongo.cursor_shared import _CURSOR_CLOSED_ERRORS, _QUERY_OPTIONS, CursorType, _Hint, _Sort +from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure +from pymongo.lock import _create_lock +from pymongo.response import PinnedResponse from pymongo.synchronous.helpers import next from pymongo.synchronous.message import ( _CursorAddress, @@ -55,18 +56,17 @@ _RawBatchGetMore, _RawBatchQuery, ) -from pymongo.synchronous.response import PinnedResponse -from pymongo.synchronous.typings import _Address, _CollationIn, _DocumentOut, _DocumentType +from pymongo.typings import _Address, _CollationIn, _DocumentOut, _DocumentType from pymongo.write_concern import validate_boolean if TYPE_CHECKING: from _typeshed import SupportsItems from bson.codec_options import CodecOptions + from pymongo.read_preferences import _ServerMode from pymongo.synchronous.client_session import ClientSession from pymongo.synchronous.collection import Collection from pymongo.synchronous.pool import Connection - from pymongo.synchronous.read_preferences import _ServerMode _IS_SYNC = True @@ -179,7 +179,7 @@ def __init__( allow_disk_use = validate_boolean("allow_disk_use", allow_disk_use) if projection is not None: - projection = helpers._fields_list_to_dict(projection, "projection") + projection = helpers_shared._fields_list_to_dict(projection, "projection") if let is not None: validate_is_document_type("let", let) @@ -191,7 +191,7 @@ def __init__( self._skip = skip self._limit = limit self._batch_size = batch_size - self._ordering = sort and helpers._index_document(sort) or None + self._ordering = sort and helpers_shared._index_document(sort) or None self._max_scan = max_scan self._explain = False self._comment = comment @@ -741,8 +741,8 @@ def sort( key, if not given :data:`~pymongo.ASCENDING` is assumed """ self._check_okay_to_chain() - keys = helpers._index_list(key_or_list, direction) - self._ordering = helpers._index_document(keys) + keys = helpers_shared._index_list(key_or_list, direction) + self._ordering = helpers_shared._index_document(keys) return self def explain(self) -> _DocumentType: @@ -773,7 +773,7 @@ def _set_hint(self, index: Optional[_Hint]) -> None: if isinstance(index, str): self._hint = index else: - self._hint = helpers._index_document(index) + self._hint = helpers_shared._index_document(index) def hint(self, index: Optional[_Hint]) -> Cursor[_DocumentType]: """Adds a 'hint', telling Mongo the proper index to use for the query. diff --git a/pymongo/synchronous/database.py b/pymongo/synchronous/database.py index 92521d7c14..eaef0558d5 100644 --- a/pymongo/synchronous/database.py +++ b/pymongo/synchronous/database.py @@ -33,18 +33,17 @@ from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions from bson.dbref import DBRef from bson.timestamp import Timestamp -from pymongo import _csot +from pymongo import _csot, common +from pymongo.common import _ecoc_coll_name, _esc_coll_name from pymongo.database_shared import _check_name, _CodecDocumentType from pymongo.errors import CollectionInvalid, InvalidOperation -from pymongo.synchronous import common +from pymongo.operations import _Op +from pymongo.read_preferences import ReadPreference, _ServerMode from pymongo.synchronous.aggregation import _DatabaseAggregationCommand from pymongo.synchronous.change_stream import DatabaseChangeStream from pymongo.synchronous.collection import Collection from pymongo.synchronous.command_cursor import CommandCursor -from pymongo.synchronous.common import _ecoc_coll_name, _esc_coll_name -from pymongo.synchronous.operations import _Op -from pymongo.synchronous.read_preferences import ReadPreference, _ServerMode -from pymongo.synchronous.typings import _CollationIn, _DocumentType, _DocumentTypeArg, _Pipeline +from pymongo.typings import _CollationIn, _DocumentType, _DocumentTypeArg, _Pipeline if TYPE_CHECKING: import bson @@ -151,7 +150,7 @@ def with_options( >>> db1.read_preference Primary() - >>> from pymongo.synchronous.read_preferences import Secondary + >>> from pymongo.read_preferences import Secondary >>> db2 = db1.with_options(read_preference=Secondary([{'node': 'analytics'}])) >>> db1.read_preference Primary() diff --git a/pymongo/synchronous/encryption.py b/pymongo/synchronous/encryption.py index cb248c5643..1e95a36dcf 100644 --- a/pymongo/synchronous/encryption.py +++ b/pymongo/synchronous/encryption.py @@ -58,7 +58,9 @@ from bson.errors import BSONError from bson.raw_bson import DEFAULT_RAW_BSON_OPTIONS, RawBSONDocument, _inflate_bson from pymongo import _csot +from pymongo.common import CONNECT_TIMEOUT from pymongo.daemon import _spawn_daemon +from pymongo.encryption_options import AutoEncryptionOpts, RangeOpts from pymongo.errors import ( ConfigurationError, EncryptedCollectionError, @@ -68,19 +70,18 @@ ServerSelectionTimeoutError, ) from pymongo.network_layer import BLOCKING_IO_ERRORS, sendall +from pymongo.operations import UpdateOne +from pymongo.pool_options import PoolOptions from pymongo.read_concern import ReadConcern from pymongo.results import BulkWriteResult, DeleteResult from pymongo.ssl_support import get_ssl_context from pymongo.synchronous.collection import Collection -from pymongo.synchronous.common import CONNECT_TIMEOUT from pymongo.synchronous.cursor import Cursor from pymongo.synchronous.database import Database -from pymongo.synchronous.encryption_options import AutoEncryptionOpts, RangeOpts from pymongo.synchronous.mongo_client import MongoClient -from pymongo.synchronous.operations import UpdateOne -from pymongo.synchronous.pool import PoolOptions, _configured_socket, _raise_connection_failure -from pymongo.synchronous.typings import _DocumentType, _DocumentTypeArg -from pymongo.synchronous.uri_parser import parse_host +from pymongo.synchronous.pool import _configured_socket, _raise_connection_failure +from pymongo.typings import _DocumentType, _DocumentTypeArg +from pymongo.uri_parser import parse_host from pymongo.write_concern import WriteConcern if TYPE_CHECKING: @@ -379,7 +380,10 @@ def _get_internal_client( ) io_callbacks = _EncryptionIO( # type:ignore[misc] - metadata_client, key_vault_coll, mongocryptd_client, opts + metadata_client, + key_vault_coll, # type:ignore[arg-type] + mongocryptd_client, + opts, ) self._auto_encrypter = AutoEncrypter( io_callbacks, diff --git a/pymongo/synchronous/encryption_options.py b/pymongo/synchronous/encryption_options.py deleted file mode 100644 index 03bc01d181..0000000000 --- a/pymongo/synchronous/encryption_options.py +++ /dev/null @@ -1,270 +0,0 @@ -# Copyright 2019-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Support for automatic client-side field level encryption.""" -from __future__ import annotations - -from typing import TYPE_CHECKING, Any, Mapping, Optional - -try: - import pymongocrypt # type:ignore[import] # noqa: F401 - - _HAVE_PYMONGOCRYPT = True -except ImportError: - _HAVE_PYMONGOCRYPT = False -from bson import int64 -from pymongo.errors import ConfigurationError -from pymongo.synchronous.common import validate_is_mapping -from pymongo.synchronous.uri_parser import _parse_kms_tls_options - -if TYPE_CHECKING: - from pymongo.synchronous.mongo_client import MongoClient - from pymongo.synchronous.typings import _DocumentTypeArg - -_IS_SYNC = True - - -class AutoEncryptionOpts: - """Options to configure automatic client-side field level encryption.""" - - def __init__( - self, - kms_providers: Mapping[str, Any], - key_vault_namespace: str, - key_vault_client: Optional[MongoClient[_DocumentTypeArg]] = None, - schema_map: Optional[Mapping[str, Any]] = None, - bypass_auto_encryption: bool = False, - mongocryptd_uri: str = "mongodb://localhost:27020", - mongocryptd_bypass_spawn: bool = False, - mongocryptd_spawn_path: str = "mongocryptd", - mongocryptd_spawn_args: Optional[list[str]] = None, - kms_tls_options: Optional[Mapping[str, Any]] = None, - crypt_shared_lib_path: Optional[str] = None, - crypt_shared_lib_required: bool = False, - bypass_query_analysis: bool = False, - encrypted_fields_map: Optional[Mapping[str, Any]] = None, - ) -> None: - """Options to configure automatic client-side field level encryption. - - Automatic client-side field level encryption requires MongoDB >=4.2 - enterprise or a MongoDB >=4.2 Atlas cluster. Automatic encryption is not - supported for operations on a database or view and will result in - error. - - Although automatic encryption requires MongoDB >=4.2 enterprise or a - MongoDB >=4.2 Atlas cluster, automatic *decryption* is supported for all - users. To configure automatic *decryption* without automatic - *encryption* set ``bypass_auto_encryption=True``. Explicit - encryption and explicit decryption is also supported for all users - with the :class:`~pymongo.encryption.ClientEncryption` class. - - See :ref:`automatic-client-side-encryption` for an example. - - :param kms_providers: Map of KMS provider options. The `kms_providers` - map values differ by provider: - - - `aws`: Map with "accessKeyId" and "secretAccessKey" as strings. - These are the AWS access key ID and AWS secret access key used - to generate KMS messages. An optional "sessionToken" may be - included to support temporary AWS credentials. - - `azure`: Map with "tenantId", "clientId", and "clientSecret" as - strings. Additionally, "identityPlatformEndpoint" may also be - specified as a string (defaults to 'login.microsoftonline.com'). - These are the Azure Active Directory credentials used to - generate Azure Key Vault messages. - - `gcp`: Map with "email" as a string and "privateKey" - as `bytes` or a base64 encoded string. - Additionally, "endpoint" may also be specified as a string - (defaults to 'oauth2.googleapis.com'). These are the - credentials used to generate Google Cloud KMS messages. - - `kmip`: Map with "endpoint" as a host with required port. - For example: ``{"endpoint": "example.com:443"}``. - - `local`: Map with "key" as `bytes` (96 bytes in length) or - a base64 encoded string which decodes - to 96 bytes. "key" is the master key used to encrypt/decrypt - data keys. This key should be generated and stored as securely - as possible. - - KMS providers may be specified with an optional name suffix - separated by a colon, for example "kmip:name" or "aws:name". - Named KMS providers do not support :ref:`CSFLE on-demand credentials`. - Named KMS providers enables more than one of each KMS provider type to be configured. - For example, to configure multiple local KMS providers:: - - kms_providers = { - "local": {"key": local_kek1}, # Unnamed KMS provider. - "local:myname": {"key": local_kek2}, # Named KMS provider with name "myname". - } - - :param key_vault_namespace: The namespace for the key vault collection. - The key vault collection contains all data keys used for encryption - and decryption. Data keys are stored as documents in this MongoDB - collection. Data keys are protected with encryption by a KMS - provider. - :param key_vault_client: By default, the key vault collection - is assumed to reside in the same MongoDB cluster as the encrypted - MongoClient. Use this option to route data key queries to a - separate MongoDB cluster. - :param schema_map: Map of collection namespace ("db.coll") to - JSON Schema. By default, a collection's JSONSchema is periodically - polled with the listCollections command. But a JSONSchema may be - specified locally with the schemaMap option. - - **Supplying a `schema_map` provides more security than relying on - JSON Schemas obtained from the server. It protects against a - malicious server advertising a false JSON Schema, which could trick - the client into sending unencrypted data that should be - encrypted.** - - Schemas supplied in the schemaMap only apply to configuring - automatic encryption for client side encryption. Other validation - rules in the JSON schema will not be enforced by the driver and - will result in an error. - :param bypass_auto_encryption: If ``True``, automatic - encryption will be disabled but automatic decryption will still be - enabled. Defaults to ``False``. - :param mongocryptd_uri: The MongoDB URI used to connect - to the *local* mongocryptd process. Defaults to - ``'mongodb://localhost:27020'``. - :param mongocryptd_bypass_spawn: If ``True``, the encrypted - MongoClient will not attempt to spawn the mongocryptd process. - Defaults to ``False``. - :param mongocryptd_spawn_path: Used for spawning the - mongocryptd process. Defaults to ``'mongocryptd'`` and spawns - mongocryptd from the system path. - :param mongocryptd_spawn_args: A list of string arguments to - use when spawning the mongocryptd process. Defaults to - ``['--idleShutdownTimeoutSecs=60']``. If the list does not include - the ``idleShutdownTimeoutSecs`` option then - ``'--idleShutdownTimeoutSecs=60'`` will be added. - :param kms_tls_options: A map of KMS provider names to TLS - options to use when creating secure connections to KMS providers. - Accepts the same TLS options as - :class:`pymongo.mongo_client.MongoClient`. For example, to - override the system default CA file:: - - kms_tls_options={'kmip': {'tlsCAFile': certifi.where()}} - - Or to supply a client certificate:: - - kms_tls_options={'kmip': {'tlsCertificateKeyFile': 'client.pem'}} - :param crypt_shared_lib_path: Override the path to load the crypt_shared library. - :param crypt_shared_lib_required: If True, raise an error if libmongocrypt is - unable to load the crypt_shared library. - :param bypass_query_analysis: If ``True``, disable automatic analysis - of outgoing commands. Set `bypass_query_analysis` to use explicit - encryption on indexed fields without the MongoDB Enterprise Advanced - licensed crypt_shared library. - :param encrypted_fields_map: Map of collection namespace ("db.coll") to documents - that described the encrypted fields for Queryable Encryption. For example:: - - { - "db.encryptedCollection": { - "escCollection": "enxcol_.encryptedCollection.esc", - "ecocCollection": "enxcol_.encryptedCollection.ecoc", - "fields": [ - { - "path": "firstName", - "keyId": Binary.from_uuid(UUID('00000000-0000-0000-0000-000000000000')), - "bsonType": "string", - "queries": {"queryType": "equality"} - }, - { - "path": "ssn", - "keyId": Binary.from_uuid(UUID('04104104-1041-0410-4104-104104104104')), - "bsonType": "string" - } - ] - } - } - - .. versionchanged:: 4.2 - Added `encrypted_fields_map` `crypt_shared_lib_path`, `crypt_shared_lib_required`, - and `bypass_query_analysis` parameters. - - .. versionchanged:: 4.0 - Added the `kms_tls_options` parameter and the "kmip" KMS provider. - - .. versionadded:: 3.9 - """ - if not _HAVE_PYMONGOCRYPT: - raise ConfigurationError( - "client side encryption requires the pymongocrypt library: " - "install a compatible version with: " - "python -m pip install 'pymongo[encryption]'" - ) - if encrypted_fields_map: - validate_is_mapping("encrypted_fields_map", encrypted_fields_map) - self._encrypted_fields_map = encrypted_fields_map - self._bypass_query_analysis = bypass_query_analysis - self._crypt_shared_lib_path = crypt_shared_lib_path - self._crypt_shared_lib_required = crypt_shared_lib_required - self._kms_providers = kms_providers - self._key_vault_namespace = key_vault_namespace - self._key_vault_client = key_vault_client - self._schema_map = schema_map - self._bypass_auto_encryption = bypass_auto_encryption - self._mongocryptd_uri = mongocryptd_uri - self._mongocryptd_bypass_spawn = mongocryptd_bypass_spawn - self._mongocryptd_spawn_path = mongocryptd_spawn_path - if mongocryptd_spawn_args is None: - mongocryptd_spawn_args = ["--idleShutdownTimeoutSecs=60"] - self._mongocryptd_spawn_args = mongocryptd_spawn_args - if not isinstance(self._mongocryptd_spawn_args, list): - raise TypeError("mongocryptd_spawn_args must be a list") - if not any("idleShutdownTimeoutSecs" in s for s in self._mongocryptd_spawn_args): - self._mongocryptd_spawn_args.append("--idleShutdownTimeoutSecs=60") - # Maps KMS provider name to a SSLContext. - self._kms_ssl_contexts = _parse_kms_tls_options(kms_tls_options) - self._bypass_query_analysis = bypass_query_analysis - - -class RangeOpts: - """Options to configure encrypted queries using the rangePreview algorithm.""" - - def __init__( - self, - sparsity: int, - min: Optional[Any] = None, - max: Optional[Any] = None, - precision: Optional[int] = None, - ) -> None: - """Options to configure encrypted queries using the rangePreview algorithm. - - .. note:: This feature is experimental only, and not intended for public use. - - :param sparsity: An integer. - :param min: A BSON scalar value corresponding to the type being queried. - :param max: A BSON scalar value corresponding to the type being queried. - :param precision: An integer, may only be set for double or decimal128 types. - - .. versionadded:: 4.4 - """ - self.min = min - self.max = max - self.sparsity = sparsity - self.precision = precision - - @property - def document(self) -> dict[str, Any]: - doc = {} - for k, v in [ - ("sparsity", int64.Int64(self.sparsity)), - ("precision", self.precision), - ("min", self.min), - ("max", self.max), - ]: - if v is not None: - doc[k] = v - return doc diff --git a/pymongo/synchronous/event_loggers.py b/pymongo/synchronous/event_loggers.py deleted file mode 100644 index fe9dd899d3..0000000000 --- a/pymongo/synchronous/event_loggers.py +++ /dev/null @@ -1,225 +0,0 @@ -# Copyright 2020-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -"""Example event logger classes. - -.. versionadded:: 3.11 - -These loggers can be registered using :func:`register` or -:class:`~pymongo.mongo_client.MongoClient`. - -``monitoring.register(CommandLogger())`` - -or - -``MongoClient(event_listeners=[CommandLogger()])`` -""" -from __future__ import annotations - -import logging - -from pymongo.synchronous import monitoring - -_IS_SYNC = True - - -class CommandLogger(monitoring.CommandListener): - """A simple listener that logs command events. - - Listens for :class:`~pymongo.monitoring.CommandStartedEvent`, - :class:`~pymongo.monitoring.CommandSucceededEvent` and - :class:`~pymongo.monitoring.CommandFailedEvent` events and - logs them at the `INFO` severity level using :mod:`logging`. - .. versionadded:: 3.11 - """ - - def started(self, event: monitoring.CommandStartedEvent) -> None: - logging.info( - f"Command {event.command_name} with request id " - f"{event.request_id} started on server " - f"{event.connection_id}" - ) - - def succeeded(self, event: monitoring.CommandSucceededEvent) -> None: - logging.info( - f"Command {event.command_name} with request id " - f"{event.request_id} on server {event.connection_id} " - f"succeeded in {event.duration_micros} " - "microseconds" - ) - - def failed(self, event: monitoring.CommandFailedEvent) -> None: - logging.info( - f"Command {event.command_name} with request id " - f"{event.request_id} on server {event.connection_id} " - f"failed in {event.duration_micros} " - "microseconds" - ) - - -class ServerLogger(monitoring.ServerListener): - """A simple listener that logs server discovery events. - - Listens for :class:`~pymongo.monitoring.ServerOpeningEvent`, - :class:`~pymongo.monitoring.ServerDescriptionChangedEvent`, - and :class:`~pymongo.monitoring.ServerClosedEvent` - events and logs them at the `INFO` severity level using :mod:`logging`. - - .. versionadded:: 3.11 - """ - - def opened(self, event: monitoring.ServerOpeningEvent) -> None: - logging.info(f"Server {event.server_address} added to topology {event.topology_id}") - - def description_changed(self, event: monitoring.ServerDescriptionChangedEvent) -> None: - previous_server_type = event.previous_description.server_type - new_server_type = event.new_description.server_type - if new_server_type != previous_server_type: - # server_type_name was added in PyMongo 3.4 - logging.info( - f"Server {event.server_address} changed type from " - f"{event.previous_description.server_type_name} to " - f"{event.new_description.server_type_name}" - ) - - def closed(self, event: monitoring.ServerClosedEvent) -> None: - logging.warning(f"Server {event.server_address} removed from topology {event.topology_id}") - - -class HeartbeatLogger(monitoring.ServerHeartbeatListener): - """A simple listener that logs server heartbeat events. - - Listens for :class:`~pymongo.monitoring.ServerHeartbeatStartedEvent`, - :class:`~pymongo.monitoring.ServerHeartbeatSucceededEvent`, - and :class:`~pymongo.monitoring.ServerHeartbeatFailedEvent` - events and logs them at the `INFO` severity level using :mod:`logging`. - - .. versionadded:: 3.11 - """ - - def started(self, event: monitoring.ServerHeartbeatStartedEvent) -> None: - logging.info(f"Heartbeat sent to server {event.connection_id}") - - def succeeded(self, event: monitoring.ServerHeartbeatSucceededEvent) -> None: - # The reply.document attribute was added in PyMongo 3.4. - logging.info( - f"Heartbeat to server {event.connection_id} " - "succeeded with reply " - f"{event.reply.document}" - ) - - def failed(self, event: monitoring.ServerHeartbeatFailedEvent) -> None: - logging.warning( - f"Heartbeat to server {event.connection_id} failed with error {event.reply}" - ) - - -class TopologyLogger(monitoring.TopologyListener): - """A simple listener that logs server topology events. - - Listens for :class:`~pymongo.monitoring.TopologyOpenedEvent`, - :class:`~pymongo.monitoring.TopologyDescriptionChangedEvent`, - and :class:`~pymongo.monitoring.TopologyClosedEvent` - events and logs them at the `INFO` severity level using :mod:`logging`. - - .. versionadded:: 3.11 - """ - - def opened(self, event: monitoring.TopologyOpenedEvent) -> None: - logging.info(f"Topology with id {event.topology_id} opened") - - def description_changed(self, event: monitoring.TopologyDescriptionChangedEvent) -> None: - logging.info(f"Topology description updated for topology id {event.topology_id}") - previous_topology_type = event.previous_description.topology_type - new_topology_type = event.new_description.topology_type - if new_topology_type != previous_topology_type: - # topology_type_name was added in PyMongo 3.4 - logging.info( - f"Topology {event.topology_id} changed type from " - f"{event.previous_description.topology_type_name} to " - f"{event.new_description.topology_type_name}" - ) - # The has_writable_server and has_readable_server methods - # were added in PyMongo 3.4. - if not event.new_description.has_writable_server(): - logging.warning("No writable servers available.") - if not event.new_description.has_readable_server(): - logging.warning("No readable servers available.") - - def closed(self, event: monitoring.TopologyClosedEvent) -> None: - logging.info(f"Topology with id {event.topology_id} closed") - - -class ConnectionPoolLogger(monitoring.ConnectionPoolListener): - """A simple listener that logs server connection pool events. - - Listens for :class:`~pymongo.monitoring.PoolCreatedEvent`, - :class:`~pymongo.monitoring.PoolClearedEvent`, - :class:`~pymongo.monitoring.PoolClosedEvent`, - :~pymongo.monitoring.class:`ConnectionCreatedEvent`, - :class:`~pymongo.monitoring.ConnectionReadyEvent`, - :class:`~pymongo.monitoring.ConnectionClosedEvent`, - :class:`~pymongo.monitoring.ConnectionCheckOutStartedEvent`, - :class:`~pymongo.monitoring.ConnectionCheckOutFailedEvent`, - :class:`~pymongo.monitoring.ConnectionCheckedOutEvent`, - and :class:`~pymongo.monitoring.ConnectionCheckedInEvent` - events and logs them at the `INFO` severity level using :mod:`logging`. - - .. versionadded:: 3.11 - """ - - def pool_created(self, event: monitoring.PoolCreatedEvent) -> None: - logging.info(f"[pool {event.address}] pool created") - - def pool_ready(self, event: monitoring.PoolReadyEvent) -> None: - logging.info(f"[pool {event.address}] pool ready") - - def pool_cleared(self, event: monitoring.PoolClearedEvent) -> None: - logging.info(f"[pool {event.address}] pool cleared") - - def pool_closed(self, event: monitoring.PoolClosedEvent) -> None: - logging.info(f"[pool {event.address}] pool closed") - - def connection_created(self, event: monitoring.ConnectionCreatedEvent) -> None: - logging.info(f"[pool {event.address}][conn #{event.connection_id}] connection created") - - def connection_ready(self, event: monitoring.ConnectionReadyEvent) -> None: - logging.info( - f"[pool {event.address}][conn #{event.connection_id}] connection setup succeeded" - ) - - def connection_closed(self, event: monitoring.ConnectionClosedEvent) -> None: - logging.info( - f"[pool {event.address}][conn #{event.connection_id}] " - f'connection closed, reason: "{event.reason}"' - ) - - def connection_check_out_started( - self, event: monitoring.ConnectionCheckOutStartedEvent - ) -> None: - logging.info(f"[pool {event.address}] connection check out started") - - def connection_check_out_failed(self, event: monitoring.ConnectionCheckOutFailedEvent) -> None: - logging.info(f"[pool {event.address}] connection check out failed, reason: {event.reason}") - - def connection_checked_out(self, event: monitoring.ConnectionCheckedOutEvent) -> None: - logging.info( - f"[pool {event.address}][conn #{event.connection_id}] connection checked out of pool" - ) - - def connection_checked_in(self, event: monitoring.ConnectionCheckedInEvent) -> None: - logging.info( - f"[pool {event.address}][conn #{event.connection_id}] connection checked into pool" - ) diff --git a/pymongo/synchronous/hello.py b/pymongo/synchronous/hello.py deleted file mode 100644 index 5c1d8438fc..0000000000 --- a/pymongo/synchronous/hello.py +++ /dev/null @@ -1,219 +0,0 @@ -# Copyright 2021-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Helpers for the 'hello' and legacy hello commands.""" -from __future__ import annotations - -import copy -import datetime -import itertools -from typing import Any, Generic, Mapping, Optional - -from bson.objectid import ObjectId -from pymongo.server_type import SERVER_TYPE -from pymongo.synchronous import common -from pymongo.synchronous.hello_compat import HelloCompat -from pymongo.synchronous.typings import ClusterTime, _DocumentType - -_IS_SYNC = True - - -def _get_server_type(doc: Mapping[str, Any]) -> int: - """Determine the server type from a hello response.""" - if not doc.get("ok"): - return SERVER_TYPE.Unknown - - if doc.get("serviceId"): - return SERVER_TYPE.LoadBalancer - elif doc.get("isreplicaset"): - return SERVER_TYPE.RSGhost - elif doc.get("setName"): - if doc.get("hidden"): - return SERVER_TYPE.RSOther - elif doc.get(HelloCompat.PRIMARY): - return SERVER_TYPE.RSPrimary - elif doc.get(HelloCompat.LEGACY_PRIMARY): - return SERVER_TYPE.RSPrimary - elif doc.get("secondary"): - return SERVER_TYPE.RSSecondary - elif doc.get("arbiterOnly"): - return SERVER_TYPE.RSArbiter - else: - return SERVER_TYPE.RSOther - elif doc.get("msg") == "isdbgrid": - return SERVER_TYPE.Mongos - else: - return SERVER_TYPE.Standalone - - -class Hello(Generic[_DocumentType]): - """Parse a hello response from the server. - - .. versionadded:: 3.12 - """ - - __slots__ = ("_doc", "_server_type", "_is_writable", "_is_readable", "_awaitable") - - def __init__(self, doc: _DocumentType, awaitable: bool = False) -> None: - self._server_type = _get_server_type(doc) - self._doc: _DocumentType = doc - self._is_writable = self._server_type in ( - SERVER_TYPE.RSPrimary, - SERVER_TYPE.Standalone, - SERVER_TYPE.Mongos, - SERVER_TYPE.LoadBalancer, - ) - - self._is_readable = self.server_type == SERVER_TYPE.RSSecondary or self._is_writable - self._awaitable = awaitable - - @property - def document(self) -> _DocumentType: - """The complete hello command response document. - - .. versionadded:: 3.4 - """ - return copy.copy(self._doc) - - @property - def server_type(self) -> int: - return self._server_type - - @property - def all_hosts(self) -> set[tuple[str, int]]: - """List of hosts, passives, and arbiters known to this server.""" - return set( - map( - common.clean_node, - itertools.chain( - self._doc.get("hosts", []), - self._doc.get("passives", []), - self._doc.get("arbiters", []), - ), - ) - ) - - @property - def tags(self) -> Mapping[str, Any]: - """Replica set member tags or empty dict.""" - return self._doc.get("tags", {}) - - @property - def primary(self) -> Optional[tuple[str, int]]: - """This server's opinion about who the primary is, or None.""" - if self._doc.get("primary"): - return common.partition_node(self._doc["primary"]) - else: - return None - - @property - def replica_set_name(self) -> Optional[str]: - """Replica set name or None.""" - return self._doc.get("setName") - - @property - def max_bson_size(self) -> int: - return self._doc.get("maxBsonObjectSize", common.MAX_BSON_SIZE) - - @property - def max_message_size(self) -> int: - return self._doc.get("maxMessageSizeBytes", 2 * self.max_bson_size) - - @property - def max_write_batch_size(self) -> int: - return self._doc.get("maxWriteBatchSize", common.MAX_WRITE_BATCH_SIZE) - - @property - def min_wire_version(self) -> int: - return self._doc.get("minWireVersion", common.MIN_WIRE_VERSION) - - @property - def max_wire_version(self) -> int: - return self._doc.get("maxWireVersion", common.MAX_WIRE_VERSION) - - @property - def set_version(self) -> Optional[int]: - return self._doc.get("setVersion") - - @property - def election_id(self) -> Optional[ObjectId]: - return self._doc.get("electionId") - - @property - def cluster_time(self) -> Optional[ClusterTime]: - return self._doc.get("$clusterTime") - - @property - def logical_session_timeout_minutes(self) -> Optional[int]: - return self._doc.get("logicalSessionTimeoutMinutes") - - @property - def is_writable(self) -> bool: - return self._is_writable - - @property - def is_readable(self) -> bool: - return self._is_readable - - @property - def me(self) -> Optional[tuple[str, int]]: - me = self._doc.get("me") - if me: - return common.clean_node(me) - return None - - @property - def last_write_date(self) -> Optional[datetime.datetime]: - return self._doc.get("lastWrite", {}).get("lastWriteDate") - - @property - def compressors(self) -> Optional[list[str]]: - return self._doc.get("compression") - - @property - def sasl_supported_mechs(self) -> list[str]: - """Supported authentication mechanisms for the current user. - - For example:: - - >>> hello.sasl_supported_mechs - ["SCRAM-SHA-1", "SCRAM-SHA-256"] - - """ - return self._doc.get("saslSupportedMechs", []) - - @property - def speculative_authenticate(self) -> Optional[Mapping[str, Any]]: - """The speculativeAuthenticate field.""" - return self._doc.get("speculativeAuthenticate") - - @property - def topology_version(self) -> Optional[Mapping[str, Any]]: - return self._doc.get("topologyVersion") - - @property - def awaitable(self) -> bool: - return self._awaitable - - @property - def service_id(self) -> Optional[ObjectId]: - return self._doc.get("serviceId") - - @property - def hello_ok(self) -> bool: - return self._doc.get("helloOk", False) - - @property - def connection_id(self) -> Optional[int]: - return self._doc.get("connectionId") diff --git a/pymongo/synchronous/hello_compat.py b/pymongo/synchronous/hello_compat.py deleted file mode 100644 index 126ed4bf54..0000000000 --- a/pymongo/synchronous/hello_compat.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright 2024-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""The HelloCompat class, placed here to break circular import issues.""" -from __future__ import annotations - -_IS_SYNC = True - - -class HelloCompat: - CMD = "hello" - LEGACY_CMD = "ismaster" - PRIMARY = "isWritablePrimary" - LEGACY_PRIMARY = "ismaster" - LEGACY_ERROR = "not master" diff --git a/pymongo/synchronous/helpers.py b/pymongo/synchronous/helpers.py index 892d6a93e3..56d20c7c10 100644 --- a/pymongo/synchronous/helpers.py +++ b/pymongo/synchronous/helpers.py @@ -1,4 +1,4 @@ -# Copyright 2009-present MongoDB, Inc. +# Copyright 2024-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,270 +12,25 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Bits and pieces used by the driver that don't really fit elsewhere.""" +"""Miscellaneous pieces that need to be synchronized.""" from __future__ import annotations import builtins import sys -import traceback -from collections import abc from typing import ( - TYPE_CHECKING, Any, Callable, - Container, - Iterable, - Mapping, - NoReturn, - Optional, - Sequence, TypeVar, - Union, cast, ) -from pymongo import ASCENDING from pymongo.errors import ( - CursorNotFound, - DuplicateKeyError, - ExecutionTimeout, - NotPrimaryError, OperationFailure, - WriteConcernError, - WriteError, - WTimeoutError, - _wtimeout_error, ) -from pymongo.helpers_constants import _NOT_PRIMARY_CODES, _REAUTHENTICATION_REQUIRED_CODE -from pymongo.synchronous.hello_compat import HelloCompat - -if TYPE_CHECKING: - from pymongo.cursor_shared import _Hint - from pymongo.synchronous.operations import _IndexList - from pymongo.synchronous.typings import _DocumentOut +from pymongo.helpers_shared import _REAUTHENTICATION_REQUIRED_CODE _IS_SYNC = True - -def _gen_index_name(keys: _IndexList) -> str: - """Generate an index name from the set of fields it is over.""" - return "_".join(["{}_{}".format(*item) for item in keys]) - - -def _index_list( - key_or_list: _Hint, direction: Optional[Union[int, str]] = None -) -> Sequence[tuple[str, Union[int, str, Mapping[str, Any]]]]: - """Helper to generate a list of (key, direction) pairs. - - Takes such a list, or a single key, or a single key and direction. - """ - if direction is not None: - if not isinstance(key_or_list, str): - raise TypeError("Expected a string and a direction") - return [(key_or_list, direction)] - else: - if isinstance(key_or_list, str): - return [(key_or_list, ASCENDING)] - elif isinstance(key_or_list, abc.ItemsView): - return list(key_or_list) # type: ignore[arg-type] - elif isinstance(key_or_list, abc.Mapping): - return list(key_or_list.items()) - elif not isinstance(key_or_list, (list, tuple)): - raise TypeError("if no direction is specified, key_or_list must be an instance of list") - values: list[tuple[str, int]] = [] - for item in key_or_list: - if isinstance(item, str): - item = (item, ASCENDING) # noqa: PLW2901 - values.append(item) - return values - - -def _index_document(index_list: _IndexList) -> dict[str, Any]: - """Helper to generate an index specifying document. - - Takes a list of (key, direction) pairs. - """ - if not isinstance(index_list, (list, tuple, abc.Mapping)): - raise TypeError( - "must use a dictionary or a list of (key, direction) pairs, not: " + repr(index_list) - ) - if not len(index_list): - raise ValueError("key_or_list must not be empty") - - index: dict[str, Any] = {} - - if isinstance(index_list, abc.Mapping): - for key in index_list: - value = index_list[key] - _validate_index_key_pair(key, value) - index[key] = value - else: - for item in index_list: - if isinstance(item, str): - item = (item, ASCENDING) # noqa: PLW2901 - key, value = item - _validate_index_key_pair(key, value) - index[key] = value - return index - - -def _validate_index_key_pair(key: Any, value: Any) -> None: - if not isinstance(key, str): - raise TypeError("first item in each key pair must be an instance of str") - if not isinstance(value, (str, int, abc.Mapping)): - raise TypeError( - "second item in each key pair must be 1, -1, " - "'2d', or another valid MongoDB index specifier." - ) - - -def _check_command_response( - response: _DocumentOut, - max_wire_version: Optional[int], - allowable_errors: Optional[Container[Union[int, str]]] = None, - parse_write_concern_error: bool = False, -) -> None: - """Check the response to a command for errors.""" - if "ok" not in response: - # Server didn't recognize our message as a command. - raise OperationFailure( - response.get("$err"), # type: ignore[arg-type] - response.get("code"), - response, - max_wire_version, - ) - - if parse_write_concern_error and "writeConcernError" in response: - _error = response["writeConcernError"] - _labels = response.get("errorLabels") - if _labels: - _error.update({"errorLabels": _labels}) - _raise_write_concern_error(_error) - - if response["ok"]: - return - - details = response - # Mongos returns the error details in a 'raw' object - # for some errors. - if "raw" in response: - for shard in response["raw"].values(): - # Grab the first non-empty raw error from a shard. - if shard.get("errmsg") and not shard.get("ok"): - details = shard - break - - errmsg = details["errmsg"] - code = details.get("code") - - # For allowable errors, only check for error messages when the code is not - # included. - if allowable_errors: - if code is not None: - if code in allowable_errors: - return - elif errmsg in allowable_errors: - return - - # Server is "not primary" or "recovering" - if code is not None: - if code in _NOT_PRIMARY_CODES: - raise NotPrimaryError(errmsg, response) - elif HelloCompat.LEGACY_ERROR in errmsg or "node is recovering" in errmsg: - raise NotPrimaryError(errmsg, response) - - # Other errors - # findAndModify with upsert can raise duplicate key error - if code in (11000, 11001, 12582): - raise DuplicateKeyError(errmsg, code, response, max_wire_version) - elif code == 50: - raise ExecutionTimeout(errmsg, code, response, max_wire_version) - elif code == 43: - raise CursorNotFound(errmsg, code, response, max_wire_version) - - raise OperationFailure(errmsg, code, response, max_wire_version) - - -def _raise_last_write_error(write_errors: list[Any]) -> NoReturn: - # If the last batch had multiple errors only report - # the last error to emulate continue_on_error. - error = write_errors[-1] - if error.get("code") == 11000: - raise DuplicateKeyError(error.get("errmsg"), 11000, error) - raise WriteError(error.get("errmsg"), error.get("code"), error) - - -def _raise_write_concern_error(error: Any) -> NoReturn: - if _wtimeout_error(error): - # Make sure we raise WTimeoutError - raise WTimeoutError(error.get("errmsg"), error.get("code"), error) - raise WriteConcernError(error.get("errmsg"), error.get("code"), error) - - -def _get_wce_doc(result: Mapping[str, Any]) -> Optional[Mapping[str, Any]]: - """Return the writeConcernError or None.""" - wce = result.get("writeConcernError") - if wce: - # The server reports errorLabels at the top level but it's more - # convenient to attach it to the writeConcernError doc itself. - error_labels = result.get("errorLabels") - if error_labels: - # Copy to avoid changing the original document. - wce = wce.copy() - wce["errorLabels"] = error_labels - return wce - - -def _check_write_command_response(result: Mapping[str, Any]) -> None: - """Backward compatibility helper for write command error handling.""" - # Prefer write errors over write concern errors - write_errors = result.get("writeErrors") - if write_errors: - _raise_last_write_error(write_errors) - - wce = _get_wce_doc(result) - if wce: - _raise_write_concern_error(wce) - - -def _fields_list_to_dict( - fields: Union[Mapping[str, Any], Iterable[str]], option_name: str -) -> Mapping[str, Any]: - """Takes a sequence of field names and returns a matching dictionary. - - ["a", "b"] becomes {"a": 1, "b": 1} - - and - - ["a.b.c", "d", "a.c"] becomes {"a.b.c": 1, "d": 1, "a.c": 1} - """ - if isinstance(fields, abc.Mapping): - return fields - - if isinstance(fields, (abc.Sequence, abc.Set)): - if not all(isinstance(field, str) for field in fields): - raise TypeError(f"{option_name} must be a list of key names, each an instance of str") - return dict.fromkeys(fields, 1) - - raise TypeError(f"{option_name} must be a mapping or list of key names") - - -def _handle_exception() -> None: - """Print exceptions raised by subscribers to stderr.""" - # Heavily influenced by logging.Handler.handleError. - - # See note here: - # https://docs.python.org/3.4/library/sys.html#sys.__stderr__ - if sys.stderr: - einfo = sys.exc_info() - try: - traceback.print_exception(einfo[0], einfo[1], einfo[2], None, sys.stderr) - except OSError: - pass - finally: - del einfo - - # See https://mypy.readthedocs.io/en/stable/generics.html?#decorator-factories F = TypeVar("F", bound=Callable[..., Any]) @@ -292,7 +47,7 @@ def inner(*args: Any, **kwargs: Any) -> Any: if no_reauth: raise if exc.code == _REAUTHENTICATION_REQUIRED_CODE: - # Look for an argument that either is a Connection + # Look for an argument that either is a AsyncConnection # or has a connection attribute, so we can trigger # a reauth. conn = None diff --git a/pymongo/synchronous/logger.py b/pymongo/synchronous/logger.py deleted file mode 100644 index d0f539ee6f..0000000000 --- a/pymongo/synchronous/logger.py +++ /dev/null @@ -1,171 +0,0 @@ -# Copyright 2023-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from __future__ import annotations - -import enum -import logging -import os -import warnings -from typing import Any - -from bson import UuidRepresentation, json_util -from bson.json_util import JSONOptions, _truncate_documents -from pymongo.synchronous.monitoring import ConnectionCheckOutFailedReason, ConnectionClosedReason - -_IS_SYNC = True - - -class _CommandStatusMessage(str, enum.Enum): - STARTED = "Command started" - SUCCEEDED = "Command succeeded" - FAILED = "Command failed" - - -class _ServerSelectionStatusMessage(str, enum.Enum): - STARTED = "Server selection started" - SUCCEEDED = "Server selection succeeded" - FAILED = "Server selection failed" - WAITING = "Waiting for suitable server to become available" - - -class _ConnectionStatusMessage(str, enum.Enum): - POOL_CREATED = "Connection pool created" - POOL_READY = "Connection pool ready" - POOL_CLOSED = "Connection pool closed" - POOL_CLEARED = "Connection pool cleared" - - CONN_CREATED = "Connection created" - CONN_READY = "Connection ready" - CONN_CLOSED = "Connection closed" - - CHECKOUT_STARTED = "Connection checkout started" - CHECKOUT_SUCCEEDED = "Connection checked out" - CHECKOUT_FAILED = "Connection checkout failed" - CHECKEDIN = "Connection checked in" - - -_DEFAULT_DOCUMENT_LENGTH = 1000 -_SENSITIVE_COMMANDS = [ - "authenticate", - "saslStart", - "saslContinue", - "getnonce", - "createUser", - "updateUser", - "copydbgetnonce", - "copydbsaslstart", - "copydb", -] -_HELLO_COMMANDS = ["hello", "ismaster", "isMaster"] -_REDACTED_FAILURE_FIELDS = ["code", "codeName", "errorLabels"] -_DOCUMENT_NAMES = ["command", "reply", "failure"] -_JSON_OPTIONS = JSONOptions(uuid_representation=UuidRepresentation.STANDARD) -_COMMAND_LOGGER = logging.getLogger("pymongo.command") -_CONNECTION_LOGGER = logging.getLogger("pymongo.connection") -_SERVER_SELECTION_LOGGER = logging.getLogger("pymongo.serverSelection") -_CLIENT_LOGGER = logging.getLogger("pymongo.client") -_VERBOSE_CONNECTION_ERROR_REASONS = { - ConnectionClosedReason.POOL_CLOSED: "Connection pool was closed", - ConnectionCheckOutFailedReason.POOL_CLOSED: "Connection pool was closed", - ConnectionClosedReason.STALE: "Connection pool was stale", - ConnectionClosedReason.ERROR: "An error occurred while using the connection", - ConnectionCheckOutFailedReason.CONN_ERROR: "An error occurred while trying to establish a new connection", - ConnectionClosedReason.IDLE: "Connection was idle too long", - ConnectionCheckOutFailedReason.TIMEOUT: "Connection exceeded the specified timeout", -} - - -def _debug_log(logger: logging.Logger, **fields: Any) -> None: - logger.debug(LogMessage(**fields)) - - -def _verbose_connection_error_reason(reason: str) -> str: - return _VERBOSE_CONNECTION_ERROR_REASONS.get(reason, reason) - - -def _info_log(logger: logging.Logger, **fields: Any) -> None: - logger.info(LogMessage(**fields)) - - -def _log_or_warn(logger: logging.Logger, message: str) -> None: - if logger.isEnabledFor(logging.INFO): - logger.info(message) - else: - # stacklevel=4 ensures that the warning is for the user's code. - warnings.warn(message, UserWarning, stacklevel=4) - - -class LogMessage: - __slots__ = ("_kwargs", "_redacted") - - def __init__(self, **kwargs: Any): - self._kwargs = kwargs - self._redacted = False - - def __str__(self) -> str: - self._redact() - return "%s" % ( - json_util.dumps( - self._kwargs, json_options=_JSON_OPTIONS, default=lambda o: o.__repr__() - ) - ) - - def _is_sensitive(self, doc_name: str) -> bool: - is_speculative_authenticate = ( - self._kwargs.pop("speculative_authenticate", False) - or "speculativeAuthenticate" in self._kwargs[doc_name] - ) - is_sensitive_command = ( - "commandName" in self._kwargs and self._kwargs["commandName"] in _SENSITIVE_COMMANDS - ) - - is_sensitive_hello = ( - self._kwargs["commandName"] in _HELLO_COMMANDS and is_speculative_authenticate - ) - - return is_sensitive_command or is_sensitive_hello - - def _redact(self) -> None: - if self._redacted: - return - self._kwargs = {k: v for k, v in self._kwargs.items() if v is not None} - if "durationMS" in self._kwargs and hasattr(self._kwargs["durationMS"], "total_seconds"): - self._kwargs["durationMS"] = self._kwargs["durationMS"].total_seconds() * 1000 - if "serviceId" in self._kwargs: - self._kwargs["serviceId"] = str(self._kwargs["serviceId"]) - document_length = int(os.getenv("MONGOB_LOG_MAX_DOCUMENT_LENGTH", _DEFAULT_DOCUMENT_LENGTH)) - if document_length < 0: - document_length = _DEFAULT_DOCUMENT_LENGTH - is_server_side_error = self._kwargs.pop("isServerSideError", False) - - for doc_name in _DOCUMENT_NAMES: - doc = self._kwargs.get(doc_name) - if doc: - if doc_name == "failure" and is_server_side_error: - doc = {k: v for k, v in doc.items() if k in _REDACTED_FAILURE_FIELDS} - if doc_name != "failure" and self._is_sensitive(doc_name): - doc = json_util.dumps({}) - else: - truncated_doc = _truncate_documents(doc, document_length)[0] - doc = json_util.dumps( - truncated_doc, - json_options=_JSON_OPTIONS, - default=lambda o: o.__repr__(), - ) - if len(doc) > document_length: - doc = ( - doc.encode()[:document_length].decode("unicode-escape", "ignore") - ) + "..." - self._kwargs[doc_name] = doc - self._redacted = True diff --git a/pymongo/synchronous/max_staleness_selectors.py b/pymongo/synchronous/max_staleness_selectors.py deleted file mode 100644 index cde43890df..0000000000 --- a/pymongo/synchronous/max_staleness_selectors.py +++ /dev/null @@ -1,125 +0,0 @@ -# Copyright 2016 MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you -# may not use this file except in compliance with the License. You -# may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or -# implied. See the License for the specific language governing -# permissions and limitations under the License. - -"""Criteria to select ServerDescriptions based on maxStalenessSeconds. - -The Max Staleness Spec says: When there is a known primary P, -a secondary S's staleness is estimated with this formula: - - (S.lastUpdateTime - S.lastWriteDate) - (P.lastUpdateTime - P.lastWriteDate) - + heartbeatFrequencyMS - -When there is no known primary, a secondary S's staleness is estimated with: - - SMax.lastWriteDate - S.lastWriteDate + heartbeatFrequencyMS - -where "SMax" is the secondary with the greatest lastWriteDate. -""" -from __future__ import annotations - -from typing import TYPE_CHECKING - -from pymongo.errors import ConfigurationError -from pymongo.server_type import SERVER_TYPE - -if TYPE_CHECKING: - from pymongo.synchronous.server_selectors import Selection - -_IS_SYNC = True - -# Constant defined in Max Staleness Spec: An idle primary writes a no-op every -# 10 seconds to refresh secondaries' lastWriteDate values. -IDLE_WRITE_PERIOD = 10 -SMALLEST_MAX_STALENESS = 90 - - -def _validate_max_staleness(max_staleness: int, heartbeat_frequency: int) -> None: - # We checked for max staleness -1 before this, it must be positive here. - if max_staleness < heartbeat_frequency + IDLE_WRITE_PERIOD: - raise ConfigurationError( - "maxStalenessSeconds must be at least heartbeatFrequencyMS +" - " %d seconds. maxStalenessSeconds is set to %d," - " heartbeatFrequencyMS is set to %d." - % (IDLE_WRITE_PERIOD, max_staleness, heartbeat_frequency * 1000) - ) - - if max_staleness < SMALLEST_MAX_STALENESS: - raise ConfigurationError( - "maxStalenessSeconds must be at least %d. " - "maxStalenessSeconds is set to %d." % (SMALLEST_MAX_STALENESS, max_staleness) - ) - - -def _with_primary(max_staleness: int, selection: Selection) -> Selection: - """Apply max_staleness, in seconds, to a Selection with a known primary.""" - primary = selection.primary - assert primary - sds = [] - - for s in selection.server_descriptions: - if s.server_type == SERVER_TYPE.RSSecondary: - # See max-staleness.rst for explanation of this formula. - assert s.last_write_date and primary.last_write_date # noqa: PT018 - staleness = ( - (s.last_update_time - s.last_write_date) - - (primary.last_update_time - primary.last_write_date) - + selection.heartbeat_frequency - ) - - if staleness <= max_staleness: - sds.append(s) - else: - sds.append(s) - - return selection.with_server_descriptions(sds) - - -def _no_primary(max_staleness: int, selection: Selection) -> Selection: - """Apply max_staleness, in seconds, to a Selection with no known primary.""" - # Secondary that's replicated the most recent writes. - smax = selection.secondary_with_max_last_write_date() - if not smax: - # No secondaries and no primary, short-circuit out of here. - return selection.with_server_descriptions([]) - - sds = [] - - for s in selection.server_descriptions: - if s.server_type == SERVER_TYPE.RSSecondary: - # See max-staleness.rst for explanation of this formula. - assert smax.last_write_date and s.last_write_date # noqa: PT018 - staleness = smax.last_write_date - s.last_write_date + selection.heartbeat_frequency - - if staleness <= max_staleness: - sds.append(s) - else: - sds.append(s) - - return selection.with_server_descriptions(sds) - - -def select(max_staleness: int, selection: Selection) -> Selection: - """Apply max_staleness, in seconds, to a Selection.""" - if max_staleness == -1: - return selection - - # Server Selection Spec: If the TopologyType is ReplicaSetWithPrimary or - # ReplicaSetNoPrimary, a client MUST raise an error if maxStaleness < - # heartbeatFrequency + IDLE_WRITE_PERIOD, or if maxStaleness < 90. - _validate_max_staleness(max_staleness, selection.heartbeat_frequency) - - if selection.primary: - return _with_primary(max_staleness, selection) - else: - return _no_primary(max_staleness, selection) diff --git a/pymongo/synchronous/message.py b/pymongo/synchronous/message.py index 0eca1e8f15..973345f3d2 100644 --- a/pymongo/synchronous/message.py +++ b/pymongo/synchronous/message.py @@ -64,27 +64,27 @@ OperationFailure, ProtocolError, ) -from pymongo.synchronous.hello_compat import HelloCompat -from pymongo.synchronous.helpers import _handle_reauth -from pymongo.synchronous.logger import ( +from pymongo.hello_compat import HelloCompat +from pymongo.logger import ( _COMMAND_LOGGER, _CommandStatusMessage, _debug_log, ) -from pymongo.synchronous.read_preferences import ReadPreference +from pymongo.read_preferences import ReadPreference +from pymongo.synchronous.helpers import _handle_reauth from pymongo.write_concern import WriteConcern if TYPE_CHECKING: from datetime import timedelta + from pymongo.compression_support import SnappyContext, ZlibContext, ZstdContext + from pymongo.monitoring import _EventListeners from pymongo.read_concern import ReadConcern + from pymongo.read_preferences import _ServerMode from pymongo.synchronous.client_session import ClientSession - from pymongo.synchronous.compression_support import SnappyContext, ZlibContext, ZstdContext from pymongo.synchronous.mongo_client import MongoClient - from pymongo.synchronous.monitoring import _EventListeners from pymongo.synchronous.pool import Connection - from pymongo.synchronous.read_preferences import _ServerMode - from pymongo.synchronous.typings import _Address, _DocumentOut + from pymongo.typings import _Address, _DocumentOut _IS_SYNC = True @@ -908,7 +908,7 @@ def _get_more( class _BulkWriteContext: - """A wrapper around Connection for use with write splitting functions.""" + """A wrapper around AsyncConnection for use with write splitting functions.""" __slots__ = ( "db_name", @@ -1012,7 +1012,7 @@ def unack_write( docs: list[Mapping[str, Any]], client: MongoClient, ) -> Optional[Mapping[str, Any]]: - """A proxy for Connection.unack_write that handles event publishing.""" + """A proxy for AsyncConnection.unack_write that handles event publishing.""" if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _COMMAND_LOGGER, diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index a44a4e039e..bd15409ecb 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -57,7 +57,8 @@ from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions, TypeRegistry from bson.timestamp import Timestamp -from pymongo import _csot, helpers_constants +from pymongo import _csot, common, helpers_shared, uri_parser +from pymongo.client_options import ClientOptions from pymongo.errors import ( AutoReconnect, BulkWriteError, @@ -72,29 +73,20 @@ WriteConcernError, ) from pymongo.lock import _HAS_REGISTER_AT_FORK, _create_lock, _release_locks +from pymongo.logger import _CLIENT_LOGGER, _log_or_warn +from pymongo.monitoring import ConnectionClosedReason +from pymongo.operations import _Op +from pymongo.read_preferences import ReadPreference, _ServerMode +from pymongo.server_selectors import writable_server_selector from pymongo.server_type import SERVER_TYPE -from pymongo.synchronous import ( - client_session, - common, - database, - helpers, - message, - periodic_executor, - uri_parser, -) +from pymongo.synchronous import client_session, database, message, periodic_executor from pymongo.synchronous.change_stream import ChangeStream, ClusterChangeStream -from pymongo.synchronous.client_options import ClientOptions from pymongo.synchronous.client_session import _EmptyServerSession from pymongo.synchronous.command_cursor import CommandCursor -from pymongo.synchronous.logger import _CLIENT_LOGGER, _log_or_warn -from pymongo.synchronous.monitoring import ConnectionClosedReason -from pymongo.synchronous.operations import _Op -from pymongo.synchronous.read_preferences import ReadPreference, _ServerMode -from pymongo.synchronous.server_selectors import writable_server_selector from pymongo.synchronous.settings import TopologySettings from pymongo.synchronous.topology import Topology, _ErrorContext -from pymongo.synchronous.topology_description import TOPOLOGY_TYPE, TopologyDescription -from pymongo.synchronous.typings import ( +from pymongo.topology_description import TOPOLOGY_TYPE, TopologyDescription +from pymongo.typings import ( ClusterTime, _Address, _CollationIn, @@ -102,7 +94,7 @@ _DocumentTypeArg, _Pipeline, ) -from pymongo.synchronous.uri_parser import ( +from pymongo.uri_parser import ( _check_options, _handle_option_deprecations, _handle_security_options, @@ -116,14 +108,14 @@ from bson.objectid import ObjectId from pymongo.read_concern import ReadConcern + from pymongo.response import Response + from pymongo.server_selectors import Selection from pymongo.synchronous.bulk import _Bulk from pymongo.synchronous.client_session import ClientSession, _ServerSession from pymongo.synchronous.cursor import _ConnectionManager from pymongo.synchronous.message import _CursorAddress, _GetMore, _Query from pymongo.synchronous.pool import Connection - from pymongo.synchronous.response import Response from pymongo.synchronous.server import Server - from pymongo.synchronous.server_selectors import Selection if sys.version_info[:2] >= (3, 9): pass @@ -134,7 +126,10 @@ T = TypeVar("T") _WriteCall = Callable[[Optional["ClientSession"], "Connection", bool], T] -_ReadCall = Callable[[Optional["ClientSession"], "Server", "Connection", _ServerMode], T] +_ReadCall = Callable[ + [Optional["ClientSession"], "Server", "Connection", _ServerMode], + T, +] _IS_SYNC = True @@ -1975,7 +1970,7 @@ def _process_kill_cursors(self) -> None: # can be caught in _process_periodic_tasks raise else: - helpers._handle_exception() + helpers_shared._handle_exception() # Don't re-open topology if it's closed and there's no pending cursors. if address_to_cursor_ids: @@ -1987,7 +1982,7 @@ def _process_kill_cursors(self) -> None: if isinstance(exc, InvalidOperation) and self._topology._closed: raise else: - helpers._handle_exception() + helpers_shared._handle_exception() # This method is run periodically by a background thread. def _process_periodic_tasks(self) -> None: @@ -2001,7 +1996,7 @@ def _process_periodic_tasks(self) -> None: if isinstance(exc, InvalidOperation) and self._topology._closed: return else: - helpers._handle_exception() + helpers_shared._handle_exception() def _return_server_session( self, server_session: Union[_ServerSession, _EmptyServerSession], lock: bool @@ -2211,7 +2206,7 @@ def _add_retryable_write_error(exc: PyMongoError, max_wire_version: int, is_mong # Do not consult writeConcernError for pre-4.4 mongos. if isinstance(exc, WriteConcernError) and is_mongos: pass - elif code in helpers_constants._RETRYABLE_ERROR_CODES: + elif code in helpers_shared._RETRYABLE_ERROR_CODES: exc._add_error_label("RetryableWriteError") # Connection errors are always retryable except NotPrimaryError and WaitQueueTimeoutError which is @@ -2367,7 +2362,7 @@ def run(self) -> T: exc_code = getattr(exc, "code", None) if self._is_not_eligible_for_retry() or ( isinstance(exc, OperationFailure) - and exc_code not in helpers_constants._RETRYABLE_ERROR_CODES + and exc_code not in helpers_shared._RETRYABLE_ERROR_CODES ): raise self._retrying = True diff --git a/pymongo/synchronous/monitor.py b/pymongo/synchronous/monitor.py index 96849e7349..8106c1922d 100644 --- a/pymongo/synchronous/monitor.py +++ b/pymongo/synchronous/monitor.py @@ -21,16 +21,17 @@ import weakref from typing import TYPE_CHECKING, Any, Mapping, Optional, cast +from pymongo import common from pymongo._csot import MovingMinimum from pymongo.errors import NetworkTimeout, NotPrimaryError, OperationFailure, _OperationCancelled +from pymongo.hello import Hello from pymongo.lock import _create_lock -from pymongo.synchronous import common, periodic_executor -from pymongo.synchronous.hello import Hello +from pymongo.pool_options import _is_faas +from pymongo.read_preferences import MovingAverage +from pymongo.server_description import ServerDescription +from pymongo.srv_resolver import _SrvResolver +from pymongo.synchronous import periodic_executor from pymongo.synchronous.periodic_executor import _shutdown_executors -from pymongo.synchronous.pool import _is_faas -from pymongo.synchronous.read_preferences import MovingAverage -from pymongo.synchronous.server_description import ServerDescription -from pymongo.synchronous.srv_resolver import _SrvResolver if TYPE_CHECKING: from pymongo.synchronous.pool import Connection, Pool, _CancellationContext diff --git a/pymongo/synchronous/monitoring.py b/pymongo/synchronous/monitoring.py deleted file mode 100644 index a4b7296881..0000000000 --- a/pymongo/synchronous/monitoring.py +++ /dev/null @@ -1,1903 +0,0 @@ -# Copyright 2015-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you -# may not use this file except in compliance with the License. You -# may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or -# implied. See the License for the specific language governing -# permissions and limitations under the License. - -"""Tools to monitor driver events. - -.. versionadded:: 3.1 - -.. attention:: Starting in PyMongo 3.11, the monitoring classes outlined below - are included in the PyMongo distribution under the - :mod:`~pymongo.event_loggers` submodule. - -Use :func:`register` to register global listeners for specific events. -Listeners must inherit from one of the abstract classes below and implement -the correct functions for that class. - -For example, a simple command logger might be implemented like this:: - - import logging - - from pymongo import monitoring - - class CommandLogger(monitoring.CommandListener): - - def started(self, event): - logging.info("Command {0.command_name} with request id " - "{0.request_id} started on server " - "{0.connection_id}".format(event)) - - def succeeded(self, event): - logging.info("Command {0.command_name} with request id " - "{0.request_id} on server {0.connection_id} " - "succeeded in {0.duration_micros} " - "microseconds".format(event)) - - def failed(self, event): - logging.info("Command {0.command_name} with request id " - "{0.request_id} on server {0.connection_id} " - "failed in {0.duration_micros} " - "microseconds".format(event)) - - monitoring.register(CommandLogger()) - -Server discovery and monitoring events are also available. For example:: - - class ServerLogger(monitoring.ServerListener): - - def opened(self, event): - logging.info("Server {0.server_address} added to topology " - "{0.topology_id}".format(event)) - - def description_changed(self, event): - previous_server_type = event.previous_description.server_type - new_server_type = event.new_description.server_type - if new_server_type != previous_server_type: - # server_type_name was added in PyMongo 3.4 - logging.info( - "Server {0.server_address} changed type from " - "{0.previous_description.server_type_name} to " - "{0.new_description.server_type_name}".format(event)) - - def closed(self, event): - logging.warning("Server {0.server_address} removed from topology " - "{0.topology_id}".format(event)) - - - class HeartbeatLogger(monitoring.ServerHeartbeatListener): - - def started(self, event): - logging.info("Heartbeat sent to server " - "{0.connection_id}".format(event)) - - def succeeded(self, event): - # The reply.document attribute was added in PyMongo 3.4. - logging.info("Heartbeat to server {0.connection_id} " - "succeeded with reply " - "{0.reply.document}".format(event)) - - def failed(self, event): - logging.warning("Heartbeat to server {0.connection_id} " - "failed with error {0.reply}".format(event)) - - class TopologyLogger(monitoring.TopologyListener): - - def opened(self, event): - logging.info("Topology with id {0.topology_id} " - "opened".format(event)) - - def description_changed(self, event): - logging.info("Topology description updated for " - "topology id {0.topology_id}".format(event)) - previous_topology_type = event.previous_description.topology_type - new_topology_type = event.new_description.topology_type - if new_topology_type != previous_topology_type: - # topology_type_name was added in PyMongo 3.4 - logging.info( - "Topology {0.topology_id} changed type from " - "{0.previous_description.topology_type_name} to " - "{0.new_description.topology_type_name}".format(event)) - # The has_writable_server and has_readable_server methods - # were added in PyMongo 3.4. - if not event.new_description.has_writable_server(): - logging.warning("No writable servers available.") - if not event.new_description.has_readable_server(): - logging.warning("No readable servers available.") - - def closed(self, event): - logging.info("Topology with id {0.topology_id} " - "closed".format(event)) - -Connection monitoring and pooling events are also available. For example:: - - class ConnectionPoolLogger(ConnectionPoolListener): - - def pool_created(self, event): - logging.info("[pool {0.address}] pool created".format(event)) - - def pool_ready(self, event): - logging.info("[pool {0.address}] pool is ready".format(event)) - - def pool_cleared(self, event): - logging.info("[pool {0.address}] pool cleared".format(event)) - - def pool_closed(self, event): - logging.info("[pool {0.address}] pool closed".format(event)) - - def connection_created(self, event): - logging.info("[pool {0.address}][connection #{0.connection_id}] " - "connection created".format(event)) - - def connection_ready(self, event): - logging.info("[pool {0.address}][connection #{0.connection_id}] " - "connection setup succeeded".format(event)) - - def connection_closed(self, event): - logging.info("[pool {0.address}][connection #{0.connection_id}] " - "connection closed, reason: " - "{0.reason}".format(event)) - - def connection_check_out_started(self, event): - logging.info("[pool {0.address}] connection check out " - "started".format(event)) - - def connection_check_out_failed(self, event): - logging.info("[pool {0.address}] connection check out " - "failed, reason: {0.reason}".format(event)) - - def connection_checked_out(self, event): - logging.info("[pool {0.address}][connection #{0.connection_id}] " - "connection checked out of pool".format(event)) - - def connection_checked_in(self, event): - logging.info("[pool {0.address}][connection #{0.connection_id}] " - "connection checked into pool".format(event)) - - -Event listeners can also be registered per instance of -:class:`~pymongo.mongo_client.MongoClient`:: - - client = MongoClient(event_listeners=[CommandLogger()]) - -Note that previously registered global listeners are automatically included -when configuring per client event listeners. Registering a new global listener -will not add that listener to existing client instances. - -.. note:: Events are delivered **synchronously**. Application threads block - waiting for event handlers (e.g. :meth:`~CommandListener.started`) to - return. Care must be taken to ensure that your event handlers are efficient - enough to not adversely affect overall application performance. - -.. warning:: The command documents published through this API are *not* copies. - If you intend to modify them in any way you must copy them in your event - handler first. -""" - -from __future__ import annotations - -import datetime -from collections import abc, namedtuple -from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence - -from bson.objectid import ObjectId -from pymongo.helpers_constants import _SENSITIVE_COMMANDS -from pymongo.synchronous.hello import Hello -from pymongo.synchronous.hello_compat import HelloCompat -from pymongo.synchronous.helpers import _handle_exception -from pymongo.synchronous.typings import _Address, _DocumentOut - -if TYPE_CHECKING: - from datetime import timedelta - - from pymongo.synchronous.server_description import ServerDescription - from pymongo.synchronous.topology_description import TopologyDescription - -_IS_SYNC = True - -_Listeners = namedtuple( - "_Listeners", - ( - "command_listeners", - "server_listeners", - "server_heartbeat_listeners", - "topology_listeners", - "cmap_listeners", - ), -) - -_LISTENERS = _Listeners([], [], [], [], []) - - -class _EventListener: - """Abstract base class for all event listeners.""" - - -class CommandListener(_EventListener): - """Abstract base class for command listeners. - - Handles `CommandStartedEvent`, `CommandSucceededEvent`, - and `CommandFailedEvent`. - """ - - def started(self, event: CommandStartedEvent) -> None: - """Abstract method to handle a `CommandStartedEvent`. - - :param event: An instance of :class:`CommandStartedEvent`. - """ - raise NotImplementedError - - def succeeded(self, event: CommandSucceededEvent) -> None: - """Abstract method to handle a `CommandSucceededEvent`. - - :param event: An instance of :class:`CommandSucceededEvent`. - """ - raise NotImplementedError - - def failed(self, event: CommandFailedEvent) -> None: - """Abstract method to handle a `CommandFailedEvent`. - - :param event: An instance of :class:`CommandFailedEvent`. - """ - raise NotImplementedError - - -class ConnectionPoolListener(_EventListener): - """Abstract base class for connection pool listeners. - - Handles all of the connection pool events defined in the Connection - Monitoring and Pooling Specification: - :class:`PoolCreatedEvent`, :class:`PoolClearedEvent`, - :class:`PoolClosedEvent`, :class:`ConnectionCreatedEvent`, - :class:`ConnectionReadyEvent`, :class:`ConnectionClosedEvent`, - :class:`ConnectionCheckOutStartedEvent`, - :class:`ConnectionCheckOutFailedEvent`, - :class:`ConnectionCheckedOutEvent`, - and :class:`ConnectionCheckedInEvent`. - - .. versionadded:: 3.9 - """ - - def pool_created(self, event: PoolCreatedEvent) -> None: - """Abstract method to handle a :class:`PoolCreatedEvent`. - - Emitted when a connection Pool is created. - - :param event: An instance of :class:`PoolCreatedEvent`. - """ - raise NotImplementedError - - def pool_ready(self, event: PoolReadyEvent) -> None: - """Abstract method to handle a :class:`PoolReadyEvent`. - - Emitted when a connection Pool is marked ready. - - :param event: An instance of :class:`PoolReadyEvent`. - - .. versionadded:: 4.0 - """ - raise NotImplementedError - - def pool_cleared(self, event: PoolClearedEvent) -> None: - """Abstract method to handle a `PoolClearedEvent`. - - Emitted when a connection Pool is cleared. - - :param event: An instance of :class:`PoolClearedEvent`. - """ - raise NotImplementedError - - def pool_closed(self, event: PoolClosedEvent) -> None: - """Abstract method to handle a `PoolClosedEvent`. - - Emitted when a connection Pool is closed. - - :param event: An instance of :class:`PoolClosedEvent`. - """ - raise NotImplementedError - - def connection_created(self, event: ConnectionCreatedEvent) -> None: - """Abstract method to handle a :class:`ConnectionCreatedEvent`. - - Emitted when a connection Pool creates a Connection object. - - :param event: An instance of :class:`ConnectionCreatedEvent`. - """ - raise NotImplementedError - - def connection_ready(self, event: ConnectionReadyEvent) -> None: - """Abstract method to handle a :class:`ConnectionReadyEvent`. - - Emitted when a connection has finished its setup, and is now ready to - use. - - :param event: An instance of :class:`ConnectionReadyEvent`. - """ - raise NotImplementedError - - def connection_closed(self, event: ConnectionClosedEvent) -> None: - """Abstract method to handle a :class:`ConnectionClosedEvent`. - - Emitted when a connection Pool closes a connection. - - :param event: An instance of :class:`ConnectionClosedEvent`. - """ - raise NotImplementedError - - def connection_check_out_started(self, event: ConnectionCheckOutStartedEvent) -> None: - """Abstract method to handle a :class:`ConnectionCheckOutStartedEvent`. - - Emitted when the driver starts attempting to check out a connection. - - :param event: An instance of :class:`ConnectionCheckOutStartedEvent`. - """ - raise NotImplementedError - - def connection_check_out_failed(self, event: ConnectionCheckOutFailedEvent) -> None: - """Abstract method to handle a :class:`ConnectionCheckOutFailedEvent`. - - Emitted when the driver's attempt to check out a connection fails. - - :param event: An instance of :class:`ConnectionCheckOutFailedEvent`. - """ - raise NotImplementedError - - def connection_checked_out(self, event: ConnectionCheckedOutEvent) -> None: - """Abstract method to handle a :class:`ConnectionCheckedOutEvent`. - - Emitted when the driver successfully checks out a connection. - - :param event: An instance of :class:`ConnectionCheckedOutEvent`. - """ - raise NotImplementedError - - def connection_checked_in(self, event: ConnectionCheckedInEvent) -> None: - """Abstract method to handle a :class:`ConnectionCheckedInEvent`. - - Emitted when the driver checks in a connection back to the connection - Pool. - - :param event: An instance of :class:`ConnectionCheckedInEvent`. - """ - raise NotImplementedError - - -class ServerHeartbeatListener(_EventListener): - """Abstract base class for server heartbeat listeners. - - Handles `ServerHeartbeatStartedEvent`, `ServerHeartbeatSucceededEvent`, - and `ServerHeartbeatFailedEvent`. - - .. versionadded:: 3.3 - """ - - def started(self, event: ServerHeartbeatStartedEvent) -> None: - """Abstract method to handle a `ServerHeartbeatStartedEvent`. - - :param event: An instance of :class:`ServerHeartbeatStartedEvent`. - """ - raise NotImplementedError - - def succeeded(self, event: ServerHeartbeatSucceededEvent) -> None: - """Abstract method to handle a `ServerHeartbeatSucceededEvent`. - - :param event: An instance of :class:`ServerHeartbeatSucceededEvent`. - """ - raise NotImplementedError - - def failed(self, event: ServerHeartbeatFailedEvent) -> None: - """Abstract method to handle a `ServerHeartbeatFailedEvent`. - - :param event: An instance of :class:`ServerHeartbeatFailedEvent`. - """ - raise NotImplementedError - - -class TopologyListener(_EventListener): - """Abstract base class for topology monitoring listeners. - Handles `TopologyOpenedEvent`, `TopologyDescriptionChangedEvent`, and - `TopologyClosedEvent`. - - .. versionadded:: 3.3 - """ - - def opened(self, event: TopologyOpenedEvent) -> None: - """Abstract method to handle a `TopologyOpenedEvent`. - - :param event: An instance of :class:`TopologyOpenedEvent`. - """ - raise NotImplementedError - - def description_changed(self, event: TopologyDescriptionChangedEvent) -> None: - """Abstract method to handle a `TopologyDescriptionChangedEvent`. - - :param event: An instance of :class:`TopologyDescriptionChangedEvent`. - """ - raise NotImplementedError - - def closed(self, event: TopologyClosedEvent) -> None: - """Abstract method to handle a `TopologyClosedEvent`. - - :param event: An instance of :class:`TopologyClosedEvent`. - """ - raise NotImplementedError - - -class ServerListener(_EventListener): - """Abstract base class for server listeners. - Handles `ServerOpeningEvent`, `ServerDescriptionChangedEvent`, and - `ServerClosedEvent`. - - .. versionadded:: 3.3 - """ - - def opened(self, event: ServerOpeningEvent) -> None: - """Abstract method to handle a `ServerOpeningEvent`. - - :param event: An instance of :class:`ServerOpeningEvent`. - """ - raise NotImplementedError - - def description_changed(self, event: ServerDescriptionChangedEvent) -> None: - """Abstract method to handle a `ServerDescriptionChangedEvent`. - - :param event: An instance of :class:`ServerDescriptionChangedEvent`. - """ - raise NotImplementedError - - def closed(self, event: ServerClosedEvent) -> None: - """Abstract method to handle a `ServerClosedEvent`. - - :param event: An instance of :class:`ServerClosedEvent`. - """ - raise NotImplementedError - - -def _to_micros(dur: timedelta) -> int: - """Convert duration 'dur' to microseconds.""" - return int(dur.total_seconds() * 10e5) - - -def _validate_event_listeners( - option: str, listeners: Sequence[_EventListeners] -) -> Sequence[_EventListeners]: - """Validate event listeners""" - if not isinstance(listeners, abc.Sequence): - raise TypeError(f"{option} must be a list or tuple") - for listener in listeners: - if not isinstance(listener, _EventListener): - raise TypeError( - f"Listeners for {option} must be either a " - "CommandListener, ServerHeartbeatListener, " - "ServerListener, TopologyListener, or " - "ConnectionPoolListener." - ) - return listeners - - -def register(listener: _EventListener) -> None: - """Register a global event listener. - - :param listener: A subclasses of :class:`CommandListener`, - :class:`ServerHeartbeatListener`, :class:`ServerListener`, - :class:`TopologyListener`, or :class:`ConnectionPoolListener`. - """ - if not isinstance(listener, _EventListener): - raise TypeError( - f"Listeners for {listener} must be either a " - "CommandListener, ServerHeartbeatListener, " - "ServerListener, TopologyListener, or " - "ConnectionPoolListener." - ) - if isinstance(listener, CommandListener): - _LISTENERS.command_listeners.append(listener) - if isinstance(listener, ServerHeartbeatListener): - _LISTENERS.server_heartbeat_listeners.append(listener) - if isinstance(listener, ServerListener): - _LISTENERS.server_listeners.append(listener) - if isinstance(listener, TopologyListener): - _LISTENERS.topology_listeners.append(listener) - if isinstance(listener, ConnectionPoolListener): - _LISTENERS.cmap_listeners.append(listener) - - -# The "hello" command is also deemed sensitive when attempting speculative -# authentication. -def _is_speculative_authenticate(command_name: str, doc: Mapping[str, Any]) -> bool: - if ( - command_name.lower() in ("hello", HelloCompat.LEGACY_CMD) - and "speculativeAuthenticate" in doc - ): - return True - return False - - -class _CommandEvent: - """Base class for command events.""" - - __slots__ = ( - "__cmd_name", - "__rqst_id", - "__conn_id", - "__op_id", - "__service_id", - "__db", - "__server_conn_id", - ) - - def __init__( - self, - command_name: str, - request_id: int, - connection_id: _Address, - operation_id: Optional[int], - service_id: Optional[ObjectId] = None, - database_name: str = "", - server_connection_id: Optional[int] = None, - ) -> None: - self.__cmd_name = command_name - self.__rqst_id = request_id - self.__conn_id = connection_id - self.__op_id = operation_id - self.__service_id = service_id - self.__db = database_name - self.__server_conn_id = server_connection_id - - @property - def command_name(self) -> str: - """The command name.""" - return self.__cmd_name - - @property - def request_id(self) -> int: - """The request id for this operation.""" - return self.__rqst_id - - @property - def connection_id(self) -> _Address: - """The address (host, port) of the server this command was sent to.""" - return self.__conn_id - - @property - def service_id(self) -> Optional[ObjectId]: - """The service_id this command was sent to, or ``None``. - - .. versionadded:: 3.12 - """ - return self.__service_id - - @property - def operation_id(self) -> Optional[int]: - """An id for this series of events or None.""" - return self.__op_id - - @property - def database_name(self) -> str: - """The database_name this command was sent to, or ``""``. - - .. versionadded:: 4.6 - """ - return self.__db - - @property - def server_connection_id(self) -> Optional[int]: - """The server-side connection id for the connection this command was sent on, or ``None``. - - .. versionadded:: 4.7 - """ - return self.__server_conn_id - - -class CommandStartedEvent(_CommandEvent): - """Event published when a command starts. - - :param command: The command document. - :param database_name: The name of the database this command was run against. - :param request_id: The request id for this operation. - :param connection_id: The address (host, port) of the server this command - was sent to. - :param operation_id: An optional identifier for a series of related events. - :param service_id: The service_id this command was sent to, or ``None``. - """ - - __slots__ = ("__cmd",) - - def __init__( - self, - command: _DocumentOut, - database_name: str, - request_id: int, - connection_id: _Address, - operation_id: Optional[int], - service_id: Optional[ObjectId] = None, - server_connection_id: Optional[int] = None, - ) -> None: - if not command: - raise ValueError(f"{command!r} is not a valid command") - # Command name must be first key. - command_name = next(iter(command)) - super().__init__( - command_name, - request_id, - connection_id, - operation_id, - service_id=service_id, - database_name=database_name, - server_connection_id=server_connection_id, - ) - cmd_name = command_name.lower() - if cmd_name in _SENSITIVE_COMMANDS or _is_speculative_authenticate(cmd_name, command): - self.__cmd: _DocumentOut = {} - else: - self.__cmd = command - - @property - def command(self) -> _DocumentOut: - """The command document.""" - return self.__cmd - - @property - def database_name(self) -> str: - """The name of the database this command was run against.""" - return super().database_name - - def __repr__(self) -> str: - return ( - "<{} {} db: {!r}, command: {!r}, operation_id: {}, service_id: {}, server_connection_id: {}>" - ).format( - self.__class__.__name__, - self.connection_id, - self.database_name, - self.command_name, - self.operation_id, - self.service_id, - self.server_connection_id, - ) - - -class CommandSucceededEvent(_CommandEvent): - """Event published when a command succeeds. - - :param duration: The command duration as a datetime.timedelta. - :param reply: The server reply document. - :param command_name: The command name. - :param request_id: The request id for this operation. - :param connection_id: The address (host, port) of the server this command - was sent to. - :param operation_id: An optional identifier for a series of related events. - :param service_id: The service_id this command was sent to, or ``None``. - :param database_name: The database this command was sent to, or ``""``. - """ - - __slots__ = ("__duration_micros", "__reply") - - def __init__( - self, - duration: datetime.timedelta, - reply: _DocumentOut, - command_name: str, - request_id: int, - connection_id: _Address, - operation_id: Optional[int], - service_id: Optional[ObjectId] = None, - database_name: str = "", - server_connection_id: Optional[int] = None, - ) -> None: - super().__init__( - command_name, - request_id, - connection_id, - operation_id, - service_id=service_id, - database_name=database_name, - server_connection_id=server_connection_id, - ) - self.__duration_micros = _to_micros(duration) - cmd_name = command_name.lower() - if cmd_name in _SENSITIVE_COMMANDS or _is_speculative_authenticate(cmd_name, reply): - self.__reply: _DocumentOut = {} - else: - self.__reply = reply - - @property - def duration_micros(self) -> int: - """The duration of this operation in microseconds.""" - return self.__duration_micros - - @property - def reply(self) -> _DocumentOut: - """The server failure document for this operation.""" - return self.__reply - - def __repr__(self) -> str: - return ( - "<{} {} db: {!r}, command: {!r}, operation_id: {}, duration_micros: {}, service_id: {}, server_connection_id: {}>" - ).format( - self.__class__.__name__, - self.connection_id, - self.database_name, - self.command_name, - self.operation_id, - self.duration_micros, - self.service_id, - self.server_connection_id, - ) - - -class CommandFailedEvent(_CommandEvent): - """Event published when a command fails. - - :param duration: The command duration as a datetime.timedelta. - :param failure: The server reply document. - :param command_name: The command name. - :param request_id: The request id for this operation. - :param connection_id: The address (host, port) of the server this command - was sent to. - :param operation_id: An optional identifier for a series of related events. - :param service_id: The service_id this command was sent to, or ``None``. - :param database_name: The database this command was sent to, or ``""``. - """ - - __slots__ = ("__duration_micros", "__failure") - - def __init__( - self, - duration: datetime.timedelta, - failure: _DocumentOut, - command_name: str, - request_id: int, - connection_id: _Address, - operation_id: Optional[int], - service_id: Optional[ObjectId] = None, - database_name: str = "", - server_connection_id: Optional[int] = None, - ) -> None: - super().__init__( - command_name, - request_id, - connection_id, - operation_id, - service_id=service_id, - database_name=database_name, - server_connection_id=server_connection_id, - ) - self.__duration_micros = _to_micros(duration) - self.__failure = failure - - @property - def duration_micros(self) -> int: - """The duration of this operation in microseconds.""" - return self.__duration_micros - - @property - def failure(self) -> _DocumentOut: - """The server failure document for this operation.""" - return self.__failure - - def __repr__(self) -> str: - return ( - "<{} {} db: {!r}, command: {!r}, operation_id: {}, duration_micros: {}, " - "failure: {!r}, service_id: {}, server_connection_id: {}>" - ).format( - self.__class__.__name__, - self.connection_id, - self.database_name, - self.command_name, - self.operation_id, - self.duration_micros, - self.failure, - self.service_id, - self.server_connection_id, - ) - - -class _PoolEvent: - """Base class for pool events.""" - - __slots__ = ("__address",) - - def __init__(self, address: _Address) -> None: - self.__address = address - - @property - def address(self) -> _Address: - """The address (host, port) pair of the server the pool is attempting - to connect to. - """ - return self.__address - - def __repr__(self) -> str: - return f"{self.__class__.__name__}({self.__address!r})" - - -class PoolCreatedEvent(_PoolEvent): - """Published when a Connection Pool is created. - - :param address: The address (host, port) pair of the server this Pool is - attempting to connect to. - - .. versionadded:: 3.9 - """ - - __slots__ = ("__options",) - - def __init__(self, address: _Address, options: dict[str, Any]) -> None: - super().__init__(address) - self.__options = options - - @property - def options(self) -> dict[str, Any]: - """Any non-default pool options that were set on this Connection Pool.""" - return self.__options - - def __repr__(self) -> str: - return f"{self.__class__.__name__}({self.address!r}, {self.__options!r})" - - -class PoolReadyEvent(_PoolEvent): - """Published when a Connection Pool is marked ready. - - :param address: The address (host, port) pair of the server this Pool is - attempting to connect to. - - .. versionadded:: 4.0 - """ - - __slots__ = () - - -class PoolClearedEvent(_PoolEvent): - """Published when a Connection Pool is cleared. - - :param address: The address (host, port) pair of the server this Pool is - attempting to connect to. - :param service_id: The service_id this command was sent to, or ``None``. - :param interrupt_connections: True if all active connections were interrupted by the Pool during clearing. - - .. versionadded:: 3.9 - """ - - __slots__ = ("__service_id", "__interrupt_connections") - - def __init__( - self, - address: _Address, - service_id: Optional[ObjectId] = None, - interrupt_connections: bool = False, - ) -> None: - super().__init__(address) - self.__service_id = service_id - self.__interrupt_connections = interrupt_connections - - @property - def service_id(self) -> Optional[ObjectId]: - """Connections with this service_id are cleared. - - When service_id is ``None``, all connections in the pool are cleared. - - .. versionadded:: 3.12 - """ - return self.__service_id - - @property - def interrupt_connections(self) -> bool: - """If True, active connections are interrupted during clearing. - - .. versionadded:: 4.7 - """ - return self.__interrupt_connections - - def __repr__(self) -> str: - return f"{self.__class__.__name__}({self.address!r}, {self.__service_id!r}, {self.__interrupt_connections!r})" - - -class PoolClosedEvent(_PoolEvent): - """Published when a Connection Pool is closed. - - :param address: The address (host, port) pair of the server this Pool is - attempting to connect to. - - .. versionadded:: 3.9 - """ - - __slots__ = () - - -class ConnectionClosedReason: - """An enum that defines values for `reason` on a - :class:`ConnectionClosedEvent`. - - .. versionadded:: 3.9 - """ - - STALE = "stale" - """The pool was cleared, making the connection no longer valid.""" - - IDLE = "idle" - """The connection became stale by being idle for too long (maxIdleTimeMS). - """ - - ERROR = "error" - """The connection experienced an error, making it no longer valid.""" - - POOL_CLOSED = "poolClosed" - """The pool was closed, making the connection no longer valid.""" - - -class ConnectionCheckOutFailedReason: - """An enum that defines values for `reason` on a - :class:`ConnectionCheckOutFailedEvent`. - - .. versionadded:: 3.9 - """ - - TIMEOUT = "timeout" - """The connection check out attempt exceeded the specified timeout.""" - - POOL_CLOSED = "poolClosed" - """The pool was previously closed, and cannot provide new connections.""" - - CONN_ERROR = "connectionError" - """The connection check out attempt experienced an error while setting up - a new connection. - """ - - -class _ConnectionEvent: - """Private base class for connection events.""" - - __slots__ = ("__address",) - - def __init__(self, address: _Address) -> None: - self.__address = address - - @property - def address(self) -> _Address: - """The address (host, port) pair of the server this connection is - attempting to connect to. - """ - return self.__address - - def __repr__(self) -> str: - return f"{self.__class__.__name__}({self.__address!r})" - - -class _ConnectionIdEvent(_ConnectionEvent): - """Private base class for connection events with an id.""" - - __slots__ = ("__connection_id",) - - def __init__(self, address: _Address, connection_id: int) -> None: - super().__init__(address) - self.__connection_id = connection_id - - @property - def connection_id(self) -> int: - """The ID of the connection.""" - return self.__connection_id - - def __repr__(self) -> str: - return f"{self.__class__.__name__}({self.address!r}, {self.__connection_id!r})" - - -class _ConnectionDurationEvent(_ConnectionIdEvent): - """Private base class for connection events with a duration.""" - - __slots__ = ("__duration",) - - def __init__(self, address: _Address, connection_id: int, duration: Optional[float]) -> None: - super().__init__(address, connection_id) - self.__duration = duration - - @property - def duration(self) -> Optional[float]: - """The duration of the connection event. - - .. versionadded:: 4.7 - """ - return self.__duration - - def __repr__(self) -> str: - return f"{self.__class__.__name__}({self.address!r}, {self.connection_id!r}, {self.__duration!r})" - - -class ConnectionCreatedEvent(_ConnectionIdEvent): - """Published when a Connection Pool creates a Connection object. - - NOTE: This connection is not ready for use until the - :class:`ConnectionReadyEvent` is published. - - :param address: The address (host, port) pair of the server this - Connection is attempting to connect to. - :param connection_id: The integer ID of the Connection in this Pool. - - .. versionadded:: 3.9 - """ - - __slots__ = () - - -class ConnectionReadyEvent(_ConnectionDurationEvent): - """Published when a Connection has finished its setup, and is ready to use. - - :param address: The address (host, port) pair of the server this - Connection is attempting to connect to. - :param connection_id: The integer ID of the Connection in this Pool. - - .. versionadded:: 3.9 - """ - - __slots__ = () - - -class ConnectionClosedEvent(_ConnectionIdEvent): - """Published when a Connection is closed. - - :param address: The address (host, port) pair of the server this - Connection is attempting to connect to. - :param connection_id: The integer ID of the Connection in this Pool. - :param reason: A reason explaining why this connection was closed. - - .. versionadded:: 3.9 - """ - - __slots__ = ("__reason",) - - def __init__(self, address: _Address, connection_id: int, reason: str): - super().__init__(address, connection_id) - self.__reason = reason - - @property - def reason(self) -> str: - """A reason explaining why this connection was closed. - - The reason must be one of the strings from the - :class:`ConnectionClosedReason` enum. - """ - return self.__reason - - def __repr__(self) -> str: - return "{}({!r}, {!r}, {!r})".format( - self.__class__.__name__, - self.address, - self.connection_id, - self.__reason, - ) - - -class ConnectionCheckOutStartedEvent(_ConnectionEvent): - """Published when the driver starts attempting to check out a connection. - - :param address: The address (host, port) pair of the server this - Connection is attempting to connect to. - - .. versionadded:: 3.9 - """ - - __slots__ = () - - -class ConnectionCheckOutFailedEvent(_ConnectionDurationEvent): - """Published when the driver's attempt to check out a connection fails. - - :param address: The address (host, port) pair of the server this - Connection is attempting to connect to. - :param reason: A reason explaining why connection check out failed. - - .. versionadded:: 3.9 - """ - - __slots__ = ("__reason",) - - def __init__(self, address: _Address, reason: str, duration: Optional[float]) -> None: - super().__init__(address=address, connection_id=0, duration=duration) - self.__reason = reason - - @property - def reason(self) -> str: - """A reason explaining why connection check out failed. - - The reason must be one of the strings from the - :class:`ConnectionCheckOutFailedReason` enum. - """ - return self.__reason - - def __repr__(self) -> str: - return f"{self.__class__.__name__}({self.address!r}, {self.__reason!r}, {self.duration!r})" - - -class ConnectionCheckedOutEvent(_ConnectionDurationEvent): - """Published when the driver successfully checks out a connection. - - :param address: The address (host, port) pair of the server this - Connection is attempting to connect to. - :param connection_id: The integer ID of the Connection in this Pool. - - .. versionadded:: 3.9 - """ - - __slots__ = () - - -class ConnectionCheckedInEvent(_ConnectionIdEvent): - """Published when the driver checks in a Connection into the Pool. - - :param address: The address (host, port) pair of the server this - Connection is attempting to connect to. - :param connection_id: The integer ID of the Connection in this Pool. - - .. versionadded:: 3.9 - """ - - __slots__ = () - - -class _ServerEvent: - """Base class for server events.""" - - __slots__ = ("__server_address", "__topology_id") - - def __init__(self, server_address: _Address, topology_id: ObjectId) -> None: - self.__server_address = server_address - self.__topology_id = topology_id - - @property - def server_address(self) -> _Address: - """The address (host, port) pair of the server""" - return self.__server_address - - @property - def topology_id(self) -> ObjectId: - """A unique identifier for the topology this server is a part of.""" - return self.__topology_id - - def __repr__(self) -> str: - return f"<{self.__class__.__name__} {self.server_address} topology_id: {self.topology_id}>" - - -class ServerDescriptionChangedEvent(_ServerEvent): - """Published when server description changes. - - .. versionadded:: 3.3 - """ - - __slots__ = ("__previous_description", "__new_description") - - def __init__( - self, - previous_description: ServerDescription, - new_description: ServerDescription, - *args: Any, - ) -> None: - super().__init__(*args) - self.__previous_description = previous_description - self.__new_description = new_description - - @property - def previous_description(self) -> ServerDescription: - """The previous - :class:`~pymongo.server_description.ServerDescription`. - """ - return self.__previous_description - - @property - def new_description(self) -> ServerDescription: - """The new - :class:`~pymongo.server_description.ServerDescription`. - """ - return self.__new_description - - def __repr__(self) -> str: - return "<{} {} changed from: {}, to: {}>".format( - self.__class__.__name__, - self.server_address, - self.previous_description, - self.new_description, - ) - - -class ServerOpeningEvent(_ServerEvent): - """Published when server is initialized. - - .. versionadded:: 3.3 - """ - - __slots__ = () - - -class ServerClosedEvent(_ServerEvent): - """Published when server is closed. - - .. versionadded:: 3.3 - """ - - __slots__ = () - - -class TopologyEvent: - """Base class for topology description events.""" - - __slots__ = ("__topology_id",) - - def __init__(self, topology_id: ObjectId) -> None: - self.__topology_id = topology_id - - @property - def topology_id(self) -> ObjectId: - """A unique identifier for the topology this server is a part of.""" - return self.__topology_id - - def __repr__(self) -> str: - return f"<{self.__class__.__name__} topology_id: {self.topology_id}>" - - -class TopologyDescriptionChangedEvent(TopologyEvent): - """Published when the topology description changes. - - .. versionadded:: 3.3 - """ - - __slots__ = ("__previous_description", "__new_description") - - def __init__( - self, - previous_description: TopologyDescription, - new_description: TopologyDescription, - *args: Any, - ) -> None: - super().__init__(*args) - self.__previous_description = previous_description - self.__new_description = new_description - - @property - def previous_description(self) -> TopologyDescription: - """The previous - :class:`~pymongo.topology_description.TopologyDescription`. - """ - return self.__previous_description - - @property - def new_description(self) -> TopologyDescription: - """The new - :class:`~pymongo.topology_description.TopologyDescription`. - """ - return self.__new_description - - def __repr__(self) -> str: - return "<{} topology_id: {} changed from: {}, to: {}>".format( - self.__class__.__name__, - self.topology_id, - self.previous_description, - self.new_description, - ) - - -class TopologyOpenedEvent(TopologyEvent): - """Published when the topology is initialized. - - .. versionadded:: 3.3 - """ - - __slots__ = () - - -class TopologyClosedEvent(TopologyEvent): - """Published when the topology is closed. - - .. versionadded:: 3.3 - """ - - __slots__ = () - - -class _ServerHeartbeatEvent: - """Base class for server heartbeat events.""" - - __slots__ = ("__connection_id", "__awaited") - - def __init__(self, connection_id: _Address, awaited: bool = False) -> None: - self.__connection_id = connection_id - self.__awaited = awaited - - @property - def connection_id(self) -> _Address: - """The address (host, port) of the server this heartbeat was sent - to. - """ - return self.__connection_id - - @property - def awaited(self) -> bool: - """Whether the heartbeat was issued as an awaitable hello command. - - .. versionadded:: 4.6 - """ - return self.__awaited - - def __repr__(self) -> str: - return f"<{self.__class__.__name__} {self.connection_id} awaited: {self.awaited}>" - - -class ServerHeartbeatStartedEvent(_ServerHeartbeatEvent): - """Published when a heartbeat is started. - - .. versionadded:: 3.3 - """ - - __slots__ = () - - -class ServerHeartbeatSucceededEvent(_ServerHeartbeatEvent): - """Fired when the server heartbeat succeeds. - - .. versionadded:: 3.3 - """ - - __slots__ = ("__duration", "__reply") - - def __init__( - self, duration: float, reply: Hello, connection_id: _Address, awaited: bool = False - ) -> None: - super().__init__(connection_id, awaited) - self.__duration = duration - self.__reply = reply - - @property - def duration(self) -> float: - """The duration of this heartbeat in microseconds.""" - return self.__duration - - @property - def reply(self) -> Hello: - """An instance of :class:`~pymongo.hello.Hello`.""" - return self.__reply - - @property - def awaited(self) -> bool: - """Whether the heartbeat was awaited. - - If true, then :meth:`duration` reflects the sum of the round trip time - to the server and the time that the server waited before sending a - response. - - .. versionadded:: 3.11 - """ - return super().awaited - - def __repr__(self) -> str: - return "<{} {} duration: {}, awaited: {}, reply: {}>".format( - self.__class__.__name__, - self.connection_id, - self.duration, - self.awaited, - self.reply, - ) - - -class ServerHeartbeatFailedEvent(_ServerHeartbeatEvent): - """Fired when the server heartbeat fails, either with an "ok: 0" - or a socket exception. - - .. versionadded:: 3.3 - """ - - __slots__ = ("__duration", "__reply") - - def __init__( - self, duration: float, reply: Exception, connection_id: _Address, awaited: bool = False - ) -> None: - super().__init__(connection_id, awaited) - self.__duration = duration - self.__reply = reply - - @property - def duration(self) -> float: - """The duration of this heartbeat in microseconds.""" - return self.__duration - - @property - def reply(self) -> Exception: - """A subclass of :exc:`Exception`.""" - return self.__reply - - @property - def awaited(self) -> bool: - """Whether the heartbeat was awaited. - - If true, then :meth:`duration` reflects the sum of the round trip time - to the server and the time that the server waited before sending a - response. - - .. versionadded:: 3.11 - """ - return super().awaited - - def __repr__(self) -> str: - return "<{} {} duration: {}, awaited: {}, reply: {!r}>".format( - self.__class__.__name__, - self.connection_id, - self.duration, - self.awaited, - self.reply, - ) - - -class _EventListeners: - """Configure event listeners for a client instance. - - Any event listeners registered globally are included by default. - - :param listeners: A list of event listeners. - """ - - def __init__(self, listeners: Optional[Sequence[_EventListener]]): - self.__command_listeners = _LISTENERS.command_listeners[:] - self.__server_listeners = _LISTENERS.server_listeners[:] - lst = _LISTENERS.server_heartbeat_listeners - self.__server_heartbeat_listeners = lst[:] - self.__topology_listeners = _LISTENERS.topology_listeners[:] - self.__cmap_listeners = _LISTENERS.cmap_listeners[:] - if listeners is not None: - for lst in listeners: - if isinstance(lst, CommandListener): - self.__command_listeners.append(lst) - if isinstance(lst, ServerListener): - self.__server_listeners.append(lst) - if isinstance(lst, ServerHeartbeatListener): - self.__server_heartbeat_listeners.append(lst) - if isinstance(lst, TopologyListener): - self.__topology_listeners.append(lst) - if isinstance(lst, ConnectionPoolListener): - self.__cmap_listeners.append(lst) - self.__enabled_for_commands = bool(self.__command_listeners) - self.__enabled_for_server = bool(self.__server_listeners) - self.__enabled_for_server_heartbeat = bool(self.__server_heartbeat_listeners) - self.__enabled_for_topology = bool(self.__topology_listeners) - self.__enabled_for_cmap = bool(self.__cmap_listeners) - - @property - def enabled_for_commands(self) -> bool: - """Are any CommandListener instances registered?""" - return self.__enabled_for_commands - - @property - def enabled_for_server(self) -> bool: - """Are any ServerListener instances registered?""" - return self.__enabled_for_server - - @property - def enabled_for_server_heartbeat(self) -> bool: - """Are any ServerHeartbeatListener instances registered?""" - return self.__enabled_for_server_heartbeat - - @property - def enabled_for_topology(self) -> bool: - """Are any TopologyListener instances registered?""" - return self.__enabled_for_topology - - @property - def enabled_for_cmap(self) -> bool: - """Are any ConnectionPoolListener instances registered?""" - return self.__enabled_for_cmap - - def event_listeners(self) -> list[_EventListeners]: - """List of registered event listeners.""" - return ( - self.__command_listeners - + self.__server_heartbeat_listeners - + self.__server_listeners - + self.__topology_listeners - + self.__cmap_listeners - ) - - def publish_command_start( - self, - command: _DocumentOut, - database_name: str, - request_id: int, - connection_id: _Address, - server_connection_id: Optional[int], - op_id: Optional[int] = None, - service_id: Optional[ObjectId] = None, - ) -> None: - """Publish a CommandStartedEvent to all command listeners. - - :param command: The command document. - :param database_name: The name of the database this command was run - against. - :param request_id: The request id for this operation. - :param connection_id: The address (host, port) of the server this - command was sent to. - :param op_id: The (optional) operation id for this operation. - :param service_id: The service_id this command was sent to, or ``None``. - """ - if op_id is None: - op_id = request_id - event = CommandStartedEvent( - command, - database_name, - request_id, - connection_id, - op_id, - service_id=service_id, - server_connection_id=server_connection_id, - ) - for subscriber in self.__command_listeners: - try: - subscriber.started(event) - except Exception: - _handle_exception() - - def publish_command_success( - self, - duration: timedelta, - reply: _DocumentOut, - command_name: str, - request_id: int, - connection_id: _Address, - server_connection_id: Optional[int], - op_id: Optional[int] = None, - service_id: Optional[ObjectId] = None, - speculative_hello: bool = False, - database_name: str = "", - ) -> None: - """Publish a CommandSucceededEvent to all command listeners. - - :param duration: The command duration as a datetime.timedelta. - :param reply: The server reply document. - :param command_name: The command name. - :param request_id: The request id for this operation. - :param connection_id: The address (host, port) of the server this - command was sent to. - :param op_id: The (optional) operation id for this operation. - :param service_id: The service_id this command was sent to, or ``None``. - :param speculative_hello: Was the command sent with speculative auth? - :param database_name: The database this command was sent to, or ``""``. - """ - if op_id is None: - op_id = request_id - if speculative_hello: - # Redact entire response when the command started contained - # speculativeAuthenticate. - reply = {} - event = CommandSucceededEvent( - duration, - reply, - command_name, - request_id, - connection_id, - op_id, - service_id, - database_name=database_name, - server_connection_id=server_connection_id, - ) - for subscriber in self.__command_listeners: - try: - subscriber.succeeded(event) - except Exception: - _handle_exception() - - def publish_command_failure( - self, - duration: timedelta, - failure: _DocumentOut, - command_name: str, - request_id: int, - connection_id: _Address, - server_connection_id: Optional[int], - op_id: Optional[int] = None, - service_id: Optional[ObjectId] = None, - database_name: str = "", - ) -> None: - """Publish a CommandFailedEvent to all command listeners. - - :param duration: The command duration as a datetime.timedelta. - :param failure: The server reply document or failure description - document. - :param command_name: The command name. - :param request_id: The request id for this operation. - :param connection_id: The address (host, port) of the server this - command was sent to. - :param op_id: The (optional) operation id for this operation. - :param service_id: The service_id this command was sent to, or ``None``. - :param database_name: The database this command was sent to, or ``""``. - """ - if op_id is None: - op_id = request_id - event = CommandFailedEvent( - duration, - failure, - command_name, - request_id, - connection_id, - op_id, - service_id=service_id, - database_name=database_name, - server_connection_id=server_connection_id, - ) - for subscriber in self.__command_listeners: - try: - subscriber.failed(event) - except Exception: - _handle_exception() - - def publish_server_heartbeat_started(self, connection_id: _Address, awaited: bool) -> None: - """Publish a ServerHeartbeatStartedEvent to all server heartbeat - listeners. - - :param connection_id: The address (host, port) pair of the connection. - :param awaited: True if this heartbeat is part of an awaitable hello command. - """ - event = ServerHeartbeatStartedEvent(connection_id, awaited) - for subscriber in self.__server_heartbeat_listeners: - try: - subscriber.started(event) - except Exception: - _handle_exception() - - def publish_server_heartbeat_succeeded( - self, connection_id: _Address, duration: float, reply: Hello, awaited: bool - ) -> None: - """Publish a ServerHeartbeatSucceededEvent to all server heartbeat - listeners. - - :param connection_id: The address (host, port) pair of the connection. - :param duration: The execution time of the event in the highest possible - resolution for the platform. - :param reply: The command reply. - :param awaited: True if the response was awaited. - """ - event = ServerHeartbeatSucceededEvent(duration, reply, connection_id, awaited) - for subscriber in self.__server_heartbeat_listeners: - try: - subscriber.succeeded(event) - except Exception: - _handle_exception() - - def publish_server_heartbeat_failed( - self, connection_id: _Address, duration: float, reply: Exception, awaited: bool - ) -> None: - """Publish a ServerHeartbeatFailedEvent to all server heartbeat - listeners. - - :param connection_id: The address (host, port) pair of the connection. - :param duration: The execution time of the event in the highest possible - resolution for the platform. - :param reply: The command reply. - :param awaited: True if the response was awaited. - """ - event = ServerHeartbeatFailedEvent(duration, reply, connection_id, awaited) - for subscriber in self.__server_heartbeat_listeners: - try: - subscriber.failed(event) - except Exception: - _handle_exception() - - def publish_server_opened(self, server_address: _Address, topology_id: ObjectId) -> None: - """Publish a ServerOpeningEvent to all server listeners. - - :param server_address: The address (host, port) pair of the server. - :param topology_id: A unique identifier for the topology this server - is a part of. - """ - event = ServerOpeningEvent(server_address, topology_id) - for subscriber in self.__server_listeners: - try: - subscriber.opened(event) - except Exception: - _handle_exception() - - def publish_server_closed(self, server_address: _Address, topology_id: ObjectId) -> None: - """Publish a ServerClosedEvent to all server listeners. - - :param server_address: The address (host, port) pair of the server. - :param topology_id: A unique identifier for the topology this server - is a part of. - """ - event = ServerClosedEvent(server_address, topology_id) - for subscriber in self.__server_listeners: - try: - subscriber.closed(event) - except Exception: - _handle_exception() - - def publish_server_description_changed( - self, - previous_description: ServerDescription, - new_description: ServerDescription, - server_address: _Address, - topology_id: ObjectId, - ) -> None: - """Publish a ServerDescriptionChangedEvent to all server listeners. - - :param previous_description: The previous server description. - :param server_address: The address (host, port) pair of the server. - :param new_description: The new server description. - :param topology_id: A unique identifier for the topology this server - is a part of. - """ - event = ServerDescriptionChangedEvent( - previous_description, new_description, server_address, topology_id - ) - for subscriber in self.__server_listeners: - try: - subscriber.description_changed(event) - except Exception: - _handle_exception() - - def publish_topology_opened(self, topology_id: ObjectId) -> None: - """Publish a TopologyOpenedEvent to all topology listeners. - - :param topology_id: A unique identifier for the topology this server - is a part of. - """ - event = TopologyOpenedEvent(topology_id) - for subscriber in self.__topology_listeners: - try: - subscriber.opened(event) - except Exception: - _handle_exception() - - def publish_topology_closed(self, topology_id: ObjectId) -> None: - """Publish a TopologyClosedEvent to all topology listeners. - - :param topology_id: A unique identifier for the topology this server - is a part of. - """ - event = TopologyClosedEvent(topology_id) - for subscriber in self.__topology_listeners: - try: - subscriber.closed(event) - except Exception: - _handle_exception() - - def publish_topology_description_changed( - self, - previous_description: TopologyDescription, - new_description: TopologyDescription, - topology_id: ObjectId, - ) -> None: - """Publish a TopologyDescriptionChangedEvent to all topology listeners. - - :param previous_description: The previous topology description. - :param new_description: The new topology description. - :param topology_id: A unique identifier for the topology this server - is a part of. - """ - event = TopologyDescriptionChangedEvent(previous_description, new_description, topology_id) - for subscriber in self.__topology_listeners: - try: - subscriber.description_changed(event) - except Exception: - _handle_exception() - - def publish_pool_created(self, address: _Address, options: dict[str, Any]) -> None: - """Publish a :class:`PoolCreatedEvent` to all pool listeners.""" - event = PoolCreatedEvent(address, options) - for subscriber in self.__cmap_listeners: - try: - subscriber.pool_created(event) - except Exception: - _handle_exception() - - def publish_pool_ready(self, address: _Address) -> None: - """Publish a :class:`PoolReadyEvent` to all pool listeners.""" - event = PoolReadyEvent(address) - for subscriber in self.__cmap_listeners: - try: - subscriber.pool_ready(event) - except Exception: - _handle_exception() - - def publish_pool_cleared( - self, - address: _Address, - service_id: Optional[ObjectId], - interrupt_connections: bool = False, - ) -> None: - """Publish a :class:`PoolClearedEvent` to all pool listeners.""" - event = PoolClearedEvent(address, service_id, interrupt_connections) - for subscriber in self.__cmap_listeners: - try: - subscriber.pool_cleared(event) - except Exception: - _handle_exception() - - def publish_pool_closed(self, address: _Address) -> None: - """Publish a :class:`PoolClosedEvent` to all pool listeners.""" - event = PoolClosedEvent(address) - for subscriber in self.__cmap_listeners: - try: - subscriber.pool_closed(event) - except Exception: - _handle_exception() - - def publish_connection_created(self, address: _Address, connection_id: int) -> None: - """Publish a :class:`ConnectionCreatedEvent` to all connection - listeners. - """ - event = ConnectionCreatedEvent(address, connection_id) - for subscriber in self.__cmap_listeners: - try: - subscriber.connection_created(event) - except Exception: - _handle_exception() - - def publish_connection_ready( - self, address: _Address, connection_id: int, duration: float - ) -> None: - """Publish a :class:`ConnectionReadyEvent` to all connection listeners.""" - event = ConnectionReadyEvent(address, connection_id, duration) - for subscriber in self.__cmap_listeners: - try: - subscriber.connection_ready(event) - except Exception: - _handle_exception() - - def publish_connection_closed(self, address: _Address, connection_id: int, reason: str) -> None: - """Publish a :class:`ConnectionClosedEvent` to all connection - listeners. - """ - event = ConnectionClosedEvent(address, connection_id, reason) - for subscriber in self.__cmap_listeners: - try: - subscriber.connection_closed(event) - except Exception: - _handle_exception() - - def publish_connection_check_out_started(self, address: _Address) -> None: - """Publish a :class:`ConnectionCheckOutStartedEvent` to all connection - listeners. - """ - event = ConnectionCheckOutStartedEvent(address) - for subscriber in self.__cmap_listeners: - try: - subscriber.connection_check_out_started(event) - except Exception: - _handle_exception() - - def publish_connection_check_out_failed( - self, address: _Address, reason: str, duration: float - ) -> None: - """Publish a :class:`ConnectionCheckOutFailedEvent` to all connection - listeners. - """ - event = ConnectionCheckOutFailedEvent(address, reason, duration) - for subscriber in self.__cmap_listeners: - try: - subscriber.connection_check_out_failed(event) - except Exception: - _handle_exception() - - def publish_connection_checked_out( - self, address: _Address, connection_id: int, duration: float - ) -> None: - """Publish a :class:`ConnectionCheckedOutEvent` to all connection - listeners. - """ - event = ConnectionCheckedOutEvent(address, connection_id, duration) - for subscriber in self.__cmap_listeners: - try: - subscriber.connection_checked_out(event) - except Exception: - _handle_exception() - - def publish_connection_checked_in(self, address: _Address, connection_id: int) -> None: - """Publish a :class:`ConnectionCheckedInEvent` to all connection - listeners. - """ - event = ConnectionCheckedInEvent(address, connection_id) - for subscriber in self.__cmap_listeners: - try: - subscriber.connection_checked_in(event) - except Exception: - _handle_exception() diff --git a/pymongo/synchronous/network.py b/pymongo/synchronous/network.py index 3f5319fd32..cdfb60e202 100644 --- a/pymongo/synchronous/network.py +++ b/pymongo/synchronous/network.py @@ -32,13 +32,17 @@ ) from bson import _decode_all_selective -from pymongo import _csot +from pymongo import _csot, helpers_shared +from pymongo.common import MAX_MESSAGE_SIZE +from pymongo.compression_support import _NO_COMPRESSION, decompress from pymongo.errors import ( NotPrimaryError, OperationFailure, ProtocolError, _OperationCancelled, ) +from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log +from pymongo.monitoring import _is_speculative_authenticate from pymongo.network_layer import ( _POLL_TIMEOUT, _UNPACK_COMPRESSION_HEADER, @@ -47,24 +51,19 @@ sendall, ) from pymongo.socket_checker import _errno_from_exception -from pymongo.synchronous import helpers as _async_helpers -from pymongo.synchronous import message as _async_message -from pymongo.synchronous.common import MAX_MESSAGE_SIZE -from pymongo.synchronous.compression_support import _NO_COMPRESSION, decompress -from pymongo.synchronous.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log +from pymongo.synchronous import message from pymongo.synchronous.message import _UNPACK_REPLY, _OpMsg, _OpReply -from pymongo.synchronous.monitoring import _is_speculative_authenticate if TYPE_CHECKING: from bson import CodecOptions + from pymongo.compression_support import SnappyContext, ZlibContext, ZstdContext + from pymongo.monitoring import _EventListeners from pymongo.read_concern import ReadConcern + from pymongo.read_preferences import _ServerMode from pymongo.synchronous.client_session import ClientSession - from pymongo.synchronous.compression_support import SnappyContext, ZlibContext, ZstdContext from pymongo.synchronous.mongo_client import MongoClient - from pymongo.synchronous.monitoring import _EventListeners from pymongo.synchronous.pool import Connection - from pymongo.synchronous.read_preferences import _ServerMode - from pymongo.synchronous.typings import _Address, _CollationIn, _DocumentOut, _DocumentType + from pymongo.typings import _Address, _CollationIn, _DocumentOut, _DocumentType from pymongo.write_concern import WriteConcern _IS_SYNC = True @@ -129,7 +128,7 @@ def command( orig = spec if is_mongos and not use_op_msg: assert read_preference is not None - spec = _async_message._maybe_add_read_preference(spec, read_preference) + spec = message._maybe_add_read_preference(spec, read_preference) if read_concern and not (session and session.in_transaction): if read_concern.level: spec["readConcern"] = read_concern.document @@ -157,22 +156,20 @@ def command( if use_op_msg: flags = _OpMsg.MORE_TO_COME if unacknowledged else 0 flags |= _OpMsg.EXHAUST_ALLOWED if exhaust_allowed else 0 - request_id, msg, size, max_doc_size = _async_message._op_msg( + request_id, msg, size, max_doc_size = message._op_msg( flags, spec, dbname, read_preference, codec_options, ctx=compression_ctx ) # If this is an unacknowledged write then make sure the encoded doc(s) # are small enough, otherwise rely on the server to return an error. if unacknowledged and max_bson_size is not None and max_doc_size > max_bson_size: - _async_message._raise_document_too_large(name, size, max_bson_size) + message._raise_document_too_large(name, size, max_bson_size) else: - request_id, msg, size = _async_message._query( + request_id, msg, size = message._query( 0, ns, 0, -1, spec, None, codec_options, compression_ctx ) - if max_bson_size is not None and size > max_bson_size + _async_message._COMMAND_OVERHEAD: - _async_message._raise_document_too_large( - name, size, max_bson_size + _async_message._COMMAND_OVERHEAD - ) + if max_bson_size is not None and size > max_bson_size + message._COMMAND_OVERHEAD: + message._raise_document_too_large(name, size, max_bson_size + message._COMMAND_OVERHEAD) if client is not None: if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( @@ -219,7 +216,7 @@ def command( if client: client._process_response(response_doc, session) if check: - _async_helpers._check_command_response( + helpers_shared._check_command_response( response_doc, conn.max_wire_version, allowable_errors, @@ -230,7 +227,7 @@ def command( if isinstance(exc, (NotPrimaryError, OperationFailure)): failure: _DocumentOut = exc.details # type: ignore[assignment] else: - failure = _async_message._convert_exception(exc) + failure = message._convert_exception(exc) if client is not None: if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( diff --git a/pymongo/synchronous/operations.py b/pymongo/synchronous/operations.py deleted file mode 100644 index 148f84a42c..0000000000 --- a/pymongo/synchronous/operations.py +++ /dev/null @@ -1,625 +0,0 @@ -# Copyright 2015-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Operation class definitions.""" -from __future__ import annotations - -import enum -from typing import ( - TYPE_CHECKING, - Any, - Generic, - Mapping, - Optional, - Sequence, - Tuple, - Union, -) - -from bson.raw_bson import RawBSONDocument -from pymongo.synchronous import helpers -from pymongo.synchronous.collation import validate_collation_or_none -from pymongo.synchronous.common import validate_is_mapping, validate_list -from pymongo.synchronous.helpers import _gen_index_name, _index_document, _index_list -from pymongo.synchronous.typings import _CollationIn, _DocumentType, _Pipeline -from pymongo.write_concern import validate_boolean - -if TYPE_CHECKING: - from pymongo.synchronous.bulk import _Bulk - -_IS_SYNC = True - -# Hint supports index name, "myIndex", a list of either strings or index pairs: [('x', 1), ('y', -1), 'z''], or a dictionary -_IndexList = Union[ - Sequence[Union[str, Tuple[str, Union[int, str, Mapping[str, Any]]]]], Mapping[str, Any] -] -_IndexKeyHint = Union[str, _IndexList] - - -class _Op(str, enum.Enum): - ABORT = "abortTransaction" - AGGREGATE = "aggregate" - COMMIT = "commitTransaction" - COUNT = "count" - CREATE = "create" - CREATE_INDEXES = "createIndexes" - CREATE_SEARCH_INDEXES = "createSearchIndexes" - DELETE = "delete" - DISTINCT = "distinct" - DROP = "drop" - DROP_DATABASE = "dropDatabase" - DROP_INDEXES = "dropIndexes" - DROP_SEARCH_INDEXES = "dropSearchIndexes" - END_SESSIONS = "endSessions" - FIND_AND_MODIFY = "findAndModify" - FIND = "find" - INSERT = "insert" - LIST_COLLECTIONS = "listCollections" - LIST_INDEXES = "listIndexes" - LIST_SEARCH_INDEX = "listSearchIndexes" - LIST_DATABASES = "listDatabases" - UPDATE = "update" - UPDATE_INDEX = "updateIndex" - UPDATE_SEARCH_INDEX = "updateSearchIndex" - RENAME = "rename" - GETMORE = "getMore" - KILL_CURSORS = "killCursors" - TEST = "testOperation" - - -class InsertOne(Generic[_DocumentType]): - """Represents an insert_one operation.""" - - __slots__ = ("_doc",) - - def __init__(self, document: _DocumentType) -> None: - """Create an InsertOne instance. - - For use with :meth:`~pymongo.collection.Collection.bulk_write`. - - :param document: The document to insert. If the document is missing an - _id field one will be added. - """ - self._doc = document - - def _add_to_bulk(self, bulkobj: _Bulk) -> None: - """Add this operation to the _Bulk instance `bulkobj`.""" - bulkobj.add_insert(self._doc) # type: ignore[arg-type] - - def __repr__(self) -> str: - return f"InsertOne({self._doc!r})" - - def __eq__(self, other: Any) -> bool: - if type(other) == type(self): - return other._doc == self._doc - return NotImplemented - - def __ne__(self, other: Any) -> bool: - return not self == other - - -class DeleteOne: - """Represents a delete_one operation.""" - - __slots__ = ("_filter", "_collation", "_hint") - - def __init__( - self, - filter: Mapping[str, Any], - collation: Optional[_CollationIn] = None, - hint: Optional[_IndexKeyHint] = None, - ) -> None: - """Create a DeleteOne instance. - - For use with :meth:`~pymongo.collection.Collection.bulk_write`. - - :param filter: A query that matches the document to delete. - :param collation: An instance of - :class:`~pymongo.collation.Collation`. - :param hint: An index to use to support the query - predicate specified either by its string name, or in the same - format as passed to - :meth:`~pymongo.collection.Collection.create_index` (e.g. - ``[('field', ASCENDING)]``). This option is only supported on - MongoDB 4.4 and above. - - .. versionchanged:: 3.11 - Added the ``hint`` option. - .. versionchanged:: 3.5 - Added the `collation` option. - """ - if filter is not None: - validate_is_mapping("filter", filter) - if hint is not None and not isinstance(hint, str): - self._hint: Union[str, dict[str, Any], None] = helpers._index_document(hint) - else: - self._hint = hint - self._filter = filter - self._collation = collation - - def _add_to_bulk(self, bulkobj: _Bulk) -> None: - """Add this operation to the _Bulk instance `bulkobj`.""" - bulkobj.add_delete( - self._filter, - 1, - collation=validate_collation_or_none(self._collation), - hint=self._hint, - ) - - def __repr__(self) -> str: - return f"DeleteOne({self._filter!r}, {self._collation!r}, {self._hint!r})" - - def __eq__(self, other: Any) -> bool: - if type(other) == type(self): - return (other._filter, other._collation, other._hint) == ( - self._filter, - self._collation, - self._hint, - ) - return NotImplemented - - def __ne__(self, other: Any) -> bool: - return not self == other - - -class DeleteMany: - """Represents a delete_many operation.""" - - __slots__ = ("_filter", "_collation", "_hint") - - def __init__( - self, - filter: Mapping[str, Any], - collation: Optional[_CollationIn] = None, - hint: Optional[_IndexKeyHint] = None, - ) -> None: - """Create a DeleteMany instance. - - For use with :meth:`~pymongo.collection.Collection.bulk_write`. - - :param filter: A query that matches the documents to delete. - :param collation: An instance of - :class:`~pymongo.collation.Collation`. - :param hint: An index to use to support the query - predicate specified either by its string name, or in the same - format as passed to - :meth:`~pymongo.collection.Collection.create_index` (e.g. - ``[('field', ASCENDING)]``). This option is only supported on - MongoDB 4.4 and above. - - .. versionchanged:: 3.11 - Added the ``hint`` option. - .. versionchanged:: 3.5 - Added the `collation` option. - """ - if filter is not None: - validate_is_mapping("filter", filter) - if hint is not None and not isinstance(hint, str): - self._hint: Union[str, dict[str, Any], None] = helpers._index_document(hint) - else: - self._hint = hint - self._filter = filter - self._collation = collation - - def _add_to_bulk(self, bulkobj: _Bulk) -> None: - """Add this operation to the _Bulk instance `bulkobj`.""" - bulkobj.add_delete( - self._filter, - 0, - collation=validate_collation_or_none(self._collation), - hint=self._hint, - ) - - def __repr__(self) -> str: - return f"DeleteMany({self._filter!r}, {self._collation!r}, {self._hint!r})" - - def __eq__(self, other: Any) -> bool: - if type(other) == type(self): - return (other._filter, other._collation, other._hint) == ( - self._filter, - self._collation, - self._hint, - ) - return NotImplemented - - def __ne__(self, other: Any) -> bool: - return not self == other - - -class ReplaceOne(Generic[_DocumentType]): - """Represents a replace_one operation.""" - - __slots__ = ("_filter", "_doc", "_upsert", "_collation", "_hint") - - def __init__( - self, - filter: Mapping[str, Any], - replacement: Union[_DocumentType, RawBSONDocument], - upsert: bool = False, - collation: Optional[_CollationIn] = None, - hint: Optional[_IndexKeyHint] = None, - ) -> None: - """Create a ReplaceOne instance. - - For use with :meth:`~pymongo.collection.Collection.bulk_write`. - - :param filter: A query that matches the document to replace. - :param replacement: The new document. - :param upsert: If ``True``, perform an insert if no documents - match the filter. - :param collation: An instance of - :class:`~pymongo.collation.Collation`. - :param hint: An index to use to support the query - predicate specified either by its string name, or in the same - format as passed to - :meth:`~pymongo.collection.Collection.create_index` (e.g. - ``[('field', ASCENDING)]``). This option is only supported on - MongoDB 4.2 and above. - - .. versionchanged:: 3.11 - Added the ``hint`` option. - .. versionchanged:: 3.5 - Added the ``collation`` option. - """ - if filter is not None: - validate_is_mapping("filter", filter) - if upsert is not None: - validate_boolean("upsert", upsert) - if hint is not None and not isinstance(hint, str): - self._hint: Union[str, dict[str, Any], None] = helpers._index_document(hint) - else: - self._hint = hint - self._filter = filter - self._doc = replacement - self._upsert = upsert - self._collation = collation - - def _add_to_bulk(self, bulkobj: _Bulk) -> None: - """Add this operation to the _Bulk instance `bulkobj`.""" - bulkobj.add_replace( - self._filter, - self._doc, - self._upsert, - collation=validate_collation_or_none(self._collation), - hint=self._hint, - ) - - def __eq__(self, other: Any) -> bool: - if type(other) == type(self): - return ( - other._filter, - other._doc, - other._upsert, - other._collation, - other._hint, - ) == ( - self._filter, - self._doc, - self._upsert, - self._collation, - other._hint, - ) - return NotImplemented - - def __ne__(self, other: Any) -> bool: - return not self == other - - def __repr__(self) -> str: - return "{}({!r}, {!r}, {!r}, {!r}, {!r})".format( - self.__class__.__name__, - self._filter, - self._doc, - self._upsert, - self._collation, - self._hint, - ) - - -class _UpdateOp: - """Private base class for update operations.""" - - __slots__ = ("_filter", "_doc", "_upsert", "_collation", "_array_filters", "_hint") - - def __init__( - self, - filter: Mapping[str, Any], - doc: Union[Mapping[str, Any], _Pipeline], - upsert: bool, - collation: Optional[_CollationIn], - array_filters: Optional[list[Mapping[str, Any]]], - hint: Optional[_IndexKeyHint], - ): - if filter is not None: - validate_is_mapping("filter", filter) - if upsert is not None: - validate_boolean("upsert", upsert) - if array_filters is not None: - validate_list("array_filters", array_filters) - if hint is not None and not isinstance(hint, str): - self._hint: Union[str, dict[str, Any], None] = helpers._index_document(hint) - else: - self._hint = hint - - self._filter = filter - self._doc = doc - self._upsert = upsert - self._collation = collation - self._array_filters = array_filters - - def __eq__(self, other: object) -> bool: - if isinstance(other, type(self)): - return ( - other._filter, - other._doc, - other._upsert, - other._collation, - other._array_filters, - other._hint, - ) == ( - self._filter, - self._doc, - self._upsert, - self._collation, - self._array_filters, - self._hint, - ) - return NotImplemented - - def __repr__(self) -> str: - return "{}({!r}, {!r}, {!r}, {!r}, {!r}, {!r})".format( - self.__class__.__name__, - self._filter, - self._doc, - self._upsert, - self._collation, - self._array_filters, - self._hint, - ) - - -class UpdateOne(_UpdateOp): - """Represents an update_one operation.""" - - __slots__ = () - - def __init__( - self, - filter: Mapping[str, Any], - update: Union[Mapping[str, Any], _Pipeline], - upsert: bool = False, - collation: Optional[_CollationIn] = None, - array_filters: Optional[list[Mapping[str, Any]]] = None, - hint: Optional[_IndexKeyHint] = None, - ) -> None: - """Represents an update_one operation. - - For use with :meth:`~pymongo.collection.Collection.bulk_write`. - - :param filter: A query that matches the document to update. - :param update: The modifications to apply. - :param upsert: If ``True``, perform an insert if no documents - match the filter. - :param collation: An instance of - :class:`~pymongo.collation.Collation`. - :param array_filters: A list of filters specifying which - array elements an update should apply. - :param hint: An index to use to support the query - predicate specified either by its string name, or in the same - format as passed to - :meth:`~pymongo.collection.Collection.create_index` (e.g. - ``[('field', ASCENDING)]``). This option is only supported on - MongoDB 4.2 and above. - - .. versionchanged:: 3.11 - Added the `hint` option. - .. versionchanged:: 3.9 - Added the ability to accept a pipeline as the `update`. - .. versionchanged:: 3.6 - Added the `array_filters` option. - .. versionchanged:: 3.5 - Added the `collation` option. - """ - super().__init__(filter, update, upsert, collation, array_filters, hint) - - def _add_to_bulk(self, bulkobj: _Bulk) -> None: - """Add this operation to the _Bulk instance `bulkobj`.""" - bulkobj.add_update( - self._filter, - self._doc, - False, - self._upsert, - collation=validate_collation_or_none(self._collation), - array_filters=self._array_filters, - hint=self._hint, - ) - - -class UpdateMany(_UpdateOp): - """Represents an update_many operation.""" - - __slots__ = () - - def __init__( - self, - filter: Mapping[str, Any], - update: Union[Mapping[str, Any], _Pipeline], - upsert: bool = False, - collation: Optional[_CollationIn] = None, - array_filters: Optional[list[Mapping[str, Any]]] = None, - hint: Optional[_IndexKeyHint] = None, - ) -> None: - """Create an UpdateMany instance. - - For use with :meth:`~pymongo.collection.Collection.bulk_write`. - - :param filter: A query that matches the documents to update. - :param update: The modifications to apply. - :param upsert: If ``True``, perform an insert if no documents - match the filter. - :param collation: An instance of - :class:`~pymongo.collation.Collation`. - :param array_filters: A list of filters specifying which - array elements an update should apply. - :param hint: An index to use to support the query - predicate specified either by its string name, or in the same - format as passed to - :meth:`~pymongo.collection.Collection.create_index` (e.g. - ``[('field', ASCENDING)]``). This option is only supported on - MongoDB 4.2 and above. - - .. versionchanged:: 3.11 - Added the `hint` option. - .. versionchanged:: 3.9 - Added the ability to accept a pipeline as the `update`. - .. versionchanged:: 3.6 - Added the `array_filters` option. - .. versionchanged:: 3.5 - Added the `collation` option. - """ - super().__init__(filter, update, upsert, collation, array_filters, hint) - - def _add_to_bulk(self, bulkobj: _Bulk) -> None: - """Add this operation to the _Bulk instance `bulkobj`.""" - bulkobj.add_update( - self._filter, - self._doc, - True, - self._upsert, - collation=validate_collation_or_none(self._collation), - array_filters=self._array_filters, - hint=self._hint, - ) - - -class IndexModel: - """Represents an index to create.""" - - __slots__ = ("__document",) - - def __init__(self, keys: _IndexKeyHint, **kwargs: Any) -> None: - """Create an Index instance. - - For use with :meth:`~pymongo.collection.Collection.create_indexes`. - - Takes either a single key or a list containing (key, direction) pairs - or keys. If no direction is given, :data:`~pymongo.ASCENDING` will - be assumed. - The key(s) must be an instance of :class:`str`, and the direction(s) must - be one of (:data:`~pymongo.ASCENDING`, :data:`~pymongo.DESCENDING`, - :data:`~pymongo.GEO2D`, :data:`~pymongo.GEOSPHERE`, - :data:`~pymongo.HASHED`, :data:`~pymongo.TEXT`). - - Valid options include, but are not limited to: - - - `name`: custom name to use for this index - if none is - given, a name will be generated. - - `unique`: if ``True``, creates a uniqueness constraint on the index. - - `background`: if ``True``, this index should be created in the - background. - - `sparse`: if ``True``, omit from the index any documents that lack - the indexed field. - - `bucketSize`: for use with geoHaystack indexes. - Number of documents to group together within a certain proximity - to a given longitude and latitude. - - `min`: minimum value for keys in a :data:`~pymongo.GEO2D` - index. - - `max`: maximum value for keys in a :data:`~pymongo.GEO2D` - index. - - `expireAfterSeconds`: Used to create an expiring (TTL) - collection. MongoDB will automatically delete documents from - this collection after seconds. The indexed field must - be a UTC datetime or the data will not expire. - - `partialFilterExpression`: A document that specifies a filter for - a partial index. - - `collation`: An instance of :class:`~pymongo.collation.Collation` - that specifies the collation to use. - - `wildcardProjection`: Allows users to include or exclude specific - field paths from a `wildcard index`_ using the { "$**" : 1} key - pattern. Requires MongoDB >= 4.2. - - `hidden`: if ``True``, this index will be hidden from the query - planner and will not be evaluated as part of query plan - selection. Requires MongoDB >= 4.4. - - See the MongoDB documentation for a full list of supported options by - server version. - - :param keys: a single key or a list containing (key, direction) pairs - or keys specifying the index to create. - :param kwargs: any additional index creation - options (see the above list) should be passed as keyword - arguments. - - .. versionchanged:: 3.11 - Added the ``hidden`` option. - .. versionchanged:: 3.2 - Added the ``partialFilterExpression`` option to support partial - indexes. - - .. _wildcard index: https://mongodb.com/docs/master/core/index-wildcard/ - """ - keys = _index_list(keys) - if kwargs.get("name") is None: - kwargs["name"] = _gen_index_name(keys) - kwargs["key"] = _index_document(keys) - collation = validate_collation_or_none(kwargs.pop("collation", None)) - self.__document = kwargs - if collation is not None: - self.__document["collation"] = collation - - @property - def document(self) -> dict[str, Any]: - """An index document suitable for passing to the createIndexes - command. - """ - return self.__document - - -class SearchIndexModel: - """Represents a search index to create.""" - - __slots__ = ("__document",) - - def __init__( - self, - definition: Mapping[str, Any], - name: Optional[str] = None, - type: Optional[str] = None, - **kwargs: Any, - ) -> None: - """Create a Search Index instance. - - For use with :meth:`~pymongo.collection.Collection.create_search_index` and :meth:`~pymongo.collection.Collection.create_search_indexes`. - - :param definition: The definition for this index. - :param name: The name for this index, if present. - :param type: The type for this index which defaults to "search". Alternative values include "vectorSearch". - :param kwargs: Keyword arguments supplying any additional options. - - .. note:: Search indexes require a MongoDB server version 7.0+ Atlas cluster. - .. versionadded:: 4.5 - .. versionchanged:: 4.7 - Added the type and kwargs arguments. - """ - self.__document: dict[str, Any] = {} - if name is not None: - self.__document["name"] = name - self.__document["definition"] = definition - if type is not None: - self.__document["type"] = type - self.__document.update(kwargs) - - @property - def document(self) -> Mapping[str, Any]: - """The document for this index.""" - return self.__document diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index 391db4e7a7..1637406ee5 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -16,17 +16,14 @@ import collections import contextlib -import copy import logging import os -import platform import socket import ssl import sys import threading import time import weakref -from pathlib import Path from typing import ( TYPE_CHECKING, Any, @@ -39,9 +36,15 @@ Union, ) -import bson from bson import DEFAULT_CODEC_OPTIONS -from pymongo import __version__, _csot +from pymongo import _csot, helpers_shared +from pymongo.common import ( + MAX_BSON_SIZE, + MAX_MESSAGE_SIZE, + MAX_WIRE_VERSION, + MAX_WRITE_BATCH_SIZE, + ORDERED_TYPES, +) from pymongo.errors import ( # type:ignore[attr-defined] AutoReconnect, ConfigurationError, @@ -56,62 +59,46 @@ WaitQueueTimeoutError, _CertificateError, ) +from pymongo.hello import Hello +from pymongo.hello_compat import HelloCompat from pymongo.lock import _create_lock -from pymongo.network_layer import sendall -from pymongo.server_api import _add_to_command -from pymongo.server_type import SERVER_TYPE -from pymongo.socket_checker import SocketChecker -from pymongo.ssl_support import HAS_SNI, SSLError -from pymongo.synchronous import helpers -from pymongo.synchronous.client_session import _validate_session_write_concern -from pymongo.synchronous.common import ( - MAX_BSON_SIZE, - MAX_CONNECTING, - MAX_IDLE_TIME_SEC, - MAX_MESSAGE_SIZE, - MAX_POOL_SIZE, - MAX_WIRE_VERSION, - MAX_WRITE_BATCH_SIZE, - MIN_POOL_SIZE, - ORDERED_TYPES, - WAIT_QUEUE_TIMEOUT, -) -from pymongo.synchronous.hello import Hello -from pymongo.synchronous.hello_compat import HelloCompat -from pymongo.synchronous.helpers import _handle_reauth -from pymongo.synchronous.logger import ( +from pymongo.logger import ( _CONNECTION_LOGGER, _ConnectionStatusMessage, _debug_log, _verbose_connection_error_reason, ) -from pymongo.synchronous.monitoring import ( +from pymongo.monitoring import ( ConnectionCheckOutFailedReason, ConnectionClosedReason, - _EventListeners, ) +from pymongo.network_layer import sendall +from pymongo.pool_options import PoolOptions +from pymongo.read_preferences import ReadPreference +from pymongo.server_api import _add_to_command +from pymongo.server_type import SERVER_TYPE +from pymongo.socket_checker import SocketChecker +from pymongo.ssl_support import HAS_SNI, SSLError +from pymongo.synchronous.client_session import _validate_session_write_concern +from pymongo.synchronous.helpers import _handle_reauth from pymongo.synchronous.network import command, receive_message -from pymongo.synchronous.read_preferences import ReadPreference if TYPE_CHECKING: from bson import CodecOptions from bson.objectid import ObjectId - from pymongo.driver_info import DriverInfo - from pymongo.pyopenssl_context import SSLContext, _sslConn - from pymongo.read_concern import ReadConcern - from pymongo.server_api import ServerApi - from pymongo.synchronous.auth import MongoCredential, _AuthContext - from pymongo.synchronous.client_session import ClientSession - from pymongo.synchronous.compression_support import ( - CompressionSettings, + from pymongo.compression_support import ( SnappyContext, ZlibContext, ZstdContext, ) + from pymongo.pyopenssl_context import _sslConn + from pymongo.read_concern import ReadConcern + from pymongo.read_preferences import _ServerMode + from pymongo.synchronous.auth import _AuthContext + from pymongo.synchronous.client_session import ClientSession from pymongo.synchronous.message import _OpMsg, _OpReply from pymongo.synchronous.mongo_client import MongoClient, _MongoClientErrorHandler - from pymongo.synchronous.read_preferences import _ServerMode - from pymongo.synchronous.typings import ClusterTime, _Address, _CollationIn + from pymongo.typings import ClusterTime, _Address, _CollationIn from pymongo.write_concern import WriteConcern try: @@ -191,217 +178,6 @@ def _set_keepalive_times(sock: socket.socket) -> None: _set_tcp_option(sock, "TCP_KEEPCNT", _MAX_TCP_KEEPCNT) -_METADATA: dict[str, Any] = {"driver": {"name": "PyMongo", "version": __version__}} - -if sys.platform.startswith("linux"): - # platform.linux_distribution was deprecated in Python 3.5 - # and removed in Python 3.8. Starting in Python 3.5 it - # raises DeprecationWarning - # DeprecationWarning: dist() and linux_distribution() functions are deprecated in Python 3.5 - _name = platform.system() - _METADATA["os"] = { - "type": _name, - "name": _name, - "architecture": platform.machine(), - # Kernel version (e.g. 4.4.0-17-generic). - "version": platform.release(), - } -elif sys.platform == "darwin": - _METADATA["os"] = { - "type": platform.system(), - "name": platform.system(), - "architecture": platform.machine(), - # (mac|i|tv)OS(X) version (e.g. 10.11.6) instead of darwin - # kernel version. - "version": platform.mac_ver()[0], - } -elif sys.platform == "win32": - _METADATA["os"] = { - "type": platform.system(), - # "Windows XP", "Windows 7", "Windows 10", etc. - "name": " ".join((platform.system(), platform.release())), - "architecture": platform.machine(), - # Windows patch level (e.g. 5.1.2600-SP3) - "version": "-".join(platform.win32_ver()[1:3]), - } -elif sys.platform.startswith("java"): - _name, _ver, _arch = platform.java_ver()[-1] - _METADATA["os"] = { - # Linux, Windows 7, Mac OS X, etc. - "type": _name, - "name": _name, - # x86, x86_64, AMD64, etc. - "architecture": _arch, - # Linux kernel version, OSX version, etc. - "version": _ver, - } -else: - # Get potential alias (e.g. SunOS 5.11 becomes Solaris 2.11) - _aliased = platform.system_alias(platform.system(), platform.release(), platform.version()) - _METADATA["os"] = { - "type": platform.system(), - "name": " ".join([part for part in _aliased[:2] if part]), - "architecture": platform.machine(), - "version": _aliased[2], - } - -if platform.python_implementation().startswith("PyPy"): - _METADATA["platform"] = " ".join( - ( - platform.python_implementation(), - ".".join(map(str, sys.pypy_version_info)), # type: ignore - "(Python %s)" % ".".join(map(str, sys.version_info)), - ) - ) -elif sys.platform.startswith("java"): - _METADATA["platform"] = " ".join( - ( - platform.python_implementation(), - ".".join(map(str, sys.version_info)), - "(%s)" % " ".join((platform.system(), platform.release())), - ) - ) -else: - _METADATA["platform"] = " ".join( - (platform.python_implementation(), ".".join(map(str, sys.version_info))) - ) - -DOCKER_ENV_PATH = "/.dockerenv" -ENV_VAR_K8S = "KUBERNETES_SERVICE_HOST" - -RUNTIME_NAME_DOCKER = "docker" -ORCHESTRATOR_NAME_K8S = "kubernetes" - - -def get_container_env_info() -> dict[str, str]: - """Returns the runtime and orchestrator of a container. - If neither value is present, the metadata client.env.container field will be omitted.""" - container = {} - - if Path(DOCKER_ENV_PATH).exists(): - container["runtime"] = RUNTIME_NAME_DOCKER - if os.getenv(ENV_VAR_K8S): - container["orchestrator"] = ORCHESTRATOR_NAME_K8S - - return container - - -def _is_lambda() -> bool: - if os.getenv("AWS_LAMBDA_RUNTIME_API"): - return True - env = os.getenv("AWS_EXECUTION_ENV") - if env: - return env.startswith("AWS_Lambda_") - return False - - -def _is_azure_func() -> bool: - return bool(os.getenv("FUNCTIONS_WORKER_RUNTIME")) - - -def _is_gcp_func() -> bool: - return bool(os.getenv("K_SERVICE") or os.getenv("FUNCTION_NAME")) - - -def _is_vercel() -> bool: - return bool(os.getenv("VERCEL")) - - -def _is_faas() -> bool: - return _is_lambda() or _is_azure_func() or _is_gcp_func() or _is_vercel() - - -def _getenv_int(key: str) -> Optional[int]: - """Like os.getenv but returns an int, or None if the value is missing/malformed.""" - val = os.getenv(key) - if not val: - return None - try: - return int(val) - except ValueError: - return None - - -def _metadata_env() -> dict[str, Any]: - env: dict[str, Any] = {} - container = get_container_env_info() - if container: - env["container"] = container - # Skip if multiple (or no) envs are matched. - if (_is_lambda(), _is_azure_func(), _is_gcp_func(), _is_vercel()).count(True) != 1: - return env - if _is_lambda(): - env["name"] = "aws.lambda" - region = os.getenv("AWS_REGION") - if region: - env["region"] = region - memory_mb = _getenv_int("AWS_LAMBDA_FUNCTION_MEMORY_SIZE") - if memory_mb is not None: - env["memory_mb"] = memory_mb - elif _is_azure_func(): - env["name"] = "azure.func" - elif _is_gcp_func(): - env["name"] = "gcp.func" - region = os.getenv("FUNCTION_REGION") - if region: - env["region"] = region - memory_mb = _getenv_int("FUNCTION_MEMORY_MB") - if memory_mb is not None: - env["memory_mb"] = memory_mb - timeout_sec = _getenv_int("FUNCTION_TIMEOUT_SEC") - if timeout_sec is not None: - env["timeout_sec"] = timeout_sec - elif _is_vercel(): - env["name"] = "vercel" - region = os.getenv("VERCEL_REGION") - if region: - env["region"] = region - return env - - -_MAX_METADATA_SIZE = 512 - - -# See: https://github.com/mongodb/specifications/blob/5112bcc/source/mongodb-handshake/handshake.rst#limitations -def _truncate_metadata(metadata: MutableMapping[str, Any]) -> None: - """Perform metadata truncation.""" - if len(bson.encode(metadata)) <= _MAX_METADATA_SIZE: - return - # 1. Omit fields from env except env.name. - env_name = metadata.get("env", {}).get("name") - if env_name: - metadata["env"] = {"name": env_name} - if len(bson.encode(metadata)) <= _MAX_METADATA_SIZE: - return - # 2. Omit fields from os except os.type. - os_type = metadata.get("os", {}).get("type") - if os_type: - metadata["os"] = {"type": os_type} - if len(bson.encode(metadata)) <= _MAX_METADATA_SIZE: - return - # 3. Omit the env document entirely. - metadata.pop("env", None) - encoded_size = len(bson.encode(metadata)) - if encoded_size <= _MAX_METADATA_SIZE: - return - # 4. Truncate platform. - overflow = encoded_size - _MAX_METADATA_SIZE - plat = metadata.get("platform", "") - if plat: - plat = plat[:-overflow] - if plat: - metadata["platform"] = plat - else: - metadata.pop("platform", None) - - -# If the first getaddrinfo call of this interpreter's life is on a thread, -# while the main thread holds the import lock, getaddrinfo deadlocks trying -# to import the IDNA codec. Import it here, where presumably we're on the -# main thread, to avoid the deadlock. See PYTHON-607. -"foo".encode("idna") - - def _raise_connection_failure( address: Any, error: Exception, @@ -462,238 +238,6 @@ def format_timeout_details(details: Optional[dict[str, float]]) -> str: return result -class PoolOptions: - """Read only connection pool options for a MongoClient. - - Should not be instantiated directly by application developers. Access - a client's pool options via - :attr:`~pymongo.client_options.ClientOptions.pool_options` instead:: - - pool_opts = client.options.pool_options - pool_opts.max_pool_size - pool_opts.min_pool_size - - """ - - __slots__ = ( - "__max_pool_size", - "__min_pool_size", - "__max_idle_time_seconds", - "__connect_timeout", - "__socket_timeout", - "__wait_queue_timeout", - "__ssl_context", - "__tls_allow_invalid_hostnames", - "__event_listeners", - "__appname", - "__driver", - "__metadata", - "__compression_settings", - "__max_connecting", - "__pause_enabled", - "__server_api", - "__load_balanced", - "__credentials", - ) - - def __init__( - self, - max_pool_size: int = MAX_POOL_SIZE, - min_pool_size: int = MIN_POOL_SIZE, - max_idle_time_seconds: Optional[int] = MAX_IDLE_TIME_SEC, - connect_timeout: Optional[float] = None, - socket_timeout: Optional[float] = None, - wait_queue_timeout: Optional[int] = WAIT_QUEUE_TIMEOUT, - ssl_context: Optional[SSLContext] = None, - tls_allow_invalid_hostnames: bool = False, - event_listeners: Optional[_EventListeners] = None, - appname: Optional[str] = None, - driver: Optional[DriverInfo] = None, - compression_settings: Optional[CompressionSettings] = None, - max_connecting: int = MAX_CONNECTING, - pause_enabled: bool = True, - server_api: Optional[ServerApi] = None, - load_balanced: Optional[bool] = None, - credentials: Optional[MongoCredential] = None, - ): - self.__max_pool_size = max_pool_size - self.__min_pool_size = min_pool_size - self.__max_idle_time_seconds = max_idle_time_seconds - self.__connect_timeout = connect_timeout - self.__socket_timeout = socket_timeout - self.__wait_queue_timeout = wait_queue_timeout - self.__ssl_context = ssl_context - self.__tls_allow_invalid_hostnames = tls_allow_invalid_hostnames - self.__event_listeners = event_listeners - self.__appname = appname - self.__driver = driver - self.__compression_settings = compression_settings - self.__max_connecting = max_connecting - self.__pause_enabled = pause_enabled - self.__server_api = server_api - self.__load_balanced = load_balanced - self.__credentials = credentials - self.__metadata = copy.deepcopy(_METADATA) - if appname: - self.__metadata["application"] = {"name": appname} - - # Combine the "driver" MongoClient option with PyMongo's info, like: - # { - # 'driver': { - # 'name': 'PyMongo|MyDriver', - # 'version': '4.2.0|1.2.3', - # }, - # 'platform': 'CPython 3.8.0|MyPlatform' - # } - if driver: - if driver.name: - self.__metadata["driver"]["name"] = "{}|{}".format( - _METADATA["driver"]["name"], - driver.name, - ) - if driver.version: - self.__metadata["driver"]["version"] = "{}|{}".format( - _METADATA["driver"]["version"], - driver.version, - ) - if driver.platform: - self.__metadata["platform"] = "{}|{}".format(_METADATA["platform"], driver.platform) - - env = _metadata_env() - if env: - self.__metadata["env"] = env - - _truncate_metadata(self.__metadata) - - @property - def _credentials(self) -> Optional[MongoCredential]: - """A :class:`~pymongo.auth.MongoCredentials` instance or None.""" - return self.__credentials - - @property - def non_default_options(self) -> dict[str, Any]: - """The non-default options this pool was created with. - - Added for CMAP's :class:`PoolCreatedEvent`. - """ - opts = {} - if self.__max_pool_size != MAX_POOL_SIZE: - opts["maxPoolSize"] = self.__max_pool_size - if self.__min_pool_size != MIN_POOL_SIZE: - opts["minPoolSize"] = self.__min_pool_size - if self.__max_idle_time_seconds != MAX_IDLE_TIME_SEC: - assert self.__max_idle_time_seconds is not None - opts["maxIdleTimeMS"] = self.__max_idle_time_seconds * 1000 - if self.__wait_queue_timeout != WAIT_QUEUE_TIMEOUT: - assert self.__wait_queue_timeout is not None - opts["waitQueueTimeoutMS"] = self.__wait_queue_timeout * 1000 - if self.__max_connecting != MAX_CONNECTING: - opts["maxConnecting"] = self.__max_connecting - return opts - - @property - def max_pool_size(self) -> float: - """The maximum allowable number of concurrent connections to each - connected server. Requests to a server will block if there are - `maxPoolSize` outstanding connections to the requested server. - Defaults to 100. Cannot be 0. - - When a server's pool has reached `max_pool_size`, operations for that - server block waiting for a socket to be returned to the pool. If - ``waitQueueTimeoutMS`` is set, a blocked operation will raise - :exc:`~pymongo.errors.ConnectionFailure` after a timeout. - By default ``waitQueueTimeoutMS`` is not set. - """ - return self.__max_pool_size - - @property - def min_pool_size(self) -> int: - """The minimum required number of concurrent connections that the pool - will maintain to each connected server. Default is 0. - """ - return self.__min_pool_size - - @property - def max_connecting(self) -> int: - """The maximum number of concurrent connection creation attempts per - pool. Defaults to 2. - """ - return self.__max_connecting - - @property - def pause_enabled(self) -> bool: - return self.__pause_enabled - - @property - def max_idle_time_seconds(self) -> Optional[int]: - """The maximum number of seconds that a connection can remain - idle in the pool before being removed and replaced. Defaults to - `None` (no limit). - """ - return self.__max_idle_time_seconds - - @property - def connect_timeout(self) -> Optional[float]: - """How long a connection can take to be opened before timing out.""" - return self.__connect_timeout - - @property - def socket_timeout(self) -> Optional[float]: - """How long a send or receive on a socket can take before timing out.""" - return self.__socket_timeout - - @property - def wait_queue_timeout(self) -> Optional[int]: - """How long a thread will wait for a socket from the pool if the pool - has no free sockets. - """ - return self.__wait_queue_timeout - - @property - def _ssl_context(self) -> Optional[SSLContext]: - """An SSLContext instance or None.""" - return self.__ssl_context - - @property - def tls_allow_invalid_hostnames(self) -> bool: - """If True skip ssl.match_hostname.""" - return self.__tls_allow_invalid_hostnames - - @property - def _event_listeners(self) -> Optional[_EventListeners]: - """An instance of pymongo.monitoring._EventListeners.""" - return self.__event_listeners - - @property - def appname(self) -> Optional[str]: - """The application name, for sending with hello in server handshake.""" - return self.__appname - - @property - def driver(self) -> Optional[DriverInfo]: - """Driver name and version, for sending with hello in handshake.""" - return self.__driver - - @property - def _compression_settings(self) -> Optional[CompressionSettings]: - return self.__compression_settings - - @property - def metadata(self) -> dict[str, Any]: - """A dict of metadata about the application, driver, os, and platform.""" - return self.__metadata.copy() - - @property - def server_api(self) -> Optional[ServerApi]: - """A pymongo.server_api.ServerApi or None.""" - return self.__server_api - - @property - def load_balanced(self) -> Optional[bool]: - """True if this Pool is configured in load balanced mode.""" - return self.__load_balanced - - class _CancellationContext: def __init__(self) -> None: self._cancelled = False @@ -926,7 +470,7 @@ def _next_reply(self) -> dict[str, Any]: self.more_to_come = reply.more_to_come unpacked_docs = reply.unpack_response() response_doc = unpacked_docs[0] - helpers._check_command_response(response_doc, self.max_wire_version) + helpers_shared._check_command_response(response_doc, self.max_wire_version) return response_doc @_handle_reauth @@ -1079,7 +623,7 @@ def write_command( result = reply.command_response(codec_options) # Raises NotPrimaryError or OperationFailure. - helpers._check_command_response(result, self.max_wire_version) + helpers_shared._check_command_response(result, self.max_wire_version) return result def authenticate(self, reauthenticate: bool = False) -> None: diff --git a/pymongo/synchronous/read_preferences.py b/pymongo/synchronous/read_preferences.py deleted file mode 100644 index 464256c343..0000000000 --- a/pymongo/synchronous/read_preferences.py +++ /dev/null @@ -1,624 +0,0 @@ -# Copyright 2012-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License", -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Utilities for choosing which member of a replica set to read from.""" - -from __future__ import annotations - -from collections import abc -from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence - -from pymongo.errors import ConfigurationError -from pymongo.synchronous import max_staleness_selectors -from pymongo.synchronous.server_selectors import ( - member_with_tags_server_selector, - secondary_with_tags_server_selector, -) - -if TYPE_CHECKING: - from pymongo.synchronous.server_selectors import Selection - from pymongo.synchronous.topology_description import TopologyDescription - -_IS_SYNC = True - -_PRIMARY = 0 -_PRIMARY_PREFERRED = 1 -_SECONDARY = 2 -_SECONDARY_PREFERRED = 3 -_NEAREST = 4 - - -_MONGOS_MODES = ( - "primary", - "primaryPreferred", - "secondary", - "secondaryPreferred", - "nearest", -) - -_Hedge = Mapping[str, Any] -_TagSets = Sequence[Mapping[str, Any]] - - -def _validate_tag_sets(tag_sets: Optional[_TagSets]) -> Optional[_TagSets]: - """Validate tag sets for a MongoClient.""" - if tag_sets is None: - return tag_sets - - if not isinstance(tag_sets, (list, tuple)): - raise TypeError(f"Tag sets {tag_sets!r} invalid, must be a sequence") - if len(tag_sets) == 0: - raise ValueError( - f"Tag sets {tag_sets!r} invalid, must be None or contain at least one set of tags" - ) - - for tags in tag_sets: - if not isinstance(tags, abc.Mapping): - raise TypeError( - f"Tag set {tags!r} invalid, must be an instance of dict, " - "bson.son.SON or other type that inherits from " - "collection.Mapping" - ) - - return list(tag_sets) - - -def _invalid_max_staleness_msg(max_staleness: Any) -> str: - return "maxStalenessSeconds must be a positive integer, not %s" % max_staleness - - -# Some duplication with common.py to avoid import cycle. -def _validate_max_staleness(max_staleness: Any) -> int: - """Validate max_staleness.""" - if max_staleness == -1: - return -1 - - if not isinstance(max_staleness, int): - raise TypeError(_invalid_max_staleness_msg(max_staleness)) - - if max_staleness <= 0: - raise ValueError(_invalid_max_staleness_msg(max_staleness)) - - return max_staleness - - -def _validate_hedge(hedge: Optional[_Hedge]) -> Optional[_Hedge]: - """Validate hedge.""" - if hedge is None: - return None - - if not isinstance(hedge, dict): - raise TypeError(f"hedge must be a dictionary, not {hedge!r}") - - return hedge - - -class _ServerMode: - """Base class for all read preferences.""" - - __slots__ = ("__mongos_mode", "__mode", "__tag_sets", "__max_staleness", "__hedge") - - def __init__( - self, - mode: int, - tag_sets: Optional[_TagSets] = None, - max_staleness: int = -1, - hedge: Optional[_Hedge] = None, - ) -> None: - self.__mongos_mode = _MONGOS_MODES[mode] - self.__mode = mode - self.__tag_sets = _validate_tag_sets(tag_sets) - self.__max_staleness = _validate_max_staleness(max_staleness) - self.__hedge = _validate_hedge(hedge) - - @property - def name(self) -> str: - """The name of this read preference.""" - return self.__class__.__name__ - - @property - def mongos_mode(self) -> str: - """The mongos mode of this read preference.""" - return self.__mongos_mode - - @property - def document(self) -> dict[str, Any]: - """Read preference as a document.""" - doc: dict[str, Any] = {"mode": self.__mongos_mode} - if self.__tag_sets not in (None, [{}]): - doc["tags"] = self.__tag_sets - if self.__max_staleness != -1: - doc["maxStalenessSeconds"] = self.__max_staleness - if self.__hedge not in (None, {}): - doc["hedge"] = self.__hedge - return doc - - @property - def mode(self) -> int: - """The mode of this read preference instance.""" - return self.__mode - - @property - def tag_sets(self) -> _TagSets: - """Set ``tag_sets`` to a list of dictionaries like [{'dc': 'ny'}] to - read only from members whose ``dc`` tag has the value ``"ny"``. - To specify a priority-order for tag sets, provide a list of - tag sets: ``[{'dc': 'ny'}, {'dc': 'la'}, {}]``. A final, empty tag - set, ``{}``, means "read from any member that matches the mode, - ignoring tags." MongoClient tries each set of tags in turn - until it finds a set of tags with at least one matching member. - For example, to only send a query to an analytic node:: - - Nearest(tag_sets=[{"node":"analytics"}]) - - Or using :class:`SecondaryPreferred`:: - - SecondaryPreferred(tag_sets=[{"node":"analytics"}]) - - .. seealso:: `Data-Center Awareness - `_ - """ - return list(self.__tag_sets) if self.__tag_sets else [{}] - - @property - def max_staleness(self) -> int: - """The maximum estimated length of time (in seconds) a replica set - secondary can fall behind the primary in replication before it will - no longer be selected for operations, or -1 for no maximum. - """ - return self.__max_staleness - - @property - def hedge(self) -> Optional[_Hedge]: - """The read preference ``hedge`` parameter. - - A dictionary that configures how the server will perform hedged reads. - It consists of the following keys: - - - ``enabled``: Enables or disables hedged reads in sharded clusters. - - Hedged reads are automatically enabled in MongoDB 4.4+ when using a - ``nearest`` read preference. To explicitly enable hedged reads, set - the ``enabled`` key to ``true``:: - - >>> Nearest(hedge={'enabled': True}) - - To explicitly disable hedged reads, set the ``enabled`` key to - ``False``:: - - >>> Nearest(hedge={'enabled': False}) - - .. versionadded:: 3.11 - """ - return self.__hedge - - @property - def min_wire_version(self) -> int: - """The wire protocol version the server must support. - - Some read preferences impose version requirements on all servers (e.g. - maxStalenessSeconds requires MongoDB 3.4 / maxWireVersion 5). - - All servers' maxWireVersion must be at least this read preference's - `min_wire_version`, or the driver raises - :exc:`~pymongo.errors.ConfigurationError`. - """ - return 0 if self.__max_staleness == -1 else 5 - - def __repr__(self) -> str: - return "{}(tag_sets={!r}, max_staleness={!r}, hedge={!r})".format( - self.name, - self.__tag_sets, - self.__max_staleness, - self.__hedge, - ) - - def __eq__(self, other: Any) -> bool: - if isinstance(other, _ServerMode): - return ( - self.mode == other.mode - and self.tag_sets == other.tag_sets - and self.max_staleness == other.max_staleness - and self.hedge == other.hedge - ) - return NotImplemented - - def __ne__(self, other: Any) -> bool: - return not self == other - - def __getstate__(self) -> dict[str, Any]: - """Return value of object for pickling. - - Needed explicitly because __slots__() defined. - """ - return { - "mode": self.__mode, - "tag_sets": self.__tag_sets, - "max_staleness": self.__max_staleness, - "hedge": self.__hedge, - } - - def __setstate__(self, value: Mapping[str, Any]) -> None: - """Restore from pickling.""" - self.__mode = value["mode"] - self.__mongos_mode = _MONGOS_MODES[self.__mode] - self.__tag_sets = _validate_tag_sets(value["tag_sets"]) - self.__max_staleness = _validate_max_staleness(value["max_staleness"]) - self.__hedge = _validate_hedge(value["hedge"]) - - def __call__(self, selection: Selection) -> Selection: - return selection - - -class Primary(_ServerMode): - """Primary read preference. - - * When directly connected to one mongod queries are allowed if the server - is standalone or a replica set primary. - * When connected to a mongos queries are sent to the primary of a shard. - * When connected to a replica set queries are sent to the primary of - the replica set. - """ - - __slots__ = () - - def __init__(self) -> None: - super().__init__(_PRIMARY) - - def __call__(self, selection: Selection) -> Selection: - """Apply this read preference to a Selection.""" - return selection.primary_selection - - def __repr__(self) -> str: - return "Primary()" - - def __eq__(self, other: Any) -> bool: - if isinstance(other, _ServerMode): - return other.mode == _PRIMARY - return NotImplemented - - -class PrimaryPreferred(_ServerMode): - """PrimaryPreferred read preference. - - * When directly connected to one mongod queries are allowed to standalone - servers, to a replica set primary, or to replica set secondaries. - * When connected to a mongos queries are sent to the primary of a shard if - available, otherwise a shard secondary. - * When connected to a replica set queries are sent to the primary if - available, otherwise a secondary. - - .. note:: When a :class:`~pymongo.mongo_client.MongoClient` is first - created reads will be routed to an available secondary until the - primary of the replica set is discovered. - - :param tag_sets: The :attr:`~tag_sets` to use if the primary is not - available. - :param max_staleness: (integer, in seconds) The maximum estimated - length of time a replica set secondary can fall behind the primary in - replication before it will no longer be selected for operations. - Default -1, meaning no maximum. If it is set, it must be at least - 90 seconds. - :param hedge: The :attr:`~hedge` to use if the primary is not available. - - .. versionchanged:: 3.11 - Added ``hedge`` parameter. - """ - - __slots__ = () - - def __init__( - self, - tag_sets: Optional[_TagSets] = None, - max_staleness: int = -1, - hedge: Optional[_Hedge] = None, - ) -> None: - super().__init__(_PRIMARY_PREFERRED, tag_sets, max_staleness, hedge) - - def __call__(self, selection: Selection) -> Selection: - """Apply this read preference to Selection.""" - if selection.primary: - return selection.primary_selection - else: - return secondary_with_tags_server_selector( - self.tag_sets, max_staleness_selectors.select(self.max_staleness, selection) - ) - - -class Secondary(_ServerMode): - """Secondary read preference. - - * When directly connected to one mongod queries are allowed to standalone - servers, to a replica set primary, or to replica set secondaries. - * When connected to a mongos queries are distributed among shard - secondaries. An error is raised if no secondaries are available. - * When connected to a replica set queries are distributed among - secondaries. An error is raised if no secondaries are available. - - :param tag_sets: The :attr:`~tag_sets` for this read preference. - :param max_staleness: (integer, in seconds) The maximum estimated - length of time a replica set secondary can fall behind the primary in - replication before it will no longer be selected for operations. - Default -1, meaning no maximum. If it is set, it must be at least - 90 seconds. - :param hedge: The :attr:`~hedge` for this read preference. - - .. versionchanged:: 3.11 - Added ``hedge`` parameter. - """ - - __slots__ = () - - def __init__( - self, - tag_sets: Optional[_TagSets] = None, - max_staleness: int = -1, - hedge: Optional[_Hedge] = None, - ) -> None: - super().__init__(_SECONDARY, tag_sets, max_staleness, hedge) - - def __call__(self, selection: Selection) -> Selection: - """Apply this read preference to Selection.""" - return secondary_with_tags_server_selector( - self.tag_sets, max_staleness_selectors.select(self.max_staleness, selection) - ) - - -class SecondaryPreferred(_ServerMode): - """SecondaryPreferred read preference. - - * When directly connected to one mongod queries are allowed to standalone - servers, to a replica set primary, or to replica set secondaries. - * When connected to a mongos queries are distributed among shard - secondaries, or the shard primary if no secondary is available. - * When connected to a replica set queries are distributed among - secondaries, or the primary if no secondary is available. - - .. note:: When a :class:`~pymongo.mongo_client.MongoClient` is first - created reads will be routed to the primary of the replica set until - an available secondary is discovered. - - :param tag_sets: The :attr:`~tag_sets` for this read preference. - :param max_staleness: (integer, in seconds) The maximum estimated - length of time a replica set secondary can fall behind the primary in - replication before it will no longer be selected for operations. - Default -1, meaning no maximum. If it is set, it must be at least - 90 seconds. - :param hedge: The :attr:`~hedge` for this read preference. - - .. versionchanged:: 3.11 - Added ``hedge`` parameter. - """ - - __slots__ = () - - def __init__( - self, - tag_sets: Optional[_TagSets] = None, - max_staleness: int = -1, - hedge: Optional[_Hedge] = None, - ) -> None: - super().__init__(_SECONDARY_PREFERRED, tag_sets, max_staleness, hedge) - - def __call__(self, selection: Selection) -> Selection: - """Apply this read preference to Selection.""" - secondaries = secondary_with_tags_server_selector( - self.tag_sets, max_staleness_selectors.select(self.max_staleness, selection) - ) - - if secondaries: - return secondaries - else: - return selection.primary_selection - - -class Nearest(_ServerMode): - """Nearest read preference. - - * When directly connected to one mongod queries are allowed to standalone - servers, to a replica set primary, or to replica set secondaries. - * When connected to a mongos queries are distributed among all members of - a shard. - * When connected to a replica set queries are distributed among all - members. - - :param tag_sets: The :attr:`~tag_sets` for this read preference. - :param max_staleness: (integer, in seconds) The maximum estimated - length of time a replica set secondary can fall behind the primary in - replication before it will no longer be selected for operations. - Default -1, meaning no maximum. If it is set, it must be at least - 90 seconds. - :param hedge: The :attr:`~hedge` for this read preference. - - .. versionchanged:: 3.11 - Added ``hedge`` parameter. - """ - - __slots__ = () - - def __init__( - self, - tag_sets: Optional[_TagSets] = None, - max_staleness: int = -1, - hedge: Optional[_Hedge] = None, - ) -> None: - super().__init__(_NEAREST, tag_sets, max_staleness, hedge) - - def __call__(self, selection: Selection) -> Selection: - """Apply this read preference to Selection.""" - return member_with_tags_server_selector( - self.tag_sets, max_staleness_selectors.select(self.max_staleness, selection) - ) - - -class _AggWritePref: - """Agg $out/$merge write preference. - - * If there are readable servers and there is any pre-5.0 server, use - primary read preference. - * Otherwise use `pref` read preference. - - :param pref: The read preference to use on MongoDB 5.0+. - """ - - __slots__ = ("pref", "effective_pref") - - def __init__(self, pref: _ServerMode): - self.pref = pref - self.effective_pref: _ServerMode = ReadPreference.PRIMARY - - def selection_hook(self, topology_description: TopologyDescription) -> None: - common_wv = topology_description.common_wire_version - if ( - topology_description.has_readable_server(ReadPreference.PRIMARY_PREFERRED) - and common_wv - and common_wv < 13 - ): - self.effective_pref = ReadPreference.PRIMARY - else: - self.effective_pref = self.pref - - def __call__(self, selection: Selection) -> Selection: - """Apply this read preference to a Selection.""" - return self.effective_pref(selection) - - def __repr__(self) -> str: - return f"_AggWritePref(pref={self.pref!r})" - - # Proxy other calls to the effective_pref so that _AggWritePref can be - # used in place of an actual read preference. - def __getattr__(self, name: str) -> Any: - return getattr(self.effective_pref, name) - - -_ALL_READ_PREFERENCES = (Primary, PrimaryPreferred, Secondary, SecondaryPreferred, Nearest) - - -def make_read_preference( - mode: int, tag_sets: Optional[_TagSets], max_staleness: int = -1 -) -> _ServerMode: - if mode == _PRIMARY: - if tag_sets not in (None, [{}]): - raise ConfigurationError("Read preference primary cannot be combined with tags") - if max_staleness != -1: - raise ConfigurationError( - "Read preference primary cannot be combined with maxStalenessSeconds" - ) - return Primary() - return _ALL_READ_PREFERENCES[mode](tag_sets, max_staleness) # type: ignore - - -_MODES = ( - "PRIMARY", - "PRIMARY_PREFERRED", - "SECONDARY", - "SECONDARY_PREFERRED", - "NEAREST", -) - - -class ReadPreference: - """An enum that defines some commonly used read preference modes. - - Apps can also create a custom read preference, for example:: - - Nearest(tag_sets=[{"node":"analytics"}]) - - See :doc:`/examples/high_availability` for code examples. - - A read preference is used in three cases: - - :class:`~pymongo.mongo_client.MongoClient` connected to a single mongod: - - - ``PRIMARY``: Queries are allowed if the server is standalone or a replica - set primary. - - All other modes allow queries to standalone servers, to a replica set - primary, or to replica set secondaries. - - :class:`~pymongo.mongo_client.MongoClient` initialized with the - ``replicaSet`` option: - - - ``PRIMARY``: Read from the primary. This is the default, and provides the - strongest consistency. If no primary is available, raise - :class:`~pymongo.errors.AutoReconnect`. - - - ``PRIMARY_PREFERRED``: Read from the primary if available, or if there is - none, read from a secondary. - - - ``SECONDARY``: Read from a secondary. If no secondary is available, - raise :class:`~pymongo.errors.AutoReconnect`. - - - ``SECONDARY_PREFERRED``: Read from a secondary if available, otherwise - from the primary. - - - ``NEAREST``: Read from any member. - - :class:`~pymongo.mongo_client.MongoClient` connected to a mongos, with a - sharded cluster of replica sets: - - - ``PRIMARY``: Read from the primary of the shard, or raise - :class:`~pymongo.errors.OperationFailure` if there is none. - This is the default. - - - ``PRIMARY_PREFERRED``: Read from the primary of the shard, or if there is - none, read from a secondary of the shard. - - - ``SECONDARY``: Read from a secondary of the shard, or raise - :class:`~pymongo.errors.OperationFailure` if there is none. - - - ``SECONDARY_PREFERRED``: Read from a secondary of the shard if available, - otherwise from the shard primary. - - - ``NEAREST``: Read from any shard member. - """ - - PRIMARY = Primary() - PRIMARY_PREFERRED = PrimaryPreferred() - SECONDARY = Secondary() - SECONDARY_PREFERRED = SecondaryPreferred() - NEAREST = Nearest() - - -def read_pref_mode_from_name(name: str) -> int: - """Get the read preference mode from mongos/uri name.""" - return _MONGOS_MODES.index(name) - - -class MovingAverage: - """Tracks an exponentially-weighted moving average.""" - - average: Optional[float] - - def __init__(self) -> None: - self.average = None - - def add_sample(self, sample: float) -> None: - if sample < 0: - # Likely system time change while waiting for hello response - # and not using time.monotonic. Ignore it, the next one will - # probably be valid. - return - if self.average is None: - self.average = sample - else: - # The Server Selection Spec requires an exponentially weighted - # average with alpha = 0.2. - self.average = 0.8 * self.average + 0.2 * sample - - def get(self) -> Optional[float]: - """Get the calculated average, or None if no samples yet.""" - return self.average - - def reset(self) -> None: - self.average = None diff --git a/pymongo/synchronous/response.py b/pymongo/synchronous/response.py index 94fd4df508..03b88fc77d 100644 --- a/pymongo/synchronous/response.py +++ b/pymongo/synchronous/response.py @@ -22,7 +22,7 @@ from pymongo.synchronous.message import _OpMsg, _OpReply from pymongo.synchronous.pool import Connection - from pymongo.synchronous.typings import _Address, _DocumentOut + from pymongo.typings import _Address, _DocumentOut _IS_SYNC = True @@ -103,7 +103,7 @@ def __init__( :param data: A network response message. :param address: (host, port) of the source server. - :param conn: The Connection used for the initial query. + :param conn: The AsyncConnection used for the initial query. :param request_id: The request id of this operation. :param duration: The duration of the operation. :param from_command: If the response is the result of a db command. @@ -117,7 +117,7 @@ def __init__( @property def conn(self) -> Connection: - """The Connection used for the initial query. + """The AsyncConnection used for the initial query. The server will send batches on this socket, without waiting for getMores from the client, until the result set is exhausted or there diff --git a/pymongo/synchronous/server.py b/pymongo/synchronous/server.py index 4c79569992..883c802e07 100644 --- a/pymongo/synchronous/server.py +++ b/pymongo/synchronous/server.py @@ -28,23 +28,24 @@ from bson import _decode_all_selective from pymongo.errors import NotPrimaryError, OperationFailure -from pymongo.synchronous.helpers import _check_command_response, _handle_reauth -from pymongo.synchronous.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log +from pymongo.helpers_shared import _check_command_response +from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log +from pymongo.response import PinnedResponse, Response +from pymongo.synchronous.helpers import _handle_reauth from pymongo.synchronous.message import _convert_exception, _GetMore, _OpMsg, _Query -from pymongo.synchronous.response import PinnedResponse, Response if TYPE_CHECKING: from queue import Queue from weakref import ReferenceType from bson.objectid import ObjectId + from pymongo.monitoring import _EventListeners + from pymongo.read_preferences import _ServerMode + from pymongo.server_description import ServerDescription from pymongo.synchronous.mongo_client import MongoClient, _MongoClientErrorHandler from pymongo.synchronous.monitor import Monitor - from pymongo.synchronous.monitoring import _EventListeners from pymongo.synchronous.pool import Connection, Pool - from pymongo.synchronous.read_preferences import _ServerMode - from pymongo.synchronous.server_description import ServerDescription - from pymongo.synchronous.typings import _DocumentOut + from pymongo.typings import _DocumentOut _IS_SYNC = True @@ -121,7 +122,7 @@ def run_operation( cursors. Can raise ConnectionFailure, OperationFailure, etc. - :param conn: A Connection instance. + :param conn: A AsyncConnection instance. :param operation: A _Query or _GetMore object. :param read_preference: The read preference to use. :param listeners: Instance of _EventListeners or None. diff --git a/pymongo/synchronous/server_description.py b/pymongo/synchronous/server_description.py deleted file mode 100644 index 4a23fc1293..0000000000 --- a/pymongo/synchronous/server_description.py +++ /dev/null @@ -1,301 +0,0 @@ -# Copyright 2014-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Represent one server the driver is connected to.""" -from __future__ import annotations - -import time -import warnings -from typing import Any, Mapping, Optional - -from bson import EPOCH_NAIVE -from bson.objectid import ObjectId -from pymongo.server_type import SERVER_TYPE -from pymongo.synchronous.hello import Hello -from pymongo.synchronous.typings import ClusterTime, _Address - -_IS_SYNC = True - - -class ServerDescription: - """Immutable representation of one server. - - :param address: A (host, port) pair - :param hello: Optional Hello instance - :param round_trip_time: Optional float - :param error: Optional, the last error attempting to connect to the server - :param round_trip_time: Optional float, the min latency from the most recent samples - """ - - __slots__ = ( - "_address", - "_server_type", - "_all_hosts", - "_tags", - "_replica_set_name", - "_primary", - "_max_bson_size", - "_max_message_size", - "_max_write_batch_size", - "_min_wire_version", - "_max_wire_version", - "_round_trip_time", - "_min_round_trip_time", - "_me", - "_is_writable", - "_is_readable", - "_ls_timeout_minutes", - "_error", - "_set_version", - "_election_id", - "_cluster_time", - "_last_write_date", - "_last_update_time", - "_topology_version", - ) - - def __init__( - self, - address: _Address, - hello: Optional[Hello] = None, - round_trip_time: Optional[float] = None, - error: Optional[Exception] = None, - min_round_trip_time: float = 0.0, - ) -> None: - self._address = address - if not hello: - hello = Hello({}) - - self._server_type = hello.server_type - self._all_hosts = hello.all_hosts - self._tags = hello.tags - self._replica_set_name = hello.replica_set_name - self._primary = hello.primary - self._max_bson_size = hello.max_bson_size - self._max_message_size = hello.max_message_size - self._max_write_batch_size = hello.max_write_batch_size - self._min_wire_version = hello.min_wire_version - self._max_wire_version = hello.max_wire_version - self._set_version = hello.set_version - self._election_id = hello.election_id - self._cluster_time = hello.cluster_time - self._is_writable = hello.is_writable - self._is_readable = hello.is_readable - self._ls_timeout_minutes = hello.logical_session_timeout_minutes - self._round_trip_time = round_trip_time - self._min_round_trip_time = min_round_trip_time - self._me = hello.me - self._last_update_time = time.monotonic() - self._error = error - self._topology_version = hello.topology_version - if error: - details = getattr(error, "details", None) - if isinstance(details, dict): - self._topology_version = details.get("topologyVersion") - - self._last_write_date: Optional[float] - if hello.last_write_date: - # Convert from datetime to seconds. - delta = hello.last_write_date - EPOCH_NAIVE - self._last_write_date = delta.total_seconds() - else: - self._last_write_date = None - - @property - def address(self) -> _Address: - """The address (host, port) of this server.""" - return self._address - - @property - def server_type(self) -> int: - """The type of this server.""" - return self._server_type - - @property - def server_type_name(self) -> str: - """The server type as a human readable string. - - .. versionadded:: 3.4 - """ - return SERVER_TYPE._fields[self._server_type] - - @property - def all_hosts(self) -> set[tuple[str, int]]: - """List of hosts, passives, and arbiters known to this server.""" - return self._all_hosts - - @property - def tags(self) -> Mapping[str, Any]: - return self._tags - - @property - def replica_set_name(self) -> Optional[str]: - """Replica set name or None.""" - return self._replica_set_name - - @property - def primary(self) -> Optional[tuple[str, int]]: - """This server's opinion about who the primary is, or None.""" - return self._primary - - @property - def max_bson_size(self) -> int: - return self._max_bson_size - - @property - def max_message_size(self) -> int: - return self._max_message_size - - @property - def max_write_batch_size(self) -> int: - return self._max_write_batch_size - - @property - def min_wire_version(self) -> int: - return self._min_wire_version - - @property - def max_wire_version(self) -> int: - return self._max_wire_version - - @property - def set_version(self) -> Optional[int]: - return self._set_version - - @property - def election_id(self) -> Optional[ObjectId]: - return self._election_id - - @property - def cluster_time(self) -> Optional[ClusterTime]: - return self._cluster_time - - @property - def election_tuple(self) -> tuple[Optional[int], Optional[ObjectId]]: - warnings.warn( - "'election_tuple' is deprecated, use 'set_version' and 'election_id' instead", - DeprecationWarning, - stacklevel=2, - ) - return self._set_version, self._election_id - - @property - def me(self) -> Optional[tuple[str, int]]: - return self._me - - @property - def logical_session_timeout_minutes(self) -> Optional[int]: - return self._ls_timeout_minutes - - @property - def last_write_date(self) -> Optional[float]: - return self._last_write_date - - @property - def last_update_time(self) -> float: - return self._last_update_time - - @property - def round_trip_time(self) -> Optional[float]: - """The current average latency or None.""" - # This override is for unittesting only! - if self._address in self._host_to_round_trip_time: - return self._host_to_round_trip_time[self._address] - - return self._round_trip_time - - @property - def min_round_trip_time(self) -> float: - """The min latency from the most recent samples.""" - return self._min_round_trip_time - - @property - def error(self) -> Optional[Exception]: - """The last error attempting to connect to the server, or None.""" - return self._error - - @property - def is_writable(self) -> bool: - return self._is_writable - - @property - def is_readable(self) -> bool: - return self._is_readable - - @property - def mongos(self) -> bool: - return self._server_type == SERVER_TYPE.Mongos - - @property - def is_server_type_known(self) -> bool: - return self.server_type != SERVER_TYPE.Unknown - - @property - def retryable_writes_supported(self) -> bool: - """Checks if this server supports retryable writes.""" - return ( - self._server_type in (SERVER_TYPE.Mongos, SERVER_TYPE.RSPrimary) - ) or self._server_type == SERVER_TYPE.LoadBalancer - - @property - def retryable_reads_supported(self) -> bool: - """Checks if this server supports retryable writes.""" - return self._max_wire_version >= 6 - - @property - def topology_version(self) -> Optional[Mapping[str, Any]]: - return self._topology_version - - def to_unknown(self, error: Optional[Exception] = None) -> ServerDescription: - unknown = ServerDescription(self.address, error=error) - unknown._topology_version = self.topology_version - return unknown - - def __eq__(self, other: Any) -> bool: - if isinstance(other, ServerDescription): - return ( - (self._address == other.address) - and (self._server_type == other.server_type) - and (self._min_wire_version == other.min_wire_version) - and (self._max_wire_version == other.max_wire_version) - and (self._me == other.me) - and (self._all_hosts == other.all_hosts) - and (self._tags == other.tags) - and (self._replica_set_name == other.replica_set_name) - and (self._set_version == other.set_version) - and (self._election_id == other.election_id) - and (self._primary == other.primary) - and (self._ls_timeout_minutes == other.logical_session_timeout_minutes) - and (self._error == other.error) - ) - - return NotImplemented - - def __ne__(self, other: Any) -> bool: - return not self == other - - def __repr__(self) -> str: - errmsg = "" - if self.error: - errmsg = f", error={self.error!r}" - return "<{} {} server_type: {}, rtt: {}{}>".format( - self.__class__.__name__, - self.address, - self.server_type_name, - self.round_trip_time, - errmsg, - ) - - # For unittesting only. Use under no circumstances! - _host_to_round_trip_time: dict = {} diff --git a/pymongo/synchronous/server_selectors.py b/pymongo/synchronous/server_selectors.py deleted file mode 100644 index a3b2066ab0..0000000000 --- a/pymongo/synchronous/server_selectors.py +++ /dev/null @@ -1,175 +0,0 @@ -# Copyright 2014-2016 MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you -# may not use this file except in compliance with the License. You -# may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or -# implied. See the License for the specific language governing -# permissions and limitations under the License. - -"""Criteria to select some ServerDescriptions from a TopologyDescription.""" -from __future__ import annotations - -from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence, TypeVar, cast - -from pymongo.server_type import SERVER_TYPE - -if TYPE_CHECKING: - from pymongo.synchronous.server_description import ServerDescription - from pymongo.synchronous.topology_description import TopologyDescription - -_IS_SYNC = True - -T = TypeVar("T") -TagSet = Mapping[str, Any] -TagSets = Sequence[TagSet] - - -class Selection: - """Input or output of a server selector function.""" - - @classmethod - def from_topology_description(cls, topology_description: TopologyDescription) -> Selection: - known_servers = topology_description.known_servers - primary = None - for sd in known_servers: - if sd.server_type == SERVER_TYPE.RSPrimary: - primary = sd - break - - return Selection( - topology_description, - topology_description.known_servers, - topology_description.common_wire_version, - primary, - ) - - def __init__( - self, - topology_description: TopologyDescription, - server_descriptions: list[ServerDescription], - common_wire_version: Optional[int], - primary: Optional[ServerDescription], - ): - self.topology_description = topology_description - self.server_descriptions = server_descriptions - self.primary = primary - self.common_wire_version = common_wire_version - - def with_server_descriptions(self, server_descriptions: list[ServerDescription]) -> Selection: - return Selection( - self.topology_description, server_descriptions, self.common_wire_version, self.primary - ) - - def secondary_with_max_last_write_date(self) -> Optional[ServerDescription]: - secondaries = secondary_server_selector(self) - if secondaries.server_descriptions: - return max( - secondaries.server_descriptions, key=lambda sd: cast(float, sd.last_write_date) - ) - return None - - @property - def primary_selection(self) -> Selection: - primaries = [self.primary] if self.primary else [] - return self.with_server_descriptions(primaries) - - @property - def heartbeat_frequency(self) -> int: - return self.topology_description.heartbeat_frequency - - @property - def topology_type(self) -> int: - return self.topology_description.topology_type - - def __bool__(self) -> bool: - return bool(self.server_descriptions) - - def __getitem__(self, item: int) -> ServerDescription: - return self.server_descriptions[item] - - -def any_server_selector(selection: T) -> T: - return selection - - -def readable_server_selector(selection: Selection) -> Selection: - return selection.with_server_descriptions( - [s for s in selection.server_descriptions if s.is_readable] - ) - - -def writable_server_selector(selection: Selection) -> Selection: - return selection.with_server_descriptions( - [s for s in selection.server_descriptions if s.is_writable] - ) - - -def secondary_server_selector(selection: Selection) -> Selection: - return selection.with_server_descriptions( - [s for s in selection.server_descriptions if s.server_type == SERVER_TYPE.RSSecondary] - ) - - -def arbiter_server_selector(selection: Selection) -> Selection: - return selection.with_server_descriptions( - [s for s in selection.server_descriptions if s.server_type == SERVER_TYPE.RSArbiter] - ) - - -def writable_preferred_server_selector(selection: Selection) -> Selection: - """Like PrimaryPreferred but doesn't use tags or latency.""" - return writable_server_selector(selection) or secondary_server_selector(selection) - - -def apply_single_tag_set(tag_set: TagSet, selection: Selection) -> Selection: - """All servers matching one tag set. - - A tag set is a dict. A server matches if its tags are a superset: - A server tagged {'a': '1', 'b': '2'} matches the tag set {'a': '1'}. - - The empty tag set {} matches any server. - """ - - def tags_match(server_tags: Mapping[str, Any]) -> bool: - for key, value in tag_set.items(): - if key not in server_tags or server_tags[key] != value: - return False - - return True - - return selection.with_server_descriptions( - [s for s in selection.server_descriptions if tags_match(s.tags)] - ) - - -def apply_tag_sets(tag_sets: TagSets, selection: Selection) -> Selection: - """All servers match a list of tag sets. - - tag_sets is a list of dicts. The empty tag set {} matches any server, - and may be provided at the end of the list as a fallback. So - [{'a': 'value'}, {}] expresses a preference for servers tagged - {'a': 'value'}, but accepts any server if none matches the first - preference. - """ - for tag_set in tag_sets: - with_tag_set = apply_single_tag_set(tag_set, selection) - if with_tag_set: - return with_tag_set - - return selection.with_server_descriptions([]) - - -def secondary_with_tags_server_selector(tag_sets: TagSets, selection: Selection) -> Selection: - """All near-enough secondaries matching the tag sets.""" - return apply_tag_sets(tag_sets, secondary_server_selector(selection)) - - -def member_with_tags_server_selector(tag_sets: TagSets, selection: Selection) -> Selection: - """All near-enough members matching the tag sets.""" - return apply_tag_sets(tag_sets, readable_server_selector(selection)) diff --git a/pymongo/synchronous/settings.py b/pymongo/synchronous/settings.py index f51b5307aa..8719e86083 100644 --- a/pymongo/synchronous/settings.py +++ b/pymongo/synchronous/settings.py @@ -20,12 +20,14 @@ from typing import Any, Collection, Optional, Type, Union from bson.objectid import ObjectId +from pymongo import common +from pymongo.common import LOCAL_THRESHOLD_MS, SERVER_SELECTION_TIMEOUT from pymongo.errors import ConfigurationError -from pymongo.synchronous import common, monitor, pool -from pymongo.synchronous.common import LOCAL_THRESHOLD_MS, SERVER_SELECTION_TIMEOUT -from pymongo.synchronous.pool import Pool, PoolOptions -from pymongo.synchronous.server_description import ServerDescription -from pymongo.synchronous.topology_description import TOPOLOGY_TYPE, _ServerSelector +from pymongo.pool_options import PoolOptions +from pymongo.server_description import ServerDescription +from pymongo.synchronous import monitor, pool +from pymongo.synchronous.pool import Pool +from pymongo.topology_description import TOPOLOGY_TYPE, _ServerSelector _IS_SYNC = True diff --git a/pymongo/synchronous/srv_resolver.py b/pymongo/synchronous/srv_resolver.py deleted file mode 100644 index e5481305e0..0000000000 --- a/pymongo/synchronous/srv_resolver.py +++ /dev/null @@ -1,149 +0,0 @@ -# Copyright 2019-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you -# may not use this file except in compliance with the License. You -# may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or -# implied. See the License for the specific language governing -# permissions and limitations under the License. - -"""Support for resolving hosts and options from mongodb+srv:// URIs.""" -from __future__ import annotations - -import ipaddress -import random -from typing import TYPE_CHECKING, Any, Optional, Union - -from pymongo.errors import ConfigurationError -from pymongo.synchronous.common import CONNECT_TIMEOUT - -if TYPE_CHECKING: - from dns import resolver - -_IS_SYNC = True - - -def _have_dnspython() -> bool: - try: - import dns # noqa: F401 - - return True - except ImportError: - return False - - -# dnspython can return bytes or str from various parts -# of its API depending on version. We always want str. -def maybe_decode(text: Union[str, bytes]) -> str: - if isinstance(text, bytes): - return text.decode() - return text - - -# PYTHON-2667 Lazily call dns.resolver methods for compatibility with eventlet. -def _resolve(*args: Any, **kwargs: Any) -> resolver.Answer: - from dns import resolver - - if hasattr(resolver, "resolve"): - # dnspython >= 2 - return resolver.resolve(*args, **kwargs) - # dnspython 1.X - return resolver.query(*args, **kwargs) - - -_INVALID_HOST_MSG = ( - "Invalid URI host: %s is not a valid hostname for 'mongodb+srv://'. " - "Did you mean to use 'mongodb://'?" -) - - -class _SrvResolver: - def __init__( - self, - fqdn: str, - connect_timeout: Optional[float], - srv_service_name: str, - srv_max_hosts: int = 0, - ): - self.__fqdn = fqdn - self.__srv = srv_service_name - self.__connect_timeout = connect_timeout or CONNECT_TIMEOUT - self.__srv_max_hosts = srv_max_hosts or 0 - # Validate the fully qualified domain name. - try: - ipaddress.ip_address(fqdn) - raise ConfigurationError(_INVALID_HOST_MSG % ("an IP address",)) - except ValueError: - pass - - try: - self.__plist = self.__fqdn.split(".")[1:] - except Exception: - raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,)) from None - self.__slen = len(self.__plist) - if self.__slen < 2: - raise ConfigurationError(_INVALID_HOST_MSG % (fqdn,)) - - def get_options(self) -> Optional[str]: - from dns import resolver - - try: - results = _resolve(self.__fqdn, "TXT", lifetime=self.__connect_timeout) - except (resolver.NoAnswer, resolver.NXDOMAIN): - # No TXT records - return None - except Exception as exc: - raise ConfigurationError(str(exc)) from None - if len(results) > 1: - raise ConfigurationError("Only one TXT record is supported") - return (b"&".join([b"".join(res.strings) for res in results])).decode("utf-8") - - def _resolve_uri(self, encapsulate_errors: bool) -> resolver.Answer: - try: - results = _resolve( - "_" + self.__srv + "._tcp." + self.__fqdn, "SRV", lifetime=self.__connect_timeout - ) - except Exception as exc: - if not encapsulate_errors: - # Raise the original error. - raise - # Else, raise all errors as ConfigurationError. - raise ConfigurationError(str(exc)) from None - return results - - def _get_srv_response_and_hosts( - self, encapsulate_errors: bool - ) -> tuple[resolver.Answer, list[tuple[str, Any]]]: - results = self._resolve_uri(encapsulate_errors) - - # Construct address tuples - nodes = [ - (maybe_decode(res.target.to_text(omit_final_dot=True)), res.port) for res in results - ] - - # Validate hosts - for node in nodes: - try: - nlist = node[0].lower().split(".")[1:][-self.__slen :] - except Exception: - raise ConfigurationError(f"Invalid SRV host: {node[0]}") from None - if self.__plist != nlist: - raise ConfigurationError(f"Invalid SRV host: {node[0]}") - if self.__srv_max_hosts: - nodes = random.sample(nodes, min(self.__srv_max_hosts, len(nodes))) - return results, nodes - - def get_hosts(self) -> list[tuple[str, Any]]: - _, nodes = self._get_srv_response_and_hosts(True) - return nodes - - def get_hosts_and_min_ttl(self) -> tuple[list[tuple[str, Any]], int]: - results, nodes = self._get_srv_response_and_hosts(False) - rrset = results.rrset - ttl = rrset.ttl if rrset else 0 - return nodes, ttl diff --git a/pymongo/synchronous/topology.py b/pymongo/synchronous/topology.py index d76cef7bfc..6832b54f69 100644 --- a/pymongo/synchronous/topology.py +++ b/pymongo/synchronous/topology.py @@ -27,7 +27,7 @@ from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Mapping, Optional, cast -from pymongo import _csot, helpers_constants +from pymongo import _csot, common, helpers_shared from pymongo.errors import ( ConnectionFailure, InvalidOperation, @@ -38,27 +38,28 @@ ServerSelectionTimeoutError, WriteError, ) +from pymongo.hello import Hello from pymongo.lock import _create_lock -from pymongo.synchronous import common, periodic_executor -from pymongo.synchronous.client_session import _ServerSession, _ServerSessionPool -from pymongo.synchronous.hello import Hello -from pymongo.synchronous.logger import ( +from pymongo.logger import ( _SERVER_SELECTION_LOGGER, _debug_log, _ServerSelectionStatusMessage, ) -from pymongo.synchronous.monitor import SrvMonitor -from pymongo.synchronous.pool import Pool, PoolOptions -from pymongo.synchronous.server import Server -from pymongo.synchronous.server_description import ServerDescription -from pymongo.synchronous.server_selectors import ( +from pymongo.pool_options import PoolOptions +from pymongo.server_description import ServerDescription +from pymongo.server_selectors import ( Selection, any_server_selector, arbiter_server_selector, secondary_server_selector, writable_server_selector, ) -from pymongo.synchronous.topology_description import ( +from pymongo.synchronous import periodic_executor +from pymongo.synchronous.client_session import _ServerSession, _ServerSessionPool +from pymongo.synchronous.monitor import SrvMonitor +from pymongo.synchronous.pool import Pool +from pymongo.synchronous.server import Server +from pymongo.topology_description import ( SRV_POLLING_TOPOLOGIES, TOPOLOGY_TYPE, TopologyDescription, @@ -69,7 +70,7 @@ if TYPE_CHECKING: from bson import ObjectId from pymongo.synchronous.settings import TopologySettings - from pymongo.synchronous.typings import ClusterTime, _Address + from pymongo.typings import ClusterTime, _Address _IS_SYNC = True @@ -788,8 +789,8 @@ def _handle_error(self, address: _Address, err_ctx: _ErrorContext) -> None: # Default error code if one does not exist. default = 10107 if isinstance(error, NotPrimaryError) else None err_code = error.details.get("code", default) # type: ignore[union-attr] - if err_code in helpers_constants._NOT_PRIMARY_CODES: - is_shutting_down = err_code in helpers_constants._SHUTDOWN_CODES + if err_code in helpers_shared._NOT_PRIMARY_CODES: + is_shutting_down = err_code in helpers_shared._SHUTDOWN_CODES # Mark server Unknown, clear the pool, and request check. if not self._settings.load_balanced: self._process_change(ServerDescription(address, error=error)) diff --git a/pymongo/synchronous/topology_description.py b/pymongo/synchronous/topology_description.py deleted file mode 100644 index 961b9da8d5..0000000000 --- a/pymongo/synchronous/topology_description.py +++ /dev/null @@ -1,678 +0,0 @@ -# Copyright 2014-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you -# may not use this file except in compliance with the License. You -# may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or -# implied. See the License for the specific language governing -# permissions and limitations under the License. - -"""Represent a deployment of MongoDB servers.""" -from __future__ import annotations - -from random import sample -from typing import ( - Any, - Callable, - List, - Mapping, - MutableMapping, - NamedTuple, - Optional, - cast, -) - -from bson.min_key import MinKey -from bson.objectid import ObjectId -from pymongo.errors import ConfigurationError -from pymongo.server_type import SERVER_TYPE -from pymongo.synchronous import common -from pymongo.synchronous.read_preferences import ReadPreference, _AggWritePref, _ServerMode -from pymongo.synchronous.server_description import ServerDescription -from pymongo.synchronous.server_selectors import Selection -from pymongo.synchronous.typings import _Address - -_IS_SYNC = True - - -# Enumeration for various kinds of MongoDB cluster topologies. -class _TopologyType(NamedTuple): - Single: int - ReplicaSetNoPrimary: int - ReplicaSetWithPrimary: int - Sharded: int - Unknown: int - LoadBalanced: int - - -TOPOLOGY_TYPE = _TopologyType(*range(6)) - -# Topologies compatible with SRV record polling. -SRV_POLLING_TOPOLOGIES: tuple[int, int] = (TOPOLOGY_TYPE.Unknown, TOPOLOGY_TYPE.Sharded) - - -_ServerSelector = Callable[[List[ServerDescription]], List[ServerDescription]] - - -class TopologyDescription: - def __init__( - self, - topology_type: int, - server_descriptions: dict[_Address, ServerDescription], - replica_set_name: Optional[str], - max_set_version: Optional[int], - max_election_id: Optional[ObjectId], - topology_settings: Any, - ) -> None: - """Representation of a deployment of MongoDB servers. - - :param topology_type: initial type - :param server_descriptions: dict of (address, ServerDescription) for - all seeds - :param replica_set_name: replica set name or None - :param max_set_version: greatest setVersion seen from a primary, or None - :param max_election_id: greatest electionId seen from a primary, or None - :param topology_settings: a TopologySettings - """ - self._topology_type = topology_type - self._replica_set_name = replica_set_name - self._server_descriptions = server_descriptions - self._max_set_version = max_set_version - self._max_election_id = max_election_id - - # The heartbeat_frequency is used in staleness estimates. - self._topology_settings = topology_settings - - # Is PyMongo compatible with all servers' wire protocols? - self._incompatible_err = None - if self._topology_type != TOPOLOGY_TYPE.LoadBalanced: - self._init_incompatible_err() - - # Server Discovery And Monitoring Spec: Whenever a client updates the - # TopologyDescription from an hello response, it MUST set - # TopologyDescription.logicalSessionTimeoutMinutes to the smallest - # logicalSessionTimeoutMinutes value among ServerDescriptions of all - # data-bearing server types. If any have a null - # logicalSessionTimeoutMinutes, then - # TopologyDescription.logicalSessionTimeoutMinutes MUST be set to null. - readable_servers = self.readable_servers - if not readable_servers: - self._ls_timeout_minutes = None - elif any(s.logical_session_timeout_minutes is None for s in readable_servers): - self._ls_timeout_minutes = None - else: - self._ls_timeout_minutes = min( # type: ignore[type-var] - s.logical_session_timeout_minutes for s in readable_servers - ) - - def _init_incompatible_err(self) -> None: - """Internal compatibility check for non-load balanced topologies.""" - for s in self._server_descriptions.values(): - if not s.is_server_type_known: - continue - - # s.min/max_wire_version is the server's wire protocol. - # MIN/MAX_SUPPORTED_WIRE_VERSION is what PyMongo supports. - server_too_new = ( - # Server too new. - s.min_wire_version is not None - and s.min_wire_version > common.MAX_SUPPORTED_WIRE_VERSION - ) - - server_too_old = ( - # Server too old. - s.max_wire_version is not None - and s.max_wire_version < common.MIN_SUPPORTED_WIRE_VERSION - ) - - if server_too_new: - self._incompatible_err = ( - "Server at %s:%d requires wire version %d, but this " # type: ignore - "version of PyMongo only supports up to %d." - % ( - s.address[0], - s.address[1] or 0, - s.min_wire_version, - common.MAX_SUPPORTED_WIRE_VERSION, - ) - ) - - elif server_too_old: - self._incompatible_err = ( - "Server at %s:%d reports wire version %d, but this " # type: ignore - "version of PyMongo requires at least %d (MongoDB %s)." - % ( - s.address[0], - s.address[1] or 0, - s.max_wire_version, - common.MIN_SUPPORTED_WIRE_VERSION, - common.MIN_SUPPORTED_SERVER_VERSION, - ) - ) - - break - - def check_compatible(self) -> None: - """Raise ConfigurationError if any server is incompatible. - - A server is incompatible if its wire protocol version range does not - overlap with PyMongo's. - """ - if self._incompatible_err: - raise ConfigurationError(self._incompatible_err) - - def has_server(self, address: _Address) -> bool: - return address in self._server_descriptions - - def reset_server(self, address: _Address) -> TopologyDescription: - """A copy of this description, with one server marked Unknown.""" - unknown_sd = self._server_descriptions[address].to_unknown() - return updated_topology_description(self, unknown_sd) - - def reset(self) -> TopologyDescription: - """A copy of this description, with all servers marked Unknown.""" - if self._topology_type == TOPOLOGY_TYPE.ReplicaSetWithPrimary: - topology_type = TOPOLOGY_TYPE.ReplicaSetNoPrimary - else: - topology_type = self._topology_type - - # The default ServerDescription's type is Unknown. - sds = {address: ServerDescription(address) for address in self._server_descriptions} - - return TopologyDescription( - topology_type, - sds, - self._replica_set_name, - self._max_set_version, - self._max_election_id, - self._topology_settings, - ) - - def server_descriptions(self) -> dict[_Address, ServerDescription]: - """dict of (address, - :class:`~pymongo.server_description.ServerDescription`). - """ - return self._server_descriptions.copy() - - @property - def topology_type(self) -> int: - """The type of this topology.""" - return self._topology_type - - @property - def topology_type_name(self) -> str: - """The topology type as a human readable string. - - .. versionadded:: 3.4 - """ - return TOPOLOGY_TYPE._fields[self._topology_type] - - @property - def replica_set_name(self) -> Optional[str]: - """The replica set name.""" - return self._replica_set_name - - @property - def max_set_version(self) -> Optional[int]: - """Greatest setVersion seen from a primary, or None.""" - return self._max_set_version - - @property - def max_election_id(self) -> Optional[ObjectId]: - """Greatest electionId seen from a primary, or None.""" - return self._max_election_id - - @property - def logical_session_timeout_minutes(self) -> Optional[int]: - """Minimum logical session timeout, or None.""" - return self._ls_timeout_minutes - - @property - def known_servers(self) -> list[ServerDescription]: - """List of Servers of types besides Unknown.""" - return [s for s in self._server_descriptions.values() if s.is_server_type_known] - - @property - def has_known_servers(self) -> bool: - """Whether there are any Servers of types besides Unknown.""" - return any(s for s in self._server_descriptions.values() if s.is_server_type_known) - - @property - def readable_servers(self) -> list[ServerDescription]: - """List of readable Servers.""" - return [s for s in self._server_descriptions.values() if s.is_readable] - - @property - def common_wire_version(self) -> Optional[int]: - """Minimum of all servers' max wire versions, or None.""" - servers = self.known_servers - if servers: - return min(s.max_wire_version for s in self.known_servers) - - return None - - @property - def heartbeat_frequency(self) -> int: - return self._topology_settings.heartbeat_frequency - - @property - def srv_max_hosts(self) -> int: - return self._topology_settings._srv_max_hosts - - def _apply_local_threshold(self, selection: Optional[Selection]) -> list[ServerDescription]: - if not selection: - return [] - round_trip_times: list[float] = [] - for server in selection.server_descriptions: - if server.round_trip_time is None: - config_err_msg = f"round_trip_time for server {server.address} is unexpectedly None: {self}, servers: {selection.server_descriptions}" - raise ConfigurationError(config_err_msg) - round_trip_times.append(server.round_trip_time) - # Round trip time in seconds. - fastest = min(round_trip_times) - threshold = self._topology_settings.local_threshold_ms / 1000.0 - return [ - s - for s in selection.server_descriptions - if (cast(float, s.round_trip_time) - fastest) <= threshold - ] - - def apply_selector( - self, - selector: Any, - address: Optional[_Address] = None, - custom_selector: Optional[_ServerSelector] = None, - ) -> list[ServerDescription]: - """List of servers matching the provided selector(s). - - :param selector: a callable that takes a Selection as input and returns - a Selection as output. For example, an instance of a read - preference from :mod:`~pymongo.read_preferences`. - :param address: A server address to select. - :param custom_selector: A callable that augments server - selection rules. Accepts a list of - :class:`~pymongo.server_description.ServerDescription` objects and - return a list of server descriptions that should be considered - suitable for the desired operation. - - .. versionadded:: 3.4 - """ - if getattr(selector, "min_wire_version", 0): - common_wv = self.common_wire_version - if common_wv and common_wv < selector.min_wire_version: - raise ConfigurationError( - "%s requires min wire version %d, but topology's min" - " wire version is %d" % (selector, selector.min_wire_version, common_wv) - ) - - if isinstance(selector, _AggWritePref): - selector.selection_hook(self) - - if self.topology_type == TOPOLOGY_TYPE.Unknown: - return [] - elif self.topology_type in (TOPOLOGY_TYPE.Single, TOPOLOGY_TYPE.LoadBalanced): - # Ignore selectors for standalone and load balancer mode. - return self.known_servers - if address: - # Ignore selectors when explicit address is requested. - description = self.server_descriptions().get(address) - return [description] if description else [] - - selection = Selection.from_topology_description(self) - # Ignore read preference for sharded clusters. - if self.topology_type != TOPOLOGY_TYPE.Sharded: - selection = selector(selection) - - # Apply custom selector followed by localThresholdMS. - if custom_selector is not None and selection: - selection = selection.with_server_descriptions( - custom_selector(selection.server_descriptions) - ) - return self._apply_local_threshold(selection) - - def has_readable_server(self, read_preference: _ServerMode = ReadPreference.PRIMARY) -> bool: - """Does this topology have any readable servers available matching the - given read preference? - - :param read_preference: an instance of a read preference from - :mod:`~pymongo.read_preferences`. Defaults to - :attr:`~pymongo.read_preferences.ReadPreference.PRIMARY`. - - .. note:: When connected directly to a single server this method - always returns ``True``. - - .. versionadded:: 3.4 - """ - common.validate_read_preference("read_preference", read_preference) - return any(self.apply_selector(read_preference)) - - def has_writable_server(self) -> bool: - """Does this topology have a writable server available? - - .. note:: When connected directly to a single server this method - always returns ``True``. - - .. versionadded:: 3.4 - """ - return self.has_readable_server(ReadPreference.PRIMARY) - - def __repr__(self) -> str: - # Sort the servers by address. - servers = sorted(self._server_descriptions.values(), key=lambda sd: sd.address) - return "<{} id: {}, topology_type: {}, servers: {!r}>".format( - self.__class__.__name__, - self._topology_settings._topology_id, - self.topology_type_name, - servers, - ) - - -# If topology type is Unknown and we receive a hello response, what should -# the new topology type be? -_SERVER_TYPE_TO_TOPOLOGY_TYPE = { - SERVER_TYPE.Mongos: TOPOLOGY_TYPE.Sharded, - SERVER_TYPE.RSPrimary: TOPOLOGY_TYPE.ReplicaSetWithPrimary, - SERVER_TYPE.RSSecondary: TOPOLOGY_TYPE.ReplicaSetNoPrimary, - SERVER_TYPE.RSArbiter: TOPOLOGY_TYPE.ReplicaSetNoPrimary, - SERVER_TYPE.RSOther: TOPOLOGY_TYPE.ReplicaSetNoPrimary, - # Note: SERVER_TYPE.LoadBalancer and Unknown are intentionally left out. -} - - -def updated_topology_description( - topology_description: TopologyDescription, server_description: ServerDescription -) -> TopologyDescription: - """Return an updated copy of a TopologyDescription. - - :param topology_description: the current TopologyDescription - :param server_description: a new ServerDescription that resulted from - a hello call - - Called after attempting (successfully or not) to call hello on the - server at server_description.address. Does not modify topology_description. - """ - address = server_description.address - - # These values will be updated, if necessary, to form the new - # TopologyDescription. - topology_type = topology_description.topology_type - set_name = topology_description.replica_set_name - max_set_version = topology_description.max_set_version - max_election_id = topology_description.max_election_id - server_type = server_description.server_type - - # Don't mutate the original dict of server descriptions; copy it. - sds = topology_description.server_descriptions() - - # Replace this server's description with the new one. - sds[address] = server_description - - if topology_type == TOPOLOGY_TYPE.Single: - # Set server type to Unknown if replica set name does not match. - if set_name is not None and set_name != server_description.replica_set_name: - error = ConfigurationError( - "client is configured to connect to a replica set named " - "'{}' but this node belongs to a set named '{}'".format( - set_name, server_description.replica_set_name - ) - ) - sds[address] = server_description.to_unknown(error=error) - # Single type never changes. - return TopologyDescription( - TOPOLOGY_TYPE.Single, - sds, - set_name, - max_set_version, - max_election_id, - topology_description._topology_settings, - ) - - if topology_type == TOPOLOGY_TYPE.Unknown: - if server_type in (SERVER_TYPE.Standalone, SERVER_TYPE.LoadBalancer): - if len(topology_description._topology_settings.seeds) == 1: - topology_type = TOPOLOGY_TYPE.Single - else: - # Remove standalone from Topology when given multiple seeds. - sds.pop(address) - elif server_type not in (SERVER_TYPE.Unknown, SERVER_TYPE.RSGhost): - topology_type = _SERVER_TYPE_TO_TOPOLOGY_TYPE[server_type] - - if topology_type == TOPOLOGY_TYPE.Sharded: - if server_type not in (SERVER_TYPE.Mongos, SERVER_TYPE.Unknown): - sds.pop(address) - - elif topology_type == TOPOLOGY_TYPE.ReplicaSetNoPrimary: - if server_type in (SERVER_TYPE.Standalone, SERVER_TYPE.Mongos): - sds.pop(address) - - elif server_type == SERVER_TYPE.RSPrimary: - (topology_type, set_name, max_set_version, max_election_id) = _update_rs_from_primary( - sds, set_name, server_description, max_set_version, max_election_id - ) - - elif server_type in (SERVER_TYPE.RSSecondary, SERVER_TYPE.RSArbiter, SERVER_TYPE.RSOther): - topology_type, set_name = _update_rs_no_primary_from_member( - sds, set_name, server_description - ) - - elif topology_type == TOPOLOGY_TYPE.ReplicaSetWithPrimary: - if server_type in (SERVER_TYPE.Standalone, SERVER_TYPE.Mongos): - sds.pop(address) - topology_type = _check_has_primary(sds) - - elif server_type == SERVER_TYPE.RSPrimary: - (topology_type, set_name, max_set_version, max_election_id) = _update_rs_from_primary( - sds, set_name, server_description, max_set_version, max_election_id - ) - - elif server_type in (SERVER_TYPE.RSSecondary, SERVER_TYPE.RSArbiter, SERVER_TYPE.RSOther): - topology_type = _update_rs_with_primary_from_member(sds, set_name, server_description) - - else: - # Server type is Unknown or RSGhost: did we just lose the primary? - topology_type = _check_has_primary(sds) - - # Return updated copy. - return TopologyDescription( - topology_type, - sds, - set_name, - max_set_version, - max_election_id, - topology_description._topology_settings, - ) - - -def _updated_topology_description_srv_polling( - topology_description: TopologyDescription, seedlist: list[tuple[str, Any]] -) -> TopologyDescription: - """Return an updated copy of a TopologyDescription. - - :param topology_description: the current TopologyDescription - :param seedlist: a list of new seeds new ServerDescription that resulted from - a hello call - """ - assert topology_description.topology_type in SRV_POLLING_TOPOLOGIES - # Create a copy of the server descriptions. - sds = topology_description.server_descriptions() - - # If seeds haven't changed, don't do anything. - if set(sds.keys()) == set(seedlist): - return topology_description - - # Remove SDs corresponding to servers no longer part of the SRV record. - for address in list(sds.keys()): - if address not in seedlist: - sds.pop(address) - - if topology_description.srv_max_hosts != 0: - new_hosts = set(seedlist) - set(sds.keys()) - n_to_add = topology_description.srv_max_hosts - len(sds) - if n_to_add > 0: - seedlist = sample(sorted(new_hosts), min(n_to_add, len(new_hosts))) - else: - seedlist = [] - # Add SDs corresponding to servers recently added to the SRV record. - for address in seedlist: - if address not in sds: - sds[address] = ServerDescription(address) - return TopologyDescription( - topology_description.topology_type, - sds, - topology_description.replica_set_name, - topology_description.max_set_version, - topology_description.max_election_id, - topology_description._topology_settings, - ) - - -def _update_rs_from_primary( - sds: MutableMapping[_Address, ServerDescription], - replica_set_name: Optional[str], - server_description: ServerDescription, - max_set_version: Optional[int], - max_election_id: Optional[ObjectId], -) -> tuple[int, Optional[str], Optional[int], Optional[ObjectId]]: - """Update topology description from a primary's hello response. - - Pass in a dict of ServerDescriptions, current replica set name, the - ServerDescription we are processing, and the TopologyDescription's - max_set_version and max_election_id if any. - - Returns (new topology type, new replica_set_name, new max_set_version, - new max_election_id). - """ - if replica_set_name is None: - replica_set_name = server_description.replica_set_name - - elif replica_set_name != server_description.replica_set_name: - # We found a primary but it doesn't have the replica_set_name - # provided by the user. - sds.pop(server_description.address) - return _check_has_primary(sds), replica_set_name, max_set_version, max_election_id - - if server_description.max_wire_version is None or server_description.max_wire_version < 17: - new_election_tuple: tuple = (server_description.set_version, server_description.election_id) - max_election_tuple: tuple = (max_set_version, max_election_id) - if None not in new_election_tuple: - if None not in max_election_tuple and new_election_tuple < max_election_tuple: - # Stale primary, set to type Unknown. - sds[server_description.address] = server_description.to_unknown() - return _check_has_primary(sds), replica_set_name, max_set_version, max_election_id - max_election_id = server_description.election_id - - if server_description.set_version is not None and ( - max_set_version is None or server_description.set_version > max_set_version - ): - max_set_version = server_description.set_version - else: - new_election_tuple = server_description.election_id, server_description.set_version - max_election_tuple = max_election_id, max_set_version - new_election_safe = tuple(MinKey() if i is None else i for i in new_election_tuple) - max_election_safe = tuple(MinKey() if i is None else i for i in max_election_tuple) - if new_election_safe < max_election_safe: - # Stale primary, set to type Unknown. - sds[server_description.address] = server_description.to_unknown() - return _check_has_primary(sds), replica_set_name, max_set_version, max_election_id - else: - max_election_id = server_description.election_id - max_set_version = server_description.set_version - - # We've heard from the primary. Is it the same primary as before? - for server in sds.values(): - if ( - server.server_type is SERVER_TYPE.RSPrimary - and server.address != server_description.address - ): - # Reset old primary's type to Unknown. - sds[server.address] = server.to_unknown() - - # There can be only one prior primary. - break - - # Discover new hosts from this primary's response. - for new_address in server_description.all_hosts: - if new_address not in sds: - sds[new_address] = ServerDescription(new_address) - - # Remove hosts not in the response. - for addr in set(sds) - server_description.all_hosts: - sds.pop(addr) - - # If the host list differs from the seed list, we may not have a primary - # after all. - return (_check_has_primary(sds), replica_set_name, max_set_version, max_election_id) - - -def _update_rs_with_primary_from_member( - sds: MutableMapping[_Address, ServerDescription], - replica_set_name: Optional[str], - server_description: ServerDescription, -) -> int: - """RS with known primary. Process a response from a non-primary. - - Pass in a dict of ServerDescriptions, current replica set name, and the - ServerDescription we are processing. - - Returns new topology type. - """ - assert replica_set_name is not None - - if replica_set_name != server_description.replica_set_name: - sds.pop(server_description.address) - elif server_description.me and server_description.address != server_description.me: - sds.pop(server_description.address) - - # Had this member been the primary? - return _check_has_primary(sds) - - -def _update_rs_no_primary_from_member( - sds: MutableMapping[_Address, ServerDescription], - replica_set_name: Optional[str], - server_description: ServerDescription, -) -> tuple[int, Optional[str]]: - """RS without known primary. Update from a non-primary's response. - - Pass in a dict of ServerDescriptions, current replica set name, and the - ServerDescription we are processing. - - Returns (new topology type, new replica_set_name). - """ - topology_type = TOPOLOGY_TYPE.ReplicaSetNoPrimary - if replica_set_name is None: - replica_set_name = server_description.replica_set_name - - elif replica_set_name != server_description.replica_set_name: - sds.pop(server_description.address) - return topology_type, replica_set_name - - # This isn't the primary's response, so don't remove any servers - # it doesn't report. Only add new servers. - for address in server_description.all_hosts: - if address not in sds: - sds[address] = ServerDescription(address) - - if server_description.me and server_description.address != server_description.me: - sds.pop(server_description.address) - - return topology_type, replica_set_name - - -def _check_has_primary(sds: Mapping[_Address, ServerDescription]) -> int: - """Current topology type is ReplicaSetWithPrimary. Is primary still known? - - Pass in a dict of ServerDescriptions. - - Returns new topology type. - """ - for s in sds.values(): - if s.server_type == SERVER_TYPE.RSPrimary: - return TOPOLOGY_TYPE.ReplicaSetWithPrimary - else: # noqa: PLW0120 - return TOPOLOGY_TYPE.ReplicaSetNoPrimary diff --git a/pymongo/synchronous/typings.py b/pymongo/synchronous/typings.py deleted file mode 100644 index bc3fb0938f..0000000000 --- a/pymongo/synchronous/typings.py +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright 2022-Present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Type aliases used by PyMongo""" -from __future__ import annotations - -from typing import ( - TYPE_CHECKING, - Any, - Mapping, - Optional, - Sequence, - Tuple, - TypeVar, - Union, -) - -from bson.typings import _DocumentOut, _DocumentType, _DocumentTypeArg - -if TYPE_CHECKING: - from pymongo.synchronous.collation import Collation - -_IS_SYNC = True - -# Common Shared Types. -_Address = Tuple[str, Optional[int]] -_CollationIn = Union[Mapping[str, Any], "Collation"] -_Pipeline = Sequence[Mapping[str, Any]] -ClusterTime = Mapping[str, Any] - -_T = TypeVar("_T") - - -def strip_optional(elem: Optional[_T]) -> _T: - """This function is to allow us to cast all of the elements of an iterator from Optional[_T] to _T - while inside a list comprehension. - """ - assert elem is not None - return elem - - -__all__ = [ - "_DocumentOut", - "_DocumentType", - "_DocumentTypeArg", - "_Address", - "_CollationIn", - "_Pipeline", - "strip_optional", -] diff --git a/pymongo/synchronous/uri_parser.py b/pymongo/synchronous/uri_parser.py deleted file mode 100644 index 8e37bdc696..0000000000 --- a/pymongo/synchronous/uri_parser.py +++ /dev/null @@ -1,624 +0,0 @@ -# Copyright 2011-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you -# may not use this file except in compliance with the License. You -# may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or -# implied. See the License for the specific language governing -# permissions and limitations under the License. - - -"""Tools to parse and validate a MongoDB URI.""" -from __future__ import annotations - -import re -import sys -import warnings -from typing import ( - TYPE_CHECKING, - Any, - Mapping, - MutableMapping, - Optional, - Sized, - Union, - cast, -) -from urllib.parse import unquote_plus - -from pymongo.errors import ConfigurationError, InvalidURI -from pymongo.synchronous.client_options import _parse_ssl_options -from pymongo.synchronous.common import ( - INTERNAL_URI_OPTION_NAME_MAP, - SRV_SERVICE_NAME, - URI_OPTIONS_DEPRECATION_MAP, - _CaseInsensitiveDictionary, - get_validated_options, -) -from pymongo.synchronous.srv_resolver import _have_dnspython, _SrvResolver -from pymongo.synchronous.typings import _Address - -if TYPE_CHECKING: - from pymongo.pyopenssl_context import SSLContext - -_IS_SYNC = True -SCHEME = "mongodb://" -SCHEME_LEN = len(SCHEME) -SRV_SCHEME = "mongodb+srv://" -SRV_SCHEME_LEN = len(SRV_SCHEME) -DEFAULT_PORT = 27017 - - -def _unquoted_percent(s: str) -> bool: - """Check for unescaped percent signs. - - :param s: A string. `s` can have things like '%25', '%2525', - and '%E2%85%A8' but cannot have unquoted percent like '%foo'. - """ - for i in range(len(s)): - if s[i] == "%": - sub = s[i : i + 3] - # If unquoting yields the same string this means there was an - # unquoted %. - if unquote_plus(sub) == sub: - return True - return False - - -def parse_userinfo(userinfo: str) -> tuple[str, str]: - """Validates the format of user information in a MongoDB URI. - Reserved characters that are gen-delimiters (":", "/", "?", "#", "[", - "]", "@") as per RFC 3986 must be escaped. - - Returns a 2-tuple containing the unescaped username followed - by the unescaped password. - - :param userinfo: A string of the form : - """ - if "@" in userinfo or userinfo.count(":") > 1 or _unquoted_percent(userinfo): - raise InvalidURI( - "Username and password must be escaped according to " - "RFC 3986, use urllib.parse.quote_plus" - ) - - user, _, passwd = userinfo.partition(":") - # No password is expected with GSSAPI authentication. - if not user: - raise InvalidURI("The empty string is not valid username.") - - return unquote_plus(user), unquote_plus(passwd) - - -def parse_ipv6_literal_host( - entity: str, default_port: Optional[int] -) -> tuple[str, Optional[Union[str, int]]]: - """Validates an IPv6 literal host:port string. - - Returns a 2-tuple of IPv6 literal followed by port where - port is default_port if it wasn't specified in entity. - - :param entity: A string that represents an IPv6 literal enclosed - in braces (e.g. '[::1]' or '[::1]:27017'). - :param default_port: The port number to use when one wasn't - specified in entity. - """ - if entity.find("]") == -1: - raise ValueError( - "an IPv6 address literal must be enclosed in '[' and ']' according to RFC 2732." - ) - i = entity.find("]:") - if i == -1: - return entity[1:-1], default_port - return entity[1:i], entity[i + 2 :] - - -def parse_host(entity: str, default_port: Optional[int] = DEFAULT_PORT) -> _Address: - """Validates a host string - - Returns a 2-tuple of host followed by port where port is default_port - if it wasn't specified in the string. - - :param entity: A host or host:port string where host could be a - hostname or IP address. - :param default_port: The port number to use when one wasn't - specified in entity. - """ - host = entity - port: Optional[Union[str, int]] = default_port - if entity[0] == "[": - host, port = parse_ipv6_literal_host(entity, default_port) - elif entity.endswith(".sock"): - return entity, default_port - elif entity.find(":") != -1: - if entity.count(":") > 1: - raise ValueError( - "Reserved characters such as ':' must be " - "escaped according RFC 2396. An IPv6 " - "address literal must be enclosed in '[' " - "and ']' according to RFC 2732." - ) - host, port = host.split(":", 1) - if isinstance(port, str): - if not port.isdigit() or int(port) > 65535 or int(port) <= 0: - raise ValueError(f"Port must be an integer between 0 and 65535: {port!r}") - port = int(port) - - # Normalize hostname to lowercase, since DNS is case-insensitive: - # http://tools.ietf.org/html/rfc4343 - # This prevents useless rediscovery if "foo.com" is in the seed list but - # "FOO.com" is in the hello response. - return host.lower(), port - - -# Options whose values are implicitly determined by tlsInsecure. -_IMPLICIT_TLSINSECURE_OPTS = { - "tlsallowinvalidcertificates", - "tlsallowinvalidhostnames", - "tlsdisableocspendpointcheck", -} - - -def _parse_options(opts: str, delim: Optional[str]) -> _CaseInsensitiveDictionary: - """Helper method for split_options which creates the options dict. - Also handles the creation of a list for the URI tag_sets/ - readpreferencetags portion, and the use of a unicode options string. - """ - options = _CaseInsensitiveDictionary() - for uriopt in opts.split(delim): - key, value = uriopt.split("=") - if key.lower() == "readpreferencetags": - options.setdefault(key, []).append(value) - else: - if key in options: - warnings.warn(f"Duplicate URI option '{key}'.", stacklevel=2) - if key.lower() == "authmechanismproperties": - val = value - else: - val = unquote_plus(value) - options[key] = val - - return options - - -def _handle_security_options(options: _CaseInsensitiveDictionary) -> _CaseInsensitiveDictionary: - """Raise appropriate errors when conflicting TLS options are present in - the options dictionary. - - :param options: Instance of _CaseInsensitiveDictionary containing - MongoDB URI options. - """ - # Implicitly defined options must not be explicitly specified. - tlsinsecure = options.get("tlsinsecure") - if tlsinsecure is not None: - for opt in _IMPLICIT_TLSINSECURE_OPTS: - if opt in options: - err_msg = "URI options %s and %s cannot be specified simultaneously." - raise InvalidURI( - err_msg % (options.cased_key("tlsinsecure"), options.cased_key(opt)) - ) - - # Handle co-occurence of OCSP & tlsAllowInvalidCertificates options. - tlsallowinvalidcerts = options.get("tlsallowinvalidcertificates") - if tlsallowinvalidcerts is not None: - if "tlsdisableocspendpointcheck" in options: - err_msg = "URI options %s and %s cannot be specified simultaneously." - raise InvalidURI( - err_msg - % ("tlsallowinvalidcertificates", options.cased_key("tlsdisableocspendpointcheck")) - ) - if tlsallowinvalidcerts is True: - options["tlsdisableocspendpointcheck"] = True - - # Handle co-occurence of CRL and OCSP-related options. - tlscrlfile = options.get("tlscrlfile") - if tlscrlfile is not None: - for opt in ("tlsinsecure", "tlsallowinvalidcertificates", "tlsdisableocspendpointcheck"): - if options.get(opt) is True: - err_msg = "URI option %s=True cannot be specified when CRL checking is enabled." - raise InvalidURI(err_msg % (opt,)) - - if "ssl" in options and "tls" in options: - - def truth_value(val: Any) -> Any: - if val in ("true", "false"): - return val == "true" - if isinstance(val, bool): - return val - return val - - if truth_value(options.get("ssl")) != truth_value(options.get("tls")): - err_msg = "Can not specify conflicting values for URI options %s and %s." - raise InvalidURI(err_msg % (options.cased_key("ssl"), options.cased_key("tls"))) - - return options - - -def _handle_option_deprecations(options: _CaseInsensitiveDictionary) -> _CaseInsensitiveDictionary: - """Issue appropriate warnings when deprecated options are present in the - options dictionary. Removes deprecated option key, value pairs if the - options dictionary is found to also have the renamed option. - - :param options: Instance of _CaseInsensitiveDictionary containing - MongoDB URI options. - """ - for optname in list(options): - if optname in URI_OPTIONS_DEPRECATION_MAP: - mode, message = URI_OPTIONS_DEPRECATION_MAP[optname] - if mode == "renamed": - newoptname = message - if newoptname in options: - warn_msg = "Deprecated option '%s' ignored in favor of '%s'." - warnings.warn( - warn_msg % (options.cased_key(optname), options.cased_key(newoptname)), - DeprecationWarning, - stacklevel=2, - ) - options.pop(optname) - continue - warn_msg = "Option '%s' is deprecated, use '%s' instead." - warnings.warn( - warn_msg % (options.cased_key(optname), newoptname), - DeprecationWarning, - stacklevel=2, - ) - elif mode == "removed": - warn_msg = "Option '%s' is deprecated. %s." - warnings.warn( - warn_msg % (options.cased_key(optname), message), - DeprecationWarning, - stacklevel=2, - ) - - return options - - -def _normalize_options(options: _CaseInsensitiveDictionary) -> _CaseInsensitiveDictionary: - """Normalizes option names in the options dictionary by converting them to - their internally-used names. - - :param options: Instance of _CaseInsensitiveDictionary containing - MongoDB URI options. - """ - # Expand the tlsInsecure option. - tlsinsecure = options.get("tlsinsecure") - if tlsinsecure is not None: - for opt in _IMPLICIT_TLSINSECURE_OPTS: - # Implicit options are logically the same as tlsInsecure. - options[opt] = tlsinsecure - - for optname in list(options): - intname = INTERNAL_URI_OPTION_NAME_MAP.get(optname, None) - if intname is not None: - options[intname] = options.pop(optname) - - return options - - -def validate_options(opts: Mapping[str, Any], warn: bool = False) -> MutableMapping[str, Any]: - """Validates and normalizes options passed in a MongoDB URI. - - Returns a new dictionary of validated and normalized options. If warn is - False then errors will be thrown for invalid options, otherwise they will - be ignored and a warning will be issued. - - :param opts: A dict of MongoDB URI options. - :param warn: If ``True`` then warnings will be logged and - invalid options will be ignored. Otherwise invalid options will - cause errors. - """ - return get_validated_options(opts, warn) - - -def split_options( - opts: str, validate: bool = True, warn: bool = False, normalize: bool = True -) -> MutableMapping[str, Any]: - """Takes the options portion of a MongoDB URI, validates each option - and returns the options in a dictionary. - - :param opt: A string representing MongoDB URI options. - :param validate: If ``True`` (the default), validate and normalize all - options. - :param warn: If ``False`` (the default), suppress all warnings raised - during validation of options. - :param normalize: If ``True`` (the default), renames all options to their - internally-used names. - """ - and_idx = opts.find("&") - semi_idx = opts.find(";") - try: - if and_idx >= 0 and semi_idx >= 0: - raise InvalidURI("Can not mix '&' and ';' for option separators.") - elif and_idx >= 0: - options = _parse_options(opts, "&") - elif semi_idx >= 0: - options = _parse_options(opts, ";") - elif opts.find("=") != -1: - options = _parse_options(opts, None) - else: - raise ValueError - except ValueError: - raise InvalidURI("MongoDB URI options are key=value pairs.") from None - - options = _handle_security_options(options) - - options = _handle_option_deprecations(options) - - if normalize: - options = _normalize_options(options) - - if validate: - options = cast(_CaseInsensitiveDictionary, validate_options(options, warn)) - if options.get("authsource") == "": - raise InvalidURI("the authSource database cannot be an empty string") - - return options - - -def split_hosts(hosts: str, default_port: Optional[int] = DEFAULT_PORT) -> list[_Address]: - """Takes a string of the form host1[:port],host2[:port]... and - splits it into (host, port) tuples. If [:port] isn't present the - default_port is used. - - Returns a set of 2-tuples containing the host name (or IP) followed by - port number. - - :param hosts: A string of the form host1[:port],host2[:port],... - :param default_port: The port number to use when one wasn't specified - for a host. - """ - nodes = [] - for entity in hosts.split(","): - if not entity: - raise ConfigurationError("Empty host (or extra comma in host list).") - port = default_port - # Unix socket entities don't have ports - if entity.endswith(".sock"): - port = None - nodes.append(parse_host(entity, port)) - return nodes - - -# Prohibited characters in database name. DB names also can't have ".", but for -# backward-compat we allow "db.collection" in URI. -_BAD_DB_CHARS = re.compile("[" + re.escape(r'/ "$') + "]") - -_ALLOWED_TXT_OPTS = frozenset( - ["authsource", "authSource", "replicaset", "replicaSet", "loadbalanced", "loadBalanced"] -) - - -def _check_options(nodes: Sized, options: Mapping[str, Any]) -> None: - # Ensure directConnection was not True if there are multiple seeds. - if len(nodes) > 1 and options.get("directconnection"): - raise ConfigurationError("Cannot specify multiple hosts with directConnection=true") - - if options.get("loadbalanced"): - if len(nodes) > 1: - raise ConfigurationError("Cannot specify multiple hosts with loadBalanced=true") - if options.get("directconnection"): - raise ConfigurationError("Cannot specify directConnection=true with loadBalanced=true") - if options.get("replicaset"): - raise ConfigurationError("Cannot specify replicaSet with loadBalanced=true") - - -def parse_uri( - uri: str, - default_port: Optional[int] = DEFAULT_PORT, - validate: bool = True, - warn: bool = False, - normalize: bool = True, - connect_timeout: Optional[float] = None, - srv_service_name: Optional[str] = None, - srv_max_hosts: Optional[int] = None, -) -> dict[str, Any]: - """Parse and validate a MongoDB URI. - - Returns a dict of the form:: - - { - 'nodelist': , - 'username': or None, - 'password': or None, - 'database': or None, - 'collection': or None, - 'options': , - 'fqdn': or None - } - - If the URI scheme is "mongodb+srv://" DNS SRV and TXT lookups will be done - to build nodelist and options. - - :param uri: The MongoDB URI to parse. - :param default_port: The port number to use when one wasn't specified - for a host in the URI. - :param validate: If ``True`` (the default), validate and - normalize all options. Default: ``True``. - :param warn: When validating, if ``True`` then will warn - the user then ignore any invalid options or values. If ``False``, - validation will error when options are unsupported or values are - invalid. Default: ``False``. - :param normalize: If ``True``, convert names of URI options - to their internally-used names. Default: ``True``. - :param connect_timeout: The maximum time in milliseconds to - wait for a response from the DNS server. - :param srv_service_name: A custom SRV service name - - .. versionchanged:: 4.6 - The delimiting slash (``/``) between hosts and connection options is now optional. - For example, "mongodb://example.com?tls=true" is now a valid URI. - - .. versionchanged:: 4.0 - To better follow RFC 3986, unquoted percent signs ("%") are no longer - supported. - - .. versionchanged:: 3.9 - Added the ``normalize`` parameter. - - .. versionchanged:: 3.6 - Added support for mongodb+srv:// URIs. - - .. versionchanged:: 3.5 - Return the original value of the ``readPreference`` MongoDB URI option - instead of the validated read preference mode. - - .. versionchanged:: 3.1 - ``warn`` added so invalid options can be ignored. - """ - if uri.startswith(SCHEME): - is_srv = False - scheme_free = uri[SCHEME_LEN:] - elif uri.startswith(SRV_SCHEME): - if not _have_dnspython(): - python_path = sys.executable or "python" - raise ConfigurationError( - 'The "dnspython" module must be ' - "installed to use mongodb+srv:// URIs. " - "To fix this error install pymongo again:\n " - "%s -m pip install pymongo>=4.3" % (python_path) - ) - is_srv = True - scheme_free = uri[SRV_SCHEME_LEN:] - else: - raise InvalidURI(f"Invalid URI scheme: URI must begin with '{SCHEME}' or '{SRV_SCHEME}'") - - if not scheme_free: - raise InvalidURI("Must provide at least one hostname or IP.") - - user = None - passwd = None - dbase = None - collection = None - options = _CaseInsensitiveDictionary() - - host_plus_db_part, _, opts = scheme_free.partition("?") - if "/" in host_plus_db_part: - host_part, _, dbase = host_plus_db_part.partition("/") - else: - host_part = host_plus_db_part - - if dbase: - dbase = unquote_plus(dbase) - if "." in dbase: - dbase, collection = dbase.split(".", 1) - if _BAD_DB_CHARS.search(dbase): - raise InvalidURI('Bad database name "%s"' % dbase) - else: - dbase = None - - if opts: - options.update(split_options(opts, validate, warn, normalize)) - if srv_service_name is None: - srv_service_name = options.get("srvServiceName", SRV_SERVICE_NAME) - if "@" in host_part: - userinfo, _, hosts = host_part.rpartition("@") - user, passwd = parse_userinfo(userinfo) - else: - hosts = host_part - - if "/" in hosts: - raise InvalidURI("Any '/' in a unix domain socket must be percent-encoded: %s" % host_part) - - hosts = unquote_plus(hosts) - fqdn = None - srv_max_hosts = srv_max_hosts or options.get("srvMaxHosts") - if is_srv: - if options.get("directConnection"): - raise ConfigurationError(f"Cannot specify directConnection=true with {SRV_SCHEME} URIs") - nodes = split_hosts(hosts, default_port=None) - if len(nodes) != 1: - raise InvalidURI(f"{SRV_SCHEME} URIs must include one, and only one, hostname") - fqdn, port = nodes[0] - if port is not None: - raise InvalidURI(f"{SRV_SCHEME} URIs must not include a port number") - - # Use the connection timeout. connectTimeoutMS passed as a keyword - # argument overrides the same option passed in the connection string. - connect_timeout = connect_timeout or options.get("connectTimeoutMS") - dns_resolver = _SrvResolver(fqdn, connect_timeout, srv_service_name, srv_max_hosts) - nodes = dns_resolver.get_hosts() - dns_options = dns_resolver.get_options() - if dns_options: - parsed_dns_options = split_options(dns_options, validate, warn, normalize) - if set(parsed_dns_options) - _ALLOWED_TXT_OPTS: - raise ConfigurationError( - "Only authSource, replicaSet, and loadBalanced are supported from DNS" - ) - for opt, val in parsed_dns_options.items(): - if opt not in options: - options[opt] = val - if options.get("loadBalanced") and srv_max_hosts: - raise InvalidURI("You cannot specify loadBalanced with srvMaxHosts") - if options.get("replicaSet") and srv_max_hosts: - raise InvalidURI("You cannot specify replicaSet with srvMaxHosts") - if "tls" not in options and "ssl" not in options: - options["tls"] = True if validate else "true" - elif not is_srv and options.get("srvServiceName") is not None: - raise ConfigurationError( - "The srvServiceName option is only allowed with 'mongodb+srv://' URIs" - ) - elif not is_srv and srv_max_hosts: - raise ConfigurationError( - "The srvMaxHosts option is only allowed with 'mongodb+srv://' URIs" - ) - else: - nodes = split_hosts(hosts, default_port=default_port) - - _check_options(nodes, options) - - return { - "nodelist": nodes, - "username": user, - "password": passwd, - "database": dbase, - "collection": collection, - "options": options, - "fqdn": fqdn, - } - - -def _parse_kms_tls_options(kms_tls_options: Optional[Mapping[str, Any]]) -> dict[str, SSLContext]: - """Parse KMS TLS connection options.""" - if not kms_tls_options: - return {} - if not isinstance(kms_tls_options, dict): - raise TypeError("kms_tls_options must be a dict") - contexts = {} - for provider, options in kms_tls_options.items(): - if not isinstance(options, dict): - raise TypeError(f'kms_tls_options["{provider}"] must be a dict') - options.setdefault("tls", True) - opts = _CaseInsensitiveDictionary(options) - opts = _handle_security_options(opts) - opts = _normalize_options(opts) - opts = cast(_CaseInsensitiveDictionary, validate_options(opts)) - ssl_context, allow_invalid_hostnames = _parse_ssl_options(opts) - if ssl_context is None: - raise ConfigurationError("TLS is required for KMS providers") - if allow_invalid_hostnames: - raise ConfigurationError("Insecure TLS options prohibited") - - for n in [ - "tlsInsecure", - "tlsAllowInvalidCertificates", - "tlsAllowInvalidHostnames", - "tlsDisableCertificateRevocationCheck", - ]: - if n in opts: - raise ConfigurationError(f"Insecure TLS options prohibited: {n}") - contexts[provider] = ssl_context - return contexts - - -if __name__ == "__main__": - import pprint - - try: - pprint.pprint(parse_uri(sys.argv[1])) # noqa: T203 - except InvalidURI as exc: - print(exc) # noqa: T201 - sys.exit(0) diff --git a/pymongo/topology_description.py b/pymongo/topology_description.py index 201d9b390d..d28e11fc47 100644 --- a/pymongo/topology_description.py +++ b/pymongo/topology_description.py @@ -1,21 +1,678 @@ -# Copyright 2024-present MongoDB, Inc. +# Copyright 2014-present MongoDB, Inc. # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); you +# may not use this file except in compliance with the License. You +# may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. -"""Re-import of synchronous TopologyDescription API for compatibility.""" +"""Represent a deployment of MongoDB servers.""" from __future__ import annotations -from pymongo.synchronous.topology_description import * # noqa: F403 -from pymongo.synchronous.topology_description import __doc__ as original_doc +from random import sample +from typing import ( + Any, + Callable, + List, + Mapping, + MutableMapping, + NamedTuple, + Optional, + cast, +) -__doc__ = original_doc +from bson.min_key import MinKey +from bson.objectid import ObjectId +from pymongo import common +from pymongo.errors import ConfigurationError +from pymongo.read_preferences import ReadPreference, _AggWritePref, _ServerMode +from pymongo.server_description import ServerDescription +from pymongo.server_selectors import Selection +from pymongo.server_type import SERVER_TYPE +from pymongo.typings import _Address + +_IS_SYNC = False + + +# Enumeration for various kinds of MongoDB cluster topologies. +class _TopologyType(NamedTuple): + Single: int + ReplicaSetNoPrimary: int + ReplicaSetWithPrimary: int + Sharded: int + Unknown: int + LoadBalanced: int + + +TOPOLOGY_TYPE = _TopologyType(*range(6)) + +# Topologies compatible with SRV record polling. +SRV_POLLING_TOPOLOGIES: tuple[int, int] = (TOPOLOGY_TYPE.Unknown, TOPOLOGY_TYPE.Sharded) + + +_ServerSelector = Callable[[List[ServerDescription]], List[ServerDescription]] + + +class TopologyDescription: + def __init__( + self, + topology_type: int, + server_descriptions: dict[_Address, ServerDescription], + replica_set_name: Optional[str], + max_set_version: Optional[int], + max_election_id: Optional[ObjectId], + topology_settings: Any, + ) -> None: + """Representation of a deployment of MongoDB servers. + + :param topology_type: initial type + :param server_descriptions: dict of (address, ServerDescription) for + all seeds + :param replica_set_name: replica set name or None + :param max_set_version: greatest setVersion seen from a primary, or None + :param max_election_id: greatest electionId seen from a primary, or None + :param topology_settings: a TopologySettings + """ + self._topology_type = topology_type + self._replica_set_name = replica_set_name + self._server_descriptions = server_descriptions + self._max_set_version = max_set_version + self._max_election_id = max_election_id + + # The heartbeat_frequency is used in staleness estimates. + self._topology_settings = topology_settings + + # Is PyMongo compatible with all servers' wire protocols? + self._incompatible_err = None + if self._topology_type != TOPOLOGY_TYPE.LoadBalanced: + self._init_incompatible_err() + + # Server Discovery And Monitoring Spec: Whenever a client updates the + # TopologyDescription from an hello response, it MUST set + # TopologyDescription.logicalSessionTimeoutMinutes to the smallest + # logicalSessionTimeoutMinutes value among ServerDescriptions of all + # data-bearing server types. If any have a null + # logicalSessionTimeoutMinutes, then + # TopologyDescription.logicalSessionTimeoutMinutes MUST be set to null. + readable_servers = self.readable_servers + if not readable_servers: + self._ls_timeout_minutes = None + elif any(s.logical_session_timeout_minutes is None for s in readable_servers): + self._ls_timeout_minutes = None + else: + self._ls_timeout_minutes = min( # type: ignore[type-var] + s.logical_session_timeout_minutes for s in readable_servers + ) + + def _init_incompatible_err(self) -> None: + """Internal compatibility check for non-load balanced topologies.""" + for s in self._server_descriptions.values(): + if not s.is_server_type_known: + continue + + # s.min/max_wire_version is the server's wire protocol. + # MIN/MAX_SUPPORTED_WIRE_VERSION is what PyMongo supports. + server_too_new = ( + # Server too new. + s.min_wire_version is not None + and s.min_wire_version > common.MAX_SUPPORTED_WIRE_VERSION + ) + + server_too_old = ( + # Server too old. + s.max_wire_version is not None + and s.max_wire_version < common.MIN_SUPPORTED_WIRE_VERSION + ) + + if server_too_new: + self._incompatible_err = ( + "Server at %s:%d requires wire version %d, but this " # type: ignore + "version of PyMongo only supports up to %d." + % ( + s.address[0], + s.address[1] or 0, + s.min_wire_version, + common.MAX_SUPPORTED_WIRE_VERSION, + ) + ) + + elif server_too_old: + self._incompatible_err = ( + "Server at %s:%d reports wire version %d, but this " # type: ignore + "version of PyMongo requires at least %d (MongoDB %s)." + % ( + s.address[0], + s.address[1] or 0, + s.max_wire_version, + common.MIN_SUPPORTED_WIRE_VERSION, + common.MIN_SUPPORTED_SERVER_VERSION, + ) + ) + + break + + def check_compatible(self) -> None: + """Raise ConfigurationError if any server is incompatible. + + A server is incompatible if its wire protocol version range does not + overlap with PyMongo's. + """ + if self._incompatible_err: + raise ConfigurationError(self._incompatible_err) + + def has_server(self, address: _Address) -> bool: + return address in self._server_descriptions + + def reset_server(self, address: _Address) -> TopologyDescription: + """A copy of this description, with one server marked Unknown.""" + unknown_sd = self._server_descriptions[address].to_unknown() + return updated_topology_description(self, unknown_sd) + + def reset(self) -> TopologyDescription: + """A copy of this description, with all servers marked Unknown.""" + if self._topology_type == TOPOLOGY_TYPE.ReplicaSetWithPrimary: + topology_type = TOPOLOGY_TYPE.ReplicaSetNoPrimary + else: + topology_type = self._topology_type + + # The default ServerDescription's type is Unknown. + sds = {address: ServerDescription(address) for address in self._server_descriptions} + + return TopologyDescription( + topology_type, + sds, + self._replica_set_name, + self._max_set_version, + self._max_election_id, + self._topology_settings, + ) + + def server_descriptions(self) -> dict[_Address, ServerDescription]: + """dict of (address, + :class:`~pymongo.server_description.ServerDescription`). + """ + return self._server_descriptions.copy() + + @property + def topology_type(self) -> int: + """The type of this topology.""" + return self._topology_type + + @property + def topology_type_name(self) -> str: + """The topology type as a human readable string. + + .. versionadded:: 3.4 + """ + return TOPOLOGY_TYPE._fields[self._topology_type] + + @property + def replica_set_name(self) -> Optional[str]: + """The replica set name.""" + return self._replica_set_name + + @property + def max_set_version(self) -> Optional[int]: + """Greatest setVersion seen from a primary, or None.""" + return self._max_set_version + + @property + def max_election_id(self) -> Optional[ObjectId]: + """Greatest electionId seen from a primary, or None.""" + return self._max_election_id + + @property + def logical_session_timeout_minutes(self) -> Optional[int]: + """Minimum logical session timeout, or None.""" + return self._ls_timeout_minutes + + @property + def known_servers(self) -> list[ServerDescription]: + """List of Servers of types besides Unknown.""" + return [s for s in self._server_descriptions.values() if s.is_server_type_known] + + @property + def has_known_servers(self) -> bool: + """Whether there are any Servers of types besides Unknown.""" + return any(s for s in self._server_descriptions.values() if s.is_server_type_known) + + @property + def readable_servers(self) -> list[ServerDescription]: + """List of readable Servers.""" + return [s for s in self._server_descriptions.values() if s.is_readable] + + @property + def common_wire_version(self) -> Optional[int]: + """Minimum of all servers' max wire versions, or None.""" + servers = self.known_servers + if servers: + return min(s.max_wire_version for s in self.known_servers) + + return None + + @property + def heartbeat_frequency(self) -> int: + return self._topology_settings.heartbeat_frequency + + @property + def srv_max_hosts(self) -> int: + return self._topology_settings._srv_max_hosts + + def _apply_local_threshold(self, selection: Optional[Selection]) -> list[ServerDescription]: + if not selection: + return [] + round_trip_times: list[float] = [] + for server in selection.server_descriptions: + if server.round_trip_time is None: + config_err_msg = f"round_trip_time for server {server.address} is unexpectedly None: {self}, servers: {selection.server_descriptions}" + raise ConfigurationError(config_err_msg) + round_trip_times.append(server.round_trip_time) + # Round trip time in seconds. + fastest = min(round_trip_times) + threshold = self._topology_settings.local_threshold_ms / 1000.0 + return [ + s + for s in selection.server_descriptions + if (cast(float, s.round_trip_time) - fastest) <= threshold + ] + + def apply_selector( + self, + selector: Any, + address: Optional[_Address] = None, + custom_selector: Optional[_ServerSelector] = None, + ) -> list[ServerDescription]: + """List of servers matching the provided selector(s). + + :param selector: a callable that takes a Selection as input and returns + a Selection as output. For example, an instance of a read + preference from :mod:`~pymongo.read_preferences`. + :param address: A server address to select. + :param custom_selector: A callable that augments server + selection rules. Accepts a list of + :class:`~pymongo.server_description.ServerDescription` objects and + return a list of server descriptions that should be considered + suitable for the desired operation. + + .. versionadded:: 3.4 + """ + if getattr(selector, "min_wire_version", 0): + common_wv = self.common_wire_version + if common_wv and common_wv < selector.min_wire_version: + raise ConfigurationError( + "%s requires min wire version %d, but topology's min" + " wire version is %d" % (selector, selector.min_wire_version, common_wv) + ) + + if isinstance(selector, _AggWritePref): + selector.selection_hook(self) + + if self.topology_type == TOPOLOGY_TYPE.Unknown: + return [] + elif self.topology_type in (TOPOLOGY_TYPE.Single, TOPOLOGY_TYPE.LoadBalanced): + # Ignore selectors for standalone and load balancer mode. + return self.known_servers + if address: + # Ignore selectors when explicit address is requested. + description = self.server_descriptions().get(address) + return [description] if description else [] + + selection = Selection.from_topology_description(self) + # Ignore read preference for sharded clusters. + if self.topology_type != TOPOLOGY_TYPE.Sharded: + selection = selector(selection) + + # Apply custom selector followed by localThresholdMS. + if custom_selector is not None and selection: + selection = selection.with_server_descriptions( + custom_selector(selection.server_descriptions) + ) + return self._apply_local_threshold(selection) + + def has_readable_server(self, read_preference: _ServerMode = ReadPreference.PRIMARY) -> bool: + """Does this topology have any readable servers available matching the + given read preference? + + :param read_preference: an instance of a read preference from + :mod:`~pymongo.read_preferences`. Defaults to + :attr:`~pymongo.read_preferences.ReadPreference.PRIMARY`. + + .. note:: When connected directly to a single server this method + always returns ``True``. + + .. versionadded:: 3.4 + """ + common.validate_read_preference("read_preference", read_preference) + return any(self.apply_selector(read_preference)) + + def has_writable_server(self) -> bool: + """Does this topology have a writable server available? + + .. note:: When connected directly to a single server this method + always returns ``True``. + + .. versionadded:: 3.4 + """ + return self.has_readable_server(ReadPreference.PRIMARY) + + def __repr__(self) -> str: + # Sort the servers by address. + servers = sorted(self._server_descriptions.values(), key=lambda sd: sd.address) + return "<{} id: {}, topology_type: {}, servers: {!r}>".format( + self.__class__.__name__, + self._topology_settings._topology_id, + self.topology_type_name, + servers, + ) + + +# If topology type is Unknown and we receive a hello response, what should +# the new topology type be? +_SERVER_TYPE_TO_TOPOLOGY_TYPE = { + SERVER_TYPE.Mongos: TOPOLOGY_TYPE.Sharded, + SERVER_TYPE.RSPrimary: TOPOLOGY_TYPE.ReplicaSetWithPrimary, + SERVER_TYPE.RSSecondary: TOPOLOGY_TYPE.ReplicaSetNoPrimary, + SERVER_TYPE.RSArbiter: TOPOLOGY_TYPE.ReplicaSetNoPrimary, + SERVER_TYPE.RSOther: TOPOLOGY_TYPE.ReplicaSetNoPrimary, + # Note: SERVER_TYPE.LoadBalancer and Unknown are intentionally left out. +} + + +def updated_topology_description( + topology_description: TopologyDescription, server_description: ServerDescription +) -> TopologyDescription: + """Return an updated copy of a TopologyDescription. + + :param topology_description: the current TopologyDescription + :param server_description: a new ServerDescription that resulted from + a hello call + + Called after attempting (successfully or not) to call hello on the + server at server_description.address. Does not modify topology_description. + """ + address = server_description.address + + # These values will be updated, if necessary, to form the new + # TopologyDescription. + topology_type = topology_description.topology_type + set_name = topology_description.replica_set_name + max_set_version = topology_description.max_set_version + max_election_id = topology_description.max_election_id + server_type = server_description.server_type + + # Don't mutate the original dict of server descriptions; copy it. + sds = topology_description.server_descriptions() + + # Replace this server's description with the new one. + sds[address] = server_description + + if topology_type == TOPOLOGY_TYPE.Single: + # Set server type to Unknown if replica set name does not match. + if set_name is not None and set_name != server_description.replica_set_name: + error = ConfigurationError( + "client is configured to connect to a replica set named " + "'{}' but this node belongs to a set named '{}'".format( + set_name, server_description.replica_set_name + ) + ) + sds[address] = server_description.to_unknown(error=error) + # Single type never changes. + return TopologyDescription( + TOPOLOGY_TYPE.Single, + sds, + set_name, + max_set_version, + max_election_id, + topology_description._topology_settings, + ) + + if topology_type == TOPOLOGY_TYPE.Unknown: + if server_type in (SERVER_TYPE.Standalone, SERVER_TYPE.LoadBalancer): + if len(topology_description._topology_settings.seeds) == 1: + topology_type = TOPOLOGY_TYPE.Single + else: + # Remove standalone from Topology when given multiple seeds. + sds.pop(address) + elif server_type not in (SERVER_TYPE.Unknown, SERVER_TYPE.RSGhost): + topology_type = _SERVER_TYPE_TO_TOPOLOGY_TYPE[server_type] + + if topology_type == TOPOLOGY_TYPE.Sharded: + if server_type not in (SERVER_TYPE.Mongos, SERVER_TYPE.Unknown): + sds.pop(address) + + elif topology_type == TOPOLOGY_TYPE.ReplicaSetNoPrimary: + if server_type in (SERVER_TYPE.Standalone, SERVER_TYPE.Mongos): + sds.pop(address) + + elif server_type == SERVER_TYPE.RSPrimary: + (topology_type, set_name, max_set_version, max_election_id) = _update_rs_from_primary( + sds, set_name, server_description, max_set_version, max_election_id + ) + + elif server_type in (SERVER_TYPE.RSSecondary, SERVER_TYPE.RSArbiter, SERVER_TYPE.RSOther): + topology_type, set_name = _update_rs_no_primary_from_member( + sds, set_name, server_description + ) + + elif topology_type == TOPOLOGY_TYPE.ReplicaSetWithPrimary: + if server_type in (SERVER_TYPE.Standalone, SERVER_TYPE.Mongos): + sds.pop(address) + topology_type = _check_has_primary(sds) + + elif server_type == SERVER_TYPE.RSPrimary: + (topology_type, set_name, max_set_version, max_election_id) = _update_rs_from_primary( + sds, set_name, server_description, max_set_version, max_election_id + ) + + elif server_type in (SERVER_TYPE.RSSecondary, SERVER_TYPE.RSArbiter, SERVER_TYPE.RSOther): + topology_type = _update_rs_with_primary_from_member(sds, set_name, server_description) + + else: + # Server type is Unknown or RSGhost: did we just lose the primary? + topology_type = _check_has_primary(sds) + + # Return updated copy. + return TopologyDescription( + topology_type, + sds, + set_name, + max_set_version, + max_election_id, + topology_description._topology_settings, + ) + + +def _updated_topology_description_srv_polling( + topology_description: TopologyDescription, seedlist: list[tuple[str, Any]] +) -> TopologyDescription: + """Return an updated copy of a TopologyDescription. + + :param topology_description: the current TopologyDescription + :param seedlist: a list of new seeds new ServerDescription that resulted from + a hello call + """ + assert topology_description.topology_type in SRV_POLLING_TOPOLOGIES + # Create a copy of the server descriptions. + sds = topology_description.server_descriptions() + + # If seeds haven't changed, don't do anything. + if set(sds.keys()) == set(seedlist): + return topology_description + + # Remove SDs corresponding to servers no longer part of the SRV record. + for address in list(sds.keys()): + if address not in seedlist: + sds.pop(address) + + if topology_description.srv_max_hosts != 0: + new_hosts = set(seedlist) - set(sds.keys()) + n_to_add = topology_description.srv_max_hosts - len(sds) + if n_to_add > 0: + seedlist = sample(sorted(new_hosts), min(n_to_add, len(new_hosts))) + else: + seedlist = [] + # Add SDs corresponding to servers recently added to the SRV record. + for address in seedlist: + if address not in sds: + sds[address] = ServerDescription(address) + return TopologyDescription( + topology_description.topology_type, + sds, + topology_description.replica_set_name, + topology_description.max_set_version, + topology_description.max_election_id, + topology_description._topology_settings, + ) + + +def _update_rs_from_primary( + sds: MutableMapping[_Address, ServerDescription], + replica_set_name: Optional[str], + server_description: ServerDescription, + max_set_version: Optional[int], + max_election_id: Optional[ObjectId], +) -> tuple[int, Optional[str], Optional[int], Optional[ObjectId]]: + """Update topology description from a primary's hello response. + + Pass in a dict of ServerDescriptions, current replica set name, the + ServerDescription we are processing, and the TopologyDescription's + max_set_version and max_election_id if any. + + Returns (new topology type, new replica_set_name, new max_set_version, + new max_election_id). + """ + if replica_set_name is None: + replica_set_name = server_description.replica_set_name + + elif replica_set_name != server_description.replica_set_name: + # We found a primary but it doesn't have the replica_set_name + # provided by the user. + sds.pop(server_description.address) + return _check_has_primary(sds), replica_set_name, max_set_version, max_election_id + + if server_description.max_wire_version is None or server_description.max_wire_version < 17: + new_election_tuple: tuple = (server_description.set_version, server_description.election_id) + max_election_tuple: tuple = (max_set_version, max_election_id) + if None not in new_election_tuple: + if None not in max_election_tuple and new_election_tuple < max_election_tuple: + # Stale primary, set to type Unknown. + sds[server_description.address] = server_description.to_unknown() + return _check_has_primary(sds), replica_set_name, max_set_version, max_election_id + max_election_id = server_description.election_id + + if server_description.set_version is not None and ( + max_set_version is None or server_description.set_version > max_set_version + ): + max_set_version = server_description.set_version + else: + new_election_tuple = server_description.election_id, server_description.set_version + max_election_tuple = max_election_id, max_set_version + new_election_safe = tuple(MinKey() if i is None else i for i in new_election_tuple) + max_election_safe = tuple(MinKey() if i is None else i for i in max_election_tuple) + if new_election_safe < max_election_safe: + # Stale primary, set to type Unknown. + sds[server_description.address] = server_description.to_unknown() + return _check_has_primary(sds), replica_set_name, max_set_version, max_election_id + else: + max_election_id = server_description.election_id + max_set_version = server_description.set_version + + # We've heard from the primary. Is it the same primary as before? + for server in sds.values(): + if ( + server.server_type is SERVER_TYPE.RSPrimary + and server.address != server_description.address + ): + # Reset old primary's type to Unknown. + sds[server.address] = server.to_unknown() + + # There can be only one prior primary. + break + + # Discover new hosts from this primary's response. + for new_address in server_description.all_hosts: + if new_address not in sds: + sds[new_address] = ServerDescription(new_address) + + # Remove hosts not in the response. + for addr in set(sds) - server_description.all_hosts: + sds.pop(addr) + + # If the host list differs from the seed list, we may not have a primary + # after all. + return (_check_has_primary(sds), replica_set_name, max_set_version, max_election_id) + + +def _update_rs_with_primary_from_member( + sds: MutableMapping[_Address, ServerDescription], + replica_set_name: Optional[str], + server_description: ServerDescription, +) -> int: + """RS with known primary. Process a response from a non-primary. + + Pass in a dict of ServerDescriptions, current replica set name, and the + ServerDescription we are processing. + + Returns new topology type. + """ + assert replica_set_name is not None + + if replica_set_name != server_description.replica_set_name: + sds.pop(server_description.address) + elif server_description.me and server_description.address != server_description.me: + sds.pop(server_description.address) + + # Had this member been the primary? + return _check_has_primary(sds) + + +def _update_rs_no_primary_from_member( + sds: MutableMapping[_Address, ServerDescription], + replica_set_name: Optional[str], + server_description: ServerDescription, +) -> tuple[int, Optional[str]]: + """RS without known primary. Update from a non-primary's response. + + Pass in a dict of ServerDescriptions, current replica set name, and the + ServerDescription we are processing. + + Returns (new topology type, new replica_set_name). + """ + topology_type = TOPOLOGY_TYPE.ReplicaSetNoPrimary + if replica_set_name is None: + replica_set_name = server_description.replica_set_name + + elif replica_set_name != server_description.replica_set_name: + sds.pop(server_description.address) + return topology_type, replica_set_name + + # This isn't the primary's response, so don't remove any servers + # it doesn't report. Only add new servers. + for address in server_description.all_hosts: + if address not in sds: + sds[address] = ServerDescription(address) + + if server_description.me and server_description.address != server_description.me: + sds.pop(server_description.address) + + return topology_type, replica_set_name + + +def _check_has_primary(sds: Mapping[_Address, ServerDescription]) -> int: + """Current topology type is ReplicaSetWithPrimary. Is primary still known? + + Pass in a dict of ServerDescriptions. + + Returns new topology type. + """ + for s in sds.values(): + if s.server_type == SERVER_TYPE.RSPrimary: + return TOPOLOGY_TYPE.ReplicaSetWithPrimary + else: # noqa: PLW0120 + return TOPOLOGY_TYPE.ReplicaSetNoPrimary diff --git a/pymongo/asynchronous/typings.py b/pymongo/typings.py similarity index 60% rename from pymongo/asynchronous/typings.py rename to pymongo/typings.py index 508c5b6dea..1923f918b1 100644 --- a/pymongo/asynchronous/typings.py +++ b/pymongo/typings.py @@ -29,7 +29,14 @@ from bson.typings import _DocumentOut, _DocumentType, _DocumentTypeArg if TYPE_CHECKING: - from pymongo.asynchronous.collation import Collation + from pymongo import AsyncMongoClient, MongoClient + from pymongo.asynchronous.bulk import _AsyncBulk + from pymongo.asynchronous.client_session import AsyncClientSession + from pymongo.asynchronous.pool import AsyncConnection + from pymongo.collation import Collation + from pymongo.synchronous.bulk import _Bulk + from pymongo.synchronous.client_session import ClientSession + from pymongo.synchronous.pool import Connection _IS_SYNC = False @@ -41,9 +48,15 @@ _T = TypeVar("_T") +# Type hinting types for compatibility between async and sync classes +_AgnosticClientSession = Union["AsyncClientSession", "ClientSession"] +_AgnosticMongoClient = Union["AsyncMongoClient", "MongoClient"] +_AgnosticBulk = Union["_AsyncBulk", "_Bulk"] +_AgnosticConnection = Union["AsyncConnection", "Connection"] + def strip_optional(elem: Optional[_T]) -> _T: - """This function is to allow us to cast all of the elements of an iterator from Optional[_T] to _T + """This function is to allow us to cast all the elements of an iterator from Optional[_T] to _T while inside a list comprehension. """ assert elem is not None @@ -58,4 +71,6 @@ def strip_optional(elem: Optional[_T]) -> _T: "_CollationIn", "_Pipeline", "strip_optional", + "_AgnosticClientSession", + "_AgnosticMongoClient", ] diff --git a/pymongo/uri_parser.py b/pymongo/uri_parser.py index e74ef18831..4247d51fd1 100644 --- a/pymongo/uri_parser.py +++ b/pymongo/uri_parser.py @@ -1,21 +1,624 @@ -# Copyright 2024-present MongoDB, Inc. +# Copyright 2011-present MongoDB, Inc. # -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); you +# may not use this file except in compliance with the License. You +# may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. -"""Re-import of synchronous URIParser API for compatibility.""" + +"""Tools to parse and validate a MongoDB URI.""" from __future__ import annotations -from pymongo.synchronous.uri_parser import * # noqa: F403 -from pymongo.synchronous.uri_parser import __doc__ as original_doc +import re +import sys +import warnings +from typing import ( + TYPE_CHECKING, + Any, + Mapping, + MutableMapping, + Optional, + Sized, + Union, + cast, +) +from urllib.parse import unquote_plus + +from pymongo.client_options import _parse_ssl_options +from pymongo.common import ( + INTERNAL_URI_OPTION_NAME_MAP, + SRV_SERVICE_NAME, + URI_OPTIONS_DEPRECATION_MAP, + _CaseInsensitiveDictionary, + get_validated_options, +) +from pymongo.errors import ConfigurationError, InvalidURI +from pymongo.srv_resolver import _have_dnspython, _SrvResolver +from pymongo.typings import _Address + +if TYPE_CHECKING: + from pymongo.pyopenssl_context import SSLContext + +_IS_SYNC = False +SCHEME = "mongodb://" +SCHEME_LEN = len(SCHEME) +SRV_SCHEME = "mongodb+srv://" +SRV_SCHEME_LEN = len(SRV_SCHEME) +DEFAULT_PORT = 27017 + + +def _unquoted_percent(s: str) -> bool: + """Check for unescaped percent signs. + + :param s: A string. `s` can have things like '%25', '%2525', + and '%E2%85%A8' but cannot have unquoted percent like '%foo'. + """ + for i in range(len(s)): + if s[i] == "%": + sub = s[i : i + 3] + # If unquoting yields the same string this means there was an + # unquoted %. + if unquote_plus(sub) == sub: + return True + return False + + +def parse_userinfo(userinfo: str) -> tuple[str, str]: + """Validates the format of user information in a MongoDB URI. + Reserved characters that are gen-delimiters (":", "/", "?", "#", "[", + "]", "@") as per RFC 3986 must be escaped. + + Returns a 2-tuple containing the unescaped username followed + by the unescaped password. + + :param userinfo: A string of the form : + """ + if "@" in userinfo or userinfo.count(":") > 1 or _unquoted_percent(userinfo): + raise InvalidURI( + "Username and password must be escaped according to " + "RFC 3986, use urllib.parse.quote_plus" + ) + + user, _, passwd = userinfo.partition(":") + # No password is expected with GSSAPI authentication. + if not user: + raise InvalidURI("The empty string is not valid username.") + + return unquote_plus(user), unquote_plus(passwd) + + +def parse_ipv6_literal_host( + entity: str, default_port: Optional[int] +) -> tuple[str, Optional[Union[str, int]]]: + """Validates an IPv6 literal host:port string. + + Returns a 2-tuple of IPv6 literal followed by port where + port is default_port if it wasn't specified in entity. + + :param entity: A string that represents an IPv6 literal enclosed + in braces (e.g. '[::1]' or '[::1]:27017'). + :param default_port: The port number to use when one wasn't + specified in entity. + """ + if entity.find("]") == -1: + raise ValueError( + "an IPv6 address literal must be enclosed in '[' and ']' according to RFC 2732." + ) + i = entity.find("]:") + if i == -1: + return entity[1:-1], default_port + return entity[1:i], entity[i + 2 :] + + +def parse_host(entity: str, default_port: Optional[int] = DEFAULT_PORT) -> _Address: + """Validates a host string + + Returns a 2-tuple of host followed by port where port is default_port + if it wasn't specified in the string. + + :param entity: A host or host:port string where host could be a + hostname or IP address. + :param default_port: The port number to use when one wasn't + specified in entity. + """ + host = entity + port: Optional[Union[str, int]] = default_port + if entity[0] == "[": + host, port = parse_ipv6_literal_host(entity, default_port) + elif entity.endswith(".sock"): + return entity, default_port + elif entity.find(":") != -1: + if entity.count(":") > 1: + raise ValueError( + "Reserved characters such as ':' must be " + "escaped according RFC 2396. An IPv6 " + "address literal must be enclosed in '[' " + "and ']' according to RFC 2732." + ) + host, port = host.split(":", 1) + if isinstance(port, str): + if not port.isdigit() or int(port) > 65535 or int(port) <= 0: + raise ValueError(f"Port must be an integer between 0 and 65535: {port!r}") + port = int(port) + + # Normalize hostname to lowercase, since DNS is case-insensitive: + # http://tools.ietf.org/html/rfc4343 + # This prevents useless rediscovery if "foo.com" is in the seed list but + # "FOO.com" is in the hello response. + return host.lower(), port + + +# Options whose values are implicitly determined by tlsInsecure. +_IMPLICIT_TLSINSECURE_OPTS = { + "tlsallowinvalidcertificates", + "tlsallowinvalidhostnames", + "tlsdisableocspendpointcheck", +} + + +def _parse_options(opts: str, delim: Optional[str]) -> _CaseInsensitiveDictionary: + """Helper method for split_options which creates the options dict. + Also handles the creation of a list for the URI tag_sets/ + readpreferencetags portion, and the use of a unicode options string. + """ + options = _CaseInsensitiveDictionary() + for uriopt in opts.split(delim): + key, value = uriopt.split("=") + if key.lower() == "readpreferencetags": + options.setdefault(key, []).append(value) + else: + if key in options: + warnings.warn(f"Duplicate URI option '{key}'.", stacklevel=2) + if key.lower() == "authmechanismproperties": + val = value + else: + val = unquote_plus(value) + options[key] = val + + return options + + +def _handle_security_options(options: _CaseInsensitiveDictionary) -> _CaseInsensitiveDictionary: + """Raise appropriate errors when conflicting TLS options are present in + the options dictionary. + + :param options: Instance of _CaseInsensitiveDictionary containing + MongoDB URI options. + """ + # Implicitly defined options must not be explicitly specified. + tlsinsecure = options.get("tlsinsecure") + if tlsinsecure is not None: + for opt in _IMPLICIT_TLSINSECURE_OPTS: + if opt in options: + err_msg = "URI options %s and %s cannot be specified simultaneously." + raise InvalidURI( + err_msg % (options.cased_key("tlsinsecure"), options.cased_key(opt)) + ) + + # Handle co-occurence of OCSP & tlsAllowInvalidCertificates options. + tlsallowinvalidcerts = options.get("tlsallowinvalidcertificates") + if tlsallowinvalidcerts is not None: + if "tlsdisableocspendpointcheck" in options: + err_msg = "URI options %s and %s cannot be specified simultaneously." + raise InvalidURI( + err_msg + % ("tlsallowinvalidcertificates", options.cased_key("tlsdisableocspendpointcheck")) + ) + if tlsallowinvalidcerts is True: + options["tlsdisableocspendpointcheck"] = True + + # Handle co-occurence of CRL and OCSP-related options. + tlscrlfile = options.get("tlscrlfile") + if tlscrlfile is not None: + for opt in ("tlsinsecure", "tlsallowinvalidcertificates", "tlsdisableocspendpointcheck"): + if options.get(opt) is True: + err_msg = "URI option %s=True cannot be specified when CRL checking is enabled." + raise InvalidURI(err_msg % (opt,)) + + if "ssl" in options and "tls" in options: + + def truth_value(val: Any) -> Any: + if val in ("true", "false"): + return val == "true" + if isinstance(val, bool): + return val + return val + + if truth_value(options.get("ssl")) != truth_value(options.get("tls")): + err_msg = "Can not specify conflicting values for URI options %s and %s." + raise InvalidURI(err_msg % (options.cased_key("ssl"), options.cased_key("tls"))) + + return options + + +def _handle_option_deprecations(options: _CaseInsensitiveDictionary) -> _CaseInsensitiveDictionary: + """Issue appropriate warnings when deprecated options are present in the + options dictionary. Removes deprecated option key, value pairs if the + options dictionary is found to also have the renamed option. + + :param options: Instance of _CaseInsensitiveDictionary containing + MongoDB URI options. + """ + for optname in list(options): + if optname in URI_OPTIONS_DEPRECATION_MAP: + mode, message = URI_OPTIONS_DEPRECATION_MAP[optname] + if mode == "renamed": + newoptname = message + if newoptname in options: + warn_msg = "Deprecated option '%s' ignored in favor of '%s'." + warnings.warn( + warn_msg % (options.cased_key(optname), options.cased_key(newoptname)), + DeprecationWarning, + stacklevel=2, + ) + options.pop(optname) + continue + warn_msg = "Option '%s' is deprecated, use '%s' instead." + warnings.warn( + warn_msg % (options.cased_key(optname), newoptname), + DeprecationWarning, + stacklevel=2, + ) + elif mode == "removed": + warn_msg = "Option '%s' is deprecated. %s." + warnings.warn( + warn_msg % (options.cased_key(optname), message), + DeprecationWarning, + stacklevel=2, + ) + + return options + + +def _normalize_options(options: _CaseInsensitiveDictionary) -> _CaseInsensitiveDictionary: + """Normalizes option names in the options dictionary by converting them to + their internally-used names. + + :param options: Instance of _CaseInsensitiveDictionary containing + MongoDB URI options. + """ + # Expand the tlsInsecure option. + tlsinsecure = options.get("tlsinsecure") + if tlsinsecure is not None: + for opt in _IMPLICIT_TLSINSECURE_OPTS: + # Implicit options are logically the same as tlsInsecure. + options[opt] = tlsinsecure + + for optname in list(options): + intname = INTERNAL_URI_OPTION_NAME_MAP.get(optname, None) + if intname is not None: + options[intname] = options.pop(optname) + + return options + + +def validate_options(opts: Mapping[str, Any], warn: bool = False) -> MutableMapping[str, Any]: + """Validates and normalizes options passed in a MongoDB URI. + + Returns a new dictionary of validated and normalized options. If warn is + False then errors will be thrown for invalid options, otherwise they will + be ignored and a warning will be issued. + + :param opts: A dict of MongoDB URI options. + :param warn: If ``True`` then warnings will be logged and + invalid options will be ignored. Otherwise invalid options will + cause errors. + """ + return get_validated_options(opts, warn) + + +def split_options( + opts: str, validate: bool = True, warn: bool = False, normalize: bool = True +) -> MutableMapping[str, Any]: + """Takes the options portion of a MongoDB URI, validates each option + and returns the options in a dictionary. + + :param opt: A string representing MongoDB URI options. + :param validate: If ``True`` (the default), validate and normalize all + options. + :param warn: If ``False`` (the default), suppress all warnings raised + during validation of options. + :param normalize: If ``True`` (the default), renames all options to their + internally-used names. + """ + and_idx = opts.find("&") + semi_idx = opts.find(";") + try: + if and_idx >= 0 and semi_idx >= 0: + raise InvalidURI("Can not mix '&' and ';' for option separators.") + elif and_idx >= 0: + options = _parse_options(opts, "&") + elif semi_idx >= 0: + options = _parse_options(opts, ";") + elif opts.find("=") != -1: + options = _parse_options(opts, None) + else: + raise ValueError + except ValueError: + raise InvalidURI("MongoDB URI options are key=value pairs.") from None + + options = _handle_security_options(options) + + options = _handle_option_deprecations(options) + + if normalize: + options = _normalize_options(options) + + if validate: + options = cast(_CaseInsensitiveDictionary, validate_options(options, warn)) + if options.get("authsource") == "": + raise InvalidURI("the authSource database cannot be an empty string") + + return options + + +def split_hosts(hosts: str, default_port: Optional[int] = DEFAULT_PORT) -> list[_Address]: + """Takes a string of the form host1[:port],host2[:port]... and + splits it into (host, port) tuples. If [:port] isn't present the + default_port is used. + + Returns a set of 2-tuples containing the host name (or IP) followed by + port number. + + :param hosts: A string of the form host1[:port],host2[:port],... + :param default_port: The port number to use when one wasn't specified + for a host. + """ + nodes = [] + for entity in hosts.split(","): + if not entity: + raise ConfigurationError("Empty host (or extra comma in host list).") + port = default_port + # Unix socket entities don't have ports + if entity.endswith(".sock"): + port = None + nodes.append(parse_host(entity, port)) + return nodes + + +# Prohibited characters in database name. DB names also can't have ".", but for +# backward-compat we allow "db.collection" in URI. +_BAD_DB_CHARS = re.compile("[" + re.escape(r'/ "$') + "]") + +_ALLOWED_TXT_OPTS = frozenset( + ["authsource", "authSource", "replicaset", "replicaSet", "loadbalanced", "loadBalanced"] +) + + +def _check_options(nodes: Sized, options: Mapping[str, Any]) -> None: + # Ensure directConnection was not True if there are multiple seeds. + if len(nodes) > 1 and options.get("directconnection"): + raise ConfigurationError("Cannot specify multiple hosts with directConnection=true") + + if options.get("loadbalanced"): + if len(nodes) > 1: + raise ConfigurationError("Cannot specify multiple hosts with loadBalanced=true") + if options.get("directconnection"): + raise ConfigurationError("Cannot specify directConnection=true with loadBalanced=true") + if options.get("replicaset"): + raise ConfigurationError("Cannot specify replicaSet with loadBalanced=true") + + +def parse_uri( + uri: str, + default_port: Optional[int] = DEFAULT_PORT, + validate: bool = True, + warn: bool = False, + normalize: bool = True, + connect_timeout: Optional[float] = None, + srv_service_name: Optional[str] = None, + srv_max_hosts: Optional[int] = None, +) -> dict[str, Any]: + """Parse and validate a MongoDB URI. + + Returns a dict of the form:: + + { + 'nodelist': , + 'username': or None, + 'password': or None, + 'database': or None, + 'collection': or None, + 'options': , + 'fqdn': or None + } + + If the URI scheme is "mongodb+srv://" DNS SRV and TXT lookups will be done + to build nodelist and options. + + :param uri: The MongoDB URI to parse. + :param default_port: The port number to use when one wasn't specified + for a host in the URI. + :param validate: If ``True`` (the default), validate and + normalize all options. Default: ``True``. + :param warn: When validating, if ``True`` then will warn + the user then ignore any invalid options or values. If ``False``, + validation will error when options are unsupported or values are + invalid. Default: ``False``. + :param normalize: If ``True``, convert names of URI options + to their internally-used names. Default: ``True``. + :param connect_timeout: The maximum time in milliseconds to + wait for a response from the DNS server. + :param srv_service_name: A custom SRV service name + + .. versionchanged:: 4.6 + The delimiting slash (``/``) between hosts and connection options is now optional. + For example, "mongodb://example.com?tls=true" is now a valid URI. + + .. versionchanged:: 4.0 + To better follow RFC 3986, unquoted percent signs ("%") are no longer + supported. + + .. versionchanged:: 3.9 + Added the ``normalize`` parameter. + + .. versionchanged:: 3.6 + Added support for mongodb+srv:// URIs. + + .. versionchanged:: 3.5 + Return the original value of the ``readPreference`` MongoDB URI option + instead of the validated read preference mode. + + .. versionchanged:: 3.1 + ``warn`` added so invalid options can be ignored. + """ + if uri.startswith(SCHEME): + is_srv = False + scheme_free = uri[SCHEME_LEN:] + elif uri.startswith(SRV_SCHEME): + if not _have_dnspython(): + python_path = sys.executable or "python" + raise ConfigurationError( + 'The "dnspython" module must be ' + "installed to use mongodb+srv:// URIs. " + "To fix this error install pymongo again:\n " + "%s -m pip install pymongo>=4.3" % (python_path) + ) + is_srv = True + scheme_free = uri[SRV_SCHEME_LEN:] + else: + raise InvalidURI(f"Invalid URI scheme: URI must begin with '{SCHEME}' or '{SRV_SCHEME}'") + + if not scheme_free: + raise InvalidURI("Must provide at least one hostname or IP.") + + user = None + passwd = None + dbase = None + collection = None + options = _CaseInsensitiveDictionary() + + host_plus_db_part, _, opts = scheme_free.partition("?") + if "/" in host_plus_db_part: + host_part, _, dbase = host_plus_db_part.partition("/") + else: + host_part = host_plus_db_part + + if dbase: + dbase = unquote_plus(dbase) + if "." in dbase: + dbase, collection = dbase.split(".", 1) + if _BAD_DB_CHARS.search(dbase): + raise InvalidURI('Bad database name "%s"' % dbase) + else: + dbase = None + + if opts: + options.update(split_options(opts, validate, warn, normalize)) + if srv_service_name is None: + srv_service_name = options.get("srvServiceName", SRV_SERVICE_NAME) + if "@" in host_part: + userinfo, _, hosts = host_part.rpartition("@") + user, passwd = parse_userinfo(userinfo) + else: + hosts = host_part + + if "/" in hosts: + raise InvalidURI("Any '/' in a unix domain socket must be percent-encoded: %s" % host_part) + + hosts = unquote_plus(hosts) + fqdn = None + srv_max_hosts = srv_max_hosts or options.get("srvMaxHosts") + if is_srv: + if options.get("directConnection"): + raise ConfigurationError(f"Cannot specify directConnection=true with {SRV_SCHEME} URIs") + nodes = split_hosts(hosts, default_port=None) + if len(nodes) != 1: + raise InvalidURI(f"{SRV_SCHEME} URIs must include one, and only one, hostname") + fqdn, port = nodes[0] + if port is not None: + raise InvalidURI(f"{SRV_SCHEME} URIs must not include a port number") + + # Use the connection timeout. connectTimeoutMS passed as a keyword + # argument overrides the same option passed in the connection string. + connect_timeout = connect_timeout or options.get("connectTimeoutMS") + dns_resolver = _SrvResolver(fqdn, connect_timeout, srv_service_name, srv_max_hosts) + nodes = dns_resolver.get_hosts() + dns_options = dns_resolver.get_options() + if dns_options: + parsed_dns_options = split_options(dns_options, validate, warn, normalize) + if set(parsed_dns_options) - _ALLOWED_TXT_OPTS: + raise ConfigurationError( + "Only authSource, replicaSet, and loadBalanced are supported from DNS" + ) + for opt, val in parsed_dns_options.items(): + if opt not in options: + options[opt] = val + if options.get("loadBalanced") and srv_max_hosts: + raise InvalidURI("You cannot specify loadBalanced with srvMaxHosts") + if options.get("replicaSet") and srv_max_hosts: + raise InvalidURI("You cannot specify replicaSet with srvMaxHosts") + if "tls" not in options and "ssl" not in options: + options["tls"] = True if validate else "true" + elif not is_srv and options.get("srvServiceName") is not None: + raise ConfigurationError( + "The srvServiceName option is only allowed with 'mongodb+srv://' URIs" + ) + elif not is_srv and srv_max_hosts: + raise ConfigurationError( + "The srvMaxHosts option is only allowed with 'mongodb+srv://' URIs" + ) + else: + nodes = split_hosts(hosts, default_port=default_port) + + _check_options(nodes, options) + + return { + "nodelist": nodes, + "username": user, + "password": passwd, + "database": dbase, + "collection": collection, + "options": options, + "fqdn": fqdn, + } + + +def _parse_kms_tls_options(kms_tls_options: Optional[Mapping[str, Any]]) -> dict[str, SSLContext]: + """Parse KMS TLS connection options.""" + if not kms_tls_options: + return {} + if not isinstance(kms_tls_options, dict): + raise TypeError("kms_tls_options must be a dict") + contexts = {} + for provider, options in kms_tls_options.items(): + if not isinstance(options, dict): + raise TypeError(f'kms_tls_options["{provider}"] must be a dict') + options.setdefault("tls", True) + opts = _CaseInsensitiveDictionary(options) + opts = _handle_security_options(opts) + opts = _normalize_options(opts) + opts = cast(_CaseInsensitiveDictionary, validate_options(opts)) + ssl_context, allow_invalid_hostnames = _parse_ssl_options(opts) + if ssl_context is None: + raise ConfigurationError("TLS is required for KMS providers") + if allow_invalid_hostnames: + raise ConfigurationError("Insecure TLS options prohibited") + + for n in [ + "tlsInsecure", + "tlsAllowInvalidCertificates", + "tlsAllowInvalidHostnames", + "tlsDisableCertificateRevocationCheck", + ]: + if n in opts: + raise ConfigurationError(f"Insecure TLS options prohibited: {n}") + contexts[provider] = ssl_context + return contexts + + +if __name__ == "__main__": + import pprint -__doc__ = original_doc + try: + pprint.pprint(parse_uri(sys.argv[1])) # noqa: T203 + except InvalidURI as exc: + print(exc) # noqa: T201 + sys.exit(0) diff --git a/test/__init__.py b/test/__init__.py index a78fab3ca1..9b6368f4de 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -45,14 +45,15 @@ import pymongo import pymongo.errors from bson.son import SON +from pymongo import common +from pymongo.common import partition_node +from pymongo.hello_compat import HelloCompat from pymongo.server_api import ServerApi from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined] -from pymongo.synchronous import common, message -from pymongo.synchronous.common import partition_node +from pymongo.synchronous import message from pymongo.synchronous.database import Database -from pymongo.synchronous.hello_compat import HelloCompat from pymongo.synchronous.mongo_client import MongoClient -from pymongo.synchronous.uri_parser import parse_uri +from pymongo.uri_parser import parse_uri if HAVE_SSL: import ssl diff --git a/test/asynchronous/__init__.py b/test/asynchronous/__init__.py index d38065eb3f..d63ed77232 100644 --- a/test/asynchronous/__init__.py +++ b/test/asynchronous/__init__.py @@ -63,25 +63,20 @@ from contextlib import asynccontextmanager, contextmanager from functools import wraps from test.version import Version -from typing import Any, Callable, Dict, Generator, no_type_check +from typing import Any, Callable, Dict, Generator from unittest import SkipTest from urllib.parse import quote_plus import pymongo import pymongo.errors from bson.son import SON -from pymongo.asynchronous import common, message -from pymongo.asynchronous.common import partition_node from pymongo.asynchronous.database import AsyncDatabase -from pymongo.asynchronous.hello_compat import HelloCompat from pymongo.asynchronous.mongo_client import AsyncMongoClient -from pymongo.asynchronous.uri_parser import parse_uri +from pymongo.common import partition_node +from pymongo.hello_compat import HelloCompat from pymongo.server_api import ServerApi from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined] -if HAVE_SSL: - import ssl - _IS_SYNC = False diff --git a/test/asynchronous/test_collection.py b/test/asynchronous/test_collection.py index 078bad9e20..07a3d25cfe 100644 --- a/test/asynchronous/test_collection.py +++ b/test/asynchronous/test_collection.py @@ -40,10 +40,9 @@ wait_until, ) -from bson import encode +from bson import RawBSONDocument, encode from bson.codec_options import CodecOptions from bson.objectid import ObjectId -from bson.raw_bson import RawBSONDocument from bson.regex import Regex from bson.son import SON from pymongo import ASCENDING, DESCENDING, GEO2D, GEOSPHERE, HASHED, TEXT @@ -53,8 +52,6 @@ from pymongo.asynchronous.helpers import anext from pymongo.asynchronous.message import _COMMAND_OVERHEAD, _gen_find_command from pymongo.asynchronous.mongo_client import AsyncMongoClient -from pymongo.asynchronous.operations import * -from pymongo.asynchronous.read_preferences import ReadPreference from pymongo.cursor_shared import CursorType from pymongo.errors import ( ConfigurationError, @@ -67,7 +64,9 @@ OperationFailure, WriteConcernError, ) +from pymongo.operations import * from pymongo.read_concern import DEFAULT_READ_CONCERN +from pymongo.read_preferences import ReadPreference from pymongo.results import ( DeleteResult, InsertManyResult, @@ -1642,7 +1641,7 @@ async def try_invalid_session(): with await self.db.test.aggregate([], {}): # type:ignore pass - with self.assertRaisesRegex(ValueError, "must be a ClientSession"): + with self.assertRaisesRegex(ValueError, "must be an AsyncClientSession"): await try_invalid_session() async def test_large_limit(self): diff --git a/test/auth_oidc/test_auth_oidc.py b/test/auth_oidc/test_auth_oidc.py index 3fb2894783..76676eb95e 100644 --- a/test/auth_oidc/test_auth_oidc.py +++ b/test/auth_oidc/test_auth_oidc.py @@ -36,13 +36,13 @@ from pymongo._gcp_helpers import _get_gcp_response from pymongo.cursor_shared import CursorType from pymongo.errors import AutoReconnect, ConfigurationError, OperationFailure +from pymongo.operations import InsertOne from pymongo.synchronous.auth_oidc import ( OIDCCallback, OIDCCallbackContext, OIDCCallbackResult, ) from pymongo.synchronous.hello_compat import HelloCompat -from pymongo.synchronous.operations import InsertOne from pymongo.synchronous.uri_parser import parse_uri ROOT = Path(__file__).parent.parent.resolve() diff --git a/test/lambda/mongodb/app.py b/test/lambda/mongodb/app.py index deb26bdf1e..5840347d9a 100644 --- a/test/lambda/mongodb/app.py +++ b/test/lambda/mongodb/app.py @@ -12,7 +12,7 @@ from bson import has_c as has_bson_c from pymongo import MongoClient from pymongo import has_c as has_pymongo_c -from pymongo.synchronous.monitoring import ( +from pymongo.monitoring import ( CommandListener, ConnectionPoolListener, ServerHeartbeatListener, diff --git a/test/mockupdb/test_mongos_command_read_mode.py b/test/mockupdb/test_mongos_command_read_mode.py index 1e91384dc4..8ee33431a8 100644 --- a/test/mockupdb/test_mongos_command_read_mode.py +++ b/test/mockupdb/test_mongos_command_read_mode.py @@ -20,7 +20,7 @@ from operations import operations # type: ignore[import] from pymongo import MongoClient, ReadPreference -from pymongo.synchronous.read_preferences import ( +from pymongo.read_preferences import ( _MONGOS_MODES, make_read_preference, read_pref_mode_from_name, diff --git a/test/mockupdb/test_network_disconnect_primary.py b/test/mockupdb/test_network_disconnect_primary.py index 36e004c05a..d05cfb531a 100644 --- a/test/mockupdb/test_network_disconnect_primary.py +++ b/test/mockupdb/test_network_disconnect_primary.py @@ -19,7 +19,7 @@ from pymongo import MongoClient from pymongo.errors import ConnectionFailure -from pymongo.synchronous.topology_description import TOPOLOGY_TYPE +from pymongo.topology_description import TOPOLOGY_TYPE class TestNetworkDisconnectPrimary(unittest.TestCase): diff --git a/test/mockupdb/test_op_msg.py b/test/mockupdb/test_op_msg.py index aa2437f230..d36e5e02b6 100644 --- a/test/mockupdb/test_op_msg.py +++ b/test/mockupdb/test_op_msg.py @@ -20,7 +20,7 @@ from pymongo import MongoClient, WriteConcern from pymongo.cursor_shared import CursorType -from pymongo.synchronous.operations import DeleteOne, InsertOne, UpdateOne +from pymongo.operations import DeleteOne, InsertOne, UpdateOne Operation = namedtuple("Operation", ["name", "function", "request", "reply"]) diff --git a/test/mockupdb/test_op_msg_read_preference.py b/test/mockupdb/test_op_msg_read_preference.py index 36b8f4fbee..0fa7b84861 100644 --- a/test/mockupdb/test_op_msg_read_preference.py +++ b/test/mockupdb/test_op_msg_read_preference.py @@ -22,7 +22,7 @@ from operations import operations # type: ignore[import] from pymongo import MongoClient, ReadPreference -from pymongo.synchronous.read_preferences import ( +from pymongo.read_preferences import ( _MONGOS_MODES, make_read_preference, read_pref_mode_from_name, diff --git a/test/mockupdb/test_query_read_pref_sharded.py b/test/mockupdb/test_query_read_pref_sharded.py index 9eb4de28c8..5297709886 100644 --- a/test/mockupdb/test_query_read_pref_sharded.py +++ b/test/mockupdb/test_query_read_pref_sharded.py @@ -21,7 +21,7 @@ from bson import SON from pymongo import MongoClient -from pymongo.synchronous.read_preferences import ( +from pymongo.read_preferences import ( Nearest, Primary, PrimaryPreferred, diff --git a/test/mockupdb/test_reset_and_request_check.py b/test/mockupdb/test_reset_and_request_check.py index 080110020a..19dfb9e395 100644 --- a/test/mockupdb/test_reset_and_request_check.py +++ b/test/mockupdb/test_reset_and_request_check.py @@ -22,8 +22,8 @@ from pymongo import MongoClient from pymongo.errors import ConnectionFailure +from pymongo.operations import _Op from pymongo.server_type import SERVER_TYPE -from pymongo.synchronous.operations import _Op class TestResetAndRequestCheck(unittest.TestCase): diff --git a/test/mockupdb/test_slave_okay_sharded.py b/test/mockupdb/test_slave_okay_sharded.py index 9692465d56..45b7d51ba0 100644 --- a/test/mockupdb/test_slave_okay_sharded.py +++ b/test/mockupdb/test_slave_okay_sharded.py @@ -28,7 +28,7 @@ from operations import operations # type: ignore[import] from pymongo import MongoClient -from pymongo.synchronous.read_preferences import make_read_preference, read_pref_mode_from_name +from pymongo.read_preferences import make_read_preference, read_pref_mode_from_name class TestSlaveOkaySharded(unittest.TestCase): diff --git a/test/mockupdb/test_slave_okay_single.py b/test/mockupdb/test_slave_okay_single.py index bf1cdee74b..b03232807e 100644 --- a/test/mockupdb/test_slave_okay_single.py +++ b/test/mockupdb/test_slave_okay_single.py @@ -27,8 +27,8 @@ from operations import operations # type: ignore[import] from pymongo import MongoClient -from pymongo.synchronous.read_preferences import make_read_preference, read_pref_mode_from_name -from pymongo.synchronous.topology_description import TOPOLOGY_TYPE +from pymongo.read_preferences import make_read_preference, read_pref_mode_from_name +from pymongo.topology_description import TOPOLOGY_TYPE def topology_type_name(client): diff --git a/test/pymongo_mocks.py b/test/pymongo_mocks.py index d3c1a271cd..90914927cb 100644 --- a/test/pymongo_mocks.py +++ b/test/pymongo_mocks.py @@ -20,13 +20,14 @@ from functools import partial from test import client_context +from pymongo import common from pymongo.errors import AutoReconnect, NetworkTimeout -from pymongo.synchronous import common -from pymongo.synchronous.hello import Hello, HelloCompat +from pymongo.hello import Hello +from pymongo.hello_compat import HelloCompat +from pymongo.server_description import ServerDescription from pymongo.synchronous.mongo_client import MongoClient from pymongo.synchronous.monitor import Monitor from pymongo.synchronous.pool import Pool -from pymongo.synchronous.server_description import ServerDescription class MockPool(Pool): diff --git a/test/sigstop_sigcont.py b/test/sigstop_sigcont.py index c5084f5943..bc1bacce33 100644 --- a/test/sigstop_sigcont.py +++ b/test/sigstop_sigcont.py @@ -21,8 +21,8 @@ sys.path[0:0] = [""] +from pymongo import monitoring from pymongo.server_api import ServerApi -from pymongo.synchronous import monitoring from pymongo.synchronous.mongo_client import MongoClient SERVER_API = None diff --git a/test/synchronous/__init__.py b/test/synchronous/__init__.py index 6eb11eee85..1320561c8c 100644 --- a/test/synchronous/__init__.py +++ b/test/synchronous/__init__.py @@ -63,24 +63,19 @@ from contextlib import contextmanager from functools import wraps from test.version import Version -from typing import Any, Callable, Dict, Generator, no_type_check +from typing import Any, Callable, Dict, Generator from unittest import SkipTest from urllib.parse import quote_plus import pymongo import pymongo.errors from bson.son import SON +from pymongo.common import partition_node +from pymongo.hello_compat import HelloCompat from pymongo.server_api import ServerApi from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined] -from pymongo.synchronous import common, message -from pymongo.synchronous.common import partition_node from pymongo.synchronous.database import Database -from pymongo.synchronous.hello_compat import HelloCompat from pymongo.synchronous.mongo_client import MongoClient -from pymongo.synchronous.uri_parser import parse_uri - -if HAVE_SSL: - import ssl _IS_SYNC = True diff --git a/test/synchronous/test_collection.py b/test/synchronous/test_collection.py index 39d7e13a31..d0de0b0608 100644 --- a/test/synchronous/test_collection.py +++ b/test/synchronous/test_collection.py @@ -39,10 +39,9 @@ wait_until, ) -from bson import encode +from bson import RawBSONDocument, encode from bson.codec_options import CodecOptions from bson.objectid import ObjectId -from bson.raw_bson import RawBSONDocument from bson.regex import Regex from bson.son import SON from pymongo import ASCENDING, DESCENDING, GEO2D, GEOSPHERE, HASHED, TEXT @@ -58,7 +57,9 @@ OperationFailure, WriteConcernError, ) +from pymongo.operations import * from pymongo.read_concern import DEFAULT_READ_CONCERN +from pymongo.read_preferences import ReadPreference from pymongo.results import ( DeleteResult, InsertManyResult, @@ -71,8 +72,6 @@ from pymongo.synchronous.helpers import next from pymongo.synchronous.message import _COMMAND_OVERHEAD, _gen_find_command from pymongo.synchronous.mongo_client import MongoClient -from pymongo.synchronous.operations import * -from pymongo.synchronous.read_preferences import ReadPreference from pymongo.write_concern import WriteConcern _IS_SYNC = True @@ -1617,7 +1616,7 @@ def try_invalid_session(): with self.db.test.aggregate([], {}): # type:ignore pass - with self.assertRaisesRegex(ValueError, "must be a ClientSession"): + with self.assertRaisesRegex(ValueError, "must be an AsyncClientSession"): try_invalid_session() def test_large_limit(self): diff --git a/test/test_auth.py b/test/test_auth.py index 6bc58e08c7..29cac352fd 100644 --- a/test/test_auth.py +++ b/test/test_auth.py @@ -33,13 +33,13 @@ single_client_noauth, ) -from pymongo import MongoClient -from pymongo.asynchronous.auth import HAVE_KERBEROS, _build_credentials_tuple +from pymongo import MongoClient, monitoring +from pymongo.asynchronous.auth import HAVE_KERBEROS +from pymongo.auth_shared import _build_credentials_tuple from pymongo.errors import OperationFailure +from pymongo.hello_compat import HelloCompat +from pymongo.read_preferences import ReadPreference from pymongo.saslprep import HAVE_STRINGPREP -from pymongo.synchronous import monitoring -from pymongo.synchronous.hello_compat import HelloCompat -from pymongo.synchronous.read_preferences import ReadPreference # YOU MUST RUN KINIT BEFORE RUNNING GSSAPI TESTS ON UNIX. GSSAPI_HOST = os.environ.get("GSSAPI_HOST") diff --git a/test/test_binary.py b/test/test_binary.py index 66a57dcb54..93f6d08315 100644 --- a/test/test_binary.py +++ b/test/test_binary.py @@ -33,7 +33,7 @@ from bson.binary import * from bson.codec_options import CodecOptions from bson.son import SON -from pymongo.synchronous.common import validate_uuid_representation +from pymongo.common import validate_uuid_representation from pymongo.synchronous.mongo_client import MongoClient from pymongo.write_concern import WriteConcern diff --git a/test/test_bulk.py b/test/test_bulk.py index 42dbf5b152..c0f8594431 100644 --- a/test/test_bulk.py +++ b/test/test_bulk.py @@ -34,15 +34,15 @@ from bson.binary import Binary, UuidRepresentation from bson.codec_options import CodecOptions from bson.objectid import ObjectId +from pymongo.common import partition_node from pymongo.errors import ( BulkWriteError, ConfigurationError, InvalidOperation, OperationFailure, ) +from pymongo.operations import * from pymongo.synchronous.collection import Collection -from pymongo.synchronous.common import partition_node -from pymongo.synchronous.operations import * from pymongo.write_concern import WriteConcern diff --git a/test/test_client.py b/test/test_client.py index af71c4890e..e21a899d02 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -37,7 +37,7 @@ import pytest -from pymongo.synchronous.operations import _Op +from pymongo.operations import _Op sys.path[0:0] = [""] @@ -83,6 +83,10 @@ ) from bson.son import SON from bson.tz_util import utc +from pymongo import event_loggers, monitoring +from pymongo.client_options import ClientOptions +from pymongo.common import _UUID_REPRESENTATIONS, CONNECT_TIMEOUT +from pymongo.compression_support import _have_snappy, _have_zstd from pymongo.driver_info import DriverInfo from pymongo.errors import ( AutoReconnect, @@ -96,28 +100,23 @@ ServerSelectionTimeoutError, WriteConcernError, ) +from pymongo.monitoring import ServerHeartbeatListener, ServerHeartbeatStartedEvent +from pymongo.pool_options import _METADATA, ENV_VAR_K8S, PoolOptions +from pymongo.read_preferences import ReadPreference +from pymongo.server_description import ServerDescription +from pymongo.server_selectors import readable_server_selector, writable_server_selector from pymongo.server_type import SERVER_TYPE -from pymongo.synchronous import event_loggers, message, monitoring -from pymongo.synchronous.client_options import ClientOptions +from pymongo.synchronous import message from pymongo.synchronous.command_cursor import CommandCursor -from pymongo.synchronous.common import _UUID_REPRESENTATIONS, CONNECT_TIMEOUT -from pymongo.synchronous.compression_support import _have_snappy, _have_zstd from pymongo.synchronous.cursor import Cursor, CursorType from pymongo.synchronous.database import Database from pymongo.synchronous.mongo_client import MongoClient -from pymongo.synchronous.monitoring import ServerHeartbeatListener, ServerHeartbeatStartedEvent from pymongo.synchronous.pool import ( - _METADATA, - ENV_VAR_K8S, Connection, - PoolOptions, ) -from pymongo.synchronous.read_preferences import ReadPreference -from pymongo.synchronous.server_description import ServerDescription -from pymongo.synchronous.server_selectors import readable_server_selector, writable_server_selector from pymongo.synchronous.settings import TOPOLOGY_TYPE from pymongo.synchronous.topology import _ErrorContext -from pymongo.synchronous.topology_description import TopologyDescription +from pymongo.topology_description import TopologyDescription from pymongo.write_concern import WriteConcern @@ -462,13 +461,13 @@ def test_uri_option_precedence(self): def test_connection_timeout_ms_propagates_to_DNS_resolver(self): # Patch the resolver. - from pymongo.synchronous.srv_resolver import _resolve + from pymongo.srv_resolver import _resolve patched_resolver = FunctionCallRecorder(_resolve) - pymongo.synchronous.srv_resolver._resolve = patched_resolver + pymongo.srv_resolver._resolve = patched_resolver def reset_resolver(): - pymongo.synchronous.srv_resolver._resolve = _resolve + pymongo.srv_resolver._resolve = _resolve self.addCleanup(reset_resolver) @@ -557,7 +556,7 @@ def test_validate_suggestion(self): with self.assertRaisesRegex(ConfigurationError, expected): MongoClient(**{typo: "standard"}) # type: ignore[arg-type] - @patch("pymongo.synchronous.srv_resolver._SrvResolver.get_hosts") + @patch("pymongo.srv_resolver._SrvResolver.get_hosts") def test_detected_environment_logging(self, mock_get_hosts): normal_hosts = [ "normal.host.com", @@ -579,7 +578,7 @@ def test_detected_environment_logging(self, mock_get_hosts): logs = [record.message for record in cm.records if record.name == "pymongo.client"] self.assertEqual(len(logs), 7) - @patch("pymongo.synchronous.srv_resolver._SrvResolver.get_hosts") + @patch("pymongo.srv_resolver._SrvResolver.get_hosts") def test_detected_environment_warning(self, mock_get_hosts): with self._caplog.at_level(logging.WARN): normal_hosts = [ diff --git a/test/test_collation.py b/test/test_collation.py index f4830da5d2..bedf0a2eaa 100644 --- a/test/test_collation.py +++ b/test/test_collation.py @@ -21,15 +21,15 @@ from test.utils import EventListener, rs_or_single_client from typing import Any -from pymongo.errors import ConfigurationError -from pymongo.synchronous.collation import ( +from pymongo.collation import ( Collation, CollationAlternate, CollationCaseFirst, CollationMaxVariable, CollationStrength, ) -from pymongo.synchronous.operations import ( +from pymongo.errors import ConfigurationError +from pymongo.operations import ( DeleteMany, DeleteOne, IndexModel, diff --git a/test/test_collection.py b/test/test_collection.py index 54f76336d5..5495e659b4 100644 --- a/test/test_collection.py +++ b/test/test_collection.py @@ -57,7 +57,9 @@ OperationFailure, WriteConcernError, ) +from pymongo.operations import * from pymongo.read_concern import DEFAULT_READ_CONCERN +from pymongo.read_preferences import ReadPreference from pymongo.results import ( DeleteResult, InsertManyResult, @@ -69,8 +71,6 @@ from pymongo.synchronous.command_cursor import CommandCursor from pymongo.synchronous.message import _COMMAND_OVERHEAD, _gen_find_command from pymongo.synchronous.mongo_client import MongoClient -from pymongo.synchronous.operations import * -from pymongo.synchronous.read_preferences import ReadPreference from pymongo.write_concern import WriteConcern diff --git a/test/test_comment.py b/test/test_comment.py index f9630655c9..931446ef3a 100644 --- a/test/test_comment.py +++ b/test/test_comment.py @@ -25,8 +25,8 @@ from test.utils import EventListener, rs_or_single_client from bson.dbref import DBRef +from pymongo.operations import IndexModel from pymongo.synchronous.command_cursor import CommandCursor -from pymongo.synchronous.operations import IndexModel class Empty: diff --git a/test/test_connection_monitoring.py b/test/test_connection_monitoring.py index 8a0f104a79..5225739f30 100644 --- a/test/test_connection_monitoring.py +++ b/test/test_connection_monitoring.py @@ -45,7 +45,7 @@ PyMongoError, WaitQueueTimeoutError, ) -from pymongo.synchronous.monitoring import ( +from pymongo.monitoring import ( ConnectionCheckedInEvent, ConnectionCheckedOutEvent, ConnectionCheckOutFailedEvent, @@ -60,9 +60,9 @@ PoolCreatedEvent, PoolReadyEvent, ) +from pymongo.read_preferences import ReadPreference from pymongo.synchronous.pool import PoolState, _PoolClosedError -from pymongo.synchronous.read_preferences import ReadPreference -from pymongo.synchronous.topology_description import updated_topology_description +from pymongo.topology_description import updated_topology_description OBJECT_TYPES = { # Event types. diff --git a/test/test_connections_survive_primary_stepdown_spec.py b/test/test_connections_survive_primary_stepdown_spec.py index bb80bda932..674612693c 100644 --- a/test/test_connections_survive_primary_stepdown_spec.py +++ b/test/test_connections_survive_primary_stepdown_spec.py @@ -28,8 +28,8 @@ ) from bson import SON +from pymongo import monitoring from pymongo.errors import NotPrimaryError -from pymongo.synchronous import monitoring from pymongo.synchronous.collection import Collection from pymongo.write_concern import WriteConcern diff --git a/test/test_crud_v1.py b/test/test_crud_v1.py index b13e4c8444..d528a1dfe7 100644 --- a/test/test_crud_v1.py +++ b/test/test_crud_v1.py @@ -29,14 +29,9 @@ drop_collections, ) -from pymongo import WriteConcern +from pymongo import WriteConcern, operations from pymongo.errors import PyMongoError -from pymongo.read_concern import ReadConcern -from pymongo.results import BulkWriteResult, _WriteResult -from pymongo.synchronous import operations -from pymongo.synchronous.command_cursor import CommandCursor -from pymongo.synchronous.cursor import Cursor -from pymongo.synchronous.operations import ( +from pymongo.operations import ( DeleteMany, DeleteOne, InsertOne, @@ -44,6 +39,10 @@ UpdateMany, UpdateOne, ) +from pymongo.read_concern import ReadConcern +from pymongo.results import BulkWriteResult, _WriteResult +from pymongo.synchronous.command_cursor import CommandCursor +from pymongo.synchronous.cursor import Cursor # Location of JSON test specifications. _TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "crud", "v1") diff --git a/test/test_cursor.py b/test/test_cursor.py index c354c42b33..8a6f1f4043 100644 --- a/test/test_cursor.py +++ b/test/test_cursor.py @@ -41,14 +41,13 @@ from bson import decode_all from bson.code import Code -from bson.son import SON from pymongo import ASCENDING, DESCENDING +from pymongo.collation import Collation from pymongo.errors import ExecutionTimeout, InvalidOperation, OperationFailure +from pymongo.operations import _IndexList from pymongo.read_concern import ReadConcern -from pymongo.synchronous.collation import Collation +from pymongo.read_preferences import ReadPreference from pymongo.synchronous.cursor import Cursor, CursorType -from pymongo.synchronous.operations import _IndexList -from pymongo.synchronous.read_preferences import ReadPreference from pymongo.write_concern import WriteConcern diff --git a/test/test_database.py b/test/test_database.py index 1520a4cc55..90bf07881d 100644 --- a/test/test_database.py +++ b/test/test_database.py @@ -38,6 +38,7 @@ from bson.objectid import ObjectId from bson.regex import Regex from bson.son import SON +from pymongo import helpers_shared from pymongo.asynchronous import auth from pymongo.errors import ( CollectionInvalid, @@ -48,11 +49,10 @@ WriteConcernError, ) from pymongo.read_concern import ReadConcern -from pymongo.synchronous import helpers +from pymongo.read_preferences import ReadPreference from pymongo.synchronous.collection import Collection from pymongo.synchronous.database import Database from pymongo.synchronous.mongo_client import MongoClient -from pymongo.synchronous.read_preferences import ReadPreference from pymongo.write_concern import WriteConcern @@ -577,10 +577,10 @@ def test_command_response_without_ok(self): # Sometimes (SERVER-10891) the server's response to a badly-formatted # command document will have no 'ok' field. We should raise # OperationFailure instead of KeyError. - self.assertRaises(OperationFailure, helpers._check_command_response, {}, None) + self.assertRaises(OperationFailure, helpers_shared._check_command_response, {}, None) try: - helpers._check_command_response({"$err": "foo"}, None) + helpers_shared._check_command_response({"$err": "foo"}, None) except OperationFailure as e: self.assertEqual(e.args[0], "foo, full error: {'$err': 'foo'}") else: @@ -594,7 +594,7 @@ def test_mongos_response(self): } with self.assertRaises(OperationFailure) as context: - helpers._check_command_response(error_document, None) + helpers_shared._check_command_response(error_document, None) self.assertIn("inner", str(context.exception)) @@ -604,7 +604,7 @@ def test_mongos_response(self): error_document = {"ok": 0, "errmsg": "outer", "raw": {"shard0/host0,host1": {}}} with self.assertRaises(OperationFailure) as context: - helpers._check_command_response(error_document, None) + helpers_shared._check_command_response(error_document, None) self.assertIn("outer", str(context.exception)) @@ -612,7 +612,7 @@ def test_mongos_response(self): error_document = {"ok": 0, "errmsg": "outer", "raw": {"shard0/host0,host1": {"ok": 0}}} with self.assertRaises(OperationFailure) as context: - helpers._check_command_response(error_document, None) + helpers_shared._check_command_response(error_document, None) self.assertIn("outer", str(context.exception)) diff --git a/test/test_default_exports.py b/test/test_default_exports.py index 91f94c9db4..d9301d2223 100644 --- a/test/test_default_exports.py +++ b/test/test_default_exports.py @@ -26,7 +26,7 @@ GRIDFS_IGNORE = [ "ASCENDING", "DESCENDING", - "ClientSession", + "AsyncClientSession", "Collection", "ObjectId", "validate_string", diff --git a/test/test_discovery_and_monitoring.py b/test/test_discovery_and_monitoring.py index 53602eaeca..e584c17f4e 100644 --- a/test/test_discovery_and_monitoring.py +++ b/test/test_discovery_and_monitoring.py @@ -40,7 +40,7 @@ from unittest.mock import patch from bson import Timestamp, json_util -from pymongo import MongoClient +from pymongo import MongoClient, common, monitoring from pymongo.errors import ( AutoReconnect, ConfigurationError, @@ -48,15 +48,15 @@ NotPrimaryError, OperationFailure, ) -from pymongo.synchronous import common, monitoring -from pymongo.synchronous.hello import Hello, HelloCompat -from pymongo.synchronous.helpers import _check_command_response, _check_write_command_response -from pymongo.synchronous.monitoring import ServerHeartbeatFailedEvent, ServerHeartbeatStartedEvent -from pymongo.synchronous.server_description import SERVER_TYPE, ServerDescription +from pymongo.hello import Hello, HelloCompat +from pymongo.hello_compat import HelloCompat +from pymongo.helpers_shared import _check_command_response, _check_write_command_response +from pymongo.monitoring import ServerHeartbeatFailedEvent, ServerHeartbeatStartedEvent +from pymongo.server_description import SERVER_TYPE, ServerDescription from pymongo.synchronous.settings import TopologySettings from pymongo.synchronous.topology import Topology, _ErrorContext -from pymongo.synchronous.topology_description import TOPOLOGY_TYPE -from pymongo.synchronous.uri_parser import parse_uri +from pymongo.topology_description import TOPOLOGY_TYPE +from pymongo.uri_parser import parse_uri # Location of JSON test specifications. SDAM_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "discovery_and_monitoring") diff --git a/test/test_dns.py b/test/test_dns.py index a2d0fd8b4d..b4c5e3684c 100644 --- a/test/test_dns.py +++ b/test/test_dns.py @@ -25,10 +25,10 @@ from test import IntegrationTest, client_context, unittest from test.utils import wait_until +from pymongo.common import validate_read_preference_tags from pymongo.errors import ConfigurationError -from pymongo.synchronous.common import validate_read_preference_tags from pymongo.synchronous.mongo_client import MongoClient -from pymongo.synchronous.uri_parser import parse_uri, split_hosts +from pymongo.uri_parser import parse_uri, split_hosts class TestDNSRepl(unittest.TestCase): diff --git a/test/test_encryption.py b/test/test_encryption.py index 0e232f4401..a297e0f524 100644 --- a/test/test_encryption.py +++ b/test/test_encryption.py @@ -70,6 +70,7 @@ from bson.son import SON from pymongo import ReadPreference from pymongo.cursor_shared import CursorType +from pymongo.encryption_options import _HAVE_PYMONGOCRYPT, AutoEncryptionOpts, RangeOpts from pymongo.errors import ( AutoReconnect, BulkWriteError, @@ -82,11 +83,10 @@ ServerSelectionTimeoutError, WriteError, ) +from pymongo.operations import InsertOne, ReplaceOne, UpdateOne from pymongo.synchronous import encryption from pymongo.synchronous.encryption import Algorithm, ClientEncryption, QueryType -from pymongo.synchronous.encryption_options import _HAVE_PYMONGOCRYPT, AutoEncryptionOpts, RangeOpts from pymongo.synchronous.mongo_client import MongoClient -from pymongo.synchronous.operations import InsertOne, ReplaceOne, UpdateOne from pymongo.write_concern import WriteConcern KMS_PROVIDERS = {"local": {"key": b"\x00" * 96}} diff --git a/test/test_examples.py b/test/test_examples.py index f0d8bd5543..e003d8459a 100644 --- a/test/test_examples.py +++ b/test/test_examples.py @@ -27,8 +27,8 @@ import pymongo from pymongo.errors import ConnectionFailure, OperationFailure from pymongo.read_concern import ReadConcern +from pymongo.read_preferences import ReadPreference from pymongo.server_api import ServerApi -from pymongo.synchronous.read_preferences import ReadPreference from pymongo.write_concern import WriteConcern diff --git a/test/test_gridfs.py b/test/test_gridfs.py index 1ef17afc2b..d08c10224d 100644 --- a/test/test_gridfs.py +++ b/test/test_gridfs.py @@ -37,9 +37,9 @@ NotPrimaryError, ServerSelectionTimeoutError, ) +from pymongo.read_preferences import ReadPreference from pymongo.synchronous.database import Database from pymongo.synchronous.mongo_client import MongoClient -from pymongo.synchronous.read_preferences import ReadPreference class JustWrite(threading.Thread): diff --git a/test/test_gridfs_bucket.py b/test/test_gridfs_bucket.py index 6ce7b79228..c3945d1053 100644 --- a/test/test_gridfs_bucket.py +++ b/test/test_gridfs_bucket.py @@ -41,8 +41,8 @@ ServerSelectionTimeoutError, WriteConcernError, ) +from pymongo.read_preferences import ReadPreference from pymongo.synchronous.mongo_client import MongoClient -from pymongo.synchronous.read_preferences import ReadPreference class JustWrite(threading.Thread): diff --git a/test/test_heartbeat_monitoring.py b/test/test_heartbeat_monitoring.py index 0566fffe5b..1302df8fde 100644 --- a/test/test_heartbeat_monitoring.py +++ b/test/test_heartbeat_monitoring.py @@ -23,7 +23,7 @@ from test.utils import HeartbeatEventListener, MockPool, single_client, wait_until from pymongo.errors import ConnectionFailure -from pymongo.synchronous.hello import Hello, HelloCompat +from pymongo.hello import Hello, HelloCompat from pymongo.synchronous.monitor import Monitor diff --git a/test/test_index_management.py b/test/test_index_management.py index b8409178d1..5b6653dcba 100644 --- a/test/test_index_management.py +++ b/test/test_index_management.py @@ -29,8 +29,8 @@ from pymongo import MongoClient from pymongo.errors import OperationFailure +from pymongo.operations import SearchIndexModel from pymongo.read_concern import ReadConcern -from pymongo.synchronous.operations import SearchIndexModel from pymongo.write_concern import WriteConcern _TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "index_management") diff --git a/test/test_logger.py b/test/test_logger.py index d1f84a8441..e8d1929b8b 100644 --- a/test/test_logger.py +++ b/test/test_logger.py @@ -20,7 +20,7 @@ from bson import json_util from pymongo.errors import OperationFailure -from pymongo.synchronous.logger import _DEFAULT_DOCUMENT_LENGTH +from pymongo.logger import _DEFAULT_DOCUMENT_LENGTH # https://github.com/mongodb/specifications/tree/master/source/command-logging-and-monitoring/tests#prose-tests diff --git a/test/test_max_staleness.py b/test/test_max_staleness.py index d41f216eb8..1b0130f7d8 100644 --- a/test/test_max_staleness.py +++ b/test/test_max_staleness.py @@ -20,7 +20,7 @@ import time import warnings -from pymongo.synchronous.operations import _Op +from pymongo.operations import _Op sys.path[0:0] = [""] @@ -30,7 +30,7 @@ from pymongo import MongoClient from pymongo.errors import ConfigurationError -from pymongo.synchronous.server_selectors import writable_server_selector +from pymongo.server_selectors import writable_server_selector # Location of JSON test specifications. _TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "max_staleness") diff --git a/test/test_mongos_load_balancing.py b/test/test_mongos_load_balancing.py index 4ab4d30657..b59d6c3e19 100644 --- a/test/test_mongos_load_balancing.py +++ b/test/test_mongos_load_balancing.py @@ -18,7 +18,7 @@ import sys import threading -from pymongo.synchronous.operations import _Op +from pymongo.operations import _Op sys.path[0:0] = [""] @@ -27,8 +27,8 @@ from test.utils import connected, wait_until from pymongo.errors import AutoReconnect, InvalidOperation -from pymongo.synchronous.server_selectors import writable_server_selector -from pymongo.synchronous.topology_description import TOPOLOGY_TYPE +from pymongo.server_selectors import writable_server_selector +from pymongo.topology_description import TOPOLOGY_TYPE @client_context.require_connection diff --git a/test/test_monitoring.py b/test/test_monitoring.py index 7f88888157..ed6a3d0bc2 100644 --- a/test/test_monitoring.py +++ b/test/test_monitoring.py @@ -27,11 +27,10 @@ from bson.int64 import Int64 from bson.objectid import ObjectId from bson.son import SON -from pymongo import CursorType, DeleteOne, InsertOne, UpdateOne +from pymongo import CursorType, DeleteOne, InsertOne, UpdateOne, monitoring from pymongo.errors import AutoReconnect, NotPrimaryError, OperationFailure -from pymongo.synchronous import monitoring +from pymongo.read_preferences import ReadPreference from pymongo.synchronous.command_cursor import CommandCursor -from pymongo.synchronous.read_preferences import ReadPreference from pymongo.write_concern import WriteConcern diff --git a/test/test_pooling.py b/test/test_pooling.py index 5ed701517a..da68f04e78 100644 --- a/test/test_pooling.py +++ b/test/test_pooling.py @@ -26,8 +26,8 @@ from bson.son import SON from pymongo import MongoClient, timeout from pymongo.errors import AutoReconnect, ConnectionFailure, DuplicateKeyError +from pymongo.hello_compat import HelloCompat from pymongo.synchronous import message -from pymongo.synchronous.hello_compat import HelloCompat sys.path[0:0] = [""] diff --git a/test/test_read_preferences.py b/test/test_read_preferences.py index 4f774aa87d..76ad14dcce 100644 --- a/test/test_read_preferences.py +++ b/test/test_read_preferences.py @@ -22,7 +22,7 @@ import sys from typing import Any -from pymongo.synchronous.operations import _Op +from pymongo.operations import _Op sys.path[0:0] = [""] @@ -39,10 +39,7 @@ from bson.son import SON from pymongo.errors import ConfigurationError, OperationFailure -from pymongo.server_type import SERVER_TYPE -from pymongo.synchronous.message import _maybe_add_read_preference -from pymongo.synchronous.mongo_client import MongoClient -from pymongo.synchronous.read_preferences import ( +from pymongo.read_preferences import ( MovingAverage, Nearest, Primary, @@ -51,8 +48,11 @@ Secondary, SecondaryPreferred, ) -from pymongo.synchronous.server_description import ServerDescription -from pymongo.synchronous.server_selectors import Selection, readable_server_selector +from pymongo.server_description import ServerDescription +from pymongo.server_selectors import Selection, readable_server_selector +from pymongo.server_type import SERVER_TYPE +from pymongo.synchronous.message import _maybe_add_read_preference +from pymongo.synchronous.mongo_client import MongoClient from pymongo.write_concern import WriteConcern diff --git a/test/test_read_write_concern_spec.py b/test/test_read_write_concern_spec.py index 93986d824d..34aa1f7546 100644 --- a/test/test_read_write_concern_spec.py +++ b/test/test_read_write_concern_spec.py @@ -40,9 +40,9 @@ WriteError, WTimeoutError, ) +from pymongo.operations import IndexModel, InsertOne from pymongo.read_concern import ReadConcern from pymongo.synchronous.mongo_client import MongoClient -from pymongo.synchronous.operations import IndexModel, InsertOne from pymongo.write_concern import WriteConcern _TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "read_write_concern") diff --git a/test/test_retryable_reads.py b/test/test_retryable_reads.py index 569f7c2751..9ea546ba9b 100644 --- a/test/test_retryable_reads.py +++ b/test/test_retryable_reads.py @@ -43,13 +43,13 @@ ) from test.utils_spec_runner import SpecRunner -from pymongo.synchronous.mongo_client import MongoClient -from pymongo.synchronous.monitoring import ( +from pymongo.monitoring import ( ConnectionCheckedOutEvent, ConnectionCheckOutFailedEvent, ConnectionCheckOutFailedReason, PoolClearedEvent, ) +from pymongo.synchronous.mongo_client import MongoClient from pymongo.write_concern import WriteConcern # Location of JSON test specifications. diff --git a/test/test_retryable_writes.py b/test/test_retryable_writes.py index 347e6c1383..45a740e844 100644 --- a/test/test_retryable_writes.py +++ b/test/test_retryable_writes.py @@ -47,15 +47,14 @@ ServerSelectionTimeoutError, WriteConcernError, ) -from pymongo.synchronous.mongo_client import MongoClient -from pymongo.synchronous.monitoring import ( +from pymongo.monitoring import ( CommandSucceededEvent, ConnectionCheckedOutEvent, ConnectionCheckOutFailedEvent, ConnectionCheckOutFailedReason, PoolClearedEvent, ) -from pymongo.synchronous.operations import ( +from pymongo.operations import ( DeleteMany, DeleteOne, InsertOne, @@ -63,6 +62,7 @@ UpdateMany, UpdateOne, ) +from pymongo.synchronous.mongo_client import MongoClient from pymongo.write_concern import WriteConcern # Location of JSON test specifications. diff --git a/test/test_sdam_monitoring_spec.py b/test/test_sdam_monitoring_spec.py index c955dc4084..5faee9b103 100644 --- a/test/test_sdam_monitoring_spec.py +++ b/test/test_sdam_monitoring_spec.py @@ -31,15 +31,14 @@ ) from bson.json_util import object_hook -from pymongo import MongoClient +from pymongo import MongoClient, monitoring +from pymongo.common import clean_node from pymongo.errors import ConnectionFailure, NotPrimaryError -from pymongo.synchronous import monitoring +from pymongo.hello import Hello +from pymongo.server_description import ServerDescription from pymongo.synchronous.collection import Collection -from pymongo.synchronous.common import clean_node -from pymongo.synchronous.hello import Hello from pymongo.synchronous.monitor import Monitor -from pymongo.synchronous.server_description import ServerDescription -from pymongo.synchronous.topology_description import TOPOLOGY_TYPE +from pymongo.topology_description import TOPOLOGY_TYPE # Location of JSON test specifications. _TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sdam_monitoring") diff --git a/test/test_server.py b/test/test_server.py index b5c6c1365f..45d01c10de 100644 --- a/test/test_server.py +++ b/test/test_server.py @@ -21,9 +21,9 @@ from test import unittest -from pymongo.synchronous.hello import Hello +from pymongo.hello import Hello +from pymongo.server_description import ServerDescription from pymongo.synchronous.server import Server -from pymongo.synchronous.server_description import ServerDescription class TestServer(unittest.TestCase): diff --git a/test/test_server_description.py b/test/test_server_description.py index 273c001c9e..ee05e95cf8 100644 --- a/test/test_server_description.py +++ b/test/test_server_description.py @@ -23,9 +23,9 @@ from bson.int64 import Int64 from bson.objectid import ObjectId +from pymongo.hello import Hello, HelloCompat +from pymongo.server_description import ServerDescription from pymongo.server_type import SERVER_TYPE -from pymongo.synchronous.hello import Hello, HelloCompat -from pymongo.synchronous.server_description import ServerDescription address = ("localhost", 27017) diff --git a/test/test_server_selection.py b/test/test_server_selection.py index 94289a00a3..42bd5a095d 100644 --- a/test/test_server_selection.py +++ b/test/test_server_selection.py @@ -20,12 +20,12 @@ from pymongo import MongoClient, ReadPreference from pymongo.errors import ServerSelectionTimeoutError -from pymongo.synchronous.hello_compat import HelloCompat -from pymongo.synchronous.operations import _Op -from pymongo.synchronous.server_selectors import writable_server_selector +from pymongo.hello_compat import HelloCompat +from pymongo.operations import _Op +from pymongo.server_selectors import writable_server_selector from pymongo.synchronous.settings import TopologySettings from pymongo.synchronous.topology import Topology -from pymongo.synchronous.typings import strip_optional +from pymongo.typings import strip_optional sys.path[0:0] = [""] diff --git a/test/test_server_selection_in_window.py b/test/test_server_selection_in_window.py index c7384590d9..9dced595c9 100644 --- a/test/test_server_selection_in_window.py +++ b/test/test_server_selection_in_window.py @@ -27,9 +27,9 @@ ) from test.utils_selection_tests import create_topology -from pymongo.synchronous.common import clean_node -from pymongo.synchronous.operations import _Op -from pymongo.synchronous.read_preferences import ReadPreference +from pymongo.common import clean_node +from pymongo.operations import _Op +from pymongo.read_preferences import ReadPreference # Location of JSON test specifications. TEST_PATH = os.path.join( diff --git a/test/test_server_selection_rtt.py b/test/test_server_selection_rtt.py index 26e871c400..a129af4585 100644 --- a/test/test_server_selection_rtt.py +++ b/test/test_server_selection_rtt.py @@ -23,7 +23,7 @@ from test import unittest -from pymongo.synchronous.read_preferences import MovingAverage +from pymongo.read_preferences import MovingAverage # Location of JSON test specifications. _TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "server_selection/rtt") diff --git a/test/test_session.py b/test/test_session.py index f746c6d7cb..a91a417066 100644 --- a/test/test_session.py +++ b/test/test_session.py @@ -35,14 +35,13 @@ from bson import DBRef from gridfs import GridFS, GridFSBucket -from pymongo import ASCENDING +from pymongo import ASCENDING, monitoring +from pymongo.common import _MAX_END_SESSIONS from pymongo.errors import ConfigurationError, InvalidOperation, OperationFailure +from pymongo.operations import IndexModel, InsertOne, UpdateOne from pymongo.read_concern import ReadConcern -from pymongo.synchronous import monitoring from pymongo.synchronous.command_cursor import CommandCursor -from pymongo.synchronous.common import _MAX_END_SESSIONS from pymongo.synchronous.cursor import Cursor -from pymongo.synchronous.operations import IndexModel, InsertOne, UpdateOne # Ignore auth commands like saslStart, so we can assert lsid is in all commands. diff --git a/test/test_srv_polling.py b/test/test_srv_polling.py index 0c293874b1..405db14ac6 100644 --- a/test/test_srv_polling.py +++ b/test/test_srv_polling.py @@ -25,10 +25,10 @@ from test.utils import FunctionCallRecorder, wait_until import pymongo +from pymongo import common from pymongo.errors import ConfigurationError -from pymongo.synchronous import common +from pymongo.srv_resolver import _have_dnspython from pymongo.synchronous.mongo_client import MongoClient -from pymongo.synchronous.srv_resolver import _have_dnspython WAIT_TIME = 0.1 @@ -51,9 +51,7 @@ def __init__( def enable(self): self.old_min_srv_rescan_interval = common.MIN_SRV_RESCAN_INTERVAL - self.old_dns_resolver_response = ( - pymongo.synchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl - ) + self.old_dns_resolver_response = pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl if self.min_srv_rescan_interval is not None: common.MIN_SRV_RESCAN_INTERVAL = self.min_srv_rescan_interval @@ -73,14 +71,14 @@ def mock_get_hosts_and_min_ttl(resolver, *args): else: patch_func = mock_get_hosts_and_min_ttl - pymongo.synchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl = patch_func # type: ignore + pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl = patch_func # type: ignore def __enter__(self): self.enable() def disable(self): common.MIN_SRV_RESCAN_INTERVAL = self.old_min_srv_rescan_interval # type: ignore - pymongo.synchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl = ( # type: ignore + pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl = ( # type: ignore self.old_dns_resolver_response ) @@ -133,10 +131,7 @@ def assert_nodelist_nochange(self, expected_nodelist, client, timeout=(100 * WAI def predicate(): if set(expected_nodelist) == set(self.get_nodelist(client)): - return ( - pymongo.synchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count - >= 1 - ) + return pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count >= 1 return False wait_until(predicate, "Node list equals expected nodelist", timeout=timeout) @@ -146,7 +141,7 @@ def predicate(): msg = "Client nodelist %s changed unexpectedly (expected %s)" raise self.fail(msg % (nodelist, expected_nodelist)) self.assertGreaterEqual( - pymongo.synchronous.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count, # type: ignore + pymongo.srv_resolver._SrvResolver.get_hosts_and_min_ttl.call_count, # type: ignore 1, "resolver was never called", ) diff --git a/test/test_ssl.py b/test/test_ssl.py index 56dd23a8e0..b123accdf6 100644 --- a/test/test_ssl.py +++ b/test/test_ssl.py @@ -33,8 +33,8 @@ from pymongo import MongoClient, ssl_support from pymongo.errors import ConfigurationError, ConnectionFailure, OperationFailure +from pymongo.hello_compat import HelloCompat from pymongo.ssl_support import HAVE_SSL, _ssl, get_ssl_context -from pymongo.synchronous.hello_compat import HelloCompat from pymongo.write_concern import WriteConcern _HAVE_PYOPENSSL = False diff --git a/test/test_streaming_protocol.py b/test/test_streaming_protocol.py index 054910ca1f..97618e105e 100644 --- a/test/test_streaming_protocol.py +++ b/test/test_streaming_protocol.py @@ -29,8 +29,8 @@ wait_until, ) -from pymongo.synchronous import monitoring -from pymongo.synchronous.hello_compat import HelloCompat +from pymongo import monitoring +from pymongo.hello_compat import HelloCompat class TestStreamingProtocol(IntegrationTest): diff --git a/test/test_topology.py b/test/test_topology.py index e6fd5a3c0b..8f7fde9810 100644 --- a/test/test_topology.py +++ b/test/test_topology.py @@ -17,7 +17,7 @@ import sys -from pymongo.synchronous.operations import _Op +from pymongo.operations import _Op sys.path[0:0] = [""] @@ -26,19 +26,19 @@ from test.utils import MockPool, wait_until from bson.objectid import ObjectId +from pymongo import common from pymongo.errors import AutoReconnect, ConfigurationError, ConnectionFailure +from pymongo.hello import Hello, HelloCompat +from pymongo.read_preferences import ReadPreference, Secondary +from pymongo.server_description import ServerDescription +from pymongo.server_selectors import any_server_selector, writable_server_selector from pymongo.server_type import SERVER_TYPE -from pymongo.synchronous import common -from pymongo.synchronous.hello import Hello, HelloCompat from pymongo.synchronous.monitor import Monitor from pymongo.synchronous.pool import PoolOptions -from pymongo.synchronous.read_preferences import ReadPreference, Secondary from pymongo.synchronous.server import Server -from pymongo.synchronous.server_description import ServerDescription -from pymongo.synchronous.server_selectors import any_server_selector, writable_server_selector from pymongo.synchronous.settings import TopologySettings from pymongo.synchronous.topology import Topology, _ErrorContext, _filter_servers -from pymongo.synchronous.topology_description import TOPOLOGY_TYPE +from pymongo.topology_description import TOPOLOGY_TYPE class SetNameDiscoverySettings(TopologySettings): diff --git a/test/test_transactions.py b/test/test_transactions.py index 4279c942ec..62525742d3 100644 --- a/test/test_transactions.py +++ b/test/test_transactions.py @@ -43,13 +43,13 @@ InvalidOperation, OperationFailure, ) +from pymongo.operations import IndexModel, InsertOne from pymongo.read_concern import ReadConcern +from pymongo.read_preferences import ReadPreference from pymongo.synchronous import client_session from pymongo.synchronous.client_session import TransactionOptions from pymongo.synchronous.command_cursor import CommandCursor from pymongo.synchronous.cursor import Cursor -from pymongo.synchronous.operations import IndexModel, InsertOne -from pymongo.synchronous.read_preferences import ReadPreference _TXN_TESTS_DEBUG = os.environ.get("TRANSACTION_TESTS_DEBUG") diff --git a/test/test_typing.py b/test/test_typing.py index 552590c644..f423b70a3e 100644 --- a/test/test_typing.py +++ b/test/test_typing.py @@ -75,9 +75,9 @@ class ImplicitMovie(TypedDict): from bson.raw_bson import RawBSONDocument from bson.son import SON from pymongo import ASCENDING, MongoClient +from pymongo.operations import DeleteOne, InsertOne, ReplaceOne +from pymongo.read_preferences import ReadPreference from pymongo.synchronous.collection import Collection -from pymongo.synchronous.operations import DeleteOne, InsertOne, ReplaceOne -from pymongo.synchronous.read_preferences import ReadPreference TEST_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mypy_fails") diff --git a/test/test_uri_parser.py b/test/test_uri_parser.py index 09178e2802..27f5fd2fbc 100644 --- a/test/test_uri_parser.py +++ b/test/test_uri_parser.py @@ -28,7 +28,7 @@ from bson.binary import JAVA_LEGACY from pymongo import ReadPreference from pymongo.errors import ConfigurationError, InvalidURI -from pymongo.synchronous.uri_parser import ( +from pymongo.uri_parser import ( parse_uri, parse_userinfo, split_hosts, diff --git a/test/test_uri_spec.py b/test/test_uri_spec.py index a5ec436498..3a8bf6275a 100644 --- a/test/test_uri_spec.py +++ b/test/test_uri_spec.py @@ -26,9 +26,9 @@ from test import clear_warning_registry, unittest -from pymongo.synchronous.common import INTERNAL_URI_OPTION_NAME_MAP, validate -from pymongo.synchronous.compression_support import _have_snappy -from pymongo.synchronous.uri_parser import parse_uri +from pymongo.common import INTERNAL_URI_OPTION_NAME_MAP, validate +from pymongo.compression_support import _have_snappy +from pymongo.uri_parser import parse_uri CONN_STRING_TEST_PATH = os.path.join( os.path.dirname(os.path.realpath(__file__)), os.path.join("connection_string", "test") diff --git a/test/unified_format.py b/test/unified_format.py index fe1419c0d0..ecf0133e74 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -68,6 +68,7 @@ from bson.regex import RE_TYPE, Regex from gridfs import GridFSBucket, GridOut from pymongo import ASCENDING, CursorType, MongoClient, _csot +from pymongo.encryption_options import _HAVE_PYMONGOCRYPT from pymongo.errors import ( BulkWriteError, ConfigurationError, @@ -78,18 +79,7 @@ OperationFailure, PyMongoError, ) -from pymongo.read_concern import ReadConcern -from pymongo.results import BulkWriteResult -from pymongo.server_api import ServerApi -from pymongo.server_type import SERVER_TYPE -from pymongo.synchronous.change_stream import ChangeStream -from pymongo.synchronous.client_session import ClientSession, TransactionOptions, _TxnState -from pymongo.synchronous.collection import Collection -from pymongo.synchronous.command_cursor import CommandCursor -from pymongo.synchronous.database import Database -from pymongo.synchronous.encryption import ClientEncryption -from pymongo.synchronous.encryption_options import _HAVE_PYMONGOCRYPT -from pymongo.synchronous.monitoring import ( +from pymongo.monitoring import ( _SENSITIVE_COMMANDS, CommandFailedEvent, CommandListener, @@ -125,12 +115,22 @@ _ServerEvent, _ServerHeartbeatEvent, ) -from pymongo.synchronous.operations import SearchIndexModel -from pymongo.synchronous.read_preferences import ReadPreference -from pymongo.synchronous.server_description import ServerDescription -from pymongo.synchronous.server_selectors import Selection, writable_server_selector -from pymongo.synchronous.topology_description import TopologyDescription -from pymongo.synchronous.typings import _Address +from pymongo.operations import SearchIndexModel +from pymongo.read_concern import ReadConcern +from pymongo.read_preferences import ReadPreference +from pymongo.results import BulkWriteResult +from pymongo.server_api import ServerApi +from pymongo.server_description import ServerDescription +from pymongo.server_selectors import Selection, writable_server_selector +from pymongo.server_type import SERVER_TYPE +from pymongo.synchronous.change_stream import ChangeStream +from pymongo.synchronous.client_session import ClientSession, TransactionOptions, _TxnState +from pymongo.synchronous.collection import Collection +from pymongo.synchronous.command_cursor import CommandCursor +from pymongo.synchronous.database import Database +from pymongo.synchronous.encryption import ClientEncryption +from pymongo.topology_description import TopologyDescription +from pymongo.typings import _Address from pymongo.write_concern import WriteConcern JSON_OPTS = json_util.JSONOptions(tz_aware=False) @@ -616,7 +616,7 @@ def get_lsid_for_session(self, session_name): session = self[session_name] if not isinstance(session, ClientSession): self.test.fail( - f"Expected entity {session_name} to be of type ClientSession, got {type(session)}" + f"Expected entity {session_name} to be of type AsyncClientSession, got {type(session)}" ) try: diff --git a/test/utils.py b/test/utils.py index bd33270c11..97b39b38e7 100644 --- a/test/utils.py +++ b/test/utils.py @@ -36,18 +36,13 @@ from bson import json_util from bson.objectid import ObjectId from bson.son import SON -from pymongo import AsyncMongoClient +from pymongo import AsyncMongoClient, monitoring, operations, read_preferences from pymongo.cursor_shared import CursorType from pymongo.errors import ConfigurationError, OperationFailure -from pymongo.helpers_constants import _SENSITIVE_COMMANDS +from pymongo.hello_compat import HelloCompat +from pymongo.helpers_shared import _SENSITIVE_COMMANDS from pymongo.lock import _create_lock -from pymongo.read_concern import ReadConcern -from pymongo.server_type import SERVER_TYPE -from pymongo.synchronous import monitoring, operations, read_preferences -from pymongo.synchronous.collection import ReturnDocument -from pymongo.synchronous.hello_compat import HelloCompat -from pymongo.synchronous.mongo_client import MongoClient -from pymongo.synchronous.monitoring import ( +from pymongo.monitoring import ( ConnectionCheckedInEvent, ConnectionCheckedOutEvent, ConnectionCheckOutFailedEvent, @@ -60,11 +55,15 @@ PoolCreatedEvent, PoolReadyEvent, ) -from pymongo.synchronous.operations import _Op +from pymongo.operations import _Op +from pymongo.read_concern import ReadConcern +from pymongo.read_preferences import ReadPreference +from pymongo.server_selectors import any_server_selector, writable_server_selector +from pymongo.server_type import SERVER_TYPE +from pymongo.synchronous.collection import ReturnDocument +from pymongo.synchronous.mongo_client import MongoClient from pymongo.synchronous.pool import _CancellationContext, _PoolGeneration -from pymongo.synchronous.read_preferences import ReadPreference -from pymongo.synchronous.server_selectors import any_server_selector, writable_server_selector -from pymongo.synchronous.uri_parser import parse_uri +from pymongo.uri_parser import parse_uri from pymongo.write_concern import WriteConcern IMPOSSIBLE_WRITE_CONCERN = WriteConcern(w=50) diff --git a/test/utils_selection_tests.py b/test/utils_selection_tests.py index 7673e9bc27..e6fb829eb3 100644 --- a/test/utils_selection_tests.py +++ b/test/utils_selection_tests.py @@ -19,7 +19,7 @@ import os import sys -from pymongo.synchronous.operations import _Op +from pymongo.operations import _Op sys.path[0:0] = [""] @@ -28,11 +28,11 @@ from test.utils import MockPool, parse_read_preference from bson import json_util +from pymongo.common import HEARTBEAT_FREQUENCY, clean_node from pymongo.errors import AutoReconnect, ConfigurationError -from pymongo.synchronous.common import HEARTBEAT_FREQUENCY, clean_node -from pymongo.synchronous.hello import Hello, HelloCompat -from pymongo.synchronous.server_description import ServerDescription -from pymongo.synchronous.server_selectors import writable_server_selector +from pymongo.hello import Hello, HelloCompat +from pymongo.server_description import ServerDescription +from pymongo.server_selectors import writable_server_selector from pymongo.synchronous.settings import TopologySettings from pymongo.synchronous.topology import Topology diff --git a/test/utils_spec_runner.py b/test/utils_spec_runner.py index e38d53b94a..0915335824 100644 --- a/test/utils_spec_runner.py +++ b/test/utils_spec_runner.py @@ -40,11 +40,11 @@ from gridfs import GridFSBucket from pymongo.errors import BulkWriteError, OperationFailure, PyMongoError from pymongo.read_concern import ReadConcern +from pymongo.read_preferences import ReadPreference from pymongo.results import BulkWriteResult, _WriteResult from pymongo.synchronous import client_session from pymongo.synchronous.command_cursor import CommandCursor from pymongo.synchronous.cursor import Cursor -from pymongo.synchronous.read_preferences import ReadPreference from pymongo.write_concern import WriteConcern diff --git a/tools/synchro.py b/tools/synchro.py index 2a0c4f4318..825cfb08a2 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -33,6 +33,9 @@ "AsyncCommandCursor": "CommandCursor", "AsyncRawBatchCursor": "RawBatchCursor", "AsyncRawBatchCommandCursor": "RawBatchCommandCursor", + "AsyncClientSession": "ClientSession", + "_AsyncBulk": "_Bulk", + "AsyncConnection": "Connection", "async_command": "command", "async_receive_message": "receive_message", "async_sendall": "sendall", From 72256755a1e2a2d950da11c43245f02f84a37748 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Tue, 11 Jun 2024 12:56:30 -0700 Subject: [PATCH 02/11] WIP message refactor --- pymongo/asynchronous/bulk.py | 10 +- pymongo/asynchronous/collection.py | 5 +- pymongo/asynchronous/command_cursor.py | 8 +- pymongo/asynchronous/cursor.py | 20 +- pymongo/asynchronous/message.py | 1101 +----------------------- pymongo/asynchronous/mongo_client.py | 6 +- pymongo/asynchronous/network.py | 5 +- pymongo/asynchronous/pool.py | 2 +- pymongo/asynchronous/server.py | 39 +- pymongo/message.py | 1085 +++++++++++++++++++++++ pymongo/response.py | 2 +- pymongo/synchronous/bulk.py | 10 +- pymongo/synchronous/collection.py | 5 +- pymongo/synchronous/command_cursor.py | 8 +- pymongo/synchronous/cursor.py | 8 +- pymongo/synchronous/message.py | 1101 +----------------------- pymongo/synchronous/mongo_client.py | 6 +- pymongo/synchronous/network.py | 5 +- pymongo/synchronous/pool.py | 2 +- pymongo/synchronous/response.py | 133 --- pymongo/synchronous/server.py | 35 +- test/test_change_stream.py | 2 +- test/test_client.py | 6 +- test/test_collection.py | 2 +- test/test_custom_types.py | 2 +- test/test_grid_file.py | 2 +- test/test_pooling.py | 3 +- test/test_read_preferences.py | 2 +- 28 files changed, 1262 insertions(+), 2353 deletions(-) create mode 100644 pymongo/message.py delete mode 100644 pymongo/synchronous/response.py diff --git a/pymongo/asynchronous/bulk.py b/pymongo/asynchronous/bulk.py index 66ed994142..8593123d0c 100644 --- a/pymongo/asynchronous/bulk.py +++ b/pymongo/asynchronous/bulk.py @@ -37,12 +37,8 @@ from pymongo import _csot, common from pymongo.asynchronous.client_session import AsyncClientSession, _validate_session_write_concern from pymongo.asynchronous.message import ( - _DELETE, - _INSERT, - _UPDATE, _BulkWriteContext, _EncryptedBulkWriteContext, - _randint, ) from pymongo.common import ( validate_is_document_type, @@ -56,6 +52,12 @@ OperationFailure, ) from pymongo.helpers_shared import _RETRYABLE_ERROR_CODES, _get_wce_doc +from pymongo.message import ( + _DELETE, + _INSERT, + _UPDATE, + _randint, +) from pymongo.read_preferences import ReadPreference from pymongo.write_concern import WriteConcern diff --git a/pymongo/asynchronous/collection.py b/pymongo/asynchronous/collection.py index 836d4c61e3..873a8be02e 100644 --- a/pymongo/asynchronous/collection.py +++ b/pymongo/asynchronous/collection.py @@ -41,8 +41,7 @@ from bson.raw_bson import RawBSONDocument from bson.son import SON from bson.timestamp import Timestamp -from pymongo import ASCENDING, _csot, common, helpers_shared -from pymongo.asynchronous import message +from pymongo import ASCENDING, _csot, common, helpers_shared, message from pymongo.asynchronous.aggregation import ( _CollectionAggregationCommand, _CollectionRawAggregationCommand, @@ -57,7 +56,6 @@ AsyncCursor, AsyncRawBatchCursor, ) -from pymongo.asynchronous.message import _UNICODE_REPLACE_CODEC_OPTIONS from pymongo.collation import validate_collation_or_none from pymongo.common import _ecoc_coll_name, _esc_coll_name from pymongo.errors import ( @@ -67,6 +65,7 @@ OperationFailure, ) from pymongo.helpers_shared import _check_write_command_response +from pymongo.message import _UNICODE_REPLACE_CODEC_OPTIONS from pymongo.operations import ( DeleteMany, DeleteOne, diff --git a/pymongo/asynchronous/command_cursor.py b/pymongo/asynchronous/command_cursor.py index 4dbc52802a..695933ef84 100644 --- a/pymongo/asynchronous/command_cursor.py +++ b/pymongo/asynchronous/command_cursor.py @@ -30,15 +30,15 @@ from bson import CodecOptions, _convert_raw_document_lists_to_streams from pymongo.asynchronous.cursor import _ConnectionManager -from pymongo.asynchronous.message import ( +from pymongo.cursor_shared import _CURSOR_CLOSED_ERRORS +from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure +from pymongo.message import ( _CursorAddress, _GetMore, _OpMsg, _OpReply, _RawBatchGetMore, ) -from pymongo.cursor_shared import _CURSOR_CLOSED_ERRORS -from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure from pymongo.response import PinnedResponse from pymongo.typings import _Address, _DocumentOut, _DocumentType @@ -260,7 +260,7 @@ async def _send_message(self, operation: _GetMore) -> None: if isinstance(response, PinnedResponse): if not self._sock_mgr: - self._sock_mgr = _ConnectionManager(response.conn, response.more_to_come) + self._sock_mgr = _ConnectionManager(response.conn, response.more_to_come) # type: ignore[arg-type] if response.from_command: cursor = response.docs[0]["cursor"] documents = cursor["nextBatch"] diff --git a/pymongo/asynchronous/cursor.py b/pymongo/asynchronous/cursor.py index 8213e9e64e..5aa4ff39d2 100644 --- a/pymongo/asynchronous/cursor.py +++ b/pymongo/asynchronous/cursor.py @@ -38,15 +38,6 @@ from bson.son import SON from pymongo import helpers_shared from pymongo.asynchronous.helpers import anext -from pymongo.asynchronous.message import ( - _CursorAddress, - _GetMore, - _OpMsg, - _OpReply, - _Query, - _RawBatchGetMore, - _RawBatchQuery, -) from pymongo.collation import validate_collation_or_none from pymongo.common import ( validate_is_document_type, @@ -55,6 +46,15 @@ from pymongo.cursor_shared import _CURSOR_CLOSED_ERRORS, _QUERY_OPTIONS, CursorType, _Hint, _Sort from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure from pymongo.lock import _ALock, _create_lock +from pymongo.message import ( + _CursorAddress, + _GetMore, + _OpMsg, + _OpReply, + _Query, + _RawBatchGetMore, + _RawBatchQuery, +) from pymongo.response import PinnedResponse from pymongo.typings import _Address, _CollationIn, _DocumentOut, _DocumentType from pymongo.write_concern import validate_boolean @@ -1104,7 +1104,7 @@ async def _send_message(self, operation: Union[_Query, _GetMore]) -> None: self._address = response.address if isinstance(response, PinnedResponse): if not self._sock_mgr: - self._sock_mgr = _ConnectionManager(response.conn, response.more_to_come) + self._sock_mgr = _ConnectionManager(response.conn, response.more_to_come) # type: ignore[arg-type] cmd_name = operation.name docs = response.docs diff --git a/pymongo/asynchronous/message.py b/pymongo/asynchronous/message.py index 0973677d3a..9a41b3475e 100644 --- a/pymongo/asynchronous/message.py +++ b/pymongo/asynchronous/message.py @@ -23,889 +23,70 @@ import datetime import logging -import random -import struct from io import BytesIO as _BytesIO from typing import ( TYPE_CHECKING, Any, - Callable, - Iterable, Mapping, MutableMapping, - NoReturn, Optional, - Union, ) -import bson -from bson import CodecOptions, _decode_selective, _dict_to_bson, _make_c_string, encode -from bson.int64 import Int64 +from bson import CodecOptions, _dict_to_bson, encode from bson.raw_bson import ( - _RAW_ARRAY_BSON_OPTIONS, DEFAULT_RAW_BSON_OPTIONS, - RawBSONDocument, _inflate_bson, ) - -try: - from pymongo import _cmessage # type: ignore[attr-defined] - - _use_c = True -except ImportError: - _use_c = False from pymongo.asynchronous.helpers import _handle_reauth from pymongo.errors import ( - ConfigurationError, - CursorNotFound, - DocumentTooLarge, - ExecutionTimeout, InvalidOperation, NotPrimaryError, OperationFailure, - ProtocolError, ) -from pymongo.hello_compat import HelloCompat from pymongo.logger import ( _COMMAND_LOGGER, _CommandStatusMessage, _debug_log, ) -from pymongo.read_preferences import ReadPreference +from pymongo.message import ( + _BSONOBJ, + _COMMAND_OVERHEAD, + _FIELD_MAP, + _OP_MAP, + _OP_MSG_MAP, + _SKIPLIM, + _ZERO_8, + _ZERO_16, + _ZERO_32, + _ZERO_64, + _compress, + _convert_exception, + _convert_write_result, + _pack_int, + _raise_document_too_large, + _randint, +) from pymongo.write_concern import WriteConcern +try: + from pymongo import _cmessage # type: ignore[attr-defined] + + _use_c = True +except ImportError: + _use_c = False + if TYPE_CHECKING: from datetime import timedelta from pymongo.asynchronous.client_session import AsyncClientSession from pymongo.asynchronous.mongo_client import AsyncMongoClient from pymongo.asynchronous.pool import AsyncConnection - from pymongo.compression_support import SnappyContext, ZlibContext, ZstdContext from pymongo.monitoring import _EventListeners - from pymongo.read_concern import ReadConcern - from pymongo.read_preferences import _ServerMode - from pymongo.typings import _Address, _DocumentOut + from pymongo.typings import _DocumentOut _IS_SYNC = False -MAX_INT32 = 2147483647 -MIN_INT32 = -2147483648 - -# Overhead allowed for encoded command documents. -_COMMAND_OVERHEAD = 16382 - -_INSERT = 0 -_UPDATE = 1 -_DELETE = 2 - -_EMPTY = b"" -_BSONOBJ = b"\x03" -_ZERO_8 = b"\x00" -_ZERO_16 = b"\x00\x00" -_ZERO_32 = b"\x00\x00\x00\x00" -_ZERO_64 = b"\x00\x00\x00\x00\x00\x00\x00\x00" -_SKIPLIM = b"\x00\x00\x00\x00\xff\xff\xff\xff" -_OP_MAP = { - _INSERT: b"\x04documents\x00\x00\x00\x00\x00", - _UPDATE: b"\x04updates\x00\x00\x00\x00\x00", - _DELETE: b"\x04deletes\x00\x00\x00\x00\x00", -} -_FIELD_MAP = {"insert": "documents", "update": "updates", "delete": "deletes"} - -_UNICODE_REPLACE_CODEC_OPTIONS: CodecOptions[Mapping[str, Any]] = CodecOptions( - unicode_decode_error_handler="replace" -) - - -def _randint() -> int: - """Generate a pseudo random 32 bit integer.""" - return random.randint(MIN_INT32, MAX_INT32) # noqa: S311 - - -def _maybe_add_read_preference( - spec: MutableMapping[str, Any], read_preference: _ServerMode -) -> MutableMapping[str, Any]: - """Add $readPreference to spec when appropriate.""" - mode = read_preference.mode - document = read_preference.document - # Only add $readPreference if it's something other than primary to avoid - # problems with mongos versions that don't support read preferences. Also, - # for maximum backwards compatibility, don't add $readPreference for - # secondaryPreferred unless tags or maxStalenessSeconds are in use (setting - # the secondaryOkay bit has the same effect). - if mode and (mode != ReadPreference.SECONDARY_PREFERRED.mode or len(document) > 1): - if "$query" not in spec: - spec = {"$query": spec} - spec["$readPreference"] = document - return spec - - -def _convert_exception(exception: Exception) -> dict[str, Any]: - """Convert an Exception into a failure document for publishing.""" - return {"errmsg": str(exception), "errtype": exception.__class__.__name__} - - -def _convert_write_result( - operation: str, command: Mapping[str, Any], result: Mapping[str, Any] -) -> dict[str, Any]: - """Convert a legacy write result to write command format.""" - # Based on _merge_legacy from bulk.py - affected = result.get("n", 0) - res = {"ok": 1, "n": affected} - errmsg = result.get("errmsg", result.get("err", "")) - if errmsg: - # The write was successful on at least the primary so don't return. - if result.get("wtimeout"): - res["writeConcernError"] = {"errmsg": errmsg, "code": 64, "errInfo": {"wtimeout": True}} - else: - # The write failed. - error = {"index": 0, "code": result.get("code", 8), "errmsg": errmsg} - if "errInfo" in result: - error["errInfo"] = result["errInfo"] - res["writeErrors"] = [error] - return res - if operation == "insert": - # GLE result for insert is always 0 in most MongoDB versions. - res["n"] = len(command["documents"]) - elif operation == "update": - if "upserted" in result: - res["upserted"] = [{"index": 0, "_id": result["upserted"]}] - # Versions of MongoDB before 2.6 don't return the _id for an - # upsert if _id is not an ObjectId. - elif result.get("updatedExisting") is False and affected == 1: - # If _id is in both the update document *and* the query spec - # the update document _id takes precedence. - update = command["updates"][0] - _id = update["u"].get("_id", update["q"].get("_id")) - res["upserted"] = [{"index": 0, "_id": _id}] - return res - - -_OPTIONS = { - "tailable": 2, - "oplogReplay": 8, - "noCursorTimeout": 16, - "awaitData": 32, - "allowPartialResults": 128, -} - - -_MODIFIERS = { - "$query": "filter", - "$orderby": "sort", - "$hint": "hint", - "$comment": "comment", - "$maxScan": "maxScan", - "$maxTimeMS": "maxTimeMS", - "$max": "max", - "$min": "min", - "$returnKey": "returnKey", - "$showRecordId": "showRecordId", - "$showDiskLoc": "showRecordId", # <= MongoDb 3.0 - "$snapshot": "snapshot", -} - - -def _gen_find_command( - coll: str, - spec: Mapping[str, Any], - projection: Optional[Union[Mapping[str, Any], Iterable[str]]], - skip: int, - limit: int, - batch_size: Optional[int], - options: Optional[int], - read_concern: ReadConcern, - collation: Optional[Mapping[str, Any]] = None, - session: Optional[AsyncClientSession] = None, - allow_disk_use: Optional[bool] = None, -) -> dict[str, Any]: - """Generate a find command document.""" - cmd: dict[str, Any] = {"find": coll} - if "$query" in spec: - cmd.update( - [ - (_MODIFIERS[key], val) if key in _MODIFIERS else (key, val) - for key, val in spec.items() - ] - ) - if "$explain" in cmd: - cmd.pop("$explain") - if "$readPreference" in cmd: - cmd.pop("$readPreference") - else: - cmd["filter"] = spec - - if projection: - cmd["projection"] = projection - if skip: - cmd["skip"] = skip - if limit: - cmd["limit"] = abs(limit) - if limit < 0: - cmd["singleBatch"] = True - if batch_size: - cmd["batchSize"] = batch_size - if read_concern.level and not (session and session.in_transaction): - cmd["readConcern"] = read_concern.document - if collation: - cmd["collation"] = collation - if allow_disk_use is not None: - cmd["allowDiskUse"] = allow_disk_use - if options: - cmd.update([(opt, True) for opt, val in _OPTIONS.items() if options & val]) - - return cmd - - -def _gen_get_more_command( - cursor_id: Optional[int], - coll: str, - batch_size: Optional[int], - max_await_time_ms: Optional[int], - comment: Optional[Any], - conn: AsyncConnection, -) -> dict[str, Any]: - """Generate a getMore command document.""" - cmd: dict[str, Any] = {"getMore": cursor_id, "collection": coll} - if batch_size: - cmd["batchSize"] = batch_size - if max_await_time_ms is not None: - cmd["maxTimeMS"] = max_await_time_ms - if comment is not None and conn.max_wire_version >= 9: - cmd["comment"] = comment - return cmd - - -class _Query: - """A query operation.""" - - __slots__ = ( - "flags", - "db", - "coll", - "ntoskip", - "spec", - "fields", - "codec_options", - "read_preference", - "limit", - "batch_size", - "name", - "read_concern", - "collation", - "session", - "client", - "allow_disk_use", - "_as_command", - "exhaust", - ) - - # For compatibility with the _GetMore class. - conn_mgr = None - cursor_id = None - - def __init__( - self, - flags: int, - db: str, - coll: str, - ntoskip: int, - spec: Mapping[str, Any], - fields: Optional[Mapping[str, Any]], - codec_options: CodecOptions, - read_preference: _ServerMode, - limit: int, - batch_size: int, - read_concern: ReadConcern, - collation: Optional[Mapping[str, Any]], - session: Optional[AsyncClientSession], - client: AsyncMongoClient, - allow_disk_use: Optional[bool], - exhaust: bool, - ): - self.flags = flags - self.db = db - self.coll = coll - self.ntoskip = ntoskip - self.spec = spec - self.fields = fields - self.codec_options = codec_options - self.read_preference = read_preference - self.read_concern = read_concern - self.limit = limit - self.batch_size = batch_size - self.collation = collation - self.session = session - self.client = client - self.allow_disk_use = allow_disk_use - self.name = "find" - self._as_command: Optional[tuple[dict[str, Any], str]] = None - self.exhaust = exhaust - - def reset(self) -> None: - self._as_command = None - - def namespace(self) -> str: - return f"{self.db}.{self.coll}" - - def use_command(self, conn: AsyncConnection) -> bool: - use_find_cmd = False - if not self.exhaust: - use_find_cmd = True - elif conn.max_wire_version >= 8: - # OP_MSG supports exhaust on MongoDB 4.2+ - use_find_cmd = True - elif not self.read_concern.ok_for_legacy: - raise ConfigurationError( - "read concern level of %s is not valid " - "with a max wire version of %d." % (self.read_concern.level, conn.max_wire_version) - ) - - conn.validate_session(self.client, self.session) - return use_find_cmd - - async def as_command( - self, conn: AsyncConnection, apply_timeout: bool = False - ) -> tuple[dict[str, Any], str]: - """Return a find command document for this query.""" - # We use the command twice: on the wire and for command monitoring. - # Generate it once, for speed and to avoid repeating side-effects. - if self._as_command is not None: - return self._as_command - - explain = "$explain" in self.spec - cmd: dict[str, Any] = _gen_find_command( - self.coll, - self.spec, - self.fields, - self.ntoskip, - self.limit, - self.batch_size, - self.flags, - self.read_concern, - self.collation, - self.session, - self.allow_disk_use, - ) - if explain: - self.name = "explain" - cmd = {"explain": cmd} - session = self.session - conn.add_server_api(cmd) - if session: - await session._apply_to(cmd, False, self.read_preference, conn) - # Explain does not support readConcern. - if not explain and not session.in_transaction: - session._update_read_concern(cmd, conn) - conn.send_cluster_time(cmd, session, self.client) - # Support auto encryption - client = self.client - if client._encrypter and not client._encrypter._bypass_auto_encryption: - cmd = await client._encrypter.encrypt(self.db, cmd, self.codec_options) - # Support CSOT - if apply_timeout: - conn.apply_timeout(client, cmd) - self._as_command = cmd, self.db - return self._as_command - - async def get_message( - self, read_preference: _ServerMode, conn: AsyncConnection, use_cmd: bool = False - ) -> tuple[int, bytes, int]: - """Get a query message, possibly setting the secondaryOk bit.""" - # Use the read_preference decided by _socket_from_server. - self.read_preference = read_preference - if read_preference.mode: - # Set the secondaryOk bit. - flags = self.flags | 4 - else: - flags = self.flags - - ns = self.namespace() - spec = self.spec - - if use_cmd: - spec = (await self.as_command(conn, apply_timeout=True))[0] - request_id, msg, size, _ = _op_msg( - 0, - spec, - self.db, - read_preference, - self.codec_options, - ctx=conn.compression_context, - ) - return request_id, msg, size - - # OP_QUERY treats ntoreturn of -1 and 1 the same, return - # one document and close the cursor. We have to use 2 for - # batch size if 1 is specified. - ntoreturn = self.batch_size == 1 and 2 or self.batch_size - if self.limit: - if ntoreturn: - ntoreturn = min(self.limit, ntoreturn) - else: - ntoreturn = self.limit - - if conn.is_mongos: - assert isinstance(spec, MutableMapping) - spec = _maybe_add_read_preference(spec, read_preference) - - return _query( - flags, - ns, - self.ntoskip, - ntoreturn, - spec, - None if use_cmd else self.fields, - self.codec_options, - ctx=conn.compression_context, - ) - - -class _GetMore: - """A getmore operation.""" - - __slots__ = ( - "db", - "coll", - "ntoreturn", - "cursor_id", - "max_await_time_ms", - "codec_options", - "read_preference", - "session", - "client", - "conn_mgr", - "_as_command", - "exhaust", - "comment", - ) - - name = "getMore" - - def __init__( - self, - db: str, - coll: str, - ntoreturn: int, - cursor_id: int, - codec_options: CodecOptions, - read_preference: _ServerMode, - session: Optional[AsyncClientSession], - client: AsyncMongoClient, - max_await_time_ms: Optional[int], - conn_mgr: Any, - exhaust: bool, - comment: Any, - ): - self.db = db - self.coll = coll - self.ntoreturn = ntoreturn - self.cursor_id = cursor_id - self.codec_options = codec_options - self.read_preference = read_preference - self.session = session - self.client = client - self.max_await_time_ms = max_await_time_ms - self.conn_mgr = conn_mgr - self._as_command: Optional[tuple[dict[str, Any], str]] = None - self.exhaust = exhaust - self.comment = comment - - def reset(self) -> None: - self._as_command = None - - def namespace(self) -> str: - return f"{self.db}.{self.coll}" - - def use_command(self, conn: AsyncConnection) -> bool: - use_cmd = False - if not self.exhaust: - use_cmd = True - elif conn.max_wire_version >= 8: - # OP_MSG supports exhaust on MongoDB 4.2+ - use_cmd = True - - conn.validate_session(self.client, self.session) - return use_cmd - - async def as_command( - self, conn: AsyncConnection, apply_timeout: bool = False - ) -> tuple[dict[str, Any], str]: - """Return a getMore command document for this query.""" - # See _Query.as_command for an explanation of this caching. - if self._as_command is not None: - return self._as_command - - cmd: dict[str, Any] = _gen_get_more_command( - self.cursor_id, - self.coll, - self.ntoreturn, - self.max_await_time_ms, - self.comment, - conn, - ) - if self.session: - await self.session._apply_to(cmd, False, self.read_preference, conn) - conn.add_server_api(cmd) - conn.send_cluster_time(cmd, self.session, self.client) - # Support auto encryption - client = self.client - if client._encrypter and not client._encrypter._bypass_auto_encryption: - cmd = await client._encrypter.encrypt(self.db, cmd, self.codec_options) - # Support CSOT - if apply_timeout: - conn.apply_timeout(client, cmd=None) - self._as_command = cmd, self.db - return self._as_command - - async def get_message( - self, dummy0: Any, conn: AsyncConnection, use_cmd: bool = False - ) -> Union[tuple[int, bytes, int], tuple[int, bytes]]: - """Get a getmore message.""" - ns = self.namespace() - ctx = conn.compression_context - - if use_cmd: - spec = (await self.as_command(conn, apply_timeout=True))[0] - if self.conn_mgr and self.exhaust: - flags = _OpMsg.EXHAUST_ALLOWED - else: - flags = 0 - request_id, msg, size, _ = _op_msg( - flags, spec, self.db, None, self.codec_options, ctx=conn.compression_context - ) - return request_id, msg, size - - return _get_more(ns, self.ntoreturn, self.cursor_id, ctx) - - -class _RawBatchQuery(_Query): - def use_command(self, conn: AsyncConnection) -> bool: - # Compatibility checks. - super().use_command(conn) - if conn.max_wire_version >= 8: - # MongoDB 4.2+ supports exhaust over OP_MSG - return True - elif not self.exhaust: - return True - return False - - -class _RawBatchGetMore(_GetMore): - def use_command(self, conn: AsyncConnection) -> bool: - # Compatibility checks. - super().use_command(conn) - if conn.max_wire_version >= 8: - # MongoDB 4.2+ supports exhaust over OP_MSG - return True - elif not self.exhaust: - return True - return False - - -class _CursorAddress(tuple): - """The server address (host, port) of a cursor, with namespace property.""" - - __namespace: Any - - def __new__(cls, address: _Address, namespace: str) -> _CursorAddress: - self = tuple.__new__(cls, address) - self.__namespace = namespace - return self - - @property - def namespace(self) -> str: - """The namespace this cursor.""" - return self.__namespace - - def __hash__(self) -> int: - # Two _CursorAddress instances with different namespaces - # must not hash the same. - return ((*self, self.__namespace)).__hash__() - - def __eq__(self, other: object) -> bool: - if isinstance(other, _CursorAddress): - return tuple(self) == tuple(other) and self.namespace == other.namespace - return NotImplemented - - def __ne__(self, other: object) -> bool: - return not self == other - - -_pack_compression_header = struct.Struct(" tuple[int, bytes]: - """Takes message data, compresses it, and adds an OP_COMPRESSED header.""" - compressed = ctx.compress(data) - request_id = _randint() - - header = _pack_compression_header( - _COMPRESSION_HEADER_SIZE + len(compressed), # Total message length - request_id, # Request id - 0, # responseTo - 2012, # operation id - operation, # original operation id - len(data), # uncompressed message length - ctx.compressor_id, - ) # compressor id - return request_id, header + compressed - - -_pack_header = struct.Struct(" tuple[int, bytes]: - """Takes message data and adds a message header based on the operation. - - Returns the resultant message string. - """ - rid = _randint() - message = _pack_header(16 + len(data), rid, 0, operation) - return rid, message + data - - -_pack_int = struct.Struct(" tuple[bytes, int, int]: - """Get a OP_MSG message. - - Note: this method handles multiple documents in a type one payload but - it does not perform batch splitting and the total message size is - only checked *after* generating the entire message. - """ - # Encode the command document in payload 0 without checking keys. - encoded = _dict_to_bson(command, False, opts) - flags_type = _pack_op_msg_flags_type(flags, 0) - total_size = len(encoded) - max_doc_size = 0 - if identifier and docs is not None: - type_one = _pack_byte(1) - cstring = _make_c_string(identifier) - encoded_docs = [_dict_to_bson(doc, False, opts) for doc in docs] - size = len(cstring) + sum(len(doc) for doc in encoded_docs) + 4 - encoded_size = _pack_int(size) - total_size += size - max_doc_size = max(len(doc) for doc in encoded_docs) - data = [flags_type, encoded, type_one, encoded_size, cstring, *encoded_docs] - else: - data = [flags_type, encoded] - return b"".join(data), total_size, max_doc_size - - -def _op_msg_compressed( - flags: int, - command: Mapping[str, Any], - identifier: str, - docs: Optional[list[Mapping[str, Any]]], - opts: CodecOptions, - ctx: Union[SnappyContext, ZlibContext, ZstdContext], -) -> tuple[int, bytes, int, int]: - """Internal OP_MSG message helper.""" - msg, total_size, max_bson_size = _op_msg_no_header(flags, command, identifier, docs, opts) - rid, msg = _compress(2013, msg, ctx) - return rid, msg, total_size, max_bson_size - - -def _op_msg_uncompressed( - flags: int, - command: Mapping[str, Any], - identifier: str, - docs: Optional[list[Mapping[str, Any]]], - opts: CodecOptions, -) -> tuple[int, bytes, int, int]: - """Internal compressed OP_MSG message helper.""" - data, total_size, max_bson_size = _op_msg_no_header(flags, command, identifier, docs, opts) - request_id, op_message = __pack_message(2013, data) - return request_id, op_message, total_size, max_bson_size - - -if _use_c: - _op_msg_uncompressed = _cmessage._op_msg - - -def _op_msg( - flags: int, - command: MutableMapping[str, Any], - dbname: str, - read_preference: Optional[_ServerMode], - opts: CodecOptions, - ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None, -) -> tuple[int, bytes, int, int]: - """Get a OP_MSG message.""" - command["$db"] = dbname - # getMore commands do not send $readPreference. - if read_preference is not None and "$readPreference" not in command: - # Only send $readPreference if it's not primary (the default). - if read_preference.mode: - command["$readPreference"] = read_preference.document - name = next(iter(command)) - try: - identifier = _FIELD_MAP[name] - docs = command.pop(identifier) - except KeyError: - identifier = "" - docs = None - try: - if ctx: - return _op_msg_compressed(flags, command, identifier, docs, opts, ctx) - return _op_msg_uncompressed(flags, command, identifier, docs, opts) - finally: - # Add the field back to the command. - if identifier: - command[identifier] = docs - - -def _query_impl( - options: int, - collection_name: str, - num_to_skip: int, - num_to_return: int, - query: Mapping[str, Any], - field_selector: Optional[Mapping[str, Any]], - opts: CodecOptions, -) -> tuple[bytes, int]: - """Get an OP_QUERY message.""" - encoded = _dict_to_bson(query, False, opts) - if field_selector: - efs = _dict_to_bson(field_selector, False, opts) - else: - efs = b"" - max_bson_size = max(len(encoded), len(efs)) - return ( - b"".join( - [ - _pack_int(options), - _make_c_string(collection_name), - _pack_int(num_to_skip), - _pack_int(num_to_return), - encoded, - efs, - ] - ), - max_bson_size, - ) - - -def _query_compressed( - options: int, - collection_name: str, - num_to_skip: int, - num_to_return: int, - query: Mapping[str, Any], - field_selector: Optional[Mapping[str, Any]], - opts: CodecOptions, - ctx: Union[SnappyContext, ZlibContext, ZstdContext], -) -> tuple[int, bytes, int]: - """Internal compressed query message helper.""" - op_query, max_bson_size = _query_impl( - options, collection_name, num_to_skip, num_to_return, query, field_selector, opts - ) - rid, msg = _compress(2004, op_query, ctx) - return rid, msg, max_bson_size - - -def _query_uncompressed( - options: int, - collection_name: str, - num_to_skip: int, - num_to_return: int, - query: Mapping[str, Any], - field_selector: Optional[Mapping[str, Any]], - opts: CodecOptions, -) -> tuple[int, bytes, int]: - """Internal query message helper.""" - op_query, max_bson_size = _query_impl( - options, collection_name, num_to_skip, num_to_return, query, field_selector, opts - ) - rid, msg = __pack_message(2004, op_query) - return rid, msg, max_bson_size - - -if _use_c: - _query_uncompressed = _cmessage._query_message - - -def _query( - options: int, - collection_name: str, - num_to_skip: int, - num_to_return: int, - query: Mapping[str, Any], - field_selector: Optional[Mapping[str, Any]], - opts: CodecOptions, - ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None, -) -> tuple[int, bytes, int]: - """Get a **query** message.""" - if ctx: - return _query_compressed( - options, collection_name, num_to_skip, num_to_return, query, field_selector, opts, ctx - ) - return _query_uncompressed( - options, collection_name, num_to_skip, num_to_return, query, field_selector, opts - ) - - -_pack_long_long = struct.Struct(" bytes: - """Get an OP_GET_MORE message.""" - return b"".join( - [ - _ZERO_32, - _make_c_string(collection_name), - _pack_int(num_to_return), - _pack_long_long(cursor_id), - ] - ) - - -def _get_more_compressed( - collection_name: str, - num_to_return: int, - cursor_id: int, - ctx: Union[SnappyContext, ZlibContext, ZstdContext], -) -> tuple[int, bytes]: - """Internal compressed getMore message helper.""" - return _compress(2005, _get_more_impl(collection_name, num_to_return, cursor_id), ctx) - - -def _get_more_uncompressed( - collection_name: str, num_to_return: int, cursor_id: int -) -> tuple[int, bytes]: - """Internal getMore message helper.""" - return __pack_message(2005, _get_more_impl(collection_name, num_to_return, cursor_id)) - - -if _use_c: - _get_more_uncompressed = _cmessage._get_more_message - - -def _get_more( - collection_name: str, - num_to_return: int, - cursor_id: int, - ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None, -) -> tuple[int, bytes]: - """Get a **getMore** message.""" - if ctx: - return _get_more_compressed(collection_name, num_to_return, cursor_id, ctx) - return _get_more_uncompressed(collection_name, num_to_return, cursor_id) - class _BulkWriteContext: """A wrapper around AsyncConnection for use with write splitting functions.""" @@ -1273,31 +454,6 @@ def max_split_size(self) -> int: return _MAX_SPLIT_SIZE_ENC -def _raise_document_too_large(operation: str, doc_size: int, max_size: int) -> NoReturn: - """Internal helper for raising DocumentTooLarge.""" - if operation == "insert": - raise DocumentTooLarge( - "BSON document too large (%d bytes)" - " - the connected server supports" - " BSON document sizes up to %d" - " bytes." % (doc_size, max_size) - ) - else: - # There's nothing intelligent we can say - # about size for update and delete - raise DocumentTooLarge(f"{operation!r} command document too large") - - -# OP_MSG ------------------------------------------------------------- - - -_OP_MSG_MAP = { - _INSERT: b"documents\x00", - _UPDATE: b"updates\x00", - _DELETE: b"deletes\x00", -} - - def _batched_op_msg_impl( operation: int, command: Mapping[str, Any], @@ -1555,206 +711,3 @@ def _batched_write_command_impl( buf.write(_pack_int(length - command_start)) return to_send, length - - -class _OpReply: - """A MongoDB OP_REPLY response message.""" - - __slots__ = ("flags", "cursor_id", "number_returned", "documents") - - UNPACK_FROM = struct.Struct(" list[bytes]: - """Check the response header from the database, without decoding BSON. - - Check the response for errors and unpack. - - Can raise CursorNotFound, NotPrimaryError, ExecutionTimeout, or - OperationFailure. - - :param cursor_id: cursor_id we sent to get this response - - used for raising an informative exception when we get cursor id not - valid at server response. - """ - if self.flags & 1: - # Shouldn't get this response if we aren't doing a getMore - if cursor_id is None: - raise ProtocolError("No cursor id for getMore operation") - - # Fake a getMore command response. OP_GET_MORE provides no - # document. - msg = "Cursor not found, cursor id: %d" % (cursor_id,) - errobj = {"ok": 0, "errmsg": msg, "code": 43} - raise CursorNotFound(msg, 43, errobj) - elif self.flags & 2: - error_object: dict = bson.BSON(self.documents).decode() - # Fake the ok field if it doesn't exist. - error_object.setdefault("ok", 0) - if error_object["$err"].startswith(HelloCompat.LEGACY_ERROR): - raise NotPrimaryError(error_object["$err"], error_object) - elif error_object.get("code") == 50: - default_msg = "operation exceeded time limit" - raise ExecutionTimeout( - error_object.get("$err", default_msg), error_object.get("code"), error_object - ) - raise OperationFailure( - "database error: %s" % error_object.get("$err"), - error_object.get("code"), - error_object, - ) - if self.documents: - return [self.documents] - return [] - - def unpack_response( - self, - cursor_id: Optional[int] = None, - codec_options: CodecOptions = _UNICODE_REPLACE_CODEC_OPTIONS, - user_fields: Optional[Mapping[str, Any]] = None, - legacy_response: bool = False, - ) -> list[dict[str, Any]]: - """Unpack a response from the database and decode the BSON document(s). - - Check the response for errors and unpack, returning a dictionary - containing the response data. - - Can raise CursorNotFound, NotPrimaryError, ExecutionTimeout, or - OperationFailure. - - :param cursor_id: cursor_id we sent to get this response - - used for raising an informative exception when we get cursor id not - valid at server response - :param codec_options: an instance of - :class:`~bson.codec_options.CodecOptions` - :param user_fields: Response fields that should be decoded - using the TypeDecoders from codec_options, passed to - bson._decode_all_selective. - """ - self.raw_response(cursor_id) - if legacy_response: - return bson.decode_all(self.documents, codec_options) - return bson._decode_all_selective(self.documents, codec_options, user_fields) - - def command_response(self, codec_options: CodecOptions) -> dict[str, Any]: - """Unpack a command response.""" - docs = self.unpack_response(codec_options=codec_options) - assert self.number_returned == 1 - return docs[0] - - def raw_command_response(self) -> NoReturn: - """Return the bytes of the command response.""" - # This should never be called on _OpReply. - raise NotImplementedError - - @property - def more_to_come(self) -> bool: - """Is the moreToCome bit set on this response?""" - return False - - @classmethod - def unpack(cls, msg: bytes) -> _OpReply: - """Construct an _OpReply from raw bytes.""" - # PYTHON-945: ignore starting_from field. - flags, cursor_id, _, number_returned = cls.UNPACK_FROM(msg) - - documents = msg[20:] - return cls(flags, cursor_id, number_returned, documents) - - -class _OpMsg: - """A MongoDB OP_MSG response message.""" - - __slots__ = ("flags", "cursor_id", "number_returned", "payload_document") - - UNPACK_FROM = struct.Struct(" list[Mapping[str, Any]]: - """ - cursor_id is ignored - user_fields is used to determine which fields must not be decoded - """ - inflated_response = _decode_selective( - RawBSONDocument(self.payload_document), user_fields, _RAW_ARRAY_BSON_OPTIONS - ) - return [inflated_response] - - def unpack_response( - self, - cursor_id: Optional[int] = None, - codec_options: CodecOptions = _UNICODE_REPLACE_CODEC_OPTIONS, - user_fields: Optional[Mapping[str, Any]] = None, - legacy_response: bool = False, - ) -> list[dict[str, Any]]: - """Unpack a OP_MSG command response. - - :param cursor_id: Ignored, for compatibility with _OpReply. - :param codec_options: an instance of - :class:`~bson.codec_options.CodecOptions` - :param user_fields: Response fields that should be decoded - using the TypeDecoders from codec_options, passed to - bson._decode_all_selective. - """ - # If _OpMsg is in-use, this cannot be a legacy response. - assert not legacy_response - return bson._decode_all_selective(self.payload_document, codec_options, user_fields) - - def command_response(self, codec_options: CodecOptions) -> dict[str, Any]: - """Unpack a command response.""" - return self.unpack_response(codec_options=codec_options)[0] - - def raw_command_response(self) -> bytes: - """Return the bytes of the command response.""" - return self.payload_document - - @property - def more_to_come(self) -> bool: - """Is the moreToCome bit set on this response?""" - return bool(self.flags & self.MORE_TO_COME) - - @classmethod - def unpack(cls, msg: bytes) -> _OpMsg: - """Construct an _OpMsg from raw bytes.""" - flags, first_payload_type, first_payload_size = cls.UNPACK_FROM(msg) - if flags != 0: - if flags & cls.CHECKSUM_PRESENT: - raise ProtocolError(f"Unsupported OP_MSG flag checksumPresent: 0x{flags:x}") - - if flags ^ cls.MORE_TO_COME: - raise ProtocolError(f"Unsupported OP_MSG flags: 0x{flags:x}") - if first_payload_type != 0: - raise ProtocolError(f"Unsupported OP_MSG payload type: 0x{first_payload_type:x}") - - if len(msg) != first_payload_size + 5: - raise ProtocolError("Unsupported OP_MSG reply: >1 section") - - payload_document = msg[5:] - return cls(flags, payload_document) - - -_UNPACK_REPLY: dict[int, Callable[[bytes], Union[_OpReply, _OpMsg]]] = { - _OpReply.OP_CODE: _OpReply.unpack, - _OpMsg.OP_CODE: _OpMsg.unpack, -} diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index acba1c1e32..9ae1b3fffa 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -59,7 +59,7 @@ from bson.codec_options import DEFAULT_CODEC_OPTIONS, CodecOptions, TypeRegistry from bson.timestamp import Timestamp from pymongo import _csot, common, helpers_shared, uri_parser -from pymongo.asynchronous import client_session, database, message, periodic_executor +from pymongo.asynchronous import client_session, database, periodic_executor from pymongo.asynchronous.change_stream import ChangeStream, ClusterChangeStream from pymongo.asynchronous.client_session import _EmptyServerSession from pymongo.asynchronous.command_cursor import AsyncCommandCursor @@ -81,6 +81,7 @@ ) from pymongo.lock import _HAS_REGISTER_AT_FORK, _ALock, _create_lock, _release_locks from pymongo.logger import _CLIENT_LOGGER, _log_or_warn +from pymongo.message import _CursorAddress, _GetMore, _Query from pymongo.monitoring import ConnectionClosedReason from pymongo.operations import _Op from pymongo.read_preferences import ReadPreference, _ServerMode @@ -111,7 +112,6 @@ from pymongo.asynchronous.bulk import _AsyncBulk from pymongo.asynchronous.client_session import AsyncClientSession, _ServerSession from pymongo.asynchronous.cursor import _ConnectionManager - from pymongo.asynchronous.message import _CursorAddress, _GetMore, _Query from pymongo.asynchronous.pool import AsyncConnection from pymongo.asynchronous.server import Server from pymongo.read_concern import ReadConcern @@ -1706,7 +1706,7 @@ async def _cmd( operation.read_preference, operation.session, address=address, - retryable=isinstance(operation, message._Query), + retryable=isinstance(operation, _Query), operation=operation.name, ) diff --git a/pymongo/asynchronous/network.py b/pymongo/asynchronous/network.py index 7c3444e071..ff43a5ffcb 100644 --- a/pymongo/asynchronous/network.py +++ b/pymongo/asynchronous/network.py @@ -33,9 +33,7 @@ ) from bson import _decode_all_selective -from pymongo import _csot, helpers_shared -from pymongo.asynchronous import message -from pymongo.asynchronous.message import _UNPACK_REPLY, _OpMsg, _OpReply +from pymongo import _csot, helpers_shared, message from pymongo.common import MAX_MESSAGE_SIZE from pymongo.compression_support import _NO_COMPRESSION, decompress from pymongo.errors import ( @@ -45,6 +43,7 @@ _OperationCancelled, ) from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log +from pymongo.message import _UNPACK_REPLY, _OpMsg, _OpReply from pymongo.monitoring import _is_speculative_authenticate from pymongo.network_layer import ( _POLL_TIMEOUT, diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index 265da13187..4fe5af6fd5 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -88,13 +88,13 @@ from bson.objectid import ObjectId from pymongo.asynchronous.auth import _AuthContext from pymongo.asynchronous.client_session import AsyncClientSession - from pymongo.asynchronous.message import _OpMsg, _OpReply from pymongo.asynchronous.mongo_client import AsyncMongoClient, _MongoClientErrorHandler from pymongo.compression_support import ( SnappyContext, ZlibContext, ZstdContext, ) + from pymongo.message import _OpMsg, _OpReply from pymongo.pyopenssl_context import _sslConn from pymongo.read_concern import ReadConcern from pymongo.read_preferences import _ServerMode diff --git a/pymongo/asynchronous/server.py b/pymongo/asynchronous/server.py index 0e6ae1574d..ecd3691d23 100644 --- a/pymongo/asynchronous/server.py +++ b/pymongo/asynchronous/server.py @@ -28,10 +28,10 @@ from bson import _decode_all_selective from pymongo.asynchronous.helpers import _handle_reauth -from pymongo.asynchronous.message import _convert_exception, _GetMore, _OpMsg, _Query from pymongo.errors import NotPrimaryError, OperationFailure from pymongo.helpers_shared import _check_command_response from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log +from pymongo.message import _convert_exception, _GetMore, _OpMsg, _Query from pymongo.response import PinnedResponse, Response if TYPE_CHECKING: @@ -106,6 +106,34 @@ def request_check(self) -> None: """Check the server's state soon.""" self._monitor.request_check() + async def operation_to_command( + self, operation: Union[_Query, _GetMore], conn: AsyncConnection, apply_timeout: bool = False + ) -> tuple[dict[str, Any], str]: + is_query = isinstance(operation, _Query) + if is_query: + explain = "$explain" in operation.spec + cmd, db = operation.as_command() + else: + explain = False + cmd, db = operation.as_command(conn) + if operation.session: + await operation.session._apply_to(cmd, False, operation.read_preference, conn) + # Explain does not support readConcern. + if is_query and not explain and not operation.session.in_transaction: + operation.session._update_read_concern(cmd, conn) + # Support auto encryption + if operation.client._encrypter and not operation.client._encrypter._bypass_auto_encryption: + cmd = await operation.client._encrypter.encrypt( + operation.db, cmd, operation.codec_options + ) + + conn.add_server_api(cmd) + conn.send_cluster_time(cmd, operation.session, operation.client) + # Support CSOT + if apply_timeout: + conn.apply_timeout(operation.client, cmd=cmd if is_query else None) + return cmd, db + @_handle_reauth async def run_operation( self, @@ -122,26 +150,26 @@ async def run_operation( cursors. Can raise ConnectionFailure, OperationFailure, etc. - :param conn: A AsyncConnection instance. + :param conn: An AsyncConnection instance. :param operation: A _Query or _GetMore object. :param read_preference: The read preference to use. :param listeners: Instance of _EventListeners or None. :param unpack_res: A callable that decodes the wire protocol response. + :param client: An AsyncMongoClient instance. """ - duration = None assert listeners is not None publish = listeners.enabled_for_commands start = datetime.now() use_cmd = operation.use_command(conn) more_to_come = operation.conn_mgr and operation.conn_mgr.more_to_come + cmd, dbn = await self.operation_to_command(operation, conn, use_cmd) if more_to_come: request_id = 0 else: - message = await operation.get_message(read_preference, conn, use_cmd) + message = operation.get_message(read_preference, conn, use_cmd) request_id, data, max_doc_size = self._split_message(message) - cmd, dbn = await operation.as_command(conn) if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _COMMAND_LOGGER, @@ -160,7 +188,6 @@ async def run_operation( ) if publish: - cmd, dbn = await operation.as_command(conn) if "$db" not in cmd: cmd["$db"] = dbn assert listeners is not None diff --git a/pymongo/message.py b/pymongo/message.py new file mode 100644 index 0000000000..623003cd09 --- /dev/null +++ b/pymongo/message.py @@ -0,0 +1,1085 @@ +# Copyright 2024-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tools for creating `messages +`_ to be sent to +MongoDB. + +.. note:: This module is for internal use and is generally not needed by + application developers. +""" +from __future__ import annotations + +import random +import struct +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Iterable, + Mapping, + MutableMapping, + NoReturn, + Optional, + Union, +) + +import bson +from bson import CodecOptions, _dict_to_bson +from bson.int64 import Int64 +from bson.raw_bson import ( + _RAW_ARRAY_BSON_OPTIONS, + RawBSONDocument, +) +from pymongo.hello_compat import HelloCompat + +try: + from pymongo import _cmessage # type: ignore[attr-defined] + + _use_c = True +except ImportError: + _use_c = False +from pymongo.errors import ( + ConfigurationError, + CursorNotFound, + DocumentTooLarge, + ExecutionTimeout, + NotPrimaryError, + OperationFailure, + ProtocolError, +) +from pymongo.read_preferences import ReadPreference, _ServerMode + +if TYPE_CHECKING: + from pymongo.asynchronous.client_session import AsyncClientSession + from pymongo.asynchronous.mongo_client import AsyncMongoClient + from pymongo.asynchronous.pool import AsyncConnection + from pymongo.compression_support import SnappyContext, ZlibContext, ZstdContext + from pymongo.read_concern import ReadConcern + from pymongo.typings import _Address, _AgnosticClientSession, _AgnosticConnection + +MAX_INT32 = 2147483647 +MIN_INT32 = -2147483648 + +# Overhead allowed for encoded command documents. +_COMMAND_OVERHEAD = 16382 + +_INSERT = 0 +_UPDATE = 1 +_DELETE = 2 + +_EMPTY = b"" +_BSONOBJ = b"\x03" +_ZERO_8 = b"\x00" +_ZERO_16 = b"\x00\x00" +_ZERO_32 = b"\x00\x00\x00\x00" +_ZERO_64 = b"\x00\x00\x00\x00\x00\x00\x00\x00" +_SKIPLIM = b"\x00\x00\x00\x00\xff\xff\xff\xff" +_OP_MAP = { + _INSERT: b"\x04documents\x00\x00\x00\x00\x00", + _UPDATE: b"\x04updates\x00\x00\x00\x00\x00", + _DELETE: b"\x04deletes\x00\x00\x00\x00\x00", +} +_FIELD_MAP = {"insert": "documents", "update": "updates", "delete": "deletes"} + +_UNICODE_REPLACE_CODEC_OPTIONS: CodecOptions[Mapping[str, Any]] = CodecOptions( + unicode_decode_error_handler="replace" +) + + +def _randint() -> int: + """Generate a pseudo random 32 bit integer.""" + return random.randint(MIN_INT32, MAX_INT32) # noqa: S311 + + +def _maybe_add_read_preference( + spec: MutableMapping[str, Any], read_preference: _ServerMode +) -> MutableMapping[str, Any]: + """Add $readPreference to spec when appropriate.""" + mode = read_preference.mode + document = read_preference.document + # Only add $readPreference if it's something other than primary to avoid + # problems with mongos versions that don't support read preferences. Also, + # for maximum backwards compatibility, don't add $readPreference for + # secondaryPreferred unless tags or maxStalenessSeconds are in use (setting + # the secondaryOkay bit has the same effect). + if mode and (mode != ReadPreference.SECONDARY_PREFERRED.mode or len(document) > 1): + if "$query" not in spec: + spec = {"$query": spec} + spec["$readPreference"] = document + return spec + + +def _convert_exception(exception: Exception) -> dict[str, Any]: + """Convert an Exception into a failure document for publishing.""" + return {"errmsg": str(exception), "errtype": exception.__class__.__name__} + + +def _convert_write_result( + operation: str, command: Mapping[str, Any], result: Mapping[str, Any] +) -> dict[str, Any]: + """Convert a legacy write result to write command format.""" + # Based on _merge_legacy from bulk.py + affected = result.get("n", 0) + res = {"ok": 1, "n": affected} + errmsg = result.get("errmsg", result.get("err", "")) + if errmsg: + # The write was successful on at least the primary so don't return. + if result.get("wtimeout"): + res["writeConcernError"] = {"errmsg": errmsg, "code": 64, "errInfo": {"wtimeout": True}} + else: + # The write failed. + error = {"index": 0, "code": result.get("code", 8), "errmsg": errmsg} + if "errInfo" in result: + error["errInfo"] = result["errInfo"] + res["writeErrors"] = [error] + return res + if operation == "insert": + # GLE result for insert is always 0 in most MongoDB versions. + res["n"] = len(command["documents"]) + elif operation == "update": + if "upserted" in result: + res["upserted"] = [{"index": 0, "_id": result["upserted"]}] + # Versions of MongoDB before 2.6 don't return the _id for an + # upsert if _id is not an ObjectId. + elif result.get("updatedExisting") is False and affected == 1: + # If _id is in both the update document *and* the query spec + # the update document _id takes precedence. + update = command["updates"][0] + _id = update["u"].get("_id", update["q"].get("_id")) + res["upserted"] = [{"index": 0, "_id": _id}] + return res + + +_OPTIONS = { + "tailable": 2, + "oplogReplay": 8, + "noCursorTimeout": 16, + "awaitData": 32, + "allowPartialResults": 128, +} + + +_MODIFIERS = { + "$query": "filter", + "$orderby": "sort", + "$hint": "hint", + "$comment": "comment", + "$maxScan": "maxScan", + "$maxTimeMS": "maxTimeMS", + "$max": "max", + "$min": "min", + "$returnKey": "returnKey", + "$showRecordId": "showRecordId", + "$showDiskLoc": "showRecordId", # <= MongoDb 3.0 + "$snapshot": "snapshot", +} + + +def _gen_find_command( + coll: str, + spec: Mapping[str, Any], + projection: Optional[Union[Mapping[str, Any], Iterable[str]]], + skip: int, + limit: int, + batch_size: Optional[int], + options: Optional[int], + read_concern: ReadConcern, + collation: Optional[Mapping[str, Any]] = None, + session: Optional[_AgnosticClientSession] = None, + allow_disk_use: Optional[bool] = None, +) -> dict[str, Any]: + """Generate a find command document.""" + cmd: dict[str, Any] = {"find": coll} + if "$query" in spec: + cmd.update( + [ + (_MODIFIERS[key], val) if key in _MODIFIERS else (key, val) + for key, val in spec.items() + ] + ) + if "$explain" in cmd: + cmd.pop("$explain") + if "$readPreference" in cmd: + cmd.pop("$readPreference") + else: + cmd["filter"] = spec + + if projection: + cmd["projection"] = projection + if skip: + cmd["skip"] = skip + if limit: + cmd["limit"] = abs(limit) + if limit < 0: + cmd["singleBatch"] = True + if batch_size: + cmd["batchSize"] = batch_size + if read_concern.level and not (session and session.in_transaction): + cmd["readConcern"] = read_concern.document + if collation: + cmd["collation"] = collation + if allow_disk_use is not None: + cmd["allowDiskUse"] = allow_disk_use + if options: + cmd.update([(opt, True) for opt, val in _OPTIONS.items() if options & val]) + + return cmd + + +def _gen_get_more_command( + cursor_id: Optional[int], + coll: str, + batch_size: Optional[int], + max_await_time_ms: Optional[int], + comment: Optional[Any], + conn: _AgnosticConnection, +) -> dict[str, Any]: + """Generate a getMore command document.""" + cmd: dict[str, Any] = {"getMore": cursor_id, "collection": coll} + if batch_size: + cmd["batchSize"] = batch_size + if max_await_time_ms is not None: + cmd["maxTimeMS"] = max_await_time_ms + if comment is not None and conn.max_wire_version >= 9: + cmd["comment"] = comment + return cmd + + +class _OpReply: + """A MongoDB OP_REPLY response message.""" + + __slots__ = ("flags", "cursor_id", "number_returned", "documents") + + UNPACK_FROM = struct.Struct(" list[bytes]: + """Check the response header from the database, without decoding BSON. + + Check the response for errors and unpack. + + Can raise CursorNotFound, NotPrimaryError, ExecutionTimeout, or + OperationFailure. + + :param cursor_id: cursor_id we sent to get this response - + used for raising an informative exception when we get cursor id not + valid at server response. + """ + if self.flags & 1: + # Shouldn't get this response if we aren't doing a getMore + if cursor_id is None: + raise ProtocolError("No cursor id for getMore operation") + + # Fake a getMore command response. OP_GET_MORE provides no + # document. + msg = "Cursor not found, cursor id: %d" % (cursor_id,) + errobj = {"ok": 0, "errmsg": msg, "code": 43} + raise CursorNotFound(msg, 43, errobj) + elif self.flags & 2: + error_object: dict = bson.BSON(self.documents).decode() + # Fake the ok field if it doesn't exist. + error_object.setdefault("ok", 0) + if error_object["$err"].startswith(HelloCompat.LEGACY_ERROR): + raise NotPrimaryError(error_object["$err"], error_object) + elif error_object.get("code") == 50: + default_msg = "operation exceeded time limit" + raise ExecutionTimeout( + error_object.get("$err", default_msg), error_object.get("code"), error_object + ) + raise OperationFailure( + "database error: %s" % error_object.get("$err"), + error_object.get("code"), + error_object, + ) + if self.documents: + return [self.documents] + return [] + + def unpack_response( + self, + cursor_id: Optional[int] = None, + codec_options: CodecOptions = _UNICODE_REPLACE_CODEC_OPTIONS, + user_fields: Optional[Mapping[str, Any]] = None, + legacy_response: bool = False, + ) -> list[dict[str, Any]]: + """Unpack a response from the database and decode the BSON document(s). + + Check the response for errors and unpack, returning a dictionary + containing the response data. + + Can raise CursorNotFound, NotPrimaryError, ExecutionTimeout, or + OperationFailure. + + :param cursor_id: cursor_id we sent to get this response - + used for raising an informative exception when we get cursor id not + valid at server response + :param codec_options: an instance of + :class:`~bson.codec_options.CodecOptions` + :param user_fields: Response fields that should be decoded + using the TypeDecoders from codec_options, passed to + bson._decode_all_selective. + """ + self.raw_response(cursor_id) + if legacy_response: + return bson.decode_all(self.documents, codec_options) + return bson._decode_all_selective(self.documents, codec_options, user_fields) + + def command_response(self, codec_options: CodecOptions) -> dict[str, Any]: + """Unpack a command response.""" + docs = self.unpack_response(codec_options=codec_options) + assert self.number_returned == 1 + return docs[0] + + def raw_command_response(self) -> NoReturn: + """Return the bytes of the command response.""" + # This should never be called on _OpReply. + raise NotImplementedError + + @property + def more_to_come(self) -> bool: + """Is the moreToCome bit set on this response?""" + return False + + @classmethod + def unpack(cls, msg: bytes) -> _OpReply: + """Construct an _OpReply from raw bytes.""" + # PYTHON-945: ignore starting_from field. + flags, cursor_id, _, number_returned = cls.UNPACK_FROM(msg) + + documents = msg[20:] + return cls(flags, cursor_id, number_returned, documents) + + +class _OpMsg: + """A MongoDB OP_MSG response message.""" + + __slots__ = ("flags", "cursor_id", "number_returned", "payload_document") + + UNPACK_FROM = struct.Struct(" list[Mapping[str, Any]]: + """ + cursor_id is ignored + user_fields is used to determine which fields must not be decoded + """ + inflated_response = bson._decode_selective( + RawBSONDocument(self.payload_document), user_fields, _RAW_ARRAY_BSON_OPTIONS + ) + return [inflated_response] + + def unpack_response( + self, + cursor_id: Optional[int] = None, + codec_options: CodecOptions = _UNICODE_REPLACE_CODEC_OPTIONS, + user_fields: Optional[Mapping[str, Any]] = None, + legacy_response: bool = False, + ) -> list[dict[str, Any]]: + """Unpack a OP_MSG command response. + + :param cursor_id: Ignored, for compatibility with _OpReply. + :param codec_options: an instance of + :class:`~bson.codec_options.CodecOptions` + :param user_fields: Response fields that should be decoded + using the TypeDecoders from codec_options, passed to + bson._decode_all_selective. + """ + # If _OpMsg is in-use, this cannot be a legacy response. + assert not legacy_response + return bson._decode_all_selective(self.payload_document, codec_options, user_fields) + + def command_response(self, codec_options: CodecOptions) -> dict[str, Any]: + """Unpack a command response.""" + return self.unpack_response(codec_options=codec_options)[0] + + def raw_command_response(self) -> bytes: + """Return the bytes of the command response.""" + return self.payload_document + + @property + def more_to_come(self) -> bool: + """Is the moreToCome bit set on this response?""" + return bool(self.flags & self.MORE_TO_COME) + + @classmethod + def unpack(cls, msg: bytes) -> _OpMsg: + """Construct an _OpMsg from raw bytes.""" + flags, first_payload_type, first_payload_size = cls.UNPACK_FROM(msg) + if flags != 0: + if flags & cls.CHECKSUM_PRESENT: + raise ProtocolError(f"Unsupported OP_MSG flag checksumPresent: 0x{flags:x}") + + if flags ^ cls.MORE_TO_COME: + raise ProtocolError(f"Unsupported OP_MSG flags: 0x{flags:x}") + if first_payload_type != 0: + raise ProtocolError(f"Unsupported OP_MSG payload type: 0x{first_payload_type:x}") + + if len(msg) != first_payload_size + 5: + raise ProtocolError("Unsupported OP_MSG reply: >1 section") + + payload_document = msg[5:] + return cls(flags, payload_document) + + +_UNPACK_REPLY: dict[int, Callable[[bytes], Union[_OpReply, _OpMsg]]] = { + _OpReply.OP_CODE: _OpReply.unpack, + _OpMsg.OP_CODE: _OpMsg.unpack, +} + + +class _Query: + """A query operation.""" + + __slots__ = ( + "flags", + "db", + "coll", + "ntoskip", + "spec", + "fields", + "codec_options", + "read_preference", + "limit", + "batch_size", + "name", + "read_concern", + "collation", + "session", + "client", + "allow_disk_use", + "_as_command", + "exhaust", + ) + + # For compatibility with the _GetMore class. + conn_mgr = None + cursor_id = None + + def __init__( + self, + flags: int, + db: str, + coll: str, + ntoskip: int, + spec: Mapping[str, Any], + fields: Optional[Mapping[str, Any]], + codec_options: CodecOptions, + read_preference: _ServerMode, + limit: int, + batch_size: int, + read_concern: ReadConcern, + collation: Optional[Mapping[str, Any]], + session: Optional[AsyncClientSession], + client: AsyncMongoClient, + allow_disk_use: Optional[bool], + exhaust: bool, + ): + self.flags = flags + self.db = db + self.coll = coll + self.ntoskip = ntoskip + self.spec = spec + self.fields = fields + self.codec_options = codec_options + self.read_preference = read_preference + self.read_concern = read_concern + self.limit = limit + self.batch_size = batch_size + self.collation = collation + self.session = session + self.client = client + self.allow_disk_use = allow_disk_use + self.name = "find" + self._as_command: Optional[tuple[dict[str, Any], str]] = None + self.exhaust = exhaust + + def reset(self) -> None: + self._as_command = None + + def namespace(self) -> str: + return f"{self.db}.{self.coll}" + + def use_command(self, conn: AsyncConnection) -> bool: + use_find_cmd = False + if not self.exhaust: + use_find_cmd = True + elif conn.max_wire_version >= 8: + # OP_MSG supports exhaust on MongoDB 4.2+ + use_find_cmd = True + elif not self.read_concern.ok_for_legacy: + raise ConfigurationError( + "read concern level of %s is not valid " + "with a max wire version of %d." % (self.read_concern.level, conn.max_wire_version) + ) + + conn.validate_session(self.client, self.session) + return use_find_cmd + + def as_command(self, dummy0: Optional[Any] = None) -> tuple[dict[str, Any], str]: + """Return a find command document for this query.""" + # We use the command twice: on the wire and for command monitoring. + # Generate it once, for speed and to avoid repeating side-effects. + if self._as_command is not None: + return self._as_command + + explain = "$explain" in self.spec + cmd: dict[str, Any] = _gen_find_command( + self.coll, + self.spec, + self.fields, + self.ntoskip, + self.limit, + self.batch_size, + self.flags, + self.read_concern, + self.collation, + self.session, + self.allow_disk_use, + ) + if explain: + self.name = "explain" + cmd = {"explain": cmd} + self._as_command = cmd, self.db + return self._as_command + + def get_message( + self, read_preference: _ServerMode, conn: AsyncConnection, use_cmd: bool = False + ) -> tuple[int, bytes, int]: + """Get a query message, possibly setting the secondaryOk bit.""" + # Use the read_preference decided by _socket_from_server. + self.read_preference = read_preference + if read_preference.mode: + # Set the secondaryOk bit. + flags = self.flags | 4 + else: + flags = self.flags + + ns = self.namespace() + spec = self.spec + + if use_cmd: + spec = self.as_command(None)[0] + request_id, msg, size, _ = _op_msg( + 0, + spec, + self.db, + read_preference, + self.codec_options, + ctx=conn.compression_context, + ) + return request_id, msg, size + + # OP_QUERY treats ntoreturn of -1 and 1 the same, return + # one document and close the cursor. We have to use 2 for + # batch size if 1 is specified. + ntoreturn = self.batch_size == 1 and 2 or self.batch_size + if self.limit: + if ntoreturn: + ntoreturn = min(self.limit, ntoreturn) + else: + ntoreturn = self.limit + + if conn.is_mongos: + assert isinstance(spec, MutableMapping) + spec = _maybe_add_read_preference(spec, read_preference) + + return _query( + flags, + ns, + self.ntoskip, + ntoreturn, + spec, + None if use_cmd else self.fields, + self.codec_options, + ctx=conn.compression_context, + ) + + +class _GetMore: + """A getmore operation.""" + + __slots__ = ( + "db", + "coll", + "ntoreturn", + "cursor_id", + "max_await_time_ms", + "codec_options", + "read_preference", + "session", + "client", + "conn_mgr", + "_as_command", + "exhaust", + "comment", + ) + + name = "getMore" + + def __init__( + self, + db: str, + coll: str, + ntoreturn: int, + cursor_id: int, + codec_options: CodecOptions, + read_preference: _ServerMode, + session: Optional[AsyncClientSession], + client: AsyncMongoClient, + max_await_time_ms: Optional[int], + conn_mgr: Any, + exhaust: bool, + comment: Any, + ): + self.db = db + self.coll = coll + self.ntoreturn = ntoreturn + self.cursor_id = cursor_id + self.codec_options = codec_options + self.read_preference = read_preference + self.session = session + self.client = client + self.max_await_time_ms = max_await_time_ms + self.conn_mgr = conn_mgr + self._as_command: Optional[tuple[dict[str, Any], str]] = None + self.exhaust = exhaust + self.comment = comment + + def reset(self) -> None: + self._as_command = None + + def namespace(self) -> str: + return f"{self.db}.{self.coll}" + + def use_command(self, conn: AsyncConnection) -> bool: + use_cmd = False + if not self.exhaust: + use_cmd = True + elif conn.max_wire_version >= 8: + # OP_MSG supports exhaust on MongoDB 4.2+ + use_cmd = True + + conn.validate_session(self.client, self.session) + return use_cmd + + def as_command(self, conn: AsyncConnection) -> tuple[dict[str, Any], str]: + """Return a getMore command document for this query.""" + # See _Query.as_command for an explanation of this caching. + if self._as_command is not None: + return self._as_command + + cmd: dict[str, Any] = _gen_get_more_command( + self.cursor_id, + self.coll, + self.ntoreturn, + self.max_await_time_ms, + self.comment, + conn, + ) + self._as_command = cmd, self.db + return self._as_command + + def get_message( + self, dummy0: Any, conn: AsyncConnection, use_cmd: bool = False + ) -> Union[tuple[int, bytes, int], tuple[int, bytes]]: + """Get a getmore message.""" + ns = self.namespace() + ctx = conn.compression_context + + if use_cmd: + spec = self.as_command(conn)[0] + if self.conn_mgr and self.exhaust: + flags = _OpMsg.EXHAUST_ALLOWED + else: + flags = 0 + request_id, msg, size, _ = _op_msg( + flags, spec, self.db, None, self.codec_options, ctx=conn.compression_context + ) + return request_id, msg, size + + return _get_more(ns, self.ntoreturn, self.cursor_id, ctx) + + +class _RawBatchQuery(_Query): + def use_command(self, conn: AsyncConnection) -> bool: + # Compatibility checks. + super().use_command(conn) + if conn.max_wire_version >= 8: + # MongoDB 4.2+ supports exhaust over OP_MSG + return True + elif not self.exhaust: + return True + return False + + +class _RawBatchGetMore(_GetMore): + def use_command(self, conn: AsyncConnection) -> bool: + # Compatibility checks. + super().use_command(conn) + if conn.max_wire_version >= 8: + # MongoDB 4.2+ supports exhaust over OP_MSG + return True + elif not self.exhaust: + return True + return False + + +class _CursorAddress(tuple): + """The server address (host, port) of a cursor, with namespace property.""" + + __namespace: Any + + def __new__(cls, address: _Address, namespace: str) -> _CursorAddress: + self = tuple.__new__(cls, address) + self.__namespace = namespace + return self + + @property + def namespace(self) -> str: + """The namespace this cursor.""" + return self.__namespace + + def __hash__(self) -> int: + # Two _CursorAddress instances with different namespaces + # must not hash the same. + return ((*self, self.__namespace)).__hash__() + + def __eq__(self, other: object) -> bool: + if isinstance(other, _CursorAddress): + return tuple(self) == tuple(other) and self.namespace == other.namespace + return NotImplemented + + def __ne__(self, other: object) -> bool: + return not self == other + + +_pack_compression_header = struct.Struct(" tuple[int, bytes]: + """Takes message data, compresses it, and adds an OP_COMPRESSED header.""" + compressed = ctx.compress(data) + request_id = _randint() + + header = _pack_compression_header( + _COMPRESSION_HEADER_SIZE + len(compressed), # Total message length + request_id, # Request id + 0, # responseTo + 2012, # operation id + operation, # original operation id + len(data), # uncompressed message length + ctx.compressor_id, + ) # compressor id + return request_id, header + compressed + + +_pack_header = struct.Struct(" tuple[int, bytes]: + """Takes message data and adds a message header based on the operation. + + Returns the resultant message string. + """ + rid = _randint() + message = _pack_header(16 + len(data), rid, 0, operation) + return rid, message + data + + +_pack_int = struct.Struct(" tuple[bytes, int, int]: + """Get a OP_MSG message. + + Note: this method handles multiple documents in a type one payload but + it does not perform batch splitting and the total message size is + only checked *after* generating the entire message. + """ + # Encode the command document in payload 0 without checking keys. + encoded = _dict_to_bson(command, False, opts) + flags_type = _pack_op_msg_flags_type(flags, 0) + total_size = len(encoded) + max_doc_size = 0 + if identifier and docs is not None: + type_one = _pack_byte(1) + cstring = bson._make_c_string(identifier) + encoded_docs = [_dict_to_bson(doc, False, opts) for doc in docs] + size = len(cstring) + sum(len(doc) for doc in encoded_docs) + 4 + encoded_size = _pack_int(size) + total_size += size + max_doc_size = max(len(doc) for doc in encoded_docs) + data = [flags_type, encoded, type_one, encoded_size, cstring, *encoded_docs] + else: + data = [flags_type, encoded] + return b"".join(data), total_size, max_doc_size + + +def _op_msg_compressed( + flags: int, + command: Mapping[str, Any], + identifier: str, + docs: Optional[list[Mapping[str, Any]]], + opts: CodecOptions, + ctx: Union[SnappyContext, ZlibContext, ZstdContext], +) -> tuple[int, bytes, int, int]: + """Internal OP_MSG message helper.""" + msg, total_size, max_bson_size = _op_msg_no_header(flags, command, identifier, docs, opts) + rid, msg = _compress(2013, msg, ctx) + return rid, msg, total_size, max_bson_size + + +def _op_msg_uncompressed( + flags: int, + command: Mapping[str, Any], + identifier: str, + docs: Optional[list[Mapping[str, Any]]], + opts: CodecOptions, +) -> tuple[int, bytes, int, int]: + """Internal compressed OP_MSG message helper.""" + data, total_size, max_bson_size = _op_msg_no_header(flags, command, identifier, docs, opts) + request_id, op_message = __pack_message(2013, data) + return request_id, op_message, total_size, max_bson_size + + +if _use_c: + _op_msg_uncompressed = _cmessage._op_msg + + +def _op_msg( + flags: int, + command: MutableMapping[str, Any], + dbname: str, + read_preference: Optional[_ServerMode], + opts: CodecOptions, + ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None, +) -> tuple[int, bytes, int, int]: + """Get a OP_MSG message.""" + command["$db"] = dbname + # getMore commands do not send $readPreference. + if read_preference is not None and "$readPreference" not in command: + # Only send $readPreference if it's not primary (the default). + if read_preference.mode: + command["$readPreference"] = read_preference.document + name = next(iter(command)) + try: + identifier = _FIELD_MAP[name] + docs = command.pop(identifier) + except KeyError: + identifier = "" + docs = None + try: + if ctx: + return _op_msg_compressed(flags, command, identifier, docs, opts, ctx) + return _op_msg_uncompressed(flags, command, identifier, docs, opts) + finally: + # Add the field back to the command. + if identifier: + command[identifier] = docs + + +def _query_impl( + options: int, + collection_name: str, + num_to_skip: int, + num_to_return: int, + query: Mapping[str, Any], + field_selector: Optional[Mapping[str, Any]], + opts: CodecOptions, +) -> tuple[bytes, int]: + """Get an OP_QUERY message.""" + encoded = _dict_to_bson(query, False, opts) + if field_selector: + efs = _dict_to_bson(field_selector, False, opts) + else: + efs = b"" + max_bson_size = max(len(encoded), len(efs)) + return ( + b"".join( + [ + _pack_int(options), + bson._make_c_string(collection_name), + _pack_int(num_to_skip), + _pack_int(num_to_return), + encoded, + efs, + ] + ), + max_bson_size, + ) + + +def _query_compressed( + options: int, + collection_name: str, + num_to_skip: int, + num_to_return: int, + query: Mapping[str, Any], + field_selector: Optional[Mapping[str, Any]], + opts: CodecOptions, + ctx: Union[SnappyContext, ZlibContext, ZstdContext], +) -> tuple[int, bytes, int]: + """Internal compressed query message helper.""" + op_query, max_bson_size = _query_impl( + options, collection_name, num_to_skip, num_to_return, query, field_selector, opts + ) + rid, msg = _compress(2004, op_query, ctx) + return rid, msg, max_bson_size + + +def _query_uncompressed( + options: int, + collection_name: str, + num_to_skip: int, + num_to_return: int, + query: Mapping[str, Any], + field_selector: Optional[Mapping[str, Any]], + opts: CodecOptions, +) -> tuple[int, bytes, int]: + """Internal query message helper.""" + op_query, max_bson_size = _query_impl( + options, collection_name, num_to_skip, num_to_return, query, field_selector, opts + ) + rid, msg = __pack_message(2004, op_query) + return rid, msg, max_bson_size + + +if _use_c: + _query_uncompressed = _cmessage._query_message + + +def _query( + options: int, + collection_name: str, + num_to_skip: int, + num_to_return: int, + query: Mapping[str, Any], + field_selector: Optional[Mapping[str, Any]], + opts: CodecOptions, + ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None, +) -> tuple[int, bytes, int]: + """Get a **query** message.""" + if ctx: + return _query_compressed( + options, collection_name, num_to_skip, num_to_return, query, field_selector, opts, ctx + ) + return _query_uncompressed( + options, collection_name, num_to_skip, num_to_return, query, field_selector, opts + ) + + +_pack_long_long = struct.Struct(" bytes: + """Get an OP_GET_MORE message.""" + return b"".join( + [ + _ZERO_32, + bson._make_c_string(collection_name), + _pack_int(num_to_return), + _pack_long_long(cursor_id), + ] + ) + + +def _get_more_compressed( + collection_name: str, + num_to_return: int, + cursor_id: int, + ctx: Union[SnappyContext, ZlibContext, ZstdContext], +) -> tuple[int, bytes]: + """Internal compressed getMore message helper.""" + return _compress(2005, _get_more_impl(collection_name, num_to_return, cursor_id), ctx) + + +def _get_more_uncompressed( + collection_name: str, num_to_return: int, cursor_id: int +) -> tuple[int, bytes]: + """Internal getMore message helper.""" + return __pack_message(2005, _get_more_impl(collection_name, num_to_return, cursor_id)) + + +if _use_c: + _get_more_uncompressed = _cmessage._get_more_message + + +def _get_more( + collection_name: str, + num_to_return: int, + cursor_id: int, + ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None, +) -> tuple[int, bytes]: + """Get a **getMore** message.""" + if ctx: + return _get_more_compressed(collection_name, num_to_return, cursor_id, ctx) + return _get_more_uncompressed(collection_name, num_to_return, cursor_id) + + +def _raise_document_too_large(operation: str, doc_size: int, max_size: int) -> NoReturn: + """Internal helper for raising DocumentTooLarge.""" + if operation == "insert": + raise DocumentTooLarge( + "BSON document too large (%d bytes)" + " - the connected server supports" + " BSON document sizes up to %d" + " bytes." % (doc_size, max_size) + ) + else: + # There's nothing intelligent we can say + # about size for update and delete + raise DocumentTooLarge(f"{operation!r} command document too large") + + +# OP_MSG ------------------------------------------------------------- + + +_OP_MSG_MAP = { + _INSERT: b"documents\x00", + _UPDATE: b"updates\x00", + _DELETE: b"deletes\x00", +} diff --git a/pymongo/response.py b/pymongo/response.py index 99a154efae..850794567c 100644 --- a/pymongo/response.py +++ b/pymongo/response.py @@ -20,7 +20,7 @@ if TYPE_CHECKING: from datetime import timedelta - from pymongo.asynchronous.message import _OpMsg, _OpReply + from pymongo.message import _OpMsg, _OpReply from pymongo.typings import _Address, _AgnosticConnection, _DocumentOut _IS_SYNC = False diff --git a/pymongo/synchronous/bulk.py b/pymongo/synchronous/bulk.py index 8d3d0e10fd..95699ac09e 100644 --- a/pymongo/synchronous/bulk.py +++ b/pymongo/synchronous/bulk.py @@ -47,15 +47,17 @@ OperationFailure, ) from pymongo.helpers_shared import _RETRYABLE_ERROR_CODES, _get_wce_doc -from pymongo.read_preferences import ReadPreference -from pymongo.synchronous.client_session import ClientSession, _validate_session_write_concern -from pymongo.synchronous.message import ( +from pymongo.message import ( _DELETE, _INSERT, _UPDATE, + _randint, +) +from pymongo.read_preferences import ReadPreference +from pymongo.synchronous.client_session import ClientSession, _validate_session_write_concern +from pymongo.synchronous.message import ( _BulkWriteContext, _EncryptedBulkWriteContext, - _randint, ) from pymongo.write_concern import WriteConcern diff --git a/pymongo/synchronous/collection.py b/pymongo/synchronous/collection.py index b8fd39f2d6..9bfc11c900 100644 --- a/pymongo/synchronous/collection.py +++ b/pymongo/synchronous/collection.py @@ -40,7 +40,7 @@ from bson.raw_bson import RawBSONDocument from bson.son import SON from bson.timestamp import Timestamp -from pymongo import ASCENDING, _csot, common, helpers_shared +from pymongo import ASCENDING, _csot, common, helpers_shared, message from pymongo.collation import validate_collation_or_none from pymongo.common import _ecoc_coll_name, _esc_coll_name from pymongo.errors import ( @@ -50,6 +50,7 @@ OperationFailure, ) from pymongo.helpers_shared import _check_write_command_response +from pymongo.message import _UNICODE_REPLACE_CODEC_OPTIONS from pymongo.operations import ( DeleteMany, DeleteOne, @@ -72,7 +73,6 @@ InsertOneResult, UpdateResult, ) -from pymongo.synchronous import message from pymongo.synchronous.aggregation import ( _CollectionAggregationCommand, _CollectionRawAggregationCommand, @@ -87,7 +87,6 @@ Cursor, RawBatchCursor, ) -from pymongo.synchronous.message import _UNICODE_REPLACE_CODEC_OPTIONS from pymongo.typings import _CollationIn, _DocumentType, _DocumentTypeArg, _Pipeline from pymongo.write_concern import DEFAULT_WRITE_CONCERN, WriteConcern, validate_boolean diff --git a/pymongo/synchronous/command_cursor.py b/pymongo/synchronous/command_cursor.py index ba9bf6ef10..e313d552ce 100644 --- a/pymongo/synchronous/command_cursor.py +++ b/pymongo/synchronous/command_cursor.py @@ -31,15 +31,15 @@ from bson import CodecOptions, _convert_raw_document_lists_to_streams from pymongo.cursor_shared import _CURSOR_CLOSED_ERRORS from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure -from pymongo.response import PinnedResponse -from pymongo.synchronous.cursor import _ConnectionManager -from pymongo.synchronous.message import ( +from pymongo.message import ( _CursorAddress, _GetMore, _OpMsg, _OpReply, _RawBatchGetMore, ) +from pymongo.response import PinnedResponse +from pymongo.synchronous.cursor import _ConnectionManager from pymongo.typings import _Address, _DocumentOut, _DocumentType if TYPE_CHECKING: @@ -260,7 +260,7 @@ def _send_message(self, operation: _GetMore) -> None: if isinstance(response, PinnedResponse): if not self._sock_mgr: - self._sock_mgr = _ConnectionManager(response.conn, response.more_to_come) + self._sock_mgr = _ConnectionManager(response.conn, response.more_to_come) # type: ignore[arg-type] if response.from_command: cursor = response.docs[0]["cursor"] documents = cursor["nextBatch"] diff --git a/pymongo/synchronous/cursor.py b/pymongo/synchronous/cursor.py index cacaeb7aad..d0dfb50fdc 100644 --- a/pymongo/synchronous/cursor.py +++ b/pymongo/synchronous/cursor.py @@ -45,9 +45,7 @@ from pymongo.cursor_shared import _CURSOR_CLOSED_ERRORS, _QUERY_OPTIONS, CursorType, _Hint, _Sort from pymongo.errors import ConnectionFailure, InvalidOperation, OperationFailure from pymongo.lock import _create_lock -from pymongo.response import PinnedResponse -from pymongo.synchronous.helpers import next -from pymongo.synchronous.message import ( +from pymongo.message import ( _CursorAddress, _GetMore, _OpMsg, @@ -56,6 +54,8 @@ _RawBatchGetMore, _RawBatchQuery, ) +from pymongo.response import PinnedResponse +from pymongo.synchronous.helpers import next from pymongo.typings import _Address, _CollationIn, _DocumentOut, _DocumentType from pymongo.write_concern import validate_boolean @@ -1102,7 +1102,7 @@ def _send_message(self, operation: Union[_Query, _GetMore]) -> None: self._address = response.address if isinstance(response, PinnedResponse): if not self._sock_mgr: - self._sock_mgr = _ConnectionManager(response.conn, response.more_to_come) + self._sock_mgr = _ConnectionManager(response.conn, response.more_to_come) # type: ignore[arg-type] cmd_name = operation.name docs = response.docs diff --git a/pymongo/synchronous/message.py b/pymongo/synchronous/message.py index 973345f3d2..69e7a8ec49 100644 --- a/pymongo/synchronous/message.py +++ b/pymongo/synchronous/message.py @@ -23,889 +23,70 @@ import datetime import logging -import random -import struct from io import BytesIO as _BytesIO from typing import ( TYPE_CHECKING, Any, - Callable, - Iterable, Mapping, MutableMapping, - NoReturn, Optional, - Union, ) -import bson -from bson import CodecOptions, _decode_selective, _dict_to_bson, _make_c_string, encode -from bson.int64 import Int64 +from bson import CodecOptions, _dict_to_bson, encode from bson.raw_bson import ( - _RAW_ARRAY_BSON_OPTIONS, DEFAULT_RAW_BSON_OPTIONS, - RawBSONDocument, _inflate_bson, ) - -try: - from pymongo import _cmessage # type: ignore[attr-defined] - - _use_c = True -except ImportError: - _use_c = False from pymongo.errors import ( - ConfigurationError, - CursorNotFound, - DocumentTooLarge, - ExecutionTimeout, InvalidOperation, NotPrimaryError, OperationFailure, - ProtocolError, ) -from pymongo.hello_compat import HelloCompat from pymongo.logger import ( _COMMAND_LOGGER, _CommandStatusMessage, _debug_log, ) -from pymongo.read_preferences import ReadPreference +from pymongo.message import ( + _BSONOBJ, + _COMMAND_OVERHEAD, + _FIELD_MAP, + _OP_MAP, + _OP_MSG_MAP, + _SKIPLIM, + _ZERO_8, + _ZERO_16, + _ZERO_32, + _ZERO_64, + _compress, + _convert_exception, + _convert_write_result, + _pack_int, + _raise_document_too_large, + _randint, +) from pymongo.synchronous.helpers import _handle_reauth from pymongo.write_concern import WriteConcern +try: + from pymongo import _cmessage # type: ignore[attr-defined] + + _use_c = True +except ImportError: + _use_c = False + if TYPE_CHECKING: from datetime import timedelta - from pymongo.compression_support import SnappyContext, ZlibContext, ZstdContext from pymongo.monitoring import _EventListeners - from pymongo.read_concern import ReadConcern - from pymongo.read_preferences import _ServerMode from pymongo.synchronous.client_session import ClientSession from pymongo.synchronous.mongo_client import MongoClient from pymongo.synchronous.pool import Connection - from pymongo.typings import _Address, _DocumentOut + from pymongo.typings import _DocumentOut _IS_SYNC = True -MAX_INT32 = 2147483647 -MIN_INT32 = -2147483648 - -# Overhead allowed for encoded command documents. -_COMMAND_OVERHEAD = 16382 - -_INSERT = 0 -_UPDATE = 1 -_DELETE = 2 - -_EMPTY = b"" -_BSONOBJ = b"\x03" -_ZERO_8 = b"\x00" -_ZERO_16 = b"\x00\x00" -_ZERO_32 = b"\x00\x00\x00\x00" -_ZERO_64 = b"\x00\x00\x00\x00\x00\x00\x00\x00" -_SKIPLIM = b"\x00\x00\x00\x00\xff\xff\xff\xff" -_OP_MAP = { - _INSERT: b"\x04documents\x00\x00\x00\x00\x00", - _UPDATE: b"\x04updates\x00\x00\x00\x00\x00", - _DELETE: b"\x04deletes\x00\x00\x00\x00\x00", -} -_FIELD_MAP = {"insert": "documents", "update": "updates", "delete": "deletes"} - -_UNICODE_REPLACE_CODEC_OPTIONS: CodecOptions[Mapping[str, Any]] = CodecOptions( - unicode_decode_error_handler="replace" -) - - -def _randint() -> int: - """Generate a pseudo random 32 bit integer.""" - return random.randint(MIN_INT32, MAX_INT32) # noqa: S311 - - -def _maybe_add_read_preference( - spec: MutableMapping[str, Any], read_preference: _ServerMode -) -> MutableMapping[str, Any]: - """Add $readPreference to spec when appropriate.""" - mode = read_preference.mode - document = read_preference.document - # Only add $readPreference if it's something other than primary to avoid - # problems with mongos versions that don't support read preferences. Also, - # for maximum backwards compatibility, don't add $readPreference for - # secondaryPreferred unless tags or maxStalenessSeconds are in use (setting - # the secondaryOkay bit has the same effect). - if mode and (mode != ReadPreference.SECONDARY_PREFERRED.mode or len(document) > 1): - if "$query" not in spec: - spec = {"$query": spec} - spec["$readPreference"] = document - return spec - - -def _convert_exception(exception: Exception) -> dict[str, Any]: - """Convert an Exception into a failure document for publishing.""" - return {"errmsg": str(exception), "errtype": exception.__class__.__name__} - - -def _convert_write_result( - operation: str, command: Mapping[str, Any], result: Mapping[str, Any] -) -> dict[str, Any]: - """Convert a legacy write result to write command format.""" - # Based on _merge_legacy from bulk.py - affected = result.get("n", 0) - res = {"ok": 1, "n": affected} - errmsg = result.get("errmsg", result.get("err", "")) - if errmsg: - # The write was successful on at least the primary so don't return. - if result.get("wtimeout"): - res["writeConcernError"] = {"errmsg": errmsg, "code": 64, "errInfo": {"wtimeout": True}} - else: - # The write failed. - error = {"index": 0, "code": result.get("code", 8), "errmsg": errmsg} - if "errInfo" in result: - error["errInfo"] = result["errInfo"] - res["writeErrors"] = [error] - return res - if operation == "insert": - # GLE result for insert is always 0 in most MongoDB versions. - res["n"] = len(command["documents"]) - elif operation == "update": - if "upserted" in result: - res["upserted"] = [{"index": 0, "_id": result["upserted"]}] - # Versions of MongoDB before 2.6 don't return the _id for an - # upsert if _id is not an ObjectId. - elif result.get("updatedExisting") is False and affected == 1: - # If _id is in both the update document *and* the query spec - # the update document _id takes precedence. - update = command["updates"][0] - _id = update["u"].get("_id", update["q"].get("_id")) - res["upserted"] = [{"index": 0, "_id": _id}] - return res - - -_OPTIONS = { - "tailable": 2, - "oplogReplay": 8, - "noCursorTimeout": 16, - "awaitData": 32, - "allowPartialResults": 128, -} - - -_MODIFIERS = { - "$query": "filter", - "$orderby": "sort", - "$hint": "hint", - "$comment": "comment", - "$maxScan": "maxScan", - "$maxTimeMS": "maxTimeMS", - "$max": "max", - "$min": "min", - "$returnKey": "returnKey", - "$showRecordId": "showRecordId", - "$showDiskLoc": "showRecordId", # <= MongoDb 3.0 - "$snapshot": "snapshot", -} - - -def _gen_find_command( - coll: str, - spec: Mapping[str, Any], - projection: Optional[Union[Mapping[str, Any], Iterable[str]]], - skip: int, - limit: int, - batch_size: Optional[int], - options: Optional[int], - read_concern: ReadConcern, - collation: Optional[Mapping[str, Any]] = None, - session: Optional[ClientSession] = None, - allow_disk_use: Optional[bool] = None, -) -> dict[str, Any]: - """Generate a find command document.""" - cmd: dict[str, Any] = {"find": coll} - if "$query" in spec: - cmd.update( - [ - (_MODIFIERS[key], val) if key in _MODIFIERS else (key, val) - for key, val in spec.items() - ] - ) - if "$explain" in cmd: - cmd.pop("$explain") - if "$readPreference" in cmd: - cmd.pop("$readPreference") - else: - cmd["filter"] = spec - - if projection: - cmd["projection"] = projection - if skip: - cmd["skip"] = skip - if limit: - cmd["limit"] = abs(limit) - if limit < 0: - cmd["singleBatch"] = True - if batch_size: - cmd["batchSize"] = batch_size - if read_concern.level and not (session and session.in_transaction): - cmd["readConcern"] = read_concern.document - if collation: - cmd["collation"] = collation - if allow_disk_use is not None: - cmd["allowDiskUse"] = allow_disk_use - if options: - cmd.update([(opt, True) for opt, val in _OPTIONS.items() if options & val]) - - return cmd - - -def _gen_get_more_command( - cursor_id: Optional[int], - coll: str, - batch_size: Optional[int], - max_await_time_ms: Optional[int], - comment: Optional[Any], - conn: Connection, -) -> dict[str, Any]: - """Generate a getMore command document.""" - cmd: dict[str, Any] = {"getMore": cursor_id, "collection": coll} - if batch_size: - cmd["batchSize"] = batch_size - if max_await_time_ms is not None: - cmd["maxTimeMS"] = max_await_time_ms - if comment is not None and conn.max_wire_version >= 9: - cmd["comment"] = comment - return cmd - - -class _Query: - """A query operation.""" - - __slots__ = ( - "flags", - "db", - "coll", - "ntoskip", - "spec", - "fields", - "codec_options", - "read_preference", - "limit", - "batch_size", - "name", - "read_concern", - "collation", - "session", - "client", - "allow_disk_use", - "_as_command", - "exhaust", - ) - - # For compatibility with the _GetMore class. - conn_mgr = None - cursor_id = None - - def __init__( - self, - flags: int, - db: str, - coll: str, - ntoskip: int, - spec: Mapping[str, Any], - fields: Optional[Mapping[str, Any]], - codec_options: CodecOptions, - read_preference: _ServerMode, - limit: int, - batch_size: int, - read_concern: ReadConcern, - collation: Optional[Mapping[str, Any]], - session: Optional[ClientSession], - client: MongoClient, - allow_disk_use: Optional[bool], - exhaust: bool, - ): - self.flags = flags - self.db = db - self.coll = coll - self.ntoskip = ntoskip - self.spec = spec - self.fields = fields - self.codec_options = codec_options - self.read_preference = read_preference - self.read_concern = read_concern - self.limit = limit - self.batch_size = batch_size - self.collation = collation - self.session = session - self.client = client - self.allow_disk_use = allow_disk_use - self.name = "find" - self._as_command: Optional[tuple[dict[str, Any], str]] = None - self.exhaust = exhaust - - def reset(self) -> None: - self._as_command = None - - def namespace(self) -> str: - return f"{self.db}.{self.coll}" - - def use_command(self, conn: Connection) -> bool: - use_find_cmd = False - if not self.exhaust: - use_find_cmd = True - elif conn.max_wire_version >= 8: - # OP_MSG supports exhaust on MongoDB 4.2+ - use_find_cmd = True - elif not self.read_concern.ok_for_legacy: - raise ConfigurationError( - "read concern level of %s is not valid " - "with a max wire version of %d." % (self.read_concern.level, conn.max_wire_version) - ) - - conn.validate_session(self.client, self.session) - return use_find_cmd - - def as_command( - self, conn: Connection, apply_timeout: bool = False - ) -> tuple[dict[str, Any], str]: - """Return a find command document for this query.""" - # We use the command twice: on the wire and for command monitoring. - # Generate it once, for speed and to avoid repeating side-effects. - if self._as_command is not None: - return self._as_command - - explain = "$explain" in self.spec - cmd: dict[str, Any] = _gen_find_command( - self.coll, - self.spec, - self.fields, - self.ntoskip, - self.limit, - self.batch_size, - self.flags, - self.read_concern, - self.collation, - self.session, - self.allow_disk_use, - ) - if explain: - self.name = "explain" - cmd = {"explain": cmd} - session = self.session - conn.add_server_api(cmd) - if session: - session._apply_to(cmd, False, self.read_preference, conn) - # Explain does not support readConcern. - if not explain and not session.in_transaction: - session._update_read_concern(cmd, conn) - conn.send_cluster_time(cmd, session, self.client) - # Support auto encryption - client = self.client - if client._encrypter and not client._encrypter._bypass_auto_encryption: - cmd = client._encrypter.encrypt(self.db, cmd, self.codec_options) - # Support CSOT - if apply_timeout: - conn.apply_timeout(client, cmd) - self._as_command = cmd, self.db - return self._as_command - - def get_message( - self, read_preference: _ServerMode, conn: Connection, use_cmd: bool = False - ) -> tuple[int, bytes, int]: - """Get a query message, possibly setting the secondaryOk bit.""" - # Use the read_preference decided by _socket_from_server. - self.read_preference = read_preference - if read_preference.mode: - # Set the secondaryOk bit. - flags = self.flags | 4 - else: - flags = self.flags - - ns = self.namespace() - spec = self.spec - - if use_cmd: - spec = (self.as_command(conn, apply_timeout=True))[0] - request_id, msg, size, _ = _op_msg( - 0, - spec, - self.db, - read_preference, - self.codec_options, - ctx=conn.compression_context, - ) - return request_id, msg, size - - # OP_QUERY treats ntoreturn of -1 and 1 the same, return - # one document and close the cursor. We have to use 2 for - # batch size if 1 is specified. - ntoreturn = self.batch_size == 1 and 2 or self.batch_size - if self.limit: - if ntoreturn: - ntoreturn = min(self.limit, ntoreturn) - else: - ntoreturn = self.limit - - if conn.is_mongos: - assert isinstance(spec, MutableMapping) - spec = _maybe_add_read_preference(spec, read_preference) - - return _query( - flags, - ns, - self.ntoskip, - ntoreturn, - spec, - None if use_cmd else self.fields, - self.codec_options, - ctx=conn.compression_context, - ) - - -class _GetMore: - """A getmore operation.""" - - __slots__ = ( - "db", - "coll", - "ntoreturn", - "cursor_id", - "max_await_time_ms", - "codec_options", - "read_preference", - "session", - "client", - "conn_mgr", - "_as_command", - "exhaust", - "comment", - ) - - name = "getMore" - - def __init__( - self, - db: str, - coll: str, - ntoreturn: int, - cursor_id: int, - codec_options: CodecOptions, - read_preference: _ServerMode, - session: Optional[ClientSession], - client: MongoClient, - max_await_time_ms: Optional[int], - conn_mgr: Any, - exhaust: bool, - comment: Any, - ): - self.db = db - self.coll = coll - self.ntoreturn = ntoreturn - self.cursor_id = cursor_id - self.codec_options = codec_options - self.read_preference = read_preference - self.session = session - self.client = client - self.max_await_time_ms = max_await_time_ms - self.conn_mgr = conn_mgr - self._as_command: Optional[tuple[dict[str, Any], str]] = None - self.exhaust = exhaust - self.comment = comment - - def reset(self) -> None: - self._as_command = None - - def namespace(self) -> str: - return f"{self.db}.{self.coll}" - - def use_command(self, conn: Connection) -> bool: - use_cmd = False - if not self.exhaust: - use_cmd = True - elif conn.max_wire_version >= 8: - # OP_MSG supports exhaust on MongoDB 4.2+ - use_cmd = True - - conn.validate_session(self.client, self.session) - return use_cmd - - def as_command( - self, conn: Connection, apply_timeout: bool = False - ) -> tuple[dict[str, Any], str]: - """Return a getMore command document for this query.""" - # See _Query.as_command for an explanation of this caching. - if self._as_command is not None: - return self._as_command - - cmd: dict[str, Any] = _gen_get_more_command( - self.cursor_id, - self.coll, - self.ntoreturn, - self.max_await_time_ms, - self.comment, - conn, - ) - if self.session: - self.session._apply_to(cmd, False, self.read_preference, conn) - conn.add_server_api(cmd) - conn.send_cluster_time(cmd, self.session, self.client) - # Support auto encryption - client = self.client - if client._encrypter and not client._encrypter._bypass_auto_encryption: - cmd = client._encrypter.encrypt(self.db, cmd, self.codec_options) - # Support CSOT - if apply_timeout: - conn.apply_timeout(client, cmd=None) - self._as_command = cmd, self.db - return self._as_command - - def get_message( - self, dummy0: Any, conn: Connection, use_cmd: bool = False - ) -> Union[tuple[int, bytes, int], tuple[int, bytes]]: - """Get a getmore message.""" - ns = self.namespace() - ctx = conn.compression_context - - if use_cmd: - spec = (self.as_command(conn, apply_timeout=True))[0] - if self.conn_mgr and self.exhaust: - flags = _OpMsg.EXHAUST_ALLOWED - else: - flags = 0 - request_id, msg, size, _ = _op_msg( - flags, spec, self.db, None, self.codec_options, ctx=conn.compression_context - ) - return request_id, msg, size - - return _get_more(ns, self.ntoreturn, self.cursor_id, ctx) - - -class _RawBatchQuery(_Query): - def use_command(self, conn: Connection) -> bool: - # Compatibility checks. - super().use_command(conn) - if conn.max_wire_version >= 8: - # MongoDB 4.2+ supports exhaust over OP_MSG - return True - elif not self.exhaust: - return True - return False - - -class _RawBatchGetMore(_GetMore): - def use_command(self, conn: Connection) -> bool: - # Compatibility checks. - super().use_command(conn) - if conn.max_wire_version >= 8: - # MongoDB 4.2+ supports exhaust over OP_MSG - return True - elif not self.exhaust: - return True - return False - - -class _CursorAddress(tuple): - """The server address (host, port) of a cursor, with namespace property.""" - - __namespace: Any - - def __new__(cls, address: _Address, namespace: str) -> _CursorAddress: - self = tuple.__new__(cls, address) - self.__namespace = namespace - return self - - @property - def namespace(self) -> str: - """The namespace this cursor.""" - return self.__namespace - - def __hash__(self) -> int: - # Two _CursorAddress instances with different namespaces - # must not hash the same. - return ((*self, self.__namespace)).__hash__() - - def __eq__(self, other: object) -> bool: - if isinstance(other, _CursorAddress): - return tuple(self) == tuple(other) and self.namespace == other.namespace - return NotImplemented - - def __ne__(self, other: object) -> bool: - return not self == other - - -_pack_compression_header = struct.Struct(" tuple[int, bytes]: - """Takes message data, compresses it, and adds an OP_COMPRESSED header.""" - compressed = ctx.compress(data) - request_id = _randint() - - header = _pack_compression_header( - _COMPRESSION_HEADER_SIZE + len(compressed), # Total message length - request_id, # Request id - 0, # responseTo - 2012, # operation id - operation, # original operation id - len(data), # uncompressed message length - ctx.compressor_id, - ) # compressor id - return request_id, header + compressed - - -_pack_header = struct.Struct(" tuple[int, bytes]: - """Takes message data and adds a message header based on the operation. - - Returns the resultant message string. - """ - rid = _randint() - message = _pack_header(16 + len(data), rid, 0, operation) - return rid, message + data - - -_pack_int = struct.Struct(" tuple[bytes, int, int]: - """Get a OP_MSG message. - - Note: this method handles multiple documents in a type one payload but - it does not perform batch splitting and the total message size is - only checked *after* generating the entire message. - """ - # Encode the command document in payload 0 without checking keys. - encoded = _dict_to_bson(command, False, opts) - flags_type = _pack_op_msg_flags_type(flags, 0) - total_size = len(encoded) - max_doc_size = 0 - if identifier and docs is not None: - type_one = _pack_byte(1) - cstring = _make_c_string(identifier) - encoded_docs = [_dict_to_bson(doc, False, opts) for doc in docs] - size = len(cstring) + sum(len(doc) for doc in encoded_docs) + 4 - encoded_size = _pack_int(size) - total_size += size - max_doc_size = max(len(doc) for doc in encoded_docs) - data = [flags_type, encoded, type_one, encoded_size, cstring, *encoded_docs] - else: - data = [flags_type, encoded] - return b"".join(data), total_size, max_doc_size - - -def _op_msg_compressed( - flags: int, - command: Mapping[str, Any], - identifier: str, - docs: Optional[list[Mapping[str, Any]]], - opts: CodecOptions, - ctx: Union[SnappyContext, ZlibContext, ZstdContext], -) -> tuple[int, bytes, int, int]: - """Internal OP_MSG message helper.""" - msg, total_size, max_bson_size = _op_msg_no_header(flags, command, identifier, docs, opts) - rid, msg = _compress(2013, msg, ctx) - return rid, msg, total_size, max_bson_size - - -def _op_msg_uncompressed( - flags: int, - command: Mapping[str, Any], - identifier: str, - docs: Optional[list[Mapping[str, Any]]], - opts: CodecOptions, -) -> tuple[int, bytes, int, int]: - """Internal compressed OP_MSG message helper.""" - data, total_size, max_bson_size = _op_msg_no_header(flags, command, identifier, docs, opts) - request_id, op_message = __pack_message(2013, data) - return request_id, op_message, total_size, max_bson_size - - -if _use_c: - _op_msg_uncompressed = _cmessage._op_msg - - -def _op_msg( - flags: int, - command: MutableMapping[str, Any], - dbname: str, - read_preference: Optional[_ServerMode], - opts: CodecOptions, - ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None, -) -> tuple[int, bytes, int, int]: - """Get a OP_MSG message.""" - command["$db"] = dbname - # getMore commands do not send $readPreference. - if read_preference is not None and "$readPreference" not in command: - # Only send $readPreference if it's not primary (the default). - if read_preference.mode: - command["$readPreference"] = read_preference.document - name = next(iter(command)) - try: - identifier = _FIELD_MAP[name] - docs = command.pop(identifier) - except KeyError: - identifier = "" - docs = None - try: - if ctx: - return _op_msg_compressed(flags, command, identifier, docs, opts, ctx) - return _op_msg_uncompressed(flags, command, identifier, docs, opts) - finally: - # Add the field back to the command. - if identifier: - command[identifier] = docs - - -def _query_impl( - options: int, - collection_name: str, - num_to_skip: int, - num_to_return: int, - query: Mapping[str, Any], - field_selector: Optional[Mapping[str, Any]], - opts: CodecOptions, -) -> tuple[bytes, int]: - """Get an OP_QUERY message.""" - encoded = _dict_to_bson(query, False, opts) - if field_selector: - efs = _dict_to_bson(field_selector, False, opts) - else: - efs = b"" - max_bson_size = max(len(encoded), len(efs)) - return ( - b"".join( - [ - _pack_int(options), - _make_c_string(collection_name), - _pack_int(num_to_skip), - _pack_int(num_to_return), - encoded, - efs, - ] - ), - max_bson_size, - ) - - -def _query_compressed( - options: int, - collection_name: str, - num_to_skip: int, - num_to_return: int, - query: Mapping[str, Any], - field_selector: Optional[Mapping[str, Any]], - opts: CodecOptions, - ctx: Union[SnappyContext, ZlibContext, ZstdContext], -) -> tuple[int, bytes, int]: - """Internal compressed query message helper.""" - op_query, max_bson_size = _query_impl( - options, collection_name, num_to_skip, num_to_return, query, field_selector, opts - ) - rid, msg = _compress(2004, op_query, ctx) - return rid, msg, max_bson_size - - -def _query_uncompressed( - options: int, - collection_name: str, - num_to_skip: int, - num_to_return: int, - query: Mapping[str, Any], - field_selector: Optional[Mapping[str, Any]], - opts: CodecOptions, -) -> tuple[int, bytes, int]: - """Internal query message helper.""" - op_query, max_bson_size = _query_impl( - options, collection_name, num_to_skip, num_to_return, query, field_selector, opts - ) - rid, msg = __pack_message(2004, op_query) - return rid, msg, max_bson_size - - -if _use_c: - _query_uncompressed = _cmessage._query_message - - -def _query( - options: int, - collection_name: str, - num_to_skip: int, - num_to_return: int, - query: Mapping[str, Any], - field_selector: Optional[Mapping[str, Any]], - opts: CodecOptions, - ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None, -) -> tuple[int, bytes, int]: - """Get a **query** message.""" - if ctx: - return _query_compressed( - options, collection_name, num_to_skip, num_to_return, query, field_selector, opts, ctx - ) - return _query_uncompressed( - options, collection_name, num_to_skip, num_to_return, query, field_selector, opts - ) - - -_pack_long_long = struct.Struct(" bytes: - """Get an OP_GET_MORE message.""" - return b"".join( - [ - _ZERO_32, - _make_c_string(collection_name), - _pack_int(num_to_return), - _pack_long_long(cursor_id), - ] - ) - - -def _get_more_compressed( - collection_name: str, - num_to_return: int, - cursor_id: int, - ctx: Union[SnappyContext, ZlibContext, ZstdContext], -) -> tuple[int, bytes]: - """Internal compressed getMore message helper.""" - return _compress(2005, _get_more_impl(collection_name, num_to_return, cursor_id), ctx) - - -def _get_more_uncompressed( - collection_name: str, num_to_return: int, cursor_id: int -) -> tuple[int, bytes]: - """Internal getMore message helper.""" - return __pack_message(2005, _get_more_impl(collection_name, num_to_return, cursor_id)) - - -if _use_c: - _get_more_uncompressed = _cmessage._get_more_message - - -def _get_more( - collection_name: str, - num_to_return: int, - cursor_id: int, - ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None, -) -> tuple[int, bytes]: - """Get a **getMore** message.""" - if ctx: - return _get_more_compressed(collection_name, num_to_return, cursor_id, ctx) - return _get_more_uncompressed(collection_name, num_to_return, cursor_id) - class _BulkWriteContext: """A wrapper around AsyncConnection for use with write splitting functions.""" @@ -1273,31 +454,6 @@ def max_split_size(self) -> int: return _MAX_SPLIT_SIZE_ENC -def _raise_document_too_large(operation: str, doc_size: int, max_size: int) -> NoReturn: - """Internal helper for raising DocumentTooLarge.""" - if operation == "insert": - raise DocumentTooLarge( - "BSON document too large (%d bytes)" - " - the connected server supports" - " BSON document sizes up to %d" - " bytes." % (doc_size, max_size) - ) - else: - # There's nothing intelligent we can say - # about size for update and delete - raise DocumentTooLarge(f"{operation!r} command document too large") - - -# OP_MSG ------------------------------------------------------------- - - -_OP_MSG_MAP = { - _INSERT: b"documents\x00", - _UPDATE: b"updates\x00", - _DELETE: b"deletes\x00", -} - - def _batched_op_msg_impl( operation: int, command: Mapping[str, Any], @@ -1555,206 +711,3 @@ def _batched_write_command_impl( buf.write(_pack_int(length - command_start)) return to_send, length - - -class _OpReply: - """A MongoDB OP_REPLY response message.""" - - __slots__ = ("flags", "cursor_id", "number_returned", "documents") - - UNPACK_FROM = struct.Struct(" list[bytes]: - """Check the response header from the database, without decoding BSON. - - Check the response for errors and unpack. - - Can raise CursorNotFound, NotPrimaryError, ExecutionTimeout, or - OperationFailure. - - :param cursor_id: cursor_id we sent to get this response - - used for raising an informative exception when we get cursor id not - valid at server response. - """ - if self.flags & 1: - # Shouldn't get this response if we aren't doing a getMore - if cursor_id is None: - raise ProtocolError("No cursor id for getMore operation") - - # Fake a getMore command response. OP_GET_MORE provides no - # document. - msg = "Cursor not found, cursor id: %d" % (cursor_id,) - errobj = {"ok": 0, "errmsg": msg, "code": 43} - raise CursorNotFound(msg, 43, errobj) - elif self.flags & 2: - error_object: dict = bson.BSON(self.documents).decode() - # Fake the ok field if it doesn't exist. - error_object.setdefault("ok", 0) - if error_object["$err"].startswith(HelloCompat.LEGACY_ERROR): - raise NotPrimaryError(error_object["$err"], error_object) - elif error_object.get("code") == 50: - default_msg = "operation exceeded time limit" - raise ExecutionTimeout( - error_object.get("$err", default_msg), error_object.get("code"), error_object - ) - raise OperationFailure( - "database error: %s" % error_object.get("$err"), - error_object.get("code"), - error_object, - ) - if self.documents: - return [self.documents] - return [] - - def unpack_response( - self, - cursor_id: Optional[int] = None, - codec_options: CodecOptions = _UNICODE_REPLACE_CODEC_OPTIONS, - user_fields: Optional[Mapping[str, Any]] = None, - legacy_response: bool = False, - ) -> list[dict[str, Any]]: - """Unpack a response from the database and decode the BSON document(s). - - Check the response for errors and unpack, returning a dictionary - containing the response data. - - Can raise CursorNotFound, NotPrimaryError, ExecutionTimeout, or - OperationFailure. - - :param cursor_id: cursor_id we sent to get this response - - used for raising an informative exception when we get cursor id not - valid at server response - :param codec_options: an instance of - :class:`~bson.codec_options.CodecOptions` - :param user_fields: Response fields that should be decoded - using the TypeDecoders from codec_options, passed to - bson._decode_all_selective. - """ - self.raw_response(cursor_id) - if legacy_response: - return bson.decode_all(self.documents, codec_options) - return bson._decode_all_selective(self.documents, codec_options, user_fields) - - def command_response(self, codec_options: CodecOptions) -> dict[str, Any]: - """Unpack a command response.""" - docs = self.unpack_response(codec_options=codec_options) - assert self.number_returned == 1 - return docs[0] - - def raw_command_response(self) -> NoReturn: - """Return the bytes of the command response.""" - # This should never be called on _OpReply. - raise NotImplementedError - - @property - def more_to_come(self) -> bool: - """Is the moreToCome bit set on this response?""" - return False - - @classmethod - def unpack(cls, msg: bytes) -> _OpReply: - """Construct an _OpReply from raw bytes.""" - # PYTHON-945: ignore starting_from field. - flags, cursor_id, _, number_returned = cls.UNPACK_FROM(msg) - - documents = msg[20:] - return cls(flags, cursor_id, number_returned, documents) - - -class _OpMsg: - """A MongoDB OP_MSG response message.""" - - __slots__ = ("flags", "cursor_id", "number_returned", "payload_document") - - UNPACK_FROM = struct.Struct(" list[Mapping[str, Any]]: - """ - cursor_id is ignored - user_fields is used to determine which fields must not be decoded - """ - inflated_response = _decode_selective( - RawBSONDocument(self.payload_document), user_fields, _RAW_ARRAY_BSON_OPTIONS - ) - return [inflated_response] - - def unpack_response( - self, - cursor_id: Optional[int] = None, - codec_options: CodecOptions = _UNICODE_REPLACE_CODEC_OPTIONS, - user_fields: Optional[Mapping[str, Any]] = None, - legacy_response: bool = False, - ) -> list[dict[str, Any]]: - """Unpack a OP_MSG command response. - - :param cursor_id: Ignored, for compatibility with _OpReply. - :param codec_options: an instance of - :class:`~bson.codec_options.CodecOptions` - :param user_fields: Response fields that should be decoded - using the TypeDecoders from codec_options, passed to - bson._decode_all_selective. - """ - # If _OpMsg is in-use, this cannot be a legacy response. - assert not legacy_response - return bson._decode_all_selective(self.payload_document, codec_options, user_fields) - - def command_response(self, codec_options: CodecOptions) -> dict[str, Any]: - """Unpack a command response.""" - return self.unpack_response(codec_options=codec_options)[0] - - def raw_command_response(self) -> bytes: - """Return the bytes of the command response.""" - return self.payload_document - - @property - def more_to_come(self) -> bool: - """Is the moreToCome bit set on this response?""" - return bool(self.flags & self.MORE_TO_COME) - - @classmethod - def unpack(cls, msg: bytes) -> _OpMsg: - """Construct an _OpMsg from raw bytes.""" - flags, first_payload_type, first_payload_size = cls.UNPACK_FROM(msg) - if flags != 0: - if flags & cls.CHECKSUM_PRESENT: - raise ProtocolError(f"Unsupported OP_MSG flag checksumPresent: 0x{flags:x}") - - if flags ^ cls.MORE_TO_COME: - raise ProtocolError(f"Unsupported OP_MSG flags: 0x{flags:x}") - if first_payload_type != 0: - raise ProtocolError(f"Unsupported OP_MSG payload type: 0x{first_payload_type:x}") - - if len(msg) != first_payload_size + 5: - raise ProtocolError("Unsupported OP_MSG reply: >1 section") - - payload_document = msg[5:] - return cls(flags, payload_document) - - -_UNPACK_REPLY: dict[int, Callable[[bytes], Union[_OpReply, _OpMsg]]] = { - _OpReply.OP_CODE: _OpReply.unpack, - _OpMsg.OP_CODE: _OpMsg.unpack, -} diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index bd15409ecb..72b81485b4 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -74,12 +74,13 @@ ) from pymongo.lock import _HAS_REGISTER_AT_FORK, _create_lock, _release_locks from pymongo.logger import _CLIENT_LOGGER, _log_or_warn +from pymongo.message import _CursorAddress, _GetMore, _Query from pymongo.monitoring import ConnectionClosedReason from pymongo.operations import _Op from pymongo.read_preferences import ReadPreference, _ServerMode from pymongo.server_selectors import writable_server_selector from pymongo.server_type import SERVER_TYPE -from pymongo.synchronous import client_session, database, message, periodic_executor +from pymongo.synchronous import client_session, database, periodic_executor from pymongo.synchronous.change_stream import ChangeStream, ClusterChangeStream from pymongo.synchronous.client_session import _EmptyServerSession from pymongo.synchronous.command_cursor import CommandCursor @@ -113,7 +114,6 @@ from pymongo.synchronous.bulk import _Bulk from pymongo.synchronous.client_session import ClientSession, _ServerSession from pymongo.synchronous.cursor import _ConnectionManager - from pymongo.synchronous.message import _CursorAddress, _GetMore, _Query from pymongo.synchronous.pool import Connection from pymongo.synchronous.server import Server @@ -1703,7 +1703,7 @@ def _cmd( operation.read_preference, operation.session, address=address, - retryable=isinstance(operation, message._Query), + retryable=isinstance(operation, _Query), operation=operation.name, ) diff --git a/pymongo/synchronous/network.py b/pymongo/synchronous/network.py index cdfb60e202..c1978087a9 100644 --- a/pymongo/synchronous/network.py +++ b/pymongo/synchronous/network.py @@ -32,7 +32,7 @@ ) from bson import _decode_all_selective -from pymongo import _csot, helpers_shared +from pymongo import _csot, helpers_shared, message from pymongo.common import MAX_MESSAGE_SIZE from pymongo.compression_support import _NO_COMPRESSION, decompress from pymongo.errors import ( @@ -42,6 +42,7 @@ _OperationCancelled, ) from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log +from pymongo.message import _UNPACK_REPLY, _OpMsg, _OpReply from pymongo.monitoring import _is_speculative_authenticate from pymongo.network_layer import ( _POLL_TIMEOUT, @@ -51,8 +52,6 @@ sendall, ) from pymongo.socket_checker import _errno_from_exception -from pymongo.synchronous import message -from pymongo.synchronous.message import _UNPACK_REPLY, _OpMsg, _OpReply if TYPE_CHECKING: from bson import CodecOptions diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index 1637406ee5..197409e84a 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -91,12 +91,12 @@ ZlibContext, ZstdContext, ) + from pymongo.message import _OpMsg, _OpReply from pymongo.pyopenssl_context import _sslConn from pymongo.read_concern import ReadConcern from pymongo.read_preferences import _ServerMode from pymongo.synchronous.auth import _AuthContext from pymongo.synchronous.client_session import ClientSession - from pymongo.synchronous.message import _OpMsg, _OpReply from pymongo.synchronous.mongo_client import MongoClient, _MongoClientErrorHandler from pymongo.typings import ClusterTime, _Address, _CollationIn from pymongo.write_concern import WriteConcern diff --git a/pymongo/synchronous/response.py b/pymongo/synchronous/response.py deleted file mode 100644 index 03b88fc77d..0000000000 --- a/pymongo/synchronous/response.py +++ /dev/null @@ -1,133 +0,0 @@ -# Copyright 2014-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Represent a response from the server.""" -from __future__ import annotations - -from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence, Union - -if TYPE_CHECKING: - from datetime import timedelta - - from pymongo.synchronous.message import _OpMsg, _OpReply - from pymongo.synchronous.pool import Connection - from pymongo.typings import _Address, _DocumentOut - -_IS_SYNC = True - - -class Response: - __slots__ = ("_data", "_address", "_request_id", "_duration", "_from_command", "_docs") - - def __init__( - self, - data: Union[_OpMsg, _OpReply], - address: _Address, - request_id: int, - duration: Optional[timedelta], - from_command: bool, - docs: Sequence[Mapping[str, Any]], - ): - """Represent a response from the server. - - :param data: A network response message. - :param address: (host, port) of the source server. - :param request_id: The request id of this operation. - :param duration: The duration of the operation. - :param from_command: if the response is the result of a db command. - """ - self._data = data - self._address = address - self._request_id = request_id - self._duration = duration - self._from_command = from_command - self._docs = docs - - @property - def data(self) -> Union[_OpMsg, _OpReply]: - """Server response's raw BSON bytes.""" - return self._data - - @property - def address(self) -> _Address: - """(host, port) of the source server.""" - return self._address - - @property - def request_id(self) -> int: - """The request id of this operation.""" - return self._request_id - - @property - def duration(self) -> Optional[timedelta]: - """The duration of the operation.""" - return self._duration - - @property - def from_command(self) -> bool: - """If the response is a result from a db command.""" - return self._from_command - - @property - def docs(self) -> Sequence[Mapping[str, Any]]: - """The decoded document(s).""" - return self._docs - - -class PinnedResponse(Response): - __slots__ = ("_conn", "_more_to_come") - - def __init__( - self, - data: Union[_OpMsg, _OpReply], - address: _Address, - conn: Connection, - request_id: int, - duration: Optional[timedelta], - from_command: bool, - docs: list[_DocumentOut], - more_to_come: bool, - ): - """Represent a response to an exhaust cursor's initial query. - - :param data: A network response message. - :param address: (host, port) of the source server. - :param conn: The AsyncConnection used for the initial query. - :param request_id: The request id of this operation. - :param duration: The duration of the operation. - :param from_command: If the response is the result of a db command. - :param docs: List of documents. - :param more_to_come: Bool indicating whether cursor is ready to be - exhausted. - """ - super().__init__(data, address, request_id, duration, from_command, docs) - self._conn = conn - self._more_to_come = more_to_come - - @property - def conn(self) -> Connection: - """The AsyncConnection used for the initial query. - - The server will send batches on this socket, without waiting for - getMores from the client, until the result set is exhausted or there - is an error. - """ - return self._conn - - @property - def more_to_come(self) -> bool: - """If true, server is ready to send batches on the socket until the - result set is exhausted or there is an error. - """ - return self._more_to_come diff --git a/pymongo/synchronous/server.py b/pymongo/synchronous/server.py index 883c802e07..4f9f1c1462 100644 --- a/pymongo/synchronous/server.py +++ b/pymongo/synchronous/server.py @@ -30,9 +30,9 @@ from pymongo.errors import NotPrimaryError, OperationFailure from pymongo.helpers_shared import _check_command_response from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log +from pymongo.message import _convert_exception, _GetMore, _OpMsg, _Query from pymongo.response import PinnedResponse, Response from pymongo.synchronous.helpers import _handle_reauth -from pymongo.synchronous.message import _convert_exception, _GetMore, _OpMsg, _Query if TYPE_CHECKING: from queue import Queue @@ -106,6 +106,32 @@ def request_check(self) -> None: """Check the server's state soon.""" self._monitor.request_check() + def operation_to_command( + self, operation: Union[_Query, _GetMore], conn: Connection, apply_timeout: bool = False + ) -> tuple[dict[str, Any], str]: + is_query = isinstance(operation, _Query) + if is_query: + explain = "$explain" in operation.spec + cmd, db = operation.as_command() + else: + explain = False + cmd, db = operation.as_command(conn) + if operation.session: + operation.session._apply_to(cmd, False, operation.read_preference, conn) + # Explain does not support readConcern. + if is_query and not explain and not operation.session.in_transaction: + operation.session._update_read_concern(cmd, conn) + # Support auto encryption + if operation.client._encrypter and not operation.client._encrypter._bypass_auto_encryption: + cmd = operation.client._encrypter.encrypt(operation.db, cmd, operation.codec_options) + + conn.add_server_api(cmd) + conn.send_cluster_time(cmd, operation.session, operation.client) + # Support CSOT + if apply_timeout: + conn.apply_timeout(operation.client, cmd=cmd if is_query else None) + return cmd, db + @_handle_reauth def run_operation( self, @@ -122,26 +148,26 @@ def run_operation( cursors. Can raise ConnectionFailure, OperationFailure, etc. - :param conn: A AsyncConnection instance. + :param conn: An AsyncConnection instance. :param operation: A _Query or _GetMore object. :param read_preference: The read preference to use. :param listeners: Instance of _EventListeners or None. :param unpack_res: A callable that decodes the wire protocol response. + :param client: An AsyncMongoClient instance. """ - duration = None assert listeners is not None publish = listeners.enabled_for_commands start = datetime.now() use_cmd = operation.use_command(conn) more_to_come = operation.conn_mgr and operation.conn_mgr.more_to_come + cmd, dbn = self.operation_to_command(operation, conn, use_cmd) if more_to_come: request_id = 0 else: message = operation.get_message(read_preference, conn, use_cmd) request_id, data, max_doc_size = self._split_message(message) - cmd, dbn = operation.as_command(conn) if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _COMMAND_LOGGER, @@ -160,7 +186,6 @@ def run_operation( ) if publish: - cmd, dbn = operation.as_command(conn) if "$db" not in cmd: cmd["$db"] = dbn assert listeners is not None diff --git a/test/test_change_stream.py b/test/test_change_stream.py index 4d8422667f..e00aaa6403 100644 --- a/test/test_change_stream.py +++ b/test/test_change_stream.py @@ -45,9 +45,9 @@ OperationFailure, ServerSelectionTimeoutError, ) +from pymongo.message import _CursorAddress from pymongo.read_concern import ReadConcern from pymongo.synchronous.command_cursor import CommandCursor -from pymongo.synchronous.message import _CursorAddress from pymongo.write_concern import WriteConcern diff --git a/test/test_client.py b/test/test_client.py index e21a899d02..b5c438a66f 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -83,7 +83,7 @@ ) from bson.son import SON from bson.tz_util import utc -from pymongo import event_loggers, monitoring +from pymongo import event_loggers, message, monitoring from pymongo.client_options import ClientOptions from pymongo.common import _UUID_REPRESENTATIONS, CONNECT_TIMEOUT from pymongo.compression_support import _have_snappy, _have_zstd @@ -106,7 +106,7 @@ from pymongo.server_description import ServerDescription from pymongo.server_selectors import readable_server_selector, writable_server_selector from pymongo.server_type import SERVER_TYPE -from pymongo.synchronous import message +from pymongo.synchronous import message as message_old from pymongo.synchronous.command_cursor import CommandCursor from pymongo.synchronous.cursor import Cursor, CursorType from pymongo.synchronous.database import Database @@ -1458,7 +1458,7 @@ def test_stale_getmore(self): with self.assertRaises(AutoReconnect): client = rs_client(connect=False, serverSelectionTimeoutMS=100) client._run_operation( - operation=message._GetMore( + operation=message_old._GetMore( "pymongo_test", "collection", 101, diff --git a/test/test_collection.py b/test/test_collection.py index 5495e659b4..7f0e16e0aa 100644 --- a/test/test_collection.py +++ b/test/test_collection.py @@ -57,6 +57,7 @@ OperationFailure, WriteConcernError, ) +from pymongo.message import _COMMAND_OVERHEAD, _gen_find_command from pymongo.operations import * from pymongo.read_concern import DEFAULT_READ_CONCERN from pymongo.read_preferences import ReadPreference @@ -69,7 +70,6 @@ from pymongo.synchronous.bulk import BulkWriteError from pymongo.synchronous.collection import Collection, ReturnDocument from pymongo.synchronous.command_cursor import CommandCursor -from pymongo.synchronous.message import _COMMAND_OVERHEAD, _gen_find_command from pymongo.synchronous.mongo_client import MongoClient from pymongo.write_concern import WriteConcern diff --git a/test/test_custom_types.py b/test/test_custom_types.py index d946eee173..7daf83244d 100644 --- a/test/test_custom_types.py +++ b/test/test_custom_types.py @@ -53,8 +53,8 @@ from bson.raw_bson import RawBSONDocument from gridfs import GridIn, GridOut from pymongo.errors import DuplicateKeyError +from pymongo.message import _CursorAddress from pymongo.synchronous.collection import ReturnDocument -from pymongo.synchronous.message import _CursorAddress class DecimalEncoder(TypeEncoder): diff --git a/test/test_grid_file.py b/test/test_grid_file.py index c45c5b5771..f663f13653 100644 --- a/test/test_grid_file.py +++ b/test/test_grid_file.py @@ -42,7 +42,7 @@ ) from pymongo import MongoClient from pymongo.errors import ConfigurationError, ServerSelectionTimeoutError -from pymongo.synchronous.message import _CursorAddress +from pymongo.message import _CursorAddress class TestGridFileNoConnect(unittest.TestCase): diff --git a/test/test_pooling.py b/test/test_pooling.py index da68f04e78..3cc544d2ea 100644 --- a/test/test_pooling.py +++ b/test/test_pooling.py @@ -24,10 +24,9 @@ from bson.codec_options import DEFAULT_CODEC_OPTIONS from bson.son import SON -from pymongo import MongoClient, timeout +from pymongo import MongoClient, message, timeout from pymongo.errors import AutoReconnect, ConnectionFailure, DuplicateKeyError from pymongo.hello_compat import HelloCompat -from pymongo.synchronous import message sys.path[0:0] = [""] diff --git a/test/test_read_preferences.py b/test/test_read_preferences.py index 76ad14dcce..0c0e24946f 100644 --- a/test/test_read_preferences.py +++ b/test/test_read_preferences.py @@ -39,6 +39,7 @@ from bson.son import SON from pymongo.errors import ConfigurationError, OperationFailure +from pymongo.message import _maybe_add_read_preference from pymongo.read_preferences import ( MovingAverage, Nearest, @@ -51,7 +52,6 @@ from pymongo.server_description import ServerDescription from pymongo.server_selectors import Selection, readable_server_selector from pymongo.server_type import SERVER_TYPE -from pymongo.synchronous.message import _maybe_add_read_preference from pymongo.synchronous.mongo_client import MongoClient from pymongo.write_concern import WriteConcern From a5ca4292792991e700dea4071af85e9fe4c8064f Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Wed, 12 Jun 2024 09:19:41 -0700 Subject: [PATCH 03/11] Move message out of asynchronous --- pymongo/asynchronous/bulk.py | 347 ++++++++---- pymongo/asynchronous/client_session.py | 4 +- pymongo/asynchronous/helpers.py | 4 +- pymongo/asynchronous/message.py | 713 ------------------------- pymongo/asynchronous/mongo_client.py | 6 +- pymongo/asynchronous/server.py | 27 +- pymongo/bulk_shared.py | 131 +++++ pymongo/message.py | 486 ++++++++++++++++- pymongo/synchronous/bulk.py | 347 ++++++++---- pymongo/synchronous/helpers.py | 4 +- pymongo/synchronous/message.py | 713 ------------------------- pymongo/synchronous/mongo_client.py | 6 +- pymongo/synchronous/server.py | 29 +- pymongo/typings.py | 190 ++++++- test/__init__.py | 3 +- test/asynchronous/test_collection.py | 2 +- test/auth_aws/test_auth_aws.py | 2 +- test/auth_oidc/test_auth_oidc.py | 4 +- test/synchronous/test_collection.py | 2 +- test/test_client.py | 3 +- test/test_collection.py | 2 +- 21 files changed, 1293 insertions(+), 1732 deletions(-) delete mode 100644 pymongo/asynchronous/message.py create mode 100644 pymongo/bulk_shared.py delete mode 100644 pymongo/synchronous/message.py diff --git a/pymongo/asynchronous/bulk.py b/pymongo/asynchronous/bulk.py index 285ac821f1..725596ab6d 100644 --- a/pymongo/asynchronous/bulk.py +++ b/pymongo/asynchronous/bulk.py @@ -19,6 +19,8 @@ from __future__ import annotations import copy +import datetime +import logging from collections.abc import MutableMapping from itertools import islice from typing import ( @@ -26,7 +28,6 @@ Any, Iterator, Mapping, - NoReturn, Optional, Type, Union, @@ -36,9 +37,13 @@ from bson.raw_bson import RawBSONDocument from pymongo import _csot, common from pymongo.asynchronous.client_session import AsyncClientSession, _validate_session_write_concern -from pymongo.asynchronous.message import ( - _BulkWriteContext, - _EncryptedBulkWriteContext, +from pymongo.asynchronous.helpers import _handle_reauth +from pymongo.bulk_shared import ( + _COMMANDS, + _DELETE_ALL, + _merge_command, + _raise_bulk_write_error, + _Run, ) from pymongo.common import ( validate_is_document_type, @@ -46,16 +51,21 @@ validate_ok_for_update, ) from pymongo.errors import ( - BulkWriteError, ConfigurationError, InvalidOperation, + NotPrimaryError, OperationFailure, ) -from pymongo.helpers_shared import _RETRYABLE_ERROR_CODES, _get_wce_doc +from pymongo.helpers_shared import _RETRYABLE_ERROR_CODES +from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log from pymongo.message import ( _DELETE, _INSERT, _UPDATE, + _BulkWriteContext, + _convert_exception, + _convert_write_result, + _EncryptedBulkWriteContext, _randint, ) from pymongo.read_preferences import ReadPreference @@ -63,111 +73,12 @@ if TYPE_CHECKING: from pymongo.asynchronous.collection import AsyncCollection + from pymongo.asynchronous.mongo_client import AsyncMongoClient from pymongo.asynchronous.pool import AsyncConnection from pymongo.typings import _DocumentOut, _DocumentType, _Pipeline _IS_SYNC = False -_DELETE_ALL: int = 0 -_DELETE_ONE: int = 1 - -# For backwards compatibility. See MongoDB src/mongo/base/error_codes.err -_BAD_VALUE: int = 2 -_UNKNOWN_ERROR: int = 8 -_WRITE_CONCERN_ERROR: int = 64 - -_COMMANDS: tuple[str, str, str] = ("insert", "update", "delete") - - -class _Run: - """Represents a batch of write operations.""" - - def __init__(self, op_type: int) -> None: - """Initialize a new Run object.""" - self.op_type: int = op_type - self.index_map: list[int] = [] - self.ops: list[Any] = [] - self.idx_offset: int = 0 - - def index(self, idx: int) -> int: - """Get the original index of an operation in this run. - - :param idx: The Run index that maps to the original index. - """ - return self.index_map[idx] - - def add(self, original_index: int, operation: Any) -> None: - """Add an operation to this Run instance. - - :param original_index: The original index of this operation - within a larger bulk operation. - :param operation: The operation document. - """ - self.index_map.append(original_index) - self.ops.append(operation) - - -def _merge_command( - run: _Run, - full_result: MutableMapping[str, Any], - offset: int, - result: Mapping[str, Any], -) -> None: - """Merge a write command result into the full bulk result.""" - affected = result.get("n", 0) - - if run.op_type == _INSERT: - full_result["nInserted"] += affected - - elif run.op_type == _DELETE: - full_result["nRemoved"] += affected - - elif run.op_type == _UPDATE: - upserted = result.get("upserted") - if upserted: - n_upserted = len(upserted) - for doc in upserted: - doc["index"] = run.index(doc["index"] + offset) - full_result["upserted"].extend(upserted) - full_result["nUpserted"] += n_upserted - full_result["nMatched"] += affected - n_upserted - else: - full_result["nMatched"] += affected - full_result["nModified"] += result["nModified"] - - write_errors = result.get("writeErrors") - if write_errors: - for doc in write_errors: - # Leave the server response intact for APM. - replacement = doc.copy() - idx = doc["index"] + offset - replacement["index"] = run.index(idx) - # Add the failed operation to the error document. - replacement["op"] = run.ops[idx] - full_result["writeErrors"].append(replacement) - - wce = _get_wce_doc(result) - if wce: - full_result["writeConcernErrors"].append(wce) - - -def _raise_bulk_write_error(full_result: _DocumentOut) -> NoReturn: - """Raise a BulkWriteError from the full bulk api result.""" - # retryWrites on MMAPv1 should raise an actionable error. - if full_result["writeErrors"]: - full_result["writeErrors"].sort(key=lambda error: error["index"]) - err = full_result["writeErrors"][0] - code = err["code"] - msg = err["errmsg"] - if code == 20 and msg.startswith("Transaction numbers"): - errmsg = ( - "This MongoDB deployment does not support " - "retryable writes. Please add retryWrites=false " - "to your connection string." - ) - raise OperationFailure(errmsg, code, full_result) - raise BulkWriteError(full_result) - class _AsyncBulk: """The private guts of the bulk write API.""" @@ -204,13 +115,16 @@ def __init__( # Extra state so that we know where to pick up on a retry attempt. self.current_run = None self.next_run = None + self.is_encrypted = False @property def bulk_ctx_class(self) -> Type[_BulkWriteContext]: encrypter = self.collection.database.client._encrypter if encrypter and not encrypter._bypass_auto_encryption: + self.is_encrypted = True return _EncryptedBulkWriteContext else: + self.is_encrypted = False return _BulkWriteContext def add_insert(self, document: _DocumentOut) -> None: @@ -315,6 +229,180 @@ def gen_unordered(self) -> Iterator[_Run]: if run.ops: yield run + @_handle_reauth + async def write_command( + self, + bwc: _BulkWriteContext, + cmd: MutableMapping[str, Any], + request_id: int, + msg: bytes, + docs: list[Mapping[str, Any]], + client: AsyncMongoClient, + ) -> dict[str, Any]: + """A proxy for SocketInfo.write_command that handles event publishing.""" + cmd[bwc.field] = docs + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + clientId=client._topology_settings._topology_id, + message=_CommandStatusMessage.STARTED, + command=cmd, + commandName=next(iter(cmd)), + databaseName=bwc.db_name, + requestId=request_id, + operationId=request_id, + driverConnectionId=bwc.conn.id, + serverConnectionId=bwc.conn.server_connection_id, + serverHost=bwc.conn.address[0], + serverPort=bwc.conn.address[1], + serviceId=bwc.conn.service_id, + ) + if bwc.publish: + bwc._start(cmd, request_id, docs) + try: + reply = await bwc.conn.write_command(request_id, msg, bwc.codec) # type: ignore[misc] + duration = datetime.datetime.now() - bwc.start_time + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + clientId=client._topology_settings._topology_id, + message=_CommandStatusMessage.SUCCEEDED, + durationMS=duration, + reply=reply, + commandName=next(iter(cmd)), + databaseName=bwc.db_name, + requestId=request_id, + operationId=request_id, + driverConnectionId=bwc.conn.id, + serverConnectionId=bwc.conn.server_connection_id, + serverHost=bwc.conn.address[0], + serverPort=bwc.conn.address[1], + serviceId=bwc.conn.service_id, + ) + if bwc.publish: + bwc._succeed(request_id, reply, duration) # type: ignore[arg-type] + except Exception as exc: + duration = datetime.datetime.now() - bwc.start_time + if isinstance(exc, (NotPrimaryError, OperationFailure)): + failure: _DocumentOut = exc.details # type: ignore[assignment] + else: + failure = _convert_exception(exc) + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + clientId=client._topology_settings._topology_id, + message=_CommandStatusMessage.FAILED, + durationMS=duration, + failure=failure, + commandName=next(iter(cmd)), + databaseName=bwc.db_name, + requestId=request_id, + operationId=request_id, + driverConnectionId=bwc.conn.id, + serverConnectionId=bwc.conn.server_connection_id, + serverHost=bwc.conn.address[0], + serverPort=bwc.conn.address[1], + serviceId=bwc.conn.service_id, + isServerSideError=isinstance(exc, OperationFailure), + ) + + if bwc.publish: + bwc._fail(request_id, failure, duration) + raise + finally: + bwc.start_time = datetime.datetime.now() + return reply # type: ignore[return-value] + + async def unack_write( + self, + bwc: _BulkWriteContext, + cmd: MutableMapping[str, Any], + request_id: int, + msg: bytes, + max_doc_size: int, + docs: list[Mapping[str, Any]], + client: AsyncMongoClient, + ) -> Optional[Mapping[str, Any]]: + """A proxy for AsyncConnection.unack_write that handles event publishing.""" + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + clientId=client._topology_settings._topology_id, + message=_CommandStatusMessage.STARTED, + command=cmd, + commandName=next(iter(cmd)), + databaseName=bwc.db_name, + requestId=request_id, + operationId=request_id, + driverConnectionId=bwc.conn.id, + serverConnectionId=bwc.conn.server_connection_id, + serverHost=bwc.conn.address[0], + serverPort=bwc.conn.address[1], + serviceId=bwc.conn.service_id, + ) + if bwc.publish: + cmd = bwc._start(cmd, request_id, docs) + try: + result = await bwc.conn.unack_write(msg, max_doc_size) # type: ignore[func-returns-value, misc, override] + duration = datetime.datetime.now() - bwc.start_time + if result is not None: + reply = _convert_write_result(bwc.name, cmd, result) # type: ignore[arg-type] + else: + # Comply with APM spec. + reply = {"ok": 1} + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + clientId=client._topology_settings._topology_id, + message=_CommandStatusMessage.SUCCEEDED, + durationMS=duration, + reply=reply, + commandName=next(iter(cmd)), + databaseName=bwc.db_name, + requestId=request_id, + operationId=request_id, + driverConnectionId=bwc.conn.id, + serverConnectionId=bwc.conn.server_connection_id, + serverHost=bwc.conn.address[0], + serverPort=bwc.conn.address[1], + serviceId=bwc.conn.service_id, + ) + if bwc.publish: + bwc._succeed(request_id, reply, duration) + except Exception as exc: + duration = datetime.datetime.now() - bwc.start_time + if isinstance(exc, OperationFailure): + failure: _DocumentOut = _convert_write_result(bwc.name, cmd, exc.details) # type: ignore[arg-type] + elif isinstance(exc, NotPrimaryError): + failure = exc.details # type: ignore[assignment] + else: + failure = _convert_exception(exc) + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + clientId=client._topology_settings._topology_id, + message=_CommandStatusMessage.FAILED, + durationMS=duration, + failure=failure, + commandName=next(iter(cmd)), + databaseName=bwc.db_name, + requestId=request_id, + operationId=request_id, + driverConnectionId=bwc.conn.id, + serverConnectionId=bwc.conn.server_connection_id, + serverHost=bwc.conn.address[0], + serverPort=bwc.conn.address[1], + serviceId=bwc.conn.service_id, + isServerSideError=isinstance(exc, OperationFailure), + ) + if bwc.publish: + assert bwc.start_time is not None + bwc._fail(request_id, failure, duration) + raise + finally: + bwc.start_time = datetime.datetime.now() + return result # type: ignore[return-value] + async def _execute_command( self, generator: Iterator[Any], @@ -387,7 +475,21 @@ async def _execute_command( # Run as many ops as possible in one command. if write_concern.acknowledged: - result, to_send = await bwc.execute(cmd, ops, client) + if self.is_encrypted: + _, batched_cmd, to_send = bwc.batch_command(cmd, ops) + result = await bwc.conn.command( + bwc.db_name, + batched_cmd, + codec_options=bwc.codec, + session=bwc.session, + client=client, + ) + else: + request_id, msg, to_send = bwc.batch_command(cmd, ops) + result = await self.write_command( + bwc, cmd, request_id, msg, to_send, client + ) + await client._process_response(result, bwc.session) # Retryable writeConcernErrors halt the execution of this run. wce = result.get("writeConcernError", {}) @@ -407,7 +509,23 @@ async def _execute_command( if self.ordered and "writeErrors" in result: break else: - to_send = await bwc.execute_unack(cmd, ops, client) + if self.is_encrypted: + _, batched_cmd, to_send = bwc.batch_command(cmd, ops) + await bwc.conn.command( + bwc.db_name, + batched_cmd, + write_concern=WriteConcern(w=0), + session=bwc.session, + client=client, + ) + else: + request_id, msg, to_send = bwc.batch_command(cmd, ops) + # Though this isn't strictly a "legacy" write, the helper + # handles publishing commands and sending our message + # without receiving a result. Send 0 for max_doc_size + # to disable size checking. Size checking is handled while + # the documents are encoded to BSON. + await self.unack_write(bwc, cmd, request_id, msg, 0, to_send, client) run.idx_offset += len(to_send) @@ -501,7 +619,18 @@ async def execute_op_msg_no_results( conn.add_server_api(cmd) ops = islice(run.ops, run.idx_offset, None) # Run as many ops as possible. - to_send = await bwc.execute_unack(cmd, ops, client) + if self.is_encrypted: + _, batched_cmd, to_send = bwc.batch_command(cmd, ops) + await bwc.conn.command( + bwc.db_name, + batched_cmd, + write_concern=WriteConcern(w=0), + session=bwc.session, + client=client, + ) + else: + request_id, msg, to_send = bwc.batch_command(cmd, ops) + await self.unack_write(bwc, cmd, request_id, msg, 0, to_send, client) run.idx_offset += len(to_send) self.current_run = run = next(generator, None) diff --git a/pymongo/asynchronous/client_session.py b/pymongo/asynchronous/client_session.py index d5931013f7..9773742fca 100644 --- a/pymongo/asynchronous/client_session.py +++ b/pymongo/asynchronous/client_session.py @@ -554,14 +554,14 @@ def options(self) -> SessionOptions: return self._options @property - async def session_id(self) -> Mapping[str, Any]: + def session_id(self) -> Mapping[str, Any]: """A BSON document, the opaque server session identifier.""" self._check_ended() self._materialize(self._client.topology_description.logical_session_timeout_minutes) return self._server_session.session_id @property - async def _transaction_id(self) -> Int64: + def _transaction_id(self) -> Int64: """The current transaction id for the underlying server session.""" self._materialize(self._client.topology_description.logical_session_timeout_minutes) return self._server_session.transaction_id diff --git a/pymongo/asynchronous/helpers.py b/pymongo/asynchronous/helpers.py index c939bfabe1..7531783214 100644 --- a/pymongo/asynchronous/helpers.py +++ b/pymongo/asynchronous/helpers.py @@ -38,8 +38,8 @@ def _handle_reauth(func: F) -> F: async def inner(*args: Any, **kwargs: Any) -> Any: no_reauth = kwargs.pop("no_reauth", False) - from pymongo.asynchronous.message import _BulkWriteContext from pymongo.asynchronous.pool import AsyncConnection + from pymongo.message import _BulkWriteContext try: return await func(*args, **kwargs) @@ -56,7 +56,7 @@ async def inner(*args: Any, **kwargs: Any) -> Any: conn = arg break if isinstance(arg, _BulkWriteContext): - conn = arg.conn + conn = arg.conn # type: ignore[assignment] break if conn: await conn.authenticate(reauthenticate=True) diff --git a/pymongo/asynchronous/message.py b/pymongo/asynchronous/message.py deleted file mode 100644 index 9a41b3475e..0000000000 --- a/pymongo/asynchronous/message.py +++ /dev/null @@ -1,713 +0,0 @@ -# Copyright 2009-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tools for creating `messages -`_ to be sent to -MongoDB. - -.. note:: This module is for internal use and is generally not needed by - application developers. -""" -from __future__ import annotations - -import datetime -import logging -from io import BytesIO as _BytesIO -from typing import ( - TYPE_CHECKING, - Any, - Mapping, - MutableMapping, - Optional, -) - -from bson import CodecOptions, _dict_to_bson, encode -from bson.raw_bson import ( - DEFAULT_RAW_BSON_OPTIONS, - _inflate_bson, -) -from pymongo.asynchronous.helpers import _handle_reauth -from pymongo.errors import ( - InvalidOperation, - NotPrimaryError, - OperationFailure, -) -from pymongo.logger import ( - _COMMAND_LOGGER, - _CommandStatusMessage, - _debug_log, -) -from pymongo.message import ( - _BSONOBJ, - _COMMAND_OVERHEAD, - _FIELD_MAP, - _OP_MAP, - _OP_MSG_MAP, - _SKIPLIM, - _ZERO_8, - _ZERO_16, - _ZERO_32, - _ZERO_64, - _compress, - _convert_exception, - _convert_write_result, - _pack_int, - _raise_document_too_large, - _randint, -) -from pymongo.write_concern import WriteConcern - -try: - from pymongo import _cmessage # type: ignore[attr-defined] - - _use_c = True -except ImportError: - _use_c = False - -if TYPE_CHECKING: - from datetime import timedelta - - from pymongo.asynchronous.client_session import AsyncClientSession - from pymongo.asynchronous.mongo_client import AsyncMongoClient - from pymongo.asynchronous.pool import AsyncConnection - from pymongo.monitoring import _EventListeners - from pymongo.typings import _DocumentOut - - -_IS_SYNC = False - - -class _BulkWriteContext: - """A wrapper around AsyncConnection for use with write splitting functions.""" - - __slots__ = ( - "db_name", - "conn", - "op_id", - "name", - "field", - "publish", - "start_time", - "listeners", - "session", - "compress", - "op_type", - "codec", - ) - - def __init__( - self, - database_name: str, - cmd_name: str, - conn: AsyncConnection, - operation_id: int, - listeners: _EventListeners, - session: AsyncClientSession, - op_type: int, - codec: CodecOptions, - ): - self.db_name = database_name - self.conn = conn - self.op_id = operation_id - self.listeners = listeners - self.publish = listeners.enabled_for_commands - self.name = cmd_name - self.field = _FIELD_MAP[self.name] - self.start_time = datetime.datetime.now() - self.session = session - self.compress = bool(conn.compression_context) - self.op_type = op_type - self.codec = codec - - def __batch_command( - self, cmd: MutableMapping[str, Any], docs: list[Mapping[str, Any]] - ) -> tuple[int, bytes, list[Mapping[str, Any]]]: - namespace = self.db_name + ".$cmd" - request_id, msg, to_send = _do_batched_op_msg( - namespace, self.op_type, cmd, docs, self.codec, self - ) - if not to_send: - raise InvalidOperation("cannot do an empty bulk write") - return request_id, msg, to_send - - async def execute( - self, cmd: MutableMapping[str, Any], docs: list[Mapping[str, Any]], client: AsyncMongoClient - ) -> tuple[Mapping[str, Any], list[Mapping[str, Any]]]: - request_id, msg, to_send = self.__batch_command(cmd, docs) - result = await self.write_command(cmd, request_id, msg, to_send, client) - await client._process_response(result, self.session) - return result, to_send - - async def execute_unack( - self, cmd: MutableMapping[str, Any], docs: list[Mapping[str, Any]], client: AsyncMongoClient - ) -> list[Mapping[str, Any]]: - request_id, msg, to_send = self.__batch_command(cmd, docs) - # Though this isn't strictly a "legacy" write, the helper - # handles publishing commands and sending our message - # without receiving a result. Send 0 for max_doc_size - # to disable size checking. Size checking is handled while - # the documents are encoded to BSON. - await self.unack_write(cmd, request_id, msg, 0, to_send, client) - return to_send - - @property - def max_bson_size(self) -> int: - """A proxy for SockInfo.max_bson_size.""" - return self.conn.max_bson_size - - @property - def max_message_size(self) -> int: - """A proxy for SockInfo.max_message_size.""" - if self.compress: - # Subtract 16 bytes for the message header. - return self.conn.max_message_size - 16 - return self.conn.max_message_size - - @property - def max_write_batch_size(self) -> int: - """A proxy for SockInfo.max_write_batch_size.""" - return self.conn.max_write_batch_size - - @property - def max_split_size(self) -> int: - """The maximum size of a BSON command before batch splitting.""" - return self.max_bson_size - - async def unack_write( - self, - cmd: MutableMapping[str, Any], - request_id: int, - msg: bytes, - max_doc_size: int, - docs: list[Mapping[str, Any]], - client: AsyncMongoClient, - ) -> Optional[Mapping[str, Any]]: - """A proxy for AsyncConnection.unack_write that handles event publishing.""" - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - clientId=client._topology_settings._topology_id, - message=_CommandStatusMessage.STARTED, - command=cmd, - commandName=next(iter(cmd)), - databaseName=self.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=self.conn.id, - serverConnectionId=self.conn.server_connection_id, - serverHost=self.conn.address[0], - serverPort=self.conn.address[1], - serviceId=self.conn.service_id, - ) - if self.publish: - cmd = self._start(cmd, request_id, docs) - try: - result = await self.conn.unack_write(msg, max_doc_size) # type: ignore[func-returns-value] - duration = datetime.datetime.now() - self.start_time - if result is not None: - reply = _convert_write_result(self.name, cmd, result) - else: - # Comply with APM spec. - reply = {"ok": 1} - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - clientId=client._topology_settings._topology_id, - message=_CommandStatusMessage.SUCCEEDED, - durationMS=duration, - reply=reply, - commandName=next(iter(cmd)), - databaseName=self.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=self.conn.id, - serverConnectionId=self.conn.server_connection_id, - serverHost=self.conn.address[0], - serverPort=self.conn.address[1], - serviceId=self.conn.service_id, - ) - if self.publish: - self._succeed(request_id, reply, duration) - except Exception as exc: - duration = datetime.datetime.now() - self.start_time - if isinstance(exc, OperationFailure): - failure: _DocumentOut = _convert_write_result(self.name, cmd, exc.details) # type: ignore[arg-type] - elif isinstance(exc, NotPrimaryError): - failure = exc.details # type: ignore[assignment] - else: - failure = _convert_exception(exc) - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - clientId=client._topology_settings._topology_id, - message=_CommandStatusMessage.FAILED, - durationMS=duration, - failure=failure, - commandName=next(iter(cmd)), - databaseName=self.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=self.conn.id, - serverConnectionId=self.conn.server_connection_id, - serverHost=self.conn.address[0], - serverPort=self.conn.address[1], - serviceId=self.conn.service_id, - isServerSideError=isinstance(exc, OperationFailure), - ) - if self.publish: - assert self.start_time is not None - self._fail(request_id, failure, duration) - raise - finally: - self.start_time = datetime.datetime.now() - return result - - @_handle_reauth - async def write_command( - self, - cmd: MutableMapping[str, Any], - request_id: int, - msg: bytes, - docs: list[Mapping[str, Any]], - client: AsyncMongoClient, - ) -> dict[str, Any]: - """A proxy for SocketInfo.write_command that handles event publishing.""" - cmd[self.field] = docs - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - clientId=client._topology_settings._topology_id, - message=_CommandStatusMessage.STARTED, - command=cmd, - commandName=next(iter(cmd)), - databaseName=self.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=self.conn.id, - serverConnectionId=self.conn.server_connection_id, - serverHost=self.conn.address[0], - serverPort=self.conn.address[1], - serviceId=self.conn.service_id, - ) - if self.publish: - self._start(cmd, request_id, docs) - try: - reply = await self.conn.write_command(request_id, msg, self.codec) - duration = datetime.datetime.now() - self.start_time - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - clientId=client._topology_settings._topology_id, - message=_CommandStatusMessage.SUCCEEDED, - durationMS=duration, - reply=reply, - commandName=next(iter(cmd)), - databaseName=self.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=self.conn.id, - serverConnectionId=self.conn.server_connection_id, - serverHost=self.conn.address[0], - serverPort=self.conn.address[1], - serviceId=self.conn.service_id, - ) - if self.publish: - self._succeed(request_id, reply, duration) - except Exception as exc: - duration = datetime.datetime.now() - self.start_time - if isinstance(exc, (NotPrimaryError, OperationFailure)): - failure: _DocumentOut = exc.details # type: ignore[assignment] - else: - failure = _convert_exception(exc) - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - clientId=client._topology_settings._topology_id, - message=_CommandStatusMessage.FAILED, - durationMS=duration, - failure=failure, - commandName=next(iter(cmd)), - databaseName=self.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=self.conn.id, - serverConnectionId=self.conn.server_connection_id, - serverHost=self.conn.address[0], - serverPort=self.conn.address[1], - serviceId=self.conn.service_id, - isServerSideError=isinstance(exc, OperationFailure), - ) - - if self.publish: - self._fail(request_id, failure, duration) - raise - finally: - self.start_time = datetime.datetime.now() - return reply - - def _start( - self, cmd: MutableMapping[str, Any], request_id: int, docs: list[Mapping[str, Any]] - ) -> MutableMapping[str, Any]: - """Publish a CommandStartedEvent.""" - cmd[self.field] = docs - self.listeners.publish_command_start( - cmd, - self.db_name, - request_id, - self.conn.address, - self.conn.server_connection_id, - self.op_id, - self.conn.service_id, - ) - return cmd - - def _succeed(self, request_id: int, reply: _DocumentOut, duration: timedelta) -> None: - """Publish a CommandSucceededEvent.""" - self.listeners.publish_command_success( - duration, - reply, - self.name, - request_id, - self.conn.address, - self.conn.server_connection_id, - self.op_id, - self.conn.service_id, - database_name=self.db_name, - ) - - def _fail(self, request_id: int, failure: _DocumentOut, duration: timedelta) -> None: - """Publish a CommandFailedEvent.""" - self.listeners.publish_command_failure( - duration, - failure, - self.name, - request_id, - self.conn.address, - self.conn.server_connection_id, - self.op_id, - self.conn.service_id, - database_name=self.db_name, - ) - - -# From the Client Side Encryption spec: -# Because automatic encryption increases the size of commands, the driver -# MUST split bulk writes at a reduced size limit before undergoing automatic -# encryption. The write payload MUST be split at 2MiB (2097152). -_MAX_SPLIT_SIZE_ENC = 2097152 - - -class _EncryptedBulkWriteContext(_BulkWriteContext): - __slots__ = () - - def __batch_command( - self, cmd: MutableMapping[str, Any], docs: list[Mapping[str, Any]] - ) -> tuple[dict[str, Any], list[Mapping[str, Any]]]: - namespace = self.db_name + ".$cmd" - msg, to_send = _encode_batched_write_command( - namespace, self.op_type, cmd, docs, self.codec, self - ) - if not to_send: - raise InvalidOperation("cannot do an empty bulk write") - - # Chop off the OP_QUERY header to get a properly batched write command. - cmd_start = msg.index(b"\x00", 4) + 9 - outgoing = _inflate_bson(memoryview(msg)[cmd_start:], DEFAULT_RAW_BSON_OPTIONS) - return outgoing, to_send - - async def execute( - self, cmd: MutableMapping[str, Any], docs: list[Mapping[str, Any]], client: AsyncMongoClient - ) -> tuple[Mapping[str, Any], list[Mapping[str, Any]]]: - batched_cmd, to_send = self.__batch_command(cmd, docs) - result: Mapping[str, Any] = await self.conn.command( - self.db_name, batched_cmd, codec_options=self.codec, session=self.session, client=client - ) - return result, to_send - - async def execute_unack( - self, cmd: MutableMapping[str, Any], docs: list[Mapping[str, Any]], client: AsyncMongoClient - ) -> list[Mapping[str, Any]]: - batched_cmd, to_send = self.__batch_command(cmd, docs) - await self.conn.command( - self.db_name, - batched_cmd, - write_concern=WriteConcern(w=0), - session=self.session, - client=client, - ) - return to_send - - @property - def max_split_size(self) -> int: - """Reduce the batch splitting size.""" - return _MAX_SPLIT_SIZE_ENC - - -def _batched_op_msg_impl( - operation: int, - command: Mapping[str, Any], - docs: list[Mapping[str, Any]], - ack: bool, - opts: CodecOptions, - ctx: _BulkWriteContext, - buf: _BytesIO, -) -> tuple[list[Mapping[str, Any]], int]: - """Create a batched OP_MSG write.""" - max_bson_size = ctx.max_bson_size - max_write_batch_size = ctx.max_write_batch_size - max_message_size = ctx.max_message_size - - flags = b"\x00\x00\x00\x00" if ack else b"\x02\x00\x00\x00" - # Flags - buf.write(flags) - - # Type 0 Section - buf.write(b"\x00") - buf.write(_dict_to_bson(command, False, opts)) - - # Type 1 Section - buf.write(b"\x01") - size_location = buf.tell() - # Save space for size - buf.write(b"\x00\x00\x00\x00") - try: - buf.write(_OP_MSG_MAP[operation]) - except KeyError: - raise InvalidOperation("Unknown command") from None - - to_send = [] - idx = 0 - for doc in docs: - # Encode the current operation - value = _dict_to_bson(doc, False, opts) - doc_length = len(value) - new_message_size = buf.tell() + doc_length - # Does first document exceed max_message_size? - doc_too_large = idx == 0 and (new_message_size > max_message_size) - # When OP_MSG is used unacknowledged we have to check - # document size client side or applications won't be notified. - # Otherwise we let the server deal with documents that are too large - # since ordered=False causes those documents to be skipped instead of - # halting the bulk write operation. - unacked_doc_too_large = not ack and (doc_length > max_bson_size) - if doc_too_large or unacked_doc_too_large: - write_op = list(_FIELD_MAP.keys())[operation] - _raise_document_too_large(write_op, len(value), max_bson_size) - # We have enough data, return this batch. - if new_message_size > max_message_size: - break - buf.write(value) - to_send.append(doc) - idx += 1 - # We have enough documents, return this batch. - if idx == max_write_batch_size: - break - - # Write type 1 section size - length = buf.tell() - buf.seek(size_location) - buf.write(_pack_int(length - size_location)) - - return to_send, length - - -def _encode_batched_op_msg( - operation: int, - command: Mapping[str, Any], - docs: list[Mapping[str, Any]], - ack: bool, - opts: CodecOptions, - ctx: _BulkWriteContext, -) -> tuple[bytes, list[Mapping[str, Any]]]: - """Encode the next batched insert, update, or delete operation - as OP_MSG. - """ - buf = _BytesIO() - - to_send, _ = _batched_op_msg_impl(operation, command, docs, ack, opts, ctx, buf) - return buf.getvalue(), to_send - - -if _use_c: - _encode_batched_op_msg = _cmessage._encode_batched_op_msg - - -def _batched_op_msg_compressed( - operation: int, - command: Mapping[str, Any], - docs: list[Mapping[str, Any]], - ack: bool, - opts: CodecOptions, - ctx: _BulkWriteContext, -) -> tuple[int, bytes, list[Mapping[str, Any]]]: - """Create the next batched insert, update, or delete operation - with OP_MSG, compressed. - """ - data, to_send = _encode_batched_op_msg(operation, command, docs, ack, opts, ctx) - - assert ctx.conn.compression_context is not None - request_id, msg = _compress(2013, data, ctx.conn.compression_context) - return request_id, msg, to_send - - -def _batched_op_msg( - operation: int, - command: Mapping[str, Any], - docs: list[Mapping[str, Any]], - ack: bool, - opts: CodecOptions, - ctx: _BulkWriteContext, -) -> tuple[int, bytes, list[Mapping[str, Any]]]: - """OP_MSG implementation entry point.""" - buf = _BytesIO() - - # Save space for message length and request id - buf.write(_ZERO_64) - # responseTo, opCode - buf.write(b"\x00\x00\x00\x00\xdd\x07\x00\x00") - - to_send, length = _batched_op_msg_impl(operation, command, docs, ack, opts, ctx, buf) - - # Header - request id and message length - buf.seek(4) - request_id = _randint() - buf.write(_pack_int(request_id)) - buf.seek(0) - buf.write(_pack_int(length)) - - return request_id, buf.getvalue(), to_send - - -if _use_c: - _batched_op_msg = _cmessage._batched_op_msg - - -def _do_batched_op_msg( - namespace: str, - operation: int, - command: MutableMapping[str, Any], - docs: list[Mapping[str, Any]], - opts: CodecOptions, - ctx: _BulkWriteContext, -) -> tuple[int, bytes, list[Mapping[str, Any]]]: - """Create the next batched insert, update, or delete operation - using OP_MSG. - """ - command["$db"] = namespace.split(".", 1)[0] - if "writeConcern" in command: - ack = bool(command["writeConcern"].get("w", 1)) - else: - ack = True - if ctx.conn.compression_context: - return _batched_op_msg_compressed(operation, command, docs, ack, opts, ctx) - return _batched_op_msg(operation, command, docs, ack, opts, ctx) - - -# End OP_MSG ----------------------------------------------------- - - -def _encode_batched_write_command( - namespace: str, - operation: int, - command: MutableMapping[str, Any], - docs: list[Mapping[str, Any]], - opts: CodecOptions, - ctx: _BulkWriteContext, -) -> tuple[bytes, list[Mapping[str, Any]]]: - """Encode the next batched insert, update, or delete command.""" - buf = _BytesIO() - - to_send, _ = _batched_write_command_impl(namespace, operation, command, docs, opts, ctx, buf) - return buf.getvalue(), to_send - - -if _use_c: - _encode_batched_write_command = _cmessage._encode_batched_write_command - - -def _batched_write_command_impl( - namespace: str, - operation: int, - command: MutableMapping[str, Any], - docs: list[Mapping[str, Any]], - opts: CodecOptions, - ctx: _BulkWriteContext, - buf: _BytesIO, -) -> tuple[list[Mapping[str, Any]], int]: - """Create a batched OP_QUERY write command.""" - max_bson_size = ctx.max_bson_size - max_write_batch_size = ctx.max_write_batch_size - # Max BSON object size + 16k - 2 bytes for ending NUL bytes. - # Server guarantees there is enough room: SERVER-10643. - max_cmd_size = max_bson_size + _COMMAND_OVERHEAD - max_split_size = ctx.max_split_size - - # No options - buf.write(_ZERO_32) - # Namespace as C string - buf.write(namespace.encode("utf8")) - buf.write(_ZERO_8) - # Skip: 0, Limit: -1 - buf.write(_SKIPLIM) - - # Where to write command document length - command_start = buf.tell() - buf.write(encode(command)) - - # Start of payload - buf.seek(-1, 2) - # Work around some Jython weirdness. - buf.truncate() - try: - buf.write(_OP_MAP[operation]) - except KeyError: - raise InvalidOperation("Unknown command") from None - - # Where to write list document length - list_start = buf.tell() - 4 - to_send = [] - idx = 0 - for doc in docs: - # Encode the current operation - key = str(idx).encode("utf8") - value = _dict_to_bson(doc, False, opts) - # Is there enough room to add this document? max_cmd_size accounts for - # the two trailing null bytes. - doc_too_large = len(value) > max_cmd_size - if doc_too_large: - write_op = list(_FIELD_MAP.keys())[operation] - _raise_document_too_large(write_op, len(value), max_bson_size) - enough_data = idx >= 1 and (buf.tell() + len(key) + len(value)) >= max_split_size - enough_documents = idx >= max_write_batch_size - if enough_data or enough_documents: - break - buf.write(_BSONOBJ) - buf.write(key) - buf.write(_ZERO_8) - buf.write(value) - to_send.append(doc) - idx += 1 - - # Finalize the current OP_QUERY message. - # Close list and command documents - buf.write(_ZERO_16) - - # Write document lengths and request id - length = buf.tell() - buf.seek(list_start) - buf.write(_pack_int(length - list_start - 1)) - buf.seek(command_start) - buf.write(_pack_int(length - command_start)) - - return to_send, length diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index da55a1160c..94edcccf82 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -1670,13 +1670,13 @@ async def _run_operation( if operation.conn_mgr: server = await self._select_server( operation.read_preference, - operation.session, + operation.session, # type: ignore[arg-type] operation.name, address=address, ) async with operation.conn_mgr._alock: - async with _MongoClientErrorHandler(self, server, operation.session) as err_handler: + async with _MongoClientErrorHandler(self, server, operation.session) as err_handler: # type: ignore[arg-type] err_handler.contribute_socket(operation.conn_mgr.conn) return await server.run_operation( operation.conn_mgr.conn, @@ -1706,7 +1706,7 @@ async def _cmd( return await self._retryable_read( _cmd, operation.read_preference, - operation.session, + operation.session, # type: ignore[arg-type] address=address, retryable=isinstance(operation, _Query), operation=operation.name, diff --git a/pymongo/asynchronous/server.py b/pymongo/asynchronous/server.py index 34735edd8d..892594c97d 100644 --- a/pymongo/asynchronous/server.py +++ b/pymongo/asynchronous/server.py @@ -109,29 +109,14 @@ def request_check(self) -> None: async def operation_to_command( self, operation: Union[_Query, _GetMore], conn: AsyncConnection, apply_timeout: bool = False ) -> tuple[dict[str, Any], str]: - is_query = isinstance(operation, _Query) - if is_query: - explain = "$explain" in operation.spec - cmd, db = operation.as_command() - else: - explain = False - cmd, db = operation.as_command(conn) - if operation.session: - operation.session._apply_to(cmd, False, operation.read_preference, conn) - # Explain does not support readConcern. - if is_query and not explain and not operation.session.in_transaction: - operation.session._update_read_concern(cmd, conn) + cmd, db = operation.as_command(conn, apply_timeout) # Support auto encryption if operation.client._encrypter and not operation.client._encrypter._bypass_auto_encryption: - cmd = await operation.client._encrypter.encrypt( + cmd = await operation.client._encrypter.encrypt( # type: ignore[misc, assignment] operation.db, cmd, operation.codec_options ) + operation.update_command(cmd) - conn.add_server_api(cmd) - conn.send_cluster_time(cmd, operation.session, operation.client) - # Support CSOT - if apply_timeout: - conn.apply_timeout(operation.client, cmd=cmd if is_query else None) return cmd, db @_handle_reauth @@ -223,7 +208,7 @@ async def run_operation( ) if use_cmd: first = docs[0] - await operation.client._process_response(first, operation.session) + await operation.client._process_response(first, operation.session) # type: ignore[misc, arg-type] _check_command_response(first, conn.max_wire_version) except Exception as exc: duration = datetime.now() - start @@ -306,7 +291,7 @@ async def run_operation( ) # Decrypt response. - client = operation.client + client = operation.client # type: ignore[assignment] if client and client._encrypter: if use_cmd: decrypted = client._encrypter.decrypt(reply.raw_command_response()) @@ -314,7 +299,7 @@ async def run_operation( response: Response - if client._should_pin_cursor(operation.session) or operation.exhaust: + if client._should_pin_cursor(operation.session) or operation.exhaust: # type: ignore[arg-type] conn.pin_cursor() if isinstance(reply, _OpMsg): # In OP_MSG, the server keeps sending only if the diff --git a/pymongo/bulk_shared.py b/pymongo/bulk_shared.py new file mode 100644 index 0000000000..7aa6340d55 --- /dev/null +++ b/pymongo/bulk_shared.py @@ -0,0 +1,131 @@ +# Copyright 2024-present MongoDB, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you +# may not use this file except in compliance with the License. You +# may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or +# implied. See the License for the specific language governing +# permissions and limitations under the License. + + +"""Constants, types, and classes shared across Bulk Write API implementations.""" +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Mapping, MutableMapping, NoReturn + +from pymongo.errors import BulkWriteError, OperationFailure +from pymongo.helpers_shared import _get_wce_doc +from pymongo.message import ( + _DELETE, + _INSERT, + _UPDATE, +) + +if TYPE_CHECKING: + from pymongo.typings import _DocumentOut + + +_DELETE_ALL: int = 0 +_DELETE_ONE: int = 1 + +# For backwards compatibility. See MongoDB src/mongo/base/error_codes.err +_BAD_VALUE: int = 2 +_UNKNOWN_ERROR: int = 8 +_WRITE_CONCERN_ERROR: int = 64 + +_COMMANDS: tuple[str, str, str] = ("insert", "update", "delete") + + +class _Run: + """Represents a batch of write operations.""" + + def __init__(self, op_type: int) -> None: + """Initialize a new Run object.""" + self.op_type: int = op_type + self.index_map: list[int] = [] + self.ops: list[Any] = [] + self.idx_offset: int = 0 + + def index(self, idx: int) -> int: + """Get the original index of an operation in this run. + + :param idx: The Run index that maps to the original index. + """ + return self.index_map[idx] + + def add(self, original_index: int, operation: Any) -> None: + """Add an operation to this Run instance. + + :param original_index: The original index of this operation + within a larger bulk operation. + :param operation: The operation document. + """ + self.index_map.append(original_index) + self.ops.append(operation) + + +def _merge_command( + run: _Run, + full_result: MutableMapping[str, Any], + offset: int, + result: Mapping[str, Any], +) -> None: + """Merge a write command result into the full bulk result.""" + affected = result.get("n", 0) + + if run.op_type == _INSERT: + full_result["nInserted"] += affected + + elif run.op_type == _DELETE: + full_result["nRemoved"] += affected + + elif run.op_type == _UPDATE: + upserted = result.get("upserted") + if upserted: + n_upserted = len(upserted) + for doc in upserted: + doc["index"] = run.index(doc["index"] + offset) + full_result["upserted"].extend(upserted) + full_result["nUpserted"] += n_upserted + full_result["nMatched"] += affected - n_upserted + else: + full_result["nMatched"] += affected + full_result["nModified"] += result["nModified"] + + write_errors = result.get("writeErrors") + if write_errors: + for doc in write_errors: + # Leave the server response intact for APM. + replacement = doc.copy() + idx = doc["index"] + offset + replacement["index"] = run.index(idx) + # Add the failed operation to the error document. + replacement["op"] = run.ops[idx] + full_result["writeErrors"].append(replacement) + + wce = _get_wce_doc(result) + if wce: + full_result["writeConcernErrors"].append(wce) + + +def _raise_bulk_write_error(full_result: _DocumentOut) -> NoReturn: + """Raise a BulkWriteError from the full bulk api result.""" + # retryWrites on MMAPv1 should raise an actionable error. + if full_result["writeErrors"]: + full_result["writeErrors"].sort(key=lambda error: error["index"]) + err = full_result["writeErrors"][0] + code = err["code"] + msg = err["errmsg"] + if code == 20 and msg.startswith("Transaction numbers"): + errmsg = ( + "This MongoDB deployment does not support " + "retryable writes. Please add retryWrites=false " + "to your connection string." + ) + raise OperationFailure(errmsg, code, full_result) + raise BulkWriteError(full_result) diff --git a/pymongo/message.py b/pymongo/message.py index 623003cd09..cf9977ac6f 100644 --- a/pymongo/message.py +++ b/pymongo/message.py @@ -21,8 +21,10 @@ """ from __future__ import annotations +import datetime import random import struct +from io import BytesIO as _BytesIO from typing import ( TYPE_CHECKING, Any, @@ -40,9 +42,12 @@ from bson.int64 import Int64 from bson.raw_bson import ( _RAW_ARRAY_BSON_OPTIONS, + DEFAULT_RAW_BSON_OPTIONS, RawBSONDocument, + _inflate_bson, ) from pymongo.hello_compat import HelloCompat +from pymongo.monitoring import _EventListeners try: from pymongo import _cmessage # type: ignore[attr-defined] @@ -55,6 +60,7 @@ CursorNotFound, DocumentTooLarge, ExecutionTimeout, + InvalidOperation, NotPrimaryError, OperationFailure, ProtocolError, @@ -62,12 +68,16 @@ from pymongo.read_preferences import ReadPreference, _ServerMode if TYPE_CHECKING: - from pymongo.asynchronous.client_session import AsyncClientSession - from pymongo.asynchronous.mongo_client import AsyncMongoClient - from pymongo.asynchronous.pool import AsyncConnection from pymongo.compression_support import SnappyContext, ZlibContext, ZstdContext from pymongo.read_concern import ReadConcern - from pymongo.typings import _Address, _AgnosticClientSession, _AgnosticConnection + from pymongo.typings import ( + _Address, + _AgnosticClientSession, + _AgnosticConnection, + _AgnosticMongoClient, + _DocumentOut, + ) + MAX_INT32 = 2147483647 MIN_INT32 = -2147483648 @@ -502,8 +512,8 @@ def __init__( batch_size: int, read_concern: ReadConcern, collation: Optional[Mapping[str, Any]], - session: Optional[AsyncClientSession], - client: AsyncMongoClient, + session: Optional[_AgnosticClientSession], + client: _AgnosticMongoClient, allow_disk_use: Optional[bool], exhaust: bool, ): @@ -532,7 +542,7 @@ def reset(self) -> None: def namespace(self) -> str: return f"{self.db}.{self.coll}" - def use_command(self, conn: AsyncConnection) -> bool: + def use_command(self, conn: _AgnosticConnection) -> bool: use_find_cmd = False if not self.exhaust: use_find_cmd = True @@ -545,10 +555,15 @@ def use_command(self, conn: AsyncConnection) -> bool: "with a max wire version of %d." % (self.read_concern.level, conn.max_wire_version) ) - conn.validate_session(self.client, self.session) + conn.validate_session(self.client, self.session) # type: ignore[arg-type] return use_find_cmd - def as_command(self, dummy0: Optional[Any] = None) -> tuple[dict[str, Any], str]: + def update_command(self, cmd: dict[str, Any]) -> None: + self._as_command = cmd, self.db + + def as_command( + self, conn: _AgnosticConnection, apply_timeout: bool = False + ) -> tuple[dict[str, Any], str]: """Return a find command document for this query.""" # We use the command twice: on the wire and for command monitoring. # Generate it once, for speed and to avoid repeating side-effects. @@ -572,11 +587,21 @@ def as_command(self, dummy0: Optional[Any] = None) -> tuple[dict[str, Any], str] if explain: self.name = "explain" cmd = {"explain": cmd} + conn.add_server_api(cmd) + if self.session: + self.session._apply_to(cmd, False, self.read_preference, conn) # type: ignore[arg-type] + # Explain does not support readConcern. + if not explain and not self.session.in_transaction: + self.session._update_read_concern(cmd, conn) # type: ignore[arg-type] + conn.send_cluster_time(cmd, self.session, self.client) # type: ignore[arg-type] + # Support CSOT + if apply_timeout: + conn.apply_timeout(self.client, cmd=cmd) # type: ignore[arg-type] self._as_command = cmd, self.db return self._as_command def get_message( - self, read_preference: _ServerMode, conn: AsyncConnection, use_cmd: bool = False + self, read_preference: _ServerMode, conn: _AgnosticConnection, use_cmd: bool = False ) -> tuple[int, bytes, int]: """Get a query message, possibly setting the secondaryOk bit.""" # Use the read_preference decided by _socket_from_server. @@ -591,7 +616,7 @@ def get_message( spec = self.spec if use_cmd: - spec = self.as_command(None)[0] + spec = self.as_command(conn)[0] request_id, msg, size, _ = _op_msg( 0, spec, @@ -657,8 +682,8 @@ def __init__( cursor_id: int, codec_options: CodecOptions, read_preference: _ServerMode, - session: Optional[AsyncClientSession], - client: AsyncMongoClient, + session: Optional[_AgnosticClientSession], + client: _AgnosticMongoClient, max_await_time_ms: Optional[int], conn_mgr: Any, exhaust: bool, @@ -684,7 +709,7 @@ def reset(self) -> None: def namespace(self) -> str: return f"{self.db}.{self.coll}" - def use_command(self, conn: AsyncConnection) -> bool: + def use_command(self, conn: _AgnosticConnection) -> bool: use_cmd = False if not self.exhaust: use_cmd = True @@ -692,10 +717,15 @@ def use_command(self, conn: AsyncConnection) -> bool: # OP_MSG supports exhaust on MongoDB 4.2+ use_cmd = True - conn.validate_session(self.client, self.session) + conn.validate_session(self.client, self.session) # type: ignore[arg-type] return use_cmd - def as_command(self, conn: AsyncConnection) -> tuple[dict[str, Any], str]: + def update_command(self, cmd: dict[str, Any]) -> None: + self._as_command = cmd, self.db + + def as_command( + self, conn: _AgnosticConnection, apply_timeout: bool = False + ) -> tuple[dict[str, Any], str]: """Return a getMore command document for this query.""" # See _Query.as_command for an explanation of this caching. if self._as_command is not None: @@ -709,11 +739,18 @@ def as_command(self, conn: AsyncConnection) -> tuple[dict[str, Any], str]: self.comment, conn, ) + if self.session: + self.session._apply_to(cmd, False, self.read_preference, conn) # type: ignore[arg-type] + conn.add_server_api(cmd) + conn.send_cluster_time(cmd, self.session, self.client) # type: ignore[arg-type] + # Support CSOT + if apply_timeout: + conn.apply_timeout(self.client, cmd=None) # type: ignore[arg-type] self._as_command = cmd, self.db return self._as_command def get_message( - self, dummy0: Any, conn: AsyncConnection, use_cmd: bool = False + self, dummy0: Any, conn: _AgnosticConnection, use_cmd: bool = False ) -> Union[tuple[int, bytes, int], tuple[int, bytes]]: """Get a getmore message.""" ns = self.namespace() @@ -734,7 +771,7 @@ def get_message( class _RawBatchQuery(_Query): - def use_command(self, conn: AsyncConnection) -> bool: + def use_command(self, conn: _AgnosticConnection) -> bool: # Compatibility checks. super().use_command(conn) if conn.max_wire_version >= 8: @@ -746,7 +783,7 @@ def use_command(self, conn: AsyncConnection) -> bool: class _RawBatchGetMore(_GetMore): - def use_command(self, conn: AsyncConnection) -> bool: + def use_command(self, conn: _AgnosticConnection) -> bool: # Compatibility checks. super().use_command(conn) if conn.max_wire_version >= 8: @@ -1083,3 +1120,414 @@ def _raise_document_too_large(operation: str, doc_size: int, max_size: int) -> N _UPDATE: b"updates\x00", _DELETE: b"deletes\x00", } + + +class _BulkWriteContext: + """A wrapper around AsyncConnection for use with write splitting functions.""" + + __slots__ = ( + "db_name", + "conn", + "op_id", + "name", + "field", + "publish", + "start_time", + "listeners", + "session", + "compress", + "op_type", + "codec", + ) + + def __init__( + self, + database_name: str, + cmd_name: str, + conn: _AgnosticConnection, + operation_id: int, + listeners: _EventListeners, + session: _AgnosticClientSession, + op_type: int, + codec: CodecOptions, + ): + self.db_name = database_name + self.conn = conn + self.op_id = operation_id + self.listeners = listeners + self.publish = listeners.enabled_for_commands + self.name = cmd_name + self.field = _FIELD_MAP[self.name] + self.start_time = datetime.datetime.now() + self.session = session + self.compress = bool(conn.compression_context) + self.op_type = op_type + self.codec = codec + + def batch_command( + self, cmd: MutableMapping[str, Any], docs: list[Mapping[str, Any]] + ) -> tuple[int, Union[bytes, dict[str, Any]], list[Mapping[str, Any]]]: + namespace = self.db_name + ".$cmd" + request_id, msg, to_send = _do_batched_op_msg( + namespace, self.op_type, cmd, docs, self.codec, self + ) + if not to_send: + raise InvalidOperation("cannot do an empty bulk write") + return request_id, msg, to_send + + @property + def max_bson_size(self) -> int: + """A proxy for SockInfo.max_bson_size.""" + return self.conn.max_bson_size + + @property + def max_message_size(self) -> int: + """A proxy for SockInfo.max_message_size.""" + if self.compress: + # Subtract 16 bytes for the message header. + return self.conn.max_message_size - 16 + return self.conn.max_message_size + + @property + def max_write_batch_size(self) -> int: + """A proxy for SockInfo.max_write_batch_size.""" + return self.conn.max_write_batch_size + + @property + def max_split_size(self) -> int: + """The maximum size of a BSON command before batch splitting.""" + return self.max_bson_size + + def _start( + self, cmd: MutableMapping[str, Any], request_id: int, docs: list[Mapping[str, Any]] + ) -> MutableMapping[str, Any]: + """Publish a CommandStartedEvent.""" + cmd[self.field] = docs + self.listeners.publish_command_start( + cmd, + self.db_name, + request_id, + self.conn.address, + self.conn.server_connection_id, + self.op_id, + self.conn.service_id, + ) + return cmd + + def _succeed(self, request_id: int, reply: _DocumentOut, duration: datetime.timedelta) -> None: + """Publish a CommandSucceededEvent.""" + self.listeners.publish_command_success( + duration, + reply, + self.name, + request_id, + self.conn.address, + self.conn.server_connection_id, + self.op_id, + self.conn.service_id, + database_name=self.db_name, + ) + + def _fail(self, request_id: int, failure: _DocumentOut, duration: datetime.timedelta) -> None: + """Publish a CommandFailedEvent.""" + self.listeners.publish_command_failure( + duration, + failure, + self.name, + request_id, + self.conn.address, + self.conn.server_connection_id, + self.op_id, + self.conn.service_id, + database_name=self.db_name, + ) + + +class _EncryptedBulkWriteContext(_BulkWriteContext): + __slots__ = () + + def batch_command( + self, cmd: MutableMapping[str, Any], docs: list[Mapping[str, Any]] + ) -> tuple[int, dict[str, Any], list[Mapping[str, Any]]]: + namespace = self.db_name + ".$cmd" + msg, to_send = _encode_batched_write_command( + namespace, self.op_type, cmd, docs, self.codec, self + ) + if not to_send: + raise InvalidOperation("cannot do an empty bulk write") + + # Chop off the OP_QUERY header to get a properly batched write command. + cmd_start = msg.index(b"\x00", 4) + 9 + outgoing = _inflate_bson(memoryview(msg)[cmd_start:], DEFAULT_RAW_BSON_OPTIONS) + return -1, outgoing, to_send + + @property + def max_split_size(self) -> int: + """Reduce the batch splitting size.""" + return _MAX_SPLIT_SIZE_ENC + + +# From the Client Side Encryption spec: +# Because automatic encryption increases the size of commands, the driver +# MUST split bulk writes at a reduced size limit before undergoing automatic +# encryption. The write payload MUST be split at 2MiB (2097152). +_MAX_SPLIT_SIZE_ENC = 2097152 + + +def _batched_op_msg_impl( + operation: int, + command: Mapping[str, Any], + docs: list[Mapping[str, Any]], + ack: bool, + opts: CodecOptions, + ctx: _BulkWriteContext, + buf: _BytesIO, +) -> tuple[list[Mapping[str, Any]], int]: + """Create a batched OP_MSG write.""" + max_bson_size = ctx.max_bson_size + max_write_batch_size = ctx.max_write_batch_size + max_message_size = ctx.max_message_size + + flags = b"\x00\x00\x00\x00" if ack else b"\x02\x00\x00\x00" + # Flags + buf.write(flags) + + # Type 0 Section + buf.write(b"\x00") + buf.write(_dict_to_bson(command, False, opts)) + + # Type 1 Section + buf.write(b"\x01") + size_location = buf.tell() + # Save space for size + buf.write(b"\x00\x00\x00\x00") + try: + buf.write(_OP_MSG_MAP[operation]) + except KeyError: + raise InvalidOperation("Unknown command") from None + + to_send = [] + idx = 0 + for doc in docs: + # Encode the current operation + value = _dict_to_bson(doc, False, opts) + doc_length = len(value) + new_message_size = buf.tell() + doc_length + # Does first document exceed max_message_size? + doc_too_large = idx == 0 and (new_message_size > max_message_size) + # When OP_MSG is used unacknowledged we have to check + # document size client side or applications won't be notified. + # Otherwise we let the server deal with documents that are too large + # since ordered=False causes those documents to be skipped instead of + # halting the bulk write operation. + unacked_doc_too_large = not ack and (doc_length > max_bson_size) + if doc_too_large or unacked_doc_too_large: + write_op = list(_FIELD_MAP.keys())[operation] + _raise_document_too_large(write_op, len(value), max_bson_size) + # We have enough data, return this batch. + if new_message_size > max_message_size: + break + buf.write(value) + to_send.append(doc) + idx += 1 + # We have enough documents, return this batch. + if idx == max_write_batch_size: + break + + # Write type 1 section size + length = buf.tell() + buf.seek(size_location) + buf.write(_pack_int(length - size_location)) + + return to_send, length + + +def _encode_batched_op_msg( + operation: int, + command: Mapping[str, Any], + docs: list[Mapping[str, Any]], + ack: bool, + opts: CodecOptions, + ctx: _BulkWriteContext, +) -> tuple[bytes, list[Mapping[str, Any]]]: + """Encode the next batched insert, update, or delete operation + as OP_MSG. + """ + buf = _BytesIO() + + to_send, _ = _batched_op_msg_impl(operation, command, docs, ack, opts, ctx, buf) + return buf.getvalue(), to_send + + +if _use_c: + _encode_batched_op_msg = _cmessage._encode_batched_op_msg + + +def _batched_op_msg_compressed( + operation: int, + command: Mapping[str, Any], + docs: list[Mapping[str, Any]], + ack: bool, + opts: CodecOptions, + ctx: _BulkWriteContext, +) -> tuple[int, bytes, list[Mapping[str, Any]]]: + """Create the next batched insert, update, or delete operation + with OP_MSG, compressed. + """ + data, to_send = _encode_batched_op_msg(operation, command, docs, ack, opts, ctx) + + assert ctx.conn.compression_context is not None + request_id, msg = _compress(2013, data, ctx.conn.compression_context) + return request_id, msg, to_send + + +def _batched_op_msg( + operation: int, + command: Mapping[str, Any], + docs: list[Mapping[str, Any]], + ack: bool, + opts: CodecOptions, + ctx: _BulkWriteContext, +) -> tuple[int, bytes, list[Mapping[str, Any]]]: + """OP_MSG implementation entry point.""" + buf = _BytesIO() + + # Save space for message length and request id + buf.write(_ZERO_64) + # responseTo, opCode + buf.write(b"\x00\x00\x00\x00\xdd\x07\x00\x00") + + to_send, length = _batched_op_msg_impl(operation, command, docs, ack, opts, ctx, buf) + + # Header - request id and message length + buf.seek(4) + request_id = _randint() + buf.write(_pack_int(request_id)) + buf.seek(0) + buf.write(_pack_int(length)) + + return request_id, buf.getvalue(), to_send + + +if _use_c: + _batched_op_msg = _cmessage._batched_op_msg + + +def _do_batched_op_msg( + namespace: str, + operation: int, + command: MutableMapping[str, Any], + docs: list[Mapping[str, Any]], + opts: CodecOptions, + ctx: _BulkWriteContext, +) -> tuple[int, bytes, list[Mapping[str, Any]]]: + """Create the next batched insert, update, or delete operation + using OP_MSG. + """ + command["$db"] = namespace.split(".", 1)[0] + if "writeConcern" in command: + ack = bool(command["writeConcern"].get("w", 1)) + else: + ack = True + if ctx.conn.compression_context: + return _batched_op_msg_compressed(operation, command, docs, ack, opts, ctx) + return _batched_op_msg(operation, command, docs, ack, opts, ctx) + + +# End OP_MSG ----------------------------------------------------- + + +def _encode_batched_write_command( + namespace: str, + operation: int, + command: MutableMapping[str, Any], + docs: list[Mapping[str, Any]], + opts: CodecOptions, + ctx: _BulkWriteContext, +) -> tuple[bytes, list[Mapping[str, Any]]]: + """Encode the next batched insert, update, or delete command.""" + buf = _BytesIO() + + to_send, _ = _batched_write_command_impl(namespace, operation, command, docs, opts, ctx, buf) + return buf.getvalue(), to_send + + +if _use_c: + _encode_batched_write_command = _cmessage._encode_batched_write_command + + +def _batched_write_command_impl( + namespace: str, + operation: int, + command: MutableMapping[str, Any], + docs: list[Mapping[str, Any]], + opts: CodecOptions, + ctx: _BulkWriteContext, + buf: _BytesIO, +) -> tuple[list[Mapping[str, Any]], int]: + """Create a batched OP_QUERY write command.""" + max_bson_size = ctx.max_bson_size + max_write_batch_size = ctx.max_write_batch_size + # Max BSON object size + 16k - 2 bytes for ending NUL bytes. + # Server guarantees there is enough room: SERVER-10643. + max_cmd_size = max_bson_size + _COMMAND_OVERHEAD + max_split_size = ctx.max_split_size + + # No options + buf.write(_ZERO_32) + # Namespace as C string + buf.write(namespace.encode("utf8")) + buf.write(_ZERO_8) + # Skip: 0, Limit: -1 + buf.write(_SKIPLIM) + + # Where to write command document length + command_start = buf.tell() + buf.write(bson.encode(command)) + + # Start of payload + buf.seek(-1, 2) + # Work around some Jython weirdness. + buf.truncate() + try: + buf.write(_OP_MAP[operation]) + except KeyError: + raise InvalidOperation("Unknown command") from None + + # Where to write list document length + list_start = buf.tell() - 4 + to_send = [] + idx = 0 + for doc in docs: + # Encode the current operation + key = str(idx).encode("utf8") + value = _dict_to_bson(doc, False, opts) + # Is there enough room to add this document? max_cmd_size accounts for + # the two trailing null bytes. + doc_too_large = len(value) > max_cmd_size + if doc_too_large: + write_op = list(_FIELD_MAP.keys())[operation] + _raise_document_too_large(write_op, len(value), max_bson_size) + enough_data = idx >= 1 and (buf.tell() + len(key) + len(value)) >= max_split_size + enough_documents = idx >= max_write_batch_size + if enough_data or enough_documents: + break + buf.write(_BSONOBJ) + buf.write(key) + buf.write(_ZERO_8) + buf.write(value) + to_send.append(doc) + idx += 1 + + # Finalize the current OP_QUERY message. + # Close list and command documents + buf.write(_ZERO_16) + + # Write document lengths and request id + length = buf.tell() + buf.seek(list_start) + buf.write(_pack_int(length - list_start - 1)) + buf.seek(command_start) + buf.write(_pack_int(length - command_start)) + + return to_send, length diff --git a/pymongo/synchronous/bulk.py b/pymongo/synchronous/bulk.py index 95699ac09e..ae47842053 100644 --- a/pymongo/synchronous/bulk.py +++ b/pymongo/synchronous/bulk.py @@ -19,6 +19,8 @@ from __future__ import annotations import copy +import datetime +import logging from collections.abc import MutableMapping from itertools import islice from typing import ( @@ -26,7 +28,6 @@ Any, Iterator, Mapping, - NoReturn, Optional, Type, Union, @@ -35,139 +36,49 @@ from bson.objectid import ObjectId from bson.raw_bson import RawBSONDocument from pymongo import _csot, common +from pymongo.bulk_shared import ( + _COMMANDS, + _DELETE_ALL, + _merge_command, + _raise_bulk_write_error, + _Run, +) from pymongo.common import ( validate_is_document_type, validate_ok_for_replace, validate_ok_for_update, ) from pymongo.errors import ( - BulkWriteError, ConfigurationError, InvalidOperation, + NotPrimaryError, OperationFailure, ) -from pymongo.helpers_shared import _RETRYABLE_ERROR_CODES, _get_wce_doc +from pymongo.helpers_shared import _RETRYABLE_ERROR_CODES +from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log from pymongo.message import ( _DELETE, _INSERT, _UPDATE, + _BulkWriteContext, + _convert_exception, + _convert_write_result, + _EncryptedBulkWriteContext, _randint, ) from pymongo.read_preferences import ReadPreference from pymongo.synchronous.client_session import ClientSession, _validate_session_write_concern -from pymongo.synchronous.message import ( - _BulkWriteContext, - _EncryptedBulkWriteContext, -) +from pymongo.synchronous.helpers import _handle_reauth from pymongo.write_concern import WriteConcern if TYPE_CHECKING: from pymongo.synchronous.collection import Collection + from pymongo.synchronous.mongo_client import MongoClient from pymongo.synchronous.pool import Connection from pymongo.typings import _DocumentOut, _DocumentType, _Pipeline _IS_SYNC = True -_DELETE_ALL: int = 0 -_DELETE_ONE: int = 1 - -# For backwards compatibility. See MongoDB src/mongo/base/error_codes.err -_BAD_VALUE: int = 2 -_UNKNOWN_ERROR: int = 8 -_WRITE_CONCERN_ERROR: int = 64 - -_COMMANDS: tuple[str, str, str] = ("insert", "update", "delete") - - -class _Run: - """Represents a batch of write operations.""" - - def __init__(self, op_type: int) -> None: - """Initialize a new Run object.""" - self.op_type: int = op_type - self.index_map: list[int] = [] - self.ops: list[Any] = [] - self.idx_offset: int = 0 - - def index(self, idx: int) -> int: - """Get the original index of an operation in this run. - - :param idx: The Run index that maps to the original index. - """ - return self.index_map[idx] - - def add(self, original_index: int, operation: Any) -> None: - """Add an operation to this Run instance. - - :param original_index: The original index of this operation - within a larger bulk operation. - :param operation: The operation document. - """ - self.index_map.append(original_index) - self.ops.append(operation) - - -def _merge_command( - run: _Run, - full_result: MutableMapping[str, Any], - offset: int, - result: Mapping[str, Any], -) -> None: - """Merge a write command result into the full bulk result.""" - affected = result.get("n", 0) - - if run.op_type == _INSERT: - full_result["nInserted"] += affected - - elif run.op_type == _DELETE: - full_result["nRemoved"] += affected - - elif run.op_type == _UPDATE: - upserted = result.get("upserted") - if upserted: - n_upserted = len(upserted) - for doc in upserted: - doc["index"] = run.index(doc["index"] + offset) - full_result["upserted"].extend(upserted) - full_result["nUpserted"] += n_upserted - full_result["nMatched"] += affected - n_upserted - else: - full_result["nMatched"] += affected - full_result["nModified"] += result["nModified"] - - write_errors = result.get("writeErrors") - if write_errors: - for doc in write_errors: - # Leave the server response intact for APM. - replacement = doc.copy() - idx = doc["index"] + offset - replacement["index"] = run.index(idx) - # Add the failed operation to the error document. - replacement["op"] = run.ops[idx] - full_result["writeErrors"].append(replacement) - - wce = _get_wce_doc(result) - if wce: - full_result["writeConcernErrors"].append(wce) - - -def _raise_bulk_write_error(full_result: _DocumentOut) -> NoReturn: - """Raise a BulkWriteError from the full bulk api result.""" - # retryWrites on MMAPv1 should raise an actionable error. - if full_result["writeErrors"]: - full_result["writeErrors"].sort(key=lambda error: error["index"]) - err = full_result["writeErrors"][0] - code = err["code"] - msg = err["errmsg"] - if code == 20 and msg.startswith("Transaction numbers"): - errmsg = ( - "This MongoDB deployment does not support " - "retryable writes. Please add retryWrites=false " - "to your connection string." - ) - raise OperationFailure(errmsg, code, full_result) - raise BulkWriteError(full_result) - class _Bulk: """The private guts of the bulk write API.""" @@ -204,13 +115,16 @@ def __init__( # Extra state so that we know where to pick up on a retry attempt. self.current_run = None self.next_run = None + self.is_encrypted = False @property def bulk_ctx_class(self) -> Type[_BulkWriteContext]: encrypter = self.collection.database.client._encrypter if encrypter and not encrypter._bypass_auto_encryption: + self.is_encrypted = True return _EncryptedBulkWriteContext else: + self.is_encrypted = False return _BulkWriteContext def add_insert(self, document: _DocumentOut) -> None: @@ -315,6 +229,180 @@ def gen_unordered(self) -> Iterator[_Run]: if run.ops: yield run + @_handle_reauth + def write_command( + self, + bwc: _BulkWriteContext, + cmd: MutableMapping[str, Any], + request_id: int, + msg: bytes, + docs: list[Mapping[str, Any]], + client: MongoClient, + ) -> dict[str, Any]: + """A proxy for SocketInfo.write_command that handles event publishing.""" + cmd[bwc.field] = docs + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + clientId=client._topology_settings._topology_id, + message=_CommandStatusMessage.STARTED, + command=cmd, + commandName=next(iter(cmd)), + databaseName=bwc.db_name, + requestId=request_id, + operationId=request_id, + driverConnectionId=bwc.conn.id, + serverConnectionId=bwc.conn.server_connection_id, + serverHost=bwc.conn.address[0], + serverPort=bwc.conn.address[1], + serviceId=bwc.conn.service_id, + ) + if bwc.publish: + bwc._start(cmd, request_id, docs) + try: + reply = bwc.conn.write_command(request_id, msg, bwc.codec) # type: ignore[misc] + duration = datetime.datetime.now() - bwc.start_time + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + clientId=client._topology_settings._topology_id, + message=_CommandStatusMessage.SUCCEEDED, + durationMS=duration, + reply=reply, + commandName=next(iter(cmd)), + databaseName=bwc.db_name, + requestId=request_id, + operationId=request_id, + driverConnectionId=bwc.conn.id, + serverConnectionId=bwc.conn.server_connection_id, + serverHost=bwc.conn.address[0], + serverPort=bwc.conn.address[1], + serviceId=bwc.conn.service_id, + ) + if bwc.publish: + bwc._succeed(request_id, reply, duration) # type: ignore[arg-type] + except Exception as exc: + duration = datetime.datetime.now() - bwc.start_time + if isinstance(exc, (NotPrimaryError, OperationFailure)): + failure: _DocumentOut = exc.details # type: ignore[assignment] + else: + failure = _convert_exception(exc) + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + clientId=client._topology_settings._topology_id, + message=_CommandStatusMessage.FAILED, + durationMS=duration, + failure=failure, + commandName=next(iter(cmd)), + databaseName=bwc.db_name, + requestId=request_id, + operationId=request_id, + driverConnectionId=bwc.conn.id, + serverConnectionId=bwc.conn.server_connection_id, + serverHost=bwc.conn.address[0], + serverPort=bwc.conn.address[1], + serviceId=bwc.conn.service_id, + isServerSideError=isinstance(exc, OperationFailure), + ) + + if bwc.publish: + bwc._fail(request_id, failure, duration) + raise + finally: + bwc.start_time = datetime.datetime.now() + return reply # type: ignore[return-value] + + def unack_write( + self, + bwc: _BulkWriteContext, + cmd: MutableMapping[str, Any], + request_id: int, + msg: bytes, + max_doc_size: int, + docs: list[Mapping[str, Any]], + client: MongoClient, + ) -> Optional[Mapping[str, Any]]: + """A proxy for AsyncConnection.unack_write that handles event publishing.""" + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + clientId=client._topology_settings._topology_id, + message=_CommandStatusMessage.STARTED, + command=cmd, + commandName=next(iter(cmd)), + databaseName=bwc.db_name, + requestId=request_id, + operationId=request_id, + driverConnectionId=bwc.conn.id, + serverConnectionId=bwc.conn.server_connection_id, + serverHost=bwc.conn.address[0], + serverPort=bwc.conn.address[1], + serviceId=bwc.conn.service_id, + ) + if bwc.publish: + cmd = bwc._start(cmd, request_id, docs) + try: + result = bwc.conn.unack_write(msg, max_doc_size) # type: ignore[func-returns-value, misc, override] + duration = datetime.datetime.now() - bwc.start_time + if result is not None: + reply = _convert_write_result(bwc.name, cmd, result) # type: ignore[arg-type] + else: + # Comply with APM spec. + reply = {"ok": 1} + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + clientId=client._topology_settings._topology_id, + message=_CommandStatusMessage.SUCCEEDED, + durationMS=duration, + reply=reply, + commandName=next(iter(cmd)), + databaseName=bwc.db_name, + requestId=request_id, + operationId=request_id, + driverConnectionId=bwc.conn.id, + serverConnectionId=bwc.conn.server_connection_id, + serverHost=bwc.conn.address[0], + serverPort=bwc.conn.address[1], + serviceId=bwc.conn.service_id, + ) + if bwc.publish: + bwc._succeed(request_id, reply, duration) + except Exception as exc: + duration = datetime.datetime.now() - bwc.start_time + if isinstance(exc, OperationFailure): + failure: _DocumentOut = _convert_write_result(bwc.name, cmd, exc.details) # type: ignore[arg-type] + elif isinstance(exc, NotPrimaryError): + failure = exc.details # type: ignore[assignment] + else: + failure = _convert_exception(exc) + if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): + _debug_log( + _COMMAND_LOGGER, + clientId=client._topology_settings._topology_id, + message=_CommandStatusMessage.FAILED, + durationMS=duration, + failure=failure, + commandName=next(iter(cmd)), + databaseName=bwc.db_name, + requestId=request_id, + operationId=request_id, + driverConnectionId=bwc.conn.id, + serverConnectionId=bwc.conn.server_connection_id, + serverHost=bwc.conn.address[0], + serverPort=bwc.conn.address[1], + serviceId=bwc.conn.service_id, + isServerSideError=isinstance(exc, OperationFailure), + ) + if bwc.publish: + assert bwc.start_time is not None + bwc._fail(request_id, failure, duration) + raise + finally: + bwc.start_time = datetime.datetime.now() + return result # type: ignore[return-value] + def _execute_command( self, generator: Iterator[Any], @@ -387,7 +475,19 @@ def _execute_command( # Run as many ops as possible in one command. if write_concern.acknowledged: - result, to_send = bwc.execute(cmd, ops, client) + if self.is_encrypted: + _, batched_cmd, to_send = bwc.batch_command(cmd, ops) + result = bwc.conn.command( + bwc.db_name, + batched_cmd, + codec_options=bwc.codec, + session=bwc.session, + client=client, + ) + else: + request_id, msg, to_send = bwc.batch_command(cmd, ops) + result = self.write_command(bwc, cmd, request_id, msg, to_send, client) + client._process_response(result, bwc.session) # Retryable writeConcernErrors halt the execution of this run. wce = result.get("writeConcernError", {}) @@ -407,7 +507,23 @@ def _execute_command( if self.ordered and "writeErrors" in result: break else: - to_send = bwc.execute_unack(cmd, ops, client) + if self.is_encrypted: + _, batched_cmd, to_send = bwc.batch_command(cmd, ops) + bwc.conn.command( + bwc.db_name, + batched_cmd, + write_concern=WriteConcern(w=0), + session=bwc.session, + client=client, + ) + else: + request_id, msg, to_send = bwc.batch_command(cmd, ops) + # Though this isn't strictly a "legacy" write, the helper + # handles publishing commands and sending our message + # without receiving a result. Send 0 for max_doc_size + # to disable size checking. Size checking is handled while + # the documents are encoded to BSON. + self.unack_write(bwc, cmd, request_id, msg, 0, to_send, client) run.idx_offset += len(to_send) @@ -499,7 +615,18 @@ def execute_op_msg_no_results(self, conn: Connection, generator: Iterator[Any]) conn.add_server_api(cmd) ops = islice(run.ops, run.idx_offset, None) # Run as many ops as possible. - to_send = bwc.execute_unack(cmd, ops, client) + if self.is_encrypted: + _, batched_cmd, to_send = bwc.batch_command(cmd, ops) + bwc.conn.command( + bwc.db_name, + batched_cmd, + write_concern=WriteConcern(w=0), + session=bwc.session, + client=client, + ) + else: + request_id, msg, to_send = bwc.batch_command(cmd, ops) + self.unack_write(bwc, cmd, request_id, msg, 0, to_send, client) run.idx_offset += len(to_send) self.current_run = run = next(generator, None) diff --git a/pymongo/synchronous/helpers.py b/pymongo/synchronous/helpers.py index 56d20c7c10..f581caae69 100644 --- a/pymongo/synchronous/helpers.py +++ b/pymongo/synchronous/helpers.py @@ -38,7 +38,7 @@ def _handle_reauth(func: F) -> F: def inner(*args: Any, **kwargs: Any) -> Any: no_reauth = kwargs.pop("no_reauth", False) - from pymongo.synchronous.message import _BulkWriteContext + from pymongo.message import _BulkWriteContext from pymongo.synchronous.pool import Connection try: @@ -56,7 +56,7 @@ def inner(*args: Any, **kwargs: Any) -> Any: conn = arg break if isinstance(arg, _BulkWriteContext): - conn = arg.conn + conn = arg.conn # type: ignore[assignment] break if conn: conn.authenticate(reauthenticate=True) diff --git a/pymongo/synchronous/message.py b/pymongo/synchronous/message.py deleted file mode 100644 index 69e7a8ec49..0000000000 --- a/pymongo/synchronous/message.py +++ /dev/null @@ -1,713 +0,0 @@ -# Copyright 2009-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Tools for creating `messages -`_ to be sent to -MongoDB. - -.. note:: This module is for internal use and is generally not needed by - application developers. -""" -from __future__ import annotations - -import datetime -import logging -from io import BytesIO as _BytesIO -from typing import ( - TYPE_CHECKING, - Any, - Mapping, - MutableMapping, - Optional, -) - -from bson import CodecOptions, _dict_to_bson, encode -from bson.raw_bson import ( - DEFAULT_RAW_BSON_OPTIONS, - _inflate_bson, -) -from pymongo.errors import ( - InvalidOperation, - NotPrimaryError, - OperationFailure, -) -from pymongo.logger import ( - _COMMAND_LOGGER, - _CommandStatusMessage, - _debug_log, -) -from pymongo.message import ( - _BSONOBJ, - _COMMAND_OVERHEAD, - _FIELD_MAP, - _OP_MAP, - _OP_MSG_MAP, - _SKIPLIM, - _ZERO_8, - _ZERO_16, - _ZERO_32, - _ZERO_64, - _compress, - _convert_exception, - _convert_write_result, - _pack_int, - _raise_document_too_large, - _randint, -) -from pymongo.synchronous.helpers import _handle_reauth -from pymongo.write_concern import WriteConcern - -try: - from pymongo import _cmessage # type: ignore[attr-defined] - - _use_c = True -except ImportError: - _use_c = False - -if TYPE_CHECKING: - from datetime import timedelta - - from pymongo.monitoring import _EventListeners - from pymongo.synchronous.client_session import ClientSession - from pymongo.synchronous.mongo_client import MongoClient - from pymongo.synchronous.pool import Connection - from pymongo.typings import _DocumentOut - - -_IS_SYNC = True - - -class _BulkWriteContext: - """A wrapper around AsyncConnection for use with write splitting functions.""" - - __slots__ = ( - "db_name", - "conn", - "op_id", - "name", - "field", - "publish", - "start_time", - "listeners", - "session", - "compress", - "op_type", - "codec", - ) - - def __init__( - self, - database_name: str, - cmd_name: str, - conn: Connection, - operation_id: int, - listeners: _EventListeners, - session: ClientSession, - op_type: int, - codec: CodecOptions, - ): - self.db_name = database_name - self.conn = conn - self.op_id = operation_id - self.listeners = listeners - self.publish = listeners.enabled_for_commands - self.name = cmd_name - self.field = _FIELD_MAP[self.name] - self.start_time = datetime.datetime.now() - self.session = session - self.compress = bool(conn.compression_context) - self.op_type = op_type - self.codec = codec - - def __batch_command( - self, cmd: MutableMapping[str, Any], docs: list[Mapping[str, Any]] - ) -> tuple[int, bytes, list[Mapping[str, Any]]]: - namespace = self.db_name + ".$cmd" - request_id, msg, to_send = _do_batched_op_msg( - namespace, self.op_type, cmd, docs, self.codec, self - ) - if not to_send: - raise InvalidOperation("cannot do an empty bulk write") - return request_id, msg, to_send - - def execute( - self, cmd: MutableMapping[str, Any], docs: list[Mapping[str, Any]], client: MongoClient - ) -> tuple[Mapping[str, Any], list[Mapping[str, Any]]]: - request_id, msg, to_send = self.__batch_command(cmd, docs) - result = self.write_command(cmd, request_id, msg, to_send, client) - client._process_response(result, self.session) - return result, to_send - - def execute_unack( - self, cmd: MutableMapping[str, Any], docs: list[Mapping[str, Any]], client: MongoClient - ) -> list[Mapping[str, Any]]: - request_id, msg, to_send = self.__batch_command(cmd, docs) - # Though this isn't strictly a "legacy" write, the helper - # handles publishing commands and sending our message - # without receiving a result. Send 0 for max_doc_size - # to disable size checking. Size checking is handled while - # the documents are encoded to BSON. - self.unack_write(cmd, request_id, msg, 0, to_send, client) - return to_send - - @property - def max_bson_size(self) -> int: - """A proxy for SockInfo.max_bson_size.""" - return self.conn.max_bson_size - - @property - def max_message_size(self) -> int: - """A proxy for SockInfo.max_message_size.""" - if self.compress: - # Subtract 16 bytes for the message header. - return self.conn.max_message_size - 16 - return self.conn.max_message_size - - @property - def max_write_batch_size(self) -> int: - """A proxy for SockInfo.max_write_batch_size.""" - return self.conn.max_write_batch_size - - @property - def max_split_size(self) -> int: - """The maximum size of a BSON command before batch splitting.""" - return self.max_bson_size - - def unack_write( - self, - cmd: MutableMapping[str, Any], - request_id: int, - msg: bytes, - max_doc_size: int, - docs: list[Mapping[str, Any]], - client: MongoClient, - ) -> Optional[Mapping[str, Any]]: - """A proxy for AsyncConnection.unack_write that handles event publishing.""" - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - clientId=client._topology_settings._topology_id, - message=_CommandStatusMessage.STARTED, - command=cmd, - commandName=next(iter(cmd)), - databaseName=self.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=self.conn.id, - serverConnectionId=self.conn.server_connection_id, - serverHost=self.conn.address[0], - serverPort=self.conn.address[1], - serviceId=self.conn.service_id, - ) - if self.publish: - cmd = self._start(cmd, request_id, docs) - try: - result = self.conn.unack_write(msg, max_doc_size) # type: ignore[func-returns-value] - duration = datetime.datetime.now() - self.start_time - if result is not None: - reply = _convert_write_result(self.name, cmd, result) - else: - # Comply with APM spec. - reply = {"ok": 1} - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - clientId=client._topology_settings._topology_id, - message=_CommandStatusMessage.SUCCEEDED, - durationMS=duration, - reply=reply, - commandName=next(iter(cmd)), - databaseName=self.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=self.conn.id, - serverConnectionId=self.conn.server_connection_id, - serverHost=self.conn.address[0], - serverPort=self.conn.address[1], - serviceId=self.conn.service_id, - ) - if self.publish: - self._succeed(request_id, reply, duration) - except Exception as exc: - duration = datetime.datetime.now() - self.start_time - if isinstance(exc, OperationFailure): - failure: _DocumentOut = _convert_write_result(self.name, cmd, exc.details) # type: ignore[arg-type] - elif isinstance(exc, NotPrimaryError): - failure = exc.details # type: ignore[assignment] - else: - failure = _convert_exception(exc) - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - clientId=client._topology_settings._topology_id, - message=_CommandStatusMessage.FAILED, - durationMS=duration, - failure=failure, - commandName=next(iter(cmd)), - databaseName=self.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=self.conn.id, - serverConnectionId=self.conn.server_connection_id, - serverHost=self.conn.address[0], - serverPort=self.conn.address[1], - serviceId=self.conn.service_id, - isServerSideError=isinstance(exc, OperationFailure), - ) - if self.publish: - assert self.start_time is not None - self._fail(request_id, failure, duration) - raise - finally: - self.start_time = datetime.datetime.now() - return result - - @_handle_reauth - def write_command( - self, - cmd: MutableMapping[str, Any], - request_id: int, - msg: bytes, - docs: list[Mapping[str, Any]], - client: MongoClient, - ) -> dict[str, Any]: - """A proxy for SocketInfo.write_command that handles event publishing.""" - cmd[self.field] = docs - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - clientId=client._topology_settings._topology_id, - message=_CommandStatusMessage.STARTED, - command=cmd, - commandName=next(iter(cmd)), - databaseName=self.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=self.conn.id, - serverConnectionId=self.conn.server_connection_id, - serverHost=self.conn.address[0], - serverPort=self.conn.address[1], - serviceId=self.conn.service_id, - ) - if self.publish: - self._start(cmd, request_id, docs) - try: - reply = self.conn.write_command(request_id, msg, self.codec) - duration = datetime.datetime.now() - self.start_time - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - clientId=client._topology_settings._topology_id, - message=_CommandStatusMessage.SUCCEEDED, - durationMS=duration, - reply=reply, - commandName=next(iter(cmd)), - databaseName=self.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=self.conn.id, - serverConnectionId=self.conn.server_connection_id, - serverHost=self.conn.address[0], - serverPort=self.conn.address[1], - serviceId=self.conn.service_id, - ) - if self.publish: - self._succeed(request_id, reply, duration) - except Exception as exc: - duration = datetime.datetime.now() - self.start_time - if isinstance(exc, (NotPrimaryError, OperationFailure)): - failure: _DocumentOut = exc.details # type: ignore[assignment] - else: - failure = _convert_exception(exc) - if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): - _debug_log( - _COMMAND_LOGGER, - clientId=client._topology_settings._topology_id, - message=_CommandStatusMessage.FAILED, - durationMS=duration, - failure=failure, - commandName=next(iter(cmd)), - databaseName=self.db_name, - requestId=request_id, - operationId=request_id, - driverConnectionId=self.conn.id, - serverConnectionId=self.conn.server_connection_id, - serverHost=self.conn.address[0], - serverPort=self.conn.address[1], - serviceId=self.conn.service_id, - isServerSideError=isinstance(exc, OperationFailure), - ) - - if self.publish: - self._fail(request_id, failure, duration) - raise - finally: - self.start_time = datetime.datetime.now() - return reply - - def _start( - self, cmd: MutableMapping[str, Any], request_id: int, docs: list[Mapping[str, Any]] - ) -> MutableMapping[str, Any]: - """Publish a CommandStartedEvent.""" - cmd[self.field] = docs - self.listeners.publish_command_start( - cmd, - self.db_name, - request_id, - self.conn.address, - self.conn.server_connection_id, - self.op_id, - self.conn.service_id, - ) - return cmd - - def _succeed(self, request_id: int, reply: _DocumentOut, duration: timedelta) -> None: - """Publish a CommandSucceededEvent.""" - self.listeners.publish_command_success( - duration, - reply, - self.name, - request_id, - self.conn.address, - self.conn.server_connection_id, - self.op_id, - self.conn.service_id, - database_name=self.db_name, - ) - - def _fail(self, request_id: int, failure: _DocumentOut, duration: timedelta) -> None: - """Publish a CommandFailedEvent.""" - self.listeners.publish_command_failure( - duration, - failure, - self.name, - request_id, - self.conn.address, - self.conn.server_connection_id, - self.op_id, - self.conn.service_id, - database_name=self.db_name, - ) - - -# From the Client Side Encryption spec: -# Because automatic encryption increases the size of commands, the driver -# MUST split bulk writes at a reduced size limit before undergoing automatic -# encryption. The write payload MUST be split at 2MiB (2097152). -_MAX_SPLIT_SIZE_ENC = 2097152 - - -class _EncryptedBulkWriteContext(_BulkWriteContext): - __slots__ = () - - def __batch_command( - self, cmd: MutableMapping[str, Any], docs: list[Mapping[str, Any]] - ) -> tuple[dict[str, Any], list[Mapping[str, Any]]]: - namespace = self.db_name + ".$cmd" - msg, to_send = _encode_batched_write_command( - namespace, self.op_type, cmd, docs, self.codec, self - ) - if not to_send: - raise InvalidOperation("cannot do an empty bulk write") - - # Chop off the OP_QUERY header to get a properly batched write command. - cmd_start = msg.index(b"\x00", 4) + 9 - outgoing = _inflate_bson(memoryview(msg)[cmd_start:], DEFAULT_RAW_BSON_OPTIONS) - return outgoing, to_send - - def execute( - self, cmd: MutableMapping[str, Any], docs: list[Mapping[str, Any]], client: MongoClient - ) -> tuple[Mapping[str, Any], list[Mapping[str, Any]]]: - batched_cmd, to_send = self.__batch_command(cmd, docs) - result: Mapping[str, Any] = self.conn.command( - self.db_name, batched_cmd, codec_options=self.codec, session=self.session, client=client - ) - return result, to_send - - def execute_unack( - self, cmd: MutableMapping[str, Any], docs: list[Mapping[str, Any]], client: MongoClient - ) -> list[Mapping[str, Any]]: - batched_cmd, to_send = self.__batch_command(cmd, docs) - self.conn.command( - self.db_name, - batched_cmd, - write_concern=WriteConcern(w=0), - session=self.session, - client=client, - ) - return to_send - - @property - def max_split_size(self) -> int: - """Reduce the batch splitting size.""" - return _MAX_SPLIT_SIZE_ENC - - -def _batched_op_msg_impl( - operation: int, - command: Mapping[str, Any], - docs: list[Mapping[str, Any]], - ack: bool, - opts: CodecOptions, - ctx: _BulkWriteContext, - buf: _BytesIO, -) -> tuple[list[Mapping[str, Any]], int]: - """Create a batched OP_MSG write.""" - max_bson_size = ctx.max_bson_size - max_write_batch_size = ctx.max_write_batch_size - max_message_size = ctx.max_message_size - - flags = b"\x00\x00\x00\x00" if ack else b"\x02\x00\x00\x00" - # Flags - buf.write(flags) - - # Type 0 Section - buf.write(b"\x00") - buf.write(_dict_to_bson(command, False, opts)) - - # Type 1 Section - buf.write(b"\x01") - size_location = buf.tell() - # Save space for size - buf.write(b"\x00\x00\x00\x00") - try: - buf.write(_OP_MSG_MAP[operation]) - except KeyError: - raise InvalidOperation("Unknown command") from None - - to_send = [] - idx = 0 - for doc in docs: - # Encode the current operation - value = _dict_to_bson(doc, False, opts) - doc_length = len(value) - new_message_size = buf.tell() + doc_length - # Does first document exceed max_message_size? - doc_too_large = idx == 0 and (new_message_size > max_message_size) - # When OP_MSG is used unacknowledged we have to check - # document size client side or applications won't be notified. - # Otherwise we let the server deal with documents that are too large - # since ordered=False causes those documents to be skipped instead of - # halting the bulk write operation. - unacked_doc_too_large = not ack and (doc_length > max_bson_size) - if doc_too_large or unacked_doc_too_large: - write_op = list(_FIELD_MAP.keys())[operation] - _raise_document_too_large(write_op, len(value), max_bson_size) - # We have enough data, return this batch. - if new_message_size > max_message_size: - break - buf.write(value) - to_send.append(doc) - idx += 1 - # We have enough documents, return this batch. - if idx == max_write_batch_size: - break - - # Write type 1 section size - length = buf.tell() - buf.seek(size_location) - buf.write(_pack_int(length - size_location)) - - return to_send, length - - -def _encode_batched_op_msg( - operation: int, - command: Mapping[str, Any], - docs: list[Mapping[str, Any]], - ack: bool, - opts: CodecOptions, - ctx: _BulkWriteContext, -) -> tuple[bytes, list[Mapping[str, Any]]]: - """Encode the next batched insert, update, or delete operation - as OP_MSG. - """ - buf = _BytesIO() - - to_send, _ = _batched_op_msg_impl(operation, command, docs, ack, opts, ctx, buf) - return buf.getvalue(), to_send - - -if _use_c: - _encode_batched_op_msg = _cmessage._encode_batched_op_msg - - -def _batched_op_msg_compressed( - operation: int, - command: Mapping[str, Any], - docs: list[Mapping[str, Any]], - ack: bool, - opts: CodecOptions, - ctx: _BulkWriteContext, -) -> tuple[int, bytes, list[Mapping[str, Any]]]: - """Create the next batched insert, update, or delete operation - with OP_MSG, compressed. - """ - data, to_send = _encode_batched_op_msg(operation, command, docs, ack, opts, ctx) - - assert ctx.conn.compression_context is not None - request_id, msg = _compress(2013, data, ctx.conn.compression_context) - return request_id, msg, to_send - - -def _batched_op_msg( - operation: int, - command: Mapping[str, Any], - docs: list[Mapping[str, Any]], - ack: bool, - opts: CodecOptions, - ctx: _BulkWriteContext, -) -> tuple[int, bytes, list[Mapping[str, Any]]]: - """OP_MSG implementation entry point.""" - buf = _BytesIO() - - # Save space for message length and request id - buf.write(_ZERO_64) - # responseTo, opCode - buf.write(b"\x00\x00\x00\x00\xdd\x07\x00\x00") - - to_send, length = _batched_op_msg_impl(operation, command, docs, ack, opts, ctx, buf) - - # Header - request id and message length - buf.seek(4) - request_id = _randint() - buf.write(_pack_int(request_id)) - buf.seek(0) - buf.write(_pack_int(length)) - - return request_id, buf.getvalue(), to_send - - -if _use_c: - _batched_op_msg = _cmessage._batched_op_msg - - -def _do_batched_op_msg( - namespace: str, - operation: int, - command: MutableMapping[str, Any], - docs: list[Mapping[str, Any]], - opts: CodecOptions, - ctx: _BulkWriteContext, -) -> tuple[int, bytes, list[Mapping[str, Any]]]: - """Create the next batched insert, update, or delete operation - using OP_MSG. - """ - command["$db"] = namespace.split(".", 1)[0] - if "writeConcern" in command: - ack = bool(command["writeConcern"].get("w", 1)) - else: - ack = True - if ctx.conn.compression_context: - return _batched_op_msg_compressed(operation, command, docs, ack, opts, ctx) - return _batched_op_msg(operation, command, docs, ack, opts, ctx) - - -# End OP_MSG ----------------------------------------------------- - - -def _encode_batched_write_command( - namespace: str, - operation: int, - command: MutableMapping[str, Any], - docs: list[Mapping[str, Any]], - opts: CodecOptions, - ctx: _BulkWriteContext, -) -> tuple[bytes, list[Mapping[str, Any]]]: - """Encode the next batched insert, update, or delete command.""" - buf = _BytesIO() - - to_send, _ = _batched_write_command_impl(namespace, operation, command, docs, opts, ctx, buf) - return buf.getvalue(), to_send - - -if _use_c: - _encode_batched_write_command = _cmessage._encode_batched_write_command - - -def _batched_write_command_impl( - namespace: str, - operation: int, - command: MutableMapping[str, Any], - docs: list[Mapping[str, Any]], - opts: CodecOptions, - ctx: _BulkWriteContext, - buf: _BytesIO, -) -> tuple[list[Mapping[str, Any]], int]: - """Create a batched OP_QUERY write command.""" - max_bson_size = ctx.max_bson_size - max_write_batch_size = ctx.max_write_batch_size - # Max BSON object size + 16k - 2 bytes for ending NUL bytes. - # Server guarantees there is enough room: SERVER-10643. - max_cmd_size = max_bson_size + _COMMAND_OVERHEAD - max_split_size = ctx.max_split_size - - # No options - buf.write(_ZERO_32) - # Namespace as C string - buf.write(namespace.encode("utf8")) - buf.write(_ZERO_8) - # Skip: 0, Limit: -1 - buf.write(_SKIPLIM) - - # Where to write command document length - command_start = buf.tell() - buf.write(encode(command)) - - # Start of payload - buf.seek(-1, 2) - # Work around some Jython weirdness. - buf.truncate() - try: - buf.write(_OP_MAP[operation]) - except KeyError: - raise InvalidOperation("Unknown command") from None - - # Where to write list document length - list_start = buf.tell() - 4 - to_send = [] - idx = 0 - for doc in docs: - # Encode the current operation - key = str(idx).encode("utf8") - value = _dict_to_bson(doc, False, opts) - # Is there enough room to add this document? max_cmd_size accounts for - # the two trailing null bytes. - doc_too_large = len(value) > max_cmd_size - if doc_too_large: - write_op = list(_FIELD_MAP.keys())[operation] - _raise_document_too_large(write_op, len(value), max_bson_size) - enough_data = idx >= 1 and (buf.tell() + len(key) + len(value)) >= max_split_size - enough_documents = idx >= max_write_batch_size - if enough_data or enough_documents: - break - buf.write(_BSONOBJ) - buf.write(key) - buf.write(_ZERO_8) - buf.write(value) - to_send.append(doc) - idx += 1 - - # Finalize the current OP_QUERY message. - # Close list and command documents - buf.write(_ZERO_16) - - # Write document lengths and request id - length = buf.tell() - buf.seek(list_start) - buf.write(_pack_int(length - list_start - 1)) - buf.seek(command_start) - buf.write(_pack_int(length - command_start)) - - return to_send, length diff --git a/pymongo/synchronous/mongo_client.py b/pymongo/synchronous/mongo_client.py index 4052508ea4..0ef87907f3 100644 --- a/pymongo/synchronous/mongo_client.py +++ b/pymongo/synchronous/mongo_client.py @@ -1667,13 +1667,13 @@ def _run_operation( if operation.conn_mgr: server = self._select_server( operation.read_preference, - operation.session, + operation.session, # type: ignore[arg-type] operation.name, address=address, ) with operation.conn_mgr._alock: - with _MongoClientErrorHandler(self, server, operation.session) as err_handler: + with _MongoClientErrorHandler(self, server, operation.session) as err_handler: # type: ignore[arg-type] err_handler.contribute_socket(operation.conn_mgr.conn) return server.run_operation( operation.conn_mgr.conn, @@ -1703,7 +1703,7 @@ def _cmd( return self._retryable_read( _cmd, operation.read_preference, - operation.session, + operation.session, # type: ignore[arg-type] address=address, retryable=isinstance(operation, _Query), operation=operation.name, diff --git a/pymongo/synchronous/server.py b/pymongo/synchronous/server.py index 4f9f1c1462..fea55b8382 100644 --- a/pymongo/synchronous/server.py +++ b/pymongo/synchronous/server.py @@ -109,27 +109,14 @@ def request_check(self) -> None: def operation_to_command( self, operation: Union[_Query, _GetMore], conn: Connection, apply_timeout: bool = False ) -> tuple[dict[str, Any], str]: - is_query = isinstance(operation, _Query) - if is_query: - explain = "$explain" in operation.spec - cmd, db = operation.as_command() - else: - explain = False - cmd, db = operation.as_command(conn) - if operation.session: - operation.session._apply_to(cmd, False, operation.read_preference, conn) - # Explain does not support readConcern. - if is_query and not explain and not operation.session.in_transaction: - operation.session._update_read_concern(cmd, conn) + cmd, db = operation.as_command(conn, apply_timeout) # Support auto encryption if operation.client._encrypter and not operation.client._encrypter._bypass_auto_encryption: - cmd = operation.client._encrypter.encrypt(operation.db, cmd, operation.codec_options) + cmd = operation.client._encrypter.encrypt( # type: ignore[misc, assignment] + operation.db, cmd, operation.codec_options + ) + operation.update_command(cmd) - conn.add_server_api(cmd) - conn.send_cluster_time(cmd, operation.session, operation.client) - # Support CSOT - if apply_timeout: - conn.apply_timeout(operation.client, cmd=cmd if is_query else None) return cmd, db @_handle_reauth @@ -221,7 +208,7 @@ def run_operation( ) if use_cmd: first = docs[0] - operation.client._process_response(first, operation.session) + operation.client._process_response(first, operation.session) # type: ignore[misc, arg-type] _check_command_response(first, conn.max_wire_version) except Exception as exc: duration = datetime.now() - start @@ -304,7 +291,7 @@ def run_operation( ) # Decrypt response. - client = operation.client + client = operation.client # type: ignore[assignment] if client and client._encrypter: if use_cmd: decrypted = client._encrypter.decrypt(reply.raw_command_response()) @@ -312,7 +299,7 @@ def run_operation( response: Response - if client._should_pin_cursor(operation.session) or operation.exhaust: + if client._should_pin_cursor(operation.session) or operation.exhaust: # type: ignore[arg-type] conn.pin_cursor() if isinstance(reply, _OpMsg): # In OP_MSG, the server keeps sending only if the diff --git a/pymongo/typings.py b/pymongo/typings.py index 1923f918b1..5a4a2a0fc3 100644 --- a/pymongo/typings.py +++ b/pymongo/typings.py @@ -15,25 +15,34 @@ """Type aliases used by PyMongo""" from __future__ import annotations +from abc import ABC, abstractmethod from typing import ( TYPE_CHECKING, Any, + Iterator, Mapping, + MutableMapping, + NoReturn, Optional, Sequence, Tuple, + Type, TypeVar, Union, ) +from bson import Int64, Timestamp from bson.typings import _DocumentOut, _DocumentType, _DocumentTypeArg +from pymongo.read_preferences import _ServerMode if TYPE_CHECKING: from pymongo import AsyncMongoClient, MongoClient from pymongo.asynchronous.bulk import _AsyncBulk - from pymongo.asynchronous.client_session import AsyncClientSession + from pymongo.asynchronous.client_session import AsyncClientSession, SessionOptions from pymongo.asynchronous.pool import AsyncConnection + from pymongo.bulk_shared import _Run from pymongo.collation import Collation + from pymongo.message import _BulkWriteContext from pymongo.synchronous.bulk import _Bulk from pymongo.synchronous.client_session import ClientSession from pymongo.synchronous.pool import Connection @@ -49,10 +58,184 @@ _T = TypeVar("_T") # Type hinting types for compatibility between async and sync classes -_AgnosticClientSession = Union["AsyncClientSession", "ClientSession"] + + +class _BaseBulk(ABC): + @property + @abstractmethod + def bulk_ctx_class(self) -> Type[_BulkWriteContext]: + ... + + @abstractmethod + def add_insert(self, document: _DocumentOut) -> None: + ... + + @abstractmethod + def add_update( + self, + selector: Mapping[str, Any], + update: Union[Mapping[str, Any], _Pipeline], + multi: bool = False, + upsert: bool = False, + collation: Optional[Mapping[str, Any]] = None, + array_filters: Optional[list[Mapping[str, Any]]] = None, + hint: Union[str, dict[str, Any], None] = None, + ) -> None: + ... + + @abstractmethod + def add_replace( + self, + selector: Mapping[str, Any], + replacement: Mapping[str, Any], + upsert: bool = False, + collation: Optional[Mapping[str, Any]] = None, + hint: Union[str, dict[str, Any], None] = None, + ) -> None: + ... + + @abstractmethod + def add_delete( + self, + selector: Mapping[str, Any], + limit: int, + collation: Optional[Mapping[str, Any]] = None, + hint: Union[str, dict[str, Any], None] = None, + ) -> None: + ... + + @abstractmethod + def gen_ordered(self) -> Iterator[Optional[_Run]]: + ... + + @abstractmethod + def gen_unordered(self) -> Iterator[_Run]: + ... + + +class _BaseClientSession(ABC): + @abstractmethod + def _check_ended(self) -> None: + ... + + @property + @abstractmethod + def client(self) -> AsyncMongoClient: + ... + + @property + @abstractmethod + def options(self) -> SessionOptions: + ... + + @property + @abstractmethod + def session_id(self) -> Mapping[str, Any]: + ... + + @property + @abstractmethod + def _transaction_id(self) -> Int64: + ... + + @property + @abstractmethod + def cluster_time(self) -> Optional[ClusterTime]: + ... + + @property + @abstractmethod + def operation_time(self) -> Optional[Timestamp]: + ... + + @abstractmethod + def _inherit_option(self, name: str, val: _T) -> _T: + ... + + @abstractmethod + def _advance_cluster_time(self, cluster_time: Optional[Mapping[str, Any]]) -> None: + ... + + @abstractmethod + def advance_cluster_time(self, cluster_time: Mapping[str, Any]) -> None: + ... + + @abstractmethod + def _advance_operation_time(self, operation_time: Optional[Timestamp]) -> None: + ... + + @abstractmethod + def advance_operation_time(self, operation_time: Timestamp) -> None: + ... + + @abstractmethod + def _process_response(self, reply: Mapping[str, Any]) -> None: + ... + + @property + @abstractmethod + def has_ended(self) -> bool: + ... + + @property + @abstractmethod + def in_transaction(self) -> bool: + ... + + @property + @abstractmethod + def _starting_transaction(self) -> bool: + ... + + @property + @abstractmethod + def _pinned_address(self) -> Optional[_Address]: + ... + + @property + @abstractmethod + def _pinned_connection(self) -> Optional[Any]: + ... + + @abstractmethod + def _pin(self, server: Any, conn: Any) -> None: + ... + + @abstractmethod + def _txn_read_preference(self) -> Optional[_ServerMode]: + ... + + @abstractmethod + def _materialize(self, logical_session_timeout_minutes: Optional[int] = None) -> None: + ... + + @abstractmethod + def _apply_to( + self, + command: MutableMapping[str, Any], + is_retryable: bool, + read_preference: _ServerMode, + conn: Any, + ) -> None: + ... + + @abstractmethod + def _start_retryable_write(self) -> None: + ... + + @abstractmethod + def _update_read_concern(self, cmd: MutableMapping[str, Any], conn: Any) -> None: + ... + + @abstractmethod + def __copy__(self) -> NoReturn: + ... + + _AgnosticMongoClient = Union["AsyncMongoClient", "MongoClient"] -_AgnosticBulk = Union["_AsyncBulk", "_Bulk"] _AgnosticConnection = Union["AsyncConnection", "Connection"] +_AgnosticClientSession = Union["AsyncClientSession", "ClientSession"] +_AgnosticBulk = Union["_AsyncBulk", "_Bulk"] def strip_optional(elem: Optional[_T]) -> _T: @@ -71,6 +254,5 @@ def strip_optional(elem: Optional[_T]) -> _T: "_CollationIn", "_Pipeline", "strip_optional", - "_AgnosticClientSession", "_AgnosticMongoClient", ] diff --git a/test/__init__.py b/test/__init__.py index 9b6368f4de..f8b27ce193 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -45,12 +45,11 @@ import pymongo import pymongo.errors from bson.son import SON -from pymongo import common +from pymongo import common, message from pymongo.common import partition_node from pymongo.hello_compat import HelloCompat from pymongo.server_api import ServerApi from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined] -from pymongo.synchronous import message from pymongo.synchronous.database import Database from pymongo.synchronous.mongo_client import MongoClient from pymongo.uri_parser import parse_uri diff --git a/test/asynchronous/test_collection.py b/test/asynchronous/test_collection.py index 07a3d25cfe..afd2766849 100644 --- a/test/asynchronous/test_collection.py +++ b/test/asynchronous/test_collection.py @@ -50,7 +50,6 @@ from pymongo.asynchronous.collection import AsyncCollection, ReturnDocument from pymongo.asynchronous.command_cursor import AsyncCommandCursor from pymongo.asynchronous.helpers import anext -from pymongo.asynchronous.message import _COMMAND_OVERHEAD, _gen_find_command from pymongo.asynchronous.mongo_client import AsyncMongoClient from pymongo.cursor_shared import CursorType from pymongo.errors import ( @@ -64,6 +63,7 @@ OperationFailure, WriteConcernError, ) +from pymongo.message import _COMMAND_OVERHEAD, _gen_find_command from pymongo.operations import * from pymongo.read_concern import DEFAULT_READ_CONCERN from pymongo.read_preferences import ReadPreference diff --git a/test/auth_aws/test_auth_aws.py b/test/auth_aws/test_auth_aws.py index f6a6c96949..3e5dcec563 100644 --- a/test/auth_aws/test_auth_aws.py +++ b/test/auth_aws/test_auth_aws.py @@ -26,7 +26,7 @@ from pymongo import MongoClient from pymongo.errors import OperationFailure -from pymongo.synchronous.uri_parser import parse_uri +from pymongo.uri_parser import parse_uri class TestAuthAWS(unittest.TestCase): diff --git a/test/auth_oidc/test_auth_oidc.py b/test/auth_oidc/test_auth_oidc.py index 76676eb95e..9f7941fbf7 100644 --- a/test/auth_oidc/test_auth_oidc.py +++ b/test/auth_oidc/test_auth_oidc.py @@ -36,14 +36,14 @@ from pymongo._gcp_helpers import _get_gcp_response from pymongo.cursor_shared import CursorType from pymongo.errors import AutoReconnect, ConfigurationError, OperationFailure +from pymongo.hello_compat import HelloCompat from pymongo.operations import InsertOne from pymongo.synchronous.auth_oidc import ( OIDCCallback, OIDCCallbackContext, OIDCCallbackResult, ) -from pymongo.synchronous.hello_compat import HelloCompat -from pymongo.synchronous.uri_parser import parse_uri +from pymongo.uri_parser import parse_uri ROOT = Path(__file__).parent.parent.resolve() TEST_PATH = ROOT / "auth" / "unified" diff --git a/test/synchronous/test_collection.py b/test/synchronous/test_collection.py index d0de0b0608..84c42feabc 100644 --- a/test/synchronous/test_collection.py +++ b/test/synchronous/test_collection.py @@ -57,6 +57,7 @@ OperationFailure, WriteConcernError, ) +from pymongo.message import _COMMAND_OVERHEAD, _gen_find_command from pymongo.operations import * from pymongo.read_concern import DEFAULT_READ_CONCERN from pymongo.read_preferences import ReadPreference @@ -70,7 +71,6 @@ from pymongo.synchronous.collection import Collection, ReturnDocument from pymongo.synchronous.command_cursor import CommandCursor from pymongo.synchronous.helpers import next -from pymongo.synchronous.message import _COMMAND_OVERHEAD, _gen_find_command from pymongo.synchronous.mongo_client import MongoClient from pymongo.write_concern import WriteConcern diff --git a/test/test_client.py b/test/test_client.py index b5c438a66f..e5d1174d05 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -106,7 +106,6 @@ from pymongo.server_description import ServerDescription from pymongo.server_selectors import readable_server_selector, writable_server_selector from pymongo.server_type import SERVER_TYPE -from pymongo.synchronous import message as message_old from pymongo.synchronous.command_cursor import CommandCursor from pymongo.synchronous.cursor import Cursor, CursorType from pymongo.synchronous.database import Database @@ -1458,7 +1457,7 @@ def test_stale_getmore(self): with self.assertRaises(AutoReconnect): client = rs_client(connect=False, serverSelectionTimeoutMS=100) client._run_operation( - operation=message_old._GetMore( + operation=message._GetMore( "pymongo_test", "collection", 101, diff --git a/test/test_collection.py b/test/test_collection.py index d2d5d60830..0de506e0f3 100644 --- a/test/test_collection.py +++ b/test/test_collection.py @@ -45,6 +45,7 @@ from bson.regex import Regex from bson.son import SON from pymongo import ASCENDING, DESCENDING, GEO2D, GEOSPHERE, HASHED, TEXT +from pymongo.bulk_shared import BulkWriteError from pymongo.cursor_shared import CursorType from pymongo.errors import ( ConfigurationError, @@ -67,7 +68,6 @@ InsertOneResult, UpdateResult, ) -from pymongo.synchronous.bulk import BulkWriteError from pymongo.synchronous.collection import Collection, ReturnDocument from pymongo.synchronous.command_cursor import CommandCursor from pymongo.synchronous.mongo_client import MongoClient From 7126a17fb9934dc56e1a980266677ee1d18998a6 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Wed, 12 Jun 2024 15:16:08 -0700 Subject: [PATCH 04/11] Remove unused abstract parent classes --- pymongo/typings.py | 185 +-------------------------- test/asynchronous/test_collection.py | 2 +- test/synchronous/test_collection.py | 2 +- 3 files changed, 3 insertions(+), 186 deletions(-) diff --git a/pymongo/typings.py b/pymongo/typings.py index 5a4a2a0fc3..e0593517a8 100644 --- a/pymongo/typings.py +++ b/pymongo/typings.py @@ -15,34 +15,25 @@ """Type aliases used by PyMongo""" from __future__ import annotations -from abc import ABC, abstractmethod from typing import ( TYPE_CHECKING, Any, - Iterator, Mapping, - MutableMapping, - NoReturn, Optional, Sequence, Tuple, - Type, TypeVar, Union, ) -from bson import Int64, Timestamp from bson.typings import _DocumentOut, _DocumentType, _DocumentTypeArg -from pymongo.read_preferences import _ServerMode if TYPE_CHECKING: from pymongo import AsyncMongoClient, MongoClient from pymongo.asynchronous.bulk import _AsyncBulk - from pymongo.asynchronous.client_session import AsyncClientSession, SessionOptions + from pymongo.asynchronous.client_session import AsyncClientSession from pymongo.asynchronous.pool import AsyncConnection - from pymongo.bulk_shared import _Run from pymongo.collation import Collation - from pymongo.message import _BulkWriteContext from pymongo.synchronous.bulk import _Bulk from pymongo.synchronous.client_session import ClientSession from pymongo.synchronous.pool import Connection @@ -58,180 +49,6 @@ _T = TypeVar("_T") # Type hinting types for compatibility between async and sync classes - - -class _BaseBulk(ABC): - @property - @abstractmethod - def bulk_ctx_class(self) -> Type[_BulkWriteContext]: - ... - - @abstractmethod - def add_insert(self, document: _DocumentOut) -> None: - ... - - @abstractmethod - def add_update( - self, - selector: Mapping[str, Any], - update: Union[Mapping[str, Any], _Pipeline], - multi: bool = False, - upsert: bool = False, - collation: Optional[Mapping[str, Any]] = None, - array_filters: Optional[list[Mapping[str, Any]]] = None, - hint: Union[str, dict[str, Any], None] = None, - ) -> None: - ... - - @abstractmethod - def add_replace( - self, - selector: Mapping[str, Any], - replacement: Mapping[str, Any], - upsert: bool = False, - collation: Optional[Mapping[str, Any]] = None, - hint: Union[str, dict[str, Any], None] = None, - ) -> None: - ... - - @abstractmethod - def add_delete( - self, - selector: Mapping[str, Any], - limit: int, - collation: Optional[Mapping[str, Any]] = None, - hint: Union[str, dict[str, Any], None] = None, - ) -> None: - ... - - @abstractmethod - def gen_ordered(self) -> Iterator[Optional[_Run]]: - ... - - @abstractmethod - def gen_unordered(self) -> Iterator[_Run]: - ... - - -class _BaseClientSession(ABC): - @abstractmethod - def _check_ended(self) -> None: - ... - - @property - @abstractmethod - def client(self) -> AsyncMongoClient: - ... - - @property - @abstractmethod - def options(self) -> SessionOptions: - ... - - @property - @abstractmethod - def session_id(self) -> Mapping[str, Any]: - ... - - @property - @abstractmethod - def _transaction_id(self) -> Int64: - ... - - @property - @abstractmethod - def cluster_time(self) -> Optional[ClusterTime]: - ... - - @property - @abstractmethod - def operation_time(self) -> Optional[Timestamp]: - ... - - @abstractmethod - def _inherit_option(self, name: str, val: _T) -> _T: - ... - - @abstractmethod - def _advance_cluster_time(self, cluster_time: Optional[Mapping[str, Any]]) -> None: - ... - - @abstractmethod - def advance_cluster_time(self, cluster_time: Mapping[str, Any]) -> None: - ... - - @abstractmethod - def _advance_operation_time(self, operation_time: Optional[Timestamp]) -> None: - ... - - @abstractmethod - def advance_operation_time(self, operation_time: Timestamp) -> None: - ... - - @abstractmethod - def _process_response(self, reply: Mapping[str, Any]) -> None: - ... - - @property - @abstractmethod - def has_ended(self) -> bool: - ... - - @property - @abstractmethod - def in_transaction(self) -> bool: - ... - - @property - @abstractmethod - def _starting_transaction(self) -> bool: - ... - - @property - @abstractmethod - def _pinned_address(self) -> Optional[_Address]: - ... - - @property - @abstractmethod - def _pinned_connection(self) -> Optional[Any]: - ... - - @abstractmethod - def _pin(self, server: Any, conn: Any) -> None: - ... - - @abstractmethod - def _txn_read_preference(self) -> Optional[_ServerMode]: - ... - - @abstractmethod - def _materialize(self, logical_session_timeout_minutes: Optional[int] = None) -> None: - ... - - @abstractmethod - def _apply_to( - self, - command: MutableMapping[str, Any], - is_retryable: bool, - read_preference: _ServerMode, - conn: Any, - ) -> None: - ... - - @abstractmethod - def _start_retryable_write(self) -> None: - ... - - @abstractmethod - def _update_read_concern(self, cmd: MutableMapping[str, Any], conn: Any) -> None: - ... - - @abstractmethod - def __copy__(self) -> NoReturn: - ... - - _AgnosticMongoClient = Union["AsyncMongoClient", "MongoClient"] _AgnosticConnection = Union["AsyncConnection", "Connection"] _AgnosticClientSession = Union["AsyncClientSession", "ClientSession"] diff --git a/test/asynchronous/test_collection.py b/test/asynchronous/test_collection.py index afd2766849..7e907eaf34 100644 --- a/test/asynchronous/test_collection.py +++ b/test/asynchronous/test_collection.py @@ -46,11 +46,11 @@ from bson.regex import Regex from bson.son import SON from pymongo import ASCENDING, DESCENDING, GEO2D, GEOSPHERE, HASHED, TEXT -from pymongo.asynchronous.bulk import BulkWriteError from pymongo.asynchronous.collection import AsyncCollection, ReturnDocument from pymongo.asynchronous.command_cursor import AsyncCommandCursor from pymongo.asynchronous.helpers import anext from pymongo.asynchronous.mongo_client import AsyncMongoClient +from pymongo.bulk_shared import BulkWriteError from pymongo.cursor_shared import CursorType from pymongo.errors import ( ConfigurationError, diff --git a/test/synchronous/test_collection.py b/test/synchronous/test_collection.py index 84c42feabc..a11d1b6bb4 100644 --- a/test/synchronous/test_collection.py +++ b/test/synchronous/test_collection.py @@ -45,6 +45,7 @@ from bson.regex import Regex from bson.son import SON from pymongo import ASCENDING, DESCENDING, GEO2D, GEOSPHERE, HASHED, TEXT +from pymongo.bulk_shared import BulkWriteError from pymongo.cursor_shared import CursorType from pymongo.errors import ( ConfigurationError, @@ -67,7 +68,6 @@ InsertOneResult, UpdateResult, ) -from pymongo.synchronous.bulk import BulkWriteError from pymongo.synchronous.collection import Collection, ReturnDocument from pymongo.synchronous.command_cursor import CommandCursor from pymongo.synchronous.helpers import next From 55ff5f0f8d7da29bb91eab8e67311ea007395620 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Wed, 12 Jun 2024 16:06:25 -0700 Subject: [PATCH 05/11] Fix gridfs synchronous docstrings --- gridfs/synchronous/grid_file.py | 60 ++++++++++++++++----------------- tools/synchro.py | 2 +- 2 files changed, 31 insertions(+), 31 deletions(-) diff --git a/gridfs/synchronous/grid_file.py b/gridfs/synchronous/grid_file.py index ee43f01897..98374cc8cb 100644 --- a/gridfs/synchronous/grid_file.py +++ b/gridfs/synchronous/grid_file.py @@ -164,7 +164,7 @@ def get(self, file_id: Any, session: Optional[ClientSession] = None) -> GridOut: :param file_id: ``"_id"`` of the file to get :param session: a - :class:`~pymongo.client_session.AsyncClientSession` + :class:`~pymongo.client_session.ClientSession` .. versionchanged:: 3.6 Added ``session`` parameter. @@ -206,7 +206,7 @@ def get_version( :param version: version of the file to get (defaults to -1, the most recent version uploaded) :param session: a - :class:`~pymongo.client_session.AsyncClientSession` + :class:`~pymongo.client_session.ClientSession` :param kwargs: find files by custom metadata. .. versionchanged:: 3.6 @@ -248,7 +248,7 @@ def get_last_version( :param filename: ``"filename"`` of the file to get, or `None` :param session: a - :class:`~pymongo.client_session.AsyncClientSession` + :class:`~pymongo.client_session.ClientSession` :param kwargs: find files by custom metadata. .. versionchanged:: 3.6 @@ -273,7 +273,7 @@ def delete(self, file_id: Any, session: Optional[ClientSession] = None) -> None: :param file_id: ``"_id"`` of the file to delete :param session: a - :class:`~pymongo.client_session.AsyncClientSession` + :class:`~pymongo.client_session.ClientSession` .. versionchanged:: 3.6 Added ``session`` parameter. @@ -290,7 +290,7 @@ def list(self, session: Optional[ClientSession] = None) -> list[str]: :class:`GridFS`. :param session: a - :class:`~pymongo.client_session.AsyncClientSession` + :class:`~pymongo.client_session.ClientSession` .. versionchanged:: 3.6 Added ``session`` parameter. @@ -329,7 +329,7 @@ def find_one( :param args: any additional positional arguments are the same as the arguments to :meth:`find`. :param session: a - :class:`~pymongo.client_session.AsyncClientSession` + :class:`~pymongo.client_session.ClientSession` :param kwargs: any additional keyword arguments are the same as the arguments to :meth:`find`. @@ -372,7 +372,7 @@ def find(self, *args: Any, **kwargs: Any) -> GridOutCursor: :meth:`~pymongo.collection.Collection.find` in :class:`~pymongo.collection.Collection`. - If a :class:`~pymongo.client_session.AsyncClientSession` is passed to + If a :class:`~pymongo.client_session.ClientSession` is passed to :meth:`find`, all returned :class:`~gridfs.grid_file.GridOut` instances are associated with that session. @@ -440,7 +440,7 @@ def exists( :param document_or_id: query document, or _id of the document to check for :param session: a - :class:`~pymongo.client_session.AsyncClientSession` + :class:`~pymongo.client_session.ClientSession` :param kwargs: keyword arguments are used as a query document, if they're present. @@ -501,7 +501,7 @@ def __init__( .. seealso:: The MongoDB documentation on `gridfs `_. """ if not isinstance(db, Database): - raise TypeError("database must be an instance of AsyncDatabase") + raise TypeError("database must be an instance of Database") db = _clear_entity_type_registry(db) @@ -558,7 +558,7 @@ def open_upload_stream( files collection document. If not provided the metadata field will be omitted from the files collection document. :param session: a - :class:`~pymongo.client_session.AsyncClientSession` + :class:`~pymongo.client_session.ClientSession` .. versionchanged:: 3.6 Added ``session`` parameter. @@ -617,7 +617,7 @@ def open_upload_stream_with_id( files collection document. If not provided the metadata field will be omitted from the files collection document. :param session: a - :class:`~pymongo.client_session.AsyncClientSession` + :class:`~pymongo.client_session.ClientSession` .. versionchanged:: 3.6 Added ``session`` parameter. @@ -674,7 +674,7 @@ def upload_from_stream( files collection document. If not provided the metadata field will be omitted from the files collection document. :param session: a - :class:`~pymongo.client_session.AsyncClientSession` + :class:`~pymongo.client_session.ClientSession` .. versionchanged:: 3.6 Added ``session`` parameter. @@ -724,7 +724,7 @@ def upload_from_stream_with_id( files collection document. If not provided the metadata field will be omitted from the files collection document. :param session: a - :class:`~pymongo.client_session.AsyncClientSession` + :class:`~pymongo.client_session.ClientSession` .. versionchanged:: 3.6 Added ``session`` parameter. @@ -755,7 +755,7 @@ def open_download_stream( :param file_id: The _id of the file to be downloaded. :param session: a - :class:`~pymongo.client_session.AsyncClientSession` + :class:`~pymongo.client_session.ClientSession` .. versionchanged:: 3.6 Added ``session`` parameter. @@ -790,7 +790,7 @@ def download_to_stream( :param file_id: The _id of the file to be downloaded. :param destination: a file-like object implementing :meth:`write`. :param session: a - :class:`~pymongo.client_session.AsyncClientSession` + :class:`~pymongo.client_session.ClientSession` .. versionchanged:: 3.6 Added ``session`` parameter. @@ -819,7 +819,7 @@ def delete(self, file_id: Any, session: Optional[ClientSession] = None) -> None: :param file_id: The _id of the file to be deleted. :param session: a - :class:`~pymongo.client_session.AsyncClientSession` + :class:`~pymongo.client_session.ClientSession` .. versionchanged:: 3.6 Added ``session`` parameter. @@ -859,7 +859,7 @@ def find(self, *args: Any, **kwargs: Any) -> GridOutCursor: :meth:`~pymongo.collection.Collection.find` in :class:`~pymongo.collection.Collection`. - If a :class:`~pymongo.client_session.AsyncClientSession` is passed to + If a :class:`~pymongo.client_session.ClientSession` is passed to :meth:`find`, all returned :class:`~gridfs.grid_file.GridOut` instances are associated with that session. @@ -902,7 +902,7 @@ def open_download_stream_by_name( filename and different uploadDate) of the file to retrieve. Defaults to -1 (the most recent revision). :param session: a - :class:`~pymongo.client_session.AsyncClientSession` + :class:`~pymongo.client_session.ClientSession` :Note: Revision numbers are defined as follows: @@ -961,7 +961,7 @@ def download_to_stream_by_name( filename and different uploadDate) of the file to retrieve. Defaults to -1 (the most recent revision). :param session: a - :class:`~pymongo.client_session.AsyncClientSession` + :class:`~pymongo.client_session.ClientSession` :Note: Revision numbers are defined as follows: @@ -1000,7 +1000,7 @@ def rename( :param file_id: The _id of the file to be renamed. :param new_filename: The new name of the file. :param session: a - :class:`~pymongo.client_session.AsyncClientSession` + :class:`~pymongo.client_session.ClientSession` .. versionchanged:: 3.6 Added ``session`` parameter. @@ -1032,7 +1032,7 @@ def __init__( provided by :class:`~gridfs.GridFS`. Raises :class:`TypeError` if `root_collection` is not an - instance of :class:`~pymongo.collection.AsyncCollection`. + instance of :class:`~pymongo.collection.Collection`. Any of the file level options specified in the `GridFS Spec `_ may be passed as @@ -1057,7 +1057,7 @@ def __init__( :param root_collection: root collection to write to :param session: a - :class:`~pymongo.client_session.AsyncClientSession` to use for all + :class:`~pymongo.client_session.ClientSession` to use for all commands :param kwargs: Any: file level options (see above) @@ -1073,10 +1073,10 @@ def __init__( .. versionchanged:: 3.0 `root_collection` must use an acknowledged - :attr:`~pymongo.collection.AsyncCollection.write_concern` + :attr:`~pymongo.collection.Collection.write_concern` """ if not isinstance(root_collection, Collection): - raise TypeError("root_collection must be an instance of AsyncCollection") + raise TypeError("root_collection must be an instance of Collection") if not root_collection.write_concern.acknowledged: raise ConfigurationError("root_collection must use acknowledged write_concern") @@ -1405,14 +1405,14 @@ def __init__( Either `file_id` or `file_document` must be specified, `file_document` will be given priority if present. Raises :class:`TypeError` if `root_collection` is not an instance of - :class:`~pymongo.collection.AsyncCollection`. + :class:`~pymongo.collection.Collection`. :param root_collection: root collection to read from :param file_id: value of ``"_id"`` for the file to read :param file_document: file document from `root_collection.files` :param session: a - :class:`~pymongo.client_session.AsyncClientSession` to use for all + :class:`~pymongo.client_session.ClientSession` to use for all commands .. versionchanged:: 3.8 @@ -1428,7 +1428,7 @@ def __init__( from the server. Metadata is fetched when first needed. """ if not isinstance(root_collection, Collection): - raise TypeError("root_collection must be an instance of AsyncCollection") + raise TypeError("root_collection must be an instance of Collection") _disallow_transactions(session) root_collection = _clear_entity_type_registry(root_collection) @@ -1486,7 +1486,7 @@ def __getattr__(self, name: str) -> Any: self.open() # type: ignore[unused-coroutine] elif not self._file: raise InvalidOperation( - "You must call AsyncGridOut.open() before accessing the %s property" % name + "You must call GridOut.open() before accessing the %s property" % name ) if name in self._file: return self._file[name] @@ -1681,13 +1681,13 @@ def writable(self) -> bool: return False def __enter__(self) -> GridOut: - """Makes it possible to use :class:`AsyncGridOut` files + """Makes it possible to use :class:`GridOut` files with the async context manager protocol. """ return self def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> Any: - """Makes it possible to use :class:`AsyncGridOut` files + """Makes it possible to use :class:`GridOut` files with the async context manager protocol. """ self.close() diff --git a/tools/synchro.py b/tools/synchro.py index 825cfb08a2..f4b0683ce6 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -147,7 +147,7 @@ "pool.py", "topology.py", ] -] +] + [_gridfs_dest_base + f for f in ["grid_file.py"]] def process_files(files: list[str]) -> None: From 1eb50d40d1d21b0ba5abf51ada089f655cdbbf37 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Thu, 13 Jun 2024 10:30:02 -0700 Subject: [PATCH 06/11] Merge hello_compat and hello --- pymongo/asynchronous/pool.py | 3 +-- pymongo/compression_support.py | 2 +- pymongo/hello.py | 9 ++++++++- pymongo/hello_compat.py | 26 -------------------------- pymongo/helpers_shared.py | 2 +- pymongo/message.py | 2 +- pymongo/monitoring.py | 3 +-- pymongo/synchronous/pool.py | 3 +-- pymongo/typings.py | 3 ++- test/__init__.py | 2 +- test/asynchronous/__init__.py | 2 +- test/auth_oidc/test_auth_oidc.py | 2 +- test/pymongo_mocks.py | 3 +-- test/synchronous/__init__.py | 2 +- test/test_auth.py | 2 +- test/test_discovery_and_monitoring.py | 1 - test/test_pooling.py | 2 +- test/test_server_selection.py | 2 +- test/test_ssl.py | 2 +- test/test_streaming_protocol.py | 2 +- test/utils.py | 2 +- 21 files changed, 27 insertions(+), 50 deletions(-) delete mode 100644 pymongo/hello_compat.py diff --git a/pymongo/asynchronous/pool.py b/pymongo/asynchronous/pool.py index 12697a36e4..4dd4ad8df3 100644 --- a/pymongo/asynchronous/pool.py +++ b/pymongo/asynchronous/pool.py @@ -62,8 +62,7 @@ WaitQueueTimeoutError, _CertificateError, ) -from pymongo.hello import Hello -from pymongo.hello_compat import HelloCompat +from pymongo.hello import Hello, HelloCompat from pymongo.lock import _ACondition, _ALock, _create_lock from pymongo.logger import ( _CONNECTION_LOGGER, diff --git a/pymongo/compression_support.py b/pymongo/compression_support.py index 7a0f2a36dd..de7b856506 100644 --- a/pymongo/compression_support.py +++ b/pymongo/compression_support.py @@ -16,7 +16,7 @@ import warnings from typing import Any, Iterable, Optional, Union -from pymongo.hello_compat import HelloCompat +from pymongo.hello import HelloCompat from pymongo.helpers_shared import _SENSITIVE_COMMANDS _IS_SYNC = False diff --git a/pymongo/hello.py b/pymongo/hello.py index 40bd842c0a..89b51980b3 100644 --- a/pymongo/hello.py +++ b/pymongo/hello.py @@ -22,7 +22,6 @@ from bson.objectid import ObjectId from pymongo import common -from pymongo.hello_compat import HelloCompat from pymongo.server_type import SERVER_TYPE from pymongo.typings import ClusterTime, _DocumentType @@ -57,6 +56,14 @@ def _get_server_type(doc: Mapping[str, Any]) -> int: return SERVER_TYPE.Standalone +class HelloCompat: + CMD = "hello" + LEGACY_CMD = "ismaster" + PRIMARY = "isWritablePrimary" + LEGACY_PRIMARY = "ismaster" + LEGACY_ERROR = "not master" + + class Hello(Generic[_DocumentType]): """Parse a hello response from the server. diff --git a/pymongo/hello_compat.py b/pymongo/hello_compat.py deleted file mode 100644 index 9bc8b088c5..0000000000 --- a/pymongo/hello_compat.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright 2024-present MongoDB, Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""The HelloCompat class, placed here to break circular import issues.""" -from __future__ import annotations - -_IS_SYNC = False - - -class HelloCompat: - CMD = "hello" - LEGACY_CMD = "ismaster" - PRIMARY = "isWritablePrimary" - LEGACY_PRIMARY = "ismaster" - LEGACY_ERROR = "not master" diff --git a/pymongo/helpers_shared.py b/pymongo/helpers_shared.py index 884a008385..c3324e162a 100644 --- a/pymongo/helpers_shared.py +++ b/pymongo/helpers_shared.py @@ -42,7 +42,7 @@ WTimeoutError, _wtimeout_error, ) -from pymongo.hello_compat import HelloCompat +from pymongo.hello import HelloCompat if TYPE_CHECKING: from pymongo.cursor_shared import _Hint diff --git a/pymongo/message.py b/pymongo/message.py index cf9977ac6f..f6f4d60dd7 100644 --- a/pymongo/message.py +++ b/pymongo/message.py @@ -46,7 +46,7 @@ RawBSONDocument, _inflate_bson, ) -from pymongo.hello_compat import HelloCompat +from pymongo.hello import HelloCompat from pymongo.monitoring import _EventListeners try: diff --git a/pymongo/monitoring.py b/pymongo/monitoring.py index 87451d5180..1e48905aee 100644 --- a/pymongo/monitoring.py +++ b/pymongo/monitoring.py @@ -190,8 +190,7 @@ def connection_checked_in(self, event): from typing import TYPE_CHECKING, Any, Mapping, Optional, Sequence from bson.objectid import ObjectId -from pymongo.hello import Hello -from pymongo.hello_compat import HelloCompat +from pymongo.hello import Hello, HelloCompat from pymongo.helpers_shared import _SENSITIVE_COMMANDS, _handle_exception from pymongo.typings import _Address, _DocumentOut diff --git a/pymongo/synchronous/pool.py b/pymongo/synchronous/pool.py index 197409e84a..59c4aaec63 100644 --- a/pymongo/synchronous/pool.py +++ b/pymongo/synchronous/pool.py @@ -59,8 +59,7 @@ WaitQueueTimeoutError, _CertificateError, ) -from pymongo.hello import Hello -from pymongo.hello_compat import HelloCompat +from pymongo.hello import Hello, HelloCompat from pymongo.lock import _create_lock from pymongo.logger import ( _CONNECTION_LOGGER, diff --git a/pymongo/typings.py b/pymongo/typings.py index e0593517a8..c89f5e2abc 100644 --- a/pymongo/typings.py +++ b/pymongo/typings.py @@ -29,13 +29,14 @@ from bson.typings import _DocumentOut, _DocumentType, _DocumentTypeArg if TYPE_CHECKING: - from pymongo import AsyncMongoClient, MongoClient from pymongo.asynchronous.bulk import _AsyncBulk from pymongo.asynchronous.client_session import AsyncClientSession + from pymongo.asynchronous.mongo_client import AsyncMongoClient from pymongo.asynchronous.pool import AsyncConnection from pymongo.collation import Collation from pymongo.synchronous.bulk import _Bulk from pymongo.synchronous.client_session import ClientSession + from pymongo.synchronous.mongo_client import MongoClient from pymongo.synchronous.pool import Connection _IS_SYNC = False diff --git a/test/__init__.py b/test/__init__.py index f8b27ce193..f45c83b061 100644 --- a/test/__init__.py +++ b/test/__init__.py @@ -47,7 +47,7 @@ from bson.son import SON from pymongo import common, message from pymongo.common import partition_node -from pymongo.hello_compat import HelloCompat +from pymongo.hello import HelloCompat from pymongo.server_api import ServerApi from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined] from pymongo.synchronous.database import Database diff --git a/test/asynchronous/__init__.py b/test/asynchronous/__init__.py index d63ed77232..0a74366ae8 100644 --- a/test/asynchronous/__init__.py +++ b/test/asynchronous/__init__.py @@ -73,7 +73,7 @@ from pymongo.asynchronous.database import AsyncDatabase from pymongo.asynchronous.mongo_client import AsyncMongoClient from pymongo.common import partition_node -from pymongo.hello_compat import HelloCompat +from pymongo.hello import HelloCompat from pymongo.server_api import ServerApi from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined] diff --git a/test/auth_oidc/test_auth_oidc.py b/test/auth_oidc/test_auth_oidc.py index 9f7941fbf7..83e25e685a 100644 --- a/test/auth_oidc/test_auth_oidc.py +++ b/test/auth_oidc/test_auth_oidc.py @@ -36,7 +36,7 @@ from pymongo._gcp_helpers import _get_gcp_response from pymongo.cursor_shared import CursorType from pymongo.errors import AutoReconnect, ConfigurationError, OperationFailure -from pymongo.hello_compat import HelloCompat +from pymongo.hello import HelloCompat from pymongo.operations import InsertOne from pymongo.synchronous.auth_oidc import ( OIDCCallback, diff --git a/test/pymongo_mocks.py b/test/pymongo_mocks.py index 90914927cb..9cbca169cf 100644 --- a/test/pymongo_mocks.py +++ b/test/pymongo_mocks.py @@ -22,8 +22,7 @@ from pymongo import common from pymongo.errors import AutoReconnect, NetworkTimeout -from pymongo.hello import Hello -from pymongo.hello_compat import HelloCompat +from pymongo.hello import Hello, HelloCompat from pymongo.server_description import ServerDescription from pymongo.synchronous.mongo_client import MongoClient from pymongo.synchronous.monitor import Monitor diff --git a/test/synchronous/__init__.py b/test/synchronous/__init__.py index 1320561c8c..9176b22d1b 100644 --- a/test/synchronous/__init__.py +++ b/test/synchronous/__init__.py @@ -71,7 +71,7 @@ import pymongo.errors from bson.son import SON from pymongo.common import partition_node -from pymongo.hello_compat import HelloCompat +from pymongo.hello import HelloCompat from pymongo.server_api import ServerApi from pymongo.ssl_support import HAVE_SSL, _ssl # type:ignore[attr-defined] from pymongo.synchronous.database import Database diff --git a/test/test_auth.py b/test/test_auth.py index 29cac352fd..45d047e7d9 100644 --- a/test/test_auth.py +++ b/test/test_auth.py @@ -37,7 +37,7 @@ from pymongo.asynchronous.auth import HAVE_KERBEROS from pymongo.auth_shared import _build_credentials_tuple from pymongo.errors import OperationFailure -from pymongo.hello_compat import HelloCompat +from pymongo.hello import HelloCompat from pymongo.read_preferences import ReadPreference from pymongo.saslprep import HAVE_STRINGPREP diff --git a/test/test_discovery_and_monitoring.py b/test/test_discovery_and_monitoring.py index e584c17f4e..ef32afbcd4 100644 --- a/test/test_discovery_and_monitoring.py +++ b/test/test_discovery_and_monitoring.py @@ -49,7 +49,6 @@ OperationFailure, ) from pymongo.hello import Hello, HelloCompat -from pymongo.hello_compat import HelloCompat from pymongo.helpers_shared import _check_command_response, _check_write_command_response from pymongo.monitoring import ServerHeartbeatFailedEvent, ServerHeartbeatStartedEvent from pymongo.server_description import SERVER_TYPE, ServerDescription diff --git a/test/test_pooling.py b/test/test_pooling.py index 3cc544d2ea..aa32f9f774 100644 --- a/test/test_pooling.py +++ b/test/test_pooling.py @@ -26,7 +26,7 @@ from bson.son import SON from pymongo import MongoClient, message, timeout from pymongo.errors import AutoReconnect, ConnectionFailure, DuplicateKeyError -from pymongo.hello_compat import HelloCompat +from pymongo.hello import HelloCompat sys.path[0:0] = [""] diff --git a/test/test_server_selection.py b/test/test_server_selection.py index 42bd5a095d..d3526617f6 100644 --- a/test/test_server_selection.py +++ b/test/test_server_selection.py @@ -20,7 +20,7 @@ from pymongo import MongoClient, ReadPreference from pymongo.errors import ServerSelectionTimeoutError -from pymongo.hello_compat import HelloCompat +from pymongo.hello import HelloCompat from pymongo.operations import _Op from pymongo.server_selectors import writable_server_selector from pymongo.synchronous.settings import TopologySettings diff --git a/test/test_ssl.py b/test/test_ssl.py index b123accdf6..3b307df39e 100644 --- a/test/test_ssl.py +++ b/test/test_ssl.py @@ -33,7 +33,7 @@ from pymongo import MongoClient, ssl_support from pymongo.errors import ConfigurationError, ConnectionFailure, OperationFailure -from pymongo.hello_compat import HelloCompat +from pymongo.hello import HelloCompat from pymongo.ssl_support import HAVE_SSL, _ssl, get_ssl_context from pymongo.write_concern import WriteConcern diff --git a/test/test_streaming_protocol.py b/test/test_streaming_protocol.py index 97618e105e..44e673822a 100644 --- a/test/test_streaming_protocol.py +++ b/test/test_streaming_protocol.py @@ -30,7 +30,7 @@ ) from pymongo import monitoring -from pymongo.hello_compat import HelloCompat +from pymongo.hello import HelloCompat class TestStreamingProtocol(IntegrationTest): diff --git a/test/utils.py b/test/utils.py index 97b39b38e7..98666e271d 100644 --- a/test/utils.py +++ b/test/utils.py @@ -39,7 +39,7 @@ from pymongo import AsyncMongoClient, monitoring, operations, read_preferences from pymongo.cursor_shared import CursorType from pymongo.errors import ConfigurationError, OperationFailure -from pymongo.hello_compat import HelloCompat +from pymongo.hello import HelloCompat from pymongo.helpers_shared import _SENSITIVE_COMMANDS from pymongo.lock import _create_lock from pymongo.monitoring import ( From 74d7f61f1cacf31eabe26d53eb67481410b48579 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Thu, 20 Jun 2024 16:06:55 -0700 Subject: [PATCH 07/11] Address review --- pymongo/asynchronous/mongo_client.py | 2 +- pymongo/message.py | 2140 +++++++++++++------------- pymongo/synchronous/server.py | 4 +- test/unified_format.py | 2 +- test/utils_selection_tests.py | 3 +- tools/synchro.py | 1 + 6 files changed, 1076 insertions(+), 1076 deletions(-) diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index 4107efbf06..407beff5b8 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -1883,7 +1883,7 @@ async def _cleanup_cursor_lock( cursor_id: int, address: Optional[_CursorAddress], conn_mgr: _ConnectionManager, - session: Optional[ClientSession], + session: Optional[AsyncClientSession], explicit_session: bool, ) -> None: """Cleanup a cursor from cursor.close() using a lock. diff --git a/pymongo/message.py b/pymongo/message.py index f6f4d60dd7..bcb4ce10ec 100644 --- a/pymongo/message.py +++ b/pymongo/message.py @@ -1,4 +1,4 @@ -# Copyright 2024-present MongoDB, Inc. +# Copyright 2009-present MongoDB, Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -38,7 +38,7 @@ ) import bson -from bson import CodecOptions, _dict_to_bson +from bson import CodecOptions, _dict_to_bson, _make_c_string from bson.int64 import Int64 from bson.raw_bson import ( _RAW_ARRAY_BSON_OPTIONS, @@ -267,1267 +267,1267 @@ def _gen_get_more_command( return cmd -class _OpReply: - """A MongoDB OP_REPLY response message.""" +_pack_compression_header = struct.Struct(" tuple[int, bytes]: + """Takes message data, compresses it, and adds an OP_COMPRESSED header.""" + compressed = ctx.compress(data) + request_id = _randint() - def __init__(self, flags: int, cursor_id: int, number_returned: int, documents: bytes): - self.flags = flags - self.cursor_id = Int64(cursor_id) - self.number_returned = number_returned - self.documents = documents + header = _pack_compression_header( + _COMPRESSION_HEADER_SIZE + len(compressed), # Total message length + request_id, # Request id + 0, # responseTo + 2012, # operation id + operation, # original operation id + len(data), # uncompressed message length + ctx.compressor_id, + ) # compressor id + return request_id, header + compressed - def raw_response( - self, cursor_id: Optional[int] = None, user_fields: Optional[Mapping[str, Any]] = None - ) -> list[bytes]: - """Check the response header from the database, without decoding BSON. - Check the response for errors and unpack. +_pack_header = struct.Struct(" tuple[int, bytes]: + """Takes message data and adds a message header based on the operation. - # Fake a getMore command response. OP_GET_MORE provides no - # document. - msg = "Cursor not found, cursor id: %d" % (cursor_id,) - errobj = {"ok": 0, "errmsg": msg, "code": 43} - raise CursorNotFound(msg, 43, errobj) - elif self.flags & 2: - error_object: dict = bson.BSON(self.documents).decode() - # Fake the ok field if it doesn't exist. - error_object.setdefault("ok", 0) - if error_object["$err"].startswith(HelloCompat.LEGACY_ERROR): - raise NotPrimaryError(error_object["$err"], error_object) - elif error_object.get("code") == 50: - default_msg = "operation exceeded time limit" - raise ExecutionTimeout( - error_object.get("$err", default_msg), error_object.get("code"), error_object - ) - raise OperationFailure( - "database error: %s" % error_object.get("$err"), - error_object.get("code"), - error_object, - ) - if self.documents: - return [self.documents] - return [] + Returns the resultant message string. + """ + rid = _randint() + message = _pack_header(16 + len(data), rid, 0, operation) + return rid, message + data - def unpack_response( - self, - cursor_id: Optional[int] = None, - codec_options: CodecOptions = _UNICODE_REPLACE_CODEC_OPTIONS, - user_fields: Optional[Mapping[str, Any]] = None, - legacy_response: bool = False, - ) -> list[dict[str, Any]]: - """Unpack a response from the database and decode the BSON document(s). - Check the response for errors and unpack, returning a dictionary - containing the response data. +_pack_int = struct.Struct(" tuple[bytes, int, int]: + """Get a OP_MSG message. - def command_response(self, codec_options: CodecOptions) -> dict[str, Any]: - """Unpack a command response.""" - docs = self.unpack_response(codec_options=codec_options) - assert self.number_returned == 1 - return docs[0] + Note: this method handles multiple documents in a type one payload but + it does not perform batch splitting and the total message size is + only checked *after* generating the entire message. + """ + # Encode the command document in payload 0 without checking keys. + encoded = _dict_to_bson(command, False, opts) + flags_type = _pack_op_msg_flags_type(flags, 0) + total_size = len(encoded) + max_doc_size = 0 + if identifier and docs is not None: + type_one = _pack_byte(1) + cstring = _make_c_string(identifier) + encoded_docs = [_dict_to_bson(doc, False, opts) for doc in docs] + size = len(cstring) + sum(len(doc) for doc in encoded_docs) + 4 + encoded_size = _pack_int(size) + total_size += size + max_doc_size = max(len(doc) for doc in encoded_docs) + data = [flags_type, encoded, type_one, encoded_size, cstring, *encoded_docs] + else: + data = [flags_type, encoded] + return b"".join(data), total_size, max_doc_size - def raw_command_response(self) -> NoReturn: - """Return the bytes of the command response.""" - # This should never be called on _OpReply. - raise NotImplementedError - @property - def more_to_come(self) -> bool: - """Is the moreToCome bit set on this response?""" - return False +def _op_msg_compressed( + flags: int, + command: Mapping[str, Any], + identifier: str, + docs: Optional[list[Mapping[str, Any]]], + opts: CodecOptions, + ctx: Union[SnappyContext, ZlibContext, ZstdContext], +) -> tuple[int, bytes, int, int]: + """Internal OP_MSG message helper.""" + msg, total_size, max_bson_size = _op_msg_no_header(flags, command, identifier, docs, opts) + rid, msg = _compress(2013, msg, ctx) + return rid, msg, total_size, max_bson_size - @classmethod - def unpack(cls, msg: bytes) -> _OpReply: - """Construct an _OpReply from raw bytes.""" - # PYTHON-945: ignore starting_from field. - flags, cursor_id, _, number_returned = cls.UNPACK_FROM(msg) - documents = msg[20:] - return cls(flags, cursor_id, number_returned, documents) +def _op_msg_uncompressed( + flags: int, + command: Mapping[str, Any], + identifier: str, + docs: Optional[list[Mapping[str, Any]]], + opts: CodecOptions, +) -> tuple[int, bytes, int, int]: + """Internal compressed OP_MSG message helper.""" + data, total_size, max_bson_size = _op_msg_no_header(flags, command, identifier, docs, opts) + request_id, op_message = __pack_message(2013, data) + return request_id, op_message, total_size, max_bson_size -class _OpMsg: - """A MongoDB OP_MSG response message.""" +if _use_c: + _op_msg_uncompressed = _cmessage._op_msg - __slots__ = ("flags", "cursor_id", "number_returned", "payload_document") - UNPACK_FROM = struct.Struct(" tuple[int, bytes, int, int]: + """Get a OP_MSG message.""" + command["$db"] = dbname + # getMore commands do not send $readPreference. + if read_preference is not None and "$readPreference" not in command: + # Only send $readPreference if it's not primary (the default). + if read_preference.mode: + command["$readPreference"] = read_preference.document + name = next(iter(command)) + try: + identifier = _FIELD_MAP[name] + docs = command.pop(identifier) + except KeyError: + identifier = "" + docs = None + try: + if ctx: + return _op_msg_compressed(flags, command, identifier, docs, opts, ctx) + return _op_msg_uncompressed(flags, command, identifier, docs, opts) + finally: + # Add the field back to the command. + if identifier: + command[identifier] = docs - # Flag bits. - CHECKSUM_PRESENT = 1 - MORE_TO_COME = 1 << 1 - EXHAUST_ALLOWED = 1 << 16 # Only present on requests. - def __init__(self, flags: int, payload_document: bytes): - self.flags = flags - self.payload_document = payload_document +def _query_impl( + options: int, + collection_name: str, + num_to_skip: int, + num_to_return: int, + query: Mapping[str, Any], + field_selector: Optional[Mapping[str, Any]], + opts: CodecOptions, +) -> tuple[bytes, int]: + """Get an OP_QUERY message.""" + encoded = _dict_to_bson(query, False, opts) + if field_selector: + efs = _dict_to_bson(field_selector, False, opts) + else: + efs = b"" + max_bson_size = max(len(encoded), len(efs)) + return ( + b"".join( + [ + _pack_int(options), + bson._make_c_string(collection_name), + _pack_int(num_to_skip), + _pack_int(num_to_return), + encoded, + efs, + ] + ), + max_bson_size, + ) - def raw_response( - self, - cursor_id: Optional[int] = None, - user_fields: Optional[Mapping[str, Any]] = {}, - ) -> list[Mapping[str, Any]]: - """ - cursor_id is ignored - user_fields is used to determine which fields must not be decoded - """ - inflated_response = bson._decode_selective( - RawBSONDocument(self.payload_document), user_fields, _RAW_ARRAY_BSON_OPTIONS + +def _query_compressed( + options: int, + collection_name: str, + num_to_skip: int, + num_to_return: int, + query: Mapping[str, Any], + field_selector: Optional[Mapping[str, Any]], + opts: CodecOptions, + ctx: Union[SnappyContext, ZlibContext, ZstdContext], +) -> tuple[int, bytes, int]: + """Internal compressed query message helper.""" + op_query, max_bson_size = _query_impl( + options, collection_name, num_to_skip, num_to_return, query, field_selector, opts + ) + rid, msg = _compress(2004, op_query, ctx) + return rid, msg, max_bson_size + + +def _query_uncompressed( + options: int, + collection_name: str, + num_to_skip: int, + num_to_return: int, + query: Mapping[str, Any], + field_selector: Optional[Mapping[str, Any]], + opts: CodecOptions, +) -> tuple[int, bytes, int]: + """Internal query message helper.""" + op_query, max_bson_size = _query_impl( + options, collection_name, num_to_skip, num_to_return, query, field_selector, opts + ) + rid, msg = __pack_message(2004, op_query) + return rid, msg, max_bson_size + + +if _use_c: + _query_uncompressed = _cmessage._query_message + + +def _query( + options: int, + collection_name: str, + num_to_skip: int, + num_to_return: int, + query: Mapping[str, Any], + field_selector: Optional[Mapping[str, Any]], + opts: CodecOptions, + ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None, +) -> tuple[int, bytes, int]: + """Get a **query** message.""" + if ctx: + return _query_compressed( + options, collection_name, num_to_skip, num_to_return, query, field_selector, opts, ctx ) - return [inflated_response] + return _query_uncompressed( + options, collection_name, num_to_skip, num_to_return, query, field_selector, opts + ) - def unpack_response( - self, - cursor_id: Optional[int] = None, - codec_options: CodecOptions = _UNICODE_REPLACE_CODEC_OPTIONS, - user_fields: Optional[Mapping[str, Any]] = None, - legacy_response: bool = False, - ) -> list[dict[str, Any]]: - """Unpack a OP_MSG command response. - :param cursor_id: Ignored, for compatibility with _OpReply. - :param codec_options: an instance of - :class:`~bson.codec_options.CodecOptions` - :param user_fields: Response fields that should be decoded - using the TypeDecoders from codec_options, passed to - bson._decode_all_selective. - """ - # If _OpMsg is in-use, this cannot be a legacy response. - assert not legacy_response - return bson._decode_all_selective(self.payload_document, codec_options, user_fields) +_pack_long_long = struct.Struct(" dict[str, Any]: - """Unpack a command response.""" - return self.unpack_response(codec_options=codec_options)[0] - def raw_command_response(self) -> bytes: - """Return the bytes of the command response.""" - return self.payload_document +def _get_more_impl(collection_name: str, num_to_return: int, cursor_id: int) -> bytes: + """Get an OP_GET_MORE message.""" + return b"".join( + [ + _ZERO_32, + bson._make_c_string(collection_name), + _pack_int(num_to_return), + _pack_long_long(cursor_id), + ] + ) - @property - def more_to_come(self) -> bool: - """Is the moreToCome bit set on this response?""" - return bool(self.flags & self.MORE_TO_COME) - @classmethod - def unpack(cls, msg: bytes) -> _OpMsg: - """Construct an _OpMsg from raw bytes.""" - flags, first_payload_type, first_payload_size = cls.UNPACK_FROM(msg) - if flags != 0: - if flags & cls.CHECKSUM_PRESENT: - raise ProtocolError(f"Unsupported OP_MSG flag checksumPresent: 0x{flags:x}") +def _get_more_compressed( + collection_name: str, + num_to_return: int, + cursor_id: int, + ctx: Union[SnappyContext, ZlibContext, ZstdContext], +) -> tuple[int, bytes]: + """Internal compressed getMore message helper.""" + return _compress(2005, _get_more_impl(collection_name, num_to_return, cursor_id), ctx) - if flags ^ cls.MORE_TO_COME: - raise ProtocolError(f"Unsupported OP_MSG flags: 0x{flags:x}") - if first_payload_type != 0: - raise ProtocolError(f"Unsupported OP_MSG payload type: 0x{first_payload_type:x}") - if len(msg) != first_payload_size + 5: - raise ProtocolError("Unsupported OP_MSG reply: >1 section") +def _get_more_uncompressed( + collection_name: str, num_to_return: int, cursor_id: int +) -> tuple[int, bytes]: + """Internal getMore message helper.""" + return __pack_message(2005, _get_more_impl(collection_name, num_to_return, cursor_id)) - payload_document = msg[5:] - return cls(flags, payload_document) + +if _use_c: + _get_more_uncompressed = _cmessage._get_more_message -_UNPACK_REPLY: dict[int, Callable[[bytes], Union[_OpReply, _OpMsg]]] = { - _OpReply.OP_CODE: _OpReply.unpack, - _OpMsg.OP_CODE: _OpMsg.unpack, +def _get_more( + collection_name: str, + num_to_return: int, + cursor_id: int, + ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None, +) -> tuple[int, bytes]: + """Get a **getMore** message.""" + if ctx: + return _get_more_compressed(collection_name, num_to_return, cursor_id, ctx) + return _get_more_uncompressed(collection_name, num_to_return, cursor_id) + + +# OP_MSG ------------------------------------------------------------- + + +_OP_MSG_MAP = { + _INSERT: b"documents\x00", + _UPDATE: b"updates\x00", + _DELETE: b"deletes\x00", } -class _Query: - """A query operation.""" +class _BulkWriteContext: + """A wrapper around AsyncConnection for use with write splitting functions.""" __slots__ = ( - "flags", - "db", - "coll", - "ntoskip", - "spec", - "fields", - "codec_options", - "read_preference", - "limit", - "batch_size", + "db_name", + "conn", + "op_id", "name", - "read_concern", - "collation", + "field", + "publish", + "start_time", + "listeners", "session", - "client", - "allow_disk_use", - "_as_command", - "exhaust", + "compress", + "op_type", + "codec", ) - # For compatibility with the _GetMore class. - conn_mgr = None - cursor_id = None - def __init__( self, - flags: int, - db: str, - coll: str, - ntoskip: int, - spec: Mapping[str, Any], - fields: Optional[Mapping[str, Any]], - codec_options: CodecOptions, - read_preference: _ServerMode, - limit: int, - batch_size: int, - read_concern: ReadConcern, - collation: Optional[Mapping[str, Any]], - session: Optional[_AgnosticClientSession], - client: _AgnosticMongoClient, - allow_disk_use: Optional[bool], - exhaust: bool, + database_name: str, + cmd_name: str, + conn: _AgnosticConnection, + operation_id: int, + listeners: _EventListeners, + session: _AgnosticClientSession, + op_type: int, + codec: CodecOptions, ): - self.flags = flags - self.db = db - self.coll = coll - self.ntoskip = ntoskip - self.spec = spec - self.fields = fields - self.codec_options = codec_options - self.read_preference = read_preference - self.read_concern = read_concern - self.limit = limit - self.batch_size = batch_size - self.collation = collation + self.db_name = database_name + self.conn = conn + self.op_id = operation_id + self.listeners = listeners + self.publish = listeners.enabled_for_commands + self.name = cmd_name + self.field = _FIELD_MAP[self.name] + self.start_time = datetime.datetime.now() self.session = session - self.client = client - self.allow_disk_use = allow_disk_use - self.name = "find" - self._as_command: Optional[tuple[dict[str, Any], str]] = None - self.exhaust = exhaust + self.compress = bool(conn.compression_context) + self.op_type = op_type + self.codec = codec - def reset(self) -> None: - self._as_command = None + def batch_command( + self, cmd: MutableMapping[str, Any], docs: list[Mapping[str, Any]] + ) -> tuple[int, Union[bytes, dict[str, Any]], list[Mapping[str, Any]]]: + namespace = self.db_name + ".$cmd" + request_id, msg, to_send = _do_batched_op_msg( + namespace, self.op_type, cmd, docs, self.codec, self + ) + if not to_send: + raise InvalidOperation("cannot do an empty bulk write") + return request_id, msg, to_send - def namespace(self) -> str: - return f"{self.db}.{self.coll}" + @property + def max_bson_size(self) -> int: + """A proxy for SockInfo.max_bson_size.""" + return self.conn.max_bson_size - def use_command(self, conn: _AgnosticConnection) -> bool: - use_find_cmd = False - if not self.exhaust: - use_find_cmd = True - elif conn.max_wire_version >= 8: - # OP_MSG supports exhaust on MongoDB 4.2+ - use_find_cmd = True - elif not self.read_concern.ok_for_legacy: - raise ConfigurationError( - "read concern level of %s is not valid " - "with a max wire version of %d." % (self.read_concern.level, conn.max_wire_version) - ) + @property + def max_message_size(self) -> int: + """A proxy for SockInfo.max_message_size.""" + if self.compress: + # Subtract 16 bytes for the message header. + return self.conn.max_message_size - 16 + return self.conn.max_message_size - conn.validate_session(self.client, self.session) # type: ignore[arg-type] - return use_find_cmd + @property + def max_write_batch_size(self) -> int: + """A proxy for SockInfo.max_write_batch_size.""" + return self.conn.max_write_batch_size - def update_command(self, cmd: dict[str, Any]) -> None: - self._as_command = cmd, self.db + @property + def max_split_size(self) -> int: + """The maximum size of a BSON command before batch splitting.""" + return self.max_bson_size - def as_command( - self, conn: _AgnosticConnection, apply_timeout: bool = False - ) -> tuple[dict[str, Any], str]: - """Return a find command document for this query.""" - # We use the command twice: on the wire and for command monitoring. - # Generate it once, for speed and to avoid repeating side-effects. - if self._as_command is not None: - return self._as_command + def _start( + self, cmd: MutableMapping[str, Any], request_id: int, docs: list[Mapping[str, Any]] + ) -> MutableMapping[str, Any]: + """Publish a CommandStartedEvent.""" + cmd[self.field] = docs + self.listeners.publish_command_start( + cmd, + self.db_name, + request_id, + self.conn.address, + self.conn.server_connection_id, + self.op_id, + self.conn.service_id, + ) + return cmd - explain = "$explain" in self.spec - cmd: dict[str, Any] = _gen_find_command( - self.coll, - self.spec, - self.fields, - self.ntoskip, - self.limit, - self.batch_size, - self.flags, - self.read_concern, - self.collation, - self.session, - self.allow_disk_use, + def _succeed(self, request_id: int, reply: _DocumentOut, duration: datetime.timedelta) -> None: + """Publish a CommandSucceededEvent.""" + self.listeners.publish_command_success( + duration, + reply, + self.name, + request_id, + self.conn.address, + self.conn.server_connection_id, + self.op_id, + self.conn.service_id, + database_name=self.db_name, ) - if explain: - self.name = "explain" - cmd = {"explain": cmd} - conn.add_server_api(cmd) - if self.session: - self.session._apply_to(cmd, False, self.read_preference, conn) # type: ignore[arg-type] - # Explain does not support readConcern. - if not explain and not self.session.in_transaction: - self.session._update_read_concern(cmd, conn) # type: ignore[arg-type] - conn.send_cluster_time(cmd, self.session, self.client) # type: ignore[arg-type] - # Support CSOT - if apply_timeout: - conn.apply_timeout(self.client, cmd=cmd) # type: ignore[arg-type] - self._as_command = cmd, self.db - return self._as_command - def get_message( - self, read_preference: _ServerMode, conn: _AgnosticConnection, use_cmd: bool = False - ) -> tuple[int, bytes, int]: - """Get a query message, possibly setting the secondaryOk bit.""" - # Use the read_preference decided by _socket_from_server. - self.read_preference = read_preference - if read_preference.mode: - # Set the secondaryOk bit. - flags = self.flags | 4 - else: - flags = self.flags + def _fail(self, request_id: int, failure: _DocumentOut, duration: datetime.timedelta) -> None: + """Publish a CommandFailedEvent.""" + self.listeners.publish_command_failure( + duration, + failure, + self.name, + request_id, + self.conn.address, + self.conn.server_connection_id, + self.op_id, + self.conn.service_id, + database_name=self.db_name, + ) - ns = self.namespace() - spec = self.spec - if use_cmd: - spec = self.as_command(conn)[0] - request_id, msg, size, _ = _op_msg( - 0, - spec, - self.db, - read_preference, - self.codec_options, - ctx=conn.compression_context, - ) - return request_id, msg, size +class _EncryptedBulkWriteContext(_BulkWriteContext): + __slots__ = () - # OP_QUERY treats ntoreturn of -1 and 1 the same, return - # one document and close the cursor. We have to use 2 for - # batch size if 1 is specified. - ntoreturn = self.batch_size == 1 and 2 or self.batch_size - if self.limit: - if ntoreturn: - ntoreturn = min(self.limit, ntoreturn) - else: - ntoreturn = self.limit + def batch_command( + self, cmd: MutableMapping[str, Any], docs: list[Mapping[str, Any]] + ) -> tuple[int, dict[str, Any], list[Mapping[str, Any]]]: + namespace = self.db_name + ".$cmd" + msg, to_send = _encode_batched_write_command( + namespace, self.op_type, cmd, docs, self.codec, self + ) + if not to_send: + raise InvalidOperation("cannot do an empty bulk write") - if conn.is_mongos: - assert isinstance(spec, MutableMapping) - spec = _maybe_add_read_preference(spec, read_preference) + # Chop off the OP_QUERY header to get a properly batched write command. + cmd_start = msg.index(b"\x00", 4) + 9 + outgoing = _inflate_bson(memoryview(msg)[cmd_start:], DEFAULT_RAW_BSON_OPTIONS) + return -1, outgoing, to_send - return _query( - flags, - ns, - self.ntoskip, - ntoreturn, - spec, - None if use_cmd else self.fields, - self.codec_options, - ctx=conn.compression_context, - ) + @property + def max_split_size(self) -> int: + """Reduce the batch splitting size.""" + return _MAX_SPLIT_SIZE_ENC -class _GetMore: - """A getmore operation.""" +def _raise_document_too_large(operation: str, doc_size: int, max_size: int) -> NoReturn: + """Internal helper for raising DocumentTooLarge.""" + if operation == "insert": + raise DocumentTooLarge( + "BSON document too large (%d bytes)" + " - the connected server supports" + " BSON document sizes up to %d" + " bytes." % (doc_size, max_size) + ) + else: + # There's nothing intelligent we can say + # about size for update and delete + raise DocumentTooLarge(f"{operation!r} command document too large") - __slots__ = ( - "db", - "coll", - "ntoreturn", - "cursor_id", - "max_await_time_ms", - "codec_options", - "read_preference", - "session", - "client", - "conn_mgr", - "_as_command", - "exhaust", - "comment", - ) - name = "getMore" +# From the Client Side Encryption spec: +# Because automatic encryption increases the size of commands, the driver +# MUST split bulk writes at a reduced size limit before undergoing automatic +# encryption. The write payload MUST be split at 2MiB (2097152). +_MAX_SPLIT_SIZE_ENC = 2097152 - def __init__( - self, - db: str, - coll: str, - ntoreturn: int, - cursor_id: int, - codec_options: CodecOptions, - read_preference: _ServerMode, - session: Optional[_AgnosticClientSession], - client: _AgnosticMongoClient, - max_await_time_ms: Optional[int], - conn_mgr: Any, - exhaust: bool, - comment: Any, - ): - self.db = db - self.coll = coll - self.ntoreturn = ntoreturn - self.cursor_id = cursor_id - self.codec_options = codec_options - self.read_preference = read_preference - self.session = session - self.client = client - self.max_await_time_ms = max_await_time_ms - self.conn_mgr = conn_mgr - self._as_command: Optional[tuple[dict[str, Any], str]] = None - self.exhaust = exhaust - self.comment = comment - def reset(self) -> None: - self._as_command = None +def _batched_op_msg_impl( + operation: int, + command: Mapping[str, Any], + docs: list[Mapping[str, Any]], + ack: bool, + opts: CodecOptions, + ctx: _BulkWriteContext, + buf: _BytesIO, +) -> tuple[list[Mapping[str, Any]], int]: + """Create a batched OP_MSG write.""" + max_bson_size = ctx.max_bson_size + max_write_batch_size = ctx.max_write_batch_size + max_message_size = ctx.max_message_size - def namespace(self) -> str: - return f"{self.db}.{self.coll}" + flags = b"\x00\x00\x00\x00" if ack else b"\x02\x00\x00\x00" + # Flags + buf.write(flags) - def use_command(self, conn: _AgnosticConnection) -> bool: - use_cmd = False - if not self.exhaust: - use_cmd = True - elif conn.max_wire_version >= 8: - # OP_MSG supports exhaust on MongoDB 4.2+ - use_cmd = True - - conn.validate_session(self.client, self.session) # type: ignore[arg-type] - return use_cmd + # Type 0 Section + buf.write(b"\x00") + buf.write(_dict_to_bson(command, False, opts)) - def update_command(self, cmd: dict[str, Any]) -> None: - self._as_command = cmd, self.db + # Type 1 Section + buf.write(b"\x01") + size_location = buf.tell() + # Save space for size + buf.write(b"\x00\x00\x00\x00") + try: + buf.write(_OP_MSG_MAP[operation]) + except KeyError: + raise InvalidOperation("Unknown command") from None - def as_command( - self, conn: _AgnosticConnection, apply_timeout: bool = False - ) -> tuple[dict[str, Any], str]: - """Return a getMore command document for this query.""" - # See _Query.as_command for an explanation of this caching. - if self._as_command is not None: - return self._as_command + to_send = [] + idx = 0 + for doc in docs: + # Encode the current operation + value = _dict_to_bson(doc, False, opts) + doc_length = len(value) + new_message_size = buf.tell() + doc_length + # Does first document exceed max_message_size? + doc_too_large = idx == 0 and (new_message_size > max_message_size) + # When OP_MSG is used unacknowledged we have to check + # document size client side or applications won't be notified. + # Otherwise we let the server deal with documents that are too large + # since ordered=False causes those documents to be skipped instead of + # halting the bulk write operation. + unacked_doc_too_large = not ack and (doc_length > max_bson_size) + if doc_too_large or unacked_doc_too_large: + write_op = list(_FIELD_MAP.keys())[operation] + _raise_document_too_large(write_op, len(value), max_bson_size) + # We have enough data, return this batch. + if new_message_size > max_message_size: + break + buf.write(value) + to_send.append(doc) + idx += 1 + # We have enough documents, return this batch. + if idx == max_write_batch_size: + break - cmd: dict[str, Any] = _gen_get_more_command( - self.cursor_id, - self.coll, - self.ntoreturn, - self.max_await_time_ms, - self.comment, - conn, - ) - if self.session: - self.session._apply_to(cmd, False, self.read_preference, conn) # type: ignore[arg-type] - conn.add_server_api(cmd) - conn.send_cluster_time(cmd, self.session, self.client) # type: ignore[arg-type] - # Support CSOT - if apply_timeout: - conn.apply_timeout(self.client, cmd=None) # type: ignore[arg-type] - self._as_command = cmd, self.db - return self._as_command + # Write type 1 section size + length = buf.tell() + buf.seek(size_location) + buf.write(_pack_int(length - size_location)) - def get_message( - self, dummy0: Any, conn: _AgnosticConnection, use_cmd: bool = False - ) -> Union[tuple[int, bytes, int], tuple[int, bytes]]: - """Get a getmore message.""" - ns = self.namespace() - ctx = conn.compression_context + return to_send, length - if use_cmd: - spec = self.as_command(conn)[0] - if self.conn_mgr and self.exhaust: - flags = _OpMsg.EXHAUST_ALLOWED - else: - flags = 0 - request_id, msg, size, _ = _op_msg( - flags, spec, self.db, None, self.codec_options, ctx=conn.compression_context - ) - return request_id, msg, size - return _get_more(ns, self.ntoreturn, self.cursor_id, ctx) +def _encode_batched_op_msg( + operation: int, + command: Mapping[str, Any], + docs: list[Mapping[str, Any]], + ack: bool, + opts: CodecOptions, + ctx: _BulkWriteContext, +) -> tuple[bytes, list[Mapping[str, Any]]]: + """Encode the next batched insert, update, or delete operation + as OP_MSG. + """ + buf = _BytesIO() + to_send, _ = _batched_op_msg_impl(operation, command, docs, ack, opts, ctx, buf) + return buf.getvalue(), to_send -class _RawBatchQuery(_Query): - def use_command(self, conn: _AgnosticConnection) -> bool: - # Compatibility checks. - super().use_command(conn) - if conn.max_wire_version >= 8: - # MongoDB 4.2+ supports exhaust over OP_MSG - return True - elif not self.exhaust: - return True - return False +if _use_c: + _encode_batched_op_msg = _cmessage._encode_batched_op_msg -class _RawBatchGetMore(_GetMore): - def use_command(self, conn: _AgnosticConnection) -> bool: - # Compatibility checks. - super().use_command(conn) - if conn.max_wire_version >= 8: - # MongoDB 4.2+ supports exhaust over OP_MSG - return True - elif not self.exhaust: - return True - return False +def _batched_op_msg_compressed( + operation: int, + command: Mapping[str, Any], + docs: list[Mapping[str, Any]], + ack: bool, + opts: CodecOptions, + ctx: _BulkWriteContext, +) -> tuple[int, bytes, list[Mapping[str, Any]]]: + """Create the next batched insert, update, or delete operation + with OP_MSG, compressed. + """ + data, to_send = _encode_batched_op_msg(operation, command, docs, ack, opts, ctx) -class _CursorAddress(tuple): - """The server address (host, port) of a cursor, with namespace property.""" + assert ctx.conn.compression_context is not None + request_id, msg = _compress(2013, data, ctx.conn.compression_context) + return request_id, msg, to_send - __namespace: Any - def __new__(cls, address: _Address, namespace: str) -> _CursorAddress: - self = tuple.__new__(cls, address) - self.__namespace = namespace - return self +def _batched_op_msg( + operation: int, + command: Mapping[str, Any], + docs: list[Mapping[str, Any]], + ack: bool, + opts: CodecOptions, + ctx: _BulkWriteContext, +) -> tuple[int, bytes, list[Mapping[str, Any]]]: + """OP_MSG implementation entry point.""" + buf = _BytesIO() - @property - def namespace(self) -> str: - """The namespace this cursor.""" - return self.__namespace + # Save space for message length and request id + buf.write(_ZERO_64) + # responseTo, opCode + buf.write(b"\x00\x00\x00\x00\xdd\x07\x00\x00") - def __hash__(self) -> int: - # Two _CursorAddress instances with different namespaces - # must not hash the same. - return ((*self, self.__namespace)).__hash__() + to_send, length = _batched_op_msg_impl(operation, command, docs, ack, opts, ctx, buf) - def __eq__(self, other: object) -> bool: - if isinstance(other, _CursorAddress): - return tuple(self) == tuple(other) and self.namespace == other.namespace - return NotImplemented + # Header - request id and message length + buf.seek(4) + request_id = _randint() + buf.write(_pack_int(request_id)) + buf.seek(0) + buf.write(_pack_int(length)) - def __ne__(self, other: object) -> bool: - return not self == other + return request_id, buf.getvalue(), to_send -_pack_compression_header = struct.Struct(" tuple[int, bytes]: - """Takes message data, compresses it, and adds an OP_COMPRESSED header.""" - compressed = ctx.compress(data) - request_id = _randint() +def _do_batched_op_msg( + namespace: str, + operation: int, + command: MutableMapping[str, Any], + docs: list[Mapping[str, Any]], + opts: CodecOptions, + ctx: _BulkWriteContext, +) -> tuple[int, bytes, list[Mapping[str, Any]]]: + """Create the next batched insert, update, or delete operation + using OP_MSG. + """ + command["$db"] = namespace.split(".", 1)[0] + if "writeConcern" in command: + ack = bool(command["writeConcern"].get("w", 1)) + else: + ack = True + if ctx.conn.compression_context: + return _batched_op_msg_compressed(operation, command, docs, ack, opts, ctx) + return _batched_op_msg(operation, command, docs, ack, opts, ctx) - header = _pack_compression_header( - _COMPRESSION_HEADER_SIZE + len(compressed), # Total message length - request_id, # Request id - 0, # responseTo - 2012, # operation id - operation, # original operation id - len(data), # uncompressed message length - ctx.compressor_id, - ) # compressor id - return request_id, header + compressed +# End OP_MSG ----------------------------------------------------- -_pack_header = struct.Struct(" tuple[bytes, list[Mapping[str, Any]]]: + """Encode the next batched insert, update, or delete command.""" + buf = _BytesIO() -def __pack_message(operation: int, data: bytes) -> tuple[int, bytes]: - """Takes message data and adds a message header based on the operation. + to_send, _ = _batched_write_command_impl(namespace, operation, command, docs, opts, ctx, buf) + return buf.getvalue(), to_send - Returns the resultant message string. - """ - rid = _randint() - message = _pack_header(16 + len(data), rid, 0, operation) - return rid, message + data +if _use_c: + _encode_batched_write_command = _cmessage._encode_batched_write_command -_pack_int = struct.Struct(" tuple[bytes, int, int]: - """Get a OP_MSG message. - - Note: this method handles multiple documents in a type one payload but - it does not perform batch splitting and the total message size is - only checked *after* generating the entire message. - """ - # Encode the command document in payload 0 without checking keys. - encoded = _dict_to_bson(command, False, opts) - flags_type = _pack_op_msg_flags_type(flags, 0) - total_size = len(encoded) - max_doc_size = 0 - if identifier and docs is not None: - type_one = _pack_byte(1) - cstring = bson._make_c_string(identifier) - encoded_docs = [_dict_to_bson(doc, False, opts) for doc in docs] - size = len(cstring) + sum(len(doc) for doc in encoded_docs) + 4 - encoded_size = _pack_int(size) - total_size += size - max_doc_size = max(len(doc) for doc in encoded_docs) - data = [flags_type, encoded, type_one, encoded_size, cstring, *encoded_docs] - else: - data = [flags_type, encoded] - return b"".join(data), total_size, max_doc_size - - -def _op_msg_compressed( - flags: int, - command: Mapping[str, Any], - identifier: str, - docs: Optional[list[Mapping[str, Any]]], - opts: CodecOptions, - ctx: Union[SnappyContext, ZlibContext, ZstdContext], -) -> tuple[int, bytes, int, int]: - """Internal OP_MSG message helper.""" - msg, total_size, max_bson_size = _op_msg_no_header(flags, command, identifier, docs, opts) - rid, msg = _compress(2013, msg, ctx) - return rid, msg, total_size, max_bson_size - - -def _op_msg_uncompressed( - flags: int, - command: Mapping[str, Any], - identifier: str, - docs: Optional[list[Mapping[str, Any]]], +def _batched_write_command_impl( + namespace: str, + operation: int, + command: MutableMapping[str, Any], + docs: list[Mapping[str, Any]], opts: CodecOptions, -) -> tuple[int, bytes, int, int]: - """Internal compressed OP_MSG message helper.""" - data, total_size, max_bson_size = _op_msg_no_header(flags, command, identifier, docs, opts) - request_id, op_message = __pack_message(2013, data) - return request_id, op_message, total_size, max_bson_size - + ctx: _BulkWriteContext, + buf: _BytesIO, +) -> tuple[list[Mapping[str, Any]], int]: + """Create a batched OP_QUERY write command.""" + max_bson_size = ctx.max_bson_size + max_write_batch_size = ctx.max_write_batch_size + # Max BSON object size + 16k - 2 bytes for ending NUL bytes. + # Server guarantees there is enough room: SERVER-10643. + max_cmd_size = max_bson_size + _COMMAND_OVERHEAD + max_split_size = ctx.max_split_size -if _use_c: - _op_msg_uncompressed = _cmessage._op_msg + # No options + buf.write(_ZERO_32) + # Namespace as C string + buf.write(namespace.encode("utf8")) + buf.write(_ZERO_8) + # Skip: 0, Limit: -1 + buf.write(_SKIPLIM) + # Where to write command document length + command_start = buf.tell() + buf.write(bson.encode(command)) -def _op_msg( - flags: int, - command: MutableMapping[str, Any], - dbname: str, - read_preference: Optional[_ServerMode], - opts: CodecOptions, - ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None, -) -> tuple[int, bytes, int, int]: - """Get a OP_MSG message.""" - command["$db"] = dbname - # getMore commands do not send $readPreference. - if read_preference is not None and "$readPreference" not in command: - # Only send $readPreference if it's not primary (the default). - if read_preference.mode: - command["$readPreference"] = read_preference.document - name = next(iter(command)) + # Start of payload + buf.seek(-1, 2) + # Work around some Jython weirdness. + buf.truncate() try: - identifier = _FIELD_MAP[name] - docs = command.pop(identifier) + buf.write(_OP_MAP[operation]) except KeyError: - identifier = "" - docs = None - try: - if ctx: - return _op_msg_compressed(flags, command, identifier, docs, opts, ctx) - return _op_msg_uncompressed(flags, command, identifier, docs, opts) - finally: - # Add the field back to the command. - if identifier: - command[identifier] = docs - + raise InvalidOperation("Unknown command") from None -def _query_impl( - options: int, - collection_name: str, - num_to_skip: int, - num_to_return: int, - query: Mapping[str, Any], - field_selector: Optional[Mapping[str, Any]], - opts: CodecOptions, -) -> tuple[bytes, int]: - """Get an OP_QUERY message.""" - encoded = _dict_to_bson(query, False, opts) - if field_selector: - efs = _dict_to_bson(field_selector, False, opts) - else: - efs = b"" - max_bson_size = max(len(encoded), len(efs)) - return ( - b"".join( - [ - _pack_int(options), - bson._make_c_string(collection_name), - _pack_int(num_to_skip), - _pack_int(num_to_return), - encoded, - efs, - ] - ), - max_bson_size, - ) + # Where to write list document length + list_start = buf.tell() - 4 + to_send = [] + idx = 0 + for doc in docs: + # Encode the current operation + key = str(idx).encode("utf8") + value = _dict_to_bson(doc, False, opts) + # Is there enough room to add this document? max_cmd_size accounts for + # the two trailing null bytes. + doc_too_large = len(value) > max_cmd_size + if doc_too_large: + write_op = list(_FIELD_MAP.keys())[operation] + _raise_document_too_large(write_op, len(value), max_bson_size) + enough_data = idx >= 1 and (buf.tell() + len(key) + len(value)) >= max_split_size + enough_documents = idx >= max_write_batch_size + if enough_data or enough_documents: + break + buf.write(_BSONOBJ) + buf.write(key) + buf.write(_ZERO_8) + buf.write(value) + to_send.append(doc) + idx += 1 + # Finalize the current OP_QUERY message. + # Close list and command documents + buf.write(_ZERO_16) -def _query_compressed( - options: int, - collection_name: str, - num_to_skip: int, - num_to_return: int, - query: Mapping[str, Any], - field_selector: Optional[Mapping[str, Any]], - opts: CodecOptions, - ctx: Union[SnappyContext, ZlibContext, ZstdContext], -) -> tuple[int, bytes, int]: - """Internal compressed query message helper.""" - op_query, max_bson_size = _query_impl( - options, collection_name, num_to_skip, num_to_return, query, field_selector, opts - ) - rid, msg = _compress(2004, op_query, ctx) - return rid, msg, max_bson_size + # Write document lengths and request id + length = buf.tell() + buf.seek(list_start) + buf.write(_pack_int(length - list_start - 1)) + buf.seek(command_start) + buf.write(_pack_int(length - command_start)) + return to_send, length -def _query_uncompressed( - options: int, - collection_name: str, - num_to_skip: int, - num_to_return: int, - query: Mapping[str, Any], - field_selector: Optional[Mapping[str, Any]], - opts: CodecOptions, -) -> tuple[int, bytes, int]: - """Internal query message helper.""" - op_query, max_bson_size = _query_impl( - options, collection_name, num_to_skip, num_to_return, query, field_selector, opts - ) - rid, msg = __pack_message(2004, op_query) - return rid, msg, max_bson_size +class _OpReply: + """A MongoDB OP_REPLY response message.""" -if _use_c: - _query_uncompressed = _cmessage._query_message + __slots__ = ("flags", "cursor_id", "number_returned", "documents") + UNPACK_FROM = struct.Struct(" tuple[int, bytes, int]: - """Get a **query** message.""" - if ctx: - return _query_compressed( - options, collection_name, num_to_skip, num_to_return, query, field_selector, opts, ctx - ) - return _query_uncompressed( - options, collection_name, num_to_skip, num_to_return, query, field_selector, opts - ) + def __init__(self, flags: int, cursor_id: int, number_returned: int, documents: bytes): + self.flags = flags + self.cursor_id = Int64(cursor_id) + self.number_returned = number_returned + self.documents = documents + def raw_response( + self, cursor_id: Optional[int] = None, user_fields: Optional[Mapping[str, Any]] = None + ) -> list[bytes]: + """Check the response header from the database, without decoding BSON. -_pack_long_long = struct.Struct(" bytes: - """Get an OP_GET_MORE message.""" - return b"".join( - [ - _ZERO_32, - bson._make_c_string(collection_name), - _pack_int(num_to_return), - _pack_long_long(cursor_id), - ] - ) + :param cursor_id: cursor_id we sent to get this response - + used for raising an informative exception when we get cursor id not + valid at server response. + """ + if self.flags & 1: + # Shouldn't get this response if we aren't doing a getMore + if cursor_id is None: + raise ProtocolError("No cursor id for getMore operation") + # Fake a getMore command response. OP_GET_MORE provides no + # document. + msg = "Cursor not found, cursor id: %d" % (cursor_id,) + errobj = {"ok": 0, "errmsg": msg, "code": 43} + raise CursorNotFound(msg, 43, errobj) + elif self.flags & 2: + error_object: dict = bson.BSON(self.documents).decode() + # Fake the ok field if it doesn't exist. + error_object.setdefault("ok", 0) + if error_object["$err"].startswith(HelloCompat.LEGACY_ERROR): + raise NotPrimaryError(error_object["$err"], error_object) + elif error_object.get("code") == 50: + default_msg = "operation exceeded time limit" + raise ExecutionTimeout( + error_object.get("$err", default_msg), error_object.get("code"), error_object + ) + raise OperationFailure( + "database error: %s" % error_object.get("$err"), + error_object.get("code"), + error_object, + ) + if self.documents: + return [self.documents] + return [] -def _get_more_compressed( - collection_name: str, - num_to_return: int, - cursor_id: int, - ctx: Union[SnappyContext, ZlibContext, ZstdContext], -) -> tuple[int, bytes]: - """Internal compressed getMore message helper.""" - return _compress(2005, _get_more_impl(collection_name, num_to_return, cursor_id), ctx) + def unpack_response( + self, + cursor_id: Optional[int] = None, + codec_options: CodecOptions = _UNICODE_REPLACE_CODEC_OPTIONS, + user_fields: Optional[Mapping[str, Any]] = None, + legacy_response: bool = False, + ) -> list[dict[str, Any]]: + """Unpack a response from the database and decode the BSON document(s). + Check the response for errors and unpack, returning a dictionary + containing the response data. -def _get_more_uncompressed( - collection_name: str, num_to_return: int, cursor_id: int -) -> tuple[int, bytes]: - """Internal getMore message helper.""" - return __pack_message(2005, _get_more_impl(collection_name, num_to_return, cursor_id)) + Can raise CursorNotFound, NotPrimaryError, ExecutionTimeout, or + OperationFailure. + :param cursor_id: cursor_id we sent to get this response - + used for raising an informative exception when we get cursor id not + valid at server response + :param codec_options: an instance of + :class:`~bson.codec_options.CodecOptions` + :param user_fields: Response fields that should be decoded + using the TypeDecoders from codec_options, passed to + bson._decode_all_selective. + """ + self.raw_response(cursor_id) + if legacy_response: + return bson.decode_all(self.documents, codec_options) + return bson._decode_all_selective(self.documents, codec_options, user_fields) -if _use_c: - _get_more_uncompressed = _cmessage._get_more_message + def command_response(self, codec_options: CodecOptions) -> dict[str, Any]: + """Unpack a command response.""" + docs = self.unpack_response(codec_options=codec_options) + assert self.number_returned == 1 + return docs[0] + def raw_command_response(self) -> NoReturn: + """Return the bytes of the command response.""" + # This should never be called on _OpReply. + raise NotImplementedError -def _get_more( - collection_name: str, - num_to_return: int, - cursor_id: int, - ctx: Union[SnappyContext, ZlibContext, ZstdContext, None] = None, -) -> tuple[int, bytes]: - """Get a **getMore** message.""" - if ctx: - return _get_more_compressed(collection_name, num_to_return, cursor_id, ctx) - return _get_more_uncompressed(collection_name, num_to_return, cursor_id) + @property + def more_to_come(self) -> bool: + """Is the moreToCome bit set on this response?""" + return False + @classmethod + def unpack(cls, msg: bytes) -> _OpReply: + """Construct an _OpReply from raw bytes.""" + # PYTHON-945: ignore starting_from field. + flags, cursor_id, _, number_returned = cls.UNPACK_FROM(msg) -def _raise_document_too_large(operation: str, doc_size: int, max_size: int) -> NoReturn: - """Internal helper for raising DocumentTooLarge.""" - if operation == "insert": - raise DocumentTooLarge( - "BSON document too large (%d bytes)" - " - the connected server supports" - " BSON document sizes up to %d" - " bytes." % (doc_size, max_size) - ) - else: - # There's nothing intelligent we can say - # about size for update and delete - raise DocumentTooLarge(f"{operation!r} command document too large") + documents = msg[20:] + return cls(flags, cursor_id, number_returned, documents) -# OP_MSG ------------------------------------------------------------- +class _OpMsg: + """A MongoDB OP_MSG response message.""" + __slots__ = ("flags", "cursor_id", "number_returned", "payload_document") -_OP_MSG_MAP = { - _INSERT: b"documents\x00", - _UPDATE: b"updates\x00", - _DELETE: b"deletes\x00", -} + UNPACK_FROM = struct.Struct(" list[Mapping[str, Any]]: + """ + cursor_id is ignored + user_fields is used to determine which fields must not be decoded + """ + inflated_response = bson._decode_selective( + RawBSONDocument(self.payload_document), user_fields, _RAW_ARRAY_BSON_OPTIONS + ) + return [inflated_response] - def __init__( + def unpack_response( self, - database_name: str, - cmd_name: str, - conn: _AgnosticConnection, - operation_id: int, - listeners: _EventListeners, - session: _AgnosticClientSession, - op_type: int, - codec: CodecOptions, - ): - self.db_name = database_name - self.conn = conn - self.op_id = operation_id - self.listeners = listeners - self.publish = listeners.enabled_for_commands - self.name = cmd_name - self.field = _FIELD_MAP[self.name] - self.start_time = datetime.datetime.now() - self.session = session - self.compress = bool(conn.compression_context) - self.op_type = op_type - self.codec = codec + cursor_id: Optional[int] = None, + codec_options: CodecOptions = _UNICODE_REPLACE_CODEC_OPTIONS, + user_fields: Optional[Mapping[str, Any]] = None, + legacy_response: bool = False, + ) -> list[dict[str, Any]]: + """Unpack a OP_MSG command response. - def batch_command( - self, cmd: MutableMapping[str, Any], docs: list[Mapping[str, Any]] - ) -> tuple[int, Union[bytes, dict[str, Any]], list[Mapping[str, Any]]]: - namespace = self.db_name + ".$cmd" - request_id, msg, to_send = _do_batched_op_msg( - namespace, self.op_type, cmd, docs, self.codec, self - ) - if not to_send: - raise InvalidOperation("cannot do an empty bulk write") - return request_id, msg, to_send + :param cursor_id: Ignored, for compatibility with _OpReply. + :param codec_options: an instance of + :class:`~bson.codec_options.CodecOptions` + :param user_fields: Response fields that should be decoded + using the TypeDecoders from codec_options, passed to + bson._decode_all_selective. + """ + # If _OpMsg is in-use, this cannot be a legacy response. + assert not legacy_response + return bson._decode_all_selective(self.payload_document, codec_options, user_fields) - @property - def max_bson_size(self) -> int: - """A proxy for SockInfo.max_bson_size.""" - return self.conn.max_bson_size + def command_response(self, codec_options: CodecOptions) -> dict[str, Any]: + """Unpack a command response.""" + return self.unpack_response(codec_options=codec_options)[0] - @property - def max_message_size(self) -> int: - """A proxy for SockInfo.max_message_size.""" - if self.compress: - # Subtract 16 bytes for the message header. - return self.conn.max_message_size - 16 - return self.conn.max_message_size + def raw_command_response(self) -> bytes: + """Return the bytes of the command response.""" + return self.payload_document @property - def max_write_batch_size(self) -> int: - """A proxy for SockInfo.max_write_batch_size.""" - return self.conn.max_write_batch_size + def more_to_come(self) -> bool: + """Is the moreToCome bit set on this response?""" + return bool(self.flags & self.MORE_TO_COME) - @property - def max_split_size(self) -> int: - """The maximum size of a BSON command before batch splitting.""" - return self.max_bson_size + @classmethod + def unpack(cls, msg: bytes) -> _OpMsg: + """Construct an _OpMsg from raw bytes.""" + flags, first_payload_type, first_payload_size = cls.UNPACK_FROM(msg) + if flags != 0: + if flags & cls.CHECKSUM_PRESENT: + raise ProtocolError(f"Unsupported OP_MSG flag checksumPresent: 0x{flags:x}") - def _start( - self, cmd: MutableMapping[str, Any], request_id: int, docs: list[Mapping[str, Any]] - ) -> MutableMapping[str, Any]: - """Publish a CommandStartedEvent.""" - cmd[self.field] = docs - self.listeners.publish_command_start( - cmd, - self.db_name, - request_id, - self.conn.address, - self.conn.server_connection_id, - self.op_id, - self.conn.service_id, - ) - return cmd - - def _succeed(self, request_id: int, reply: _DocumentOut, duration: datetime.timedelta) -> None: - """Publish a CommandSucceededEvent.""" - self.listeners.publish_command_success( - duration, - reply, - self.name, - request_id, - self.conn.address, - self.conn.server_connection_id, - self.op_id, - self.conn.service_id, - database_name=self.db_name, - ) - - def _fail(self, request_id: int, failure: _DocumentOut, duration: datetime.timedelta) -> None: - """Publish a CommandFailedEvent.""" - self.listeners.publish_command_failure( - duration, - failure, - self.name, - request_id, - self.conn.address, - self.conn.server_connection_id, - self.op_id, - self.conn.service_id, - database_name=self.db_name, - ) + if flags ^ cls.MORE_TO_COME: + raise ProtocolError(f"Unsupported OP_MSG flags: 0x{flags:x}") + if first_payload_type != 0: + raise ProtocolError(f"Unsupported OP_MSG payload type: 0x{first_payload_type:x}") + if len(msg) != first_payload_size + 5: + raise ProtocolError("Unsupported OP_MSG reply: >1 section") -class _EncryptedBulkWriteContext(_BulkWriteContext): - __slots__ = () + payload_document = msg[5:] + return cls(flags, payload_document) - def batch_command( - self, cmd: MutableMapping[str, Any], docs: list[Mapping[str, Any]] - ) -> tuple[int, dict[str, Any], list[Mapping[str, Any]]]: - namespace = self.db_name + ".$cmd" - msg, to_send = _encode_batched_write_command( - namespace, self.op_type, cmd, docs, self.codec, self - ) - if not to_send: - raise InvalidOperation("cannot do an empty bulk write") - # Chop off the OP_QUERY header to get a properly batched write command. - cmd_start = msg.index(b"\x00", 4) + 9 - outgoing = _inflate_bson(memoryview(msg)[cmd_start:], DEFAULT_RAW_BSON_OPTIONS) - return -1, outgoing, to_send +_UNPACK_REPLY: dict[int, Callable[[bytes], Union[_OpReply, _OpMsg]]] = { + _OpReply.OP_CODE: _OpReply.unpack, + _OpMsg.OP_CODE: _OpMsg.unpack, +} - @property - def max_split_size(self) -> int: - """Reduce the batch splitting size.""" - return _MAX_SPLIT_SIZE_ENC +class _Query: + """A query operation.""" -# From the Client Side Encryption spec: -# Because automatic encryption increases the size of commands, the driver -# MUST split bulk writes at a reduced size limit before undergoing automatic -# encryption. The write payload MUST be split at 2MiB (2097152). -_MAX_SPLIT_SIZE_ENC = 2097152 + __slots__ = ( + "flags", + "db", + "coll", + "ntoskip", + "spec", + "fields", + "codec_options", + "read_preference", + "limit", + "batch_size", + "name", + "read_concern", + "collation", + "session", + "client", + "allow_disk_use", + "_as_command", + "exhaust", + ) + # For compatibility with the _GetMore class. + conn_mgr = None + cursor_id = None -def _batched_op_msg_impl( - operation: int, - command: Mapping[str, Any], - docs: list[Mapping[str, Any]], - ack: bool, - opts: CodecOptions, - ctx: _BulkWriteContext, - buf: _BytesIO, -) -> tuple[list[Mapping[str, Any]], int]: - """Create a batched OP_MSG write.""" - max_bson_size = ctx.max_bson_size - max_write_batch_size = ctx.max_write_batch_size - max_message_size = ctx.max_message_size + def __init__( + self, + flags: int, + db: str, + coll: str, + ntoskip: int, + spec: Mapping[str, Any], + fields: Optional[Mapping[str, Any]], + codec_options: CodecOptions, + read_preference: _ServerMode, + limit: int, + batch_size: int, + read_concern: ReadConcern, + collation: Optional[Mapping[str, Any]], + session: Optional[_AgnosticClientSession], + client: _AgnosticMongoClient, + allow_disk_use: Optional[bool], + exhaust: bool, + ): + self.flags = flags + self.db = db + self.coll = coll + self.ntoskip = ntoskip + self.spec = spec + self.fields = fields + self.codec_options = codec_options + self.read_preference = read_preference + self.read_concern = read_concern + self.limit = limit + self.batch_size = batch_size + self.collation = collation + self.session = session + self.client = client + self.allow_disk_use = allow_disk_use + self.name = "find" + self._as_command: Optional[tuple[dict[str, Any], str]] = None + self.exhaust = exhaust - flags = b"\x00\x00\x00\x00" if ack else b"\x02\x00\x00\x00" - # Flags - buf.write(flags) + def reset(self) -> None: + self._as_command = None - # Type 0 Section - buf.write(b"\x00") - buf.write(_dict_to_bson(command, False, opts)) + def namespace(self) -> str: + return f"{self.db}.{self.coll}" - # Type 1 Section - buf.write(b"\x01") - size_location = buf.tell() - # Save space for size - buf.write(b"\x00\x00\x00\x00") - try: - buf.write(_OP_MSG_MAP[operation]) - except KeyError: - raise InvalidOperation("Unknown command") from None + def use_command(self, conn: _AgnosticConnection) -> bool: + use_find_cmd = False + if not self.exhaust: + use_find_cmd = True + elif conn.max_wire_version >= 8: + # OP_MSG supports exhaust on MongoDB 4.2+ + use_find_cmd = True + elif not self.read_concern.ok_for_legacy: + raise ConfigurationError( + "read concern level of %s is not valid " + "with a max wire version of %d." % (self.read_concern.level, conn.max_wire_version) + ) - to_send = [] - idx = 0 - for doc in docs: - # Encode the current operation - value = _dict_to_bson(doc, False, opts) - doc_length = len(value) - new_message_size = buf.tell() + doc_length - # Does first document exceed max_message_size? - doc_too_large = idx == 0 and (new_message_size > max_message_size) - # When OP_MSG is used unacknowledged we have to check - # document size client side or applications won't be notified. - # Otherwise we let the server deal with documents that are too large - # since ordered=False causes those documents to be skipped instead of - # halting the bulk write operation. - unacked_doc_too_large = not ack and (doc_length > max_bson_size) - if doc_too_large or unacked_doc_too_large: - write_op = list(_FIELD_MAP.keys())[operation] - _raise_document_too_large(write_op, len(value), max_bson_size) - # We have enough data, return this batch. - if new_message_size > max_message_size: - break - buf.write(value) - to_send.append(doc) - idx += 1 - # We have enough documents, return this batch. - if idx == max_write_batch_size: - break + conn.validate_session(self.client, self.session) # type: ignore[arg-type] + return use_find_cmd - # Write type 1 section size - length = buf.tell() - buf.seek(size_location) - buf.write(_pack_int(length - size_location)) + def update_command(self, cmd: dict[str, Any]) -> None: + self._as_command = cmd, self.db - return to_send, length + def as_command( + self, conn: _AgnosticConnection, apply_timeout: bool = False + ) -> tuple[dict[str, Any], str]: + """Return a find command document for this query.""" + # We use the command twice: on the wire and for command monitoring. + # Generate it once, for speed and to avoid repeating side-effects. + if self._as_command is not None: + return self._as_command + explain = "$explain" in self.spec + cmd: dict[str, Any] = _gen_find_command( + self.coll, + self.spec, + self.fields, + self.ntoskip, + self.limit, + self.batch_size, + self.flags, + self.read_concern, + self.collation, + self.session, + self.allow_disk_use, + ) + if explain: + self.name = "explain" + cmd = {"explain": cmd} + conn.add_server_api(cmd) + if self.session: + self.session._apply_to(cmd, False, self.read_preference, conn) # type: ignore[arg-type] + # Explain does not support readConcern. + if not explain and not self.session.in_transaction: + self.session._update_read_concern(cmd, conn) # type: ignore[arg-type] + conn.send_cluster_time(cmd, self.session, self.client) # type: ignore[arg-type] + # Support CSOT + if apply_timeout: + conn.apply_timeout(self.client, cmd=cmd) # type: ignore[arg-type] + self._as_command = cmd, self.db + return self._as_command -def _encode_batched_op_msg( - operation: int, - command: Mapping[str, Any], - docs: list[Mapping[str, Any]], - ack: bool, - opts: CodecOptions, - ctx: _BulkWriteContext, -) -> tuple[bytes, list[Mapping[str, Any]]]: - """Encode the next batched insert, update, or delete operation - as OP_MSG. - """ - buf = _BytesIO() + def get_message( + self, read_preference: _ServerMode, conn: _AgnosticConnection, use_cmd: bool = False + ) -> tuple[int, bytes, int]: + """Get a query message, possibly setting the secondaryOk bit.""" + # Use the read_preference decided by _socket_from_server. + self.read_preference = read_preference + if read_preference.mode: + # Set the secondaryOk bit. + flags = self.flags | 4 + else: + flags = self.flags - to_send, _ = _batched_op_msg_impl(operation, command, docs, ack, opts, ctx, buf) - return buf.getvalue(), to_send + ns = self.namespace() + spec = self.spec + if use_cmd: + spec = self.as_command(conn)[0] + request_id, msg, size, _ = _op_msg( + 0, + spec, + self.db, + read_preference, + self.codec_options, + ctx=conn.compression_context, + ) + return request_id, msg, size -if _use_c: - _encode_batched_op_msg = _cmessage._encode_batched_op_msg + # OP_QUERY treats ntoreturn of -1 and 1 the same, return + # one document and close the cursor. We have to use 2 for + # batch size if 1 is specified. + ntoreturn = self.batch_size == 1 and 2 or self.batch_size + if self.limit: + if ntoreturn: + ntoreturn = min(self.limit, ntoreturn) + else: + ntoreturn = self.limit + if conn.is_mongos: + assert isinstance(spec, MutableMapping) + spec = _maybe_add_read_preference(spec, read_preference) -def _batched_op_msg_compressed( - operation: int, - command: Mapping[str, Any], - docs: list[Mapping[str, Any]], - ack: bool, - opts: CodecOptions, - ctx: _BulkWriteContext, -) -> tuple[int, bytes, list[Mapping[str, Any]]]: - """Create the next batched insert, update, or delete operation - with OP_MSG, compressed. - """ - data, to_send = _encode_batched_op_msg(operation, command, docs, ack, opts, ctx) + return _query( + flags, + ns, + self.ntoskip, + ntoreturn, + spec, + None if use_cmd else self.fields, + self.codec_options, + ctx=conn.compression_context, + ) - assert ctx.conn.compression_context is not None - request_id, msg = _compress(2013, data, ctx.conn.compression_context) - return request_id, msg, to_send +class _GetMore: + """A getmore operation.""" -def _batched_op_msg( - operation: int, - command: Mapping[str, Any], - docs: list[Mapping[str, Any]], - ack: bool, - opts: CodecOptions, - ctx: _BulkWriteContext, -) -> tuple[int, bytes, list[Mapping[str, Any]]]: - """OP_MSG implementation entry point.""" - buf = _BytesIO() + __slots__ = ( + "db", + "coll", + "ntoreturn", + "cursor_id", + "max_await_time_ms", + "codec_options", + "read_preference", + "session", + "client", + "conn_mgr", + "_as_command", + "exhaust", + "comment", + ) - # Save space for message length and request id - buf.write(_ZERO_64) - # responseTo, opCode - buf.write(b"\x00\x00\x00\x00\xdd\x07\x00\x00") + name = "getMore" - to_send, length = _batched_op_msg_impl(operation, command, docs, ack, opts, ctx, buf) + def __init__( + self, + db: str, + coll: str, + ntoreturn: int, + cursor_id: int, + codec_options: CodecOptions, + read_preference: _ServerMode, + session: Optional[_AgnosticClientSession], + client: _AgnosticMongoClient, + max_await_time_ms: Optional[int], + conn_mgr: Any, + exhaust: bool, + comment: Any, + ): + self.db = db + self.coll = coll + self.ntoreturn = ntoreturn + self.cursor_id = cursor_id + self.codec_options = codec_options + self.read_preference = read_preference + self.session = session + self.client = client + self.max_await_time_ms = max_await_time_ms + self.conn_mgr = conn_mgr + self._as_command: Optional[tuple[dict[str, Any], str]] = None + self.exhaust = exhaust + self.comment = comment - # Header - request id and message length - buf.seek(4) - request_id = _randint() - buf.write(_pack_int(request_id)) - buf.seek(0) - buf.write(_pack_int(length)) + def reset(self) -> None: + self._as_command = None - return request_id, buf.getvalue(), to_send + def namespace(self) -> str: + return f"{self.db}.{self.coll}" + def use_command(self, conn: _AgnosticConnection) -> bool: + use_cmd = False + if not self.exhaust: + use_cmd = True + elif conn.max_wire_version >= 8: + # OP_MSG supports exhaust on MongoDB 4.2+ + use_cmd = True -if _use_c: - _batched_op_msg = _cmessage._batched_op_msg + conn.validate_session(self.client, self.session) # type: ignore[arg-type] + return use_cmd + def update_command(self, cmd: dict[str, Any]) -> None: + self._as_command = cmd, self.db -def _do_batched_op_msg( - namespace: str, - operation: int, - command: MutableMapping[str, Any], - docs: list[Mapping[str, Any]], - opts: CodecOptions, - ctx: _BulkWriteContext, -) -> tuple[int, bytes, list[Mapping[str, Any]]]: - """Create the next batched insert, update, or delete operation - using OP_MSG. - """ - command["$db"] = namespace.split(".", 1)[0] - if "writeConcern" in command: - ack = bool(command["writeConcern"].get("w", 1)) - else: - ack = True - if ctx.conn.compression_context: - return _batched_op_msg_compressed(operation, command, docs, ack, opts, ctx) - return _batched_op_msg(operation, command, docs, ack, opts, ctx) + def as_command( + self, conn: _AgnosticConnection, apply_timeout: bool = False + ) -> tuple[dict[str, Any], str]: + """Return a getMore command document for this query.""" + # See _Query.as_command for an explanation of this caching. + if self._as_command is not None: + return self._as_command + cmd: dict[str, Any] = _gen_get_more_command( + self.cursor_id, + self.coll, + self.ntoreturn, + self.max_await_time_ms, + self.comment, + conn, + ) + if self.session: + self.session._apply_to(cmd, False, self.read_preference, conn) # type: ignore[arg-type] + conn.add_server_api(cmd) + conn.send_cluster_time(cmd, self.session, self.client) # type: ignore[arg-type] + # Support CSOT + if apply_timeout: + conn.apply_timeout(self.client, cmd=None) # type: ignore[arg-type] + self._as_command = cmd, self.db + return self._as_command -# End OP_MSG ----------------------------------------------------- + def get_message( + self, dummy0: Any, conn: _AgnosticConnection, use_cmd: bool = False + ) -> Union[tuple[int, bytes, int], tuple[int, bytes]]: + """Get a getmore message.""" + ns = self.namespace() + ctx = conn.compression_context + if use_cmd: + spec = self.as_command(conn)[0] + if self.conn_mgr and self.exhaust: + flags = _OpMsg.EXHAUST_ALLOWED + else: + flags = 0 + request_id, msg, size, _ = _op_msg( + flags, spec, self.db, None, self.codec_options, ctx=conn.compression_context + ) + return request_id, msg, size -def _encode_batched_write_command( - namespace: str, - operation: int, - command: MutableMapping[str, Any], - docs: list[Mapping[str, Any]], - opts: CodecOptions, - ctx: _BulkWriteContext, -) -> tuple[bytes, list[Mapping[str, Any]]]: - """Encode the next batched insert, update, or delete command.""" - buf = _BytesIO() + return _get_more(ns, self.ntoreturn, self.cursor_id, ctx) - to_send, _ = _batched_write_command_impl(namespace, operation, command, docs, opts, ctx, buf) - return buf.getvalue(), to_send +class _RawBatchQuery(_Query): + def use_command(self, conn: _AgnosticConnection) -> bool: + # Compatibility checks. + super().use_command(conn) + if conn.max_wire_version >= 8: + # MongoDB 4.2+ supports exhaust over OP_MSG + return True + elif not self.exhaust: + return True + return False -if _use_c: - _encode_batched_write_command = _cmessage._encode_batched_write_command +class _RawBatchGetMore(_GetMore): + def use_command(self, conn: _AgnosticConnection) -> bool: + # Compatibility checks. + super().use_command(conn) + if conn.max_wire_version >= 8: + # MongoDB 4.2+ supports exhaust over OP_MSG + return True + elif not self.exhaust: + return True + return False -def _batched_write_command_impl( - namespace: str, - operation: int, - command: MutableMapping[str, Any], - docs: list[Mapping[str, Any]], - opts: CodecOptions, - ctx: _BulkWriteContext, - buf: _BytesIO, -) -> tuple[list[Mapping[str, Any]], int]: - """Create a batched OP_QUERY write command.""" - max_bson_size = ctx.max_bson_size - max_write_batch_size = ctx.max_write_batch_size - # Max BSON object size + 16k - 2 bytes for ending NUL bytes. - # Server guarantees there is enough room: SERVER-10643. - max_cmd_size = max_bson_size + _COMMAND_OVERHEAD - max_split_size = ctx.max_split_size - # No options - buf.write(_ZERO_32) - # Namespace as C string - buf.write(namespace.encode("utf8")) - buf.write(_ZERO_8) - # Skip: 0, Limit: -1 - buf.write(_SKIPLIM) +class _CursorAddress(tuple): + """The server address (host, port) of a cursor, with namespace property.""" - # Where to write command document length - command_start = buf.tell() - buf.write(bson.encode(command)) + __namespace: Any - # Start of payload - buf.seek(-1, 2) - # Work around some Jython weirdness. - buf.truncate() - try: - buf.write(_OP_MAP[operation]) - except KeyError: - raise InvalidOperation("Unknown command") from None + def __new__(cls, address: _Address, namespace: str) -> _CursorAddress: + self = tuple.__new__(cls, address) + self.__namespace = namespace + return self - # Where to write list document length - list_start = buf.tell() - 4 - to_send = [] - idx = 0 - for doc in docs: - # Encode the current operation - key = str(idx).encode("utf8") - value = _dict_to_bson(doc, False, opts) - # Is there enough room to add this document? max_cmd_size accounts for - # the two trailing null bytes. - doc_too_large = len(value) > max_cmd_size - if doc_too_large: - write_op = list(_FIELD_MAP.keys())[operation] - _raise_document_too_large(write_op, len(value), max_bson_size) - enough_data = idx >= 1 and (buf.tell() + len(key) + len(value)) >= max_split_size - enough_documents = idx >= max_write_batch_size - if enough_data or enough_documents: - break - buf.write(_BSONOBJ) - buf.write(key) - buf.write(_ZERO_8) - buf.write(value) - to_send.append(doc) - idx += 1 + @property + def namespace(self) -> str: + """The namespace this cursor.""" + return self.__namespace - # Finalize the current OP_QUERY message. - # Close list and command documents - buf.write(_ZERO_16) + def __hash__(self) -> int: + # Two _CursorAddress instances with different namespaces + # must not hash the same. + return ((*self, self.__namespace)).__hash__() - # Write document lengths and request id - length = buf.tell() - buf.seek(list_start) - buf.write(_pack_int(length - list_start - 1)) - buf.seek(command_start) - buf.write(_pack_int(length - command_start)) + def __eq__(self, other: object) -> bool: + if isinstance(other, _CursorAddress): + return tuple(self) == tuple(other) and self.namespace == other.namespace + return NotImplemented - return to_send, length + def __ne__(self, other: object) -> bool: + return not self == other diff --git a/pymongo/synchronous/server.py b/pymongo/synchronous/server.py index fea55b8382..347155784f 100644 --- a/pymongo/synchronous/server.py +++ b/pymongo/synchronous/server.py @@ -135,12 +135,12 @@ def run_operation( cursors. Can raise ConnectionFailure, OperationFailure, etc. - :param conn: An AsyncConnection instance. + :param conn: A Connection instance. :param operation: A _Query or _GetMore object. :param read_preference: The read preference to use. :param listeners: Instance of _EventListeners or None. :param unpack_res: A callable that decodes the wire protocol response. - :param client: An AsyncMongoClient instance. + :param client: A MongoClient instance. """ assert listeners is not None publish = listeners.enabled_for_commands diff --git a/test/unified_format.py b/test/unified_format.py index ecf0133e74..fe3bb27b54 100644 --- a/test/unified_format.py +++ b/test/unified_format.py @@ -616,7 +616,7 @@ def get_lsid_for_session(self, session_name): session = self[session_name] if not isinstance(session, ClientSession): self.test.fail( - f"Expected entity {session_name} to be of type AsyncClientSession, got {type(session)}" + f"Expected entity {session_name} to be of type ClientSession, got {type(session)}" ) try: diff --git a/test/utils_selection_tests.py b/test/utils_selection_tests.py index e6fb829eb3..cef5780d21 100644 --- a/test/utils_selection_tests.py +++ b/test/utils_selection_tests.py @@ -19,8 +19,6 @@ import os import sys -from pymongo.operations import _Op - sys.path[0:0] = [""] from test import unittest @@ -31,6 +29,7 @@ from pymongo.common import HEARTBEAT_FREQUENCY, clean_node from pymongo.errors import AutoReconnect, ConfigurationError from pymongo.hello import Hello, HelloCompat +from pymongo.operations import _Op from pymongo.server_description import ServerDescription from pymongo.server_selectors import writable_server_selector from pymongo.synchronous.settings import TopologySettings diff --git a/tools/synchro.py b/tools/synchro.py index f4b0683ce6..2b35bc7e03 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -146,6 +146,7 @@ "operations.py", "pool.py", "topology.py", + "server.py", ] ] + [_gridfs_dest_base + f for f in ["grid_file.py"]] From 161de216859ecabfe2143abee5d53df007f1df57 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Mon, 24 Jun 2024 14:28:14 -0700 Subject: [PATCH 08/11] Update async test_collection --- pymongo/asynchronous/mongo_client.py | 2 +- test/asynchronous/test_collection.py | 3 +- test/synchronous/test_collection.py | 7 +++-- tools/synchro.py | 44 +++++++++++++++------------- 4 files changed, 31 insertions(+), 25 deletions(-) diff --git a/pymongo/asynchronous/mongo_client.py b/pymongo/asynchronous/mongo_client.py index 407beff5b8..b12e3dd694 100644 --- a/pymongo/asynchronous/mongo_client.py +++ b/pymongo/asynchronous/mongo_client.py @@ -2040,7 +2040,7 @@ async def _tmp_session( """If provided session is None, lend a temporary session.""" if session is not None: if not isinstance(session, client_session.AsyncClientSession): - raise ValueError("'session' argument must be a AsyncClientSession or None.") + raise ValueError("'session' argument must be an AsyncClientSession or None.") # Don't call end_session. yield session return diff --git a/test/asynchronous/test_collection.py b/test/asynchronous/test_collection.py index 7e907eaf34..45f1c34c9b 100644 --- a/test/asynchronous/test_collection.py +++ b/test/asynchronous/test_collection.py @@ -40,9 +40,10 @@ wait_until, ) -from bson import RawBSONDocument, encode +from bson import encode from bson.codec_options import CodecOptions from bson.objectid import ObjectId +from bson.raw_bson import RawBSONDocument from bson.regex import Regex from bson.son import SON from pymongo import ASCENDING, DESCENDING, GEO2D, GEOSPHERE, HASHED, TEXT diff --git a/test/synchronous/test_collection.py b/test/synchronous/test_collection.py index a11d1b6bb4..d46b43989b 100644 --- a/test/synchronous/test_collection.py +++ b/test/synchronous/test_collection.py @@ -39,9 +39,10 @@ wait_until, ) -from bson import RawBSONDocument, encode +from bson import encode from bson.codec_options import CodecOptions from bson.objectid import ObjectId +from bson.raw_bson import RawBSONDocument from bson.regex import Regex from bson.son import SON from pymongo import ASCENDING, DESCENDING, GEO2D, GEOSPHERE, HASHED, TEXT @@ -133,7 +134,7 @@ def test_iteration(self): if _IS_SYNC: msg = "'Collection' object is not iterable" else: - msg = "'AsyncCollection' object is not iterable" + msg = "'Collection' object is not iterable" # Iteration fails with self.assertRaisesRegex(TypeError, msg): for _ in coll: # type: ignore[misc] # error: "None" not callable [misc] @@ -1616,7 +1617,7 @@ def try_invalid_session(): with self.db.test.aggregate([], {}): # type:ignore pass - with self.assertRaisesRegex(ValueError, "must be an AsyncClientSession"): + with self.assertRaisesRegex(ValueError, "must be a ClientSession"): try_invalid_session() def test_large_limit(self): diff --git a/tools/synchro.py b/tools/synchro.py index 2b35bc7e03..586c80fd81 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -128,27 +128,31 @@ ] -docstring_translate_files = [ - _pymongo_dest_base + f - for f in [ - "aggregation.py", - "change_stream.py", - "collection.py", - "command_cursor.py", - "cursor.py", - "client_options.py", - "client_session.py", - "database.py", - "encryption.py", - "encryption_options.py", - "mongo_client.py", - "network.py", - "operations.py", - "pool.py", - "topology.py", - "server.py", +docstring_translate_files = ( + [ + _pymongo_dest_base + f + for f in [ + "aggregation.py", + "change_stream.py", + "collection.py", + "command_cursor.py", + "cursor.py", + "client_options.py", + "client_session.py", + "database.py", + "encryption.py", + "encryption_options.py", + "mongo_client.py", + "network.py", + "operations.py", + "pool.py", + "topology.py", + "server.py", + ] ] -] + [_gridfs_dest_base + f for f in ["grid_file.py"]] + + [_gridfs_dest_base + f for f in ["grid_file.py"]] + + sync_test_files +) def process_files(files: list[str]) -> None: From a3b15f04c6bf77bfaaa8d39582c4ceaddb21d15b Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Mon, 24 Jun 2024 15:18:54 -0700 Subject: [PATCH 09/11] Resolve merge conflicts --- pymongo/pool_options.py | 23 +++++++++++++++++++++++ test/test_client.py | 22 +++++++++++++++++++++- 2 files changed, 44 insertions(+), 1 deletion(-) diff --git a/pymongo/pool_options.py b/pymongo/pool_options.py index 4170bb5cb6..668b82635a 100644 --- a/pymongo/pool_options.py +++ b/pymongo/pool_options.py @@ -243,6 +243,29 @@ def _truncate_metadata(metadata: MutableMapping[str, Any]) -> None: metadata["platform"] = plat else: metadata.pop("platform", None) + encoded_size = len(bson.encode(metadata)) + if encoded_size <= _MAX_METADATA_SIZE: + return + # 5. Truncate driver info. + overflow = encoded_size - _MAX_METADATA_SIZE + driver = metadata.get("driver", {}) + if driver: + # Truncate driver version. + driver_version = driver.get("version")[:-overflow] + if len(driver_version) >= len(_METADATA["driver"]["version"]): + metadata["driver"]["version"] = driver_version + else: + metadata["driver"]["version"] = _METADATA["driver"]["version"] + encoded_size = len(bson.encode(metadata)) + if encoded_size <= _MAX_METADATA_SIZE: + return + # Truncate driver name. + overflow = encoded_size - _MAX_METADATA_SIZE + driver_name = driver.get("name")[:-overflow] + if len(driver_name) >= len(_METADATA["driver"]["name"]): + metadata["driver"]["name"] = driver_name + else: + metadata["driver"]["name"] = _METADATA["driver"]["name"] # If the first getaddrinfo call of this interpreter's life is on a thread, diff --git a/test/test_client.py b/test/test_client.py index 43d43d6d00..73fbe749b1 100644 --- a/test/test_client.py +++ b/test/test_client.py @@ -36,6 +36,7 @@ import pytest +import bson from pymongo.operations import _Op sys.path[0:0] = [""] @@ -100,7 +101,7 @@ WriteConcernError, ) from pymongo.monitoring import ServerHeartbeatListener, ServerHeartbeatStartedEvent -from pymongo.pool_options import _METADATA, ENV_VAR_K8S, PoolOptions +from pymongo.pool_options import _MAX_METADATA_SIZE, _METADATA, ENV_VAR_K8S, PoolOptions from pymongo.read_preferences import ReadPreference from pymongo.server_description import ServerDescription from pymongo.server_selectors import readable_server_selector, writable_server_selector @@ -359,6 +360,25 @@ def test_metadata(self): ) options = client.options self.assertEqual(options.pool_options.metadata, metadata) + # Test truncating driver info metadata. + client = MongoClient( + driver=DriverInfo(name="s" * _MAX_METADATA_SIZE), + connect=False, + ) + options = client.options + self.assertLessEqual( + len(bson.encode(options.pool_options.metadata)), + _MAX_METADATA_SIZE, + ) + client = MongoClient( + driver=DriverInfo(name="s" * _MAX_METADATA_SIZE, version="s" * _MAX_METADATA_SIZE), + connect=False, + ) + options = client.options + self.assertLessEqual( + len(bson.encode(options.pool_options.metadata)), + _MAX_METADATA_SIZE, + ) @mock.patch.dict("os.environ", {ENV_VAR_K8S: "1"}) def test_container_metadata(self): From d4559b7f3d37b448f41d374754da0d1a52e812fb Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Mon, 24 Jun 2024 16:56:10 -0700 Subject: [PATCH 10/11] Fix bulk typechecking errors --- pymongo/asynchronous/bulk.py | 99 +++++++++++++++++-------------- pymongo/synchronous/bulk.py | 105 ++++++++++++++++++--------------- pymongo/synchronous/helpers.py | 4 +- tools/synchro.py | 26 +------- 4 files changed, 115 insertions(+), 119 deletions(-) diff --git a/pymongo/asynchronous/bulk.py b/pymongo/asynchronous/bulk.py index 725596ab6d..c200899dd1 100644 --- a/pymongo/asynchronous/bulk.py +++ b/pymongo/asynchronous/bulk.py @@ -403,6 +403,56 @@ async def unack_write( bwc.start_time = datetime.datetime.now() return result # type: ignore[return-value] + async def _execute_batch_unack( + self, + bwc: Union[_BulkWriteContext, _EncryptedBulkWriteContext], + cmd: dict[str, Any], + ops: list[Mapping[str, Any]], + client: AsyncMongoClient, + ) -> list[Mapping[str, Any]]: + if self.is_encrypted: + _, batched_cmd, to_send = bwc.batch_command(cmd, ops) + await bwc.conn.command( # type: ignore[misc] + bwc.db_name, + batched_cmd, # type: ignore[arg-type] + write_concern=WriteConcern(w=0), + session=bwc.session, # type: ignore[arg-type] + client=client, # type: ignore[arg-type] + ) + else: + request_id, msg, to_send = bwc.batch_command(cmd, ops) + # Though this isn't strictly a "legacy" write, the helper + # handles publishing commands and sending our message + # without receiving a result. Send 0 for max_doc_size + # to disable size checking. Size checking is handled while + # the documents are encoded to BSON. + await self.unack_write(bwc, cmd, request_id, msg, 0, to_send, client) # type: ignore[arg-type] + + return to_send + + async def _execute_batch( + self, + bwc: Union[_BulkWriteContext, _EncryptedBulkWriteContext], + cmd: dict[str, Any], + ops: list[Mapping[str, Any]], + client: AsyncMongoClient, + ) -> tuple[dict[str, Any], list[Mapping[str, Any]]]: + if self.is_encrypted: + _, batched_cmd, to_send = bwc.batch_command(cmd, ops) + result = await bwc.conn.command( # type: ignore[misc] + bwc.db_name, + batched_cmd, # type: ignore[arg-type] + codec_options=bwc.codec, + session=bwc.session, # type: ignore[arg-type] + client=client, # type: ignore[arg-type] + ) + else: + request_id, msg, to_send = bwc.batch_command(cmd, ops) + result = await self.write_command(bwc, cmd, request_id, msg, to_send, client) # type: ignore[arg-type] + await client._process_response(result, bwc.session) # type: ignore[arg-type] + + return result, to_send # type: ignore[return-value] + async def _execute_command( self, generator: Iterator[Any], @@ -475,21 +525,7 @@ async def _execute_command( # Run as many ops as possible in one command. if write_concern.acknowledged: - if self.is_encrypted: - _, batched_cmd, to_send = bwc.batch_command(cmd, ops) - result = await bwc.conn.command( - bwc.db_name, - batched_cmd, - codec_options=bwc.codec, - session=bwc.session, - client=client, - ) - else: - request_id, msg, to_send = bwc.batch_command(cmd, ops) - result = await self.write_command( - bwc, cmd, request_id, msg, to_send, client - ) - await client._process_response(result, bwc.session) + result, to_send = await self._execute_batch(bwc, cmd, ops, client) # Retryable writeConcernErrors halt the execution of this run. wce = result.get("writeConcernError", {}) @@ -509,23 +545,7 @@ async def _execute_command( if self.ordered and "writeErrors" in result: break else: - if self.is_encrypted: - _, batched_cmd, to_send = bwc.batch_command(cmd, ops) - await bwc.conn.command( - bwc.db_name, - batched_cmd, - write_concern=WriteConcern(w=0), - session=bwc.session, - client=client, - ) - else: - request_id, msg, to_send = bwc.batch_command(cmd, ops) - # Though this isn't strictly a "legacy" write, the helper - # handles publishing commands and sending our message - # without receiving a result. Send 0 for max_doc_size - # to disable size checking. Size checking is handled while - # the documents are encoded to BSON. - await self.unack_write(bwc, cmd, request_id, msg, 0, to_send, client) + to_send = await self._execute_batch_unack(bwc, cmd, ops, client) run.idx_offset += len(to_send) @@ -576,7 +596,7 @@ async def retryable_bulk( retryable_bulk, session, operation, - bulk=self, + bulk=self, # type: ignore[arg-type] operation_id=op_id, ) @@ -619,18 +639,7 @@ async def execute_op_msg_no_results( conn.add_server_api(cmd) ops = islice(run.ops, run.idx_offset, None) # Run as many ops as possible. - if self.is_encrypted: - _, batched_cmd, to_send = bwc.batch_command(cmd, ops) - await bwc.conn.command( - bwc.db_name, - batched_cmd, - write_concern=WriteConcern(w=0), - session=bwc.session, - client=client, - ) - else: - request_id, msg, to_send = bwc.batch_command(cmd, ops) - await self.unack_write(bwc, cmd, request_id, msg, 0, to_send, client) + to_send = await self._execute_batch_unack(bwc, cmd, ops, client) run.idx_offset += len(to_send) self.current_run = run = next(generator, None) diff --git a/pymongo/synchronous/bulk.py b/pymongo/synchronous/bulk.py index ae47842053..4da64c4a78 100644 --- a/pymongo/synchronous/bulk.py +++ b/pymongo/synchronous/bulk.py @@ -91,7 +91,7 @@ def __init__( comment: Optional[str] = None, let: Optional[Any] = None, ) -> None: - """Initialize a _AsyncBulk instance.""" + """Initialize a _Bulk instance.""" self.collection = collection.with_options( codec_options=collection.codec_options._replace( unicode_decode_error_handler="replace", document_class=dict @@ -323,7 +323,7 @@ def unack_write( docs: list[Mapping[str, Any]], client: MongoClient, ) -> Optional[Mapping[str, Any]]: - """A proxy for AsyncConnection.unack_write that handles event publishing.""" + """A proxy for Connection.unack_write that handles event publishing.""" if _COMMAND_LOGGER.isEnabledFor(logging.DEBUG): _debug_log( _COMMAND_LOGGER, @@ -403,6 +403,56 @@ def unack_write( bwc.start_time = datetime.datetime.now() return result # type: ignore[return-value] + def _execute_batch_unack( + self, + bwc: Union[_BulkWriteContext, _EncryptedBulkWriteContext], + cmd: dict[str, Any], + ops: list[Mapping[str, Any]], + client: MongoClient, + ) -> list[Mapping[str, Any]]: + if self.is_encrypted: + _, batched_cmd, to_send = bwc.batch_command(cmd, ops) + bwc.conn.command( # type: ignore[misc] + bwc.db_name, + batched_cmd, # type: ignore[arg-type] + write_concern=WriteConcern(w=0), + session=bwc.session, # type: ignore[arg-type] + client=client, # type: ignore[arg-type] + ) + else: + request_id, msg, to_send = bwc.batch_command(cmd, ops) + # Though this isn't strictly a "legacy" write, the helper + # handles publishing commands and sending our message + # without receiving a result. Send 0 for max_doc_size + # to disable size checking. Size checking is handled while + # the documents are encoded to BSON. + self.unack_write(bwc, cmd, request_id, msg, 0, to_send, client) # type: ignore[arg-type] + + return to_send + + def _execute_batch( + self, + bwc: Union[_BulkWriteContext, _EncryptedBulkWriteContext], + cmd: dict[str, Any], + ops: list[Mapping[str, Any]], + client: MongoClient, + ) -> tuple[dict[str, Any], list[Mapping[str, Any]]]: + if self.is_encrypted: + _, batched_cmd, to_send = bwc.batch_command(cmd, ops) + result = bwc.conn.command( # type: ignore[misc] + bwc.db_name, + batched_cmd, # type: ignore[arg-type] + codec_options=bwc.codec, + session=bwc.session, # type: ignore[arg-type] + client=client, # type: ignore[arg-type] + ) + else: + request_id, msg, to_send = bwc.batch_command(cmd, ops) + result = self.write_command(bwc, cmd, request_id, msg, to_send, client) # type: ignore[arg-type] + client._process_response(result, bwc.session) # type: ignore[arg-type] + + return result, to_send # type: ignore[return-value] + def _execute_command( self, generator: Iterator[Any], @@ -423,8 +473,8 @@ def _execute_command( self.next_run = None run = self.current_run - # AsyncConnection.command validates the session, but we use - # AsyncConnection.write_command + # Connection.command validates the session, but we use + # Connection.write_command conn.validate_session(client, session) last_run = False @@ -475,19 +525,7 @@ def _execute_command( # Run as many ops as possible in one command. if write_concern.acknowledged: - if self.is_encrypted: - _, batched_cmd, to_send = bwc.batch_command(cmd, ops) - result = bwc.conn.command( - bwc.db_name, - batched_cmd, - codec_options=bwc.codec, - session=bwc.session, - client=client, - ) - else: - request_id, msg, to_send = bwc.batch_command(cmd, ops) - result = self.write_command(bwc, cmd, request_id, msg, to_send, client) - client._process_response(result, bwc.session) + result, to_send = self._execute_batch(bwc, cmd, ops, client) # Retryable writeConcernErrors halt the execution of this run. wce = result.get("writeConcernError", {}) @@ -507,23 +545,7 @@ def _execute_command( if self.ordered and "writeErrors" in result: break else: - if self.is_encrypted: - _, batched_cmd, to_send = bwc.batch_command(cmd, ops) - bwc.conn.command( - bwc.db_name, - batched_cmd, - write_concern=WriteConcern(w=0), - session=bwc.session, - client=client, - ) - else: - request_id, msg, to_send = bwc.batch_command(cmd, ops) - # Though this isn't strictly a "legacy" write, the helper - # handles publishing commands and sending our message - # without receiving a result. Send 0 for max_doc_size - # to disable size checking. Size checking is handled while - # the documents are encoded to BSON. - self.unack_write(bwc, cmd, request_id, msg, 0, to_send, client) + to_send = self._execute_batch_unack(bwc, cmd, ops, client) run.idx_offset += len(to_send) @@ -574,7 +596,7 @@ def retryable_bulk( retryable_bulk, session, operation, - bulk=self, + bulk=self, # type: ignore[arg-type] operation_id=op_id, ) @@ -615,18 +637,7 @@ def execute_op_msg_no_results(self, conn: Connection, generator: Iterator[Any]) conn.add_server_api(cmd) ops = islice(run.ops, run.idx_offset, None) # Run as many ops as possible. - if self.is_encrypted: - _, batched_cmd, to_send = bwc.batch_command(cmd, ops) - bwc.conn.command( - bwc.db_name, - batched_cmd, - write_concern=WriteConcern(w=0), - session=bwc.session, - client=client, - ) - else: - request_id, msg, to_send = bwc.batch_command(cmd, ops) - self.unack_write(bwc, cmd, request_id, msg, 0, to_send, client) + to_send = self._execute_batch_unack(bwc, cmd, ops, client) run.idx_offset += len(to_send) self.current_run = run = next(generator, None) diff --git a/pymongo/synchronous/helpers.py b/pymongo/synchronous/helpers.py index f581caae69..9b6809613e 100644 --- a/pymongo/synchronous/helpers.py +++ b/pymongo/synchronous/helpers.py @@ -47,7 +47,7 @@ def inner(*args: Any, **kwargs: Any) -> Any: if no_reauth: raise if exc.code == _REAUTHENTICATION_REQUIRED_CODE: - # Look for an argument that either is a AsyncConnection + # Look for an argument that either is a Connection # or has a connection attribute, so we can trigger # a reauth. conn = None @@ -69,7 +69,7 @@ def inner(*args: Any, **kwargs: Any) -> Any: def next(cls: Any) -> Any: - """Compatibility function until we drop 3.9 support: https://docs.python.org/3/library/functions.html#anext.""" + """Compatibility function until we drop 3.9 support: https://docs.python.org/3/library/functions.html#next.""" if sys.version_info >= (3, 10): return builtins.next(cls) else: diff --git a/tools/synchro.py b/tools/synchro.py index 586c80fd81..b13a0a351d 100644 --- a/tools/synchro.py +++ b/tools/synchro.py @@ -128,31 +128,7 @@ ] -docstring_translate_files = ( - [ - _pymongo_dest_base + f - for f in [ - "aggregation.py", - "change_stream.py", - "collection.py", - "command_cursor.py", - "cursor.py", - "client_options.py", - "client_session.py", - "database.py", - "encryption.py", - "encryption_options.py", - "mongo_client.py", - "network.py", - "operations.py", - "pool.py", - "topology.py", - "server.py", - ] - ] - + [_gridfs_dest_base + f for f in ["grid_file.py"]] - + sync_test_files -) +docstring_translate_files = sync_files + sync_gridfs_files + sync_test_files def process_files(files: list[str]) -> None: From b5574ec787125e5e935f5caa930c23b7823240af Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Tue, 25 Jun 2024 15:08:52 -0700 Subject: [PATCH 11/11] Remove _IS_SYNC from moved modules --- pymongo/client_options.py | 2 -- pymongo/collation.py | 2 -- pymongo/common.py | 1 - pymongo/compression_support.py | 3 --- pymongo/encryption_options.py | 2 -- pymongo/event_loggers.py | 2 -- pymongo/hello.py | 2 -- pymongo/helpers_shared.py | 1 - pymongo/logger.py | 2 -- pymongo/max_staleness_selectors.py | 1 - pymongo/monitoring.py | 1 - pymongo/operations.py | 1 - pymongo/read_preferences.py | 1 - pymongo/response.py | 2 -- pymongo/server_description.py | 2 -- pymongo/server_selectors.py | 1 - pymongo/srv_resolver.py | 2 -- pymongo/topology_description.py | 2 -- pymongo/typings.py | 1 - pymongo/uri_parser.py | 1 - 20 files changed, 32 deletions(-) diff --git a/pymongo/client_options.py b/pymongo/client_options.py index ddc22c3aff..2fb7b30c7b 100644 --- a/pymongo/client_options.py +++ b/pymongo/client_options.py @@ -40,8 +40,6 @@ from pymongo.pyopenssl_context import SSLContext from pymongo.topology_description import _ServerSelector -_IS_SYNC = False - def _parse_credentials( username: str, password: str, database: Optional[str], options: Mapping[str, Any] diff --git a/pymongo/collation.py b/pymongo/collation.py index 115c8c7e88..9956872965 100644 --- a/pymongo/collation.py +++ b/pymongo/collation.py @@ -23,8 +23,6 @@ from pymongo import common from pymongo.write_concern import validate_boolean -_IS_SYNC = False - class CollationStrength: """ diff --git a/pymongo/common.py b/pymongo/common.py index 16f3ff2580..2ef9fa92f1 100644 --- a/pymongo/common.py +++ b/pymongo/common.py @@ -55,7 +55,6 @@ if TYPE_CHECKING: from pymongo.typings import _AgnosticClientSession -_IS_SYNC = False ORDERED_TYPES: Sequence[Type] = (SON, OrderedDict) diff --git a/pymongo/compression_support.py b/pymongo/compression_support.py index de7b856506..7123b90dfe 100644 --- a/pymongo/compression_support.py +++ b/pymongo/compression_support.py @@ -19,9 +19,6 @@ from pymongo.hello import HelloCompat from pymongo.helpers_shared import _SENSITIVE_COMMANDS -_IS_SYNC = False - - _SUPPORTED_COMPRESSORS = {"snappy", "zlib", "zstd"} _NO_COMPRESSION = {HelloCompat.CMD, HelloCompat.LEGACY_CMD} _NO_COMPRESSION.update(_SENSITIVE_COMMANDS) diff --git a/pymongo/encryption_options.py b/pymongo/encryption_options.py index f90bf8d129..f12e6e6f79 100644 --- a/pymongo/encryption_options.py +++ b/pymongo/encryption_options.py @@ -31,8 +31,6 @@ if TYPE_CHECKING: from pymongo.typings import _AgnosticMongoClient, _DocumentTypeArg -_IS_SYNC = False - class AutoEncryptionOpts: """Options to configure automatic client-side field level encryption.""" diff --git a/pymongo/event_loggers.py b/pymongo/event_loggers.py index 3a241df52b..287db3fc4d 100644 --- a/pymongo/event_loggers.py +++ b/pymongo/event_loggers.py @@ -32,8 +32,6 @@ from pymongo import monitoring -_IS_SYNC = False - class CommandLogger(monitoring.CommandListener): """A simple listener that logs command events. diff --git a/pymongo/hello.py b/pymongo/hello.py index 89b51980b3..62bb799805 100644 --- a/pymongo/hello.py +++ b/pymongo/hello.py @@ -25,8 +25,6 @@ from pymongo.server_type import SERVER_TYPE from pymongo.typings import ClusterTime, _DocumentType -_IS_SYNC = False - def _get_server_type(doc: Mapping[str, Any]) -> int: """Determine the server type from a hello response.""" diff --git a/pymongo/helpers_shared.py b/pymongo/helpers_shared.py index c3324e162a..83ea2ddf78 100644 --- a/pymongo/helpers_shared.py +++ b/pymongo/helpers_shared.py @@ -49,7 +49,6 @@ from pymongo.operations import _IndexList from pymongo.typings import _DocumentOut -_IS_SYNC = False # From the SDAM spec, the "node is shutting down" codes. diff --git a/pymongo/logger.py b/pymongo/logger.py index ed398c8329..2caafa778d 100644 --- a/pymongo/logger.py +++ b/pymongo/logger.py @@ -23,8 +23,6 @@ from bson.json_util import JSONOptions, _truncate_documents from pymongo.monitoring import ConnectionCheckOutFailedReason, ConnectionClosedReason -_IS_SYNC = False - class _CommandStatusMessage(str, enum.Enum): STARTED = "Command started" diff --git a/pymongo/max_staleness_selectors.py b/pymongo/max_staleness_selectors.py index d9b2396a0c..89bfa65281 100644 --- a/pymongo/max_staleness_selectors.py +++ b/pymongo/max_staleness_selectors.py @@ -36,7 +36,6 @@ if TYPE_CHECKING: from pymongo.server_selectors import Selection -_IS_SYNC = False # Constant defined in Max Staleness Spec: An idle primary writes a no-op every # 10 seconds to refresh secondaries' lastWriteDate values. diff --git a/pymongo/monitoring.py b/pymongo/monitoring.py index 1e48905aee..260213e18b 100644 --- a/pymongo/monitoring.py +++ b/pymongo/monitoring.py @@ -200,7 +200,6 @@ def connection_checked_in(self, event): from pymongo.server_description import ServerDescription from pymongo.topology_description import TopologyDescription -_IS_SYNC = False _Listeners = namedtuple( "_Listeners", diff --git a/pymongo/operations.py b/pymongo/operations.py index f43f8bdc8c..2967a29441 100644 --- a/pymongo/operations.py +++ b/pymongo/operations.py @@ -38,7 +38,6 @@ if TYPE_CHECKING: from pymongo.typings import _AgnosticBulk -_IS_SYNC = False # Hint supports index name, "myIndex", a list of either strings or index pairs: [('x', 1), ('y', -1), 'z''], or a dictionary _IndexList = Union[ diff --git a/pymongo/read_preferences.py b/pymongo/read_preferences.py index 10deba7bbe..a7e138cd90 100644 --- a/pymongo/read_preferences.py +++ b/pymongo/read_preferences.py @@ -30,7 +30,6 @@ from pymongo.server_selectors import Selection from pymongo.topology_description import TopologyDescription -_IS_SYNC = False _PRIMARY = 0 _PRIMARY_PREFERRED = 1 diff --git a/pymongo/response.py b/pymongo/response.py index 850794567c..e47749423f 100644 --- a/pymongo/response.py +++ b/pymongo/response.py @@ -23,8 +23,6 @@ from pymongo.message import _OpMsg, _OpReply from pymongo.typings import _Address, _AgnosticConnection, _DocumentOut -_IS_SYNC = False - class Response: __slots__ = ("_data", "_address", "_request_id", "_duration", "_from_command", "_docs") diff --git a/pymongo/server_description.py b/pymongo/server_description.py index 5a2e62837d..6393fce0a1 100644 --- a/pymongo/server_description.py +++ b/pymongo/server_description.py @@ -25,8 +25,6 @@ from pymongo.server_type import SERVER_TYPE from pymongo.typings import ClusterTime, _Address -_IS_SYNC = False - class ServerDescription: """Immutable representation of one server. diff --git a/pymongo/server_selectors.py b/pymongo/server_selectors.py index c0f7ad6ea6..c22ad599ee 100644 --- a/pymongo/server_selectors.py +++ b/pymongo/server_selectors.py @@ -23,7 +23,6 @@ from pymongo.server_description import ServerDescription from pymongo.topology_description import TopologyDescription -_IS_SYNC = False T = TypeVar("T") TagSet = Mapping[str, Any] diff --git a/pymongo/srv_resolver.py b/pymongo/srv_resolver.py index 2d699f9c1f..6f6cc285fa 100644 --- a/pymongo/srv_resolver.py +++ b/pymongo/srv_resolver.py @@ -25,8 +25,6 @@ if TYPE_CHECKING: from dns import resolver -_IS_SYNC = False - def _have_dnspython() -> bool: try: diff --git a/pymongo/topology_description.py b/pymongo/topology_description.py index d28e11fc47..cc2330cbab 100644 --- a/pymongo/topology_description.py +++ b/pymongo/topology_description.py @@ -37,8 +37,6 @@ from pymongo.server_type import SERVER_TYPE from pymongo.typings import _Address -_IS_SYNC = False - # Enumeration for various kinds of MongoDB cluster topologies. class _TopologyType(NamedTuple): diff --git a/pymongo/typings.py b/pymongo/typings.py index c89f5e2abc..9f6d7b1669 100644 --- a/pymongo/typings.py +++ b/pymongo/typings.py @@ -39,7 +39,6 @@ from pymongo.synchronous.mongo_client import MongoClient from pymongo.synchronous.pool import Connection -_IS_SYNC = False # Common Shared Types. _Address = Tuple[str, Optional[int]] diff --git a/pymongo/uri_parser.py b/pymongo/uri_parser.py index 4247d51fd1..4ebd3008c3 100644 --- a/pymongo/uri_parser.py +++ b/pymongo/uri_parser.py @@ -46,7 +46,6 @@ if TYPE_CHECKING: from pymongo.pyopenssl_context import SSLContext -_IS_SYNC = False SCHEME = "mongodb://" SCHEME_LEN = len(SCHEME) SRV_SCHEME = "mongodb+srv://"