Skip to content

Commit

Permalink
Merge branch 'users/phschaad/adapt_passes' into users/phschaad/cf_blo…
Browse files Browse the repository at this point in the history
…ck_data_deps
  • Loading branch information
phschaad committed Nov 13, 2024
2 parents 49106de + 40d4a12 commit f7b2543
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 8 deletions.
17 changes: 10 additions & 7 deletions dace/codegen/targets/framecode.py
Original file line number Diff line number Diff line change
Expand Up @@ -936,13 +936,16 @@ def generate_code(self,

if isinstance(cfr, LoopRegion) and cfr.loop_variable is not None and cfr.init_statement is not None:
if not cfr.loop_variable in interstate_symbols:
l_end = loop_analysis.get_loop_end(cfr)
l_start = loop_analysis.get_init_assignment(cfr)
l_step = loop_analysis.get_loop_stride(cfr)
sym_type = dtypes.result_type_of(infer_expr_type(l_start, global_symbols),
infer_expr_type(l_step, global_symbols),
infer_expr_type(l_end, global_symbols))
interstate_symbols[cfr.loop_variable] = sym_type
if cfr.loop_variable in global_symbols:
interstate_symbols[cfr.loop_variable] = global_symbols[cfr.loop_variable]
else:
l_end = loop_analysis.get_loop_end(cfr)
l_start = loop_analysis.get_init_assignment(cfr)
l_step = loop_analysis.get_loop_stride(cfr)
sym_type = dtypes.result_type_of(infer_expr_type(l_start, global_symbols),
infer_expr_type(l_step, global_symbols),
infer_expr_type(l_end, global_symbols))
interstate_symbols[cfr.loop_variable] = sym_type
if not cfr.loop_variable in global_symbols:
global_symbols[cfr.loop_variable] = interstate_symbols[cfr.loop_variable]

Expand Down
4 changes: 4 additions & 0 deletions dace/sdfg/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1557,6 +1557,10 @@ def _tswds_cf_region(
for _, b in region.branches:
yield from _tswds_cf_region(sdfg, b, symbols, recursive)
return
elif isinstance(region, LoopRegion):
# Add the own loop variable to the defined symbols, if present.
loop_syms = region.new_symbols(symbols)
symbols.update({k: v for k, v in loop_syms.items() if v is not None})

# Add symbols from inter-state edges along the state machine
start_region = region.start_block
Expand Down
12 changes: 11 additions & 1 deletion dace/transformation/interstate/loop_to_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import sympy as sp
from typing import Dict, List, Set

from dace import data as dt, memlet, nodes, sdfg as sd, symbolic, subsets, properties
from dace import data as dt, dtypes, memlet, nodes, sdfg as sd, symbolic, subsets, properties
from dace.codegen.tools.type_inference import infer_expr_type
from dace.sdfg import graph as gr, nodes
from dace.sdfg import SDFG, SDFGState
from dace.sdfg import utils as sdutil
Expand Down Expand Up @@ -94,6 +95,15 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive = False):
if start is None or end is None or step is None or itervar is None:
return False

sset = {}
sset.update(sdfg.symbols)
sset.update(sdfg.arrays)
t = dtypes.result_type_of(infer_expr_type(start, sset), infer_expr_type(step, sset), infer_expr_type(end, sset))
# We may only convert something to map if the bounds are all integer-derived types. Otherwise most map schedules
# except for sequential would be invalid.
if not t in dtypes.INTEGER_TYPES:
return False

# Loops containing break, continue, or returns may not be turned into a map.
for blk in self.loop.all_control_flow_blocks():
if isinstance(blk, (BreakBlock, ContinueBlock, ReturnBlock)):
Expand Down

0 comments on commit f7b2543

Please sign in to comment.