Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PYTHON-4636 Stop blocking the I/O Loop for socket reads #1871

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 5 additions & 76 deletions pymongo/asynchronous/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,8 @@
"""Internal network layer helper methods."""
from __future__ import annotations

import asyncio
import datetime
import errno
import logging
import socket
import time
from typing import (
TYPE_CHECKING,
Expand All @@ -40,19 +37,16 @@
NotPrimaryError,
OperationFailure,
ProtocolError,
_OperationCancelled,
)
from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log
from pymongo.message import _UNPACK_REPLY, _OpMsg, _OpReply
from pymongo.monitoring import _is_speculative_authenticate
from pymongo.network_layer import (
_POLL_TIMEOUT,
_UNPACK_COMPRESSION_HEADER,
_UNPACK_HEADER,
BLOCKING_IO_ERRORS,
async_receive_data,
async_sendall,
)
from pymongo.socket_checker import _errno_from_exception

if TYPE_CHECKING:
from bson import CodecOptions
Expand Down Expand Up @@ -318,9 +312,7 @@ async def receive_message(
else:
deadline = None
# Ignore the response's request id.
length, _, response_to, op_code = _UNPACK_HEADER(
await _receive_data_on_socket(conn, 16, deadline)
)
length, _, response_to, op_code = _UNPACK_HEADER(await async_receive_data(conn, 16, deadline))
# No request_id for exhaust cursor "getMore".
if request_id is not None:
if request_id != response_to:
Expand All @@ -336,11 +328,11 @@ async def receive_message(
)
if op_code == 2012:
op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(
await _receive_data_on_socket(conn, 9, deadline)
await async_receive_data(conn, 9, deadline)
)
data = decompress(await _receive_data_on_socket(conn, length - 25, deadline), compressor_id)
data = decompress(await async_receive_data(conn, length - 25, deadline), compressor_id)
else:
data = await _receive_data_on_socket(conn, length - 16, deadline)
data = await async_receive_data(conn, length - 16, deadline)

try:
unpack_reply = _UNPACK_REPLY[op_code]
Expand All @@ -349,66 +341,3 @@ async def receive_message(
f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}"
) from None
return unpack_reply(data)


async def wait_for_read(conn: AsyncConnection, deadline: Optional[float]) -> None:
"""Block until at least one byte is read, or a timeout, or a cancel."""
sock = conn.conn
timed_out = False
# Check if the connection's socket has been manually closed
if sock.fileno() == -1:
return
while True:
# SSLSocket can have buffered data which won't be caught by select.
if hasattr(sock, "pending") and sock.pending() > 0:
readable = True
else:
# Wait up to 500ms for the socket to become readable and then
# check for cancellation.
if deadline:
remaining = deadline - time.monotonic()
# When the timeout has expired perform one final check to
# see if the socket is readable. This helps avoid spurious
# timeouts on AWS Lambda and other FaaS environments.
if remaining <= 0:
timed_out = True
timeout = max(min(remaining, _POLL_TIMEOUT), 0)
else:
timeout = _POLL_TIMEOUT
readable = conn.socket_checker.select(sock, read=True, timeout=timeout)
if conn.cancel_context.cancelled:
raise _OperationCancelled("operation cancelled")
if readable:
return
if timed_out:
raise socket.timeout("timed out")
await asyncio.sleep(0)


async def _receive_data_on_socket(
conn: AsyncConnection, length: int, deadline: Optional[float]
) -> memoryview:
buf = bytearray(length)
mv = memoryview(buf)
bytes_read = 0
while bytes_read < length:
try:
await wait_for_read(conn, deadline)
# CSOT: Update timeout. When the timeout has expired perform one
# final non-blocking recv. This helps avoid spurious timeouts when
# the response is actually already buffered on the client.
if _csot.get_timeout() and deadline is not None:
conn.set_conn_timeout(max(deadline - time.monotonic(), 0))
chunk_length = conn.conn.recv_into(mv[bytes_read:])
except BLOCKING_IO_ERRORS:
raise socket.timeout("timed out") from None
except OSError as exc:
if _errno_from_exception(exc) == errno.EINTR:
continue
raise
if chunk_length == 0:
raise OSError("connection closed")

bytes_read += chunk_length

return mv
115 changes: 114 additions & 1 deletion pymongo/network_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,21 @@
from __future__ import annotations

import asyncio
import errno
import socket
import struct
import sys
import time
from asyncio import AbstractEventLoop, Future
from typing import (
TYPE_CHECKING,
Optional,
Union,
)

from pymongo import ssl_support
from pymongo import _csot, ssl_support
from pymongo.errors import _OperationCancelled
from pymongo.socket_checker import _errno_from_exception

try:
from ssl import SSLError, SSLSocket
Expand All @@ -51,6 +57,10 @@
BLOCKING_IO_WRITE_ERROR,
)

if TYPE_CHECKING:
from pymongo.asynchronous.pool import AsyncConnection
from pymongo.synchronous.pool import Connection

_UNPACK_HEADER = struct.Struct("<iiii").unpack
_UNPACK_COMPRESSION_HEADER = struct.Struct("<iiB").unpack
_POLL_TIMEOUT = 0.5
Expand Down Expand Up @@ -131,3 +141,106 @@ async def _async_sendall_ssl(

def sendall(sock: Union[socket.socket, _sslConn], buf: bytes) -> None:
sock.sendall(buf)


async def async_receive_data(
conn: AsyncConnection, length: int, deadline: Optional[float]
) -> memoryview:
sock = conn.conn
sock_timeout = sock.gettimeout()
if deadline:
# When the timeout has expired perform one final check to
# see if the socket is readable. This helps avoid spurious
# timeouts on AWS Lambda and other FaaS environments.
timeout = max(deadline - time.monotonic(), 0)
else:
timeout = sock_timeout

sock.settimeout(0.0)
loop = asyncio.get_event_loop()
try:
if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)):
return await asyncio.wait_for(_async_receive_ssl(sock, length, loop), timeout=timeout)
else:
return await asyncio.wait_for(_async_receive(sock, length, loop), timeout=timeout) # type: ignore[arg-type]
except asyncio.TimeoutError as exc:
# Convert the asyncio.wait_for timeout error to socket.timeout which pool.py understands.
raise socket.timeout("timed out") from exc
finally:
sock.settimeout(sock_timeout)


async def _async_receive(conn: socket.socket, length: int, loop: AbstractEventLoop) -> memoryview:
mv = memoryview(bytearray(length))
bytes_read = 0
while bytes_read < length:
chunk_length = await loop.sock_recv_into(conn, mv[bytes_read:])
if chunk_length == 0:
raise OSError("connection closed")
bytes_read += chunk_length
return mv


async def _async_receive_ssl(conn: _sslConn, length: int, loop: AbstractEventLoop) -> memoryview: # noqa: ARG001
return memoryview(b"")


# Sync version:
def wait_for_read(conn: Connection, deadline: Optional[float]) -> None:
"""Block until at least one byte is read, or a timeout, or a cancel."""
sock = conn.conn
timed_out = False
# Check if the connection's socket has been manually closed
if sock.fileno() == -1:
return
while True:
# SSLSocket can have buffered data which won't be caught by select.
if hasattr(sock, "pending") and sock.pending() > 0:
readable = True
else:
# Wait up to 500ms for the socket to become readable and then
# check for cancellation.
if deadline:
remaining = deadline - time.monotonic()
# When the timeout has expired perform one final check to
# see if the socket is readable. This helps avoid spurious
# timeouts on AWS Lambda and other FaaS environments.
if remaining <= 0:
timed_out = True
timeout = max(min(remaining, _POLL_TIMEOUT), 0)
else:
timeout = _POLL_TIMEOUT
readable = conn.socket_checker.select(sock, read=True, timeout=timeout)
if conn.cancel_context.cancelled:
raise _OperationCancelled("operation cancelled")
if readable:
return
if timed_out:
raise socket.timeout("timed out")


def receive_data(conn: Connection, length: int, deadline: Optional[float]) -> memoryview:
buf = bytearray(length)
mv = memoryview(buf)
bytes_read = 0
while bytes_read < length:
try:
wait_for_read(conn, deadline)
# CSOT: Update timeout. When the timeout has expired perform one
# final non-blocking recv. This helps avoid spurious timeouts when
# the response is actually already buffered on the client.
if _csot.get_timeout() and deadline is not None:
conn.set_conn_timeout(max(deadline - time.monotonic(), 0))
chunk_length = conn.conn.recv_into(mv[bytes_read:])
except BLOCKING_IO_ERRORS:
raise socket.timeout("timed out") from None
except OSError as exc:
if _errno_from_exception(exc) == errno.EINTR:
continue
raise
if chunk_length == 0:
raise OSError("connection closed")

bytes_read += chunk_length

return mv
77 changes: 5 additions & 72 deletions pymongo/synchronous/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@
from __future__ import annotations

import datetime
import errno
import logging
import socket
import time
from typing import (
TYPE_CHECKING,
Expand All @@ -39,19 +37,16 @@
NotPrimaryError,
OperationFailure,
ProtocolError,
_OperationCancelled,
)
from pymongo.logger import _COMMAND_LOGGER, _CommandStatusMessage, _debug_log
from pymongo.message import _UNPACK_REPLY, _OpMsg, _OpReply
from pymongo.monitoring import _is_speculative_authenticate
from pymongo.network_layer import (
_POLL_TIMEOUT,
_UNPACK_COMPRESSION_HEADER,
_UNPACK_HEADER,
BLOCKING_IO_ERRORS,
receive_data,
sendall,
)
from pymongo.socket_checker import _errno_from_exception

if TYPE_CHECKING:
from bson import CodecOptions
Expand Down Expand Up @@ -317,7 +312,7 @@ def receive_message(
else:
deadline = None
# Ignore the response's request id.
length, _, response_to, op_code = _UNPACK_HEADER(_receive_data_on_socket(conn, 16, deadline))
length, _, response_to, op_code = _UNPACK_HEADER(receive_data(conn, 16, deadline))
# No request_id for exhaust cursor "getMore".
if request_id is not None:
if request_id != response_to:
Expand All @@ -332,12 +327,10 @@ def receive_message(
f"message size ({max_message_size!r})"
)
if op_code == 2012:
op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(
_receive_data_on_socket(conn, 9, deadline)
)
data = decompress(_receive_data_on_socket(conn, length - 25, deadline), compressor_id)
op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(receive_data(conn, 9, deadline))
data = decompress(receive_data(conn, length - 25, deadline), compressor_id)
else:
data = _receive_data_on_socket(conn, length - 16, deadline)
data = receive_data(conn, length - 16, deadline)

try:
unpack_reply = _UNPACK_REPLY[op_code]
Expand All @@ -346,63 +339,3 @@ def receive_message(
f"Got opcode {op_code!r} but expected {_UNPACK_REPLY.keys()!r}"
) from None
return unpack_reply(data)


def wait_for_read(conn: Connection, deadline: Optional[float]) -> None:
"""Block until at least one byte is read, or a timeout, or a cancel."""
sock = conn.conn
timed_out = False
# Check if the connection's socket has been manually closed
if sock.fileno() == -1:
return
while True:
# SSLSocket can have buffered data which won't be caught by select.
if hasattr(sock, "pending") and sock.pending() > 0:
readable = True
else:
# Wait up to 500ms for the socket to become readable and then
# check for cancellation.
if deadline:
remaining = deadline - time.monotonic()
# When the timeout has expired perform one final check to
# see if the socket is readable. This helps avoid spurious
# timeouts on AWS Lambda and other FaaS environments.
if remaining <= 0:
timed_out = True
timeout = max(min(remaining, _POLL_TIMEOUT), 0)
else:
timeout = _POLL_TIMEOUT
readable = conn.socket_checker.select(sock, read=True, timeout=timeout)
if conn.cancel_context.cancelled:
raise _OperationCancelled("operation cancelled")
if readable:
return
if timed_out:
raise socket.timeout("timed out")


def _receive_data_on_socket(conn: Connection, length: int, deadline: Optional[float]) -> memoryview:
buf = bytearray(length)
mv = memoryview(buf)
bytes_read = 0
while bytes_read < length:
try:
wait_for_read(conn, deadline)
# CSOT: Update timeout. When the timeout has expired perform one
# final non-blocking recv. This helps avoid spurious timeouts when
# the response is actually already buffered on the client.
if _csot.get_timeout() and deadline is not None:
conn.set_conn_timeout(max(deadline - time.monotonic(), 0))
chunk_length = conn.conn.recv_into(mv[bytes_read:])
except BLOCKING_IO_ERRORS:
raise socket.timeout("timed out") from None
except OSError as exc:
if _errno_from_exception(exc) == errno.EINTR:
continue
raise
if chunk_length == 0:
raise OSError("connection closed")

bytes_read += chunk_length

return mv
1 change: 1 addition & 0 deletions tools/synchro.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
"AsyncConnection": "Connection",
"async_command": "command",
"async_receive_message": "receive_message",
"async_receive_data": "receive_data",
"async_sendall": "sendall",
"asynchronous": "synchronous",
"Asynchronous": "Synchronous",
Expand Down
Loading