From 4744d088525b3c2b58b956b678d1530d5effeffe Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Tue, 19 Sep 2023 10:45:03 +0200 Subject: [PATCH] Sync --- dace/frontend/python/newast.py | 4 +- dace/sdfg/__init__.py | 2 +- dace/sdfg/propagation.py | 2 +- dace/sdfg/replace.py | 2 +- dace/sdfg/sdfg.py | 113 +------- dace/sdfg/state.py | 255 ++++++++++++------ dace/sdfg/utils.py | 64 +++-- dace/transformation/dataflow/warp_tiling.py | 2 +- dace/transformation/helpers.py | 4 +- .../interstate/multistate_inline.py | 10 +- .../transformation/interstate/scope_inline.py | 6 +- .../transformation/interstate/sdfg_nesting.py | 4 +- .../transformation/interstate/state_fusion.py | 4 +- dace/transformation/pass_pipeline.py | 52 +++- dace/transformation/passes/analysis.py | 2 +- .../passes/dead_state_elimination.py | 26 +- .../transformation/passes/scalar_to_symbol.py | 18 +- dace/transformation/passes/simplify.py | 4 +- dace/transformation/transformation.py | 8 +- tests/transformations/state_fission_test.py | 6 +- .../trivial_loop_elimination_test.py | 23 +- 21 files changed, 326 insertions(+), 285 deletions(-) diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index 07667cfb83..2fa16182e3 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -31,7 +31,7 @@ from dace.sdfg.propagation import propagate_memlet, propagate_subset, propagate_states from dace.memlet import Memlet from dace.properties import LambdaProperty, CodeBlock -from dace.sdfg import SDFG, SDFGState, ControlFlowGraph, ControlFlowBlock, LoopScopeBlock, ScopeBlock +from dace.sdfg import SDFG, SDFGState, ControlFlowBlock, LoopScopeBlock, ScopeBlock from dace.sdfg.replace import replace_datadesc_names from dace.symbolic import pystr_to_symbolic, inequal_symbols @@ -2170,7 +2170,7 @@ def _recursive_visit(self, body: List[ast.AST], name: str, lineno: int, - parent: ControlFlowGraph, + parent: ScopeBlock, unconnected_last_block=True, extra_symbols=None) -> Tuple[SDFGState, SDFGState, SDFGState, bool]: """ Visits a subtree of the AST, creating special states before and after the visit. Returns the previous state, diff --git a/dace/sdfg/__init__.py b/dace/sdfg/__init__.py index 307da66e3e..9f48433bd5 100644 --- a/dace/sdfg/__init__.py +++ b/dace/sdfg/__init__.py @@ -1,7 +1,7 @@ # Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. from dace.sdfg.sdfg import SDFG, InterstateEdge, LogicalGroup -from dace.sdfg.state import SDFGState, ControlFlowBlock, ControlFlowGraph, ScopeBlock, LoopScopeBlock, BranchScopeBlock +from dace.sdfg.state import SDFGState, ControlFlowBlock, ScopeBlock, LoopScopeBlock, BranchScopeBlock from dace.sdfg.scope import (scope_contains_scope, is_devicelevel_gpu, devicelevel_block_size, ScopeSubgraphView) diff --git a/dace/sdfg/propagation.py b/dace/sdfg/propagation.py index 0554775dcd..52060e0edd 100644 --- a/dace/sdfg/propagation.py +++ b/dace/sdfg/propagation.py @@ -1332,7 +1332,7 @@ def propagate_memlet(dfg_state, if memlet.is_empty(): return Memlet() - sdfg = dfg_state.parent + sdfg = dfg_state.sdfg scope_node_symbols = set(conn for conn in entry_node.in_connectors if not conn.startswith('IN_')) defined_vars = [ symbolic.pystr_to_symbolic(s) for s in (dfg_state.symbols_defined_at(entry_node).keys() diff --git a/dace/sdfg/replace.py b/dace/sdfg/replace.py index 5ce2b7a45f..2558ffd6ba 100644 --- a/dace/sdfg/replace.py +++ b/dace/sdfg/replace.py @@ -168,7 +168,7 @@ def replace_datadesc_names(sdfg, repl: Dict[str, str]): sdfg.constants_prop[repl[aname]] = sdfg.constants_prop[aname] del sdfg.constants_prop[aname] - for cf in sdfg.all_cfgs_recursive(recurse_into_sdfgs=False): + for cf in sdfg.all_state_scopes_recursive(recurse_into_sdfgs=False): # Replace in interstate edges for e in cf.edges(): e.data.replace_dict(repl, replace_keys=False) diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index 4d457fca33..50b2023309 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -1110,7 +1110,7 @@ def remove_data(self, name, validate=True): # Verify that there are no access nodes that use this data if validate: - for state in self.nodes(): + for state in self.states(): for node in state.nodes(): if isinstance(node, nd.AccessNode) and node.data == name: raise ValueError(f"Cannot remove data descriptor " @@ -1222,9 +1222,9 @@ def remove_node(self, node: SDFGState): self._cached_start_block = None return super().remove_node(node) - def states(self): - """ Alias that returns the nodes (states) in this SDFG. """ - return self.nodes() + def states(self) -> Iterator[SDFGState]: + """ Returns the states in this SDFG, recursing into state scope blocks. """ + return self.all_states_recursive() def arrays_recursive(self): """ Iterate over all arrays in this SDFG, including arrays within @@ -1507,46 +1507,6 @@ def from_file(filename: str) -> 'SDFG': # Dynamic SDFG creation API ############################## - def add_state(self, label=None, is_start_block=False) -> 'SDFGState': - return super().add_state(label, is_start_block) - - def add_state_before(self, state: 'SDFGState', label=None, is_start_state=False) -> 'SDFGState': - """ Adds a new SDFG state before an existing state, reconnecting - predecessors to it instead. - - :param state: The state to prepend the new state before. - :param label: State label. - :param is_start_state: If True, resets SDFG starting state to this - state. - :return: A new SDFGState object. - """ - new_state = self.add_state(label, is_start_state) - # Reconnect - for e in self.in_edges(state): - self.remove_edge(e) - self.add_edge(e.src, new_state, e.data) - # Add unconditional connection between the new state and the current - self.add_edge(new_state, state, InterstateEdge()) - return new_state - - def add_state_after(self, state: 'SDFGState', label=None, is_start_state=False) -> 'SDFGState': - """ Adds a new SDFG state after an existing state, reconnecting - it to the successors instead. - - :param state: The state to append the new state after. - :param label: State label. - :param is_start_state: If True, resets SDFG starting state to this - state. - :return: A new SDFGState object. - """ - new_state = self.add_state(label, is_start_state) - # Reconnect - for e in self.out_edges(state): - self.remove_edge(e) - self.add_edge(new_state, e.dst, e.data) - # Add unconditional connection between the current and the new state - self.add_edge(state, new_state, InterstateEdge()) - return new_state def _find_new_name(self, name: str): """ Tries to find a new name by adding an underscore and a number. """ @@ -1993,69 +1953,6 @@ def add_rdistrarray(self, array_a: str, array_b: str): self.append_exit_code(self._rdistrarrays[rdistrarray_name].exit_code(self)) return rdistrarray_name - def add_loop( - self, - before_state, - after_state, - loop_var: str, - initialize_expr: str, - condition_expr: str, - increment_expr: str, - inverted: bool = False, - ): - """ - Helper function that adds a looping state machine around a - given state (or sequence of states). - - :param before_state: The state after which the loop should - begin, or None if the loop is the first - state (creates an empty state). - :param loop_state: The state that begins the loop. See also - ``loop_end_state`` if the loop is multi-state. - :param after_state: The state that should be invoked after - the loop ends, or None if the program - should terminate (creates an empty state). - :param loop_var: A name of an inter-state variable to use - for the loop. If None, ``initialize_expr`` - and ``increment_expr`` must be None. - :param initialize_expr: A string expression that is assigned - to ``loop_var`` before the loop begins. - If None, does not define an expression. - :param condition_expr: A string condition that occurs every - loop iteration. If None, loops forever - (undefined behavior). - :param increment_expr: A string expression that is assigned to - ``loop_var`` after every loop iteration. - If None, does not define an expression. - :param loop_end_state: If the loop wraps multiple states, the - state where the loop iteration ends. - If None, sets the end state to - ``loop_state`` as well. - :return: A 3-tuple of (``before_state``, generated loop guard state, - ``after_state``). - """ - # Argument checks - if loop_var is None and (initialize_expr or increment_expr): - raise ValueError("Cannot initalize or increment an empty loop variable") - - loop_scope = LoopScopeBlock(loop_var=loop_var, - initialize_expr=initialize_expr, - update_expr=increment_expr, - condition_expr=condition_expr, - inverted=inverted) - - # Handling empty states - if before_state is None: - before_state = self.add_state() - if after_state is None: - after_state = self.add_state() - - self.add_node(loop_scope) - self.add_edge(before_state, loop_scope) - self.add_edge(loop_scope, after_state) - - return before_state, loop_scope, after_state - # SDFG queries ############################## @@ -2254,7 +2151,7 @@ def __call__(self, *args, **kwargs): def fill_scope_connectors(self): """ Fills missing scope connectors (i.e., "IN_#"/"OUT_#" on entry/exit nodes) according to data on the memlets. """ - for cf in self.all_cfgs_recursive(): + for cf in self.all_state_scopes_recursive(): for block in cf.nodes(): if isinstance(block, SDFGState): block.fill_scope_connectors() diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 18caf9d80c..857d8afb21 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -968,7 +968,7 @@ class ControlFlowBlock(BlockGraphView, abc.ABC): def __init__(self, label: str='', - parent: Optional['ControlFlowBlock']=None, + parent: Optional['ScopeBlock']=None, sdfg: Optional['dace.SDFG'] = None): super(ControlFlowBlock, self).__init__() self._label = label @@ -2228,98 +2228,20 @@ def __init__(self, graph, subgraph_nodes): @make_properties -class ControlFlowGraph(OrderedDiGraph[ControlFlowBlock, 'dace.sdfg.InterstateEdge'], ControlGraphView): - - def __init__(self): - super(ControlFlowGraph, self).__init__() - - self._labels: Set[str] = set() - self._start_block: Optional[int] = None - self._cached_start_block: Optional[ControlFlowBlock] = None - - ################################################################### - # Traversal methods - - def all_cfgs_recursive(self, recurse_into_sdfgs=True) -> Iterator['ControlFlowGraph']: - """ Iterate over this and all nested CFGs. """ - yield self - for block in self.nodes(): - if isinstance(block, SDFGState) and recurse_into_sdfgs: - for node in block.nodes(): - if isinstance(node, nd.NestedSDFG): - yield from node.sdfg.all_cfgs_recursive() - elif isinstance(block, ControlFlowGraph): - yield from block.all_cfgs_recursive() - - def all_sdfgs_recursive(self) -> Iterator['dace.SDFG']: - """ Iterate over this and all nested SDFGs. """ - for cfg in self.all_cfgs_recursive(recurse_into_sdfgs=True): - if isinstance(cfg, dace.SDFG): - yield cfg - - def all_states_recursive(self) -> Iterator[SDFGState]: - """ Iterate over all states in this control flow graph. """ - for block in self.nodes(): - if isinstance(block, SDFGState): - yield block - elif isinstance(block, ControlFlowGraph): - yield from block.all_states_recursive() - - def all_control_flow_blocks_recursive(self, recurse_into_sdfgs=True) -> Iterator[ControlFlowBlock]: - """ Iterate over all control flow blocks in this control flow graph. """ - for cfg in self.all_cfgs_recursive(recurse_into_sdfgs=recurse_into_sdfgs): - for block in cfg.nodes(): - yield block - - def all_interstate_edges_recursive(self, recurse_into_sdfgs=True) -> Iterator[Edge['dace.sdfg.InterstateEdge']]: - """ Iterate over all interstate edges in this control flow graph. """ - for cfg in self.all_cfgs_recursive(recurse_into_sdfgs=recurse_into_sdfgs): - for edge in cfg.edges(): - yield edge - - ################################################################### - # Getters & setters, overrides - - @property - def start_block(self): - """ Returns the starting block of this ControlFlowGraph. """ - if self._cached_start_block is not None: - return self._cached_start_block - - source_nodes = self.source_nodes() - if len(source_nodes) == 1: - self._cached_start_block = source_nodes[0] - return source_nodes[0] - # If the starting block is ambiguous allow manual override. - if self._start_block is not None: - self._cached_start_block = self.node(self._start_block) - return self._cached_start_block - raise ValueError('Ambiguous or undefined starting block for ControlFlowGraph, ' - 'please use "is_start_block=True" when adding the ' - 'starting block with "add_state" or "add_node"') - - @start_block.setter - def start_block(self, block_id): - """ Manually sets the starting block of this ControlFlowGraph. - - :param block_id: The node ID (use `node_id(block)`) of the block to set. - """ - if block_id < 0 or block_id >= self.number_of_nodes(): - raise ValueError('Invalid state ID') - self._start_block = block_id - self._cached_start_block = self.node(block_id) - - -@make_properties -class ScopeBlock(ControlFlowGraph, ControlFlowBlock): +class ScopeBlock(OrderedDiGraph[ControlFlowBlock, 'dace.sdfg.InterstateEdge'], ControlGraphView, ControlFlowBlock): def __init__(self, label: str='', - parent: Optional['ControlFlowBlock']=None, + parent: Optional['ScopeBlock']=None, sdfg: Optional['dace.SDFG'] = None): - ControlFlowGraph.__init__(self) + OrderedDiGraph.__init__(self) + ControlGraphView.__init__(self) ControlFlowBlock.__init__(self, label, parent, sdfg) + self._labels: Set[str] = set() + self._start_block: Optional[int] = None + self._cached_start_block: Optional[ControlFlowBlock] = None + def add_edge(self, src: ControlFlowBlock, dst: ControlFlowBlock, data: 'dace.sdfg.InterstateEdge'): """ Adds a new edge to the graph. Must be an InterstateEdge or a subclass thereof. @@ -2359,6 +2281,92 @@ def add_state(self, label=None, is_start_block=False) -> SDFGState: self.add_node(state, is_start_block=is_start_block) return state + def add_state_before(self, state: SDFGState, label=None, is_start_state=False) -> SDFGState: + """ Adds a new SDFG state before an existing state, reconnecting predecessors to it instead. + + :param state: The state to prepend the new state before. + :param label: State label. + :param is_start_state: If True, resets scope block starting state to this state. + :return: A new SDFGState object. + """ + new_state = self.add_state(label, is_start_state) + # Reconnect + for e in self.in_edges(state): + self.remove_edge(e) + self.add_edge(e.src, new_state, e.data) + # Add unconditional connection between the new state and the current + self.add_edge(new_state, state, dace.sdfg.InterstateEdge()) + return new_state + + def add_state_after(self, state: SDFGState, label=None, is_start_state=False) -> SDFGState: + """ Adds a new SDFG state after an existing state, reconnecting it to the successors instead. + + :param state: The state to append the new state after. + :param label: State label. + :param is_start_state: If True, resets SDFG starting state to this state. + :return: A new SDFGState object. + """ + new_state = self.add_state(label, is_start_state) + # Reconnect + for e in self.out_edges(state): + self.remove_edge(e) + self.add_edge(new_state, e.dst, e.data) + # Add unconditional connection between the current and the new state + self.add_edge(state, new_state, dace.sdfg.InterstateEdge()) + return new_state + + def add_loop( + self, + before_state: SDFGState, + after_state: SDFGState, + loop_var: str, + initialize_expr: str, + condition_expr: str, + increment_expr: str, + inverted: bool = False, + ): + """ + Helper function that adds a looping state machine around a given state (or sequence of states). + + :param before_state: The state after which the loop should begin, or None if the loop is the first state + (creates an empty state). + :param loop_state: The state that begins the loop. See also ``loop_end_state`` if the loop is multi-state. + :param after_state: The state that should be invoked after the loop ends, or None if the program should + terminate (creates an empty state). + :param loop_var: A name of an inter-state variable to use for the loop. If None, ``initialize_expr`` and + ``increment_expr`` must be None. + :param initialize_expr: A string expression that is assigned to ``loop_var`` before the loop begins. If None, + does not define an expression. + :param condition_expr: A string condition that occurs every loop iteration. If None, loops forever (undefined + behavior). + :param increment_expr: A string expression that is assigned to ``loop_var`` after every loop iteration. If None, + does not define an expression. + :param loop_end_state: If the loop wraps multiple states, the state where the loop iteration ends. If None, sets + the end state to ``loop_state`` as well. + :return: A 3-tuple of (``before_state``, generated loop guard state, ``after_state``). + """ + # Argument checks + if loop_var is None and (initialize_expr or increment_expr): + raise ValueError("Cannot initalize or increment an empty loop variable") + + loop_scope = LoopScopeBlock(loop_var=loop_var, + initialize_expr=initialize_expr, + update_expr=increment_expr, + condition_expr=condition_expr, + inverted=inverted) + + # Handling empty states + if before_state is None: + before_state = self.add_state() + if after_state is None: + after_state = self.add_state() + + self.add_node(loop_scope) + self.add_edge(before_state, loop_scope) + self.add_edge(loop_scope, after_state) + + return before_state, loop_scope, after_state + @abc.abstractmethod def used_symbols(self, all_symbols: bool) -> Tuple[Set[str], Set[str], Set[str]]: defined_syms = set() @@ -2394,11 +2402,51 @@ def used_symbols(self, all_symbols: bool) -> Tuple[Set[str], Set[str], Set[str]] return free_syms, defined_syms, used_before_assignment def to_json(self, parent=None): - graph_json = ControlFlowGraph.to_json(self) + graph_json = OrderedDiGraph.to_json(self) block_json = ControlFlowBlock.to_json(self, parent) graph_json.update(block_json) return graph_json + ################################################################### + # Traversal methods + + def all_state_scopes_recursive(self, recurse_into_sdfgs=True) -> Iterator['ScopeBlock']: + """ Iterate over this and all nested state scopes. """ + yield self + for block in self.nodes(): + if isinstance(block, SDFGState) and recurse_into_sdfgs: + for node in block.nodes(): + if isinstance(node, nd.NestedSDFG): + yield from node.sdfg.all_state_scopes_recursive(recurse_into_sdfgs=recurse_into_sdfgs) + elif isinstance(block, ScopeBlock): + yield from block.all_state_scopes_recursive(recurse_into_sdfgs=recurse_into_sdfgs) + + def all_sdfgs_recursive(self) -> Iterator['dace.SDFG']: + """ Iterate over this and all nested SDFGs. """ + for cfg in self.all_state_scopes_recursive(recurse_into_sdfgs=True): + if isinstance(cfg, dace.SDFG): + yield cfg + + def all_states_recursive(self) -> Iterator[SDFGState]: + """ Iterate over all states in this control flow graph. """ + for block in self.nodes(): + if isinstance(block, SDFGState): + yield block + elif isinstance(block, ScopeBlock): + yield from block.all_states_recursive() + + def all_control_flow_blocks_recursive(self, recurse_into_sdfgs=True) -> Iterator[ControlFlowBlock]: + """ Iterate over all control flow blocks in this control flow graph. """ + for cfg in self.all_state_scopes_recursive(recurse_into_sdfgs=recurse_into_sdfgs): + for block in cfg.nodes(): + yield block + + def all_interstate_edges_recursive(self, recurse_into_sdfgs=True) -> Iterator[Edge['dace.sdfg.InterstateEdge']]: + """ Iterate over all interstate edges in this control flow graph. """ + for cfg in self.all_state_scopes_recursive(recurse_into_sdfgs=recurse_into_sdfgs): + for edge in cfg.edges(): + yield edge + ################################################################### # Getters & setters, overrides @@ -2408,6 +2456,35 @@ def __str__(self): def __repr__(self) -> str: return f'{self.__class__.__name__} ({self.label})' + @property + def start_block(self): + """ Returns the starting block of this ControlFlowGraph. """ + if self._cached_start_block is not None: + return self._cached_start_block + + source_nodes = self.source_nodes() + if len(source_nodes) == 1: + self._cached_start_block = source_nodes[0] + return source_nodes[0] + # If the starting block is ambiguous allow manual override. + if self._start_block is not None: + self._cached_start_block = self.node(self._start_block) + return self._cached_start_block + raise ValueError('Ambiguous or undefined starting block for ControlFlowGraph, ' + 'please use "is_start_block=True" when adding the ' + 'starting block with "add_state" or "add_node"') + + @start_block.setter + def start_block(self, block_id): + """ Manually sets the starting block of this ControlFlowGraph. + + :param block_id: The node ID (use `node_id(block)`) of the block to set. + """ + if block_id < 0 or block_id >= self.number_of_nodes(): + raise ValueError('Invalid state ID') + self._start_block = block_id + self._cached_start_block = self.node(block_id) + @make_properties class LoopScopeBlock(ScopeBlock): @@ -2424,7 +2501,7 @@ def __init__(self, condition_expr: str, update_expr: str, label: str = '', - parent: Optional[ControlFlowGraph] = None, + parent: Optional[ScopeBlock] = None, sdfg: Optional['dace.SDFG'] = None, inverted: bool = False): super(LoopScopeBlock, self).__init__(label, parent, sdfg) diff --git a/dace/sdfg/utils.py b/dace/sdfg/utils.py index 2552e1d417..bda6bee8d0 100644 --- a/dace/sdfg/utils.py +++ b/dace/sdfg/utils.py @@ -13,7 +13,7 @@ from dace.sdfg.graph import MultiConnectorEdge from dace.sdfg.sdfg import SDFG from dace.sdfg.nodes import Node, NestedSDFG -from dace.sdfg.state import SDFGState, StateSubgraphView, LoopScopeBlock +from dace.sdfg.state import SDFGState, StateSubgraphView, LoopScopeBlock, ScopeBlock from dace.sdfg.scope import ScopeSubgraphView from dace.sdfg import nodes as nd, graph as gr, propagation from dace import config, data as dt, dtypes, memlet as mm, subsets as sbs, symbolic @@ -1206,7 +1206,7 @@ def fuse_states(sdfg: SDFG, permissive: bool = False, progress: bool = None) -> counter = 0 if progress is True or progress is None: fusible_states = 0 - for cfg in sdfg.all_cfgs_recursive(): + for cfg in sdfg.all_state_scopes_recursive(): fusible_states += cfg.number_of_edges() if progress is True: @@ -1216,7 +1216,7 @@ def fuse_states(sdfg: SDFG, permissive: bool = False, progress: bool = None) -> for sd in sdfg.all_sdfgs_recursive(): id = sd.sdfg_id - for cfg in sd.all_cfgs_recursive(recurse_into_sdfgs=False): + for cfg in sd.all_state_scopes_recursive(recurse_into_sdfgs=False): while True: edges = list(cfg.nx.edges) applied = 0 @@ -1476,31 +1476,26 @@ def _traverse(scope: Node, symbols: Dict[str, dtypes.typeclass]): yield from _traverse(None, symbols) -def traverse_sdfg_with_defined_symbols( - sdfg: SDFG, - recursive: bool = False) -> Generator[Tuple[SDFGState, Node, Dict[str, dtypes.typeclass]], None, None]: - """ - Traverses the SDFG, its states and nodes, yielding the defined symbols and their types at each node. - - :return: A generator that yields tuples of (state, node in state, currently-defined symbols) - """ - # Start with global symbols - symbols = copy.copy(sdfg.symbols) - symbols.update({k: dt.create_datadescriptor(v).dtype for k, v in sdfg.constants.items()}) - for desc in sdfg.arrays.values(): - symbols.update({str(s): s.dtype for s in desc.free_symbols}) - +def _tswds_scope_block( + sdfg: SDFG, + scope: ScopeBlock, + symbols: Dict[str, dtypes.typeclass], + recursive: bool, +) -> Generator[Tuple[SDFGState, Node, Dict[str, dtypes.typeclass]], None, None]: # Add symbols from inter-state edges along the state machine - start_state = sdfg.start_state + start_block = scope.start_block visited = set() visited_edges = set() - for edge in sdfg.dfs_edges(start_state): + for edge in sdfg.dfs_edges(start_block): # Source -> inter-state definition -> Destination visited_edges.add(edge) # Source if edge.src not in visited: visited.add(edge.src) - yield from _tswds_state(sdfg, edge.src, symbols, recursive) + if isinstance(edge.src, SDFGState): + yield from _tswds_state(sdfg, edge.src, symbols, recursive) + else: + yield from _tswds_scope_block(sdfg, edge.src, symbols, recursive) # Add edge symbols into defined symbols issyms = edge.data.new_symbols(sdfg, symbols) @@ -1509,11 +1504,34 @@ def traverse_sdfg_with_defined_symbols( # Destination if edge.dst not in visited: visited.add(edge.dst) - yield from _tswds_state(sdfg, edge.dst, symbols, recursive) + if isinstance(edge.dst, SDFGState): + yield from _tswds_state(sdfg, edge.dst, symbols, recursive) + else: + yield from _tswds_scope_block(sdfg, edge.dst, symbols, recursive) # If there is only one state, the DFS will miss it - if start_state not in visited: - yield from _tswds_state(sdfg, start_state, symbols, recursive) + if start_block not in visited: + if isinstance(start_block, SDFGState): + yield from _tswds_state(sdfg, start_block, symbols, recursive) + else: + yield from _tswds_scope_block(sdfg, start_block, symbols, recursive) + + +def traverse_sdfg_with_defined_symbols( + sdfg: SDFG, + recursive: bool = False) -> Generator[Tuple[SDFGState, Node, Dict[str, dtypes.typeclass]], None, None]: + """ + Traverses the SDFG, its states and nodes, yielding the defined symbols and their types at each node. + + :return: A generator that yields tuples of (state, node in state, currently-defined symbols) + """ + # Start with global symbols + symbols = copy.copy(sdfg.symbols) + symbols.update({k: dt.create_datadescriptor(v).dtype for k, v in sdfg.constants.items()}) + for desc in sdfg.arrays.values(): + symbols.update({str(s): s.dtype for s in desc.free_symbols}) + + yield from _tswds_scope_block(sdfg, sdfg, symbols, recursive) def is_fpga_kernel(sdfg, state): diff --git a/dace/transformation/dataflow/warp_tiling.py b/dace/transformation/dataflow/warp_tiling.py index 7639794943..810848b677 100644 --- a/dace/transformation/dataflow/warp_tiling.py +++ b/dace/transformation/dataflow/warp_tiling.py @@ -123,7 +123,7 @@ def apply(self, graph: SDFGState, sdfg: SDFG) -> nodes.MapEntry: write = nstate.add_write(name) edge = nstate.add_nedge(read, write, copy.deepcopy(out_edge.data)) edge.data.wcr = None - xfh.state_fission(nsdfg, SubgraphView(nstate, [read, write])) + xfh.state_fission(SubgraphView(nstate, [read, write])) newnode = nstate.add_access(name) nstate.remove_edge(out_edge) diff --git a/dace/transformation/helpers.py b/dace/transformation/helpers.py index 16517de9e2..8cd4346861 100644 --- a/dace/transformation/helpers.py +++ b/dace/transformation/helpers.py @@ -646,7 +646,7 @@ def nest_state_subgraph(sdfg: SDFG, return nested_sdfg -def state_fission(sdfg: SDFG, subgraph: graph.SubgraphView, label: Optional[str] = None) -> SDFGState: +def state_fission(subgraph: graph.SubgraphView, label: Optional[str] = None) -> SDFGState: """ Given a subgraph, adds a new SDFG state before the state that contains it, removes the subgraph from the original state, and connects the two states. @@ -656,7 +656,7 @@ def state_fission(sdfg: SDFG, subgraph: graph.SubgraphView, label: Optional[str] """ state: SDFGState = subgraph.graph - newstate = sdfg.add_state_before(state, label=label) + newstate = state.parent.add_state_before(state, label=label) # Save edges before removing nodes orig_edges = subgraph.edges() diff --git a/dace/transformation/interstate/multistate_inline.py b/dace/transformation/interstate/multistate_inline.py index d1f92ff628..aa999ae480 100644 --- a/dace/transformation/interstate/multistate_inline.py +++ b/dace/transformation/interstate/multistate_inline.py @@ -164,7 +164,7 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG): # Symbols outer_symbols = {str(k): v for k, v in sdfg.symbols.items()} - for cf in sdfg.all_cfgs_recursive(recurse_into_sdfgs=False): + for cf in sdfg.all_state_scopes_recursive(recurse_into_sdfgs=False): for ise in cf.edges(): outer_symbols.update(ise.data.new_symbols(sdfg, outer_symbols)) @@ -189,12 +189,12 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG): # Collect and modify interstate edges as necessary outer_assignments = set() - for cf in sdfg.all_cfgs_recursive(): + for cf in sdfg.all_state_scopes_recursive(): for e in cf.edges(): outer_assignments |= e.data.assignments.keys() inner_assignments = set() - for cf in nsdfg.all_cfgs_recursive(): + for cf in nsdfg.all_state_scopes_recursive(): for e in cf.edges(): inner_assignments |= e.data.assignments.keys() @@ -237,7 +237,7 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG): # All transients become transients of the parent (if data already # exists, find new name) - for cf in nsdfg.all_cfgs_recursive(): + for cf in nsdfg.all_state_scopes_recursive(): for nblock in cf.nodes(): for node in nblock.nodes(): if isinstance(node, nodes.AccessNode): @@ -413,7 +413,7 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG): # e.data, outer_edge.data) # Replace nested SDFG parents and SDFG pointers. - for cf in nsdfg.all_cfgs_recursive(): + for cf in nsdfg.all_state_scopes_recursive(): for nblock in cf.nodes(): nblock.parent = cf nblock.sdfg = sdfg diff --git a/dace/transformation/interstate/scope_inline.py b/dace/transformation/interstate/scope_inline.py index 74d9edb647..c05516400a 100644 --- a/dace/transformation/interstate/scope_inline.py +++ b/dace/transformation/interstate/scope_inline.py @@ -4,7 +4,7 @@ from typing import Any, Set, Optional from dace.frontend.python import astutils -from dace.sdfg import SDFG, ControlFlowGraph, InterstateEdge, SDFGState +from dace.sdfg import SDFG, InterstateEdge, SDFGState, ScopeBlock from dace.sdfg import utils as sdutil from dace.sdfg.nodes import CodeBlock from dace.sdfg.state import LoopScopeBlock, ScopeBlock @@ -26,10 +26,10 @@ def annotates_memlets(): def expressions(cls): return [sdutil.node_path_graph(cls.block)] - def can_be_applied(self, graph: ControlFlowGraph, expr_index: int, sdfg: SDFG, permissive: bool = False) -> bool: + def can_be_applied(self, graph: ScopeBlock, expr_index: int, sdfg: SDFG, permissive: bool = False) -> bool: return True - def apply(self, graph: ControlFlowGraph, sdfg: SDFG) -> Optional[int]: + def apply(self, graph: ScopeBlock, sdfg: SDFG) -> Optional[int]: parent: ScopeBlock = graph internal_start = self.block.start_block diff --git a/dace/transformation/interstate/sdfg_nesting.py b/dace/transformation/interstate/sdfg_nesting.py index df75bb9911..8081472199 100644 --- a/dace/transformation/interstate/sdfg_nesting.py +++ b/dace/transformation/interstate/sdfg_nesting.py @@ -565,7 +565,7 @@ def apply(self, state: SDFGState, sdfg: SDFG): # Fission state if necessary cc = utils.weakly_connected_component(state, node) if not any(n in cc for n in subgraph.nodes()): - helpers.state_fission(state.sdfg, cc) + helpers.state_fission(cc) for edge in removed_out_edges: # Find last access node that refers to this edge try: @@ -580,7 +580,7 @@ def apply(self, state: SDFGState, sdfg: SDFG): cc = utils.weakly_connected_component(state, node) if not any(n in cc for n in subgraph.nodes()): cc2 = SubgraphView(state, [n for n in state.nodes() if n not in cc]) - state = helpers.state_fission(sdfg, cc2) + state = helpers.state_fission(cc2) ####################################################### # Remove nested SDFG node diff --git a/dace/transformation/interstate/state_fusion.py b/dace/transformation/interstate/state_fusion.py index 84748d11eb..6cff57e339 100644 --- a/dace/transformation/interstate/state_fusion.py +++ b/dace/transformation/interstate/state_fusion.py @@ -10,7 +10,7 @@ from dace.config import Config from dace.sdfg import SDFG, nodes from dace.sdfg import utils as sdutil -from dace.sdfg.state import SDFGState, ControlFlowGraph +from dace.sdfg.state import SDFGState, ScopeBlock from dace.transformation import transformation @@ -455,7 +455,7 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): return True - def apply(self, graph: Union[ControlFlowGraph, SDFGState], sdfg: SDFG): + def apply(self, graph: Union[ScopeBlock, SDFGState], sdfg: SDFG): first_state: SDFGState = self.first_state second_state: SDFGState = self.second_state diff --git a/dace/transformation/pass_pipeline.py b/dace/transformation/pass_pipeline.py index 4e16bb6207..e1e1e133fa 100644 --- a/dace/transformation/pass_pipeline.py +++ b/dace/transformation/pass_pipeline.py @@ -3,7 +3,7 @@ API for SDFG analysis and manipulation Passes, as well as Pipelines that contain multiple dependent passes. """ from dace import properties, serialize -from dace.sdfg import SDFG, SDFGState, graph as gr, nodes, utils as sdutil +from dace.sdfg import SDFG, SDFGState, graph as gr, nodes, utils as sdutil, ScopeBlock from enum import Flag, auto from typing import Any, Dict, Iterator, List, Optional, Set, Type, Union @@ -307,6 +307,56 @@ def apply(self, scope: nodes.EntryNode, state: SDFGState, pipeline_results: Dict raise NotImplementedError +@properties.make_properties +class ControlFlowScopePass(Pass): + """ + A specialized Pass type that applies to each control flow scope (i.e., CFG) separately. Such a pass is + realized by implementing the ``apply`` method, which accepts a CFG and the SDFG it belongs to. + + :see: Pass + """ + + CATEGORY: str = 'Helper' + + def apply_pass( + self, + sdfg: SDFG, + pipeline_results: Dict[str, Any], + ) -> Optional[Dict[nodes.EntryNode, Optional[Any]]]: + """ + Applies the pass to the CFGs of the given SDFG by calling ``apply`` on each CFG. + + :param sdfg: The SDFG to apply the pass to. + :param pipeline_results: If in the context of a ``Pipeline``, a dictionary that is populated with prior Pass + results as ``{Pass subclass name: returned object from pass}``. If not run in a + pipeline, an empty dictionary is expected. + :return: A dictionary of ``{entry node: return value}`` for visited CFGs with a non-None return value, or None + if nothing was returned. + """ + result = {} + for scope_block in sdfg.all_state_scopes_recursive(recurse_into_sdfgs=False): + retval = self.apply(scope_block, scope_block.sdfg, pipeline_results) + if retval is not None: + result[scope_block] = retval + + if not result: + return None + return result + + def apply(self, scope_block: ScopeBlock, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Optional[Any]: + """ + Applies this pass on the given scope. + + :param scope_block: The control flow scope block to apply the pass to. + :param sdfg: The parent SDFG of the given scope. + :param pipeline_results: If in the context of a ``Pipeline``, a dictionary that is populated with prior Pass + results as ``{Pass subclass name: returned object from pass}``. If not run in a + pipeline, an empty dictionary is expected. + :return: Some object if pass was applied, or None if nothing changed. + """ + raise NotImplementedError + + @dataclass @properties.make_properties class Pipeline(Pass): diff --git a/dace/transformation/passes/analysis.py b/dace/transformation/passes/analysis.py index e0a562f143..cc1155a1e8 100644 --- a/dace/transformation/passes/analysis.py +++ b/dace/transformation/passes/analysis.py @@ -68,7 +68,7 @@ def apply_pass( for sdfg in top_sdfg.all_sdfgs_recursive(): adesc = set(sdfg.arrays.keys()) result: Dict[Union[ControlFlowBlock, Edge[InterstateEdge]], Tuple[Set[str], Set[str]]] = {} - for cfg in sdfg.all_cfgs_recursive(recurse_into_sdfgs=False): + for cfg in sdfg.all_state_scopes_recursive(recurse_into_sdfgs=False): for block in cfg.nodes(): readset = block.free_symbols # No symbols may be written to inside states. diff --git a/dace/transformation/passes/dead_state_elimination.py b/dace/transformation/passes/dead_state_elimination.py index a5ff0ba71a..53755e01ef 100644 --- a/dace/transformation/passes/dead_state_elimination.py +++ b/dace/transformation/passes/dead_state_elimination.py @@ -5,6 +5,7 @@ from typing import Optional, Set, Tuple, Union from dace import SDFG, InterstateEdge, SDFGState, symbolic, properties +from dace.sdfg.state import ScopeBlock from dace.properties import CodeBlock from dace.sdfg.graph import Edge from dace.sdfg.validation import InvalidSDFGInterstateEdgeError @@ -12,7 +13,7 @@ @properties.make_properties -class DeadStateElimination(ppl.Pass): +class DeadStateElimination(ppl.ControlFlowScopePass): """ Removes all unreachable states (e.g., due to a branch that will never be taken) from an SDFG. """ @@ -26,23 +27,23 @@ def should_reapply(self, modified: ppl.Modifies) -> bool: # If connectivity or any edges were changed, some more states might be dead return modified & (ppl.Modifies.InterstateEdges | ppl.Modifies.States) - def apply_pass(self, sdfg: SDFG, _) -> Optional[Set[Union[SDFGState, Edge[InterstateEdge]]]]: + def apply(self, scope_block: ScopeBlock, sdfg: SDFG, _) -> Optional[Set[Union[SDFGState, Edge[InterstateEdge]]]]: """ Removes unreachable states throughout an SDFG. + :param scope_block: The scope block to modify. :param sdfg: The SDFG to modify. :param pipeline_results: If in the context of a ``Pipeline``, a dictionary that is populated with prior Pass results as ``{Pass subclass name: returned object from pass}``. If not run in a pipeline, an empty dictionary is expected. - :param initial_symbols: If not None, sets values of initial symbols. :return: A set of the removed states, or None if nothing was changed. """ # Mark dead states and remove them - dead_states, dead_edges, annotated = self.find_dead_states(sdfg, set_unconditional_edges=True) + dead_states, dead_edges, annotated = self.find_dead_states(scope_block, sdfg, set_unconditional_edges=True) for e in dead_edges: - sdfg.remove_edge(e) - sdfg.remove_nodes_from(dead_states) + scope_block.remove_edge(e) + scope_block.remove_nodes_from(dead_states) result = dead_states | dead_edges @@ -53,6 +54,7 @@ def apply_pass(self, sdfg: SDFG, _) -> Optional[Set[Union[SDFGState, Edge[Inters def find_dead_states( self, + scope_block: ScopeBlock, sdfg: SDFG, set_unconditional_edges: bool = True) -> Tuple[Set[SDFGState], Set[Edge[InterstateEdge]], bool]: """ @@ -72,7 +74,7 @@ def find_dead_states( # Run a modified BFS where definitely False edges are not traversed, or if there is an # unconditional edge the rest are not. The inverse of the visited states is the dead set. - queue = collections.deque([sdfg.start_state]) + queue = collections.deque([scope_block.start_block]) while len(queue) > 0: node = queue.popleft() if node in visited: @@ -81,13 +83,13 @@ def find_dead_states( # First, check for unconditional edges unconditional = None - for e in sdfg.out_edges(node): + for e in scope_block.out_edges(node): # If an unconditional edge is found, ignore all other outgoing edges if self.is_definitely_taken(e.data, sdfg): # If more than one unconditional outgoing edge exist, fail with Invalid SDFG if unconditional is not None: raise InvalidSDFGInterstateEdgeError('Multiple unconditional edges leave the same state', sdfg, - sdfg.edge_id(e)) + scope_block.edge_id(e)) unconditional = e if set_unconditional_edges and not e.data.is_unconditional(): # Annotate edge as unconditional @@ -100,7 +102,7 @@ def find_dead_states( continue if unconditional is not None: # Unconditional edge exists, skip traversal # Remove other (now never taken) edges from graph - for e in sdfg.out_edges(node): + for e in scope_block.out_edges(node): if e is not unconditional: dead_edges.add(e) @@ -108,7 +110,7 @@ def find_dead_states( # End of unconditional check # Check outgoing edges normally - for e in sdfg.out_edges(node): + for e in scope_block.out_edges(node): next_node = e.dst # Test for edges that definitely evaluate to False @@ -121,7 +123,7 @@ def find_dead_states( queue.append(next_node) # Dead states are states that are not live (i.e., visited) - return set(sdfg.nodes()) - visited, dead_edges, edges_annotated + return set(scope_block.nodes()) - visited, dead_edges, edges_annotated def report(self, pass_retval: Set[Union[SDFGState, Edge[InterstateEdge]]]) -> str: if pass_retval is not None and not pass_retval: diff --git a/dace/transformation/passes/scalar_to_symbol.py b/dace/transformation/passes/scalar_to_symbol.py index 124efdaae1..05dace8166 100644 --- a/dace/transformation/passes/scalar_to_symbol.py +++ b/dace/transformation/passes/scalar_to_symbol.py @@ -95,7 +95,7 @@ def find_promotable_scalars(sdfg: sd.SDFG, transients_only: bool = True, integer # Check all occurrences of candidates in SDFG and filter out candidates_seen: Set[str] = set() - for state in sdfg.nodes(): + for state in sdfg.states(): candidates_in_state: Set[str] = set() for node in state.nodes(): @@ -225,7 +225,7 @@ def find_promotable_scalars(sdfg: sd.SDFG, transients_only: bool = True, integer # Filter out non-integral symbols that do not appear in inter-state edges interstate_symbols = set() - for edge in sdfg.edges(): + for edge in sdfg.all_interstate_edges_recursive(recurse_into_sdfgs=False): interstate_symbols |= edge.data.free_symbols for candidate in (candidates - interstate_symbols): if integers_only and sdfg.arrays[candidate].dtype not in dtypes.INTEGER_TYPES: @@ -508,7 +508,7 @@ def remove_scalar_reads(sdfg: sd.SDFG, array_names: Dict[str, str]): replacement symbol name. :note: Operates in-place on the SDFG. """ - for state in sdfg.nodes(): + for state in sdfg.states(): scalar_nodes = [n for n in state.nodes() if isinstance(n, nodes.AccessNode) and n.data in array_names] for node in scalar_nodes: symname = array_names[node.data] @@ -633,8 +633,8 @@ def apply_pass(self, sdfg: SDFG, _: Dict[Any, Any]) -> Set[str]: if len(to_promote) == 0: return None - for state in sdfg.nodes(): - scalar_nodes = [n for n in state.nodes() if isinstance(n, nodes.AccessNode) and n.data in to_promote] + for state in sdfg.states(): + scalar_nodes = [n for n in state.data_nodes() if n.data in to_promote] # Step 2: Assignment tasklets for node in scalar_nodes: if state.in_degree(node) == 0: @@ -645,8 +645,8 @@ def apply_pass(self, sdfg: SDFG, _: Dict[Any, Any]) -> Set[str]: # There is only zero or one incoming edges by definition tasklet_inputs = [e.src for e in state.in_edges(input)] # Step 2.1 - new_state = xfh.state_fission(sdfg, gr.SubgraphView(state, set([input, node] + tasklet_inputs))) - new_isedge: sd.InterstateEdge = sdfg.out_edges(new_state)[0] + new_state = xfh.state_fission(gr.SubgraphView(state, set([input, node] + tasklet_inputs))) + new_isedge: sd.InterstateEdge = state.parent.out_edges(new_state)[0] # Step 2.2 node: nodes.AccessNode = new_state.sink_nodes()[0] input = new_state.in_edges(node)[0].src @@ -683,7 +683,7 @@ def apply_pass(self, sdfg: SDFG, _: Dict[Any, Any]) -> Set[str]: remove_scalar_reads(sdfg, {k: k for k in to_promote}) # Step 4: Isolated nodes - for state in sdfg.nodes(): + for state in sdfg.states(): scalar_nodes = [n for n in state.nodes() if isinstance(n, nodes.AccessNode) and n.data in to_promote] state.remove_nodes_from([n for n in scalar_nodes if len(state.all_edges(n)) == 0]) @@ -699,7 +699,7 @@ def apply_pass(self, sdfg: SDFG, _: Dict[Any, Any]) -> Set[str]: # Step 6: Inter-state edge cleanup cleanup_re = {s: re.compile(fr'\b{re.escape(s)}\[.*?\]') for s in to_promote} promo = TaskletPromoterDict({k: k for k in to_promote}) - for edge in sdfg.edges(): + for edge in sdfg.all_interstate_edges_recursive(recurse_into_sdfgs=False): ise: InterstateEdge = edge.data # Condition if not edge.data.is_unconditional(): diff --git a/dace/transformation/passes/simplify.py b/dace/transformation/passes/simplify.py index e6f416cc4d..d5441ff468 100644 --- a/dace/transformation/passes/simplify.py +++ b/dace/transformation/passes/simplify.py @@ -17,12 +17,12 @@ SIMPLIFY_PASSES = [ InlineSDFGs, - #ScalarToSymbolPromotion, + ScalarToSymbolPromotion, FuseStates, #OptionalArrayInference, #ConstantPropagation, #DeadDataflowElimination, - #DeadStateElimination, + DeadStateElimination, #RemoveUnusedSymbols, #ArrayElimination, #ConsolidateEdges, diff --git a/dace/transformation/transformation.py b/dace/transformation/transformation.py index 99f48373c1..bd07b76c90 100644 --- a/dace/transformation/transformation.py +++ b/dace/transformation/transformation.py @@ -22,7 +22,7 @@ import copy from dace import dtypes, serialize from dace.dtypes import ScheduleType -from dace.sdfg import SDFG, SDFGState, ControlFlowGraph +from dace.sdfg import SDFG, SDFGState, ScopeBlock from dace.sdfg import nodes as nd, graph as gr, utils as sdutil, propagation, infer_types, state as st from dace.properties import make_properties, Property, DictProperty, SetProperty from dace.transformation import pass_pipeline as ppl @@ -108,7 +108,7 @@ def expressions(cls) -> List[gr.SubgraphView]: raise NotImplementedError def can_be_applied(self, - graph: Union[ControlFlowGraph, SDFGState], + graph: Union[ScopeBlock, SDFGState], expr_index: int, sdfg: SDFG, permissive: bool = False) -> bool: @@ -126,7 +126,7 @@ def can_be_applied(self, """ raise NotImplementedError - def apply(self, graph: Union[ControlFlowGraph, SDFGState], sdfg: SDFG) -> Union[Any, None]: + def apply(self, graph: Union[ScopeBlock, SDFGState], sdfg: SDFG) -> Union[Any, None]: """ Applies this transformation instance on the matched pattern graph. @@ -500,7 +500,7 @@ def expressions(cls) -> List[gr.SubgraphView]: pass @abc.abstractmethod - def can_be_applied(self, graph: ControlFlowGraph, expr_index: int, sdfg: SDFG, permissive: bool = False) -> bool: + def can_be_applied(self, graph: ScopeBlock, expr_index: int, sdfg: SDFG, permissive: bool = False) -> bool: """ Returns True if this transformation can be applied on the candidate matched subgraph. :param graph: SDFG object in which the match was found. diff --git a/tests/transformations/state_fission_test.py b/tests/transformations/state_fission_test.py index 7c03fbed89..fb813d700c 100644 --- a/tests/transformations/state_fission_test.py +++ b/tests/transformations/state_fission_test.py @@ -120,17 +120,17 @@ def test_state_fission(): sdfg = make_nested_sdfg_cpu() # state fission - state = sdfg.states()[0] + state = list(sdfg.states())[0] node_x = state.nodes()[0] node_y = state.nodes()[1] node_z = state.nodes()[2] vec_add1 = state.nodes()[3] subg = dace.sdfg.graph.SubgraphView(state, [node_x, node_y, vec_add1, node_z]) - helpers.state_fission(sdfg, subg) + helpers.state_fission(subg) sdfg.validate() - assert (len(sdfg.states()) == 2) + assert (len(list(sdfg.states())) == 2) # run the program vec_add = sdfg.compile() diff --git a/tests/transformations/trivial_loop_elimination_test.py b/tests/transformations/trivial_loop_elimination_test.py index 6f2769f921..20514a3331 100644 --- a/tests/transformations/trivial_loop_elimination_test.py +++ b/tests/transformations/trivial_loop_elimination_test.py @@ -3,7 +3,6 @@ import dace from dace.transformation.interstate import TrivialLoopElimination from dace.symbolic import pystr_to_symbolic -import unittest import numpy as np I = dace.symbol("I") @@ -17,21 +16,19 @@ def trivial_loop(data: dace.float64[I, J]): data[i, j] = data[i, j] + data[i - 1, j] -class TrivialLoopEliminationTest(unittest.TestCase): +def test_semantic_eq(): + A1 = np.random.rand(16, 16) + A2 = np.copy(A1) - def test_semantic_eq(self): - A1 = np.random.rand(16, 16) - A2 = np.copy(A1) + sdfg = trivial_loop.to_sdfg(simplify=False) + sdfg(A1, I=A1.shape[0], J=A1.shape[1]) - sdfg = trivial_loop.to_sdfg(simplify=False) - sdfg(A1, I=A1.shape[0], J=A1.shape[1]) + count = sdfg.apply_transformations(TrivialLoopElimination) + assert (count > 0) + sdfg(A2, I=A1.shape[0], J=A1.shape[1]) - count = sdfg.apply_transformations(TrivialLoopElimination) - self.assertGreater(count, 0) - sdfg(A2, I=A1.shape[0], J=A1.shape[1]) - - self.assertTrue(np.allclose(A1, A2)) + assert np.allclose(A1, A2) if __name__ == '__main__': - unittest.main() + test_semantic_eq()