Skip to content

Commit

Permalink
In-out connector's global source when connector becomes out-only at o…
Browse files Browse the repository at this point in the history
…uter SDFG scopes. (#1463)

Adds utility-method support for the case of an in-out nested SDFG
connector that is out-only at outer SDFG scopes.
  • Loading branch information
alexnick83 authored Nov 30, 2023
1 parent edbf49f commit 79cf2ff
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 3 deletions.
5 changes: 5 additions & 0 deletions dace/sdfg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1847,6 +1847,11 @@ def get_global_memlet_path_src(sdfg: SDFG, state: SDFGState, edge: MultiConnecto
if len(pedges) > 0:
pedge = pedges[0]
return get_global_memlet_path_src(psdfg, pstate, pedge)
else:
pedges = list(pstate.out_edges_by_connector(pnode, src.data))
if len(pedges) > 0:
pedge = pedges[0]
return get_global_memlet_path_dst(psdfg, pstate, pedge)
return src


Expand Down
48 changes: 45 additions & 3 deletions tests/sdfg/validation/nested_sdfg_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved.
import dace
import numpy as np


def test_inout_connector_validation_success():
Expand Down Expand Up @@ -33,6 +34,48 @@ def test_inout_connector_validation_success():
return


def test_inout_connector_validation_success_2():

sdfg = dace.SDFG("test_inout_connector_validation_success_2")
sdfg.add_array("A", [1], dace.int32)

nsdfg_0 = dace.SDFG("nested_sdfg_0")
nsdfg_0.add_array("B", [1], dace.int32)

nsdfg_1 = dace.SDFG("nested_sdfg_1")
nsdfg_1.add_array("C", [1], dace.int32)

nstate = nsdfg_1.add_state()
read_c = nstate.add_access("C")
write_c = nstate.add_access("C")
tasklet = nstate.add_tasklet("tasklet", {"__inp"}, {"__out"}, "__out = __inp + 5")
nstate.add_edge(read_c, None, tasklet, '__inp', dace.Memlet.from_array('C', nsdfg_1.arrays['C']))
nstate.add_edge(tasklet, '__out', write_c, None, dace.Memlet.from_array('C', nsdfg_1.arrays['C']))

nstate = nsdfg_0.add_state()
tasklet_0 = nstate.add_tasklet("tasklet_00", {}, {"__out"}, "__out = 3")
write_b_0 = nstate.add_access("B")
tasklet_1 = nstate.add_nested_sdfg(nsdfg_1, nsdfg_0, {"C"}, {"C"})
write_b_1 = nstate.add_access("B")
nstate.add_edge(tasklet_0, '__out', write_b_0, None, dace.Memlet.from_array('B', nsdfg_0.arrays['B']))
nstate.add_edge(write_b_0, None, tasklet_1, 'C', dace.Memlet.from_array('B', nsdfg_0.arrays['B']))
nstate.add_edge(tasklet_1, 'C', write_b_1, None, dace.Memlet.from_array('B', nsdfg_0.arrays['B']))

state = sdfg.add_state()
tasklet = state.add_nested_sdfg(nsdfg_0, sdfg, {}, {"B"})
write_a = state.add_access("A")
state.add_edge(tasklet, 'B', write_a, None, dace.Memlet.from_array('A', sdfg.arrays['A']))

try:
sdfg.validate()
except dace.sdfg.InvalidSDFGError:
assert False, "SDFG should validate"

A = np.array([1], dtype=np.int32)
sdfg(A=A)
assert A[0] == 8


def test_inout_connector_validation_fail():

sdfg = dace.SDFG("test_inout_connector_validation_fail")
Expand Down Expand Up @@ -79,7 +122,6 @@ def mystate(state, src, dst):
# output path (tasklet[b]->dst)
state.add_memlet_path(tasklet, dst_node, src_conn='b', memlet=dace.Memlet(data=dst, subset='0'))


sub_sdfg = dace.SDFG('nested_sub')
sub_sdfg.add_scalar('sA', dace.float32)
sub_sdfg.add_scalar('sB', dace.float32, transient=True)
Expand All @@ -92,7 +134,6 @@ def mystate(state, src, dst):

sub_sdfg.add_edge(state0, state1, dace.InterstateEdge())


state = sdfg.add_state('s0')
me, mx = state.add_map('mymap', dict(k='0:2'))
nsdfg = state.add_nested_sdfg(sub_sdfg, sdfg, {'sA'}, {'sC'})
Expand All @@ -101,7 +142,7 @@ def mystate(state, src, dst):

state.add_memlet_path(Ain, me, nsdfg, memlet=dace.Memlet(data='A', subset='k'), dst_conn='sA')
state.add_memlet_path(nsdfg, mx, Aout, memlet=dace.Memlet(data='A', subset='k'), src_conn='sC')

try:
sdfg.validate()
except dace.sdfg.InvalidSDFGError:
Expand All @@ -112,5 +153,6 @@ def mystate(state, src, dst):

if __name__ == "__main__":
test_inout_connector_validation_success()
test_inout_connector_validation_success_2()
test_inout_connector_validation_fail()
test_nested_sdfg_with_transient_connector()

0 comments on commit 79cf2ff

Please sign in to comment.