From 017c6c21fb5d97e546097da5107efa5b15cae3ee Mon Sep 17 00:00:00 2001 From: Vaughn Kottler Date: Fri, 19 Jul 2024 23:50:27 -0700 Subject: [PATCH] Improve tftp command-line interface Fix talkback --- .github/workflows/python-package.yml | 2 +- README.md | 6 +-- local/arbiter/.gitignore | 1 + local/variables/package.yaml | 2 +- pyproject.toml | 2 +- runtimepy/__init__.py | 4 +- runtimepy/commands/tftp.py | 5 +- runtimepy/net/udp/tftp/__init__.py | 53 ++++++++++----------- runtimepy/net/udp/tftp/base.py | 70 +++++++++++++++++++++++++--- runtimepy/net/udp/tftp/endpoint.py | 11 ++++- runtimepy/net/util.py | 18 +++++++ tests/commands/test_tftp.py | 2 +- tests/net/test_util.py | 22 +++++++++ 13 files changed, 153 insertions(+), 45 deletions(-) create mode 100644 local/arbiter/.gitignore create mode 100644 tests/net/test_util.py diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 56eac567..d38aece2 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -77,7 +77,7 @@ jobs: - run: | mk python-release owner=vkottler \ - repo=runtimepy version=5.4.1 + repo=runtimepy version=5.4.2 if: | matrix.python-version == '3.12' && matrix.system == 'ubuntu-latest' diff --git a/README.md b/README.md index 13d49ba4..8013c056 100644 --- a/README.md +++ b/README.md @@ -2,11 +2,11 @@ ===================================== generator=datazen version=3.1.4 - hash=039855eb758d9eb1ea70df0654e31b61 + hash=4f8a71a6066638ed1a90f375188f0578 ===================================== --> -# runtimepy ([5.4.1](https://pypi.org/project/runtimepy/)) +# runtimepy ([5.4.2](https://pypi.org/project/runtimepy/)) [![python](https://img.shields.io/pypi/pyversions/runtimepy.svg)](https://pypi.org/project/runtimepy/) ![Build Status](https://github.com/vkottler/runtimepy/workflows/Python%20Package/badge.svg) @@ -155,7 +155,7 @@ options: $ ./venv3.12/bin/runtimepy tftp -h usage: runtimepy tftp [-h] [-p PORT] [-m MODE] [-t TIMEOUT] [-r REEMIT] - {read,write} host our_file their_file + {read,write} host our_file [their_file] positional arguments: {read,write} action to perform diff --git a/local/arbiter/.gitignore b/local/arbiter/.gitignore new file mode 100644 index 00000000..4871fd52 --- /dev/null +++ b/local/arbiter/.gitignore @@ -0,0 +1 @@ +test.txt diff --git a/local/variables/package.yaml b/local/variables/package.yaml index 5d0f9fe1..de679e9b 100644 --- a/local/variables/package.yaml +++ b/local/variables/package.yaml @@ -1,5 +1,5 @@ --- major: 5 minor: 4 -patch: 1 +patch: 2 entry: runtimepy diff --git a/pyproject.toml b/pyproject.toml index c7f58257..8ecab8bc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta:__legacy__" [project] name = "runtimepy" -version = "5.4.1" +version = "5.4.2" description = "A framework for implementing Python services." readme = "README.md" requires-python = ">=3.11" diff --git a/runtimepy/__init__.py b/runtimepy/__init__.py index ae5947c0..56e07970 100644 --- a/runtimepy/__init__.py +++ b/runtimepy/__init__.py @@ -1,7 +1,7 @@ # ===================================== # generator=datazen # version=3.1.4 -# hash=acff0a8ea3e1379494862c3188de7a76 +# hash=e6970089f5f2935c496cb3e9bb06b774 # ===================================== """ @@ -10,7 +10,7 @@ DESCRIPTION = "A framework for implementing Python services." PKG_NAME = "runtimepy" -VERSION = "5.4.1" +VERSION = "5.4.2" # runtimepy-specific content. METRICS_NAME = "metrics" diff --git a/runtimepy/commands/tftp.py b/runtimepy/commands/tftp.py index 6b55f73f..d7949f8d 100644 --- a/runtimepy/commands/tftp.py +++ b/runtimepy/commands/tftp.py @@ -33,6 +33,9 @@ def tftp_cmd(args: argparse.Namespace) -> int: "process_kwargs": {"stop_sig": stop_sig}, } + if not args.their_file: + args.their_file = str(args.our_file) + if args.operation == "read": task = tftp_read(addr, args.our_file, args.their_file, **kwargs) else: @@ -83,6 +86,6 @@ def add_tftp_cmd(parser: argparse.ArgumentParser) -> CommandFunction: parser.add_argument("host", help="host to message") parser.add_argument("our_file", type=Path, help="path to our file") - parser.add_argument("their_file", type=str, help="path to their file") + parser.add_argument("their_file", nargs="?", help="path to their file") return tftp_cmd diff --git a/runtimepy/net/udp/tftp/__init__.py b/runtimepy/net/udp/tftp/__init__.py index e7eb5b10..13f2373e 100644 --- a/runtimepy/net/udp/tftp/__init__.py +++ b/runtimepy/net/udp/tftp/__init__.py @@ -38,19 +38,19 @@ async def request_read( ) -> bool: """Request a tftp read operation.""" - endpoint = self.endpoint(addr) end_of_data = False idx = 1 - def ack_sender() -> None: - """Send acks.""" - nonlocal idx - self.send_ack(block=idx - 1, addr=addr) - async with AsyncExitStack() as stack: # Claim read lock and ignore cancellation. stack.enter_context(suppress(asyncio.CancelledError)) - await stack.enter_async_context(endpoint.lock) + + endpoint, event = await self._await_first_block(stack, addr=addr) + + def ack_sender() -> None: + """Send acks.""" + nonlocal idx + endpoint.ack_sender(idx - 1, endpoint.addr) def send_rrq() -> None: """Send request""" @@ -60,9 +60,6 @@ def send_rrq() -> None: "Requesting '%s' (%s) -> %s.", filename, mode, destination ) - event = asyncio.Event() - endpoint.awaiting_blocks[idx] = event - with self.log_time("Awaiting first data block", reminder=True): # Wait for first data block. if not await repeat_until( @@ -112,21 +109,22 @@ def write_block() -> None: if success: write_block() - # Repeat last ack in the background. - if end_of_data: - self._conn_tasks.append( - asyncio.create_task( - repeat_until( # type: ignore - ack_sender, - asyncio.Event(), - endpoint.period.value, - endpoint.timeout.value, + # Repeat last ack in the background. + if end_of_data: + self._conn_tasks.append( + asyncio.create_task( + repeat_until( # type: ignore + ack_sender, + asyncio.Event(), + endpoint.period.value, + endpoint.timeout.value, + ) ) ) - ) - # Make a to-string or log method for vcorelib FileInfo? - # + # Ensure at least one ack sends. + await asyncio.sleep(0.01) + self.logger.info( "Read %s (%s).", FileInfo.from_file(destination), @@ -146,16 +144,14 @@ async def request_write( """Request a tftp write operation.""" result = False - endpoint = self.endpoint(addr) with as_path(source) as src: async with AsyncExitStack() as stack: # Claim write lock and ignore cancellation. stack.enter_context(suppress(asyncio.CancelledError)) - await stack.enter_async_context(endpoint.lock) - event = asyncio.Event() - endpoint.awaiting_acks[0] = event + # Set up first-ack handling. + endpoint, event = await self._await_first_ack(stack, addr=addr) def send_wrq() -> None: """Send request.""" @@ -183,6 +179,11 @@ def send_wrq() -> None: ) # Compare hashes. + self.logger.info( + "Reading '%s' %s.", + filename, + "succeeded" if result else "failed", + ) if result: result = file_md5_hex(src) == file_md5_hex(tmp) self.logger.info( diff --git a/runtimepy/net/udp/tftp/base.py b/runtimepy/net/udp/tftp/base.py index 18a1d249..1511adc2 100644 --- a/runtimepy/net/udp/tftp/base.py +++ b/runtimepy/net/udp/tftp/base.py @@ -3,6 +3,8 @@ """ # built-in +import asyncio +from contextlib import AsyncExitStack from io import BytesIO import logging from pathlib import Path @@ -22,6 +24,7 @@ encode_filename_mode, parse_filename_mode, ) +from runtimepy.net.util import normalize_host from runtimepy.primitives import Double, Uint16 REEMIT_PERIOD_S = 0.20 @@ -124,9 +127,39 @@ def error_sender( self.error_sender = error_sender - self._endpoints: dict[str, TftpEndpoint] = {} + self._endpoints: dict[IpHost, TftpEndpoint] = {} + self._awaiting_first_ack: dict[str, TftpEndpoint] = {} + self._awaiting_first_block: dict[str, TftpEndpoint] = {} # self._self = self.endpoint(self.local_address) + async def _await_first_ack( + self, + stack: AsyncExitStack, + addr: Union[IpHost, tuple[str, int]] = None, + ) -> tuple[TftpEndpoint, asyncio.Event]: + """Set up an endpoint to wait for an initial ack from a server.""" + + endpoint = self.endpoint(addr) + await stack.enter_async_context(endpoint.lock) + event = asyncio.Event() + endpoint.awaiting_acks[0] = event + self._awaiting_first_ack[endpoint.addr.hostname] = endpoint + return endpoint, event + + async def _await_first_block( + self, + stack: AsyncExitStack, + addr: Union[IpHost, tuple[str, int]] = None, + ) -> tuple[TftpEndpoint, asyncio.Event]: + """Set up an endpoint to wait for an initial block from a server.""" + + endpoint = self.endpoint(addr) + await stack.enter_async_context(endpoint.lock) + event = asyncio.Event() + endpoint.awaiting_blocks[1] = event + self._awaiting_first_block[endpoint.addr.hostname] = endpoint + return endpoint, event + def endpoint( self, addr: Union[IpHost, tuple[str, int]] = None ) -> TftpEndpoint: @@ -136,10 +169,10 @@ def endpoint( addr = self.remote_address assert addr is not None - key = f"{addr[0]}:{addr[1]}" + addr = normalize_host(*addr) - if key not in self._endpoints: - self._endpoints[key] = TftpEndpoint( + if addr not in self._endpoints: + self._endpoints[addr] = TftpEndpoint( self._path, self.logger, addr, @@ -150,7 +183,7 @@ def endpoint( self.endpoint_timeout, ) - return self._endpoints[key] + return self._endpoints[addr] def send_rrq( self, @@ -270,15 +303,38 @@ async def _handle_data( ) -> None: """Handle a data message.""" + endpoint = self.endpoint(addr) block = self._read_block_number(stream) - self.endpoint(addr).handle_data(block, stream.read()) + + # Check if we're currently waiting for an initial block. + hostname = endpoint.addr.hostname + if block == 1 and hostname in self._awaiting_first_block: + to_update = self._awaiting_first_block[hostname] + del self._awaiting_first_block[hostname] + self._endpoints[endpoint.addr] = to_update + endpoint = to_update.update_from_other(endpoint) + + endpoint.handle_data(block, stream.read()) async def _handle_ack( self, stream: BinaryIO, addr: tuple[str, int] ) -> None: """Handle an acknowledge message.""" - self.endpoint(addr).handle_ack(self._read_block_number(stream)) + endpoint = self.endpoint(addr) + block = self._read_block_number(stream) + + # Check if we're currently waiting for an initial acknowledgement. This + # will come from the same host but a different port, so update + # references when this is detected. + hostname = endpoint.addr.hostname + if block == 0 and hostname in self._awaiting_first_ack: + to_update = self._awaiting_first_ack[hostname] + del self._awaiting_first_ack[hostname] + self._endpoints[endpoint.addr] = to_update + endpoint = to_update.update_from_other(endpoint) + + endpoint.handle_ack(block) def _read_block_number(self, stream: BinaryIO) -> int: """Read block number from the stream.""" diff --git a/runtimepy/net/udp/tftp/endpoint.py b/runtimepy/net/udp/tftp/endpoint.py index 1c279960..ef7c9365 100644 --- a/runtimepy/net/udp/tftp/endpoint.py +++ b/runtimepy/net/udp/tftp/endpoint.py @@ -40,7 +40,7 @@ def __init__( self, root: Path, logger: LoggerType, - addr: Union[IpHost, tuple[str, int]], + addr: IpHost, data_sender: TftpDataSender, ack_sender: TftpAckSender, error_sender: TftpErrorSender, @@ -75,6 +75,13 @@ def __init__( self.timeout = timeout self.log_limiter = RateLimiter.from_s(1.0) + def update_from_other(self, other: "TftpEndpoint") -> "TftpEndpoint": + """Update this endpoint's attributes with attributes of another's.""" + + self.logger.info("Updating address to '%s'.", other.addr) + self.addr = other.addr + return self + def chunk_sender(self, block: int, data: bytes) -> Callable[[], None]: """Create a method that sends a specific block of data.""" @@ -135,7 +142,7 @@ def handle_ack(self, block: int) -> None: def __str__(self) -> str: """Get this instance as a string.""" - return f"{self.addr[0]}:{self.addr[1]}" + return str(self.addr) def handle_error(self, error_code: TftpErrorCode, message: str) -> None: """Handle a tftp error message.""" diff --git a/runtimepy/net/util.py b/runtimepy/net/util.py index 6585db9a..5020f41c 100644 --- a/runtimepy/net/util.py +++ b/runtimepy/net/util.py @@ -19,6 +19,11 @@ class IPv4Host(NamedTuple): name: str = "" port: int = 0 + @property + def hostname(self) -> str: + """Get a hostname for this instance.""" + return hostname(self.name) + @property def address(self) -> ipaddress.IPv4Address: """Get an address object for this hostname.""" @@ -28,6 +33,10 @@ def __str__(self) -> str: """Get this host as a string.""" return hostname_port(self.name, self.port) + def __hash__(self) -> int: + """Get a hash for this instance.""" + return hash(str(self)) + class IPv6Host(NamedTuple): """See: https://docs.python.org/3/library/socket.html#socket-families.""" @@ -37,6 +46,11 @@ class IPv6Host(NamedTuple): flowinfo: int = 0 scope_id: int = 0 + @property + def hostname(self) -> str: + """Get a hostname for this instance.""" + return hostname(self.name) + @property def address(self) -> ipaddress.IPv6Address: """Get an address object for this hostname.""" @@ -46,6 +60,10 @@ def __str__(self) -> str: """Get this host as a string.""" return hostname_port(self.name, self.port) + def __hash__(self) -> int: + """Get a hash for this instance.""" + return hash(str(self)) + IpHost = _Union[IPv4Host, IPv6Host] IpHostlike = _Union[str, int, IpHost, None] diff --git a/tests/commands/test_tftp.py b/tests/commands/test_tftp.py index 3f3855ce..8fff1581 100644 --- a/tests/commands/test_tftp.py +++ b/tests/commands/test_tftp.py @@ -33,4 +33,4 @@ def test_tftp_command_basic(): path_fd.write("Hello, world!\n") runtimepy_main(base + ["write", "localhost", str(ours), str(theirs)]) - runtimepy_main(base + ["read", "localhost", str(ours), str(theirs)]) + runtimepy_main(base + ["read", "localhost", str(ours)]) diff --git a/tests/net/test_util.py b/tests/net/test_util.py new file mode 100644 index 00000000..634e0830 --- /dev/null +++ b/tests/net/test_util.py @@ -0,0 +1,22 @@ +""" +Test the 'net.util' module. +""" + +# module under test +from runtimepy.net.util import IPv4Host, IPv6Host + + +def test_ip_hosts(): + """Test basic instantiations of IP host instances.""" + + ipv4 = IPv4Host("127.0.0.1", 0) + assert ipv4.hostname + assert ipv4.address + assert str(ipv4) + assert hash(ipv4) + + ipv6 = IPv6Host("::1", 0) + assert ipv6.hostname + assert ipv6.address + assert str(ipv6) + assert hash(ipv6)