diff --git a/dace/sdfg/state.py b/dace/sdfg/state.py index fd0d5e99f3..1946b19c5b 100644 --- a/dace/sdfg/state.py +++ b/dace/sdfg/state.py @@ -726,11 +726,11 @@ def update_if_not_none(dic, update): defined_syms[str(sym)] = sym.dtype # Add inter-state symbols - if isinstance(sdfg.start_block, LoopRegion): + if isinstance(sdfg.start_block, AbstractControlFlowRegion): update_if_not_none(defined_syms, sdfg.start_block.new_symbols(defined_syms)) for edge in sdfg.all_interstate_edges(): update_if_not_none(defined_syms, edge.data.new_symbols(sdfg, defined_syms)) - if isinstance(edge.dst, LoopRegion): + if isinstance(edge.dst, AbstractControlFlowRegion): update_if_not_none(defined_syms, edge.dst.new_symbols(defined_syms)) # Add scope symbols all the way to the subgraph @@ -2722,6 +2722,12 @@ def inline(self) -> Tuple[bool, Any]: return False, None + def new_symbols(self, symbols: dict) -> Dict[str, dtypes.typeclass]: + """ + Returns a mapping between the symbol defined by this control flow region and its type, if it exists. + """ + return {} + ################################################################### # CFG API methods @@ -3306,9 +3312,6 @@ def _used_symbols_internal(self, return free_syms, defined_syms, used_before_assignment def new_symbols(self, symbols) -> Dict[str, dtypes.typeclass]: - """ - Returns a mapping between the symbol defined by this loop and its type, if it exists. - """ # Avoid cyclic import from dace.codegen.tools.type_inference import infer_expr_type from dace.transformation.passes.analysis import loop_analysis @@ -3402,11 +3405,7 @@ def add_branch(self, condition: Optional[CodeBlock], branch: ControlFlowRegion): branch.sdfg = self.sdfg def remove_branch(self, branch: ControlFlowRegion): - filtered_branches = [] - for c, b in self._branches: - if b is not branch: - filtered_branches.append((c, b)) - self._branches = filtered_branches + self._branches = [(c, b) for c, b in self._branches if b is not branch] def _used_symbols_internal(self, all_symbols: bool, diff --git a/dace/transformation/interstate/loop_to_map.py b/dace/transformation/interstate/loop_to_map.py index 9f487f561a..230545eecf 100644 --- a/dace/transformation/interstate/loop_to_map.py +++ b/dace/transformation/interstate/loop_to_map.py @@ -492,7 +492,7 @@ def apply(self, graph: ControlFlowRegion, sdfg: sd.SDFG): # Direct edges among source and sink access nodes must pass through a tasklet. # We first gather them and handle them later. - direct_edges: Set[gr.Edge[memlet.Memlet]] = set() + direct_edges: Set[gr.MultiConnectorEdge[memlet.Memlet]] = set() for n1 in source_nodes: if not isinstance(n1, nodes.AccessNode): continue diff --git a/dace/transformation/passes/analysis/analysis.py b/dace/transformation/passes/analysis/analysis.py index 2bd90a4f22..ffd8a6134f 100644 --- a/dace/transformation/passes/analysis/analysis.py +++ b/dace/transformation/passes/analysis/analysis.py @@ -240,6 +240,19 @@ def should_reapply(self, modified: ppl.Modifies) -> bool: # If access nodes were modified, reapply return modified & ppl.Modifies.AccessNodes + def _get_loop_region_readset(self, loop: LoopRegion, arrays: Set[str]) -> Set[str]: + readset = set() + exprs = { loop.loop_condition.as_string } + update_stmt = loop_analysis.get_update_assignment(loop) + init_stmt = loop_analysis.get_init_assignment(loop) + if update_stmt: + exprs.add(update_stmt) + if init_stmt: + exprs.add(init_stmt) + for expr in exprs: + readset |= symbolic.free_symbols_and_functions(expr) & arrays + return readset + def apply_pass(self, top_sdfg: SDFG, _) -> Dict[ControlFlowBlock, Tuple[Set[str], Set[str]]]: """ :return: A dictionary mapping each control flow block to a tuple of its (read, written) data descriptors. @@ -263,15 +276,7 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[ControlFlowBlock, Tuple[Set[str] if state.out_degree(anode) > 0: readset.add(anode.data) if isinstance(block, LoopRegion): - exprs = set([ block.loop_condition.as_string ]) - update_stmt = loop_analysis.get_update_assignment(block) - init_stmt = loop_analysis.get_init_assignment(block) - if update_stmt: - exprs.add(update_stmt) - if init_stmt: - exprs.add(init_stmt) - for expr in exprs: - readset |= symbolic.free_symbols_and_functions(expr) & arrays + readset |= self._get_loop_region_readset(block, arrays) elif isinstance(block, ConditionalBlock): for cond, _ in block.branches: if cond is not None: diff --git a/dace/transformation/passes/constant_propagation.py b/dace/transformation/passes/constant_propagation.py index d95b72ab28..92183451a2 100644 --- a/dace/transformation/passes/constant_propagation.py +++ b/dace/transformation/passes/constant_propagation.py @@ -82,7 +82,6 @@ def _add_nested_datanames(name: str, desc: data.Structure): if isinstance(v, data.Structure): _add_nested_datanames(f'{name}.{k}', v) elif isinstance(v, data.ContainerArray): - # TODO: How are we handling this? pass arrays.add(f'{name}.{k}') @@ -91,11 +90,11 @@ def _add_nested_datanames(name: str, desc: data.Structure): _add_nested_datanames(name, desc) # Trace all constants and symbols through blocks - in_consts: BlockConstsT = { sdfg: initial_symbols } - pre_consts: BlockConstsT = {} - post_consts: BlockConstsT = {} - out_consts: BlockConstsT = {} - self._collect_constants_for_region(sdfg, arrays, in_consts, pre_consts, post_consts, out_consts) + in_constants: BlockConstsT = { sdfg: initial_symbols } + pre_constants: BlockConstsT = {} + post_constants: BlockConstsT = {} + out_constants: BlockConstsT = {} + self._collect_constants_for_region(sdfg, arrays, in_constants, pre_constants, post_constants, out_constants) # Keep track of replaced and ambiguous symbols symbols_replaced: Dict[str, Any] = {} @@ -103,11 +102,11 @@ def _add_nested_datanames(name: str, desc: data.Structure): # Collect symbols from symbol-dependent data descriptors # If there can be multiple values over the SDFG, the symbols are not propagated - desc_symbols, multivalue_desc_symbols = self._find_desc_symbols(sdfg, in_consts) + desc_symbols, multivalue_desc_symbols = self._find_desc_symbols(sdfg, in_constants) # Replace constants per state - for block, mapping in optional_progressbar(in_consts.items(), 'Propagating constants', n=len(in_consts), - progress=self.progress): + for block, mapping in optional_progressbar(in_constants.items(), 'Propagating constants', + n=len(in_constants), progress=self.progress): if block is sdfg: continue @@ -120,7 +119,8 @@ def _add_nested_datanames(name: str, desc: data.Structure): } out_mapping = { k: v - for k, v in out_consts[block].items() if v is not _UnknownValue and k not in multivalue_desc_symbols + for k, v in out_constants[block].items() + if v is not _UnknownValue and k not in multivalue_desc_symbols } if mapping: @@ -139,20 +139,7 @@ def _add_nested_datanames(name: str, desc: data.Structure): e.data.replace_dict(out_mapping, replace_keys=False) if isinstance(block, LoopRegion): - if block in post_consts and post_consts[block] is not None: - if block.update_statement is not None and (block.inverted and block.update_before_condition or - not block.inverted): - # Replace the RHS of the update experssion - post_mapping = { - k: v - for k, v in post_consts[block].items() - if v is not _UnknownValue and k not in multivalue_desc_symbols - } - update_stmt = block.update_statement - updates = update_stmt.code if isinstance(update_stmt.code, list) else [update_stmt.code] - for update in updates: - astutils.ASTReplaceAssignmentRHS(post_mapping).visit(update) - block.update_statement.code = updates + self._propagate_loop(block, post_constants, multivalue_desc_symbols) # Gather initial propagated symbols result = {k: v for k, v in symbols_replaced.items() if k not in remaining_unknowns} @@ -205,6 +192,23 @@ def _add_nested_datanames(name: str, desc: data.Structure): def report(self, pass_retval: Set[str]) -> str: return f'Propagated {len(pass_retval)} constants.' + def _propagate_loop(self, loop: LoopRegion, post_constants: BlockConstsT, + multivalue_desc_symbols: Set[str]) -> None: + if loop in post_constants and post_constants[loop] is not None: + if loop.update_statement is not None and (loop.inverted and loop.update_before_condition or + not loop.inverted): + # Replace the RHS of the update experssion + post_mapping = { + k: v + for k, v in post_constants[loop].items() + if v is not _UnknownValue and k not in multivalue_desc_symbols + } + update_stmt = loop.update_statement + updates = update_stmt.code if isinstance(update_stmt.code, list) else [update_stmt.code] + for update in updates: + astutils.ASTReplaceAssignmentRHS(post_mapping).visit(update) + loop.update_statement.code = updates + def _collect_constants_for_conditional(self, conditional: ConditionalBlock, arrays: Set[str], in_const_dict: BlockConstsT, pre_const_dict: BlockConstsT, post_const_dict: BlockConstsT, out_const_dict: BlockConstsT) -> None: diff --git a/dace/transformation/passes/dead_state_elimination.py b/dace/transformation/passes/dead_state_elimination.py index 23f2a785f5..e9c622d128 100644 --- a/dace/transformation/passes/dead_state_elimination.py +++ b/dace/transformation/passes/dead_state_elimination.py @@ -166,7 +166,7 @@ def _find_dead_branches(self, block: ConditionalBlock) -> List[Tuple[CodeBlock, raise InvalidSDFGNodeError('Conditional block detected, where else branch is not the last branch') break # If an unconditional branch is found, ignore all other branches that follow this one. - if cond.as_string.strip() == '1' or self._is_truthy(symbolic.pystr_to_symbolic(cond.as_string), block.sdfg): + if cond.as_string.strip() == '1' or self._is_definitely_true(symbolic.pystr_to_symbolic(cond.as_string), block.sdfg): unconditional = branch break if unconditional is not None: @@ -177,7 +177,7 @@ def _find_dead_branches(self, block: ConditionalBlock) -> List[Tuple[CodeBlock, else: # Check if any branches are certainly never taken. for cond, branch in block.branches: - if cond is not None and self._is_falsy(symbolic.pystr_to_symbolic(cond.as_string), block.sdfg): + if cond is not None and self._is_definitely_false(symbolic.pystr_to_symbolic(cond.as_string), block.sdfg): dead_branches.append([cond, branch]) return dead_branches @@ -195,9 +195,9 @@ def is_definitely_taken(self, edge: InterstateEdge, sdfg: SDFG) -> bool: return True # Evaluate condition - return self._is_truthy(edge.condition_sympy(), sdfg) + return self._is_definitely_true(edge.condition_sympy(), sdfg) - def _is_truthy(self, cond: sp.Basic, sdfg: SDFG) -> bool: + def _is_definitely_true(self, cond: sp.Basic, sdfg: SDFG) -> bool: if cond == True or cond == sp.Not(sp.logic.boolalg.BooleanFalse(), evaluate=False): return True @@ -215,9 +215,9 @@ def is_definitely_not_taken(self, edge: InterstateEdge, sdfg: SDFG) -> bool: return False # Evaluate condition - return self._is_falsy(edge.condition_sympy(), sdfg) + return self._is_definitely_false(edge.condition_sympy(), sdfg) - def _is_falsy(self, cond: sp.Basic, sdfg: SDFG) -> bool: + def _is_definitely_false(self, cond: sp.Basic, sdfg: SDFG) -> bool: if cond == False or cond == sp.Not(sp.logic.boolalg.BooleanTrue(), evaluate=False): return True