Skip to content

Commit

Permalink
Improve tftp command-line interface
Browse files Browse the repository at this point in the history
Fix talkback
  • Loading branch information
vkottler committed Jul 22, 2024
1 parent 88b824c commit 017c6c2
Show file tree
Hide file tree
Showing 13 changed files with 153 additions and 45 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ jobs:
- run: |
mk python-release owner=vkottler \
repo=runtimepy version=5.4.1
repo=runtimepy version=5.4.2
if: |
matrix.python-version == '3.12'
&& matrix.system == 'ubuntu-latest'
Expand Down
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
=====================================
generator=datazen
version=3.1.4
hash=039855eb758d9eb1ea70df0654e31b61
hash=4f8a71a6066638ed1a90f375188f0578
=====================================
-->

# runtimepy ([5.4.1](https://pypi.org/project/runtimepy/))
# runtimepy ([5.4.2](https://pypi.org/project/runtimepy/))

[![python](https://img.shields.io/pypi/pyversions/runtimepy.svg)](https://pypi.org/project/runtimepy/)
![Build Status](https://github.com/vkottler/runtimepy/workflows/Python%20Package/badge.svg)
Expand Down Expand Up @@ -155,7 +155,7 @@ options:
$ ./venv3.12/bin/runtimepy tftp -h
usage: runtimepy tftp [-h] [-p PORT] [-m MODE] [-t TIMEOUT] [-r REEMIT]
{read,write} host our_file their_file
{read,write} host our_file [their_file]
positional arguments:
{read,write} action to perform
Expand Down
1 change: 1 addition & 0 deletions local/arbiter/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
test.txt
2 changes: 1 addition & 1 deletion local/variables/package.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
---
major: 5
minor: 4
patch: 1
patch: 2
entry: runtimepy
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta:__legacy__"

[project]
name = "runtimepy"
version = "5.4.1"
version = "5.4.2"
description = "A framework for implementing Python services."
readme = "README.md"
requires-python = ">=3.11"
Expand Down
4 changes: 2 additions & 2 deletions runtimepy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# =====================================
# generator=datazen
# version=3.1.4
# hash=acff0a8ea3e1379494862c3188de7a76
# hash=e6970089f5f2935c496cb3e9bb06b774
# =====================================

"""
Expand All @@ -10,7 +10,7 @@

DESCRIPTION = "A framework for implementing Python services."
PKG_NAME = "runtimepy"
VERSION = "5.4.1"
VERSION = "5.4.2"

# runtimepy-specific content.
METRICS_NAME = "metrics"
Expand Down
5 changes: 4 additions & 1 deletion runtimepy/commands/tftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ def tftp_cmd(args: argparse.Namespace) -> int:
"process_kwargs": {"stop_sig": stop_sig},
}

if not args.their_file:
args.their_file = str(args.our_file)

if args.operation == "read":
task = tftp_read(addr, args.our_file, args.their_file, **kwargs)
else:
Expand Down Expand Up @@ -83,6 +86,6 @@ def add_tftp_cmd(parser: argparse.ArgumentParser) -> CommandFunction:
parser.add_argument("host", help="host to message")

parser.add_argument("our_file", type=Path, help="path to our file")
parser.add_argument("their_file", type=str, help="path to their file")
parser.add_argument("their_file", nargs="?", help="path to their file")

return tftp_cmd
53 changes: 27 additions & 26 deletions runtimepy/net/udp/tftp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,19 @@ async def request_read(
) -> bool:
"""Request a tftp read operation."""

endpoint = self.endpoint(addr)
end_of_data = False
idx = 1

def ack_sender() -> None:
"""Send acks."""
nonlocal idx
self.send_ack(block=idx - 1, addr=addr)

async with AsyncExitStack() as stack:
# Claim read lock and ignore cancellation.
stack.enter_context(suppress(asyncio.CancelledError))
await stack.enter_async_context(endpoint.lock)

endpoint, event = await self._await_first_block(stack, addr=addr)

def ack_sender() -> None:
"""Send acks."""
nonlocal idx
endpoint.ack_sender(idx - 1, endpoint.addr)

def send_rrq() -> None:
"""Send request"""
Expand All @@ -60,9 +60,6 @@ def send_rrq() -> None:
"Requesting '%s' (%s) -> %s.", filename, mode, destination
)

event = asyncio.Event()
endpoint.awaiting_blocks[idx] = event

with self.log_time("Awaiting first data block", reminder=True):
# Wait for first data block.
if not await repeat_until(
Expand Down Expand Up @@ -112,21 +109,22 @@ def write_block() -> None:
if success:
write_block()

# Repeat last ack in the background.
if end_of_data:
self._conn_tasks.append(
asyncio.create_task(
repeat_until( # type: ignore
ack_sender,
asyncio.Event(),
endpoint.period.value,
endpoint.timeout.value,
# Repeat last ack in the background.
if end_of_data:
self._conn_tasks.append(
asyncio.create_task(
repeat_until( # type: ignore
ack_sender,
asyncio.Event(),
endpoint.period.value,
endpoint.timeout.value,
)
)
)
)

# Make a to-string or log method for vcorelib FileInfo?
#
# Ensure at least one ack sends.
await asyncio.sleep(0.01)

self.logger.info(
"Read %s (%s).",
FileInfo.from_file(destination),
Expand All @@ -146,16 +144,14 @@ async def request_write(
"""Request a tftp write operation."""

result = False
endpoint = self.endpoint(addr)

with as_path(source) as src:
async with AsyncExitStack() as stack:
# Claim write lock and ignore cancellation.
stack.enter_context(suppress(asyncio.CancelledError))
await stack.enter_async_context(endpoint.lock)

event = asyncio.Event()
endpoint.awaiting_acks[0] = event
# Set up first-ack handling.
endpoint, event = await self._await_first_ack(stack, addr=addr)

def send_wrq() -> None:
"""Send request."""
Expand Down Expand Up @@ -183,6 +179,11 @@ def send_wrq() -> None:
)

# Compare hashes.
self.logger.info(
"Reading '%s' %s.",
filename,
"succeeded" if result else "failed",
)
if result:
result = file_md5_hex(src) == file_md5_hex(tmp)
self.logger.info(
Expand Down
70 changes: 63 additions & 7 deletions runtimepy/net/udp/tftp/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
"""

# built-in
import asyncio
from contextlib import AsyncExitStack
from io import BytesIO
import logging
from pathlib import Path
Expand All @@ -22,6 +24,7 @@
encode_filename_mode,
parse_filename_mode,
)
from runtimepy.net.util import normalize_host
from runtimepy.primitives import Double, Uint16

REEMIT_PERIOD_S = 0.20
Expand Down Expand Up @@ -124,9 +127,39 @@ def error_sender(

self.error_sender = error_sender

self._endpoints: dict[str, TftpEndpoint] = {}
self._endpoints: dict[IpHost, TftpEndpoint] = {}
self._awaiting_first_ack: dict[str, TftpEndpoint] = {}
self._awaiting_first_block: dict[str, TftpEndpoint] = {}
# self._self = self.endpoint(self.local_address)

async def _await_first_ack(
self,
stack: AsyncExitStack,
addr: Union[IpHost, tuple[str, int]] = None,
) -> tuple[TftpEndpoint, asyncio.Event]:
"""Set up an endpoint to wait for an initial ack from a server."""

endpoint = self.endpoint(addr)
await stack.enter_async_context(endpoint.lock)
event = asyncio.Event()
endpoint.awaiting_acks[0] = event
self._awaiting_first_ack[endpoint.addr.hostname] = endpoint
return endpoint, event

async def _await_first_block(
self,
stack: AsyncExitStack,
addr: Union[IpHost, tuple[str, int]] = None,
) -> tuple[TftpEndpoint, asyncio.Event]:
"""Set up an endpoint to wait for an initial block from a server."""

endpoint = self.endpoint(addr)
await stack.enter_async_context(endpoint.lock)
event = asyncio.Event()
endpoint.awaiting_blocks[1] = event
self._awaiting_first_block[endpoint.addr.hostname] = endpoint
return endpoint, event

def endpoint(
self, addr: Union[IpHost, tuple[str, int]] = None
) -> TftpEndpoint:
Expand All @@ -136,10 +169,10 @@ def endpoint(
addr = self.remote_address

assert addr is not None
key = f"{addr[0]}:{addr[1]}"
addr = normalize_host(*addr)

if key not in self._endpoints:
self._endpoints[key] = TftpEndpoint(
if addr not in self._endpoints:
self._endpoints[addr] = TftpEndpoint(
self._path,
self.logger,
addr,
Expand All @@ -150,7 +183,7 @@ def endpoint(
self.endpoint_timeout,
)

return self._endpoints[key]
return self._endpoints[addr]

def send_rrq(
self,
Expand Down Expand Up @@ -270,15 +303,38 @@ async def _handle_data(
) -> None:
"""Handle a data message."""

endpoint = self.endpoint(addr)
block = self._read_block_number(stream)
self.endpoint(addr).handle_data(block, stream.read())

# Check if we're currently waiting for an initial block.
hostname = endpoint.addr.hostname
if block == 1 and hostname in self._awaiting_first_block:
to_update = self._awaiting_first_block[hostname]
del self._awaiting_first_block[hostname]
self._endpoints[endpoint.addr] = to_update
endpoint = to_update.update_from_other(endpoint)

endpoint.handle_data(block, stream.read())

async def _handle_ack(
self, stream: BinaryIO, addr: tuple[str, int]
) -> None:
"""Handle an acknowledge message."""

self.endpoint(addr).handle_ack(self._read_block_number(stream))
endpoint = self.endpoint(addr)
block = self._read_block_number(stream)

# Check if we're currently waiting for an initial acknowledgement. This
# will come from the same host but a different port, so update
# references when this is detected.
hostname = endpoint.addr.hostname
if block == 0 and hostname in self._awaiting_first_ack:
to_update = self._awaiting_first_ack[hostname]
del self._awaiting_first_ack[hostname]
self._endpoints[endpoint.addr] = to_update
endpoint = to_update.update_from_other(endpoint)

endpoint.handle_ack(block)

def _read_block_number(self, stream: BinaryIO) -> int:
"""Read block number from the stream."""
Expand Down
11 changes: 9 additions & 2 deletions runtimepy/net/udp/tftp/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(
self,
root: Path,
logger: LoggerType,
addr: Union[IpHost, tuple[str, int]],
addr: IpHost,
data_sender: TftpDataSender,
ack_sender: TftpAckSender,
error_sender: TftpErrorSender,
Expand Down Expand Up @@ -75,6 +75,13 @@ def __init__(
self.timeout = timeout
self.log_limiter = RateLimiter.from_s(1.0)

def update_from_other(self, other: "TftpEndpoint") -> "TftpEndpoint":
"""Update this endpoint's attributes with attributes of another's."""

self.logger.info("Updating address to '%s'.", other.addr)
self.addr = other.addr
return self

def chunk_sender(self, block: int, data: bytes) -> Callable[[], None]:
"""Create a method that sends a specific block of data."""

Expand Down Expand Up @@ -135,7 +142,7 @@ def handle_ack(self, block: int) -> None:

def __str__(self) -> str:
"""Get this instance as a string."""
return f"{self.addr[0]}:{self.addr[1]}"
return str(self.addr)

def handle_error(self, error_code: TftpErrorCode, message: str) -> None:
"""Handle a tftp error message."""
Expand Down
18 changes: 18 additions & 0 deletions runtimepy/net/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ class IPv4Host(NamedTuple):
name: str = ""
port: int = 0

@property
def hostname(self) -> str:
"""Get a hostname for this instance."""
return hostname(self.name)

@property
def address(self) -> ipaddress.IPv4Address:
"""Get an address object for this hostname."""
Expand All @@ -28,6 +33,10 @@ def __str__(self) -> str:
"""Get this host as a string."""
return hostname_port(self.name, self.port)

def __hash__(self) -> int:
"""Get a hash for this instance."""
return hash(str(self))


class IPv6Host(NamedTuple):
"""See: https://docs.python.org/3/library/socket.html#socket-families."""
Expand All @@ -37,6 +46,11 @@ class IPv6Host(NamedTuple):
flowinfo: int = 0
scope_id: int = 0

@property
def hostname(self) -> str:
"""Get a hostname for this instance."""
return hostname(self.name)

@property
def address(self) -> ipaddress.IPv6Address:
"""Get an address object for this hostname."""
Expand All @@ -46,6 +60,10 @@ def __str__(self) -> str:
"""Get this host as a string."""
return hostname_port(self.name, self.port)

def __hash__(self) -> int:
"""Get a hash for this instance."""
return hash(str(self))


IpHost = _Union[IPv4Host, IPv6Host]
IpHostlike = _Union[str, int, IpHost, None]
Expand Down
2 changes: 1 addition & 1 deletion tests/commands/test_tftp.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,4 @@ def test_tftp_command_basic():
path_fd.write("Hello, world!\n")

runtimepy_main(base + ["write", "localhost", str(ours), str(theirs)])
runtimepy_main(base + ["read", "localhost", str(ours), str(theirs)])
runtimepy_main(base + ["read", "localhost", str(ours)])
Loading

0 comments on commit 017c6c2

Please sign in to comment.