Skip to content

Commit

Permalink
Add async support for cancellation
Browse files Browse the repository at this point in the history
  • Loading branch information
NoahStapp committed Sep 20, 2024
1 parent 3d399da commit 7f71430
Showing 1 changed file with 24 additions and 5 deletions.
29 changes: 24 additions & 5 deletions pymongo/network_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ def _is_ready(fut: Future) -> None:
while read < length:
try:
read += conn.recv_into(mv[read:])
if read == 0:
raise OSError("connection closed")
except BLOCKING_IO_ERRORS as exc:
fd = conn.fileno()
# Check for closed socket.
Expand Down Expand Up @@ -195,11 +197,20 @@ def sendall(sock: Union[socket.socket, _sslConn], buf: bytes) -> None:
sock.sendall(buf)


async def _poll_cancellation(conn: AsyncConnection) -> None:
while True:
if conn.cancel_context.cancelled:
return

await asyncio.sleep(_POLL_TIMEOUT)


async def async_receive_data(
conn: AsyncConnection, length: int, deadline: Optional[float]
) -> memoryview:
sock = conn.conn
sock_timeout = sock.gettimeout()
timeout: Optional[Union[float, int]]
if deadline:
# When the timeout has expired perform one final check to
# see if the socket is readable. This helps avoid spurious
Expand All @@ -210,14 +221,22 @@ async def async_receive_data(

sock.settimeout(0.0)
loop = asyncio.get_event_loop()
cancellation_task = asyncio.create_task(_poll_cancellation(conn))
try:
if _HAVE_SSL and isinstance(sock, (SSLSocket, _sslConn)):
return await asyncio.wait_for(_async_receive_ssl(sock, length, loop), timeout=timeout)
read_task = asyncio.create_task(_async_receive_ssl(sock, length, loop)) # type: ignore[arg-type]
else:
return await asyncio.wait_for(_async_receive(sock, length, loop), timeout=timeout) # type: ignore[arg-type]
except asyncio.TimeoutError as exc:
# Convert the asyncio.wait_for timeout error to socket.timeout which pool.py understands.
raise socket.timeout("timed out") from exc
read_task = asyncio.create_task(_async_receive(sock, length, loop)) # type: ignore[arg-type]
tasks = [read_task, cancellation_task]
result = await asyncio.wait(tasks, timeout=timeout, return_when=asyncio.FIRST_COMPLETED)
if len(result[1]) == 2:
raise socket.timeout("timed out")
finished = next(iter(result[0]))
next(iter(result[1])).cancel()
if finished == read_task:
return finished.result() # type: ignore[return-value]
else:
raise _OperationCancelled("operation cancelled")
finally:
sock.settimeout(sock_timeout)

Expand Down

0 comments on commit 7f71430

Please sign in to comment.