From 7e4bc3da7154164ca5b1264a86dbd1cf921e6260 Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Wed, 13 Nov 2024 12:16:30 +0100 Subject: [PATCH 1/2] Fix loop symbol type inference and loop to map --- dace/codegen/targets/framecode.py | 17 ++++++++++------- dace/transformation/interstate/loop_to_map.py | 12 +++++++++++- 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/dace/codegen/targets/framecode.py b/dace/codegen/targets/framecode.py index c0e08cfba7..11a198f119 100644 --- a/dace/codegen/targets/framecode.py +++ b/dace/codegen/targets/framecode.py @@ -936,13 +936,16 @@ def generate_code(self, if isinstance(cfr, LoopRegion) and cfr.loop_variable is not None and cfr.init_statement is not None: if not cfr.loop_variable in interstate_symbols: - l_end = loop_analysis.get_loop_end(cfr) - l_start = loop_analysis.get_init_assignment(cfr) - l_step = loop_analysis.get_loop_stride(cfr) - sym_type = dtypes.result_type_of(infer_expr_type(l_start, global_symbols), - infer_expr_type(l_step, global_symbols), - infer_expr_type(l_end, global_symbols)) - interstate_symbols[cfr.loop_variable] = sym_type + if cfr.loop_variable in global_symbols: + interstate_symbols[cfr.loop_variable] = global_symbols[cfr.loop_variable] + else: + l_end = loop_analysis.get_loop_end(cfr) + l_start = loop_analysis.get_init_assignment(cfr) + l_step = loop_analysis.get_loop_stride(cfr) + sym_type = dtypes.result_type_of(infer_expr_type(l_start, global_symbols), + infer_expr_type(l_step, global_symbols), + infer_expr_type(l_end, global_symbols)) + interstate_symbols[cfr.loop_variable] = sym_type if not cfr.loop_variable in global_symbols: global_symbols[cfr.loop_variable] = interstate_symbols[cfr.loop_variable] diff --git a/dace/transformation/interstate/loop_to_map.py b/dace/transformation/interstate/loop_to_map.py index 55327af5fb..9f487f561a 100644 --- a/dace/transformation/interstate/loop_to_map.py +++ b/dace/transformation/interstate/loop_to_map.py @@ -6,7 +6,8 @@ import sympy as sp from typing import Dict, List, Set -from dace import data as dt, memlet, nodes, sdfg as sd, symbolic, subsets, properties +from dace import data as dt, dtypes, memlet, nodes, sdfg as sd, symbolic, subsets, properties +from dace.codegen.tools.type_inference import infer_expr_type from dace.sdfg import graph as gr, nodes from dace.sdfg import SDFG, SDFGState from dace.sdfg import utils as sdutil @@ -94,6 +95,15 @@ def can_be_applied(self, graph, expr_index, sdfg, permissive = False): if start is None or end is None or step is None or itervar is None: return False + sset = {} + sset.update(sdfg.symbols) + sset.update(sdfg.arrays) + t = dtypes.result_type_of(infer_expr_type(start, sset), infer_expr_type(step, sset), infer_expr_type(end, sset)) + # We may only convert something to map if the bounds are all integer-derived types. Otherwise most map schedules + # except for sequential would be invalid. + if not t in dtypes.INTEGER_TYPES: + return False + # Loops containing break, continue, or returns may not be turned into a map. for blk in self.loop.all_control_flow_blocks(): if isinstance(blk, (BreakBlock, ContinueBlock, ReturnBlock)): From 40d4a125c941988a684be21404ffc9ebc440977a Mon Sep 17 00:00:00 2001 From: Philipp Schaad Date: Wed, 13 Nov 2024 15:26:40 +0100 Subject: [PATCH 2/2] Fix traversal for defined symbols --- dace/sdfg/utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/dace/sdfg/utils.py b/dace/sdfg/utils.py index 8015c6dd4d..46cdf1fe13 100644 --- a/dace/sdfg/utils.py +++ b/dace/sdfg/utils.py @@ -1557,6 +1557,10 @@ def _tswds_cf_region( for _, b in region.branches: yield from _tswds_cf_region(sdfg, b, symbols, recursive) return + elif isinstance(region, LoopRegion): + # Add the own loop variable to the defined symbols, if present. + loop_syms = region.new_symbols(symbols) + symbols.update({k: v for k, v in loop_syms.items() if v is not None}) # Add symbols from inter-state edges along the state machine start_region = region.start_block