Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Fix topology verification #123

Merged
merged 7 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/qref/schema_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def _validate_connections(self) -> Self:
if missed_ports:
raise ValueError(
"The following ports appear in a connection but are not "
"among routine's port or their children's ports: {missed_ports}."
f"among routine's port or their children's ports: {missed_ports}."
)
return self

Expand Down
118 changes: 82 additions & 36 deletions src/qref/verification.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import defaultdict
from collections import Counter, defaultdict
from dataclasses import dataclass
from typing import Callable

from .functools import accepts_all_qref_types
from .schema_v1 import RoutineV1, SchemaV1
from .schema_v1 import RoutineV1

AdjacencyList = dict[str, list[str]]

Expand All @@ -36,55 +37,61 @@ def __bool__(self) -> bool:


@accepts_all_qref_types
def verify_topology(routine: SchemaV1 | RoutineV1) -> TopologyVerificationOutput:
def verify_topology(routine: RoutineV1) -> TopologyVerificationOutput:
"""Checks whether program has correct topology.

Correct topology cannot include cycles or disconnected ports.

Args:
routine: Routine or program to be verified.
"""
if isinstance(routine, SchemaV1):
routine = routine.program
problems = _verify_routine_topology(routine)
return TopologyVerificationOutput(problems)


def _verify_routine_topology(routine: RoutineV1) -> list[str]:
adjacency_list = _get_adjacency_list_from_routine(routine, path=None)
def _make_prefixer(ancestor_path: tuple[str, ...]) -> Callable[[str], str]:
def _prefix(text: str) -> str:
return ".".join([*ancestor_path, text])

return _prefix


def _verify_routine_topology(routine: RoutineV1, ancestor_path: tuple[str, ...] = ()) -> list[str]:
adjacency_list = _get_adjacency_list_from_routine(routine, path=ancestor_path)

return [
*_find_cycles(adjacency_list),
*_find_disconnected_ports(routine),
*[problem for child in routine.children for problem in _verify_routine_topology(child)],
*_find_cycles(adjacency_list, ancestor_path),
*_find_disconnected_ports(routine, ancestor_path),
*[
problem
for child in routine.children
for problem in _verify_routine_topology(child, ancestor_path + (routine.name,))
],
]


def _get_adjacency_list_from_routine(routine: RoutineV1, path: str | None) -> AdjacencyList:
def _get_adjacency_list_from_routine(routine: RoutineV1, path: tuple[str, ...]) -> AdjacencyList:
"""This function creates a flat graph representing one hierarchy level of a routine.

Nodes represent ports and edges represent connections (they're directed).
Additionaly, we add node for each children and edges coming from all the input ports
into the children, and from the children into all the output ports.
"""
graph = defaultdict[str, list[str]](list)
if path is None:
current_path = routine.name
else:
current_path = ".".join([path, routine.name])
_prefix = _make_prefixer(path + (routine.name,))

# First, we go through all the connections and add them as adges to the graph
for connection in routine.connections:
source = ".".join([current_path, connection.source])
target = ".".join([current_path, connection.target])
source = _prefix(connection.source)
target = _prefix(connection.target)
graph[source].append(target)

# Then for each children we add an extra node and set of connections
for child in routine.children:
input_ports: list[str] = []
output_ports: list[str] = []

child_path = ".".join([current_path, child.name])
child_path = _prefix(child.name)
for port in child.ports:
if port.direction == "input":
input_ports.append(".".join([child_path, port.name]))
Expand All @@ -99,20 +106,22 @@ def _get_adjacency_list_from_routine(routine: RoutineV1, path: str | None) -> Ad
return graph


def _find_cycles(adjacency_list: AdjacencyList) -> list[str]:
def _find_cycles(adjacency_list: AdjacencyList, ancestor_path: tuple[str, ...]) -> list[str]:
# Note: it only returns the first detected cycle.
for node in list(adjacency_list):
problem = _dfs_iteration(adjacency_list, node)
problem = _dfs_iteration(adjacency_list, node, ancestor_path)
if problem:
return problem
return []


def _dfs_iteration(adjacency_list: AdjacencyList, start_node: str) -> list[str]:
def _dfs_iteration(adjacency_list: AdjacencyList, start_node: str, ancestor_path: tuple[str, ...]) -> list[str]:
to_visit: list[str] = [start_node]
visited: list[str] = []
predecessors: dict[str, str] = {}

_prefix = _make_prefixer(ancestor_path)

while to_visit:
node = to_visit.pop()
visited.append(node)
Expand All @@ -122,29 +131,66 @@ def _dfs_iteration(adjacency_list: AdjacencyList, start_node: str) -> list[str]:
# Reconstruct the cycle
cycle = [neighbour]
while len(cycle) < 2 or cycle[-1] != start_node:
cycle.append(predecessors[cycle[-1]])
cycle.append(_prefix(predecessors[cycle[-1]]))
return [f"Cycle detected: {cycle[::-1]}"]
if neighbour not in visited:
to_visit.append(neighbour)
return []


def _find_disconnected_ports(routine: RoutineV1) -> list[str]:
def _find_disconnected_ports(routine: RoutineV1, ancestor_path: tuple[str, ...]) -> list[str]:
problems: list[str] = []

_prefix = _make_prefixer(ancestor_path + (routine.name,))

sources_counts = Counter[str]()
target_counts = Counter[str]()

for connection in routine.connections:
sources_counts[connection.source] += 1
target_counts[connection.target] += 1

multi_sources = [source for source, count in sources_counts.items() if count > 1]

multi_targets = [target for target, count in target_counts.items() if count > 1]

if multi_sources:
problems.append(f"Too many outgoing connections from {','.join(_prefix(target) for target in multi_sources)}.")

if multi_targets:
problems.append(f"Too many incoming connections to {','.join(_prefix(target) for target in multi_targets)}.")

requiring_outgoing = set[str]()
requiring_incoming = set[str]()
thru_ports = set[str]()

for port in routine.ports:
if port.direction == "input" and routine.children:
requiring_outgoing.add(port.name)
elif port.direction == "output" and routine.children:
requiring_incoming.add(port.name)
elif port.direction == "through": # Note: through ports have to be valid regardless of existence of children
thru_ports.add(port.name)

for child in routine.children:
# Directions are reversed compared to parent + through ports have to be connected on both ends
for port in child.ports:
pname = f"{routine.name}.{child.name}.{port.name}"
if port.direction == "input":
matches_in = [c for c in routine.connections if c.target == f"{child.name}.{port.name}"]
if len(matches_in) == 0:
problems.append(f"No incoming connections to {pname}.")
elif len(matches_in) > 1:
problems.append(f"Too many incoming connections to {pname}.")
elif port.direction == "output":
matches_out = [c for c in routine.connections if c.source == f"{child.name}.{port.name}"]
if len(matches_out) == 0:
problems.append(f"No outgoing connections from {pname}.")
elif len(matches_out) > 1:
problems.append(f"Too many outgoing connections from {pname}.")
pname = f"{child.name}.{port.name}"
if port.direction != "output":
requiring_incoming.add(pname)
if port.direction != "input":
requiring_outgoing.add(pname)

for pname in requiring_outgoing:
if pname not in sources_counts:
problems.append(f"No outgoing connection from {_prefix(pname)}.")

for pname in requiring_incoming:
if pname not in target_counts:
problems.append(f"No incoming connection to {_prefix(pname)}.")

for pname in thru_ports:
if pname in sources_counts or pname in target_counts:
problems.append(f"A through port {_prefix(pname)} is connected via an internal connection.")

return problems
62 changes: 60 additions & 2 deletions tests/qref/data/invalid_topology_programs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,65 @@
version: v1
description: Program has badly connected ports
problems:
- "No incoming connections to root.child_1.in_1."
- "No incoming connection to root.child_1.in_1."
- "Too many incoming connections to root.child_2.in_0."
- "Too many outgoing connections from root.child_2.out_0."
- "No outgoing connections from root.child_2.out_1."
- "No outgoing connection from root.child_2.out_1."
- input:
version: v1
program:
name: root
ports:
- direction: through
name: thru_0
size: 1
- direction: output
name: out_0
size: 1
connections:
- thru_0 -> out_0
description: "A routine with its thru port connected to its output port."
problems:
- "A through port root.thru_0 is connected via an internal connection."
- input:
version: v1
program:
name: root
ports:
- direction: input
name: in_0
size: N
- direction: output
name: out_0
size: N
- direction: output
name: out_1
size: N
connections:
- in_0 -> foo.in_0
- foo.out_0 -> out_0
- foo.out_1 -> out_1
children:
- name: foo
ports:
- direction: input
name: in_0
size: N
- direction: output
name: out_0
size: N
- direction: output
name: out_1
size: N
connections:
- bar.out_0 -> out_1
children:
- name: bar
ports:
- name: out_0
size: N
direction: output
description: "Program with disconnected container ports"
problems:
- "No outgoing connection from root.foo.in_0."
- "No incoming connection to root.foo.out_0."
18 changes: 18 additions & 0 deletions tests/qref/test_schema_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,21 @@ def test_invalid_program_fails_to_validate_with_pydantic_model_v1(input):

def test_valid_program_succesfully_validate_with_pydantic_model_v1(valid_program):
SchemaV1.model_validate(valid_program)


def test_validation_error_includes_name_of_the_missed_port():
input = {
"version": "v1",
"program": {
"name": "root",
"ports": [{"name": "in_0", "direction": "input", "size": 1}],
"connections": ["in_0 -> out_0"],
},
}

pattern = (
"The following ports appear in a connection but are not among routine's port "
r"or their children's ports: \['out_0'\]." # <- out_0 is the important bit here
)
with pytest.raises(pydantic.ValidationError, match=pattern):
SchemaV1.model_validate(input)
7 changes: 3 additions & 4 deletions tests/qref/test_topology_verification.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def test_correct_routines_pass_topology_validation(valid_program):
def test_invalid_program_fails_to_validate_with_schema_v1(input, problems):
verification_output = verify_topology(SchemaV1(**input))

assert not verification_output
assert len(problems) == len(verification_output.problems)
for expected_problem, problem in zip(problems, verification_output.problems):
assert expected_problem == problem
# We use sorted here, to make sure that we don't test the order in which the
# problems appear, as the order is only an implementation detail.
assert sorted(verification_output.problems) == sorted(problems)
Loading