Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allowed wait_socket_readable/writable to accept a file descriptor #824

Merged
merged 18 commits into from
Nov 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,10 @@ Sockets and networking
.. autofunction:: anyio.create_connected_udp_socket
.. autofunction:: anyio.getaddrinfo
.. autofunction:: anyio.getnameinfo
.. autofunction:: anyio.wait_readable
.. autofunction:: anyio.wait_socket_readable
.. autofunction:: anyio.wait_socket_writable
.. autofunction:: anyio.wait_writable

.. autoclass:: anyio.abc.SocketAttribute
.. autoclass:: anyio.abc.SocketStream()
Expand Down
4 changes: 4 additions & 0 deletions docs/versionhistory.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ This library adheres to `Semantic Versioning 2.0 <http://semver.org/>`_.

- Fixed a misleading ``ValueError`` in the context of DNS failures
(`#815 <https://github.com/agronholm/anyio/issues/815>`_; PR by @graingert)
- Added the ``wait_readable()`` and ``wait_writable()`` functions which will accept
an object with a ``.fileno()`` method or an integer handle, and deprecated
their now obsolete versions (``wait_socket_readable()`` and
``wait_socket_writable()`` (PR by @davidbrochart)

**4.6.2**

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ dependencies = [
"exceptiongroup >= 1.0.2; python_version < '3.11'",
"idna >= 2.8",
"sniffio >= 1.1",
"typing_extensions >= 4.1; python_version < '3.11'",
"typing_extensions >= 4.5; python_version < '3.13'",
]
dynamic = ["version"]

Expand Down
2 changes: 2 additions & 0 deletions src/anyio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,10 @@
from ._core._sockets import create_unix_listener as create_unix_listener
from ._core._sockets import getaddrinfo as getaddrinfo
from ._core._sockets import getnameinfo as getnameinfo
from ._core._sockets import wait_readable as wait_readable
from ._core._sockets import wait_socket_readable as wait_socket_readable
from ._core._sockets import wait_socket_writable as wait_socket_writable
from ._core._sockets import wait_writable as wait_writable
from ._core._streams import create_memory_object_stream as create_memory_object_stream
from ._core._subprocesses import open_process as open_process
from ._core._subprocesses import run_process as run_process
Expand Down
38 changes: 24 additions & 14 deletions src/anyio/_backends/_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
from types import TracebackType
from typing import (
IO,
TYPE_CHECKING,
Any,
Optional,
TypeVar,
Expand Down Expand Up @@ -99,6 +100,9 @@
from ..lowlevel import RunVar
from ..streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream

if TYPE_CHECKING:
from _typeshed import HasFileno

if sys.version_info >= (3, 10):
from typing import ParamSpec
else:
Expand Down Expand Up @@ -1718,8 +1722,8 @@ async def send(self, item: bytes) -> None:
return


_read_events: RunVar[dict[Any, asyncio.Event]] = RunVar("read_events")
_write_events: RunVar[dict[Any, asyncio.Event]] = RunVar("write_events")
_read_events: RunVar[dict[int, asyncio.Event]] = RunVar("read_events")
_write_events: RunVar[dict[int, asyncio.Event]] = RunVar("write_events")


#
Expand Down Expand Up @@ -2671,25 +2675,28 @@ async def getnameinfo(
return await get_running_loop().getnameinfo(sockaddr, flags)

@classmethod
async def wait_socket_readable(cls, sock: socket.socket) -> None:
async def wait_readable(cls, obj: HasFileno | int) -> None:
await cls.checkpoint()
try:
read_events = _read_events.get()
except LookupError:
read_events = {}
_read_events.set(read_events)

if read_events.get(sock):
if not isinstance(obj, int):
obj = obj.fileno()

if read_events.get(obj):
raise BusyResourceError("reading from") from None

loop = get_running_loop()
event = read_events[sock] = asyncio.Event()
loop.add_reader(sock, event.set)
event = read_events[obj] = asyncio.Event()
loop.add_reader(obj, event.set)
try:
await event.wait()
finally:
if read_events.pop(sock, None) is not None:
loop.remove_reader(sock)
if read_events.pop(obj, None) is not None:
loop.remove_reader(obj)
readable = True
else:
readable = False
Expand All @@ -2698,25 +2705,28 @@ async def wait_socket_readable(cls, sock: socket.socket) -> None:
raise ClosedResourceError

@classmethod
async def wait_socket_writable(cls, sock: socket.socket) -> None:
async def wait_writable(cls, obj: HasFileno | int) -> None:
await cls.checkpoint()
try:
write_events = _write_events.get()
except LookupError:
write_events = {}
_write_events.set(write_events)

if write_events.get(sock):
if not isinstance(obj, int):
obj = obj.fileno()

if write_events.get(obj):
raise BusyResourceError("writing to") from None

loop = get_running_loop()
event = write_events[sock] = asyncio.Event()
loop.add_writer(sock.fileno(), event.set)
event = write_events[obj] = asyncio.Event()
loop.add_writer(obj, event.set)
try:
await event.wait()
finally:
if write_events.pop(sock, None) is not None:
loop.remove_writer(sock)
if write_events.pop(obj, None) is not None:
loop.remove_writer(obj)
writable = True
else:
writable = False
Expand Down
12 changes: 8 additions & 4 deletions src/anyio/_backends/_trio.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from types import TracebackType
from typing import (
IO,
TYPE_CHECKING,
Any,
Generic,
NoReturn,
Expand Down Expand Up @@ -80,6 +81,9 @@
from ..abc._eventloop import AsyncBackend, StrOrBytesPath
from ..streams.memory import MemoryObjectSendStream

if TYPE_CHECKING:
from _typeshed import HasFileno

if sys.version_info >= (3, 10):
from typing import ParamSpec
else:
Expand Down Expand Up @@ -1260,18 +1264,18 @@ async def getnameinfo(
return await trio.socket.getnameinfo(sockaddr, flags)

@classmethod
async def wait_socket_readable(cls, sock: socket.socket) -> None:
async def wait_readable(cls, obj: HasFileno | int) -> None:
try:
await wait_readable(sock)
await wait_readable(obj)
except trio.ClosedResourceError as exc:
raise ClosedResourceError().with_traceback(exc.__traceback__) from None
except trio.BusyResourceError:
raise BusyResourceError("reading from") from None

@classmethod
async def wait_socket_writable(cls, sock: socket.socket) -> None:
async def wait_writable(cls, obj: HasFileno | int) -> None:
try:
await wait_writable(sock)
await wait_writable(obj)
except trio.ClosedResourceError as exc:
raise ClosedResourceError().with_traceback(exc.__traceback__) from None
except trio.BusyResourceError:
Expand Down
77 changes: 74 additions & 3 deletions src/anyio/_core/_sockets.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from ipaddress import IPv6Address, ip_address
from os import PathLike, chmod
from socket import AddressFamily, SocketKind
from typing import Any, Literal, cast, overload
from typing import TYPE_CHECKING, Any, Literal, cast, overload

from .. import to_thread
from ..abc import (
Expand All @@ -31,9 +31,19 @@
from ._synchronization import Event
from ._tasks import create_task_group, move_on_after

if TYPE_CHECKING:
from _typeshed import HasFileno
else:
HasFileno = object

if sys.version_info < (3, 11):
from exceptiongroup import ExceptionGroup

if sys.version_info < (3, 13):
from typing_extensions import deprecated
else:
from warnings import deprecated

IPPROTO_IPV6 = getattr(socket, "IPPROTO_IPV6", 41) # https://bugs.python.org/issue29515

AnyIPAddressFamily = Literal[
Expand Down Expand Up @@ -591,8 +601,12 @@ def getnameinfo(sockaddr: IPSockAddrType, flags: int = 0) -> Awaitable[tuple[str
return get_async_backend().getnameinfo(sockaddr, flags)


@deprecated("This function is deprecated; use `wait_readable` instead")
def wait_socket_readable(sock: socket.socket) -> Awaitable[None]:
"""
.. deprecated:: 4.7.0
Use :func:`wait_readable` instead.

Wait until the given socket has data to be read.

This does **NOT** work on Windows when using the asyncio backend with a proactor
Expand All @@ -608,11 +622,15 @@ def wait_socket_readable(sock: socket.socket) -> Awaitable[None]:
to become readable

"""
return get_async_backend().wait_socket_readable(sock)
return get_async_backend().wait_readable(sock.fileno())


@deprecated("This function is deprecated; use `wait_writable` instead")
def wait_socket_writable(sock: socket.socket) -> Awaitable[None]:
"""
.. deprecated:: 4.7.0
Use :func:`wait_writable` instead.

Wait until the given socket can be written to.

This does **NOT** work on Windows when using the asyncio backend with a proactor
Expand All @@ -628,7 +646,60 @@ def wait_socket_writable(sock: socket.socket) -> Awaitable[None]:
to become writable

"""
return get_async_backend().wait_socket_writable(sock)
return get_async_backend().wait_writable(sock.fileno())


def wait_readable(obj: HasFileno | int) -> Awaitable[None]:
"""
Wait until the given object has data to be read.

On Unix systems, ``obj`` must either be an integer file descriptor, or else an
object with a ``.fileno()`` method which returns an integer file descriptor. Any
kind of file descriptor can be passed, though the exact semantics will depend on
your kernel. For example, this probably won't do anything useful for on-disk files.

On Windows systems, ``obj`` must either be an integer ``SOCKET`` handle, or else an
object with a ``.fileno()`` method which returns an integer ``SOCKET`` handle. File
descriptors aren't supported, and neither are handles that refer to anything besides
a ``SOCKET``.

This does **NOT** work on Windows when using the asyncio backend with a proactor
event loop (default on py3.8+).

.. warning:: Only use this on raw sockets that have not been wrapped by any higher
level constructs like socket streams!

:param obj: an object with a ``.fileno()`` method or an integer handle
:raises ~anyio.ClosedResourceError: if the object was closed while waiting for the
object to become readable
:raises ~anyio.BusyResourceError: if another task is already waiting for the object
to become readable

"""
return get_async_backend().wait_readable(obj)


def wait_writable(obj: HasFileno | int) -> Awaitable[None]:
"""
Wait until the given object can be written to.

This does **NOT** work on Windows when using the asyncio backend with a proactor
event loop (default on py3.8+).

.. seealso:: See the documentation of :func:`wait_readable` for the definition of
``obj``.

.. warning:: Only use this on raw sockets that have not been wrapped by any higher
level constructs like socket streams!

:param obj: an object with a ``.fileno()`` method or an integer handle
:raises ~anyio.ClosedResourceError: if the object was closed while waiting for the
object to become writable
:raises ~anyio.BusyResourceError: if another task is already waiting for the object
to become writable

"""
return get_async_backend().wait_writable(obj)


#
Expand Down
6 changes: 4 additions & 2 deletions src/anyio/abc/_eventloop.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from typing_extensions import TypeAlias

if TYPE_CHECKING:
from _typeshed import HasFileno

from .._core._synchronization import CapacityLimiter, Event, Lock, Semaphore
from .._core._tasks import CancelScope
from .._core._testing import TaskInfo
Expand Down Expand Up @@ -333,12 +335,12 @@ async def getnameinfo(

@classmethod
@abstractmethod
async def wait_socket_readable(cls, sock: socket) -> None:
async def wait_readable(cls, obj: HasFileno | int) -> None:
pass

@classmethod
@abstractmethod
async def wait_socket_writable(cls, sock: socket) -> None:
async def wait_writable(cls, obj: HasFileno | int) -> None:
pass

@classmethod
Expand Down
Loading