From 97d36b797101051c27ac09a709d2a78eca333271 Mon Sep 17 00:00:00 2001 From: Vaughn Kottler Date: Thu, 19 Oct 2023 22:08:42 -0500 Subject: [PATCH] 2.14.0 - Add support for connection restarting --- .github/workflows/python-package.yml | 2 +- README.md | 4 +- config | 2 +- local/variables/package.yaml | 4 +- pyproject.toml | 2 +- runtimepy/__init__.py | 4 +- runtimepy/net/connection.py | 61 ++++++++++++++++-- runtimepy/net/manager.py | 14 ++++- runtimepy/net/mixin.py | 9 ++- runtimepy/net/tcp/connection.py | 48 +++++++++++++-- tests/net/tcp/test_connection.py | 92 +++++++++++++++++++++++++++- tests/net/test_connection.py | 4 ++ 12 files changed, 223 insertions(+), 23 deletions(-) diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index aef0978b..23807057 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -68,7 +68,7 @@ jobs: - run: | mk python-release owner=vkottler \ - repo=runtimepy version=2.13.4 + repo=runtimepy version=2.14.0 if: | matrix.python-version == '3.11' && matrix.system == 'ubuntu-latest' diff --git a/README.md b/README.md index fc6c4916..a6630810 100644 --- a/README.md +++ b/README.md @@ -2,11 +2,11 @@ ===================================== generator=datazen version=3.1.3 - hash=50efa13ced2b135f66aff7b19a87d83d + hash=995dec4050d0b078008ad8e4911d515c ===================================== --> -# runtimepy ([2.13.4](https://pypi.org/project/runtimepy/)) +# runtimepy ([2.14.0](https://pypi.org/project/runtimepy/)) [![python](https://img.shields.io/pypi/pyversions/runtimepy.svg)](https://pypi.org/project/runtimepy/) ![Build Status](https://github.com/vkottler/runtimepy/workflows/Python%20Package/badge.svg) diff --git a/config b/config index 3eb2947d..0c555bf6 160000 --- a/config +++ b/config @@ -1 +1 @@ -Subproject commit 3eb2947d46a1767933af2d126638db8e49565862 +Subproject commit 0c555bf6565cc5d90408adbad3c162edca43a7e8 diff --git a/local/variables/package.yaml b/local/variables/package.yaml index 9c04892e..767d4ff1 100644 --- a/local/variables/package.yaml +++ b/local/variables/package.yaml @@ -1,5 +1,5 @@ --- major: 2 -minor: 13 -patch: 4 +minor: 14 +patch: 0 entry: runtimepy diff --git a/pyproject.toml b/pyproject.toml index 19d9d166..68d1b989 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta:__legacy__" [project] name = "runtimepy" -version = "2.13.4" +version = "2.14.0" description = "A framework for implementing Python services." readme = "README.md" requires-python = ">=3.11" diff --git a/runtimepy/__init__.py b/runtimepy/__init__.py index c1f276eb..a3affa3b 100644 --- a/runtimepy/__init__.py +++ b/runtimepy/__init__.py @@ -1,7 +1,7 @@ # ===================================== # generator=datazen # version=3.1.3 -# hash=95514f08f4202b6195f2631ae525def6 +# hash=ce4d96baa832f20ead1736b028ac1141 # ===================================== """ @@ -10,7 +10,7 @@ DESCRIPTION = "A framework for implementing Python services." PKG_NAME = "runtimepy" -VERSION = "2.13.4" +VERSION = "2.14.0" # runtimepy-specific content. METRICS_NAME = "metrics" diff --git a/runtimepy/net/connection.py b/runtimepy/net/connection.py index e81f95d4..6fa38ed8 100644 --- a/runtimepy/net/connection.py +++ b/runtimepy/net/connection.py @@ -20,7 +20,7 @@ from runtimepy.metrics import ConnectionMetrics from runtimepy.mixins.environment import ChannelEnvironmentMixin from runtimepy.mixins.logging import LoggerMixinLevelControl -from runtimepy.primitives import Bool +from runtimepy.primitives import Bool, Uint8 from runtimepy.primitives.byte_order import DEFAULT_BYTE_ORDER, ByteOrder BinaryMessage = _Union[bytes, bytearray, memoryview] @@ -44,7 +44,6 @@ def __init__( """Initialize this connection.""" LoggerMixinLevelControl.__init__(self, logger=logger) - self._enabled = Bool(True) # A queue for out-going text messages. Connections that don't use # this can set 'uses_text_tx_queue' to False to avoid scheduling a @@ -60,7 +59,6 @@ def __init__( self._tasks: _List[_asyncio.Task[None]] = [] self.initialized = _asyncio.Event() - self.disabled_event = _asyncio.Event() self.metrics = ConnectionMetrics() @@ -71,10 +69,24 @@ def __init__( self.register_connection_metrics(self.metrics) # State. + self._enabled = Bool() + self.disabled_event = _asyncio.Event() self.env.channel("enabled", self._enabled) + self._set_enabled(True) + + self._restarts = Uint8() + self.env.channel("restarts", self._restarts) + + self._auto_restart = Bool() + self.env.channel("auto_restart", self._auto_restart) self.init() + @property + def auto_restart(self) -> bool: + """Determine if this connection should be automatically restarted.""" + return bool(self._auto_restart) + def init(self) -> None: """Initialize this instance.""" @@ -132,13 +144,22 @@ def disabled(self) -> bool: def disable_extra(self) -> None: """Additional tasks to perform when disabling.""" + def _set_enabled(self, state: bool) -> None: + """Set the enabled state for this connection.""" + + self._enabled.value = state + if not state: + self.disabled_event.set() + self.initialized.clear() + else: + self.disabled_event.clear() + def disable(self, reason: str) -> None: """Disable this connection.""" if self._enabled: self.logger.info("Disabling connection: '%s'.", reason) self.disable_extra() - self._enabled.value = False # Cancel tasks. for task in self._tasks: @@ -146,10 +167,11 @@ def disable(self, reason: str) -> None: task.cancel() # Signal that this connection has been disabled. - self.disabled_event.set() + self._set_enabled(False) async def _wait_sig(self, stop_sig: _asyncio.Event) -> None: """Disable the connection if a stop signal gets set.""" + await stop_sig.wait() self.disable("stop signal") @@ -164,16 +186,43 @@ async def _async_init(self) -> None: self.env.finalize(strict=False) self.initialized.set() - async def process(self, stop_sig: _asyncio.Event = None) -> None: + async def restart(self) -> bool: + """ + Reset necessary underlying state for this connection to 'process' + again. + """ + raise NotImplementedError + + async def disable_in(self, time: float) -> None: + """A method for disabling a connection after some delay.""" + + await _asyncio.sleep(time) + self.disable(f"timed disable ({time}s)") + + async def process( + self, stop_sig: _asyncio.Event = None, disable_time: float = None + ) -> None: """ Process tasks for this connection while the connection is active. """ + # Try to re-enable the connection if necessary. + if self.disabled and (stop_sig is None or not stop_sig.is_set()): + assert await self.restart() + self._set_enabled(True) + self._restarts.raw.value += 1 + self._tasks = [ _asyncio.create_task(self._process_read()), _asyncio.create_task(self._async_init()), ] + # Disable the connection automatically if requested. + if disable_time is not None: + self._tasks.append( + _asyncio.create_task(self.disable_in(disable_time)) + ) + if self.uses_text_tx_queue: self._tasks.append( _asyncio.create_task(self._process_write_text()) diff --git a/runtimepy/net/manager.py b/runtimepy/net/manager.py index 13315284..bb23404a 100644 --- a/runtimepy/net/manager.py +++ b/runtimepy/net/manager.py @@ -76,7 +76,19 @@ async def manage(self, stop_sig: _asyncio.Event) -> None: next_tasks = _log_exceptions(tasks) # Filter out disabled connections. - self._conns = [x for x in self._conns if not x.disabled] + enabled = [] + for conn in self._conns: + if not conn.disabled: + enabled.append(conn) + + # Check if this connection should be restarted. + elif conn.auto_restart: + next_tasks.append( + _asyncio.create_task(conn.process(stop_sig=stop_sig)) + ) + enabled.append(conn) + + self._conns = enabled # If a new connection was made, register a task for processing # it. diff --git a/runtimepy/net/mixin.py b/runtimepy/net/mixin.py index d7cae657..ca458af3 100644 --- a/runtimepy/net/mixin.py +++ b/runtimepy/net/mixin.py @@ -30,8 +30,8 @@ class TransportMixin: _transport: _asyncio.BaseTransport - def __init__(self, transport: _asyncio.BaseTransport) -> None: - """Initialize this instance.""" + def set_transport(self, transport: _asyncio.BaseTransport) -> None: + """Set the transport for this instance.""" self._transport = transport @@ -45,6 +45,11 @@ def __init__(self, transport: _asyncio.BaseTransport) -> None: # None). self.remote_address = self._remote_address() + def __init__(self, transport: _asyncio.BaseTransport) -> None: + """Initialize this instance.""" + + self.set_transport(transport) + @property def socket(self) -> _SocketType: """Get this instance's underlying socket.""" diff --git a/runtimepy/net/tcp/connection.py b/runtimepy/net/tcp/connection.py index e92fc694..a28cc084 100644 --- a/runtimepy/net/tcp/connection.py +++ b/runtimepy/net/tcp/connection.py @@ -85,10 +85,18 @@ def __init__(self, transport: _Transport, protocol: QueueProtocol) -> None: # Re-assign with updated type information. self._transport: _Transport = transport + self._set_protocol(protocol) + + super().__init__(_getLogger(self.logger_name("TCP "))) + + # Store connection-instantiation arguments. + self._conn_kwargs: dict[str, _Any] = {} + + def _set_protocol(self, protocol: QueueProtocol) -> None: + """Set a new protocol for this instance.""" self._protocol = protocol self._protocol.conn = self - super().__init__(_getLogger(self.logger_name("TCP "))) async def _await_message(self) -> _Optional[_Union[_BinaryMessage, str]]: """Await the next message. Return None on error or failure.""" @@ -108,8 +116,13 @@ def send_binary(self, data: _BinaryMessage) -> None: self.metrics.tx.increment(len(data)) @classmethod - async def create_connection(cls: _Type[T], **kwargs) -> T: - """Create a TCP connection.""" + async def _transport_protocol( + cls: _Type[T], **kwargs + ) -> tuple[_Transport, QueueProtocol]: + """ + Create a transport and protocol pair relevant for this class's + implementation. + """ eloop = _get_event_loop() @@ -117,7 +130,34 @@ async def create_connection(cls: _Type[T], **kwargs) -> T: transport, protocol = await eloop.create_connection( QueueProtocol, **kwargs ) - return cls(transport, protocol) + return transport, protocol + + async def restart(self) -> bool: + """ + Reset necessary underlying state for this connection to 'process' + again. + """ + + transport, protocol = await self._transport_protocol( + **self._conn_kwargs + ) + self.set_transport(transport) + self._set_protocol(protocol) + + return True + + @classmethod + async def create_connection(cls: _Type[T], **kwargs) -> T: + """Create a TCP connection.""" + + transport, protocol = await cls._transport_protocol(**kwargs) + inst = cls(transport, protocol) + + # Is there a better way to do this? We can't restart a server's side + # of a connection (seems okay). + inst._conn_kwargs = {**kwargs} + + return inst @classmethod @_asynccontextmanager diff --git a/tests/net/tcp/test_connection.py b/tests/net/tcp/test_connection.py index fd730406..797ca17f 100644 --- a/tests/net/tcp/test_connection.py +++ b/tests/net/tcp/test_connection.py @@ -9,7 +9,7 @@ from pytest import mark # module under test -from runtimepy.net import sockname +from runtimepy.net import get_free_socket_name, normalize_host, sockname from runtimepy.net.manager import ConnectionManager # internal @@ -38,6 +38,96 @@ async def test_tcp_connection_basic(): ) +@mark.asyncio +async def test_tcp_connection_restart(): + """Test that a TCP connection can be restarted.""" + + host = "127.0.0.1" + port = get_free_socket_name(local=normalize_host(host)).port + + async with SampleTcpConnection.serve(host=host, port=port): + client = await SampleTcpConnection.create_connection( + host=host, port=port + ) + + # Run the connection for a bit. + await client.process(disable_time=0.1) + + # Run the connection again (triggering a restart). + await client.process(disable_time=0.1) + + # Confirm the connection did restart. + assert client.env.value("restarts") == 1 + + +@mark.asyncio +async def test_tcp_connection_manager_auto_restart(): + """ + Test that a connection manager can automatically restart TCP connections. + """ + + manager = ConnectionManager() + sig = asyncio.Event() + host_queue: asyncio.Queue = asyncio.Queue() + + def app(conn: SampleTcpConnection) -> None: + """A sample application callback.""" + assert conn + + def serve_cb(server) -> None: + """Publish the server host.""" + host_queue.put_nowait(sockname(server.sockets[0])) + + async def connect() -> None: + """Connect to the server a number of times.""" + + # Wait for the server to start. + host = await host_queue.get() + + conn = await SampleTcpConnection.create_connection( + host="localhost", port=host.port + ) + + # Allow the connection manager to manage this connection. + await manager.queue.put(conn) + + await conn.initialized.wait() + + # Enable connection restart. + conn.env.set("auto_restart", True) + + # Disable the connection. + conn.disable("testing restart") + + # Wait for re-connect. + await conn.initialized.wait() + + # Confirm the connection did restart. + assert conn.env.value("restarts") == 1 + + # End test. + conn.env.set("auto_restart", False) + conn.disable("ending test") + sig.set() + + await asyncio.wait( + [ + asyncio.create_task(x) + for x in [ + connect(), + SampleTcpConnection.app( + sig, + callback=app, + serving_callback=serve_cb, + port=0, + manager=manager, + ), + ] + ], + return_when=asyncio.ALL_COMPLETED, + ) + + @mark.asyncio async def test_tcp_connection_app(): """Test the TCP connection's application interface.""" diff --git a/tests/net/test_connection.py b/tests/net/test_connection.py index 7f7ab6b2..56ef421a 100644 --- a/tests/net/test_connection.py +++ b/tests/net/test_connection.py @@ -37,3 +37,7 @@ async def test_connection_basic(): await conn._send_binay_message( # pylint: disable=protected-access "test".encode() ) + + conn.disable("testing") + with raises(NotImplementedError): + await conn.process()