From 25ba2de8c466cea295f75ad0cc53f2abdf13afb6 Mon Sep 17 00:00:00 2001 From: Lukas Truemper Date: Sat, 2 Sep 2023 15:34:08 +0200 Subject: [PATCH] TaskletFusion: Fix additional edges in case of none-connectors --- .../transformation/dataflow/tasklet_fusion.py | 2 + tests/transformations/tasklet_fusion_test.py | 44 +++++++++++++++++++ 2 files changed, 46 insertions(+) diff --git a/dace/transformation/dataflow/tasklet_fusion.py b/dace/transformation/dataflow/tasklet_fusion.py index 99f8f625be..ea1649b51b 100644 --- a/dace/transformation/dataflow/tasklet_fusion.py +++ b/dace/transformation/dataflow/tasklet_fusion.py @@ -249,6 +249,8 @@ def apply(self, graph: dace.SDFGState, sdfg: dace.SDFG): t1.language) for in_edge in graph.in_edges(t1): + if len(new_tasklet.in_connectors) > 0 and in_edge.src_conn is None: + continue graph.add_edge(in_edge.src, in_edge.src_conn, new_tasklet, in_edge.dst_conn, in_edge.data) for in_edge in graph.in_edges(t2): diff --git a/tests/transformations/tasklet_fusion_test.py b/tests/transformations/tasklet_fusion_test.py index c7fd6802d5..743010e8c9 100644 --- a/tests/transformations/tasklet_fusion_test.py +++ b/tests/transformations/tasklet_fusion_test.py @@ -213,6 +213,49 @@ def test_map_with_tasklets(language: str, with_data: bool): ref = map_with_tasklets.f(A, B) assert (np.allclose(C, ref)) +def test_none_connector(): + @dace.program + def sdfg_none_connector(A: dace.float32[32], B: dace.float32[32]): + tmp = dace.define_local([32], dace.float32) + for i in dace.map[0:32]: + with dace.tasklet: + a >> tmp[i] + a = 0 + + tmp2 = dace.define_local([32], dace.float32) + for i in dace.map[0:32]: + with dace.tasklet: + a << A[i] + b >> tmp2[i] + b = a + 1 + + + for i in dace.map[0:32]: + with dace.tasklet: + a << tmp[i] + b << tmp2[i] + c >> B[i] + c = a + b + + sdfg = sdfg_none_connector.to_sdfg() + sdfg.simplify() + applied = sdfg.apply_transformations_repeated(MapFusion) + assert applied == 2 + + map_entry = None + for node in sdfg.start_state.nodes(): + if isinstance(node, dace.nodes.MapEntry): + map_entry = node + break + + assert map_entry is not None + assert len([edge.src_conn for edge in sdfg.start_state.out_edges(map_entry) if edge.src_conn is None]) == 1 + + applied = sdfg.apply_transformations_repeated(TaskletFusion) + assert applied == 2 + + assert sdfg.start_state.out_degree(map_entry) == 1 + assert len([edge.src_conn for edge in sdfg.start_state.out_edges(map_entry) if edge.src_conn is None]) == 0 if __name__ == '__main__': test_basic() @@ -224,3 +267,4 @@ def test_map_with_tasklets(language: str, with_data: bool): test_map_with_tasklets(language='Python', with_data=True) test_map_with_tasklets(language='CPP', with_data=False) test_map_with_tasklets(language='CPP', with_data=True) + test_none_connector()