From 4976e1635f350c644afaf210e36667f88def34bc Mon Sep 17 00:00:00 2001 From: Roman Cattaneo Date: Wed, 4 Dec 2024 17:31:16 +0100 Subject: [PATCH 1/2] Cleanup unused imports (#1803) Found two unused imports and an unused variable. From other code, I inferred that you follow the convention to prefix unused (return) variables with underscores. I noticed that running `yapf` (as suggested in the contribution guidelines) modified the files in places that I didn't touch. I thus separated the pure formatting changes in the first commit. As argued in PR https://github.com/spcl/dace/pull/1731, I think it would be beneficial (for the project) to enforce formatting as part of the CI. @phschaad not sure if you got to discuss this in the weekly DaCe meeting. I started a discussion page https://github.com/spcl/dace/discussions/1804 and I'm happy to contribute a corresponding workflow early next year (assuming we agree). --------- Co-authored-by: Roman Cattaneo <> --- dace/frontend/python/newast.py | 32 ++++++++++++++++++++------------ dace/frontend/python/parser.py | 22 ++++++++++++---------- 2 files changed, 32 insertions(+), 22 deletions(-) diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index 1cbb8e67c9..d2813371c9 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -1319,7 +1319,7 @@ def _views_to_data(state: SDFGState, nodes: List[dace.nodes.AccessNode]) -> List self.sdfg.replace_dict(repl_dict) propagate_states(self.sdfg) - for state, memlet, inner_indices in itertools.chain(self.inputs.values(), self.outputs.values()): + for state, memlet, _inner_indices in itertools.chain(self.inputs.values(), self.outputs.values()): if state is not None and state.dynamic_executions: memlet.dynamic = True @@ -2366,8 +2366,11 @@ def visit_For(self, node: ast.For): init_expr='%s = %s' % (indices[0], astutils.unparse(ast_ranges[0][0])), update_expr=incr[indices[0]], inverted=False) - _, first_subblock, _, _ = self._recursive_visit(node.body, f'for_{node.lineno}', node.lineno, - extra_symbols=extra_syms, parent=loop_region, + _, first_subblock, _, _ = self._recursive_visit(node.body, + f'for_{node.lineno}', + node.lineno, + extra_symbols=extra_syms, + parent=loop_region, unconnected_last_block=False) loop_region.start_block = loop_region.node_id(first_subblock) self._connect_break_blocks(loop_region) @@ -2449,7 +2452,10 @@ def visit_While(self, node: ast.While): loop_region = self._add_loop_region(loop_cond, label=f'while_{node.lineno}', inverted=False) # Parse body - self._recursive_visit(node.body, f'while_{node.lineno}', node.lineno, parent=loop_region, + self._recursive_visit(node.body, + f'while_{node.lineno}', + node.lineno, + parent=loop_region, unconnected_last_block=False) if test_region is not None: @@ -2540,7 +2546,6 @@ def _has_loop_ancestor(self, node: ControlFlowBlock) -> bool: node = node.parent_graph return False - def visit_Break(self, node: ast.Break): if not self._has_loop_ancestor(self.cfg_target): raise DaceSyntaxError(self, node, "Break block outside loop region") @@ -2572,8 +2577,7 @@ def visit_If(self, node: ast.If): # Process 'else'/'elif' statements if len(node.orelse) > 0: - else_body = ControlFlowRegion(f'{cond_block.label}_else_{node.orelse[0].lineno}', - sdfg=self.sdfg) + else_body = ControlFlowRegion(f'{cond_block.label}_else_{node.orelse[0].lineno}', sdfg=self.sdfg) cond_block.add_branch(None, else_body) # Visit recursively self._recursive_visit(node.orelse, 'else', node.lineno, else_body, False) @@ -2934,7 +2938,6 @@ def _add_aug_assignment(self, wsqueezed = [i for i in range(len(wtarget_subset)) if i not in wsqz] rsqueezed = [i for i in range(len(rtarget_subset)) if i not in rsqz] - if (boolarr or indirect_indices or (sqz_wsub.size() == sqz_osub.size() and sqz_wsub.size() == sqz_rsub.size())): map_range = {i: rng for i, rng in all_idx_tuples} @@ -3358,8 +3361,11 @@ def _visit_assign(self, node, node_target, op, dtype=None, is_return=False): new_data, rng = None, None dtype_keys = tuple(dtypes.dtype_to_typeclass().keys()) - if not (result in self.sdfg.symbols or symbolic.issymbolic(result) or isinstance(result, dtype_keys) or - (isinstance(result, str) and any(result in x for x in [self.sdfg.arrays, self.sdfg._pgrids, self.sdfg._subarrays, self.sdfg._rdistrarrays]))): + if not ( + result in self.sdfg.symbols or symbolic.issymbolic(result) or isinstance(result, dtype_keys) or + (isinstance(result, str) and any( + result in x + for x in [self.sdfg.arrays, self.sdfg._pgrids, self.sdfg._subarrays, self.sdfg._rdistrarrays]))): raise DaceSyntaxError( self, node, "In assignments, the rhs may only be " "data, numerical/boolean constants " @@ -3467,7 +3473,9 @@ def _visit_assign(self, node, node_target, op, dtype=None, is_return=False): cname = self.sdfg.find_new_constant(f'__ind{i}_{true_name}') self.sdfg.add_constant(cname, carr) # Add constant to descriptor repository - self.sdfg.add_array(cname, carr.shape, dtypes.dtype_to_typeclass(carr.dtype.type), + self.sdfg.add_array(cname, + carr.shape, + dtypes.dtype_to_typeclass(carr.dtype.type), transient=True) if numpy.array(arr).dtype == numpy.bool_: boolarr = cname @@ -4769,7 +4777,7 @@ def visit_With(self, node: ast.With, is_async=False): evald = astutils.evalnode(node.items[0].context_expr, self.globals) if hasattr(evald, "name"): named_region_name: str = evald.name - else: + else: named_region_name = f"Named Region {node.lineno}" named_region = NamedRegion(named_region_name, debuginfo=self.current_lineinfo) self.cfg_target.add_node(named_region) diff --git a/dace/frontend/python/parser.py b/dace/frontend/python/parser.py index 20018effd0..b65e7c227d 100644 --- a/dace/frontend/python/parser.py +++ b/dace/frontend/python/parser.py @@ -59,9 +59,10 @@ def _get_locals_and_globals(f): result.update(f.__globals__) # grab the free variables (i.e. locals) if f.__closure__ is not None: - result.update( - {k: v - for k, v in zip(f.__code__.co_freevars, [_get_cell_contents_or_none(x) for x in f.__closure__])}) + result.update({ + k: v + for k, v in zip(f.__code__.co_freevars, [_get_cell_contents_or_none(x) for x in f.__closure__]) + }) return result @@ -142,6 +143,7 @@ def infer_symbols_from_datadescriptor(sdfg: SDFG, class DaceProgram(pycommon.SDFGConvertible): """ A data-centric program object, obtained by decorating a function with ``@dace.program``. """ + def __init__(self, f, args, @@ -405,9 +407,10 @@ def _create_sdfg_args(self, sdfg: SDFG, args: Tuple[Any], kwargs: Dict[str, Any] # Update arguments with symbols in data shapes result.update( - infer_symbols_from_datadescriptor( - sdfg, {k: create_datadescriptor(v) - for k, v in result.items() if k not in self.constant_args})) + infer_symbols_from_datadescriptor(sdfg, { + k: create_datadescriptor(v) + for k, v in result.items() if k not in self.constant_args + })) return result def __call__(self, *args, **kwargs): @@ -487,9 +490,6 @@ def _parse(self, args, kwargs, simplify=None, save=False, validate=False) -> SDF :param validate: If True, validates the resulting SDFG after creation. :return: The generated SDFG object. """ - # Avoid import loop - from dace.transformation.passes import scalar_to_symbol as scal2sym - from dace.transformation import helpers as xfh # Obtain DaCe program as SDFG sdfg, cached = self._generate_pdp(args, kwargs, simplify=simplify) @@ -812,7 +812,9 @@ def get_program_hash(self, *args, **kwargs) -> cached_program.ProgramCacheKey: _, key = self._load_sdfg(None, *args, **kwargs) return key - def _generate_pdp(self, args: Tuple[Any], kwargs: Dict[str, Any], + def _generate_pdp(self, + args: Tuple[Any], + kwargs: Dict[str, Any], simplify: Optional[bool] = None) -> Tuple[SDFG, bool]: """ Generates the parsed AST representation of a DaCe program. From 896a1e189b92bec5c78fa9e2f0f199c074561d40 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Thu, 12 Dec 2024 15:50:56 +0100 Subject: [PATCH 2/2] Complete Transition to Control Flow Regions (#1676) This PR completes the transition to hierarchical control flow regions in DaCe. By nature of the significance in change that is brought through the transition to hierarchical control flow regions, this PR is rather substantial. An exhaustive listing of all adaptations is not feasible, but the most important changes and adaptations are listed below: - [x] Change the default of the Python frontend to generate SDFGs using experimental CFG blocks. A subsequent PR will remove the option to _not_ use experimental CFG blocks entirely, but this was left to a separate PR to avoid growing this one even more than it already has. - [x] The option to write a pass or transformation that is _not_ compatible with experimental blocks has been removed, forcing new transformations and passes to consider them in their design. - [x] Simplifications to loop related transformations by adapting explicit loop regions. - [x] Add a new pass base type, `ControlFlowRegionPass`: This pass works like `StatePass` or `ScopePass`, and can be extended to write a pass that applies recursively to each control flow region of an SDFG. An option can be set to either apply bottom-up or top-down. - [x] A pass has been added to dead code elimination to prune empty or falsy conditional branches. - [x] Include a control flow raising pass in the simplification pipeline, ensuring that even SDFGs generated without the explicit use of experimental blocks are raised to the new style SDFGs. - [x] Adapt all passes and transformations currently in main DaCe to work with SDFGs containing experimental CFG blocks. - [x] Almost all transformations and analyses now _expect_ that experimental blocks are used for regular / reducible control flow, meaning some control flow analyses have been ditched to improve overall performance and reliability of DaCe, and remove redundancy. - [x] Ensure all compiler backends correctly handle experimental blocks. - [x] Adapt state propagation into a separate pass that has been made to use experimental blocks. Legacy state propagation has been left in for now, including tests that ensure it works as intended, to avoid making this PR even larger. However, it is planned to remove this in a subsequent PR soon. - [x] A block fusion pass has been added to the simplification pipeline. This operates similar to StateFusion, but fuses no-op general control flow blocks (empty states or control flow blocks) with other control flow blocks. This reduces the number of nodes and edges in CFGs further. - [x] Numerous bugfixes with respect to experimental blocks and analyses around them, thanks to the ability to now run the entire CI pipeline with them. Note: The FV3 integration test fails and will continue to fail with this PR, since GT4Py cartesian, which is used by PyFV3, does not consider experimental blocks in their design. Since DaCe v1.0.0 will be released _without_ this PR in it, my suggestion is to limit the application of the FV3 integration and regression tests to PRs which are made to a specific v1.0.0 maintenance branch, which is used for fixes to v1.0.0. --------- Co-authored-by: Philip Mueller Co-authored-by: Tal Ben-Nun --- .github/workflows/pyFV3-ci.yml | 11 +- .gitignore | 2 + dace/codegen/control_flow.py | 67 ++- .../codegen/instrumentation/data/data_dump.py | 66 +-- dace/codegen/instrumentation/gpu_events.py | 80 +-- dace/codegen/instrumentation/likwid.py | 44 +- dace/codegen/instrumentation/papi.py | 91 +-- dace/codegen/instrumentation/provider.py | 39 +- dace/codegen/instrumentation/timer.py | 42 +- dace/codegen/targets/cpu.py | 18 +- dace/codegen/targets/cuda.py | 24 +- dace/codegen/targets/fpga.py | 22 +- dace/codegen/targets/framecode.py | 58 +- dace/codegen/targets/snitch.py | 4 +- dace/frontend/fortran/fortran_parser.py | 18 +- dace/frontend/python/astutils.py | 35 +- dace/frontend/python/interface.py | 7 +- dace/frontend/python/newast.py | 6 +- dace/frontend/python/parser.py | 9 +- dace/sdfg/analysis/cfg.py | 9 +- dace/sdfg/analysis/cutout.py | 2 +- .../analysis/schedule_tree/sdfg_to_tree.py | 5 +- .../analysis/writeset_underapproximation.py | 2 +- dace/sdfg/graph.py | 2 +- dace/sdfg/infer_types.py | 1 - dace/sdfg/nodes.py | 4 +- dace/sdfg/performance_evaluation/helpers.py | 4 +- dace/sdfg/propagation.py | 40 +- dace/sdfg/replace.py | 69 +-- dace/sdfg/sdfg.py | 26 +- dace/sdfg/state.py | 455 ++++++++++----- dace/sdfg/utils.py | 322 +++++++---- dace/sdfg/validation.py | 24 +- dace/transformation/__init__.py | 2 +- dace/transformation/auto/auto_optimize.py | 5 +- dace/transformation/dataflow/__init__.py | 2 +- .../dataflow/double_buffering.py | 7 +- dace/transformation/dataflow/map_fission.py | 29 +- dace/transformation/dataflow/map_for_loop.py | 25 +- dace/transformation/dataflow/mpi.py | 6 +- .../transformation/dataflow/otf_map_fusion.py | 13 +- dace/transformation/dataflow/tiling.py | 2 +- dace/transformation/helpers.py | 418 +++++++------- dace/transformation/interstate/__init__.py | 1 + .../transformation/interstate/block_fusion.py | 102 ++++ .../interstate/fpga_transform_sdfg.py | 15 +- .../interstate/fpga_transform_state.py | 36 +- .../interstate/gpu_transform_sdfg.py | 162 +++--- .../interstate/loop_detection.py | 2 +- .../transformation/interstate/loop_lifting.py | 4 +- .../transformation/interstate/loop_peeling.py | 157 ++---- dace/transformation/interstate/loop_to_map.py | 493 ++++++---------- dace/transformation/interstate/loop_unroll.py | 161 +++--- .../interstate/move_assignment_outside_if.py | 60 +- .../interstate/move_loop_into_map.py | 116 ++-- .../interstate/multistate_inline.py | 66 ++- .../transformation/interstate/sdfg_nesting.py | 40 +- .../interstate/state_elimination.py | 35 +- .../transformation/interstate/state_fusion.py | 2 +- .../state_fusion_with_happens_before.py | 41 +- .../interstate/trivial_loop_elimination.py | 72 +-- dace/transformation/pass_pipeline.py | 85 ++- .../passes/analysis/__init__.py | 1 + .../passes/analysis/analysis.py | 509 ++++++++++++----- .../passes/analysis/loop_analysis.py | 27 +- .../passes/array_elimination.py | 16 +- .../passes/consolidate_edges.py | 4 +- .../passes/constant_propagation.py | 332 ++++++++--- .../passes/dead_dataflow_elimination.py | 27 +- .../passes/dead_state_elimination.py | 136 +++-- dace/transformation/passes/fusion_inline.py | 87 ++- dace/transformation/passes/optional_arrays.py | 42 +- .../transformation/passes/pattern_matching.py | 30 +- dace/transformation/passes/prune_symbols.py | 42 +- .../passes/reference_reduction.py | 15 +- dace/transformation/passes/scalar_fission.py | 4 +- .../transformation/passes/scalar_to_symbol.py | 12 +- .../simplification/control_flow_raising.py | 132 ++++- .../prune_empty_conditional_branches.py | 72 +++ dace/transformation/passes/simplify.py | 51 +- dace/transformation/passes/symbol_ssa.py | 18 +- dace/transformation/passes/transient_reuse.py | 4 +- dace/transformation/subgraph/composite.py | 21 +- dace/transformation/subgraph/expansion.py | 5 +- .../subgraph/gpu_persistent_fusion.py | 39 +- .../transformation/subgraph/stencil_tiling.py | 26 +- .../subgraph/subgraph_fusion.py | 23 +- dace/transformation/testing.py | 6 +- dace/transformation/transformation.py | 26 +- dace/viewer/webclient | 2 +- doc/frontend/parsing.rst | 8 +- tests/codegen/control_flow_detection_test.py | 5 +- tests/codegen/data_instrumentation_test.py | 14 +- tests/constant_array_test.py | 8 +- tests/fortran/array_test.py | 22 +- tests/fortran/fortran_language_test.py | 2 +- tests/fortran/fortran_loops_test.py | 2 +- tests/inlining_test.py | 22 +- tests/passes/constant_propagation_test.py | 18 +- tests/passes/dead_code_elimination_test.py | 62 ++- tests/passes/scalar_fission_test.py | 31 +- tests/passes/scalar_to_symbol_test.py | 45 +- ...calar_write_shadow_scopes_analysis_test.py | 77 ++- .../control_flow_raising_test.py | 43 +- .../prune_empty_conditional_branches_test.py | 105 ++++ .../symbol_write_scopes_analysis_test.py | 4 +- .../writeset_underapproximation_test.py | 25 + tests/python_frontend/augassign_wcr_test.py | 12 +- .../conditional_regions_test.py | 6 +- .../python_frontend/function_regions_test.py | 23 +- tests/python_frontend/loop_regions_test.py | 48 +- .../multiple_nested_sdfgs_test.py | 4 +- tests/python_frontend/named_region_test.py | 25 +- tests/schedule_tree/nesting_test.py | 7 +- tests/schedule_tree/schedule_test.py | 8 +- tests/sdfg/conditional_region_test.py | 1 + tests/sdfg/control_flow_inline_test.py | 16 +- tests/sdfg/free_symbols_test.py | 4 +- tests/sdfg/loop_region_test.py | 12 +- tests/sdfg/schedule_inference_test.py | 24 +- tests/sdfg/state_test.py | 8 +- tests/sdfg/work_depth_test.py | 25 +- tests/state_propagation_test.py | 136 ++++- .../interstate/loop_lifting_test.py | 8 +- tests/transformations/loop_detection_test.py | 3 +- .../transformations/loop_manipulation_test.py | 16 +- tests/transformations/loop_to_map_test.py | 77 ++- .../move_assignment_outside_if_test.py | 67 ++- .../move_loop_into_map_test.py | 525 +++++++++--------- tests/transformations/nest_subgraph_test.py | 8 +- tests/transformations/redundant_copy_test.py | 1 - 131 files changed, 4288 insertions(+), 2745 deletions(-) create mode 100644 dace/transformation/interstate/block_fusion.py create mode 100644 dace/transformation/passes/simplification/prune_empty_conditional_branches.py create mode 100644 tests/passes/simplification/prune_empty_conditional_branches_test.py diff --git a/.github/workflows/pyFV3-ci.yml b/.github/workflows/pyFV3-ci.yml index 852b887cdb..2f587e9894 100644 --- a/.github/workflows/pyFV3-ci.yml +++ b/.github/workflows/pyFV3-ci.yml @@ -1,12 +1,17 @@ name: NASA/NOAA pyFV3 repository build test +# Temporarily disabled for main, and instead applied to a specific DaCe v1 maintenance branch (v1/maintenance). Once +# the FV3 bridge has been adapted to DaCe v1, this will need to be reverted back to apply to main. on: push: - branches: [ main, ci-fix ] + #branches: [ main, ci-fix ] + branches: [ v1/maintenance, ci-fix ] pull_request: - branches: [ main, ci-fix ] + #branches: [ main, ci-fix ] + branches: [ v1/maintenance, ci-fix ] merge_group: - branches: [ main, ci-fix ] + #branches: [ main, ci-fix ] + branches: [ v1/maintenance, ci-fix ] defaults: run: diff --git a/.gitignore b/.gitignore index 7209622916..03c801a68f 100644 --- a/.gitignore +++ b/.gitignore @@ -141,6 +141,8 @@ src.VC.VC.opendb # DaCe .dacecache/ +# Ignore dacecache if added as a symlink +.dacecache out.sdfg *.out results.log diff --git a/dace/codegen/control_flow.py b/dace/codegen/control_flow.py index f5559984e7..5928cc71f2 100644 --- a/dace/codegen/control_flow.py +++ b/dace/codegen/control_flow.py @@ -1,4 +1,4 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Various classes to facilitate the code generation of structured control flow elements (e.g., ``for``, ``if``, ``while``) from state machines in SDFGs. @@ -62,8 +62,8 @@ import sympy as sp from dace import dtypes from dace.sdfg.analysis import cfg as cfg_analysis -from dace.sdfg.state import (BreakBlock, ConditionalBlock, ContinueBlock, ControlFlowBlock, ControlFlowRegion, LoopRegion, - ReturnBlock, SDFGState) +from dace.sdfg.state import (BreakBlock, ConditionalBlock, ContinueBlock, ControlFlowBlock, ControlFlowRegion, + LoopRegion, ReturnBlock, SDFGState) from dace.sdfg.sdfg import SDFG, InterstateEdge from dace.sdfg.graph import Edge from dace.properties import CodeBlock @@ -200,7 +200,10 @@ class BreakCFBlock(ControlFlow): block: BreakBlock def as_cpp(self, codegen, symbols) -> str: - return 'break;\n' + cfg = self.block.parent_graph + expr = '__state_{}_{}:;\n'.format(cfg.cfg_id, self.block.label) + expr += 'break;\n' + return expr @property def first_block(self) -> BreakBlock: @@ -214,7 +217,10 @@ class ContinueCFBlock(ControlFlow): block: ContinueBlock def as_cpp(self, codegen, symbols) -> str: - return 'continue;\n' + cfg = self.block.parent_graph + expr = '__state_{}_{}:;\n'.format(cfg.cfg_id, self.block.label) + expr += 'continue;\n' + return expr @property def first_block(self) -> ContinueBlock: @@ -228,7 +234,10 @@ class ReturnCFBlock(ControlFlow): block: ReturnBlock def as_cpp(self, codegen, symbols) -> str: - return 'return;\n' + cfg = self.block.parent_graph + expr = '__state_{}_{}:;\n'.format(cfg.cfg_id, self.block.label) + expr += 'return;\n' + return expr @property def first_block(self) -> ReturnBlock: @@ -316,7 +325,13 @@ def as_cpp(self, codegen, symbols) -> str: # One unconditional edge if (len(out_edges) == 1 and out_edges[0].data.is_unconditional()): continue - expr += f'goto __state_exit_{sdfg.cfg_id};\n' + if self.region: + expr += f'goto __state_exit_{self.region.cfg_id};\n' + else: + expr += f'goto __state_exit_{sdfg.cfg_id};\n' + + if self.region and not isinstance(self.region, SDFG): + expr += f'__state_exit_{self.region.cfg_id}:;\n' return expr @@ -536,10 +551,14 @@ def as_cpp(self, codegen, symbols) -> str: expr = '' if self.loop.update_statement and self.loop.init_statement and self.loop.loop_variable: - init = unparse_interstate_edge(self.loop.init_statement.code[0], sdfg, codegen=codegen, symbols=symbols) + lsyms = {} + lsyms.update(symbols) + if codegen.dispatcher.defined_vars.has(self.loop.loop_variable) and not self.loop.loop_variable in lsyms: + lsyms[self.loop.loop_variable] = codegen.dispatcher.defined_vars.get(self.loop.loop_variable)[1] + init = unparse_interstate_edge(self.loop.init_statement.code[0], sdfg, codegen=codegen, symbols=lsyms) init = init.strip(';') - update = unparse_interstate_edge(self.loop.update_statement.code[0], sdfg, codegen=codegen, symbols=symbols) + update = unparse_interstate_edge(self.loop.update_statement.code[0], sdfg, codegen=codegen, symbols=lsyms) update = update.strip(';') if self.loop.inverted: @@ -571,6 +590,8 @@ def as_cpp(self, codegen, symbols) -> str: expr += _clean_loop_body(self.body.as_cpp(codegen, symbols)) expr += '\n}\n' + expr += f'__state_exit_{self.loop.cfg_id}:;\n' + return expr @property @@ -1018,21 +1039,16 @@ def _structured_control_flow_traversal_with_regions(cfg: ControlFlowRegion, start: Optional[ControlFlowBlock] = None, stop: Optional[ControlFlowBlock] = None, generate_children_of: Optional[ControlFlowBlock] = None, - branch_merges: Optional[Dict[ControlFlowBlock, - ControlFlowBlock]] = None, ptree: Optional[Dict[ControlFlowBlock, ControlFlowBlock]] = None, visited: Optional[Set[ControlFlowBlock]] = None): - if branch_merges is None: - branch_merges = cfg_analysis.branch_merges(cfg) - if ptree is None: ptree = cfg_analysis.block_parent_tree(cfg, with_loops=False) start = start if start is not None else cfg.start_block - def make_empty_block(): + def make_empty_block(region): return GeneralBlock(dispatch_state, parent_block, - last_block=False, region=None, elements=[], gotos_to_ignore=[], + last_block=False, region=region, elements=[], gotos_to_ignore=[], gotos_to_break=[], gotos_to_continue=[], assignments_to_ignore=[], sequential=True) # Traverse states in custom order @@ -1059,18 +1075,18 @@ def make_empty_block(): cfg_block = GeneralConditionalScope(dispatch_state, parent_block, False, node, []) for cond, branch in node.branches: if branch is not None: - body = make_empty_block() + body = make_empty_block(branch) body.parent = cfg_block _structured_control_flow_traversal_with_regions(branch, dispatch_state, body) cfg_block.branch_bodies.append((cond, body)) elif isinstance(node, ControlFlowRegion): if isinstance(node, LoopRegion): - body = make_empty_block() + body = make_empty_block(node) cfg_block = GeneralLoopScope(dispatch_state, parent_block, False, node, body) body.parent = cfg_block _structured_control_flow_traversal_with_regions(node, dispatch_state, body) else: - cfg_block = make_empty_block() + cfg_block = make_empty_block(node) cfg_block.region = node _structured_control_flow_traversal_with_regions(node, dispatch_state, cfg_block) @@ -1095,13 +1111,14 @@ def make_empty_block(): return visited - {stop} -def structured_control_flow_tree_with_regions(sdfg: SDFG, dispatch_state: Callable[[SDFGState], str]) -> ControlFlow: +def structured_control_flow_tree_with_regions(cfg: ControlFlowRegion, + dispatch_state: Callable[[SDFGState], str]) -> ControlFlow: """ - Returns a structured control-flow tree (i.e., with constructs such as branches and loops) from an SDFG based on the + Returns a structured control-flow tree (i.e., with constructs such as branches and loops) from a CFG based on the control flow regions it contains. - :param sdfg: The SDFG to iterate over. - :return: Control-flow block representing the entire SDFG. + :param cfg: The graph to iterate over. + :return: Control-flow block representing the entire graph. """ root_block = GeneralBlock(dispatch_state=dispatch_state, parent=None, @@ -1113,7 +1130,7 @@ def structured_control_flow_tree_with_regions(sdfg: SDFG, dispatch_state: Callab gotos_to_break=[], assignments_to_ignore=[], sequential=True) - _structured_control_flow_traversal_with_regions(sdfg, dispatch_state, root_block) + _structured_control_flow_traversal_with_regions(cfg, dispatch_state, root_block) _reset_block_parents(root_block) return root_block @@ -1127,7 +1144,7 @@ def structured_control_flow_tree(sdfg: SDFG, dispatch_state: Callable[[SDFGState :param sdfg: The SDFG to iterate over. :return: Control-flow block representing the entire SDFG. """ - if sdfg.root_sdfg.using_experimental_blocks: + if sdfg.root_sdfg.using_explicit_control_flow: return structured_control_flow_tree_with_regions(sdfg, dispatch_state) # Avoid import loops diff --git a/dace/codegen/instrumentation/data/data_dump.py b/dace/codegen/instrumentation/data/data_dump.py index 5fc487f94d..e8c6236a01 100644 --- a/dace/codegen/instrumentation/data/data_dump.py +++ b/dace/codegen/instrumentation/data/data_dump.py @@ -1,10 +1,10 @@ -# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. -from dace import config, data as dt, dtypes, registry, SDFG +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +from dace import data as dt, dtypes, registry, SDFG from dace.sdfg import nodes, is_devicelevel_gpu from dace.codegen.prettycode import CodeIOStream from dace.codegen.instrumentation.provider import InstrumentationProvider from dace.sdfg.scope import is_devicelevel_fpga -from dace.sdfg.state import SDFGState +from dace.sdfg.state import ControlFlowRegion, SDFGState from dace.codegen import common from dace.codegen import cppunparse from dace.codegen.targets import cpp @@ -101,7 +101,8 @@ def on_sdfg_end(self, sdfg: SDFG, local_stream: CodeIOStream, global_stream: Cod if sdfg.parent is None: sdfg.append_exit_code('delete __state->serializer;\n') - def on_state_begin(self, sdfg: SDFG, state: SDFGState, local_stream: CodeIOStream, global_stream: CodeIOStream): + def on_state_begin(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, local_stream: CodeIOStream, + global_stream: CodeIOStream): if state.symbol_instrument == dtypes.DataInstrumentationType.No_Instrumentation: return @@ -119,17 +120,17 @@ def on_state_begin(self, sdfg: SDFG, state: SDFGState, local_stream: CodeIOStrea condition_preamble = f'if ({cond_string})' + ' {' condition_postamble = '}' - state_id = sdfg.node_id(state) - local_stream.write(condition_preamble, sdfg, state_id) + state_id = cfg.node_id(state) + local_stream.write(condition_preamble, cfg, state_id) defined_symbols = state.defined_symbols() for sym, _ in defined_symbols.items(): local_stream.write( - f'__state->serializer->save_symbol("{sym}", "{state_id}", {cpp.sym2cpp(sym)});\n', sdfg, state_id + f'__state->serializer->save_symbol("{sym}", "{state_id}", {cpp.sym2cpp(sym)});\n', cfg, state_id ) - local_stream.write(condition_postamble, sdfg, state_id) + local_stream.write(condition_postamble, cfg, state_id) - def on_node_end(self, sdfg: SDFG, state: SDFGState, node: nodes.AccessNode, outer_stream: CodeIOStream, - inner_stream: CodeIOStream, global_stream: CodeIOStream): + def on_node_end(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, node: nodes.AccessNode, + outer_stream: CodeIOStream, inner_stream: CodeIOStream, global_stream: CodeIOStream): from dace.codegen.dispatcher import DefinedType # Avoid import loop if is_devicelevel_gpu(sdfg, state, node) or is_devicelevel_fpga(sdfg, state, node): @@ -159,9 +160,9 @@ def on_node_end(self, sdfg: SDFG, state: SDFGState, node: nodes.AccessNode, oute ptrname = '&' + ptrname # Create UUID - state_id = sdfg.node_id(state) + state_id = cfg.node_id(state) node_id = state.node_id(node) - uuid = f'{sdfg.cfg_id}_{state_id}_{node_id}' + uuid = f'{cfg.cfg_id}_{state_id}_{node_id}' # Get optional pre/postamble for instrumenting device data preamble, postamble = '', '' @@ -174,13 +175,13 @@ def on_node_end(self, sdfg: SDFG, state: SDFGState, node: nodes.AccessNode, oute strides = ', '.join(cpp.sym2cpp(s) for s in desc.strides) # Write code - inner_stream.write(condition_preamble, sdfg, state_id, node_id) - inner_stream.write(preamble, sdfg, state_id, node_id) + inner_stream.write(condition_preamble, cfg, state_id, node_id) + inner_stream.write(preamble, cfg, state_id, node_id) inner_stream.write( f'__state->serializer->save({ptrname}, {cpp.sym2cpp(desc.total_size - desc.start_offset)}, ' - f'"{node.data}", "{uuid}", {shape}, {strides});\n', sdfg, state_id, node_id) - inner_stream.write(postamble, sdfg, state_id, node_id) - inner_stream.write(condition_postamble, sdfg, state_id, node_id) + f'"{node.data}", "{uuid}", {shape}, {strides});\n', cfg, state_id, node_id) + inner_stream.write(postamble, cfg, state_id, node_id) + inner_stream.write(condition_postamble, cfg, state_id, node_id) @registry.autoregister_params(type=dtypes.DataInstrumentationType.Restore) @@ -216,7 +217,8 @@ def on_sdfg_end(self, sdfg: SDFG, local_stream: CodeIOStream, global_stream: Cod if sdfg.parent is None: sdfg.append_exit_code('delete __state->serializer;\n') - def on_state_begin(self, sdfg: SDFG, state: SDFGState, local_stream: CodeIOStream, global_stream: CodeIOStream): + def on_state_begin(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, local_stream: CodeIOStream, + global_stream: CodeIOStream): if state.symbol_instrument == dtypes.DataInstrumentationType.No_Instrumentation: return @@ -234,18 +236,18 @@ def on_state_begin(self, sdfg: SDFG, state: SDFGState, local_stream: CodeIOStrea condition_preamble = f'if ({cond_string})' + ' {' condition_postamble = '}' - state_id = sdfg.node_id(state) - local_stream.write(condition_preamble, sdfg, state_id) + state_id = state.block_id + local_stream.write(condition_preamble, cfg, state_id) defined_symbols = state.defined_symbols() for sym, sym_type in defined_symbols.items(): local_stream.write( f'{cpp.sym2cpp(sym)} = __state->serializer->restore_symbol<{sym_type.ctype}>("{sym}", "{state_id}");\n', - sdfg, state_id + cfg, state_id ) - local_stream.write(condition_postamble, sdfg, state_id) + local_stream.write(condition_postamble, cfg, state_id) - def on_node_begin(self, sdfg: SDFG, state: SDFGState, node: nodes.AccessNode, outer_stream: CodeIOStream, - inner_stream: CodeIOStream, global_stream: CodeIOStream): + def on_node_begin(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, node: nodes.AccessNode, + outer_stream: CodeIOStream, inner_stream: CodeIOStream, global_stream: CodeIOStream): from dace.codegen.dispatcher import DefinedType # Avoid import loop if is_devicelevel_gpu(sdfg, state, node) or is_devicelevel_fpga(sdfg, state, node): @@ -275,21 +277,21 @@ def on_node_begin(self, sdfg: SDFG, state: SDFGState, node: nodes.AccessNode, ou ptrname = '&' + ptrname # Create UUID - state_id = sdfg.node_id(state) + state_id = cfg.node_id(state) node_id = state.node_id(node) - uuid = f'{sdfg.cfg_id}_{state_id}_{node_id}' + uuid = f'{cfg.cfg_id}_{state_id}_{node_id}' # Get optional pre/postamble for instrumenting device data preamble, postamble = '', '' if desc.storage == dtypes.StorageType.GPU_Global: - self._setup_gpu_runtime(sdfg, global_stream) + self._setup_gpu_runtime(cfg, global_stream) preamble, postamble, ptrname = self._generate_copy_to_device(node, desc, ptrname) # Write code - inner_stream.write(condition_preamble, sdfg, state_id, node_id) - inner_stream.write(preamble, sdfg, state_id, node_id) + inner_stream.write(condition_preamble, cfg, state_id, node_id) + inner_stream.write(preamble, cfg, state_id, node_id) inner_stream.write( f'__state->serializer->restore({ptrname}, {cpp.sym2cpp(desc.total_size - desc.start_offset)}, ' - f'"{node.data}", "{uuid}");\n', sdfg, state_id, node_id) - inner_stream.write(postamble, sdfg, state_id, node_id) - inner_stream.write(condition_postamble, sdfg, state_id, node_id) + f'"{node.data}", "{uuid}");\n', cfg, state_id, node_id) + inner_stream.write(postamble, cfg, state_id, node_id) + inner_stream.write(condition_postamble, cfg, state_id, node_id) diff --git a/dace/codegen/instrumentation/gpu_events.py b/dace/codegen/instrumentation/gpu_events.py index cfd5a1cbb3..6e0d483a43 100644 --- a/dace/codegen/instrumentation/gpu_events.py +++ b/dace/codegen/instrumentation/gpu_events.py @@ -6,7 +6,7 @@ from dace.codegen import common from dace.codegen.instrumentation.provider import InstrumentationProvider from dace.sdfg.sdfg import SDFG -from dace.sdfg.state import SDFGState +from dace.sdfg.state import ControlFlowRegion, SDFGState @registry.autoregister_params(type=dtypes.InstrumentationType.GPU_Events) @@ -53,8 +53,8 @@ def _record_event(self, id, stream): streamstr = f'__state->gpu_context->streams[{stream}]' return '%sEventRecord(__dace_ev_%s, %s);' % (self.backend, id, streamstr) - def _report(self, timer_name: str, sdfg: SDFG = None, state: SDFGState = None, node: nodes.Node = None): - idstr = self._idstr(sdfg, state, node) + def _report(self, timer_name: str, cfg: ControlFlowRegion = None, state: SDFGState = None, node: nodes.Node = None): + idstr = self._idstr(cfg, state, node) state_id = -1 node_id = -1 @@ -73,12 +73,12 @@ def _report(self, timer_name: str, sdfg: SDFG = None, state: SDFGState = None, n id=idstr, timer_name=timer_name, backend=self.backend, - cfg_id=sdfg.cfg_id, + cfg_id=cfg.cfg_id, state_id=state_id, node_id=node_id) # Code generation hooks - def on_state_begin(self, sdfg: SDFG, state: SDFGState, local_stream: CodeIOStream, + def on_state_begin(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, local_stream: CodeIOStream, global_stream: CodeIOStream) -> None: state_id = state.parent_graph.node_id(state) # Create GPU events for each instrumented scope in the state @@ -86,84 +86,84 @@ def on_state_begin(self, sdfg: SDFG, state: SDFGState, local_stream: CodeIOStrea if isinstance(node, (nodes.CodeNode, nodes.EntryNode)): s = (self._get_sobj(node) if isinstance(node, nodes.EntryNode) else node) if s.instrument == dtypes.InstrumentationType.GPU_Events: - idstr = self._idstr(sdfg, state, node) - local_stream.write(self._create_event('b' + idstr), sdfg, state_id, node) - local_stream.write(self._create_event('e' + idstr), sdfg, state_id, node) + idstr = self._idstr(cfg, state, node) + local_stream.write(self._create_event('b' + idstr), cfg, state_id, node) + local_stream.write(self._create_event('e' + idstr), cfg, state_id, node) # Create and record a CUDA/HIP event for the entire state if state.instrument == dtypes.InstrumentationType.GPU_Events: - idstr = 'b' + self._idstr(sdfg, state, None) - local_stream.write(self._create_event(idstr), sdfg, state_id) - local_stream.write(self._record_event(idstr, 0), sdfg, state_id) - idstr = 'e' + self._idstr(sdfg, state, None) - local_stream.write(self._create_event(idstr), sdfg, state_id) + idstr = 'b' + self._idstr(cfg, state, None) + local_stream.write(self._create_event(idstr), cfg, state_id) + local_stream.write(self._record_event(idstr, 0), cfg, state_id) + idstr = 'e' + self._idstr(cfg, state, None) + local_stream.write(self._create_event(idstr), cfg, state_id) - def on_state_end(self, sdfg: SDFG, state: SDFGState, local_stream: CodeIOStream, + def on_state_end(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, local_stream: CodeIOStream, global_stream: CodeIOStream) -> None: state_id = state.parent_graph.node_id(state) # Record and measure state stream event if state.instrument == dtypes.InstrumentationType.GPU_Events: - idstr = self._idstr(sdfg, state, None) - local_stream.write(self._record_event('e' + idstr, 0), sdfg, state_id) - local_stream.write(self._report('State %s' % state.label, sdfg, state), sdfg, state_id) - local_stream.write(self._destroy_event('b' + idstr), sdfg, state_id) - local_stream.write(self._destroy_event('e' + idstr), sdfg, state_id) + idstr = self._idstr(cfg, state, None) + local_stream.write(self._record_event('e' + idstr, 0), cfg, state_id) + local_stream.write(self._report('State %s' % state.label, cfg, state), cfg, state_id) + local_stream.write(self._destroy_event('b' + idstr), cfg, state_id) + local_stream.write(self._destroy_event('e' + idstr), cfg, state_id) # Destroy CUDA/HIP events for scopes in the state for node in state.nodes(): if isinstance(node, (nodes.CodeNode, nodes.EntryNode)): s = (self._get_sobj(node) if isinstance(node, nodes.EntryNode) else node) if s.instrument == dtypes.InstrumentationType.GPU_Events: - idstr = self._idstr(sdfg, state, node) - local_stream.write(self._destroy_event('b' + idstr), sdfg, state_id, node) - local_stream.write(self._destroy_event('e' + idstr), sdfg, state_id, node) + idstr = self._idstr(cfg, state, node) + local_stream.write(self._destroy_event('b' + idstr), cfg, state_id, node) + local_stream.write(self._destroy_event('e' + idstr), cfg, state_id, node) - def on_scope_entry(self, sdfg: SDFG, state: SDFGState, node: nodes.EntryNode, outer_stream: CodeIOStream, - inner_stream: CodeIOStream, global_stream: CodeIOStream) -> None: + def on_scope_entry(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, node: nodes.EntryNode, + outer_stream: CodeIOStream, inner_stream: CodeIOStream, global_stream: CodeIOStream) -> None: state_id = state.parent_graph.node_id(state) s = self._get_sobj(node) if s.instrument == dtypes.InstrumentationType.GPU_Events: if s.schedule != dtypes.ScheduleType.GPU_Device: raise TypeError('GPU Event instrumentation only applies to ' 'GPU_Device map scopes') - idstr = 'b' + self._idstr(sdfg, state, node) + idstr = 'b' + self._idstr(cfg, state, node) stream = getattr(node, '_cuda_stream', -1) - outer_stream.write(self._record_event(idstr, stream), sdfg, state_id, node) + outer_stream.write(self._record_event(idstr, stream), cfg, state_id, node) - def on_scope_exit(self, sdfg: SDFG, state: SDFGState, node: nodes.ExitNode, outer_stream: CodeIOStream, - inner_stream: CodeIOStream, global_stream: CodeIOStream) -> None: + def on_scope_exit(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, node: nodes.ExitNode, + outer_stream: CodeIOStream, inner_stream: CodeIOStream, global_stream: CodeIOStream) -> None: state_id = state.parent_graph.node_id(state) entry_node = state.entry_node(node) s = self._get_sobj(node) if s.instrument == dtypes.InstrumentationType.GPU_Events: - idstr = 'e' + self._idstr(sdfg, state, entry_node) + idstr = 'e' + self._idstr(cfg, state, entry_node) stream = getattr(node, '_cuda_stream', -1) - outer_stream.write(self._record_event(idstr, stream), sdfg, state_id, node) - outer_stream.write(self._report('%s %s' % (type(s).__name__, s.label), sdfg, state, entry_node), sdfg, + outer_stream.write(self._record_event(idstr, stream), cfg, state_id, node) + outer_stream.write(self._report('%s %s' % (type(s).__name__, s.label), cfg, state, entry_node), cfg, state_id, node) - def on_node_begin(self, sdfg: SDFG, state: SDFGState, node: nodes.Node, outer_stream: CodeIOStream, - inner_stream: CodeIOStream, global_stream: CodeIOStream) -> None: + def on_node_begin(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, node: nodes.Node, + outer_stream: CodeIOStream, inner_stream: CodeIOStream, global_stream: CodeIOStream) -> None: if (not isinstance(node, nodes.CodeNode) or is_devicelevel_gpu(sdfg, state, node)): return # Only run for host nodes # TODO(later): Implement "clock64"-based GPU counters if node.instrument == dtypes.InstrumentationType.GPU_Events: state_id = state.parent_graph.node_id(state) - idstr = 'b' + self._idstr(sdfg, state, node) + idstr = 'b' + self._idstr(cfg, state, node) stream = getattr(node, '_cuda_stream', -1) - outer_stream.write(self._record_event(idstr, stream), sdfg, state_id, node) + outer_stream.write(self._record_event(idstr, stream), cfg, state_id, node) - def on_node_end(self, sdfg: SDFG, state: SDFGState, node: nodes.Node, outer_stream: CodeIOStream, - inner_stream: CodeIOStream, global_stream: CodeIOStream) -> None: + def on_node_end(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, node: nodes.Node, + outer_stream: CodeIOStream, inner_stream: CodeIOStream, global_stream: CodeIOStream) -> None: if (not isinstance(node, nodes.Tasklet) or is_devicelevel_gpu(sdfg, state, node)): return # Only run for host nodes # TODO(later): Implement "clock64"-based GPU counters if node.instrument == dtypes.InstrumentationType.GPU_Events: state_id = state.parent_graph.node_id(state) - idstr = 'e' + self._idstr(sdfg, state, node) + idstr = 'e' + self._idstr(cfg, state, node) stream = getattr(node, '_cuda_stream', -1) - outer_stream.write(self._record_event(idstr, stream), sdfg, state_id, node) - outer_stream.write(self._report('%s %s' % (type(node).__name__, node.label), sdfg, state, node), sdfg, + outer_stream.write(self._record_event(idstr, stream), cfg, state_id, node) + outer_stream.write(self._report('%s %s' % (type(node).__name__, node.label), cfg, state, node), cfg, state_id, node) diff --git a/dace/codegen/instrumentation/likwid.py b/dace/codegen/instrumentation/likwid.py index 8d1c9e3b71..bd9ffe63a7 100644 --- a/dace/codegen/instrumentation/likwid.py +++ b/dace/codegen/instrumentation/likwid.py @@ -1,4 +1,4 @@ -# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Implements the LIKWID counter performance instrumentation provider. Used for collecting CPU performance counters. """ @@ -15,7 +15,7 @@ from dace.config import Config from dace.sdfg import nodes from dace.sdfg.sdfg import SDFG -from dace.sdfg.state import SDFGState +from dace.sdfg.state import ControlFlowRegion, SDFGState from dace.transformation import helpers as xfh @@ -213,13 +213,13 @@ def on_sdfg_end(self, sdfg, local_stream, global_stream): ''' self.codegen._exitcode.write(exit_code, sdfg) - def on_state_begin(self, sdfg: SDFG, state: SDFGState, local_stream: CodeIOStream, + def on_state_begin(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, local_stream: CodeIOStream, global_stream: CodeIOStream) -> None: if not self._likwid_used: return if state.instrument == dace.InstrumentationType.LIKWID_CPU: - cfg_id = state.parent_graph.cfg_id + cfg_id = cfg.cfg_id state_id = state.block_id node_id = -1 region = f"state_{cfg_id}_{state_id}_{node_id}" @@ -250,13 +250,13 @@ def on_state_begin(self, sdfg: SDFG, state: SDFGState, local_stream: CodeIOStrea ''' local_stream.write(marker_code) - def on_state_end(self, sdfg: SDFG, state: SDFGState, local_stream: CodeIOStream, + def on_state_end(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, local_stream: CodeIOStream, global_stream: CodeIOStream) -> None: if not self._likwid_used: return if state.instrument == dace.InstrumentationType.LIKWID_CPU: - cfg_id = state.parent_graph.cfg_id + cfg_id = cfg.cfg_id state_id = state.block_id node_id = -1 region = f"state_{cfg_id}_{state_id}_{node_id}" @@ -269,8 +269,8 @@ def on_state_end(self, sdfg: SDFG, state: SDFGState, local_stream: CodeIOStream, ''' local_stream.write(marker_code) - def on_scope_entry(self, sdfg: SDFG, state: SDFGState, node: nodes.EntryNode, outer_stream: CodeIOStream, - inner_stream: CodeIOStream, global_stream: CodeIOStream) -> None: + def on_scope_entry(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, node: nodes.EntryNode, + outer_stream: CodeIOStream, inner_stream: CodeIOStream, global_stream: CodeIOStream) -> None: if not self._likwid_used or node.instrument != dace.InstrumentationType.LIKWID_CPU: return @@ -279,7 +279,7 @@ def on_scope_entry(self, sdfg: SDFG, state: SDFGState, node: nodes.EntryNode, ou elif node.schedule not in LIKWIDInstrumentationCPU.perf_whitelist_schedules: raise TypeError("Unsupported schedule on scope") - cfg_id = state.parent_graph.cfg_id + cfg_id = cfg.cfg_id state_id = state.block_id node_id = state.node_id(node) region = f"scope_{cfg_id}_{state_id}_{node_id}" @@ -296,13 +296,13 @@ def on_scope_entry(self, sdfg: SDFG, state: SDFGState, node: nodes.EntryNode, ou ''' outer_stream.write(marker_code) - def on_scope_exit(self, sdfg: SDFG, state: SDFGState, node: nodes.ExitNode, outer_stream: CodeIOStream, - inner_stream: CodeIOStream, global_stream: CodeIOStream) -> None: + def on_scope_exit(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, node: nodes.ExitNode, + outer_stream: CodeIOStream, inner_stream: CodeIOStream, global_stream: CodeIOStream) -> None: entry_node = state.entry_node(node) if not self._likwid_used or entry_node.instrument != dace.InstrumentationType.LIKWID_CPU: return - cfg_id = state.parent_graph.cfg_id + cfg_id = cfg.cfg_id state_id = state.block_id node_id = state.node_id(entry_node) region = f"scope_{cfg_id}_{state_id}_{node_id}" @@ -405,13 +405,13 @@ def on_sdfg_end(self, sdfg: SDFG, local_stream: CodeIOStream, global_stream: Cod ''' self.codegen._exitcode.write(exit_code, sdfg) - def on_state_begin(self, sdfg: SDFG, state: SDFGState, local_stream: CodeIOStream, + def on_state_begin(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, local_stream: CodeIOStream, global_stream: CodeIOStream) -> None: if not self._likwid_used: return if state.instrument == dace.InstrumentationType.LIKWID_GPU: - cfg_id = state.parent_graph.cfg_id + cfg_id = cfg.cfg_id state_id = state.block_id node_id = -1 region = f"state_{cfg_id}_{state_id}_{node_id}" @@ -428,13 +428,13 @@ def on_state_begin(self, sdfg: SDFG, state: SDFGState, local_stream: CodeIOStrea ''' local_stream.write(marker_code) - def on_state_end(self, sdfg: SDFG, state: SDFGState, local_stream: CodeIOStream, + def on_state_end(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, local_stream: CodeIOStream, global_stream: CodeIOStream) -> None: if not self._likwid_used: return if state.instrument == dace.InstrumentationType.LIKWID_GPU: - cfg_id = state.parent_graph.cfg_id + cfg_id = cfg.cfg_id state_id = state.block_id node_id = -1 region = f"state_{cfg_id}_{state_id}_{node_id}" @@ -444,8 +444,8 @@ def on_state_end(self, sdfg: SDFG, state: SDFGState, local_stream: CodeIOStream, ''' local_stream.write(marker_code) - def on_scope_entry(self, sdfg: SDFG, state: SDFGState, node: nodes.EntryNode, outer_stream: CodeIOStream, - inner_stream: CodeIOStream, global_stream: CodeIOStream) -> None: + def on_scope_entry(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, node: nodes.EntryNode, + outer_stream: CodeIOStream, inner_stream: CodeIOStream, global_stream: CodeIOStream) -> None: if not self._likwid_used or node.instrument != dace.InstrumentationType.LIKWID_GPU: return @@ -454,7 +454,7 @@ def on_scope_entry(self, sdfg: SDFG, state: SDFGState, node: nodes.EntryNode, ou elif node.schedule not in LIKWIDInstrumentationGPU.perf_whitelist_schedules: raise TypeError("Unsupported schedule on scope") - cfg_id = state.parent_graph.cfg_id + cfg_id = cfg.cfg_id state_id = state.block_id node_id = state.node_id(node) region = f"scope_{cfg_id}_{state_id}_{node_id}" @@ -471,13 +471,13 @@ def on_scope_entry(self, sdfg: SDFG, state: SDFGState, node: nodes.EntryNode, ou ''' outer_stream.write(marker_code) - def on_scope_exit(self, sdfg: SDFG, state: SDFGState, node: nodes.ExitNode, outer_stream: CodeIOStream, - inner_stream: CodeIOStream, global_stream: CodeIOStream) -> None: + def on_scope_exit(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, node: nodes.ExitNode, + outer_stream: CodeIOStream, inner_stream: CodeIOStream, global_stream: CodeIOStream) -> None: entry_node = state.entry_node(node) if not self._likwid_used or entry_node.instrument != dace.InstrumentationType.LIKWID_GPU: return - cfg_id = state.parent_graph.cfg_id + cfg_id = cfg.cfg_id state_id = state.block_id node_id = state.node_id(entry_node) region = f"scope_{cfg_id}_{state_id}_{node_id}" diff --git a/dace/codegen/instrumentation/papi.py b/dace/codegen/instrumentation/papi.py index 4885611408..ac2f6aafb7 100644 --- a/dace/codegen/instrumentation/papi.py +++ b/dace/codegen/instrumentation/papi.py @@ -103,19 +103,19 @@ def on_sdfg_end(self, sdfg, local_stream, global_stream): local_stream.write('__perf_store.flush();', sdfg) - def on_state_begin(self, sdfg, state, local_stream, global_stream): + def on_state_begin(self, sdfg, cfg, state, local_stream, global_stream): if not self._papi_used: return if state.instrument == dace.InstrumentationType.PAPI_Counters: - uid = _unified_id(-1, sdfg.node_id(state)) + uid = _unified_id(-1, cfg.node_id(state)) local_stream.write("__perf_store.markSuperSectionStart(%d);" % uid) - def on_copy_begin(self, sdfg, state, src_node, dst_node, edge, local_stream, global_stream, copy_shape, src_strides, - dst_strides): + def on_copy_begin(self, sdfg, cfg, state, src_node, dst_node, edge, local_stream, global_stream, copy_shape, + src_strides, dst_strides): if not self._papi_used: return - state_id = sdfg.node_id(state) + state_id = cfg.node_id(state) memlet = edge.data # For perfcounters, we have to make sure that: @@ -153,7 +153,7 @@ def on_copy_begin(self, sdfg, state, src_node, dst_node, edge, local_stream, glo # would be a section with 1 entry)) local_stream.write( self.perf_section_start_string(node_id, copy_size, copy_size), - sdfg, + cfg, state_id, [src_node, dst_node], ) @@ -169,34 +169,34 @@ def on_copy_begin(self, sdfg, state, src_node, dst_node, edge, local_stream, glo unique_id=unique_cpy_id, size=copy_size, ), - sdfg, + cfg, state_id, [src_node, dst_node], ) - def on_copy_end(self, sdfg, state, src_node, dst_node, edge, local_stream, global_stream): + def on_copy_end(self, sdfg, cfg, state, src_node, dst_node, edge, local_stream, global_stream): if not self._papi_used: return - state_id = sdfg.node_id(state) + state_id = cfg.node_id(state) node_id = state.node_id(dst_node) if self.perf_should_instrument: unique_cpy_id = self._unique_counter local_stream.write( "__perf_cpy_%d_%d.leaveCritical(__vs_cpy_%d_%d);" % (node_id, unique_cpy_id, node_id, unique_cpy_id), - sdfg, + cfg, state_id, [src_node, dst_node], ) self.perf_should_instrument = False - def on_node_begin(self, sdfg, state, node, outer_stream, inner_stream, global_stream): + def on_node_begin(self, sdfg, cfg, state, node, outer_stream, inner_stream, global_stream): if not self._papi_used: return - state_id = sdfg.node_id(state) + state_id = cfg.node_id(state) unified_id = _unified_id(state.node_id(node), state_id) perf_should_instrument = (node.instrument == dace.InstrumentationType.PAPI_Counters @@ -207,25 +207,25 @@ def on_node_begin(self, sdfg, state, node, outer_stream, inner_stream, global_st if isinstance(node, nodes.Tasklet): inner_stream.write( "dace::perf::%s __perf_%s;\n" % (self.perf_counter_string(), node.label), - sdfg, + cfg, state_id, node, ) inner_stream.write( 'auto& __perf_vs_%s = __perf_store.getNewValueSet(__perf_%s, ' ' %d, PAPI_thread_id(), 0);\n' % (node.label, node.label, unified_id), - sdfg, + cfg, state_id, node, ) - inner_stream.write("__perf_%s.enterCritical();\n" % node.label, sdfg, state_id, node) + inner_stream.write("__perf_%s.enterCritical();\n" % node.label, cfg, state_id, node) - def on_node_end(self, sdfg, state, node, outer_stream, inner_stream, global_stream): + def on_node_end(self, sdfg, cfg, state, node, outer_stream, inner_stream, global_stream): if not self._papi_used: return - state_id = sdfg.node_id(state) + state_id = cfg.node_id(state) node_id = state.node_id(node) unified_id = _unified_id(node_id, state_id) @@ -234,7 +234,7 @@ def on_node_end(self, sdfg, state, node, outer_stream, inner_stream, global_stre if not PAPIInstrumentation.has_surrounding_perfcounters(node, state): inner_stream.write( "__perf_%s.leaveCritical(__perf_vs_%s);" % (node.label, node.label), - sdfg, + cfg, state_id, node, ) @@ -242,21 +242,21 @@ def on_node_end(self, sdfg, state, node, outer_stream, inner_stream, global_stre # Add bytes moved inner_stream.write( "__perf_store.addBytesMoved(%s);" % - PAPIUtils.get_tasklet_byte_accesses(node, state, sdfg, state_id), sdfg, state_id, node) + PAPIUtils.get_tasklet_byte_accesses(node, state, sdfg, cfg, state_id), cfg, state_id, node) - def on_scope_entry(self, sdfg, state, node, outer_stream, inner_stream, global_stream): + def on_scope_entry(self, sdfg, cfg, state, node, outer_stream, inner_stream, global_stream): if not self._papi_used: return if isinstance(node, nodes.MapEntry): - return self.on_map_entry(sdfg, state, node, outer_stream, inner_stream) + return self.on_map_entry(sdfg, cfg, state, node, outer_stream, inner_stream) elif isinstance(node, nodes.ConsumeEntry): - return self.on_consume_entry(sdfg, state, node, outer_stream, inner_stream) + return self.on_consume_entry(sdfg, cfg, state, node, outer_stream, inner_stream) raise TypeError - def on_scope_exit(self, sdfg, state, node, outer_stream, inner_stream, global_stream): + def on_scope_exit(self, sdfg, cfg, state, node, outer_stream, inner_stream, global_stream): if not self._papi_used: return - state_id = sdfg.node_id(state) + state_id = cfg.node_id(state) entry_node = state.entry_node(node) if not self.should_instrument_entry(entry_node): return @@ -265,11 +265,11 @@ def on_scope_exit(self, sdfg, state, node, outer_stream, inner_stream, global_st perf_end_string = self.perf_counter_end_measurement_string(unified_id) # Inner part - inner_stream.write(perf_end_string, sdfg, state_id, node) + inner_stream.write(perf_end_string, cfg, state_id, node) - def on_map_entry(self, sdfg, state, node, outer_stream, inner_stream): + def on_map_entry(self, sdfg, cfg, state, node, outer_stream, inner_stream): dfg = state.scope_subgraph(node) - state_id = sdfg.node_id(state) + state_id = cfg.node_id(state) if node.map.instrument != dace.InstrumentationType.PAPI_Counters: return @@ -280,7 +280,7 @@ def on_map_entry(self, sdfg, state, node, outer_stream, inner_stream): result = outer_stream - input_size: str = PAPIUtils.get_memory_input_size(node, sdfg, state_id) + input_size: str = PAPIUtils.get_memory_input_size(node, sdfg, cfg, state_id) # Emit supersection if possible result.write(self.perf_get_supersection_start_string(node, dfg, unified_id)) @@ -288,7 +288,7 @@ def on_map_entry(self, sdfg, state, node, outer_stream, inner_stream): if not self.should_instrument_entry(node): return - size = PAPIUtils.accumulate_byte_movement(node, node, dfg, sdfg, state_id) + size = PAPIUtils.accumulate_byte_movement(node, node, dfg, sdfg, cfg, state_id) size = sym2cpp(sp.simplify(size)) result.write(self.perf_section_start_string(unified_id, size, input_size)) @@ -299,10 +299,10 @@ def on_map_entry(self, sdfg, state, node, outer_stream, inner_stream): map_name = node.map.params[-1] - result.write(self.perf_counter_start_measurement_string(unified_id, map_name), sdfg, state_id, node) + result.write(self.perf_counter_start_measurement_string(unified_id, map_name), cfg, state_id, node) - def on_consume_entry(self, sdfg, state, node, outer_stream, inner_stream): - state_id = sdfg.node_id(state) + def on_consume_entry(self, sdfg, cfg, state, node, outer_stream, inner_stream): + state_id = cfg.node_id(state) unified_id = _unified_id(state.node_id(node), state_id) # Outer part @@ -312,18 +312,18 @@ def on_consume_entry(self, sdfg, state, node, outer_stream, inner_stream): # Mark the SuperSection start (if possible) result.write( self.perf_get_supersection_start_string(node, state, unified_id), - sdfg, + cfg, state_id, node, ) # Mark the section start with zeros (due to dynamic accesses) - result.write(self.perf_section_start_string(unified_id, "0", "0"), sdfg, state_id, node) + result.write(self.perf_section_start_string(unified_id, "0", "0"), cfg, state_id, node) # Generate a thread affinity locker result.write( "dace::perf::ThreadLockProvider __perf_tlp_%d;\n" % unified_id, - sdfg, + cfg, state_id, node, ) @@ -343,7 +343,7 @@ def on_consume_entry(self, sdfg, state, node, outer_stream, inner_stream): "__perf_tlp_{id}.getAndIncreaseCounter()".format(id=unified_id), core_str="dace::perf::getThreadID()", ), - sdfg, + cfg, state_id, node, ) @@ -605,8 +605,8 @@ def get_memlet_byte_size(sdfg: dace.SDFG, memlet: Memlet): return memlet.volume * memdata.dtype.bytes @staticmethod - def get_out_memlet_costs(sdfg: dace.SDFG, state_id: int, node: nodes.Node, dfg: DataflowGraphView): - scope_dict = sdfg.node(state_id).scope_dict() + def get_out_memlet_costs(sdfg: dace.SDFG, cfg, state_id: int, node: nodes.Node, dfg: DataflowGraphView): + scope_dict = cfg.node(state_id).scope_dict() out_costs = 0 for edge in dfg.out_edges(node): @@ -639,6 +639,7 @@ def get_out_memlet_costs(sdfg: dace.SDFG, state_id: int, node: nodes.Node, dfg: def get_tasklet_byte_accesses(tasklet: nodes.CodeNode, dfg: DataflowGraphView, sdfg: dace.SDFG, + cfg, state_id: int) -> str: """ Get the amount of bytes processed by `tasklet`. The formula is sum(inedges * size) + sum(outedges * size) """ @@ -649,7 +650,7 @@ def get_tasklet_byte_accesses(tasklet: nodes.CodeNode, for ie in in_edges: in_accum.append(PAPIUtils.get_memlet_byte_size(sdfg, ie.data)) - out_accum.append(PAPIUtils.get_out_memlet_costs(sdfg, state_id, tasklet, dfg)) + out_accum.append(PAPIUtils.get_out_memlet_costs(sdfg, cfg, state_id, tasklet, dfg)) # Merge full = in_accum @@ -663,7 +664,7 @@ def get_parents(outermost_node: nodes.Node, node: nodes.Node, sdfg: dace.SDFG, s parent = None # Because dfg is only a subgraph view, it does not contain the entry # node for a given entry. This O(n) solution is suboptimal - for state in sdfg.nodes(): + for state in sdfg.states(): s_d = state.scope_dict() try: scope = s_d[node] @@ -681,8 +682,8 @@ def get_parents(outermost_node: nodes.Node, node: nodes.Node, sdfg: dace.SDFG, s return PAPIUtils.get_parents(outermost_node, parent, sdfg, state_id) + [parent] @staticmethod - def get_memory_input_size(node, sdfg, state_id) -> str: - curr_state = sdfg.nodes()[state_id] + def get_memory_input_size(node, sdfg, cfg, state_id) -> str: + curr_state = cfg.node(state_id) input_size = 0 for edge in curr_state.in_edges(node): @@ -696,7 +697,7 @@ def get_memory_input_size(node, sdfg, state_id) -> str: return sym2cpp(input_size) @staticmethod - def accumulate_byte_movement(outermost_node, node, dfg: DataflowGraphView, sdfg, state_id): + def accumulate_byte_movement(outermost_node, node, dfg: DataflowGraphView, sdfg, cfg, state_id): itvars = dict() # initialize an empty dict @@ -711,7 +712,7 @@ def accumulate_byte_movement(outermost_node, node, dfg: DataflowGraphView, sdfg, if len(children) > 0: size = 0 for x in children: - size = size + PAPIUtils.accumulate_byte_movement(outermost_node, x, dfg, sdfg, state_id) + size = size + PAPIUtils.accumulate_byte_movement(outermost_node, x, dfg, sdfg, cfg, state_id) return size else: @@ -740,7 +741,7 @@ def accumulate_byte_movement(outermost_node, node, dfg: DataflowGraphView, sdfg, return 0 # We can ignore this. elif isinstance(node, Tasklet): return itcount * symbolic.pystr_to_symbolic( - PAPIUtils.get_tasklet_byte_accesses(node, dfg, sdfg, state_id)) + PAPIUtils.get_tasklet_byte_accesses(node, dfg, sdfg, cfg, state_id)) elif isinstance(node, nodes.AccessNode): return 0 else: diff --git a/dace/codegen/instrumentation/provider.py b/dace/codegen/instrumentation/provider.py index a3748b241b..9374ed60dd 100644 --- a/dace/codegen/instrumentation/provider.py +++ b/dace/codegen/instrumentation/provider.py @@ -60,34 +60,37 @@ def on_sdfg_end(self, sdfg: SDFG, local_stream: CodeIOStream, global_stream: Cod """ pass - def on_state_begin(self, sdfg: SDFG, state: SDFGState, local_stream: CodeIOStream, + def on_state_begin(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, local_stream: CodeIOStream, global_stream: CodeIOStream) -> None: """ Event called at the beginning of SDFG state code generation. :param sdfg: The generated SDFG object. + :param cfg: The generated Control Flow Region object. :param state: The generated SDFGState object. :param local_stream: Code generator for the in-function code. :param global_stream: Code generator for global (external) code. """ pass - def on_state_end(self, sdfg: SDFG, state: SDFGState, local_stream: CodeIOStream, + def on_state_end(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, local_stream: CodeIOStream, global_stream: CodeIOStream) -> None: """ Event called at the end of SDFG state code generation. :param sdfg: The generated SDFG object. + :param cfg: The generated Control Flow Region object. :param state: The generated SDFGState object. :param local_stream: Code generator for the in-function code. :param global_stream: Code generator for global (external) code. """ pass - def on_scope_entry(self, sdfg: SDFG, state: SDFGState, node: nodes.EntryNode, outer_stream: CodeIOStream, - inner_stream: CodeIOStream, global_stream: CodeIOStream) -> None: + def on_scope_entry(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, node: nodes.EntryNode, + outer_stream: CodeIOStream, inner_stream: CodeIOStream, global_stream: CodeIOStream) -> None: """ Event called at the beginning of a scope (on generating an EntryNode). :param sdfg: The generated SDFG object. + :param cfg: The generated Control Flow Region object. :param state: The generated SDFGState object. :param node: The EntryNode object from which code is generated. :param outer_stream: Code generator for the internal code before @@ -98,11 +101,12 @@ def on_scope_entry(self, sdfg: SDFG, state: SDFGState, node: nodes.EntryNode, ou """ pass - def on_scope_exit(self, sdfg: SDFG, state: SDFGState, node: nodes.ExitNode, outer_stream: CodeIOStream, - inner_stream: CodeIOStream, global_stream: CodeIOStream) -> None: + def on_scope_exit(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, node: nodes.ExitNode, + outer_stream: CodeIOStream, inner_stream: CodeIOStream, global_stream: CodeIOStream) -> None: """ Event called at the end of a scope (on generating an ExitNode). :param sdfg: The generated SDFG object. + :param cfg: The generated Control Flow Region object. :param state: The generated SDFGState object. :param node: The ExitNode object from which code is generated. :param outer_stream: Code generator for the internal code after @@ -113,12 +117,13 @@ def on_scope_exit(self, sdfg: SDFG, state: SDFGState, node: nodes.ExitNode, oute """ pass - def on_copy_begin(self, sdfg: SDFG, state: SDFGState, src_node: nodes.Node, dst_node: nodes.Node, - edge: MultiConnectorEdge[Memlet], local_stream: CodeIOStream, global_stream: CodeIOStream, - copy_shape, src_strides, dst_strides) -> None: + def on_copy_begin(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, src_node: nodes.Node, + dst_node: nodes.Node, edge: MultiConnectorEdge[Memlet], local_stream: CodeIOStream, + global_stream: CodeIOStream, copy_shape, src_strides, dst_strides) -> None: """ Event called at the beginning of generating a copy operation. :param sdfg: The generated SDFG object. + :param cfg: The generated Control Flow Region object. :param state: The generated SDFGState object. :param src_node: The source node of the copy. :param dst_node: The destination node of the copy. @@ -131,11 +136,13 @@ def on_copy_begin(self, sdfg: SDFG, state: SDFGState, src_node: nodes.Node, dst_ """ pass - def on_copy_end(self, sdfg: SDFG, state: SDFGState, src_node: nodes.Node, dst_node: nodes.Node, - edge: MultiConnectorEdge[Memlet], local_stream: CodeIOStream, global_stream: CodeIOStream) -> None: + def on_copy_end(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, src_node: nodes.Node, + dst_node: nodes.Node, edge: MultiConnectorEdge[Memlet], local_stream: CodeIOStream, + global_stream: CodeIOStream) -> None: """ Event called at the end of generating a copy operation. :param sdfg: The generated SDFG object. + :param cfg: The generated Control Flow Region object. :param state: The generated SDFGState object. :param src_node: The source node of the copy. :param dst_node: The destination node of the copy. @@ -145,11 +152,12 @@ def on_copy_end(self, sdfg: SDFG, state: SDFGState, src_node: nodes.Node, dst_no """ pass - def on_node_begin(self, sdfg: SDFG, state: SDFGState, node: nodes.Node, outer_stream: CodeIOStream, - inner_stream: CodeIOStream, global_stream: CodeIOStream) -> None: + def on_node_begin(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, node: nodes.Node, + outer_stream: CodeIOStream, inner_stream: CodeIOStream, global_stream: CodeIOStream) -> None: """ Event called at the beginning of generating a node. :param sdfg: The generated SDFG object. + :param cfg: The generated Control Flow Region object. :param state: The generated SDFGState object. :param node: The generated node. :param outer_stream: Code generator for the internal code before @@ -160,11 +168,12 @@ def on_node_begin(self, sdfg: SDFG, state: SDFGState, node: nodes.Node, outer_st """ pass - def on_node_end(self, sdfg: SDFG, state: SDFGState, node: nodes.Node, outer_stream: CodeIOStream, - inner_stream: CodeIOStream, global_stream: CodeIOStream) -> None: + def on_node_end(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, node: nodes.Node, + outer_stream: CodeIOStream, inner_stream: CodeIOStream, global_stream: CodeIOStream) -> None: """ Event called at the end of generating a node. :param sdfg: The generated SDFG object. + :param cfg: The generated Control Flow Region object. :param state: The generated SDFGState object. :param node: The generated node. :param outer_stream: Code generator for the internal code after diff --git a/dace/codegen/instrumentation/timer.py b/dace/codegen/instrumentation/timer.py index a13e50faca..fea2cf70ea 100644 --- a/dace/codegen/instrumentation/timer.py +++ b/dace/codegen/instrumentation/timer.py @@ -16,24 +16,24 @@ def on_sdfg_begin(self, sdfg, local_stream, global_stream, codegen): sdfg.append_global_code('\n#include ', None) if sdfg.instrument == dtypes.InstrumentationType.Timer: - self.on_tbegin(local_stream, sdfg) + self.on_tbegin(local_stream, sdfg, sdfg) def on_sdfg_end(self, sdfg, local_stream, global_stream): if sdfg.instrument == dtypes.InstrumentationType.Timer: - self.on_tend('SDFG %s' % sdfg.name, local_stream, sdfg) + self.on_tend('SDFG %s' % sdfg.name, local_stream, sdfg, sdfg) - def on_tbegin(self, stream: CodeIOStream, sdfg=None, state=None, node=None): - idstr = self._idstr(sdfg, state, node) + def on_tbegin(self, stream: CodeIOStream, sdfg=None, cfg=None, state=None, node=None): + idstr = self._idstr(cfg, state, node) stream.write('auto __dace_tbegin_%s = std::chrono::high_resolution_clock::now();' % idstr) - def on_tend(self, timer_name: str, stream: CodeIOStream, sdfg=None, state=None, node=None): - idstr = self._idstr(sdfg, state, node) + def on_tend(self, timer_name: str, stream: CodeIOStream, sdfg=None, cfg=None, state=None, node=None): + idstr = self._idstr(cfg, state, node) state_id = -1 node_id = -1 if state is not None: - state_id = sdfg.node_id(state) + state_id = state.block_id if node is not None: node_id = state.node_id(node) @@ -41,16 +41,16 @@ def on_tend(self, timer_name: str, stream: CodeIOStream, sdfg=None, state=None, unsigned long int __dace_ts_start_{id} = std::chrono::duration_cast(__dace_tbegin_{id}.time_since_epoch()).count(); unsigned long int __dace_ts_end_{id} = std::chrono::duration_cast(__dace_tend_{id}.time_since_epoch()).count(); __state->report.add_completion("{timer_name}", "Timer", __dace_ts_start_{id}, __dace_ts_end_{id}, {cfg_id}, {state_id}, {node_id});''' - .format(timer_name=timer_name, id=idstr, cfg_id=sdfg.cfg_id, state_id=state_id, node_id=node_id)) + .format(timer_name=timer_name, id=idstr, cfg_id=cfg.cfg_id, state_id=state_id, node_id=node_id)) # Code generation hooks - def on_state_begin(self, sdfg, state, local_stream, global_stream): + def on_state_begin(self, sdfg, cfg, state, local_stream, global_stream): if state.instrument == dtypes.InstrumentationType.Timer: - self.on_tbegin(local_stream, sdfg, state) + self.on_tbegin(local_stream, sdfg, cfg, state) - def on_state_end(self, sdfg, state, local_stream, global_stream): + def on_state_end(self, sdfg, cfg, state, local_stream, global_stream): if state.instrument == dtypes.InstrumentationType.Timer: - self.on_tend('State %s' % state.label, local_stream, sdfg, state) + self.on_tend('State %s' % state.label, local_stream, sdfg, cfg, state) def _get_sobj(self, node): # Get object behind scope @@ -59,26 +59,26 @@ def _get_sobj(self, node): else: return node.map - def on_scope_entry(self, sdfg, state, node, outer_stream, inner_stream, global_stream): + def on_scope_entry(self, sdfg, cfg, state, node, outer_stream, inner_stream, global_stream): s = self._get_sobj(node) if s.instrument == dtypes.InstrumentationType.Timer: - self.on_tbegin(outer_stream, sdfg, state, node) + self.on_tbegin(outer_stream, sdfg, cfg, state, node) - def on_scope_exit(self, sdfg, state, node, outer_stream, inner_stream, global_stream): + def on_scope_exit(self, sdfg, cfg, state, node, outer_stream, inner_stream, global_stream): entry_node = state.entry_node(node) s = self._get_sobj(node) if s.instrument == dtypes.InstrumentationType.Timer: - self.on_tend('%s %s' % (type(s).__name__, s.label), outer_stream, sdfg, state, entry_node) + self.on_tend('%s %s' % (type(s).__name__, s.label), outer_stream, sdfg, cfg, state, entry_node) - def on_node_begin(self, sdfg, state, node, outer_stream, inner_stream, global_stream): + def on_node_begin(self, sdfg, cfg, state, node, outer_stream, inner_stream, global_stream): if not isinstance(node, CodeNode): return if node.instrument == dtypes.InstrumentationType.Timer: - self.on_tbegin(outer_stream, sdfg, state, node) + self.on_tbegin(outer_stream, sdfg, cfg, state, node) - def on_node_end(self, sdfg, state, node, outer_stream, inner_stream, global_stream): + def on_node_end(self, sdfg, cfg, state, node, outer_stream, inner_stream, global_stream): if not isinstance(node, CodeNode): return if node.instrument == dtypes.InstrumentationType.Timer: - idstr = self._idstr(sdfg, state, node) - self.on_tend('%s %s' % (type(node).__name__, idstr), outer_stream, sdfg, state, node) + idstr = self._idstr(cfg, state, node) + self.on_tend('%s %s' % (type(node).__name__, idstr), outer_stream, sdfg, cfg, state, node) diff --git a/dace/codegen/targets/cpu.py b/dace/codegen/targets/cpu.py index 9ba202757e..c7d05de5a3 100644 --- a/dace/codegen/targets/cpu.py +++ b/dace/codegen/targets/cpu.py @@ -854,7 +854,7 @@ def _emit_copy( # Instrumentation: Pre-copy for instr in self._dispatcher.instrumentation.values(): if instr is not None: - instr.on_copy_begin(sdfg, state_dfg, src_node, dst_node, edge, stream, None, copy_shape, + instr.on_copy_begin(sdfg, cfg, state_dfg, src_node, dst_node, edge, stream, None, copy_shape, src_strides, dst_strides) nc = True @@ -916,7 +916,7 @@ def _emit_copy( # Instrumentation: Post-copy for instr in self._dispatcher.instrumentation.values(): if instr is not None: - instr.on_copy_end(sdfg, state_dfg, src_node, dst_node, edge, stream, None) + instr.on_copy_end(sdfg, cfg, state_dfg, src_node, dst_node, edge, stream, None) ############################################################# ########################################################################### @@ -1506,7 +1506,7 @@ def _generate_Tasklet(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSubgra # Instrumentation: Pre-tasklet instr = self._dispatcher.instrumentation[node.instrument] if instr is not None: - instr.on_node_begin(sdfg, state_dfg, node, outer_stream_begin, inner_stream, function_stream) + instr.on_node_begin(sdfg, cfg, state_dfg, node, outer_stream_begin, inner_stream, function_stream) inner_stream.write("\n ///////////////////\n", cfg, state_id, node) @@ -1535,7 +1535,7 @@ def _generate_Tasklet(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSubgra # Instrumentation: Post-tasklet if instr is not None: - instr.on_node_end(sdfg, state_dfg, node, outer_stream_end, inner_stream, function_stream) + instr.on_node_end(sdfg, cfg, state_dfg, node, outer_stream_end, inner_stream, function_stream) callsite_stream.write(outer_stream_begin.getvalue(), cfg, state_id, node) callsite_stream.write('{', cfg, state_id, node) @@ -1710,7 +1710,7 @@ def _generate_NestedSDFG( # If the SDFG has a unique name, use it sdfg_label = node.unique_name else: - sdfg_label = "%s_%d_%d_%d" % (node.sdfg.name, sdfg.cfg_id, state_id, dfg.node_id(node)) + sdfg_label = "%s_%d_%d_%d" % (node.sdfg.name, cfg.cfg_id, state_id, dfg.node_id(node)) code_already_generated = False if unique_functions and not inline: @@ -1856,7 +1856,7 @@ def _generate_MapEntry( # Instrumentation: Pre-scope instr = self._dispatcher.instrumentation[node.map.instrument] if instr is not None: - instr.on_scope_entry(sdfg, state_dfg, node, callsite_stream, inner_stream, function_stream) + instr.on_scope_entry(sdfg, cfg, state_dfg, node, callsite_stream, inner_stream, function_stream) # TODO: Refactor to generate_scope_preamble once a general code # generator (that CPU inherits from) is implemented @@ -1970,7 +1970,7 @@ def _generate_MapExit(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSubgra # Instrumentation: Post-scope instr = self._dispatcher.instrumentation[node.map.instrument] if instr is not None and not is_devicelevel_gpu(sdfg, state_dfg, node): - instr.on_scope_exit(sdfg, state_dfg, node, outer_stream, callsite_stream, function_stream) + instr.on_scope_exit(sdfg, cfg, state_dfg, node, outer_stream, callsite_stream, function_stream) self.generate_scope_postamble(sdfg, dfg, state_id, function_stream, outer_stream, callsite_stream) @@ -2155,7 +2155,7 @@ def _generate_AccessNode(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSub # Instrumentation: Pre-node instr = self._dispatcher.instrumentation[node.instrument] if instr is not None: - instr.on_node_begin(sdfg, state_dfg, node, callsite_stream, callsite_stream, function_stream) + instr.on_node_begin(sdfg, cfg, state_dfg, node, callsite_stream, callsite_stream, function_stream) sdict = state_dfg.scope_dict() for edge in state_dfg.in_edges(node): @@ -2198,7 +2198,7 @@ def _generate_AccessNode(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg: StateSub # Instrumentation: Post-node if instr is not None: - instr.on_node_end(sdfg, state_dfg, node, callsite_stream, callsite_stream, function_stream) + instr.on_node_end(sdfg, cfg, state_dfg, node, callsite_stream, callsite_stream, function_stream) # Methods for subclasses to override diff --git a/dace/codegen/targets/cuda.py b/dace/codegen/targets/cuda.py index 6425f01688..fd9840fdf0 100644 --- a/dace/codegen/targets/cuda.py +++ b/dace/codegen/targets/cuda.py @@ -157,8 +157,8 @@ def preprocess(self, sdfg: SDFG) -> None: # Find GPU<->GPU strided copies that cannot be represented by a single copy command from dace.transformation.dataflow import CopyToMap for e, state in list(sdfg.all_edges_recursive()): - nsdfg = state.parent if isinstance(e.src, nodes.AccessNode) and isinstance(e.dst, nodes.AccessNode): + nsdfg = state.parent if (e.src.desc(nsdfg).storage == dtypes.StorageType.GPU_Global and e.dst.desc(nsdfg).storage == dtypes.StorageType.GPU_Global): copy_shape, src_strides, dst_strides, _, _ = memlet_copy_to_absolute_strides( @@ -775,7 +775,7 @@ def increment(streams): state_streams = [] state_subsdfg_events = [] - for state in sdfg.nodes(): + for state in sdfg.states(): # Start by annotating source nodes source_nodes = state.source_nodes() @@ -873,7 +873,7 @@ def increment(streams): # Compute maximal number of events by counting edges (within the same # state) that point from one stream to another state_events = [] - for i, state in enumerate(sdfg.nodes()): + for i, state in enumerate(sdfg.states()): events = state_subsdfg_events[i] for e in state.edges(): @@ -1306,7 +1306,7 @@ def generate_state(self, # Invoke all instrumentation providers for instr in self._frame._dispatcher.instrumentation.values(): if instr is not None: - instr.on_state_end(sdfg, state, callsite_stream, function_stream) + instr.on_state_end(sdfg, cfg, state, callsite_stream, function_stream) def generate_devicelevel_state(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, function_stream: CodeIOStream, callsite_stream: CodeIOStream) -> None: @@ -1435,8 +1435,7 @@ def generate_scope(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg_scope: StateSub create_grid_barrier = True self.create_grid_barrier = create_grid_barrier - kernel_name = '%s_%d_%d_%d' % (scope_entry.map.label, sdfg.cfg_id, sdfg.node_id(state), - state.node_id(scope_entry)) + kernel_name = '%s_%d_%d_%d' % (scope_entry.map.label, cfg.cfg_id, state.block_id, state.node_id(scope_entry)) # Comprehend grid/block dimensions from scopes grid_dims, block_dims, tbmap, dtbmap, _ = self.get_kernel_dimensions(dfg_scope) @@ -1496,9 +1495,10 @@ def generate_scope(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg_scope: StateSub # Instrumentation for kernel scope instr = self._dispatcher.instrumentation[scope_entry.map.instrument] if instr is not None: - instr.on_scope_entry(sdfg, state, scope_entry, callsite_stream, self.scope_entry_stream, self._globalcode) + instr.on_scope_entry(sdfg, cfg, state, scope_entry, callsite_stream, self.scope_entry_stream, + self._globalcode) outer_stream = CodeIOStream() - instr.on_scope_exit(sdfg, state, scope_exit, outer_stream, self.scope_exit_stream, self._globalcode) + instr.on_scope_exit(sdfg, cfg, state, scope_exit, outer_stream, self.scope_exit_stream, self._globalcode) # Redefine constant arguments and rename arguments to device counterparts # TODO: This (const behavior and code below) is all a hack. @@ -1587,7 +1587,7 @@ def generate_scope(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg_scope: StateSub # Write kernel prototype self._localcode.write( '__global__ void %s %s(%s) {\n' % - (launch_bounds, kernel_name, ', '.join(kernel_args_typed + extra_kernel_args_typed)), sdfg, state_id, node) + (launch_bounds, kernel_name, ', '.join(kernel_args_typed + extra_kernel_args_typed)), cfg, state_id, node) # Write constant expressions in GPU code self._frame.generate_constants(sdfg, self._localcode) @@ -2009,7 +2009,7 @@ def generate_kernel_scope(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg_scope: S bidx = krange.coord_at(dsym) # handle dynamic map inputs - for e in dace.sdfg.dynamic_map_inputs(sdfg.states()[state_id], dfg_scope.source_nodes()[0]): + for e in dace.sdfg.dynamic_map_inputs(cfg.node(state_id), dfg_scope.source_nodes()[0]): kernel_stream.write( self._cpu_codegen.memlet_definition(sdfg, e.data, False, e.dst_conn, e.dst.in_connectors[e.dst_conn]), cfg, state_id, @@ -2032,7 +2032,7 @@ def generate_kernel_scope(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg_scope: S expr = _topy(bidx[i]).replace('__DAPB%d' % i, block_expr) - kernel_stream.write(f'{tidtype.ctype} {varname} = {expr};', sdfg, state_id, node) + kernel_stream.write(f'{tidtype.ctype} {varname} = {expr};', cfg, state_id, node) self._dispatcher.defined_vars.add(varname, DefinedType.Scalar, tidtype.ctype) # Delinearize beyond the third dimension @@ -2059,7 +2059,7 @@ def generate_kernel_scope(self, sdfg: SDFG, cfg: ControlFlowRegion, dfg_scope: S assert CUDACodeGen._in_device_code is False CUDACodeGen._in_device_code = True self._kernel_map = node - self._kernel_state = sdfg.node(state_id) + self._kernel_state = cfg.node(state_id) self._block_dims = block_dims self._grid_dims = grid_dims diff --git a/dace/codegen/targets/fpga.py b/dace/codegen/targets/fpga.py index 61ba9f95ad..197f515e52 100644 --- a/dace/codegen/targets/fpga.py +++ b/dace/codegen/targets/fpga.py @@ -421,10 +421,11 @@ def find_rtl_tasklet(self, subgraph: ScopeSubgraphView): ''' for n in subgraph.nodes(): if isinstance(n, dace.nodes.NestedSDFG): - for sg in dace.sdfg.concurrent_subgraphs(n.sdfg.start_state): - node = self.find_rtl_tasklet(sg) - if node: - return node + if len(n.sdfg.nodes()) == 1 and isinstance(n.sdfg.nodes()[0], SDFGState): + for sg in dace.sdfg.concurrent_subgraphs(n.sdfg.start_state): + node = self.find_rtl_tasklet(sg) + if node: + return node elif isinstance(n, dace.nodes.Tasklet) and n.language == dace.dtypes.Language.SystemVerilog: return n return None @@ -438,9 +439,10 @@ def is_multi_pumped_subgraph(self, subgraph: ScopeSubgraphView): ''' for n in subgraph.nodes(): if isinstance(n, dace.nodes.NestedSDFG): - for sg in dace.sdfg.concurrent_subgraphs(n.sdfg.start_state): - if self.is_multi_pumped_subgraph(sg): - return True + if len(n.sdfg.nodes()) == 1 and isinstance(n.sdfg.nodes()[0], SDFGState): + for sg in dace.sdfg.concurrent_subgraphs(n.sdfg.nodes()[0]): + if self.is_multi_pumped_subgraph(sg): + return True elif isinstance(n, dace.nodes.MapEntry) and n.schedule == dace.ScheduleType.FPGA_Multi_Pumped: return True return False @@ -1105,7 +1107,7 @@ def generate_nested_state(self, sdfg: SDFG, cfg: ControlFlowRegion, state: dace. self._dispatcher.dispatch_subgraph(sdfg, cfg, sg, - sdfg.node_id(state), + cfg.node_id(state), function_stream, callsite_stream, skip_entry_node=False) @@ -1720,7 +1722,7 @@ def _emit_copy(self, sdfg: SDFG, cfg: ControlFlowRegion, state_id: int, src_node raise NotImplementedError("Reads from shift registers only supported from tasklets.") # Try to turn into degenerate/strided ND copies - state_dfg = sdfg.nodes()[state_id] + state_dfg = cfg.node(state_id) copy_shape, src_strides, dst_strides, src_expr, dst_expr = (cpp.memlet_copy_to_absolute_strides( self._dispatcher, sdfg, state_dfg, edge, src_node, dst_node, packed_types=True)) @@ -1974,7 +1976,7 @@ def _is_innermost(self, scope, scope_dict, sdfg): return False to_search += scope_dict[x] elif isinstance(x, dace.sdfg.nodes.NestedSDFG): - for state in x.sdfg: + for state in x.sdfg.states(): if not self._is_innermost(state.nodes(), state.scope_children(), x.sdfg): return False return True diff --git a/dace/codegen/targets/framecode.py b/dace/codegen/targets/framecode.py index 0b8fa739fe..f760715ef9 100644 --- a/dace/codegen/targets/framecode.py +++ b/dace/codegen/targets/framecode.py @@ -1,4 +1,4 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. import collections import copy import re @@ -9,6 +9,7 @@ import dace from dace import config, data, dtypes +from dace import symbolic from dace.cli import progress from dace.codegen import control_flow as cflow from dace.codegen import dispatcher as disp @@ -16,12 +17,11 @@ from dace.codegen.common import codeblock_to_cpp, sym2cpp from dace.codegen.targets.target import TargetCodeGenerator from dace.codegen.tools.type_inference import infer_expr_type -from dace.frontend.python import astutils from dace.sdfg import SDFG, SDFGState, nodes from dace.sdfg import scope as sdscope from dace.sdfg import utils from dace.sdfg.analysis import cfg as cfg_analysis -from dace.sdfg.state import ControlFlowRegion, LoopRegion +from dace.sdfg.state import ConditionalBlock, ControlFlowBlock, ControlFlowRegion, LoopRegion from dace.transformation.passes.analysis import StateReachability, loop_analysis @@ -423,7 +423,7 @@ def generate_state(self, # Invoke all instrumentation providers for instr in self._dispatcher.instrumentation.values(): if instr is not None: - instr.on_state_begin(sdfg, state, callsite_stream, global_stream) + instr.on_state_begin(sdfg, cfg, state, callsite_stream, global_stream) ##################### # Create dataflow graph for state's children. @@ -470,7 +470,7 @@ def generate_state(self, # Invoke all instrumentation providers for instr in self._dispatcher.instrumentation.values(): if instr is not None: - instr.on_state_end(sdfg, state, callsite_stream, global_stream) + instr.on_state_end(sdfg, cfg, state, callsite_stream, global_stream) def generate_states(self, sdfg: SDFG, global_stream: CodeIOStream, callsite_stream: CodeIOStream) -> Set[SDFGState]: states_generated = set() @@ -485,7 +485,7 @@ def dispatch_state(state: SDFGState) -> str: states_generated.add(state) # For sanity check return stream.getvalue() - if sdfg.root_sdfg.recheck_using_experimental_blocks(): + if sdfg.root_sdfg.recheck_using_explicit_control_flow(): # Use control flow blocks embedded in the SDFG to generate control flow. cft = cflow.structured_control_flow_tree_with_regions(sdfg, dispatch_state) elif config.Config.get_bool('optimizer', 'detect_control_flow'): @@ -692,10 +692,14 @@ def determine_allocation_lifetime(self, top_sdfg: SDFG): curstate: SDFGState = None multistate = False - # Does the array appear in inter-state edges? + # Does the array appear in inter-state edges or loop / conditional block conditions etc.? for isedge in sdfg.all_interstate_edges(): if name in self.free_symbols(isedge.data): multistate = True + for cfg in sdfg.all_control_flow_regions(): + block_syms = cfg.used_symbols(all_symbols=True, with_contents=False) + if name in block_syms: + multistate = True for state in sdfg.states(): if multistate: @@ -848,8 +852,8 @@ def deallocate_arrays_in_scope(self, sdfg: SDFG, cfg: ControlFlowRegion, scope: desc = node.desc(tsdfg) - self._dispatcher.dispatch_deallocate(tsdfg, cfg, state, state_id, node, desc, function_stream, - callsite_stream) + self._dispatcher.dispatch_deallocate(tsdfg, state.parent_graph, state, state_id, node, desc, + function_stream, callsite_stream) def generate_code(self, sdfg: SDFG, @@ -919,20 +923,17 @@ def generate_code(self, global_symbols.update(symbols) if isinstance(cfr, LoopRegion) and cfr.loop_variable is not None and cfr.init_statement is not None: - init_assignment = cfr.init_statement.code[0] - update_assignment = cfr.update_statement.code[0] - if isinstance(init_assignment, astutils.ast.Assign): - init_assignment = init_assignment.value - if isinstance(update_assignment, astutils.ast.Assign): - update_assignment = update_assignment.value if not cfr.loop_variable in interstate_symbols: - l_end = loop_analysis.get_loop_end(cfr) - l_start = loop_analysis.get_init_assignment(cfr) - l_step = loop_analysis.get_loop_stride(cfr) - sym_type = dtypes.result_type_of(infer_expr_type(l_start, global_symbols), - infer_expr_type(l_step, global_symbols), - infer_expr_type(l_end, global_symbols)) - interstate_symbols[cfr.loop_variable] = sym_type + if cfr.loop_variable in global_symbols: + interstate_symbols[cfr.loop_variable] = global_symbols[cfr.loop_variable] + else: + l_end = loop_analysis.get_loop_end(cfr) + l_start = loop_analysis.get_init_assignment(cfr) + l_step = loop_analysis.get_loop_stride(cfr) + sym_type = dtypes.result_type_of(infer_expr_type(l_start, global_symbols), + infer_expr_type(l_step, global_symbols), + infer_expr_type(l_end, global_symbols)) + interstate_symbols[cfr.loop_variable] = sym_type if not cfr.loop_variable in global_symbols: global_symbols[cfr.loop_variable] = interstate_symbols[cfr.loop_variable] @@ -1042,21 +1043,20 @@ def generate_code(self, return (generated_header, clean_code, self._dispatcher.used_targets, self._dispatcher.used_environments) -def _get_dominator_and_postdominator(cfg: ControlFlowRegion, accesses: List[Tuple[SDFGState, nodes.AccessNode]]): +def _get_dominator_and_postdominator(sdfg: SDFG, accesses: List[Tuple[SDFGState, nodes.AccessNode]]): """ Gets the closest common dominator and post-dominator for a list of states. Used for determining allocation of data used in branched states. """ - # Get immediate dominators - idom = nx.immediate_dominators(cfg.nx, cfg.start_block) - alldoms = cfg_analysis.all_dominators(cfg, idom) + alldoms: Dict[ControlFlowBlock, Set[ControlFlowBlock]] = collections.defaultdict(lambda: set()) + allpostdoms: Dict[ControlFlowBlock, Set[ControlFlowBlock]] = collections.defaultdict(lambda: set()) + idom: Dict[ControlFlowRegion, Dict[ControlFlowBlock, ControlFlowBlock]] = {} + ipostdom: Dict[ControlFlowRegion, Dict[ControlFlowBlock, ControlFlowBlock]] = {} + utils.get_control_flow_block_dominators(sdfg, idom, alldoms, ipostdom, allpostdoms) states = [a for a, _ in accesses] data_name = accesses[0][1].data - # Get immediate post-dominators - ipostdom, allpostdoms = utils.postdominators(cfg, return_alldoms=True) - # All dominators and postdominators include the states themselves for state in states: alldoms[state].add(state) diff --git a/dace/codegen/targets/snitch.py b/dace/codegen/targets/snitch.py index 5a62ca2995..bcdcb61941 100644 --- a/dace/codegen/targets/snitch.py +++ b/dace/codegen/targets/snitch.py @@ -1,4 +1,4 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. from typing import Union import dace @@ -201,7 +201,7 @@ def generate_state(self, sdfg: SDFG, cfg: ControlFlowRegion, state: SDFGState, g # Invoke all instrumentation providers for instr in self.dispatcher.instrumentation.values(): if instr is not None: - instr.on_state_begin(sdfg, state, callsite_stream, global_stream) + instr.on_state_begin(sdfg, cfg, state, callsite_stream, global_stream) ##################### # Create dataflow graph for state's children. diff --git a/dace/frontend/fortran/fortran_parser.py b/dace/frontend/fortran/fortran_parser.py index 1cdecc99a8..6b14f63edd 100644 --- a/dace/frontend/fortran/fortran_parser.py +++ b/dace/frontend/fortran/fortran_parser.py @@ -29,7 +29,7 @@ class AST_translator: """ This class is responsible for translating the internal AST into a SDFG. """ - def __init__(self, ast: ast_components.InternalFortranAst, source: str, use_experimental_cfg_blocks: bool = False): + def __init__(self, ast: ast_components.InternalFortranAst, source: str, use_explicit_cf: bool = False): """ :ast: The internal fortran AST to be used for translation :source: The source file name from which the AST was generated @@ -69,7 +69,7 @@ def __init__(self, ast: ast_components.InternalFortranAst, source: str, use_expe ast_internal_classes.Allocate_Stmt_Node: self.allocate2sdfg, ast_internal_classes.Break_Node: self.break2sdfg, } - self.use_experimental_cfg_blocks = use_experimental_cfg_blocks + self.use_explicit_cf = use_explicit_cf def get_dace_type(self, type): """ @@ -271,7 +271,7 @@ def forstmt2sdfg(self, node: ast_internal_classes.For_Stmt_Node, sdfg: SDFG, cfg :param sdfg: The SDFG to which the node should be translated """ - if not self.use_experimental_cfg_blocks: + if not self.use_explicit_cf: declloop = False name = "FOR_l_" + str(node.line_number[0]) + "_c_" + str(node.line_number[1]) begin_state = ast_utils.add_simple_state_to_sdfg(self, cfg, "Begin" + name) @@ -1103,7 +1103,7 @@ def create_sdfg_from_string( source_string: str, sdfg_name: str, normalize_offsets: bool = False, - use_experimental_cfg_blocks: bool = False + use_explicit_cf: bool = False ): """ Creates an SDFG from a fortran file in a string @@ -1133,7 +1133,7 @@ def create_sdfg_from_string( program = ast_transforms.ForDeclarer().visit(program) program = ast_transforms.IndexExtractor(program, normalize_offsets).visit(program) - ast2sdfg = AST_translator(own_ast, __file__, use_experimental_cfg_blocks) + ast2sdfg = AST_translator(own_ast, __file__, use_explicit_cf) sdfg = SDFG(sdfg_name) ast2sdfg.top_level = program ast2sdfg.globalsdfg = sdfg @@ -1148,11 +1148,11 @@ def create_sdfg_from_string( sdfg.parent_sdfg = None sdfg.parent_nsdfg_node = None sdfg.reset_cfg_list() - sdfg.using_experimental_blocks = use_experimental_cfg_blocks + sdfg.using_explicit_control_flow = use_explicit_cf return sdfg -def create_sdfg_from_fortran_file(source_string: str, use_experimental_cfg_blocks: bool = False): +def create_sdfg_from_fortran_file(source_string: str, use_explicit_cf: bool = False): """ Creates an SDFG from a fortran file :param source_string: The fortran file name @@ -1180,11 +1180,11 @@ def create_sdfg_from_fortran_file(source_string: str, use_experimental_cfg_block program = ast_transforms.ForDeclarer().visit(program) program = ast_transforms.IndexExtractor(program).visit(program) - ast2sdfg = AST_translator(own_ast, __file__, use_experimental_cfg_blocks) + ast2sdfg = AST_translator(own_ast, __file__, use_explicit_cf) sdfg = SDFG(source_string) ast2sdfg.top_level = program ast2sdfg.globalsdfg = sdfg ast2sdfg.translate(program, sdfg) - sdfg.using_experimental_blocks = use_experimental_cfg_blocks + sdfg.using_explicit_control_flow = use_explicit_cf return sdfg diff --git a/dace/frontend/python/astutils.py b/dace/frontend/python/astutils.py index 425e94cd9f..4e6aa68651 100644 --- a/dace/frontend/python/astutils.py +++ b/dace/frontend/python/astutils.py @@ -1,9 +1,8 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Various AST parsing utilities for DaCe. """ import ast import astunparse import copy -from collections import OrderedDict from io import StringIO import inspect import numbers @@ -12,7 +11,7 @@ import sys from typing import Any, Dict, List, Optional, Set, Union -from dace import dtypes, symbolic +from dace import symbolic if sys.version_info >= (3, 8): @@ -587,6 +586,36 @@ def visit_keyword(self, node: ast.keyword): return self.generic_visit(node) +class FindAssignment(ast.NodeVisitor): + + assignments: Dict[str, str] + multiple: bool + + def __init__(self): + self.assignments = {} + self.multiple = False + + def visit_Assign(self, node: ast.Assign) -> Any: + for tgt in node.targets: + if isinstance(tgt, ast.Name): + if tgt.id in self.assignments: + self.multiple = True + self.assignments[tgt.id] = unparse(node.value) + return self.generic_visit(node) + + +class ASTReplaceAssignmentRHS(ast.NodeVisitor): + + repl_visitor: ASTFindReplace + + def __init__(self, repl: Dict[str, str]): + self.repl_visitor = ASTFindReplace(repl) + + def visit_Assign(self, node: ast.Assign) -> Any: + self.repl_visitor.visit(node.value) + return self.generic_visit(node) + + class RemoveSubscripts(ast.NodeTransformer): def __init__(self, keywords: Set[str]): self.keywords = keywords diff --git a/dace/frontend/python/interface.py b/dace/frontend/python/interface.py index 14164054d3..6fb92077b7 100644 --- a/dace/frontend/python/interface.py +++ b/dace/frontend/python/interface.py @@ -44,7 +44,7 @@ def program(f: F, recompile: bool = True, distributed_compilation: bool = False, constant_functions=False, - use_experimental_cfg_blocks=False, + use_explicit_cf=True, **kwargs) -> Callable[..., parser.DaceProgram]: """ Entry point to a data-centric program. For methods and ``classmethod``s, use @@ -69,8 +69,7 @@ def program(f: F, not depend on internal variables are constant. This will hardcode their return values into the resulting program. - :param use_experimental_cfg_blocks: If True, makes use of experimental CFG blocks susch as loop and conditional - regions. + :param use_explicit_cf: If True, makes use of explicit control flow constructs. :note: If arguments are defined with type hints, the program can be compiled ahead-of-time with ``.compile()``. """ @@ -87,7 +86,7 @@ def program(f: F, regenerate_code=regenerate_code, recompile=recompile, distributed_compilation=distributed_compilation, - use_experimental_cfg_blocks=use_experimental_cfg_blocks) + use_explicit_cf=use_explicit_cf) function = program diff --git a/dace/frontend/python/newast.py b/dace/frontend/python/newast.py index d2813371c9..bb6e28f163 100644 --- a/dace/frontend/python/newast.py +++ b/dace/frontend/python/newast.py @@ -2566,7 +2566,7 @@ def visit_If(self, node: ast.If): # Add conditional region cond_block = ConditionalBlock(f'if_{node.lineno}') - self.cfg_target.add_node(cond_block) + self.cfg_target.add_node(cond_block, ensure_unique_name=True) self._on_block_added(cond_block) if_body = ControlFlowRegion(cond_block.label + '_body', sdfg=self.sdfg) @@ -4527,7 +4527,7 @@ def visit_Call(self, node: ast.Call, create_callbacks=False): else: name = "call" call_region = FunctionCallRegion(label=f"{name}_{node.lineno}", arguments=[]) - self.cfg_target.add_node(call_region) + self.cfg_target.add_node(call_region, ensure_unique_name=True) self._on_block_added(call_region) previous_last_cfg_target = self.last_cfg_target previous_target = self.cfg_target @@ -4780,7 +4780,7 @@ def visit_With(self, node: ast.With, is_async=False): else: named_region_name = f"Named Region {node.lineno}" named_region = NamedRegion(named_region_name, debuginfo=self.current_lineinfo) - self.cfg_target.add_node(named_region) + self.cfg_target.add_node(named_region, ensure_unique_name=True) self._on_block_added(named_region) self._recursive_visit(node.body, "init_named", node.lineno, named_region, unconnected_last_block=False) return diff --git a/dace/frontend/python/parser.py b/dace/frontend/python/parser.py index b65e7c227d..0faa2e36ce 100644 --- a/dace/frontend/python/parser.py +++ b/dace/frontend/python/parser.py @@ -156,7 +156,7 @@ def __init__(self, recompile: bool = True, distributed_compilation: bool = False, method: bool = False, - use_experimental_cfg_blocks: bool = False): + use_explicit_cf: bool = True): from dace.codegen import compiled_sdfg # Avoid import loops self.f = f @@ -176,7 +176,7 @@ def __init__(self, self.recreate_sdfg = recreate_sdfg self.regenerate_code = regenerate_code self.recompile = recompile - self.use_experimental_cfg_blocks = use_experimental_cfg_blocks + self.use_explicit_cf = use_explicit_cf self.distributed_compilation = distributed_compilation self.global_vars = _get_locals_and_globals(f) @@ -494,11 +494,10 @@ def _parse(self, args, kwargs, simplify=None, save=False, validate=False) -> SDF # Obtain DaCe program as SDFG sdfg, cached = self._generate_pdp(args, kwargs, simplify=simplify) - if not self.use_experimental_cfg_blocks: + if not self.use_explicit_cf: for nsdfg in sdfg.all_sdfgs_recursive(): - sdutils.inline_conditional_blocks(nsdfg) sdutils.inline_control_flow_regions(nsdfg) - sdfg.using_experimental_blocks = self.use_experimental_cfg_blocks + sdfg.using_explicit_control_flow = self.use_explicit_cf sdfg.reset_cfg_list() diff --git a/dace/sdfg/analysis/cfg.py b/dace/sdfg/analysis/cfg.py index c96ef5aff0..eb9aea0e2b 100644 --- a/dace/sdfg/analysis/cfg.py +++ b/dace/sdfg/analysis/cfg.py @@ -377,10 +377,11 @@ def blockorder_topological_sort(cfg: ControlFlowRegion, elif isinstance(block, ConditionalBlock): if not ignore_nonstate_blocks: yield block - for _, branch in block.branches: - if not ignore_nonstate_blocks: - yield branch - yield from blockorder_topological_sort(branch, recursive, ignore_nonstate_blocks) + if recursive: + for _, branch in block.branches: + if not ignore_nonstate_blocks: + yield branch + yield from blockorder_topological_sort(branch, recursive, ignore_nonstate_blocks) elif isinstance(block, SDFGState): yield block else: diff --git a/dace/sdfg/analysis/cutout.py b/dace/sdfg/analysis/cutout.py index ec95157989..20e8f8d6df 100644 --- a/dace/sdfg/analysis/cutout.py +++ b/dace/sdfg/analysis/cutout.py @@ -574,7 +574,7 @@ def _transformation_determine_affected_nodes( if transformation.cfg_id >= 0 and target_sdfg.cfg_list: target_sdfg = target_sdfg.cfg_list[transformation.cfg_id] - subgraph = transformation.get_subgraph(target_sdfg) + subgraph = transformation.subgraph_view(target_sdfg) for n in subgraph.nodes(): affected_nodes.add(n) diff --git a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py index 9357ca3db9..e0bc95ad34 100644 --- a/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py +++ b/dace/sdfg/analysis/schedule_tree/sdfg_to_tree.py @@ -652,7 +652,10 @@ def as_schedule_tree(sdfg: SDFG, in_place: bool = False, toplevel: bool = True) ############################# # Create initial tree from CFG - cfg: cf.ControlFlow = cf.structured_control_flow_tree(sdfg, lambda _: '') + if sdfg.using_explicit_control_flow: + cfg: cf.ControlFlow = cf.structured_control_flow_tree_with_regions(sdfg, lambda _: '') + else: + cfg: cf.ControlFlow = cf.structured_control_flow_tree(sdfg, lambda _: '') # Traverse said tree (also into states) to create the schedule tree def totree(node: cf.ControlFlow, parent: cf.GeneralBlock = None) -> List[tn.ScheduleTreeNode]: diff --git a/dace/sdfg/analysis/writeset_underapproximation.py b/dace/sdfg/analysis/writeset_underapproximation.py index a0f84e93a6..0426cb0942 100644 --- a/dace/sdfg/analysis/writeset_underapproximation.py +++ b/dace/sdfg/analysis/writeset_underapproximation.py @@ -685,7 +685,7 @@ class UnderapproximateWritesDict: Tuple[SDFGState, SDFGState, List[SDFGState], str, subsets.Range]] = field(default_factory=dict) -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class UnderapproximateWrites(ppl.Pass): # Dictionary mapping each edge to a copy of the memlet of that edge with its write set underapproximated. diff --git a/dace/sdfg/graph.py b/dace/sdfg/graph.py index b8dcfd4e8f..4b3bd82721 100644 --- a/dace/sdfg/graph.py +++ b/dace/sdfg/graph.py @@ -680,7 +680,7 @@ def add_edge(self, src: NodeT, dst: NodeT, data: EdgeT = None): def remove_node(self, node: NodeT): try: - for edge in itertools.chain(self.in_edges(node), self.out_edges(node)): + for edge in self.all_edges(node): self.remove_edge(edge) del self._nodes[node] self._nx.remove_node(node) diff --git a/dace/sdfg/infer_types.py b/dace/sdfg/infer_types.py index c05708670e..940114bbe2 100644 --- a/dace/sdfg/infer_types.py +++ b/dace/sdfg/infer_types.py @@ -1,7 +1,6 @@ # Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. from collections import defaultdict from dace import data, dtypes -from dace.codegen.tools import type_inference from dace.memlet import Memlet from dace.sdfg import SDFG, SDFGState, nodes, validation from dace.sdfg import nodes diff --git a/dace/sdfg/nodes.py b/dace/sdfg/nodes.py index 2b75a73ffa..eb073f4319 100644 --- a/dace/sdfg/nodes.py +++ b/dace/sdfg/nodes.py @@ -1418,8 +1418,8 @@ def expand(self, sdfg, state, *args, **kwargs) -> str: if implementation not in self.implementations.keys(): raise KeyError("Unknown implementation for node {}: {}".format(type(self).__name__, implementation)) transformation_type = type(self).implementations[implementation] - cfg_id = sdfg.cfg_id - state_id = sdfg.nodes().index(state) + cfg_id = state.parent_graph.cfg_id + state_id = state.block_id subgraph = {transformation_type._match_node: state.node_id(self)} transformation: ExpandTransformation = transformation_type() transformation.setup_match(sdfg, cfg_id, state_id, subgraph, 0) diff --git a/dace/sdfg/performance_evaluation/helpers.py b/dace/sdfg/performance_evaluation/helpers.py index 552e2917cc..0272101562 100644 --- a/dace/sdfg/performance_evaluation/helpers.py +++ b/dace/sdfg/performance_evaluation/helpers.py @@ -34,9 +34,9 @@ def get_uuid(element, state=None): if isinstance(element, SDFG): return ids_to_string(element.cfg_id) elif isinstance(element, SDFGState): - return ids_to_string(element.parent.cfg_id, element.parent.node_id(element)) + return ids_to_string(element.parent_graph.cfg_id, element.block_id) elif isinstance(element, nodes.Node): - return ids_to_string(state.parent.cfg_id, state.parent.node_id(state), state.node_id(element)) + return ids_to_string(state.parent_graph.cfg_id, state.block_id, state.node_id(element)) else: return ids_to_string(-1) diff --git a/dace/sdfg/propagation.py b/dace/sdfg/propagation.py index f048389421..2983ec3c63 100644 --- a/dace/sdfg/propagation.py +++ b/dace/sdfg/propagation.py @@ -9,7 +9,7 @@ import itertools import warnings from collections import deque -from typing import List, Set +from typing import TYPE_CHECKING, List, Set import sympy from sympy import Symbol, ceiling @@ -22,6 +22,11 @@ from dace.symbolic import issymbolic, pystr_to_symbolic, simplify +if TYPE_CHECKING: + from dace.sdfg import SDFG + from dace.sdfg.state import SDFGState + + @registry.make_registry class MemletPattern(object): """ @@ -561,7 +566,7 @@ def propagate(self, array, expressions, node_range): return subsets.Range(rng) -def _annotate_loop_ranges(sdfg, unannotated_cycle_states): +def _annotate_loop_ranges(sdfg: 'SDFG', unannotated_cycle_states): """ Annotate each valid for loop construct with its loop variable ranges. @@ -670,8 +675,8 @@ def _annotate_loop_ranges(sdfg, unannotated_cycle_states): loop_states = sdutils.dfs_conditional(sdfg, sources=[begin], condition=lambda _, child: child != guard) for v in loop_states: - v.ranges[itervar] = subsets.Range([rng]) - guard.ranges[itervar] = subsets.Range([rng]) + v.ranges[str(itervar)] = subsets.Range([rng]) + guard.ranges[str(itervar)] = subsets.Range([rng]) condition_edges[guard] = sdfg.edges_between(guard, begin)[0] guard.is_loop_guard = True guard.itvar = itervar @@ -682,7 +687,7 @@ def _annotate_loop_ranges(sdfg, unannotated_cycle_states): return condition_edges -def propagate_states(sdfg, concretize_dynamic_unbounded=False) -> None: +def propagate_states(sdfg: 'SDFG', concretize_dynamic_unbounded: bool = False) -> None: """ Annotate the states of an SDFG with the number of executions. @@ -739,6 +744,15 @@ def propagate_states(sdfg, concretize_dynamic_unbounded=False) -> None: :note: This operates on the SDFG in-place. """ + if sdfg.using_explicit_control_flow: + # Avoid cyclic imports + from dace.transformation.pass_pipeline import Pipeline + from dace.transformation.passes.analysis import StatePropagation + + state_prop_pipeline = Pipeline([StatePropagation()]) + state_prop_pipeline.apply_pass(sdfg, {}) + return + # We import here to avoid cyclic imports. from dace.sdfg import InterstateEdge from dace.sdfg.analysis import cfg @@ -948,7 +962,7 @@ def propagate_states(sdfg, concretize_dynamic_unbounded=False) -> None: sdfg.remove_node(temp_exit_state) -def propagate_memlets_nested_sdfg(parent_sdfg, parent_state, nsdfg_node): +def propagate_memlets_nested_sdfg(parent_sdfg: 'SDFG', parent_state: 'SDFGState', nsdfg_node: nodes.NestedSDFG): """ Propagate memlets out of a nested sdfg. @@ -980,7 +994,7 @@ def propagate_memlets_nested_sdfg(parent_sdfg, parent_state, nsdfg_node): # the corresponding memlets and use them to calculate the memlet volume and # subset corresponding to the outside memlet attached to that connector. # This is passed out via `border_memlets` and propagated along from there. - for state in sdfg.nodes(): + for state in sdfg.states(): for node in state.data_nodes(): for direction in border_memlets: if (node.label not in border_memlets[direction]): @@ -1139,20 +1153,18 @@ def propagate_memlets_nested_sdfg(parent_sdfg, parent_state, nsdfg_node): oedge.data.dynamic = True -def reset_state_annotations(sdfg): +def reset_state_annotations(sdfg: 'SDFG'): """ Resets the state (loop-related) annotations of an SDFG. :note: This operation is shallow (does not go into nested SDFGs). """ - for state in sdfg.nodes(): + for state in sdfg.states(): state.executions = 0 state.dynamic_executions = True state.ranges = {} - state.is_loop_guard = False - state.itervar = None -def propagate_memlets_sdfg(sdfg): +def propagate_memlets_sdfg(sdfg: 'SDFG'): """ Propagates memlets throughout an entire given SDFG. :note: This is an in-place operation on the SDFG. @@ -1160,13 +1172,13 @@ def propagate_memlets_sdfg(sdfg): # Reset previous annotations first reset_state_annotations(sdfg) - for state in sdfg.nodes(): + for state in sdfg.states(): propagate_memlets_state(sdfg, state) propagate_states(sdfg) -def propagate_memlets_state(sdfg, state): +def propagate_memlets_state(sdfg: 'SDFG', state: 'SDFGState'): """ Propagates memlets throughout one SDFG state. :param sdfg: The SDFG in which the state is situated. diff --git a/dace/sdfg/replace.py b/dace/sdfg/replace.py index e3bea0b807..2f9ead4dcd 100644 --- a/dace/sdfg/replace.py +++ b/dace/sdfg/replace.py @@ -1,9 +1,9 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Contains functionality to perform find-and-replace of symbols in SDFGs. """ import re import warnings -from typing import TYPE_CHECKING, Any, Dict, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, Optional import sympy as sp @@ -96,6 +96,38 @@ def replace(subgraph: 'StateSubgraphView', name: str, new_name: str): replace_dict(subgraph, {name: new_name}) +def replace_in_codeblock(codeblock: properties.CodeBlock, repl: Dict[str, str], node: Optional[Any] = None): + code = codeblock.code + if isinstance(code, str) and code: + lang = codeblock.language + if lang is dtypes.Language.CPP: # Replace in C++ code + prefix = '' + tokenized = tokenize_cpp.findall(code) + active_replacements = set() + for name, new_name in repl.items(): + if name not in tokenized: + continue + # Use local variables and shadowing to replace + replacement = f'auto {name} = {cppunparse.pyexpr2cpp(new_name)};\n' + prefix = replacement + prefix + active_replacements.add(name) + + if prefix: + codeblock.code = prefix + code + if node and isinstance(node, dace.nodes.Tasklet): + # Ignore replaced symbols since they no longer exist as reads + node.ignored_symbols = node.ignored_symbols.union(active_replacements) + + else: + warnings.warn('Replacement of %s with %s was not made ' + 'for string tasklet code of language %s' % (name, new_name, lang)) + + elif codeblock.code is not None: + afr = ASTFindReplace(repl) + for stmt in codeblock.code: + afr.visit(stmt) + + def replace_properties_dict(node: Any, repl: Dict[str, str], symrepl: Optional[Dict[symbolic.SymbolicType, symbolic.SymbolicType]] = None): @@ -127,35 +159,7 @@ def replace_properties_dict(node: Any, if hasattr(node, 'in_connectors'): reduced_repl -= set(node.in_connectors.keys()) | set(node.out_connectors.keys()) reduced_repl = {k: repl[k] for k in reduced_repl} - code = propval.code - if isinstance(code, str) and code: - lang = propval.language - if lang is dtypes.Language.CPP: # Replace in C++ code - prefix = '' - tokenized = tokenize_cpp.findall(code) - active_replacements = set() - for name, new_name in reduced_repl.items(): - if name not in tokenized: - continue - # Use local variables and shadowing to replace - replacement = f'auto {name} = {cppunparse.pyexpr2cpp(new_name)};\n' - prefix = replacement + prefix - active_replacements.add(name) - - if prefix: - propval.code = prefix + code - if isinstance(node, dace.nodes.Tasklet): - # Ignore replaced symbols since they no longer exist as reads - node.ignored_symbols = node.ignored_symbols.union(active_replacements) - - else: - warnings.warn('Replacement of %s with %s was not made ' - 'for string tasklet code of language %s' % (name, new_name, lang)) - - elif propval.code is not None: - afr = ASTFindReplace(reduced_repl) - for stmt in propval.code: - afr.visit(stmt) + replace_in_codeblock(propval, reduced_repl, node) elif (isinstance(propclass, properties.DictProperty) and pname == 'symbol_mapping'): # Symbol mappings for nested SDFGs for symname, sym_mapping in propval.items(): @@ -196,3 +200,6 @@ def replace_datadesc_names(sdfg: 'dace.SDFG', repl: Dict[str, str]): for edge in block.edges(): if edge.data.data in repl: edge.data.data = repl[edge.data.data] + + # Replace in loop or branch conditions: + cf.replace_meta_accesses(repl) diff --git a/dace/sdfg/sdfg.py b/dace/sdfg/sdfg.py index e3aee6ca51..4a141aef12 100644 --- a/dace/sdfg/sdfg.py +++ b/dace/sdfg/sdfg.py @@ -460,8 +460,8 @@ class SDFG(ControlFlowRegion): desc='Mapping between callback name and its original callback ' '(for when the same callback is used with a different signature)') - using_experimental_blocks = Property(dtype=bool, default=False, - desc="Whether the SDFG contains experimental control flow blocks") + using_explicit_control_flow = Property(dtype=bool, default=False, + desc="Whether the SDFG contains explicit control flow constructs") def __init__(self, name: str, @@ -1306,7 +1306,8 @@ def _used_symbols_internal(self, defined_syms: Optional[Set] = None, free_syms: Optional[Set] = None, used_before_assignment: Optional[Set] = None, - keep_defined_in_mapping: bool = False) -> Tuple[Set[str], Set[str], Set[str]]: + keep_defined_in_mapping: bool = False, + with_contents: bool = True) -> Tuple[Set[str], Set[str], Set[str]]: defined_syms = set() if defined_syms is None else defined_syms free_syms = set() if free_syms is None else free_syms used_before_assignment = set() if used_before_assignment is None else used_before_assignment @@ -1327,7 +1328,8 @@ def _used_symbols_internal(self, keep_defined_in_mapping=keep_defined_in_mapping, defined_syms=defined_syms, free_syms=free_syms, - used_before_assignment=used_before_assignment) + used_before_assignment=used_before_assignment, + with_contents=with_contents) def get_all_toplevel_symbols(self) -> Set[str]: """ @@ -1364,7 +1366,7 @@ def read_and_write_sets(self) -> Tuple[Set[AnyStr], Set[AnyStr]]: read_set = set() write_set = set() for state in self.states(): - for edge in self.in_edges(state): + for edge in state.parent_graph.in_edges(state): read_set |= edge.data.free_symbols & self.arrays.keys() # Get dictionaries of subsets read and written from each state rs, ws = state._read_and_write_sets() @@ -1522,7 +1524,7 @@ def transients(self): result = {} tstate = {} - for (i, state) in enumerate(self.nodes()): + for (i, state) in enumerate(self.states()): scope_dict = state.scope_dict() for node in state.nodes(): if isinstance(node, nd.AccessNode) and node.desc(self).transient: @@ -2842,14 +2844,14 @@ def make_array_memlet(self, array: str): """ return dace.Memlet.from_array(array, self.data(array)) - def recheck_using_experimental_blocks(self) -> bool: - found_experimental_block = False + def recheck_using_explicit_control_flow(self) -> bool: + found_explicit_cf_block = False for node, graph in self.root_sdfg.all_nodes_recursive(): if isinstance(graph, ControlFlowRegion) and not isinstance(graph, SDFG): - found_experimental_block = True + found_explicit_cf_block = True break if isinstance(node, ControlFlowBlock) and not isinstance(node, SDFGState): - found_experimental_block = True + found_explicit_cf_block = True break - self.root_sdfg.using_experimental_blocks = found_experimental_block - return found_experimental_block + self.root_sdfg.using_explicit_control_flow = found_explicit_cf_block + return found_explicit_cf_block diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index 6c36ea108d..30640306cd 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -8,13 +8,12 @@ import inspect import itertools import warnings -from typing import (TYPE_CHECKING, Any, AnyStr, Callable, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Union, +from typing import (TYPE_CHECKING, Any, AnyStr, Callable, Dict, Iterable, Iterator, List, Optional, Set, Tuple, Type, Union, overload) -import sympy - import dace from dace.frontend.python import astutils +from dace.sdfg.replace import replace_in_codeblock import dace.serialize from dace import data as dt from dace import dtypes @@ -22,11 +21,11 @@ from dace import serialize from dace import subsets as sbs from dace import symbolic -from dace.properties import (CodeBlock, DebugInfoProperty, DictProperty, EnumProperty, Property, SubsetProperty, SymbolicProperty, - CodeProperty, make_properties) +from dace.properties import (CodeBlock, DebugInfoProperty, DictProperty, EnumProperty, Property, SubsetProperty, + SymbolicProperty, CodeProperty, make_properties) from dace.sdfg import nodes as nd -from dace.sdfg.graph import (MultiConnectorEdge, NodeNotFoundError, OrderedMultiDiConnectorGraph, SubgraphView, - OrderedDiGraph, Edge, generate_element_id) +from dace.sdfg.graph import (MultiConnectorEdge, NodeNotFoundError, OrderedMultiDiConnectorGraph, SubgraphView, OrderedDiGraph, Edge, + generate_element_id) from dace.sdfg.propagation import propagate_memlet from dace.sdfg.validation import validate_state from dace.subsets import Range, Subset @@ -217,13 +216,18 @@ def edges_by_connector(self, node: nd.Node, connector: AnyStr) -> Iterable[Multi # Query, subgraph, and replacement methods @abc.abstractmethod - def used_symbols(self, all_symbols: bool, keep_defined_in_mapping: bool = False) -> Set[str]: + def used_symbols(self, all_symbols: bool, keep_defined_in_mapping: bool = False, + with_contents: bool = True) -> Set[str]: """ Returns a set of symbol names that are used in the graph. :param all_symbols: If False, only returns symbols that are needed as arguments (only used in generated code). :param keep_defined_in_mapping: If True, symbols defined in inter-state edges that are in the symbol mapping will be removed from the set of defined symbols. + :param with_contents: Compute the symbols used including the ones used by the contents of the graph. If set to + False, only symbols used on the BlockGraphView itself are returned. The latter may + include symbols used in the conditions of conditional blocks, loops, etc. Defaults to + True. """ return set() @@ -647,7 +651,11 @@ def is_leaf_memlet(self, e): return False return True - def used_symbols(self, all_symbols: bool, keep_defined_in_mapping: bool = False) -> Set[str]: + def used_symbols(self, all_symbols: bool, keep_defined_in_mapping: bool = False, + with_contents: bool = True) -> Set[str]: + if not with_contents: + return set() + state = self.graph if isinstance(self, SubgraphView) else self sdfg = state.sdfg new_symbols = set() @@ -695,17 +703,6 @@ def used_symbols(self, all_symbols: bool, keep_defined_in_mapping: bool = False) new_symbols.update(set(sdfg.constants.keys())) return freesyms - new_symbols - @property - def free_symbols(self) -> Set[str]: - """ - Returns a set of symbol names that are used, but not defined, in - this graph view (SDFG state or subgraph thereof). - - :note: Assumes that the graph is valid (i.e., without undefined or - overlapping symbols). - """ - return self.used_symbols(all_symbols=True) - def defined_symbols(self) -> Dict[str, dt.Data]: """ Returns a dictionary that maps currently-defined symbols in this SDFG @@ -728,8 +725,12 @@ def update_if_not_none(dic, update): defined_syms[str(sym)] = sym.dtype # Add inter-state symbols - for edge in sdfg.dfs_edges(sdfg.start_state): + if isinstance(sdfg.start_block, AbstractControlFlowRegion): + update_if_not_none(defined_syms, sdfg.start_block.new_symbols(defined_syms)) + for edge in sdfg.all_interstate_edges(): update_if_not_none(defined_syms, edge.data.new_symbols(sdfg, defined_syms)) + if isinstance(edge.dst, AbstractControlFlowRegion): + update_if_not_none(defined_syms, edge.dst.new_symbols(defined_syms)) # Add scope symbols all the way to the subgraph sdict = state.scope_dict() @@ -1029,6 +1030,18 @@ def nodes(self) -> List['ControlFlowBlock']: def edges(self) -> List[Edge['dace.sdfg.InterstateEdge']]: ... + @overload + def in_edges(self, node: 'ControlFlowBlock') -> List[Edge['dace.sdfg.InterstateEdge']]: + ... + + @overload + def out_edges(self, node: 'ControlFlowBlock') -> List[Edge['dace.sdfg.InterstateEdge']]: + ... + + @overload + def all_edges(self, node: 'ControlFlowBlock') -> List[Edge['dace.sdfg.InterstateEdge']]: + ... + ################################################################### # Traversal methods @@ -1103,11 +1116,14 @@ def _used_symbols_internal(self, defined_syms: Optional[Set] = None, free_syms: Optional[Set] = None, used_before_assignment: Optional[Set] = None, - keep_defined_in_mapping: bool = False) -> Tuple[Set[str], Set[str], Set[str]]: + keep_defined_in_mapping: bool = False, + with_contents: bool = True) -> Tuple[Set[str], Set[str], Set[str]]: raise NotImplementedError() - def used_symbols(self, all_symbols: bool, keep_defined_in_mapping: bool = False) -> Set[str]: - return self._used_symbols_internal(all_symbols, keep_defined_in_mapping=keep_defined_in_mapping)[0] + def used_symbols(self, all_symbols: bool, keep_defined_in_mapping: bool = False, + with_contents: bool = True) -> Set[str]: + return self._used_symbols_internal(all_symbols, keep_defined_in_mapping=keep_defined_in_mapping, + with_contents=with_contents)[0] def read_and_write_sets(self) -> Tuple[Set[AnyStr], Set[AnyStr]]: read_set = set() @@ -1146,6 +1162,8 @@ def all_transients(self) -> List[str]: def replace(self, name: str, new_name: str): for n in self.nodes(): n.replace(name, new_name) + for e in self.edges(): + e.data.replace(name, new_name) def replace_dict(self, repl: Dict[str, str], @@ -1177,6 +1195,12 @@ class ControlFlowBlock(BlockGraphView, abc.ABC): pre_conditions = DictProperty(key_type=str, value_type=list, desc='Pre-conditions for this block') post_conditions = DictProperty(key_type=str, value_type=list, desc='Post-conditions for this block') invariant_conditions = DictProperty(key_type=str, value_type=list, desc='Invariant conditions for this block') + ranges = DictProperty(key_type=str, value_type=Range, default={}, + desc='Variable ranges across this block, typically within loops') + + executions = SymbolicProperty(default=0, + desc="The number of times this block gets executed (0 stands for unbounded)") + dynamic_executions = Property(dtype=bool, default=True, desc="The number of executions of this block is dynamic") _label: str @@ -1203,6 +1227,9 @@ def nodes(self): def edges(self): return [] + def sub_regions(self) -> List['AbstractControlFlowRegion']: + return [] + def set_default_lineinfo(self, lineinfo: dace.dtypes.DebugInfo): """ Sets the default source line information to be lineinfo, or None to @@ -1249,7 +1276,7 @@ def __deepcopy__(self, memo): result = cls.__new__(cls) memo[id(self)] = result for k, v in self.__dict__.items(): - if k in ('_parent_graph', '_sdfg', 'guid'): # Skip derivative attributes and GUID + if k in ('_parent_graph', '_sdfg', '_cfg_list', 'guid'): # Skip derivative attributes and GUID continue setattr(result, k, copy.deepcopy(v, memo)) @@ -1311,17 +1338,6 @@ class SDFGState(OrderedMultiDiConnectorGraph[nd.Node, mm.Memlet], ControlFlowBlo symbol_instrument_condition = CodeProperty(desc="Condition under which to trigger the symbol instrumentation", default=CodeBlock("1", language=dtypes.Language.CPP)) - executions = SymbolicProperty(default=0, - desc="The number of times this state gets " - "executed (0 stands for unbounded)") - dynamic_executions = Property(dtype=bool, default=True, desc="The number of executions of this state " - "is dynamic") - - ranges = DictProperty(key_type=symbolic.symbol, - value_type=Range, - default={}, - desc='Variable ranges, typically within loops') - location = DictProperty(key_type=str, value_type=symbolic.pystr_to_symbolic, desc='Full storage location identifier (e.g., rank, GPU ID)') @@ -2565,19 +2581,55 @@ def sdfg(self) -> 'SDFG': @make_properties -class ControlFlowRegion(OrderedDiGraph[ControlFlowBlock, 'dace.sdfg.InterstateEdge'], ControlGraphView, - ControlFlowBlock): +class AbstractControlFlowRegion(OrderedDiGraph[ControlFlowBlock, 'dace.sdfg.InterstateEdge'], ControlGraphView, + ControlFlowBlock, abc.ABC): + """ + Abstract superclass to represent all kinds of control flow regions in an SDFG. + This is consequently one of the three main classes of control flow graph nodes, which include ``ControlFlowBlock``s, + ``SDFGState``s, and nested ``AbstractControlFlowRegion``s. An ``AbstractControlFlowRegion`` can further be either a + region that directly contains a control flow graph (``ControlFlowRegion``s and subclasses thereof), or something + that acts like and has the same utilities as a control flow region, including the same API, but is itself not + directly a single graph. An example of this is the ``ConditionalBlock``, which acts as a single control flow region + to the outside, but contains multiple actual graphs (one per branch). As such, there are very few but important + differences between the subclasses of ``AbstractControlFlowRegion``s, such as how traversals are performed, how many + start blocks there are, etc. + """ - def __init__(self, label: str = '', sdfg: Optional['SDFG'] = None): + def __init__(self, label: str = '', sdfg: Optional['SDFG'] = None, + parent: Optional['AbstractControlFlowRegion'] = None): OrderedDiGraph.__init__(self) ControlGraphView.__init__(self) - ControlFlowBlock.__init__(self, label, sdfg) + ControlFlowBlock.__init__(self, label, sdfg, parent) self._labels: Set[str] = set() self._start_block: Optional[int] = None self._cached_start_block: Optional[ControlFlowBlock] = None self._cfg_list: List['ControlFlowRegion'] = [self] + def get_meta_codeblocks(self) -> List[CodeBlock]: + """ + Get a list of codeblocks used by the control flow region. + This may include things such as loop control statements or conditions for branching etc. + """ + return [] + + def get_meta_read_memlets(self) -> List[mm.Memlet]: + """ + Get read memlets used by the control flow region itself, such as in condition checks for conditional blocks, or + in loop conditions for loops etc. + """ + return [] + + def replace_meta_accesses(self, replacements: dict) -> None: + """ + Replace accesses to specific data containers in reads or writes performed by the control flow region itself in + meta accesses, such as in condition checks for conditional blocks or in loop conditions for loops, etc. + + :param replacements: A dictionary mapping the current data container names to the names of data containers with + which accesses to them should be replaced. + """ + pass + @property def root_sdfg(self) -> 'SDFG': from dace.sdfg.sdfg import SDFG # Avoid import loop @@ -2585,7 +2637,7 @@ def root_sdfg(self) -> 'SDFG': raise RuntimeError('Root CFG is not of type SDFG') return self.cfg_list[0] - def reset_cfg_list(self) -> List['ControlFlowRegion']: + def reset_cfg_list(self) -> List['AbstractControlFlowRegion']: """ Reset the CFG list when changes have been made to the SDFG's CFG tree. This collects all control flow graphs recursively and propagates the collection to all CFGs as the new CFG list. @@ -2635,29 +2687,37 @@ def state(self, state_id: int) -> SDFGState: raise TypeError(f'The node with id {state_id} is not an SDFGState') return node - def inline(self) -> Tuple[bool, Any]: + def inline(self, lower_returns: bool = False) -> Tuple[bool, Any]: """ Inlines the control flow region into its parent control flow region (if it exists). + :param lower_returns: Whether or not to remove explicit return blocks when inlining where possible. Defaults to + False. :return: True if the inlining succeeded, false otherwise. """ parent = self.parent_graph if parent: # Add all region states and make sure to keep track of all the ones that need to be connected in the end. - to_connect: Set[SDFGState] = set() + to_connect: Set[ControlFlowBlock] = set() + ends_context: Set[ControlFlowBlock] = set() block_to_state_map: Dict[ControlFlowBlock, SDFGState] = dict() for node in self.nodes(): node.label = self.label + '_' + node.label - if isinstance(node, ReturnBlock) and isinstance(parent, dace.SDFG): + if isinstance(node, ReturnBlock) and lower_returns and isinstance(parent, dace.SDFG): # If a return block is being inlined into an SDFG, convert it into a regular state. Otherwise it # remains as-is. newnode = parent.add_state(node.label) block_to_state_map[node] = newnode + if self.out_degree(node) == 0: + to_connect.add(newnode) + ends_context.add(newnode) else: parent.add_node(node, ensure_unique_name=True) - if self.out_degree(node) == 0 and not isinstance(node, (BreakBlock, ContinueBlock, ReturnBlock)): + if self.out_degree(node) == 0: to_connect.add(node) + if isinstance(node, (BreakBlock, ContinueBlock, ReturnBlock)): + ends_context.add(node) # Add all region edges. for edge in self.edges(): @@ -2679,16 +2739,10 @@ def inline(self) -> Tuple[bool, Any]: parent.remove_edge(a_edge) for node in to_connect: - parent.add_edge(node, end_state, dace.InterstateEdge()) - else: - # TODO: Move this to dead state elimination. - dead_blocks = [succ for succ in parent.successors(self) if parent.in_degree(succ) == 1] - while dead_blocks: - layer = list(dead_blocks) - dead_blocks.clear() - for u in layer: - dead_blocks.extend([succ for succ in parent.successors(u) if parent.in_degree(succ) == 1]) - parent.remove_node(u) + if node in ends_context: + parent.add_edge(node, end_state, dace.InterstateEdge(condition='False')) + else: + parent.add_edge(node, end_state, dace.InterstateEdge()) # Remove the original control flow region (self) from the parent graph. parent.remove_node(self) @@ -2699,6 +2753,12 @@ def inline(self) -> Tuple[bool, Any]: return False, None + def new_symbols(self, symbols: Dict[str, dtypes.typeclass]) -> Dict[str, dtypes.typeclass]: + """ + Returns a mapping between the symbol defined by this control flow region and its type, if it exists. + """ + return {} + ################################################################### # CFG API methods @@ -2747,9 +2807,13 @@ def add_node(self, self._cached_start_block = None node.parent_graph = self if isinstance(self, dace.SDFG): - node.sdfg = self + sdfg = self else: - node.sdfg = self.sdfg + sdfg = self.sdfg + node.sdfg = sdfg + if isinstance(node, AbstractControlFlowRegion): + for n in node.all_control_flow_blocks(): + n.sdfg = self.sdfg start_block = is_start_block if is_start_state is not None: warnings.warn('is_start_state is deprecated, use is_start_block instead', DeprecationWarning) @@ -2825,23 +2889,27 @@ def add_state_after(self, ################################################################### # Traversal methods - def all_control_flow_regions(self, recursive=False, load_ext=False) -> Iterator['ControlFlowRegion']: + def all_control_flow_regions(self, recursive=False, load_ext=False, + parent_first=True) -> Iterator['AbstractControlFlowRegion']: """ Iterate over this and all nested control flow regions. """ - yield self + if parent_first: + yield self for block in self.nodes(): if isinstance(block, SDFGState) and recursive: for node in block.nodes(): if isinstance(node, nd.NestedSDFG): if node.sdfg: - yield from node.sdfg.all_control_flow_regions(recursive=recursive, load_ext=load_ext) + yield from node.sdfg.all_control_flow_regions(recursive=recursive, load_ext=load_ext, + parent_first=parent_first) elif load_ext: node.load_external(block) - yield from node.sdfg.all_control_flow_regions(recursive=recursive, load_ext=load_ext) - elif isinstance(block, ControlFlowRegion): - yield from block.all_control_flow_regions(recursive=recursive, load_ext=load_ext) - elif isinstance(block, ConditionalBlock): - for _, branch in block.branches: - yield from branch.all_control_flow_regions(recursive=recursive, load_ext=load_ext) + yield from node.sdfg.all_control_flow_regions(recursive=recursive, load_ext=load_ext, + parent_first=parent_first) + elif isinstance(block, AbstractControlFlowRegion): + yield from block.all_control_flow_regions(recursive=recursive, load_ext=load_ext, + parent_first=parent_first) + if not parent_first: + yield self def all_sdfgs_recursive(self, load_ext=False) -> Iterator['SDFG']: """ Iterate over this and all nested SDFGs. """ @@ -2854,11 +2922,8 @@ def all_states(self) -> Iterator[SDFGState]: for block in self.nodes(): if isinstance(block, SDFGState): yield block - elif isinstance(block, ControlFlowRegion): + elif isinstance(block, AbstractControlFlowRegion): yield from block.all_states() - elif isinstance(block, ConditionalBlock): - for _, region in block.branches: - yield from region.all_states() def all_control_flow_blocks(self, recursive=False) -> Iterator[ControlFlowBlock]: """ Iterate over all control flow blocks in this control flow graph. """ @@ -2880,45 +2945,45 @@ def _used_symbols_internal(self, defined_syms: Optional[Set] = None, free_syms: Optional[Set] = None, used_before_assignment: Optional[Set] = None, - keep_defined_in_mapping: bool = False) -> Tuple[Set[str], Set[str], Set[str]]: + keep_defined_in_mapping: bool = False, + with_contents: bool = True) -> Tuple[Set[str], Set[str], Set[str]]: defined_syms = set() if defined_syms is None else defined_syms free_syms = set() if free_syms is None else free_syms used_before_assignment = set() if used_before_assignment is None else used_before_assignment - try: - ordered_blocks = self.bfs_nodes(self.start_block) - except ValueError: # Failsafe (e.g., for invalid or empty SDFGs) - ordered_blocks = self.nodes() - - for block in ordered_blocks: - state_symbols = set() - if isinstance(block, (ControlFlowRegion, ConditionalBlock)): - b_free_syms, b_defined_syms, b_used_before_syms = block._used_symbols_internal(all_symbols, - defined_syms, - free_syms, - used_before_assignment, - keep_defined_in_mapping) - free_syms |= b_free_syms - defined_syms |= b_defined_syms - used_before_assignment |= b_used_before_syms - state_symbols = b_free_syms - else: - state_symbols = block.used_symbols(all_symbols, keep_defined_in_mapping) - free_syms |= state_symbols - - # Add free inter-state symbols - for e in self.out_edges(block): - # NOTE: First we get the true InterstateEdge free symbols, then we compute the newly defined symbols by - # subracting the (true) free symbols from the edge's assignment keys. This way we can correctly - # compute the symbols that are used before being assigned. - efsyms = e.data.used_symbols(all_symbols) - # collect symbols representing data containers - dsyms = {sym for sym in efsyms if sym in self.sdfg.arrays} - for d in dsyms: - efsyms |= {str(sym) for sym in self.sdfg.arrays[d].used_symbols(all_symbols)} - defined_syms |= set(e.data.assignments.keys()) - (efsyms | state_symbols) - used_before_assignment.update(efsyms - defined_syms) - free_syms |= efsyms + if with_contents: + try: + ordered_blocks = self.bfs_nodes(self.start_block) + except ValueError: # Failsafe (e.g., for invalid or empty SDFGs) + ordered_blocks = self.nodes() + + for block in ordered_blocks: + state_symbols = set() + if isinstance(block, (ControlFlowRegion, ConditionalBlock)): + b_free_syms, b_defined_syms, b_used_before_syms = block._used_symbols_internal( + all_symbols, defined_syms, free_syms, used_before_assignment, keep_defined_in_mapping, + with_contents) + free_syms |= b_free_syms + defined_syms |= b_defined_syms + used_before_assignment |= b_used_before_syms + state_symbols = b_free_syms + else: + state_symbols = block.used_symbols(all_symbols, keep_defined_in_mapping, with_contents) + free_syms |= state_symbols + + # Add free inter-state symbols + for e in self.out_edges(block): + # NOTE: First we get the true InterstateEdge free symbols, then we compute the newly defined symbols + # by subracting the (true) free symbols from the edge's assignment keys. This way we can correctly + # compute the symbols that are used before being assigned. + efsyms = e.data.used_symbols(all_symbols) + # collect symbols representing data containers + dsyms = {sym for sym in efsyms if sym in self.sdfg.arrays} + for d in dsyms: + efsyms |= {str(sym) for sym in self.sdfg.arrays[d].used_symbols(all_symbols)} + defined_syms |= set(e.data.assignments.keys()) - (efsyms | state_symbols) + used_before_assignment.update(efsyms - defined_syms) + free_syms |= efsyms # Remove symbols that were used before they were assigned. defined_syms -= used_before_assignment @@ -3028,7 +3093,19 @@ def start_block(self, block_id): if block_id < 0 or block_id >= self.number_of_nodes(): raise ValueError('Invalid state ID') self._start_block = block_id - self._cached_start_block = self.node(block_id) + self._cached_start_block = None + + +@make_properties +class ControlFlowRegion(AbstractControlFlowRegion): + """ + A ``ControlFlowRegion`` represents a control flow graph node that itself contains a control flow graph. + This can be an arbitrary control flow graph, but may also be a specific type of control flow region with additional + semantics, such as a loop or a function call. + """ + + def __init__(self, label = '', sdfg = None, parent = None): + super().__init__(label, sdfg, parent) @make_properties @@ -3112,10 +3189,12 @@ def __init__(self, self.inverted = inverted self.update_before_condition = update_before_condition - def inline(self) -> Tuple[bool, Any]: + def inline(self, lower_returns: bool = False) -> Tuple[bool, Any]: """ Inlines the loop region into its parent control flow region. + :param lower_returns: Whether or not to remove explicit return blocks when inlining where possible. Defaults to + False. :return: True if the inlining succeeded, false otherwise. """ parent = self.parent_graph @@ -3141,9 +3220,10 @@ def inline(self) -> Tuple[bool, Any]: # and return are inlined correctly. def recursive_inline_cf_regions(region: ControlFlowRegion) -> None: for block in region.nodes(): - if (isinstance(block, ControlFlowRegion) or isinstance(block, ConditionalBlock)) and not isinstance(block, LoopRegion): + if ((isinstance(block, ControlFlowRegion) or isinstance(block, ConditionalBlock)) + and not isinstance(block, LoopRegion)): recursive_inline_cf_regions(block) - block.inline() + block.inline(lower_returns=lower_returns) recursive_inline_cf_regions(self) # Add all boilerplate loop states necessary for the structure. @@ -3232,12 +3312,38 @@ def recursive_inline_cf_regions(region: ControlFlowRegion) -> None: return True, (init_state, guard_state, end_state) + def get_meta_codeblocks(self): + codes = [self.loop_condition] + if self.init_statement: + codes.append(self.init_statement) + if self.update_statement: + codes.append(self.update_statement) + return codes + + def get_meta_read_memlets(self) -> List[mm.Memlet]: + # Avoid cyclic imports. + from dace.sdfg.sdfg import memlets_in_ast + read_memlets = memlets_in_ast(self.loop_condition.code[0], self.sdfg.arrays) + if self.init_statement: + read_memlets.extend(memlets_in_ast(self.init_statement.code[0], self.sdfg.arrays)) + if self.update_statement: + read_memlets.extend(memlets_in_ast(self.update_statement.code[0], self.sdfg.arrays)) + return read_memlets + + def replace_meta_accesses(self, replacements): + replace_in_codeblock(self.loop_condition, replacements) + if self.init_statement: + replace_in_codeblock(self.init_statement, replacements) + if self.update_statement: + replace_in_codeblock(self.update_statement, replacements) + def _used_symbols_internal(self, all_symbols: bool, defined_syms: Optional[Set] = None, free_syms: Optional[Set] = None, used_before_assignment: Optional[Set] = None, - keep_defined_in_mapping: bool = False) -> Tuple[Set[str], Set[str], Set[str]]: + keep_defined_in_mapping: bool = False, + with_contents: bool = True) -> Tuple[Set[str], Set[str], Set[str]]: defined_syms = set() if defined_syms is None else defined_syms free_syms = set() if free_syms is None else free_syms used_before_assignment = set() if used_before_assignment is None else used_before_assignment @@ -3247,10 +3353,12 @@ def _used_symbols_internal(self, free_syms |= self.init_statement.get_free_symbols() if self.update_statement is not None: free_syms |= self.update_statement.get_free_symbols() - free_syms |= self.loop_condition.get_free_symbols() + cond_free_syms = self.loop_condition.get_free_symbols() + if self.loop_variable and self.loop_variable in cond_free_syms: + cond_free_syms.remove(self.loop_variable) b_free_symbols, b_defined_symbols, b_used_before_assignment = super()._used_symbols_internal( - all_symbols, keep_defined_in_mapping=keep_defined_in_mapping) + all_symbols, keep_defined_in_mapping=keep_defined_in_mapping, with_contents=with_contents) outside_defined = defined_syms - used_before_assignment used_before_assignment |= ((b_used_before_assignment - {self.loop_variable}) - outside_defined) free_syms |= b_free_symbols @@ -3258,21 +3366,41 @@ def _used_symbols_internal(self, defined_syms -= used_before_assignment free_syms -= defined_syms + free_syms |= cond_free_syms return free_syms, defined_syms, used_before_assignment + def new_symbols(self, symbols) -> Dict[str, dtypes.typeclass]: + # Avoid cyclic import + from dace.codegen.tools.type_inference import infer_expr_type + from dace.transformation.passes.analysis import loop_analysis + + if self.init_statement and self.loop_variable: + alltypes = copy.copy(symbols) + alltypes.update({k: v.dtype for k, v in self.sdfg.arrays.items()}) + l_end = loop_analysis.get_loop_end(self) + l_start = loop_analysis.get_init_assignment(self) + l_step = loop_analysis.get_loop_stride(self) + inferred_type = dtypes.result_type_of(infer_expr_type(l_start, alltypes), + infer_expr_type(l_step, alltypes), + infer_expr_type(l_end, alltypes)) + init_rhs = loop_analysis.get_init_assignment(self) + if self.loop_variable not in symbolic.free_symbols_and_functions(init_rhs): + return {self.loop_variable: inferred_type} + return {} + def replace_dict(self, repl: Dict[str, str], symrepl: Optional[Dict[symbolic.SymbolicType, symbolic.SymbolicType]] = None, replace_in_graph: bool = True, replace_keys: bool = True): if replace_keys: - from dace.sdfg.replace import replace_properties_dict - replace_properties_dict(self, repl, symrepl) - if self.loop_variable and self.loop_variable in repl: self.loop_variable = repl[self.loop_variable] + from dace.sdfg.replace import replace_properties_dict + replace_properties_dict(self, repl, symrepl) + super().replace_dict(repl, symrepl, replace_in_graph) def add_break(self, label=None) -> BreakBlock: @@ -3312,7 +3440,7 @@ def has_return(self) -> bool: @make_properties -class ConditionalBlock(ControlFlowBlock, ControlGraphView): +class ConditionalBlock(AbstractControlFlowRegion): _branches: List[Tuple[Optional[CodeBlock], ControlFlowRegion]] @@ -3320,6 +3448,14 @@ def __init__(self, label: str = '', sdfg: Optional['SDFG'] = None, parent: Optio super().__init__(label, sdfg, parent) self._branches = [] + def sub_regions(self): + return [b for _, b in self.branches] + + def replace_meta_accesses(self, replacements): + for c, _ in self.branches: + if c is not None: + replace_in_codeblock(c, replacements) + def __str__(self): return self._label @@ -3332,21 +3468,35 @@ def branches(self) -> List[Tuple[Optional[CodeBlock], ControlFlowRegion]]: def add_branch(self, condition: Optional[CodeBlock], branch: ControlFlowRegion): self._branches.append([condition, branch]) - branch.parent_graph = self.parent_graph + branch.parent_graph = self branch.sdfg = self.sdfg - def nodes(self) -> List['ControlFlowBlock']: - return [node for _, node in self._branches if node is not None] - - def edges(self) -> List[Edge['dace.sdfg.InterstateEdge']]: - return [] + def remove_branch(self, branch: ControlFlowRegion): + self._branches = [(c, b) for c, b in self._branches if b is not branch] + + def get_meta_codeblocks(self): + codes = [] + for c, _ in self.branches: + if c is not None: + codes.append(c) + return codes + + def get_meta_read_memlets(self) -> List[mm.Memlet]: + # Avoid cyclic imports. + from dace.sdfg.sdfg import memlets_in_ast + read_memlets = [] + for c, _ in self.branches: + if c is not None: + read_memlets.extend(memlets_in_ast(c.code[0], self.sdfg.arrays)) + return read_memlets def _used_symbols_internal(self, all_symbols: bool, defined_syms: Optional[Set] = None, free_syms: Optional[Set] = None, used_before_assignment: Optional[Set] = None, - keep_defined_in_mapping: bool = False) -> Tuple[Set[str], Set[str], Set[str]]: + keep_defined_in_mapping: bool = False, + with_contents: bool = True) -> Tuple[Set[str], Set[str], Set[str]]: defined_syms = set() if defined_syms is None else defined_syms free_syms = set() if free_syms is None else free_syms used_before_assignment = set() if used_before_assignment is None else used_before_assignment @@ -3355,7 +3505,7 @@ def _used_symbols_internal(self, if condition is not None: free_syms |= condition.get_free_symbols(defined_syms) b_free_symbols, b_defined_symbols, b_used_before_assignment = region._used_symbols_internal( - all_symbols, defined_syms, free_syms, used_before_assignment, keep_defined_in_mapping) + all_symbols, defined_syms, free_syms, used_before_assignment, keep_defined_in_mapping, with_contents) free_syms |= b_free_symbols defined_syms |= b_defined_symbols used_before_assignment |= b_used_before_assignment @@ -3370,12 +3520,17 @@ def replace_dict(self, symrepl: Optional[Dict[symbolic.SymbolicType, symbolic.SymbolicType]] = None, replace_in_graph: bool = True, replace_keys: bool = True): + # Avoid circular imports + from dace.sdfg.replace import replace_in_codeblock + if replace_keys: from dace.sdfg.replace import replace_properties_dict replace_properties_dict(self, repl, symrepl) - for _, region in self._branches: + for cond, region in self._branches: region.replace_dict(repl, symrepl, replace_in_graph) + if cond is not None: + replace_in_codeblock(cond, repl) def to_json(self, parent=None): json = super().to_json(parent) @@ -3396,15 +3551,17 @@ def from_json(cls, json_obj, context=None): for condition, region in json_obj['branches']: if condition is not None: - ret._branches.append((CodeBlock.from_json(condition), ControlFlowRegion.from_json(region, context))) + ret.add_branch(CodeBlock.from_json(condition), ControlFlowRegion.from_json(region, context)) else: - ret._branches.append((None, ControlFlowRegion.from_json(region, context))) + ret.add_branch(None, ControlFlowRegion.from_json(region, context)) return ret - - def inline(self) -> Tuple[bool, Any]: + + def inline(self, lower_returns: bool = False) -> Tuple[bool, Any]: """ Inlines the conditional region into its parent control flow region. + :param lower_returns: Whether or not to remove explicit return blocks when inlining where possible. Defaults to + False. :return: True if the inlining succeeded, false otherwise. """ parent = self.parent_graph @@ -3438,6 +3595,7 @@ def inline(self) -> Tuple[bool, Any]: parent.add_node(region) parent.add_edge(guard_state, region, InterstateEdge(condition=condition)) parent.add_edge(region, end_state, InterstateEdge()) + region.inline(lower_returns=lower_returns) if full_cond_expression is not None: negative_full_cond = astutils.negate_expr(full_cond_expression) negative_cond = CodeBlock([negative_full_cond]) @@ -3447,7 +3605,8 @@ def inline(self) -> Tuple[bool, Any]: if else_branch is not None: parent.add_node(else_branch) parent.add_edge(guard_state, else_branch, InterstateEdge(condition=negative_cond)) - parent.add_edge(region, end_state, InterstateEdge()) + parent.add_edge(else_branch, end_state, InterstateEdge()) + else_branch.inline(lower_returns=lower_returns) else: parent.add_edge(guard_state, end_state, InterstateEdge(condition=negative_cond)) @@ -3458,6 +3617,42 @@ def inline(self) -> Tuple[bool, Any]: return True, (guard_state, end_state) + # Abstract control flow region overrides + + @property + def start_block(self): + return None + + @start_block.setter + def start_block(self, _): + pass + + # Graph API overrides. + + def node_id(self, node: 'ControlFlowBlock') -> int: + try: + return next(i for i, (_, b) in enumerate(self._branches) if b is node) + except StopIteration: + raise NodeNotFoundError(node) + + def nodes(self) -> List['ControlFlowBlock']: + return [node for _, node in self._branches] + + def number_of_nodes(self): + return len(self._branches) + + def edges(self) -> List[Edge['dace.sdfg.InterstateEdge']]: + return [] + + def in_edges(self, _: 'ControlFlowBlock') -> List[Edge['dace.sdfg.InterstateEdge']]: + return [] + + def out_edges(self, _: 'ControlFlowBlock') -> List[Edge['dace.sdfg.InterstateEdge']]: + return [] + + def all_edges(self, _: 'ControlFlowBlock') -> List[Edge['dace.sdfg.InterstateEdge']]: + return [] + @make_properties class NamedRegion(ControlFlowRegion): diff --git a/dace/sdfg/utils.py b/dace/sdfg/utils.py index 623a82c5bf..26b6629a81 100644 --- a/dace/sdfg/utils.py +++ b/dace/sdfg/utils.py @@ -11,14 +11,15 @@ import dace.sdfg.nodes from dace.codegen import compiled_sdfg as csdfg from dace.sdfg.graph import MultiConnectorEdge -from dace.sdfg.sdfg import SDFG +from dace.sdfg.sdfg import SDFG, InterstateEdge from dace.sdfg.nodes import Node, NestedSDFG -from dace.sdfg.state import ConditionalBlock, SDFGState, StateSubgraphView, LoopRegion, ControlFlowRegion +from dace.sdfg.state import (AbstractControlFlowRegion, ConditionalBlock, ControlFlowBlock, SDFGState, + StateSubgraphView, LoopRegion, ControlFlowRegion) from dace.sdfg.scope import ScopeSubgraphView from dace.sdfg import nodes as nd, graph as gr, propagation from dace import config, data as dt, dtypes, memlet as mm, subsets as sbs from dace.cli.progress import optional_progressbar -from typing import Any, Callable, Dict, Generator, List, Optional, Set, Sequence, Tuple, Union +from typing import Any, Callable, Dict, Generator, List, Optional, Set, Sequence, Tuple, Type, Union def node_path_graph(*args) -> gr.OrderedDiGraph: @@ -1235,7 +1236,8 @@ def fuse_states(sdfg: SDFG, permissive: bool = False, progress: bool = None) -> shows progress bar. :return: The total number of states fused. """ - from dace.transformation.interstate import StateFusion # Avoid import loop + from dace.transformation.interstate import StateFusion, BlockFusion # Avoid import loop + if progress is None and not config.Config.get_bool('progress'): progress = False @@ -1268,20 +1270,33 @@ def fuse_states(sdfg: SDFG, permissive: bool = False, progress: bool = None) -> progress = True pbar = tqdm(total=fusible_states, desc='Fusing states', initial=counter) - if (u in skip_nodes or v in skip_nodes or not isinstance(v, SDFGState) - or not isinstance(u, SDFGState)): + if u in skip_nodes or v in skip_nodes: continue - candidate = {StateFusion.first_state: u, StateFusion.second_state: v} - sf = StateFusion() - sf.setup_match(cfg, cfg.cfg_id, -1, candidate, 0, override=True) - if sf.can_be_applied(cfg, 0, sd, permissive=permissive): - sf.apply(cfg, sd) - applied += 1 - counter += 1 - if progress: - pbar.update(1) - skip_nodes.add(u) - skip_nodes.add(v) + + if isinstance(u, SDFGState) and isinstance(v, SDFGState): + candidate = {StateFusion.first_state: u, StateFusion.second_state: v} + sf = StateFusion() + sf.setup_match(cfg, cfg.cfg_id, -1, candidate, 0, override=True) + if sf.can_be_applied(cfg, 0, sd, permissive=permissive): + sf.apply(cfg, sd) + applied += 1 + counter += 1 + if progress: + pbar.update(1) + skip_nodes.add(u) + skip_nodes.add(v) + else: + candidate = {BlockFusion.first_block: u, BlockFusion.second_block: v} + bf = BlockFusion() + bf.setup_match(cfg, cfg.cfg_id, -1, candidate, 0, override=True) + if bf.can_be_applied(cfg, 0, sd, permissive=permissive): + bf.apply(cfg, sd) + applied += 1 + counter += 1 + if progress: + pbar.update(1) + skip_nodes.add(u) + skip_nodes.add(v) if applied == 0: break if progress: @@ -1289,20 +1304,17 @@ def fuse_states(sdfg: SDFG, permissive: bool = False, progress: bool = None) -> return counter -def inline_loop_blocks(sdfg: SDFG, permissive: bool = False, progress: bool = None) -> int: - blocks = [n for n, _ in sdfg.all_nodes_recursive() if isinstance(n, LoopRegion)] - count = 0 - - for _block in optional_progressbar(reversed(blocks), title='Inlining Loops', n=len(blocks), progress=progress): - block: LoopRegion = _block - if block.inline()[0]: - count += 1 - - return count - - -def inline_control_flow_regions(sdfg: SDFG, permissive: bool = False, progress: bool = None) -> int: - blocks = [n for n, _ in sdfg.all_nodes_recursive() if isinstance(n, ControlFlowRegion)] +def inline_control_flow_regions(sdfg: SDFG, types: Optional[List[Type[AbstractControlFlowRegion]]] = None, + ignore_region_types: Optional[List[Type[AbstractControlFlowRegion]]] = None, + progress: bool = None, lower_returns: bool = False, + eliminate_dead_states: bool = False) -> int: + if types: + blocks = [n for n, _ in sdfg.all_nodes_recursive() if type(n) in types] + elif ignore_region_types: + blocks = [n for n, _ in sdfg.all_nodes_recursive() + if isinstance(n, AbstractControlFlowRegion) and type(n) not in ignore_region_types] + else: + blocks = [n for n, _ in sdfg.all_nodes_recursive() if isinstance(n, AbstractControlFlowRegion)] count = 0 for _block in optional_progressbar(reversed(blocks), @@ -1310,23 +1322,17 @@ def inline_control_flow_regions(sdfg: SDFG, permissive: bool = False, progress: n=len(blocks), progress=progress): block: ControlFlowRegion = _block - if block.inline()[0]: + # Control flow regions where the parent is a conditional block are not inlined. + if block.parent_graph and type(block.parent_graph) == ConditionalBlock: + continue + if block.inline(lower_returns=lower_returns)[0]: count += 1 + if eliminate_dead_states: + # Avoid cyclic imports. + from dace.transformation.passes.dead_state_elimination import DeadStateElimination + DeadStateElimination().apply_pass(sdfg, {}) - return count - - -def inline_conditional_blocks(sdfg: SDFG, permissive: bool = False, progress: bool = None) -> int: - blocks = [n for n, _ in sdfg.all_nodes_recursive() if isinstance(n, ConditionalBlock)] - count = 0 - - for _block in optional_progressbar(reversed(blocks), - title='Inlining conditional blocks', - n=len(blocks), - progress=progress): - block: ConditionalBlock = _block - if block.inline()[0]: - count += 1 + sdfg.reset_cfg_list() return count @@ -1455,11 +1461,12 @@ def get_next_nonempty_states(sdfg: SDFG, state: SDFGState) -> Set[SDFGState]: result: Set[SDFGState] = set() # Traverse children until states are not empty - for succ in sdfg.successors(state): - result |= set(dfs_conditional(sdfg, sources=[succ], condition=lambda parent, _: parent.is_empty())) + for succ in state.parent_graph.successors(state): + result |= set(dfs_conditional(state.parent_graph, sources=[succ], + condition=lambda parent, _: parent.number_of_nodes() == 0)) # Filter out empty states - result = {s for s in result if not s.is_empty()} + result = {s for s in result if not s.number_of_nodes() == 0} return result @@ -1550,42 +1557,48 @@ def _traverse(scope: Node, symbols: Dict[str, dtypes.typeclass]): def _tswds_cf_region( sdfg: SDFG, - region: ControlFlowRegion, + cfg: AbstractControlFlowRegion, symbols: Dict[str, dtypes.typeclass], recursive: bool = False) -> Generator[Tuple[SDFGState, Node, Dict[str, dtypes.typeclass]], None, None]: - # Add symbols from inter-state edges along the state machine - start_region = region.start_block - visited = set() - visited_edges = set() - for edge in region.dfs_edges(start_region): - # Source -> inter-state definition -> Destination - visited_edges.add(edge) - # Source - if edge.src not in visited: - visited.add(edge.src) - if isinstance(edge.src, SDFGState): - yield from _tswds_state(sdfg, edge.src, {}, recursive) - elif isinstance(edge.src, ControlFlowRegion): - yield from _tswds_cf_region(sdfg, edge.src, symbols, recursive) - - # Add edge symbols into defined symbols - issyms = edge.data.new_symbols(sdfg, symbols) - symbols.update({k: v for k, v in issyms.items() if v is not None}) - - # Destination - if edge.dst not in visited: - visited.add(edge.dst) - if isinstance(edge.dst, SDFGState): - yield from _tswds_state(sdfg, edge.dst, symbols, recursive) - elif isinstance(edge.dst, ControlFlowRegion): - yield from _tswds_cf_region(sdfg, edge.dst, symbols, recursive) - - # If there is only one state, the DFS will miss it - if start_region not in visited: - if isinstance(start_region, SDFGState): - yield from _tswds_state(sdfg, start_region, symbols, recursive) - elif isinstance(start_region, ControlFlowRegion): - yield from _tswds_cf_region(sdfg, start_region, symbols, recursive) + sub_regions = cfg.sub_regions() or [cfg] + for region in sub_regions: + # Add symbols newly defined by this region, if present. + region_symbols = region.new_symbols(symbols) + symbols.update({k: v for k, v in region_symbols.items() if v is not None}) + + # Add symbols from inter-state edges along the state machine + start_region = region.start_block + visited = set() + visited_edges = set() + for edge in region.dfs_edges(start_region): + # Source -> inter-state definition -> Destination + visited_edges.add(edge) + # Source + if edge.src not in visited: + visited.add(edge.src) + if isinstance(edge.src, SDFGState): + yield from _tswds_state(sdfg, edge.src, {}, recursive) + elif isinstance(edge.src, AbstractControlFlowRegion): + yield from _tswds_cf_region(sdfg, edge.src, symbols, recursive) + + # Add edge symbols into defined symbols + issyms = edge.data.new_symbols(sdfg, symbols) + symbols.update({k: v for k, v in issyms.items() if v is not None}) + + # Destination + if edge.dst not in visited: + visited.add(edge.dst) + if isinstance(edge.dst, SDFGState): + yield from _tswds_state(sdfg, edge.dst, symbols, recursive) + elif isinstance(edge.dst, AbstractControlFlowRegion): + yield from _tswds_cf_region(sdfg, edge.dst, symbols, recursive) + + # If there is only one state, the DFS will miss it + if start_region not in visited: + if isinstance(start_region, SDFGState): + yield from _tswds_state(sdfg, start_region, symbols, recursive) + elif isinstance(start_region, AbstractControlFlowRegion): + yield from _tswds_cf_region(sdfg, start_region, symbols, recursive) def traverse_sdfg_with_defined_symbols( @@ -1630,41 +1643,44 @@ def is_fpga_kernel(sdfg, state): return at_least_one_fpga_array +CFBlockDictT = Dict[ControlFlowBlock, ControlFlowBlock] + + def postdominators( - sdfg: SDFG, + cfg: ControlFlowRegion, return_alldoms: bool = False -) -> Optional[Union[Dict[SDFGState, SDFGState], Tuple[Dict[SDFGState, SDFGState], Dict[SDFGState, Set[SDFGState]]]]]: +) -> Optional[Union[CFBlockDictT, Tuple[CFBlockDictT, Dict[ControlFlowBlock, Set[ControlFlowBlock]]]]]: """ - Return the immediate postdominators of an SDFG. This may require creating new nodes and removing them, which - happens in-place on the SDFG. + Return the immediate postdominators of a CFG. This may require creating new nodes and removing them, which + happens in-place on the CFG. - :param sdfg: The SDFG to generate the postdominators from. + :param cfg: The CFG to generate the postdominators from. :param return_alldoms: If True, returns the "all postdominators" dictionary as well. :return: Immediate postdominators, or a 2-tuple of (ipostdom, allpostdoms) if ``return_alldoms`` is True. """ - from dace.sdfg.analysis import cfg + from dace.sdfg.analysis import cfg as cfg_analysis # Get immediate post-dominators - sink_nodes = sdfg.sink_nodes() + sink_nodes = cfg.sink_nodes() if len(sink_nodes) > 1: - sink = sdfg.add_state() + sink = cfg.add_state() for snode in sink_nodes: - sdfg.add_edge(snode, sink, dace.InterstateEdge()) + cfg.add_edge(snode, sink, dace.InterstateEdge()) elif len(sink_nodes) == 0: return None else: sink = sink_nodes[0] - ipostdom: Dict[SDFGState, SDFGState] = nx.immediate_dominators(sdfg._nx.reverse(), sink) + ipostdom: CFBlockDictT = nx.immediate_dominators(cfg._nx.reverse(), sink) if return_alldoms: - allpostdoms = cfg.all_dominators(sdfg, ipostdom) + allpostdoms = cfg_analysis.all_dominators(cfg, ipostdom) retval = (ipostdom, allpostdoms) else: retval = ipostdom # If a new sink was added for post-dominator computation, remove it if len(sink_nodes) > 1: - sdfg.remove_node(sink) + cfg.remove_node(sink) return retval @@ -1960,3 +1976,119 @@ def get_global_memlet_path_dst(sdfg: SDFG, state: SDFGState, edge: MultiConnecto pedge = pedges[0] return get_global_memlet_path_dst(psdfg, pstate, pedge) return dst + + +def get_control_flow_block_dominators(sdfg: SDFG, + idom: Optional[Dict[ControlFlowBlock, ControlFlowBlock]] = None, + all_dom: Optional[Dict[ControlFlowBlock, Set[ControlFlowBlock]]] = None, + ipostdom: Optional[Dict[ControlFlowBlock, ControlFlowBlock]] = None, + all_postdom: Optional[Dict[ControlFlowBlock, Set[ControlFlowBlock]]] = None): + """ + Find the dominator and postdominator relationship between control flow blocks of an SDFG. + This transitively computes the domination relationship across control flow regions, as if the SDFG were to be + inlined entirely. + + :param idom: A dictionary in which to store immediate dominator relationships. Not computed if None. + :param all_dom: A dictionary in which to store all dominator relationships. Not computed if None. + :param ipostdom: A dictionary in which to store immediate postdominator relationships. Not computed if None. + :param all_postdom: A dictionary in which to all postdominator relationships. Not computed if None. + """ + # Avoid cyclic import + from dace.sdfg.analysis import cfg as cfg_analysis + + if idom is not None or all_dom is not None: + added_sinks: Dict[AbstractControlFlowRegion, SDFGState] = {} + if idom is None: + idom = {} + for cfg in sdfg.all_control_flow_regions(parent_first=True): + if isinstance(cfg, ConditionalBlock): + continue + sinks = cfg.sink_nodes() + if len(sinks) > 1: + added_sinks[cfg] = cfg.add_state() + for s in sinks: + cfg.add_edge(s, added_sinks[cfg], InterstateEdge()) + idom.update(nx.immediate_dominators(cfg.nx, cfg.start_block)) + # Compute the transitive relationship of immediate dominators: + # - For every start state in a control flow region, the immediate dominator is the immediate dominator of the + # parent control flow region. + # - If the immediate dominator is a conditional or a loop, change the immediate dominator to be the immediate + # dominator of that loop or conditional. + # - If the immediate dominator is any other control flow region, change the immediate dominator to be the + # immediate dominator of that region's end / exit - or a virtual one if no single one exists. + for k, _ in idom.items(): + if k.parent_graph is not sdfg and k is k.parent_graph.start_block: + next_dom = idom[k.parent_graph] + while next_dom.parent_graph is not sdfg and next_dom is next_dom.parent_graph.start_block: + next_dom = idom[next_dom.parent_graph] + idom[k] = next_dom + changed = True + while changed: + changed = False + for k, v in idom.items(): + if isinstance(v, AbstractControlFlowRegion): + if isinstance(v, (LoopRegion, ConditionalBlock)): + idom[k] = idom[v] + else: + if v in added_sinks: + idom[k] = idom[added_sinks[v]] + else: + idom[k] = v.sink_nodes()[0] + if idom[k] is not v: + changed = True + + for cf, v in added_sinks.items(): + cf.remove_node(v) + + if all_dom is not None: + all_dom.update(cfg_analysis.all_dominators(sdfg, idom)) + + if ipostdom is not None or all_postdom is not None: + added_sinks: Dict[AbstractControlFlowRegion, SDFGState] = {} + sinks_per_cfg: Dict[AbstractControlFlowRegion, ControlFlowBlock] = {} + if ipostdom is None: + ipostdom = {} + + for cfg in sdfg.all_control_flow_regions(parent_first=True): + if isinstance(cfg, ConditionalBlock): + continue + # Get immediate post-dominators + sink_nodes = cfg.sink_nodes() + if len(sink_nodes) > 1: + sink = cfg.add_state() + added_sinks[cfg] = sink + sinks_per_cfg[cfg] = sink + for snode in sink_nodes: + cfg.add_edge(snode, sink, dace.InterstateEdge()) + elif len(sink_nodes) == 0: + return None + else: + sink = sink_nodes[0] + sinks_per_cfg[cfg] = sink + ipostdom.update(nx.immediate_dominators(cfg._nx.reverse(), sink)) + + # Compute the transitive relationship of immediate postdominators, similar to how it works for immediate + # dominators, but inverse. + for k, _ in ipostdom.items(): + if k.parent_graph is not sdfg and k is sinks_per_cfg[k.parent_graph]: + next_pdom = ipostdom[k.parent_graph] + while next_pdom.parent_graph is not sdfg and next_pdom is sinks_per_cfg[next_pdom.parent_graph]: + next_pdom = ipostdom[next_pdom.parent_graph] + ipostdom[k] = next_pdom + changed = True + while changed: + changed = False + for k, v in ipostdom.items(): + if isinstance(v, AbstractControlFlowRegion): + if isinstance(v, (LoopRegion, ConditionalBlock)): + ipostdom[k] = ipostdom[v] + else: + ipostdom[k] = v.start_block + if ipostdom[k] is not v: + changed = True + + for cf, v in added_sinks.items(): + cf.remove_node(v) + + if all_postdom is not None: + all_postdom.update(cfg_analysis.all_dominators(sdfg, ipostdom)) diff --git a/dace/sdfg/validation.py b/dace/sdfg/validation.py index 4cb6856415..b030d85466 100644 --- a/dace/sdfg/validation.py +++ b/dace/sdfg/validation.py @@ -39,6 +39,7 @@ def validate_control_flow_region(sdfg: 'SDFG', symbols: dict, references: Set[int] = None, **context: bool): + from dace.sdfg.state import SDFGState, ControlFlowRegion, ConditionalBlock, LoopRegion from dace.sdfg.scope import is_in_scope from dace.sdfg.state import ConditionalBlock, ControlFlowRegion, SDFGState @@ -75,8 +76,15 @@ def validate_control_flow_region(sdfg: 'SDFG', if isinstance(edge.src, SDFGState): validate_state(edge.src, region.node_id(edge.src), sdfg, symbols, initialized_transients, references, **context) + elif isinstance(edge.src, ConditionalBlock): + for _, r in edge.src.branches: + if r is not None: + validate_control_flow_region(sdfg, r, initialized_transients, symbols, references, **context) elif isinstance(edge.src, ControlFlowRegion): - validate_control_flow_region(sdfg, edge.src, initialized_transients, symbols, references, **context) + lsyms = copy.copy(symbols) + if isinstance(edge.src, LoopRegion) and not edge.src.loop_variable in lsyms: + lsyms[edge.src.loop_variable] = None + validate_control_flow_region(sdfg, edge.src, initialized_transients, lsyms, references, **context) ########################################## # Edge @@ -139,7 +147,10 @@ def validate_control_flow_region(sdfg: 'SDFG', if r is not None: validate_control_flow_region(sdfg, r, initialized_transients, symbols, references, **context) elif isinstance(edge.dst, ControlFlowRegion): - validate_control_flow_region(sdfg, edge.dst, initialized_transients, symbols, references, **context) + lsyms = copy.copy(symbols) + if isinstance(edge.dst, LoopRegion) and not edge.dst.loop_variable in lsyms: + lsyms[edge.dst.loop_variable] = None + validate_control_flow_region(sdfg, edge.dst, initialized_transients, lsyms, references, **context) # End of block DFS # If there is only one block, the DFS will miss it @@ -147,8 +158,15 @@ def validate_control_flow_region(sdfg: 'SDFG', if isinstance(start_block, SDFGState): validate_state(start_block, region.node_id(start_block), sdfg, symbols, initialized_transients, references, **context) + elif isinstance(start_block, ConditionalBlock): + for _, r in start_block.branches: + if r is not None: + validate_control_flow_region(sdfg, r, initialized_transients, symbols, references, **context) elif isinstance(start_block, ControlFlowRegion): - validate_control_flow_region(sdfg, start_block, initialized_transients, symbols, references, **context) + lsyms = copy.copy(symbols) + if isinstance(start_block, LoopRegion) and not start_block.loop_variable in lsyms: + lsyms[start_block.loop_variable] = None + validate_control_flow_region(sdfg, start_block, initialized_transients, lsyms, references, **context) # Validate all inter-state edges (including self-loops not found by DFS) for eid, edge in enumerate(region.edges()): diff --git a/dace/transformation/__init__.py b/dace/transformation/__init__.py index 0b27542ca6..7f1a1fb064 100644 --- a/dace/transformation/__init__.py +++ b/dace/transformation/__init__.py @@ -1,4 +1,4 @@ from .transformation import (PatternNode, PatternTransformation, SingleStateTransformation, MultiStateTransformation, SubgraphTransformation, ExpandTransformation, - experimental_cfg_block_compatible, single_level_sdfg_only) + explicit_cf_compatible, single_level_sdfg_only) from .pass_pipeline import Pass, Pipeline, FixedPointPipeline diff --git a/dace/transformation/auto/auto_optimize.py b/dace/transformation/auto/auto_optimize.py index 0a1371ae68..0c74842634 100644 --- a/dace/transformation/auto/auto_optimize.py +++ b/dace/transformation/auto/auto_optimize.py @@ -577,8 +577,6 @@ def auto_optimize(sdfg: SDFG, sdfg.apply_transformations_repeated(TrivialMapElimination, validate=validate, validate_all=validate_all) while transformed: sdfg.simplify(validate=False, validate_all=validate_all) - for s in sdfg.cfg_list: - xfh.split_interstate_edges(s) l2ms = sdfg.apply_transformations_repeated((LoopToMap, RefineNestedAccess), validate=False, validate_all=validate_all) @@ -598,6 +596,7 @@ def auto_optimize(sdfg: SDFG, # fuse subgraphs greedily sdfg.simplify() + sdfg.reset_cfg_list() greedy_fuse(sdfg, device=device, validate_all=validate_all) @@ -668,6 +667,8 @@ def auto_optimize(sdfg: SDFG, print("Specializing the SDFG for symbols", known_symbols) sdfg.specialize(known_symbols) + sdfg.reset_cfg_list() + # Validate at the end if validate or validate_all: sdfg.validate() diff --git a/dace/transformation/dataflow/__init__.py b/dace/transformation/dataflow/__init__.py index 6fa274f041..e12ee8e1a9 100644 --- a/dace/transformation/dataflow/__init__.py +++ b/dace/transformation/dataflow/__init__.py @@ -5,7 +5,7 @@ from .mapreduce import MapReduceFusion, MapWCRFusion from .map_expansion import MapExpansion from .map_collapse import MapCollapse -from .map_for_loop import MapToForLoop, MapToForLoopRegion +from .map_for_loop import MapToForLoop from .map_interchange import MapInterchange from .map_dim_shuffle import MapDimShuffle from .map_fusion import MapFusion diff --git a/dace/transformation/dataflow/double_buffering.py b/dace/transformation/dataflow/double_buffering.py index bb42aa57ac..e0bc76818d 100644 --- a/dace/transformation/dataflow/double_buffering.py +++ b/dace/transformation/dataflow/double_buffering.py @@ -127,8 +127,8 @@ def apply(self, graph: sd.SDFGState, sdfg: sd.SDFG): ############################## # Add initial reads to initial nested state - initial_state: sd.SDFGState = nsdfg_node.sdfg.start_state - initial_state.label = '%s_init' % map_entry.map.label + loop_block = nsdfg_node.sdfg.start_block + initial_state = nsdfg_node.sdfg.add_state_before(loop_block, '%s_init' % map_entry.map.label) for edge in edges_to_replace: initial_state.add_node(edge.src) rnode = edge.src @@ -151,8 +151,7 @@ def apply(self, graph: sd.SDFGState, sdfg: sd.SDFG): ############################## # Add the main state's contents to the last state, modifying # memlets appropriately. - final_state: sd.SDFGState = nsdfg_node.sdfg.sink_nodes()[0] - final_state.label = '%s_final_computation' % map_entry.map.label + final_state = nsdfg_node.sdfg.add_state_after(loop_block, '%s_final_computation' % map_entry.map.label) dup_nstate = copy.deepcopy(nstate) final_state.add_nodes_from(dup_nstate.nodes()) for e in dup_nstate.edges(): diff --git a/dace/transformation/dataflow/map_fission.py b/dace/transformation/dataflow/map_fission.py index 89e3d2d90f..9f40a36b4d 100644 --- a/dace/transformation/dataflow/map_fission.py +++ b/dace/transformation/dataflow/map_fission.py @@ -1,19 +1,21 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Map Fission transformation. """ from copy import deepcopy as dcpy from collections import defaultdict -from dace import registry, sdfg as sd, memlet as mm, subsets, data as dt +from dace import sdfg as sd, memlet as mm, subsets, data as dt from dace.codegen import control_flow as cf +from dace.properties import CodeBlock from dace.sdfg import nodes, graph as gr from dace.sdfg import utils as sdutil -from dace.sdfg.graph import OrderedDiGraph from dace.sdfg.propagation import propagate_memlets_state, propagate_subset +from dace.sdfg.state import ConditionalBlock, LoopRegion from dace.symbolic import pystr_to_symbolic from dace.transformation import transformation, helpers from typing import List, Optional, Tuple +@transformation.explicit_cf_compatible class MapFission(transformation.SingleStateTransformation): """ Implements the MapFission transformation. Map fission refers to subsuming a map scope into its internal subgraph, @@ -64,7 +66,7 @@ def _components(subgraph: gr.SubgraphView) -> List[Tuple[nodes.Node, nodes.Node] return ns @staticmethod - def _border_arrays(sdfg, parent, subgraph): + def _border_arrays(sdfg: sd.SDFG, parent, subgraph): """ Returns a set of array names that are local to the fission subgraph. """ nested = isinstance(parent, sd.SDFGState) @@ -123,12 +125,15 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): # Get NestedSDFG control flow components nsdfg_node.sdfg.reset_cfg_list() - cf_comp = helpers.find_sdfg_control_flow(nsdfg_node.sdfg) - if len(cf_comp) == 1: - child = list(cf_comp.values())[0][1] - conditions = [] - if isinstance(child, (cf.ForScope, cf.WhileScope, cf.IfScope)): - conditions.append(child.condition if isinstance(child, (cf.ForScope, cf.IfScope)) else child.test) + if len(nsdfg_node.sdfg.nodes()) == 1: + child = nsdfg_node.sdfg.nodes()[0] + conditions: List[CodeBlock] = [] + if isinstance(child, LoopRegion): + conditions.append(child.loop_condition) + elif isinstance(child, ConditionalBlock): + for c, _ in child.branches: + if c is not None: + conditions.append(c) for cond in conditions: if any(p in cond.get_free_symbols() for p in map_node.map.params): return False @@ -138,7 +143,7 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): return False if any(p in cond.get_free_symbols() for p in map_node.map.params): return False - helpers.nest_sdfg_control_flow(nsdfg_node.sdfg, cf_comp) + helpers.nest_sdfg_control_flow(nsdfg_node.sdfg) subgraphs = list(nsdfg_node.sdfg.nodes()) @@ -175,7 +180,7 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): # Find all nodes not in subgraph not_subgraph = set(n.data for n in graph.nodes() if n not in snodes and isinstance(n, nodes.AccessNode)) not_subgraph.update( - set(n.data for s in sdfg.nodes() if s != graph for n in s.nodes() + set(n.data for s in sdfg.states() if s != graph for n in s.nodes() if isinstance(n, nodes.AccessNode))) for _, component_out in components: diff --git a/dace/transformation/dataflow/map_for_loop.py b/dace/transformation/dataflow/map_for_loop.py index d7148fc651..0cd872d97d 100644 --- a/dace/transformation/dataflow/map_for_loop.py +++ b/dace/transformation/dataflow/map_for_loop.py @@ -12,7 +12,7 @@ from typing import Tuple, Optional -class MapToForLoopRegion(transformation.SingleStateTransformation): +class MapToForLoop(transformation.SingleStateTransformation): """ Implements the Map to for-loop transformation. Takes a map and enforces a sequential schedule by transforming it into a loop region. Creates a nested SDFG, if @@ -112,27 +112,6 @@ def replace_param(param): sdfg.reset_cfg_list() # Ensure the SDFG is marked as containing CFG regions - sdfg.root_sdfg.using_experimental_blocks = True - - return node, nstate - - -class MapToForLoop(MapToForLoopRegion): - """ Implements the Map to for-loop transformation. - - Takes a map and enforces a sequential schedule by transforming it into - a state-machine of a for-loop. Creates a nested SDFG, if necessary. - """ - - before_state: SDFGState - guard: SDFGState - after_state: SDFGState - - def apply(self, graph: SDFGState, sdfg: SDFG) -> Tuple[nodes.NestedSDFG, SDFGState]: - node, nstate = super().apply(graph, sdfg) - _, (self.before_state, self.guard, self.after_state) = self.loop_region.inline() - - sdfg.reset_cfg_list() - sdfg.recheck_using_experimental_blocks() + sdfg.root_sdfg.using_explicit_control_flow = True return node, nstate diff --git a/dace/transformation/dataflow/mpi.py b/dace/transformation/dataflow/mpi.py index c44c21e9b9..e838c61648 100644 --- a/dace/transformation/dataflow/mpi.py +++ b/dace/transformation/dataflow/mpi.py @@ -102,7 +102,7 @@ def apply(self, graph: SDFGState, sdfg: SDFG): rangeexpr = str(map_entry.map.range.num_elements()) stripmine_subgraph = {StripMining.map_entry: self.subgraph[MPITransformMap.map_entry]} - cfg_id = sdfg.cfg_id + cfg_id = graph.parent_graph.cfg_id stripmine = StripMining() stripmine.setup_match(sdfg, cfg_id, self.state_id, stripmine_subgraph, self.expr_index) stripmine.dim_idx = -1 @@ -128,7 +128,7 @@ def apply(self, graph: SDFGState, sdfg: SDFG): LocalStorage.node_a: graph.node_id(outer_map), LocalStorage.node_b: self.subgraph[MPITransformMap.map_entry] } - cfg_id = sdfg.cfg_id + cfg_id = graph.parent_graph.cfg_id in_local_storage = InLocalStorage() in_local_storage.setup_match(sdfg, cfg_id, self.state_id, in_local_storage_subgraph, self.expr_index) in_local_storage.array = e.data.data @@ -146,7 +146,7 @@ def apply(self, graph: SDFGState, sdfg: SDFG): LocalStorage.node_a: graph.node_id(in_map_exit), LocalStorage.node_b: graph.node_id(out_map_exit) } - cfg_id = sdfg.cfg_id + cfg_id = graph.parent_graph.cfg_id outlocalstorage = OutLocalStorage() outlocalstorage.setup_match(sdfg, cfg_id, self.state_id, outlocalstorage_subgraph, self.expr_index) outlocalstorage.array = name diff --git a/dace/transformation/dataflow/otf_map_fusion.py b/dace/transformation/dataflow/otf_map_fusion.py index a793d1e679..aa339b75c5 100644 --- a/dace/transformation/dataflow/otf_map_fusion.py +++ b/dace/transformation/dataflow/otf_map_fusion.py @@ -478,8 +478,11 @@ def advanced_replace(subgraph: StateSubgraphView, s: str, s_: str) -> None: elif isinstance(node, nodes.NestedSDFG): for nsdfg in node.sdfg.all_sdfgs_recursive(): nsdfg.replace(s, s_) - for nstate in nsdfg.nodes(): - for nnode in nstate.nodes(): - if isinstance(nnode, nodes.MapEntry): - params = [s_ if p == s else p for p in nnode.map.params] - nnode.map.params = params + for cfg in nsdfg.all_control_flow_regions(): + cfg.replace(s, s_) + for nblock in cfg.nodes(): + if isinstance(nblock, SDFGState): + for nnode in nblock.nodes(): + if isinstance(nnode, nodes.MapEntry): + params = [s_ if p == s else p for p in nnode.map.params] + nnode.map.params = params diff --git a/dace/transformation/dataflow/tiling.py b/dace/transformation/dataflow/tiling.py index 8a6d75f4db..aa26786c6e 100644 --- a/dace/transformation/dataflow/tiling.py +++ b/dace/transformation/dataflow/tiling.py @@ -56,7 +56,7 @@ def apply(self, graph: SDFGState, sdfg: SDFG): from dace.transformation.dataflow.map_collapse import MapCollapse from dace.transformation.dataflow.strip_mining import StripMining stripmine_subgraph = {StripMining.map_entry: self.subgraph[MapTiling.map_entry]} - cfg_id = sdfg.cfg_id + cfg_id = graph.parent_graph.cfg_id last_map_entry = None removed_maps = 0 diff --git a/dace/transformation/helpers.py b/dace/transformation/helpers.py index 7824030e5c..b703dd402d 100644 --- a/dace/transformation/helpers.py +++ b/dace/transformation/helpers.py @@ -4,15 +4,16 @@ import itertools from networkx import MultiDiGraph -from dace.sdfg.state import ControlFlowRegion +from dace.properties import CodeBlock +from dace.sdfg.state import AbstractControlFlowRegion, ConditionalBlock, ControlFlowBlock, ControlFlowRegion, LoopRegion, ReturnBlock from dace.subsets import Range, Subset, union import dace.subsets as subsets -from typing import Dict, List, Optional, Tuple, Set, Union +from typing import Dict, Iterable, List, Optional, Tuple, Set, Union from dace import data, dtypes, symbolic from dace.codegen import control_flow as cf from dace.sdfg import nodes, utils -from dace.sdfg.graph import SubgraphView, MultiConnectorEdge +from dace.sdfg.graph import Edge, SubgraphView, MultiConnectorEdge from dace.sdfg.scope import ScopeSubgraphView, ScopeTree from dace.sdfg import SDFG, SDFGState, InterstateEdge from dace.sdfg import graph @@ -30,10 +31,13 @@ def nest_sdfg_subgraph(sdfg: SDFG, subgraph: SubgraphView, start: Optional[SDFGS """ # Nest states - states = subgraph.nodes() + blocks: List[ControlFlowBlock] = subgraph.nodes() return_state = None - if len(states) > 1: + if len(blocks) > 1 or isinstance(blocks[0], AbstractControlFlowRegion): + # Avoid cyclic imports + from dace.transformation.passes.analysis import loop_analysis + graph: ControlFlowRegion = blocks[0].parent_graph if start is not None: source_node = start else: @@ -48,6 +52,30 @@ def nest_sdfg_subgraph(sdfg: SDFG, subgraph: SubgraphView, start: Optional[SDFGS raise NotImplementedError sink_node = sink_nodes[0] + all_blocks: List[ControlFlowBlock] = [] + is_edges: List[Edge[InterstateEdge]] = [] + for b in blocks: + if isinstance(b, AbstractControlFlowRegion): + all_blocks.append(b) + for nb in b.all_control_flow_blocks(): + all_blocks.append(nb) + for e in b.all_interstate_edges(): + is_edges.append(e) + else: + all_blocks.append(b) + states: List[SDFGState] = [b for b in all_blocks if isinstance(b, SDFGState)] + for src in blocks: + for dst in blocks: + for edge in graph.edges_between(src, dst): + is_edges.append(edge) + return_blocks: Set[ReturnBlock] = set([b for b in all_blocks if isinstance(b, ReturnBlock)]) + if len(return_blocks) > 0: + did_return_inner = '_did_ret_from_nsdfg' + did_return_inner = sdfg._find_new_name(did_return_inner) + sdfg.add_scalar(did_return_inner, dtypes.int32, transient=True) + else: + did_return_inner = None + # Find read/write sets read_set, write_set = set(), set() for state in states: @@ -67,12 +95,21 @@ def nest_sdfg_subgraph(sdfg: SDFG, subgraph: SubgraphView, start: Optional[SDFGS if e.data.data and e.data.data in sdfg.arrays: write_set.add(e.data.data) # Add data from edges - for src in states: - for dst in states: - for edge in sdfg.edges_between(src, dst): - for s in edge.data.free_symbols: - if s in sdfg.arrays: - read_set.add(s) + for edge in is_edges: + for s in edge.data.free_symbols: + if s in sdfg.arrays: + read_set.add(s) + for blk in all_blocks: + if isinstance(blk, ConditionalBlock): + for c, _ in blk.branches: + if c is not None: + for s in c.get_free_symbols(): + if s in sdfg.arrays: + read_set.add(s) + elif isinstance(blk, LoopRegion): + for s in blk.loop_condition.get_free_symbols(): + if s in sdfg.arrays: + read_set.add(s) # Find NestedSDFG's unique data rw_set = read_set | write_set @@ -82,7 +119,7 @@ def nest_sdfg_subgraph(sdfg: SDFG, subgraph: SubgraphView, start: Optional[SDFGS continue found = False for state in sdfg.states(): - if state in states: + if state in blocks: continue for node in state.nodes(): if (isinstance(node, nodes.AccessNode) and node.data == name): @@ -98,7 +135,7 @@ def nest_sdfg_subgraph(sdfg: SDFG, subgraph: SubgraphView, start: Optional[SDFGS # Find defined subgraph symbols defined_symbols = set() strictly_defined_symbols = set() - for e in subgraph.edges(): + for e in is_edges: defined_symbols.update(set(e.data.assignments.keys())) for k, v in e.data.assignments.items(): try: @@ -107,22 +144,65 @@ def nest_sdfg_subgraph(sdfg: SDFG, subgraph: SubgraphView, start: Optional[SDFGS except AttributeError: # `symbolic.pystr_to_symbolic` may return bool, which doesn't have attribute `args` pass + for b in all_blocks: + if isinstance(b, LoopRegion) and b.loop_variable: + defined_symbols.add(b.loop_variable) + if b.loop_variable not in sdfg.symbols: + if b.init_statement: + init_assignment = loop_analysis.get_init_assignment(b) + if b.loop_variable not in {str(s) for s in symbolic.pystr_to_symbolic(init_assignment).args}: + strictly_defined_symbols.add(b.loop_variable) + else: + strictly_defined_symbols.add(b.loop_variable) + + return_state = new_state = graph.add_state('nested_sdfg_parent') + + # If there is a return that is being nested in, a conditional return is added right after the new nested SDFG + # which will be taken if the inner, nested return was hit. + ret_cond = None + if len(return_blocks) > 0: + ret_cond = ConditionalBlock('return_' + sdfg.label + '_from_nested', sdfg, graph) + graph.add_node(ret_cond, ensure_unique_name=True) + ret_branch = ControlFlowRegion('return_' + sdfg.label + '_from_nested_body', sdfg, ret_cond) + ret_block = ReturnBlock('return', sdfg, ret_branch) + ret_branch.add_node(ret_block) + ret_cond.add_branch(CodeBlock(did_return_inner), ret_branch) - return_state = new_state = sdfg.add_state('nested_sdfg_parent') nsdfg = SDFG("nested_sdfg", constants=sdfg.constants_prop, parent=new_state) nsdfg.add_node(source_node, is_start_state=True) - nsdfg.add_nodes_from([s for s in states if s is not source_node]) - for s in states: - s.parent = nsdfg + nsdfg.add_nodes_from([s for s in blocks if s is not source_node]) for e in subgraph.edges(): nsdfg.add_edge(e.src, e.dst, e.data) - for e in sdfg.in_edges(source_node): - sdfg.add_edge(e.src, new_state, e.data) - for e in sdfg.out_edges(sink_node): - sdfg.add_edge(new_state, e.dst, e.data) + # Annotate any transitions to return blocks in the inner, nested SDFG by first setting the added transient + # scalar to 1 / true to detect that the inner SDFG returned. + if len(return_blocks) > 0: + for blk in nsdfg.all_control_flow_blocks(): + if blk in return_blocks: + pre_state = blk.parent_graph.add_state_before(blk) + did_ret_tasklet = pre_state.add_tasklet('__did_ret_set', {}, {'out'}, 'out = 1') + did_ret_access = pre_state.add_access(did_return_inner) + pre_state.add_edge(did_ret_tasklet, 'out', did_ret_access, None, Memlet(did_return_inner + '[0]')) + write_set.add(did_return_inner) + + if ret_cond is not None: + pre_state = graph.add_state('before_nested_sdfg_parent') + for e in graph.in_edges(source_node): + graph.add_edge(e.src, pre_state, e.data) + did_ret_tasklet = pre_state.add_tasklet('__did_ret_init', {}, {'out'}, 'out = 0') + did_ret_access = pre_state.add_access(did_return_inner) + pre_state.add_edge(did_ret_tasklet, 'out', did_ret_access, None, Memlet(did_return_inner + '[0]')) + graph.add_edge(pre_state, new_state, InterstateEdge()) + graph.add_edge(new_state, ret_cond, InterstateEdge()) + for e in graph.out_edges(sink_node): + graph.add_edge(ret_cond, e.dst, e.data) + else: + for e in graph.in_edges(source_node): + graph.add_edge(e.src, new_state, e.data) + for e in graph.out_edges(sink_node): + graph.add_edge(new_state, e.dst, e.data) - sdfg.remove_nodes_from(states) + graph.remove_nodes_from(blocks) # Add NestedSDFG arrays for name in read_set | write_set: @@ -139,8 +219,11 @@ def nest_sdfg_subgraph(sdfg: SDFG, subgraph: SubgraphView, start: Optional[SDFGS ndefined_symbols = set() out_mapping = {} out_state = None - for e in nsdfg.edges(): + for e in nsdfg.all_interstate_edges(): ndefined_symbols.update(set(e.data.assignments.keys())) + for b in all_blocks: + if isinstance(b, LoopRegion) and b.loop_variable is not None and b.loop_variable != '' and b.init_statement: + ndefined_symbols.add(b.loop_variable) if ndefined_symbols: out_state = nsdfg.add_state('symbolic_output') nsdfg.add_edge(sink_node, out_state, InterstateEdge()) @@ -160,7 +243,8 @@ def nest_sdfg_subgraph(sdfg: SDFG, subgraph: SubgraphView, start: Optional[SDFGS # Add NestedSDFG node fsymbols = sdfg.symbols.keys() | nsdfg.free_symbols - fsymbols.update(defined_symbols - strictly_defined_symbols) + fsymbols.update(defined_symbols) + fsymbols = fsymbols - strictly_defined_symbols mapping = {s: s for s in fsymbols} cnode = new_state.add_nested_sdfg(nsdfg, None, read_set, write_set, mapping) for s in strictly_defined_symbols: @@ -177,188 +261,29 @@ def nest_sdfg_subgraph(sdfg: SDFG, subgraph: SubgraphView, start: Optional[SDFGS # Part (2) if out_state is not None: - extra_state = sdfg.add_state('symbolic_output') - for e in sdfg.out_edges(new_state): - sdfg.add_edge(extra_state, e.dst, e.data) - sdfg.remove_edge(e) - sdfg.add_edge(new_state, extra_state, InterstateEdge(assignments=out_mapping)) + extra_state = graph.add_state('symbolic_output') + for e in graph.out_edges(new_state): + graph.add_edge(extra_state, e.dst, e.data) + graph.remove_edge(e) + graph.add_edge(new_state, extra_state, InterstateEdge(assignments=out_mapping)) new_state = extra_state else: - return_state = states[0] + return_state = blocks[0] return return_state -def _copy_state(sdfg: SDFG, - state: SDFGState, - before: bool = True, - states: Optional[Set[SDFGState]] = None) -> SDFGState: - """ - Duplicates a state, placing the copy before or after (see param before) the original and redirecting a subset of its - edges (see param state). The state is expected to be a scope's source or sink state and this method facilitates the - nesting of SDFG subgraphs where the state may be part of multiple scopes. - - :param state: The SDFGState to copy. - :param before: True if the copy should be placed before the original. - :param states: A collection of SDFGStates that should be considered for edge redirection. - :return: The SDFGState copy. - """ - - state_copy = copy.deepcopy(state) - state_copy._label += '_copy' - state_copy.parent = sdfg - sdfg.add_node(state_copy) - - in_conditions = [] - for e in sdfg.in_edges(state): - if states and e.src not in states: - continue - sdfg.add_edge(e.src, state_copy, e.data) - sdfg.remove_edge(e) - if not e.data.is_unconditional(): - in_conditions.append(e.data.condition.as_string) - - out_conditions = [] - for e in sdfg.out_edges(state): - if states and e.dst not in states: - continue - sdfg.add_edge(state_copy, e.dst, e.data) - sdfg.remove_edge(e) - if not e.data.is_unconditional(): - out_conditions.append(e.data.condition.as_string) - - if before: - condition = None - if in_conditions: - condition = 'or'.join([f"({c})" for c in in_conditions]) - sdfg.add_edge(state_copy, state, InterstateEdge(condition=condition)) - else: - condition = None - # NOTE: The following should be unecessary for preserving program semantics. Therefore we comment it out to - # avoid the overhead of evaluating the condition. - # if out_conditions: - # condition = 'or'.join([f"({c})" for c in out_conditions]) - sdfg.add_edge(state, state_copy, InterstateEdge(condition=condition)) - - return state_copy - - -def find_sdfg_control_flow(sdfg: SDFG) -> Dict[SDFGState, Set[SDFGState]]: - """ - Partitions the SDFG to subgraphs that can be nested independently of each other. The method does not nest the - subgraphs but alters the SDFG; (1) interstate edges are split, (2) scope source/sink states that belong to multiple - scopes are duplicated (see _copy_state). - - :param sdfg: The SDFG to be partitioned. - :return: The found subgraphs in the form of a dictionary where the keys are the start state of the subgraphs and the - values are the sets of SDFGStates contained withing each subgraph. - """ - - split_interstate_edges(sdfg) - - # Create a unique sink state to avoid issues with finding control flow. - sink_states = sdfg.sink_nodes() - if len(sink_states) > 1: - new_sink = sdfg.add_state('common_sink') - for s in sink_states: - sdfg.add_edge(s, new_sink, InterstateEdge()) - - ipostdom = utils.postdominators(sdfg) - cft = cf.structured_control_flow_tree(sdfg, None) - - # Iterate over the SDFG's control flow scopes and create for each an SDFG subraph. These subgraphs must be disjoint, - # so we duplicate SDFGStates that appear in more than one scopes (guards and exits of loops and conditionals). - components = {} - visited = {} # Dict[SDFGState, bool]: True if SDFGState in Scope (non-SingleState) - for i, child in enumerate(cft.children): - if isinstance(child, cf.BasicCFBlock): - if child.state in visited: - continue - components[child.state] = (set([child.state]), child) - visited[child.state] = False - elif isinstance(child, (cf.ForScope, cf.WhileScope)): - guard = child.guard - fexit = None - condition = child.condition if isinstance(child, cf.ForScope) else child.test - for e in sdfg.out_edges(guard): - if e.data.condition != condition: - fexit = e.dst - break - if fexit is None: - raise ValueError("Cannot find for-scope's exit states.") - - states = set(utils.dfs_conditional(sdfg, [guard], lambda p, _: p is not fexit)) - - if guard in visited: - if visited[guard]: - guard_copy = _copy_state(sdfg, guard, False, states) - guard.remove_nodes_from(guard.nodes()) - states.remove(guard) - states.add(guard_copy) - guard = guard_copy - else: - del components[guard] - del visited[guard] - - if not (i == len(cft.children) - 2 and isinstance(cft.children[i + 1], cf.BasicCFBlock) - and cft.children[i + 1].state is fexit): - fexit_copy = _copy_state(sdfg, fexit, True, states) - fexit.remove_nodes_from(fexit.nodes()) - states.remove(fexit) - states.add(fexit_copy) - - components[guard] = (states, child) - visited.update({s: True for s in states}) - elif isinstance(child, (cf.IfScope, cf.IfElseChain)): - guard = child.branch_block - ifexit = ipostdom[guard] - - states = set(utils.dfs_conditional(sdfg, [guard], lambda p, _: p is not ifexit)) - - if guard in visited: - if visited[guard]: - guard_copy = _copy_state(sdfg, guard, False, states) - guard.remove_nodes_from(guard.nodes()) - states.remove(guard) - states.add(guard_copy) - guard = guard_copy - else: - del components[guard] - del visited[guard] - - if not (i == len(cft.children) - 2 and isinstance(cft.children[i + 1], cf.BasicCFBlock) - and cft.children[i + 1].state is ifexit): - ifexit_copy = _copy_state(sdfg, ifexit, True, states) - ifexit.remove_nodes_from(ifexit.nodes()) - states.remove(ifexit) - states.add(ifexit_copy) - - components[guard] = (states, child) - visited.update({s: True for s in states}) - else: - raise ValueError(f"Unsupported control flow class {type(child)}") - - return components - - -def nest_sdfg_control_flow(sdfg: SDFG, components=None): +def nest_sdfg_control_flow(sdfg: SDFG): """ Partitions the SDFG to subgraphs and nests them. :param sdfg: The SDFG to be partitioned. - :param components: An existing partition of the SDFG. """ - - components = components or find_sdfg_control_flow(sdfg) - - num_components = len(components) - - if num_components < 2: - return - - for i, (start, (component, _)) in enumerate(components.items()): - nest_sdfg_subgraph(sdfg, graph.SubgraphView(sdfg, component), start) + for nd in sdfg.nodes(): + if isinstance(nd, AbstractControlFlowRegion): + nest_sdfg_subgraph(sdfg, SubgraphView(sdfg, [nd])) + sdfg.reset_cfg_list() def nest_state_subgraph(sdfg: SDFG, @@ -1069,7 +994,7 @@ def constant_symbols(sdfg: SDFG) -> Set[str]: :param sdfg: The input SDFG. :return: A set of symbol names that remain constant throughout the SDFG. """ - interstate_symbols = {k for e in sdfg.edges() for k in e.data.assignments.keys()} + interstate_symbols = {k for e in sdfg.all_interstate_edges() for k in e.data.assignments.keys()} return set(sdfg.symbols) - interstate_symbols @@ -1214,7 +1139,7 @@ def traverse(state: SDFGState, treenode: ScopeTree): snodes = state.scope_children()[treenode.entry] for node in snodes: if isinstance(node, nodes.NestedSDFG): - for nstate in node.sdfg.nodes(): + for nstate in node.sdfg.states(): ntree = nstate.scope_tree()[None] ntree.state = nstate treenode.children.append(ntree) @@ -1429,8 +1354,8 @@ def can_run_state_on_fpga(state: SDFGState): return False # Streams have strict conditions due to code generator limitations - if (isinstance(node, nodes.AccessNode) and isinstance(graph.parent.arrays[node.data], data.Stream)): - nodedesc = graph.parent.arrays[node.data] + if (isinstance(node, nodes.AccessNode) and isinstance(graph.sdfg.arrays[node.data], data.Stream)): + nodedesc = graph.sdfg.arrays[node.data] sdict = graph.scope_dict() if nodedesc.storage in [ dtypes.StorageType.CPU_Heap, dtypes.StorageType.CPU_Pinned, dtypes.StorageType.CPU_ThreadLocal @@ -1442,7 +1367,7 @@ def can_run_state_on_fpga(state: SDFGState): return False # Arrays of streams cannot have symbolic size on FPGA - if symbolic.issymbolic(nodedesc.total_size, graph.parent.constants): + if symbolic.issymbolic(nodedesc.total_size, graph.sdfg.constants): return False # Streams cannot be unbounded on FPGA @@ -1562,3 +1487,98 @@ def make_map_internal_write_external(sdfg: SDFG, state: SDFGState, map_exit: nod memlet=Memlet(data=sink.data, subset=copy.deepcopy(subset), other_subset=copy.deepcopy(subset))) + + +def all_isedges_between(src: ControlFlowBlock, dst: ControlFlowBlock) -> Iterable[Edge[InterstateEdge]]: + """ + Helper function that generates an iterable of all edges potentially encountered between two control flow blocks. + """ + if src.sdfg is not dst.sdfg: + raise RuntimeError('Blocks reside in different SDFGs') + + if src.parent_graph is dst.parent_graph: + # Simple case where both blocks reside in the same graph: + edges = set() + for p in src.parent_graph.all_simple_paths(src, dst, as_edges=True): + for e in p: + edges.add(e) + if isinstance(e.dst, ControlFlowRegion): + edges.update(e.dst.all_interstate_edges()) + return edges + else: + # In the case where the two blocks are not in the same graph, we follow this procedure: + # 1. Collect the list of control flow regions on the direct path between the source and destination: + # a) Determine the 'lowest common parent' region + # b) Determine the list of parents of the source before the common parent is reached + # c) Determine the list of parents of the destination before the common parent is reached. + # 2. In each of the parents of the source, add all edges from the source or the next parent until the + # end(s) of each region to the result + # 3. In each of the destination's parents, add all edges from the start block on until the destination + # or next parent to the result. + # 4. In the lowest common parent region, find all edge paths between the next parent regions for both + # the source and destination. + # Note that for each edge, if the destination is a control flow region, any edges inside of it may also + # be on the path and consequently also need to be added. + edges = set() + + # Step 1.a): Find the lowest common parent region. + common_regions = set() + pivot_graph = src.parent_graph + all_parent_regions_src = [pivot_graph] + while not isinstance(pivot_graph, SDFG): + pivot_graph = pivot_graph.parent_graph + all_parent_regions_src.append(pivot_graph) + pivot_graph = dst.parent_graph + all_parent_regions_dst = [pivot_graph] + while not isinstance(pivot_graph, SDFG): + pivot_graph = pivot_graph.parent_graph + all_parent_regions_dst.append(pivot_graph) + if pivot_graph in all_parent_regions_src: + common_regions.add(pivot_graph) + + # Step 1.b) and 1.c): Determine the list of parents involved in the path for the source and destination. + involved_src: List[ControlFlowRegion] = [] + involved_dst: List[ControlFlowRegion] = [] + common_parent: ControlFlowRegion = None + for r in all_parent_regions_src: + if r not in common_regions: + involved_src.append(r) + else: + common_parent = r + break + for r in all_parent_regions_dst: + if r not in common_regions: + involved_dst.append(r) + else: + if r is not common_parent: + raise RuntimeError('No common parent found') + break + + # Step 2 + src_pivot = src + for r in involved_src: + for sink in r.sink_nodes(): + for p in r.all_simple_paths(src_pivot, sink, as_edges=True): + for e in p: + edges.add(e) + if isinstance(e.dst, ControlFlowRegion): + edges.update(e.dst.all_interstate_edges()) + src_pivot = r + # Step 3 + dst_pivot = dst + for r in involved_dst: + for p in r.all_simple_paths(r.start_block, dst_pivot, as_edges=True): + for e in p: + edges.add(e) + if isinstance(e.dst, ControlFlowRegion): + edges.update(e.dst.all_interstate_edges()) + dst_pivot = r + + # Step 4 + for p in common_parent.all_simple_paths(src_pivot, dst_pivot, as_edges=True): + for e in p: + edges.add(e) + if isinstance(e.dst, ControlFlowRegion) and not e.dst is dst_pivot: + edges.update(e.dst.all_interstate_edges()) + + return edges diff --git a/dace/transformation/interstate/__init__.py b/dace/transformation/interstate/__init__.py index b8bcc716e6..a53152e09c 100644 --- a/dace/transformation/interstate/__init__.py +++ b/dace/transformation/interstate/__init__.py @@ -1,6 +1,7 @@ # Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. """ This module initializes the inter-state transformations package.""" +from .block_fusion import BlockFusion from .state_fusion import StateFusion from .state_fusion_with_happens_before import StateFusionExtended from .state_elimination import (EndStateElimination, StartStateElimination, StateAssignElimination, diff --git a/dace/transformation/interstate/block_fusion.py b/dace/transformation/interstate/block_fusion.py new file mode 100644 index 0000000000..cf180ad771 --- /dev/null +++ b/dace/transformation/interstate/block_fusion.py @@ -0,0 +1,102 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. + +from dace.sdfg import utils as sdutil +from dace.sdfg.state import AbstractControlFlowRegion, ControlFlowBlock, ControlFlowRegion, SDFGState +from dace.transformation import transformation + + +@transformation.explicit_cf_compatible +class BlockFusion(transformation.MultiStateTransformation): + """ Implements the block-fusion transformation. + + Block-fusion takes two control flow blocks that are connected through a single edge, where either one or both + blocks are 'no-op' control flow blocks, and fuses them into one. + """ + + first_block = transformation.PatternNode(ControlFlowBlock) + second_block = transformation.PatternNode(ControlFlowBlock) + + @staticmethod + def annotates_memlets(): + return False + + @classmethod + def expressions(cls): + return [sdutil.node_path_graph(cls.first_block, cls.second_block)] + + def _is_noop(self, block: ControlFlowBlock) -> bool: + if isinstance(block, SDFGState): + return block.is_empty() + elif type(block) == ControlFlowBlock: + return True + return False + + def can_be_applied(self, graph, expr_index, sdfg, permissive=False): + # First block must have only one unconditional output edge (with dst the second block). + out_edges = graph.out_edges(self.first_block) + if len(out_edges) != 1 or out_edges[0].dst is not self.second_block or not out_edges[0].data.is_unconditional(): + return False + # Inversely, the second block may only have one input edge, with src being the first block. + in_edges_second = graph.in_edges(self.second_block) + if len(in_edges_second) != 1 or in_edges_second[0].src is not self.first_block: + return False + + # Ensure that either that both blocks are fusable blocks, meaning that at least one of the two blocks must be + # a 'no-op' block. That can be an empty SDFGState or a general control flow block without further semantics + # (no loop, conditional, break, continue, control flow region, etc.). + if not self._is_noop(self.first_block) and not self._is_noop(self.second_block): + return False + + # The interstate edge may have assignments if there are input edges to the first block that can absorb them. + in_edges = graph.in_edges(self.first_block) + if out_edges[0].data.assignments: + if not in_edges: + return False + # If the first block is a control flow region, no absorption is possible. + if isinstance(self.first_block, AbstractControlFlowRegion): + return False + # Fail if symbol is set before the block to fuse + new_assignments = set(out_edges[0].data.assignments.keys()) + if any((new_assignments & set(e.data.assignments.keys())) for e in in_edges): + return False + # Fail if symbol is used in the dataflow of that block + if len(new_assignments & self.first_block.free_symbols) > 0: + return False + # Fail if symbols assigned on the first edge are free symbols on the second edge + symbols_used = set(out_edges[0].data.free_symbols) + for e in in_edges: + if e.data.assignments.keys() & symbols_used: + return False + # Also fail in the inverse; symbols assigned on the second edge are free symbols on the first edge + if new_assignments & set(e.data.free_symbols): + return False + + # There can be no block that has output edges pointing to both the first and the second block. Such a case will + # produce a multi-graph. + for src, _, _ in in_edges: + for _, dst, _ in graph.out_edges(src): + if dst == self.second_block: + return False + return True + + def apply(self, graph: ControlFlowRegion, sdfg): + first_is_start = graph.start_block is self.first_block + connecting_edge = graph.edges_between(self.first_block, self.second_block)[0] + assignments_to_absorb = connecting_edge.data.assignments + graph.remove_edge(connecting_edge) + for ie in graph.in_edges(self.first_block): + if assignments_to_absorb: + ie.data.assignments.update(assignments_to_absorb) + + if self._is_noop(self.first_block): + # We remove the first block and let the second one remain. + for ie in graph.in_edges(self.first_block): + graph.add_edge(ie.src, self.second_block, ie.data) + if first_is_start: + graph.start_block = self.second_block.block_id + graph.remove_node(self.first_block) + else: + # We remove the second block and let the first one remain. + for oe in graph.out_edges(self.second_block): + graph.add_edge(self.first_block, oe.dst, oe.data) + graph.remove_node(self.second_block) diff --git a/dace/transformation/interstate/fpga_transform_sdfg.py b/dace/transformation/interstate/fpga_transform_sdfg.py index ac4672d892..09a6ee2aa8 100644 --- a/dace/transformation/interstate/fpga_transform_sdfg.py +++ b/dace/transformation/interstate/fpga_transform_sdfg.py @@ -1,14 +1,15 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Contains inter-state transformations of an SDFG to run on an FPGA. """ import networkx as nx from dace import properties +from dace.sdfg.sdfg import SDFG from dace.transformation import transformation @properties.make_properties -@transformation.single_level_sdfg_only +@transformation.explicit_cf_compatible class FPGATransformSDFG(transformation.MultiStateTransformation): """ Implements the FPGATransformSDFG transformation, which takes an entire SDFG and transforms it into an FPGA-capable SDFG. """ @@ -28,20 +29,20 @@ def expressions(cls): # Match anything return [nx.DiGraph()] - def can_be_applied(self, graph, expr_index, sdfg, permissive=False): + def can_be_applied(self, graph, expr_index, sdfg: SDFG, permissive=False): # Avoid import loops from dace.transformation.interstate import FPGATransformState # Condition match depends on matching FPGATransformState for each state - for state_id, state in enumerate(sdfg.nodes()): + for state in sdfg.states(): fps = FPGATransformState() - fps.setup_match(sdfg, graph.cfg_id, -1, {FPGATransformState.state: state_id}, 0) - if not fps.can_be_applied(sdfg, expr_index, sdfg): + fps.setup_match(sdfg, state.parent_graph.cfg_id, -1, {FPGATransformState.state: state.block_id}, 0) + if not fps.can_be_applied(state.parent_graph, expr_index, sdfg): return False return True - def apply(self, _, sdfg): + def apply(self, _, sdfg: SDFG): # Avoid import loops from dace.transformation.interstate import NestSDFG from dace.transformation.interstate import FPGATransformState diff --git a/dace/transformation/interstate/fpga_transform_state.py b/dace/transformation/interstate/fpga_transform_state.py index 60a2a33001..dc888d8c33 100644 --- a/dace/transformation/interstate/fpga_transform_state.py +++ b/dace/transformation/interstate/fpga_transform_state.py @@ -1,15 +1,17 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Contains inter-state transformations of an SDFG to run on an FPGA. """ import copy import dace -from dace import data, memlet, dtypes, registry, sdfg as sd, subsets +from dace import memlet, dtypes, sdfg as sd, subsets from dace.sdfg import nodes from dace.sdfg import utils as sdutil +from dace.sdfg.sdfg import SDFG +from dace.sdfg.state import ControlFlowRegion, SDFGState from dace.transformation import transformation, helpers as xfh -def fpga_update(sdfg, state, depth): +def fpga_update(sdfg: SDFG, state: SDFGState, depth: int): scope_dict = state.scope_dict() for node in state.nodes(): if (isinstance(node, nodes.AccessNode) and node.desc(sdfg).storage == dtypes.StorageType.Default): @@ -25,11 +27,11 @@ def fpga_update(sdfg, state, depth): if (hasattr(node, "schedule") and node.schedule == dace.dtypes.ScheduleType.Default): node.schedule = dace.dtypes.ScheduleType.FPGA_Device if isinstance(node, nodes.NestedSDFG): - for s in node.sdfg.nodes(): + for s in node.sdfg.states(): fpga_update(node.sdfg, s, depth + 1) -@transformation.single_level_sdfg_only +@transformation.explicit_cf_compatible class FPGATransformState(transformation.MultiStateTransformation): """ Implements the FPGATransformState transformation. """ @@ -76,7 +78,7 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): return True - def apply(self, _, sdfg): + def apply(self, graph: ControlFlowRegion, sdfg: SDFG): state = self.state # Find source/sink (data) nodes that are relevant outside this FPGA @@ -96,16 +98,12 @@ def apply(self, _, sdfg): # Input nodes may also be nodes with WCR memlets # We have to recur across nested SDFGs to find them wcr_input_nodes = set() - stack = [] - parent_sdfg = {state: sdfg} # Map states to their parent SDFG - for node, graph in state.all_nodes_recursive(): - if isinstance(graph, dace.SDFG): - parent_sdfg[node] = graph + for node, node_parent_graph in state.all_nodes_recursive(): if isinstance(node, dace.sdfg.nodes.AccessNode): - for e in graph.in_edges(node): + for e in node_parent_graph.in_edges(node): if e.data.wcr is not None: - trace = dace.sdfg.trace_nested_access(node, graph, parent_sdfg[graph]) + trace = dace.sdfg.trace_nested_access(node, node_parent_graph, node_parent_graph.sdfg) for node_trace, memlet_trace, state_trace, sdfg_trace in trace: # Find the name of the accessed node in our scope if state_trace == state and sdfg_trace == sdfg: @@ -158,9 +156,9 @@ def apply(self, _, sdfg): sdutil.change_edge_src(state, node, fpga_node) state.remove_node(node) - sdfg.add_node(pre_state) - sdutil.change_edge_dest(sdfg, state, pre_state) - sdfg.add_edge(pre_state, state, sd.InterstateEdge()) + graph.add_node(pre_state) + sdutil.change_edge_dest(graph, state, pre_state) + graph.add_edge(pre_state, state, sd.InterstateEdge()) if output_nodes: @@ -200,9 +198,9 @@ def apply(self, _, sdfg): sdutil.change_edge_dest(state, node, fpga_node) state.remove_node(node) - sdfg.add_node(post_state) - sdutil.change_edge_src(sdfg, state, post_state) - sdfg.add_edge(state, post_state, sd.InterstateEdge()) + graph.add_node(post_state) + sdutil.change_edge_src(graph, state, post_state) + graph.add_edge(state, post_state, sd.InterstateEdge()) # propagate memlet info from a nested sdfg for src, src_conn, dst, dst_conn, mem in state.edges(): diff --git a/dace/transformation/interstate/gpu_transform_sdfg.py b/dace/transformation/interstate/gpu_transform_sdfg.py index 53296529a5..49a2e16227 100644 --- a/dace/transformation/interstate/gpu_transform_sdfg.py +++ b/dace/transformation/interstate/gpu_transform_sdfg.py @@ -1,15 +1,17 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Contains inter-state transformations of an SDFG to run on the GPU. """ -from dace import data, memlet, dtypes, registry, sdfg as sd, symbolic, subsets as sbs, propagate_memlets_sdfg +from dace import data, memlet, dtypes, sdfg as sd, subsets as sbs, propagate_memlets_sdfg from dace.sdfg import nodes, scope from dace.sdfg import utils as sdutil +from dace.sdfg.replace import replace_in_codeblock +from dace.sdfg.state import AbstractControlFlowRegion, ConditionalBlock, LoopRegion, SDFGState from dace.transformation import transformation, helpers as xfh from dace.properties import ListProperty, Property, make_properties from collections import defaultdict from copy import deepcopy as dc from sympy import floor -from typing import Dict +from typing import Dict, List, Set, Tuple gpu_storage = [dtypes.StorageType.GPU_Global, dtypes.StorageType.GPU_Shared, dtypes.StorageType.CPU_Pinned] @@ -83,7 +85,7 @@ def _recursive_in_check(node, state, gpu_scalars): @make_properties -@transformation.single_level_sdfg_only +@transformation.explicit_cf_compatible class GPUTransformSDFG(transformation.MultiStateTransformation): """ Implements the GPUTransformSDFG transformation. @@ -150,7 +152,7 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): if isinstance(node, (nodes.ConsumeEntry, nodes.ConsumeExit)): return False - for state in sdfg.nodes(): + for state in sdfg.states(): schildren = state.scope_children() for node in schildren[None]: # If two top-level tasklets are connected with a code->code @@ -180,8 +182,8 @@ def apply(self, _, sdfg: sd.SDFG): # Step 0: SDFG metadata # Find all input and output data descriptors - input_nodes = [] - output_nodes = [] + input_nodes: List[Tuple[str, data.Data]] = [] + output_nodes: List[Tuple[str, data.Data]] = [] global_code_nodes: Dict[sd.SDFGState, nodes.Tasklet] = defaultdict(list) if self.host_maps is None: self.host_maps = [] @@ -192,13 +194,13 @@ def apply(self, _, sdfg: sd.SDFG): propagate_memlets_sdfg(sdfg) # Input and ouputs of all host_maps need to be marked as host_data - for state in sdfg.nodes(): + for state in sdfg.states(): for node in state.nodes(): if isinstance(node, nodes.EntryNode) and node.guid in self.host_maps: accesses = self._get_marked_inputs_and_outputs(state, node) self.host_data.extend(accesses) - for state in sdfg.nodes(): + for state in sdfg.states(): sdict = state.scope_dict() for node in state.nodes(): if (isinstance(node, nodes.AccessNode) and node.desc(sdfg).transient == False): @@ -222,8 +224,8 @@ def apply(self, _, sdfg: sd.SDFG): if (e.data.data not in input_nodes and sdfg.arrays[e.data.data].transient == False): input_nodes.append((e.data.data, sdfg.arrays[e.data.data])) - start_state = sdfg.start_state - end_states = sdfg.sink_nodes() + start_block = sdfg.start_block + end_blocks = sdfg.sink_nodes() ####################################################### # Step 1: Create cloned GPU arrays and replace originals @@ -262,7 +264,7 @@ def apply(self, _, sdfg: sd.SDFG): found_full_write = False full_subset = sbs.Range.from_array(onode) try: - for state in sdfg.nodes(): + for state in sdfg.states(): for node in state.nodes(): if (isinstance(node, nodes.AccessNode) and node.data == onodename): for e in state.in_edges(node): @@ -283,20 +285,24 @@ def apply(self, _, sdfg: sd.SDFG): if not found_full_write: input_nodes.append((onodename, onode)) - for edge in sdfg.edges(): - memlets = edge.data.get_read_memlets(sdfg.arrays) - for mem in memlets: - if sdfg.arrays[mem.data].storage == dtypes.StorageType.GPU_Global: - data_already_on_gpu[mem.data] = None + check_memlets: List[memlet.Memlet] = [] + for edge in sdfg.all_interstate_edges(): + check_memlets.extend(edge.data.get_read_memlets(sdfg.arrays)) + for blk in sdfg.all_control_flow_blocks(): + if isinstance(blk, AbstractControlFlowRegion): + check_memlets.extend(blk.get_meta_read_memlets()) + for mem in check_memlets: + if sdfg.arrays[mem.data].storage == dtypes.StorageType.GPU_Global: + data_already_on_gpu[mem.data] = None # Replace nodes - for state in sdfg.nodes(): + for state in sdfg.states(): for node in state.nodes(): if (isinstance(node, nodes.AccessNode) and node.data in cloned_arrays): node.data = cloned_arrays[node.data] # Replace memlets - for state in sdfg.nodes(): + for state in sdfg.states(): for edge in state.edges(): if edge.data.data in cloned_arrays: edge.data.data = cloned_arrays[edge.data.data] @@ -306,7 +312,7 @@ def apply(self, _, sdfg: sd.SDFG): excluded_copyin = self.exclude_copyin.split(',') copyin_state = sdfg.add_state(sdfg.label + '_copyin') - sdfg.add_edge(copyin_state, start_state, sd.InterstateEdge()) + sdfg.add_edge(copyin_state, start_block, sd.InterstateEdge()) for nname, desc in dtypes.deduplicate(input_nodes): if nname in excluded_copyin or nname not in cloned_arrays: @@ -322,7 +328,7 @@ def apply(self, _, sdfg: sd.SDFG): excluded_copyout = self.exclude_copyout.split(',') copyout_state = sdfg.add_state(sdfg.label + '_copyout') - for state in end_states: + for state in end_blocks: sdfg.add_edge(state, copyout_state, sd.InterstateEdge()) for nname, desc in dtypes.deduplicate(output_nodes): @@ -338,8 +344,8 @@ def apply(self, _, sdfg: sd.SDFG): ####################################################### # Step 4: Change all top-level maps and library nodes to GPU schedule - gpu_nodes = set() - for state in sdfg.nodes(): + gpu_nodes: Set[Tuple[SDFGState, nodes.Node]] = set() + for state in sdfg.states(): sdict = state.scope_dict() for node in state.nodes(): if sdict[node] is None: @@ -381,7 +387,7 @@ def apply(self, _, sdfg: sd.SDFG): # inside a GPU kernel. gpu_scalars = {} - nsdfgs = [] + nsdfgs: List[Tuple[nodes.NestedSDFG, SDFGState]] = [] changed = True # Iterates over Tasklets that not inside a GPU kernel. Such Tasklets must be moved inside a GPU kernel only # if they write to GPU memory. The check takes into account the fact that GPU kernels can read host-based @@ -440,7 +446,7 @@ def apply(self, _, sdfg: sd.SDFG): const_syms = xfh.constant_symbols(sdfg) - for state in sdfg.nodes(): + for state in sdfg.states(): sdict = state.scope_dict() for node in state.nodes(): if isinstance(node, nodes.AccessNode) and node.desc(sdfg).transient: @@ -530,63 +536,77 @@ def apply(self, _, sdfg: sd.SDFG): cloned_data = set(cloned_arrays.keys()).union(gpu_scalars.keys()).union(data_already_on_gpu.keys()) - for state in list(sdfg.nodes()): + def _create_copy_out(arrays_used: Set[str]) -> Dict[str, str]: + # Add copy-out nodes + name_mapping = {} + for nname in arrays_used: + # Handle GPU scalars + if nname in gpu_scalars: + hostname = gpu_scalars[nname] + if not hostname: + desc = sdfg.arrays[nname].clone() + desc.storage = dtypes.StorageType.CPU_Heap + desc.transient = True + hostname = sdfg.add_datadesc('host_' + nname, desc, find_new_name=True) + gpu_scalars[nname] = hostname + else: + desc = sdfg.arrays[hostname] + devicename = nname + elif nname in data_already_on_gpu: + hostname = data_already_on_gpu[nname] + if not hostname: + desc = sdfg.arrays[nname].clone() + desc.storage = dtypes.StorageType.CPU_Heap + desc.transient = True + hostname = sdfg.add_datadesc('host_' + nname, desc, find_new_name=True) + data_already_on_gpu[nname] = hostname + else: + desc = sdfg.arrays[hostname] + devicename = nname + else: + desc = sdfg.arrays[nname] + hostname = nname + devicename = cloned_arrays[nname] + + src_array = nodes.AccessNode(devicename, debuginfo=desc.debuginfo) + dst_array = nodes.AccessNode(hostname, debuginfo=desc.debuginfo) + co_state.add_node(src_array) + co_state.add_node(dst_array) + co_state.add_nedge(src_array, dst_array, + memlet.Memlet.from_array(dst_array.data, dst_array.desc(sdfg))) + name_mapping[devicename] = hostname + return name_mapping + + for block in list(sdfg.all_control_flow_blocks()): arrays_used = set() - for e in sdfg.out_edges(state): + for e in block.parent_graph.out_edges(block): # Used arrays = intersection between symbols and cloned data arrays_used.update(set(e.data.free_symbols) & cloned_data) # Create a state and copy out used arrays if len(arrays_used) > 0: - - co_state = sdfg.add_state(state.label + '_icopyout') + co_state = block.parent_graph.add_state(block.label + '_icopyout') # Reconnect outgoing edges to after interim copyout state - for e in sdfg.out_edges(state): - sdutil.change_edge_src(sdfg, state, co_state) + for e in block.parent_graph.out_edges(block): + sdutil.change_edge_src(block.parent_graph, block, co_state) # Add unconditional edge to interim state - sdfg.add_edge(state, co_state, sd.InterstateEdge()) - - # Add copy-out nodes - for nname in arrays_used: - - # Handle GPU scalars - if nname in gpu_scalars: - hostname = gpu_scalars[nname] - if not hostname: - desc = sdfg.arrays[nname].clone() - desc.storage = dtypes.StorageType.CPU_Heap - desc.transient = True - hostname = sdfg.add_datadesc('host_' + nname, desc, find_new_name=True) - gpu_scalars[nname] = hostname - else: - desc = sdfg.arrays[hostname] - devicename = nname - elif nname in data_already_on_gpu: - hostname = data_already_on_gpu[nname] - if not hostname: - desc = sdfg.arrays[nname].clone() - desc.storage = dtypes.StorageType.CPU_Heap - desc.transient = True - hostname = sdfg.add_datadesc('host_' + nname, desc, find_new_name=True) - data_already_on_gpu[nname] = hostname - else: - desc = sdfg.arrays[hostname] - devicename = nname - else: - desc = sdfg.arrays[nname] - hostname = nname - devicename = cloned_arrays[nname] - - src_array = nodes.AccessNode(devicename, debuginfo=desc.debuginfo) - dst_array = nodes.AccessNode(hostname, debuginfo=desc.debuginfo) - co_state.add_node(src_array) - co_state.add_node(dst_array) - co_state.add_nedge(src_array, dst_array, - memlet.Memlet.from_array(dst_array.data, dst_array.desc(sdfg))) - for e in sdfg.out_edges(co_state): + block.parent_graph.add_edge(block, co_state, sd.InterstateEdge()) + mapping = _create_copy_out(arrays_used) + for devicename, hostname in mapping.items(): + for e in block.parent_graph.out_edges(co_state): e.data.replace(devicename, hostname, False) + for block in list(sdfg.all_control_flow_blocks()): + arrays_used = set(block.used_symbols(all_symbols=True, with_contents=False)) & cloned_data + + # Create a state and copy out used arrays + if len(arrays_used) > 0: + co_state = block.parent_graph.add_state_before(block, block.label + '_icopyout') + mapping = _create_copy_out(arrays_used) + for devicename, hostname in mapping.items(): + block.replace_meta_accesses({devicename: hostname}) + # Step 9: Simplify if not self.simplify: return diff --git a/dace/transformation/interstate/loop_detection.py b/dace/transformation/interstate/loop_detection.py index 8081447132..fbd627eeeb 100644 --- a/dace/transformation/interstate/loop_detection.py +++ b/dace/transformation/interstate/loop_detection.py @@ -12,7 +12,7 @@ # NOTE: This class extends PatternTransformation directly in order to not show up in the matches -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class DetectLoop(transformation.PatternTransformation): """ Detects a for-loop construct from an SDFG. """ diff --git a/dace/transformation/interstate/loop_lifting.py b/dace/transformation/interstate/loop_lifting.py index 072c2519ed..746910964c 100644 --- a/dace/transformation/interstate/loop_lifting.py +++ b/dace/transformation/interstate/loop_lifting.py @@ -8,7 +8,7 @@ @properties.make_properties -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class LoopLifting(DetectLoop, transformation.MultiStateTransformation): def can_be_applied(self, graph: transformation.ControlFlowRegion, expr_index: int, sdfg: transformation.SDFG, @@ -95,5 +95,5 @@ def apply(self, graph: ControlFlowRegion, sdfg: SDFG): for n in full_body: graph.remove_node(n) - sdfg.root_sdfg.using_experimental_blocks = True + sdfg.root_sdfg.using_explicit_control_flow = True sdfg.reset_cfg_list() diff --git a/dace/transformation/interstate/loop_peeling.py b/dace/transformation/interstate/loop_peeling.py index c2e50cd37a..94174ab309 100644 --- a/dace/transformation/interstate/loop_peeling.py +++ b/dace/transformation/interstate/loop_peeling.py @@ -1,47 +1,35 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -""" Loop unroll transformation """ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +""" Loop peeling transformation """ import sympy as sp -from typing import Optional +from typing import List, Optional from dace import sdfg as sd +from dace import symbolic from dace.sdfg.state import ControlFlowRegion from dace.properties import Property, make_properties, CodeBlock -from dace.sdfg import graph as gr -from dace.sdfg import utils as sdutil from dace.symbolic import pystr_to_symbolic -from dace.transformation.interstate.loop_detection import (DetectLoop, find_for_loop) from dace.transformation.interstate.loop_unroll import LoopUnroll -from dace.transformation.transformation import experimental_cfg_block_compatible +from dace.transformation.passes.analysis import loop_analysis +from dace.transformation.transformation import explicit_cf_compatible @make_properties -@experimental_cfg_block_compatible +@explicit_cf_compatible class LoopPeeling(LoopUnroll): """ - Splits the first `count` iterations of a state machine for-loop into - multiple, separate states. + Splits the first `count` iterations of loop into multiple, separate control flow regions (one per iteration). """ begin = Property( dtype=bool, default=True, - desc='If True, peels loop from beginning (first `count` ' - 'iterations), otherwise peels last `count` iterations.', + desc='If True, peels loop from beginning (first `count` iterations), otherwise peels last `count` iterations.', ) def can_be_applied(self, graph, expr_index, sdfg, permissive=False): if not super().can_be_applied(graph, expr_index, sdfg, permissive): return False - - guard = self.loop_guard - begin = self.loop_begin - - # If loop cannot be detected, fail - found = find_for_loop(sdfg, guard, begin) - if found is None: - return False - return True def _modify_cond(self, condition, var, step): @@ -77,90 +65,55 @@ def _modify_cond(self, condition, var, step): return res def apply(self, graph: ControlFlowRegion, sdfg: sd.SDFG): - #################################################################### # Obtain loop information - begin: sd.SDFGState = self.loop_begin - after_state: sd.SDFGState = self.exit_state - - # Obtain iteration variable, range, and stride - condition_edge = self.loop_condition_edge() - not_condition_edge = self.loop_exit_edge() - itervar, rng, loop_struct = self.loop_information() - - # Get loop states - loop_states = self.loop_body() - first_id = loop_states.index(begin) - last_state = loop_struct[1] - last_id = loop_states.index(last_state) - loop_subgraph = gr.SubgraphView(graph, loop_states) - - #################################################################### - # Transform + start = loop_analysis.get_init_assignment(self.loop) + end = loop_analysis.get_loop_end(self.loop) + stride = loop_analysis.get_loop_stride(self.loop) + is_symbolic = any([symbolic.issymbolic(r) for r in (start, end)]) if self.begin: - # If begin, change initialization assignment and prepend states before - # guard - init_edges = [] - before_states = loop_struct[0] - for before_state in before_states: - init_edge = self.loop_init_edge() - init_edge.data.assignments[itervar] = str(rng[0] + self.count * rng[2]) - init_edges.append(init_edge) - append_states = before_states - - # Add `count` states, each with instantiated iteration variable + # Create states for loop subgraph + peeled_iterations: List[ControlFlowRegion] = [] for i in range(self.count): - # Instantiate loop states with iterate value - state_name: str = 'start_' + itervar + str(i * rng[2]) - state_name = state_name.replace('-', 'm').replace('+', 'p').replace('*', 'M').replace('/', 'D') - new_states = self.instantiate_loop( - sdfg, - loop_states, - loop_subgraph, - itervar, - rng[0] + i * rng[2], - state_name, - ) - - # Connect states to before the loop with unconditional edges - for append_state in append_states: - graph.add_edge(append_state, new_states[first_id], sd.InterstateEdge()) - append_states = [new_states[last_id]] - - # Reconnect edge to guard state from last peeled iteration - for append_state in append_states: - if append_state not in before_states: - for init_edge in init_edges: - graph.remove_edge(init_edge) - graph.add_edge(append_state, init_edge.dst, init_edges[0].data) + # Instantiate loop contents as a new control flow region with iterate value. + current_index = start + (i * stride) + iteration_region = self.instantiate_loop_iteration(graph, self.loop, current_index, + str(i) if is_symbolic else None) + + # Connect iterations with unconditional edges + if len(peeled_iterations) > 0: + graph.add_edge(peeled_iterations[-1], iteration_region, sd.InterstateEdge()) + peeled_iterations.append(iteration_region) + + # Connect the peeled iterations to the remainder of the loop and adjust the remaining iteration bounds. + if peeled_iterations: + for ie in graph.in_edges(self.loop): + graph.add_edge(ie.src, peeled_iterations[0], ie.data) + graph.remove_edge(ie) + graph.add_edge(peeled_iterations[-1], self.loop, sd.InterstateEdge()) + + new_start = symbolic.evaluate(start + (self.count * stride), sdfg.constants) + self.loop.init_statement = CodeBlock(f'{self.loop.loop_variable} = {new_start}') else: - # If begin, change initialization assignment and prepend states before - # guard - itervar_sym = pystr_to_symbolic(itervar) - condition_edge.data.condition = CodeBlock(self._modify_cond(condition_edge.data.condition, itervar, rng[2])) - not_condition_edge.data.condition = CodeBlock( - self._modify_cond(not_condition_edge.data.condition, itervar, rng[2])) - prepend_state = after_state - - # Add `count` states, each with instantiated iteration variable + # Create states for loop subgraph + peeled_iterations: List[ControlFlowRegion] = [] for i in reversed(range(self.count)): - # Instantiate loop states with iterate value - state_name: str = 'end_' + itervar + str(-i * rng[2]) - state_name = state_name.replace('-', 'm').replace('+', 'p').replace('*', 'M').replace('/', 'D') - new_states = self.instantiate_loop( - sdfg, - loop_states, - loop_subgraph, - itervar, - itervar_sym + i * rng[2], - state_name, - ) - - # Connect states to before the loop with unconditional edges - graph.add_edge(new_states[last_id], prepend_state, sd.InterstateEdge()) - prepend_state = new_states[first_id] - - # Reconnect edge to guard state from last peeled iteration - if prepend_state != after_state: - graph.remove_edge(not_condition_edge) - graph.add_edge(not_condition_edge.src, prepend_state, not_condition_edge.data) + # Instantiate loop contents as a new control flow region with iterate value. + current_index = pystr_to_symbolic(self.loop.loop_variable) + (i * stride) + iteration_region = self.instantiate_loop_iteration(graph, self.loop, current_index, + str(i) if is_symbolic else None) + + # Connect iterations with unconditional edges + if len(peeled_iterations) > 0: + graph.add_edge(iteration_region, peeled_iterations[-1], sd.InterstateEdge()) + peeled_iterations.append(iteration_region) + + # Connect the peeled iterations to the remainder of the loop and adjust the remaining iteration bounds. + if peeled_iterations: + for oe in graph.out_edges(self.loop): + graph.add_edge(peeled_iterations[0], oe.dst, oe.data) + graph.remove_edge(oe) + graph.add_edge(self.loop, peeled_iterations[-1], sd.InterstateEdge()) + + new_cond = CodeBlock(self._modify_cond(self.loop.loop_condition, self.loop.loop_variable, stride)) + self.loop.loop_condition = new_cond diff --git a/dace/transformation/interstate/loop_to_map.py b/dace/transformation/interstate/loop_to_map.py index 58b9d51c59..9b1c460372 100644 --- a/dace/transformation/interstate/loop_to_map.py +++ b/dace/transformation/interstate/loop_to_map.py @@ -1,23 +1,21 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Loop to map transformation """ from collections import defaultdict import copy -import itertools import sympy as sp -import networkx as nx -from typing import Dict, List, Optional, Set, Tuple +from typing import Dict, List, Set -from dace import data as dt, dtypes, memlet, nodes, registry, sdfg as sd, symbolic, subsets -from dace.properties import Property, make_properties, CodeBlock +from dace import data as dt, dtypes, memlet, nodes, sdfg as sd, symbolic, subsets, properties +from dace.codegen.tools.type_inference import infer_expr_type from dace.sdfg import graph as gr, nodes -from dace.sdfg import SDFG, SDFGState, InterstateEdge +from dace.sdfg import SDFG, SDFGState from dace.sdfg import utils as sdutil -from dace.sdfg.analysis import cfg -from dace.frontend.python.astutils import ASTFindReplace -from dace.transformation.interstate.loop_detection import (DetectLoop, find_for_loop) +from dace.sdfg.analysis import cfg as cfg_analysis +from dace.sdfg.state import BreakBlock, ContinueBlock, ControlFlowRegion, LoopRegion, ReturnBlock import dace.transformation.helpers as helpers from dace.transformation import transformation as xf +from dace.transformation.passes.analysis import loop_analysis def _check_range(subset, a, itersym, b, step): @@ -74,70 +72,59 @@ def _sanitize_by_index(indices: Set[int], subset: subsets.Subset) -> subsets.Ran return type(subset)([t for i, t in enumerate(subset) if i in indices]) -@make_properties -@xf.single_level_sdfg_only -class LoopToMap(DetectLoop, xf.MultiStateTransformation): - """Convert a control flow loop into a dataflow map. Currently only supports - the simple case where there is no overlap between inputs and outputs in - the body of the loop, and where the loop body only consists of a single - state. +@properties.make_properties +@xf.explicit_cf_compatible +class LoopToMap(xf.MultiStateTransformation): + """ + Convert a control flow loop into a dataflow map. Currently only supports the simple case where there is no overlap + between inputs and outputs in the body of the loop, and where the loop body only consists of a single state. """ - itervar = Property( - dtype=str, - allow_none=True, - default=None, - desc='The name of the iteration variable (optional).', - ) - - def can_be_applied(self, graph: SDFGState, expr_index: int, sdfg: SDFG, permissive: bool = False): - # Is this even a loop - if not super().can_be_applied(graph, expr_index, sdfg, permissive): - return False + loop = xf.PatternNode(LoopRegion) - begin = self.loop_begin + @classmethod + def expressions(cls): + return [sdutil.node_path_graph(cls.loop)] - # Guard state should not contain any dataflow - if expr_index <= 1: - guard = self.loop_guard - if len(guard.nodes()) != 0: - return False + def can_be_applied(self, graph, expr_index, sdfg, permissive = False): + # If loop information cannot be determined, fail. + start = loop_analysis.get_init_assignment(self.loop) + end = loop_analysis.get_loop_end(self.loop) + step = loop_analysis.get_loop_stride(self.loop) + itervar = self.loop.loop_variable + if start is None or end is None or step is None or itervar is None: + return False - # If loop cannot be detected, fail - found = self.loop_information(itervar=self.itervar) - if not found: + sset = {} + sset.update(sdfg.symbols) + sset.update(sdfg.arrays) + t = dtypes.result_type_of(infer_expr_type(start, sset), infer_expr_type(step, sset), infer_expr_type(end, sset)) + # We may only convert something to map if the bounds are all integer-derived types. Otherwise most map schedules + # except for sequential would be invalid. + if not t in dtypes.INTEGER_TYPES: return False - itervar, (start, end, step), (_, body_end) = found + # Loops containing break, continue, or returns may not be turned into a map. + for blk in self.loop.all_control_flow_blocks(): + if isinstance(blk, (BreakBlock, ContinueBlock, ReturnBlock)): + return False - # We cannot handle symbols read from data containers unless they are - # scalar + # We cannot handle symbols read from data containers unless they are scalar. for expr in (start, end, step): if symbolic.contains_sympy_functions(expr): return False - in_order_states = list(cfg.blockorder_topological_sort(sdfg)) - loop_begin_idx = in_order_states.index(begin) - loop_end_idx = in_order_states.index(body_end) - - if loop_end_idx < loop_begin_idx: # Malformed loop - return False - - # Find all loop-body states - states: List[SDFGState] = self.loop_body() - - assert (body_end in states) - - write_set: Set[str] = set() - for state in states: - _, wset = state.read_and_write_sets() - write_set |= wset + _, write_set = self.loop.read_and_write_sets() + loop_states = set(self.loop.all_states()) + all_loop_blocks = set(self.loop.all_control_flow_blocks()) # Collect symbol reads and writes from inter-state assignments + in_order_loop_blocks = list(cfg_analysis.blockorder_topological_sort(self.loop, recursive=True, + ignore_nonstate_blocks=False)) symbols_that_may_be_used: Set[str] = {itervar} used_before_assignment: Set[str] = set() - for state in states: - for e in sdfg.out_edges(state): + for block in in_order_loop_blocks: + for e in block.parent_graph.out_edges(block): # Collect read-before-assigned symbols (this works because the states are always in order, # see above call to `blockorder_topological_sort`) read_symbols = e.data.read_symbols() @@ -159,12 +146,12 @@ def can_be_applied(self, graph: SDFGState, expr_index: int, sdfg: SDFG, permissi # Get access nodes from other states to isolate local loop variables other_access_nodes: Set[str] = set() - for state in sdfg.nodes(): - if state in states: + for state in sdfg.states(): + if state in loop_states: continue other_access_nodes |= set(n.data for n in state.data_nodes() if sdfg.arrays[n.data].transient) # Add non-transient nodes from loop state - for state in states: + for state in loop_states: other_access_nodes |= set(n.data for n in state.data_nodes() if not sdfg.arrays[n.data].transient) write_memlets: Dict[str, List[memlet.Memlet]] = defaultdict(list) @@ -173,7 +160,7 @@ def can_be_applied(self, graph: SDFGState, expr_index: int, sdfg: SDFG, permissi a = sp.Wild('a', exclude=[itersym]) b = sp.Wild('b', exclude=[itersym]) - for state in states: + for state in loop_states: for dn in state.data_nodes(): if dn.data not in other_access_nodes: continue @@ -196,7 +183,7 @@ def can_be_applied(self, graph: SDFGState, expr_index: int, sdfg: SDFG, permissi write_memlets[dn.data].append(e.data) # After looping over relevant writes, consider reads that may overlap - for state in states: + for state in loop_states: for dn in state.data_nodes(): if dn.data not in other_access_nodes: continue @@ -212,31 +199,47 @@ def can_be_applied(self, graph: SDFGState, expr_index: int, sdfg: SDFG, permissi # Consider reads in inter-state edges (could be in assignments or in condition) isread_set: Set[memlet.Memlet] = set() - for s in states: - for e in sdfg.all_edges(s): - isread_set |= set(e.data.get_read_memlets(sdfg.arrays)) + for e in self.loop.all_interstate_edges(): + isread_set |= set(e.data.get_read_memlets(sdfg.arrays)) for mmlt in isread_set: if mmlt.data in write_memlets: if not self.test_read_memlet(sdfg, None, None, itersym, itervar, start, end, step, write_memlets, mmlt, mmlt.subset): return False - # Check that the iteration variable and other symbols are not used on other edges or states - # before they are reassigned - for state in in_order_states[loop_begin_idx + 1:]: - if state in states: + # Check that the iteration variable and other symbols are not used on other edges or blocks before they are + # reassigned. + in_order_blocks = list(cfg_analysis.blockorder_topological_sort(sdfg, recursive=True, + ignore_nonstate_blocks=False)) + # First check the outgoing edges of the loop itself. + reassigned_symbols: Set[str] = None + for oe in graph.out_edges(self.loop): + if symbols_that_may_be_used & oe.data.read_symbols(): + return False + # Check for symbols that are set by all outgoing edges + # TODO: Handle case of subset of out_edges + if reassigned_symbols is None: + reassigned_symbols = set(oe.data.assignments.keys()) + else: + reassigned_symbols &= oe.data.assignments.keys() + # Remove reassigned symbols + if reassigned_symbols is not None: + symbols_that_may_be_used -= reassigned_symbols + loop_idx = in_order_blocks.index(self.loop) + for block in in_order_blocks[loop_idx + 1:]: + if block in all_loop_blocks: continue # Don't continue in this direction, as all loop symbols have been reassigned if not symbols_that_may_be_used: break # Check state contents - if symbols_that_may_be_used & state.free_symbols: + if symbols_that_may_be_used & block.free_symbols: return False # Check inter-state edges - reassigned_symbols: Set[str] = None - for e in sdfg.out_edges(state): + reassigned_symbols = None + for e in block.parent_graph.out_edges(block): if symbols_that_may_be_used & e.data.read_symbols(): return False @@ -346,217 +349,116 @@ def _is_array_thread_local(self, name: str, itervar: str, sdfg: SDFG, states: Li return False return True - def apply(self, _, sdfg: sd.SDFG): + def apply(self, graph: ControlFlowRegion, sdfg: sd.SDFG): from dace.sdfg.propagation import align_memlet # Obtain loop information - itervar, (start, end, step), (_, body_end) = self.loop_information(itervar=self.itervar) - states = self.loop_body() - body: sd.SDFGState = self.loop_begin - exit_state = self.exit_state - entry_edge = self.loop_condition_edge() - init_edge = self.loop_init_edge() - after_edge = self.loop_exit_edge() - condition_edge = self.loop_condition_edge() - increment_edge = self.loop_increment_edge() + itervar = self.loop.loop_variable + start = loop_analysis.get_init_assignment(self.loop) + end = loop_analysis.get_loop_end(self.loop) + step = loop_analysis.get_loop_stride(self.loop) nsdfg = None # Nest loop-body states - if len(states) > 1: - - # Find read/write sets - read_set, write_set = set(), set() - for state in states: - rset, wset = state.read_and_write_sets() - read_set |= rset - write_set |= wset - # Add to write set also scalars between tasklets - for src_node in state.nodes(): - if not isinstance(src_node, nodes.Tasklet): - continue - for dst_node in state.nodes(): - if src_node is dst_node: - continue - if not isinstance(dst_node, nodes.Tasklet): - continue - for e in state.edges_between(src_node, dst_node): - if e.data.data and e.data.data in sdfg.arrays: - write_set.add(e.data.data) - # Add data from edges - for src in states: - for dst in states: - for edge in sdfg.edges_between(src, dst): - for s in edge.data.free_symbols: - if s in sdfg.arrays: - read_set.add(s) - - # Find NestedSDFG's unique data - rw_set = read_set | write_set - unique_set = set() - for name in rw_set: - if not sdfg.arrays[name].transient: + states = set(self.loop.all_states()) + # Find read/write sets + read_set, write_set = set(), set() + for state in self.loop.all_states(): + rset, wset = state.read_and_write_sets() + read_set |= rset + write_set |= wset + # Add to write set also scalars between tasklets + for src_node in state.nodes(): + if not isinstance(src_node, nodes.Tasklet): continue - found = False - for state in sdfg.states(): - if state in states: + for dst_node in state.nodes(): + if src_node is dst_node: continue - for node in state.nodes(): - if (isinstance(node, nodes.AccessNode) and node.data == name): - found = True - break - if not found and self._is_array_thread_local(name, itervar, sdfg, states): - unique_set.add(name) - - # Find NestedSDFG's connectors - read_set = {n for n in read_set if n not in unique_set or not sdfg.arrays[n].transient} - write_set = {n for n in write_set if n not in unique_set or not sdfg.arrays[n].transient} - - # Create NestedSDFG and add all loop-body states and edges - # Also, find defined symbols in NestedSDFG - fsymbols = set(sdfg.free_symbols) - new_body = sdfg.add_state('single_state_body') - nsdfg = SDFG("loop_body", constants=sdfg.constants_prop, parent=new_body) - nsdfg.add_node(body, is_start_state=True) - body.parent = nsdfg - nexit_state = nsdfg.add_state('exit') - nsymbols = dict() - for state in states: - if state is body: - continue - nsdfg.add_node(state) - state.parent = nsdfg - for state in states: - if state is body: + if not isinstance(dst_node, nodes.Tasklet): + continue + for e in state.edges_between(src_node, dst_node): + if e.data.data and e.data.data in sdfg.arrays: + write_set.add(e.data.data) + # Add data from edges + for edge in self.loop.all_interstate_edges(): + for s in edge.data.free_symbols: + if s in sdfg.arrays: + read_set.add(s) + + # Find NestedSDFG's / Loop's unique data + rw_set = read_set | write_set + unique_set = set() + for name in rw_set: + if not sdfg.arrays[name].transient: + continue + found = False + for state in sdfg.states(): + if state in states: continue - for src, dst, data in sdfg.in_edges(state): - nsymbols.update({s: sdfg.symbols[s] for s in data.assignments.keys() if s in sdfg.symbols}) - nsdfg.add_edge(src, dst, data) - nsdfg.add_edge(body_end, nexit_state, InterstateEdge()) - - increment_edge = None - - # Specific instructions for loop type - if self.expr_index <= 1: # Natural loop with guard - guard = self.loop_guard - - # Move guard -> body edge to guard -> new_body - for e in sdfg.edges_between(guard, body): - sdfg.remove_edge(e) - condition_edge = sdfg.add_edge(e.src, new_body, e.data) - # Move body_end -> guard edge to new_body -> guard - for e in sdfg.edges_between(body_end, guard): - sdfg.remove_edge(e) - increment_edge = sdfg.add_edge(new_body, e.dst, e.data) - - - elif 1 < self.expr_index <= 3 or 5 <= self.expr_index <= 7: # Rotated loop - entrystate = self.entry_state - latch = self.loop_latch - - # Move entry edge to entry -> new_body - for src, dst, data, in sdfg.edges_between(entrystate, body): - init_edge = sdfg.add_edge(src, new_body, data) - - # Move body_end -> latch to new_body -> latch - for src, dst, data in sdfg.edges_between(latch, exit_state): - after_edge = sdfg.add_edge(new_body, dst, data) - - elif self.expr_index == 4: # Self-loop - entrystate = self.entry_state - - # Move entry edge to entry -> new_body - for src, dst, data in sdfg.edges_between(entrystate, body): - init_edge = sdfg.add_edge(src, new_body, data) - for src, dst, data in sdfg.edges_between(body, exit_state): - after_edge = sdfg.add_edge(new_body, dst, data) - - - # Delete loop-body states and edges from parent SDFG - sdfg.remove_nodes_from(states) - - # Add NestedSDFG arrays - for name in read_set | write_set: - nsdfg.arrays[name] = copy.deepcopy(sdfg.arrays[name]) - nsdfg.arrays[name].transient = False - for name in unique_set: - nsdfg.arrays[name] = sdfg.arrays[name] - del sdfg.arrays[name] - - # Add NestedSDFG node - cnode = new_body.add_nested_sdfg(nsdfg, None, read_set, write_set) - if sdfg.parent: - for s, m in sdfg.parent_nsdfg_node.symbol_mapping.items(): - if s not in cnode.symbol_mapping: - cnode.symbol_mapping[s] = m - nsdfg.add_symbol(s, sdfg.symbols[s]) - for name in read_set: - r = new_body.add_read(name) - new_body.add_edge(r, None, cnode, name, memlet.Memlet.from_array(name, sdfg.arrays[name])) - for name in write_set: - w = new_body.add_write(name) - new_body.add_edge(cnode, name, w, None, memlet.Memlet.from_array(name, sdfg.arrays[name])) - - # Fix SDFG symbols - for sym in sdfg.free_symbols - fsymbols: - if sym in sdfg.symbols: - del sdfg.symbols[sym] - for sym, dtype in nsymbols.items(): - nsdfg.symbols[sym] = dtype - - # Change body state reference - body = new_body + for node in state.nodes(): + if (isinstance(node, nodes.AccessNode) and node.data == name): + found = True + break + if not found and self._is_array_thread_local(name, itervar, sdfg, states): + unique_set.add(name) + + # Find NestedSDFG's connectors + read_set = {n for n in read_set if n not in unique_set or not sdfg.arrays[n].transient} + write_set = {n for n in write_set if n not in unique_set or not sdfg.arrays[n].transient} + + # Create NestedSDFG and add the loop contents to it. Gaher symbols defined in the NestedSDFG. + fsymbols = set(sdfg.free_symbols) + body = graph.add_state_before(self.loop, 'single_state_body') + nsdfg = SDFG('loop_body', constants=sdfg.constants_prop, parent=body) + nsdfg.add_node(self.loop.start_block, is_start_block=True) + nsymbols = dict() + for block in self.loop.nodes(): + if block is self.loop.start_block: + continue + nsdfg.add_node(block) + for e in self.loop.edges(): + nsymbols.update({s: sdfg.symbols[s] for s in e.data.assignments.keys() if s in sdfg.symbols}) + nsdfg.add_edge(e.src, e.dst, e.data) + + # Add NestedSDFG arrays + for name in read_set | write_set: + nsdfg.arrays[name] = copy.deepcopy(sdfg.arrays[name]) + nsdfg.arrays[name].transient = False + for name in unique_set: + nsdfg.arrays[name] = sdfg.arrays[name] + del sdfg.arrays[name] + + # Add NestedSDFG node + cnode = body.add_nested_sdfg(nsdfg, body, read_set, write_set) + if sdfg.parent: + for s, m in sdfg.parent_nsdfg_node.symbol_mapping.items(): + if s not in cnode.symbol_mapping: + cnode.symbol_mapping[s] = m + nsdfg.add_symbol(s, sdfg.symbols[s]) + for name in read_set: + r = body.add_read(name) + body.add_edge(r, None, cnode, name, memlet.Memlet.from_array(name, sdfg.arrays[name])) + for name in write_set: + w = body.add_write(name) + body.add_edge(cnode, name, w, None, memlet.Memlet.from_array(name, sdfg.arrays[name])) + + # Fix SDFG symbols + for sym in sdfg.free_symbols - fsymbols: + if sym in sdfg.symbols: + del sdfg.symbols[sym] + for sym, dtype in nsymbols.items(): + nsdfg.symbols[sym] = dtype if (step < 0) == True: - # If step is negative, we have to flip start and end to produce a - # correct map with a positive increment + # If step is negative, we have to flip start and end to produce a correct map with a positive increment. start, end, step = end, start, -step - reentry_assignments = {k: v for k, v in condition_edge.data.assignments.items() if k != itervar} - - # If necessary, make a nested SDFG with assignments - symbols_to_remove = set() - if len(reentry_assignments) > 0: - nsdfg = helpers.nest_state_subgraph(sdfg, body, gr.SubgraphView(body, body.nodes())) - for sym in entry_edge.data.free_symbols: - if sym in nsdfg.symbol_mapping or sym in nsdfg.in_connectors: - continue - if sym in sdfg.symbols: - nsdfg.symbol_mapping[sym] = symbolic.pystr_to_symbolic(sym) - nsdfg.sdfg.add_symbol(sym, sdfg.symbols[sym]) - elif sym in sdfg.arrays: - if sym in nsdfg.sdfg.arrays: - raise NotImplementedError - rnode = body.add_read(sym) - nsdfg.add_in_connector(sym) - desc = copy.deepcopy(sdfg.arrays[sym]) - desc.transient = False - nsdfg.sdfg.add_datadesc(sym, desc) - body.add_edge(rnode, None, nsdfg, sym, memlet.Memlet(sym)) - for name, desc in nsdfg.sdfg.arrays.items(): - if desc.transient and not self._is_array_thread_local(name, itervar, nsdfg.sdfg, nsdfg.sdfg.states()): - odesc = copy.deepcopy(desc) - sdfg.arrays[name] = odesc - desc.transient = False - wnode = body.add_access(name) - nsdfg.add_out_connector(name) - body.add_edge(nsdfg, name, wnode, None, memlet.Memlet.from_array(name, odesc)) - - nstate = nsdfg.sdfg.node(0) - init_state = nsdfg.sdfg.add_state_before(nstate) - nisedge = nsdfg.sdfg.edges_between(init_state, nstate)[0] - nisedge.data.assignments = reentry_assignments - symbols_to_remove = set(nisedge.data.assignments.keys()) - for k in nisedge.data.assignments.keys(): - if k in nsdfg.symbol_mapping: - del nsdfg.symbol_mapping[k] - condition_edge.data.assignments = {} - source_nodes = body.source_nodes() sink_nodes = body.sink_nodes() # Check intermediate notes - intermediate_nodes = [] + intermediate_nodes: List[nodes.AccessNode] = [] for node in body.nodes(): if isinstance(node, nodes.AccessNode) and body.in_degree(node) > 0 and node not in sink_nodes: # Scalars written without WCR must be thread-local @@ -590,7 +492,7 @@ def apply(self, _, sdfg: sd.SDFG): # Direct edges among source and sink access nodes must pass through a tasklet. # We first gather them and handle them later. - direct_edges = set() + direct_edges: Set[gr.MultiConnectorEdge[memlet.Memlet]] = set() for n1 in source_nodes: if not isinstance(n1, nodes.AccessNode): continue @@ -623,7 +525,7 @@ def apply(self, _, sdfg: sd.SDFG): body.add_edge_pair(exit, e.src, n, new_memlet, internal_connector=e.src_conn) else: body.add_nedge(n, exit, memlet.Memlet()) - intermediate_sinks = {} + intermediate_sinks: Dict[str, nodes.AccessNode] = {} for n in intermediate_nodes: if isinstance(sdfg.arrays[n.data], dt.View): continue @@ -636,8 +538,8 @@ def apply(self, _, sdfg: sd.SDFG): # Here we handle the direct edges among source and sink access nodes. for e in direct_edges: - src = e.src.data - dst = e.dst.data + src: str = e.src.data + dst: str = e.dst.data if e.data.subset.num_elements() == 1: t = body.add_tasklet(f"{n1}_{n2}", {'__inp'}, {'__out'}, "__out = __inp") src_conn, dst_conn = '__out', '__inp' @@ -667,44 +569,19 @@ def apply(self, _, sdfg: sd.SDFG): if not source_nodes and not sink_nodes: body.add_nedge(entry, exit, memlet.Memlet()) - # Get rid of the loop exit condition edge (it will be readded below) - if self.expr_index not in (5, 6, 7): - sdfg.remove_edge(after_edge) - - # Remove the assignment on the edge to the guard - for e in [init_edge, increment_edge]: - if e is None: - continue - if itervar in e.data.assignments: - del e.data.assignments[itervar] - - # Remove the condition on the entry edge - condition_edge.data.condition = CodeBlock("1") - - # Get rid of backedge to guard - if increment_edge is not None: - sdfg.remove_edge(increment_edge) + # Redirect outgoing edges connected to the loop to connect to the body state instead. + for e in graph.out_edges(self.loop): + graph.add_edge(body, e.dst, e.data) + # Delete the loop and connected edges. + graph.remove_node(self.loop) - # Route body directly to after state, maintaining any other assignments - # it might have had - sdfg.add_edge(body, exit_state, sd.InterstateEdge(assignments=after_edge.data.assignments)) - - # If this had made the iteration variable a free symbol, we can remove - # it from the SDFG symbols + # If this had made the iteration variable a free symbol, we can remove it from the SDFG symbols if itervar in sdfg.free_symbols: sdfg.remove_symbol(itervar) - for sym in symbols_to_remove: - if sym in sdfg.symbols and helpers.is_symbol_unused(sdfg, sym): - sdfg.remove_symbol(sym) - - # Reset all nested SDFG parent pointers - if nsdfg is not None: - if isinstance(nsdfg, nodes.NestedSDFG): - nsdfg = nsdfg.sdfg - - for nstate in nsdfg.nodes(): - for nnode in nstate.nodes(): - if isinstance(nnode, nodes.NestedSDFG): - nnode.sdfg.parent_nsdfg_node = nnode - nnode.sdfg.parent = nstate - nnode.sdfg.parent_sdfg = nsdfg + + sdfg.reset_cfg_list() + for n, p in sdfg.all_nodes_recursive(): + if isinstance(n, nodes.NestedSDFG): + n.sdfg.parent = p + n.sdfg.parent_nsdfg_node = n + n.sdfg.parent_sdfg = p.sdfg diff --git a/dace/transformation/interstate/loop_unroll.py b/dace/transformation/interstate/loop_unroll.py index 663745c0d6..a23777c749 100644 --- a/dace/transformation/interstate/loop_unroll.py +++ b/dace/transformation/interstate/loop_unroll.py @@ -1,135 +1,116 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Loop unroll transformation """ import copy -from typing import List +from typing import List, Optional from dace import sdfg as sd, symbolic from dace.properties import Property, make_properties -from dace.sdfg import graph as gr from dace.sdfg import utils as sdutil -from dace.sdfg.state import ControlFlowRegion +from dace.sdfg.state import ControlFlowRegion, LoopRegion from dace.frontend.python.astutils import ASTFindReplace -from dace.transformation.interstate.loop_detection import (DetectLoop, find_for_loop) from dace.transformation import transformation as xf +from dace.transformation.passes.analysis import loop_analysis @make_properties -@xf.experimental_cfg_block_compatible -class LoopUnroll(DetectLoop, xf.MultiStateTransformation): - """ Unrolls a state machine for-loop into multiple states """ +@xf.explicit_cf_compatible +class LoopUnroll(xf.MultiStateTransformation): + """ Unrolls a for-loop into multiple individual control flow regions """ + + loop = xf.PatternNode(LoopRegion) count = Property( dtype=int, default=0, - desc='Number of iterations to unroll, or zero for all ' - 'iterations (loop must be constant-sized for 0)', + desc='Number of iterations to unroll, or zero for all iterations (loop must be constant-sized for 0)', ) - def can_be_applied(self, graph, expr_index, sdfg, permissive=False): - # Is this even a loop - if not super().can_be_applied(graph, expr_index, sdfg, permissive): - return False + inline_iterations = Property(dtype=bool, default=True, + desc="Whether or not to inline individual iterations' CFGs after unrolling") - found = self.loop_information() + @classmethod + def expressions(cls): + return [sdutil.node_path_graph(cls.loop)] - # If loop cannot be detected, fail - if not found: + def can_be_applied(self, graph, expr_index, sdfg, permissive=False): + # If loop information cannot be determined, fail. + start = loop_analysis.get_init_assignment(self.loop) + end = loop_analysis.get_loop_end(self.loop) + step = loop_analysis.get_loop_stride(self.loop) + itervar = self.loop.loop_variable + if start is None or end is None or step is None or itervar is None: return False - _, rng, _ = found # If loop stride is not specialized or constant-sized, fail - if symbolic.issymbolic(rng[2], sdfg.constants): + if symbolic.issymbolic(step, sdfg.constants): return False # If loop range diff is not constant-sized, fail - if symbolic.issymbolic(rng[1] - rng[0], sdfg.constants): + if symbolic.issymbolic(end - start, sdfg.constants): return False return True def apply(self, graph: ControlFlowRegion, sdfg): - # Obtain loop information - begin: sd.SDFGState = self.loop_begin - after_state: sd.SDFGState = self.exit_state - - # Obtain iteration variable, range, and stride, together with the last - # state(s) before the loop and the last loop state. - itervar, rng, loop_struct = self.loop_information() - # Loop must be fully unrollable for now. if self.count != 0: - raise NotImplementedError # TODO(later) + raise NotImplementedError # TODO(later) - # Get loop states - loop_states = self.loop_body() - first_id = loop_states.index(begin) - last_state = loop_struct[1] - last_id = loop_states.index(last_state) - loop_subgraph = gr.SubgraphView(graph, loop_states) + # Obtain loop information + start = loop_analysis.get_init_assignment(self.loop) + end = loop_analysis.get_loop_end(self.loop) + stride = loop_analysis.get_loop_stride(self.loop) try: - start, end, stride = (r for r in rng) stride = symbolic.evaluate(stride, sdfg.constants) loop_diff = int(symbolic.evaluate(end - start + 1, sdfg.constants)) - is_symbolic = any([symbolic.issymbolic(r) for r in rng[:2]]) + is_symbolic = any([symbolic.issymbolic(r) for r in (start, end)]) except TypeError: raise TypeError('Loop difference and strides cannot be symbolic.') - # Create states for loop subgraph - unrolled_states = [] + # Create states for loop subgraph + unrolled_iterations: List[ControlFlowRegion] = [] for i in range(0, loop_diff, stride): + # Instantiate loop contents as a new control flow region with iterate value. current_index = start + i - # Instantiate loop states with iterate value - new_states = self.instantiate_loop(sdfg, loop_states, loop_subgraph, itervar, current_index, - str(i) if is_symbolic else None) + iteration_region = self.instantiate_loop_iteration(graph, self.loop, current_index, + str(i) if is_symbolic else None) # Connect iterations with unconditional edges - if len(unrolled_states) > 0: - graph.add_edge(unrolled_states[-1][1], new_states[first_id], sd.InterstateEdge()) - - unrolled_states.append((new_states[first_id], new_states[last_id])) - - # Get any assignments that might be on the edge to the after state - after_assignments = self.loop_exit_edge().data.assignments - - # Connect new states to before and after states without conditions - if unrolled_states: - before_states = loop_struct[0] - for before_state in before_states: - graph.add_edge(before_state, unrolled_states[0][0], sd.InterstateEdge()) - graph.add_edge(unrolled_states[-1][1], after_state, sd.InterstateEdge(assignments=after_assignments)) - - # Remove old states from SDFG - guard_or_latch = self.loop_meta_states() - graph.remove_nodes_from(guard_or_latch + loop_states) - - def instantiate_loop( - self, - sdfg: sd.SDFG, - loop_states: List[sd.SDFGState], - loop_subgraph: gr.SubgraphView, - itervar: str, - value: symbolic.SymbolicType, - state_suffix=None, - ): - # Using to/from JSON copies faster than deepcopy (which will also - # copy the parent SDFG) - new_states = [sd.SDFGState.from_json(s.to_json(), context={'sdfg': sdfg}) for s in loop_states] - - # Replace iterate with value in each state - for state in new_states: - state.label = state.label + '_' + itervar + '_' + (state_suffix if state_suffix is not None else str(value)) - state.replace(itervar, value) - - graph = loop_states[0].parent_graph - # Add subgraph to original SDFG - for edge in loop_subgraph.edges(): - src = new_states[loop_states.index(edge.src)] - dst = new_states[loop_states.index(edge.dst)] - + if len(unrolled_iterations) > 0: + graph.add_edge(unrolled_iterations[-1], iteration_region, sd.InterstateEdge()) + unrolled_iterations.append(iteration_region) + + if unrolled_iterations: + for ie in graph.in_edges(self.loop): + graph.add_edge(ie.src, unrolled_iterations[0], ie.data) + for oe in graph.out_edges(self.loop): + graph.add_edge(unrolled_iterations[-1], oe.dst, oe.data) + + # Remove old loop. + graph.remove_node(self.loop) + + if self.inline_iterations: + for it in unrolled_iterations: + it.inline() + + def instantiate_loop_iteration(self, graph: ControlFlowRegion, loop: LoopRegion, value: symbolic.SymbolicType, + label_suffix: Optional[str] = None) -> ControlFlowRegion: + it_label = loop.label + '_' + loop.loop_variable + (label_suffix if label_suffix is not None else str(value)) + iteration_region = ControlFlowRegion(it_label, graph.sdfg, graph) + graph.add_node(iteration_region) + block_map = {} + for block in loop.nodes(): + # Using to/from JSON copies faster than deepcopy. + new_block = sd.SDFGState.from_json(block.to_json(), context={'sdfg': graph.sdfg}) + block_map[block] = new_block + new_block.replace(loop.loop_variable, value) + iteration_region.add_node(new_block, is_start_block=(block is loop.start_block)) + for edge in loop.edges(): + src = block_map[edge.src] + dst = block_map[edge.dst] # Replace conditions in subgraph edges - data: sd.InterstateEdge = copy.deepcopy(edge.data) + data = copy.deepcopy(edge.data) if not data.is_unconditional(): - ASTFindReplace({itervar: str(value)}).visit(data.condition) - - graph.add_edge(src, dst, data) + ASTFindReplace({loop.loop_variable: str(value)}).visit(data.condition) + iteration_region.add_edge(src, dst, data) - return new_states + return iteration_region diff --git a/dace/transformation/interstate/move_assignment_outside_if.py b/dace/transformation/interstate/move_assignment_outside_if.py index 3b101818ca..8cfaa591d7 100644 --- a/dace/transformation/interstate/move_assignment_outside_if.py +++ b/dace/transformation/interstate/move_assignment_outside_if.py @@ -1,58 +1,51 @@ -# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Transformation to move assignments outside if statements to potentially avoid warp divergence. Speedup gained is questionable. """ import ast +from typing import Dict, List, Tuple import sympy as sp from dace import sdfg as sd -from dace.sdfg import graph as gr -from dace.sdfg.nodes import Tasklet, AccessNode +from dace.sdfg import graph as gr, utils as sdutil, nodes as nd +from dace.sdfg.state import ConditionalBlock, ControlFlowRegion +from dace.symbolic import pystr_to_symbolic from dace.transformation import transformation -@transformation.single_level_sdfg_only +@transformation.explicit_cf_compatible class MoveAssignmentOutsideIf(transformation.MultiStateTransformation): - if_guard = transformation.PatternNode(sd.SDFGState) - if_stmt = transformation.PatternNode(sd.SDFGState) - else_stmt = transformation.PatternNode(sd.SDFGState) + conditional = transformation.PatternNode(ConditionalBlock) @classmethod def expressions(cls): - sdfg = gr.OrderedDiGraph() - sdfg.add_nodes_from([cls.if_guard, cls.if_stmt, cls.else_stmt]) - sdfg.add_edge(cls.if_guard, cls.if_stmt, sd.InterstateEdge()) - sdfg.add_edge(cls.if_guard, cls.else_stmt, sd.InterstateEdge()) - return [sdfg] + return [sdutil.node_path_graph(cls.conditional)] def can_be_applied(self, graph, expr_index, sdfg, permissive=False): - # The if-guard can only have two outgoing edges: to the if and to the else part - guard_outedges = graph.out_edges(self.if_guard) - if len(guard_outedges) != 2: + # The conditional can only have two branches, with conditions either being negations of one another, or the + # second branch being an 'else' branch. + if len(self.conditional.branches) != 2: return False - - # Outgoing edges must be a negation of each other - if guard_outedges[0].data.condition_sympy() != (sp.Not(guard_outedges[1].data.condition_sympy())): - return False - - # The if guard should either have zero or one incoming edge - if len(sdfg.in_edges(self.if_guard)) > 1: + fcond = self.conditional.branches[0][0] + scond = self.conditional.branches[1][0] + if (fcond is None or (scond is not None and + (pystr_to_symbolic(fcond.as_string)) != sp.Not(pystr_to_symbolic(scond.as_string)))): return False # set of the variables which get a const value assigned assigned_const = set() # Dict which collects all AccessNodes for each variable together with its state - access_nodes = {} + access_nodes: Dict[str, List[Tuple[nd.AccessNode, sd.SDFGState]]] = {} # set of the variables which are only written to self.write_only_values = set() # Dictionary which stores additional information for the variables which are written only self.assign_context = {} - for state in [self.if_stmt, self.else_stmt]: + for state in self.conditional.all_states(): for node in state.nodes(): - if isinstance(node, Tasklet): + if isinstance(node, nd.Tasklet): # If node is a tasklet, check if assigns a constant value assigns_const = True for code_stmt in node.code.code: @@ -60,10 +53,10 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): assigns_const = False if assigns_const: for edge in state.out_edges(node): - if isinstance(edge.dst, AccessNode): + if isinstance(edge.dst, nd.AccessNode): assigned_const.add(edge.dst.data) - self.assign_context[edge.dst.data] = {"state": state, "tasklet": node} - elif isinstance(node, AccessNode): + self.assign_context[edge.dst.data] = {'state': state, 'tasklet': node} + elif isinstance(node, nd.AccessNode): if node.data not in access_nodes: access_nodes[node.data] = [] access_nodes[node.data].append((node, state)) @@ -92,14 +85,14 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): return False return True - def apply(self, _, sdfg: sd.SDFG): + def apply(self, graph: ControlFlowRegion, sdfg: sd.SDFG): # create a new state before the guard state where the zero assignment happens - new_assign_state = sdfg.add_state_before(self.if_guard, label="const_assignment_state") + new_assign_state = graph.add_state_before(self.conditional, label='const_assignment_state') # Move all the Tasklets together with the AccessNode for value in self.write_only_values: - state = self.assign_context[value]["state"] - tasklet = self.assign_context[value]["tasklet"] + state: sd.SDFGState = self.assign_context[value]['state'] + tasklet: nd.Tasklet = self.assign_context[value]['tasklet'] new_assign_state.add_node(tasklet) for edge in state.out_edges(tasklet): state.remove_edge(edge) @@ -110,5 +103,4 @@ def apply(self, _, sdfg: sd.SDFG): state.remove_node(tasklet) # Remove the state if it was emptied if state.is_empty(): - sdfg.remove_node(state) - return sdfg + graph.remove_node(state) diff --git a/dace/transformation/interstate/move_loop_into_map.py b/dace/transformation/interstate/move_loop_into_map.py index 29a9906fe0..de898c8f5c 100644 --- a/dace/transformation/interstate/move_loop_into_map.py +++ b/dace/transformation/interstate/move_loop_into_map.py @@ -1,18 +1,19 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Moves a loop around a map into the map """ import copy +from dace.sdfg.state import ControlFlowRegion, LoopRegion, SDFGState import dace.transformation.helpers as helpers import networkx as nx from dace.sdfg.scope import ScopeTree -from dace import data as dt, Memlet, nodes, sdfg as sd, subsets as sbs, symbolic, symbol -from dace.properties import CodeBlock -from dace.sdfg import nodes, propagation +from dace import Memlet, nodes, sdfg as sd, subsets as sbs, symbolic, symbol +from dace.sdfg import nodes, propagation, utils as sdutil from dace.transformation import transformation -from dace.transformation.interstate.loop_detection import (DetectLoop, find_for_loop) from sympy import diff from typing import List, Set, Tuple +from dace.transformation.passes.analysis import loop_analysis + def fold(memlet_subset_ranges, itervar, lower, upper): return [(r[0].replace(symbol(itervar), lower), r[1].replace(symbol(itervar), upper), r[2]) @@ -23,32 +24,34 @@ def offset(memlet_subset_ranges, value): return (memlet_subset_ranges[0] + value, memlet_subset_ranges[1] + value, memlet_subset_ranges[2]) -@transformation.single_level_sdfg_only -class MoveLoopIntoMap(DetectLoop, transformation.MultiStateTransformation): +@transformation.explicit_cf_compatible +class MoveLoopIntoMap(transformation.MultiStateTransformation): """ Moves a loop around a map into the map """ - def can_be_applied(self, graph, expr_index, sdfg, permissive=False): - # Is this even a loop - if not super().can_be_applied(graph, expr_index, sdfg, permissive): - return False + loop = transformation.PatternNode(LoopRegion) - # Obtain loop information - body: sd.SDFGState = self.loop_begin + @classmethod + def expressions(cls): + return [sdutil.node_path_graph(cls.loop)] - # Obtain iteration variable, range, and stride - loop_info = self.loop_information() - if not loop_info: + def can_be_applied(self, graph, expr_index, sdfg, permissive=False): + # If loop information cannot be determined, fail. + start = loop_analysis.get_init_assignment(self.loop) + end = loop_analysis.get_loop_end(self.loop) + step = loop_analysis.get_loop_stride(self.loop) + itervar = self.loop.loop_variable + if start is None or end is None or step is None or itervar is None: return False - itervar, (start, end, step), (_, body_end) = loop_info if step not in [-1, 1]: return False # Body must contain a single state - if body != body_end: + if len(self.loop.nodes()) != 1 or not isinstance(self.loop.nodes()[0], SDFGState): return False + body: SDFGState = self.loop.nodes()[0] # Body must have only a single connected component # NOTE: This is a strict check that can be potentially relaxed. @@ -153,14 +156,9 @@ def test_subset_dependency(subset: sbs.Subset, mparams: Set[int]) -> Tuple[bool, return True - def apply(self, _, sdfg: sd.SDFG): - # Obtain loop information - body: sd.SDFGState = self.loop_begin - - # Obtain iteration variable, range, and stride - itervar, (start, end, step), _ = self.loop_information() - - forward_loop = step > 0 + def apply(self, graph: ControlFlowRegion, sdfg: sd.SDFG): + body: sd.SDFGState = self.loop.nodes()[0] + itervar = self.loop.loop_variable for node in body.nodes(): if isinstance(node, nodes.MapEntry): @@ -171,50 +169,27 @@ def apply(self, _, sdfg: sd.SDFG): # nest map's content in sdfg map_subgraph = body.scope_subgraph(map_entry, include_entry=False, include_exit=False) nsdfg = helpers.nest_state_subgraph(sdfg, body, map_subgraph, full_data=True) + nested_state: SDFGState = nsdfg.sdfg.nodes()[0] # replicate loop in nested sdfg - new_before, new_guard, new_after = nsdfg.sdfg.add_loop( - before_state=None, - loop_state=nsdfg.sdfg.nodes()[0], - loop_end_state=None, - after_state=None, - loop_var=itervar, - initialize_expr=f'{start}', - condition_expr=f'{itervar} <= {end}' if forward_loop else f'{itervar} >= {end}', - increment_expr=f'{itervar} + {step}' if forward_loop else f'{itervar} - {abs(step)}') - - # remove outer loop - before_guard_edge = nsdfg.sdfg.edges_between(new_before, new_guard)[0] - for e in nsdfg.sdfg.out_edges(new_guard): - if e.dst is new_after: - guard_after_edge = e - else: - guard_body_edge = e - - if self.expr_index <= 1: - guard = self.loop_guard - for body_inedge in sdfg.in_edges(body): - if body_inedge.src is guard: - guard_body_edge.data.assignments.update(body_inedge.data.assignments) - sdfg.remove_edge(body_inedge) - for body_outedge in sdfg.out_edges(body): - sdfg.remove_edge(body_outedge) - for guard_inedge in sdfg.in_edges(guard): - before_guard_edge.data.assignments.update(guard_inedge.data.assignments) - guard_inedge.data.assignments = {} - sdfg.add_edge(guard_inedge.src, body, guard_inedge.data) - sdfg.remove_edge(guard_inedge) - for guard_outedge in sdfg.out_edges(guard): - if guard_outedge.dst is body: - guard_body_edge.data.assignments.update(guard_outedge.data.assignments) - else: - guard_after_edge.data.assignments.update(guard_outedge.data.assignments) - guard_outedge.data.condition = CodeBlock("1") - sdfg.add_edge(body, guard_outedge.dst, guard_outedge.data) - sdfg.remove_edge(guard_outedge) - sdfg.remove_node(guard) - else: # Rotated or self loops - raise NotImplementedError('MoveLoopIntoMap not implemented for rotated and self-loops') + inner_loop = LoopRegion(self.loop.label, + self.loop.loop_condition, + self.loop.loop_variable, + self.loop.init_statement, + self.loop.update_statement, + self.loop.inverted, + nsdfg, + self.loop.update_before_condition) + inner_loop.add_node(nested_state, is_start_block=True) + nsdfg.sdfg.remove_node(nested_state) + nsdfg.sdfg.add_node(inner_loop, is_start_block=True) + + graph.add_node(body, is_start_block=(graph.start_block is self.loop)) + for ie in graph.in_edges(self.loop): + graph.add_edge(ie.src, body, ie.data) + for oe in graph.out_edges(self.loop): + graph.add_edge(body, oe.dst, oe.data) + graph.remove_node(self.loop) if itervar in nsdfg.symbol_mapping: del nsdfg.symbol_mapping[itervar] @@ -254,9 +229,12 @@ def apply(self, _, sdfg: sd.SDFG): if helpers.is_symbol_unused(sdfg, s): sdfg.remove_symbol(s) + sdfg.reset_cfg_list() + from dace.transformation.interstate import RefineNestedAccess transformation = RefineNestedAccess() - transformation.setup_match(sdfg, 0, sdfg.node_id(body), {RefineNestedAccess.nsdfg: body.node_id(nsdfg)}, 0) + transformation.setup_match(sdfg, body.parent_graph.cfg_id, body.block_id, + {RefineNestedAccess.nsdfg: body.node_id(nsdfg)}, 0) transformation.apply(body, sdfg) # Second propagation for refined accesses. diff --git a/dace/transformation/interstate/multistate_inline.py b/dace/transformation/interstate/multistate_inline.py index 84b8a82c75..89f0edcea9 100644 --- a/dace/transformation/interstate/multistate_inline.py +++ b/dace/transformation/interstate/multistate_inline.py @@ -1,4 +1,4 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Inline multi-state SDFGs. """ from copy import deepcopy as dc @@ -14,11 +14,11 @@ from dace.transformation import transformation, helpers from dace.properties import make_properties from dace import data -from dace.sdfg.state import StateSubgraphView +from dace.sdfg.state import LoopRegion, ReturnBlock, StateSubgraphView @make_properties -@transformation.single_level_sdfg_only +@transformation.explicit_cf_compatible class InlineMultistateSDFG(transformation.SingleStateTransformation): """ Inlines a multi-state nested SDFG into a top-level SDFG. This only happens @@ -74,7 +74,7 @@ def replfunc(mapping): return all(istr == ostr for istr, ostr in zip(istrides, ostrides)) - def can_be_applied(self, state: SDFGState, expr_index, sdfg, permissive=False): + def can_be_applied(self, state: SDFGState, expr_index, sdfg: SDFG, permissive=False): nested_sdfg = self.nested_sdfg if nested_sdfg.no_inline: return False @@ -135,6 +135,14 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG): nsdfg_node = self.nested_sdfg nsdfg: SDFG = nsdfg_node.sdfg + # If the nested SDFG contains returns, ensure they are inlined first. + has_return = False + for blk in nsdfg.all_control_flow_blocks(): + if isinstance(blk, ReturnBlock): + has_return = True + if has_return: + sdutil.inline_control_flow_regions(nsdfg, lower_returns=True) + if nsdfg_node.schedule != dtypes.ScheduleType.Default: infer_types.set_default_schedule_and_storage_types(nsdfg, [nsdfg_node.schedule]) @@ -153,14 +161,14 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG): sdfg._callback_mapping.update(nsdfg.callback_mapping) # Environments - for nstate in nsdfg.nodes(): + for nstate in nsdfg.states(): for node in nstate.nodes(): if isinstance(node, nodes.CodeNode): node.environments |= nsdfg_node.environments # Symbols outer_symbols = {str(k): v for k, v in sdfg.symbols.items()} - for ise in sdfg.edges(): + for ise in sdfg.all_interstate_edges(): outer_symbols.update(ise.data.new_symbols(sdfg, outer_symbols)) # Isolate nsdfg in a separate state @@ -195,12 +203,20 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG): # Collect and modify interstate edges as necessary outer_assignments = set() - for e in sdfg.edges(): + for e in sdfg.all_interstate_edges(): outer_assignments |= e.data.assignments.keys() + for b in sdfg.all_control_flow_blocks(): + if isinstance(b, LoopRegion): + if b.loop_variable is not None: + outer_assignments.add(b.loop_variable) inner_assignments = set() - for e in nsdfg.edges(): + for e in nsdfg.all_interstate_edges(): inner_assignments |= e.data.assignments.keys() + for b in nsdfg.all_control_flow_blocks(): + if isinstance(b, LoopRegion): + if b.loop_variable is not None: + inner_assignments.add(b.loop_variable) allnames = set(outer_symbols.keys()) | set(sdfg.arrays.keys()) assignments_to_replace = inner_assignments & (outer_assignments | allnames) @@ -220,7 +236,7 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG): # All transients become transients of the parent (if data already # exists, find new name) - for nstate in nsdfg.nodes(): + for nstate in nsdfg.states(): for node in nstate.nodes(): if isinstance(node, nodes.AccessNode): datadesc = nsdfg.arrays[node.data] @@ -268,8 +284,8 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG): symbolic.safe_replace(repldict, lambda m: replace_datadesc_names(nsdfg, m), value_as_string=True) # Make unique names for states - statenames = set(s.label for s in sdfg.nodes()) - for nstate in nsdfg.nodes(): + statenames = set(s.label for s in sdfg.states()) + for nstate in nsdfg.states(): if nstate.label in statenames: newname = data.find_new_name(nstate.label, statenames) statenames.add(newname) @@ -278,11 +294,11 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG): ####################################################### # Add nested SDFG states into top-level SDFG - outer_start_state = sdfg.start_state + outer_start_state = outer_state.parent_graph.start_block - sdfg.add_nodes_from(nsdfg.nodes()) + outer_state.parent_graph.add_nodes_from(nsdfg.nodes()) for ise in nsdfg.edges(): - sdfg.add_edge(ise.src, ise.dst, ise.data) + outer_state.parent_graph.add_edge(ise.src, ise.dst, ise.data) ####################################################### # Reconnect inlined SDFG @@ -291,25 +307,25 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG): sinks = nsdfg.sink_nodes() # Reconnect state machine - for e in sdfg.in_edges(nsdfg_state): - sdfg.add_edge(e.src, source, e.data) - for e in sdfg.out_edges(nsdfg_state): + for e in outer_state.parent_graph.in_edges(nsdfg_state): + outer_state.parent_graph.add_edge(e.src, source, e.data) + for e in outer_state.parent_graph.out_edges(nsdfg_state): for sink in sinks: - sdfg.add_edge(sink, e.dst, dc(e.data)) + outer_state.parent_graph.add_edge(sink, e.dst, dc(e.data)) # Redirect sink incoming edges with a `False` condition to e.dst (return statements) - for e2 in sdfg.in_edges(sink): + for e2 in outer_state.parent_graph.in_edges(sink): if e2.data.condition_sympy() == False: - sdfg.add_edge(e2.src, e.dst, InterstateEdge()) + outer_state.parent_graph.add_edge(e2.src, e.dst, InterstateEdge()) # Modify start state as necessary if outer_start_state is nsdfg_state: - sdfg.start_state = sdfg.node_id(source) + outer_state.parent_graph.start_block = outer_state.parent_graph.node_id(source) # TODO: Modify memlets by offsetting # Replace nested SDFG parents with new SDFG - for nstate in nsdfg.nodes(): - nstate.parent = sdfg + for nstate in nsdfg.states(): + nstate.sdfg = sdfg for node in nstate.nodes(): if isinstance(node, nodes.NestedSDFG): node.sdfg.parent_sdfg = sdfg @@ -317,8 +333,8 @@ def apply(self, outer_state: SDFGState, sdfg: SDFG): ####################################################### # Remove nested SDFG and state - sdfg.remove_node(nsdfg_state) + outer_state.parent_graph.remove_node(nsdfg_state) - sdfg._cfg_list = sdfg.reset_cfg_list() + sdfg.reset_cfg_list() return nsdfg.nodes() diff --git a/dace/transformation/interstate/sdfg_nesting.py b/dace/transformation/interstate/sdfg_nesting.py index 7b64e49869..31e751bb6a 100644 --- a/dace/transformation/interstate/sdfg_nesting.py +++ b/dace/transformation/interstate/sdfg_nesting.py @@ -1,4 +1,4 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ SDFG nesting transformation. """ import ast @@ -16,13 +16,14 @@ from dace.sdfg.graph import MultiConnectorEdge, SubgraphView from dace.sdfg import SDFG, SDFGState from dace.sdfg import utils as sdutil, infer_types, propagation +from dace.sdfg.state import LoopRegion from dace.transformation import transformation, helpers from dace.properties import make_properties, Property from dace import data @make_properties -@transformation.single_level_sdfg_only +@transformation.explicit_cf_compatible class InlineSDFG(transformation.SingleStateTransformation): """ Inlines a single-state nested SDFG into a top-level SDFG. @@ -99,7 +100,7 @@ def can_be_applied(self, graph: SDFGState, expr_index, sdfg, permissive=False): nested_sdfg = self.nested_sdfg if nested_sdfg.no_inline: return False - if len(nested_sdfg.sdfg.nodes()) != 1: + if len(nested_sdfg.sdfg.nodes()) != 1 or not isinstance(nested_sdfg.sdfg.nodes()[0], SDFGState): return False # Ensure every connector has one incoming/outgoing edge and that it @@ -154,7 +155,7 @@ def can_be_applied(self, graph: SDFGState, expr_index, sdfg, permissive=False): out_data[dst.data] = e.src_conn rem_inpconns = dc(in_connectors) rem_outconns = dc(out_connectors) - nstate = nested_sdfg.sdfg.node(0) + nstate: SDFGState = nested_sdfg.sdfg.nodes()[0] for node in nstate.nodes(): if isinstance(node, nodes.AccessNode): if node.data in rem_inpconns: @@ -317,7 +318,7 @@ def apply(self, state: SDFGState, sdfg: SDFG): symbolic.safe_replace(nsdfg_node.symbol_mapping, nsdfg.replace_dict) # Access nodes that need to be reshaped - reshapes: Set(str) = set() + reshapes: Set[str] = set() for aname, array in nsdfg.arrays.items(): if array.transient: continue @@ -737,11 +738,10 @@ def _modify_reshape_data(self, reshapes: Set[str], repldict: Dict[str, str], new @make_properties -@transformation.single_level_sdfg_only +@transformation.explicit_cf_compatible class InlineTransients(transformation.SingleStateTransformation): """ - Inlines all transient arrays that are not used anywhere else into a - nested SDFG. + Inlines all transient arrays that are not used anywhere else into a nested SDFG. """ nsdfg = transformation.PatternNode(nodes.NestedSDFG) @@ -784,7 +784,7 @@ def _candidates(sdfg: SDFG, graph: SDFGState, nsdfg: nodes.NestedSDFG) -> Dict[s return candidates # Check for uses in other states - for state in sdfg.nodes(): + for state in sdfg.states(): if state is graph: continue for node in state.data_nodes(): @@ -881,7 +881,7 @@ def visit_Subscript(self, node: ast.Subscript) -> ast.Subscript: @make_properties -@transformation.single_level_sdfg_only +@transformation.explicit_cf_compatible class RefineNestedAccess(transformation.SingleStateTransformation): """ Reduces memlet shape when a memlet is connected to a nested SDFG, but not @@ -923,7 +923,7 @@ def _candidates( in_candidates: Dict[str, Tuple[Memlet, SDFGState, Set[int]]] = {} out_candidates: Dict[str, Tuple[Memlet, SDFGState, Set[int]]] = {} ignore = set() - for nstate in nsdfg.sdfg.nodes(): + for nstate in nsdfg.sdfg.states(): for dnode in nstate.data_nodes(): if nsdfg.sdfg.arrays[dnode.data].transient: continue @@ -970,7 +970,7 @@ def _candidates( in_candidates[e.data.data] = (e.data, nstate, set(range(len(e.data.subset)))) # Check read memlets in interstate edges for candidates - for e in nsdfg.sdfg.edges(): + for e in nsdfg.sdfg.all_interstate_edges(): for m in e.data.get_read_memlets(nsdfg.sdfg.arrays): # If more than one unique element detected, remove from candidates if m.data in in_candidates: @@ -1035,7 +1035,8 @@ def _check_cand(candidates, outer_edges): # If there are any symbols here that are not defined # in "defined_symbols" - missing_symbols = (memlet.get_free_symbols_by_indices(list(indices), list(indices)) - set(nsdfg.symbol_mapping.keys())) + missing_symbols = (memlet.get_free_symbols_by_indices(list(indices), + list(indices)) - set(nsdfg.symbol_mapping.keys())) if missing_symbols: ignore.add(cname) continue @@ -1078,13 +1079,13 @@ def _offset_refine(torefine: Dict[str, Tuple[Memlet, Set[int]]], if aname in refined: continue # Refine internal memlets - for nstate in nsdfg.nodes(): + for nstate in nsdfg.states(): for e in nstate.edges(): if e.data.data == aname: e.data.subset.offset(refine.subset, True, indices) # Refine accesses in interstate edges refiner = ASTRefiner(aname, refine.subset, nsdfg, indices) - for isedge in nsdfg.edges(): + for isedge in nsdfg.all_interstate_edges(): for k, v in isedge.data.assignments.items(): vast = ast.parse(v) refiner.visit(vast) @@ -1105,7 +1106,7 @@ def _offset_refine(torefine: Dict[str, Tuple[Memlet, Set[int]]], @make_properties -@transformation.single_level_sdfg_only +@transformation.explicit_cf_compatible class NestSDFG(transformation.MultiStateTransformation): """ Implements SDFG Nesting, taking an SDFG as an input and creating a nested SDFG node from it. """ @@ -1135,7 +1136,7 @@ def apply(self, _, sdfg: SDFG) -> nodes.NestedSDFG: outputs = {} transients = {} - for state in nested_sdfg.nodes(): + for state in nested_sdfg.states(): # Input and output nodes are added as input and output nodes of the nested SDFG for node in state.nodes(): if (isinstance(node, nodes.AccessNode) and not node.desc(nested_sdfg).transient): @@ -1256,7 +1257,7 @@ def apply(self, _, sdfg: SDFG) -> nodes.NestedSDFG: nested_sdfg.arrays[newarrname].transient = False # Update memlets - for state in nested_sdfg.nodes(): + for state in nested_sdfg.states(): for _, edge in enumerate(state.edges()): _, _, _, _, mem = edge src = state.memlet_path(edge)[0].src @@ -1288,6 +1289,9 @@ def apply(self, _, sdfg: SDFG) -> nodes.NestedSDFG: for e in nested_sdfg.edges(): defined_syms |= set(e.data.new_symbols(sdfg, {}).keys()) + for blk in nested_sdfg.all_control_flow_blocks(): + if isinstance(blk, LoopRegion): + defined_syms |= set(blk.new_symbols({}).keys()) defined_syms |= set(nested_sdfg.constants.keys()) diff --git a/dace/transformation/interstate/state_elimination.py b/dace/transformation/interstate/state_elimination.py index 2640e30ccc..94619576bf 100644 --- a/dace/transformation/interstate/state_elimination.py +++ b/dace/transformation/interstate/state_elimination.py @@ -1,4 +1,4 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ State elimination transformations """ import networkx as nx @@ -8,11 +8,12 @@ from dace.properties import CodeBlock from dace.sdfg import nodes, SDFG, SDFGState from dace.sdfg import utils as sdutil +from dace.sdfg.sdfg import InterstateEdge from dace.sdfg.state import ControlFlowRegion from dace.transformation import transformation -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class EndStateElimination(transformation.MultiStateTransformation): """ End-state elimination removes a redundant state that has one incoming edge @@ -60,7 +61,7 @@ def apply(self, graph, sdfg): sdfg.remove_symbol(sym) -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class StartStateElimination(transformation.MultiStateTransformation): """ Start-state elimination removes a redundant state that has one outgoing edge @@ -77,7 +78,7 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): state = self.start_state # The transformation applies only to nested SDFGs - if not graph.parent: + if not isinstance(graph, SDFG) or not graph.parent: return False # Only empty states can be eliminated @@ -133,7 +134,7 @@ def _assignments_to_consider(sdfg, edge, is_constant=False): return assignments_to_consider -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class StateAssignElimination(transformation.MultiStateTransformation): """ State assign elimination removes all assignments into the final state @@ -222,7 +223,7 @@ def _str_repl(s, d): symbolic.safe_replace(repl_dict, lambda m: _str_repl(sdfg, m)) -def _alias_assignments(sdfg, edge): +def _alias_assignments(sdfg: SDFG, edge: InterstateEdge): assignments_to_consider = {} for var, assign in edge.assignments.items(): if assign in sdfg.symbols or (assign in sdfg.arrays and isinstance(sdfg.arrays[assign], dt.Scalar)): @@ -230,7 +231,7 @@ def _alias_assignments(sdfg, edge): return assignments_to_consider -@transformation.single_level_sdfg_only +@transformation.explicit_cf_compatible class SymbolAliasPromotion(transformation.MultiStateTransformation): """ SymbolAliasPromotion moves inter-state assignments that create symbolic @@ -298,12 +299,12 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): return True - def apply(self, _, sdfg): + def apply(self, graph: ControlFlowRegion, sdfg: SDFG): fstate = self.first_state sstate = self.second_state - edge = sdfg.edges_between(fstate, sstate)[0].data - in_edge = sdfg.in_edges(fstate)[0].data + edge = graph.edges_between(fstate, sstate)[0].data + in_edge = graph.in_edges(fstate)[0].data to_consider = _alias_assignments(sdfg, edge) @@ -335,7 +336,7 @@ def apply(self, _, sdfg): in_edge.assignments[k] = v -@transformation.single_level_sdfg_only +@transformation.explicit_cf_compatible class HoistState(transformation.SingleStateTransformation): """ Move a state out of a nested SDFG """ nsdfg = transformation.PatternNode(nodes.NestedSDFG) @@ -359,6 +360,8 @@ def can_be_applied(self, graph: SDFGState, expr_index, sdfg, permissive=False): return False if nsdfg.sdfg.start_state.number_of_nodes() != 0: return False + if any([not isinstance(x, SDFGState) for x in nsdfg.sdfg.nodes()]): + return False # Must have at least two states with a hoistable source state if nsdfg.sdfg.number_of_nodes() < 2: @@ -428,8 +431,8 @@ def can_be_applied(self, graph: SDFGState, expr_index, sdfg, permissive=False): def apply(self, state: SDFGState, sdfg: SDFG): nsdfg: nodes.NestedSDFG = self.nsdfg - new_state = sdfg.add_state_before(state) - isedge = sdfg.edges_between(new_state, state)[0] + new_state = state.parent_graph.add_state_before(state) + isedge = state.parent_graph.edges_between(new_state, state)[0] # Find relevant symbol and data descriptor mapping mapping: Dict[str, str] = {} @@ -438,7 +441,7 @@ def apply(self, state: SDFGState, sdfg: SDFG): mapping.update({k: next(iter(state.out_edges_by_connector(nsdfg, k))).data.data for k in nsdfg.out_connectors}) # Get internal state and interstate edge - source_state = nsdfg.sdfg.start_state + source_state: SDFGState = nsdfg.sdfg.start_state nisedge = nsdfg.sdfg.out_edges(source_state)[0] # Add state contents (nodes) @@ -489,7 +492,7 @@ def replfunc(m): nsdfg.sdfg.start_state = nsdfg.sdfg.node_id(nisedge.dst) -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class TrueConditionElimination(transformation.MultiStateTransformation): """ If a state transition condition is always true, removes condition from edge. @@ -525,7 +528,7 @@ def apply(self, graph: ControlFlowRegion, sdfg: SDFG): edge.data.condition = CodeBlock("1") -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class FalseConditionElimination(transformation.MultiStateTransformation): """ If a state transition condition is always false, removes edge. diff --git a/dace/transformation/interstate/state_fusion.py b/dace/transformation/interstate/state_fusion.py index dbdf7642bd..7e3dc6916b 100644 --- a/dace/transformation/interstate/state_fusion.py +++ b/dace/transformation/interstate/state_fusion.py @@ -32,7 +32,7 @@ def top_level_nodes(state: SDFGState): return state.scope_children()[None] -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class StateFusion(transformation.MultiStateTransformation): """ Implements the state-fusion transformation. diff --git a/dace/transformation/interstate/state_fusion_with_happens_before.py b/dace/transformation/interstate/state_fusion_with_happens_before.py index 408f5a76f2..c358a131f6 100644 --- a/dace/transformation/interstate/state_fusion_with_happens_before.py +++ b/dace/transformation/interstate/state_fusion_with_happens_before.py @@ -1,4 +1,4 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ State fusion transformation """ from typing import Dict, List, Set @@ -9,7 +9,7 @@ from dace.config import Config from dace.sdfg import nodes from dace.sdfg import utils as sdutil -from dace.sdfg.state import SDFGState +from dace.sdfg.state import ControlFlowRegion, SDFGState from dace.transformation import transformation @@ -31,7 +31,7 @@ def top_level_nodes(state: SDFGState): return state.scope_children()[None] -@transformation.single_level_sdfg_only +@transformation.explicit_cf_compatible class StateFusionExtended(transformation.MultiStateTransformation): """ Implements the state-fusion transformation extended to fuse states with RAW and WAW dependencies. An empty memlet is used to represent a dependency between two subgraphs with RAW and WAW dependencies. @@ -461,33 +461,33 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): return True - def apply(self, _, sdfg): + def apply(self, graph: ControlFlowRegion, sdfg): first_state: SDFGState = self.first_state second_state: SDFGState = self.second_state # Remove interstate edge(s) - edges = sdfg.edges_between(first_state, second_state) + edges = graph.edges_between(first_state, second_state) for edge in edges: if edge.data.assignments: - for src, dst, other_data in sdfg.in_edges(first_state): + for src, dst, other_data in graph.in_edges(first_state): other_data.assignments.update(edge.data.assignments) - sdfg.remove_edge(edge) + graph.remove_edge(edge) # Special case 1: first state is empty if first_state.is_empty(): - sdutil.change_edge_dest(sdfg, first_state, second_state) - sdfg.remove_node(first_state) - if sdfg.start_state == first_state: - sdfg.start_state = sdfg.node_id(second_state) + sdutil.change_edge_dest(graph, first_state, second_state) + graph.remove_node(first_state) + if graph.start_block == first_state: + graph.start_block = graph.node_id(second_state) return # Special case 2: second state is empty if second_state.is_empty(): - sdutil.change_edge_src(sdfg, second_state, first_state) - sdutil.change_edge_dest(sdfg, second_state, first_state) - sdfg.remove_node(second_state) - if sdfg.start_state == second_state: - sdfg.start_state = sdfg.node_id(first_state) + sdutil.change_edge_src(graph, second_state, first_state) + sdutil.change_edge_dest(graph, second_state, first_state) + graph.remove_node(second_state) + if graph.start_block == second_state: + graph.start_block = graph.node_id(first_state) return # Normal case: both states are not empty @@ -495,7 +495,6 @@ def apply(self, _, sdfg): # Find source/sink (data) nodes first_input = [node for node in first_state.source_nodes() if isinstance(node, nodes.AccessNode)] first_output = [node for node in first_state.sink_nodes() if isinstance(node, nodes.AccessNode)] - second_input = [node for node in second_state.source_nodes() if isinstance(node, nodes.AccessNode)] top2 = top_level_nodes(second_state) @@ -585,7 +584,7 @@ def apply(self, _, sdfg): merged_nodes.add(n) # Redirect edges and remove second state - sdutil.change_edge_src(sdfg, second_state, first_state) - sdfg.remove_node(second_state) - if sdfg.start_state == second_state: - sdfg.start_state = sdfg.node_id(first_state) + sdutil.change_edge_src(graph, second_state, first_state) + graph.remove_node(second_state) + if graph.start_block == second_state: + graph.start_block = graph.node_id(first_state) diff --git a/dace/transformation/interstate/trivial_loop_elimination.py b/dace/transformation/interstate/trivial_loop_elimination.py index 411d9ff07d..e948cba7ba 100644 --- a/dace/transformation/interstate/trivial_loop_elimination.py +++ b/dace/transformation/interstate/trivial_loop_elimination.py @@ -1,33 +1,39 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Eliminates trivial loop """ from dace import sdfg as sd -from dace.properties import CodeBlock +from dace.sdfg import utils as sdutil +from dace.sdfg.sdfg import InterstateEdge +from dace.sdfg.state import ControlFlowRegion, LoopRegion from dace.transformation import helpers, transformation -from dace.transformation.interstate.loop_detection import (DetectLoop, find_for_loop) +from dace.transformation.passes.analysis import loop_analysis -@transformation.single_level_sdfg_only -class TrivialLoopElimination(DetectLoop, transformation.MultiStateTransformation): +@transformation.explicit_cf_compatible +class TrivialLoopElimination(transformation.MultiStateTransformation): """ Eliminates loops with a single loop iteration. """ - def can_be_applied(self, graph, expr_index, sdfg, permissive=False): - # Is this even a loop - if not super().can_be_applied(graph, expr_index, sdfg, permissive): - return False + loop = transformation.PatternNode(LoopRegion) + + @classmethod + def expressions(cls): + return [sdutil.node_path_graph(cls.loop)] - # Obtain iteration variable, range, and stride - loop_info = self.loop_information() - if not loop_info: + def can_be_applied(self, graph, expr_index, sdfg, permissive=False): + # Check if this is a for-loop with known range. + start = loop_analysis.get_init_assignment(self.loop) + end = loop_analysis.get_loop_end(self.loop) + stride = loop_analysis.get_loop_stride(self.loop) + if start is None or end is None or stride is None: return False - _, (start, end, step), _ = loop_info + # Check if this is a trivial loop. try: - if step > 0 and start + step < end + 1: + if stride > 0 and start + stride < end + 1: return False - if step < 0 and start + step > end - 1: + if stride < 0 and start + stride > end - 1: return False except: # if the relation can't be determined it's not a trivial loop @@ -35,28 +41,26 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive=False): return True - def apply(self, _, sdfg: sd.SDFG): - # Obtain loop information + def apply(self, graph: ControlFlowRegion, sdfg: sd.SDFG): # Obtain iteration variable, range and stride - itervar, (start, end, step), (_, body_end) = self.loop_information() - states = self.loop_body() - - for state in states: - state.replace(itervar, start) - - # Remove loop - sdfg.remove_edge(self.loop_increment_edge()) + itervar = self.loop.loop_variable + start = loop_analysis.get_init_assignment(self.loop) - init_edge = self.loop_init_edge() - init_edge.data.assignments = {} - sdfg.add_edge(init_edge.src, self.loop_begin, init_edge.data) - sdfg.remove_edge(init_edge) + self.loop.replace(itervar, start) - exit_edge = self.loop_exit_edge() - exit_edge.data.condition = CodeBlock("1") - sdfg.add_edge(body_end, exit_edge.dst, exit_edge.data) - sdfg.remove_edge(exit_edge) + # Add the loop contents to the parent graph. + graph.add_node(self.loop.start_block) + for e in graph.in_edges(self.loop): + graph.add_edge(e.src, self.loop.start_block, e.data) + sink = graph.add_state(self.loop.label + '_sink') + for n in self.loop.sink_nodes(): + graph.add_edge(n, sink, InterstateEdge()) + for e in graph.out_edges(self.loop): + graph.add_edge(sink, e.dst, e.data) + for e in self.loop.edges(): + graph.add_edge(e.src, e.dst, e.data) - sdfg.remove_nodes_from(self.loop_meta_states()) + # Remove loop and if necessary also the loop variable. + graph.remove_node(self.loop) if itervar in sdfg.symbols and helpers.is_symbol_unused(sdfg, itervar): sdfg.remove_symbol(itervar) diff --git a/dace/transformation/pass_pipeline.py b/dace/transformation/pass_pipeline.py index 9a8154df90..bca7626b85 100644 --- a/dace/transformation/pass_pipeline.py +++ b/dace/transformation/pass_pipeline.py @@ -10,6 +10,8 @@ from typing import Any, Dict, Iterator, List, Optional, Set, Type, Union from dataclasses import dataclass +from dace.sdfg.state import ConditionalBlock, ControlFlowRegion + class Modifies(Flag): """ @@ -132,6 +134,9 @@ def subclasses_recursive(cls) -> Set[Type['Pass']]: return result + def set_opts(self, opts: Dict[str, Any]) -> None: + pass + @properties.make_properties class VisitorPass(Pass): """ @@ -257,6 +262,60 @@ def apply(self, state: SDFGState, pipeline_results: Dict[str, Any]) -> Optional[ raise NotImplementedError +@properties.make_properties +class ControlFlowRegionPass(Pass): + """ + A specialized Pass type that applies to each control flow region separately, buttom up. Such a pass is realized by + implementing the ``apply`` method, which accepts a single control flow region, and assumes the pass was already + applied to each control flow region nested inside of that. + + :see: Pass + """ + + CATEGORY: str = 'Helper' + + apply_to_conditionals = properties.Property(dtype=bool, default=False, + desc='Whether or not to apply to conditional blocks. If false, do ' + + 'not apply to conditional blocks, but only their children.') + top_down = properties.Property(dtype=bool, default=False, + desc='Whether or not to apply top down (i.e., parents before children)') + + def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Optional[Dict[int, Optional[Any]]]: + """ + Applies the pass to control flow regions of the given SDFG by calling ``apply`` on each region. + + :param sdfg: The SDFG to apply the pass to. + :param pipeline_results: If in the context of a ``Pipeline``, a dictionary that is populated with prior Pass + results as ``{Pass subclass name: returned object from pass}``. If not run in a + pipeline, an empty dictionary is expected. + :return: A dictionary of ``{cfg_id: return value}`` for visited regions with a non-None return value, or None + if nothing was returned. + """ + result = {} + for region in sdfg.all_control_flow_regions(recursive=True, parent_first=self.top_down): + if isinstance(region, ConditionalBlock) and not self.apply_to_conditionals: + continue + retval = self.apply(region, pipeline_results) + if retval is not None: + result[region.cfg_id] = retval + + if not result: + return None + return result + + def apply(self, region: ControlFlowRegion, pipeline_results: Dict[str, Any]) -> Optional[Any]: + """ + Applies this pass on the given control flow region. + + :param state: The control flow region to apply the pass to. + :param pipeline_results: If in the context of a ``Pipeline``, a dictionary that is populated with prior Pass + results as ``{Pass subclass name: returned object from pass}``. If not run in a + pipeline, an empty dictionary is expected. + :return: Some object if pass was applied, or None if nothing changed. + """ + raise NotImplementedError + + @properties.make_properties class ScopePass(Pass): """ @@ -494,35 +553,9 @@ def apply_subpass(self, sdfg: SDFG, p: Pass, state: Dict[str, Any]) -> Optional[ :param state: The pipeline results state. :return: The pass return value. """ - if sdfg.root_sdfg.using_experimental_blocks: - if (not hasattr(p, '__experimental_cfg_block_compatible__') or - p.__experimental_cfg_block_compatible__ == False): - warnings.warn(p.__class__.__name__ + ' is not being applied due to incompatibility with ' + - 'experimental control flow blocks. If the SDFG does not contain experimental blocks, ' + - 'ensure the top level SDFG does not have `SDFG.using_experimental_blocks` set to ' + - 'True. If ' + p.__class__.__name__ + ' is compatible with experimental blocks, ' + - 'please annotate it with the class decorator ' + - '`@dace.transformation.experimental_cfg_block_compatible`. see ' + - '`https://github.com/spcl/dace/wiki/Experimental-Control-Flow-Blocks` ' + - 'for more information.') - return None - return p.apply_pass(sdfg, state) def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Optional[Dict[str, Any]]: - if sdfg.root_sdfg.using_experimental_blocks: - if (not hasattr(self, '__experimental_cfg_block_compatible__') or - self.__experimental_cfg_block_compatible__ == False): - warnings.warn('Pipeline ' + self.__class__.__name__ + ' is being skipped due to incompatibility with ' + - 'experimental control flow blocks. If the SDFG does not contain experimental blocks, ' + - 'ensure the top level SDFG does not have `SDFG.using_experimental_blocks` set to ' + - 'True. If ' + self.__class__.__name__ + ' is compatible with experimental blocks, ' + - 'please annotate it with the class decorator ' + - '`@dace.transformation.experimental_cfg_block_compatible`. see ' + - '`https://github.com/spcl/dace/wiki/Experimental-Control-Flow-Blocks` ' + - 'for more information.') - return None - state = pipeline_results retval = {} self._modified = Modifies.Nothing diff --git a/dace/transformation/passes/analysis/__init__.py b/dace/transformation/passes/analysis/__init__.py index 5bc1f6e3f3..b0b39f3c4b 100644 --- a/dace/transformation/passes/analysis/__init__.py +++ b/dace/transformation/passes/analysis/__init__.py @@ -1 +1,2 @@ from .analysis import * +from .loop_analysis import * diff --git a/dace/transformation/passes/analysis/analysis.py b/dace/transformation/passes/analysis/analysis.py index cc0d95c1d8..94c24399ee 100644 --- a/dace/transformation/passes/analysis/analysis.py +++ b/dace/transformation/passes/analysis/analysis.py @@ -2,24 +2,30 @@ from collections import defaultdict, deque from dataclasses import dataclass -from dace.sdfg.state import ControlFlowBlock, ControlFlowRegion, LoopRegion + +import sympy + +from dace.sdfg.state import AbstractControlFlowRegion, ConditionalBlock, ControlFlowBlock, ControlFlowRegion, LoopRegion +from dace.subsets import Range from dace.transformation import pass_pipeline as ppl, transformation from dace import SDFG, SDFGState, properties, InterstateEdge, Memlet, data as dt, symbolic from dace.sdfg.graph import Edge from dace.sdfg import nodes as nd, utils as sdutil -from dace.sdfg.analysis import cfg +from dace.sdfg.analysis import cfg as cfg_analysis from dace.sdfg.propagation import align_memlet -from typing import Dict, Set, Tuple, Any, Optional, Union +from typing import Dict, Iterable, List, Set, Tuple, Any, Optional, Union import networkx as nx from networkx.algorithms import shortest_paths as nxsp +from dace.transformation.passes.analysis import loop_analysis + WriteScopeDict = Dict[str, Dict[Optional[Tuple[SDFGState, nd.AccessNode]], - Set[Tuple[SDFGState, Union[nd.AccessNode, InterstateEdge]]]]] -SymbolScopeDict = Dict[str, Dict[Edge[InterstateEdge], Set[Union[Edge[InterstateEdge], SDFGState]]]] + Set[Union[Tuple[SDFGState, nd.AccessNode], Tuple[ControlFlowBlock, InterstateEdge]]]]] +SymbolScopeDict = Dict[str, Dict[Edge[InterstateEdge], Set[Union[Edge[InterstateEdge], ControlFlowBlock]]]] @properties.make_properties -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class StateReachability(ppl.Pass): """ Evaluates state reachability (which other states can be executed after each state). @@ -58,7 +64,7 @@ def apply_pass(self, top_sdfg: SDFG, pipeline_res: Dict) -> Dict[int, Dict[SDFGS @properties.make_properties -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class ControlFlowBlockReachability(ppl.Pass): """ Evaluates control flow block reachability (which control flow block can be executed after each control flow block) @@ -103,7 +109,7 @@ def _region_closure(self, region: ControlFlowRegion, def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[ControlFlowBlock, Set[ControlFlowBlock]]]: """ :return: For each control flow region, a dictionary mapping each control flow block to its other reachable - control flow blocks in the same region. + control flow blocks. """ single_level_reachable: Dict[int, Dict[ControlFlowBlock, Set[ControlFlowBlock]]] = defaultdict( lambda: defaultdict(set) @@ -113,7 +119,12 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[ControlFlowBlock, Set[ # The implementation below is faster # tc: nx.DiGraph = nx.transitive_closure(sdfg.nx) for n, v in reachable_nodes(cfg.nx): - single_level_reachable[cfg.cfg_id][n] = set(v) + reach = set() + for nd in v: + reach.add(nd) + if isinstance(nd, AbstractControlFlowRegion): + reach.update(nd.all_control_flow_blocks()) + single_level_reachable[cfg.cfg_id][n] = reach if isinstance(cfg, LoopRegion): single_level_reachable[cfg.cfg_id][n].update(cfg.nodes()) @@ -126,7 +137,7 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[ControlFlowBlock, Set[ result: Dict[ControlFlowBlock, Set[ControlFlowBlock]] = defaultdict(set) for block in cfg.nodes(): for reached in single_level_reachable[block.parent_graph.cfg_id][block]: - if isinstance(reached, ControlFlowRegion): + if isinstance(reached, AbstractControlFlowRegion): result[block].update(reached.all_control_flow_blocks()) result[block].add(reached) if block.parent_graph is not sdfg: @@ -184,10 +195,10 @@ def reachable_nodes(G): @properties.make_properties -@transformation.experimental_cfg_block_compatible -class SymbolAccessSets(ppl.Pass): +@transformation.explicit_cf_compatible +class SymbolAccessSets(ppl.ControlFlowRegionPass): """ - Evaluates symbol access sets (which symbols are read/written in each state or interstate edge). + Evaluates symbol access sets (which symbols are read/written in each control flow block or interstate edge). """ CATEGORY: str = 'Analysis' @@ -199,33 +210,25 @@ def should_reapply(self, modified: ppl.Modifies) -> bool: # If anything was modified, reapply return modified & ppl.Modifies.States | ppl.Modifies.Edges | ppl.Modifies.Symbols | ppl.Modifies.Nodes - def apply_pass(self, top_sdfg: SDFG, - _) -> Dict[int, Dict[Union[SDFGState, Edge[InterstateEdge]], Tuple[Set[str], Set[str]]]]: - """ - :return: A dictionary mapping each state and interstate edge to a tuple of its (read, written) symbols. - """ - top_result: Dict[int, Dict[Union[SDFGState, Edge[InterstateEdge]], Tuple[Set[str], Set[str]]]] = {} - for sdfg in top_sdfg.all_sdfgs_recursive(): - for cfg in sdfg.all_control_flow_regions(): - adesc = set(sdfg.arrays.keys()) - result: Dict[SDFGState, Tuple[Set[str], Set[str]]] = {} - for block in cfg.nodes(): - if isinstance(block, SDFGState): - # No symbols may be written to inside states. - result[block] = (block.free_symbols, set()) - for oedge in cfg.out_edges(block): - edge_readset = oedge.data.read_symbols() - adesc - edge_writeset = set(oedge.data.assignments.keys()) - result[oedge] = (edge_readset, edge_writeset) - top_result[cfg.cfg_id] = result - return top_result + def apply(self, region: ControlFlowRegion, _) -> Dict[Union[ControlFlowBlock, Edge[InterstateEdge]], + Tuple[Set[str], Set[str]]]: + adesc = set(region.sdfg.arrays.keys()) + result: Dict[ControlFlowBlock, Tuple[Set[str], Set[str]]] = {} + for block in region.nodes(): + # No symbols may be written to inside blocks. + result[block] = (block.free_symbols, set()) + for oedge in region.out_edges(block): + edge_readset = oedge.data.read_symbols() - adesc + edge_writeset = set(oedge.data.assignments.keys()) + result[oedge] = (edge_readset, edge_writeset) + return result @properties.make_properties -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class AccessSets(ppl.Pass): """ - Evaluates memory access sets (which arrays/data descriptors are read/written in each state). + Evaluates memory access sets (which arrays/data descriptors are read/written in each control flow block). """ CATEGORY: str = 'Analysis' @@ -234,40 +237,65 @@ def modifies(self) -> ppl.Modifies: return ppl.Modifies.Nothing def should_reapply(self, modified: ppl.Modifies) -> bool: - # If anything was modified, reapply + # If access nodes were modified, reapply return modified & ppl.Modifies.AccessNodes - def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[SDFGState, Tuple[Set[str], Set[str]]]]: + def _get_loop_region_readset(self, loop: LoopRegion, arrays: Set[str]) -> Set[str]: + readset = set() + exprs = { loop.loop_condition.as_string } + update_stmt = loop_analysis.get_update_assignment(loop) + init_stmt = loop_analysis.get_init_assignment(loop) + if update_stmt: + exprs.add(update_stmt) + if init_stmt: + exprs.add(init_stmt) + for expr in exprs: + readset |= symbolic.free_symbols_and_functions(expr) & arrays + return readset + + def apply_pass(self, top_sdfg: SDFG, _) -> Dict[ControlFlowBlock, Tuple[Set[str], Set[str]]]: """ - :return: A dictionary mapping each state to a tuple of its (read, written) data descriptors. + :return: A dictionary mapping each control flow block to a tuple of its (read, written) data descriptors. """ - top_result: Dict[int, Dict[SDFGState, Tuple[Set[str], Set[str]]]] = {} + result: Dict[ControlFlowBlock, Tuple[Set[str], Set[str]]] = {} for sdfg in top_sdfg.all_sdfgs_recursive(): - result: Dict[SDFGState, Tuple[Set[str], Set[str]]] = {} - for state in sdfg.states(): + arrays: Set[str] = set(sdfg.arrays.keys()) + for block in sdfg.all_control_flow_blocks(): readset, writeset = set(), set() - for anode in state.data_nodes(): - if state.in_degree(anode) > 0: - writeset.add(anode.data) - if state.out_degree(anode) > 0: - readset.add(anode.data) - - result[state] = (readset, writeset) + if isinstance(block, SDFGState): + for anode in block.data_nodes(): + if block.in_degree(anode) > 0: + writeset.add(anode.data) + if block.out_degree(anode) > 0: + readset.add(anode.data) + elif isinstance(block, AbstractControlFlowRegion): + for state in block.all_states(): + for anode in state.data_nodes(): + if state.in_degree(anode) > 0: + writeset.add(anode.data) + if state.out_degree(anode) > 0: + readset.add(anode.data) + if isinstance(block, LoopRegion): + readset |= self._get_loop_region_readset(block, arrays) + elif isinstance(block, ConditionalBlock): + for cond, _ in block.branches: + if cond is not None: + readset |= symbolic.free_symbols_and_functions(cond.as_string) & arrays + + result[block] = (readset, writeset) # Edges that read from arrays add to both ends' access sets anames = sdfg.arrays.keys() - for e in sdfg.edges(): + for e in sdfg.all_interstate_edges(): fsyms = e.data.free_symbols & anames if fsyms: result[e.src][0].update(fsyms) result[e.dst][0].update(fsyms) - - top_result[sdfg.cfg_id] = result - return top_result + return result @properties.make_properties -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class FindAccessStates(ppl.Pass): """ For each data descriptor, creates a set of states in which access nodes of that data are used. @@ -306,7 +334,7 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[str, Set[SDFGState]]]: @properties.make_properties -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class FindAccessNodes(ppl.Pass): """ For each data descriptor, creates a dictionary mapping states to all read and write access nodes with the given @@ -343,9 +371,10 @@ def apply_pass(self, top_sdfg: SDFG, @properties.make_properties -class SymbolWriteScopes(ppl.Pass): +@transformation.explicit_cf_compatible +class SymbolWriteScopes(ppl.ControlFlowRegionPass): """ - For each symbol, create a dictionary mapping each writing interstate edge to that symbol to the set of interstate + For each symbol, create a dictionary mapping each interstate edge writing to that symbol to the set of interstate edges and states reading that symbol that are dominated by that write. """ @@ -355,17 +384,16 @@ def modifies(self) -> ppl.Modifies: return ppl.Modifies.Nothing def should_reapply(self, modified: ppl.Modifies) -> bool: - # If anything was modified, reapply - return modified & ppl.Modifies.Symbols | ppl.Modifies.States | ppl.Modifies.Edges | ppl.Modifies.Nodes + return modified & ppl.Modifies.Symbols | ppl.Modifies.CFG | ppl.Modifies.Edges | ppl.Modifies.Nodes def depends_on(self): - return {SymbolAccessSets, StateReachability} + return {SymbolAccessSets, ControlFlowBlockReachability} - def _find_dominating_write(self, sym: str, read: Union[SDFGState, Edge[InterstateEdge]], - state_idom: Dict[SDFGState, SDFGState]) -> Optional[Edge[InterstateEdge]]: - last_state: SDFGState = read if isinstance(read, SDFGState) else read.src + def _find_dominating_write(self, sym: str, read: Union[ControlFlowBlock, Edge[InterstateEdge]], + block_idom: Dict[ControlFlowBlock, ControlFlowBlock]) -> Optional[Edge[InterstateEdge]]: + last_block: ControlFlowBlock = read if isinstance(read, ControlFlowBlock) else read.src - in_edges = last_state.parent.in_edges(last_state) + in_edges = last_block.parent_graph.in_edges(last_block) deg = len(in_edges) if deg == 0: return None @@ -373,9 +401,9 @@ def _find_dominating_write(self, sym: str, read: Union[SDFGState, Edge[Interstat return in_edges[0] write_isedge = None - n_state = state_idom[last_state] if state_idom[last_state] != last_state else None - while n_state is not None and write_isedge is None: - oedges = n_state.parent.out_edges(n_state) + n_block = block_idom[last_block] if block_idom[last_block] != last_block else None + while n_block is not None and write_isedge is None: + oedges = n_block.parent_graph.out_edges(n_block) odeg = len(oedges) if odeg == 1: if any([sym == k for k in oedges[0].data.assignments.keys()]): @@ -383,70 +411,68 @@ def _find_dominating_write(self, sym: str, read: Union[SDFGState, Edge[Interstat else: dom_edge = None for cand in oedges: - if nxsp.has_path(n_state.parent.nx, cand.dst, last_state): + if nxsp.has_path(n_block.parent_graph.nx, cand.dst, last_block): if dom_edge is not None: dom_edge = None break elif any([sym == k for k in cand.data.assignments.keys()]): dom_edge = cand write_isedge = dom_edge - n_state = state_idom[n_state] if state_idom[n_state] != n_state else None + n_block = block_idom[n_block] if block_idom[n_block] != n_block else None return write_isedge - def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Dict[int, SymbolScopeDict]: - top_result: Dict[int, SymbolScopeDict] = dict() - - for sdfg in sdfg.all_sdfgs_recursive(): - result: SymbolScopeDict = defaultdict(lambda: defaultdict(lambda: set())) - - idom = nx.immediate_dominators(sdfg.nx, sdfg.start_state) - all_doms = cfg.all_dominators(sdfg, idom) - symbol_access_sets: Dict[Union[SDFGState, Edge[InterstateEdge]], - Tuple[Set[str], - Set[str]]] = pipeline_results[SymbolAccessSets.__name__][sdfg.cfg_id] - state_reach: Dict[SDFGState, Set[SDFGState]] = pipeline_results[StateReachability.__name__][sdfg.cfg_id] - - for read_loc, (reads, _) in symbol_access_sets.items(): - for sym in reads: - dominating_write = self._find_dominating_write(sym, read_loc, idom) - result[sym][dominating_write].add(read_loc if isinstance(read_loc, SDFGState) else read_loc) - - # If any write A is dominated by another write B and any reads in B's scope are also reachable by A, - # then merge A and its scope into B's scope. - to_remove = set() - for sym in result.keys(): - for write, accesses in result[sym].items(): - if write is None: - continue - dominators = all_doms[write.dst] - reach = state_reach[write.dst] - for dom in dominators: - iedges = dom.parent.in_edges(dom) - if len(iedges) == 1 and iedges[0] in result[sym]: - other_accesses = result[sym][iedges[0]] - coarsen = False - for a_state_or_edge in other_accesses: - if isinstance(a_state_or_edge, SDFGState): - if a_state_or_edge in reach: - coarsen = True - break - else: - if a_state_or_edge.src in reach: - coarsen = True - break - if coarsen: - other_accesses.update(accesses) - other_accesses.add(write) - to_remove.add((sym, write)) - result[sym][write] = set() - for sym, write in to_remove: - del result[sym][write] + def apply(self, region, pipeline_results) -> SymbolScopeDict: + result: SymbolScopeDict = defaultdict(lambda: defaultdict(lambda: set())) - top_result[sdfg.cfg_id] = result - return top_result + idom = nx.immediate_dominators(region.nx, region.start_block) + all_doms = cfg_analysis.all_dominators(region, idom) + + b_reach: Dict[ControlFlowBlock, + Set[ControlFlowBlock]] = pipeline_results[ControlFlowBlockReachability.__name__][region.cfg_id] + symbol_access_sets: Dict[Union[ControlFlowBlock, Edge[InterstateEdge]], + Tuple[Set[str], Set[str]]] = pipeline_results[SymbolAccessSets.__name__][region.cfg_id] + + for read_loc, (reads, _) in symbol_access_sets.items(): + for sym in reads: + dominating_write = self._find_dominating_write(sym, read_loc, idom) + result[sym][dominating_write].add(read_loc) + + # If any write A is dominated by another write B and any reads in B's scope are also reachable by A, then merge + # A and its scope into B's scope. + to_remove = set() + for sym in result.keys(): + for write, accesses in result[sym].items(): + if write is None: + continue + dominators = all_doms[write.dst] + reach = b_reach[write.dst] + for dom in dominators: + iedges = dom.parent_graph.in_edges(dom) + if len(iedges) == 1 and iedges[0] in result[sym]: + other_accesses = result[sym][iedges[0]] + coarsen = False + for a_state_or_edge in other_accesses: + if isinstance(a_state_or_edge, SDFGState): + if a_state_or_edge in reach: + coarsen = True + break + else: + if a_state_or_edge.src in reach: + coarsen = True + break + if coarsen: + other_accesses.update(accesses) + other_accesses.add(write) + to_remove.add((sym, write)) + result[sym][write] = set() + for sym, write in to_remove: + del result[sym][write] + + return result @properties.make_properties +@transformation.explicit_cf_compatible class ScalarWriteShadowScopes(ppl.Pass): """ For each scalar or array of size 1, create a dictionary mapping writes to that data container to the set of reads @@ -463,17 +489,18 @@ def should_reapply(self, modified: ppl.Modifies) -> bool: return modified & ppl.Modifies.States def depends_on(self): - return {AccessSets, FindAccessNodes, StateReachability} + return {AccessSets, FindAccessNodes, ControlFlowBlockReachability} def _find_dominating_write(self, desc: str, - state: SDFGState, + block: ControlFlowBlock, read: Union[nd.AccessNode, InterstateEdge], access_nodes: Dict[SDFGState, Tuple[Set[nd.AccessNode], Set[nd.AccessNode]]], - state_idom: Dict[SDFGState, SDFGState], - access_sets: Dict[SDFGState, Tuple[Set[str], Set[str]]], + idom_dict: Dict[ControlFlowRegion, Dict[ControlFlowBlock, ControlFlowBlock]], + access_sets: Dict[ControlFlowBlock, Tuple[Set[str], Set[str]]], no_self_shadowing: bool = False) -> Optional[Tuple[SDFGState, nd.AccessNode]]: if isinstance(read, nd.AccessNode): + state: SDFGState = block # If the read is also a write, it shadows itself. iedges = state.in_edges(read) if len(iedges) > 0 and any(not e.data.is_empty() for e in iedges) and not no_self_shadowing: @@ -489,24 +516,31 @@ def _find_dominating_write(self, closest_candidate = cand if closest_candidate is not None: return (state, closest_candidate) - elif isinstance(read, InterstateEdge): + elif isinstance(read, InterstateEdge) and isinstance(block, SDFGState): # Attempt to find a shadowing write in the current state. # TODO: Can this be done more efficiently? closest_candidate = None - write_nodes = access_nodes[desc][state][1] + write_nodes = access_nodes[desc][block][1] for cand in write_nodes: - if closest_candidate is None or nxsp.has_path(state._nx, closest_candidate, cand): + if closest_candidate is None or nxsp.has_path(block._nx, closest_candidate, cand): closest_candidate = cand if closest_candidate is not None: - return (state, closest_candidate) + return (block, closest_candidate) - # Find the dominating write state if the current state is not the dominating write state. + # Find the dominating write state if the current block is not the dominating write state. write_state = None - nstate = state_idom[state] if state_idom[state] != state else None - while nstate is not None and write_state is None: - if desc in access_sets[nstate][1]: - write_state = nstate - nstate = state_idom[nstate] if state_idom[nstate] != nstate else None + pivot_block = block + region = block.parent_graph + while region is not None and write_state is None: + nblock = idom_dict[region][pivot_block] if idom_dict[region][pivot_block] != block else None + while nblock is not None and write_state is None: + if isinstance(nblock, SDFGState) and desc in access_sets[nblock][1]: + write_state = nblock + nblock = idom_dict[region][nblock] if idom_dict[region][nblock] != nblock else None + # No dominating write found in the current control flow graph, check one further up. + if write_state is None: + pivot_block = region + region = region.parent_graph # Find a dominating write in the write state, i.e., the 'last' write to the data container. if write_state is not None: @@ -530,33 +564,52 @@ def apply_pass(self, top_sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Dict[i """ top_result: Dict[int, WriteScopeDict] = dict() + access_sets: Dict[ControlFlowBlock, Tuple[Set[str], Set[str]]] = pipeline_results[AccessSets.__name__] + for sdfg in top_sdfg.all_sdfgs_recursive(): result: WriteScopeDict = defaultdict(lambda: defaultdict(lambda: set())) - idom = nx.immediate_dominators(sdfg.nx, sdfg.start_state) - all_doms = cfg.all_dominators(sdfg, idom) - access_sets: Dict[SDFGState, Tuple[Set[str], - Set[str]]] = pipeline_results[AccessSets.__name__][sdfg.cfg_id] + idom_dict: Dict[ControlFlowRegion, Dict[ControlFlowBlock, ControlFlowBlock]] = {} + all_doms_transitive: Dict[ControlFlowBlock, Set[ControlFlowBlock]] = defaultdict(lambda: set()) + for cfg in sdfg.all_control_flow_regions(): + if isinstance(cfg, ConditionalBlock): + idom_dict[cfg] = {b: b for _, b in cfg.branches} + all_doms = {b: set([b]) for _, b in cfg.branches} + else: + idom_dict[cfg] = nx.immediate_dominators(cfg.nx, cfg.start_block) + all_doms = cfg_analysis.all_dominators(cfg, idom_dict[cfg]) + + # Since all_control_flow_regions goes top-down in the graph hierarchy, we can build a transitive + # closure of all dominators her. + for k in all_doms.keys(): + all_doms_transitive[k].update(all_doms[k]) + all_doms_transitive[k].add(cfg) + all_doms_transitive[k].update(all_doms_transitive[cfg]) + access_nodes: Dict[str, Dict[SDFGState, Tuple[Set[nd.AccessNode], Set[nd.AccessNode]]]] = pipeline_results[ FindAccessNodes.__name__][sdfg.cfg_id] - state_reach: Dict[SDFGState, Set[SDFGState]] = pipeline_results[StateReachability.__name__][sdfg.cfg_id] + + block_reach: Dict[ControlFlowBlock, Set[ControlFlowBlock]] = pipeline_results[ + ControlFlowBlockReachability.__name__ + ] anames = sdfg.arrays.keys() for desc in sdfg.arrays: desc_states_with_nodes = set(access_nodes[desc].keys()) for state in desc_states_with_nodes: for read_node in access_nodes[desc][state][0]: - write = self._find_dominating_write(desc, state, read_node, access_nodes, idom, access_sets) + write = self._find_dominating_write(desc, state, read_node, access_nodes, idom_dict, + access_sets) result[desc][write].add((state, read_node)) # Ensure accesses to interstate edges are also considered. - for state, accesses in access_sets.items(): + for block, accesses in access_sets.items(): if desc in accesses[0]: - out_edges = sdfg.out_edges(state) + out_edges = block.parent_graph.out_edges(block) for oedge in out_edges: syms = oedge.data.free_symbols & anames if desc in syms: - write = self._find_dominating_write(desc, state, oedge.data, access_nodes, idom, + write = self._find_dominating_write(desc, block, oedge.data, access_nodes, idom_dict, access_sets) - result[desc][write].add((state, oedge.data)) + result[desc][write].add((block, oedge.data)) # Take care of any write nodes that have not been assigned to a scope yet, i.e., writes that are not # dominating any reads and are thus not part of the results yet. for state in desc_states_with_nodes: @@ -566,7 +619,7 @@ def apply_pass(self, top_sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Dict[i state, write_node, access_nodes, - idom, + idom_dict, access_sets, no_self_shadowing=True) result[desc][write].add((state, write_node)) @@ -578,8 +631,8 @@ def apply_pass(self, top_sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Dict[i if write is None: continue write_state, write_node = write - dominators = all_doms[write_state] - reach = state_reach[write_state] + dominators = all_doms_transitive[write_state] + reach = block_reach[write_state.parent_graph.cfg_id][write_state] for other_write, other_accesses in result[desc].items(): if other_write is not None and other_write[1] is write_node and other_write[0] is write_state: continue @@ -598,7 +651,7 @@ def apply_pass(self, top_sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Dict[i @properties.make_properties -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class AccessRanges(ppl.Pass): """ For each data descriptor, finds all memlets used to access it (read/write ranges). @@ -636,7 +689,7 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[str, Set[Memlet]]]: @dataclass(unsafe_hash=True) @properties.make_properties -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class FindReferenceSources(ppl.Pass): """ For each Reference data descriptor, finds all memlets used to set it. If a Tasklet was used @@ -725,7 +778,7 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[int, Dict[str, Set[Union[Memlet, @properties.make_properties -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class DeriveSDFGConstraints(ppl.Pass): CATEGORY: str = 'Analysis' @@ -758,3 +811,163 @@ def apply_pass(self, sdfg: SDFG, _) -> Tuple[Dict[str, Set[str]], Dict[str, Set[ invariants: Dict[str, Set[str]] = {} self._derive_parameter_datasize_constraints(sdfg, invariants) return {}, invariants, {} + + +@transformation.explicit_cf_compatible +class StatePropagation(ppl.ControlFlowRegionPass): + """ + Analyze a control flow region to determine the number of times each block inside of it is executed in the form of a + symbolic expression, or a concrete number where possible. + Each control flow block is marked with a symbolic expression for the number of executions, and a boolean flag to + indicate whether the number of executions is dynamic or not. A combination of dynamic being set to true and the + number of executions being 0 indicates that the number of executions is dynamically unbounded. + Additionally, the pass annotates each block with a `ranges` property, which indicates for loop variables defined + at that block what range of values the variable may take on. + Note: This path directly annotates the graph. + This pass supersedes `dace.sdfg.propagation.propagate_states` and is based on its algorithm, with significant + simplifications thanks to the use of control flow regions. + """ + + CATEGORY: str = 'Analysis' + + def __init__(self): + super().__init__() + self.top_down = True + self.apply_to_conditionals = True + + def depends_on(self): + return {ControlFlowBlockReachability} + + def _propagate_in_cfg(self, cfg: ControlFlowRegion, reachable: Dict[ControlFlowBlock, Set[ControlFlowBlock]], + starting_executions: int, starting_dynamic_executions: bool): + visited_blocks: Set[ControlFlowBlock] = set() + traversal_q: deque[Tuple[ControlFlowBlock, int, bool, List[str]]] = deque() + traversal_q.append((cfg.start_block, starting_executions, starting_dynamic_executions, [])) + while traversal_q: + (block, proposed_executions, proposed_dynamic, itvar_stack) = traversal_q.pop() + out_edges = cfg.out_edges(block) + if block in visited_blocks: + # This block has already been visited, meaning there are multiple paths towards this block. + if proposed_executions == 0 and proposed_dynamic: + block.executions = 0 + block.dynamic_executions = True + else: + block.executions = sympy.Max(block.executions, proposed_executions).doit() + block.dynamic_executions = (block.dynamic_executions or proposed_dynamic) + elif proposed_dynamic and proposed_executions == 0: + # We're propagating a dynamic unbounded number of executions, which always gets propagated + # unconditionally. Propagate to all children. + visited_blocks.add(block) + block.executions = proposed_executions + block.dynamic_executions = proposed_dynamic + # This gets pushed through to all children unconditionally. + if len(out_edges) > 0: + for oedge in out_edges: + traversal_q.append((oedge.dst, proposed_executions, proposed_dynamic, itvar_stack)) + else: + # If the state hasn't been visited yet and we're not propagating a dynamic unbounded number of + # executions, we calculate the number of executions for the next state(s) and continue propagating. + visited_blocks.add(block) + block.executions = proposed_executions + block.dynamic_executions = proposed_dynamic + if len(out_edges) == 1: + # Continue with the only child state. + if not out_edges[0].data.is_unconditional(): + # If the transition to the child state is based on a condition, this state could be an implicit + # exit state. The child state's number of executions is thus only given as an upper bound and + # marked as dynamic. + proposed_dynamic = True + traversal_q.append((out_edges[0].dst, proposed_executions, proposed_dynamic, itvar_stack)) + elif len(out_edges) > 1: + # Conditional split + for oedge in out_edges: + traversal_q.append((oedge.dst, block.executions, True, itvar_stack)) + + # Check if the CFG contains any cycles. Any cycles left in the graph (after control flow raising) are + # irreducible control flow and thus lead to a dynamically unbounded number of executions. Mark any block + # inside and reachable from any block inside the cycle as dynamically unbounded, irrespectively of what it was + # marked as before. + cycles: Iterable[Iterable[ControlFlowBlock]] = cfg.find_cycles() + for cycle in cycles: + for blk in cycle: + blk.executions = 0 + blk.dynamic_executions = True + for reached in reachable[blk]: + reached.executions = 0 + blk.dynamic_executions = True + + def apply(self, region, pipeline_results) -> None: + if isinstance(region, ConditionalBlock): + # In a conditional block, each branch is executed up to as many times as the conditional block itself is. + # TODO(later): We may be able to derive ranges here based on the branch conditions too. + for _, b in region.branches: + b.executions = region.executions + b.dynamic_executions = True + b.ranges = region.ranges + else: + if isinstance(region, SDFG): + # The root SDFG is executed exactly once, any other, nested SDFG is executed as many times as the parent + # state is. + if region is region.root_sdfg: + region.executions = 1 + region.dynamic_executions = False + elif region.parent: + region.executions = region.parent.executions + region.dynamic_executions = region.parent.dynamic_executions + + # Clear existing annotations. + for blk in region.nodes(): + blk.executions = 0 + blk.dynamic_executions = True + blk.ranges = region.ranges + + # Determine the number of executions for the start block within this region. In the case of loops, this + # is dependent on the number of loop iterations - where they can be determined. Where they may not be + # determined, the number of iterations is assumed to be dynamically unbounded. For any other control flow + # region, the start block is executed as many times as the region itself is. + starting_execs = region.executions + starting_dynamic = region.dynamic_executions + if isinstance(region, LoopRegion): + # If inside a loop, add range information if possible. + start = loop_analysis.get_init_assignment(region) + stop = loop_analysis.get_loop_end(region) + stride = loop_analysis.get_loop_stride(region) + if start is not None and stop is not None and stride is not None and region.loop_variable: + # This inequality needs to be checked exactly like this due to constraints in sympy/symbolic + # expressions, do not simplify! + if (stride < 0) == True: + rng = (stop, start, -stride) + else: + rng = (start, stop, stride) + for blk in region.nodes(): + blk.ranges[str(region.loop_variable)] = Range([rng]) + + # Get surrounding iteration variables for the case of nested loops. + itvar_stack = [] + par = region.parent_graph + while par is not None and not isinstance(par, SDFG): + if isinstance(par, LoopRegion) and par.loop_variable: + itvar_stack.append(par.loop_variable) + par = par.parent_graph + + # Calculate the number of loop executions. + # This resolves ranges based on the order of iteration variables from surrounding loops. + loop_executions = sympy.ceiling(((stop + 1) - start) / stride) + for outer_itvar_string in itvar_stack: + outer_range = region.ranges[outer_itvar_string] + outer_start = outer_range[0][0] + outer_stop = outer_range[0][1] + outer_stride = outer_range[0][2] + outer_itvar = symbolic.pystr_to_symbolic(outer_itvar_string) + exec_repl = loop_executions.subs({outer_itvar: (outer_itvar * outer_stride + outer_start)}) + sum_rng = (outer_itvar, 0, sympy.ceiling((outer_stop - outer_start) / outer_stride)) + loop_executions = sympy.Sum(exec_repl, sum_rng) + starting_execs = loop_executions.doit() + starting_dynamic = region.dynamic_executions + else: + starting_execs = 0 + starting_dynamic = True + + # Propagate the number of executions. + self._propagate_in_cfg(region, pipeline_results[ControlFlowBlockReachability.__name__][region.cfg_id], + starting_execs, starting_dynamic) diff --git a/dace/transformation/passes/analysis/loop_analysis.py b/dace/transformation/passes/analysis/loop_analysis.py index 3d15f73c73..69a77422e8 100644 --- a/dace/transformation/passes/analysis/loop_analysis.py +++ b/dace/transformation/passes/analysis/loop_analysis.py @@ -3,8 +3,7 @@ Various analyses concerning LopoRegions, and utility functions to get information about LoopRegions for other passes. """ -import ast -from typing import Any, Dict, Optional +from typing import Dict, Optional from dace.frontend.python import astutils import sympy @@ -13,28 +12,12 @@ from dace.sdfg.state import LoopRegion -class FindAssignment(ast.NodeVisitor): - - assignments: Dict[str, str] - multiple: bool - - def __init__(self): - self.assignments = {} - self.multiple = False - - def visit_Assign(self, node: ast.Assign) -> Any: - for tgt in node.targets: - if isinstance(tgt, ast.Name): - if tgt.id in self.assignments: - self.multiple = True - self.assignments[tgt.id] = astutils.unparse(node.value) - return self.generic_visit(node) - - def get_loop_end(loop: LoopRegion) -> Optional[symbolic.SymbolicType]: """ Parse a loop region to identify the end value of the iteration variable under normal loop termination (no break). """ + if loop.loop_variable is None or loop.loop_variable == '': + return None end: Optional[symbolic.SymbolicType] = None a = sympy.Wild('a') condition = symbolic.pystr_to_symbolic(loop.loop_condition.as_string) @@ -68,7 +51,7 @@ def get_init_assignment(loop: LoopRegion) -> Optional[symbolic.SymbolicType]: init_codes_list = init_stmt.code if isinstance(init_stmt.code, list) else [init_stmt.code] assignments: Dict[str, str] = {} for code in init_codes_list: - visitor = FindAssignment() + visitor = astutils.FindAssignment() visitor.visit(code) if visitor.multiple: return None @@ -94,7 +77,7 @@ def get_update_assignment(loop: LoopRegion) -> Optional[symbolic.SymbolicType]: update_codes_list = update_stmt.code if isinstance(update_stmt.code, list) else [update_stmt.code] assignments: Dict[str, str] = {} for code in update_codes_list: - visitor = FindAssignment() + visitor = astutils.FindAssignment() visitor.visit(code) if visitor.multiple: return None diff --git a/dace/transformation/passes/array_elimination.py b/dace/transformation/passes/array_elimination.py index 46411478d5..fd472336e0 100644 --- a/dace/transformation/passes/array_elimination.py +++ b/dace/transformation/passes/array_elimination.py @@ -13,7 +13,7 @@ @properties.make_properties -@transformation.single_level_sdfg_only +@transformation.explicit_cf_compatible class ArrayElimination(ppl.Pass): """ Merges and removes arrays and their corresponding accesses. This includes redundant array copies, unnecessary views, @@ -48,7 +48,7 @@ def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Optional[S # Traverse SDFG backwards try: - state_order = list(cfg.blockorder_topological_sort(sdfg)) + state_order = list(cfg.blockorder_topological_sort(sdfg, recursive=True, ignore_nonstate_blocks=True)) except KeyError: return None for state in reversed(state_order): @@ -132,14 +132,14 @@ def remove_redundant_views(self, sdfg: SDFG, state: SDFGState, access_nodes: Dic """ removed_nodes: Set[nodes.AccessNode] = set() xforms = [RemoveSliceView()] - state_id = sdfg.node_id(state) + state_id = state.block_id for nodeset in access_nodes.values(): for anode in list(nodeset): for xform in xforms: # Quick path to setup match candidate = {type(xform).view: anode} - xform.setup_match(sdfg, sdfg.cfg_id, state_id, candidate, 0, override=True) + xform.setup_match(sdfg, state.parent_graph.cfg_id, state_id, candidate, 0, override=True) # Try to apply if xform.can_be_applied(state, 0, sdfg): @@ -154,7 +154,7 @@ def remove_redundant_copies(self, sdfg: SDFG, state: SDFGState, removable_data: Removes access nodes that represent redundant copies and/or views. """ removed_nodes: Set[nodes.AccessNode] = set() - state_id = sdfg.node_id(state) + state_id = state.block_id # Transformations that remove the first access node xforms_first: List[SingleStateTransformation] = [RedundantWriteSlice(), UnsqueezeViewRemove(), RedundantArray()] @@ -184,7 +184,8 @@ def remove_redundant_copies(self, sdfg: SDFG, state: SDFGState, removable_data: for xform in xforms_first: # Quick path to setup match candidate = {type(xform).in_array: anode, type(xform).out_array: succ} - xform.setup_match(sdfg, sdfg.cfg_id, state_id, candidate, 0, override=True) + xform.setup_match(sdfg, state.parent_graph.cfg_id, state_id, candidate, 0, + override=True) # Try to apply if xform.can_be_applied(state, 0, sdfg): @@ -204,7 +205,8 @@ def remove_redundant_copies(self, sdfg: SDFG, state: SDFGState, removable_data: for xform in xforms_second: # Quick path to setup match candidate = {type(xform).in_array: pred, type(xform).out_array: anode} - xform.setup_match(sdfg, sdfg.cfg_id, state_id, candidate, 0, override=True) + xform.setup_match(sdfg, state.parent_graph.cfg_id, state_id, candidate, 0, + override=True) # Try to apply if xform.can_be_applied(state, 0, sdfg): diff --git a/dace/transformation/passes/consolidate_edges.py b/dace/transformation/passes/consolidate_edges.py index 5b1aae2621..94cd29b6ae 100644 --- a/dace/transformation/passes/consolidate_edges.py +++ b/dace/transformation/passes/consolidate_edges.py @@ -5,11 +5,11 @@ from dace import SDFG, properties from typing import Optional -from dace.transformation.transformation import experimental_cfg_block_compatible +from dace.transformation.transformation import explicit_cf_compatible @properties.make_properties -@experimental_cfg_block_compatible +@explicit_cf_compatible class ConsolidateEdges(ppl.Pass): """ Removes extraneous edges with memlets that refer to the same data containers within the same scope. diff --git a/dace/transformation/passes/constant_propagation.py b/dace/transformation/passes/constant_propagation.py index bfa0928415..24c35edcc9 100644 --- a/dace/transformation/passes/constant_propagation.py +++ b/dace/transformation/passes/constant_propagation.py @@ -1,11 +1,12 @@ -# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. import ast from dataclasses import dataclass from dace.frontend.python import astutils -from dace.sdfg.analysis import cfg +from dace.sdfg.analysis import cfg as cfg_analysis from dace.sdfg.sdfg import InterstateEdge -from dace.sdfg import nodes, utils as sdutil +from dace.sdfg import nodes +from dace.sdfg.state import AbstractControlFlowRegion, ConditionalBlock, ControlFlowBlock, ControlFlowRegion, LoopRegion from dace.transformation import pass_pipeline as ppl, transformation from dace.cli.progress import optional_progressbar from dace import data, SDFG, SDFGState, dtypes, symbolic, properties @@ -17,9 +18,13 @@ class _UnknownValue: pass +ConstsT = Dict[str, Any] +BlockConstsT = Dict[ControlFlowBlock, ConstsT] + + @dataclass(unsafe_hash=True) @properties.make_properties -@transformation.single_level_sdfg_only +@transformation.explicit_cf_compatible class ConstantPropagation(ppl.Pass): """ Propagates constants and symbols that were assigned to one value forward through the SDFG, reducing @@ -42,7 +47,7 @@ def should_apply(self, sdfg: SDFG) -> bool: """ Fast check (O(m)) whether the pass should early-exit without traversing the SDFG. """ - for edge in sdfg.edges(): + for edge in sdfg.all_interstate_edges(): # If there are no assignments, there are no constants to propagate if len(edge.data.assignments) == 0: continue @@ -69,8 +74,27 @@ def apply_pass(self, sdfg: SDFG, _, initial_symbols: Optional[Dict[str, Any]] = if not initial_symbols and not self.should_apply(sdfg): result = {} else: - # Trace all constants and symbols through states - per_state_constants: Dict[SDFGState, Dict[str, Any]] = self.collect_constants(sdfg, initial_symbols) + arrays: Set[str] = set(sdfg.arrays.keys() | sdfg.constants_prop.keys()) + + # Add nested data to arrays + def _add_nested_datanames(name: str, desc: data.Structure): + for k, v in desc.members.items(): + if isinstance(v, data.Structure): + _add_nested_datanames(f'{name}.{k}', v) + elif isinstance(v, data.ContainerArray): + pass + arrays.add(f'{name}.{k}') + + for name, desc in sdfg.arrays.items(): + if isinstance(desc, data.Structure): + _add_nested_datanames(name, desc) + + # Trace all constants and symbols through blocks + in_constants: BlockConstsT = { sdfg: initial_symbols } + pre_constants: BlockConstsT = {} + post_constants: BlockConstsT = {} + out_constants: BlockConstsT = {} + self._collect_constants_for_region(sdfg, arrays, in_constants, pre_constants, post_constants, out_constants) # Keep track of replaced and ambiguous symbols symbols_replaced: Dict[str, Any] = {} @@ -78,13 +102,14 @@ def apply_pass(self, sdfg: SDFG, _, initial_symbols: Optional[Dict[str, Any]] = # Collect symbols from symbol-dependent data descriptors # If there can be multiple values over the SDFG, the symbols are not propagated - desc_symbols, multivalue_desc_symbols = self._find_desc_symbols(sdfg, per_state_constants) + desc_symbols, multivalue_desc_symbols = self._find_desc_symbols(sdfg, in_constants) # Replace constants per state - for state, mapping in optional_progressbar(per_state_constants.items(), - 'Propagating constants', - n=len(per_state_constants), - progress=self.progress): + for block, mapping in optional_progressbar(in_constants.items(), 'Propagating constants', + n=len(in_constants), progress=self.progress): + if block is sdfg: + continue + remaining_unknowns.update( {k for k, v in mapping.items() if v is _UnknownValue or k in multivalue_desc_symbols}) @@ -92,17 +117,29 @@ def apply_pass(self, sdfg: SDFG, _, initial_symbols: Optional[Dict[str, Any]] = k: v for k, v in mapping.items() if v is not _UnknownValue and k not in multivalue_desc_symbols } - if not mapping: - continue + out_mapping = { + k: v + for k, v in out_constants[block].items() + if v is not _UnknownValue and k not in multivalue_desc_symbols + } - # Update replaced symbols for later replacements - symbols_replaced.update(mapping) + if mapping: + # Update replaced symbols for later replacements + symbols_replaced.update(mapping) - # Replace in state contents - state.replace_dict(mapping) - # Replace in outgoing edges as well - for e in sdfg.out_edges(state): - e.data.replace_dict(mapping, replace_keys=False) + if isinstance(block, SDFGState): + # Replace in state contents + block.replace_dict(mapping) + elif isinstance(block, AbstractControlFlowRegion): + block.replace_dict(mapping, replace_in_graph=False, replace_keys=False) + + if out_mapping: + # Replace in outgoing edges as well + for e in block.parent_graph.out_edges(block): + e.data.replace_dict(out_mapping, replace_keys=False) + + if isinstance(block, LoopRegion): + self._propagate_loop(block, post_constants, multivalue_desc_symbols) # Gather initial propagated symbols result = {k: v for k, v in symbols_replaced.items() if k not in remaining_unknowns} @@ -114,7 +151,7 @@ def apply_pass(self, sdfg: SDFG, _, initial_symbols: Optional[Dict[str, Any]] = replace_keys=False) # Remove constant symbol assignments in interstate edges - for edge in sdfg.edges(): + for edge in sdfg.all_interstate_edges(): intersection = result & edge.data.assignments.keys() for sym in intersection: del edge.data.assignments[sym] @@ -134,7 +171,7 @@ def apply_pass(self, sdfg: SDFG, _, initial_symbols: Optional[Dict[str, Any]] = sid = sdfg.cfg_id result = set((sid, sym) for sym in result) - for state in sdfg.nodes(): + for state in sdfg.states(): for node in state.nodes(): if isinstance(node, nodes.NestedSDFG): nested_id = node.sdfg.cfg_id @@ -155,59 +192,160 @@ def apply_pass(self, sdfg: SDFG, _, initial_symbols: Optional[Dict[str, Any]] = def report(self, pass_retval: Set[str]) -> str: return f'Propagated {len(pass_retval)} constants.' - def collect_constants(self, - sdfg: SDFG, - initial_symbols: Optional[Dict[str, Any]] = None) -> Dict[SDFGState, Dict[str, Any]]: + def _propagate_loop(self, loop: LoopRegion, post_constants: BlockConstsT, + multivalue_desc_symbols: Set[str]) -> None: + if loop in post_constants and post_constants[loop] is not None: + if loop.update_statement is not None and (loop.inverted and loop.update_before_condition or + not loop.inverted): + # Replace the RHS of the update experssion + post_mapping = { + k: v + for k, v in post_constants[loop].items() + if v is not _UnknownValue and k not in multivalue_desc_symbols + } + update_stmt = loop.update_statement + updates = update_stmt.code if isinstance(update_stmt.code, list) else [update_stmt.code] + for update in updates: + astutils.ASTReplaceAssignmentRHS(post_mapping).visit(update) + loop.update_statement.code = updates + + def _collect_constants_for_conditional(self, conditional: ConditionalBlock, arrays: Set[str], + in_const_dict: BlockConstsT, pre_const_dict: BlockConstsT, + post_const_dict: BlockConstsT, out_const_dict: BlockConstsT) -> None: """ - Finds all constants and constant-assigned symbols in the SDFG for each state. - - :param sdfg: The SDFG to traverse. - :param initial_symbols: If not None, sets values of initial symbols. - :return: A dictionary mapping an SDFG state to a mapping of constants and their corresponding values. + Collect the constants for and inside of a conditional region. + Recursively collects constants inside of nested regions. + + :param conditional: The conditional region to traverse. + :param arrays: A set of data descriptors in the SDFG. + :param in_const_dict: Dictionary mapping each control flow block to the set of constants observed right before + the block is executed. Populated by this function. + :param pre_const_dict: Dictionary mapping each control flow block to the set of constants observed before its + contents are executed. Populated by this function. + :param post_const_dict: Dictionary mapping each control flow block to the set of constants observed after its + contents are executed. Populated by this function. + :param out_const_dict: Dictionary mapping each control flow block to the set of constants observed right after + the block is executed. Populated by this function. """ - arrays: Set[str] = set(sdfg.arrays.keys() | sdfg.constants_prop.keys()) - result: Dict[SDFGState, Dict[str, Any]] = {} - - # Add nested data to arrays - def _add_nested_datanames(name: str, desc: data.Structure): - for k, v in desc.members.items(): - if isinstance(v, data.Structure): - _add_nested_datanames(f'{name}.{k}', v) - elif isinstance(v, data.ContainerArray): - # TODO: How are we handling this? - pass - arrays.add(f'{name}.{k}') - - for name, desc in sdfg.arrays.items(): - if isinstance(desc, data.Structure): - _add_nested_datanames(name, desc) + in_consts = in_const_dict[conditional] + # First, collect all constants for each of the branches. + for _, branch in conditional.branches: + in_const_dict[branch] = in_consts + self._collect_constants_for_region(branch, arrays, in_const_dict, pre_const_dict, post_const_dict, + out_const_dict) + # Second, determine the 'post constants' (constants at the end of the conditional region) as an intersection + # between the output constants of each of the branches. + post_consts = {} + post_consts_intersection = None + has_else = False + for cond, branch in conditional.branches: + if post_consts_intersection is None: + post_consts_intersection = set(out_const_dict[branch].keys()) + else: + post_consts_intersection &= set(out_const_dict[branch].keys()) + if cond is None: + has_else = True + for _, branch in conditional.branches: + for k, v in out_const_dict[branch].items(): + if k in post_consts_intersection: + if k not in post_consts: + post_consts[k] = v + elif post_consts[k] != _UnknownValue and post_consts[k] != v: + post_consts[k] = _UnknownValue + else: + post_consts[k] = _UnknownValue + post_const_dict[conditional] = post_consts + + # Finally, determine the conditional region's output constants. + if has_else: + # If there is an else, at least one branch will certainly be taken, so the output constants are the region's + # post constants. + out_const_dict[conditional] = post_consts + else: + # No else branch is present, so it is possible that no branch is executed. In this case the out constants + # are the intersection between the in constants and the post constants. + out_consts = in_consts.copy() + for k, v in post_consts.items(): + if k not in out_consts: + out_consts[k] = _UnknownValue + elif out_consts[k] != _UnknownValue and out_consts[k] != v: + out_consts[k] = _UnknownValue + out_const_dict[conditional] = out_consts + + def _assignments_in_loop(self, loop: LoopRegion) -> Set[str]: + assignments_within = set() + for e in loop.all_interstate_edges(): + for k in e.data.assignments.keys(): + assignments_within.add(k) + if loop.loop_variable is not None: + assignments_within.add(loop.loop_variable) + return assignments_within + + def _collect_constants_for_region(self, cfg: ControlFlowRegion, arrays: Set[str], in_const_dict: BlockConstsT, + pre_const_dict: BlockConstsT, post_const_dict: BlockConstsT, + out_const_dict: BlockConstsT) -> None: + """ + Finds all constants and constant-assigned symbols in the control flow graph for each block. + Recursively collects constants for nested control flow regions. + + :param cfg: The CFG to traverse. + :param arrays: A set of data descriptors in the SDFG. + :param in_const_dict: Dictionary mapping each control flow block to the set of constants observed right before + the block is executed. Populated by this function. + :param pre_const_dict: Dictionary mapping each control flow block to the set of constants observed before its + contents are executed. Populated by this function. + :param post_const_dict: Dictionary mapping each control flow block to the set of constants observed after its + contents are executed. Populated by this function. + :param out_const_dict: Dictionary mapping each control flow block to the set of constants observed right after + the block is executed. Populated by this function. + """ + # Given the 'in constants', i.e., the constants for before the current region is executed, compute the 'pre + # constants', i.e., the set of constants seen inside the region when executing. + if cfg in in_const_dict: + in_const = in_const_dict[cfg] + if isinstance(cfg, LoopRegion): + # In the case of a loop, the 'pre constants' are equivalent to the 'in constants', with the exception + # of values that may at any point be re-assigned inside the loop, since that assignment would carry over + # into the next iteration (including increments to the loop variable, if present). + assigned_in_loop = self._assignments_in_loop(cfg) + pre_const = { k: (v if k not in assigned_in_loop else _UnknownValue) for k, v in in_const.items() } + else: + # In any other case, the 'pre constants' are equivalent to the 'in constants'. + pre_const = {} + pre_const.update(in_const) + else: + # No 'in constants' for the current region - so initialize to nothing. + pre_const = {} + pre_const_dict[cfg] = pre_const + in_const = {} + pre_const_dict[cfg] = pre_const # Process: - # * Collect constants in topologically ordered states + # * Collect constants in topologically ordered blocks # * Propagate forward symbols forward and edge assignments # * If value is ambiguous (not the same), set value to UNKNOWN # * Repeat until no update is performed - start_state = sdfg.start_state - if initial_symbols: - result[start_state] = {} - result[start_state].update(initial_symbols) + start_block = cfg.start_block + if pre_const: + in_const_dict[start_block] = {} + in_const_dict[start_block].update(pre_const) redo = True while redo: redo = False - # Traverse SDFG topologically - for state in optional_progressbar(cfg.blockorder_topological_sort(sdfg), 'Collecting constants', - sdfg.number_of_nodes(), self.progress): - + # Traverse CFG topologically + for block in optional_progressbar(cfg_analysis.blockorder_topological_sort(cfg, recursive=False), + 'Collecting constants for ' + cfg.label, cfg.number_of_nodes(), + self.progress): # Get predecessors - in_edges = sdfg.in_edges(state) + in_edges = cfg.in_edges(block) assignments = {} for edge in in_edges: # If source was already visited, use its propagated constants constants: Dict[str, Any] = {} - if edge.src in result: - constants.update(result[edge.src]) + if edge.src in out_const_dict: + constants.update(out_const_dict[edge.src]) # Update constants with incoming edge self._propagate(constants, self._data_independent_assignments(edge.data, arrays)) @@ -217,16 +355,15 @@ def _add_nested_datanames(name: str, desc: data.Structure): # If a symbol appearing in the replacing expression of a constant is modified, # the constant is not valid anymore if ((aname in assignments and aval != assignments[aname]) or - symbolic.free_symbols_and_functions(aval) & edge.data.assignments.keys()): + symbolic.free_symbols_and_functions(aval) & edge.data.assignments.keys()): assignments[aname] = _UnknownValue else: assignments[aname] = aval - for edge in sdfg.out_edges(state): + for edge in cfg.out_edges(block): for aname, aval in assignments.items(): - # If the specific replacement would result in the value - # being both used and reassigned on the same inter-state - # edge, remove it from consideration. + # If the specific replacement would result in the value being both used and reassigned on the + # same inter-state edge, remove it from consideration. replacements = symbolic.free_symbols_and_functions(aval) used_in_assignments = { k @@ -236,12 +373,67 @@ def _add_nested_datanames(name: str, desc: data.Structure): if reassignments and (used_in_assignments - reassignments): assignments[aname] = _UnknownValue - if state not in result: # Condition may evaluate to False when state is the start-state - result[state] = {} - redo |= self._propagate(result[state], assignments) - - return result - + if isinstance(block, LoopRegion): + # Any constants before a loop that may be overwritten inside the loop cannot be assumed as constants + # for the loop itself. + assigned_in_loop = self._assignments_in_loop(block) + for k in assignments.keys(): + if k in assigned_in_loop: + assignments[k] = _UnknownValue + + if block not in in_const_dict: + in_const_dict[block] = {} + if assignments: + redo |= self._propagate(in_const_dict[block], assignments) + + if isinstance(block, ControlFlowRegion): + self._collect_constants_for_region(block, arrays, in_const_dict, pre_const_dict, post_const_dict, + out_const_dict) + elif isinstance(block, ConditionalBlock): + self._collect_constants_for_conditional(block, arrays, in_const_dict, pre_const_dict, + post_const_dict, out_const_dict) + else: + # Simple case, no change in constants through this block (states and other basic blocks). + pre_const_dict[block] = in_const_dict[block].copy() + post_const_dict[block] = in_const_dict[block].copy() + out_const_dict[block] = in_const_dict[block].copy() + + # For all sink nodes, compute the overlapping set of constants between them, making sure all constants in the + # resulting intersection are actually constants (i.e., all blocks see the same constant value for them). This + # resulting overlap forms the 'post constants' of this CFG. + post_consts = {} + post_consts_intersection = None + sinks = cfg.sink_nodes() + for sink in sinks: + if post_consts_intersection is None: + post_consts_intersection = set(out_const_dict[sink].keys()) + else: + post_consts_intersection &= set(out_const_dict[sink].keys()) + for sink in sinks: + for k, v in out_const_dict[sink].items(): + if k in post_consts_intersection: + if k not in post_consts: + post_consts[k] = v + elif post_consts[k] != _UnknownValue and post_consts[k] != v: + post_consts[k] = _UnknownValue + else: + post_consts[k] = _UnknownValue + post_const_dict[cfg] = post_consts + + out_consts = {} + if isinstance(cfg, LoopRegion): + # For a loop we can not determine if it is being executed and how many times it would be executed. The 'out + # constants' are thus formed from the intersection of the loop's 'in constants' and 'post constants'. + out_consts.update(in_const) + for k, v in post_consts.items(): + if k not in out_consts: + out_consts[k] = _UnknownValue + elif out_consts[k] != _UnknownValue and out_consts[k] != v: + out_consts[k] = _UnknownValue + else: + out_consts.update(post_consts) + out_const_dict[cfg] = out_consts + def _find_desc_symbols(self, sdfg: SDFG, constants: Dict[SDFGState, Dict[str, Any]]) -> Tuple[Set[str], Set[str]]: """ Finds constant symbols that data descriptors (e.g., arrays) depend on. diff --git a/dace/transformation/passes/dead_dataflow_elimination.py b/dace/transformation/passes/dead_dataflow_elimination.py index 856924abd2..908150d5e2 100644 --- a/dace/transformation/passes/dead_dataflow_elimination.py +++ b/dace/transformation/passes/dead_dataflow_elimination.py @@ -11,6 +11,7 @@ from dace.sdfg import utils as sdutil from dace.sdfg.analysis import cfg from dace.sdfg import infer_types +from dace.sdfg.state import ControlFlowBlock from dace.transformation import pass_pipeline as ppl, transformation from dace.transformation.passes import analysis as ap @@ -19,8 +20,8 @@ @dataclass(unsafe_hash=True) @properties.make_properties -@transformation.single_level_sdfg_only -class DeadDataflowElimination(ppl.Pass): +@transformation.explicit_cf_compatible +class DeadDataflowElimination(ppl.ControlFlowRegionPass): """ Removes unused computations from SDFG states. Traverses the graph backwards, removing any computations that result in transient descriptors @@ -41,12 +42,12 @@ def modifies(self) -> ppl.Modifies: def should_reapply(self, modified: ppl.Modifies) -> bool: # If dataflow or states changed, new dead code may be exposed - return modified & (ppl.Modifies.Nodes | ppl.Modifies.Edges | ppl.Modifies.States) + return modified & (ppl.Modifies.Nodes | ppl.Modifies.Edges | ppl.Modifies.CFG) def depends_on(self) -> Set[Type[ppl.Pass]]: - return {ap.StateReachability, ap.AccessSets} + return {ap.ControlFlowBlockReachability, ap.AccessSets} - def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Optional[Dict[SDFGState, Set[str]]]: + def apply(self, region, pipeline_results): """ Removes unreachable dataflow throughout SDFG states. @@ -57,15 +58,19 @@ def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Optional[D :return: A dictionary mapping states to removed data descriptor names, or None if nothing changed. """ # Depends on the following analysis passes: - # * State reachability - # * Read/write access sets per state - reachable: Dict[SDFGState, Set[SDFGState]] = pipeline_results['StateReachability'][sdfg.cfg_id] - access_sets: Dict[SDFGState, Tuple[Set[str], Set[str]]] = pipeline_results['AccessSets'][sdfg.cfg_id] + # * Control flow block reachability + # * Read/write access sets per block + sdfg = region if isinstance(region, SDFG) else region.sdfg + reachable: Dict[ControlFlowBlock, Set[ControlFlowBlock]] = pipeline_results[ + ap.ControlFlowBlockReachability.__name__ + ][region.cfg_id] + access_sets: Dict[ControlFlowBlock, Tuple[Set[str], Set[str]]] = pipeline_results[ap.AccessSets.__name__] result: Dict[SDFGState, Set[str]] = defaultdict(set) - # Traverse SDFG backwards + # Traverse region backwards try: - state_order = list(cfg.blockorder_topological_sort(sdfg)) + state_order: List[SDFGState] = list(cfg.blockorder_topological_sort(region, recursive=False, + ignore_nonstate_blocks=True)) except KeyError: return None for state in reversed(state_order): diff --git a/dace/transformation/passes/dead_state_elimination.py b/dace/transformation/passes/dead_state_elimination.py index 43239fe9af..cc7e262e4d 100644 --- a/dace/transformation/passes/dead_state_elimination.py +++ b/dace/transformation/passes/dead_state_elimination.py @@ -1,18 +1,19 @@ -# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. import collections import sympy as sp -from typing import Optional, Set, Tuple, Union +from typing import List, Optional, Set, Tuple, Union from dace import SDFG, InterstateEdge, SDFGState, symbolic, properties from dace.properties import CodeBlock from dace.sdfg.graph import Edge -from dace.sdfg.validation import InvalidSDFGInterstateEdgeError +from dace.sdfg.state import ConditionalBlock, ControlFlowBlock, ControlFlowRegion +from dace.sdfg.validation import InvalidSDFGInterstateEdgeError, InvalidSDFGNodeError from dace.transformation import pass_pipeline as ppl, transformation @properties.make_properties -@transformation.single_level_sdfg_only +@transformation.explicit_cf_compatible class DeadStateElimination(ppl.Pass): """ Removes all unreachable states (e.g., due to a branch that will never be taken) from an SDFG. @@ -25,7 +26,7 @@ def modifies(self) -> ppl.Modifies: def should_reapply(self, modified: ppl.Modifies) -> bool: # If connectivity or any edges were changed, some more states might be dead - return modified & (ppl.Modifies.InterstateEdges | ppl.Modifies.States) + return modified & ppl.Modifies.CFG def apply_pass(self, sdfg: SDFG, _) -> Optional[Set[Union[SDFGState, Edge[InterstateEdge]]]]: """ @@ -38,42 +39,74 @@ def apply_pass(self, sdfg: SDFG, _) -> Optional[Set[Union[SDFGState, Edge[Inters :param initial_symbols: If not None, sets values of initial symbols. :return: A set of the removed states, or None if nothing was changed. """ - # Mark dead states and remove them - dead_states, dead_edges, annotated = self.find_dead_states(sdfg, set_unconditional_edges=True) - - for e in dead_edges: - sdfg.remove_edge(e) - sdfg.remove_nodes_from(dead_states) + result: Set[Union[ControlFlowBlock, InterstateEdge]] = set() + removed_regions: Set[ControlFlowRegion] = set() + annotated = None + for cfg in list(sdfg.all_control_flow_regions()): + if cfg in removed_regions or isinstance(cfg, ConditionalBlock): + continue - result = dead_states | dead_edges + # Mark dead blocks and remove them + dead_blocks, dead_edges, annotated = self.find_dead_control_flow(cfg, set_unconditional_edges=True) + for e in dead_edges: + cfg.remove_edge(e) + for block in dead_blocks: + cfg.remove_node(block) + if isinstance(block, ControlFlowRegion): + removed_regions.add(block) + + region_result = dead_blocks | dead_edges + result |= region_result + + for node in cfg.nodes(): + if isinstance(node, ConditionalBlock): + dead_branches = self._find_dead_branches(node) + if len(dead_branches) < len(node.branches): + for _, b in dead_branches: + result.add(b) + node.remove_branch(b) + # If only an 'else' is left over, inline it. + if len(node.branches) == 1 and node.branches[0][0] is None: + branch = node.branches[0][1] + node.parent_graph.add_node(branch) + for ie in cfg.in_edges(node): + cfg.add_edge(ie.src, branch, ie.data) + for oe in cfg.out_edges(node): + cfg.add_edge(branch, oe.dst, oe.data) + result.add(node) + cfg.remove_node(node) + else: + result.add(node) + cfg.remove_node(node) if not annotated: return result or None else: return result or set() # Return an empty set if edges were annotated - def find_dead_states( + def find_dead_control_flow( self, - sdfg: SDFG, - set_unconditional_edges: bool = True) -> Tuple[Set[SDFGState], Set[Edge[InterstateEdge]], bool]: + cfg: ControlFlowRegion, + set_unconditional_edges: bool = True) -> Tuple[Set[ControlFlowBlock], Set[Edge[InterstateEdge]], bool]: """ - Finds "dead" (unreachable) states in an SDFG. A state is deemed unreachable if it is: + Finds "dead" (unreachable) control flow in a CFG. A block is deemed unreachable if it is: - * Unreachable from the starting state + * Unreachable from the starting block * Conditions leading to it will always evaluate to False - * There is another unconditional (always True) inter-state edge that leads to another state + * There is another unconditional (always True) inter-state edge that leads to another block - :param sdfg: The SDFG to traverse. + :param cfg: The CFG to traverse. :param set_unconditional_edges: If True, conditions of edges evaluated as unconditional are removed. - :return: A 3-tuple of (unreachable states, unreachable edges, were edges annotated). + :return: A 3-tuple of (unreachable blocks, unreachable edges, were edges annotated). """ - visited: Set[SDFGState] = set() + sdfg = cfg.sdfg if cfg.sdfg is not None else cfg + visited: Set[ControlFlowBlock] = set() dead_edges: Set[Edge[InterstateEdge]] = set() edges_annotated = False # Run a modified BFS where definitely False edges are not traversed, or if there is an - # unconditional edge the rest are not. The inverse of the visited states is the dead set. - queue = collections.deque([sdfg.start_state]) + # unconditional edge the rest are not. The inverse of the visited blocks is the dead set. + queue = collections.deque([cfg.start_block]) while len(queue) > 0: node = queue.popleft() if node in visited: @@ -82,13 +115,13 @@ def find_dead_states( # First, check for unconditional edges unconditional = None - for e in sdfg.out_edges(node): + for e in cfg.out_edges(node): # If an unconditional edge is found, ignore all other outgoing edges if self.is_definitely_taken(e.data, sdfg): # If more than one unconditional outgoing edge exist, fail with Invalid SDFG if unconditional is not None: - raise InvalidSDFGInterstateEdgeError('Multiple unconditional edges leave the same state', sdfg, - sdfg.edge_id(e)) + raise InvalidSDFGInterstateEdgeError('Multiple unconditional edges leave the same block', cfg, + cfg.edge_id(e)) unconditional = e if set_unconditional_edges and not e.data.is_unconditional(): # Annotate edge as unconditional @@ -101,7 +134,7 @@ def find_dead_states( continue if unconditional is not None: # Unconditional edge exists, skip traversal # Remove other (now never taken) edges from graph - for e in sdfg.out_edges(node): + for e in cfg.out_edges(node): if e is not unconditional: dead_edges.add(e) @@ -109,7 +142,7 @@ def find_dead_states( # End of unconditional check # Check outgoing edges normally - for e in sdfg.out_edges(node): + for e in cfg.out_edges(node): next_node = e.dst # Test for edges that definitely evaluate to False @@ -122,7 +155,32 @@ def find_dead_states( queue.append(next_node) # Dead states are states that are not live (i.e., visited) - return set(sdfg.nodes()) - visited, dead_edges, edges_annotated + return set(cfg.nodes()) - visited, dead_edges, edges_annotated + + def _find_dead_branches(self, block: ConditionalBlock) -> List[Tuple[CodeBlock, ControlFlowRegion]]: + dead_branches = [] + unconditional = None + for i, (cond, branch) in enumerate(block.branches): + if cond is None: + if not i == len(block.branches) - 1: + raise InvalidSDFGNodeError('Conditional block detected, where else branch is not the last branch') + break + # If an unconditional branch is found, ignore all other branches that follow this one. + if cond.as_string.strip() == '1' or self._is_definitely_true(symbolic.pystr_to_symbolic(cond.as_string), block.sdfg): + unconditional = branch + break + if unconditional is not None: + # Remove other (now never taken) branches + for cond, branch in block.branches: + if branch is not unconditional: + dead_branches.append([cond, branch]) + else: + # Check if any branches are certainly never taken. + for cond, branch in block.branches: + if cond is not None and self._is_definitely_false(symbolic.pystr_to_symbolic(cond.as_string), block.sdfg): + dead_branches.append([cond, branch]) + + return dead_branches def report(self, pass_retval: Set[Union[SDFGState, Edge[InterstateEdge]]]) -> str: if pass_retval is not None and not pass_retval: @@ -137,13 +195,15 @@ def is_definitely_taken(self, edge: InterstateEdge, sdfg: SDFG) -> bool: return True # Evaluate condition - scond = edge.condition_sympy() - if scond == True or scond == sp.Not(sp.logic.boolalg.BooleanFalse(), evaluate=False): + return self._is_definitely_true(edge.condition_sympy(), sdfg) + + def _is_definitely_true(self, cond: sp.Basic, sdfg: SDFG) -> bool: + if cond == True or cond == sp.Not(sp.logic.boolalg.BooleanFalse(), evaluate=False): return True # Evaluate non-optional arrays - scond = symbolic.evaluate_optional_arrays(scond, sdfg) - if scond == True: + cond = symbolic.evaluate_optional_arrays(cond, sdfg) + if cond == True: return True # Indeterminate or False condition @@ -155,13 +215,15 @@ def is_definitely_not_taken(self, edge: InterstateEdge, sdfg: SDFG) -> bool: return False # Evaluate condition - scond = edge.condition_sympy() - if scond == False or scond == sp.Not(sp.logic.boolalg.BooleanTrue(), evaluate=False): + return self._is_definitely_false(edge.condition_sympy(), sdfg) + + def _is_definitely_false(self, cond: sp.Basic, sdfg: SDFG) -> bool: + if cond == False or cond == sp.Not(sp.logic.boolalg.BooleanTrue(), evaluate=False): return True # Evaluate non-optional arrays - scond = symbolic.evaluate_optional_arrays(scond, sdfg) - if scond == False: + cond = symbolic.evaluate_optional_arrays(cond, sdfg) + if cond == False: return True # Indeterminate or True condition diff --git a/dace/transformation/passes/fusion_inline.py b/dace/transformation/passes/fusion_inline.py index 934073240b..a873bf0888 100644 --- a/dace/transformation/passes/fusion_inline.py +++ b/dace/transformation/passes/fusion_inline.py @@ -8,14 +8,15 @@ from dace import SDFG, properties from dace.sdfg import nodes -from dace.sdfg.utils import fuse_states, inline_sdfgs +from dace.sdfg.state import ConditionalBlock, FunctionCallRegion, LoopRegion, NamedRegion +from dace.sdfg.utils import fuse_states, inline_control_flow_regions, inline_sdfgs from dace.transformation import pass_pipeline as ppl -from dace.transformation.transformation import experimental_cfg_block_compatible +from dace.transformation.transformation import explicit_cf_compatible @dataclass(unsafe_hash=True) @properties.make_properties -@experimental_cfg_block_compatible +@explicit_cf_compatible class FuseStates(ppl.Pass): """ Fuses all possible states of an SDFG (and all sub-SDFGs). @@ -52,6 +53,7 @@ def report(self, pass_retval: int) -> str: @dataclass(unsafe_hash=True) @properties.make_properties +@explicit_cf_compatible class InlineSDFGs(ppl.Pass): """ Inlines all possible nested SDFGs (and sub-SDFGs). @@ -89,7 +91,84 @@ def report(self, pass_retval: int) -> str: @dataclass(unsafe_hash=True) @properties.make_properties -@experimental_cfg_block_compatible +@explicit_cf_compatible +class InlineControlFlowRegions(ppl.Pass): + """ + Inlines all control flow regions. + """ + + CATEGORY: str = 'Simplification' + + progress = properties.Property(dtype=bool, + default=None, + allow_none=True, + desc='Whether to print progress, or None for default (print after 5 seconds).') + + no_inline_loops = properties.Property(dtype=bool, default=True, desc='Whether to prevent inlining loops.') + no_inline_conditional = properties.Property(dtype=bool, default=True, + desc='Whether to prevent inlining conditional blocks.') + no_inline_function_call_regions = properties.Property(dtype=bool, default=True, + desc='Whether to prevent inlining function call regions.') + no_inline_named_regions = properties.Property(dtype=bool, default=True, + desc='Whether to prevent inlining named control flow regions.') + + def should_reapply(self, modified: ppl.Modifies) -> bool: + return modified & (ppl.Modifies.NestedSDFGs | ppl.Modifies.States) + + def modifies(self) -> ppl.Modifies: + return ppl.Modifies.States | ppl.Modifies.NestedSDFGs + + def apply_pass(self, sdfg: SDFG, _: Dict[str, Any]) -> Optional[int]: + """ + Inlines all possible nested SDFGs (and all sub-SDFGs). + + :param sdfg: The SDFG to transform. + + :return: The total number of states fused, or None if did not apply. + """ + ignore_region_types = [] + if self.no_inline_loops: + ignore_region_types.append(LoopRegion) + if self.no_inline_conditional: + ignore_region_types.append(ConditionalBlock) + if self.no_inline_named_regions: + ignore_region_types.append(NamedRegion) + if self.no_inline_function_call_regions: + ignore_region_types.append(FunctionCallRegion) + if len(ignore_region_types) < 1: + ignore_region_types = None + + inlined = 0 + while True: + inlined_in_iteration = inline_control_flow_regions(sdfg, None, ignore_region_types, self.progress) + if inlined_in_iteration < 1: + break + inlined += inlined_in_iteration + + if inlined: + sdfg.reset_cfg_list() + return inlined + return None + + def report(self, pass_retval: int) -> str: + return f'Inlined {pass_retval} regions.' + + def set_opts(self, opts): + opt_keys = [ + 'no_inline_loops', + 'no_inline_conditional', + 'no_inline_function_call_regions', + 'no_inline_named_regions', + ] + + for k in opt_keys: + if k in opts: + setattr(self, k, opts[k]) + + +@dataclass(unsafe_hash=True) +@properties.make_properties +@explicit_cf_compatible class FixNestedSDFGReferences(ppl.Pass): """ Fixes nested SDFG references to parent state/SDFG/node diff --git a/dace/transformation/passes/optional_arrays.py b/dace/transformation/passes/optional_arrays.py index f52ee5af43..6f96f0f53f 100644 --- a/dace/transformation/passes/optional_arrays.py +++ b/dace/transformation/passes/optional_arrays.py @@ -2,14 +2,15 @@ from typing import Dict, Iterator, Optional, Set, Tuple -from dace import SDFG, SDFGState, data, properties +from dace import SDFG, data, properties from dace.sdfg import nodes from dace.sdfg import utils as sdutil +from dace.sdfg.state import ControlFlowBlock, ControlFlowRegion, LoopRegion, SDFGState from dace.transformation import pass_pipeline as ppl, transformation @properties.make_properties -@transformation.single_level_sdfg_only +@transformation.explicit_cf_compatible class OptionalArrayInference(ppl.Pass): """ Infers the ``optional`` property of arrays, i.e., if they can be given None, throughout the SDFG and all nested @@ -63,7 +64,7 @@ def apply_pass(self, arr.optional = parent_arrays[aname] # Change unconditionally-accessed arrays to non-optional - for state in self.traverse_unconditional_states(sdfg): + for state in self.traverse_unconditional_blocks(sdfg, recursive=True): for anode in state.data_nodes(): desc = anode.desc(sdfg) if isinstance(desc, data.Array) and desc.optional is None: @@ -71,7 +72,7 @@ def apply_pass(self, result.add((cfg_id, anode.data)) # Propagate information to nested SDFGs - for state in sdfg.nodes(): + for state in sdfg.states(): for node in state.nodes(): if isinstance(node, nodes.NestedSDFG): # Create information about parent arrays @@ -96,27 +97,38 @@ def apply_pass(self, return result or None - def traverse_unconditional_states(self, sdfg: SDFG) -> Iterator[SDFGState]: + def traverse_unconditional_blocks(self, cfg: ControlFlowRegion, + recursive: bool = True, + produce_nonstate: bool = False) -> Iterator[ControlFlowBlock]: """ - Traverse SDFG and keep track of whether the state is executed unconditionally. + Traverse CFG and keep track of whether the block is executed unconditionally. """ - ipostdom = sdutil.postdominators(sdfg) - curstate = sdfg.start_state - out_degree = sdfg.out_degree(curstate) + ipostdom = sdutil.postdominators(cfg) + curblock = cfg.start_block + out_degree = cfg.out_degree(curblock) while out_degree > 0: - yield curstate + if produce_nonstate: + yield curblock + elif isinstance(curblock, SDFGState): + yield curblock + + if recursive and isinstance(curblock, ControlFlowRegion) and not isinstance(curblock, LoopRegion): + yield from self.traverse_unconditional_blocks(curblock, recursive, produce_nonstate) if out_degree == 1: # Unconditional, continue to next state - curstate = sdfg.successors(curstate)[0] + curblock = cfg.successors(curblock)[0] elif out_degree > 1: # Conditional branch # Conditional code follows, use immediate post-dominator for next unconditional state - curstate = ipostdom[curstate] + curblock = ipostdom[curblock] # Compute new out degree - if curstate in sdfg.nodes(): - out_degree = sdfg.out_degree(curstate) + if curblock in cfg.nodes(): + out_degree = cfg.out_degree(curblock) else: out_degree = 0 # Yield final state - yield curstate + if produce_nonstate: + yield curblock + elif isinstance(curblock, SDFGState): + yield curblock def report(self, pass_retval: Set[Tuple[int, str]]) -> str: return f'Inferred {len(pass_retval)} optional arrays.' diff --git a/dace/transformation/passes/pattern_matching.py b/dace/transformation/passes/pattern_matching.py index faa011f7d9..7aa16633fd 100644 --- a/dace/transformation/passes/pattern_matching.py +++ b/dace/transformation/passes/pattern_matching.py @@ -98,16 +98,16 @@ def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Dict[str, # For every transformation in the list, find first match and apply for xform in self.transformations: - if sdfg.root_sdfg.using_experimental_blocks: - if (not hasattr(xform, '__experimental_cfg_block_compatible__') or - xform.__experimental_cfg_block_compatible__ == False): + if sdfg.root_sdfg.using_explicit_control_flow: + if (not hasattr(xform, '__explicit_cf_compatible__') or + xform.__explicit_cf_compatible__ == False): warnings.warn('Pattern matching is skipping transformation ' + xform.__class__.__name__ + ' due to incompatibility with experimental control flow blocks. If the ' + 'SDFG does not contain experimental blocks, ensure the top level SDFG does ' + - 'not have `SDFG.using_experimental_blocks` set to True. If ' + + 'not have `SDFG.using_explicit_control_flow` set to True. If ' + xform.__class__.__name__ + ' is compatible with experimental blocks, ' + 'please annotate it with the class decorator ' + - '`@dace.transformation.experimental_cfg_block_compatible`. see ' + + '`@dace.transformation.explicit_cf_compatible`. see ' + '`https://github.com/spcl/dace/wiki/Experimental-Control-Flow-Blocks` ' + 'for more information.') continue @@ -218,16 +218,16 @@ def _apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any], apply_once: while applied_anything: applied_anything = False for xform in xforms: - if sdfg.root_sdfg.using_experimental_blocks: - if (not hasattr(xform, '__experimental_cfg_block_compatible__') or - xform.__experimental_cfg_block_compatible__ == False): + if sdfg.root_sdfg.using_explicit_control_flow: + if (not hasattr(xform, '__explicit_cf_compatible__') or + xform.__explicit_cf_compatible__ == False): warnings.warn('Pattern matching is skipping transformation ' + xform.__class__.__name__ + ' due to incompatibility with experimental control flow blocks. If the ' + 'SDFG does not contain experimental blocks, ensure the top level SDFG does ' + - 'not have `SDFG.using_experimental_blocks` set to True. If ' + + 'not have `SDFG.using_explicit_control_flow` set to True. If ' + xform.__class__.__name__ + ' is compatible with experimental blocks, ' + 'please annotate it with the class decorator ' + - '`@dace.transformation.experimental_cfg_block_compatible`. see ' + + '`@dace.transformation.explicit_cf_compatible`. see ' + '`https://github.com/spcl/dace/wiki/Experimental-Control-Flow-Blocks` ' + 'for more information.') continue @@ -410,16 +410,16 @@ def _try_to_match_transformation(graph: Union[ControlFlowRegion, SDFGState], col for oname, oval in opts.items(): setattr(match, oname, oval) - if sdfg.root_sdfg.using_experimental_blocks: - if (not hasattr(match, '__experimental_cfg_block_compatible__') or - match.__experimental_cfg_block_compatible__ == False): + if sdfg.root_sdfg.using_explicit_control_flow: + if (not hasattr(match, '__explicit_cf_compatible__') or + match.__explicit_cf_compatible__ == False): warnings.warn('Pattern matching is skipping transformation ' + match.__class__.__name__ + ' due to incompatibility with experimental control flow blocks. If the ' + 'SDFG does not contain experimental blocks, ensure the top level SDFG does ' + - 'not have `SDFG.using_experimental_blocks` set to True. If ' + + 'not have `SDFG.using_explicit_control_flow` set to True. If ' + match.__class__.__name__ + ' is compatible with experimental blocks, ' + 'please annotate it with the class decorator ' + - '`@dace.transformation.experimental_cfg_block_compatible`. see ' + + '`@dace.transformation.explicit_cf_compatible`. see ' + '`https://github.com/spcl/dace/wiki/Experimental-Control-Flow-Blocks` ' + 'for more information.') return None diff --git a/dace/transformation/passes/prune_symbols.py b/dace/transformation/passes/prune_symbols.py index 3b3940f804..a01d903a1d 100644 --- a/dace/transformation/passes/prune_symbols.py +++ b/dace/transformation/passes/prune_symbols.py @@ -6,12 +6,13 @@ from dace import SDFG, dtypes, properties, symbolic from dace.sdfg import nodes +from dace.sdfg.state import SDFGState from dace.transformation import pass_pipeline as ppl, transformation @dataclass(unsafe_hash=True) @properties.make_properties -@transformation.single_level_sdfg_only +@transformation.explicit_cf_compatible class RemoveUnusedSymbols(ppl.Pass): """ Prunes unused symbols from the SDFG symbol repository (``sdfg.symbols``) and interstate edges. @@ -64,7 +65,7 @@ def apply_pass(self, sdfg: SDFG, _) -> Optional[Set[Tuple[int, str]]]: sid = sdfg.cfg_id result = set((sid, sym) for sym in result) - for state in sdfg.nodes(): + for state in sdfg.states(): for node in state.nodes(): if isinstance(node, nodes.NestedSDFG): old_symbols = self.symbols @@ -90,27 +91,28 @@ def used_symbols(self, sdfg: SDFG) -> Set[str]: for desc in sdfg.arrays.values(): result |= set(map(str, desc.free_symbols)) - for state in sdfg.nodes(): - result |= state.free_symbols + for block in sdfg.all_control_flow_blocks(): + result |= block.free_symbols # In addition to the standard free symbols, we are conservative with other tasklet languages by # tokenizing their code. Since this is intersected with `sdfg.symbols`, keywords such as "if" are # ok to include - for node in state.nodes(): - if isinstance(node, nodes.Tasklet): - if node.code.language != dtypes.Language.Python: - result |= symbolic.symbols_in_code(node.code.as_string, sdfg.symbols.keys(), - node.ignored_symbols) - if node.code_global.language != dtypes.Language.Python: - result |= symbolic.symbols_in_code(node.code_global.as_string, sdfg.symbols.keys(), - node.ignored_symbols) - if node.code_init.language != dtypes.Language.Python: - result |= symbolic.symbols_in_code(node.code_init.as_string, sdfg.symbols.keys(), - node.ignored_symbols) - if node.code_exit.language != dtypes.Language.Python: - result |= symbolic.symbols_in_code(node.code_exit.as_string, sdfg.symbols.keys(), - node.ignored_symbols) - - for e in sdfg.edges(): + if isinstance(block, SDFGState): + for node in block.nodes(): + if isinstance(node, nodes.Tasklet): + if node.code.language != dtypes.Language.Python: + result |= symbolic.symbols_in_code(node.code.as_string, sdfg.symbols.keys(), + node.ignored_symbols) + if node.code_global.language != dtypes.Language.Python: + result |= symbolic.symbols_in_code(node.code_global.as_string, sdfg.symbols.keys(), + node.ignored_symbols) + if node.code_init.language != dtypes.Language.Python: + result |= symbolic.symbols_in_code(node.code_init.as_string, sdfg.symbols.keys(), + node.ignored_symbols) + if node.code_exit.language != dtypes.Language.Python: + result |= symbolic.symbols_in_code(node.code_exit.as_string, sdfg.symbols.keys(), + node.ignored_symbols) + + for e in sdfg.all_interstate_edges(): result |= e.data.free_symbols return result diff --git a/dace/transformation/passes/reference_reduction.py b/dace/transformation/passes/reference_reduction.py index dc5ae1eb7d..6418a6025b 100644 --- a/dace/transformation/passes/reference_reduction.py +++ b/dace/transformation/passes/reference_reduction.py @@ -5,13 +5,13 @@ from dace import SDFG, SDFGState, data, properties, Memlet from dace.sdfg import nodes -from dace.sdfg.analysis import cfg from dace.transformation import pass_pipeline as ppl, transformation +from dace.transformation.helpers import all_isedges_between from dace.transformation.passes import analysis as ap @properties.make_properties -@transformation.single_level_sdfg_only +@transformation.explicit_cf_compatible class ReferenceToView(ppl.Pass): """ Replaces Reference data descriptors that are only set to one source with views. @@ -135,13 +135,10 @@ def find_candidates( # Filter self and unreachable states if other_state is state or other_state not in reachable_states[state]: continue - for path in sdfg.all_simple_paths(state, other_state, as_edges=True): - for e in path: - # The symbol was modified/reassigned in one of the paths, skip - if fsyms & e.data.assignments.keys(): - result.remove(cand) - break - if cand not in result: + for e in all_isedges_between(state, other_state): + # The symbol was modified/reassigned in one of the paths, skip + if fsyms & e.data.assignments.keys(): + result.remove(cand) break if cand not in result: break diff --git a/dace/transformation/passes/scalar_fission.py b/dace/transformation/passes/scalar_fission.py index f691a861d7..8d88f2752b 100644 --- a/dace/transformation/passes/scalar_fission.py +++ b/dace/transformation/passes/scalar_fission.py @@ -1,4 +1,4 @@ -# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. from collections import defaultdict from typing import Any, Dict, Optional, Set @@ -8,7 +8,7 @@ from dace.transformation.passes import analysis as ap -@transformation.single_level_sdfg_only +@transformation.explicit_cf_compatible class ScalarFission(ppl.Pass): """ Fission transient scalars or arrays of size 1 that are dominated by a write into separate data containers. diff --git a/dace/transformation/passes/scalar_to_symbol.py b/dace/transformation/passes/scalar_to_symbol.py index a37729ca7c..43cd45146d 100644 --- a/dace/transformation/passes/scalar_to_symbol.py +++ b/dace/transformation/passes/scalar_to_symbol.py @@ -21,9 +21,10 @@ from dace.sdfg import utils as sdutils from dace.sdfg.replace import replace_properties_dict from dace.sdfg.sdfg import InterstateEdge +from dace.sdfg.state import ConditionalBlock, LoopRegion from dace.transformation import helpers as xfh from dace.transformation import pass_pipeline as passes -from dace.transformation.transformation import experimental_cfg_block_compatible +from dace.transformation.transformation import explicit_cf_compatible class AttributedCallDetector(ast.NodeVisitor): @@ -233,6 +234,8 @@ def find_promotable_scalars(sdfg: sd.SDFG, transients_only: bool = True, integer interstate_symbols = set() for edge in sdfg.all_interstate_edges(): interstate_symbols |= edge.data.free_symbols + for reg in sdfg.all_control_flow_regions(): + interstate_symbols |= reg.used_symbols(all_symbols=True, with_contents=False) for candidate in (candidates - interstate_symbols): if integers_only and sdfg.arrays[candidate].dtype not in dtypes.INTEGER_TYPES: candidates.remove(candidate) @@ -598,7 +601,7 @@ def translate_cpp_tasklet_to_python(code: str): @dataclass(unsafe_hash=True) @props.make_properties -@experimental_cfg_block_compatible +@explicit_cf_compatible class ScalarToSymbolPromotion(passes.Pass): CATEGORY: str = 'Simplification' @@ -735,6 +738,11 @@ def apply_pass(self, sdfg: SDFG, _: Dict[Any, Any]) -> Set[str]: # should work for all Python versions. assignment = cleanup_re[scalar].sub(scalar, assignment.strip()) ise.assignments[aname] = assignment + for reg in sdfg.all_control_flow_regions(): + meta_codes = reg.get_meta_codeblocks() + for cd in meta_codes: + for stmt in cd.code: + promo.visit(stmt) # Step 7: Indirection remove_symbol_indirection(sdfg) diff --git a/dace/transformation/passes/simplification/control_flow_raising.py b/dace/transformation/passes/simplification/control_flow_raising.py index abe305f12c..81b9c6b0eb 100644 --- a/dace/transformation/passes/simplification/control_flow_raising.py +++ b/dace/transformation/passes/simplification/control_flow_raising.py @@ -1,18 +1,24 @@ # Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. -from typing import Optional, Tuple +import ast +from typing import List, Optional, Tuple + import networkx as nx +import sympy + from dace import properties +from dace.frontend.python import astutils from dace.sdfg.analysis import cfg as cfg_analysis from dace.sdfg.sdfg import SDFG, InterstateEdge -from dace.sdfg.state import ConditionalBlock, ControlFlowBlock, ControlFlowRegion +from dace.sdfg.state import ConditionalBlock, ControlFlowRegion, ReturnBlock from dace.sdfg.utils import dfs_conditional -from dace.transformation import pass_pipeline as ppl, transformation +from dace.transformation import pass_pipeline as ppl +from dace.transformation import transformation from dace.transformation.interstate.loop_lifting import LoopLifting @properties.make_properties -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class ControlFlowRaising(ppl.Pass): """ Raises all detectable control flow that can be expressed with native SDFG structures, such as loops and branching. @@ -20,21 +26,102 @@ class ControlFlowRaising(ppl.Pass): CATEGORY: str = 'Simplification' + raise_sink_node_returns = properties.Property( + dtype=bool, + default=False, + desc='Whether or not to lift sink nodes in an SDFG context to explicit return blocks.') + def modifies(self) -> ppl.Modifies: return ppl.Modifies.CFG def should_reapply(self, modified: ppl.Modifies) -> bool: return modified & ppl.Modifies.CFG + def _lift_returns(self, sdfg: SDFG) -> int: + """ + Make any implicit early program exits explicit by inserting return blocks. + An implicit early program exit is a control flow block with not at least one unconditional edge leading out of + it, or where there is no 'catchall' condition that negates all other conditions. For any such transition, if + the condition(s) is / are not met, the SDFG halts. + This method detects such situations and inserts an explicit transition to a return block for each such missing + unconditional edge or 'catchall' condition. Note that this is only performed on the top-level control flow + region, i.e., the SDFG itself. Any implicit early stops inside nested regions only end the context of that + region, and not the entire SDFG. + + :param sdfg: The SDFG in which to lift returns + :returns: The number of return blocks lifted + """ + returns_lifted = 0 + for nd in sdfg.nodes(): + # Existing returns can be skipped. + if isinstance(nd, ReturnBlock): + continue + + # First check if there is an unconditional outgoing edge. + has_unconditional = False + full_cond_expression: Optional[List[ast.AST]] = None + oedges = sdfg.out_edges(nd) + for oe in oedges: + if oe.data.is_unconditional(): + has_unconditional = True + break + else: + if full_cond_expression is None: + full_cond_expression = oe.data.condition.code[0] + else: + full_cond_expression = astutils.and_expr(full_cond_expression, oe.data.condition.code[0]) + # If there is no unconditional outgoing edge, there may be a catchall that is the negation of all other + # conditions. + # NOTE: Checking that for the general case is expensive. For now, we check it for the case of two outgoing + # edges, where the two edges are a negation of one another, which is cheap. In any other case, an + # explicit return is added with the negation of everything. This is conservative and always correct, + # but may insert a stray (and unreachable) return in rare cases. That case should hardly ever occur + # and does not lead to any negative side effects. + if has_unconditional: + insert_return = False + else: + if len(oedges) == 2 and oedges[0].data.condition_sympy() == sympy.Not(oedges[1].data.condition_sympy()): + insert_return = False + else: + insert_return = True + + if insert_return: + if full_cond_expression is None: + # If there is no condition, there are no outgoing edges - so this is already an explicit program + # exit by being a sink node. + if self.raise_sink_node_returns: + ret_block = ReturnBlock(sdfg.name + '_return') + sdfg.add_node(ret_block, ensure_unique_name=True) + sdfg.add_edge(nd, ret_block, InterstateEdge()) + returns_lifted += 1 + else: + ret_block = ReturnBlock(nd.label + '_return') + sdfg.add_node(ret_block, ensure_unique_name=True) + catchall_condition_expression = astutils.negate_expr(full_cond_expression) + ret_edge = InterstateEdge(condition=properties.CodeBlock([catchall_condition_expression])) + sdfg.add_edge(nd, ret_block, ret_edge) + returns_lifted += 1 + + return returns_lifted + def _lift_conditionals(self, sdfg: SDFG) -> int: cfgs = list(sdfg.all_control_flow_regions()) n_cond_regions_pre = len([x for x in sdfg.all_control_flow_blocks() if isinstance(x, ConditionalBlock)]) for region in cfgs: - sinks = region.sink_nodes() - dummy_exit = region.add_state('__DACE_DUMMY') - for s in sinks: - region.add_edge(s, dummy_exit, InterstateEdge()) + if isinstance(region, ConditionalBlock): + continue + + # If there are multiple sinks, create a dummy exit node for finding branch merges. If there is at least one + # non-return block sink, do not count return blocks as sink nodes. Doing so could cause branches to inter- + # connect unnecessarily, thus preventing lifting. + non_return_sinks = [s for s in region.sink_nodes() if not isinstance(s, ReturnBlock)] + sinks = non_return_sinks if len(non_return_sinks) > 0 else region.sink_nodes() + dummy_exit = None + if len(sinks) > 1: + dummy_exit = region.add_state('__DACE_DUMMY') + for s in sinks: + region.add_edge(s, dummy_exit, InterstateEdge()) idom = nx.immediate_dominators(region.nx, region.start_block) alldoms = cfg_analysis.all_dominators(region, idom) branch_merges = cfg_analysis.branch_merges(region, idom, alldoms) @@ -58,12 +145,13 @@ def _lift_conditionals(self, sdfg: SDFG) -> int: conditional.add_branch(oe.data.condition, branch) if oe.dst is merge_block: # Empty branch. + branch.add_state('noop') + graph.remove_edge(oe) continue branch_nodes = set(dfs_conditional(graph, [oe.dst], lambda _, x: x is not merge_block)) branch_start = branch.add_state(branch_name + '_start', is_start_block=True) branch.add_nodes_from(branch_nodes) - branch_end = branch.add_state(branch_name + '_end') branch.add_edge(branch_start, oe.dst, InterstateEdge(assignments=oe.data.assignments)) added = set() for e in graph.all_edges(*branch_nodes): @@ -72,25 +160,39 @@ def _lift_conditionals(self, sdfg: SDFG) -> int: if e is oe: continue elif e.dst is merge_block: - branch.add_edge(e.src, branch_end, e.data) + if e.data.assignments or not e.data.is_unconditional(): + branch.add_edge(e.src, branch.add_state(branch_name + '_end'), e.data) else: branch.add_edge(e.src, e.dst, e.data) graph.remove_nodes_from(branch_nodes) # Connect to the end of the branch / what happens after. - if merge_block is not dummy_exit: + if dummy_exit is None or merge_block is not dummy_exit: graph.add_edge(conditional, merge_block, InterstateEdge()) - region.remove_node(dummy_exit) + if dummy_exit is not None: + region.remove_node(dummy_exit) n_cond_regions_post = len([x for x in sdfg.all_control_flow_blocks() if isinstance(x, ConditionalBlock)]) - return n_cond_regions_post - n_cond_regions_pre + lifted = n_cond_regions_post - n_cond_regions_pre + if lifted: + sdfg.root_sdfg.using_explicit_control_flow = True + return lifted - def apply_pass(self, top_sdfg: SDFG, _) -> Optional[Tuple[int, int]]: + def apply_pass(self, top_sdfg: SDFG, _) -> Optional[Tuple[int, int, int]]: + lifted_returns = 0 lifted_loops = 0 lifted_branches = 0 for sdfg in top_sdfg.all_sdfgs_recursive(): + lifted_returns += self._lift_returns(sdfg) lifted_loops += sdfg.apply_transformations_repeated([LoopLifting], validate_all=False, validate=False) lifted_branches += self._lift_conditionals(sdfg) if lifted_branches == 0 and lifted_loops == 0: return None - return lifted_loops, lifted_branches + top_sdfg.reset_cfg_list() + return lifted_returns, lifted_loops, lifted_branches + + def report(self, pass_retval: Optional[Tuple[int, int, int]]): + if pass_retval and any([x > 0 for x in pass_retval]): + return f'Lifted {pass_retval[0]} returns, {pass_retval[1]} loops, and {pass_retval[2]} conditional blocks' + else: + return 'No control flow lifted' diff --git a/dace/transformation/passes/simplification/prune_empty_conditional_branches.py b/dace/transformation/passes/simplification/prune_empty_conditional_branches.py new file mode 100644 index 0000000000..d7bd397830 --- /dev/null +++ b/dace/transformation/passes/simplification/prune_empty_conditional_branches.py @@ -0,0 +1,72 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. + +from typing import Optional +from dace import properties +from dace.frontend.python import astutils +from dace.sdfg.sdfg import InterstateEdge +from dace.sdfg.state import ConditionalBlock, ControlFlowRegion, SDFGState +from dace.transformation import pass_pipeline as ppl, transformation + + +@properties.make_properties +@transformation.explicit_cf_compatible +class PruneEmptyConditionalBranches(ppl.ControlFlowRegionPass): + """ + Prunes empty (or no-op) conditional branches from conditional blocks. + """ + + CATEGORY: str = 'Simplification' + + def __init__(self): + super().__init__() + self.apply_to_conditionals = True + + def modifies(self) -> ppl.Modifies: + return ppl.Modifies.CFG + + def should_reapply(self, modified: ppl.Modifies) -> bool: + return modified & ppl.Modifies.CFG + + def apply(self, region: ControlFlowRegion, _) -> Optional[int]: + if not isinstance(region, ConditionalBlock): + return None + removed_branches = 0 + all_branches = region.branches + has_else = all_branches[-1][0] is None + new_else_cond = None + for cond, branch in all_branches: + branch_nodes = branch.nodes() + if (len(branch_nodes) == 0 or (len(branch_nodes) == 1 and isinstance(branch_nodes[0], SDFGState) and + len(branch_nodes[0].nodes()) == 0)): + # Found a branch we can eliminate. + if has_else and branch is not all_branches[-1][1]: + # If this conditional has an else branch and that is not the branch being eliminated, we need to + # change that else branch to a conditional else-if branch that negates the current branch's + # condition. + negated_condition = astutils.negate_expr(cond.code[0]) + if new_else_cond is None: + new_else_cond = properties.CodeBlock([negated_condition]) + else: + combined_cond = astutils.and_expr(negated_condition, new_else_cond.code[0]) + new_else_cond = properties.CodeBlock([combined_cond]) + region.remove_branch(branch) + else: + # Simple case, eliminate the branch. + region.remove_branch(branch) + removed_branches += 1 + # If the else branch remains, make sure it now has the new negate-all condition. + if new_else_cond is not None and region.branches[-1][0] is None: + region._branches[-1] = (new_else_cond, region._branches[-1][1]) + + if len(region.branches) == 0: + # The conditional has become entirely empty, remove it. + replacement_node_before = region.parent_graph.add_state_before(region) + replacement_node_after = region.parent_graph.add_state_before(region) + region.parent_graph.add_edge(replacement_node_before, replacement_node_after, InterstateEdge()) + region.parent_graph.remove_node(region) + + if removed_branches > 0: + region.reset_cfg_list() + return removed_branches + return None + diff --git a/dace/transformation/passes/simplify.py b/dace/transformation/passes/simplify.py index 81e8e88362..bfd22ebaf3 100644 --- a/dace/transformation/passes/simplify.py +++ b/dace/transformation/passes/simplify.py @@ -1,6 +1,6 @@ -# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. from dataclasses import dataclass -from typing import Any, Dict, Optional, Set +from typing import Any, Dict, List, Optional, Set import warnings from dace import SDFG, config, properties @@ -11,20 +11,25 @@ from dace.transformation.passes.constant_propagation import ConstantPropagation from dace.transformation.passes.dead_dataflow_elimination import DeadDataflowElimination from dace.transformation.passes.dead_state_elimination import DeadStateElimination -from dace.transformation.passes.fusion_inline import FuseStates, InlineSDFGs +from dace.transformation.passes.fusion_inline import FuseStates, InlineControlFlowRegions, InlineSDFGs from dace.transformation.passes.optional_arrays import OptionalArrayInference from dace.transformation.passes.scalar_to_symbol import ScalarToSymbolPromotion from dace.transformation.passes.prune_symbols import RemoveUnusedSymbols from dace.transformation.passes.reference_reduction import ReferenceToView +from dace.transformation.passes.simplification.control_flow_raising import ControlFlowRaising +from dace.transformation.passes.simplification.prune_empty_conditional_branches import PruneEmptyConditionalBranches SIMPLIFY_PASSES = [ InlineSDFGs, + InlineControlFlowRegions, ScalarToSymbolPromotion, + ControlFlowRaising, FuseStates, OptionalArrayInference, ConstantPropagation, DeadDataflowElimination, DeadStateElimination, + PruneEmptyConditionalBranches, RemoveUnusedSymbols, ReferenceToView, ArrayElimination, @@ -43,7 +48,7 @@ @dataclass(unsafe_hash=True) @properties.make_properties -@transformation.experimental_cfg_block_compatible +@transformation.explicit_cf_compatible class SimplifyPass(ppl.FixedPointPipeline): """ A pipeline that simplifies an SDFG by applying a series of simplification passes. @@ -58,15 +63,23 @@ class SimplifyPass(ppl.FixedPointPipeline): skip = properties.SetProperty(element_type=str, default=set(), desc='Set of pass names to skip.') verbose = properties.Property(dtype=bool, default=False, desc='Whether to print reports after every pass.') + no_inline_function_call_regions = properties.Property(dtype=bool, default=False, + desc='Whether to prevent inlining function call regions.') + no_inline_named_regions = properties.Property(dtype=bool, default=False, + desc='Whether to prevent inlining named control flow regions.') + def __init__(self, validate: bool = False, validate_all: bool = False, skip: Optional[Set[str]] = None, - verbose: bool = False): + verbose: bool = False, + no_inline_function_call_regions: bool = False, + no_inline_named_regions: bool = False, + pass_options: Optional[Dict[str, Any]] = None): if skip: - passes = [p() for p in SIMPLIFY_PASSES if p.__name__ not in skip] + passes: List[ppl.Pass] = [p() for p in SIMPLIFY_PASSES if p.__name__ not in skip] else: - passes = [p() for p in SIMPLIFY_PASSES] + passes: List[ppl.Pass] = [p() for p in SIMPLIFY_PASSES] super().__init__(passes=passes) self.validate = validate @@ -77,19 +90,31 @@ def __init__(self, else: self.verbose = verbose + self.no_inline_function_call_regions = no_inline_function_call_regions + self.no_inline_named_regions = no_inline_named_regions + + pass_opts = { + 'no_inline_function_call_regions': self.no_inline_function_call_regions, + 'no_inline_named_regions': self.no_inline_named_regions, + } + if pass_options: + pass_opts.update(pass_options) + for p in passes: + p.set_opts(pass_opts) + def apply_subpass(self, sdfg: SDFG, p: ppl.Pass, state: Dict[str, Any]): """ Apply a pass from the pipeline. This method is meant to be overridden by subclasses. """ - if sdfg.root_sdfg.using_experimental_blocks: - if (not hasattr(p, '__experimental_cfg_block_compatible__') or - p.__experimental_cfg_block_compatible__ == False): + if sdfg.root_sdfg.using_explicit_control_flow: + if (not hasattr(p, '__explicit_cf_compatible__') or + p.__explicit_cf_compatible__ == False): warnings.warn(p.__class__.__name__ + ' is not being applied due to incompatibility with ' + 'experimental control flow blocks. If the SDFG does not contain experimental blocks, ' + - 'ensure the top level SDFG does not have `SDFG.using_experimental_blocks` set to ' + + 'ensure the top level SDFG does not have `SDFG.using_explicit_control_flow` set to ' + 'True. If ' + p.__class__.__name__ + ' is compatible with experimental blocks, ' + 'please annotate it with the class decorator ' + - '`@dace.transformation.experimental_cfg_block_compatible`. see ' + + '`@dace.transformation.explicit_cf_compatible`. see ' + '`https://github.com/spcl/dace/wiki/Experimental-Control-Flow-Blocks` ' + 'for more information.') return None @@ -103,6 +128,8 @@ def apply_subpass(self, sdfg: SDFG, p: ppl.Pass, state: Dict[str, Any]): ret = ret or None else: ret = p.apply_pass(sdfg, state) + if ret is not None: + sdfg.reset_cfg_list() if self.verbose: if ret is not None: diff --git a/dace/transformation/passes/symbol_ssa.py b/dace/transformation/passes/symbol_ssa.py index fa59f88df7..da0d1cdbb1 100644 --- a/dace/transformation/passes/symbol_ssa.py +++ b/dace/transformation/passes/symbol_ssa.py @@ -1,14 +1,15 @@ -# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. from collections import defaultdict from typing import Any, Dict, Optional, Set -from dace import SDFG, SDFGState +from dace import SDFG +from dace.sdfg.state import ControlFlowBlock from dace.transformation import pass_pipeline as ppl, transformation from dace.transformation.passes import analysis as ap -@transformation.single_level_sdfg_only -class StrictSymbolSSA(ppl.Pass): +@transformation.explicit_cf_compatible +class StrictSymbolSSA(ppl.ControlFlowRegionPass): """ Perform an SSA transformation on all symbols in the SDFG in a strict manner, i.e., without introducing phi nodes. """ @@ -24,19 +25,20 @@ def should_reapply(self, modified: ppl.Modifies) -> bool: def depends_on(self): return {ap.SymbolWriteScopes} - def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Optional[Dict[str, Set[str]]]: + def apply(self, region, pipeline_results) -> Optional[Dict[str, Set[str]]]: """ Rename symbols in a restricted SSA manner. - :param sdfg: The SDFG to modify. + :param region: The control flow region to modify. :param pipeline_results: If in the context of a ``Pipeline``, a dictionary that is populated with prior Pass results as ``{Pass subclass name: returned object from pass}``. If not run in a pipeline, an empty dictionary is expected. :return: A dictionary mapping the original name to a set of all new names created for each symbol. """ results: Dict[str, Set[str]] = defaultdict(lambda: set()) + sdfg = region if isinstance(region, SDFG) else region.sdfg - symbol_scope_dict: ap.SymbolScopeDict = pipeline_results[ap.SymbolWriteScopes.__name__][sdfg.cfg_id] + symbol_scope_dict: ap.SymbolScopeDict = pipeline_results[ap.SymbolWriteScopes.__name__][region.cfg_id] for name, scope_dict in symbol_scope_dict.items(): # If there is only one scope, don't do anything. @@ -58,7 +60,7 @@ def apply_pass(self, sdfg: SDFG, pipeline_results: Dict[str, Any]) -> Optional[D # Replace all dominated reads. for read in shadowed_reads: - if isinstance(read, SDFGState): + if isinstance(read, ControlFlowBlock): read.replace(name, newname) else: if read not in scope_dict: diff --git a/dace/transformation/passes/transient_reuse.py b/dace/transformation/passes/transient_reuse.py index 805ddadff4..99e41d724f 100644 --- a/dace/transformation/passes/transient_reuse.py +++ b/dace/transformation/passes/transient_reuse.py @@ -6,11 +6,11 @@ from dace import SDFG, properties from dace.sdfg import nodes from dace.transformation import pass_pipeline as ppl -from dace.transformation.transformation import experimental_cfg_block_compatible +from dace.transformation.transformation import explicit_cf_compatible @properties.make_properties -@experimental_cfg_block_compatible +@explicit_cf_compatible class TransientReuse(ppl.Pass): """ Reduces memory consumption by reusing allocated transient array memory. Only modifies arrays that can safely be diff --git a/dace/transformation/subgraph/composite.py b/dace/transformation/subgraph/composite.py index e25ccd192a..fb09fcb51e 100644 --- a/dace/transformation/subgraph/composite.py +++ b/dace/transformation/subgraph/composite.py @@ -3,6 +3,7 @@ Subgraph Fusion - Stencil Tiling Transformation """ +from dace.sdfg.state import SDFGState, StateSubgraphView from dace.transformation.subgraph import SubgraphFusion, MultiExpansion from dace.transformation.subgraph.stencil_tiling import StencilTiling from dace.transformation.subgraph import helpers @@ -18,7 +19,7 @@ @make_properties -@transformation.single_level_sdfg_only +@transformation.explicit_cf_compatible class CompositeFusion(transformation.SubgraphTransformation): """ MultiExpansion + SubgraphFusion in one Transformation Additional StencilTiling is also possible as a canonicalizing @@ -46,8 +47,8 @@ class CompositeFusion(transformation.SubgraphTransformation): expansion_split = Property(desc="Allow MultiExpansion to split up maps, if enabled", dtype=bool, default=True) - def can_be_applied(self, sdfg: SDFG, subgraph: SubgraphView) -> bool: - graph = subgraph.graph + def can_be_applied(self, sdfg: SDFG, subgraph: StateSubgraphView) -> bool: + graph: SDFGState = subgraph.graph if self.allow_expansion == True: subgraph_fusion = SubgraphFusion() subgraph_fusion.setup_match(subgraph) @@ -63,9 +64,16 @@ def can_be_applied(self, sdfg: SDFG, subgraph: SubgraphView) -> bool: graph_indices = [i for (i, n) in enumerate(graph.nodes()) if n in subgraph] sdfg_copy = copy.deepcopy(sdfg) sdfg_copy.reset_cfg_list() - graph_copy = sdfg_copy.nodes()[sdfg.nodes().index(graph)] + par_graph_copy = None + for cfr in sdfg_copy.all_control_flow_regions(): + if cfr.guid == graph.parent_graph.guid: + par_graph_copy = cfr + break + if not par_graph_copy: + return False + graph_copy = par_graph_copy.node(graph.block_id) subgraph_copy = SubgraphView(graph_copy, [graph_copy.nodes()[i] for i in graph_indices]) - expansion.cfg_id = sdfg_copy.cfg_id + expansion.cfg_id = par_graph_copy.cfg_id ##sdfg_copy.apply_transformations(MultiExpansion, states=[graph]) #expansion = MultiExpansion() @@ -99,9 +107,6 @@ def can_be_applied(self, sdfg: SDFG, subgraph: SubgraphView) -> bool: def apply(self, sdfg): subgraph = self.subgraph_view(sdfg) graph = subgraph.graph - scope_dict = graph.scope_dict() - map_entries = helpers.get_outermost_scope_maps(sdfg, graph, subgraph, scope_dict) - first_entry = next(iter(map_entries)) if self.allow_expansion: expansion = MultiExpansion() diff --git a/dace/transformation/subgraph/expansion.py b/dace/transformation/subgraph/expansion.py index aa182e8c80..b013627d2e 100644 --- a/dace/transformation/subgraph/expansion.py +++ b/dace/transformation/subgraph/expansion.py @@ -6,6 +6,7 @@ from dace.sdfg import nodes from dace.sdfg import replace, SDFG, dynamic_map_inputs from dace.sdfg.graph import SubgraphView +from dace.sdfg.state import SDFGState, StateSubgraphView from dace.transformation import transformation from dace.properties import make_properties, Property from dace.transformation.subgraph import helpers @@ -58,12 +59,12 @@ class MultiExpansion(transformation.SubgraphTransformation): allow_offset = Property(dtype=bool, desc="Offset ranges to zero", default=True) - def can_be_applied(self, sdfg: SDFG, subgraph: SubgraphView) -> bool: + def can_be_applied(self, sdfg: SDFG, subgraph: StateSubgraphView) -> bool: # get lowest scope maps of subgraph # grab first node and see whether all nodes are in the same graph # (or nested sdfgs therein) - graph = subgraph.graph + graph: SDFGState = subgraph.graph # next, get all the maps by obtaining a copy (for potential offsets) map_entries = helpers.get_outermost_scope_maps(sdfg, graph, subgraph) diff --git a/dace/transformation/subgraph/gpu_persistent_fusion.py b/dace/transformation/subgraph/gpu_persistent_fusion.py index ff4812d0af..a9a75d6bc7 100644 --- a/dace/transformation/subgraph/gpu_persistent_fusion.py +++ b/dace/transformation/subgraph/gpu_persistent_fusion.py @@ -1,11 +1,11 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. import copy import dace -from dace import dtypes, nodes, registry, Memlet +from dace import nodes, Memlet from dace.sdfg import SDFG, SDFGState, InterstateEdge from dace.dtypes import StorageType, ScheduleType from dace.properties import Property, make_properties -from dace.sdfg.utils import concurrent_subgraphs +from dace.sdfg.state import AbstractControlFlowRegion, LoopRegion from dace.sdfg.graph import SubgraphView from dace.transformation.transformation import SubgraphTransformation @@ -68,12 +68,18 @@ class GPUPersistentKernel(SubgraphTransformation): @staticmethod def can_be_applied(sdfg: SDFG, subgraph: SubgraphView): - if not set(subgraph.nodes()).issubset(set(sdfg.nodes())): return False + subgraph_blocks = set() + for nd in subgraph.nodes(): + subgraph_blocks.add(nd) + if isinstance(nd, AbstractControlFlowRegion): + subgraph_blocks.update(nd.all_control_flow_blocks()) + subgraph_states = set([blk for blk in subgraph_blocks if isinstance(blk, SDFGState)]) + # All states need to be GPU states - for state in subgraph: + for state in subgraph_states: if not GPUPersistentKernel.is_gpu_state(sdfg, state): return False @@ -114,6 +120,12 @@ def can_be_applied(sdfg: SDFG, subgraph: SubgraphView): def apply(self, sdfg: SDFG): subgraph = self.subgraph_view(sdfg) + subgraph_blocks = set() + for nd in subgraph.nodes(): + subgraph_blocks.add(nd) + if isinstance(nd, AbstractControlFlowRegion): + subgraph_blocks.update(nd.all_control_flow_blocks()) + entry_states_in, entry_states_out = self.get_entry_states(sdfg, subgraph) _, exit_states_out = self.get_exit_states(sdfg, subgraph) @@ -181,13 +193,22 @@ def apply(self, sdfg: SDFG): new_symbols.add(k) if k in sdfg.symbols and k not in kernel_sdfg.symbols: kernel_sdfg.add_symbol(k, sdfg.symbols[k]) + for blk in subgraph_blocks: + if isinstance(blk, AbstractControlFlowRegion): + for k, v in blk.new_symbols(sdfg.symbols).items(): + new_symbols.add(k) + if k not in kernel_sdfg.symbols: + kernel_sdfg.add_symbol(k, v) # Setting entry node in nested SDFG if no entry guard was created if entry_guard_state is None: kernel_sdfg.start_state = kernel_sdfg.node_id(entry_state_in) - for state in subgraph: - state.parent = kernel_sdfg + for nd in subgraph: + nd.sdfg = kernel_sdfg + if isinstance(nd, AbstractControlFlowRegion): + for n in nd.all_control_flow_blocks(): + n.sdfg = kernel_sdfg # remove the now nested nodes from the outer sdfg and make sure the # launch state is properly connected to remaining states @@ -203,7 +224,7 @@ def apply(self, sdfg: SDFG): sdfg.add_edge(launch_state, exit_state_out, InterstateEdge()) # Handle data for kernel - kernel_data = set(node.data for state in kernel_sdfg for node in state.nodes() + kernel_data = set(node.data for state in kernel_sdfg.states() for node in state.nodes() if isinstance(node, nodes.AccessNode)) other_data = set(node.data for state in other_states for node in state.nodes() if isinstance(node, nodes.AccessNode)) @@ -230,7 +251,7 @@ def apply(self, sdfg: SDFG): kernel_args_write = set() for data in kernel_args: data_accesses_read_only = [ - state.in_degree(node) == 0 for state in kernel_sdfg for node in state + state.in_degree(node) == 0 for state in kernel_sdfg.states() for node in state if isinstance(node, nodes.AccessNode) and node.data == data ] if all(data_accesses_read_only): diff --git a/dace/transformation/subgraph/stencil_tiling.py b/dace/transformation/subgraph/stencil_tiling.py index 1ba86252c4..29989292be 100644 --- a/dace/transformation/subgraph/stencil_tiling.py +++ b/dace/transformation/subgraph/stencil_tiling.py @@ -2,12 +2,11 @@ """ This module contains classes and functions that implement the orthogonal stencil tiling transformation. """ -import math - import dace -from dace import dtypes, registry, symbolic +from dace import dtypes, symbolic from dace.properties import make_properties, Property, ShapeProperty from dace.sdfg import nodes +from dace.sdfg.state import SDFGState from dace.transformation import transformation from dace.sdfg.propagation import _propagate_node @@ -15,7 +14,6 @@ from dace.transformation.dataflow.map_expansion import MapExpansion from dace.transformation.dataflow.map_collapse import MapCollapse from dace.transformation.dataflow.strip_mining import StripMining -from dace.transformation.interstate.loop_unroll import LoopUnroll from dace.transformation.interstate.loop_detection import DetectLoop from dace.transformation.subgraph import SubgraphFusion @@ -305,8 +303,8 @@ def can_be_applied(sdfg, subgraph) -> bool: return True def apply(self, sdfg): - graph = sdfg.node(self.state_id) subgraph = self.subgraph_view(sdfg) + graph: SDFGState = subgraph.graph map_entries = helpers.get_outermost_scope_maps(sdfg, graph, subgraph) result = StencilTiling.topology(sdfg, graph, map_entries) @@ -430,7 +428,7 @@ def apply(self, sdfg): stripmine_subgraph = {StripMining.map_entry: graph.node_id(map_entry)} - cfg_id = sdfg.cfg_id + cfg_id = graph.parent_graph.cfg_id last_map_entry = None original_schedule = map_entry.schedule self.tile_sizes = [] @@ -557,7 +555,7 @@ def apply(self, sdfg): if l > 1: subgraph = {MapExpansion.map_entry: graph.node_id(map_entry)} trafo_expansion = MapExpansion() - trafo_expansion.setup_match(sdfg, sdfg.cfg_id, sdfg.nodes().index(graph), subgraph, 0) + trafo_expansion.setup_match(sdfg, graph.parent_graph.cfg_id, graph.block_id, subgraph, 0) trafo_expansion.apply(graph, sdfg) maps = [map_entry] for _ in range(l - 1): @@ -568,21 +566,15 @@ def apply(self, sdfg): # MapToForLoop subgraph = {MapToForLoop.map_entry: graph.node_id(map)} trafo_for_loop = MapToForLoop() - trafo_for_loop.setup_match(sdfg, sdfg.cfg_id, sdfg.nodes().index(graph), subgraph, 0) + trafo_for_loop.setup_match(sdfg, graph.parent_graph.cfg_id, graph.block_id, subgraph, 0) trafo_for_loop.apply(graph, sdfg) nsdfg = trafo_for_loop.nsdfg # LoopUnroll + # Prevent circular import + from dace.transformation.interstate.loop_unroll import LoopUnroll - guard = trafo_for_loop.guard - end = trafo_for_loop.after_state - begin = next(e.dst for e in nsdfg.out_edges(guard) if e.dst != end) - - subgraph = { - DetectLoop.loop_guard: nsdfg.node_id(guard), - DetectLoop.loop_begin: nsdfg.node_id(begin), - DetectLoop.exit_state: nsdfg.node_id(end) - } + subgraph = { LoopUnroll.loop: trafo_for_loop.loop_region.block_id } transformation = LoopUnroll() transformation.setup_match(nsdfg, nsdfg.cfg_id, -1, subgraph, 0) transformation.apply(nsdfg, nsdfg) diff --git a/dace/transformation/subgraph/subgraph_fusion.py b/dace/transformation/subgraph/subgraph_fusion.py index 1ff286b85c..6b78e7276c 100644 --- a/dace/transformation/subgraph/subgraph_fusion.py +++ b/dace/transformation/subgraph/subgraph_fusion.py @@ -1,24 +1,21 @@ -# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ This module contains classes that implement subgraph fusion. """ import dace import networkx as nx -from dace import dtypes, registry, symbolic, subsets, data -from dace.sdfg import nodes, utils, replace, SDFG, scope_contains_scope -from dace.sdfg.graph import SubgraphView -from dace.sdfg.scope import ScopeTree +from dace import dtypes, symbolic, subsets, data +from dace.sdfg import nodes, SDFG from dace.memlet import Memlet +from dace.sdfg.state import SDFGState, StateSubgraphView from dace.transformation import transformation from dace.properties import EnumProperty, ListProperty, make_properties, Property -from dace.symbolic import overapproximate -from dace.sdfg.propagation import propagate_memlets_sdfg, propagate_memlet, propagate_memlets_scope, _propagate_node +from dace.sdfg.propagation import _propagate_node from dace.transformation.subgraph import helpers -from dace.transformation.dataflow import RedundantArray -from dace.sdfg.utils import consolidate_edges_scope, get_view_node +from dace.sdfg.utils import consolidate_edges_scope from dace.transformation.helpers import find_contiguous_subsets from copy import deepcopy as dcpy -from typing import List, Union, Tuple +from typing import List, Tuple import warnings import dace.libraries.standard as stdlib @@ -74,7 +71,7 @@ class SubgraphFusion(transformation.SubgraphTransformation): desc="A list of array names to treat as non-transients and not compress", ) - def can_be_applied(self, sdfg: SDFG, subgraph: SubgraphView) -> bool: + def can_be_applied(self, sdfg: SDFG, subgraph: StateSubgraphView) -> bool: """ Fusible if @@ -89,7 +86,7 @@ def can_be_applied(self, sdfg: SDFG, subgraph: SubgraphView) -> bool: 4. Check for any disjoint accesses of arrays. """ # get graph - graph = subgraph.graph + graph: SDFGState = subgraph.graph for node in subgraph.nodes(): if node not in graph.nodes(): return False @@ -626,7 +623,7 @@ def determine_compressible_nodes(sdfg: dace.sdfg.SDFG, # do a full global search and count each data from each intermediate node scope_dict = graph.scope_dict() - for state in sdfg.nodes(): + for state in sdfg.states(): for node in state.nodes(): if isinstance(node, nodes.AccessNode) and node.data in data_intermediate: # add them to the counter set in all cases diff --git a/dace/transformation/testing.py b/dace/transformation/testing.py index 79738c9ec3..dea1be2a9b 100644 --- a/dace/transformation/testing.py +++ b/dace/transformation/testing.py @@ -6,6 +6,7 @@ import traceback from dace.sdfg import SDFG +from dace.sdfg.state import ControlFlowRegion from dace.transformation.optimizer import Optimizer @@ -68,8 +69,9 @@ def _optimize_recursive(self, sdfg: SDFG, depth: int): print(' ' * depth, type(match).__name__, '- ', end='', file=self.stdout) - tsdfg: SDFG = new_sdfg.cfg_list[match.cfg_id] - tgraph = tsdfg.node(match.state_id) if match.state_id >= 0 else tsdfg + tcfg: ControlFlowRegion = new_sdfg.cfg_list[match.cfg_id] + tsdfg = tcfg.sdfg if not isinstance(tcfg, SDFG) else tcfg + tgraph = tcfg.node(match.state_id) if match.state_id >= 0 else tcfg match._sdfg = tsdfg match.apply(tgraph, tsdfg) diff --git a/dace/transformation/transformation.py b/dace/transformation/transformation.py index 727ec5555b..8c11c5d200 100644 --- a/dace/transformation/transformation.py +++ b/dace/transformation/transformation.py @@ -1,4 +1,4 @@ -# Copyright 2019-2022 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ This file contains classes that describe data-centric transformations. @@ -20,10 +20,10 @@ import abc import copy -from dace import dtypes, serialize +from dace import serialize from dace.dtypes import ScheduleType from dace.sdfg import SDFG, SDFGState -from dace.sdfg.state import ControlFlowRegion +from dace.sdfg.state import ControlFlowBlock, ControlFlowRegion from dace.sdfg import nodes as nd, graph as gr, utils as sdutil, propagation, infer_types, state as st from dace.properties import make_properties, Property, DictProperty, SetProperty from dace.transformation import pass_pipeline as ppl @@ -34,8 +34,8 @@ PassT = TypeVar('PassT', bound=ppl.Pass) -def experimental_cfg_block_compatible(cls: PassT) -> PassT: - cls.__experimental_cfg_block_compatible__ = True +def explicit_cf_compatible(cls: PassT) -> PassT: + cls.__explicit_cf_compatible__ = True return cls @@ -339,7 +339,7 @@ def _can_be_applied_and_apply( # Check that all keyword arguments are nodes and if interstate or not sample_node = next(iter(where.values())) - if isinstance(sample_node, SDFGState): + if isinstance(sample_node, ControlFlowBlock): graph = sample_node.parent_graph state_id = -1 cfg_id = graph.cfg_id @@ -506,7 +506,7 @@ def from_json(json_obj: Dict[str, Any], context: Dict[str, Any] = None) -> 'Patt @make_properties -@experimental_cfg_block_compatible +@explicit_cf_compatible class SingleStateTransformation(PatternTransformation, abc.ABC): """ Base class for pattern-matching transformations that find matches within a single SDFG state. @@ -811,8 +811,7 @@ def setup_match(self, subgraph: Union[Set[int], gr.SubgraphView], cfg_id: int = self.subgraph = set(subgraph.graph.node_id(n) for n in subgraph.nodes()) if isinstance(subgraph.graph, SDFGState): - sdfg = subgraph.graph.parent - self.cfg_id = sdfg.cfg_id + self.cfg_id = subgraph.graph.parent_graph.cfg_id self.state_id = subgraph.graph.block_id elif isinstance(subgraph.graph, SDFG): self.cfg_id = subgraph.graph.cfg_id @@ -824,13 +823,6 @@ def setup_match(self, subgraph: Union[Set[int], gr.SubgraphView], cfg_id: int = self.cfg_id = cfg_id self.state_id = state_id - def get_subgraph(self, sdfg: SDFG) -> gr.SubgraphView: - sdfg = sdfg.cfg_list[self.cfg_id] - if self.state_id == -1: - return gr.SubgraphView(sdfg, list(map(sdfg.node, self.subgraph))) - state = sdfg.node(self.state_id) - return st.StateSubgraphView(state, list(map(state.node, self.subgraph))) - @classmethod def subclasses_recursive(cls) -> Set[Type['PatternTransformation']]: """ @@ -1069,7 +1061,7 @@ def blocksafe_wrapper(tgt, *args, **kwargs): sdfg = get_sdfg_arg(tgt, *args) if sdfg and isinstance(sdfg, SDFG): root_sdfg: SDFG = sdfg.cfg_list[0] - if not root_sdfg.using_experimental_blocks: + if not root_sdfg.using_explicit_control_flow: return vanilla_method(tgt, *args, **kwargs) else: warnings.warn('Skipping ' + function_name + ' from ' + cls.__name__ + diff --git a/dace/viewer/webclient b/dace/viewer/webclient index 64861bbc05..f8f3e9d352 160000 --- a/dace/viewer/webclient +++ b/dace/viewer/webclient @@ -1 +1 @@ -Subproject commit 64861bbc054c62bc6cb3f8525bfc4703d6c5e364 +Subproject commit f8f3e9d352ad28794ecddf94fbb04d888083f6fa diff --git a/doc/frontend/parsing.rst b/doc/frontend/parsing.rst index 7adc415497..d909cd7deb 100644 --- a/doc/frontend/parsing.rst +++ b/doc/frontend/parsing.rst @@ -169,7 +169,7 @@ Example: :alt: Generated SDFG for-loop for the above Data-Centric Python program If the :class:`~dace.frontend.python.parser.DaceProgram`'s -:attr:`~dace.frontend.python.parser.DaceProgram.use_experimental_cfg_blocks` attribute is set to true, this will utilize +:attr:`~dace.frontend.python.parser.DaceProgram.use_explicit_control_flow` attribute is set to true, this will utilize :class:`~dace.sdfg.state.LoopRegion`s instead of the explicit state machine depicted above. :func:`~dace.frontend.python.newast.ProgramVisitor.visit_While` @@ -191,7 +191,7 @@ Parses `while `_ statement :alt: Generated SDFG while-loop for the above Data-Centric Python program If the :class:`~dace.frontend.python.parser.DaceProgram`'s -:attr:`~dace.frontend.python.parser.DaceProgram.use_experimental_cfg_blocks` attribute is set to true, this will utilize +:attr:`~dace.frontend.python.parser.DaceProgram.use_explicit_control_flow` attribute is set to true, this will utilize :class:`~dace.sdfg.state.LoopRegion`s instead of the explicit state machine depicted above. :func:`~dace.frontend.python.newast.ProgramVisitor.visit_Break` @@ -214,7 +214,7 @@ behaves as an if-else statement. This is also evident from the generated dataflo :alt: Generated SDFG for-loop with a break statement for the above Data-Centric Python program If the :class:`~dace.frontend.python.parser.DaceProgram`'s -:attr:`~dace.frontend.python.parser.DaceProgram.use_experimental_cfg_blocks` attribute is set to true, loops are +:attr:`~dace.frontend.python.parser.DaceProgram.use_explicit_control_flow` attribute is set to true, loops are represented with :class:`~dace.sdfg.state.LoopRegion`s, and a break is represented with a special :class:`~dace.sdfg.state.LoopRegion.BreakState`. @@ -238,7 +238,7 @@ of `continue` makes the ``A[i] = i`` statement unreachable. This is also evident :alt: Generated SDFG for-loop with a continue statement for the above Data-Centric Python program If the :class:`~dace.frontend.python.parser.DaceProgram`'s -:attr:`~dace.frontend.python.parser.DaceProgram.use_experimental_cfg_blocks` attribute is set to true, loops are +:attr:`~dace.frontend.python.parser.DaceProgram.use_explicit_control_flow` attribute is set to true, loops are represented with :class:`~dace.sdfg.state.LoopRegion`s, and a continue is represented with a special :class:`~dace.sdfg.state.LoopRegion.ContinueState`. diff --git a/tests/codegen/control_flow_detection_test.py b/tests/codegen/control_flow_detection_test.py index e97f7db77b..aaf0e11d42 100644 --- a/tests/codegen/control_flow_detection_test.py +++ b/tests/codegen/control_flow_detection_test.py @@ -1,5 +1,4 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. -from math import exp +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. import pytest import dace @@ -66,7 +65,7 @@ def looptest(): sdfg: dace.SDFG = looptest.to_sdfg(simplify=True) if dace.Config.get_bool('optimizer', 'detect_control_flow'): - assert 'for (' in sdfg.generate_code()[0].code + assert 'while (' in sdfg.generate_code()[0].code A = looptest() A_ref = np.array([0, 0, 2, 0, 4, 0, 6, 0, 8, 0], dtype=np.int32) diff --git a/tests/codegen/data_instrumentation_test.py b/tests/codegen/data_instrumentation_test.py index b254a204b5..aef9c83df3 100644 --- a/tests/codegen/data_instrumentation_test.py +++ b/tests/codegen/data_instrumentation_test.py @@ -317,12 +317,9 @@ def dinstr(A: dace.float64[20]): assert len(dreport.keys()) == 1 assert 'i' in dreport.keys() - assert len(dreport['i']) == 22 - desired = list(range(1, 19)) - s_idx = dreport['i'].index(1) - e_idx = dreport['i'].index(18) - assert np.allclose(dreport['i'][s_idx:e_idx+1], desired) - assert 19 in dreport['i'] + assert len(dreport['i']) == 19 + desired = list(range(0, 19)) + assert np.allclose(dreport['i'], desired) @pytest.mark.datainstrument @@ -356,7 +353,10 @@ def dinstr(A: dace.float64[20]): for i in range(j): A[i] = 0 - sdfg = dinstr.to_sdfg(simplify=True) + # Simplification is turned off to avoid killing the initial start state, since symbol instrumentation can for now + # only be triggered on SDFG states. + # TODO(later): Make it so symbols can be instrumented on any Control flow block + sdfg = dinstr.to_sdfg(simplify=False) sdfg.start_state.symbol_instrument = dace.DataInstrumentationType.Save A = np.ones((20, )) sdfg(A, j=15) diff --git a/tests/constant_array_test.py b/tests/constant_array_test.py index 69444768af..d8067f524e 100644 --- a/tests/constant_array_test.py +++ b/tests/constant_array_test.py @@ -1,4 +1,4 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. from __future__ import print_function import argparse @@ -112,12 +112,10 @@ def test(a: dace.float64[10]): sdfg = test.to_sdfg(simplify=False) sdfg.apply_transformations_repeated([StateFusion, RedundantArray, RedundantSecondArray]) - state = sdfg.node(0) # modify cst to be a dace constant: the python frontend adds an assignment tasklet - n = [n for n in state.nodes() if isinstance(n, nodes.AccessNode) and n.data == 'cst'][0] - for pred in state.predecessors(n): - state.remove_node(pred) + assign_state = sdfg.node(0) + sdfg.remove_node(assign_state) sdfg.add_constant('cst', 1.0, sdfg.arrays['cst']) diff --git a/tests/fortran/array_test.py b/tests/fortran/array_test.py index a8ece680a6..d5b8c5d669 100644 --- a/tests/fortran/array_test.py +++ b/tests/fortran/array_test.py @@ -1,22 +1,13 @@ -# Copyright 2023 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. -from fparser.common.readfortran import FortranStringReader -from fparser.common.readfortran import FortranFileReader -from fparser.two.parser import ParserFactory -import sys, os import numpy as np -import pytest -from dace import SDFG, SDFGState, nodes, dtypes, data, subsets, symbolic +from dace import dtypes, symbolic from dace.frontend.fortran import fortran_parser -from fparser.two.symbol_table import SymbolTable from dace.sdfg import utils as sdutil from dace.sdfg.nodes import AccessNode -import dace.frontend.fortran.ast_components as ast_components -import dace.frontend.fortran.ast_transforms as ast_transforms -import dace.frontend.fortran.ast_utils as ast_utils -import dace.frontend.fortran.ast_internal_classes as ast_internal_classes +from dace.sdfg.state import LoopRegion def test_fortran_frontend_array_access(): @@ -199,9 +190,10 @@ def test_fortran_frontend_memlet_in_map_test(): """ sdfg = fortran_parser.create_sdfg_from_string(test_string, "memlet_range_test") sdfg.simplify() - # Expect that start is begin of for loop -> only one out edge to guard defining iterator variable - assert len(sdfg.out_edges(sdfg.start_state)) == 1 - iter_var = symbolic.symbol(list(sdfg.out_edges(sdfg.start_state)[0].data.assignments.keys())[0]) + # Expect that the start block is a loop + loop = sdfg.nodes()[0] + assert isinstance(loop, LoopRegion) + iter_var = symbolic.pystr_to_symbolic(loop.loop_variable) for state in sdfg.states(): if len(state.nodes()) > 1: diff --git a/tests/fortran/fortran_language_test.py b/tests/fortran/fortran_language_test.py index 0a87baa4da..840f0bda0e 100644 --- a/tests/fortran/fortran_language_test.py +++ b/tests/fortran/fortran_language_test.py @@ -1,4 +1,4 @@ -# Copyright 2023 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. import numpy as np diff --git a/tests/fortran/fortran_loops_test.py b/tests/fortran/fortran_loops_test.py index 4d4c259f07..b18a5e36e8 100644 --- a/tests/fortran/fortran_loops_test.py +++ b/tests/fortran/fortran_loops_test.py @@ -29,7 +29,7 @@ def test_fortran_frontend_loop_region_basic_loop(): ENDDO end SUBROUTINE loop_test_function """ - sdfg = fortran_parser.create_sdfg_from_string(test_string, test_name, use_experimental_cfg_blocks=True) + sdfg = fortran_parser.create_sdfg_from_string(test_string, test_name, use_explicit_cf=True) a_test = np.full([10, 10], 2, order="F", dtype=np.float64) b_test = np.full([10, 10], 3, order="F", dtype=np.float64) diff --git a/tests/inlining_test.py b/tests/inlining_test.py index 251c85e7bc..3ff56b45a7 100644 --- a/tests/inlining_test.py +++ b/tests/inlining_test.py @@ -1,5 +1,6 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. import dace +from dace.sdfg.state import FunctionCallRegion, NamedRegion from dace.transformation.interstate import InlineSDFG, StateFusion from dace.libraries import blas from dace.library import change_default @@ -134,7 +135,7 @@ def outerprog(A: dace.float64[20]): from dace.transformation.interstate import InlineMultistateSDFG sdfg.apply_transformations(InlineMultistateSDFG) - assert sdfg.number_of_nodes() in (4, 5) + assert sdfg.number_of_nodes() in (1, 2) sdfg(A) assert np.allclose(A, expected) @@ -145,14 +146,14 @@ def test_multistate_inline_samename(): @dace.program def nested(A: dace.float64[20]): for i in range(5): - A[i] += A[i - 1] + A[i + 1] += A[i] @dace.program def outerprog(A: dace.float64[20]): for i in range(5): nested(A) - sdfg = outerprog.to_sdfg(simplify=True) + sdfg = outerprog.to_sdfg(simplify=False) A = np.random.rand(20) expected = np.copy(A) @@ -160,7 +161,8 @@ def outerprog(A: dace.float64[20]): from dace.transformation.interstate import InlineMultistateSDFG sdfg.apply_transformations(InlineMultistateSDFG) - assert sdfg.number_of_nodes() in (7, 8) + sdfg.simplify() + assert sdfg.number_of_nodes() == 1 sdfg(A) assert np.allclose(A, expected) @@ -193,8 +195,11 @@ def outerprog(A: dace.float64[20], B: dace.float64[20]): b = 2 * a sdfg = outerprog.to_sdfg(simplify=False) + for cf in sdfg.all_control_flow_regions(): + if isinstance(cf, (FunctionCallRegion, NamedRegion)): + cf.inline() sdfg.apply_transformations_repeated((StateFusion, InlineSDFG)) - assert len(sdfg.states()) == 1 + assert len(sdfg.nodes()) == 1 A = np.random.rand(20) B = np.random.rand(20) @@ -229,9 +234,12 @@ def outerprog(A: dace.float64[10], B: dace.float64[10], C: dace.float64[10]): c = 2 * a sdfg = outerprog.to_sdfg(simplify=False) + for cf in sdfg.all_control_flow_regions(): + if isinstance(cf, (FunctionCallRegion, NamedRegion)): + cf.inline() dace.propagate_memlets_sdfg(sdfg) sdfg.apply_transformations_repeated((StateFusion, InlineSDFG)) - assert len(sdfg.states()) == 1 + assert len(sdfg.nodes()) == 1 assert len([node for node in sdfg.start_state.data_nodes()]) == 3 A = np.random.rand(10) diff --git a/tests/passes/constant_propagation_test.py b/tests/passes/constant_propagation_test.py index acb1033554..48ae6f5b91 100644 --- a/tests/passes/constant_propagation_test.py +++ b/tests/passes/constant_propagation_test.py @@ -2,6 +2,7 @@ import pytest import dace +from dace.sdfg.state import LoopRegion from dace.transformation.passes.constant_propagation import ConstantPropagation, _UnknownValue from dace.transformation.passes.scalar_to_symbol import ScalarToSymbolPromotion import numpy as np @@ -69,7 +70,9 @@ def program(a: dace.float64[20]): ScalarToSymbolPromotion().apply_pass(sdfg, {}) ConstantPropagation().apply_pass(sdfg, {}) - assert set(sdfg.symbols.keys()) == {'i'} + for node in sdfg.all_control_flow_regions(): + if isinstance(node, LoopRegion): + assert node.loop_variable == 'i' # Test tasklets for node, _ in sdfg.all_nodes_recursive(): if isinstance(node, dace.nodes.Tasklet): @@ -91,7 +94,9 @@ def program(a: dace.float64[20]): ScalarToSymbolPromotion().apply_pass(sdfg, {}) ConstantPropagation().apply_pass(sdfg, {}) - assert set(sdfg.symbols.keys()) == {'i'} + for node in sdfg.all_control_flow_regions(): + if isinstance(node, LoopRegion): + assert node.loop_variable == 'i' # Test tasklets i_found = 0 @@ -118,7 +123,10 @@ def program(a: dace.float64[20, 20]): ScalarToSymbolPromotion().apply_pass(sdfg, {}) ConstantPropagation().apply_pass(sdfg, {}) - assert set(sdfg.symbols.keys()) == {'i', 'j'} + assert 'j' in sdfg.symbols + for node in sdfg.all_control_flow_regions(): + if isinstance(node, LoopRegion): + assert node.loop_variable == 'i' # Test memlet last_state = sdfg.sink_nodes()[0] @@ -187,7 +195,9 @@ def test_complex_case(): sdfg.add_edge(usei, merge, dace.InterstateEdge(assignments={'j': 'j+1'})) sdfg.add_edge(merge, last, dace.InterstateEdge('j >= 2')) - propagated = ConstantPropagation().collect_constants(sdfg) #, reachability + propagated = {} + arrays = set(sdfg.arrays.keys() | sdfg.constants_prop.keys()) + ConstantPropagation()._collect_constants_for_region(sdfg, arrays, propagated, {}, {}, {}) assert len(propagated[init]) == 0 assert propagated[branch2]['i'] == '7' assert propagated[guard]['i'] is _UnknownValue diff --git a/tests/passes/dead_code_elimination_test.py b/tests/passes/dead_code_elimination_test.py index 1832ad8321..231ccac84f 100644 --- a/tests/passes/dead_code_elimination_test.py +++ b/tests/passes/dead_code_elimination_test.py @@ -4,6 +4,8 @@ import numpy as np import pytest import dace +from dace.properties import CodeBlock +from dace.sdfg.state import ConditionalBlock, ControlFlowRegion, LoopRegion from dace.transformation.pass_pipeline import Pipeline from dace.transformation.passes.dead_state_elimination import DeadStateElimination from dace.transformation.passes.dead_dataflow_elimination import DeadDataflowElimination @@ -65,6 +67,60 @@ def test_dse_edge_condition_with_integer_as_boolean_regression(): assert res is None +def test_dse_inside_loop(): + sdfg = dace.SDFG('dse_inside_loop') + sdfg.add_symbol('a', dace.int32) + loop = LoopRegion('loop', 'i < 10', 'i', 'i = 0', 'i = i + 1') + start = sdfg.add_state(is_start_block=True) + sdfg.add_node(loop) + end = sdfg.add_state() + sdfg.add_edge(start, loop, dace.InterstateEdge()) + sdfg.add_edge(loop, end, dace.InterstateEdge()) + s = loop.add_state(is_start_block=True) + s1 = loop.add_state() + s2 = loop.add_state() + s3 = loop.add_state() + e = loop.add_state() + loop.add_edge(s, s1, dace.InterstateEdge('a > 0')) + loop.add_edge(s, s2, dace.InterstateEdge('a >= a')) # Always True + loop.add_edge(s, s3, dace.InterstateEdge('a < 0')) + loop.add_edge(s1, e, dace.InterstateEdge()) + loop.add_edge(s2, e, dace.InterstateEdge()) + loop.add_edge(s3, e, dace.InterstateEdge()) + + DeadStateElimination().apply_pass(sdfg, {}) + assert set(sdfg.states()) == {start, s, s2, e, end} + + +def test_dse_inside_loop_conditional(): + sdfg = dace.SDFG('dse_inside_loop') + sdfg.add_symbol('a', dace.int32) + loop = LoopRegion('loop', 'i < 10', 'i', 'i = 0', 'i = i + 1') + start = sdfg.add_state(is_start_block=True) + sdfg.add_node(loop) + end = sdfg.add_state() + sdfg.add_edge(start, loop, dace.InterstateEdge()) + sdfg.add_edge(loop, end, dace.InterstateEdge()) + s = loop.add_state(is_start_block=True) + cond_block = ConditionalBlock('cond', sdfg, loop) + loop.add_node(cond_block) + b1 = ControlFlowRegion('b1', sdfg) + b1.add_state() + cond_block.add_branch(CodeBlock('a > 0'), b1) + b2 = ControlFlowRegion('b2', sdfg) + s2 = b2.add_state() + cond_block.add_branch(CodeBlock('a >= a'), b2) + b3 = ControlFlowRegion('b3', sdfg) + b3.add_state() + cond_block.add_branch(CodeBlock('a < 0'), b3) + e = loop.add_state() + loop.add_edge(s, cond_block, dace.InterstateEdge()) + loop.add_edge(cond_block, e, dace.InterstateEdge()) + + DeadStateElimination().apply_pass(sdfg, {}) + assert set(sdfg.states()) == {start, s, s2, e, end} + + def test_dde_simple(): @dace.program @@ -229,12 +285,12 @@ def dce_tester(a: dace.float64[20], b: dace.float64[20]): sdfg = dce_tester.to_sdfg(simplify=False) result = Pipeline([DeadDataflowElimination(), DeadStateElimination()]).apply_pass(sdfg, {}) sdfg.simplify() - assert sdfg.number_of_nodes() <= 5 + assert sdfg.number_of_nodes() <= 4 # Check that arrays were removed assert all('c' not in [n.data for n in state.data_nodes()] for state in sdfg.nodes()) assert any('f' in [n.data for n in rstate if isinstance(n, dace.nodes.AccessNode)] - for rstate in result['DeadDataflowElimination'].values()) + for rstate in result[DeadDataflowElimination.__name__][0].values()) def test_dce_callback(): @@ -328,6 +384,8 @@ def test_dce_add_type_hint_of_variable(dtype): test_dse_simple() test_dse_unconditional() test_dse_edge_condition_with_integer_as_boolean_regression() + test_dse_inside_loop() + test_dse_inside_loop_conditional() test_dde_simple() test_dde_libnode() test_dde_access_node_in_scope(False) diff --git a/tests/passes/scalar_fission_test.py b/tests/passes/scalar_fission_test.py index adf66f5b1d..f8c59b8f4d 100644 --- a/tests/passes/scalar_fission_test.py +++ b/tests/passes/scalar_fission_test.py @@ -1,4 +1,4 @@ -# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Tests the scalar fission pass. """ import pytest @@ -6,9 +6,12 @@ import dace from dace.transformation.pass_pipeline import Pipeline from dace.transformation.passes.scalar_fission import ScalarFission +from dace.transformation.passes.simplification.control_flow_raising import ControlFlowRaising +from dace.transformation.passes.simplification.prune_empty_conditional_branches import PruneEmptyConditionalBranches -def test_scalar_fission(): +@pytest.mark.parametrize('with_raising', (False, True)) +def test_scalar_fission(with_raising): """ Test the scalar fission pass. This heavily relies on the scalar write shadow scopes pass, which is tested separately. @@ -95,6 +98,9 @@ def test_scalar_fission(): sdfg.add_edge(loop_2_2, guard_2, dace.InterstateEdge(assignments={'i': 'i + 1'})) sdfg.add_edge(guard_2, end_state, dace.InterstateEdge(condition='i >= (N - 1)')) + if with_raising: + Pipeline([ControlFlowRaising(), PruneEmptyConditionalBranches()]).apply_pass(sdfg, {}) + # Test the pass. pipeline = Pipeline([ScalarFission()]) pipeline.apply_pass(sdfg, {}) @@ -107,7 +113,8 @@ def test_scalar_fission(): assert all([n.data == list(tmp1_edge.assignments.values())[0] for n in [tmp1_write, loop1_read_tmp]]) assert all([n.data == list(tmp2_edge.assignments.values())[0] for n in [tmp2_write, loop2_read_tmp]]) -def test_branch_subscopes_nofission(): +@pytest.mark.parametrize('with_raising', (False, True)) +def test_branch_subscopes_nofission(with_raising): sdfg = dace.SDFG('branch_subscope_fission') sdfg.add_symbol('i', dace.int32) sdfg.add_array('A', [2], dace.int32) @@ -185,11 +192,15 @@ def test_branch_subscopes_nofission(): right_after.add_edge(a10, None, t6, 'b', dace.Memlet('B[0]')) right_after.add_edge(t6, 'c', a11, None, dace.Memlet('C[0]')) + if with_raising: + Pipeline([ControlFlowRaising(), PruneEmptyConditionalBranches()]).apply_pass(sdfg, {}) + Pipeline([ScalarFission()]).apply_pass(sdfg, {}) assert set(sdfg.arrays.keys()) == {'A', 'B', 'C'} -def test_branch_subscopes_fission(): +@pytest.mark.parametrize('with_raising', (False, True)) +def test_branch_subscopes_fission(with_raising): sdfg = dace.SDFG('branch_subscope_fission') sdfg.add_symbol('i', dace.int32) sdfg.add_array('A', [2], dace.int32) @@ -277,11 +288,17 @@ def test_branch_subscopes_fission(): merge_1.add_edge(a13, None, t8, 'b', dace.Memlet('B[0]')) merge_1.add_edge(t8, 'c', a14, None, dace.Memlet('C[0]')) + if with_raising: + Pipeline([ControlFlowRaising(), PruneEmptyConditionalBranches()]).apply_pass(sdfg, {}) + Pipeline([ScalarFission()]).apply_pass(sdfg, {}) assert set(sdfg.arrays.keys()) == {'A', 'B', 'C', 'B_0', 'B_1'} if __name__ == '__main__': - test_scalar_fission() - test_branch_subscopes_nofission() - test_branch_subscopes_fission() + test_scalar_fission(False) + test_branch_subscopes_nofission(False) + test_branch_subscopes_fission(False) + test_scalar_fission(True) + test_branch_subscopes_nofission(True) + test_branch_subscopes_fission(True) diff --git a/tests/passes/scalar_to_symbol_test.py b/tests/passes/scalar_to_symbol_test.py index 36decceba2..9d8791182e 100644 --- a/tests/passes/scalar_to_symbol_test.py +++ b/tests/passes/scalar_to_symbol_test.py @@ -1,6 +1,7 @@ # Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Tests the scalar to symbol promotion functionality. """ import dace +from dace.sdfg.state import ConditionalBlock, LoopRegion from dace.transformation.passes import scalar_to_symbol from dace.transformation import transformation as xf, interstate as isxf from dace.transformation.interstate import loop_detection as ld @@ -188,15 +189,21 @@ def testprog6(A: dace.float64[20, 20]): sdfg: dace.SDFG = testprog6.to_sdfg(simplify=False) assert scalar_to_symbol.find_promotable_scalars(sdfg) == {'j'} scalar_to_symbol.ScalarToSymbolPromotion().apply_pass(sdfg, {}) - sdfg.apply_transformations_repeated(isxf.StateFusion) + sdfg.apply_transformations_repeated([isxf.StateFusion, isxf.BlockFusion]) - # There should be 4 states: - # [empty] --j=A[1, 1]--> [A->MapEntry->Tasklet->MapExit->A] --> [empty] - # \--------------------------------------------/ - assert sdfg.number_of_nodes() == 4 - ctr = collections.Counter(s.number_of_nodes() for s in sdfg) - assert ctr[0] == 3 - assert ctr[5] == 1 + # There should be 2 states: + # [empty] --j=A[1, 1]--> [Conditional] + assert sdfg.number_of_nodes() == 2 + # The conditional should contain one branch, with one state, with a single map from A->A inside of it. + cond = None + for n in sdfg.nodes(): + if isinstance(n, ConditionalBlock): + cond = n + break + assert cond is not None + assert len(cond.branches) == 1 + assert len(cond.branches[0][1].nodes()) == 1 + assert len(cond.branches[0][1].nodes()[0].nodes()) == 5 # Program should produce correct result A = np.random.rand(20, 20) @@ -235,24 +242,6 @@ def testprog7(A: dace.float64[20, 20]): assert np.allclose(A, expected) -class LoopTester(ld.DetectLoop, xf.MultiStateTransformation): - """ Tester method that sets loop index on a guard state. """ - - def can_be_applied(self, graph, expr_index, sdfg, permissive): - if not super().can_be_applied(graph, expr_index, sdfg, permissive): - return False - guard = self.loop_guard - if hasattr(guard, '_LOOPINDEX'): - return False - return True - - def apply(self, graph: dace.SDFGState, sdfg: dace.SDFG): - guard = self.loop_guard - edge = sdfg.in_edges(guard)[0] - loopindex = next(iter(edge.data.assignments.keys())) - guard._LOOPINDEX = loopindex - - def test_promote_loop(): """ Loop promotion. """ N = dace.symbol('N') @@ -269,7 +258,7 @@ def testprog8(A: dace.float32[20, 20]): assert 'i' in scalar_to_symbol.find_promotable_scalars(sdfg) scalar_to_symbol.ScalarToSymbolPromotion().apply_pass(sdfg, {}) sdfg.simplify() - assert sdfg.apply_transformations_repeated(LoopTester) == 1 + assert any(isinstance(n, LoopRegion) for n in sdfg.nodes()) def test_promote_loops(): @@ -294,7 +283,7 @@ def testprog9(A: dace.float32[20, 20]): assert 'k' in scalars scalar_to_symbol.ScalarToSymbolPromotion().apply_pass(sdfg, {}) sdfg.simplify() - assert sdfg.apply_transformations_repeated(LoopTester) == 3 + assert any(isinstance(n, LoopRegion) for n in sdfg.nodes()) def test_promote_indirection(): diff --git a/tests/passes/scalar_write_shadow_scopes_analysis_test.py b/tests/passes/scalar_write_shadow_scopes_analysis_test.py index b833a12a94..78704bca60 100644 --- a/tests/passes/scalar_write_shadow_scopes_analysis_test.py +++ b/tests/passes/scalar_write_shadow_scopes_analysis_test.py @@ -1,14 +1,16 @@ -# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Tests the scalar write shadowing analysis pass. """ import pytest - import dace from dace.transformation.pass_pipeline import Pipeline from dace.transformation.passes.analysis import ScalarWriteShadowScopes +from dace.transformation.passes.simplification.control_flow_raising import ControlFlowRaising +from dace.transformation.passes.simplification.prune_empty_conditional_branches import PruneEmptyConditionalBranches -def test_scalar_write_shadow_split(): +@pytest.mark.parametrize('with_raising', (False, True)) +def test_scalar_write_shadow_split(with_raising): """ Test the scalar write shadow scopes pass with writes dominating reads across state. """ @@ -90,6 +92,9 @@ def test_scalar_write_shadow_split(): sdfg.add_edge(loop_2_2, guard_2, dace.InterstateEdge(assignments={'i': 'i + 1'})) sdfg.add_edge(guard_2, end_state, dace.InterstateEdge(condition='i >= (N - 1)')) + if with_raising: + Pipeline([ControlFlowRaising(), PruneEmptyConditionalBranches()]).apply_pass(sdfg, {}) + # Test the pass. pipeline = Pipeline([ScalarWriteShadowScopes()]) results = pipeline.apply_pass(sdfg, {})[ScalarWriteShadowScopes.__name__] @@ -106,7 +111,8 @@ def test_scalar_write_shadow_split(): } -def test_scalar_write_shadow_fused(): +@pytest.mark.parametrize('with_raising', (False, True)) +def test_scalar_write_shadow_fused(with_raising): """ Test the scalar write shadow scopes pass with writes dominating reads in the same state. """ @@ -176,6 +182,9 @@ def test_scalar_write_shadow_fused(): sdfg.add_edge(loop_2, guard_2, dace.InterstateEdge(assignments={'i': 'i + 1'})) sdfg.add_edge(guard_2, end_state, dace.InterstateEdge(condition='i >= (N - 1)')) + if with_raising: + Pipeline([ControlFlowRaising(), PruneEmptyConditionalBranches()]).apply_pass(sdfg, {}) + # Test the pass. pipeline = Pipeline([ScalarWriteShadowScopes()]) results = pipeline.apply_pass(sdfg, {})[ScalarWriteShadowScopes.__name__] @@ -186,7 +195,8 @@ def test_scalar_write_shadow_fused(): assert results[0]['B'][None] == {(loop_1, b1_read), (loop_2, b2_read), (loop_1, b1_write), (loop_2, b2_write)} -def test_scalar_write_shadow_interstate_self(): +@pytest.mark.parametrize('with_raising', (False, True)) +def test_scalar_write_shadow_interstate_self(with_raising): """ Tests the scalar write shadow pass with interstate edge reads being shadowed by the state they're originating from. """ @@ -270,6 +280,9 @@ def test_scalar_write_shadow_interstate_self(): sdfg.add_edge(loop_2_2, guard_2, dace.InterstateEdge(assignments={'i': 'i + 1'})) sdfg.add_edge(guard_2, end_state, dace.InterstateEdge(condition='i >= (N - 1)')) + if with_raising: + Pipeline([ControlFlowRaising(), PruneEmptyConditionalBranches()]).apply_pass(sdfg, {}) + # Test the pass. pipeline = Pipeline([ScalarWriteShadowScopes()]) results = pipeline.apply_pass(sdfg, {})[ScalarWriteShadowScopes.__name__] @@ -286,7 +299,8 @@ def test_scalar_write_shadow_interstate_self(): } -def test_scalar_write_shadow_interstate_pred(): +@pytest.mark.parametrize('with_raising', (False, True)) +def test_scalar_write_shadow_interstate_pred(with_raising): """ Tests the scalar write shadow pass with interstate edge reads being shadowed by a predecessor state. """ @@ -374,6 +388,9 @@ def test_scalar_write_shadow_interstate_pred(): sdfg.add_edge(loop_2_3, guard_2, dace.InterstateEdge(assignments={'i': 'i + 1'})) sdfg.add_edge(guard_2, end_state, dace.InterstateEdge(condition='i >= (N - 1)')) + if with_raising: + Pipeline([ControlFlowRaising(), PruneEmptyConditionalBranches()]).apply_pass(sdfg, {}) + # Test the pass. pipeline = Pipeline([ScalarWriteShadowScopes()]) results = pipeline.apply_pass(sdfg, {})[ScalarWriteShadowScopes.__name__] @@ -390,7 +407,8 @@ def test_scalar_write_shadow_interstate_pred(): } -def test_loop_fake_shadow(): +@pytest.mark.parametrize('with_raising', (False, True)) +def test_loop_fake_shadow(with_raising): sdfg = dace.SDFG('loop_fake_shadow') sdfg.add_array('A', [1], dace.float64, transient=True) sdfg.add_array('B', [1], dace.float64) @@ -432,13 +450,17 @@ def test_loop_fake_shadow(): sdfg.add_edge(loop2, guard, dace.InterstateEdge(assignments={'i': 'i + 1'})) sdfg.add_edge(guard, end, dace.InterstateEdge(condition='i >= 10')) + if with_raising: + Pipeline([ControlFlowRaising(), PruneEmptyConditionalBranches()]).apply_pass(sdfg, {}) + ppl = Pipeline([ScalarWriteShadowScopes()]) res = ppl.apply_pass(sdfg, {})[ScalarWriteShadowScopes.__name__] assert res[0]['A'][(init, init_access)] == {(loop, loop_access), (loop2, loop2_access), (end, end_access)} -def test_loop_fake_complex_shadow(): +@pytest.mark.parametrize('with_raising', (False, True)) +def test_loop_fake_complex_shadow(with_raising): sdfg = dace.SDFG('loop_fake_shadow') sdfg.add_array('A', [1], dace.float64, transient=True) sdfg.add_array('B', [1], dace.float64) @@ -472,13 +494,17 @@ def test_loop_fake_complex_shadow(): sdfg.add_edge(loop2, guard, dace.InterstateEdge(assignments={'i': 'i + 1'})) sdfg.add_edge(guard, end, dace.InterstateEdge(condition='i >= 10')) + if with_raising: + Pipeline([ControlFlowRaising(), PruneEmptyConditionalBranches()]).apply_pass(sdfg, {}) + ppl = Pipeline([ScalarWriteShadowScopes()]) res = ppl.apply_pass(sdfg, {})[ScalarWriteShadowScopes.__name__] assert res[0]['A'][(init, init_access)] == {(loop, loop_access), (loop2, loop2_access)} -def test_loop_real_shadow(): +@pytest.mark.parametrize('with_raising', (False, True)) +def test_loop_real_shadow(with_raising): sdfg = dace.SDFG('loop_fake_shadow') sdfg.add_array('A', [1], dace.float64, transient=True) sdfg.add_array('B', [1], dace.float64) @@ -514,6 +540,9 @@ def test_loop_real_shadow(): sdfg.add_edge(loop2, guard, dace.InterstateEdge(assignments={'i': 'i + 1'})) sdfg.add_edge(guard, end, dace.InterstateEdge(condition='i >= 10')) + if with_raising: + Pipeline([ControlFlowRaising(), PruneEmptyConditionalBranches()]).apply_pass(sdfg, {}) + ppl = Pipeline([ScalarWriteShadowScopes()]) res = ppl.apply_pass(sdfg, {})[ScalarWriteShadowScopes.__name__] @@ -521,7 +550,8 @@ def test_loop_real_shadow(): assert res[0]['A'][(loop2, loop2_access)] == {(loop2, loop2_access)} -def test_dominationless_write_branch(): +@pytest.mark.parametrize('with_raising', (False, True)) +def test_dominationless_write_branch(with_raising): sdfg = dace.SDFG('dominationless_write_branch') sdfg.add_array('A', [1], dace.float64, transient=True) sdfg.add_array('B', [1], dace.float64) @@ -558,6 +588,9 @@ def test_dominationless_write_branch(): sdfg.add_edge(guard, merge, dace.InterstateEdge(condition='B[0] >= 10')) sdfg.add_edge(left, merge, dace.InterstateEdge()) + if with_raising: + Pipeline([ControlFlowRaising(), PruneEmptyConditionalBranches()]).apply_pass(sdfg, {}) + ppl = Pipeline([ScalarWriteShadowScopes()]) res = ppl.apply_pass(sdfg, {})[ScalarWriteShadowScopes.__name__] @@ -566,11 +599,19 @@ def test_dominationless_write_branch(): if __name__ == '__main__': - test_scalar_write_shadow_split() - test_scalar_write_shadow_fused() - test_scalar_write_shadow_interstate_self() - test_scalar_write_shadow_interstate_pred() - test_loop_fake_shadow() - test_loop_fake_complex_shadow() - test_loop_real_shadow() - test_dominationless_write_branch() + test_scalar_write_shadow_split(False) + test_scalar_write_shadow_fused(False) + test_scalar_write_shadow_interstate_self(False) + test_scalar_write_shadow_interstate_pred(False) + test_loop_fake_shadow(False) + test_loop_fake_complex_shadow(False) + test_loop_real_shadow(False) + test_dominationless_write_branch(False) + test_scalar_write_shadow_split(True) + test_scalar_write_shadow_fused(True) + test_scalar_write_shadow_interstate_self(True) + test_scalar_write_shadow_interstate_pred(True) + test_loop_fake_shadow(True) + test_loop_fake_complex_shadow(True) + test_loop_real_shadow(True) + test_dominationless_write_branch(True) diff --git a/tests/passes/simplification/control_flow_raising_test.py b/tests/passes/simplification/control_flow_raising_test.py index 53e01df12f..8b22446974 100644 --- a/tests/passes/simplification/control_flow_raising_test.py +++ b/tests/passes/simplification/control_flow_raising_test.py @@ -1,13 +1,16 @@ # Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +import pytest import dace import numpy as np from dace.sdfg.state import ConditionalBlock -from dace.transformation.pass_pipeline import FixedPointPipeline, Pipeline +from dace.sdfg.utils import inline_control_flow_regions +from dace.transformation.pass_pipeline import FixedPointPipeline from dace.transformation.passes.simplification.control_flow_raising import ControlFlowRaising -def test_dataflow_if_check(): +@pytest.mark.parametrize('lowered_returns', [False, True]) +def test_dataflow_if_check(lowered_returns: bool): @dace.program def dataflow_if_check(A: dace.int32[10], i: dace.int64): @@ -19,10 +22,12 @@ def dataflow_if_check(A: dace.int32[10], i: dace.int64): sdfg = dataflow_if_check.to_sdfg() + # To test raising, we inline the control flow generated by the frontend. + inline_control_flow_regions(sdfg, lower_returns=lowered_returns) + assert not any(isinstance(x, ConditionalBlock) for x in sdfg.nodes()) ppl = FixedPointPipeline([ControlFlowRaising()]) - ppl.__experimental_cfg_block_compatible__ = True ppl.apply_pass(sdfg, {}) assert any(isinstance(x, ConditionalBlock) for x in sdfg.nodes()) @@ -36,7 +41,8 @@ def dataflow_if_check(A: dace.int32[10], i: dace.int64): assert sdfg(A, 6)[0] == 0 -def test_nested_if_chain(): +@pytest.mark.parametrize('lowered_returns', [False, True]) +def test_nested_if_chain(lowered_returns: bool): @dace.program def nested_if_chain(i: dace.int64): @@ -56,8 +62,16 @@ def nested_if_chain(i: dace.int64): sdfg = nested_if_chain.to_sdfg() + # To test raising, we inline the control flow generated by the frontend. + inline_control_flow_regions(sdfg, lower_returns=lowered_returns) + assert not any(isinstance(x, ConditionalBlock) for x in sdfg.nodes()) + ppl = FixedPointPipeline([ControlFlowRaising()]) + ppl.apply_pass(sdfg, {}) + + assert any(isinstance(x, ConditionalBlock) for x in sdfg.nodes()) + assert nested_if_chain(0)[0] == 0 assert nested_if_chain(2)[0] == 1 assert nested_if_chain(4)[0] == 2 @@ -65,7 +79,8 @@ def nested_if_chain(i: dace.int64): assert nested_if_chain(15)[0] == 4 -def test_elif_chain(): +@pytest.mark.parametrize('lowered_returns', [False, True]) +def test_elif_chain(lowered_returns: bool): @dace.program def elif_chain(i: dace.int64): @@ -80,9 +95,16 @@ def elif_chain(i: dace.int64): else: return 4 - elif_chain.use_experimental_cfg_blocks = True sdfg = elif_chain.to_sdfg() + # To test raising, we inline the control flow generated by the frontend. + inline_control_flow_regions(sdfg, lower_returns=lowered_returns) + + assert not any(isinstance(x, ConditionalBlock) for x in sdfg.nodes()) + + ppl = FixedPointPipeline([ControlFlowRaising()]) + ppl.apply_pass(sdfg, {}) + assert any(isinstance(x, ConditionalBlock) for x in sdfg.nodes()) assert elif_chain(0)[0] == 0 @@ -93,6 +115,9 @@ def elif_chain(i: dace.int64): if __name__ == '__main__': - test_dataflow_if_check() - test_nested_if_chain() - test_elif_chain() + test_dataflow_if_check(False) + test_dataflow_if_check(True) + test_nested_if_chain(False) + test_nested_if_chain(True) + test_elif_chain(False) + test_elif_chain(True) diff --git a/tests/passes/simplification/prune_empty_conditional_branches_test.py b/tests/passes/simplification/prune_empty_conditional_branches_test.py new file mode 100644 index 0000000000..dc25cdc670 --- /dev/null +++ b/tests/passes/simplification/prune_empty_conditional_branches_test.py @@ -0,0 +1,105 @@ +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. + + +import numpy as np +import dace +from dace.sdfg.state import ConditionalBlock +from dace.transformation.passes.simplification.prune_empty_conditional_branches import PruneEmptyConditionalBranches + + +def test_prune_empty_else(): + N = dace.symbol('N') + + @dace.program + def prune_empty_else(A: dace.int32[N]): + A[:] = 0 + if N == 32: + for i in range(N): + A[i] = 1 + else: + A[:] = 0 + + sdfg = prune_empty_else.to_sdfg(simplify=False) + + conditional: ConditionalBlock = None + for n in sdfg.nodes(): + if isinstance(n, ConditionalBlock): + conditional = n + break + + assert len(conditional.branches) == 2 + + conditional._branches[-1][0] = None + else_branch = conditional._branches[-1][1] + else_branch.remove_nodes_from(else_branch.nodes()) + else_branch.add_state('empty') + + res = PruneEmptyConditionalBranches().apply_pass(sdfg, {}) + + assert res[conditional.cfg_id] == 1 + assert len(conditional.branches) == 1 + + N1 = 32 + N2 = 31 + A1 = np.zeros((N1,), dtype=np.int32) + A2 = np.zeros((N2,), dtype=np.int32) + verif1 = np.full((N1,), 1, dtype=np.int32) + verif2 = np.zeros((N2,), dtype=np.int32) + + sdfg(A1, N=N1) + sdfg(A2, N=N2) + + assert np.allclose(A1, verif1) + assert np.allclose(A2, verif2) + + +def test_prune_empty_if_with_else(): + N = dace.symbol('N') + + @dace.program + def prune_empty_if_with_else(A: dace.int32[N]): + A[:] = 0 + if N == 32: + for i in range(N): + A[i] = 2 + else: + A[:] = 1 + + sdfg = prune_empty_if_with_else.to_sdfg(simplify=False) + + conditional: ConditionalBlock = None + for n in sdfg.nodes(): + if isinstance(n, ConditionalBlock): + conditional = n + break + + assert len(conditional.branches) == 2 + + conditional._branches[-1][0] = None + if_branch = conditional._branches[0][1] + if_branch.remove_nodes_from(if_branch.nodes()) + if_branch.add_state('empty') + + res = PruneEmptyConditionalBranches().apply_pass(sdfg, {}) + + assert res[conditional.cfg_id] == 1 + assert len(conditional.branches) == 1 + assert conditional.branches[0][0] is not None + + N1 = 32 + N2 = 31 + A1 = np.zeros((N1,), dtype=np.int32) + A2 = np.zeros((N2,), dtype=np.int32) + verif1 = np.zeros((N1,), dtype=np.int32) + verif2 = np.full((N2,), 1, dtype=np.int32) + + sdfg(A1, N=N1) + sdfg(A2, N=N2) + + assert np.allclose(A1, verif1) + assert np.allclose(A2, verif2) + + +if __name__ == '__main__': + test_prune_empty_else() + test_prune_empty_if_with_else() diff --git a/tests/passes/symbol_write_scopes_analysis_test.py b/tests/passes/symbol_write_scopes_analysis_test.py index 8450841729..0f3207a262 100644 --- a/tests/passes/symbol_write_scopes_analysis_test.py +++ b/tests/passes/symbol_write_scopes_analysis_test.py @@ -1,8 +1,6 @@ -# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Tests the symbol write scopes analysis pass. """ -import pytest - import dace from dace.transformation.pass_pipeline import Pipeline from dace.transformation.passes.analysis import SymbolWriteScopes, SymbolScopeDict diff --git a/tests/passes/writeset_underapproximation_test.py b/tests/passes/writeset_underapproximation_test.py index 96df87b5e7..b92aee12da 100644 --- a/tests/passes/writeset_underapproximation_test.py +++ b/tests/passes/writeset_underapproximation_test.py @@ -3,6 +3,7 @@ from typing import Dict import dace from dace.sdfg.analysis.writeset_underapproximation import UnderapproximateWrites, UnderapproximateWritesDict +from dace.sdfg.utils import inline_control_flow_regions from dace.subsets import Range from dace.transformation.pass_pipeline import Pipeline @@ -307,6 +308,9 @@ def loop(A: dace.float64[N, M]): sdfg = loop.to_sdfg(simplify=True) + # NOTE: Until the analysis is changed to make use of the new blocks, inline control flow for the analysis. + inline_control_flow_regions(sdfg) + pipeline = Pipeline([UnderapproximateWrites()]) results = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] @@ -331,6 +335,9 @@ def loop(A: dace.float64[N, M]): sdfg = loop.to_sdfg(simplify=True) + # NOTE: Until the analysis is changed to make use of the new blocks, inline control flow for the analysis. + inline_control_flow_regions(sdfg) + pipeline = Pipeline([UnderapproximateWrites()]) results = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] @@ -456,6 +463,9 @@ def nested_loop(A: dace.float64[M, N]): sdfg = nested_loop.to_sdfg(simplify=True) + # NOTE: Until the analysis is changed to make use of the new blocks, inline control flow for the analysis. + inline_control_flow_regions(sdfg) + pipeline = Pipeline([UnderapproximateWrites()]) result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] write_approx = result[sdfg.cfg_id].approximation @@ -491,6 +501,9 @@ def nested_loop(A: dace.float64[M, N]): sdfg = nested_loop.to_sdfg(simplify=True) + # NOTE: Until the analysis is changed to make use of the new blocks, inline control flow for the analysis. + inline_control_flow_regions(sdfg) + pipeline = Pipeline([UnderapproximateWrites()]) result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] @@ -524,6 +537,9 @@ def nested_loop(A: dace.float64[M, N]): sdfg = nested_loop.to_sdfg(simplify=True) + # NOTE: Until the analysis is changed to make use of the new blocks, inline control flow for the analysis. + inline_control_flow_regions(sdfg) + pipeline = Pipeline([UnderapproximateWrites()]) result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] @@ -558,6 +574,9 @@ def nested_loop(A: dace.float64[M, N]): sdfg = nested_loop.to_sdfg(simplify=True) + # NOTE: Until the analysis is changed to make use of the new blocks, inline control flow for the analysis. + inline_control_flow_regions(sdfg) + pipeline = Pipeline([UnderapproximateWrites()]) result: Dict[int, UnderapproximateWritesDict] = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] @@ -824,6 +843,9 @@ def nested_loop(A: dace.float64[M, N]): sdfg = nested_loop.to_sdfg(simplify=True) + # NOTE: Until the analysis is changed to make use of the new blocks, inline control flow for the analysis. + inline_control_flow_regions(sdfg) + pipeline = Pipeline([UnderapproximateWrites()]) result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] @@ -854,6 +876,9 @@ def nested_loop(A: dace.float64[M, N]): sdfg = nested_loop.to_sdfg(simplify=True) + # NOTE: Until the analysis is changed to make use of the new blocks, inline control flow for the analysis. + inline_control_flow_regions(sdfg) + pipeline = Pipeline([UnderapproximateWrites()]) result = pipeline.apply_pass(sdfg, {})[UnderapproximateWrites.__name__] diff --git a/tests/python_frontend/augassign_wcr_test.py b/tests/python_frontend/augassign_wcr_test.py index 46c0dd8802..1931d18577 100644 --- a/tests/python_frontend/augassign_wcr_test.py +++ b/tests/python_frontend/augassign_wcr_test.py @@ -60,8 +60,8 @@ def test_augassign_wcr(): with dace.config.set_temporary('frontend', 'avoid_wcr', value=True): test_sdfg = augassign_wcr.to_sdfg(simplify=False) wcr_count = 0 - for sdfg in test_sdfg.cfg_list: - for state in sdfg.nodes(): + for sdfg in test_sdfg.all_sdfgs_recursive(): + for state in sdfg.states(): for edge in state.edges(): if edge.data.wcr: wcr_count += 1 @@ -81,8 +81,8 @@ def test_augassign_wcr2(): with dace.config.set_temporary('frontend', 'avoid_wcr', value=True): test_sdfg = augassign_wcr2.to_sdfg(simplify=False) wcr_count = 0 - for sdfg in test_sdfg.cfg_list: - for state in sdfg.nodes(): + for sdfg in test_sdfg.all_sdfgs_recursive(): + for state in sdfg.states(): for edge in state.edges(): if edge.data.wcr: wcr_count += 1 @@ -105,8 +105,8 @@ def test_augassign_wcr3(): with dace.config.set_temporary('frontend', 'avoid_wcr', value=True): test_sdfg = augassign_wcr3.to_sdfg(simplify=False) wcr_count = 0 - for sdfg in test_sdfg.cfg_list: - for state in sdfg.nodes(): + for sdfg in test_sdfg.all_sdfgs_recursive(): + for state in sdfg.states(): for edge in state.edges(): if edge.data.wcr: wcr_count += 1 diff --git a/tests/python_frontend/conditional_regions_test.py b/tests/python_frontend/conditional_regions_test.py index 07e214653c..6a917a13c3 100644 --- a/tests/python_frontend/conditional_regions_test.py +++ b/tests/python_frontend/conditional_regions_test.py @@ -15,7 +15,7 @@ def dataflow_if_check(A: dace.int32[10], i: dace.int64): return 10 return 100 - dataflow_if_check.use_experimental_cfg_blocks = True + dataflow_if_check.use_explicit_cf = True sdfg = dataflow_if_check.to_sdfg() assert any(isinstance(x, ConditionalBlock) for x in sdfg.nodes()) @@ -47,7 +47,7 @@ def nested_if_chain(i: dace.int64): else: return 4 - nested_if_chain.use_experimental_cfg_blocks = True + nested_if_chain.use_explicit_cf = True sdfg = nested_if_chain.to_sdfg() assert any(isinstance(x, ConditionalBlock) for x in sdfg.nodes()) @@ -74,7 +74,7 @@ def elif_chain(i: dace.int64): else: return 4 - elif_chain.use_experimental_cfg_blocks = True + elif_chain.use_explicit_cf = True sdfg = elif_chain.to_sdfg() assert any(isinstance(x, ConditionalBlock) for x in sdfg.nodes()) diff --git a/tests/python_frontend/function_regions_test.py b/tests/python_frontend/function_regions_test.py index c5c9b4ac6f..5d5082a92e 100644 --- a/tests/python_frontend/function_regions_test.py +++ b/tests/python_frontend/function_regions_test.py @@ -3,6 +3,7 @@ import numpy as np import dace from dace.sdfg.state import FunctionCallRegion +from dace.transformation.passes.simplify import SimplifyPass def test_function_call(): N = dace.symbol("N") @@ -11,9 +12,9 @@ def func(A: dace.float64[N]): @dace.program def prog(I: dace.float64[N]): return func(I) - prog.use_experimental_cfg_blocks = True - sdfg = prog.to_sdfg() - call_region: FunctionCallRegion = sdfg.nodes()[1] + sdfg = prog.to_sdfg(simplify=False) + SimplifyPass(no_inline_function_call_regions=True, no_inline_named_regions=True).apply_pass(sdfg, {}) + call_region: FunctionCallRegion = sdfg.nodes()[0] assert call_region.arguments == {'A': 'I'} assert sdfg(np.array([+1], dtype=np.float64), N=1) == 15 assert sdfg(np.array([-1], dtype=np.float64), N=1) == 5 @@ -26,13 +27,13 @@ def func(A: dace.float64[N], B: dace.float64[N], C: dace.float64[N]): def prog(E: dace.float64[N], F: dace.float64[N], G: dace.float64[N]): func(A=E, B=F, C=G) func(A=G, B=E, C=E) - prog.use_experimental_cfg_blocks = True E = np.array([1]) F = np.array([2]) G = np.array([3]) - sdfg = prog.to_sdfg(E=E, F=F, G=G, N=1) - call1: FunctionCallRegion = sdfg.nodes()[1] - call2: FunctionCallRegion = sdfg.nodes()[2] + sdfg = prog.to_sdfg(E=E, F=F, G=G, N=1, simplify=False) + SimplifyPass(no_inline_function_call_regions=True, no_inline_named_regions=True).apply_pass(sdfg, {}) + call1: FunctionCallRegion = sdfg.nodes()[0] + call2: FunctionCallRegion = sdfg.nodes()[1] assert call1.arguments == {'A': 'E', 'B': 'F', 'C': 'G'} assert call2.arguments == {'A': 'G', 'B': 'E', 'C': 'E'} @@ -44,10 +45,10 @@ def func(A: dace.float64[N], B: dace.float64[N], C: dace.float64[N]): def prog(): func(A=np.array([1]), B=np.array([2]), C=np.array([3])) func(A=np.array([3]), B=np.array([1]), C=np.array([1])) - prog.use_experimental_cfg_blocks = True - sdfg = prog.to_sdfg(N=1) - call1: FunctionCallRegion = sdfg.nodes()[1] - call2: FunctionCallRegion = sdfg.nodes()[2] + sdfg = prog.to_sdfg(N=1, simplify=False) + SimplifyPass(no_inline_function_call_regions=True, no_inline_named_regions=True).apply_pass(sdfg, {}) + call1: FunctionCallRegion = sdfg.nodes()[0] + call2: FunctionCallRegion = sdfg.nodes()[1] assert call1.arguments == {'A': '__tmp0', 'B': '__tmp1', 'C': '__tmp2'} assert call2.arguments == {'A': '__tmp4', 'B': '__tmp5', 'C': '__tmp6'} diff --git a/tests/python_frontend/loop_regions_test.py b/tests/python_frontend/loop_regions_test.py index cb7fa30fd4..1047f770da 100644 --- a/tests/python_frontend/loop_regions_test.py +++ b/tests/python_frontend/loop_regions_test.py @@ -15,7 +15,7 @@ def for_loop(): def test_for_loop(): - for_loop.use_experimental_cfg_blocks = True + for_loop.use_explicit_cf = True sdfg = for_loop.to_sdfg() assert any(isinstance(x, LoopRegion) for x in sdfg.nodes()) @@ -39,7 +39,7 @@ def for_loop_with_break_continue(): def test_for_loop_with_break_continue(): - for_loop_with_break_continue.use_experimental_cfg_blocks = True + for_loop_with_break_continue.use_explicit_cf = True sdfg = for_loop_with_break_continue.to_sdfg() assert any(isinstance(x, LoopRegion) for x in sdfg.nodes()) @@ -68,7 +68,7 @@ def nested_for_loop(): def test_nested_for_loop(): - nested_for_loop.use_experimental_cfg_blocks = True + nested_for_loop.use_explicit_cf = True sdfg = nested_for_loop.to_sdfg() assert any(isinstance(x, LoopRegion) for x in sdfg.nodes()) @@ -92,7 +92,7 @@ def while_loop(): def test_while_loop(): - while_loop.use_experimental_cfg_blocks = True + while_loop.use_explicit_cf = True sdfg = while_loop.to_sdfg() assert any(isinstance(x, LoopRegion) for x in sdfg.nodes()) @@ -118,7 +118,7 @@ def while_loop_with_break_continue(): def test_while_loop_with_break_continue(): - while_loop_with_break_continue.use_experimental_cfg_blocks = True + while_loop_with_break_continue.use_explicit_cf = True sdfg = while_loop_with_break_continue.to_sdfg() assert any(isinstance(x, LoopRegion) for x in sdfg.nodes()) @@ -151,7 +151,7 @@ def nested_while_loop(): def test_nested_while_loop(): - nested_while_loop.use_experimental_cfg_blocks = True + nested_while_loop.use_explicit_cf = True sdfg = nested_while_loop.to_sdfg() assert any(isinstance(x, LoopRegion) for x in sdfg.nodes()) @@ -184,7 +184,7 @@ def nested_for_while_loop(): def test_nested_for_while_loop(): - nested_for_while_loop.use_experimental_cfg_blocks = True + nested_for_while_loop.use_explicit_cf = True sdfg = nested_for_while_loop.to_sdfg() assert any(isinstance(x, LoopRegion) for x in sdfg.nodes()) @@ -217,7 +217,7 @@ def nested_while_for_loop(): def test_nested_while_for_loop(): - nested_while_for_loop.use_experimental_cfg_blocks = True + nested_while_for_loop.use_explicit_cf = True sdfg = nested_while_for_loop.to_sdfg() assert any(isinstance(x, LoopRegion) for x in sdfg.nodes()) @@ -244,7 +244,7 @@ def map_with_break_continue(): def test_map_with_break_continue(): try: - map_with_break_continue.use_experimental_cfg_blocks = True + map_with_break_continue.use_explicit_cf = True map_with_break_continue() except Exception as e: if isinstance(e, DaceSyntaxError): @@ -266,7 +266,7 @@ def test_nested_map_for_loop(): for i in range(10): for j in range(10): ref[i, j] = i * 10 + j - nested_map_for_loop.use_experimental_cfg_blocks = True + nested_map_for_loop.use_explicit_cf = True val = nested_map_for_loop() assert (np.array_equal(val, ref)) @@ -287,7 +287,7 @@ def test_nested_map_for_for_loop(): for j in range(10): for k in range(10): ref[i, j, k] = i * 100 + j * 10 + k - nested_map_for_for_loop.use_experimental_cfg_blocks = True + nested_map_for_for_loop.use_explicit_cf = True val = nested_map_for_for_loop() assert (np.array_equal(val, ref)) @@ -308,7 +308,7 @@ def test_nested_for_map_for_loop(): for j in range(10): for k in range(10): ref[i, j, k] = i * 100 + j * 10 + k - nested_for_map_for_loop.use_experimental_cfg_blocks = True + nested_for_map_for_loop.use_explicit_cf = True val = nested_for_map_for_loop() assert (np.array_equal(val, ref)) @@ -332,7 +332,7 @@ def test_nested_map_for_loop_with_tasklet(): for i in range(10): for j in range(10): ref[i, j] = i * 10 + j - nested_map_for_loop_with_tasklet.use_experimental_cfg_blocks = True + nested_map_for_loop_with_tasklet.use_explicit_cf = True val = nested_map_for_loop_with_tasklet() assert (np.array_equal(val, ref)) @@ -358,7 +358,7 @@ def test_nested_map_for_for_loop_with_tasklet(): for j in range(10): for k in range(10): ref[i, j, k] = i * 100 + j * 10 + k - nested_map_for_for_loop_with_tasklet.use_experimental_cfg_blocks = True + nested_map_for_for_loop_with_tasklet.use_explicit_cf = True val = nested_map_for_for_loop_with_tasklet() assert (np.array_equal(val, ref)) @@ -384,7 +384,7 @@ def test_nested_for_map_for_loop_with_tasklet(): for j in range(10): for k in range(10): ref[i, j, k] = i * 100 + j * 10 + k - nested_for_map_for_loop_with_tasklet.use_experimental_cfg_blocks = True + nested_for_map_for_loop_with_tasklet.use_explicit_cf = True val = nested_for_map_for_loop_with_tasklet() assert (np.array_equal(val, ref)) @@ -404,7 +404,7 @@ def test_nested_map_for_loop_2(): for i in range(10): for j in range(10): ref[i, j] = 2 + i * 10 + j - nested_map_for_loop_2.use_experimental_cfg_blocks = True + nested_map_for_loop_2.use_explicit_cf = True val = nested_map_for_loop_2(B) assert (np.array_equal(val, ref)) @@ -430,7 +430,7 @@ def test_nested_map_for_loop_with_tasklet_2(): for i in range(10): for j in range(10): ref[i, j] = 2 + i * 10 + j - nested_map_for_loop_with_tasklet_2.use_experimental_cfg_blocks = True + nested_map_for_loop_with_tasklet_2.use_explicit_cf = True val = nested_map_for_loop_with_tasklet_2(B) assert (np.array_equal(val, ref)) @@ -449,7 +449,7 @@ def test_nested_map_with_symbol(): for i in range(10): for j in range(i, 10): ref[i, j] = i * 10 + j - nested_map_with_symbol.use_experimental_cfg_blocks = True + nested_map_with_symbol.use_explicit_cf = True val = nested_map_with_symbol() assert (np.array_equal(val, ref)) @@ -477,7 +477,7 @@ def for_else(A: dace.float64[20]): for_else.f(expected_1) for_else.f(expected_2) - for_else.use_experimental_cfg_blocks = True + for_else.use_explicit_cf = True for_else(A) assert np.allclose(A, expected_1) @@ -500,7 +500,7 @@ def while_else(A: dace.float64[2]): A[1] = 1.0 A[1] = 1.0 - while_else.use_experimental_cfg_blocks = True + while_else.use_explicit_cf = True A = np.array([0.0, 0.0]) expected = np.array([5.0, 1.0]) @@ -523,7 +523,7 @@ def branch_in_for(cond: dace.int32): def test_branch_in_for(): - branch_in_for.use_experimental_cfg_blocks = True + branch_in_for.use_explicit_cf = True sdfg = branch_in_for.to_sdfg(simplify=False) assert len(sdfg.source_nodes()) == 1 @@ -540,7 +540,7 @@ def branch_in_while(cond: dace.int32): def test_branch_in_while(): - branch_in_while.use_experimental_cfg_blocks = True + branch_in_while.use_explicit_cf = True sdfg = branch_in_while.to_sdfg(simplify=False) assert len(sdfg.source_nodes()) == 1 @@ -553,7 +553,7 @@ def for_with_return(A: dace.int32[10]): return 1 return 0 - for_with_return.use_experimental_cfg_blocks = True + for_with_return.use_explicit_cf = True sdfg = for_with_return.to_sdfg() A = np.full((10,), 1).astype(np.int32) @@ -578,7 +578,7 @@ def for_while_with_return(A: dace.int32[10, 10]): j += 1 return 0 - for_while_with_return.use_experimental_cfg_blocks = True + for_while_with_return.use_explicit_cf = True sdfg = for_while_with_return.to_sdfg() A = np.full((10,10), 1).astype(np.int32) diff --git a/tests/python_frontend/multiple_nested_sdfgs_test.py b/tests/python_frontend/multiple_nested_sdfgs_test.py index fc1d9f852b..722342dbfe 100644 --- a/tests/python_frontend/multiple_nested_sdfgs_test.py +++ b/tests/python_frontend/multiple_nested_sdfgs_test.py @@ -68,8 +68,8 @@ def multiple_nested_sdfgs(input: dace.float32[2, 2], output: dace.float32[2, 2]) sdfg = multiple_nested_sdfgs.to_sdfg(simplify=False) state = None - for node in sdfg.nodes(): - if re.fullmatch(r"out_tmp_div_sum_\d+_call.*", node.label): + for node in sdfg.states(): + if re.fullmatch(r"call_out_tmp_div_sum_\d+.*", node.label): assert state is None, "Two states match the regex, cannot decide which one should be used" state = node assert state is not None diff --git a/tests/python_frontend/named_region_test.py b/tests/python_frontend/named_region_test.py index f9be206bca..593fde5c0f 100644 --- a/tests/python_frontend/named_region_test.py +++ b/tests/python_frontend/named_region_test.py @@ -3,6 +3,7 @@ import numpy as np import dace from dace.sdfg.state import NamedRegion +from dace.transformation.passes.simplify import SimplifyPass def test_named_region_no_name(): @@ -11,21 +12,21 @@ def func(A: dace.float64[1]): with dace.named: A[0] = 20 return A - func.use_experimental_cfg_blocks = True - sdfg = func.to_sdfg() - named_region = sdfg.reset_cfg_list()[1] + sdfg = func.to_sdfg(simplify=False) + SimplifyPass(no_inline_function_call_regions=True, no_inline_named_regions=True).apply_pass(sdfg, {}) + named_region = sdfg.nodes()[0] assert isinstance(named_region, NamedRegion) A = np.zeros(shape=(1,)) - assert func(A) == 20 + assert sdfg(A) == 20 def test_named_region_with_name(): @dace.program def func(): with dace.named("my named region"): pass - func.use_experimental_cfg_blocks = True - sdfg = func.to_sdfg() - named_region: NamedRegion = sdfg.reset_cfg_list()[1] + sdfg = func.to_sdfg(simplify=False) + SimplifyPass(no_inline_function_call_regions=True, no_inline_named_regions=True).apply_pass(sdfg, {}) + named_region: NamedRegion = sdfg.nodes()[0] assert named_region.label == "my named region" def test_nested_named_regions(): @@ -35,13 +36,13 @@ def func(): with dace.named("middle region"): with dace.named("inner region"): pass - func.use_experimental_cfg_blocks = True - sdfg = func.to_sdfg() - outer: NamedRegion = sdfg.nodes()[1] + sdfg = func.to_sdfg(simplify=False) + SimplifyPass(no_inline_function_call_regions=True, no_inline_named_regions=True).apply_pass(sdfg, {}) + outer: NamedRegion = sdfg.nodes()[0] assert outer.label == "outer region" - middle: NamedRegion = outer.nodes()[1] + middle: NamedRegion = outer.nodes()[0] assert middle.label == "middle region" - inner: NamedRegion = middle.nodes()[1] + inner: NamedRegion = middle.nodes()[0] assert inner.label == "inner region" if __name__ == "__main__": diff --git a/tests/schedule_tree/nesting_test.py b/tests/schedule_tree/nesting_test.py index 161f15d6c1..8361ecb149 100644 --- a/tests/schedule_tree/nesting_test.py +++ b/tests/schedule_tree/nesting_test.py @@ -5,6 +5,7 @@ import dace from dace.sdfg.analysis.schedule_tree import treenodes as tn from dace.sdfg.analysis.schedule_tree.sdfg_to_tree import as_schedule_tree +from dace.sdfg.utils import inline_control_flow_regions from dace.transformation.dataflow import RemoveSliceView import pytest @@ -63,7 +64,8 @@ def tester(A: dace.float64[N, N]): if simplified: assert [type(n) - for n in stree.preorder_traversal()][1:] == [tn.MapScope, tn.MapScope, tn.ForScope, tn.TaskletNode] + for n in stree.preorder_traversal()][1:] == [tn.MapScope, tn.MapScope, tn.GeneralLoopScope, + tn.TaskletNode] tasklet: tn.TaskletNode = list(stree.preorder_traversal())[-1] @@ -127,6 +129,7 @@ def tester(a: dace.float64[40], b: dace.float64[40]): nester(b[1:21], a[10:30]) sdfg = tester.to_sdfg(simplify=False) + inline_control_flow_regions(sdfg) sdfg.apply_transformations_repeated(RemoveSliceView) stree = as_schedule_tree(sdfg) @@ -150,6 +153,7 @@ def tester(a: dace.float64[40]): nester(a[1:21], a[10:30]) sdfg = tester.to_sdfg(simplify=False) + inline_control_flow_regions(sdfg) sdfg.apply_transformations_repeated(RemoveSliceView) stree = as_schedule_tree(sdfg) @@ -176,6 +180,7 @@ def tester(a: dace.float64[N, N]): nester1(a[:, 1]) sdfg = tester.to_sdfg(simplify=simplify) + inline_control_flow_regions(sdfg) stree = as_schedule_tree(sdfg) # Simplifying yields a different SDFG due to views, so testing is slightly different diff --git a/tests/schedule_tree/schedule_test.py b/tests/schedule_tree/schedule_test.py index 1bf2962cb3..295e5f6bce 100644 --- a/tests/schedule_tree/schedule_test.py +++ b/tests/schedule_tree/schedule_test.py @@ -27,14 +27,14 @@ def matmul(A: dace.float32[10, 10], B: dace.float32[10, 10], C: dace.float32[10, assert len(stree.children) == 1 # for fornode = stree.children[0] - assert isinstance(fornode, tn.ForScope) + assert isinstance(fornode, tn.GeneralLoopScope) assert len(fornode.children) == 1 # map mapnode = fornode.children[0] assert isinstance(mapnode, tn.MapScope) assert len(mapnode.children) == 2 # copy, for copynode, fornode = mapnode.children assert isinstance(copynode, tn.CopyNode) - assert isinstance(fornode, tn.ForScope) + assert isinstance(fornode, tn.GeneralLoopScope) assert len(fornode.children) == 1 # tasklet tasklet = fornode.children[0] assert isinstance(tasklet, tn.TaskletNode) @@ -80,7 +80,7 @@ def main(a: dace.float64[20, 10]): assert len(stree.children) == 4 offsets = ['', '5', '10', '15'] for fornode, offset in zip(stree.children, offsets): - assert isinstance(fornode, tn.ForScope) + assert isinstance(fornode, tn.GeneralLoopScope) assert len(fornode.children) == 1 # map mapnode = fornode.children[0] assert isinstance(mapnode, tn.MapScope) @@ -128,7 +128,7 @@ def main(a: dace.float64[20, 10]): sdfg = main.to_sdfg() stree = as_schedule_tree(sdfg) - assert isinstance(stree.children[0], tn.NView) + assert any(isinstance(v, tn.NView) for v in stree.children) def test_irreducible_sub_sdfg(): diff --git a/tests/sdfg/conditional_region_test.py b/tests/sdfg/conditional_region_test.py index 0be40f43d3..38778cba2b 100644 --- a/tests/sdfg/conditional_region_test.py +++ b/tests/sdfg/conditional_region_test.py @@ -43,6 +43,7 @@ def test_serialization(): for j in range(10): cfg = ControlFlowRegion(f'cfg_{j}', sdfg) + cfg.add_state('noop') cond_region.add_branch(CodeBlock(f'i == {j}'), cfg) assert sdfg.is_valid() diff --git a/tests/sdfg/control_flow_inline_test.py b/tests/sdfg/control_flow_inline_test.py index 87af09b9c4..3a4cfd7c13 100644 --- a/tests/sdfg/control_flow_inline_test.py +++ b/tests/sdfg/control_flow_inline_test.py @@ -19,7 +19,7 @@ def test_loop_inlining_regular_for(): sdfg.add_edge(state0, loop1, dace.InterstateEdge()) sdfg.add_edge(loop1, state3, dace.InterstateEdge()) - sdutils.inline_loop_blocks(sdfg) + sdutils.inline_control_flow_regions(sdfg, [LoopRegion]) states = sdfg.nodes() # Get top-level states only, not all (.states()), in case something went wrong assert len(states) == 8 @@ -41,7 +41,7 @@ def test_loop_inlining_regular_while(): sdfg.add_edge(state0, loop1, dace.InterstateEdge()) sdfg.add_edge(loop1, state3, dace.InterstateEdge()) - sdutils.inline_loop_blocks(sdfg) + sdutils.inline_control_flow_regions(sdfg, [LoopRegion]) states = sdfg.nodes() # Get top-level states only, not all (.states()), in case something went wrong guard = None @@ -75,7 +75,7 @@ def test_loop_inlining_do_while(): sdfg.add_edge(state0, loop1, dace.InterstateEdge()) sdfg.add_edge(loop1, state3, dace.InterstateEdge()) - sdutils.inline_loop_blocks(sdfg) + sdutils.inline_control_flow_regions(sdfg, [LoopRegion]) states = sdfg.nodes() # Get top-level states only, not all (.states()), in case something went wrong guard = None @@ -115,7 +115,7 @@ def test_loop_inlining_do_for(): sdfg.add_edge(state0, loop1, dace.InterstateEdge()) sdfg.add_edge(loop1, state3, dace.InterstateEdge()) - sdutils.inline_loop_blocks(sdfg) + sdutils.inline_control_flow_regions(sdfg, [LoopRegion]) states = sdfg.nodes() # Get top-level states only, not all (.states()), in case something went wrong guard = None @@ -175,7 +175,7 @@ def test_inline_triple_nested_for(): reduce_state.add_edge(tmpnode2, None, red, None, dace.Memlet.simple('tmp', '0:N, 0:M, 0:K')) reduce_state.add_edge(red, None, cnode, None, dace.Memlet.simple('C', '0:N, 0:M')) - sdutils.inline_loop_blocks(sdfg) + sdutils.inline_control_flow_regions(sdfg, [LoopRegion]) assert len(sdfg.nodes()) == 14 assert not any(isinstance(s, LoopRegion) for s in sdfg.nodes()) @@ -203,7 +203,7 @@ def test_loop_inlining_for_continue_break(): state7 = sdfg.add_state('state7') sdfg.add_edge(loop1, state7, dace.InterstateEdge()) - sdutils.inline_loop_blocks(sdfg) + sdutils.inline_control_flow_regions(sdfg, [LoopRegion]) states = sdfg.nodes() # Get top-level states only, not all (.states()), in case something went wrong assert len(states) == 12 @@ -240,7 +240,7 @@ def test_loop_inlining_multi_assignments(): sdfg.add_edge(state0, loop1, dace.InterstateEdge()) sdfg.add_edge(loop1, state3, dace.InterstateEdge()) - sdutils.inline_loop_blocks(sdfg) + sdutils.inline_control_flow_regions(sdfg, [LoopRegion]) states = sdfg.nodes() # Get top-level states only, not all (.states()), in case something went wrong assert len(states) == 8 @@ -282,7 +282,7 @@ def test_loop_inlining_invalid_update_statement(): sdfg.add_edge(state0, loop1, dace.InterstateEdge()) sdfg.add_edge(loop1, state3, dace.InterstateEdge()) - sdutils.inline_loop_blocks(sdfg) + sdutils.inline_control_flow_regions(sdfg, [LoopRegion]) nodes = sdfg.nodes() assert len(nodes) == 3 diff --git a/tests/sdfg/free_symbols_test.py b/tests/sdfg/free_symbols_test.py index 3d162203d1..51cd739000 100644 --- a/tests/sdfg/free_symbols_test.py +++ b/tests/sdfg/free_symbols_test.py @@ -55,7 +55,7 @@ def test_sdfg(): sdfg: dace.SDFG = fsymtest_multistate.to_sdfg() sdfg.simplify() # Test each state separately - for state in sdfg.nodes(): + for state in sdfg.states(): assert (state.free_symbols == {'k', 'N', 'M', 'L'} or state.free_symbols == set()) # The SDFG itself should have another free symbol assert sdfg.free_symbols == {'K', 'M', 'N', 'L'} @@ -67,7 +67,7 @@ def test_constants(): sdfg.add_constant('K', 5) sdfg.add_constant('L', 20) - for state in sdfg.nodes(): + for state in sdfg.states(): assert (state.free_symbols == {'k', 'N', 'M'} or state.free_symbols == set()) assert sdfg.free_symbols == {'M', 'N'} diff --git a/tests/sdfg/loop_region_test.py b/tests/sdfg/loop_region_test.py index dedafb67ba..86b84851e2 100644 --- a/tests/sdfg/loop_region_test.py +++ b/tests/sdfg/loop_region_test.py @@ -8,7 +8,7 @@ def _make_regular_for_loop() -> SDFG: sdfg = dace.SDFG('regular_for') - sdfg.using_experimental_blocks = True + sdfg.using_explicit_control_flow = True state0 = sdfg.add_state('state0', is_start_block=True) loop1 = LoopRegion(label='loop1', condition_expr='i < 10', loop_var='i', initialize_expr='i = 0', update_expr='i = i + 1', inverted=False) @@ -27,7 +27,7 @@ def _make_regular_for_loop() -> SDFG: def _make_regular_while_loop() -> SDFG: sdfg = dace.SDFG('regular_while') - sdfg.using_experimental_blocks = True + sdfg.using_explicit_control_flow = True state0 = sdfg.add_state('state0', is_start_block=True) loop1 = LoopRegion(label='loop1', condition_expr='i < 10') sdfg.add_array('A', [10], dace.float32) @@ -47,7 +47,7 @@ def _make_regular_while_loop() -> SDFG: def _make_do_while_loop() -> SDFG: sdfg = dace.SDFG('do_while') - sdfg.using_experimental_blocks = True + sdfg.using_explicit_control_flow = True sdfg.add_symbol('i', dace.int32) state0 = sdfg.add_state('state0', is_start_block=True) loop1 = LoopRegion(label='loop1', condition_expr='i < 10', inverted=True) @@ -67,7 +67,7 @@ def _make_do_while_loop() -> SDFG: def _make_do_for_loop() -> SDFG: sdfg = dace.SDFG('do_for') - sdfg.using_experimental_blocks = True + sdfg.using_explicit_control_flow = True sdfg.add_symbol('i', dace.int32) sdfg.add_array('A', [10], dace.float32) state0 = sdfg.add_state('state0', is_start_block=True) @@ -88,7 +88,7 @@ def _make_do_for_loop() -> SDFG: def _make_do_for_inverted_cond_loop() -> SDFG: sdfg = dace.SDFG('do_for_inverted_cond') - sdfg.using_experimental_blocks = True + sdfg.using_explicit_control_flow = True sdfg.add_symbol('i', dace.int32) sdfg.add_array('A', [10], dace.float32) state0 = sdfg.add_state('state0', is_start_block=True) @@ -109,7 +109,7 @@ def _make_do_for_inverted_cond_loop() -> SDFG: def _make_triple_nested_for_loop() -> SDFG: sdfg = dace.SDFG('gemm') - sdfg.using_experimental_blocks = True + sdfg.using_explicit_control_flow = True sdfg.add_symbol('i', dace.int32) sdfg.add_symbol('j', dace.int32) sdfg.add_symbol('k', dace.int32) diff --git a/tests/sdfg/schedule_inference_test.py b/tests/sdfg/schedule_inference_test.py index 1b1b3422d8..8f4fcd6acb 100644 --- a/tests/sdfg/schedule_inference_test.py +++ b/tests/sdfg/schedule_inference_test.py @@ -1,6 +1,7 @@ # Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. """ Tests for default storage/schedule inference. """ import dace +from dace.sdfg.state import SDFGState from dace.sdfg.validation import InvalidSDFGNodeError from dace.sdfg.infer_types import set_default_schedule_and_storage_types from dace.transformation.helpers import get_parent_map @@ -95,10 +96,10 @@ def top(a: dace.float64[20, 20], b: dace.float64[20, 20]): sdfg = top.to_sdfg(simplify=False) set_default_schedule_and_storage_types(sdfg, None) - for node, state in sdfg.all_nodes_recursive(): - nsdfg = state.parent - if isinstance(node, dace.nodes.AccessNode): - assert node.desc(nsdfg).storage == dace.StorageType.CPU_Heap + for sd in sdfg.all_sdfgs_recursive(): + for state in sd.states(): + for dn in state.data_nodes(): + assert dn.desc(sd).storage == dace.StorageType.CPU_Heap def test_nested_storage_equivalence(): @@ -114,13 +115,13 @@ def top(a: dace.float64[20, 20] @ dace.StorageType.CPU_Heap, b: dace.float64[20, sdfg = top.to_sdfg(simplify=False) set_default_schedule_and_storage_types(sdfg, None) - for node, state in sdfg.all_nodes_recursive(): - nsdfg = state.parent - if isinstance(node, dace.nodes.AccessNode): - if state.out_degree(node) > 0: # Check for a in external and internal scopes - assert node.desc(nsdfg).storage == dace.StorageType.CPU_Heap - elif state.in_degree(node) > 0: # Check for b in external and internal scopes - assert node.desc(nsdfg).storage == dace.StorageType.CPU_Pinned + for sd in sdfg.all_sdfgs_recursive(): + for state in sd.states(): + for dn in state.data_nodes(): + if state.out_degree(dn) > 0: # Check for a in external and internal scopes + assert dn.desc(sd).storage == dace.StorageType.CPU_Heap + elif state.in_degree(dn) > 0: # Check for b in external and internal scopes + assert dn.desc(sd).storage == dace.StorageType.CPU_Pinned def test_ambiguous_schedule(): @@ -171,7 +172,6 @@ def add(a: dace.float32[10, 10] @ dace.StorageType.GPU_Global, test_gpu_schedule_autodetect() test_gpu_schedule_scalar_autodetect() test_gpu_schedule_scalar_autodetect_2() - test_nested_kernel_computation() test_nested_map_in_loop_schedule() test_nested_storage() test_nested_storage_equivalence() diff --git a/tests/sdfg/state_test.py b/tests/sdfg/state_test.py index dc8ede776f..1a8199532c 100644 --- a/tests/sdfg/state_test.py +++ b/tests/sdfg/state_test.py @@ -1,7 +1,7 @@ -# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +from copy import deepcopy import dace from dace import subsets as sbs -from dace.transformation.helpers import find_sdfg_control_flow def test_read_write_set(): @@ -60,8 +60,8 @@ def double_loop(arr: dace.float32[N]): arr[i] *= 2 sdfg = double_loop.to_sdfg() - find_sdfg_control_flow(sdfg) - sdfg.validate() + copied_sdfg = deepcopy(sdfg) + copied_sdfg.validate() def test_read_and_write_set_filter(): diff --git a/tests/sdfg/work_depth_test.py b/tests/sdfg/work_depth_test.py index e677cca752..5ecda1cb88 100644 --- a/tests/sdfg/work_depth_test.py +++ b/tests/sdfg/work_depth_test.py @@ -1,4 +1,4 @@ -# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. """ Contains test cases for the work depth analysis. """ from typing import Dict, List, Tuple @@ -13,6 +13,7 @@ import sympy as sp import numpy as np +from dace.sdfg.utils import inline_control_flow_regions from dace.transformation.interstate import NestSDFG from dace.transformation.dataflow import MapExpansion @@ -192,11 +193,11 @@ def gemm_library_node_symbolic(x: dc.float64[M, K], y: dc.float64[K, N], z: dc.f 'nested_if_else': (nested_if_else, (sp.Max(K, 3 * N, M + N), sp.Max(3, K, M + 1))), 'max_of_positive_symbols': (max_of_positive_symbol, (3 * N**2, 3 * N)), 'multiple_array_sizes': (multiple_array_sizes, (sp.Max(2 * K, 3 * N, 2 * M + 3), 5)), - 'unbounded_while_do': (unbounded_while_do, (sp.Symbol('num_execs_0_2') * N, sp.Symbol('num_execs_0_2'))), + 'unbounded_while_do': (unbounded_while_do, (sp.Symbol('num_execs_0_5') * N, sp.Symbol('num_execs_0_5'))), # We get this Max(1, num_execs), since it is a do-while loop, but the num_execs symbol does not capture this. - 'unbounded_nonnegify': (unbounded_nonnegify, (2 * sp.Symbol('num_execs_0_7') * N, 2 * sp.Symbol('num_execs_0_7'))), + 'unbounded_nonnegify': (unbounded_nonnegify, (2 * sp.Symbol('num_execs_0_8') * N, 2 * sp.Symbol('num_execs_0_8'))), 'break_for_loop': (break_for_loop, (N**2, N)), - 'break_while_loop': (break_while_loop, (sp.Symbol('num_execs_0_5') * N, sp.Symbol('num_execs_0_5'))), + 'break_while_loop': (break_while_loop, (sp.Symbol('num_execs_0_7') * N, sp.Symbol('num_execs_0_7'))), 'sequential_ifs': (sequntial_ifs, (sp.Max(N + 1, M) + sp.Max(N + 1, M + 1), sp.Max(1, M) + 1)), 'reduction_library_node': (reduction_library_node, (456, sp.log(456))), 'reduction_library_node_symbolic': (reduction_library_node_symbolic, (N, sp.log(N))), @@ -217,6 +218,12 @@ def test_work_depth(test_name): sdfg.apply_transformations(NestSDFG) if 'nested_maps' in test.name: sdfg.apply_transformations(MapExpansion) + + # NOTE: Until the W/D Analysis is changed to make use of the new blocks, inline control flow for the analysis. + inline_control_flow_regions(sdfg) + for sd in sdfg.all_sdfgs_recursive(): + sd.using_explicit_control_flow = False + analyze_sdfg(sdfg, w_d_map, get_tasklet_work_depth, [], False) res = w_d_map[get_uuid(sdfg)] # substitue each symbol without assumptions. @@ -264,6 +271,12 @@ def test_avg_par(test_name: str): sdfg.apply_transformations(NestSDFG) if 'nested_maps' in test_name: sdfg.apply_transformations(MapExpansion) + + # NOTE: Until the W/D Analysis is changed to make use of the new blocks, inline control flow for the analysis. + inline_control_flow_regions(sdfg) + for sd in sdfg.all_sdfgs_recursive(): + sd.using_explicit_control_flow = False + analyze_sdfg(sdfg, w_d_map, get_tasklet_avg_par, [], False) res = w_d_map[get_uuid(sdfg)][0] / w_d_map[get_uuid(sdfg)][1] # substitue each symbol without assumptions. @@ -320,8 +333,8 @@ def test_assumption_system_contradictions(assumptions): for test_name in work_depth_test_cases.keys(): test_work_depth(test_name) - for test, correct in tests_cases_avg_par: - test_avg_par(test, correct) + for test_name in tests_cases_avg_par.keys(): + test_avg_par(test_name) for expr, assums, res in assumptions_tests: test_assumption_system(expr, assums, res) diff --git a/tests/state_propagation_test.py b/tests/state_propagation_test.py index 226775a0e7..2984a7707a 100644 --- a/tests/state_propagation_test.py +++ b/tests/state_propagation_test.py @@ -1,10 +1,12 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. +import pytest from dace.dtypes import Language from dace.properties import CodeProperty, CodeBlock from dace.sdfg.sdfg import InterstateEdge import dace from dace.sdfg.propagation import propagate_states +from dace.transformation.passes.simplification.control_flow_raising import ControlFlowRaising def state_check_executions(state, expected, expected_dynamic=False): @@ -16,7 +18,8 @@ def state_check_executions(state, expected, expected_dynamic=False): raise RuntimeError('Expected static executions, got dynamic') -def test_conditional_fake_merge(): +@pytest.mark.parametrize('with_regions', [False, True]) +def test_conditional_fake_merge(with_regions): sdfg = dace.SDFG('fake_merge') state_init = sdfg.add_state('init') @@ -40,13 +43,17 @@ def test_conditional_fake_merge(): sdfg.add_edge(state_c, state_e, InterstateEdge(condition=CodeProperty.from_string('not (j < 10)', language=Language.Python))) + if with_regions: + ControlFlowRaising().apply_pass(sdfg, {}) + propagate_states(sdfg) state_check_executions(state_d, 1, True) state_check_executions(state_e, 1, True) -def test_conditional_full_merge(): +@pytest.mark.parametrize('with_regions', [False, True]) +def test_conditional_full_merge(with_regions): sdfg = dace.SDFG('conditional_full_merge') sdfg.add_scalar('a', dace.int32) @@ -71,6 +78,9 @@ def test_conditional_full_merge(): sdfg.add_edge(r_branch, if_merge_2, dace.InterstateEdge()) sdfg.add_edge(if_merge_2, if_merge_1, dace.InterstateEdge()) + if with_regions: + ControlFlowRaising().apply_pass(sdfg, {}) + propagate_states(sdfg) # Check start state. @@ -92,7 +102,8 @@ def test_conditional_full_merge(): state_check_executions(if_merge_1, 1) -def test_while_inside_for(): +@pytest.mark.parametrize('with_regions', [False, True]) +def test_while_inside_for(with_regions): sdfg = dace.SDFG('while_inside_for') sdfg.add_symbol('i', dace.int32) @@ -116,13 +127,19 @@ def test_while_inside_for(): sdfg.add_edge(guard_2, loop_2, dace.InterstateEdge(condition=CodeBlock('j < 20'))) sdfg.add_edge(loop_2, guard_2, dace.InterstateEdge()) + if with_regions: + ControlFlowRaising().apply_pass(sdfg, {}) + propagate_states(sdfg) # Check start state. state_check_executions(init_state, 1) # Check the for loop guard, `i in range(20)`. - state_check_executions(guard_1, 21) + if with_regions: + state_check_executions(guard_1, 20) + else: + state_check_executions(guard_1, 21) # Check loop-end branch. state_check_executions(end_1, 1) # Check inside the loop. @@ -136,7 +153,8 @@ def test_while_inside_for(): state_check_executions(loop_2, 0, expected_dynamic=True) -def test_for_with_nested_full_merge_branch(): +@pytest.mark.parametrize('with_regions', [False, True]) +def test_for_with_nested_full_merge_branch(with_regions): sdfg = dace.SDFG('for_full_merge') sdfg.add_symbol('i', dace.int32) @@ -171,13 +189,19 @@ def test_for_with_nested_full_merge_branch(): sdfg.add_edge(r_branch, if_merge, dace.InterstateEdge()) sdfg.add_edge(if_merge, guard_1, dace.InterstateEdge(assignments={'i': 'i + 1'})) + if with_regions: + ControlFlowRaising().apply_pass(sdfg, {}) + propagate_states(sdfg) # Check start state. state_check_executions(init_state, 1) # For loop, check loop guard, `for i in range(20)`. - state_check_executions(guard_1, 21) + if with_regions: + state_check_executions(guard_1, 20) + else: + state_check_executions(guard_1, 21) # Check loop-end branch. state_check_executions(end_1, 1) # Check inside the loop. @@ -190,7 +214,8 @@ def test_for_with_nested_full_merge_branch(): state_check_executions(if_merge, 20) -def test_for_inside_branch(): +@pytest.mark.parametrize('with_regions', [False, True]) +def test_for_inside_branch(with_regions): sdfg = dace.SDFG('for_in_branch') state_init = sdfg.add_state('init') @@ -218,15 +243,22 @@ def test_for_inside_branch(): 'j': 'j + 1', })) + if with_regions: + ControlFlowRaising().apply_pass(sdfg, {}) + propagate_states(sdfg) state_check_executions(branch_guard, 1, False) - state_check_executions(loop_guard, 11, True) + if with_regions: + state_check_executions(loop_guard, 10, True) + else: + state_check_executions(loop_guard, 11, True) state_check_executions(loop_state, 10, True) state_check_executions(branch_merge, 1, False) -def test_full_merge_inside_loop(): +@pytest.mark.parametrize('with_regions', [False, True]) +def test_full_merge_inside_loop(with_regions): sdfg = dace.SDFG('full_merge_inside_loop') state_init = sdfg.add_state('init') @@ -256,16 +288,23 @@ def test_full_merge_inside_loop(): 'i': 'i + 1', })) + if with_regions: + ControlFlowRaising().apply_pass(sdfg, {}) + propagate_states(sdfg) - state_check_executions(loop_guard, 11, False) + if with_regions: + state_check_executions(loop_guard, 10, False) + else: + state_check_executions(loop_guard, 11, False) state_check_executions(branch_guard, 10, False) state_check_executions(branch_state, 10, True) state_check_executions(branch_merge, 10, False) state_check_executions(loop_end, 1, False) -def test_while_with_nested_full_merge_branch(): +@pytest.mark.parametrize('with_regions', [False, True]) +def test_while_with_nested_full_merge_branch(with_regions): sdfg = dace.SDFG('while_full_merge') sdfg.add_scalar('a', dace.int32) @@ -299,6 +338,9 @@ def test_while_with_nested_full_merge_branch(): sdfg.add_edge(r_branch, if_merge, dace.InterstateEdge()) sdfg.add_edge(if_merge, guard_1, dace.InterstateEdge()) + if with_regions: + ControlFlowRaising().apply_pass(sdfg, {}) + propagate_states(sdfg) # Check start state. @@ -318,7 +360,8 @@ def test_while_with_nested_full_merge_branch(): state_check_executions(if_merge, 0, expected_dynamic=True) -def test_3_fold_nested_loop_with_symbolic_bounds(): +@pytest.mark.parametrize('with_regions', [False, True]) +def test_3_fold_nested_loop_with_symbolic_bounds(with_regions): N = dace.symbol('N') M = dace.symbol('M') K = dace.symbol('K') @@ -355,34 +398,47 @@ def test_3_fold_nested_loop_with_symbolic_bounds(): sdfg.add_edge(guard_3, loop_3, dace.InterstateEdge(condition=CodeBlock('k < K'))) sdfg.add_edge(loop_3, guard_3, dace.InterstateEdge(assignments={'k': 'k + 1'})) + if with_regions: + ControlFlowRaising().apply_pass(sdfg, {}) + propagate_states(sdfg) # Check start state. state_check_executions(init_state, 1) # 1st level loop, check loop guard, `for i in range(N)`. - state_check_executions(guard_1, N + 1) + if with_regions: + state_check_executions(guard_1, N) + else: + state_check_executions(guard_1, N + 1) # Check loop-end branch. state_check_executions(end_1, 1) # Check inside the loop. state_check_executions(loop_1, N) # 2nd level nested loop, check loog guard, `for j in range(M)`. - state_check_executions(guard_2, M * N + N) + if with_regions: + state_check_executions(guard_2, M * N) + else: + state_check_executions(guard_2, M * N + N) # Check loop-end branch. state_check_executions(end_2, N) # Check inside the loop. state_check_executions(loop_2, M * N) # 3rd level nested loop, check loop guard, `for k in range(K)`. - state_check_executions(guard_3, M * N * K + M * N) + if with_regions: + state_check_executions(guard_3, M * N * K) + else: + state_check_executions(guard_3, M * N * K + M * N) # Check loop-end branch. state_check_executions(end_3, M * N) # Check inside the loop. state_check_executions(loop_3, M * N * K) -def test_3_fold_nested_loop(): +@pytest.mark.parametrize('with_regions', [False, True]) +def test_3_fold_nested_loop(with_regions): sdfg = dace.SDFG('nest_3') sdfg.add_symbol('i', dace.int32) @@ -415,27 +471,40 @@ def test_3_fold_nested_loop(): sdfg.add_edge(guard_3, loop_3, dace.InterstateEdge(condition=CodeBlock('k < j'))) sdfg.add_edge(loop_3, guard_3, dace.InterstateEdge(assignments={'k': 'k + 1'})) + if with_regions: + ControlFlowRaising().apply_pass(sdfg, {}) + propagate_states(sdfg) # Check start state. state_check_executions(init_state, 1) # 1st level loop, check loop guard, `for i in range(20)`. - state_check_executions(guard_1, 21) + if with_regions: + state_check_executions(guard_1, 20) + else: + # When using a state-machine-style loop, the guard is executed N+1 times for N loop iterations. + state_check_executions(guard_1, 21) # Check loop-end branch. state_check_executions(end_1, 1) # Check inside the loop. state_check_executions(loop_1, 20) # 2nd level nested loop, check loog guard, `for j in range(i, 20)`. - state_check_executions(guard_2, 230) + if with_regions: + state_check_executions(guard_2, 210) + else: + state_check_executions(guard_2, 230) # Check loop-end branch. state_check_executions(end_2, 20) # Check inside the loop. state_check_executions(loop_2, 210) # 3rd level nested loop, check loop guard, `for k in range(i, j)`. - state_check_executions(guard_3, 1540) + if with_regions: + state_check_executions(guard_3, 1330) + else: + state_check_executions(guard_3, 1540) # Check loop-end branch. state_check_executions(end_3, 210) # Check inside the loop. @@ -443,12 +512,21 @@ def test_3_fold_nested_loop(): if __name__ == "__main__": - test_3_fold_nested_loop() - test_3_fold_nested_loop_with_symbolic_bounds() - test_while_with_nested_full_merge_branch() - test_for_with_nested_full_merge_branch() - test_for_inside_branch() - test_while_inside_for() - test_conditional_full_merge() - test_conditional_fake_merge() - test_full_merge_inside_loop() + test_3_fold_nested_loop(False) + test_3_fold_nested_loop_with_symbolic_bounds(False) + test_while_with_nested_full_merge_branch(False) + test_for_with_nested_full_merge_branch(False) + test_for_inside_branch(False) + test_while_inside_for(False) + test_conditional_full_merge(False) + test_conditional_fake_merge(False) + test_full_merge_inside_loop(False) + test_3_fold_nested_loop(True) + test_3_fold_nested_loop_with_symbolic_bounds(True) + test_while_with_nested_full_merge_branch(True) + test_for_with_nested_full_merge_branch(True) + test_for_inside_branch(True) + test_while_inside_for(True) + test_conditional_full_merge(True) + test_conditional_fake_merge(True) + test_full_merge_inside_loop(True) diff --git a/tests/transformations/interstate/loop_lifting_test.py b/tests/transformations/interstate/loop_lifting_test.py index 20f244621c..676512f5f6 100644 --- a/tests/transformations/interstate/loop_lifting_test.py +++ b/tests/transformations/interstate/loop_lifting_test.py @@ -45,7 +45,7 @@ def test_lift_regular_for_loop(): sdfg(A=A_valid, N=N) sdfg.apply_transformations_repeated([LoopLifting]) - assert sdfg.using_experimental_blocks == True + assert sdfg.using_explicit_control_flow == True assert any(isinstance(x, LoopRegion) for x in sdfg.nodes()) sdfg(A=A, N=N) @@ -101,7 +101,7 @@ def test_lift_loop_llvm_canonical(increment_before_condition): sdfg(A=A_valid, N=N) sdfg.apply_transformations_repeated([LoopLifting]) - assert sdfg.using_experimental_blocks == True + assert sdfg.using_explicit_control_flow == True assert any(isinstance(x, LoopRegion) for x in sdfg.nodes()) sdfg(A=A, N=N) @@ -158,7 +158,7 @@ def test_lift_loop_llvm_canonical_while(): sdfg(A=A_valid, N=N) sdfg.apply_transformations_repeated([LoopLifting]) - assert sdfg.using_experimental_blocks == True + assert sdfg.using_explicit_control_flow == True assert any(isinstance(x, LoopRegion) for x in sdfg.nodes()) sdfg(A=A, N=N) @@ -201,7 +201,7 @@ def test_do_while(): sdfg(A=A_valid, N=N) sdfg.apply_transformations_repeated([LoopLifting]) - assert sdfg.using_experimental_blocks == True + assert sdfg.using_explicit_control_flow == True assert any(isinstance(x, LoopRegion) for x in sdfg.nodes()) sdfg(A=A, N=N) diff --git a/tests/transformations/loop_detection_test.py b/tests/transformations/loop_detection_test.py index 323a27787a..27ab7a3660 100644 --- a/tests/transformations/loop_detection_test.py +++ b/tests/transformations/loop_detection_test.py @@ -19,7 +19,8 @@ def tester(a: dace.float64[20]): for i in range(1, 20): a[i] = a[i - 1] + 1 - sdfg = tester.to_sdfg() + tester.use_explicit_cf = False + sdfg = tester.to_sdfg(simplify=False) xform = CountLoops() assert sdfg.apply_transformations(xform) == 1 itvar, rng, _ = xform.loop_information() diff --git a/tests/transformations/loop_manipulation_test.py b/tests/transformations/loop_manipulation_test.py index dbeed91464..7d87e1d2b9 100644 --- a/tests/transformations/loop_manipulation_test.py +++ b/tests/transformations/loop_manipulation_test.py @@ -27,9 +27,9 @@ def regression(A, B): def test_unroll(): sdfg: dace.SDFG = tounroll.to_sdfg() sdfg.simplify() - assert len(sdfg.nodes()) == 4 + assert len(sdfg.nodes()) == 1 sdfg.apply_transformations(LoopUnroll) - assert len(sdfg.nodes()) == (5 + 2) + assert len(sdfg.nodes()) == 5 * 2 sdfg.simplify() assert len(sdfg.nodes()) == 1 A = np.random.rand(20) @@ -47,11 +47,9 @@ def test_unroll(): def test_peeling_start(): sdfg: dace.SDFG = tounroll.to_sdfg() sdfg.simplify() - assert len(sdfg.nodes()) == 4 + assert len(sdfg.nodes()) == 1 sdfg.apply_transformations(LoopPeeling, dict(count=2)) - assert len(sdfg.nodes()) == 6 - sdfg.simplify() - assert len(sdfg.nodes()) == 4 + assert len(sdfg.nodes()) == 3 A = np.random.rand(20) B = np.random.rand(20) reg = regression(A, B) @@ -67,11 +65,9 @@ def test_peeling_start(): def test_peeling_end(): sdfg: dace.SDFG = tounroll.to_sdfg() sdfg.simplify() - assert len(sdfg.nodes()) == 4 + assert len(sdfg.nodes()) == 1 sdfg.apply_transformations(LoopPeeling, dict(count=2, begin=False)) - assert len(sdfg.nodes()) == 6 - sdfg.simplify() - assert len(sdfg.nodes()) == 4 + assert len(sdfg.nodes()) == 3 A = np.random.rand(20) B = np.random.rand(20) reg = regression(A, B) diff --git a/tests/transformations/loop_to_map_test.py b/tests/transformations/loop_to_map_test.py index 2cab97da78..5f4b5c66f9 100644 --- a/tests/transformations/loop_to_map_test.py +++ b/tests/transformations/loop_to_map_test.py @@ -3,15 +3,15 @@ import copy import os import tempfile -from typing import Tuple import numpy as np import pytest import dace -from dace.sdfg import nodes, propagation +from dace.sdfg import nodes +from dace.sdfg.state import LoopRegion from dace.transformation.interstate import LoopToMap, StateFusion -from dace.transformation.interstate.loop_detection import DetectLoop +from dace.transformation.interstate.loop_lifting import LoopLifting def make_sdfg(with_wcr, map_in_guard, reverse_loop, use_variable, assign_after, log_path): @@ -87,6 +87,8 @@ def make_sdfg(with_wcr, map_in_guard, reverse_loop, use_variable, assign_after, post_tasklet = post.add_tasklet("post", {}, {"e"}, "e = i" if use_variable else "e = N") post.add_memlet_path(post_tasklet, e, src_conn="e", memlet=dace.Memlet("E[0]")) + sdfg.apply_transformations_repeated([LoopLifting]) + return sdfg @@ -285,6 +287,7 @@ def test_interstate_dep(): ref = np.random.randint(0, 10, size=(10, ), dtype=np.int32) val = np.copy(ref) + sdfg.apply_transformations_repeated([LoopLifting]) sdfg(A=ref) assert sdfg.apply_transformations(LoopToMap) == 0 @@ -294,7 +297,8 @@ def test_interstate_dep(): def test_need_for_tasklet(): - + # Note: Since the introduction of loop regions this no longer requires a tasklet, as the nested SDFG is directly + # equivalent to the loop region, including all direct access node to access node copy operations. sdfg = dace.SDFG('needs_tasklet') aname, _ = sdfg.add_array('A', (10, ), dace.int32) bname, _ = sdfg.add_array('B', (10, ), dace.int32) @@ -304,14 +308,8 @@ def test_need_for_tasklet(): bnode = body.add_access(bname) body.add_nedge(anode, bnode, dace.Memlet(data=aname, subset='i', other_subset='9 - i')) + sdfg.apply_transformations_repeated([LoopLifting]) sdfg.apply_transformations_repeated(LoopToMap) - found = False - for n, s in sdfg.all_nodes_recursive(): - if isinstance(n, nodes.Tasklet): - found = True - break - - assert found A = np.arange(10, dtype=np.int32) B = np.empty((10, ), dtype=np.int32) @@ -321,7 +319,8 @@ def test_need_for_tasklet(): def test_need_for_transient(): - + # Note: Since the introduction of loop regions this no longer requires a transient, as the nested SDFG is directly + # equivalent to the loop region, including all direct access node to access node copy operations. sdfg = dace.SDFG('needs_transient') aname, _ = sdfg.add_array('A', (10, 10), dace.int32) bname, _ = sdfg.add_array('B', (10, 10), dace.int32) @@ -331,14 +330,8 @@ def test_need_for_transient(): bnode = body.add_access(bname) body.add_nedge(anode, bnode, dace.Memlet(data=aname, subset='0:10, i', other_subset='0:10, 9 - i')) + sdfg.apply_transformations_repeated([LoopLifting]) sdfg.apply_transformations_repeated(LoopToMap) - found = False - for n, s in sdfg.all_nodes_recursive(): - if isinstance(n, nodes.AccessNode) and n.data not in (aname, bname): - found = True - break - - assert found A = np.arange(100, dtype=np.int32).reshape(10, 10).copy() B = np.empty((10, 10), dtype=np.int32) @@ -403,6 +396,7 @@ def test_symbol_write_before_read(): sdfg.add_edge(body_start, body, dace.InterstateEdge(assignments=dict(j='0'))) sdfg.add_edge(body, body_end, dace.InterstateEdge(assignments=dict(j='j + 1'))) + sdfg.apply_transformations_repeated([LoopLifting]) assert sdfg.apply_transformations(LoopToMap) == 1 @@ -430,6 +424,7 @@ def test_symbol_array_mix(overwrite): sdfg.add_edge(body_start, body, dace.InterstateEdge(assignments=dict(sym='sym + tmp'))) sdfg.add_edge(body, body_end, dace.InterstateEdge(assignments=dict(sym='sym + 1.0'))) + sdfg.apply_transformations_repeated([LoopLifting]) assert sdfg.apply_transformations(LoopToMap) == (1 if overwrite else 0) @@ -456,6 +451,7 @@ def test_symbol_array_mix_2(parallel): t = body_start.add_tasklet('use', {}, {'o'}, 'o = sym') body_start.add_edge(t, 'o', body_start.add_write('B'), None, dace.Memlet('B[i]')) + sdfg.apply_transformations_repeated([LoopLifting]) assert sdfg.apply_transformations(LoopToMap) == (1 if parallel else 0) @@ -482,6 +478,7 @@ def test_internal_symbol_used_outside(overwrite): else: sdfg.add_edge(after, after_1, dace.InterstateEdge()) + sdfg.apply_transformations_repeated([LoopLifting]) assert sdfg.apply_transformations(LoopToMap) == (1 if overwrite else 0) @@ -511,6 +508,7 @@ def test_shared_local_transient_single_state(): body.add_edge(t1, '__out', anode, None, dace.Memlet(data='A', subset='i')) body.add_edge(anode, None, t2, '__inp', dace.Memlet(data='A', subset='i')) body.add_edge(t2, '__out', bnode, None, dace.Memlet(data='__return', subset='i')) + sdfg.apply_transformations_repeated([LoopLifting]) sdfg.apply_transformations_repeated(LoopToMap) assert 'A' in sdfg.arrays @@ -550,6 +548,7 @@ def test_thread_local_transient_single_state(): body.add_edge(anode, None, t2, '__inp', dace.Memlet(data='A', subset='i')) body.add_edge(t2, '__out', bnode, None, dace.Memlet(data='__return', subset='i')) + sdfg.apply_transformations_repeated([LoopLifting]) sdfg.apply_transformations_repeated(LoopToMap) assert not ('A' in sdfg.arrays) @@ -588,6 +587,7 @@ def test_shared_local_transient_multi_state(): body1.add_edge(anode1, None, t2, '__inp', dace.Memlet(data='A', subset='i')) body1.add_edge(t2, '__out', bnode, None, dace.Memlet(data='__return', subset='i')) + sdfg.apply_transformations_repeated([LoopLifting]) sdfg.apply_transformations_repeated(LoopToMap) assert 'A' in sdfg.arrays @@ -629,6 +629,7 @@ def test_thread_local_transient_multi_state(): body1.add_edge(anode1, None, t2, '__inp', dace.Memlet(data='A', subset='i')) body1.add_edge(t2, '__out', bnode, None, dace.Memlet(data='__return', subset='i')) + sdfg.apply_transformations_repeated([LoopLifting]) sdfg.apply_transformations_repeated(LoopToMap) assert not ('A' in sdfg.arrays) @@ -651,34 +652,25 @@ def nested_loops(A: dace.int32[10, 10, 10], l: dace.int32): sdfg = nested_loops.to_sdfg() - def find_loop(sdfg: dace.SDFG, itervar: str) -> Tuple[dace.SDFGState, dace.SDFGState, dace.SDFGState]: - - guard, begin, fexit = None, None, None - for e in sdfg.edges(): - if itervar in e.data.assignments and e.data.assignments[itervar] == '0': - guard = e.dst - elif e.data.condition.as_string in (f'({itervar} >= 10)', f'(not ({itervar} < 10))'): - fexit = e.dst - assert all(s is not None for s in (guard, fexit)) - - begin = next((e for e in sdfg.out_edges(guard) if e.dst != fexit)).dst - - return guard, begin, fexit + def find_loop(sdfg: dace.SDFG, itervar: str) -> LoopRegion: + for cfg in sdfg.all_control_flow_regions(): + if isinstance(cfg, LoopRegion) and cfg.loop_variable == itervar: + return cfg sdfg0 = copy.deepcopy(sdfg) - i_guard, i_begin, i_exit = find_loop(sdfg0, 'i') - LoopToMap.apply_to(sdfg0, loop_guard=i_guard, loop_begin=i_begin, exit_state=i_exit) + i_loop = find_loop(sdfg0, 'i') + LoopToMap.apply_to(sdfg0, loop=i_loop) nsdfg = next((sd for sd in sdfg0.all_sdfgs_recursive() if sd.parent is not None)) - j_guard, j_begin, j_exit = find_loop(nsdfg, 'j') - LoopToMap.apply_to(nsdfg, loop_guard=j_guard, loop_begin=j_begin, exit_state=j_exit) + j_loop = find_loop(nsdfg, 'j') + LoopToMap.apply_to(nsdfg, loop=j_loop) val = np.arange(1000, dtype=np.int32).reshape(10, 10, 10).copy() sdfg(A=val, l=5) assert np.allclose(ref, val) - j_guard, j_begin, j_exit = find_loop(sdfg, 'j') - LoopToMap.apply_to(sdfg, loop_guard=j_guard, loop_begin=j_begin, exit_state=j_exit) + j_loop = find_loop(sdfg, 'j') + LoopToMap.apply_to(sdfg, loop=j_loop) # NOTE: The following fails to apply because of subset A[0:i+1], which is overapproximated. # i_guard, i_begin, i_exit = find_loop(sdfg, 'i') # LoopToMap.apply_to(sdfg, loop_guard=i_guard, loop_begin=i_begin, exit_state=i_exit) @@ -718,7 +710,7 @@ def internal_write(inp0: dace.int32[10], inp1: dace.int32[10], out: dace.int32[1 val = np.empty((10, ), dtype=np.int32) internal_write.f(inp0, inp1, ref) - internal_write(inp0, inp1, val) + sdfg(inp0, inp1, val) assert np.array_equal(val, ref) @@ -741,14 +733,15 @@ def test_rotated_loop_to_map(simplify): sdfg.add_edge(guard, exitstate, dace.InterstateEdge('N <= 0')) sdfg.add_edge(guard, preheader, dace.InterstateEdge('N > 0')) sdfg.add_edge(preheader, body, dace.InterstateEdge(assignments=dict(i=0))) - sdfg.add_edge(body, latch, dace.InterstateEdge()) - sdfg.add_edge(latch, body, dace.InterstateEdge('i < N', assignments=dict(i='i + 1'))) + sdfg.add_edge(body, latch, dace.InterstateEdge(assignments=dict(i='i + 1'))) + sdfg.add_edge(latch, body, dace.InterstateEdge('i < N')) sdfg.add_edge(latch, loopexit, dace.InterstateEdge('i >= N')) sdfg.add_edge(loopexit, exitstate, dace.InterstateEdge()) t = body.add_tasklet('addone', {'inp'}, {'out'}, 'out = inp + 1') body.add_edge(body.add_read('A'), None, t, 'inp', dace.Memlet('A[i]')) body.add_edge(t, 'out', body.add_write('A'), None, dace.Memlet('A[i]')) + sdfg.apply_transformations_repeated([LoopLifting]) if simplify: sdfg.apply_transformations_repeated(StateFusion) @@ -779,6 +772,8 @@ def test_self_loop_to_map(): body.add_edge(body.add_read('A'), None, t, 'inp', dace.Memlet('A[i]')) body.add_edge(t, 'out', body.add_write('A'), None, dace.Memlet('A[i]')) + sdfg.apply_transformations_repeated([LoopLifting]) + assert sdfg.apply_transformations_repeated(LoopToMap) == 1 a = np.random.rand(20) diff --git a/tests/transformations/move_assignment_outside_if_test.py b/tests/transformations/move_assignment_outside_if_test.py index 270fd8f842..13725738e7 100644 --- a/tests/transformations/move_assignment_outside_if_test.py +++ b/tests/transformations/move_assignment_outside_if_test.py @@ -1,5 +1,7 @@ -# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. + import dace +from dace.sdfg.state import ConditionalBlock from dace.transformation.interstate import MoveAssignmentOutsideIf from dace.sdfg import InterstateEdge from dace.memlet import Memlet @@ -35,18 +37,23 @@ def one_variable_simple_test(const_value: int = 0): # Create if-else condition such that either the formula state or the const state is executed sdfg.add_edge(guard, formula_state, InterstateEdge(condition='B[0] < 0.5')) sdfg.add_edge(guard, const_state, InterstateEdge(condition='B[0] >= 0.5')) + sdfg.simplify() sdfg.validate() # Assure transformation is applied assert sdfg.apply_transformations_repeated([MoveAssignmentOutsideIf]) == 1 + sdfg.simplify() # SDFG now starts with a state containing the const_tasklet - assert const_tasklet in sdfg.start_state.nodes() - # The formula state has only one in_edge with the condition - assert len(sdfg.in_edges(formula_state)) == 1 - assert sdfg.in_edges(formula_state)[0].data.condition.as_string == '(B[0] < 0.5)' - # All state have at most one out_edge -> there is no if-else branching anymore - for state in sdfg.states(): - assert len(sdfg.out_edges(state)) <= 1 + assert const_tasklet in sdfg.start_block.nodes() + # There should now only be one conditional branch remaining in the entire SDFG. + conditional = None + for n in sdfg.nodes(): + if isinstance(n, ConditionalBlock): + conditional = n + break + assert conditional is not None + assert len(conditional.branches) == 1 + assert conditional.branches[0][0].as_string == '(B[0] < 0.5)' def multiple_variable_test(): @@ -89,21 +96,26 @@ def multiple_variable_test(): # Create if-else condition such that either the formula state or the const state is executed sdfg.add_edge(guard, formula_state, InterstateEdge(condition='D[0] < 0.5')) sdfg.add_edge(guard, const_state, InterstateEdge(condition='D[0] >= 0.5')) + sdfg.simplify() sdfg.validate() # Assure transformation is applied assert sdfg.apply_transformations_repeated([MoveAssignmentOutsideIf]) == 1 + sdfg.simplify() # There are no other tasklets in the start state beside the const assignment tasklet as there are no other const # assignments - for node in sdfg.start_state.nodes(): + for node in sdfg.start_block.nodes(): if isinstance(node, Tasklet): assert node == const_tasklet_a or node == const_tasklet_b - # The formula state has only one in_edge with the condition - assert len(sdfg.in_edges(formula_state)) == 1 - assert sdfg.in_edges(formula_state)[0].data.condition.as_string == '(D[0] < 0.5)' - # All state have at most one out_edge -> there is no if-else branching anymore - for state in sdfg.states(): - assert len(sdfg.out_edges(state)) <= 1 + # There should now only be one conditional branch remaining in the entire SDFG. + conditional = None + for n in sdfg.nodes(): + if isinstance(n, ConditionalBlock): + conditional = n + break + assert conditional is not None + assert len(conditional.branches) == 1 + assert conditional.branches[0][0].as_string == '(D[0] < 0.5)' def multiple_variable_not_all_const_test(): @@ -145,6 +157,7 @@ def multiple_variable_not_all_const_test(): # Create if-else condition such that either the formula state or the const state is executed sdfg.add_edge(guard, formula_state, InterstateEdge(condition='C[0] < 0.5')) sdfg.add_edge(guard, const_state, InterstateEdge(condition='C[0] >= 0.5')) + sdfg.simplify() sdfg.validate() # Assure transformation is applied @@ -154,24 +167,18 @@ def multiple_variable_not_all_const_test(): for node in sdfg.start_state.nodes(): if isinstance(node, Tasklet): assert node == const_tasklet_a or node == const_tasklet_b - # The formula state has only one in_edge with the condition - assert len(sdfg.in_edges(formula_state)) == 1 - assert sdfg.in_edges(formula_state)[0].data.condition.as_string == '(C[0] < 0.5)' - # Guard still has two outgoing edges as if-else pattern still exists - assert len(sdfg.out_edges(guard)) == 2 - # const state now has only const_tasklet_b left plus two access nodes - assert len(const_state.nodes()) == 3 - for node in const_state.nodes(): - if isinstance(node, Tasklet): - assert node == const_tasklet_b + # The conditional should still have two conditional branches + conditional = None + for n in sdfg.nodes(): + if isinstance(n, ConditionalBlock): + conditional = n + break + assert conditional is not None + assert len(conditional.branches) == 2 -def main(): +if __name__ == '__main__': one_variable_simple_test(0) one_variable_simple_test(2) multiple_variable_test() multiple_variable_not_all_const_test() - - -if __name__ == '__main__': - main() diff --git a/tests/transformations/move_loop_into_map_test.py b/tests/transformations/move_loop_into_map_test.py index ad51941cb0..70960f8239 100644 --- a/tests/transformations/move_loop_into_map_test.py +++ b/tests/transformations/move_loop_into_map_test.py @@ -1,7 +1,8 @@ -# Copyright 2019-2021 ETH Zurich and the DaCe authors. All rights reserved. +# Copyright 2019-2024 ETH Zurich and the DaCe authors. All rights reserved. + import dace +from dace.sdfg.state import LoopRegion from dace.transformation.interstate import MoveLoopIntoMap -import unittest import copy import numpy as np @@ -70,261 +71,271 @@ def apply_multiple_times_1(A: dace.float64[10, 10, 10, 10]): A[k, i, j, l] = k * 1000 + i * 100 + j * 10 + l -class MoveLoopIntoMapTest(unittest.TestCase): - - def semantic_eq(self, program): - A1 = np.random.rand(16, 16) - A2 = np.copy(A1) - - sdfg = program.to_sdfg(simplify=True) - sdfg(A1, I=A1.shape[0], J=A1.shape[1]) - - count = sdfg.apply_transformations(MoveLoopIntoMap) - self.assertGreater(count, 0) - sdfg(A2, I=A2.shape[0], J=A2.shape[1]) - - self.assertTrue(np.allclose(A1, A2)) - - def test_forward_loops_semantic_eq(self): - self.semantic_eq(forward_loop) - - def test_backward_loops_semantic_eq(self): - self.semantic_eq(backward_loop) - - def test_multiple_edges(self): - self.semantic_eq(multiple_edges) - - def test_itervar_in_map_range(self): - sdfg = should_not_apply_1.to_sdfg(simplify=True) - count = sdfg.apply_transformations(MoveLoopIntoMap) - self.assertEqual(count, 0) - - def test_itervar_in_data(self): - sdfg = should_not_apply_2.to_sdfg(simplify=True) - count = sdfg.apply_transformations(MoveLoopIntoMap) - self.assertEqual(count, 0) - - def test_non_injective_index(self): - sdfg = should_not_apply_3.to_sdfg(simplify=True) - count = sdfg.apply_transformations(MoveLoopIntoMap) - self.assertEqual(count, 0) - - def test_apply_multiple_times(self): - sdfg = apply_multiple_times.to_sdfg(simplify=True) - overall = 0 - count = 1 - while (count > 0): - count = sdfg.apply_transformations_repeated(MoveLoopIntoMap, permissive=True) - overall += count - sdfg.simplify() - - self.assertEqual(overall, 2) - - val = np.zeros((10, 10, 10), dtype=np.float64) - ref = val.copy() - - sdfg(A=val) - apply_multiple_times.f(ref) - - self.assertTrue(np.allclose(val, ref)) - - def test_apply_multiple_times_1(self): - sdfg = apply_multiple_times_1.to_sdfg(simplify=True) - overall = 0 - count = 1 - while (count > 0): - count = sdfg.apply_transformations_repeated(MoveLoopIntoMap, permissive=True) - overall += count - sdfg.simplify() - - self.assertEqual(overall, 2) - - val = np.zeros((10, 10, 10, 10), dtype=np.float64) - ref = val.copy() - - sdfg(A=val) - apply_multiple_times_1.f(ref) - - self.assertTrue(np.allclose(val, ref)) - - def test_more_than_a_map(self): - """ - `out` is read and written indirectly by the MapExit, potentially leading to a RW dependency. - - Note that there is actually no dependency, however, the transformation, because it relies - on `SDFGState.read_and_write_sets()` it can not detect this and can thus not be applied. - """ - sdfg = dace.SDFG('more_than_a_map') - _, aarr = sdfg.add_array('A', (3, 3), dace.float64) - _, barr = sdfg.add_array('B', (3, 3), dace.float64) - _, oarr = sdfg.add_array('out', (3, 3), dace.float64) - _, tarr = sdfg.add_array('tmp', (3, 3), dace.float64, transient=True) - body = sdfg.add_state('map_state') - aread = body.add_access('A') - oread = body.add_access('out') - bread = body.add_access('B') - twrite = body.add_access('tmp') - owrite = body.add_access('out') - body.add_mapped_tasklet('op', - dict(i='0:3', j='0:3'), - dict(__in1=dace.Memlet('out[i, j]'), __in2=dace.Memlet('B[i, j]')), - '__out = __in1 - __in2', - dict(__out=dace.Memlet('tmp[i, j]')), - external_edges=True, - input_nodes=dict(out=oread, B=bread), - output_nodes=dict(tmp=twrite)) - body.add_nedge(aread, oread, dace.Memlet.from_array('A', oarr)) - body.add_nedge(twrite, owrite, dace.Memlet.from_array('out', oarr)) - sdfg.add_loop(None, body, None, '_', '0', '_ < 10', '_ + 1') - - count = sdfg.apply_transformations(MoveLoopIntoMap, validate_all=True, validate=True) - self.assertTrue(count == 0) - - def test_more_than_a_map_1(self): - """ - `out` is written indirectly by the MapExit but is not read and, therefore, does not create a RW dependency. - """ - sdfg = dace.SDFG('more_than_a_map_1') - _, aarr = sdfg.add_array('A', (3, 3), dace.float64) - _, barr = sdfg.add_array('B', (3, 3), dace.float64) - _, oarr = sdfg.add_array('out', (3, 3), dace.float64) - _, tarr = sdfg.add_array('tmp', (3, 3), dace.float64, transient=True) - body = sdfg.add_state('map_state') - aread = body.add_access('A') - bread = body.add_access('B') - twrite = body.add_access('tmp') - owrite = body.add_access('out') - body.add_mapped_tasklet('op', - dict(i='0:3', j='0:3'), - dict(__in1=dace.Memlet('A[i, j]'), __in2=dace.Memlet('B[i, j]')), - '__out = __in1 - __in2', - dict(__out=dace.Memlet('tmp[i, j]')), - external_edges=True, - input_nodes=dict(A=aread, B=bread), - output_nodes=dict(tmp=twrite)) - body.add_nedge(twrite, owrite, dace.Memlet.from_array('out', oarr)) - sdfg.add_loop(None, body, None, '_', '0', '_ < 10', '_ + 1') - count = sdfg.apply_transformations(MoveLoopIntoMap) - self.assertTrue(count > 0) - - A = np.arange(9, dtype=np.float64).reshape(3, 3).copy() - B = np.arange(9, 18, dtype=np.float64).reshape(3, 3).copy() - val = np.empty((3, 3), dtype=np.float64) - sdfg(A=A, B=B, out=val) - - def reference(A, B): - for i in range(10): - tmp = A - B - out = tmp - return out - - ref = reference(A, B) - self.assertTrue(np.allclose(val, ref)) - - def test_more_than_a_map_2(self): - """ `out` is written indirectly by the MapExit with a subset dependent on the loop variable. This creates a RW - dependency. - """ - sdfg = dace.SDFG('more_than_a_map_2') - _, aarr = sdfg.add_array('A', (3, 3), dace.float64) - _, barr = sdfg.add_array('B', (3, 3), dace.float64) - _, oarr = sdfg.add_array('out', (3, 3), dace.float64) - _, tarr = sdfg.add_array('tmp', (3, 3), dace.float64, transient=True) - body = sdfg.add_state('map_state') - aread = body.add_access('A') - bread = body.add_access('B') - twrite = body.add_access('tmp') - owrite = body.add_access('out') - body.add_mapped_tasklet('op', - dict(i='0:3', j='0:3'), - dict(__in1=dace.Memlet('A[i, j]'), __in2=dace.Memlet('B[i, j]')), - '__out = __in1 - __in2', - dict(__out=dace.Memlet('tmp[i, j]')), - external_edges=True, - input_nodes=dict(A=aread, B=bread), - output_nodes=dict(tmp=twrite)) - body.add_nedge(twrite, owrite, dace.Memlet('out[k%3, (k+1)%3]', other_subset='(k+1)%3, k%3')) - sdfg.add_loop(None, body, None, 'k', '0', 'k < 10', 'k + 1') - count = sdfg.apply_transformations(MoveLoopIntoMap) - self.assertFalse(count > 0) - - def test_more_than_a_map_3(self): - """ There are more than one connected components in the loop body. The transformation should not apply. """ - sdfg = dace.SDFG('more_than_a_map_3') - _, aarr = sdfg.add_array('A', (3, 3), dace.float64) - _, barr = sdfg.add_array('B', (3, 3), dace.float64) - _, oarr = sdfg.add_array('out', (3, 3), dace.float64) - _, tarr = sdfg.add_array('tmp', (3, 3), dace.float64, transient=True) - body = sdfg.add_state('map_state') - aread = body.add_access('A') - bread = body.add_access('B') - twrite = body.add_access('tmp') - owrite = body.add_access('out') - body.add_mapped_tasklet('op', - dict(i='0:3', j='0:3'), - dict(__in1=dace.Memlet('A[i, j]'), __in2=dace.Memlet('B[i, j]')), - '__out = __in1 - __in2', - dict(__out=dace.Memlet('tmp[i, j]')), - external_edges=True, - input_nodes=dict(A=aread, B=bread), - output_nodes=dict(tmp=twrite)) - body.add_nedge(twrite, owrite, dace.Memlet.from_array('out', oarr)) - aread2 = body.add_access('A') - owrite2 = body.add_access('out') - body.add_nedge(aread2, owrite2, dace.Memlet.from_array('out', oarr)) - sdfg.add_loop(None, body, None, '_', '0', '_ < 10', '_ + 1') - count = sdfg.apply_transformations(MoveLoopIntoMap) - self.assertFalse(count > 0) - - def test_more_than_a_map_4(self): - """ - The test is very similar to `test_more_than_a_map()`. But a memlet is different - which leads to a RW dependency, which blocks the transformation. - """ - sdfg = dace.SDFG('more_than_a_map') - _, aarr = sdfg.add_array('A', (3, 3), dace.float64) - _, barr = sdfg.add_array('B', (3, 3), dace.float64) - _, oarr = sdfg.add_array('out', (3, 3), dace.float64) - _, tarr = sdfg.add_array('tmp', (3, 3), dace.float64, transient=True) - body = sdfg.add_state('map_state') - aread = body.add_access('A') - oread = body.add_access('out') - bread = body.add_access('B') - twrite = body.add_access('tmp') - owrite = body.add_access('out') - body.add_mapped_tasklet('op', - dict(i='0:3', j='0:3'), - dict(__in1=dace.Memlet('out[i, j]'), __in2=dace.Memlet('B[i, j]')), - '__out = __in1 - __in2', - dict(__out=dace.Memlet('tmp[i, j]')), - external_edges=True, - input_nodes=dict(out=oread, B=bread), - output_nodes=dict(tmp=twrite)) - body.add_nedge(aread, oread, dace.Memlet('A[Mod(_, 3), 0:3] -> [Mod(_ + 1, 3), 0:3]', aarr)) - body.add_nedge(twrite, owrite, dace.Memlet.from_array('out', oarr)) - sdfg.add_loop(None, body, None, '_', '0', '_ < 10', '_ + 1') - - sdfg_args_ref = { - "A": np.array(np.random.rand(3, 3), dtype=np.float64), - "B": np.array(np.random.rand(3, 3), dtype=np.float64), - "out": np.array(np.random.rand(3, 3), dtype=np.float64), - } - sdfg_args_res = copy.deepcopy(sdfg_args_ref) - - # Perform the reference execution - sdfg(**sdfg_args_ref) - - # Apply the transformation and execute the SDFG again. - count = sdfg.apply_transformations(MoveLoopIntoMap, validate_all=True, validate=True) - sdfg(**sdfg_args_res) - - for name in sdfg_args_ref.keys(): - self.assertTrue( - np.allclose(sdfg_args_ref[name], sdfg_args_res[name]), - f"Miss match for {name}", - ) - self.assertFalse(count > 0) +def _semantic_eq(program): + A1 = np.random.rand(16, 16) + A2 = np.copy(A1) + + sdfg = program.to_sdfg(simplify=True) + sdfg(A1, I=A1.shape[0], J=A1.shape[1]) + + count = sdfg.apply_transformations(MoveLoopIntoMap) + assert count > 0 + sdfg(A2, I=A2.shape[0], J=A2.shape[1]) + + assert np.allclose(A1, A2) + +def test_forward_loops_semantic_eq(): + _semantic_eq(forward_loop) + +def test_backward_loops_semantic_eq(): + _semantic_eq(backward_loop) + +def test_multiple_edges(): + _semantic_eq(multiple_edges) + +def test_itervar_in_map_range(): + sdfg = should_not_apply_1.to_sdfg(simplify=True) + count = sdfg.apply_transformations(MoveLoopIntoMap) + assert count == 0 + +def test_itervar_in_data(): + sdfg = should_not_apply_2.to_sdfg(simplify=True) + count = sdfg.apply_transformations(MoveLoopIntoMap) + assert count == 0 + +def test_non_injective_index(): + sdfg = should_not_apply_3.to_sdfg(simplify=True) + count = sdfg.apply_transformations(MoveLoopIntoMap) + assert count == 0 + +def test_apply_multiple_times(): + sdfg = apply_multiple_times.to_sdfg(simplify=True) + overall = 0 + count = 1 + while (count > 0): + count = sdfg.apply_transformations_repeated(MoveLoopIntoMap, permissive=True) + overall += count + sdfg.simplify() + + assert overall == 2 + + val = np.zeros((10, 10, 10), dtype=np.float64) + ref = val.copy() + + sdfg(A=val) + apply_multiple_times.f(ref) + + assert np.allclose(val, ref) + +def test_apply_multiple_times_1(): + sdfg = apply_multiple_times_1.to_sdfg(simplify=True) + overall = 0 + count = 1 + while (count > 0): + count = sdfg.apply_transformations_repeated(MoveLoopIntoMap, permissive=True) + overall += count + sdfg.simplify() + + assert overall == 2 + + val = np.zeros((10, 10, 10, 10), dtype=np.float64) + ref = val.copy() + + sdfg(A=val) + apply_multiple_times_1.f(ref) + + assert np.allclose(val, ref) + +def test_more_than_a_map(): + """ `out` is read and written indirectly by the MapExit, potentially leading to a RW dependency. + + Note that there is actually no dependency, however, the transformation, because it relies + on `SDFGState.read_and_write_sets()` it can not detect this and can thus not be applied. + """ + sdfg = dace.SDFG('more_than_a_map') + _, aarr = sdfg.add_array('A', (3, 3), dace.float64) + _, barr = sdfg.add_array('B', (3, 3), dace.float64) + _, oarr = sdfg.add_array('out', (3, 3), dace.float64) + _, tarr = sdfg.add_array('tmp', (3, 3), dace.float64, transient=True) + loop = LoopRegion('myloop', '_ < 10', '_', '_ = 0', '_ = _ + 1') + sdfg.add_node(loop) + body = loop.add_state('map_state') + aread = body.add_access('A') + oread = body.add_access('out') + bread = body.add_access('B') + twrite = body.add_access('tmp') + owrite = body.add_access('out') + body.add_mapped_tasklet('op', + dict(i='0:3', j='0:3'), + dict(__in1=dace.Memlet('out[i, j]'), __in2=dace.Memlet('B[i, j]')), + '__out = __in1 - __in2', + dict(__out=dace.Memlet('tmp[i, j]')), + external_edges=True, + input_nodes=dict(out=oread, B=bread), + output_nodes=dict(tmp=twrite)) + body.add_nedge(aread, oread, dace.Memlet.from_array('A', aarr)) + body.add_nedge(twrite, owrite, dace.Memlet.from_array('out', oarr)) + count = sdfg.apply_transformations(MoveLoopIntoMap) + assert count == 0 + +def test_more_than_a_map_1(): + """ + `out` is written indirectly by the MapExit but is not read and, therefore, does not create a RW dependency. + """ + sdfg = dace.SDFG('more_than_a_map_1') + _, aarr = sdfg.add_array('A', (3, 3), dace.float64) + _, barr = sdfg.add_array('B', (3, 3), dace.float64) + _, oarr = sdfg.add_array('out', (3, 3), dace.float64) + _, tarr = sdfg.add_array('tmp', (3, 3), dace.float64, transient=True) + loop = LoopRegion('myloop', '_ < 10', '_', '_ = 0', '_ = _ + 1') + sdfg.add_node(loop) + body = loop.add_state('map_state') + aread = body.add_access('A') + bread = body.add_access('B') + twrite = body.add_access('tmp') + owrite = body.add_access('out') + body.add_mapped_tasklet('op', + dict(i='0:3', j='0:3'), + dict(__in1=dace.Memlet('A[i, j]'), __in2=dace.Memlet('B[i, j]')), + '__out = __in1 - __in2', + dict(__out=dace.Memlet('tmp[i, j]')), + external_edges=True, + input_nodes=dict(A=aread, B=bread), + output_nodes=dict(tmp=twrite)) + body.add_nedge(twrite, owrite, dace.Memlet.from_array('out', oarr)) + count = sdfg.apply_transformations(MoveLoopIntoMap) + assert count > 0 + + A = np.arange(9, dtype=np.float64).reshape(3, 3).copy() + B = np.arange(9, 18, dtype=np.float64).reshape(3, 3).copy() + val = np.empty((3, 3), dtype=np.float64) + sdfg(A=A, B=B, out=val) + + def reference(A, B): + for i in range(10): + tmp = A - B + out = tmp + return out + + ref = reference(A, B) + assert np.allclose(val, ref) + +def test_more_than_a_map_2(): + """ `out` is written indirectly by the MapExit with a subset dependent on the loop variable. This creates a RW + dependency. + """ + sdfg = dace.SDFG('more_than_a_map_2') + _, aarr = sdfg.add_array('A', (3, 3), dace.float64) + _, barr = sdfg.add_array('B', (3, 3), dace.float64) + _, oarr = sdfg.add_array('out', (3, 3), dace.float64) + _, tarr = sdfg.add_array('tmp', (3, 3), dace.float64, transient=True) + loop = LoopRegion('myloop', 'k < 10', 'k', 'k = 0', 'k = k + 1') + sdfg.add_node(loop) + body = loop.add_state('map_state') + aread = body.add_access('A') + bread = body.add_access('B') + twrite = body.add_access('tmp') + owrite = body.add_access('out') + body.add_mapped_tasklet('op', + dict(i='0:3', j='0:3'), + dict(__in1=dace.Memlet('A[i, j]'), __in2=dace.Memlet('B[i, j]')), + '__out = __in1 - __in2', + dict(__out=dace.Memlet('tmp[i, j]')), + external_edges=True, + input_nodes=dict(A=aread, B=bread), + output_nodes=dict(tmp=twrite)) + body.add_nedge(twrite, owrite, dace.Memlet('out[k%3, (k+1)%3]', other_subset='(k+1)%3, k%3')) + count = sdfg.apply_transformations(MoveLoopIntoMap) + assert count == 0 + +def test_more_than_a_map_3(): + """ There are more than one connected components in the loop body. The transformation should not apply. """ + sdfg = dace.SDFG('more_than_a_map_3') + _, aarr = sdfg.add_array('A', (3, 3), dace.float64) + _, barr = sdfg.add_array('B', (3, 3), dace.float64) + _, oarr = sdfg.add_array('out', (3, 3), dace.float64) + _, tarr = sdfg.add_array('tmp', (3, 3), dace.float64, transient=True) + loop = LoopRegion('myloop', '_ < 10', '_', '_ = 0', '_ = _ + 1') + sdfg.add_node(loop) + body = loop.add_state('map_state') + aread = body.add_access('A') + bread = body.add_access('B') + twrite = body.add_access('tmp') + owrite = body.add_access('out') + body.add_mapped_tasklet('op', + dict(i='0:3', j='0:3'), + dict(__in1=dace.Memlet('A[i, j]'), __in2=dace.Memlet('B[i, j]')), + '__out = __in1 - __in2', + dict(__out=dace.Memlet('tmp[i, j]')), + external_edges=True, + input_nodes=dict(A=aread, B=bread), + output_nodes=dict(tmp=twrite)) + body.add_nedge(twrite, owrite, dace.Memlet.from_array('out', oarr)) + aread2 = body.add_access('A') + owrite2 = body.add_access('out') + body.add_nedge(aread2, owrite2, dace.Memlet.from_array('out', oarr)) + count = sdfg.apply_transformations(MoveLoopIntoMap) + assert count == 0 + +def test_more_than_a_map_4(): + """ + The test is very similar to `test_more_than_a_map()`. But a memlet is different + which leads to a RW dependency, which blocks the transformation. + """ + sdfg = dace.SDFG('more_than_a_map') + _, aarr = sdfg.add_array('A', (3, 3), dace.float64) + _, barr = sdfg.add_array('B', (3, 3), dace.float64) + _, oarr = sdfg.add_array('out', (3, 3), dace.float64) + _, tarr = sdfg.add_array('tmp', (3, 3), dace.float64, transient=True) + body = sdfg.add_state('map_state') + aread = body.add_access('A') + oread = body.add_access('out') + bread = body.add_access('B') + twrite = body.add_access('tmp') + owrite = body.add_access('out') + body.add_mapped_tasklet('op', + dict(i='0:3', j='0:3'), + dict(__in1=dace.Memlet('out[i, j]'), __in2=dace.Memlet('B[i, j]')), + '__out = __in1 - __in2', + dict(__out=dace.Memlet('tmp[i, j]')), + external_edges=True, + input_nodes=dict(out=oread, B=bread), + output_nodes=dict(tmp=twrite)) + body.add_nedge(aread, oread, dace.Memlet('A[Mod(_, 3), 0:3] -> [Mod(_ + 1, 3), 0:3]', aarr)) + body.add_nedge(twrite, owrite, dace.Memlet.from_array('out', oarr)) + sdfg.add_loop(None, body, None, '_', '0', '_ < 10', '_ + 1') + + sdfg_args_ref = { + "A": np.array(np.random.rand(3, 3), dtype=np.float64), + "B": np.array(np.random.rand(3, 3), dtype=np.float64), + "out": np.array(np.random.rand(3, 3), dtype=np.float64), + } + sdfg_args_res = copy.deepcopy(sdfg_args_ref) + + # Perform the reference execution + sdfg(**sdfg_args_ref) + + # Apply the transformation and execute the SDFG again. + count = sdfg.apply_transformations(MoveLoopIntoMap, validate_all=True, validate=True) + sdfg(**sdfg_args_res) + + for name in sdfg_args_ref.keys(): + assert np.allclose(sdfg_args_ref[name], sdfg_args_res[name]) + assert count == 0 + if __name__ == '__main__': - unittest.main() + test_forward_loops_semantic_eq() + test_backward_loops_semantic_eq() + test_multiple_edges() + test_itervar_in_map_range() + test_itervar_in_data() + test_non_injective_index() + test_apply_multiple_times() + test_apply_multiple_times_1() + test_more_than_a_map() + test_more_than_a_map_1() + test_more_than_a_map_2() + test_more_than_a_map_3() + test_more_than_a_map_4() diff --git a/tests/transformations/nest_subgraph_test.py b/tests/transformations/nest_subgraph_test.py index 623b029c3a..a9ed62cbdf 100644 --- a/tests/transformations/nest_subgraph_test.py +++ b/tests/transformations/nest_subgraph_test.py @@ -71,7 +71,7 @@ def symbolic_return(): cft = cf.structured_control_flow_tree(sdfg, None) for_scope = None for i, child in enumerate(cft.children): - if isinstance(child, (cf.ForScope, cf.WhileScope)): + if isinstance(child, (cf.GeneralLoopScope)): for_scope = child break assert for_scope @@ -80,11 +80,9 @@ def symbolic_return(): exit_scope = cft.children[i+1] assert isinstance(exit_scope, cf.BasicCFBlock) - guard = for_scope.guard - fexit = exit_scope.first_block - states = list(utils.dfs_conditional(sdfg, [guard], lambda p, _: p is not fexit)) + states = for_scope.loop.nodes() - nest_sdfg_subgraph(sdfg, SubgraphView(sdfg, states), start=guard) + nest_sdfg_subgraph(sdfg, SubgraphView(for_scope.loop, states)) result = sdfg() val = result[1][0] diff --git a/tests/transformations/redundant_copy_test.py b/tests/transformations/redundant_copy_test.py index 2c753c6fc5..280d5f182a 100644 --- a/tests/transformations/redundant_copy_test.py +++ b/tests/transformations/redundant_copy_test.py @@ -450,7 +450,6 @@ def flip_and_flatten(a, b): if __name__ == '__main__': - test_slicing_with_redundant_arrays() test_in() test_out() test_out_success()