diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index aef522bcf1..912dc171c7 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -963,7 +963,7 @@ class ControlFlowBlock(BlockGraphView, abc.ABC): is_collapsed = Property(dtype=bool, desc='Show this block as collapsed', default=False) _sdfg: Optional['dace.SDFG'] = None - _parent: Optional['ControlFlowBlock'] = None + _parent: Optional['ScopeBlock'] = None _label: str def __init__(self, @@ -1012,12 +1012,12 @@ def name(self) -> str: return self._label @property - def parent(self) -> Optional['ControlFlowBlock']: + def parent(self) -> Optional['ScopeBlock']: """ Returns the parent block of this block. """ return self._parent @parent.setter - def parent(self, block: Optional['ControlFlowBlock']): + def parent(self, block: Optional['ScopeBlock']): self._parent = block @property diff --git a/dace/transformation/interstate/multistate_inline.py b/dace/transformation/interstate/multistate_inline.py index 4d560ab70a..d1f92ff628 100644 --- a/dace/transformation/interstate/multistate_inline.py +++ b/dace/transformation/interstate/multistate_inline.py @@ -158,15 +158,15 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG): sdfg.append_exit_code(code.code, loc) # Environments - for nstate in nsdfg.nodes(): - for node in nstate.nodes(): - if isinstance(node, nodes.CodeNode): - node.environments |= nsdfg_node.environments + for node, _ in nsdfg.all_nodes_recursive(): + if isinstance(node, nodes.CodeNode): + node.environments |= nsdfg_node.environments # Symbols outer_symbols = {str(k): v for k, v in sdfg.symbols.items()} - for ise in sdfg.edges(): - outer_symbols.update(ise.data.new_symbols(sdfg, outer_symbols)) + for cf in sdfg.all_cfgs_recursive(recurse_into_sdfgs=False): + for ise in cf.edges(): + outer_symbols.update(ise.data.new_symbols(sdfg, outer_symbols)) # Find original source/destination edges (there is only one edge per # connector, according to match) @@ -189,12 +189,14 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG): # Collect and modify interstate edges as necessary outer_assignments = set() - for e in sdfg.edges(): - outer_assignments |= e.data.assignments.keys() + for cf in sdfg.all_cfgs_recursive(): + for e in cf.edges(): + outer_assignments |= e.data.assignments.keys() inner_assignments = set() - for e in nsdfg.edges(): - inner_assignments |= e.data.assignments.keys() + for cf in nsdfg.all_cfgs_recursive(): + for e in cf.edges(): + inner_assignments |= e.data.assignments.keys() allnames = set(outer_symbols.keys()) | set(sdfg.arrays.keys()) assignments_to_replace = inner_assignments & (outer_assignments | allnames) @@ -235,30 +237,31 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG): # All transients become transients of the parent (if data already # exists, find new name) - for nstate in nsdfg.nodes(): - for node in nstate.nodes(): - if isinstance(node, nodes.AccessNode): - datadesc = nsdfg.arrays[node.data] - if node.data not in transients and datadesc.transient: - new_name = node.data - if (new_name in sdfg.arrays or new_name in outer_symbols or new_name in sdfg.constants): - new_name = f'{nsdfg.label}_{node.data}' - - name = sdfg.add_datadesc(new_name, datadesc, find_new_name=True) - transients[node.data] = name - - # All transients of edges between code nodes are also added to parent - for edge in nstate.edges(): - if (isinstance(edge.src, nodes.CodeNode) and isinstance(edge.dst, nodes.CodeNode)): - if edge.data.data is not None: - datadesc = nsdfg.arrays[edge.data.data] - if edge.data.data not in transients and datadesc.transient: - new_name = edge.data.data + for cf in nsdfg.all_cfgs_recursive(): + for nblock in cf.nodes(): + for node in nblock.nodes(): + if isinstance(node, nodes.AccessNode): + datadesc = nsdfg.arrays[node.data] + if node.data not in transients and datadesc.transient: + new_name = node.data if (new_name in sdfg.arrays or new_name in outer_symbols or new_name in sdfg.constants): - new_name = f'{nsdfg.label}_{edge.data.data}' + new_name = f'{nsdfg.label}_{node.data}' name = sdfg.add_datadesc(new_name, datadesc, find_new_name=True) - transients[edge.data.data] = name + transients[node.data] = name + + # All transients of edges between code nodes are also added to parent + for edge in nblock.edges(): + if (isinstance(edge.src, nodes.CodeNode) and isinstance(edge.dst, nodes.CodeNode)): + if edge.data.data is not None: + datadesc = nsdfg.arrays[edge.data.data] + if edge.data.data not in transients and datadesc.transient: + new_name = edge.data.data + if (new_name in sdfg.arrays or new_name in outer_symbols or new_name in sdfg.constants): + new_name = f'{nsdfg.label}_{edge.data.data}' + + name = sdfg.add_datadesc(new_name, datadesc, find_new_name=True) + transients[edge.data.data] = name # All constants (and associated transients) become constants of the parent @@ -329,12 +332,12 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG): # e.dst_conn, e.data) # Make unique names for states - statenames = set(s.label for s in sdfg.nodes()) - for nstate in nsdfg.nodes(): - if nstate.label in statenames: - newname = data.find_new_name(nstate.label, statenames) - statenames.add(newname) - nstate.label = newname + blocknames = set(s.label for s in sdfg.all_control_flow_blocks_recursive(recurse_into_sdfgs=False)) + for nblock in nsdfg.all_control_flow_blocks_recursive(recurse_into_sdfgs=False): + if nblock.label in blocknames: + newname = data.find_new_name(nblock.label, blocknames) + blocknames.add(newname) + nblock.label = newname ####################################################### # Add nested SDFG states into top-level SDFG @@ -352,19 +355,20 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG): sinks = nsdfg.sink_nodes() # Reconnect state machine - for e in sdfg.in_edges(outer_state): - sdfg.add_edge(e.src, source, e.data) - for e in sdfg.out_edges(outer_state): + parent_graph = outer_state.parent + for e in parent_graph.in_edges(outer_state): + parent_graph.add_edge(e.src, source, e.data) + for e in parent_graph.out_edges(outer_state): for sink in sinks: - sdfg.add_edge(sink, e.dst, dc(e.data)) + parent_graph.add_edge(sink, e.dst, dc(e.data)) # Redirect sink incoming edges with a `False` condition to e.dst (return statements) - for e2 in sdfg.in_edges(sink): + for e2 in parent_graph.in_edges(sink): if e2.data.condition_sympy() == False: - sdfg.add_edge(e2.src, e.dst, InterstateEdge()) + parent_graph.add_edge(e2.src, e.dst, InterstateEdge()) # Modify start state as necessary if outer_start_state is outer_state: - sdfg.start_state = sdfg.node_id(source) + parent_graph.start_block = parent_graph.node_id(source) # TODO: Modify memlets by offsetting # If both source and sink nodes are inputs/outputs, reconnect once @@ -408,13 +412,15 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG): # e._data = helpers.unsqueeze_memlet( # e.data, outer_edge.data) - # Replace nested SDFG parents with new SDFG - for nstate in nsdfg.nodes(): - nstate.parent = sdfg - for node in nstate.nodes(): - if isinstance(node, nodes.NestedSDFG): - node.sdfg.parent_sdfg = sdfg - node.sdfg.parent_nsdfg_node = node + # Replace nested SDFG parents and SDFG pointers. + for cf in nsdfg.all_cfgs_recursive(): + for nblock in cf.nodes(): + nblock.parent = cf + nblock.sdfg = sdfg + for node in nblock.nodes(): + if isinstance(node, nodes.NestedSDFG): + node.sdfg.parent_sdfg = sdfg + node.sdfg.parent_nsdfg_node = node ####################################################### # Remove nested SDFG and state