diff --git a/dace/codegen/control_flow.py b/dace/codegen/control_flow.py index 182604c892..1b97241e47 100644 --- a/dace/codegen/control_flow.py +++ b/dace/codegen/control_flow.py @@ -82,6 +82,9 @@ class ControlFlow: # a string with its generated code. dispatch_state: Callable[[SDFGState], str] + # The parent control flow block of this one, used to avoid generating extraneous ``goto``s + parent: Optional['ControlFlow'] + @property def first_state(self) -> SDFGState: """ @@ -222,11 +225,18 @@ def as_cpp(self, codegen, symbols) -> str: out_edges = sdfg.out_edges(elem.state) for j, e in enumerate(out_edges): if e not in self.gotos_to_ignore: - # If this is the last generated edge and it leads - # to the next state, skip emitting goto + # Skip gotos to immediate successors successor = None - if (j == (len(out_edges) - 1) and (i + 1) < len(self.elements)): - successor = self.elements[i + 1].first_state + # If this is the last generated edge + if j == (len(out_edges) - 1): + if (i + 1) < len(self.elements): + # If last edge leads to next state in block + successor = self.elements[i + 1].first_state + elif i == len(self.elements) - 1: + # If last edge leads to first state in next block + next_block = _find_next_block(self) + if next_block is not None: + successor = next_block.first_state expr += elem.generate_transition(sdfg, e, successor) else: @@ -478,13 +488,14 @@ def children(self) -> List[ControlFlow]: def _loop_from_structure(sdfg: SDFG, guard: SDFGState, enter_edge: Edge[InterstateEdge], leave_edge: Edge[InterstateEdge], back_edges: List[Edge[InterstateEdge]], - dispatch_state: Callable[[SDFGState], str]) -> Union[ForScope, WhileScope]: + dispatch_state: Callable[[SDFGState], + str], parent_block: GeneralBlock) -> Union[ForScope, WhileScope]: """ Helper method that constructs the correct structured loop construct from a set of states. Can construct for or while loops. """ - body = GeneralBlock(dispatch_state, [], [], [], [], [], True) + body = GeneralBlock(dispatch_state, parent_block, [], [], [], [], [], True) guard_inedges = sdfg.in_edges(guard) increment_edges = [e for e in guard_inedges if e in back_edges] @@ -535,10 +546,10 @@ def _loop_from_structure(sdfg: SDFG, guard: SDFGState, enter_edge: Edge[Intersta # Also ignore assignments in increment edge (handled in for stmt) body.assignments_to_ignore.append(increment_edge) - return ForScope(dispatch_state, itvar, guard, init, condition, update, body, init_edges) + return ForScope(dispatch_state, parent_block, itvar, guard, init, condition, update, body, init_edges) # Otherwise, it is a while loop - return WhileScope(dispatch_state, guard, condition, body) + return WhileScope(dispatch_state, parent_block, guard, condition, body) def _cases_from_branches( @@ -617,6 +628,31 @@ def _child_of(node: SDFGState, parent: SDFGState, ptree: Dict[SDFGState, SDFGSta return False +def _find_next_block(block: ControlFlow) -> Optional[ControlFlow]: + """ + Returns the immediate successor control flow block. + """ + # Find block in parent + parent = block.parent + if parent is None: + return None + ind = next(i for i, b in enumerate(parent.children) if b is block) + if ind == len(parent.children) - 1 or isinstance(parent, (IfScope, IfElseChain, SwitchCaseScope)): + # If last block, or other children are not reachable from current node (branches), + # recursively continue upwards + return _find_next_block(parent) + return parent.children[ind + 1] + + +def _reset_block_parents(block: ControlFlow): + """ + Fixes block parents after processing. + """ + for child in block.children: + child.parent = block + _reset_block_parents(child) + + def _structured_control_flow_traversal(sdfg: SDFG, start: SDFGState, ptree: Dict[SDFGState, SDFGState], @@ -645,7 +681,7 @@ def _structured_control_flow_traversal(sdfg: SDFG, """ def make_empty_block(): - return GeneralBlock(dispatch_state, [], [], [], [], [], True) + return GeneralBlock(dispatch_state, parent_block, [], [], [], [], [], True) # Traverse states in custom order visited = set() if visited is None else visited @@ -657,7 +693,7 @@ def make_empty_block(): if node in visited or node is stop: continue visited.add(node) - stateblock = SingleState(dispatch_state, node) + stateblock = SingleState(dispatch_state, parent_block, node) oe = sdfg.out_edges(node) if len(oe) == 0: # End state @@ -708,12 +744,14 @@ def make_empty_block(): if (len(oe) == 2 and oe[0].data.condition_sympy() == sp.Not(oe[1].data.condition_sympy())): # If without else if oe[0].dst is mergestate: - branch_block = IfScope(dispatch_state, sdfg, node, oe[1].data.condition, cblocks[oe[1]]) + branch_block = IfScope(dispatch_state, parent_block, sdfg, node, oe[1].data.condition, + cblocks[oe[1]]) elif oe[1].dst is mergestate: - branch_block = IfScope(dispatch_state, sdfg, node, oe[0].data.condition, cblocks[oe[0]]) + branch_block = IfScope(dispatch_state, parent_block, sdfg, node, oe[0].data.condition, + cblocks[oe[0]]) else: - branch_block = IfScope(dispatch_state, sdfg, node, oe[0].data.condition, cblocks[oe[0]], - cblocks[oe[1]]) + branch_block = IfScope(dispatch_state, parent_block, sdfg, node, oe[0].data.condition, + cblocks[oe[0]], cblocks[oe[1]]) else: # If there are 2 or more edges (one is not the negation of the # other): @@ -721,10 +759,10 @@ def make_empty_block(): if switch: # If all edges are of form "x == y" for a single x and # integer y, it is a switch/case - branch_block = SwitchCaseScope(dispatch_state, sdfg, node, switch[0], switch[1]) + branch_block = SwitchCaseScope(dispatch_state, parent_block, sdfg, node, switch[0], switch[1]) else: # Otherwise, create if/else if/.../else goto exit chain - branch_block = IfElseChain(dispatch_state, sdfg, node, + branch_block = IfElseChain(dispatch_state, parent_block, sdfg, node, [(e.data.condition, cblocks[e] if e in cblocks else make_empty_block()) for e in oe]) # End of branch classification @@ -739,11 +777,11 @@ def make_empty_block(): loop_exit = None scope = None if ptree[oe[0].dst] == node and ptree[oe[1].dst] != node: - scope = _loop_from_structure(sdfg, node, oe[0], oe[1], back_edges, dispatch_state) + scope = _loop_from_structure(sdfg, node, oe[0], oe[1], back_edges, dispatch_state, parent_block) body_start = oe[0].dst loop_exit = oe[1].dst elif ptree[oe[1].dst] == node and ptree[oe[0].dst] != node: - scope = _loop_from_structure(sdfg, node, oe[1], oe[0], back_edges, dispatch_state) + scope = _loop_from_structure(sdfg, node, oe[1], oe[0], back_edges, dispatch_state, parent_block) body_start = oe[1].dst loop_exit = oe[0].dst @@ -836,7 +874,8 @@ def structured_control_flow_tree(sdfg: SDFG, dispatch_state: Callable[[SDFGState if len(common_frontier) == 1: branch_merges[state] = next(iter(common_frontier)) - root_block = GeneralBlock(dispatch_state, [], [], [], [], [], True) + root_block = GeneralBlock(dispatch_state, None, [], [], [], [], [], True) _structured_control_flow_traversal(sdfg, sdfg.start_state, ptree, branch_merges, back_edges, dispatch_state, root_block) + _reset_block_parents(root_block) return root_block diff --git a/dace/codegen/targets/framecode.py b/dace/codegen/targets/framecode.py index 9ee5c2ef17..dfdbbb392b 100644 --- a/dace/codegen/targets/framecode.py +++ b/dace/codegen/targets/framecode.py @@ -471,7 +471,7 @@ def dispatch_state(state: SDFGState) -> str: # If disabled, generate entire graph as general control flow block states_topological = list(sdfg.topological_sort(sdfg.start_state)) last = states_topological[-1] - cft = cflow.GeneralBlock(dispatch_state, + cft = cflow.GeneralBlock(dispatch_state, None, [cflow.SingleState(dispatch_state, s, s is last) for s in states_topological], [], [], [], [], False) diff --git a/tests/codegen/control_flow_detection_test.py b/tests/codegen/control_flow_detection_test.py index 99d6a39b29..982140f7ed 100644 --- a/tests/codegen/control_flow_detection_test.py +++ b/tests/codegen/control_flow_detection_test.py @@ -120,6 +120,33 @@ def test_single_outedge_branch(): assert np.allclose(res, 2) +def test_extraneous_goto(): + + @dace.program + def tester(a: dace.float64[20]): + if a[0] < 0: + a[1] = 1 + a[2] = 1 + + sdfg = tester.to_sdfg(simplify=True) + assert 'goto' not in sdfg.generate_code()[0].code + + +def test_extraneous_goto_nested(): + + @dace.program + def tester(a: dace.float64[20]): + if a[0] < 0: + if a[0] < 1: + a[1] = 1 + else: + a[1] = 2 + a[2] = 1 + + sdfg = tester.to_sdfg(simplify=True) + assert 'goto' not in sdfg.generate_code()[0].code + + if __name__ == '__main__': test_for_loop_detection() test_invalid_for_loop_detection() @@ -128,3 +155,5 @@ def test_single_outedge_branch(): test_edge_sympy_function('TrueFalse') test_edge_sympy_function('SwitchCase') test_single_outedge_branch() + test_extraneous_goto() + test_extraneous_goto_nested()