diff --git a/dace/codegen/codegen.py b/dace/codegen/codegen.py index b963da4812..cae1644419 100644 --- a/dace/codegen/codegen.py +++ b/dace/codegen/codegen.py @@ -16,7 +16,7 @@ from dace.codegen.targets import cpp, cpu from dace.codegen.instrumentation import InstrumentationProvider -from dace.sdfg.state import SDFGState +from dace.sdfg.state import SDFGState, ScopeBlock def generate_headers(sdfg: SDFG, frame: framecode.DaCeCodeGenerator) -> str: @@ -102,6 +102,8 @@ def _get_codegen_targets(sdfg: SDFG, frame: framecode.DaCeCodeGenerator): state: SDFGState = parent nsdfg = state.sdfg frame.targets.add(disp.get_node_dispatcher(nsdfg, state, node)) + elif isinstance(node, ScopeBlock): + frame.targets.add(disp.get_state_scope_dispatcher(parent, node)) # Array allocation if isinstance(node, dace.nodes.AccessNode): @@ -149,7 +151,7 @@ def _get_codegen_targets(sdfg: SDFG, frame: framecode.DaCeCodeGenerator): disp.instrumentation[sdfg.instrument] = provider_mapping[sdfg.instrument] -def generate_code(sdfg, validate=True) -> List[CodeObject]: +def generate_code(sdfg: SDFG, validate=True) -> List[CodeObject]: """ Generates code as a list of code objects for a given SDFG. diff --git a/dace/codegen/targets/framecode.py b/dace/codegen/targets/framecode.py index 57767c66b0..fda9df87e6 100644 --- a/dace/codegen/targets/framecode.py +++ b/dace/codegen/targets/framecode.py @@ -83,7 +83,7 @@ def free_symbols(self, obj: Any): if k in self.fsyms: return self.fsyms[k] if hasattr(obj, 'used_symbols'): - result = obj.used_symbols(all_symbols=False) + result = obj.used_symbols(all_symbols=False)[0] else: result = obj.free_symbols self.fsyms[k] = result @@ -395,9 +395,14 @@ def generate_external_memory_management(self, sdfg: SDFG, callsite_stream: CodeI # Footer callsite_stream.write('}', sdfg) - def generate_state(self, sdfg, state, global_stream, callsite_stream, generate_state_footer=True): + def generate_state(self, + sdfg: SDFG, + state: SDFGState, + global_stream: CodeIOStream, + callsite_stream: CodeIOStream, + generate_state_footer=True) -> None: - sid = sdfg.node_id(state) + sid = state.parent.node_id(state) # Emit internal transient array allocation self.allocate_arrays_in_scope(sdfg, state, global_stream, callsite_stream) @@ -444,7 +449,7 @@ def generate_state(self, sdfg, state, global_stream, callsite_stream, generate_s if instr is not None: instr.on_state_end(sdfg, state, callsite_stream, global_stream) - def generate_states(self, sdfg, global_stream, callsite_stream): + def generate_states(self, sdfg: SDFG, global_stream: CodeIOStream, callsite_stream: CodeIOStream): states_generated = set() opbar = progress.OptionalProgressBar(sdfg.number_of_nodes(), title=f'Generating code (SDFG {sdfg.sdfg_id})') @@ -526,8 +531,7 @@ def _can_allocate(self, sdfg: SDFG, state: SDFGState, desc: data.Data, scope: Un def determine_allocation_lifetime(self, top_sdfg: SDFG): """ - Determines where (at which scope/state/SDFG) each data descriptor - will be allocated/deallocated. + Determines where (at which scope/state/SDFG) each data descriptor will be allocated/deallocated. :param top_sdfg: The top-level SDFG to determine for. """ @@ -543,8 +547,7 @@ def determine_allocation_lifetime(self, top_sdfg: SDFG): ############################################# # Look for all states in which a scope-allocated array is used in instances: Dict[str, List[Tuple[SDFGState, nodes.AccessNode]]] = collections.defaultdict(list) - array_names = sdfg.arrays.keys( - ) #set(k for k, v in sdfg.arrays.items() if v.lifetime == dtypes.AllocationLifetime.Scope) + array_names = sdfg.arrays.keys() # Iterate topologically to get state-order for state in sdfg.topological_sort(): for node in state.data_nodes(): diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index 5cdd10ba98..07667cfb83 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -31,7 +31,7 @@ from dace.sdfg.propagation import propagate_memlet, propagate_subset, propagate_states from dace.memlet import Memlet from dace.properties import LambdaProperty, CodeBlock -from dace.sdfg import SDFG, SDFGState, ControlFlowGraph, ControlFlowBlock, LoopScopeBlock +from dace.sdfg import SDFG, SDFGState, ControlFlowGraph, ControlFlowBlock, LoopScopeBlock, ScopeBlock from dace.sdfg.replace import replace_datadesc_names from dace.symbolic import pystr_to_symbolic, inequal_symbols @@ -1046,8 +1046,8 @@ class ProgramVisitor(ExtNodeVisitor): sdfg: SDFG last_block: ControlFlowBlock - cfg_target: ControlFlowGraph - last_cfg_target: ControlFlowGraph + cfg_target: ScopeBlock + last_cfg_target: ScopeBlock current_state: SDFGState def __init__(self, @@ -1334,8 +1334,8 @@ def _add_block(self, block: ControlFlowBlock): else: self.current_state = block - def _add_state(self, label=None) -> SDFGState: - state = self.cfg_target.add_state(label, False) + def _add_state(self, label=None, is_start=False) -> SDFGState: + state = self.cfg_target.add_state(label, is_start_block=is_start) self._add_block(state) return state @@ -2344,8 +2344,10 @@ def visit_For(self, node: ast.For): initialize_expr=astutils.unparse(ast_ranges[0][0]), update_expr=incr[indices[0]], inverted=False) - self._recursive_visit(node.body, f'for_{node.lineno}', node.lineno, extra_symbols=extra_syms, - parent=loop_scope, unconnected_last_block=False) + _, first_subblock, _, _ = self._recursive_visit(node.body, f'for_{node.lineno}', + node.lineno, extra_symbols=extra_syms, + parent=loop_scope, unconnected_last_block=False) + loop_scope.start_block = loop_scope.node_id(first_subblock) # Handle else clause if node.orelse: diff --git a/dace/sdfg/graph.py b/dace/sdfg/graph.py index 5c93149529..42fb228c3f 100644 --- a/dace/sdfg/graph.py +++ b/dace/sdfg/graph.py @@ -365,8 +365,7 @@ def sink_nodes(self) -> List[NodeT]: return [n for n in self.nodes() if self.out_degree(n) == 0] def topological_sort(self, source: NodeT = None) -> Sequence[NodeT]: - """Returns nodes in topological order iff the graph contains exactly - one node with no incoming edges.""" + """Returns nodes in topological order iff the graph contains exactly one node with no incoming edges.""" if source is not None: sources = [source] else: diff --git a/dace/sdfg/infer_types.py b/dace/sdfg/infer_types.py index ed4f5e068f..d243453f0c 100644 --- a/dace/sdfg/infer_types.py +++ b/dace/sdfg/infer_types.py @@ -61,7 +61,7 @@ def infer_connector_types(sdfg: SDFG): :param sdfg: The SDFG to infer. """ # Loop over states, and in a topological sort over each state's nodes - for state in sdfg.nodes(): + for state in sdfg.all_states_recursive(): for node in dfs_topological_sort(state): # Try to infer input connector type from node type or previous edges for e in state.in_edges(node): @@ -167,7 +167,7 @@ def set_default_schedule_and_storage_types(scope: Union[SDFG, SDFGState, nodes.E if isinstance(scope, SDFG): # Set device for default top-level schedules and storages - for state in scope.nodes(): + for state in scope.all_states_recursive(): set_default_schedule_and_storage_types(state, parent_schedules, use_parent_schedule=use_parent_schedule, diff --git a/dace/sdfg/scope.py b/dace/sdfg/scope.py index 75f8eacf75..2b2dd6a1e0 100644 --- a/dace/sdfg/scope.py +++ b/dace/sdfg/scope.py @@ -122,10 +122,7 @@ def node_id_or_none(node): if node is None: return -1 return state.node_id(node) - res = {} - for k, v in scope_dict.items(): - res[node_id_or_none(k)] = [node_id_or_none(vi) for vi in v] if v is not None else [] - return res + return {node_id_or_none(k): [node_id_or_none(vi) for vi in v] for k, v in scope_dict.items()} def scope_contains_scope(sdict: ScopeDictType, node: NodeType, other_node: NodeType) -> bool: diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index f2ece7e042..4d457fca33 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -1404,22 +1404,21 @@ def transients(self): return result def shared_transients(self, check_toplevel=True) -> List[str]: - """ Returns a list of transient data that appears in more than one - state. """ + """ Returns a list of transient data that appears in more than one state. """ seen = {} shared = [] # If a transient is present in an inter-state edge, it is shared - for interstate_edge in self.edges(): + for interstate_edge in self.all_interstate_edges_recursive(): for sym in interstate_edge.data.free_symbols: if sym in self.arrays and self.arrays[sym].transient: seen[sym] = interstate_edge shared.append(sym) # If transient is accessed in more than one state, it is shared - for state in self.nodes(): - for node in state.nodes(): - if isinstance(node, nd.AccessNode) and node.desc(self).transient: + for state in self.all_states_recursive(): + for node in state.data_nodes(): + if node.desc(self).transient: if (check_toplevel and node.desc(self).toplevel) or (node.data in seen and seen[node.data] != state): shared.append(node.data) @@ -2255,8 +2254,10 @@ def __call__(self, *args, **kwargs): def fill_scope_connectors(self): """ Fills missing scope connectors (i.e., "IN_#"/"OUT_#" on entry/exit nodes) according to data on the memlets. """ - for state in self.nodes(): - state.fill_scope_connectors() + for cf in self.all_cfgs_recursive(): + for block in cf.nodes(): + if isinstance(block, SDFGState): + block.fill_scope_connectors() def predecessor_state_transitions(self, state): """ Yields paths (lists of edges) that the SDFG can pass through @@ -2555,7 +2556,7 @@ def expand_library_nodes(self, recursive=True): including library nodes that expand to library nodes. """ - states = list(self.states()) + states = list(self.all_states_recursive()) while len(states) > 0: state = states.pop() expanded_something = False diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 912dc171c7..18caf9d80c 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -1169,7 +1169,7 @@ def to_json(self, parent=None): from dace.sdfg.scope import _scope_dict_to_ids # Create scope dictionary with a failsafe try: - scope_dict_regular = self.scope_dict() + scope_dict_regular = self.scope_children() scope_dict_ids = _scope_dict_to_ids(self, scope_dict_regular) scope_dict = {k: sorted(v) for k, v in sorted(scope_dict_ids.items())} except (RuntimeError, ValueError): @@ -2271,6 +2271,12 @@ def all_control_flow_blocks_recursive(self, recurse_into_sdfgs=True) -> Iterator for block in cfg.nodes(): yield block + def all_interstate_edges_recursive(self, recurse_into_sdfgs=True) -> Iterator[Edge['dace.sdfg.InterstateEdge']]: + """ Iterate over all interstate edges in this control flow graph. """ + for cfg in self.all_cfgs_recursive(recurse_into_sdfgs=recurse_into_sdfgs): + for edge in cfg.edges(): + yield edge + ################################################################### # Getters & setters, overrides diff --git a/dace/sdfg/utils.py b/dace/sdfg/utils.py index f6081efbcc..2552e1d417 100644 --- a/dace/sdfg/utils.py +++ b/dace/sdfg/utils.py @@ -13,7 +13,7 @@ from dace.sdfg.graph import MultiConnectorEdge from dace.sdfg.sdfg import SDFG from dace.sdfg.nodes import Node, NestedSDFG -from dace.sdfg.state import SDFGState, StateSubgraphView +from dace.sdfg.state import SDFGState, StateSubgraphView, LoopScopeBlock from dace.sdfg.scope import ScopeSubgraphView from dace.sdfg import nodes as nd, graph as gr, propagation from dace import config, data as dt, dtypes, memlet as mm, subsets as sbs, symbolic @@ -1246,6 +1246,31 @@ def fuse_states(sdfg: SDFG, permissive: bool = False, progress: bool = None) -> return counter +def inline_loop_blocks(sdfg: SDFG, permissive: bool = False, progress: bool = None) -> int: + # Avoid import loops + from dace.transformation.interstate import LoopScopeInline + + counter = 0 + blocks = [(n, p) for n, p in sdfg.all_nodes_recursive() if isinstance(n, LoopScopeBlock)] + + for block, graph in optional_progressbar(reversed(blocks), title='Inlining Loops', n=len(blocks), progress=progress): + id = block.sdfg.sdfg_id + + # We have to reevaluate every time due to changing IDs + block_id = graph.node_id(block) + + candidate = { + LoopScopeInline.block: block, + } + inliner = LoopScopeInline() + inliner.setup_match(graph, id, block_id, candidate, 0, override=True) + if inliner.can_be_applied(graph, 0, block.sdfg, permissive=permissive): + inliner.apply(graph, block.sdfg) + counter += 1 + + return counter + + def inline_sdfgs(sdfg: SDFG, permissive: bool = False, progress: bool = None, multistate: bool = True) -> int: """ Inlines all possible nested SDFGs (or sub-SDFGs) using an optimized diff --git a/dace/transformation/interstate/__init__.py b/dace/transformation/interstate/__init__.py index 0bd168751c..f3fc18e273 100644 --- a/dace/transformation/interstate/__init__.py +++ b/dace/transformation/interstate/__init__.py @@ -12,6 +12,7 @@ from .loop_unroll import LoopUnroll from .loop_peeling import LoopPeeling from .loop_to_map import LoopToMap +from .scope_inline import LoopScopeInline from .move_loop_into_map import MoveLoopIntoMap from .trivial_loop_elimination import TrivialLoopElimination from .multistate_inline import InlineMultistateSDFG diff --git a/dace/transformation/interstate/scope_inline.py b/dace/transformation/interstate/scope_inline.py new file mode 100644 index 0000000000..89d3a1bc5f --- /dev/null +++ b/dace/transformation/interstate/scope_inline.py @@ -0,0 +1,80 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +""" Inline all scope blocks in SDFGs. """ + +from typing import Any, Set, Optional + +from dace.frontend.python import astutils +from dace.sdfg import SDFG, ControlFlowGraph, InterstateEdge, SDFGState +from dace.sdfg import utils as sdutil +from dace.sdfg.nodes import CodeBlock +from dace.sdfg.state import LoopScopeBlock, ScopeBlock +from dace.transformation import transformation + + +class LoopScopeInline(transformation.MultiStateTransformation): + """ + Inlines a loop scope block into a legacy-style state machine. + """ + + block = transformation.PatternNode(LoopScopeBlock) + + @staticmethod + def annotates_memlets(): + return False + + @classmethod + def expressions(cls): + return [sdutil.node_path_graph(cls.block)] + + def can_be_applied(self, graph: ControlFlowGraph, expr_index: int, sdfg: SDFG, permissive: bool = False) -> bool: + return True + + def apply(self, graph: ControlFlowGraph, sdfg: SDFG) -> Optional[int]: + parent: ScopeBlock = graph + + internal_start = self.block.start_block + + # Construct the basic loop state structure. + init_state = parent.add_state(self.block.label + '_init') + for b_edge in parent.in_edges(self.block): + parent.add_edge(b_edge.src, init_state, b_edge.data) + parent.remove_edge(b_edge) + + guard_state = parent.add_state(self.block.label + '_guard') + init_edge = InterstateEdge() + if self.block.init_statement is not None: + init_edge.assignments = { + self.block.loop_variable: self.block.init_statement.as_string.rpartition('=')[2].strip() + } + parent.add_edge(init_state, guard_state, init_edge) + + end_state = parent.add_state(self.block.label + '_end') + parent.add_edge(guard_state, end_state, + InterstateEdge(condition=CodeBlock(astutils.negate_expr(self.block.scope_condition.code)))) + for a_edge in parent.out_edges(self.block): + parent.add_edge(end_state, a_edge.dst, a_edge.data) + parent.remove_edge(a_edge) + + last_loop_state = parent.add_state(self.block.label + '_loop') + loop_edge = InterstateEdge() + if self.block.update_statement is not None: + loop_edge.assignments = { + self.block.loop_variable: self.block.update_statement.as_string.rpartition('=')[2].strip() + } + parent.add_edge(last_loop_state, guard_state, loop_edge) + + to_connect: Set[SDFGState] = set() + for node in self.block.nodes(): + parent.add_node(node) + if self.block.out_degree(node) == 0: + to_connect.add(node) + for edge in self.block.edges(): + parent.add_edge(edge.src, edge.dst, edge.data) + + # Connect the loop states + parent.add_edge(guard_state, internal_start, + InterstateEdge(condition=self.block.scope_condition.as_string)) + for node in to_connect: + parent.add_edge(node, last_loop_state, InterstateEdge()) + + parent.remove_node(self.block)