diff --git a/src/anyio/_backends/_asyncio.py b/src/anyio/_backends/_asyncio.py index a7f74412..d6bfa578 100644 --- a/src/anyio/_backends/_asyncio.py +++ b/src/anyio/_backends/_asyncio.py @@ -98,7 +98,6 @@ from ..abc._eventloop import StrOrBytesPath from ..lowlevel import RunVar from ..streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream -from ._selector_thread import _get_selector_windows if sys.version_info >= (3, 10): from typing import ParamSpec @@ -2684,19 +2683,20 @@ async def wait_socket_readable(cls, sock: socket.socket) -> None: raise BusyResourceError("reading from") from None loop = get_running_loop() - if ( - sys.platform == "win32" - and asyncio.get_event_loop_policy().__class__.__name__ - == "WindowsProactorEventLoopPolicy" - ): + add_reader = loop.add_reader + event = read_events[sock] = asyncio.Event() + #try: + # add_reader(sock, event.set) + #except NotImplementedError: + if True: + # Proactor on Windows does not yet implement add/remove reader + from ._selector_thread import _get_selector_windows + selector = _get_selector_windows(loop) - add_reader = selector.add_reader + selector.add_reader(sock, event.set) remove_reader = selector.remove_reader else: - add_reader = loop.add_reader remove_reader = loop.remove_reader - event = read_events[sock] = asyncio.Event() - add_reader(sock, event.set) try: await event.wait() finally: @@ -2722,13 +2722,24 @@ async def wait_socket_writable(cls, sock: socket.socket) -> None: raise BusyResourceError("writing to") from None loop = get_running_loop() + add_writer = loop.add_writer event = write_events[sock] = asyncio.Event() - loop.add_writer(sock.fileno(), event.set) + try: + add_writer(sock.fileno(), event.set) + except NotImplementedError: + # Proactor on Windows does not yet implement add/remove writer + from ._selector_thread import _get_selector_windows + + selector = _get_selector_windows(loop) + selector.add_writer(sock, event.set) + remove_writer = selector.remove_writer + else: + remove_writer = loop.remove_writer try: await event.wait() finally: if write_events.pop(sock, None) is not None: - loop.remove_writer(sock) + remove_writer(sock) writable = True else: writable = False diff --git a/src/anyio/_backends/_selector_thread.py b/src/anyio/_backends/_selector_thread.py index 10635fe1..c86080a8 100644 --- a/src/anyio/_backends/_selector_thread.py +++ b/src/anyio/_backends/_selector_thread.py @@ -7,7 +7,6 @@ from __future__ import annotations import asyncio -import atexit import errno import functools import select @@ -21,6 +20,8 @@ ) from weakref import WeakKeyDictionary +from ._asyncio import find_root_task + if typing.TYPE_CHECKING: from typing_extensions import Protocol @@ -38,7 +39,7 @@ def fileno(self) -> int: _selector_loops: set[SelectorThread] = set() -def _atexit_callback() -> None: +def _at_loop_close_callback(future: asyncio.Future) -> None: for loop in _selector_loops: with loop._select_cond: loop._closing_selector = True @@ -56,12 +57,7 @@ def _atexit_callback() -> None: _selector_loops.clear() -atexit.register(_atexit_callback) - - # SelectorThread from tornado 6.4.0 - - class SelectorThread: """Define ``add_reader`` methods to be called in a background select thread. @@ -84,19 +80,6 @@ def __init__(self, real_loop: asyncio.AbstractEventLoop) -> None: ) = None self._closing_selector = False self._thread: threading.Thread | None = None - self._thread_manager_handle = self._thread_manager() - - async def thread_manager_anext() -> None: - # the anext builtin wasn't added until 3.10. We just need to iterate - # this generator one step. - await self._thread_manager_handle.__anext__() - - # When the loop starts, start the thread. Not too soon because we can't - # clean up if we get to this point but the event loop is closed without - # starting. - self._real_loop.call_soon( - lambda: self._real_loop.create_task(thread_manager_anext()) - ) self._readers: dict[_FileDescriptorLike, Callable] = {} self._writers: dict[_FileDescriptorLike, Callable] = {} @@ -108,6 +91,7 @@ async def thread_manager_anext() -> None: self._waker_w.setblocking(False) _selector_loops.add(self) self.add_reader(self._waker_r, self._consume_waker) + self._thread_manager() def close(self) -> None: if self._closed: @@ -124,30 +108,19 @@ def close(self) -> None: self._waker_w.close() self._closed = True - async def _thread_manager(self) -> typing.AsyncGenerator[None, None]: + def _thread_manager(self) -> typing.AsyncGenerator[None, None]: # Create a thread to run the select system call. We manage this thread - # manually so we can trigger a clean shutdown from an atexit hook. Note + # manually so we can trigger a clean shutdown at loop teardown. Note # that due to the order of operations at shutdown, only daemon threads # can be shut down in this way (non-daemon threads would require the # introduction of a new hook: https://bugs.python.org/issue41962) self._thread = threading.Thread( - name="Tornado selector", + name="AnyIO selector", daemon=True, target=self._run_select, ) self._thread.start() self._start_select() - try: - # The presense of this yield statement means that this coroutine - # is actually an asynchronous generator, which has a special - # shutdown protocol. We wait at this yield point until the - # event loop's shutdown_asyncgens method is called, at which point - # we will get a GeneratorExit exception and can shut down the - # selector thread. - yield - except GeneratorExit: - self.close() - raise def _wake_selector(self) -> None: if self._closed: @@ -298,6 +271,7 @@ def _get_selector_windows( if asyncio_loop in _selectors: return _selectors[asyncio_loop] + find_root_task().add_done_callback(_at_loop_close_callback) selector_thread = _selectors[asyncio_loop] = SelectorThread(asyncio_loop) # patch loop.close to also close the selector thread diff --git a/tests/test_sockets.py b/tests/test_sockets.py index 0e0f794e..cbe6200a 100644 --- a/tests/test_sockets.py +++ b/tests/test_sockets.py @@ -1858,17 +1858,14 @@ def client(port: int) -> None: sock.connect(("127.0.0.1", port)) sock.sendall(b"Hello, world") - with move_on_after(0.1): - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: - sock.bind(("127.0.0.1", 0)) - port = sock.getsockname()[1] - sock.listen() - thread = Thread(target=client, args=(port,), daemon=True) - thread.start() - conn, addr = sock.accept() + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("127.0.0.1", 0)) + port = sock.getsockname()[1] + sock.listen() + thread = Thread(target=client, args=(port,)) + thread.start() + thread.join() + conn, addr = sock.accept() + with fail_after(5): with conn: await wait_socket_readable(conn) - socket_readable = True - - assert socket_readable - thread.join()