diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index 2fa16182e3..5efd155d6b 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -1210,7 +1210,7 @@ def parse_program(self, program: ast.FunctionDef, is_tasklet: bool = False): for stmt in program.body: self.visit_TopLevel(stmt) if len(self.sdfg.nodes()) == 0: - self.cfg_target.add_state('EmptyState') + self.sdfg.add_state('EmptyState') # Handle return values # Assignments to return values become __return* arrays @@ -1277,7 +1277,7 @@ def _views_to_data(state: SDFGState, nodes: List[dace.nodes.AccessNode]) -> List return new_nodes # Map view access nodes to their respective data - for state in self.sdfg.nodes(): + for state in self.sdfg.states(): # NOTE: We need to support views of views nodes = list(state.data_nodes()) while nodes: @@ -2023,7 +2023,7 @@ def _add_dependencies(self, else: name = memlet.data vname = "{c}_in_from_{s}{n}".format(c=conn, - s=self.sdfg.nodes().index(state), + s=self.sdfg.states().index(state), n=('_%s' % state.node_id(entry_node) if entry_node else '')) self.accesses[(name, scope_memlet.subset, 'r')] = (vname, orng) orig_shape = orng.size() @@ -2113,7 +2113,7 @@ def _add_dependencies(self, else: name = memlet.data vname = "{c}_out_of_{s}{n}".format(c=conn, - s=self.sdfg.nodes().index(state), + s=self.sdfg.states().index(state), n=('_%s' % state.node_id(exit_node) if exit_node else '')) self.accesses[(name, scope_memlet.subset, 'w')] = (vname, orng) orig_shape = orng.size() @@ -2210,7 +2210,8 @@ def _recursive_visit(self, # Restore previous target self.cfg_target = previous_target self.last_cfg_target = previous_last_cfg_target - self.last_block = previous_block + if not unconnected_last_block: + self.last_block = previous_block return previous_block, first_innner_block, last_inner_block, has_return_statement @@ -2358,7 +2359,7 @@ def visit_For(self, node: ast.For): # The state that all "break" edges go to state = self.cfg_target.add_state(f'postloop_{node.lineno}') if self.last_block is not None: - self.sdfg.add_edge(self.last_block, state, dace.InterstateEdge()) + self.cfg_target.add_edge(self.last_block, state, dace.InterstateEdge()) self.last_block = state return state @@ -2473,25 +2474,25 @@ def visit_If(self, node: ast.If): # Visit recursively laststate, first_if_state, last_if_state, return_stmt = \ - self._recursive_visit(node.body, 'if', node.lineno) + self._recursive_visit(node.body, 'if', node.lineno, self.cfg_target, True) end_if_state = self.last_block # Connect the states - self.sdfg.add_edge(laststate, first_if_state, dace.InterstateEdge(cond)) - self.sdfg.add_edge(last_if_state, end_if_state, dace.InterstateEdge(condition=f"{not return_stmt}")) + self.cfg_target.add_edge(laststate, first_if_state, dace.InterstateEdge(cond)) + self.cfg_target.add_edge(last_if_state, end_if_state, dace.InterstateEdge(condition=f"{not return_stmt}")) # Process 'else'/'elif' statements if len(node.orelse) > 0: # Visit recursively _, first_else_state, last_else_state, return_stmt = \ - self._recursive_visit(node.orelse, 'else', node.lineno, False) + self._recursive_visit(node.orelse, 'else', node.lineno, self.cfg_target, False) # Connect the states - self.sdfg.add_edge(laststate, first_else_state, dace.InterstateEdge(cond_else)) - self.sdfg.add_edge(last_else_state, end_if_state, dace.InterstateEdge(condition=f"{not return_stmt}")) - self.last_block = end_if_state + self.cfg_target.add_edge(laststate, first_else_state, dace.InterstateEdge(cond_else)) + self.cfg_target.add_edge(last_else_state, end_if_state, dace.InterstateEdge(condition=f"{not return_stmt}")) else: - self.sdfg.add_edge(laststate, end_if_state, dace.InterstateEdge(cond_else)) + self.cfg_target.add_edge(laststate, end_if_state, dace.InterstateEdge(cond_else)) + self.last_block = end_if_state def _parse_tasklet(self, state: SDFGState, node: TaskletType, name=None): @@ -3323,7 +3324,7 @@ def _visit_assign(self, node, node_target, op, dtype=None, is_return=False): # Handle output indirection output_indirection = None if _subset_has_indirection(rng, self): - output_indirection = self.sdfg.add_state('wslice_%s_%d' % (new_name, node.lineno)) + output_indirection = self.cfg_target.add_state('wslice_%s_%d' % (new_name, node.lineno)) wnode = output_indirection.add_write(new_name, debuginfo=self.current_lineinfo) memlet = Memlet.simple(new_name, str(rng)) # Dependent augmented assignments need WCR in the @@ -3369,7 +3370,7 @@ def _visit_assign(self, node, node_target, op, dtype=None, is_return=False): # Connect states properly when there is output indirection if output_indirection: - self.sdfg.add_edge(self.last_block, output_indirection, dace.sdfg.InterstateEdge()) + self.cfg_target.add_edge(self.last_block, output_indirection, dace.sdfg.InterstateEdge()) self.last_block = output_indirection def visit_AugAssign(self, node: ast.AugAssign): @@ -3814,8 +3815,8 @@ def _parse_sdfg_call(self, funcname: str, func: Union[SDFG, SDFGConvertible], no for sym, local in mapping.items(): if isinstance(local, str) and local in self.sdfg.arrays: # Add assignment state and inter-state edge - symassign_state = self.sdfg.add_state_before(state) - isedge = self.sdfg.edges_between(symassign_state, state)[0] + symassign_state = self.cfg_target.add_state_before(state) + isedge = self.cfg_target.edges_between(symassign_state, state)[0] newsym = self.sdfg.find_new_symbol(f'sym_{local}') desc = self.sdfg.arrays[local] self.sdfg.add_symbol(newsym, desc.dtype) @@ -3879,7 +3880,7 @@ def _parse_sdfg_call(self, funcname: str, func: Union[SDFG, SDFGConvertible], no # Delete the old read descriptor if not isinput: conn_used = False - for s in self.sdfg.nodes(): + for s in self.sdfg.states(): for n in s.data_nodes(): if n.data == aname: conn_used = True @@ -4866,7 +4867,7 @@ def _promote(node: ast.AST) -> Union[Any, str, symbolic.symbol]: # `not sym` returns True. This exception is benign. pass state = self._add_state(f'promote_{scalar}_to_{str(sym)}') - edge = self.sdfg.in_edges(state)[0] + edge = state.parent.in_edges(state)[0] edge.data.assignments = {str(sym): scalar} return sym return scalar diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index 7d433b670c..34ff74fc34 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -1261,7 +1261,8 @@ def used_symbols(self, all_symbols: bool) -> Tuple[Set[str], Set[str], Set[str]] # Add free state symbols used_before_assignment = set() - b_free_syms, b_defined_syms, b_used_before_syms = super().used_symbols(all_symbols) + b_free_syms, b_defined_syms, b_used_before_syms = super().used_symbols(all_symbols, defined_syms, free_syms, + used_before_assignment) free_syms |= b_free_syms defined_syms |= b_defined_syms used_before_assignment |= b_used_before_syms @@ -2030,10 +2031,6 @@ def compile(self, output_file=None, validate=True) -> 'CompiledSDFG': ############################ # DaCe Compilation Process # - # Convert any scope blocks to old-school state machines for now. - # TODO: Adapt codegen to deal wiht scope blocks instead. - sdutils.inline_loop_blocks(self) - if self._regenerate_code or not os.path.isdir(build_folder): # Clone SDFG as the other modules may modify its contents sdfg = copy.deepcopy(self) @@ -2041,6 +2038,10 @@ def compile(self, output_file=None, validate=True) -> 'CompiledSDFG': # if the codegen modifies the SDFG (thereby changing its hash) sdfg.build_folder = build_folder + # Convert any scope blocks to old-school state machines for now. + # TODO: Adapt codegen to deal wiht scope blocks instead. + sdutils.inline_loop_blocks(sdfg) + # Rename SDFG to avoid runtime issues with clashing names index = 0 while sdfg.is_loaded(): diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 857d8afb21..033be384a8 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -2368,10 +2368,14 @@ def add_loop( return before_state, loop_scope, after_state @abc.abstractmethod - def used_symbols(self, all_symbols: bool) -> Tuple[Set[str], Set[str], Set[str]]: - defined_syms = set() - free_syms = set() - used_before_assignment = set() + def used_symbols(self, + all_symbols: bool, + defined_syms: Optional[Set]=None, + free_syms: Optional[Set]=None, + used_before_assignment: Optional[Set]=None) -> Tuple[Set[str], Set[str], Set[str]]: + defined_syms = set() if defined_syms is None else defined_syms + free_syms = set() if free_syms is None else free_syms + used_before_assignment = set() if used_before_assignment is None else used_before_assignment try: ordered_blocks = self.topological_sort(self.start_block) diff --git a/dace/transformation/interstate/multistate_inline.py b/dace/transformation/interstate/multistate_inline.py index aa999ae480..a4af63b482 100644 --- a/dace/transformation/interstate/multistate_inline.py +++ b/dace/transformation/interstate/multistate_inline.py @@ -237,31 +237,30 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG): # All transients become transients of the parent (if data already # exists, find new name) - for cf in nsdfg.all_state_scopes_recursive(): - for nblock in cf.nodes(): - for node in nblock.nodes(): - if isinstance(node, nodes.AccessNode): - datadesc = nsdfg.arrays[node.data] - if node.data not in transients and datadesc.transient: - new_name = node.data + for state in nsdfg.states(): + for node in state.nodes(): + if isinstance(node, nodes.AccessNode): + datadesc = nsdfg.arrays[node.data] + if node.data not in transients and datadesc.transient: + new_name = node.data + if (new_name in sdfg.arrays or new_name in outer_symbols or new_name in sdfg.constants): + new_name = f'{nsdfg.label}_{node.data}' + + name = sdfg.add_datadesc(new_name, datadesc, find_new_name=True) + transients[node.data] = name + + # All transients of edges between code nodes are also added to parent + for edge in state.edges(): + if (isinstance(edge.src, nodes.CodeNode) and isinstance(edge.dst, nodes.CodeNode)): + if edge.data.data is not None: + datadesc = nsdfg.arrays[edge.data.data] + if edge.data.data not in transients and datadesc.transient: + new_name = edge.data.data if (new_name in sdfg.arrays or new_name in outer_symbols or new_name in sdfg.constants): - new_name = f'{nsdfg.label}_{node.data}' + new_name = f'{nsdfg.label}_{edge.data.data}' name = sdfg.add_datadesc(new_name, datadesc, find_new_name=True) - transients[node.data] = name - - # All transients of edges between code nodes are also added to parent - for edge in nblock.edges(): - if (isinstance(edge.src, nodes.CodeNode) and isinstance(edge.dst, nodes.CodeNode)): - if edge.data.data is not None: - datadesc = nsdfg.arrays[edge.data.data] - if edge.data.data not in transients and datadesc.transient: - new_name = edge.data.data - if (new_name in sdfg.arrays or new_name in outer_symbols or new_name in sdfg.constants): - new_name = f'{nsdfg.label}_{edge.data.data}' - - name = sdfg.add_datadesc(new_name, datadesc, find_new_name=True) - transients[edge.data.data] = name + transients[edge.data.data] = name # All constants (and associated transients) become constants of the parent @@ -413,11 +412,12 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG): # e.data, outer_edge.data) # Replace nested SDFG parents and SDFG pointers. - for cf in nsdfg.all_state_scopes_recursive(): - for nblock in cf.nodes(): - nblock.parent = cf - nblock.sdfg = sdfg - for node in nblock.nodes(): + for n in nsdfg.nodes(): + n.parent = outer_state.parent + for block in nsdfg.all_control_flow_blocks_recursive(recurse_into_sdfgs=False): + block.sdfg = outer_state.sdfg + if isinstance(block, SDFGState): + for node in block.nodes(): if isinstance(node, nodes.NestedSDFG): node.sdfg.parent_sdfg = sdfg node.sdfg.parent_nsdfg_node = node diff --git a/tests/memlet_propagation_test.py b/tests/memlet_propagation_test.py index f90834cbb7..f1196348da 100644 --- a/tests/memlet_propagation_test.py +++ b/tests/memlet_propagation_test.py @@ -73,7 +73,7 @@ def sparse(A: dace.float32[M, N], ind: dace.int32[M, N]): propagate_memlets_sdfg(sdfg) # Verify all memlet subsets and volumes in the main state of the program, i.e. around the NSDFG. - map_state = sdfg.states()[1] + map_state = list(sdfg.states())[1] i = dace.symbol('i') j = dace.symbol('j')