diff --git a/pymongo/network_layer.py b/pymongo/network_layer.py index 46805ad1cb..199dd6763f 100644 --- a/pymongo/network_layer.py +++ b/pymongo/network_layer.py @@ -127,7 +127,7 @@ async def _async_receive_ssl( ) -> memoryview: mv = memoryview(bytearray(length)) fd = conn.fileno() - read = 0 + total_read = 0 def _is_ready(fut: Future) -> None: loop.remove_writer(fd) @@ -136,11 +136,12 @@ def _is_ready(fut: Future) -> None: return fut.set_result(None) - while read < length: + while total_read < length: try: - read += conn.recv_into(mv[read:]) + read = conn.recv_into(mv[total_read:]) if read == 0: raise OSError("connection closed") + total_read += read except BLOCKING_IO_ERRORS as exc: fd = conn.fileno() # Check for closed socket. @@ -228,15 +229,19 @@ async def async_receive_data( else: read_task = asyncio.create_task(_async_receive(sock, length, loop)) # type: ignore[arg-type] tasks = [read_task, cancellation_task] - result = await asyncio.wait(tasks, timeout=timeout, return_when=asyncio.FIRST_COMPLETED) - if len(result[1]) == 2: + done, pending = await asyncio.wait( + tasks, timeout=timeout, return_when=asyncio.FIRST_COMPLETED + ) + for task in pending: + task.cancel() + if len(done) == 0: raise socket.timeout("timed out") - finished = next(iter(result[0])) - next(iter(result[1])).cancel() - if finished == read_task: - return finished.result() # type: ignore[return-value] - else: - raise _OperationCancelled("operation cancelled") + for task in done: + if task == read_task: + return read_task.result() + else: + raise _OperationCancelled("operation cancelled") + return None # type: ignore[return-value] finally: sock.settimeout(sock_timeout) diff --git a/pymongo/pyopenssl_context.py b/pymongo/pyopenssl_context.py index 4f6f6f4a89..e521a92789 100644 --- a/pymongo/pyopenssl_context.py +++ b/pymongo/pyopenssl_context.py @@ -105,11 +105,16 @@ def _ragged_eof(exc: BaseException) -> bool: # https://docs.python.org/3/library/ssl.html#notes-on-non-blocking-sockets class _sslConn(_SSL.Connection): def __init__( - self, ctx: _SSL.Context, sock: Optional[_socket.socket], suppress_ragged_eofs: bool + self, + ctx: _SSL.Context, + sock: Optional[_socket.socket], + suppress_ragged_eofs: bool, + is_async: bool = False, ): self.socket_checker = _SocketChecker() self.suppress_ragged_eofs = suppress_ragged_eofs super().__init__(ctx, sock) + self._is_async = is_async def _call(self, call: Callable[..., _T], *args: Any, **kwargs: Any) -> _T: timeout = self.gettimeout() @@ -119,6 +124,8 @@ def _call(self, call: Callable[..., _T], *args: Any, **kwargs: Any) -> _T: try: return call(*args, **kwargs) except BLOCKING_IO_ERRORS as exc: + if self._is_async: + raise exc # Check for closed socket. if self.fileno() == -1: if timeout and _time.monotonic() - start > timeout: @@ -381,7 +388,7 @@ async def a_wrap_socket( """Wrap an existing Python socket connection and return a TLS socket object. """ - ssl_conn = _sslConn(self._ctx, sock, suppress_ragged_eofs) + ssl_conn = _sslConn(self._ctx, sock, suppress_ragged_eofs, True) loop = asyncio.get_running_loop() if session: ssl_conn.set_session(session)