diff --git a/docs/versionhistory.rst b/docs/versionhistory.rst
index 513ecf0b..c32eef70 100644
--- a/docs/versionhistory.rst
+++ b/docs/versionhistory.rst
@@ -7,9 +7,10 @@ This library adheres to `Semantic Versioning 2.0 `_.
- Fixed a misleading ``ValueError`` in the context of DNS failures
(`#815 `_; PR by @graingert)
-- Allowed ``wait_socket_readable`` and ``wait_socket_writable`` to accept a socket
- file descriptor (`#824 `_)
- (PR by @davidbrochart)
+- Added ``wait_readable`` and ``wait_writable`` functions that accept an object with a
+ ``.fileno()`` method or an integer handle, and deprecated ``wait_socket_readable``
+ and ``wait_socket_writable``.
+ (`#824 `_) (PR by @davidbrochart)
**4.6.2**
diff --git a/src/anyio/__init__.py b/src/anyio/__init__.py
index fd9fe06b..0738e595 100644
--- a/src/anyio/__init__.py
+++ b/src/anyio/__init__.py
@@ -34,8 +34,10 @@
from ._core._sockets import create_unix_listener as create_unix_listener
from ._core._sockets import getaddrinfo as getaddrinfo
from ._core._sockets import getnameinfo as getnameinfo
+from ._core._sockets import wait_readable as wait_readable
from ._core._sockets import wait_socket_readable as wait_socket_readable
from ._core._sockets import wait_socket_writable as wait_socket_writable
+from ._core._sockets import wait_writable as wait_writable
from ._core._streams import create_memory_object_stream as create_memory_object_stream
from ._core._subprocesses import open_process as open_process
from ._core._subprocesses import run_process as run_process
diff --git a/src/anyio/_backends/_asyncio.py b/src/anyio/_backends/_asyncio.py
index 50ba8cbb..9d071a6b 100644
--- a/src/anyio/_backends/_asyncio.py
+++ b/src/anyio/_backends/_asyncio.py
@@ -1722,8 +1722,8 @@ async def send(self, item: bytes) -> None:
return
-_read_events: RunVar[dict[int, asyncio.Event]] = RunVar("read_events")
-_write_events: RunVar[dict[int, asyncio.Event]] = RunVar("write_events")
+_read_events: RunVar[dict[socket.socket | int, asyncio.Event]] = RunVar("read_events")
+_write_events: RunVar[dict[socket.socket | int, asyncio.Event]] = RunVar("write_events")
#
@@ -2675,7 +2675,7 @@ async def getnameinfo(
return await get_running_loop().getnameinfo(sockaddr, flags)
@classmethod
- async def wait_socket_readable(cls, sock: HasFileno | int) -> None:
+ async def wait_socket_readable(cls, sock: socket.socket) -> None:
await cls.checkpoint()
try:
read_events = _read_events.get()
@@ -2683,9 +2683,6 @@ async def wait_socket_readable(cls, sock: HasFileno | int) -> None:
read_events = {}
_read_events.set(read_events)
- if not isinstance(sock, int):
- sock = sock.fileno()
-
if read_events.get(sock):
raise BusyResourceError("reading from") from None
@@ -2705,7 +2702,7 @@ async def wait_socket_readable(cls, sock: HasFileno | int) -> None:
raise ClosedResourceError
@classmethod
- async def wait_socket_writable(cls, sock: HasFileno | int) -> None:
+ async def wait_socket_writable(cls, sock: socket.socket) -> None:
await cls.checkpoint()
try:
write_events = _write_events.get()
@@ -2713,15 +2710,12 @@ async def wait_socket_writable(cls, sock: HasFileno | int) -> None:
write_events = {}
_write_events.set(write_events)
- if not isinstance(sock, int):
- sock = sock.fileno()
-
if write_events.get(sock):
raise BusyResourceError("writing to") from None
loop = get_running_loop()
event = write_events[sock] = asyncio.Event()
- loop.add_writer(sock, event.set)
+ loop.add_writer(sock.fileno(), event.set)
try:
await event.wait()
finally:
@@ -2734,6 +2728,66 @@ async def wait_socket_writable(cls, sock: HasFileno | int) -> None:
if not writable:
raise ClosedResourceError
+ @classmethod
+ async def wait_readable(cls, obj: HasFileno | int) -> None:
+ await cls.checkpoint()
+ try:
+ read_events = _read_events.get()
+ except LookupError:
+ read_events = {}
+ _read_events.set(read_events)
+
+ if not isinstance(obj, int):
+ obj = obj.fileno()
+
+ if read_events.get(obj):
+ raise BusyResourceError("reading from") from None
+
+ loop = get_running_loop()
+ event = read_events[obj] = asyncio.Event()
+ loop.add_reader(obj, event.set)
+ try:
+ await event.wait()
+ finally:
+ if read_events.pop(obj, None) is not None:
+ loop.remove_reader(obj)
+ readable = True
+ else:
+ readable = False
+
+ if not readable:
+ raise ClosedResourceError
+
+ @classmethod
+ async def wait_writable(cls, obj: HasFileno | int) -> None:
+ await cls.checkpoint()
+ try:
+ write_events = _write_events.get()
+ except LookupError:
+ write_events = {}
+ _write_events.set(write_events)
+
+ if not isinstance(obj, int):
+ obj = obj.fileno()
+
+ if write_events.get(obj):
+ raise BusyResourceError("writing to") from None
+
+ loop = get_running_loop()
+ event = write_events[obj] = asyncio.Event()
+ loop.add_writer(obj, event.set)
+ try:
+ await event.wait()
+ finally:
+ if write_events.pop(obj, None) is not None:
+ loop.remove_writer(obj)
+ writable = True
+ else:
+ writable = False
+
+ if not writable:
+ raise ClosedResourceError
+
@classmethod
def current_default_thread_limiter(cls) -> CapacityLimiter:
try:
diff --git a/src/anyio/_backends/_trio.py b/src/anyio/_backends/_trio.py
index a6749e41..c33c97a6 100644
--- a/src/anyio/_backends/_trio.py
+++ b/src/anyio/_backends/_trio.py
@@ -1264,7 +1264,7 @@ async def getnameinfo(
return await trio.socket.getnameinfo(sockaddr, flags)
@classmethod
- async def wait_socket_readable(cls, sock: HasFileno | int) -> None:
+ async def wait_socket_readable(cls, sock: socket.socket) -> None:
try:
await wait_readable(sock)
except trio.ClosedResourceError as exc:
@@ -1273,7 +1273,7 @@ async def wait_socket_readable(cls, sock: HasFileno | int) -> None:
raise BusyResourceError("reading from") from None
@classmethod
- async def wait_socket_writable(cls, sock: HasFileno | int) -> None:
+ async def wait_socket_writable(cls, sock: socket.socket) -> None:
try:
await wait_writable(sock)
except trio.ClosedResourceError as exc:
@@ -1281,6 +1281,24 @@ async def wait_socket_writable(cls, sock: HasFileno | int) -> None:
except trio.BusyResourceError:
raise BusyResourceError("writing to") from None
+ @classmethod
+ async def wait_readable(cls, obj: HasFileno | int) -> None:
+ try:
+ await wait_readable(obj)
+ except trio.ClosedResourceError as exc:
+ raise ClosedResourceError().with_traceback(exc.__traceback__) from None
+ except trio.BusyResourceError:
+ raise BusyResourceError("reading from") from None
+
+ @classmethod
+ async def wait_writable(cls, obj: HasFileno | int) -> None:
+ try:
+ await wait_writable(obj)
+ except trio.ClosedResourceError as exc:
+ raise ClosedResourceError().with_traceback(exc.__traceback__) from None
+ except trio.BusyResourceError:
+ raise BusyResourceError("writing to") from None
+
@classmethod
def current_default_thread_limiter(cls) -> CapacityLimiter:
try:
diff --git a/src/anyio/_core/_sockets.py b/src/anyio/_core/_sockets.py
index c8c189ec..1c8e74e2 100644
--- a/src/anyio/_core/_sockets.py
+++ b/src/anyio/_core/_sockets.py
@@ -11,6 +11,7 @@
from os import PathLike, chmod
from socket import AddressFamily, SocketKind
from typing import TYPE_CHECKING, Any, Literal, cast, overload
+from warnings import warn
from .. import to_thread
from ..abc import (
@@ -596,8 +597,10 @@ def getnameinfo(sockaddr: IPSockAddrType, flags: int = 0) -> Awaitable[tuple[str
return get_async_backend().getnameinfo(sockaddr, flags)
-def wait_socket_readable(sock: HasFileno | int) -> Awaitable[None]:
+def wait_socket_readable(sock: socket.socket) -> Awaitable[None]:
"""
+ Deprecated, use `wait_readable` instead.
+
Wait until the given socket has data to be read.
This does **NOT** work on Windows when using the asyncio backend with a proactor
@@ -606,18 +609,25 @@ def wait_socket_readable(sock: HasFileno | int) -> Awaitable[None]:
.. warning:: Only use this on raw sockets that have not been wrapped by any higher
level constructs like socket streams!
- :param sock: a socket object or its file descriptor
+ :param sock: a socket object
:raises ~anyio.ClosedResourceError: if the socket was closed while waiting for the
socket to become readable
:raises ~anyio.BusyResourceError: if another task is already waiting for the socket
to become readable
"""
+ warn(
+ "This function is deprecated; use `wait_readable` instead",
+ DeprecationWarning,
+ stacklevel=2,
+ )
return get_async_backend().wait_socket_readable(sock)
-def wait_socket_writable(sock: HasFileno | int) -> Awaitable[None]:
+def wait_socket_writable(sock: socket.socket) -> Awaitable[None]:
"""
+ Deprecated, use `wait_writable` instead.
+
Wait until the given socket can be written to.
This does **NOT** work on Windows when using the asyncio backend with a proactor
@@ -626,16 +636,73 @@ def wait_socket_writable(sock: HasFileno | int) -> Awaitable[None]:
.. warning:: Only use this on raw sockets that have not been wrapped by any higher
level constructs like socket streams!
- :param sock: a socket object or its file descriptor
+ :param sock: a socket object
:raises ~anyio.ClosedResourceError: if the socket was closed while waiting for the
socket to become writable
:raises ~anyio.BusyResourceError: if another task is already waiting for the socket
to become writable
"""
+ warn(
+ "This function is deprecated; use `wait_writable` instead",
+ DeprecationWarning,
+ stacklevel=2,
+ )
return get_async_backend().wait_socket_writable(sock)
+def wait_readable(obj: HasFileno | int) -> Awaitable[None]:
+ """
+ Wait until the given object has data to be read.
+
+ On Unix systems, ``obj`` must either be an integer file descriptor, or else an
+ object with a ``.fileno()`` method which returns an integer file descriptor. Any
+ kind of file descriptor can be passed, though the exact semantics will depend on
+ your kernel. For example, this probably won't do anything useful for on-disk files.
+
+ On Windows systems, ``obj`` must either be an integer ``SOCKET`` handle, or else an
+ object with a ``.fileno()`` method which returns an integer ``SOCKET`` handle. File
+ descriptors aren't supported, and neither are handles that refer to anything besides
+ a ``SOCKET``.
+
+ This does **NOT** work on Windows when using the asyncio backend with a proactor
+ event loop (default on py3.8+).
+
+ .. warning:: Only use this on raw sockets that have not been wrapped by any higher
+ level constructs like socket streams!
+
+ :param obj: an object with a ``.fileno()`` method or an integer handle.
+ :raises ~anyio.ClosedResourceError: if the object was closed while waiting for the
+ object to become readable
+ :raises ~anyio.BusyResourceError: if another task is already waiting for the object
+ to become readable
+
+ """
+ return get_async_backend().wait_readable(obj)
+
+
+def wait_writable(obj: HasFileno | int) -> Awaitable[None]:
+ """
+ Wait until the given object can be written to.
+
+ See `wait_readable` for the definition of ``obj``.
+
+ This does **NOT** work on Windows when using the asyncio backend with a proactor
+ event loop (default on py3.8+).
+
+ .. warning:: Only use this on raw sockets that have not been wrapped by any higher
+ level constructs like socket streams!
+
+ :param obj: an object with a ``.fileno()`` method or an integer handle.
+ :raises ~anyio.ClosedResourceError: if the object was closed while waiting for the
+ object to become writable
+ :raises ~anyio.BusyResourceError: if another task is already waiting for the object
+ to become writable
+
+ """
+ return get_async_backend().wait_writable(obj)
+
+
#
# Private API
#
diff --git a/src/anyio/abc/_eventloop.py b/src/anyio/abc/_eventloop.py
index 3d7866fd..05bf83d5 100644
--- a/src/anyio/abc/_eventloop.py
+++ b/src/anyio/abc/_eventloop.py
@@ -335,12 +335,22 @@ async def getnameinfo(
@classmethod
@abstractmethod
- async def wait_socket_readable(cls, sock: HasFileno | int) -> None:
+ async def wait_socket_readable(cls, sock: socket) -> None:
pass
@classmethod
@abstractmethod
- async def wait_socket_writable(cls, sock: HasFileno | int) -> None:
+ async def wait_socket_writable(cls, sock: socket) -> None:
+ pass
+
+ @classmethod
+ @abstractmethod
+ async def wait_readable(cls, obj: HasFileno | int) -> None:
+ pass
+
+ @classmethod
+ @abstractmethod
+ async def wait_writable(cls, obj: HasFileno | int) -> None:
pass
@classmethod
diff --git a/tests/test_sockets.py b/tests/test_sockets.py
index 8e376ab2..cf41d8e6 100644
--- a/tests/test_sockets.py
+++ b/tests/test_sockets.py
@@ -46,8 +46,10 @@
getnameinfo,
move_on_after,
wait_all_tasks_blocked,
+ wait_readable,
wait_socket_readable,
wait_socket_writable,
+ wait_writable,
)
from anyio.abc import (
IPSockAddrType,
@@ -1866,7 +1868,7 @@ async def test_wait_socket(
if policy.__class__.__name__ == "WindowsProactorEventLoopPolicy":
pytest.skip("Does not work on asyncio/Windows/ProactorEventLoop")
- wait_socket = wait_socket_readable if event == "readable" else wait_socket_writable
+ wait = wait_readable if event == "readable" else wait_writable
def client(port: int) -> None:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
@@ -1884,4 +1886,27 @@ def client(port: int) -> None:
with conn:
sock_or_fd: HasFileno | int = conn.fileno() if socket_type == "fd" else conn
with fail_after(10):
- await wait_socket(sock_or_fd)
+ await wait(sock_or_fd)
+
+
+async def test_deprecated_wait_socket(anyio_backend_name: str) -> None:
+ if anyio_backend_name == "asyncio" and sys.platform == "win32":
+ import asyncio
+
+ policy = asyncio.get_event_loop_policy()
+ if policy.__class__.__name__ == "WindowsProactorEventLoopPolicy":
+ pytest.skip("Does not work on asyncio/Windows/ProactorEventLoop")
+
+ with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
+ with pytest.warns(
+ DeprecationWarning,
+ match="This function is deprecated; use `wait_readable` instead",
+ ):
+ with move_on_after(0.1):
+ await wait_socket_readable(sock)
+ with pytest.warns(
+ DeprecationWarning,
+ match="This function is deprecated; use `wait_writable` instead",
+ ):
+ with move_on_after(0.1):
+ await wait_socket_writable(sock)