From 837800c92bedd19b67643f4fd64fb4d79a5e10ea Mon Sep 17 00:00:00 2001 From: Alexandros Nikolaos Ziogas Date: Wed, 28 Sep 2022 12:34:21 +0200 Subject: [PATCH 01/16] Fixed constant-propagation for cases where the start-state is a scope's guard. --- .../passes/constant_propagation.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/dace/transformation/passes/constant_propagation.py b/dace/transformation/passes/constant_propagation.py index 08f258514b..d1630a8f6f 100644 --- a/dace/transformation/passes/constant_propagation.py +++ b/dace/transformation/passes/constant_propagation.py @@ -171,15 +171,22 @@ def collect_constants(self, # Traverse SDFG topologically for state in optional_progressbar(sdfg.topological_sort(start_state), 'Collecting constants', sdfg.number_of_nodes(), self.progress): - if state in result: + # NOTE: We must always check the start-state regardless if there are initial symbols. This is necessary + # when the start-state is a scope's guard instead of a special initialization state, i.e., when the start- + # state has incoming edges that may involve the initial symbols. See also: + # `tests.passes.constant_propagation_test.test_for_with_external_init_nested_start_with_guard`` + if state in result and state is not start_state: continue # Get predecessors in_edges = sdfg.in_edges(state) if len(in_edges) == 1: # Special case, propagate as-is - result[state] = {} + if state not in result: # Condition evaluates to False when state is the start-state + result[state] = {} + # First the prior state - self._propagate(result[state], result[in_edges[0].src]) + if in_edges[0].src in result: # Condition evaluates to False when state is the start-state + self._propagate(result[state], result[in_edges[0].src]) # Then assignments on the incoming edge self._propagate(result[state], self._data_independent_assignments(in_edges[0].data, arrays)) @@ -205,7 +212,8 @@ def collect_constants(self, else: assignments[aname] = aval - result[state] = {} + if state not in result: # Condition may evaluate to False when state is the start-state + result[state] = {} self._propagate(result[state], assignments) return result From 3f3072d56bd019d6d9efd14f77eca9d3aca84daf Mon Sep 17 00:00:00 2001 From: Alexandros Nikolaos Ziogas Date: Wed, 28 Sep 2022 12:34:41 +0200 Subject: [PATCH 02/16] Added test for the previous commit's fix. --- tests/passes/constant_propagation_test.py | 46 ++++++++++++++++++++++- 1 file changed, 45 insertions(+), 1 deletion(-) diff --git a/tests/passes/constant_propagation_test.py b/tests/passes/constant_propagation_test.py index fab842897c..fc22dd7f96 100644 --- a/tests/passes/constant_propagation_test.py +++ b/tests/passes/constant_propagation_test.py @@ -356,7 +356,7 @@ def test_for_with_external_init_nested(): N = dace.symbol('N') - sdfg = dace.SDFG('for_with_external_init') + sdfg = dace.SDFG('for_with_external_init_nested') sdfg.add_array('A', (N, ), dace.int32) init = sdfg.add_state('init', is_start_state=True) main = sdfg.add_state('main') @@ -393,6 +393,49 @@ def test_for_with_external_init_nested(): assert np.allclose(val1, ref) +def test_for_with_external_init_nested_start_with_guard(): + """ + This test differs from the one above in lacking an initialization SDFGState in the NestedSDFG. Instead, the guard + of the nested for-loop is explicitly set as the start-state of the NestedSDFG. + """ + + N = dace.symbol('N') + + sdfg = dace.SDFG('for_with_external_init_nested_start_with_guard') + sdfg.add_array('A', (N, ), dace.int32) + init = sdfg.add_state('init', is_start_state=True) + main = sdfg.add_state('main') + sdfg.add_edge(init, main, dace.InterstateEdge(assignments={'i': '1'})) + + nsdfg = dace.SDFG('nested_sdfg') + nsdfg.add_array('inner_A', (N,), dace.int32) + nguard = nsdfg.add_state('nested_guard', is_start_state=True) + nbody = nsdfg.add_state('nested_body') + nexit = nsdfg.add_state('nested_exit') + nsdfg.add_edge(nguard, nbody, dace.InterstateEdge(condition='i <= N')) + nsdfg.add_edge(nbody, nguard, dace.InterstateEdge(assignments={'i': 'i+1'})) + nsdfg.add_edge(nguard, nexit, dace.InterstateEdge(condition='i > N')) + + na = nbody.add_access('inner_A') + nt = nbody.add_tasklet('tasklet', {}, {'__out'}, '__out = i-1') + nbody.add_edge(nt, '__out', na, None, dace.Memlet('inner_A[i-1]')) + + a = main.add_access('A') + t = main.add_nested_sdfg(nsdfg, None, {}, {'inner_A'}, {'N': 'N', 'i': 'i'}) + main.add_edge(t, 'inner_A', a, None, dace.Memlet.from_array('A', sdfg.arrays['A'])) + + sdfg.validate() + + ref = np.arange(10, dtype=np.int32) + val0 = np.ndarray((10, ), dtype=np.int32) + sdfg(A=val0, N=10) + assert np.allclose(val0, ref) + ConstantPropagation().apply_pass(sdfg, {}) + val1 = np.ndarray((10, ), dtype=np.int32) + sdfg(A=val1, N=10) + assert np.allclose(val1, ref) + + if __name__ == '__main__': test_simple_constants() test_nested_constants() @@ -408,3 +451,4 @@ def test_for_with_external_init_nested(): test_allocation_varying(True) test_for_with_external_init() test_for_with_external_init_nested() + test_for_with_external_init_nested_start_with_guard() From ee4b856d26d52da1242cee7d239a46343d077c72 Mon Sep 17 00:00:00 2001 From: Alexandros Nikolaos Ziogas Date: Wed, 28 Sep 2022 14:35:44 +0200 Subject: [PATCH 03/16] Added test for MapFission issues with direct copies among AccessNodes. --- tests/transformations/mapfission_test.py | 48 ++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/tests/transformations/mapfission_test.py b/tests/transformations/mapfission_test.py index beeb64a9bb..917077a0e8 100644 --- a/tests/transformations/mapfission_test.py +++ b/tests/transformations/mapfission_test.py @@ -394,6 +394,54 @@ def map_with_if_2(A: dace.int32[10]): val1 = np.ndarray((10, ), dtype=np.int32) sdfg(A=val1) self.assertTrue(np.array_equal(val1, ref)) + + def test_array_copy_outside_scope(self): + + """ + This test checks for two issues occuring when MapFission applies on a NestedSDFG with a state-subgraph + containing copies among AccessNodes. In such cases, these copies may end up outside the scope of the generated + Maps (after MapFssion), potentially leading to the following errors: + 1. The memlet subset corresponding to a NestedSDFG connector (input/output) may have its dimensionality + erroneously increased. + 2. The memlet subset corresponding to a NestedSDFG connector (input/output) may not be propagated even if it uses + the Map's parameters. + """ + + sdfg = dace.SDFG('array_copy_outside_scope') + iname, _ = sdfg.add_array('inp', (10,), dtype=dace.int32) + oname, _ = sdfg.add_array('out', (10,), dtype=dace.int32) + + nsdfg = dace.SDFG('nested_sdfg') + niname, nidesc = nsdfg.add_array('ninp', (1,), dtype=dace.int32) + ntname, ntdesc = nsdfg.add_scalar('ntmp', dtype=dace.int32, transient=True) + noname, nodesc = nsdfg.add_array('nout', (1,), dtype=dace.int32) + + nstate = nsdfg.add_state('nmain') + ninode = nstate.add_access(niname) + ntnode = nstate.add_access(ntname) + nonode = nstate.add_access(noname) + tasklet = nstate.add_tasklet('tasklet', {'__inp'}, {'__out'}, '__out = __inp + 1') + nstate.add_edge(ninode, None, tasklet, '__inp', dace.Memlet.from_array(niname, nidesc)) + nstate.add_edge(tasklet, '__out', ntnode, None, dace.Memlet.from_array(ntname, ntdesc)) + nstate.add_nedge(ntnode, nonode, dace.Memlet.from_array(noname, nodesc)) + + state = sdfg.add_state('main') + inode = state.add_access(iname) + onode = state.add_access(oname) + me, mx = state.add_map('map', {'i': '0:10'}) + snode = state.add_nested_sdfg(nsdfg, None, {'ninp'}, {'nout'}) + state.add_memlet_path(inode, me, snode, memlet=dace.Memlet(data=iname, subset='i'), dst_conn='ninp') + state.add_memlet_path(snode, mx, onode, memlet=dace.Memlet(data=oname, subset='i'), src_conn='nout') + + # Issue no. 1 will be caught by validation after MapFission + sdfg.apply_transformations(MapFission) + + # Issue no. 2 will be caught by code-generation due to `i` existing in a memlet outside the Map's scope. + A = np.arange(10, dtype=np.int32) + B = np.empty((10,), dtype=np.int32) + sdfg(inp=A, out=B) + assert np.array_equal(A, B) + if __name__ == '__main__': From 748ea4410d92971f920ecedaffdef238d653646a Mon Sep 17 00:00:00 2001 From: Alexandros Nikolaos Ziogas Date: Wed, 28 Sep 2022 14:36:44 +0200 Subject: [PATCH 04/16] MapFission: fixed erroneous increased subset dimensionality for memlets that don't need it. --- dace/transformation/dataflow/map_fission.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/dace/transformation/dataflow/map_fission.py b/dace/transformation/dataflow/map_fission.py index 0e2dd868d5..940a5f5960 100644 --- a/dace/transformation/dataflow/map_fission.py +++ b/dace/transformation/dataflow/map_fission.py @@ -455,9 +455,21 @@ def apply(self, graph: sd.SDFGState, sdfg: sd.SDFG): for edge in state.all_edges(node): for e in state.memlet_tree(edge): # Prepend map dimensions to memlet - e.data.subset = subsets.Range([(pystr_to_symbolic(d) - r[0], pystr_to_symbolic(d) - r[0], 1) - for d, r in zip(outer_map.params, outer_map.range)] + - e.data.subset.ranges) + # NOTE: Do this only for the subset corresponding to `node.data`. If the edge is copying + # to/from another AccessNode, the other data may not need extra dimensions. For example, see + # `test.transformations.mapfission_test.MapFissionTest.test_array_copy_outside_scope`. + if e.data.data == node.data: + if e.data.subset: + e.data.subset = subsets.Range([(pystr_to_symbolic(d) - r[0], + pystr_to_symbolic(d) - r[0], 1) + for d, r in zip(outer_map.params, outer_map.range)] + + e.data.subset.ranges) + else: + if e.data.other_subset: + e.data.other_subset = subsets.Range( + [(pystr_to_symbolic(d) - r[0], pystr_to_symbolic(d) - r[0], 1) + for d, r in zip(outer_map.params, outer_map.range)] + + e.data.other_subset.ranges) # If nested SDFG, reconnect nodes around map and modify memlets if self.expr_index == 1: From 2c97a90d61275d6345741a06de195dfeb84c26f5 Mon Sep 17 00:00:00 2001 From: Alexandros Nikolaos Ziogas Date: Wed, 28 Sep 2022 14:48:22 +0200 Subject: [PATCH 05/16] MapFission: Fixed issue where memles of edges directly connecting AccessNodes are not propagated properly. --- dace/transformation/dataflow/map_fission.py | 20 +++++++++++++++++++- tests/transformations/mapfission_test.py | 2 +- 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/dace/transformation/dataflow/map_fission.py b/dace/transformation/dataflow/map_fission.py index 940a5f5960..9021e29203 100644 --- a/dace/transformation/dataflow/map_fission.py +++ b/dace/transformation/dataflow/map_fission.py @@ -8,7 +8,7 @@ from dace.sdfg import nodes, graph as gr from dace.sdfg import utils as sdutil from dace.sdfg.graph import OrderedDiGraph -from dace.sdfg.propagation import propagate_memlets_state +from dace.sdfg.propagation import propagate_memlets_state, propagate_subset from dace.symbolic import pystr_to_symbolic from dace.transformation import transformation, helpers from typing import List, Optional, Tuple @@ -413,6 +413,14 @@ def apply(self, graph: sd.SDFGState, sdfg: sd.SDFG): # Correct connectors and memlets in nested SDFGs to account for # missing outside map if self.expr_index == 1: + + # NOTE: In the following scope dictionary, we mark the new MapEntries as existing in their own scope. + # This makes it easier to detect edges that are outside the new Map scopes (after MapFission). + scope_dict = state.scope_dict() + for k, v in scope_dict.items(): + if isinstance(k, nodes.MapEntry) and k in new_map_entries and v is None: + scope_dict[k] = k + to_correct = ([(e, e.src) for e in external_edges_entry] + [(e, e.dst) for e in external_edges_exit]) corrected_nodes = set() for edge, node in to_correct: @@ -442,6 +450,12 @@ def apply(self, graph: sd.SDFGState, sdfg: sd.SDFG): for e in state.memlet_tree(internal_edge): e.data.subset.offset(desc.offset, False) e.data.subset = helpers.unsqueeze_memlet(e.data, outer_edge.data).subset + # NOTE: If the edge is outside of the new Map scope, then try to propagate it. This is + # needed for edges directly connecting AccessNodes, because the standard memlet + # propagation will stop at the first AccessNode outside the Map scope. For example, see + # `test.transformations.mapfission_test.MapFissionTest.test_array_copy_outside_scope`. + if not (scope_dict[e.src] and scope_dict[e.dst]): + e.data = propagate_subset([e.data], desc, outer_map.params, outer_map.range) # Only after offsetting memlets we can modify the # overall offset @@ -498,3 +512,7 @@ def apply(self, graph: sd.SDFGState, sdfg: sd.SDFG): # Remove outer map graph.remove_nodes_from([map_entry, map_exit]) + + # NOTE: It is better to manually call memlet propagation here to ensure that all subsets are properly updated. + # This can solve issues when, e.g., applying MapFission through `SDFG.apply_transformations_repeated`. + propagate_memlets_state(sdfg, graph) diff --git a/tests/transformations/mapfission_test.py b/tests/transformations/mapfission_test.py index 917077a0e8..a7faa9c882 100644 --- a/tests/transformations/mapfission_test.py +++ b/tests/transformations/mapfission_test.py @@ -440,7 +440,7 @@ def test_array_copy_outside_scope(self): A = np.arange(10, dtype=np.int32) B = np.empty((10,), dtype=np.int32) sdfg(inp=A, out=B) - assert np.array_equal(A, B) + assert np.array_equal(A+1, B) From cc9da27446b484d5c03961e389eaa4167ad3ea64 Mon Sep 17 00:00:00 2001 From: Alexandros Nikolaos Ziogas Date: Wed, 28 Sep 2022 22:14:40 +0200 Subject: [PATCH 06/16] Made test for GPUTransform issue with GPU device-scheduled NestedSDFGs that are not inside a GPU kernel. --- tests/transformations/gpu_transform_test.py | 34 +++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/tests/transformations/gpu_transform_test.py b/tests/transformations/gpu_transform_test.py index 91403c0457..706eeaabc1 100644 --- a/tests/transformations/gpu_transform_test.py +++ b/tests/transformations/gpu_transform_test.py @@ -2,6 +2,8 @@ """ Unit tests for the GPU to-device transformation. """ import dace +import numpy as np +import pytest from dace.transformation.interstate import GPUTransformSDFG @@ -26,5 +28,37 @@ def program(A: dace.float64[20, 20]): assert desc.lifetime is not dace.AllocationLifetime.SDFG +@pytest.mark.gpu +def test_scalar_to_symbol_in_nested_sdfg(): + """ + GPUTransformSDFG will automatically create copy-out states for GPU scalars that are used in host-side interstate + edges. However, this process may only be applied in top-level SDFGs and not in NestedSDFGs that have GPU-device + schedule but are not part of a single GPU kernel, leading to illegal memory accesses. + """ + + @dace.program + def nested_program(a: dace.int32, out: dace.int32[10]): + for i in range(10): + if a < 5: + out[i] = 0 + a *= 2 + else: + out[i] = 10 + a /= 2 + + @dace.program + def main_program(a: dace.int32): + out = np.ndarray((10,), dtype=np.int32) + nested_program(a, out) + return out + + sdfg = main_program.to_sdfg(simplify=False) + sdfg.apply_transformations(GPUTransformSDFG) + out = np.empty((10,), dtype=np.int32) + sdfg(a=4, out=out) + assert np.array_equal(out, np.array([0, 10] * 5, dtype=np.int32)) + + if __name__ == '__main__': test_toplevel_transient_lifetime() + test_scalar_to_symbol_in_nested_sdfg() From b6387ca4e5fb366bedd4d794112fcdee1bee590c Mon Sep 17 00:00:00 2001 From: Alexandros Nikolaos Ziogas Date: Wed, 28 Sep 2022 22:15:37 +0200 Subject: [PATCH 07/16] Refactored GPUTransformSDFG to call itself recursively for (some) NestedSDFGs. --- .../interstate/gpu_transform_sdfg.py | 99 ++++++++++--------- 1 file changed, 54 insertions(+), 45 deletions(-) diff --git a/dace/transformation/interstate/gpu_transform_sdfg.py b/dace/transformation/interstate/gpu_transform_sdfg.py index 3a47c74c9a..a6389da8d9 100644 --- a/dace/transformation/interstate/gpu_transform_sdfg.py +++ b/dace/transformation/interstate/gpu_transform_sdfg.py @@ -32,7 +32,7 @@ def _recursive_out_check(node, state, gpu_scalars): scalset = scalset.union(sset) scalout = scalout and ssout continue - if desc.shape == (1,): # Pseudo-scalar + if desc.shape == (1, ): # Pseudo-scalar scalout = False sset, ssout = _recursive_out_check(last_edge.dst, state, gpu_scalars) scalset = scalset.union(sset) @@ -66,7 +66,7 @@ def _recursive_in_check(node, state, gpu_scalars): scalset = scalset.union(sset) scalout = scalout and ssout continue - if desc.shape == (1,): # Pseudo-scalar + if desc.shape == (1, ): # Pseudo-scalar scalout = False sset, ssout = _recursive_in_check(last_edge.src, state, gpu_scalars) scalset = scalset.union(sset) @@ -81,10 +81,6 @@ def _recursive_in_check(node, state, gpu_scalars): return scalset, scalout -def _codenode_condition(node): - return isinstance(node, (nodes.LibraryNode, nodes.NestedSDFG)) and node.schedule == dtypes.ScheduleType.GPU_Default - - @make_properties class GPUTransformSDFG(transformation.MultiStateTransformation): """ Implements the GPUTransformSDFG transformation. @@ -305,33 +301,63 @@ def apply(self, _, sdfg: sd.SDFG): ####################################################### # Step 5: Collect free tasklets and check for scalars that have to be moved to the GPU + # Also recursively call GPUTransformSDFG on NestedSDFGs that have GPU device schedule but are not actually + # inside a GPU kernel. gpu_scalars = {} + nsdfgs = [] changed = True # Iterates over Tasklets that not inside a GPU kernel. Such Tasklets must be moved inside a GPU kernel only # if they write to GPU memory. The check takes into account the fact that GPU kernels can read host-based # Scalars, but cannot write to them. while changed: changed = False - for node, state in sdfg.all_nodes_recursive(): - if isinstance(node, nodes.Tasklet): - if node in global_code_nodes[state]: - continue - if state.entry_node(node) is None and not scope.is_devicelevel_gpu_kernel( - state.parent, state, node): - scalars, scalar_output = _recursive_out_check(node, state, gpu_scalars) - sset, ssout = _recursive_in_check(node, state, gpu_scalars) - scalars = scalars.union(sset) - scalar_output = scalar_output and ssout - csdfg = state.parent - # If the tasklet is not adjacent only to scalars or it is in a GPU scope. - # The latter includes NestedSDFGs that have a GPU-Device schedule but are not in a GPU kernel. - if (not scalar_output - or (csdfg.parent is not None - and csdfg.parent_nsdfg_node.schedule == dtypes.ScheduleType.GPU_Default)): - global_code_nodes[state].append(node) - gpu_scalars.update({k: None for k in scalars}) - changed = True + for state in sdfg.states(): + for node in state.nodes(): + # Handle NestedSDFGs later. + if isinstance(node, nodes.NestedSDFG): + if state.entry_node(node) is None and not scope.is_devicelevel_gpu_kernel( + state.parent, state, node): + nsdfgs.append((node, state)) + elif isinstance(node, nodes.Tasklet): + if node in global_code_nodes[state]: + continue + if state.entry_node(node) is None and not scope.is_devicelevel_gpu_kernel( + state.parent, state, node): + scalars, scalar_output = _recursive_out_check(node, state, gpu_scalars) + sset, ssout = _recursive_in_check(node, state, gpu_scalars) + scalars = scalars.union(sset) + scalar_output = scalar_output and ssout + csdfg = state.parent + # If the tasklet is not adjacent only to scalars or it is in a GPU scope. + # The latter includes NestedSDFGs that have a GPU-Device schedule but are not in a GPU kernel. + if (not scalar_output + or (csdfg.parent is not None + and csdfg.parent_nsdfg_node.schedule == dtypes.ScheduleType.GPU_Default)): + global_code_nodes[state].append(node) + gpu_scalars.update({k: None for k in scalars}) + changed = True + + # Apply GPUTransformSDFG recursively to NestedSDFGs. + for node, state in nsdfgs: + excl_copyin = set() + for e in state.in_edges(node): + src = state.memlet_path(e)[0].src + if isinstance(src, nodes.AccessNode) and sdfg.arrays[src.data].storage in gpu_storage: + excl_copyin.add(e.dst_conn) + node.sdfg.arrays[e.dst_conn].storage = sdfg.arrays[src.data].storage + excl_copyout = set() + for e in state.out_edges(node): + dst = state.memlet_path(e)[-1].dst + if isinstance(dst, nodes.AccessNode) and sdfg.arrays[dst.data].storage in gpu_storage: + excl_copyout.add(e.src_conn) + node.sdfg.arrays[e.src_conn].storage = sdfg.arrays[dst.data].storage + # TODO: Do we want to copy here the options from the top-level SDFG? + node.sdfg.apply_transformations( + GPUTransformSDFG, { + 'exclude_copyin': ','.join([str(n) for n in excl_copyin]), + 'exclude_copyout': ','.join([str(n) for n in excl_copyout]) + }) ####################################################### # Step 6: Modify transient data storage @@ -350,26 +376,9 @@ def apply(self, _, sdfg: sd.SDFG): if sdict[node] is None and nodedesc.storage not in gpu_storage: - # Ensure that scalars not already GPU-marked are actually used in a GPU scope. + # Scalars were already checked. if isinstance(nodedesc, data.Scalar) and not node.data in gpu_scalars: - used_in_gpu_scope = False - for e in state.in_edges(node): - if _codenode_condition(state.memlet_path(e)[0].src): - used_in_gpu_scope = True - break - if not used_in_gpu_scope: - for e in state.out_edges(node): - if _codenode_condition(state.memlet_path(e)[-1].dst): - used_in_gpu_scope = True - break - if not used_in_gpu_scope: - continue - for e in state.all_edges(node): - for node in (e.src, e.dst): - if isinstance(node, nodes.Tasklet): - if (state.entry_node(node) is None and not scope.is_devicelevel_gpu( - state.parent, state, node, with_gpu_default=True)): - global_code_nodes[state].append(node) + continue # NOTE: the cloned arrays match too but it's the same storage so we don't care nodedesc.storage = dtypes.StorageType.GPU_Global @@ -470,5 +479,5 @@ def apply(self, _, sdfg: sd.SDFG): # Step 9: Simplify if not self.simplify: return - + sdfg.simplify() From 0b597741bfaa2f8835a348a2e83b5f4c6a49c469 Mon Sep 17 00:00:00 2001 From: Alexandros Nikolaos Ziogas Date: Thu, 29 Sep 2022 11:02:06 +0200 Subject: [PATCH 08/16] Fixed test. --- tests/transformations/gpu_transform_test.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/transformations/gpu_transform_test.py b/tests/transformations/gpu_transform_test.py index 706eeaabc1..d6814273a6 100644 --- a/tests/transformations/gpu_transform_test.py +++ b/tests/transformations/gpu_transform_test.py @@ -54,8 +54,7 @@ def main_program(a: dace.int32): sdfg = main_program.to_sdfg(simplify=False) sdfg.apply_transformations(GPUTransformSDFG) - out = np.empty((10,), dtype=np.int32) - sdfg(a=4, out=out) + out = sdfg(a=4) assert np.array_equal(out, np.array([0, 10] * 5, dtype=np.int32)) From bd92660aea4e4b19d4a6229e91865f5e2743e8e1 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Mon, 10 Oct 2022 11:08:56 +0200 Subject: [PATCH 09/16] Add symbol-based loop to map tests --- tests/transformations/loop_to_map_test.py | 93 +++++++++++++++++++++-- 1 file changed, 85 insertions(+), 8 deletions(-) diff --git a/tests/transformations/loop_to_map_test.py b/tests/transformations/loop_to_map_test.py index 2a2b51beab..08c785d282 100644 --- a/tests/transformations/loop_to_map_test.py +++ b/tests/transformations/loop_to_map_test.py @@ -1,9 +1,12 @@ # Copyright 2020-2020 ETH Zurich and the DaCe authors. All rights reserved. import argparse -import dace -import numpy as np import os import tempfile + +import numpy as np +import pytest + +import dace from dace.sdfg import nodes from dace.transformation.interstate import LoopToMap @@ -291,8 +294,8 @@ def test_interstate_dep(): def test_need_for_tasklet(): sdfg = dace.SDFG('needs_tasklet') - aname, _ = sdfg.add_array('A', (10,), dace.int32) - bname, _ = sdfg.add_array('B', (10,), dace.int32) + aname, _ = sdfg.add_array('A', (10, ), dace.int32) + bname, _ = sdfg.add_array('B', (10, ), dace.int32) body = sdfg.add_state('body') _, _, _ = sdfg.add_loop(None, body, None, 'i', '0', 'i < 10', 'i + 1', None) anode = body.add_access(aname) @@ -305,11 +308,11 @@ def test_need_for_tasklet(): if isinstance(n, nodes.Tasklet): found = True break - + assert found A = np.arange(10, dtype=np.int32) - B = np.empty((10,), dtype=np.int32) + B = np.empty((10, ), dtype=np.int32) sdfg(A=A, B=B) assert np.array_equal(B, np.arange(9, -1, -1, dtype=np.int32)) @@ -332,7 +335,7 @@ def test_need_for_transient(): if isinstance(n, nodes.AccessNode) and n.data not in (aname, bname): found = True break - + assert found A = np.arange(100, dtype=np.int32).reshape(10, 10).copy() @@ -341,8 +344,77 @@ def test_need_for_transient(): for i in range(10): start = i * 10 - assert np.array_equal(B[i], np.arange(start + 9, start -1, -1, dtype=np.int32)) + assert np.array_equal(B[i], np.arange(start + 9, start - 1, -1, dtype=np.int32)) + + +def test_symbol_race(): + + # Adapted from npbench's crc16 test + # https://github.com/spcl/npbench/blob/main/npbench/benchmarks/crc16/crc16_dace.py + poly: dace.uint16 = 0x8408 + + @dace.program + def tester(data: dace.int32[20]): + crc: dace.uint16 = 0xFFFF + for i in range(20): + b = data[i] + cur_byte = 0xFF & b + for _ in range(0, 8): + if (crc & 0x0001) ^ (cur_byte & 0x0001): + crc = (crc >> 1) ^ poly + else: + crc >>= 1 + cur_byte >>= 1 + crc = (~crc & 0xFFFF) + crc = (crc << 8) | ((crc >> 8) & 0xFF) + + sdfg = tester.to_sdfg(simplify=True) + assert sdfg.apply_transformations(LoopToMap) == 0 + + +def test_symbol_write_before_read(): + sdfg = dace.SDFG('tester') + init = sdfg.add_state(is_start_state=True) + body_start = sdfg.add_state() + body = sdfg.add_state() + body_end = sdfg.add_state() + sdfg.add_loop(init, body_start, None, 'i', '0', 'i < 20', 'i + 1', loop_end_state=body_end) + + # Internal loop structure + sdfg.add_edge(body_start, body, dace.InterstateEdge(assignments=dict(j='0'))) + sdfg.add_edge(body, body_end, dace.InterstateEdge(assignments=dict(j='j + 1'))) + + assert sdfg.apply_transformations(LoopToMap) == 1 + + +def test_symbol_array_mix(): + pass + + +@pytest.mark.parametrize('overwrite', (False, True)) +def test_internal_symbol_used_outside(overwrite): + sdfg = dace.SDFG('tester') + init = sdfg.add_state(is_start_state=True) + body_start = sdfg.add_state() + body = sdfg.add_state() + body_end = sdfg.add_state() + after = sdfg.add_state() + sdfg.add_loop(init, body_start, after, 'i', '0', 'i < 20', 'i + 1', loop_end_state=body_end) + + # Internal loop structure + sdfg.add_edge(body_start, body, dace.InterstateEdge(assignments=dict(j='0'))) + sdfg.add_edge(body, body_end, dace.InterstateEdge(assignments=dict(j='j + 1'))) + + # Use after + after_1 = sdfg.add_state() + after_1.add_tasklet('use', {}, {}, 'printf("%d\\n", j)') + + if overwrite: + sdfg.add_edge(after, after_1, dace.InterstateEdge(assignments=dict(j='5'))) + else: + sdfg.add_edge(after, after_1, dace.InterstateEdge()) + assert sdfg.apply_transformations(LoopToMap) == (1 if overwrite else 0) if __name__ == "__main__": @@ -365,3 +437,8 @@ def test_need_for_transient(): test_interstate_dep() test_need_for_tasklet() test_need_for_transient() + test_symbol_race() + test_symbol_write_before_read() + test_symbol_array_mix() + test_internal_symbol_used_outside(False) + test_internal_symbol_used_outside(True) From 3afe91b085fd7fec98a1a6e62bd0b5008d36d24c Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Mon, 10 Oct 2022 13:42:17 +0200 Subject: [PATCH 10/16] One more test --- tests/transformations/loop_to_map_test.py | 36 ++++++++++++++++++----- 1 file changed, 28 insertions(+), 8 deletions(-) diff --git a/tests/transformations/loop_to_map_test.py b/tests/transformations/loop_to_map_test.py index 08c785d282..688a8a9080 100644 --- a/tests/transformations/loop_to_map_test.py +++ b/tests/transformations/loop_to_map_test.py @@ -13,8 +13,7 @@ def make_sdfg(with_wcr, map_in_guard, reverse_loop, use_variable, assign_after, log_path): - sdfg = dace.SDFG(f"loop_to_map_test_{with_wcr}_{map_in_guard}_" - f"{reverse_loop}_{use_variable}_{assign_after}") + sdfg = dace.SDFG(f"loop_to_map_test_{with_wcr}_{map_in_guard}_{reverse_loop}_{use_variable}_{assign_after}") sdfg.set_global_code("#include \n#include ") init = sdfg.add_state("init") @@ -165,7 +164,6 @@ def test_loop_to_map_variable_reassigned(n=None): def test_output_copy(): - @dace.program def l2mtest_copy(A: dace.float64[20, 20]): for i in range(1, 20): @@ -185,7 +183,6 @@ def l2mtest_copy(A: dace.float64[20, 20]): def test_output_accumulate(): - @dace.program def l2mtest_accumulate(A: dace.float64[20, 20]): for i in range(1, 20): @@ -243,7 +240,6 @@ def detect_greater(i: _[0:size]): def test_empty_loop(): - @dace.program def empty_loop(): for i in range(10): @@ -387,8 +383,31 @@ def test_symbol_write_before_read(): assert sdfg.apply_transformations(LoopToMap) == 1 -def test_symbol_array_mix(): - pass +@pytest.mark.parametrize('overwrite', (False, True)) +def test_symbol_array_mix(overwrite): + sdfg = dace.SDFG('tester') + sdfg.add_transient('tmp', [1], dace.float64) + sdfg.add_symbol('sym', dace.float64) + init = sdfg.add_state(is_start_state=True) + body_start = sdfg.add_state() + body = sdfg.add_state() + body_end = sdfg.add_state() + after = sdfg.add_state() + sdfg.add_loop(init, body_start, after, 'i', '0', 'i < 20', 'i + 1', loop_end_state=body_end) + + sdfg.out_edges(init)[0].data.assignments['sym'] = '0.0' + + # Internal loop structure + t = body_start.add_tasklet('def', {}, {'o'}, 'o = i') + body_start.add_edge(t, 'o', body_start.add_write('tmp'), None, dace.Memlet('tmp')) + + if overwrite: + sdfg.add_edge(body_start, body, dace.InterstateEdge(assignments=dict(sym='tmp'))) + else: + sdfg.add_edge(body_start, body, dace.InterstateEdge(assignments=dict(sym='sym + tmp'))) + sdfg.add_edge(body, body_end, dace.InterstateEdge(assignments=dict(sym='sym + 1.0'))) + + assert sdfg.apply_transformations(LoopToMap) == (1 if overwrite else 0) @pytest.mark.parametrize('overwrite', (False, True)) @@ -439,6 +458,7 @@ def test_internal_symbol_used_outside(overwrite): test_need_for_transient() test_symbol_race() test_symbol_write_before_read() - test_symbol_array_mix() + test_symbol_array_mix(False) + test_symbol_array_mix(True) test_internal_symbol_used_outside(False) test_internal_symbol_used_outside(True) From 63c6356b45b0e4929a657ef2b17c86ee1fba17b7 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Mon, 10 Oct 2022 13:58:48 +0200 Subject: [PATCH 11/16] More info in validation exceptions --- dace/frontend/python/newast.py | 1 + dace/sdfg/validation.py | 36 ++++++++++++++++++++++------------ 2 files changed, 25 insertions(+), 12 deletions(-) diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index e67bb2b412..cb6a620adb 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -2427,6 +2427,7 @@ def visit_Continue(self, node: ast.Continue): def visit_If(self, node: ast.If): # Add a guard state self._add_state('if_guard') + self.last_state.debuginfo = self.current_lineinfo # Generate conditions cond, cond_else = self._visit_test(node.test) diff --git a/dace/sdfg/validation.py b/dace/sdfg/validation.py index 2ceade4b5c..8bf8f870eb 100644 --- a/dace/sdfg/validation.py +++ b/dace/sdfg/validation.py @@ -141,7 +141,9 @@ def validate_sdfg(sdfg: 'dace.sdfg.SDFG', references: Set[int] = None): undef_syms = set(edge.data.free_symbols) - set(symbols.keys()) if len(undef_syms) > 0: eid = sdfg.edge_id(edge) - raise InvalidSDFGInterstateEdgeError("Undefined symbols in edge: %s" % undef_syms, sdfg, eid) + raise InvalidSDFGInterstateEdgeError( + f'Undefined symbols in edge: {undef_syms}. Add those with ' + '`sdfg.add_symbol()` or define outside with `dace.symbol()`', sdfg, eid) # Validate inter-state edge names issyms = edge.data.new_symbols(sdfg, symbols) @@ -231,8 +233,7 @@ def validate_state(state: 'dace.sdfg.SDFGState', raise InvalidSDFGError("Invalid state name", sdfg, state_id) if state._parent != sdfg: - raise InvalidSDFGError("State does not point to the correct " - "parent", sdfg, state_id) + raise InvalidSDFGError("State does not point to the correct " "parent", sdfg, state_id) # Unreachable ######################################## @@ -618,7 +619,6 @@ def validate_state(state: 'dace.sdfg.SDFGState', class InvalidSDFGError(Exception): """ A class of exceptions thrown when SDFG validation fails. """ - def __init__(self, message: str, sdfg: 'SDFG', state_id: int): self.message = message self.sdfg = sdfg @@ -641,8 +641,7 @@ def _getlineinfo(self, obj) -> str: if lineinfo.start_line >= 0: if lineinfo.start_column > 0: - return (f'File "{lineinfo.filename}", line {lineinfo.start_line}, ' - f'column {lineinfo.start_column}') + return (f'File "{lineinfo.filename}", line {lineinfo.start_line}, ' f'column {lineinfo.start_column}') return f'File "{lineinfo.filename}", line {lineinfo.start_line}' return f'File "{lineinfo.filename}"' @@ -670,7 +669,6 @@ def __str__(self): class InvalidSDFGInterstateEdgeError(InvalidSDFGError): """ Exceptions of invalid inter-state edges in an SDFG. """ - def __init__(self, message: str, sdfg: 'SDFG', edge_id: int): self.message = message self.sdfg = sdfg @@ -687,15 +685,31 @@ def __str__(self): str(e.src), str(e.dst), ) + locinfo_src = self._getlineinfo(e.src) + locinfo_dst = self._getlineinfo(e.dst) + else: + edgestr = '' + locinfo_src = locinfo_dst = '' + + if locinfo_src or locinfo_dst: + if locinfo_src == locinfo_dst: + locinfo = f'at {locinfo_src}' + elif locinfo_src and not locinfo_dst: + locinfo = f'at {locinfo_src}' + elif locinfo_dst and not locinfo_src: + locinfo = f'at {locinfo_src}' + else: + locinfo = f'between\n {locinfo_src}\n and\n {locinfo_dst}' + + locinfo = f'\nOriginating from source code {locinfo}' else: - edgestr = "" + locinfo = '' - return "%s%s" % (self.message, edgestr) + return f'{self.message}{edgestr}{locinfo}' class InvalidSDFGNodeError(InvalidSDFGError): """ Exceptions of invalid nodes in an SDFG state. """ - def __init__(self, message: str, sdfg: 'SDFG', state_id: int, node_id: int): self.message = message self.sdfg = sdfg @@ -729,14 +743,12 @@ class NodeNotExpandedError(InvalidSDFGNodeError): Exception that is raised whenever a library node was not expanded before code generation. """ - def __init__(self, sdfg: 'SDFG', state_id: int, node_id: int): super().__init__('Library node not expanded', sdfg, state_id, node_id) class InvalidSDFGEdgeError(InvalidSDFGError): """ Exceptions of invalid edges in an SDFG state. """ - def __init__(self, message: str, sdfg: 'SDFG', state_id: int, edge_id: int): self.message = message self.sdfg = sdfg From c25a0cf1f7bd9b0e76fc685d2bd64142774ecc41 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Mon, 10 Oct 2022 13:59:12 +0200 Subject: [PATCH 12/16] Add test --- tests/transformations/loop_to_map_test.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/transformations/loop_to_map_test.py b/tests/transformations/loop_to_map_test.py index 688a8a9080..c57882c77b 100644 --- a/tests/transformations/loop_to_map_test.py +++ b/tests/transformations/loop_to_map_test.py @@ -342,6 +342,22 @@ def test_need_for_transient(): start = i * 10 assert np.array_equal(B[i], np.arange(start + 9, start - 1, -1, dtype=np.int32)) +def test_iteration_variable_used_outside(): + N = dace.symbol("N", dace.int32) + + @dace.program + def tester(A: dace.float64[N], output: dace.float64[1]): + i = -1 + + for i in range(N): + A[i] += 1 + + if i > 10: + output[0] = 1.0 + + sdfg = tester.to_sdfg(simplify=True) + assert sdfg.apply_transformations(LoopToMap) == 0 + def test_symbol_race(): @@ -456,6 +472,7 @@ def test_internal_symbol_used_outside(overwrite): test_interstate_dep() test_need_for_tasklet() test_need_for_transient() + test_iteration_variable_used_outside() test_symbol_race() test_symbol_write_before_read() test_symbol_array_mix(False) From 584b4924fd22c3b971f0551fa1e60472a115b57f Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Mon, 10 Oct 2022 22:40:49 +0200 Subject: [PATCH 13/16] Add tests --- tests/transformations/loop_to_map_test.py | 29 ++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/tests/transformations/loop_to_map_test.py b/tests/transformations/loop_to_map_test.py index c57882c77b..b2940b259d 100644 --- a/tests/transformations/loop_to_map_test.py +++ b/tests/transformations/loop_to_map_test.py @@ -351,7 +351,7 @@ def tester(A: dace.float64[N], output: dace.float64[1]): for i in range(N): A[i] += 1 - + if i > 10: output[0] = 1.0 @@ -425,6 +425,31 @@ def test_symbol_array_mix(overwrite): assert sdfg.apply_transformations(LoopToMap) == (1 if overwrite else 0) +@pytest.mark.parametrize('parallel', (False, True)) +def test_symbol_array_mix_2(parallel): + sdfg = dace.SDFG('tester') + sdfg.add_array('A', [20], dace.float64) + sdfg.add_array('B', [20], dace.float64) + sdfg.add_symbol('sym', dace.float64) + init = sdfg.add_state(is_start_state=True) + body_start = sdfg.add_state() + body_end = sdfg.add_state() + after = sdfg.add_state() + sdfg.add_loop(init, body_start, after, 'i', '1', 'i < 20', 'i + 1', loop_end_state=body_end) + + sdfg.out_edges(init)[0].data.assignments['sym'] = '0.0' + + # Internal loop structure + if not parallel: + t = body_start.add_tasklet('def', {}, {'o'}, 'o = i') + body_start.add_edge(t, 'o', body_start.add_write('A'), None, dace.Memlet('A[i]')) + + sdfg.add_edge(body_start, body_end, dace.InterstateEdge(assignments=dict(sym='A[i - 1]'))) + t = body_start.add_tasklet('use', {}, {'o'}, 'o = sym') + body_start.add_edge(t, 'o', body_start.add_write('B'), None, dace.Memlet('B[i]')) + + assert sdfg.apply_transformations(LoopToMap) == (1 if parallel else 0) + @pytest.mark.parametrize('overwrite', (False, True)) def test_internal_symbol_used_outside(overwrite): @@ -477,5 +502,7 @@ def test_internal_symbol_used_outside(overwrite): test_symbol_write_before_read() test_symbol_array_mix(False) test_symbol_array_mix(True) + test_symbol_array_mix_2(False) + test_symbol_array_mix_2(True) test_internal_symbol_used_outside(False) test_internal_symbol_used_outside(True) From 12732c5a3033e76c369cd9e9e25be48d0daba797 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Mon, 10 Oct 2022 23:13:13 +0200 Subject: [PATCH 14/16] Consider symbols in LoopToMap --- dace/sdfg/sdfg.py | 15 +++- dace/transformation/interstate/loop_to_map.py | 87 ++++++++++++------- 2 files changed, 67 insertions(+), 35 deletions(-) diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index 82ef220cbe..471296cd42 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -196,15 +196,22 @@ def condition_sympy(self): self._cond_sympy = symbolic.pystr_to_symbolic(self.condition.as_string) return self._cond_sympy - @property - def free_symbols(self) -> Set[str]: - """ Returns a set of symbols used in this edge's properties. """ + def read_symbols(self) -> Set[str]: + """ + Returns a set of symbols read in this edge (including symbols in the condition and assignment values). + """ # Symbols in conditions and assignments result = set(map(str, dace.symbolic.symbols_in_ast(self.condition.code[0]))) for assign in self.assignments.values(): result |= symbolic.free_symbols_and_functions(assign) - return result - set(self.assignments.keys()) + return result + + @property + def free_symbols(self) -> Set[str]: + """ Returns a set of symbols used in this edge's properties. """ + return self.read_symbols() - set(self.assignments.keys()) + def replace_dict(self, repl: Dict[str, str], replace_keys=True) -> None: """ diff --git a/dace/transformation/interstate/loop_to_map.py b/dace/transformation/interstate/loop_to_map.py index 8f1bfe5d39..ac9dfeb920 100644 --- a/dace/transformation/interstate/loop_to_map.py +++ b/dace/transformation/interstate/loop_to_map.py @@ -114,25 +114,41 @@ def can_be_applied(self, graph: SDFGState, expr_index: int, sdfg: SDFG, permissi if symbolic.contains_sympy_functions(expr): return False + in_order_states = list(cfg.stateorder_topological_sort(sdfg)) + loop_begin_idx = in_order_states.index(begin) + loop_end_idx = in_order_states.index(body_end) + + if loop_end_idx < loop_begin_idx: # Malformed loop + return False + # Find all loop-body states - states = set() - to_visit = [begin] - while to_visit: - state = to_visit.pop(0) - for _, dst, _ in sdfg.out_edges(state): - if dst not in states and dst is not guard: - to_visit.append(dst) - states.add(state) + states: List[SDFGState] = in_order_states[loop_begin_idx:loop_end_idx + 1] assert (body_end in states) - write_set = set() + write_set: Set[str] = set() for state in states: _, wset = state.read_and_write_sets() write_set |= wset + # Collect symbol reads and writes from inter-state assignments + symbols_that_may_be_used: Set[str] = {itervar} + used_before_assignment: Set[str] = set() + for state in states: + for e in sdfg.out_edges(state): + # Collect read-before-assigned symbols (this works because the states are always in order, + # see above call to `stateorder_topological_sort`) + read_symbols = e.data.read_symbols() + read_symbols -= symbols_that_may_be_used + used_before_assignment |= read_symbols + # If symbol was read before it is assigned, the loop cannot be parallel + if e.data.assignments.keys() & used_before_assignment: + return False + + symbols_that_may_be_used |= e.data.assignments.keys() + # Get access nodes from other states to isolate local loop variables - other_access_nodes = set() + other_access_nodes: Set[str] = set() for state in sdfg.nodes(): if state in states: continue @@ -141,7 +157,7 @@ def can_be_applied(self, graph: SDFGState, expr_index: int, sdfg: SDFG, permissi for state in states: other_access_nodes |= set(n.data for n in state.data_nodes() if not sdfg.arrays[n.data].transient) - write_memlets = defaultdict(list) + write_memlets: Dict[str, List[memlet.Memlet]] = defaultdict(list) itersym = symbolic.pystr_to_symbolic(itervar) a = sp.Wild('a', exclude=[itersym]) @@ -185,7 +201,7 @@ def can_be_applied(self, graph: SDFGState, expr_index: int, sdfg: SDFG, permissi return False # Consider reads in inter-state edges (could be in assignments or in condition) - isread_set = set() + isread_set: Set[memlet.Memlet] = set() for s in states: for e in sdfg.all_edges(s): isread_set |= set(e.data.get_read_memlets(sdfg.arrays)) @@ -195,26 +211,34 @@ def can_be_applied(self, graph: SDFGState, expr_index: int, sdfg: SDFG, permissi mmlt.subset): return False - # Check that the iteration variable is not used on other edges or states - # before it is reassigned - prior_states = True - for state in cfg.stateorder_topological_sort(sdfg): - # Skip all states up to guard - if prior_states: - if state is begin: - prior_states = False - continue - # We do not need to check the loop-body states - if state in states: - continue - if itervar in state.free_symbols: - return False - # Don't continue in this direction, as the variable has - # now been reassigned - # TODO: Handle case of subset of out_edges - if all(itervar in e.data.assignments for e in sdfg.out_edges(state)): + # Check that the iteration variable and other symbols are not used on other edges or states + # before they are reassigned + for state in in_order_states[loop_end_idx + 1:]: + # Don't continue in this direction, as all loop symbols have been reassigned + if not symbols_that_may_be_used: break + # Check state contents + if symbols_that_may_be_used & state.free_symbols: + return False + + # Check inter-state edges + reassigned_symbols: Set[str] = None + for e in sdfg.out_edges(state): + if symbols_that_may_be_used & e.data.read_symbols(): + return False + + # Check for symbols that are set by all outgoing edges + # TODO: Handle case of subset of out_edges + if reassigned_symbols is None: + reassigned_symbols = set(e.data.assignments.keys()) + else: + reassigned_symbols &= e.data.assignments.keys() + + # Remove reassigned symbols + if reassigned_symbols is not None: + symbols_that_may_be_used -= reassigned_symbols + return True def test_read_memlet(self, sdfg: SDFG, itersym: symbolic.SymbolicType, itervar: str, start: symbolic.SymbolicType, @@ -390,7 +414,8 @@ def apply(self, _, sdfg: sd.SDFG): # Fix SDFG symbols for sym in sdfg.free_symbols - fsymbols: - del sdfg.symbols[sym] + if sym in sdfg.symbols: + del sdfg.symbols[sym] for sym, dtype in nsymbols.items(): nsdfg.symbols[sym] = dtype From 2f6f06d85cf21b8226c5d57e5b8dc3182b94fd60 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Tue, 11 Oct 2022 02:55:19 +0200 Subject: [PATCH 15/16] Fix state-order traversal w.r.t. if branches without else --- dace/sdfg/analysis/cfg.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/dace/sdfg/analysis/cfg.py b/dace/sdfg/analysis/cfg.py index 20b61446df..2c782e9800 100644 --- a/dace/sdfg/analysis/cfg.py +++ b/dace/sdfg/analysis/cfg.py @@ -219,12 +219,10 @@ def _stateorder_topological_sort(sdfg: SDFG, """ # Traverse states in custom order visited = visited or set() - if stop is not None: - visited.add(stop) stack = [start] while stack: node = stack.pop() - if node in visited: + if node in visited or node is stop: continue yield node @@ -265,6 +263,9 @@ def _stateorder_topological_sort(sdfg: SDFG, mergestate = stop for branch in oe: + if branch.dst is mergestate: + # If we hit the merge state (if without else), defer to end of branch traversal + continue for s in _stateorder_topological_sort(sdfg, branch.dst, ptree, @@ -273,8 +274,7 @@ def _stateorder_topological_sort(sdfg: SDFG, visited=visited): yield s visited.add(s) - if mergestate != stop: - stack.append(mergestate) + stack.append(mergestate) def stateorder_topological_sort(sdfg: SDFG) -> Iterator[SDFGState]: From 42cd81fd441f3b14e50431aadfc44038b8cb3464 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Tue, 11 Oct 2022 10:52:49 +0200 Subject: [PATCH 16/16] Minor fix --- dace/sdfg/analysis/cfg.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dace/sdfg/analysis/cfg.py b/dace/sdfg/analysis/cfg.py index 2c782e9800..f2926e7bc3 100644 --- a/dace/sdfg/analysis/cfg.py +++ b/dace/sdfg/analysis/cfg.py @@ -225,6 +225,7 @@ def _stateorder_topological_sort(sdfg: SDFG, if node in visited or node is stop: continue yield node + visited.add(node) oe = sdfg.out_edges(node) if len(oe) == 0: # End state