Skip to content

Commit

Permalink
Speed up StateReachability pass for large state machines
Browse files Browse the repository at this point in the history
  • Loading branch information
tbennun committed Sep 20, 2023
1 parent 680a956 commit aa88d82
Showing 1 changed file with 39 additions and 25 deletions.
64 changes: 39 additions & 25 deletions dace/transformation/passes/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Set[Tuple[SDFGState, Union[nd.AccessNode, InterstateEdge]]]]]
SymbolScopeDict = Dict[str, Dict[Edge[InterstateEdge], Set[Union[Edge[InterstateEdge], SDFGState]]]]


@properties.make_properties
class StateReachability(ppl.Pass):
"""
Expand All @@ -35,10 +36,20 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[SDFGState, Set[SDFGSta
"""
reachable: Dict[int, Dict[SDFGState, Set[SDFGState]]] = {}
for sdfg in top_sdfg.all_sdfgs_recursive():
reachable[sdfg.sdfg_id] = {}
tc: nx.DiGraph = nx.transitive_closure(sdfg.nx)
for state in sdfg.nodes():
reachable[sdfg.sdfg_id][state] = set(tc.successors(state))
result: Dict[SDFGState, Set[SDFGState]] = {}

# In networkx this is currently implemented naively for directed graphs.
# The implementation below is faster
# tc: nx.DiGraph = nx.transitive_closure(sdfg.nx)

for n, v in nx.all_pairs_shortest_path_length(sdfg.nx):
result[n] = set(t for t, l in v.items() if l > 0)
# Add self-edges
if n in sdfg.successors(n):
result[n].add(n)

reachable[sdfg.sdfg_id] = result

return reachable


Expand All @@ -57,9 +68,8 @@ def should_reapply(self, modified: ppl.Modifies) -> bool:
# If anything was modified, reapply
return modified & ppl.Modifies.States | ppl.Modifies.Edges | ppl.Modifies.Symbols | ppl.Modifies.Nodes

def apply_pass(
self, top_sdfg: SDFG, _
) -> Dict[int, Dict[Union[SDFGState, Edge[InterstateEdge]], Tuple[Set[str], Set[str]]]]:
def apply_pass(self, top_sdfg: SDFG,
_) -> Dict[int, Dict[Union[SDFGState, Edge[InterstateEdge]], Tuple[Set[str], Set[str]]]]:
"""
:return: A dictionary mapping each state to a tuple of its (read, written) data descriptors.
"""
Expand Down Expand Up @@ -216,9 +226,8 @@ def should_reapply(self, modified: ppl.Modifies) -> bool:
def depends_on(self):
return {SymbolAccessSets, StateReachability}

def _find_dominating_write(
self, sym: str, read: Union[SDFGState, Edge[InterstateEdge]], state_idom: Dict[SDFGState, SDFGState]
) -> Optional[Edge[InterstateEdge]]:
def _find_dominating_write(self, sym: str, read: Union[SDFGState, Edge[InterstateEdge]],
state_idom: Dict[SDFGState, SDFGState]) -> Optional[Edge[InterstateEdge]]:
last_state: SDFGState = read if isinstance(read, SDFGState) else read.src

in_edges = last_state.parent.in_edges(last_state)
Expand Down Expand Up @@ -257,9 +266,9 @@ def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Dict[int,

idom = nx.immediate_dominators(sdfg.nx, sdfg.start_state)
all_doms = cfg.all_dominators(sdfg, idom)
symbol_access_sets: Dict[
Union[SDFGState, Edge[InterstateEdge]], Tuple[Set[str], Set[str]]
] = pipeline_results[SymbolAccessSets.__name__][sdfg.sdfg_id]
symbol_access_sets: Dict[Union[SDFGState, Edge[InterstateEdge]],
Tuple[Set[str],
Set[str]]] = pipeline_results[SymbolAccessSets.__name__][sdfg.sdfg_id]
state_reach: Dict[SDFGState, Set[SDFGState]] = pipeline_results[StateReachability.__name__][sdfg.sdfg_id]

for read_loc, (reads, _) in symbol_access_sets.items():
Expand Down Expand Up @@ -321,12 +330,14 @@ def should_reapply(self, modified: ppl.Modifies) -> bool:
def depends_on(self):
return {AccessSets, FindAccessNodes, StateReachability}

def _find_dominating_write(
self, desc: str, state: SDFGState, read: Union[nd.AccessNode, InterstateEdge],
access_nodes: Dict[SDFGState, Tuple[Set[nd.AccessNode], Set[nd.AccessNode]]],
state_idom: Dict[SDFGState, SDFGState], access_sets: Dict[SDFGState, Tuple[Set[str], Set[str]]],
no_self_shadowing: bool = False
) -> Optional[Tuple[SDFGState, nd.AccessNode]]:
def _find_dominating_write(self,
desc: str,
state: SDFGState,
read: Union[nd.AccessNode, InterstateEdge],
access_nodes: Dict[SDFGState, Tuple[Set[nd.AccessNode], Set[nd.AccessNode]]],
state_idom: Dict[SDFGState, SDFGState],
access_sets: Dict[SDFGState, Tuple[Set[str], Set[str]]],
no_self_shadowing: bool = False) -> Optional[Tuple[SDFGState, nd.AccessNode]]:
if isinstance(read, nd.AccessNode):
# If the read is also a write, it shadows itself.
iedges = state.in_edges(read)
Expand Down Expand Up @@ -408,18 +419,21 @@ def apply_pass(self, top_sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Dict[i
for oedge in out_edges:
syms = oedge.data.free_symbols & anames
if desc in syms:
write = self._find_dominating_write(
desc, state, oedge.data, access_nodes, idom, access_sets
)
write = self._find_dominating_write(desc, state, oedge.data, access_nodes, idom,
access_sets)
result[desc][write].add((state, oedge.data))
# Take care of any write nodes that have not been assigned to a scope yet, i.e., writes that are not
# dominating any reads and are thus not part of the results yet.
for state in desc_states_with_nodes:
for write_node in access_nodes[desc][state][1]:
if not (state, write_node) in result[desc]:
write = self._find_dominating_write(
desc, state, write_node, access_nodes, idom, access_sets, no_self_shadowing=True
)
write = self._find_dominating_write(desc,
state,
write_node,
access_nodes,
idom,
access_sets,
no_self_shadowing=True)
result[desc][write].add((state, write_node))

# If any write A is dominated by another write B and any reads in B's scope are also reachable by A,
Expand Down

0 comments on commit aa88d82

Please sign in to comment.