diff --git a/tests/test_subprocess.py b/tests/test_subprocess.py index 93191bacb..7b5d0ef40 100644 --- a/tests/test_subprocess.py +++ b/tests/test_subprocess.py @@ -3,7 +3,7 @@ import socket from unittest.mock import patch -from uvicorn._subprocess import SpawnProcess, get_subprocess, subprocess_started +from uvicorn._subprocess import SocketSharePickle, SpawnProcess, get_subprocess, subprocess_started from uvicorn._types import ASGIReceiveCallable, ASGISendCallable, Scope from uvicorn.config import Config @@ -36,7 +36,7 @@ def test_subprocess_started() -> None: with patch("tests.test_subprocess.server_run") as mock_run: with patch.object(config, "configure_logging") as mock_config_logging: - subprocess_started(config, server_run, [fdsock], None) + subprocess_started(config, server_run, [SocketSharePickle(fdsock)], None) mock_run.assert_called_once() mock_config_logging.assert_called_once() diff --git a/uvicorn/_subprocess.py b/uvicorn/_subprocess.py index 1c06844de..ea99681b6 100644 --- a/uvicorn/_subprocess.py +++ b/uvicorn/_subprocess.py @@ -7,9 +7,9 @@ import multiprocessing import os +import socket import sys from multiprocessing.context import SpawnProcess -from socket import socket from typing import Callable from uvicorn.config import Config @@ -18,10 +18,39 @@ spawn = multiprocessing.get_context("spawn") +class SocketSharePickle: + def __init__(self, sock: socket.socket): + self._sock = sock + + def get(self) -> socket.socket: + return self._sock + + +class SocketShareRebind: + def __init__(self, sock: socket.socket): + if (sys.platform == "linux" and hasattr(socket, "SO_REUSEPORT")) or hasattr(socket, "SO_REUSEPORT_LB"): + raise RuntimeError("socket_load_balance not supported") + sock.setsockopt(socket.SOL_SOCKET, getattr(socket, "SO_REUSEPORT_LB", socket.SO_REUSEPORT), 1) + self._family = sock.family + self._sockname = sock.getsockname() + + def get(self) -> socket.socket: + try: + sock = socket.socket(family=self._family) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.setsockopt(socket.SOL_SOCKET, getattr(socket, "SO_REUSEPORT_LB", socket.SO_REUSEPORT), 1) + + sock.bind(self._sockname) + return sock + except BaseException: + sock.close() + raise + + def get_subprocess( config: Config, target: Callable[..., None], - sockets: list[socket], + sockets: list[socket.socket], ) -> SpawnProcess: """ Called in the parent process, to instantiate a new child process instance. @@ -41,10 +70,15 @@ def get_subprocess( except (AttributeError, OSError): stdin_fileno = None + socket_shares: list[SocketShareRebind] | list[SocketSharePickle] + if config.socket_load_balance: + socket_shares = [SocketShareRebind(s) for s in sockets] + else: + socket_shares = [SocketSharePickle(s) for s in sockets] kwargs = { "config": config, "target": target, - "sockets": sockets, + "sockets": socket_shares, "stdin_fileno": stdin_fileno, } @@ -54,7 +88,7 @@ def get_subprocess( def subprocess_started( config: Config, target: Callable[..., None], - sockets: list[socket], + sockets: list[SocketSharePickle] | list[SocketShareRebind], stdin_fileno: int | None, ) -> None: """ @@ -77,7 +111,7 @@ def subprocess_started( try: # Now we can call into `Server.run(sockets=sockets)` - target(sockets=sockets) + target(sockets=[s.get() for s in sockets]) except KeyboardInterrupt: # pragma: no cover # supress the exception to avoid a traceback from subprocess.Popen # the parent already expects us to end, so no vital information is lost diff --git a/uvicorn/config.py b/uvicorn/config.py index 9aff8c968..54a2f9b3f 100644 --- a/uvicorn/config.py +++ b/uvicorn/config.py @@ -223,6 +223,7 @@ def __init__( headers: list[tuple[str, str]] | None = None, factory: bool = False, h11_max_incomplete_event_size: int | None = None, + socket_load_balance: bool = False, ): self.app = app self.host = host @@ -268,6 +269,7 @@ def __init__( self.encoded_headers: list[tuple[bytes, bytes]] = [] self.factory = factory self.h11_max_incomplete_event_size = h11_max_incomplete_event_size + self.socket_load_balance = socket_load_balance self.loaded = False self.configure_logging() diff --git a/uvicorn/main.py b/uvicorn/main.py index 43956622d..755fb4a06 100644 --- a/uvicorn/main.py +++ b/uvicorn/main.py @@ -360,6 +360,13 @@ def print_version(ctx: click.Context, param: click.Parameter, value: bool) -> No help="Treat APP as an application factory, i.e. a () -> callable.", show_default=True, ) +@click.option( + "--socket-load-balance", + is_flag=True, + default=False, + help="Use kernel support for socket load balancing", + show_default=True, +) def main( app: str, host: str, @@ -408,6 +415,7 @@ def main( app_dir: str, h11_max_incomplete_event_size: int | None, factory: bool, + socket_load_balance: bool = False, ) -> None: run( app, @@ -457,6 +465,7 @@ def main( factory=factory, app_dir=app_dir, h11_max_incomplete_event_size=h11_max_incomplete_event_size, + socket_load_balance=socket_load_balance, ) @@ -509,6 +518,7 @@ def run( app_dir: str | None = None, factory: bool = False, h11_max_incomplete_event_size: int | None = None, + socket_load_balance: bool = False, ) -> None: if app_dir is not None: sys.path.insert(0, app_dir) @@ -560,6 +570,7 @@ def run( use_colors=use_colors, factory=factory, h11_max_incomplete_event_size=h11_max_incomplete_event_size, + socket_load_balance=socket_load_balance, ) server = Server(config=config) @@ -570,11 +581,11 @@ def run( try: if config.should_reload: - sock = config.bind_socket() - ChangeReload(config, target=server.run, sockets=[sock]).run() + with config.bind_socket() as sock: + ChangeReload(config, target=server.run, sockets=[sock]).run() elif config.workers > 1: - sock = config.bind_socket() - Multiprocess(config, target=server.run, sockets=[sock]).run() + with config.bind_socket() as sock: + Multiprocess(config, target=server.run, sockets=[sock]).run() else: server.run() except KeyboardInterrupt: