From 79cf2ff96d58766327587eddcc9bb7e10fb702a6 Mon Sep 17 00:00:00 2001 From: alexnick83 <31545860+alexnick83@users.noreply.github.com> Date: Thu, 30 Nov 2023 16:45:44 +0100 Subject: [PATCH] In-out connector's global source when connector becomes out-only at outer SDFG scopes. (#1463) Adds utility-method support for the case of an in-out nested SDFG connector that is out-only at outer SDFG scopes. --- dace/sdfg/utils.py | 5 +++ tests/sdfg/validation/nested_sdfg_test.py | 48 +++++++++++++++++++++-- 2 files changed, 50 insertions(+), 3 deletions(-) diff --git a/dace/sdfg/utils.py b/dace/sdfg/utils.py index d0f1a67ab9..1405901802 100644 --- a/dace/sdfg/utils.py +++ b/dace/sdfg/utils.py @@ -1847,6 +1847,11 @@ def get_global_memlet_path_src(sdfg: SDFG, state: SDFGState, edge: MultiConnecto if len(pedges) > 0: pedge = pedges[0] return get_global_memlet_path_src(psdfg, pstate, pedge) + else: + pedges = list(pstate.out_edges_by_connector(pnode, src.data)) + if len(pedges) > 0: + pedge = pedges[0] + return get_global_memlet_path_dst(psdfg, pstate, pedge) return src diff --git a/tests/sdfg/validation/nested_sdfg_test.py b/tests/sdfg/validation/nested_sdfg_test.py index 100568507e..67ed8ab2a8 100644 --- a/tests/sdfg/validation/nested_sdfg_test.py +++ b/tests/sdfg/validation/nested_sdfg_test.py @@ -1,5 +1,6 @@ # Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. import dace +import numpy as np def test_inout_connector_validation_success(): @@ -33,6 +34,48 @@ def test_inout_connector_validation_success(): return +def test_inout_connector_validation_success_2(): + + sdfg = dace.SDFG("test_inout_connector_validation_success_2") + sdfg.add_array("A", [1], dace.int32) + + nsdfg_0 = dace.SDFG("nested_sdfg_0") + nsdfg_0.add_array("B", [1], dace.int32) + + nsdfg_1 = dace.SDFG("nested_sdfg_1") + nsdfg_1.add_array("C", [1], dace.int32) + + nstate = nsdfg_1.add_state() + read_c = nstate.add_access("C") + write_c = nstate.add_access("C") + tasklet = nstate.add_tasklet("tasklet", {"__inp"}, {"__out"}, "__out = __inp + 5") + nstate.add_edge(read_c, None, tasklet, '__inp', dace.Memlet.from_array('C', nsdfg_1.arrays['C'])) + nstate.add_edge(tasklet, '__out', write_c, None, dace.Memlet.from_array('C', nsdfg_1.arrays['C'])) + + nstate = nsdfg_0.add_state() + tasklet_0 = nstate.add_tasklet("tasklet_00", {}, {"__out"}, "__out = 3") + write_b_0 = nstate.add_access("B") + tasklet_1 = nstate.add_nested_sdfg(nsdfg_1, nsdfg_0, {"C"}, {"C"}) + write_b_1 = nstate.add_access("B") + nstate.add_edge(tasklet_0, '__out', write_b_0, None, dace.Memlet.from_array('B', nsdfg_0.arrays['B'])) + nstate.add_edge(write_b_0, None, tasklet_1, 'C', dace.Memlet.from_array('B', nsdfg_0.arrays['B'])) + nstate.add_edge(tasklet_1, 'C', write_b_1, None, dace.Memlet.from_array('B', nsdfg_0.arrays['B'])) + + state = sdfg.add_state() + tasklet = state.add_nested_sdfg(nsdfg_0, sdfg, {}, {"B"}) + write_a = state.add_access("A") + state.add_edge(tasklet, 'B', write_a, None, dace.Memlet.from_array('A', sdfg.arrays['A'])) + + try: + sdfg.validate() + except dace.sdfg.InvalidSDFGError: + assert False, "SDFG should validate" + + A = np.array([1], dtype=np.int32) + sdfg(A=A) + assert A[0] == 8 + + def test_inout_connector_validation_fail(): sdfg = dace.SDFG("test_inout_connector_validation_fail") @@ -79,7 +122,6 @@ def mystate(state, src, dst): # output path (tasklet[b]->dst) state.add_memlet_path(tasklet, dst_node, src_conn='b', memlet=dace.Memlet(data=dst, subset='0')) - sub_sdfg = dace.SDFG('nested_sub') sub_sdfg.add_scalar('sA', dace.float32) sub_sdfg.add_scalar('sB', dace.float32, transient=True) @@ -92,7 +134,6 @@ def mystate(state, src, dst): sub_sdfg.add_edge(state0, state1, dace.InterstateEdge()) - state = sdfg.add_state('s0') me, mx = state.add_map('mymap', dict(k='0:2')) nsdfg = state.add_nested_sdfg(sub_sdfg, sdfg, {'sA'}, {'sC'}) @@ -101,7 +142,7 @@ def mystate(state, src, dst): state.add_memlet_path(Ain, me, nsdfg, memlet=dace.Memlet(data='A', subset='k'), dst_conn='sA') state.add_memlet_path(nsdfg, mx, Aout, memlet=dace.Memlet(data='A', subset='k'), src_conn='sC') - + try: sdfg.validate() except dace.sdfg.InvalidSDFGError: @@ -112,5 +153,6 @@ def mystate(state, src, dst): if __name__ == "__main__": test_inout_connector_validation_success() + test_inout_connector_validation_success_2() test_inout_connector_validation_fail() test_nested_sdfg_with_transient_connector()