Skip to content

Commit

Permalink
Address more comments
Browse files Browse the repository at this point in the history
  • Loading branch information
phschaad committed Dec 9, 2024
1 parent c4e78d7 commit 3a2b342
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 50 deletions.
19 changes: 9 additions & 10 deletions dace/sdfg/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,11 +726,11 @@ def update_if_not_none(dic, update):
defined_syms[str(sym)] = sym.dtype

# Add inter-state symbols
if isinstance(sdfg.start_block, LoopRegion):
if isinstance(sdfg.start_block, AbstractControlFlowRegion):
update_if_not_none(defined_syms, sdfg.start_block.new_symbols(defined_syms))
for edge in sdfg.all_interstate_edges():
update_if_not_none(defined_syms, edge.data.new_symbols(sdfg, defined_syms))
if isinstance(edge.dst, LoopRegion):
if isinstance(edge.dst, AbstractControlFlowRegion):
update_if_not_none(defined_syms, edge.dst.new_symbols(defined_syms))

# Add scope symbols all the way to the subgraph
Expand Down Expand Up @@ -2722,6 +2722,12 @@ def inline(self) -> Tuple[bool, Any]:

return False, None

def new_symbols(self, symbols: dict) -> Dict[str, dtypes.typeclass]:
"""
Returns a mapping between the symbol defined by this control flow region and its type, if it exists.
"""
return {}

###################################################################
# CFG API methods

Expand Down Expand Up @@ -3306,9 +3312,6 @@ def _used_symbols_internal(self,
return free_syms, defined_syms, used_before_assignment

def new_symbols(self, symbols) -> Dict[str, dtypes.typeclass]:
"""
Returns a mapping between the symbol defined by this loop and its type, if it exists.
"""
# Avoid cyclic import
from dace.codegen.tools.type_inference import infer_expr_type
from dace.transformation.passes.analysis import loop_analysis
Expand Down Expand Up @@ -3402,11 +3405,7 @@ def add_branch(self, condition: Optional[CodeBlock], branch: ControlFlowRegion):
branch.sdfg = self.sdfg

def remove_branch(self, branch: ControlFlowRegion):
filtered_branches = []
for c, b in self._branches:
if b is not branch:
filtered_branches.append((c, b))
self._branches = filtered_branches
self._branches = [(c, b) for c, b in self._branches if b is not branch]

def _used_symbols_internal(self,
all_symbols: bool,
Expand Down
2 changes: 1 addition & 1 deletion dace/transformation/interstate/loop_to_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ def apply(self, graph: ControlFlowRegion, sdfg: sd.SDFG):

# Direct edges among source and sink access nodes must pass through a tasklet.
# We first gather them and handle them later.
direct_edges: Set[gr.Edge[memlet.Memlet]] = set()
direct_edges: Set[gr.MultiConnectorEdge[memlet.Memlet]] = set()
for n1 in source_nodes:
if not isinstance(n1, nodes.AccessNode):
continue
Expand Down
23 changes: 14 additions & 9 deletions dace/transformation/passes/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,19 @@ def should_reapply(self, modified: ppl.Modifies) -> bool:
# If access nodes were modified, reapply
return modified & ppl.Modifies.AccessNodes

def _get_loop_region_readset(self, loop: LoopRegion, arrays: Set[str]) -> Set[str]:
readset = set()
exprs = { loop.loop_condition.as_string }
update_stmt = loop_analysis.get_update_assignment(loop)
init_stmt = loop_analysis.get_init_assignment(loop)
if update_stmt:
exprs.add(update_stmt)
if init_stmt:
exprs.add(init_stmt)
for expr in exprs:
readset |= symbolic.free_symbols_and_functions(expr) & arrays
return readset

def apply_pass(self, top_sdfg: SDFG, _) -> Dict[ControlFlowBlock, Tuple[Set[str], Set[str]]]:
"""
:return: A dictionary mapping each control flow block to a tuple of its (read, written) data descriptors.
Expand All @@ -263,15 +276,7 @@ def apply_pass(self, top_sdfg: SDFG, _) -> Dict[ControlFlowBlock, Tuple[Set[str]
if state.out_degree(anode) > 0:
readset.add(anode.data)
if isinstance(block, LoopRegion):
exprs = set([ block.loop_condition.as_string ])
update_stmt = loop_analysis.get_update_assignment(block)
init_stmt = loop_analysis.get_init_assignment(block)
if update_stmt:
exprs.add(update_stmt)
if init_stmt:
exprs.add(init_stmt)
for expr in exprs:
readset |= symbolic.free_symbols_and_functions(expr) & arrays
readset |= self._get_loop_region_readset(block, arrays)
elif isinstance(block, ConditionalBlock):
for cond, _ in block.branches:
if cond is not None:
Expand Down
52 changes: 28 additions & 24 deletions dace/transformation/passes/constant_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,6 @@ def _add_nested_datanames(name: str, desc: data.Structure):
if isinstance(v, data.Structure):
_add_nested_datanames(f'{name}.{k}', v)
elif isinstance(v, data.ContainerArray):
# TODO: How are we handling this?
pass
arrays.add(f'{name}.{k}')

Expand All @@ -91,23 +90,23 @@ def _add_nested_datanames(name: str, desc: data.Structure):
_add_nested_datanames(name, desc)

# Trace all constants and symbols through blocks
in_consts: BlockConstsT = { sdfg: initial_symbols }
pre_consts: BlockConstsT = {}
post_consts: BlockConstsT = {}
out_consts: BlockConstsT = {}
self._collect_constants_for_region(sdfg, arrays, in_consts, pre_consts, post_consts, out_consts)
in_constants: BlockConstsT = { sdfg: initial_symbols }
pre_constants: BlockConstsT = {}
post_constants: BlockConstsT = {}
out_constants: BlockConstsT = {}
self._collect_constants_for_region(sdfg, arrays, in_constants, pre_constants, post_constants, out_constants)

# Keep track of replaced and ambiguous symbols
symbols_replaced: Dict[str, Any] = {}
remaining_unknowns: Set[str] = set()

# Collect symbols from symbol-dependent data descriptors
# If there can be multiple values over the SDFG, the symbols are not propagated
desc_symbols, multivalue_desc_symbols = self._find_desc_symbols(sdfg, in_consts)
desc_symbols, multivalue_desc_symbols = self._find_desc_symbols(sdfg, in_constants)

# Replace constants per state
for block, mapping in optional_progressbar(in_consts.items(), 'Propagating constants', n=len(in_consts),
progress=self.progress):
for block, mapping in optional_progressbar(in_constants.items(), 'Propagating constants',
n=len(in_constants), progress=self.progress):
if block is sdfg:
continue

Expand All @@ -120,7 +119,8 @@ def _add_nested_datanames(name: str, desc: data.Structure):
}
out_mapping = {
k: v
for k, v in out_consts[block].items() if v is not _UnknownValue and k not in multivalue_desc_symbols
for k, v in out_constants[block].items()
if v is not _UnknownValue and k not in multivalue_desc_symbols
}

if mapping:
Expand All @@ -139,20 +139,7 @@ def _add_nested_datanames(name: str, desc: data.Structure):
e.data.replace_dict(out_mapping, replace_keys=False)

if isinstance(block, LoopRegion):
if block in post_consts and post_consts[block] is not None:
if block.update_statement is not None and (block.inverted and block.update_before_condition or
not block.inverted):
# Replace the RHS of the update experssion
post_mapping = {
k: v
for k, v in post_consts[block].items()
if v is not _UnknownValue and k not in multivalue_desc_symbols
}
update_stmt = block.update_statement
updates = update_stmt.code if isinstance(update_stmt.code, list) else [update_stmt.code]
for update in updates:
astutils.ASTReplaceAssignmentRHS(post_mapping).visit(update)
block.update_statement.code = updates
self._propagate_loop(block, post_constants, multivalue_desc_symbols)

# Gather initial propagated symbols
result = {k: v for k, v in symbols_replaced.items() if k not in remaining_unknowns}
Expand Down Expand Up @@ -205,6 +192,23 @@ def _add_nested_datanames(name: str, desc: data.Structure):
def report(self, pass_retval: Set[str]) -> str:
return f'Propagated {len(pass_retval)} constants.'

def _propagate_loop(self, loop: LoopRegion, post_constants: BlockConstsT,
multivalue_desc_symbols: Set[str]) -> None:
if loop in post_constants and post_constants[loop] is not None:
if loop.update_statement is not None and (loop.inverted and loop.update_before_condition or
not loop.inverted):
# Replace the RHS of the update experssion
post_mapping = {
k: v
for k, v in post_constants[loop].items()
if v is not _UnknownValue and k not in multivalue_desc_symbols
}
update_stmt = loop.update_statement
updates = update_stmt.code if isinstance(update_stmt.code, list) else [update_stmt.code]
for update in updates:
astutils.ASTReplaceAssignmentRHS(post_mapping).visit(update)
loop.update_statement.code = updates

def _collect_constants_for_conditional(self, conditional: ConditionalBlock, arrays: Set[str],
in_const_dict: BlockConstsT, pre_const_dict: BlockConstsT,
post_const_dict: BlockConstsT, out_const_dict: BlockConstsT) -> None:
Expand Down
12 changes: 6 additions & 6 deletions dace/transformation/passes/dead_state_elimination.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def _find_dead_branches(self, block: ConditionalBlock) -> List[Tuple[CodeBlock,
raise InvalidSDFGNodeError('Conditional block detected, where else branch is not the last branch')
break
# If an unconditional branch is found, ignore all other branches that follow this one.
if cond.as_string.strip() == '1' or self._is_truthy(symbolic.pystr_to_symbolic(cond.as_string), block.sdfg):
if cond.as_string.strip() == '1' or self._is_definitely_true(symbolic.pystr_to_symbolic(cond.as_string), block.sdfg):
unconditional = branch
break
if unconditional is not None:
Expand All @@ -177,7 +177,7 @@ def _find_dead_branches(self, block: ConditionalBlock) -> List[Tuple[CodeBlock,
else:
# Check if any branches are certainly never taken.
for cond, branch in block.branches:
if cond is not None and self._is_falsy(symbolic.pystr_to_symbolic(cond.as_string), block.sdfg):
if cond is not None and self._is_definitely_false(symbolic.pystr_to_symbolic(cond.as_string), block.sdfg):
dead_branches.append([cond, branch])

return dead_branches
Expand All @@ -195,9 +195,9 @@ def is_definitely_taken(self, edge: InterstateEdge, sdfg: SDFG) -> bool:
return True

# Evaluate condition
return self._is_truthy(edge.condition_sympy(), sdfg)
return self._is_definitely_true(edge.condition_sympy(), sdfg)

def _is_truthy(self, cond: sp.Basic, sdfg: SDFG) -> bool:
def _is_definitely_true(self, cond: sp.Basic, sdfg: SDFG) -> bool:
if cond == True or cond == sp.Not(sp.logic.boolalg.BooleanFalse(), evaluate=False):
return True

Expand All @@ -215,9 +215,9 @@ def is_definitely_not_taken(self, edge: InterstateEdge, sdfg: SDFG) -> bool:
return False

# Evaluate condition
return self._is_falsy(edge.condition_sympy(), sdfg)
return self._is_definitely_false(edge.condition_sympy(), sdfg)

def _is_falsy(self, cond: sp.Basic, sdfg: SDFG) -> bool:
def _is_definitely_false(self, cond: sp.Basic, sdfg: SDFG) -> bool:
if cond == False or cond == sp.Not(sp.logic.boolalg.BooleanTrue(), evaluate=False):
return True

Expand Down

0 comments on commit 3a2b342

Please sign in to comment.