Skip to content

Commit

Permalink
Sync
Browse files Browse the repository at this point in the history
  • Loading branch information
phschaad committed Sep 19, 2023
1 parent ad14839 commit 4744d08
Show file tree
Hide file tree
Showing 21 changed files with 326 additions and 285 deletions.
4 changes: 2 additions & 2 deletions dace/frontend/python/newast.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from dace.sdfg.propagation import propagate_memlet, propagate_subset, propagate_states
from dace.memlet import Memlet
from dace.properties import LambdaProperty, CodeBlock
from dace.sdfg import SDFG, SDFGState, ControlFlowGraph, ControlFlowBlock, LoopScopeBlock, ScopeBlock
from dace.sdfg import SDFG, SDFGState, ControlFlowBlock, LoopScopeBlock, ScopeBlock
from dace.sdfg.replace import replace_datadesc_names
from dace.symbolic import pystr_to_symbolic, inequal_symbols

Expand Down Expand Up @@ -2170,7 +2170,7 @@ def _recursive_visit(self,
body: List[ast.AST],
name: str,
lineno: int,
parent: ControlFlowGraph,
parent: ScopeBlock,
unconnected_last_block=True,
extra_symbols=None) -> Tuple[SDFGState, SDFGState, SDFGState, bool]:
""" Visits a subtree of the AST, creating special states before and after the visit. Returns the previous state,
Expand Down
2 changes: 1 addition & 1 deletion dace/sdfg/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved.
from dace.sdfg.sdfg import SDFG, InterstateEdge, LogicalGroup

from dace.sdfg.state import SDFGState, ControlFlowBlock, ControlFlowGraph, ScopeBlock, LoopScopeBlock, BranchScopeBlock
from dace.sdfg.state import SDFGState, ControlFlowBlock, ScopeBlock, LoopScopeBlock, BranchScopeBlock

from dace.sdfg.scope import (scope_contains_scope, is_devicelevel_gpu, devicelevel_block_size, ScopeSubgraphView)

Expand Down
2 changes: 1 addition & 1 deletion dace/sdfg/propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1332,7 +1332,7 @@ def propagate_memlet(dfg_state,
if memlet.is_empty():
return Memlet()

sdfg = dfg_state.parent
sdfg = dfg_state.sdfg
scope_node_symbols = set(conn for conn in entry_node.in_connectors if not conn.startswith('IN_'))
defined_vars = [
symbolic.pystr_to_symbolic(s) for s in (dfg_state.symbols_defined_at(entry_node).keys()
Expand Down
2 changes: 1 addition & 1 deletion dace/sdfg/replace.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ def replace_datadesc_names(sdfg, repl: Dict[str, str]):
sdfg.constants_prop[repl[aname]] = sdfg.constants_prop[aname]
del sdfg.constants_prop[aname]

for cf in sdfg.all_cfgs_recursive(recurse_into_sdfgs=False):
for cf in sdfg.all_state_scopes_recursive(recurse_into_sdfgs=False):
# Replace in interstate edges
for e in cf.edges():
e.data.replace_dict(repl, replace_keys=False)
Expand Down
113 changes: 5 additions & 108 deletions dace/sdfg/sdfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1110,7 +1110,7 @@ def remove_data(self, name, validate=True):

# Verify that there are no access nodes that use this data
if validate:
for state in self.nodes():
for state in self.states():
for node in state.nodes():
if isinstance(node, nd.AccessNode) and node.data == name:
raise ValueError(f"Cannot remove data descriptor "
Expand Down Expand Up @@ -1222,9 +1222,9 @@ def remove_node(self, node: SDFGState):
self._cached_start_block = None
return super().remove_node(node)

def states(self):
""" Alias that returns the nodes (states) in this SDFG. """
return self.nodes()
def states(self) -> Iterator[SDFGState]:
""" Returns the states in this SDFG, recursing into state scope blocks. """
return self.all_states_recursive()

def arrays_recursive(self):
""" Iterate over all arrays in this SDFG, including arrays within
Expand Down Expand Up @@ -1507,46 +1507,6 @@ def from_file(filename: str) -> 'SDFG':

# Dynamic SDFG creation API
##############################
def add_state(self, label=None, is_start_block=False) -> 'SDFGState':
return super().add_state(label, is_start_block)

def add_state_before(self, state: 'SDFGState', label=None, is_start_state=False) -> 'SDFGState':
""" Adds a new SDFG state before an existing state, reconnecting
predecessors to it instead.
:param state: The state to prepend the new state before.
:param label: State label.
:param is_start_state: If True, resets SDFG starting state to this
state.
:return: A new SDFGState object.
"""
new_state = self.add_state(label, is_start_state)
# Reconnect
for e in self.in_edges(state):
self.remove_edge(e)
self.add_edge(e.src, new_state, e.data)
# Add unconditional connection between the new state and the current
self.add_edge(new_state, state, InterstateEdge())
return new_state

def add_state_after(self, state: 'SDFGState', label=None, is_start_state=False) -> 'SDFGState':
""" Adds a new SDFG state after an existing state, reconnecting
it to the successors instead.
:param state: The state to append the new state after.
:param label: State label.
:param is_start_state: If True, resets SDFG starting state to this
state.
:return: A new SDFGState object.
"""
new_state = self.add_state(label, is_start_state)
# Reconnect
for e in self.out_edges(state):
self.remove_edge(e)
self.add_edge(new_state, e.dst, e.data)
# Add unconditional connection between the current and the new state
self.add_edge(state, new_state, InterstateEdge())
return new_state

def _find_new_name(self, name: str):
""" Tries to find a new name by adding an underscore and a number. """
Expand Down Expand Up @@ -1993,69 +1953,6 @@ def add_rdistrarray(self, array_a: str, array_b: str):
self.append_exit_code(self._rdistrarrays[rdistrarray_name].exit_code(self))
return rdistrarray_name

def add_loop(
self,
before_state,
after_state,
loop_var: str,
initialize_expr: str,
condition_expr: str,
increment_expr: str,
inverted: bool = False,
):
"""
Helper function that adds a looping state machine around a
given state (or sequence of states).
:param before_state: The state after which the loop should
begin, or None if the loop is the first
state (creates an empty state).
:param loop_state: The state that begins the loop. See also
``loop_end_state`` if the loop is multi-state.
:param after_state: The state that should be invoked after
the loop ends, or None if the program
should terminate (creates an empty state).
:param loop_var: A name of an inter-state variable to use
for the loop. If None, ``initialize_expr``
and ``increment_expr`` must be None.
:param initialize_expr: A string expression that is assigned
to ``loop_var`` before the loop begins.
If None, does not define an expression.
:param condition_expr: A string condition that occurs every
loop iteration. If None, loops forever
(undefined behavior).
:param increment_expr: A string expression that is assigned to
``loop_var`` after every loop iteration.
If None, does not define an expression.
:param loop_end_state: If the loop wraps multiple states, the
state where the loop iteration ends.
If None, sets the end state to
``loop_state`` as well.
:return: A 3-tuple of (``before_state``, generated loop guard state,
``after_state``).
"""
# Argument checks
if loop_var is None and (initialize_expr or increment_expr):
raise ValueError("Cannot initalize or increment an empty loop variable")

loop_scope = LoopScopeBlock(loop_var=loop_var,
initialize_expr=initialize_expr,
update_expr=increment_expr,
condition_expr=condition_expr,
inverted=inverted)

# Handling empty states
if before_state is None:
before_state = self.add_state()
if after_state is None:
after_state = self.add_state()

self.add_node(loop_scope)
self.add_edge(before_state, loop_scope)
self.add_edge(loop_scope, after_state)

return before_state, loop_scope, after_state

# SDFG queries
##############################

Expand Down Expand Up @@ -2254,7 +2151,7 @@ def __call__(self, *args, **kwargs):
def fill_scope_connectors(self):
""" Fills missing scope connectors (i.e., "IN_#"/"OUT_#" on entry/exit
nodes) according to data on the memlets. """
for cf in self.all_cfgs_recursive():
for cf in self.all_state_scopes_recursive():
for block in cf.nodes():
if isinstance(block, SDFGState):
block.fill_scope_connectors()
Expand Down
Loading

0 comments on commit 4744d08

Please sign in to comment.