Skip to content

Commit

Permalink
More statistics
Browse files Browse the repository at this point in the history
  • Loading branch information
ThrudPrimrose committed Dec 3, 2024
1 parent 74dee0d commit 02ed614
Showing 1 changed file with 43 additions and 3 deletions.
46 changes: 43 additions & 3 deletions dace/transformation/passes/analysis/sdfg_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import dace
from dace import SDFG, SDFGState
from dace.transformation import pass_pipeline as ppl
from dace.sdfg import utils as sdutil

class SDFGStatistics(ppl.Pass):
def modifies(self) -> ppl.Modifies:
Expand Down Expand Up @@ -34,7 +35,7 @@ def camel_to_title(s: str) -> str:
elif isinstance(value, defaultdict):
formatted_value = "\n" + self._print_defaultdict(value, tab_count + 1)[:-1]
else:
raise Exception("Unsupported type")
formatted_value = value

retstr += (f"{tab_count * '\t'}{readable_key}: {formatted_value}\n")
return retstr
Expand All @@ -52,7 +53,29 @@ def _count_nested_sdfgs(self, sdfg: SDFG, level: int = 0, counts: Dict[int, int]
return counts


def apply_pass(self, sdfg: SDFG, _: Dict[str, Any]) -> Optional[Dict[str, Set[str]]]:
def _count_maps(self, sdfg: SDFG):
num_outer_maps = 0
maps = []
for s in sdfg.states():
sdict = s.scope_dict()
for n in s.nodes():
if isinstance(n, dace.nodes.MapEntry):
if sdict[n] is None:
num_outer_maps += 1
map_entry = n
map_exit = s.exit_node(n)
num_inner_maps = 0
inner_map_info = []
for inner_node in sdutil.dfs_topological_sort(s, map_entry):
if inner_node != map_entry and isinstance(inner_node, dace.nodes.MapEntry):
num_inner_maps += 1
inner_map_info.append((inner_node.map.label, inner_node.map.range))
if inner_node == map_exit:
break
maps.append((n.map.label, n.map.range, num_inner_maps, inner_map_info))
return num_outer_maps, maps

def apply_pass(self, sdfg: SDFG, _: Dict[str, Any], recursive: bool = False) -> Optional[Dict[str, Set[str]]]:
statistics_dict: Dict[str, Set[str]] = defaultdict(lambda: set())

states = sdfg.states()
Expand Down Expand Up @@ -82,9 +105,26 @@ def apply_pass(self, sdfg: SDFG, _: Dict[str, Any]) -> Optional[Dict[str, Set[st
for k, v in counts.items():
formatted_counts[f"Level {k}"] = str(v)

statistics_dict["number_of_nested_sdfgs_per_depth"] = formatted_counts
if len(formatted_counts) > 0:
statistics_dict["number_of_nested_sdfgs_per_depth"] = formatted_counts
else:
statistics_dict["number_of_nested_sdfgs_per_depth"] = None

statistics_dict["maximum_nested_sdfg_depth"] = {len(formatted_counts)}

num_outer_maps, map_info = self._count_maps(sdfg)
statistics_dict["number_of_outer_maps"] = num_outer_maps

map_str = ""
# (n.map.label, n.map.range, num_inner_maps, inner_map_info) = map_info
for outer_map, outer_range, num_inner_maps, inner_map_info in map_info:
map_str += f"\t{outer_map}: {outer_range}\n"
map_str += f"\tNum Inner Maps: {num_inner_maps}\n"
map_str += "\n".join([f"\t\t{f}: {s}" for f, s in inner_map_info])

statistics_dict["map_information"] = "\n" + map_str if map_str != "" else None


retstr = self._print_defaultdict(statistics_dict)[:-1]
print(retstr)

Expand Down

0 comments on commit 02ed614

Please sign in to comment.