Skip to content

Commit

Permalink
router: explicitly track endpoints
Browse files Browse the repository at this point in the history
The router currently keeps a mapping of open channels to the endpoints
responsible for them.  This presents two problems:

 - when we close down an endpoint, we need to iterate all open channels
   in order to determine which channels belong to that endpoint
 - it's possible to have active endpoints associated with the router
   which the router has no idea about

Move to a more explicit model where we add a second mapping: endpoints
to their set of open channels.  This makes endpoint shutdown easier and
adds the advantage that an endpoint with no channels can still be
tracked by the router (with an empty channel list).

The second point of this will be useful in future commits when the
router (and not the routing rules) become responsible for ensuring that
all endpoints are correctly shutdown.
  • Loading branch information
allisonkarlitskaya committed Nov 2, 2023
1 parent 0c634d1 commit be5a593
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 7 deletions.
6 changes: 2 additions & 4 deletions src/cockpit/beiboot.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,11 +260,9 @@ class SshBridge(Router):
ssh_peer: SshPeer

def __init__(self, args: argparse.Namespace):
super().__init__([])
self.ssh_peer = SshPeer(self, args.destination, args)

super().__init__([
DefaultRoutingRule(self, self.ssh_peer),
])
self.routing_rules.append(DefaultRoutingRule(self, self.ssh_peer))

def do_send_init(self):
pass # wait for the peer to do it first
Expand Down
2 changes: 1 addition & 1 deletion src/cockpit/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def is_closing(self) -> bool:
return self._close_args is not None

def _close_now(self):
self.send_control('close', **self._close_args)
self.shutdown_endpoint(**self._close_args)

def _task_done(self, task):
# Strictly speaking, we should read the result and check for exceptions but:
Expand Down
9 changes: 7 additions & 2 deletions src/cockpit/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class Endpoint:
__endpoint_frozen_queue: Optional[ExecutionQueue] = None

def __init__(self, router: 'Router'):
router.endpoints[self] = set()
self.router = router

def freeze_endpoint(self):
Expand Down Expand Up @@ -94,6 +95,7 @@ def send_channel_message(self, channel: str, **kwargs: JsonDocument) -> None:
def send_channel_control(self, channel, command, **kwargs: JsonDocument) -> None:
self.router.write_control(channel=channel, command=command, **kwargs)
if command == 'close':
self.router.endpoints[self].remove(channel)
self.router.drop_channel(channel)

def shutdown_endpoint(self, **kwargs: JsonDocument) -> None:
Expand Down Expand Up @@ -130,13 +132,15 @@ def shutdown(self):
class Router(CockpitProtocolServer):
routing_rules: List[RoutingRule]
open_channels: Dict[str, Endpoint]
endpoints: 'dict[Endpoint, set[str]]'
_eof: bool = False

def __init__(self, routing_rules: List[RoutingRule]):
for rule in routing_rules:
rule.router = self
self.routing_rules = routing_rules
self.open_channels = {}
self.endpoints = {}

def check_rules(self, options: JsonObject) -> Endpoint:
for rule in self.routing_rules:
Expand All @@ -161,14 +165,14 @@ def drop_channel(self, channel: str) -> None:
self.transport.close()

def shutdown_endpoint(self, endpoint: Endpoint, **kwargs) -> None:
channels = {key for key, value in self.open_channels.items() if value == endpoint}
channels = self.endpoints.pop(endpoint)
logger.debug('shutdown_endpoint(%s, %s) will close %s', endpoint, kwargs, channels)
for channel in channels:
self.write_control(command='close', channel=channel, **kwargs)
self.drop_channel(channel)

def do_kill(self, host: Optional[str], group: Optional[str]) -> None:
endpoints = set(self.open_channels.values())
endpoints = set(self.endpoints)
logger.debug('do_kill(%s, %s). Considering %d endpoints.', host, group, len(endpoints))
for endpoint in endpoints:
endpoint.do_kill(host, group)
Expand All @@ -189,6 +193,7 @@ def channel_control_received(self, channel: str, command: str, message: JsonObje
return

self.open_channels[channel] = endpoint
self.endpoints[endpoint].add(channel)
else:
try:
endpoint = self.open_channels[channel]
Expand Down

0 comments on commit be5a593

Please sign in to comment.