diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst index 513ecf0b..c32eef70 100644 --- a/docs/versionhistory.rst +++ b/docs/versionhistory.rst @@ -7,9 +7,10 @@ This library adheres to `Semantic Versioning 2.0 `_. - Fixed a misleading ``ValueError`` in the context of DNS failures (`#815 `_; PR by @graingert) -- Allowed ``wait_socket_readable`` and ``wait_socket_writable`` to accept a socket - file descriptor (`#824 `_) - (PR by @davidbrochart) +- Added ``wait_readable`` and ``wait_writable`` functions that accept an object with a + ``.fileno()`` method or an integer handle, and deprecated ``wait_socket_readable`` + and ``wait_socket_writable``. + (`#824 `_) (PR by @davidbrochart) **4.6.2** diff --git a/src/anyio/__init__.py b/src/anyio/__init__.py index fd9fe06b..0738e595 100644 --- a/src/anyio/__init__.py +++ b/src/anyio/__init__.py @@ -34,8 +34,10 @@ from ._core._sockets import create_unix_listener as create_unix_listener from ._core._sockets import getaddrinfo as getaddrinfo from ._core._sockets import getnameinfo as getnameinfo +from ._core._sockets import wait_readable as wait_readable from ._core._sockets import wait_socket_readable as wait_socket_readable from ._core._sockets import wait_socket_writable as wait_socket_writable +from ._core._sockets import wait_writable as wait_writable from ._core._streams import create_memory_object_stream as create_memory_object_stream from ._core._subprocesses import open_process as open_process from ._core._subprocesses import run_process as run_process diff --git a/src/anyio/_backends/_asyncio.py b/src/anyio/_backends/_asyncio.py index 50ba8cbb..9d071a6b 100644 --- a/src/anyio/_backends/_asyncio.py +++ b/src/anyio/_backends/_asyncio.py @@ -1722,8 +1722,8 @@ async def send(self, item: bytes) -> None: return -_read_events: RunVar[dict[int, asyncio.Event]] = RunVar("read_events") -_write_events: RunVar[dict[int, asyncio.Event]] = RunVar("write_events") +_read_events: RunVar[dict[socket.socket | int, asyncio.Event]] = RunVar("read_events") +_write_events: RunVar[dict[socket.socket | int, asyncio.Event]] = RunVar("write_events") # @@ -2675,7 +2675,7 @@ async def getnameinfo( return await get_running_loop().getnameinfo(sockaddr, flags) @classmethod - async def wait_socket_readable(cls, sock: HasFileno | int) -> None: + async def wait_socket_readable(cls, sock: socket.socket) -> None: await cls.checkpoint() try: read_events = _read_events.get() @@ -2683,9 +2683,6 @@ async def wait_socket_readable(cls, sock: HasFileno | int) -> None: read_events = {} _read_events.set(read_events) - if not isinstance(sock, int): - sock = sock.fileno() - if read_events.get(sock): raise BusyResourceError("reading from") from None @@ -2705,7 +2702,7 @@ async def wait_socket_readable(cls, sock: HasFileno | int) -> None: raise ClosedResourceError @classmethod - async def wait_socket_writable(cls, sock: HasFileno | int) -> None: + async def wait_socket_writable(cls, sock: socket.socket) -> None: await cls.checkpoint() try: write_events = _write_events.get() @@ -2713,15 +2710,12 @@ async def wait_socket_writable(cls, sock: HasFileno | int) -> None: write_events = {} _write_events.set(write_events) - if not isinstance(sock, int): - sock = sock.fileno() - if write_events.get(sock): raise BusyResourceError("writing to") from None loop = get_running_loop() event = write_events[sock] = asyncio.Event() - loop.add_writer(sock, event.set) + loop.add_writer(sock.fileno(), event.set) try: await event.wait() finally: @@ -2734,6 +2728,66 @@ async def wait_socket_writable(cls, sock: HasFileno | int) -> None: if not writable: raise ClosedResourceError + @classmethod + async def wait_readable(cls, obj: HasFileno | int) -> None: + await cls.checkpoint() + try: + read_events = _read_events.get() + except LookupError: + read_events = {} + _read_events.set(read_events) + + if not isinstance(obj, int): + obj = obj.fileno() + + if read_events.get(obj): + raise BusyResourceError("reading from") from None + + loop = get_running_loop() + event = read_events[obj] = asyncio.Event() + loop.add_reader(obj, event.set) + try: + await event.wait() + finally: + if read_events.pop(obj, None) is not None: + loop.remove_reader(obj) + readable = True + else: + readable = False + + if not readable: + raise ClosedResourceError + + @classmethod + async def wait_writable(cls, obj: HasFileno | int) -> None: + await cls.checkpoint() + try: + write_events = _write_events.get() + except LookupError: + write_events = {} + _write_events.set(write_events) + + if not isinstance(obj, int): + obj = obj.fileno() + + if write_events.get(obj): + raise BusyResourceError("writing to") from None + + loop = get_running_loop() + event = write_events[obj] = asyncio.Event() + loop.add_writer(obj, event.set) + try: + await event.wait() + finally: + if write_events.pop(obj, None) is not None: + loop.remove_writer(obj) + writable = True + else: + writable = False + + if not writable: + raise ClosedResourceError + @classmethod def current_default_thread_limiter(cls) -> CapacityLimiter: try: diff --git a/src/anyio/_backends/_trio.py b/src/anyio/_backends/_trio.py index a6749e41..c33c97a6 100644 --- a/src/anyio/_backends/_trio.py +++ b/src/anyio/_backends/_trio.py @@ -1264,7 +1264,7 @@ async def getnameinfo( return await trio.socket.getnameinfo(sockaddr, flags) @classmethod - async def wait_socket_readable(cls, sock: HasFileno | int) -> None: + async def wait_socket_readable(cls, sock: socket.socket) -> None: try: await wait_readable(sock) except trio.ClosedResourceError as exc: @@ -1273,7 +1273,7 @@ async def wait_socket_readable(cls, sock: HasFileno | int) -> None: raise BusyResourceError("reading from") from None @classmethod - async def wait_socket_writable(cls, sock: HasFileno | int) -> None: + async def wait_socket_writable(cls, sock: socket.socket) -> None: try: await wait_writable(sock) except trio.ClosedResourceError as exc: @@ -1281,6 +1281,24 @@ async def wait_socket_writable(cls, sock: HasFileno | int) -> None: except trio.BusyResourceError: raise BusyResourceError("writing to") from None + @classmethod + async def wait_readable(cls, obj: HasFileno | int) -> None: + try: + await wait_readable(obj) + except trio.ClosedResourceError as exc: + raise ClosedResourceError().with_traceback(exc.__traceback__) from None + except trio.BusyResourceError: + raise BusyResourceError("reading from") from None + + @classmethod + async def wait_writable(cls, obj: HasFileno | int) -> None: + try: + await wait_writable(obj) + except trio.ClosedResourceError as exc: + raise ClosedResourceError().with_traceback(exc.__traceback__) from None + except trio.BusyResourceError: + raise BusyResourceError("writing to") from None + @classmethod def current_default_thread_limiter(cls) -> CapacityLimiter: try: diff --git a/src/anyio/_core/_sockets.py b/src/anyio/_core/_sockets.py index c8c189ec..1c8e74e2 100644 --- a/src/anyio/_core/_sockets.py +++ b/src/anyio/_core/_sockets.py @@ -11,6 +11,7 @@ from os import PathLike, chmod from socket import AddressFamily, SocketKind from typing import TYPE_CHECKING, Any, Literal, cast, overload +from warnings import warn from .. import to_thread from ..abc import ( @@ -596,8 +597,10 @@ def getnameinfo(sockaddr: IPSockAddrType, flags: int = 0) -> Awaitable[tuple[str return get_async_backend().getnameinfo(sockaddr, flags) -def wait_socket_readable(sock: HasFileno | int) -> Awaitable[None]: +def wait_socket_readable(sock: socket.socket) -> Awaitable[None]: """ + Deprecated, use `wait_readable` instead. + Wait until the given socket has data to be read. This does **NOT** work on Windows when using the asyncio backend with a proactor @@ -606,18 +609,25 @@ def wait_socket_readable(sock: HasFileno | int) -> Awaitable[None]: .. warning:: Only use this on raw sockets that have not been wrapped by any higher level constructs like socket streams! - :param sock: a socket object or its file descriptor + :param sock: a socket object :raises ~anyio.ClosedResourceError: if the socket was closed while waiting for the socket to become readable :raises ~anyio.BusyResourceError: if another task is already waiting for the socket to become readable """ + warn( + "This function is deprecated; use `wait_readable` instead", + DeprecationWarning, + stacklevel=2, + ) return get_async_backend().wait_socket_readable(sock) -def wait_socket_writable(sock: HasFileno | int) -> Awaitable[None]: +def wait_socket_writable(sock: socket.socket) -> Awaitable[None]: """ + Deprecated, use `wait_writable` instead. + Wait until the given socket can be written to. This does **NOT** work on Windows when using the asyncio backend with a proactor @@ -626,16 +636,73 @@ def wait_socket_writable(sock: HasFileno | int) -> Awaitable[None]: .. warning:: Only use this on raw sockets that have not been wrapped by any higher level constructs like socket streams! - :param sock: a socket object or its file descriptor + :param sock: a socket object :raises ~anyio.ClosedResourceError: if the socket was closed while waiting for the socket to become writable :raises ~anyio.BusyResourceError: if another task is already waiting for the socket to become writable """ + warn( + "This function is deprecated; use `wait_writable` instead", + DeprecationWarning, + stacklevel=2, + ) return get_async_backend().wait_socket_writable(sock) +def wait_readable(obj: HasFileno | int) -> Awaitable[None]: + """ + Wait until the given object has data to be read. + + On Unix systems, ``obj`` must either be an integer file descriptor, or else an + object with a ``.fileno()`` method which returns an integer file descriptor. Any + kind of file descriptor can be passed, though the exact semantics will depend on + your kernel. For example, this probably won't do anything useful for on-disk files. + + On Windows systems, ``obj`` must either be an integer ``SOCKET`` handle, or else an + object with a ``.fileno()`` method which returns an integer ``SOCKET`` handle. File + descriptors aren't supported, and neither are handles that refer to anything besides + a ``SOCKET``. + + This does **NOT** work on Windows when using the asyncio backend with a proactor + event loop (default on py3.8+). + + .. warning:: Only use this on raw sockets that have not been wrapped by any higher + level constructs like socket streams! + + :param obj: an object with a ``.fileno()`` method or an integer handle. + :raises ~anyio.ClosedResourceError: if the object was closed while waiting for the + object to become readable + :raises ~anyio.BusyResourceError: if another task is already waiting for the object + to become readable + + """ + return get_async_backend().wait_readable(obj) + + +def wait_writable(obj: HasFileno | int) -> Awaitable[None]: + """ + Wait until the given object can be written to. + + See `wait_readable` for the definition of ``obj``. + + This does **NOT** work on Windows when using the asyncio backend with a proactor + event loop (default on py3.8+). + + .. warning:: Only use this on raw sockets that have not been wrapped by any higher + level constructs like socket streams! + + :param obj: an object with a ``.fileno()`` method or an integer handle. + :raises ~anyio.ClosedResourceError: if the object was closed while waiting for the + object to become writable + :raises ~anyio.BusyResourceError: if another task is already waiting for the object + to become writable + + """ + return get_async_backend().wait_writable(obj) + + # # Private API # diff --git a/src/anyio/abc/_eventloop.py b/src/anyio/abc/_eventloop.py index 3d7866fd..05bf83d5 100644 --- a/src/anyio/abc/_eventloop.py +++ b/src/anyio/abc/_eventloop.py @@ -335,12 +335,22 @@ async def getnameinfo( @classmethod @abstractmethod - async def wait_socket_readable(cls, sock: HasFileno | int) -> None: + async def wait_socket_readable(cls, sock: socket) -> None: pass @classmethod @abstractmethod - async def wait_socket_writable(cls, sock: HasFileno | int) -> None: + async def wait_socket_writable(cls, sock: socket) -> None: + pass + + @classmethod + @abstractmethod + async def wait_readable(cls, obj: HasFileno | int) -> None: + pass + + @classmethod + @abstractmethod + async def wait_writable(cls, obj: HasFileno | int) -> None: pass @classmethod diff --git a/tests/test_sockets.py b/tests/test_sockets.py index 8e376ab2..cf41d8e6 100644 --- a/tests/test_sockets.py +++ b/tests/test_sockets.py @@ -46,8 +46,10 @@ getnameinfo, move_on_after, wait_all_tasks_blocked, + wait_readable, wait_socket_readable, wait_socket_writable, + wait_writable, ) from anyio.abc import ( IPSockAddrType, @@ -1866,7 +1868,7 @@ async def test_wait_socket( if policy.__class__.__name__ == "WindowsProactorEventLoopPolicy": pytest.skip("Does not work on asyncio/Windows/ProactorEventLoop") - wait_socket = wait_socket_readable if event == "readable" else wait_socket_writable + wait = wait_readable if event == "readable" else wait_writable def client(port: int) -> None: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: @@ -1884,4 +1886,27 @@ def client(port: int) -> None: with conn: sock_or_fd: HasFileno | int = conn.fileno() if socket_type == "fd" else conn with fail_after(10): - await wait_socket(sock_or_fd) + await wait(sock_or_fd) + + +async def test_deprecated_wait_socket(anyio_backend_name: str) -> None: + if anyio_backend_name == "asyncio" and sys.platform == "win32": + import asyncio + + policy = asyncio.get_event_loop_policy() + if policy.__class__.__name__ == "WindowsProactorEventLoopPolicy": + pytest.skip("Does not work on asyncio/Windows/ProactorEventLoop") + + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + with pytest.warns( + DeprecationWarning, + match="This function is deprecated; use `wait_readable` instead", + ): + with move_on_after(0.1): + await wait_socket_readable(sock) + with pytest.warns( + DeprecationWarning, + match="This function is deprecated; use `wait_writable` instead", + ): + with move_on_after(0.1): + await wait_socket_writable(sock)