Skip to content

Commit

Permalink
Async pyopenssl support
Browse files Browse the repository at this point in the history
  • Loading branch information
NoahStapp committed Sep 20, 2024
1 parent 7f71430 commit d69b5f6
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 13 deletions.
27 changes: 16 additions & 11 deletions pymongo/network_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ async def _async_receive_ssl(
) -> memoryview:
mv = memoryview(bytearray(length))
fd = conn.fileno()
read = 0
total_read = 0

def _is_ready(fut: Future) -> None:
loop.remove_writer(fd)
Expand All @@ -136,11 +136,12 @@ def _is_ready(fut: Future) -> None:
return
fut.set_result(None)

while read < length:
while total_read < length:
try:
read += conn.recv_into(mv[read:])
read = conn.recv_into(mv[total_read:])
if read == 0:
raise OSError("connection closed")
total_read += read
except BLOCKING_IO_ERRORS as exc:
fd = conn.fileno()
# Check for closed socket.
Expand Down Expand Up @@ -228,15 +229,19 @@ async def async_receive_data(
else:
read_task = asyncio.create_task(_async_receive(sock, length, loop)) # type: ignore[arg-type]
tasks = [read_task, cancellation_task]
result = await asyncio.wait(tasks, timeout=timeout, return_when=asyncio.FIRST_COMPLETED)
if len(result[1]) == 2:
done, pending = await asyncio.wait(
tasks, timeout=timeout, return_when=asyncio.FIRST_COMPLETED
)
for task in pending:
task.cancel()
if len(done) == 0:
raise socket.timeout("timed out")
finished = next(iter(result[0]))
next(iter(result[1])).cancel()
if finished == read_task:
return finished.result() # type: ignore[return-value]
else:
raise _OperationCancelled("operation cancelled")
for task in done:
if task == read_task:
return read_task.result()
else:
raise _OperationCancelled("operation cancelled")
return None # type: ignore[return-value]
finally:
sock.settimeout(sock_timeout)

Expand Down
11 changes: 9 additions & 2 deletions pymongo/pyopenssl_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,16 @@ def _ragged_eof(exc: BaseException) -> bool:
# https://docs.python.org/3/library/ssl.html#notes-on-non-blocking-sockets
class _sslConn(_SSL.Connection):
def __init__(
self, ctx: _SSL.Context, sock: Optional[_socket.socket], suppress_ragged_eofs: bool
self,
ctx: _SSL.Context,
sock: Optional[_socket.socket],
suppress_ragged_eofs: bool,
is_async: bool = False,
):
self.socket_checker = _SocketChecker()
self.suppress_ragged_eofs = suppress_ragged_eofs
super().__init__(ctx, sock)
self._is_async = is_async

def _call(self, call: Callable[..., _T], *args: Any, **kwargs: Any) -> _T:
timeout = self.gettimeout()
Expand All @@ -119,6 +124,8 @@ def _call(self, call: Callable[..., _T], *args: Any, **kwargs: Any) -> _T:
try:
return call(*args, **kwargs)
except BLOCKING_IO_ERRORS as exc:
if self._is_async:
raise exc
# Check for closed socket.
if self.fileno() == -1:
if timeout and _time.monotonic() - start > timeout:
Expand Down Expand Up @@ -381,7 +388,7 @@ async def a_wrap_socket(
"""Wrap an existing Python socket connection and return a TLS socket
object.
"""
ssl_conn = _sslConn(self._ctx, sock, suppress_ragged_eofs)
ssl_conn = _sslConn(self._ctx, sock, suppress_ragged_eofs, True)
loop = asyncio.get_running_loop()
if session:
ssl_conn.set_session(session)
Expand Down

0 comments on commit d69b5f6

Please sign in to comment.