Skip to content

Commit

Permalink
Added a function add_scope_connectors() to the Map{Entry, Exit} n…
Browse files Browse the repository at this point in the history
…odes.

For a Map one must always perform two calls (one to `Node.add_in_connector()` and `Node.add_out_connector()`), both calls are different, but highly related.
This commit added a new function to the `Map{Entry, Exit}` classes that allows to combine these two calls into one.
  • Loading branch information
philip-paul-mueller committed Dec 13, 2024
1 parent 896a1e1 commit 22ecf32
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 1 deletion.
1 change: 0 additions & 1 deletion dace/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,6 @@ def __init__(self, wrapped_type, typename=None):
# Convert python basic types
if isinstance(wrapped_type, str):
try:

if wrapped_type == "bool":
wrapped_type = numpy.bool_
else:
Expand Down
42 changes: 42 additions & 0 deletions dace/sdfg/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,43 @@ def add_out_connector(self, connector_name: str, dtype: dtypes.typeclass = None,
self.out_connectors = connectors
return True

def _add_scope_connectors(
self,
connector_name: str,
dtype: Optional[dtypes.typeclass] = None,
force: bool = False,
) -> None:
""" Adds input and output connector names to `self` in one step.
The function will add an input connector with name `'IN_' + connector_name`
and an output connector with name `'OUT_' + connector_name`.
The function is a shorthand for calling `add_in_connector()` and `add_out_connector()`.
:param connector_name: The base name of the new connectors.
:param dtype: The type of the connectors, or `None` for auto-detect.
:param force: Add connector even if input or output connector of that name already exists.
:return: True if the operation is successful, otherwise False.
"""
in_connector_name = "IN_" + connector_name
out_connector_name = "OUT_" + connector_name
if not force:
if in_connector_name in self.in_connectors or in_connector_name in self.out_connectors:
return False
if out_connector_name in self.in_connectors or out_connector_name in self.out_connectors:
return False
# We force unconditionally because we have performed the tests above.
self.add_in_connector(
connector_name=in_connector_name,
dtype=dtype,
force=True,
)
self.add_out_connector(
connector_name=out_connector_name,
dtype=dtype,
force=True,
)
return True

def remove_in_connector(self, connector_name: str):
""" Removes an input connector from the node.
Expand Down Expand Up @@ -741,6 +778,9 @@ class EntryNode(Node):
def validate(self, sdfg, state):
self.map.validate(sdfg, state, self)

add_scope_connectors = Node._add_scope_connectors



# ------------------------------------------------------------------------------

Expand All @@ -752,6 +792,8 @@ class ExitNode(Node):
def validate(self, sdfg, state):
self.map.validate(sdfg, state, self)

add_scope_connectors = Node._add_scope_connectors


# ------------------------------------------------------------------------------

Expand Down
35 changes: 35 additions & 0 deletions tests/sdfg/nodes_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import dace

def test_add_scope_connectors():
sdfg = dace.SDFG("add_scope_connectors_sdfg")
state = sdfg.add_state(is_start_block=True)
me: dace.nodes.MapEntry
mx: dace.nodes.MapExit
me, mx = state.add_map("test_map", ndrange={"__i0": "0:10"})
assert all(
len(mn.in_connectors) == 0 and len(mn.out_connectors) == 0
for mn in [me, mx]
)
me.add_in_connector("IN_T", dtype=dace.float64)
assert len(me.in_connectors) == 1 and me.in_connectors["IN_T"] is dace.float64 and len(me.out_connectors) == 0
assert len(mx.in_connectors) == 0 and len(mx.out_connectors) == 0

# Because there is already an `IN_T` this call will fail.
assert not me.add_scope_connectors("T")
assert len(me.in_connectors) == 1 and me.in_connectors["IN_T"] is dace.float64 and len(me.out_connectors) == 0
assert len(mx.in_connectors) == 0 and len(mx.out_connectors) == 0

# Now it will work, because we specify force, however, the current type for `IN_T` will be overridden.
assert me.add_scope_connectors("T", force=True)
assert len(me.in_connectors) == 1 and me.in_connectors["IN_T"].type is None
assert len(me.out_connectors) == 1 and me.out_connectors["OUT_T"].type is None
assert len(mx.in_connectors) == 0 and len(mx.out_connectors) == 0

# Now tries to the full adding.
assert mx.add_scope_connectors("B", dtype=dace.int64)
assert len(mx.in_connectors) == 1 and mx.in_connectors["IN_B"] is dace.int64
assert len(mx.out_connectors) == 1 and mx.out_connectors["OUT_B"] is dace.int64


if __name__ == "__main__":
test_add_scope_connectors()

0 comments on commit 22ecf32

Please sign in to comment.