diff --git a/dace/transformation/passes/analysis/sdfg_statistics.py b/dace/transformation/passes/analysis/sdfg_statistics.py index 7f5b552190..3961495da5 100644 --- a/dace/transformation/passes/analysis/sdfg_statistics.py +++ b/dace/transformation/passes/analysis/sdfg_statistics.py @@ -6,6 +6,7 @@ import dace from dace import SDFG, SDFGState from dace.transformation import pass_pipeline as ppl +from dace.sdfg import utils as sdutil class SDFGStatistics(ppl.Pass): def modifies(self) -> ppl.Modifies: @@ -34,7 +35,7 @@ def camel_to_title(s: str) -> str: elif isinstance(value, defaultdict): formatted_value = "\n" + self._print_defaultdict(value, tab_count + 1)[:-1] else: - raise Exception("Unsupported type") + formatted_value = value retstr += (f"{tab_count * '\t'}{readable_key}: {formatted_value}\n") return retstr @@ -52,7 +53,29 @@ def _count_nested_sdfgs(self, sdfg: SDFG, level: int = 0, counts: Dict[int, int] return counts - def apply_pass(self, sdfg: SDFG, _: Dict[str, Any]) -> Optional[Dict[str, Set[str]]]: + def _count_maps(self, sdfg: SDFG): + num_outer_maps = 0 + maps = [] + for s in sdfg.states(): + sdict = s.scope_dict() + for n in s.nodes(): + if isinstance(n, dace.nodes.MapEntry): + if sdict[n] is None: + num_outer_maps += 1 + map_entry = n + map_exit = s.exit_node(n) + num_inner_maps = 0 + inner_map_info = [] + for inner_node in sdutil.dfs_topological_sort(s, map_entry): + if inner_node != map_entry and isinstance(inner_node, dace.nodes.MapEntry): + num_inner_maps += 1 + inner_map_info.append((inner_node.map.label, inner_node.map.range)) + if inner_node == map_exit: + break + maps.append((n.map.label, n.map.range, num_inner_maps, inner_map_info)) + return num_outer_maps, maps + + def apply_pass(self, sdfg: SDFG, _: Dict[str, Any], recursive: bool = False) -> Optional[Dict[str, Set[str]]]: statistics_dict: Dict[str, Set[str]] = defaultdict(lambda: set()) states = sdfg.states() @@ -82,9 +105,26 @@ def apply_pass(self, sdfg: SDFG, _: Dict[str, Any]) -> Optional[Dict[str, Set[st for k, v in counts.items(): formatted_counts[f"Level {k}"] = str(v) - statistics_dict["number_of_nested_sdfgs_per_depth"] = formatted_counts + if len(formatted_counts) > 0: + statistics_dict["number_of_nested_sdfgs_per_depth"] = formatted_counts + else: + statistics_dict["number_of_nested_sdfgs_per_depth"] = None + statistics_dict["maximum_nested_sdfg_depth"] = {len(formatted_counts)} + num_outer_maps, map_info = self._count_maps(sdfg) + statistics_dict["number_of_outer_maps"] = num_outer_maps + + map_str = "" + # (n.map.label, n.map.range, num_inner_maps, inner_map_info) = map_info + for outer_map, outer_range, num_inner_maps, inner_map_info in map_info: + map_str += f"\t{outer_map}: {outer_range}\n" + map_str += f"\tNum Inner Maps: {num_inner_maps}\n" + map_str += "\n".join([f"\t\t{f}: {s}" for f, s in inner_map_info]) + + statistics_dict["map_information"] = "\n" + map_str if map_str != "" else None + + retstr = self._print_defaultdict(statistics_dict)[:-1] print(retstr)