-
Notifications
You must be signed in to change notification settings - Fork 129
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added a function
add_scope_connectors()
to the Map{Entry, Exit}
n…
…odes. For a Map one must always perform two calls (one to `Node.add_in_connector()` and `Node.add_out_connector()`), both calls are different, but highly related. This commit added a new function to the `Map{Entry, Exit}` classes that allows to combine these two calls into one.
- Loading branch information
1 parent
896a1e1
commit 22ecf32
Showing
3 changed files
with
77 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
import dace | ||
|
||
def test_add_scope_connectors(): | ||
sdfg = dace.SDFG("add_scope_connectors_sdfg") | ||
state = sdfg.add_state(is_start_block=True) | ||
me: dace.nodes.MapEntry | ||
mx: dace.nodes.MapExit | ||
me, mx = state.add_map("test_map", ndrange={"__i0": "0:10"}) | ||
assert all( | ||
len(mn.in_connectors) == 0 and len(mn.out_connectors) == 0 | ||
for mn in [me, mx] | ||
) | ||
me.add_in_connector("IN_T", dtype=dace.float64) | ||
assert len(me.in_connectors) == 1 and me.in_connectors["IN_T"] is dace.float64 and len(me.out_connectors) == 0 | ||
assert len(mx.in_connectors) == 0 and len(mx.out_connectors) == 0 | ||
|
||
# Because there is already an `IN_T` this call will fail. | ||
assert not me.add_scope_connectors("T") | ||
assert len(me.in_connectors) == 1 and me.in_connectors["IN_T"] is dace.float64 and len(me.out_connectors) == 0 | ||
assert len(mx.in_connectors) == 0 and len(mx.out_connectors) == 0 | ||
|
||
# Now it will work, because we specify force, however, the current type for `IN_T` will be overridden. | ||
assert me.add_scope_connectors("T", force=True) | ||
assert len(me.in_connectors) == 1 and me.in_connectors["IN_T"].type is None | ||
assert len(me.out_connectors) == 1 and me.out_connectors["OUT_T"].type is None | ||
assert len(mx.in_connectors) == 0 and len(mx.out_connectors) == 0 | ||
|
||
# Now tries to the full adding. | ||
assert mx.add_scope_connectors("B", dtype=dace.int64) | ||
assert len(mx.in_connectors) == 1 and mx.in_connectors["IN_B"] is dace.int64 | ||
assert len(mx.out_connectors) == 1 and mx.out_connectors["OUT_B"] is dace.int64 | ||
|
||
|
||
if __name__ == "__main__": | ||
test_add_scope_connectors() |