From 56ef20cbd1052c833ded7064286f280f6e4e17ae Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Fri, 15 Nov 2024 08:31:48 -0800 Subject: [PATCH] Make scalar to symbol promotion robust to node order in state --- dace/sdfg/analysis/schedule_tree/treenodes.py | 2 ++ .../transformation/passes/scalar_to_symbol.py | 4 ++- tests/passes/scalar_to_symbol_test.py | 30 +++++++++++++++++++ 3 files changed, 35 insertions(+), 1 deletion(-) diff --git a/dace/sdfg/analysis/schedule_tree/treenodes.py b/dace/sdfg/analysis/schedule_tree/treenodes.py index 3b447fa15a..dabd436b56 100644 --- a/dace/sdfg/analysis/schedule_tree/treenodes.py +++ b/dace/sdfg/analysis/schedule_tree/treenodes.py @@ -41,6 +41,8 @@ def __init__(self, children: Optional[List['ScheduleTreeNode']] = None): if self.children: for child in children: child.parent = self + self.containers = {} + self.symbols = {} def as_string(self, indent: int = 0): if not self.children: diff --git a/dace/transformation/passes/scalar_to_symbol.py b/dace/transformation/passes/scalar_to_symbol.py index 33712c8a1c..a37729ca7c 100644 --- a/dace/transformation/passes/scalar_to_symbol.py +++ b/dace/transformation/passes/scalar_to_symbol.py @@ -522,6 +522,8 @@ def remove_scalar_reads(sdfg: sd.SDFG, array_names: Dict[str, str]): for state in sdfg.states(): scalar_nodes = [n for n in state.nodes() if isinstance(n, nodes.AccessNode) and n.data in array_names] for node in scalar_nodes: + if node not in state: + continue symname = array_names[node.data] for out_edge in state.out_edges(node): for e in state.memlet_tree(out_edge): @@ -649,7 +651,7 @@ def apply_pass(self, sdfg: SDFG, _: Dict[Any, Any]) -> Set[str]: scalar_nodes = [n for n in state.nodes() if isinstance(n, nodes.AccessNode) and n.data in to_promote] # Step 2: Assignment tasklets for node in scalar_nodes: - if state.in_degree(node) == 0: + if node not in state or state.in_degree(node) == 0: continue in_edge = state.in_edges(node)[0] input = in_edge.src diff --git a/tests/passes/scalar_to_symbol_test.py b/tests/passes/scalar_to_symbol_test.py index 7fdfbdf737..36decceba2 100644 --- a/tests/passes/scalar_to_symbol_test.py +++ b/tests/passes/scalar_to_symbol_test.py @@ -729,6 +729,35 @@ def test_double_index_bug(): assert getattr(sympy_node, "name", None) != "indices" +def test_reversed_order(): + """ + Tests a failure reported in issue #1727. + """ + sdfg = dace.SDFG('tester') + sdfg.add_array('inputs', [1], dace.int32) + sdfg.add_transient('a', [1], dace.int32) + sdfg.add_transient('b', [1], dace.int32) + sdfg.add_array('output', [1], dace.int32) + initstate = sdfg.add_state() + state = sdfg.add_state_after(initstate) + finistate = sdfg.add_state_after(state) + + # Note the order here + w = state.add_write('b') + t = state.add_tasklet('assign', {'inp'}, {'out'}, 'out = inp') + r = state.add_read('a') + state.add_edge(t, 'out', w, None, dace.Memlet('b')) + state.add_edge(r, None, t, 'inp', dace.Memlet('a')) + + initstate.add_nedge(initstate.add_read('inputs'), initstate.add_write('a'), dace.Memlet('inputs')) + finistate.add_nedge(finistate.add_read('b'), finistate.add_write('output'), dace.Memlet('output')) + + sdfg.validate() + promoted = scalar_to_symbol.ScalarToSymbolPromotion().apply_pass(sdfg, {}) + assert promoted == {'a', 'b'} + sdfg.compile() + + if __name__ == '__main__': test_find_promotable() test_promote_simple() @@ -753,3 +782,4 @@ def test_double_index_bug(): test_ternary_expression(False) test_ternary_expression(True) test_double_index_bug() + test_reversed_order()