Skip to content

Commit

Permalink
Re-work multistate inline
Browse files Browse the repository at this point in the history
  • Loading branch information
phschaad committed Sep 5, 2023
1 parent a320d09 commit 75195fa
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 54 deletions.
6 changes: 3 additions & 3 deletions dace/sdfg/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -963,7 +963,7 @@ class ControlFlowBlock(BlockGraphView, abc.ABC):
is_collapsed = Property(dtype=bool, desc='Show this block as collapsed', default=False)

_sdfg: Optional['dace.SDFG'] = None
_parent: Optional['ControlFlowBlock'] = None
_parent: Optional['ScopeBlock'] = None
_label: str

def __init__(self,
Expand Down Expand Up @@ -1012,12 +1012,12 @@ def name(self) -> str:
return self._label

@property
def parent(self) -> Optional['ControlFlowBlock']:
def parent(self) -> Optional['ScopeBlock']:
""" Returns the parent block of this block. """
return self._parent

@parent.setter
def parent(self, block: Optional['ControlFlowBlock']):
def parent(self, block: Optional['ScopeBlock']):
self._parent = block

@property
Expand Down
108 changes: 57 additions & 51 deletions dace/transformation/interstate/multistate_inline.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,15 +158,15 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG):
sdfg.append_exit_code(code.code, loc)

# Environments
for nstate in nsdfg.nodes():
for node in nstate.nodes():
if isinstance(node, nodes.CodeNode):
node.environments |= nsdfg_node.environments
for node, _ in nsdfg.all_nodes_recursive():
if isinstance(node, nodes.CodeNode):
node.environments |= nsdfg_node.environments

# Symbols
outer_symbols = {str(k): v for k, v in sdfg.symbols.items()}
for ise in sdfg.edges():
outer_symbols.update(ise.data.new_symbols(sdfg, outer_symbols))
for cf in sdfg.all_cfgs_recursive(recurse_into_sdfgs=False):
for ise in cf.edges():
outer_symbols.update(ise.data.new_symbols(sdfg, outer_symbols))

# Find original source/destination edges (there is only one edge per
# connector, according to match)
Expand All @@ -189,12 +189,14 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG):
# Collect and modify interstate edges as necessary

outer_assignments = set()
for e in sdfg.edges():
outer_assignments |= e.data.assignments.keys()
for cf in sdfg.all_cfgs_recursive():
for e in cf.edges():
outer_assignments |= e.data.assignments.keys()

inner_assignments = set()
for e in nsdfg.edges():
inner_assignments |= e.data.assignments.keys()
for cf in nsdfg.all_cfgs_recursive():
for e in cf.edges():
inner_assignments |= e.data.assignments.keys()

allnames = set(outer_symbols.keys()) | set(sdfg.arrays.keys())
assignments_to_replace = inner_assignments & (outer_assignments | allnames)
Expand Down Expand Up @@ -235,30 +237,31 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG):

# All transients become transients of the parent (if data already
# exists, find new name)
for nstate in nsdfg.nodes():
for node in nstate.nodes():
if isinstance(node, nodes.AccessNode):
datadesc = nsdfg.arrays[node.data]
if node.data not in transients and datadesc.transient:
new_name = node.data
if (new_name in sdfg.arrays or new_name in outer_symbols or new_name in sdfg.constants):
new_name = f'{nsdfg.label}_{node.data}'

name = sdfg.add_datadesc(new_name, datadesc, find_new_name=True)
transients[node.data] = name

# All transients of edges between code nodes are also added to parent
for edge in nstate.edges():
if (isinstance(edge.src, nodes.CodeNode) and isinstance(edge.dst, nodes.CodeNode)):
if edge.data.data is not None:
datadesc = nsdfg.arrays[edge.data.data]
if edge.data.data not in transients and datadesc.transient:
new_name = edge.data.data
for cf in nsdfg.all_cfgs_recursive():
for nblock in cf.nodes():
for node in nblock.nodes():
if isinstance(node, nodes.AccessNode):
datadesc = nsdfg.arrays[node.data]
if node.data not in transients and datadesc.transient:
new_name = node.data
if (new_name in sdfg.arrays or new_name in outer_symbols or new_name in sdfg.constants):
new_name = f'{nsdfg.label}_{edge.data.data}'
new_name = f'{nsdfg.label}_{node.data}'

name = sdfg.add_datadesc(new_name, datadesc, find_new_name=True)
transients[edge.data.data] = name
transients[node.data] = name

# All transients of edges between code nodes are also added to parent
for edge in nblock.edges():
if (isinstance(edge.src, nodes.CodeNode) and isinstance(edge.dst, nodes.CodeNode)):
if edge.data.data is not None:
datadesc = nsdfg.arrays[edge.data.data]
if edge.data.data not in transients and datadesc.transient:
new_name = edge.data.data
if (new_name in sdfg.arrays or new_name in outer_symbols or new_name in sdfg.constants):
new_name = f'{nsdfg.label}_{edge.data.data}'

name = sdfg.add_datadesc(new_name, datadesc, find_new_name=True)
transients[edge.data.data] = name


# All constants (and associated transients) become constants of the parent
Expand Down Expand Up @@ -329,12 +332,12 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG):
# e.dst_conn, e.data)

# Make unique names for states
statenames = set(s.label for s in sdfg.nodes())
for nstate in nsdfg.nodes():
if nstate.label in statenames:
newname = data.find_new_name(nstate.label, statenames)
statenames.add(newname)
nstate.label = newname
blocknames = set(s.label for s in sdfg.all_control_flow_blocks_recursive(recurse_into_sdfgs=False))
for nblock in nsdfg.all_control_flow_blocks_recursive(recurse_into_sdfgs=False):
if nblock.label in blocknames:
newname = data.find_new_name(nblock.label, blocknames)
blocknames.add(newname)
nblock.label = newname

#######################################################
# Add nested SDFG states into top-level SDFG
Expand All @@ -352,19 +355,20 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG):
sinks = nsdfg.sink_nodes()

# Reconnect state machine
for e in sdfg.in_edges(outer_state):
sdfg.add_edge(e.src, source, e.data)
for e in sdfg.out_edges(outer_state):
parent_graph = outer_state.parent
for e in parent_graph.in_edges(outer_state):
parent_graph.add_edge(e.src, source, e.data)
for e in parent_graph.out_edges(outer_state):
for sink in sinks:
sdfg.add_edge(sink, e.dst, dc(e.data))
parent_graph.add_edge(sink, e.dst, dc(e.data))
# Redirect sink incoming edges with a `False` condition to e.dst (return statements)
for e2 in sdfg.in_edges(sink):
for e2 in parent_graph.in_edges(sink):
if e2.data.condition_sympy() == False:
sdfg.add_edge(e2.src, e.dst, InterstateEdge())
parent_graph.add_edge(e2.src, e.dst, InterstateEdge())

# Modify start state as necessary
if outer_start_state is outer_state:
sdfg.start_state = sdfg.node_id(source)
parent_graph.start_block = parent_graph.node_id(source)

# TODO: Modify memlets by offsetting
# If both source and sink nodes are inputs/outputs, reconnect once
Expand Down Expand Up @@ -408,13 +412,15 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG):
# e._data = helpers.unsqueeze_memlet(
# e.data, outer_edge.data)

# Replace nested SDFG parents with new SDFG
for nstate in nsdfg.nodes():
nstate.parent = sdfg
for node in nstate.nodes():
if isinstance(node, nodes.NestedSDFG):
node.sdfg.parent_sdfg = sdfg
node.sdfg.parent_nsdfg_node = node
# Replace nested SDFG parents and SDFG pointers.
for cf in nsdfg.all_cfgs_recursive():
for nblock in cf.nodes():
nblock.parent = cf
nblock.sdfg = sdfg
for node in nblock.nodes():
if isinstance(node, nodes.NestedSDFG):
node.sdfg.parent_sdfg = sdfg
node.sdfg.parent_nsdfg_node = node

#######################################################
# Remove nested SDFG and state
Expand Down

0 comments on commit 75195fa

Please sign in to comment.