From 34df15c0a1f5c6784ea17301989d17d23bf636a1 Mon Sep 17 00:00:00 2001 From: Cliff Hodel Date: Tue, 4 Jul 2023 15:37:44 +0200 Subject: [PATCH 01/18] initial push of work_depth analysis script --- dace/sdfg/work_depth_analysis/helpers.py | 229 ++++++ .../work_depth_analysis.py | 749 ++++++++++++++++++ 2 files changed, 978 insertions(+) create mode 100644 dace/sdfg/work_depth_analysis/helpers.py create mode 100644 dace/sdfg/work_depth_analysis/work_depth_analysis.py diff --git a/dace/sdfg/work_depth_analysis/helpers.py b/dace/sdfg/work_depth_analysis/helpers.py new file mode 100644 index 0000000000..28b4741452 --- /dev/null +++ b/dace/sdfg/work_depth_analysis/helpers.py @@ -0,0 +1,229 @@ +from dace import SDFG, SDFGState, nodes, serialize +from collections import deque +from typing import List, Dict, Set, Tuple, Optional, Union +import networkx as nx + +NodeT = str +EdgeT = Tuple[NodeT, NodeT] + +class NodeCycle: + + nodes: Set[NodeT] = [] + + def __init__(self, nodes: List[NodeT]) -> None: + self.nodes = set(nodes) + + @property + def length(self) -> int: + return len(self.nodes) + + +UUID_SEPARATOR = '/' + + +def ids_to_string(sdfg_id, state_id=-1, node_id=-1, edge_id=-1): + return (str(sdfg_id) + UUID_SEPARATOR + str(state_id) + UUID_SEPARATOR + + str(node_id) + UUID_SEPARATOR + str(edge_id)) + +def get_uuid(element, state=None): + if isinstance(element, SDFG): + return ids_to_string(element.sdfg_id) + elif isinstance(element, SDFGState): + return ids_to_string(element.parent.sdfg_id, + element.parent.node_id(element)) + elif isinstance(element, nodes.Node): + return ids_to_string(state.parent.sdfg_id, state.parent.node_id(state), + state.node_id(element)) + else: + return ids_to_string(-1) + + + + + + + + +def get_domtree( + graph: nx.DiGraph, + start_node: str, + idom: Dict[str, str] = None +): + idom = idom or nx.immediate_dominators(graph, start_node) + + alldominated = { n: set() for n in graph.nodes } + domtree = nx.DiGraph() + + for node, dom in idom.items(): + if node is dom: + continue + domtree.add_edge(dom, node) + alldominated[dom].add(node) + + nextidom = idom[dom] + ndom = nextidom if nextidom != dom else None + + while ndom: + alldominated[ndom].add(node) + nextidom = idom[ndom] + ndom = nextidom if nextidom != ndom else None + + # 'Rank' the tree, i.e., annotate each node with the level it is on. + q = deque() + q.append((start_node, 0)) + while q: + node, level = q.popleft() + domtree.add_node(node, level=level) + for s in domtree.successors(node): + q.append((s, level + 1)) + + return alldominated, domtree + + + + +def backedges( + graph: nx.DiGraph, start: Optional[NodeT], strict: bool = False +) -> Union[Set[EdgeT], Tuple[Set[EdgeT], Set[EdgeT]]]: + '''Find all backedges in a directed graph. + + Note: + This algorithm has an algorithmic complexity of O((|V|+|E|)*C) for a + graph with vertices V, edges E, and C cycles. + + Args: + graph (nx.DiGraph): The graph for which to search backedges. + start (str): Start node of the graph. If no start is provided, a node + with no incoming edges is used as the start. If no such node can + be found, a `ValueError` is raised. + + Returns: + A set of backedges in the graph. + + Raises: + ValueError: If no `start` is provided and the graph contains no nodes + with no incoming edges. + ''' + backedges = set() + eclipsed_backedges = set() + + if start is None: + for node in graph.nodes(): + if graph.in_degree(node) == 0: + start = node + break + if start is None: + raise ValueError( + 'No start node provided and no start node could ' + + 'be determined automatically' + ) + + # Gather all cycles in the graph. Cycles are represented as a sequence of + # nodes. + # O((|V|+|E|)*(C+1)), for C cycles. + all_cycles_nx: List[List[NodeT]] = nx.cycles.simple_cycles(graph) + #all_cycles_nx: List[List[NodeT]] = nx.simple_cycles(graph) + all_cycles: Set[NodeCycle] = set() + for cycle in all_cycles_nx: + all_cycles.add(NodeCycle(cycle)) + + # Construct a dictionary mapping a node to the cycles containing that node. + # O(|V|*|C|) + cycle_map: Dict[NodeT, Set[NodeCycle]] = dict() + for cycle in all_cycles: + for node in cycle.nodes: + try: + cycle_map[node].add(cycle) + except KeyError: + cycle_map[node] = set([cycle]) + + # Do a BFS traversal of the graph to detect the back edges. + # For each node that is part of an (unhandled) cycle, find the longest + # still unhandled cycle and try to use it to find the back edge for it. + bfs_frontier = [start] + visited: Set[NodeT] = set([start]) + handled_cycles: Set[NodeCycle] = set() + unhandled_cycles = all_cycles + while bfs_frontier: + node = bfs_frontier.pop(0) + pred = [p for p in graph.predecessors(node) if p not in visited] + longest_cycles: Dict[NodeT, NodeCycle] = dict() + try: + cycles = cycle_map[node] + remove_cycles = set() + for cycle in cycles: + if cycle not in handled_cycles: + for p in pred: + if p in cycle.nodes: + if p not in longest_cycles: + longest_cycles[p] = cycle + else: + if cycle.length > longest_cycles[p].length: + longest_cycles[p] = cycle + else: + remove_cycles.add(cycle) + for cycle in remove_cycles: + cycles.remove(cycle) + except KeyError: + longest_cycles = dict() + + # For the current node, find the incoming edge which belongs to the + # cycle and has not been visited yet, which indicates a backedge. + node_backedge_candidates: Set[Tuple[EdgeT, NodeCycle]] = set() + for p, longest_cycle in longest_cycles.items(): + handled_cycles.add(longest_cycle) + unhandled_cycles.remove(longest_cycle) + cycle_map[node].remove(longest_cycle) + backedge_candidates = graph.in_edges(node) + for candidate in backedge_candidates: + src = candidate[0] + dst = candidate[0] + if src not in visited and src in longest_cycle.nodes: + node_backedge_candidates.add((candidate, longest_cycle)) + if not strict: + backedges.add(candidate) + + # Make sure that any cycle containing this back edge is + # not evaluated again, i.e., mark as handled. + remove_cycles = set() + for cycle in unhandled_cycles: + if src in cycle.nodes and dst in cycle.nodes: + handled_cycles.add(cycle) + remove_cycles.add(cycle) + for cycle in remove_cycles: + unhandled_cycles.remove(cycle) + + # If strict is set, we only report the longest cycle's back edges for + # any given node, and separately return any other backedges as + # 'eclipsed' backedges. In the case of a while-loop, for example, + # the loop edge is considered a backedge, while a continue inside the + # loop is considered an 'eclipsed' backedge. + if strict: + longest_candidate: Tuple[EdgeT, NodeCycle] = None + eclipsed_candidates = set() + for be_candidate in node_backedge_candidates: + if longest_candidate is None: + longest_candidate = be_candidate + elif longest_candidate[1].length < be_candidate[1].length: + eclipsed_candidates.add(longest_candidate[0]) + longest_candidate = be_candidate + else: + eclipsed_candidates.add(be_candidate[0]) + if longest_candidate is not None: + backedges.add(longest_candidate[0]) + if eclipsed_candidates: + eclipsed_backedges.update(eclipsed_candidates) + + + # Continue BFS. + for neighbour in graph.successors(node): + if neighbour not in visited: + visited.add(neighbour) + bfs_frontier.append(neighbour) + + if strict: + return backedges, eclipsed_backedges + else: + return backedges + + diff --git a/dace/sdfg/work_depth_analysis/work_depth_analysis.py b/dace/sdfg/work_depth_analysis/work_depth_analysis.py new file mode 100644 index 0000000000..7f1f746b26 --- /dev/null +++ b/dace/sdfg/work_depth_analysis/work_depth_analysis.py @@ -0,0 +1,749 @@ +import argparse +from collections import deque +from dace.sdfg import nodes as nd, propagation, InterstateEdge, utils as sdutil +from dace import SDFG, SDFGState, dtypes +from dace.subsets import Range +from typing import Tuple, Dict +import os +import sympy as sp +import networkx as nx +from copy import deepcopy +from dace.libraries.blas import MatMul, Transpose +from dace.libraries.standard import Reduce +from dace.symbolic import pystr_to_symbolic +import ast +import astunparse +import warnings +from dace.sdfg.graph import Edge + +from dace.sdfg.work_depth_analysis import get_uuid, get_domtree, backedges as get_backedges + + +def find_loop_guards_tails_exits(sdfg_nx: nx.DiGraph): + # preparation phase: compute dominators, backedges etc + for node in sdfg_nx.nodes(): + if sdfg_nx.in_degree(node) == 0: + start = node + break + if start is None: + raise ValueError('No start node could be determined') + + # sdfg can have multiple end nodes --> not good for postDomTree + # --> add a new end node + artificial_end_node = 'artificial_end_node' + sdfg_nx.add_node(artificial_end_node) + for node in sdfg_nx.nodes(): + if sdfg_nx.out_degree(node) == 0 and node != artificial_end_node: + # this is an end node of the sdfg + sdfg_nx.add_edge(node, artificial_end_node) + + # sanity check: + if sdfg_nx.in_degree(artificial_end_node) == 0: + raise ValueError('No end node could be determined in the SDFG') + + + + iDoms = nx.immediate_dominators(sdfg_nx, start) + allDom, domTree = get_domtree(sdfg_nx, start, iDoms) + + reversed_sdfg_nx = sdfg_nx.reverse() + iPostDoms = nx.immediate_dominators(reversed_sdfg_nx, artificial_end_node) + allPostDoms, postDomTree = get_domtree(reversed_sdfg_nx, artificial_end_node, iPostDoms) + + backedges = get_backedges(sdfg_nx, start) + backedgesDstDict = {} + for be in backedges: + if be[1] in backedgesDstDict: + backedgesDstDict[be[1]].add(be) + else: + backedgesDstDict[be[1]] = set([be]) + + + nodes_oNodes_exits = [] + + # iterate over all nodes + for node in sdfg_nx.nodes(): + # does any backedge end in node + if node in backedgesDstDict: + inc_backedges = backedgesDstDict[node] + + + # gather all successors of node that are not reached by backedges + successors = [] + for edge in sdfg_nx.out_edges(node): + if not edge in backedges: + successors.append(edge[1]) + + + # if len(inc_backedges) > 1: + # raise ValueError('node has multiple incoming backedges...') + # instead: if multiple incoming backedges, do the below for each backedge + for be in inc_backedges: + + + # since node has an incoming backedge, it is either a loop guard or loop tail + # oNode will exactly be the other thing + oNode = be[0] + exitCandidates = set() + for succ in successors: + if succ != oNode and oNode not in allDom[succ]: + exitCandidates.add(succ) + for succ in sdfg_nx.successors(oNode): + if succ != node: + exitCandidates.add(succ) + + if len(exitCandidates) == 0: + raise ValueError('failed to find any exit nodes') + elif len(exitCandidates) > 1: + # // Find the exit candidate that sits highest up in the + # // postdominator tree (i.e., has the lowest level). + # // That must be the exit node (it must post-dominate) + # // everything inside the loop. If there are multiple + # // candidates on the lowest level (i.e., disjoint set of + # // postdominated nodes), there are multiple exit paths, + # // and they all share one level. + cand = exitCandidates.pop() + minSet = set([cand]) + minLevel = nx.get_node_attributes(postDomTree, 'level')[cand] + for cand in exitCandidates: + curr_level = nx.get_node_attributes(postDomTree, 'level')[cand] + if curr_level < minLevel: + # new minimum found + minLevel = curr_level + minSet.clear() + minSet.add(cand) + elif curr_level == minLevel: + # add cand to curr set + minSet.add(cand) + + if len(minSet) > 0: + exitCandidates = minSet + else: + raise ValueError('failed to find exit minSet') + + # now we have a triple (node, oNode, exitCandidates) + nodes_oNodes_exits.append((node, oNode, exitCandidates)) + + return nodes_oNodes_exits + + + +def get_array_size_symbols(sdfg): + symbols = set() + for _, _, arr in sdfg.arrays_recursive(): + for s in arr.shape: + if isinstance(s, sp.Symbol): + symbols.add(s) + return symbols + +def posify_certain_symbols(expr, syms_to_posify, syms_to_nonnegify): + expr = sp.sympify(expr) + nonneg = {s: sp.Dummy(s.name, nonnegative=True, **s.assumptions0) + for s in syms_to_nonnegify if s.is_nonnegative is None} + pos = {s: sp.Dummy(s.name, positive=True, **s.assumptions0) + for s in syms_to_posify if s.is_positive is None} + # merge the two dicts into reps + reps = {**nonneg, **pos} + expr = expr.subs(reps) + return expr.subs({r: s for s, r in reps.items()}) + +def symeval(val, symbols): + first_replacement = { + pystr_to_symbolic(k): pystr_to_symbolic('__REPLSYM_' + k) + for k in symbols.keys() + } + second_replacement = { + pystr_to_symbolic('__REPLSYM_' + k): v + for k, v in symbols.items() + } + return val.subs(first_replacement).subs(second_replacement) + +def count_matmul(node, symbols, state): + A_memlet = next(e for e in state.in_edges(node) if e.dst_conn == '_a') + B_memlet = next(e for e in state.in_edges(node) if e.dst_conn == '_b') + C_memlet = next(e for e in state.out_edges(node) if e.src_conn == '_c') + result = 2 # Multiply, add + # Batch + if len(C_memlet.data.subset) == 3: + result *= symeval(C_memlet.data.subset.size()[0], symbols) + # M*N + result *= symeval(C_memlet.data.subset.size()[-2], symbols) + result *= symeval(C_memlet.data.subset.size()[-1], symbols) + # K + result *= symeval(A_memlet.data.subset.size()[-1], symbols) + return result + + +def count_reduce(node, symbols, state): + result = 0 + if node.wcr is not None: + result += count_arithmetic_ops_code(node.wcr) + in_memlet = None + in_edges = state.in_edges(node) + if in_edges is not None and len(in_edges) == 1: + in_memlet = in_edges[0] + if in_memlet is not None and in_memlet.data.volume is not None: + result *= in_memlet.data.volume + else: + result = 0 + return result + +bigo = sp.Function('bigo') +PYFUNC_TO_ARITHMETICS = { + 'float': 0, + 'math.exp': 1, + 'math.tanh': 1, + 'math.sqrt': 1, + 'min': 0, + 'max': 0, + 'ceiling': 0, + 'floor': 0, +} +LIBNODES_TO_ARITHMETICS = { + MatMul: count_matmul, + Transpose: lambda *args: 0, + Reduce: count_reduce, +} + + +class ArithmeticCounter(ast.NodeVisitor): + + def __init__(self): + self.count = 0 + + def visit_BinOp(self, node): + if isinstance(node.op, ast.MatMult): + raise NotImplementedError('MatMult op count requires shape ' + 'inference') + self.count += 1 + return self.generic_visit(node) + + def visit_UnaryOp(self, node): + self.count += 1 + return self.generic_visit(node) + + def visit_Call(self, node): + fname = astunparse.unparse(node.func)[:-1] + if fname not in PYFUNC_TO_ARITHMETICS: + print('WARNING: Unrecognized python function "%s"' % fname) + return self.generic_visit(node) + self.count += PYFUNC_TO_ARITHMETICS[fname] + return self.generic_visit(node) + + def visit_AugAssign(self, node): + return self.visit_BinOp(node) + + def visit_For(self, node): + raise NotImplementedError + + def visit_While(self, node): + raise NotImplementedError + +def count_arithmetic_ops_code(code): + ctr = ArithmeticCounter() + if isinstance(code, (tuple, list)): + for stmt in code: + ctr.visit(stmt) + elif isinstance(code, str): + ctr.visit(ast.parse(code)) + else: + ctr.visit(code) + return ctr.count + +class DepthCounter(ast.NodeVisitor): + + def __init__(self): + self.count = 0 + + # TODO: if we have a tasklet like _out = 2 * _in + 500 + # will this then have depth of 2? or not because of instruction level parallelism? + def visit_BinOp(self, node): + if isinstance(node.op, ast.MatMult): + raise NotImplementedError('MatMult op count requires shape ' + 'inference') + self.count += 1 + return self.generic_visit(node) + + def visit_UnaryOp(self, node): + self.count += 1 + return self.generic_visit(node) + + def visit_Call(self, node): + fname = astunparse.unparse(node.func)[:-1] + if fname not in PYFUNC_TO_ARITHMETICS: + print('WARNING: Unrecognized python function "%s"' % fname) + return self.generic_visit(node) + self.count += PYFUNC_TO_ARITHMETICS[fname] + return self.generic_visit(node) + + def visit_AugAssign(self, node): + return self.visit_BinOp(node) + + def visit_For(self, node): + raise NotImplementedError + + def visit_While(self, node): + raise NotImplementedError + +def count_depth_code(code): + ctr = DepthCounter() + if isinstance(code, (tuple, list)): + for stmt in code: + ctr.visit(stmt) + elif isinstance(code, str): + ctr.visit(ast.parse(code)) + else: + ctr.visit(code) + return ctr.count + +def tasklet_work(tasklet_node, state): + if tasklet_node.code.language == dtypes.Language.CPP: + for oedge in state.out_edges(tasklet_node): + return bigo(oedge.data.num_accesses) + + elif tasklet_node.code.language == dtypes.Language.Python: + return count_arithmetic_ops_code(tasklet_node.code.code) + else: + # other languages not implemented, count whole tasklet as work of 1 + warnings.warn('Work of tasklets only properly analyzed for Python or CPP. For all other ' + 'languages work = 1 will be counted for each tasklet.') + return 1 + +def tasklet_depth(tasklet_node, state): + # if tasklet_node.code.language == dtypes.Language.CPP: + # for oedge in state.out_edges(tasklet_node): + # return bigo(oedge.data.num_accesses) + + if tasklet_node.code.language == dtypes.Language.Python: + return count_depth_code(tasklet_node.code.code) + else: + # other languages not implemented, count whole tasklet as work of 1 + warnings.warn('Depth of tasklets only properly analyzed for Python code. For all other ' + 'languages depth = 1 will be counted for each tasklet.') + return 1 + +def get_tasklet_work(node, state): + return tasklet_work(node, state), -1 + +def get_tasklet_work_depth(node, state): + return tasklet_work(node, state), tasklet_depth(node, state) + +def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], analyze_tasklet, syms_to_nonnegify) -> None: + print('Analyzing work and depth of SDFG', sdfg.name) + print('SDFG has', len(sdfg.nodes()), 'states') + print('Calculating work and depth for all states individually...') + + # First determine the work and depth of each state individually. + # Keep track of the work and depth for each state in a dictionary, where work and depth are multiplied by the number + # of times the state will be executed. + state_depths: Dict[SDFGState, sp.Expr] = {} + state_works: Dict[SDFGState, sp.Expr] = {} + for state in sdfg.nodes(): + state_work, state_depth = state_work_depth(state, w_d_map, analyze_tasklet, syms_to_nonnegify) + if state.executions == 0:# or state.executions == sp.zoo: + print('State executions must be statically known exactly or with an upper bound. Offender:', state) + new_symbol = sp.Symbol(f'num_execs_{sdfg.sdfg_id}_{sdfg.node_id(state)}') + state.executions = new_symbol + syms_to_nonnegify |= {new_symbol} + state_works[state] = state_work * state.executions + state_depths[state] = state_depth * state.executions + w_d_map[get_uuid(state)] = (sp.simplify(state_work * state.executions), sp.simplify(state_depth * state.executions)) + + print('Calculating work and depth of the SDFG...') + + + nodes_oNodes_exits = find_loop_guards_tails_exits(sdfg._nx) + print(nodes_oNodes_exits) + # Now we need to go over each triple (node, oNode, exits) + # for each triple, we + # - remove edge (oNode, node), i.e. the backward edge + # - for all exits e, add edge (oNode, e). This edge may already exist + + for node, oNode, exits in nodes_oNodes_exits: + sdfg.remove_edge(sdfg.edges_between(oNode, node)[0]) + for e in exits: + # TODO: This will probably fail if len(exits) > 1, but in which cases does that even happen? + if len(sdfg.edges_between(oNode, e)) == 0: + # no edge there yet + sdfg.add_edge(oNode, e, InterstateEdge()) + + # Prepare the SDFG for a detph analysis by 'inlining' loops. This removes the edge between the guard and the exit + # state and the edge between the last loop state and the guard, and instead places an edge between the last loop + # state and the exit state. Additionally, construct a dummy exit state and connect every state that has no outgoing + # edges to it. + + + + + + dummy_exit = sdfg.add_state('dummy_exit') + for state in sdfg.nodes(): + """ + if hasattr(state, 'condition_edge') and hasattr(state, 'is_loop_guard') and state.is_loop_guard: + # This is a loop guard. + loop_begin = state.condition_edge.dst + # Determine loop states through a depth first search from the start of the loop. Everything reached before + # arriving back at the loop guard is part of the loop. + # TODO: This is hacky. Loops should report the loop states directly. This may fail or behave unexpectedly + # for break/return statements inside of loops. + loop_states = set(sdutil.dfs_conditional(sdfg, sources=[loop_begin], condition=lambda _, s: s != state)) + loop_exit = None + exit_edge = None + loop_end = None + end_edge = None + for iedge in sdfg.in_edges(state): + if iedge.src in loop_states: + end_edge = iedge + loop_end = iedge.src + for oedge in sdfg.out_edges(state): + if oedge.dst not in loop_states: + loop_exit = oedge.dst + exit_edge = oedge + + if loop_exit is None or loop_end is None: + raise RuntimeError('Failed to analyze the depth of a loop starting at', state) + + sdfg.remove_edge(exit_edge) + sdfg.remove_edge(end_edge) + sdfg.add_edge(loop_end, loop_exit, InterstateEdge()) + #""" + + if len(sdfg.out_edges(state)) == 0 and state != dummy_exit: + sdfg.add_edge(state, dummy_exit, InterstateEdge()) + + depth_map: Dict[SDFGState, sp.Expr] = {} + work_map: Dict[SDFGState, sp.Expr] = {} + state_depths[dummy_exit] = sp.sympify(0) + state_works[dummy_exit] = sp.sympify(0) + + # Perform a BFS traversal of the state machine and calculate the maximum work / depth at each state. Only advance to + # the next state in the BFS if all incoming edges have been visited, to ensure the maximum work / depth expressions + # have been calculated. + traversal_q = deque() + traversal_q.append((sdfg.start_state, sp.sympify(0), sp.sympify(0), None)) + visited = set() + while traversal_q: + state, depth, work, ie = traversal_q.popleft() + + if ie is not None: + visited.add(ie) + + n_depth = sp.simplify(depth + state_depths[state]) + n_work = sp.simplify(work + state_works[state]) + + if state in depth_map: + depth_map[state] = sp.Max(depth_map[state], n_depth) + else: + depth_map[state] = n_depth + + if state in work_map: + work_map[state] = sp.Max(work_map[state], n_work) + else: + work_map[state] = n_work + + out_edges = sdfg.out_edges(state) + if any(iedge not in visited for iedge in sdfg.in_edges(state)): + pass + else: + for oedge in out_edges: + traversal_q.append((oedge.dst, depth_map[state], work_map[state], oedge)) + + max_depth = depth_map[dummy_exit] + max_work = work_map[dummy_exit] + + print('SDFG', sdfg.name, 'processed') + w_d_map[get_uuid(sdfg)] = (sp.simplify(max_work), sp.simplify(max_depth)) + return sp.simplify(max_work), sp.simplify(max_depth) + + +""" +Analyze the work and depth of a scope. +This works by constructing a proxy graph of the scope and then finding the maximum depth path in that graph between +the source and sink. The proxy graph is constructed to remove any multi-edges between nodes and to remove nodes that +do not contribute to the depth. Additionally, nested scopes are summarized into single nodes. All of this is necessary +to reduce the number of possible paths in the graph, as much as possible, since they all have to be brute-force +enumerated to find the maximum depth path. +:note: This is terribly inefficient and should be improved. +:param state: The state in which the scope to analyze is contained. +:param sym_map: A dictionary mapping symbols to their values. +:param entry: The entry node of the scope to analyze. If None, the entire state is analyzed. +:return: A tuple containing the work and depth of the scope. +""" +def scope_work_depth(state: SDFGState, w_d_map: Dict[str, sp.Expr], analyze_tasklet, syms_to_nonnegify, entry: nd.EntryNode = None) -> Tuple[sp.Expr, sp.Expr]: + + # find the work / depth of each node + # for maps and nested SDFG, we do it recursively + work = sp.sympify(0) + max_depth = sp.sympify(0) + scope_nodes = state.scope_children()[entry] + scope_exit = None if entry is None else state.exit_node(entry) + for node in scope_nodes: + # add node to map + w_d_map[get_uuid(node, state)] = (sp.sympify(0), sp.sympify(0)) # TODO: do we need this line? + if isinstance(node, nd.EntryNode): + # If the scope contains an entry node, we need to recursively analyze the scope of the entry node first. + # The resulting work/depth are summarized into the entry node + s_work, s_depth = scope_work_depth(state, w_d_map, analyze_tasklet, syms_to_nonnegify, node) + # add up work for whole state, but also save work for this sub-scope scope in w_d_map + work += s_work + w_d_map[get_uuid(node, state)] = (s_work, s_depth) + elif node == scope_exit: + pass + elif isinstance(node, nd.Tasklet): + # add up work for whole state, but also save work for this node in w_d_map + t_work, t_depth = analyze_tasklet(node, state) + work += t_work + w_d_map[get_uuid(node, state)] = (sp.sympify(t_work), sp.sympify(t_depth)) + elif isinstance(node, nd.NestedSDFG): + # Nested SDFGs are recursively analyzed first. + nsdfg_work, nsdfg_depth = sdfg_work_depth(node.sdfg, w_d_map, analyze_tasklet, syms_to_nonnegify) + + # add up work for whole state, but also save work for this nested SDFG in w_d_map + work += nsdfg_work + w_d_map[get_uuid(node, state)] = (nsdfg_work, nsdfg_depth) + + if entry is not None: + # If the scope being analyzed is a map, multiply the work by the number of iterations of the map. + if isinstance(entry, nd.MapEntry): + nmap: nd.Map = entry.map + range: Range = nmap.range + n_exec = range.num_elements_exact() + work = work * sp.simplify(n_exec) + else: + print('WARNING: Only Map scopes are supported in work analysis for now. Assuming 1 iteration.') + + + # TODO: Kinda ugly if condition... + # only do this if we even analyzed depth of tasklets + max_depth = sp.sympify(0) + if analyze_tasklet == get_tasklet_work_depth: + # Calculate the maximum depth of the scope by finding the 'deepest' path from the source to the sink. This is done by + # a BFS in topological order, where each node propagates its current max depth for all incoming paths + traversal_q = deque() + visited = set() + # find all starting nodes + if entry: + # the entry is the starting node + traversal_q.append((entry, sp.sympify(0), None)) + else: + for node in scope_nodes: + if len(state.in_edges(node)) == 0: + # push this node into the deque + traversal_q.append((node, sp.sympify(0), None)) + + + depth_map = {} + + while traversal_q: + node, in_depth, in_edge = traversal_q.popleft() + + if in_edge is not None: + visited.add(in_edge) + + n_depth = sp.simplify(in_depth + w_d_map[get_uuid(node, state)][1]) + + if node in depth_map: + depth_map[node] = sp.Max(depth_map[node], n_depth) + else: + depth_map[node] = n_depth + + out_edges = state.out_edges(node) + # only advance to next node, if all incoming edges have been visited or the current node is the entry (aka starting node) + # if the current node is the exit of the current scope, we stop, such that we don't leave the current scope + if (all(iedge in visited for iedge in state.in_edges(node)) or node == entry) and node != scope_exit: + # if we encounter a nested map, we must not analyze its contents (as they have already been recursively analyzed) + # hence, we continue from the outgoing edges of the corresponding exit + if isinstance(node, nd.EntryNode) and node != entry: + # get the corresponding exit note + exit_node = state.exit_node(node) + # replace out_edges with the out_edges of the scope exit node + out_edges = state.out_edges(exit_node) + for oedge in out_edges: + traversal_q.append((oedge.dst, depth_map[node], oedge)) + if len(out_edges) == 0 or node == scope_exit: + # this is an end node --> update max_depth + max_depth = sp.Max(max_depth, depth_map[node]) + + # summarise work / depth of the whole state in the dictionary + w_d_map[get_uuid(state)] = (sp.simplify(work), sp.simplify(max_depth)) + return sp.simplify(work), sp.simplify(max_depth) + +""" +Analyze the work and depth of a state. +:param state: The state to analyze. +:param sym_map: A dictionary mapping symbols to their values. +:return: A tuple containing the work and depth of the state. +""" +def state_work_depth(state: SDFGState, w_d_map: Dict[str, sp.Expr], analyze_tasklet, syms_to_nonnegify) -> None: + work, depth = scope_work_depth(state, w_d_map, analyze_tasklet, syms_to_nonnegify, None) + return work, depth + + +""" +Analyze the work and depth of an SDFG. +Optionally, a dictionary mapping symbols to their values can be provided to concretize the analysis. +Note that this also significantly speeds up the analysis due to sympy not having to perform the analysis symbolically. +:note: SDFGs should have split interstate edges. This means there should be no interstate edges containing both a + condition and an assignment. +:param sdfg: The SDFG to analyze. +:param sym_map: A dictionary mapping symbols to their values. +:return: A tuple containing the work and depth of the SDFG +""" +# def analyze_sdfg(sdfg: SDFG, w_d_map: Dict[str, str], sym_map: Dict[str, int]) -> Dict[str, Tuple[str, str]]: +def analyze_sdfg(sdfg: SDFG, w_d_map: Dict[str, sp.Expr], analyze_tasklet) -> Dict[str, Tuple[sp.Expr, sp.Expr]]: + # Run state propagation for all SDFGs recursively. This is necessary to determine the number of times each state + # will be executed, or to determine upper bounds for that number (such as in the case of branching) + print('Propagating states...') + for sd in sdfg.all_sdfgs_recursive(): + propagation.propagate_states(sd) + + # deepcopy such that original sdfg not changed + # sdfg = deepcopy(sdfg) + + # Check if the SDFG has any dynamically unbounded executions, i.e., if there are any states that have neither a + # statically known number of executions, nor an upper bound on the number of executions. Warn if this is the case. + print('Checking for dynamically unbounded executions...') + for sd in sdfg.all_sdfgs_recursive(): + if any([s.executions == 0 and s.dynamic_executions for s in sd.nodes()]): + print('WARNING: SDFG has dynamic executions. The analysis may fail in unexpected ways or be inaccurate.') + + syms_to_nonnegify = set() + # Analyze the work and depth of the SDFG. + print('Analyzing SDFG...') + sdfg_work_depth(sdfg, w_d_map, analyze_tasklet, syms_to_nonnegify) + + # TODO: maybe do this posify more often for performance? + array_symbols = get_array_size_symbols(sdfg) + for k, (v_w, v_d) in w_d_map.items(): + v_w = posify_certain_symbols(v_w, array_symbols, syms_to_nonnegify) + v_d = posify_certain_symbols(v_d, array_symbols, syms_to_nonnegify) + w_d_map[k] = (v_w, v_d) + + + + + +def get_work(sdfg_json): + # final version loads sdfg from json + # loaded = load_sdfg_from_json(sdfg_json) + # if loaded['error'] is not None: + # return loaded['error'] + # sdfg = loaded['sdfg'] + + # for now we load simply load from a file + sdfg = SDFG.from_file(sdfg_json) + + + + # try: + work_map = {} + analyze_sdfg(sdfg, work_map, get_tasklet_work) + for k, v, in work_map.items(): + work_map[k] = (str(sp.simplify(v[0]))) + return { + 'workMap': work_map, + } + # except Exception as e: + # return { + # 'error': { + # 'message': 'Failed to analyze work depth', + # 'details': get_exception_message(e), + # }, + # } + + + +def get_work_depth(sdfg_json): + # final version loads sdfg from json + # loaded = load_sdfg_from_json(sdfg_json) + # if loaded['error'] is not None: + # return loaded['error'] + # sdfg = loaded['sdfg'] + + # for now we load simply load from a file + sdfg = SDFG.from_file(sdfg_json) + + + # try: + work_depth_map = {} + analyze_sdfg(sdfg, work_depth_map, get_tasklet_work_depth) + for k, v, in work_depth_map.items(): + work_depth_map[k] = (str(sp.simplify(v[0])), str(sp.simplify(v[1]))) + return { + 'workDepthMap': work_depth_map, + } + # except Exception as e: + # return { + # 'error': { + # 'message': 'Failed to analyze work depth', + # 'details': get_exception_message(e), + # }, + # } + + + + + +################################################################################ +# Utility functions for running the analysis from the command line ############# +################################################################################ + +class keyvalue(argparse.Action): + + def __call__(self, parser, namespace, values, option_string=None): + setattr(namespace, self.dest, dict()) + for v in values: + k, v = v.split('=') + getattr(namespace, self.dest)[k] = v + + +def main() -> None: + analyze_depth = True + + parser = argparse.ArgumentParser( + 'work_depth_analysis', + usage='python work_depth_analysis.py [-h] filename', + description='Analyze the work/depth of an SDFG.' + ) + + parser.add_argument('filename', type=str, help='The SDFG file to analyze.') + parser.add_argument('--kwargs', nargs='*', help='Define symbols.', action=keyvalue) + + args = parser.parse_args() + + if not os.path.exists(args.filename): + print(args.filename, 'does not exist.') + exit() + + symbols_map = {} + if args.kwargs: + for k, v in args.kwargs.items(): + symbols_map[k] = int(v) + + # TODO: symbols_map maybe not needed + if analyze_depth: + map = get_work_depth(args.filename) + map = map['workDepthMap'] + else: + map = get_work(args.filename) + map = map['workMap'] + + # find uuid of the whole SDFG + sdfg = SDFG.from_file(args.filename) + result = map[get_uuid(sdfg)] + + + print(80*'-') + if isinstance(result, Tuple): + print("Work:\t", result[0]) + print("Depth:\t", result[1]) + else: + print("Work:\t", result) + + print(80*'-') + + + + +if __name__ == '__main__': + main() \ No newline at end of file From e00024748d1d64b5929c67d1c6d5cf0f514488e8 Mon Sep 17 00:00:00 2001 From: Cliff Hodel Date: Tue, 4 Jul 2023 15:42:06 +0200 Subject: [PATCH 02/18] adding tests to work_depth analysis --- .../work_depth_analysis.py | 2 +- .../work_depth_analysis/work_depth_tests.py | 247 ++++++++++++++++++ 2 files changed, 248 insertions(+), 1 deletion(-) create mode 100644 dace/sdfg/work_depth_analysis/work_depth_tests.py diff --git a/dace/sdfg/work_depth_analysis/work_depth_analysis.py b/dace/sdfg/work_depth_analysis/work_depth_analysis.py index 7f1f746b26..ef0cd0e1fb 100644 --- a/dace/sdfg/work_depth_analysis/work_depth_analysis.py +++ b/dace/sdfg/work_depth_analysis/work_depth_analysis.py @@ -16,7 +16,7 @@ import warnings from dace.sdfg.graph import Edge -from dace.sdfg.work_depth_analysis import get_uuid, get_domtree, backedges as get_backedges +from dace.sdfg.work_depth_analysis.helpers import get_uuid, get_domtree, backedges as get_backedges def find_loop_guards_tails_exits(sdfg_nx: nx.DiGraph): diff --git a/dace/sdfg/work_depth_analysis/work_depth_tests.py b/dace/sdfg/work_depth_analysis/work_depth_tests.py new file mode 100644 index 0000000000..46eb36c3bf --- /dev/null +++ b/dace/sdfg/work_depth_analysis/work_depth_tests.py @@ -0,0 +1,247 @@ +import dace as dc +import numpy as np +from dace.sdfg.work_depth_analysis.work_depth_analysis import analyze_sdfg, get_tasklet_work_depth +from dace.sdfg.work_depth_analysis.helpers import get_uuid +import sympy as sp + +from dace.transformation.interstate import NestSDFG +from dace.transformation.dataflow import MapExpansion + + + + + +N = dc.symbol('N') +M = dc.symbol('M') +K = dc.symbol('K') + + +@dc.program +def single_map(x: dc.float64[N], y: dc.float64[N], z: dc.float64[N]): + z[:] = x + y + +@dc.program +def single_for_loop(x: dc.float64[N], y: dc.float64[N]): + for i in range(N): + x[i] += y[i] + +@dc.program +def if_else(x: dc.int64[1000], y: dc.int64[1000], z: dc.int64[1000], sum: dc.int64[1]): + if x[10] > 50: + z[:] = x + y # 1000 work, 1 depth + else: + for i in range(100): # 100 work, 100 depth + sum += x[i] + +@dc.program +def if_else_sym(x: dc.int64[N], y: dc.int64[N], z: dc.int64[N], sum: dc.int64[1]): + if x[10] > 50: + z[:] = x + y # N work, 1 depth + else: + for i in range(K): # K work, K depth + sum += x[i] + +@dc.program +def nested_sdfg(x: dc.float64[N], y: dc.float64[N], z: dc.float64[N]): + single_map(x, y, z) + single_for_loop(x, y) + +@dc.program +def nested_maps(x: dc.float64[N, M], y: dc.float64[N, M], z: dc.float64[N, M]): + z[:, :] = x + y + +@dc.program +def nested_for_loops(x: dc.float64[N], y: dc.float64[K]): + for i in range(N): + for j in range(K): + x[i] += y[j] + +@dc.program +def nested_if_else(x: dc.int64[N], y: dc.int64[N], z: dc.int64[N], sum: dc.int64[1]): + if x[10] > 50: + if x[9] > 50: + z[:] = x + y # N work, 1 depth + z[:] += 2 * x # 2*N work, 2 depth --> total outer if: 3*N work, 3 depth + else: + if y[9] > 50: + for i in range(K): + sum += x[i] # K work, K depth + else: + for j in range(M): + sum += x[j] # M work, M depth + z[:] = x + y # N work, depth 1 --> total inner else: M+N work, M+1 depth + # --> total outer else: Max(K, M+N) work, Max(K, M+1) depth + # --> total over both branches: Max(K, M+N, 3*N) work, Max(K, M+1, 3) depth + +@dc.program +def max_of_positive_symbol(x: dc.float64[N]): + if x[0] > 0: + for i in range(2*N): # work 2*N^2, depth 2*N + x += 1 + else: + for j in range(3*N): # work 3*N^2, depth 3*N + x += 1 + # total is work 3*N^2, depth 3*N without any max + + + +@dc.program +def multiple_array_sizes(x: dc.int64[N], y: dc.int64[N], z: dc.int64[N], + x2: dc.int64[M], y2: dc.int64[M], z2: dc.int64[M], + x3: dc.int64[K], y3: dc.int64[K], z3: dc.int64[K]): + if x[0] > 0: + z[:] = 2 * x + y # work 2*N, depth 2 + elif x[1] > 0: + z2[:] = 2 * x2 + y2 # work 2*M + 3, depth 5 + z2[0] += 3 + z[1] + z[2] + elif x[2] > 0: + z3[:] = 2 * x3 + y3 # work 2*K, depth 2 + elif x[3] > 0: + z[:] = 3 * x + y + 1 # work 3*N, depth 3 + # --> work= Max(3*N, 2*M, 2*K) and depth = 5 + + +@dc.program +def unbounded_while_do(x: dc.float64[N]): + while x[0] < 100: + x += 1 + +@dc.program +def unbounded_do_while(x: dc.float64[N]): + while True: + x += 1 + if x[0] >= 100: + break + + +@dc.program +def unbounded_nonnegify(x: dc.float64[N]): + while x[0] < 100: + if x[1] < 42: + x += 3*x + else: + x += x + +@dc.program +def continue_for_loop(x:dc.float64[N]): + for i in range(N): + if x[i] > 100: + continue + x += 1 + +@dc.program +def break_for_loop(x:dc.float64[N]): + for i in range(N): + if x[i] > 100: + break + x += 1 + +@dc.program +def break_while_loop(x:dc.float64[N]): + while x[0] > 10: + if x[1] > 100: + break + x += 1 + +# @dc.program +# def continue_for_loop2(x:dc.float64[N]): +# i = 0 +# while True: +# i += 1 +# if i % 2 == 0: +# continue +# x += 1 +# if x[0] > 10: +# break + + +tests = [single_map, + single_for_loop, + if_else, + if_else_sym, + nested_sdfg, + nested_maps, + nested_for_loops, + nested_if_else, + max_of_positive_symbol, + multiple_array_sizes, + unbounded_while_do, + unbounded_do_while, + unbounded_nonnegify, + continue_for_loop, + break_for_loop, + break_while_loop] +# tests = [single_map] +results = [(N, 1), + (N, N), + (1000, 100), + (sp.Max(N, K), sp.Max(K,1)), + (2*N, N + 1), + (N*M, 1), + (N*K, N*K), + (sp.Max(K, M+N, 3*N), sp.Max(K, M+1, 3)), + (3*N**2, 3*N), + (sp.Max(3*N, 2*M + 3, 2*K), 5), + (N*sp.Symbol('num_execs_0_2'), sp.Symbol('num_execs_0_2')), + (N*sp.Symbol('num_execs_0_1'), sp.Symbol('num_execs_0_1')), + (sp.Max(N*sp.Symbol('num_execs_0_5'), 2*N*sp.Symbol('num_execs_0_3')), sp.Max(sp.Symbol('num_execs_0_5'), 2*sp.Symbol('num_execs_0_3'))), + (sp.Symbol('num_execs_0_2')*N, sp.Symbol('num_execs_0_2')), + (N**2, N), + (sp.Symbol('num_execs_0_3')*N, sp.Symbol('num_execs_0_3'))] + + + + + +def test_work_depth(): + good = 0 + failed = 0 + exception = 0 + failed_tests = [] + for test, correct in zip(tests, results): + w_d_map = {} + sdfg = test.to_sdfg()#simplify=False) + if 'nested_sdfg' in test.name: + sdfg.apply_transformations(NestSDFG) + if 'nested_maps' in test.name: + sdfg.apply_transformations(MapExpansion) + # sdfg.view() + # try: + analyze_sdfg(sdfg, w_d_map, get_tasklet_work_depth) + res = w_d_map[get_uuid(sdfg)] + + # check result + if correct == res: + good += 1 + else: + # sdfg.view() + failed += 1 + failed_tests.append(test.name) + print(f'Test {test.name} failed:') + print('correct', correct) + print('result', res) + print() + # except Exception as e: + # print(e) + # failed += 1 + # exception += 1 + + print(100*'-') + print(100*'-') + print(f'Ran {len(tests)} tests. {good} succeeded and {failed} failed ' + f'({exception} of those triggered an exception)') + print(100*'-') + print('failed tests:', failed_tests) + print(100*'-') + + + + + + + + + + +if __name__ == '__main__': + test_work_depth() \ No newline at end of file From 58e19060e9ee23c25fe9aa8d375655b070bbd332 Mon Sep 17 00:00:00 2001 From: Cliff Hodel Date: Tue, 4 Jul 2023 17:18:18 +0200 Subject: [PATCH 03/18] rename work depth analysis --- dace/sdfg/work_depth_analysis/work_depth.py | 749 ++++++++++++++++++ .../work_depth_analysis/work_depth_tests.py | 2 +- 2 files changed, 750 insertions(+), 1 deletion(-) create mode 100644 dace/sdfg/work_depth_analysis/work_depth.py diff --git a/dace/sdfg/work_depth_analysis/work_depth.py b/dace/sdfg/work_depth_analysis/work_depth.py new file mode 100644 index 0000000000..ef0cd0e1fb --- /dev/null +++ b/dace/sdfg/work_depth_analysis/work_depth.py @@ -0,0 +1,749 @@ +import argparse +from collections import deque +from dace.sdfg import nodes as nd, propagation, InterstateEdge, utils as sdutil +from dace import SDFG, SDFGState, dtypes +from dace.subsets import Range +from typing import Tuple, Dict +import os +import sympy as sp +import networkx as nx +from copy import deepcopy +from dace.libraries.blas import MatMul, Transpose +from dace.libraries.standard import Reduce +from dace.symbolic import pystr_to_symbolic +import ast +import astunparse +import warnings +from dace.sdfg.graph import Edge + +from dace.sdfg.work_depth_analysis.helpers import get_uuid, get_domtree, backedges as get_backedges + + +def find_loop_guards_tails_exits(sdfg_nx: nx.DiGraph): + # preparation phase: compute dominators, backedges etc + for node in sdfg_nx.nodes(): + if sdfg_nx.in_degree(node) == 0: + start = node + break + if start is None: + raise ValueError('No start node could be determined') + + # sdfg can have multiple end nodes --> not good for postDomTree + # --> add a new end node + artificial_end_node = 'artificial_end_node' + sdfg_nx.add_node(artificial_end_node) + for node in sdfg_nx.nodes(): + if sdfg_nx.out_degree(node) == 0 and node != artificial_end_node: + # this is an end node of the sdfg + sdfg_nx.add_edge(node, artificial_end_node) + + # sanity check: + if sdfg_nx.in_degree(artificial_end_node) == 0: + raise ValueError('No end node could be determined in the SDFG') + + + + iDoms = nx.immediate_dominators(sdfg_nx, start) + allDom, domTree = get_domtree(sdfg_nx, start, iDoms) + + reversed_sdfg_nx = sdfg_nx.reverse() + iPostDoms = nx.immediate_dominators(reversed_sdfg_nx, artificial_end_node) + allPostDoms, postDomTree = get_domtree(reversed_sdfg_nx, artificial_end_node, iPostDoms) + + backedges = get_backedges(sdfg_nx, start) + backedgesDstDict = {} + for be in backedges: + if be[1] in backedgesDstDict: + backedgesDstDict[be[1]].add(be) + else: + backedgesDstDict[be[1]] = set([be]) + + + nodes_oNodes_exits = [] + + # iterate over all nodes + for node in sdfg_nx.nodes(): + # does any backedge end in node + if node in backedgesDstDict: + inc_backedges = backedgesDstDict[node] + + + # gather all successors of node that are not reached by backedges + successors = [] + for edge in sdfg_nx.out_edges(node): + if not edge in backedges: + successors.append(edge[1]) + + + # if len(inc_backedges) > 1: + # raise ValueError('node has multiple incoming backedges...') + # instead: if multiple incoming backedges, do the below for each backedge + for be in inc_backedges: + + + # since node has an incoming backedge, it is either a loop guard or loop tail + # oNode will exactly be the other thing + oNode = be[0] + exitCandidates = set() + for succ in successors: + if succ != oNode and oNode not in allDom[succ]: + exitCandidates.add(succ) + for succ in sdfg_nx.successors(oNode): + if succ != node: + exitCandidates.add(succ) + + if len(exitCandidates) == 0: + raise ValueError('failed to find any exit nodes') + elif len(exitCandidates) > 1: + # // Find the exit candidate that sits highest up in the + # // postdominator tree (i.e., has the lowest level). + # // That must be the exit node (it must post-dominate) + # // everything inside the loop. If there are multiple + # // candidates on the lowest level (i.e., disjoint set of + # // postdominated nodes), there are multiple exit paths, + # // and they all share one level. + cand = exitCandidates.pop() + minSet = set([cand]) + minLevel = nx.get_node_attributes(postDomTree, 'level')[cand] + for cand in exitCandidates: + curr_level = nx.get_node_attributes(postDomTree, 'level')[cand] + if curr_level < minLevel: + # new minimum found + minLevel = curr_level + minSet.clear() + minSet.add(cand) + elif curr_level == minLevel: + # add cand to curr set + minSet.add(cand) + + if len(minSet) > 0: + exitCandidates = minSet + else: + raise ValueError('failed to find exit minSet') + + # now we have a triple (node, oNode, exitCandidates) + nodes_oNodes_exits.append((node, oNode, exitCandidates)) + + return nodes_oNodes_exits + + + +def get_array_size_symbols(sdfg): + symbols = set() + for _, _, arr in sdfg.arrays_recursive(): + for s in arr.shape: + if isinstance(s, sp.Symbol): + symbols.add(s) + return symbols + +def posify_certain_symbols(expr, syms_to_posify, syms_to_nonnegify): + expr = sp.sympify(expr) + nonneg = {s: sp.Dummy(s.name, nonnegative=True, **s.assumptions0) + for s in syms_to_nonnegify if s.is_nonnegative is None} + pos = {s: sp.Dummy(s.name, positive=True, **s.assumptions0) + for s in syms_to_posify if s.is_positive is None} + # merge the two dicts into reps + reps = {**nonneg, **pos} + expr = expr.subs(reps) + return expr.subs({r: s for s, r in reps.items()}) + +def symeval(val, symbols): + first_replacement = { + pystr_to_symbolic(k): pystr_to_symbolic('__REPLSYM_' + k) + for k in symbols.keys() + } + second_replacement = { + pystr_to_symbolic('__REPLSYM_' + k): v + for k, v in symbols.items() + } + return val.subs(first_replacement).subs(second_replacement) + +def count_matmul(node, symbols, state): + A_memlet = next(e for e in state.in_edges(node) if e.dst_conn == '_a') + B_memlet = next(e for e in state.in_edges(node) if e.dst_conn == '_b') + C_memlet = next(e for e in state.out_edges(node) if e.src_conn == '_c') + result = 2 # Multiply, add + # Batch + if len(C_memlet.data.subset) == 3: + result *= symeval(C_memlet.data.subset.size()[0], symbols) + # M*N + result *= symeval(C_memlet.data.subset.size()[-2], symbols) + result *= symeval(C_memlet.data.subset.size()[-1], symbols) + # K + result *= symeval(A_memlet.data.subset.size()[-1], symbols) + return result + + +def count_reduce(node, symbols, state): + result = 0 + if node.wcr is not None: + result += count_arithmetic_ops_code(node.wcr) + in_memlet = None + in_edges = state.in_edges(node) + if in_edges is not None and len(in_edges) == 1: + in_memlet = in_edges[0] + if in_memlet is not None and in_memlet.data.volume is not None: + result *= in_memlet.data.volume + else: + result = 0 + return result + +bigo = sp.Function('bigo') +PYFUNC_TO_ARITHMETICS = { + 'float': 0, + 'math.exp': 1, + 'math.tanh': 1, + 'math.sqrt': 1, + 'min': 0, + 'max': 0, + 'ceiling': 0, + 'floor': 0, +} +LIBNODES_TO_ARITHMETICS = { + MatMul: count_matmul, + Transpose: lambda *args: 0, + Reduce: count_reduce, +} + + +class ArithmeticCounter(ast.NodeVisitor): + + def __init__(self): + self.count = 0 + + def visit_BinOp(self, node): + if isinstance(node.op, ast.MatMult): + raise NotImplementedError('MatMult op count requires shape ' + 'inference') + self.count += 1 + return self.generic_visit(node) + + def visit_UnaryOp(self, node): + self.count += 1 + return self.generic_visit(node) + + def visit_Call(self, node): + fname = astunparse.unparse(node.func)[:-1] + if fname not in PYFUNC_TO_ARITHMETICS: + print('WARNING: Unrecognized python function "%s"' % fname) + return self.generic_visit(node) + self.count += PYFUNC_TO_ARITHMETICS[fname] + return self.generic_visit(node) + + def visit_AugAssign(self, node): + return self.visit_BinOp(node) + + def visit_For(self, node): + raise NotImplementedError + + def visit_While(self, node): + raise NotImplementedError + +def count_arithmetic_ops_code(code): + ctr = ArithmeticCounter() + if isinstance(code, (tuple, list)): + for stmt in code: + ctr.visit(stmt) + elif isinstance(code, str): + ctr.visit(ast.parse(code)) + else: + ctr.visit(code) + return ctr.count + +class DepthCounter(ast.NodeVisitor): + + def __init__(self): + self.count = 0 + + # TODO: if we have a tasklet like _out = 2 * _in + 500 + # will this then have depth of 2? or not because of instruction level parallelism? + def visit_BinOp(self, node): + if isinstance(node.op, ast.MatMult): + raise NotImplementedError('MatMult op count requires shape ' + 'inference') + self.count += 1 + return self.generic_visit(node) + + def visit_UnaryOp(self, node): + self.count += 1 + return self.generic_visit(node) + + def visit_Call(self, node): + fname = astunparse.unparse(node.func)[:-1] + if fname not in PYFUNC_TO_ARITHMETICS: + print('WARNING: Unrecognized python function "%s"' % fname) + return self.generic_visit(node) + self.count += PYFUNC_TO_ARITHMETICS[fname] + return self.generic_visit(node) + + def visit_AugAssign(self, node): + return self.visit_BinOp(node) + + def visit_For(self, node): + raise NotImplementedError + + def visit_While(self, node): + raise NotImplementedError + +def count_depth_code(code): + ctr = DepthCounter() + if isinstance(code, (tuple, list)): + for stmt in code: + ctr.visit(stmt) + elif isinstance(code, str): + ctr.visit(ast.parse(code)) + else: + ctr.visit(code) + return ctr.count + +def tasklet_work(tasklet_node, state): + if tasklet_node.code.language == dtypes.Language.CPP: + for oedge in state.out_edges(tasklet_node): + return bigo(oedge.data.num_accesses) + + elif tasklet_node.code.language == dtypes.Language.Python: + return count_arithmetic_ops_code(tasklet_node.code.code) + else: + # other languages not implemented, count whole tasklet as work of 1 + warnings.warn('Work of tasklets only properly analyzed for Python or CPP. For all other ' + 'languages work = 1 will be counted for each tasklet.') + return 1 + +def tasklet_depth(tasklet_node, state): + # if tasklet_node.code.language == dtypes.Language.CPP: + # for oedge in state.out_edges(tasklet_node): + # return bigo(oedge.data.num_accesses) + + if tasklet_node.code.language == dtypes.Language.Python: + return count_depth_code(tasklet_node.code.code) + else: + # other languages not implemented, count whole tasklet as work of 1 + warnings.warn('Depth of tasklets only properly analyzed for Python code. For all other ' + 'languages depth = 1 will be counted for each tasklet.') + return 1 + +def get_tasklet_work(node, state): + return tasklet_work(node, state), -1 + +def get_tasklet_work_depth(node, state): + return tasklet_work(node, state), tasklet_depth(node, state) + +def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], analyze_tasklet, syms_to_nonnegify) -> None: + print('Analyzing work and depth of SDFG', sdfg.name) + print('SDFG has', len(sdfg.nodes()), 'states') + print('Calculating work and depth for all states individually...') + + # First determine the work and depth of each state individually. + # Keep track of the work and depth for each state in a dictionary, where work and depth are multiplied by the number + # of times the state will be executed. + state_depths: Dict[SDFGState, sp.Expr] = {} + state_works: Dict[SDFGState, sp.Expr] = {} + for state in sdfg.nodes(): + state_work, state_depth = state_work_depth(state, w_d_map, analyze_tasklet, syms_to_nonnegify) + if state.executions == 0:# or state.executions == sp.zoo: + print('State executions must be statically known exactly or with an upper bound. Offender:', state) + new_symbol = sp.Symbol(f'num_execs_{sdfg.sdfg_id}_{sdfg.node_id(state)}') + state.executions = new_symbol + syms_to_nonnegify |= {new_symbol} + state_works[state] = state_work * state.executions + state_depths[state] = state_depth * state.executions + w_d_map[get_uuid(state)] = (sp.simplify(state_work * state.executions), sp.simplify(state_depth * state.executions)) + + print('Calculating work and depth of the SDFG...') + + + nodes_oNodes_exits = find_loop_guards_tails_exits(sdfg._nx) + print(nodes_oNodes_exits) + # Now we need to go over each triple (node, oNode, exits) + # for each triple, we + # - remove edge (oNode, node), i.e. the backward edge + # - for all exits e, add edge (oNode, e). This edge may already exist + + for node, oNode, exits in nodes_oNodes_exits: + sdfg.remove_edge(sdfg.edges_between(oNode, node)[0]) + for e in exits: + # TODO: This will probably fail if len(exits) > 1, but in which cases does that even happen? + if len(sdfg.edges_between(oNode, e)) == 0: + # no edge there yet + sdfg.add_edge(oNode, e, InterstateEdge()) + + # Prepare the SDFG for a detph analysis by 'inlining' loops. This removes the edge between the guard and the exit + # state and the edge between the last loop state and the guard, and instead places an edge between the last loop + # state and the exit state. Additionally, construct a dummy exit state and connect every state that has no outgoing + # edges to it. + + + + + + dummy_exit = sdfg.add_state('dummy_exit') + for state in sdfg.nodes(): + """ + if hasattr(state, 'condition_edge') and hasattr(state, 'is_loop_guard') and state.is_loop_guard: + # This is a loop guard. + loop_begin = state.condition_edge.dst + # Determine loop states through a depth first search from the start of the loop. Everything reached before + # arriving back at the loop guard is part of the loop. + # TODO: This is hacky. Loops should report the loop states directly. This may fail or behave unexpectedly + # for break/return statements inside of loops. + loop_states = set(sdutil.dfs_conditional(sdfg, sources=[loop_begin], condition=lambda _, s: s != state)) + loop_exit = None + exit_edge = None + loop_end = None + end_edge = None + for iedge in sdfg.in_edges(state): + if iedge.src in loop_states: + end_edge = iedge + loop_end = iedge.src + for oedge in sdfg.out_edges(state): + if oedge.dst not in loop_states: + loop_exit = oedge.dst + exit_edge = oedge + + if loop_exit is None or loop_end is None: + raise RuntimeError('Failed to analyze the depth of a loop starting at', state) + + sdfg.remove_edge(exit_edge) + sdfg.remove_edge(end_edge) + sdfg.add_edge(loop_end, loop_exit, InterstateEdge()) + #""" + + if len(sdfg.out_edges(state)) == 0 and state != dummy_exit: + sdfg.add_edge(state, dummy_exit, InterstateEdge()) + + depth_map: Dict[SDFGState, sp.Expr] = {} + work_map: Dict[SDFGState, sp.Expr] = {} + state_depths[dummy_exit] = sp.sympify(0) + state_works[dummy_exit] = sp.sympify(0) + + # Perform a BFS traversal of the state machine and calculate the maximum work / depth at each state. Only advance to + # the next state in the BFS if all incoming edges have been visited, to ensure the maximum work / depth expressions + # have been calculated. + traversal_q = deque() + traversal_q.append((sdfg.start_state, sp.sympify(0), sp.sympify(0), None)) + visited = set() + while traversal_q: + state, depth, work, ie = traversal_q.popleft() + + if ie is not None: + visited.add(ie) + + n_depth = sp.simplify(depth + state_depths[state]) + n_work = sp.simplify(work + state_works[state]) + + if state in depth_map: + depth_map[state] = sp.Max(depth_map[state], n_depth) + else: + depth_map[state] = n_depth + + if state in work_map: + work_map[state] = sp.Max(work_map[state], n_work) + else: + work_map[state] = n_work + + out_edges = sdfg.out_edges(state) + if any(iedge not in visited for iedge in sdfg.in_edges(state)): + pass + else: + for oedge in out_edges: + traversal_q.append((oedge.dst, depth_map[state], work_map[state], oedge)) + + max_depth = depth_map[dummy_exit] + max_work = work_map[dummy_exit] + + print('SDFG', sdfg.name, 'processed') + w_d_map[get_uuid(sdfg)] = (sp.simplify(max_work), sp.simplify(max_depth)) + return sp.simplify(max_work), sp.simplify(max_depth) + + +""" +Analyze the work and depth of a scope. +This works by constructing a proxy graph of the scope and then finding the maximum depth path in that graph between +the source and sink. The proxy graph is constructed to remove any multi-edges between nodes and to remove nodes that +do not contribute to the depth. Additionally, nested scopes are summarized into single nodes. All of this is necessary +to reduce the number of possible paths in the graph, as much as possible, since they all have to be brute-force +enumerated to find the maximum depth path. +:note: This is terribly inefficient and should be improved. +:param state: The state in which the scope to analyze is contained. +:param sym_map: A dictionary mapping symbols to their values. +:param entry: The entry node of the scope to analyze. If None, the entire state is analyzed. +:return: A tuple containing the work and depth of the scope. +""" +def scope_work_depth(state: SDFGState, w_d_map: Dict[str, sp.Expr], analyze_tasklet, syms_to_nonnegify, entry: nd.EntryNode = None) -> Tuple[sp.Expr, sp.Expr]: + + # find the work / depth of each node + # for maps and nested SDFG, we do it recursively + work = sp.sympify(0) + max_depth = sp.sympify(0) + scope_nodes = state.scope_children()[entry] + scope_exit = None if entry is None else state.exit_node(entry) + for node in scope_nodes: + # add node to map + w_d_map[get_uuid(node, state)] = (sp.sympify(0), sp.sympify(0)) # TODO: do we need this line? + if isinstance(node, nd.EntryNode): + # If the scope contains an entry node, we need to recursively analyze the scope of the entry node first. + # The resulting work/depth are summarized into the entry node + s_work, s_depth = scope_work_depth(state, w_d_map, analyze_tasklet, syms_to_nonnegify, node) + # add up work for whole state, but also save work for this sub-scope scope in w_d_map + work += s_work + w_d_map[get_uuid(node, state)] = (s_work, s_depth) + elif node == scope_exit: + pass + elif isinstance(node, nd.Tasklet): + # add up work for whole state, but also save work for this node in w_d_map + t_work, t_depth = analyze_tasklet(node, state) + work += t_work + w_d_map[get_uuid(node, state)] = (sp.sympify(t_work), sp.sympify(t_depth)) + elif isinstance(node, nd.NestedSDFG): + # Nested SDFGs are recursively analyzed first. + nsdfg_work, nsdfg_depth = sdfg_work_depth(node.sdfg, w_d_map, analyze_tasklet, syms_to_nonnegify) + + # add up work for whole state, but also save work for this nested SDFG in w_d_map + work += nsdfg_work + w_d_map[get_uuid(node, state)] = (nsdfg_work, nsdfg_depth) + + if entry is not None: + # If the scope being analyzed is a map, multiply the work by the number of iterations of the map. + if isinstance(entry, nd.MapEntry): + nmap: nd.Map = entry.map + range: Range = nmap.range + n_exec = range.num_elements_exact() + work = work * sp.simplify(n_exec) + else: + print('WARNING: Only Map scopes are supported in work analysis for now. Assuming 1 iteration.') + + + # TODO: Kinda ugly if condition... + # only do this if we even analyzed depth of tasklets + max_depth = sp.sympify(0) + if analyze_tasklet == get_tasklet_work_depth: + # Calculate the maximum depth of the scope by finding the 'deepest' path from the source to the sink. This is done by + # a BFS in topological order, where each node propagates its current max depth for all incoming paths + traversal_q = deque() + visited = set() + # find all starting nodes + if entry: + # the entry is the starting node + traversal_q.append((entry, sp.sympify(0), None)) + else: + for node in scope_nodes: + if len(state.in_edges(node)) == 0: + # push this node into the deque + traversal_q.append((node, sp.sympify(0), None)) + + + depth_map = {} + + while traversal_q: + node, in_depth, in_edge = traversal_q.popleft() + + if in_edge is not None: + visited.add(in_edge) + + n_depth = sp.simplify(in_depth + w_d_map[get_uuid(node, state)][1]) + + if node in depth_map: + depth_map[node] = sp.Max(depth_map[node], n_depth) + else: + depth_map[node] = n_depth + + out_edges = state.out_edges(node) + # only advance to next node, if all incoming edges have been visited or the current node is the entry (aka starting node) + # if the current node is the exit of the current scope, we stop, such that we don't leave the current scope + if (all(iedge in visited for iedge in state.in_edges(node)) or node == entry) and node != scope_exit: + # if we encounter a nested map, we must not analyze its contents (as they have already been recursively analyzed) + # hence, we continue from the outgoing edges of the corresponding exit + if isinstance(node, nd.EntryNode) and node != entry: + # get the corresponding exit note + exit_node = state.exit_node(node) + # replace out_edges with the out_edges of the scope exit node + out_edges = state.out_edges(exit_node) + for oedge in out_edges: + traversal_q.append((oedge.dst, depth_map[node], oedge)) + if len(out_edges) == 0 or node == scope_exit: + # this is an end node --> update max_depth + max_depth = sp.Max(max_depth, depth_map[node]) + + # summarise work / depth of the whole state in the dictionary + w_d_map[get_uuid(state)] = (sp.simplify(work), sp.simplify(max_depth)) + return sp.simplify(work), sp.simplify(max_depth) + +""" +Analyze the work and depth of a state. +:param state: The state to analyze. +:param sym_map: A dictionary mapping symbols to their values. +:return: A tuple containing the work and depth of the state. +""" +def state_work_depth(state: SDFGState, w_d_map: Dict[str, sp.Expr], analyze_tasklet, syms_to_nonnegify) -> None: + work, depth = scope_work_depth(state, w_d_map, analyze_tasklet, syms_to_nonnegify, None) + return work, depth + + +""" +Analyze the work and depth of an SDFG. +Optionally, a dictionary mapping symbols to their values can be provided to concretize the analysis. +Note that this also significantly speeds up the analysis due to sympy not having to perform the analysis symbolically. +:note: SDFGs should have split interstate edges. This means there should be no interstate edges containing both a + condition and an assignment. +:param sdfg: The SDFG to analyze. +:param sym_map: A dictionary mapping symbols to their values. +:return: A tuple containing the work and depth of the SDFG +""" +# def analyze_sdfg(sdfg: SDFG, w_d_map: Dict[str, str], sym_map: Dict[str, int]) -> Dict[str, Tuple[str, str]]: +def analyze_sdfg(sdfg: SDFG, w_d_map: Dict[str, sp.Expr], analyze_tasklet) -> Dict[str, Tuple[sp.Expr, sp.Expr]]: + # Run state propagation for all SDFGs recursively. This is necessary to determine the number of times each state + # will be executed, or to determine upper bounds for that number (such as in the case of branching) + print('Propagating states...') + for sd in sdfg.all_sdfgs_recursive(): + propagation.propagate_states(sd) + + # deepcopy such that original sdfg not changed + # sdfg = deepcopy(sdfg) + + # Check if the SDFG has any dynamically unbounded executions, i.e., if there are any states that have neither a + # statically known number of executions, nor an upper bound on the number of executions. Warn if this is the case. + print('Checking for dynamically unbounded executions...') + for sd in sdfg.all_sdfgs_recursive(): + if any([s.executions == 0 and s.dynamic_executions for s in sd.nodes()]): + print('WARNING: SDFG has dynamic executions. The analysis may fail in unexpected ways or be inaccurate.') + + syms_to_nonnegify = set() + # Analyze the work and depth of the SDFG. + print('Analyzing SDFG...') + sdfg_work_depth(sdfg, w_d_map, analyze_tasklet, syms_to_nonnegify) + + # TODO: maybe do this posify more often for performance? + array_symbols = get_array_size_symbols(sdfg) + for k, (v_w, v_d) in w_d_map.items(): + v_w = posify_certain_symbols(v_w, array_symbols, syms_to_nonnegify) + v_d = posify_certain_symbols(v_d, array_symbols, syms_to_nonnegify) + w_d_map[k] = (v_w, v_d) + + + + + +def get_work(sdfg_json): + # final version loads sdfg from json + # loaded = load_sdfg_from_json(sdfg_json) + # if loaded['error'] is not None: + # return loaded['error'] + # sdfg = loaded['sdfg'] + + # for now we load simply load from a file + sdfg = SDFG.from_file(sdfg_json) + + + + # try: + work_map = {} + analyze_sdfg(sdfg, work_map, get_tasklet_work) + for k, v, in work_map.items(): + work_map[k] = (str(sp.simplify(v[0]))) + return { + 'workMap': work_map, + } + # except Exception as e: + # return { + # 'error': { + # 'message': 'Failed to analyze work depth', + # 'details': get_exception_message(e), + # }, + # } + + + +def get_work_depth(sdfg_json): + # final version loads sdfg from json + # loaded = load_sdfg_from_json(sdfg_json) + # if loaded['error'] is not None: + # return loaded['error'] + # sdfg = loaded['sdfg'] + + # for now we load simply load from a file + sdfg = SDFG.from_file(sdfg_json) + + + # try: + work_depth_map = {} + analyze_sdfg(sdfg, work_depth_map, get_tasklet_work_depth) + for k, v, in work_depth_map.items(): + work_depth_map[k] = (str(sp.simplify(v[0])), str(sp.simplify(v[1]))) + return { + 'workDepthMap': work_depth_map, + } + # except Exception as e: + # return { + # 'error': { + # 'message': 'Failed to analyze work depth', + # 'details': get_exception_message(e), + # }, + # } + + + + + +################################################################################ +# Utility functions for running the analysis from the command line ############# +################################################################################ + +class keyvalue(argparse.Action): + + def __call__(self, parser, namespace, values, option_string=None): + setattr(namespace, self.dest, dict()) + for v in values: + k, v = v.split('=') + getattr(namespace, self.dest)[k] = v + + +def main() -> None: + analyze_depth = True + + parser = argparse.ArgumentParser( + 'work_depth_analysis', + usage='python work_depth_analysis.py [-h] filename', + description='Analyze the work/depth of an SDFG.' + ) + + parser.add_argument('filename', type=str, help='The SDFG file to analyze.') + parser.add_argument('--kwargs', nargs='*', help='Define symbols.', action=keyvalue) + + args = parser.parse_args() + + if not os.path.exists(args.filename): + print(args.filename, 'does not exist.') + exit() + + symbols_map = {} + if args.kwargs: + for k, v in args.kwargs.items(): + symbols_map[k] = int(v) + + # TODO: symbols_map maybe not needed + if analyze_depth: + map = get_work_depth(args.filename) + map = map['workDepthMap'] + else: + map = get_work(args.filename) + map = map['workMap'] + + # find uuid of the whole SDFG + sdfg = SDFG.from_file(args.filename) + result = map[get_uuid(sdfg)] + + + print(80*'-') + if isinstance(result, Tuple): + print("Work:\t", result[0]) + print("Depth:\t", result[1]) + else: + print("Work:\t", result) + + print(80*'-') + + + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/dace/sdfg/work_depth_analysis/work_depth_tests.py b/dace/sdfg/work_depth_analysis/work_depth_tests.py index 46eb36c3bf..e233280143 100644 --- a/dace/sdfg/work_depth_analysis/work_depth_tests.py +++ b/dace/sdfg/work_depth_analysis/work_depth_tests.py @@ -1,6 +1,6 @@ import dace as dc import numpy as np -from dace.sdfg.work_depth_analysis.work_depth_analysis import analyze_sdfg, get_tasklet_work_depth +from dace.sdfg.work_depth_analysis.work_depth import analyze_sdfg, get_tasklet_work_depth from dace.sdfg.work_depth_analysis.helpers import get_uuid import sympy as sp From 0b3fdeacfb9778633fd4a4d4d7d0fea7061f6f80 Mon Sep 17 00:00:00 2001 From: Cliff Hodel Date: Tue, 11 Jul 2023 14:07:32 +0200 Subject: [PATCH 04/18] todos added --- dace/sdfg/work_depth_analysis/work_depth.py | 26 +- .../work_depth_analysis.py | 749 ------------------ .../work_depth_analysis/work_depth_tests.py | 2 +- 3 files changed, 17 insertions(+), 760 deletions(-) delete mode 100644 dace/sdfg/work_depth_analysis/work_depth_analysis.py diff --git a/dace/sdfg/work_depth_analysis/work_depth.py b/dace/sdfg/work_depth_analysis/work_depth.py index ef0cd0e1fb..2defcd17c0 100644 --- a/dace/sdfg/work_depth_analysis/work_depth.py +++ b/dace/sdfg/work_depth_analysis/work_depth.py @@ -198,7 +198,9 @@ def count_reduce(node, symbols, state): 'max': 0, 'ceiling': 0, 'floor': 0, -} + # TODO: what about: "cos", "sin", "sqrt", "atan2" (source: npbench's arc_distance), "abs" (mandelbrot1), "exp" (mlp) + # tanh (go_fast) +} LIBNODES_TO_ARITHMETICS = { MatMul: count_matmul, Transpose: lambda *args: 0, @@ -225,6 +227,7 @@ def visit_UnaryOp(self, node): def visit_Call(self, node): fname = astunparse.unparse(node.func)[:-1] if fname not in PYFUNC_TO_ARITHMETICS: + # TODO: why do we get this warning? WARNING: Unrecognized python function "dace.float64" (source: npbench's azimint_hist) print('WARNING: Unrecognized python function "%s"' % fname) return self.generic_visit(node) self.count += PYFUNC_TO_ARITHMETICS[fname] @@ -317,6 +320,7 @@ def tasklet_depth(tasklet_node, state): if tasklet_node.code.language == dtypes.Language.Python: return count_depth_code(tasklet_node.code.code) else: + # TODO: improve this # other languages not implemented, count whole tasklet as work of 1 warnings.warn('Depth of tasklets only properly analyzed for Python code. For all other ' 'languages depth = 1 will be counted for each tasklet.') @@ -329,9 +333,9 @@ def get_tasklet_work_depth(node, state): return tasklet_work(node, state), tasklet_depth(node, state) def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], analyze_tasklet, syms_to_nonnegify) -> None: - print('Analyzing work and depth of SDFG', sdfg.name) - print('SDFG has', len(sdfg.nodes()), 'states') - print('Calculating work and depth for all states individually...') + # print('Analyzing work and depth of SDFG', sdfg.name) + # print('SDFG has', len(sdfg.nodes()), 'states') + # print('Calculating work and depth for all states individually...') # First determine the work and depth of each state individually. # Keep track of the work and depth for each state in a dictionary, where work and depth are multiplied by the number @@ -341,7 +345,7 @@ def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], ana for state in sdfg.nodes(): state_work, state_depth = state_work_depth(state, w_d_map, analyze_tasklet, syms_to_nonnegify) if state.executions == 0:# or state.executions == sp.zoo: - print('State executions must be statically known exactly or with an upper bound. Offender:', state) + # print('State executions must be statically known exactly or with an upper bound. Offender:', state) new_symbol = sp.Symbol(f'num_execs_{sdfg.sdfg_id}_{sdfg.node_id(state)}') state.executions = new_symbol syms_to_nonnegify |= {new_symbol} @@ -349,7 +353,7 @@ def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], ana state_depths[state] = state_depth * state.executions w_d_map[get_uuid(state)] = (sp.simplify(state_work * state.executions), sp.simplify(state_depth * state.executions)) - print('Calculating work and depth of the SDFG...') + # print('Calculating work and depth of the SDFG...') nodes_oNodes_exits = find_loop_guards_tails_exits(sdfg._nx) @@ -593,7 +597,7 @@ def state_work_depth(state: SDFGState, w_d_map: Dict[str, sp.Expr], analyze_task def analyze_sdfg(sdfg: SDFG, w_d_map: Dict[str, sp.Expr], analyze_tasklet) -> Dict[str, Tuple[sp.Expr, sp.Expr]]: # Run state propagation for all SDFGs recursively. This is necessary to determine the number of times each state # will be executed, or to determine upper bounds for that number (such as in the case of branching) - print('Propagating states...') + # print('Propagating states...') for sd in sdfg.all_sdfgs_recursive(): propagation.propagate_states(sd) @@ -602,14 +606,15 @@ def analyze_sdfg(sdfg: SDFG, w_d_map: Dict[str, sp.Expr], analyze_tasklet) -> Di # Check if the SDFG has any dynamically unbounded executions, i.e., if there are any states that have neither a # statically known number of executions, nor an upper bound on the number of executions. Warn if this is the case. - print('Checking for dynamically unbounded executions...') + # print('Checking for dynamically unbounded executions...') for sd in sdfg.all_sdfgs_recursive(): if any([s.executions == 0 and s.dynamic_executions for s in sd.nodes()]): - print('WARNING: SDFG has dynamic executions. The analysis may fail in unexpected ways or be inaccurate.') + pass + # print('WARNING: SDFG has dynamic executions. The analysis may fail in unexpected ways or be inaccurate.') syms_to_nonnegify = set() # Analyze the work and depth of the SDFG. - print('Analyzing SDFG...') + # print('Analyzing SDFG...') sdfg_work_depth(sdfg, w_d_map, analyze_tasklet, syms_to_nonnegify) # TODO: maybe do this posify more often for performance? @@ -618,6 +623,7 @@ def analyze_sdfg(sdfg: SDFG, w_d_map: Dict[str, sp.Expr], analyze_tasklet) -> Di v_w = posify_certain_symbols(v_w, array_symbols, syms_to_nonnegify) v_d = posify_certain_symbols(v_d, array_symbols, syms_to_nonnegify) w_d_map[k] = (v_w, v_d) + sdfg.view() diff --git a/dace/sdfg/work_depth_analysis/work_depth_analysis.py b/dace/sdfg/work_depth_analysis/work_depth_analysis.py deleted file mode 100644 index ef0cd0e1fb..0000000000 --- a/dace/sdfg/work_depth_analysis/work_depth_analysis.py +++ /dev/null @@ -1,749 +0,0 @@ -import argparse -from collections import deque -from dace.sdfg import nodes as nd, propagation, InterstateEdge, utils as sdutil -from dace import SDFG, SDFGState, dtypes -from dace.subsets import Range -from typing import Tuple, Dict -import os -import sympy as sp -import networkx as nx -from copy import deepcopy -from dace.libraries.blas import MatMul, Transpose -from dace.libraries.standard import Reduce -from dace.symbolic import pystr_to_symbolic -import ast -import astunparse -import warnings -from dace.sdfg.graph import Edge - -from dace.sdfg.work_depth_analysis.helpers import get_uuid, get_domtree, backedges as get_backedges - - -def find_loop_guards_tails_exits(sdfg_nx: nx.DiGraph): - # preparation phase: compute dominators, backedges etc - for node in sdfg_nx.nodes(): - if sdfg_nx.in_degree(node) == 0: - start = node - break - if start is None: - raise ValueError('No start node could be determined') - - # sdfg can have multiple end nodes --> not good for postDomTree - # --> add a new end node - artificial_end_node = 'artificial_end_node' - sdfg_nx.add_node(artificial_end_node) - for node in sdfg_nx.nodes(): - if sdfg_nx.out_degree(node) == 0 and node != artificial_end_node: - # this is an end node of the sdfg - sdfg_nx.add_edge(node, artificial_end_node) - - # sanity check: - if sdfg_nx.in_degree(artificial_end_node) == 0: - raise ValueError('No end node could be determined in the SDFG') - - - - iDoms = nx.immediate_dominators(sdfg_nx, start) - allDom, domTree = get_domtree(sdfg_nx, start, iDoms) - - reversed_sdfg_nx = sdfg_nx.reverse() - iPostDoms = nx.immediate_dominators(reversed_sdfg_nx, artificial_end_node) - allPostDoms, postDomTree = get_domtree(reversed_sdfg_nx, artificial_end_node, iPostDoms) - - backedges = get_backedges(sdfg_nx, start) - backedgesDstDict = {} - for be in backedges: - if be[1] in backedgesDstDict: - backedgesDstDict[be[1]].add(be) - else: - backedgesDstDict[be[1]] = set([be]) - - - nodes_oNodes_exits = [] - - # iterate over all nodes - for node in sdfg_nx.nodes(): - # does any backedge end in node - if node in backedgesDstDict: - inc_backedges = backedgesDstDict[node] - - - # gather all successors of node that are not reached by backedges - successors = [] - for edge in sdfg_nx.out_edges(node): - if not edge in backedges: - successors.append(edge[1]) - - - # if len(inc_backedges) > 1: - # raise ValueError('node has multiple incoming backedges...') - # instead: if multiple incoming backedges, do the below for each backedge - for be in inc_backedges: - - - # since node has an incoming backedge, it is either a loop guard or loop tail - # oNode will exactly be the other thing - oNode = be[0] - exitCandidates = set() - for succ in successors: - if succ != oNode and oNode not in allDom[succ]: - exitCandidates.add(succ) - for succ in sdfg_nx.successors(oNode): - if succ != node: - exitCandidates.add(succ) - - if len(exitCandidates) == 0: - raise ValueError('failed to find any exit nodes') - elif len(exitCandidates) > 1: - # // Find the exit candidate that sits highest up in the - # // postdominator tree (i.e., has the lowest level). - # // That must be the exit node (it must post-dominate) - # // everything inside the loop. If there are multiple - # // candidates on the lowest level (i.e., disjoint set of - # // postdominated nodes), there are multiple exit paths, - # // and they all share one level. - cand = exitCandidates.pop() - minSet = set([cand]) - minLevel = nx.get_node_attributes(postDomTree, 'level')[cand] - for cand in exitCandidates: - curr_level = nx.get_node_attributes(postDomTree, 'level')[cand] - if curr_level < minLevel: - # new minimum found - minLevel = curr_level - minSet.clear() - minSet.add(cand) - elif curr_level == minLevel: - # add cand to curr set - minSet.add(cand) - - if len(minSet) > 0: - exitCandidates = minSet - else: - raise ValueError('failed to find exit minSet') - - # now we have a triple (node, oNode, exitCandidates) - nodes_oNodes_exits.append((node, oNode, exitCandidates)) - - return nodes_oNodes_exits - - - -def get_array_size_symbols(sdfg): - symbols = set() - for _, _, arr in sdfg.arrays_recursive(): - for s in arr.shape: - if isinstance(s, sp.Symbol): - symbols.add(s) - return symbols - -def posify_certain_symbols(expr, syms_to_posify, syms_to_nonnegify): - expr = sp.sympify(expr) - nonneg = {s: sp.Dummy(s.name, nonnegative=True, **s.assumptions0) - for s in syms_to_nonnegify if s.is_nonnegative is None} - pos = {s: sp.Dummy(s.name, positive=True, **s.assumptions0) - for s in syms_to_posify if s.is_positive is None} - # merge the two dicts into reps - reps = {**nonneg, **pos} - expr = expr.subs(reps) - return expr.subs({r: s for s, r in reps.items()}) - -def symeval(val, symbols): - first_replacement = { - pystr_to_symbolic(k): pystr_to_symbolic('__REPLSYM_' + k) - for k in symbols.keys() - } - second_replacement = { - pystr_to_symbolic('__REPLSYM_' + k): v - for k, v in symbols.items() - } - return val.subs(first_replacement).subs(second_replacement) - -def count_matmul(node, symbols, state): - A_memlet = next(e for e in state.in_edges(node) if e.dst_conn == '_a') - B_memlet = next(e for e in state.in_edges(node) if e.dst_conn == '_b') - C_memlet = next(e for e in state.out_edges(node) if e.src_conn == '_c') - result = 2 # Multiply, add - # Batch - if len(C_memlet.data.subset) == 3: - result *= symeval(C_memlet.data.subset.size()[0], symbols) - # M*N - result *= symeval(C_memlet.data.subset.size()[-2], symbols) - result *= symeval(C_memlet.data.subset.size()[-1], symbols) - # K - result *= symeval(A_memlet.data.subset.size()[-1], symbols) - return result - - -def count_reduce(node, symbols, state): - result = 0 - if node.wcr is not None: - result += count_arithmetic_ops_code(node.wcr) - in_memlet = None - in_edges = state.in_edges(node) - if in_edges is not None and len(in_edges) == 1: - in_memlet = in_edges[0] - if in_memlet is not None and in_memlet.data.volume is not None: - result *= in_memlet.data.volume - else: - result = 0 - return result - -bigo = sp.Function('bigo') -PYFUNC_TO_ARITHMETICS = { - 'float': 0, - 'math.exp': 1, - 'math.tanh': 1, - 'math.sqrt': 1, - 'min': 0, - 'max': 0, - 'ceiling': 0, - 'floor': 0, -} -LIBNODES_TO_ARITHMETICS = { - MatMul: count_matmul, - Transpose: lambda *args: 0, - Reduce: count_reduce, -} - - -class ArithmeticCounter(ast.NodeVisitor): - - def __init__(self): - self.count = 0 - - def visit_BinOp(self, node): - if isinstance(node.op, ast.MatMult): - raise NotImplementedError('MatMult op count requires shape ' - 'inference') - self.count += 1 - return self.generic_visit(node) - - def visit_UnaryOp(self, node): - self.count += 1 - return self.generic_visit(node) - - def visit_Call(self, node): - fname = astunparse.unparse(node.func)[:-1] - if fname not in PYFUNC_TO_ARITHMETICS: - print('WARNING: Unrecognized python function "%s"' % fname) - return self.generic_visit(node) - self.count += PYFUNC_TO_ARITHMETICS[fname] - return self.generic_visit(node) - - def visit_AugAssign(self, node): - return self.visit_BinOp(node) - - def visit_For(self, node): - raise NotImplementedError - - def visit_While(self, node): - raise NotImplementedError - -def count_arithmetic_ops_code(code): - ctr = ArithmeticCounter() - if isinstance(code, (tuple, list)): - for stmt in code: - ctr.visit(stmt) - elif isinstance(code, str): - ctr.visit(ast.parse(code)) - else: - ctr.visit(code) - return ctr.count - -class DepthCounter(ast.NodeVisitor): - - def __init__(self): - self.count = 0 - - # TODO: if we have a tasklet like _out = 2 * _in + 500 - # will this then have depth of 2? or not because of instruction level parallelism? - def visit_BinOp(self, node): - if isinstance(node.op, ast.MatMult): - raise NotImplementedError('MatMult op count requires shape ' - 'inference') - self.count += 1 - return self.generic_visit(node) - - def visit_UnaryOp(self, node): - self.count += 1 - return self.generic_visit(node) - - def visit_Call(self, node): - fname = astunparse.unparse(node.func)[:-1] - if fname not in PYFUNC_TO_ARITHMETICS: - print('WARNING: Unrecognized python function "%s"' % fname) - return self.generic_visit(node) - self.count += PYFUNC_TO_ARITHMETICS[fname] - return self.generic_visit(node) - - def visit_AugAssign(self, node): - return self.visit_BinOp(node) - - def visit_For(self, node): - raise NotImplementedError - - def visit_While(self, node): - raise NotImplementedError - -def count_depth_code(code): - ctr = DepthCounter() - if isinstance(code, (tuple, list)): - for stmt in code: - ctr.visit(stmt) - elif isinstance(code, str): - ctr.visit(ast.parse(code)) - else: - ctr.visit(code) - return ctr.count - -def tasklet_work(tasklet_node, state): - if tasklet_node.code.language == dtypes.Language.CPP: - for oedge in state.out_edges(tasklet_node): - return bigo(oedge.data.num_accesses) - - elif tasklet_node.code.language == dtypes.Language.Python: - return count_arithmetic_ops_code(tasklet_node.code.code) - else: - # other languages not implemented, count whole tasklet as work of 1 - warnings.warn('Work of tasklets only properly analyzed for Python or CPP. For all other ' - 'languages work = 1 will be counted for each tasklet.') - return 1 - -def tasklet_depth(tasklet_node, state): - # if tasklet_node.code.language == dtypes.Language.CPP: - # for oedge in state.out_edges(tasklet_node): - # return bigo(oedge.data.num_accesses) - - if tasklet_node.code.language == dtypes.Language.Python: - return count_depth_code(tasklet_node.code.code) - else: - # other languages not implemented, count whole tasklet as work of 1 - warnings.warn('Depth of tasklets only properly analyzed for Python code. For all other ' - 'languages depth = 1 will be counted for each tasklet.') - return 1 - -def get_tasklet_work(node, state): - return tasklet_work(node, state), -1 - -def get_tasklet_work_depth(node, state): - return tasklet_work(node, state), tasklet_depth(node, state) - -def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], analyze_tasklet, syms_to_nonnegify) -> None: - print('Analyzing work and depth of SDFG', sdfg.name) - print('SDFG has', len(sdfg.nodes()), 'states') - print('Calculating work and depth for all states individually...') - - # First determine the work and depth of each state individually. - # Keep track of the work and depth for each state in a dictionary, where work and depth are multiplied by the number - # of times the state will be executed. - state_depths: Dict[SDFGState, sp.Expr] = {} - state_works: Dict[SDFGState, sp.Expr] = {} - for state in sdfg.nodes(): - state_work, state_depth = state_work_depth(state, w_d_map, analyze_tasklet, syms_to_nonnegify) - if state.executions == 0:# or state.executions == sp.zoo: - print('State executions must be statically known exactly or with an upper bound. Offender:', state) - new_symbol = sp.Symbol(f'num_execs_{sdfg.sdfg_id}_{sdfg.node_id(state)}') - state.executions = new_symbol - syms_to_nonnegify |= {new_symbol} - state_works[state] = state_work * state.executions - state_depths[state] = state_depth * state.executions - w_d_map[get_uuid(state)] = (sp.simplify(state_work * state.executions), sp.simplify(state_depth * state.executions)) - - print('Calculating work and depth of the SDFG...') - - - nodes_oNodes_exits = find_loop_guards_tails_exits(sdfg._nx) - print(nodes_oNodes_exits) - # Now we need to go over each triple (node, oNode, exits) - # for each triple, we - # - remove edge (oNode, node), i.e. the backward edge - # - for all exits e, add edge (oNode, e). This edge may already exist - - for node, oNode, exits in nodes_oNodes_exits: - sdfg.remove_edge(sdfg.edges_between(oNode, node)[0]) - for e in exits: - # TODO: This will probably fail if len(exits) > 1, but in which cases does that even happen? - if len(sdfg.edges_between(oNode, e)) == 0: - # no edge there yet - sdfg.add_edge(oNode, e, InterstateEdge()) - - # Prepare the SDFG for a detph analysis by 'inlining' loops. This removes the edge between the guard and the exit - # state and the edge between the last loop state and the guard, and instead places an edge between the last loop - # state and the exit state. Additionally, construct a dummy exit state and connect every state that has no outgoing - # edges to it. - - - - - - dummy_exit = sdfg.add_state('dummy_exit') - for state in sdfg.nodes(): - """ - if hasattr(state, 'condition_edge') and hasattr(state, 'is_loop_guard') and state.is_loop_guard: - # This is a loop guard. - loop_begin = state.condition_edge.dst - # Determine loop states through a depth first search from the start of the loop. Everything reached before - # arriving back at the loop guard is part of the loop. - # TODO: This is hacky. Loops should report the loop states directly. This may fail or behave unexpectedly - # for break/return statements inside of loops. - loop_states = set(sdutil.dfs_conditional(sdfg, sources=[loop_begin], condition=lambda _, s: s != state)) - loop_exit = None - exit_edge = None - loop_end = None - end_edge = None - for iedge in sdfg.in_edges(state): - if iedge.src in loop_states: - end_edge = iedge - loop_end = iedge.src - for oedge in sdfg.out_edges(state): - if oedge.dst not in loop_states: - loop_exit = oedge.dst - exit_edge = oedge - - if loop_exit is None or loop_end is None: - raise RuntimeError('Failed to analyze the depth of a loop starting at', state) - - sdfg.remove_edge(exit_edge) - sdfg.remove_edge(end_edge) - sdfg.add_edge(loop_end, loop_exit, InterstateEdge()) - #""" - - if len(sdfg.out_edges(state)) == 0 and state != dummy_exit: - sdfg.add_edge(state, dummy_exit, InterstateEdge()) - - depth_map: Dict[SDFGState, sp.Expr] = {} - work_map: Dict[SDFGState, sp.Expr] = {} - state_depths[dummy_exit] = sp.sympify(0) - state_works[dummy_exit] = sp.sympify(0) - - # Perform a BFS traversal of the state machine and calculate the maximum work / depth at each state. Only advance to - # the next state in the BFS if all incoming edges have been visited, to ensure the maximum work / depth expressions - # have been calculated. - traversal_q = deque() - traversal_q.append((sdfg.start_state, sp.sympify(0), sp.sympify(0), None)) - visited = set() - while traversal_q: - state, depth, work, ie = traversal_q.popleft() - - if ie is not None: - visited.add(ie) - - n_depth = sp.simplify(depth + state_depths[state]) - n_work = sp.simplify(work + state_works[state]) - - if state in depth_map: - depth_map[state] = sp.Max(depth_map[state], n_depth) - else: - depth_map[state] = n_depth - - if state in work_map: - work_map[state] = sp.Max(work_map[state], n_work) - else: - work_map[state] = n_work - - out_edges = sdfg.out_edges(state) - if any(iedge not in visited for iedge in sdfg.in_edges(state)): - pass - else: - for oedge in out_edges: - traversal_q.append((oedge.dst, depth_map[state], work_map[state], oedge)) - - max_depth = depth_map[dummy_exit] - max_work = work_map[dummy_exit] - - print('SDFG', sdfg.name, 'processed') - w_d_map[get_uuid(sdfg)] = (sp.simplify(max_work), sp.simplify(max_depth)) - return sp.simplify(max_work), sp.simplify(max_depth) - - -""" -Analyze the work and depth of a scope. -This works by constructing a proxy graph of the scope and then finding the maximum depth path in that graph between -the source and sink. The proxy graph is constructed to remove any multi-edges between nodes and to remove nodes that -do not contribute to the depth. Additionally, nested scopes are summarized into single nodes. All of this is necessary -to reduce the number of possible paths in the graph, as much as possible, since they all have to be brute-force -enumerated to find the maximum depth path. -:note: This is terribly inefficient and should be improved. -:param state: The state in which the scope to analyze is contained. -:param sym_map: A dictionary mapping symbols to their values. -:param entry: The entry node of the scope to analyze. If None, the entire state is analyzed. -:return: A tuple containing the work and depth of the scope. -""" -def scope_work_depth(state: SDFGState, w_d_map: Dict[str, sp.Expr], analyze_tasklet, syms_to_nonnegify, entry: nd.EntryNode = None) -> Tuple[sp.Expr, sp.Expr]: - - # find the work / depth of each node - # for maps and nested SDFG, we do it recursively - work = sp.sympify(0) - max_depth = sp.sympify(0) - scope_nodes = state.scope_children()[entry] - scope_exit = None if entry is None else state.exit_node(entry) - for node in scope_nodes: - # add node to map - w_d_map[get_uuid(node, state)] = (sp.sympify(0), sp.sympify(0)) # TODO: do we need this line? - if isinstance(node, nd.EntryNode): - # If the scope contains an entry node, we need to recursively analyze the scope of the entry node first. - # The resulting work/depth are summarized into the entry node - s_work, s_depth = scope_work_depth(state, w_d_map, analyze_tasklet, syms_to_nonnegify, node) - # add up work for whole state, but also save work for this sub-scope scope in w_d_map - work += s_work - w_d_map[get_uuid(node, state)] = (s_work, s_depth) - elif node == scope_exit: - pass - elif isinstance(node, nd.Tasklet): - # add up work for whole state, but also save work for this node in w_d_map - t_work, t_depth = analyze_tasklet(node, state) - work += t_work - w_d_map[get_uuid(node, state)] = (sp.sympify(t_work), sp.sympify(t_depth)) - elif isinstance(node, nd.NestedSDFG): - # Nested SDFGs are recursively analyzed first. - nsdfg_work, nsdfg_depth = sdfg_work_depth(node.sdfg, w_d_map, analyze_tasklet, syms_to_nonnegify) - - # add up work for whole state, but also save work for this nested SDFG in w_d_map - work += nsdfg_work - w_d_map[get_uuid(node, state)] = (nsdfg_work, nsdfg_depth) - - if entry is not None: - # If the scope being analyzed is a map, multiply the work by the number of iterations of the map. - if isinstance(entry, nd.MapEntry): - nmap: nd.Map = entry.map - range: Range = nmap.range - n_exec = range.num_elements_exact() - work = work * sp.simplify(n_exec) - else: - print('WARNING: Only Map scopes are supported in work analysis for now. Assuming 1 iteration.') - - - # TODO: Kinda ugly if condition... - # only do this if we even analyzed depth of tasklets - max_depth = sp.sympify(0) - if analyze_tasklet == get_tasklet_work_depth: - # Calculate the maximum depth of the scope by finding the 'deepest' path from the source to the sink. This is done by - # a BFS in topological order, where each node propagates its current max depth for all incoming paths - traversal_q = deque() - visited = set() - # find all starting nodes - if entry: - # the entry is the starting node - traversal_q.append((entry, sp.sympify(0), None)) - else: - for node in scope_nodes: - if len(state.in_edges(node)) == 0: - # push this node into the deque - traversal_q.append((node, sp.sympify(0), None)) - - - depth_map = {} - - while traversal_q: - node, in_depth, in_edge = traversal_q.popleft() - - if in_edge is not None: - visited.add(in_edge) - - n_depth = sp.simplify(in_depth + w_d_map[get_uuid(node, state)][1]) - - if node in depth_map: - depth_map[node] = sp.Max(depth_map[node], n_depth) - else: - depth_map[node] = n_depth - - out_edges = state.out_edges(node) - # only advance to next node, if all incoming edges have been visited or the current node is the entry (aka starting node) - # if the current node is the exit of the current scope, we stop, such that we don't leave the current scope - if (all(iedge in visited for iedge in state.in_edges(node)) or node == entry) and node != scope_exit: - # if we encounter a nested map, we must not analyze its contents (as they have already been recursively analyzed) - # hence, we continue from the outgoing edges of the corresponding exit - if isinstance(node, nd.EntryNode) and node != entry: - # get the corresponding exit note - exit_node = state.exit_node(node) - # replace out_edges with the out_edges of the scope exit node - out_edges = state.out_edges(exit_node) - for oedge in out_edges: - traversal_q.append((oedge.dst, depth_map[node], oedge)) - if len(out_edges) == 0 or node == scope_exit: - # this is an end node --> update max_depth - max_depth = sp.Max(max_depth, depth_map[node]) - - # summarise work / depth of the whole state in the dictionary - w_d_map[get_uuid(state)] = (sp.simplify(work), sp.simplify(max_depth)) - return sp.simplify(work), sp.simplify(max_depth) - -""" -Analyze the work and depth of a state. -:param state: The state to analyze. -:param sym_map: A dictionary mapping symbols to their values. -:return: A tuple containing the work and depth of the state. -""" -def state_work_depth(state: SDFGState, w_d_map: Dict[str, sp.Expr], analyze_tasklet, syms_to_nonnegify) -> None: - work, depth = scope_work_depth(state, w_d_map, analyze_tasklet, syms_to_nonnegify, None) - return work, depth - - -""" -Analyze the work and depth of an SDFG. -Optionally, a dictionary mapping symbols to their values can be provided to concretize the analysis. -Note that this also significantly speeds up the analysis due to sympy not having to perform the analysis symbolically. -:note: SDFGs should have split interstate edges. This means there should be no interstate edges containing both a - condition and an assignment. -:param sdfg: The SDFG to analyze. -:param sym_map: A dictionary mapping symbols to their values. -:return: A tuple containing the work and depth of the SDFG -""" -# def analyze_sdfg(sdfg: SDFG, w_d_map: Dict[str, str], sym_map: Dict[str, int]) -> Dict[str, Tuple[str, str]]: -def analyze_sdfg(sdfg: SDFG, w_d_map: Dict[str, sp.Expr], analyze_tasklet) -> Dict[str, Tuple[sp.Expr, sp.Expr]]: - # Run state propagation for all SDFGs recursively. This is necessary to determine the number of times each state - # will be executed, or to determine upper bounds for that number (such as in the case of branching) - print('Propagating states...') - for sd in sdfg.all_sdfgs_recursive(): - propagation.propagate_states(sd) - - # deepcopy such that original sdfg not changed - # sdfg = deepcopy(sdfg) - - # Check if the SDFG has any dynamically unbounded executions, i.e., if there are any states that have neither a - # statically known number of executions, nor an upper bound on the number of executions. Warn if this is the case. - print('Checking for dynamically unbounded executions...') - for sd in sdfg.all_sdfgs_recursive(): - if any([s.executions == 0 and s.dynamic_executions for s in sd.nodes()]): - print('WARNING: SDFG has dynamic executions. The analysis may fail in unexpected ways or be inaccurate.') - - syms_to_nonnegify = set() - # Analyze the work and depth of the SDFG. - print('Analyzing SDFG...') - sdfg_work_depth(sdfg, w_d_map, analyze_tasklet, syms_to_nonnegify) - - # TODO: maybe do this posify more often for performance? - array_symbols = get_array_size_symbols(sdfg) - for k, (v_w, v_d) in w_d_map.items(): - v_w = posify_certain_symbols(v_w, array_symbols, syms_to_nonnegify) - v_d = posify_certain_symbols(v_d, array_symbols, syms_to_nonnegify) - w_d_map[k] = (v_w, v_d) - - - - - -def get_work(sdfg_json): - # final version loads sdfg from json - # loaded = load_sdfg_from_json(sdfg_json) - # if loaded['error'] is not None: - # return loaded['error'] - # sdfg = loaded['sdfg'] - - # for now we load simply load from a file - sdfg = SDFG.from_file(sdfg_json) - - - - # try: - work_map = {} - analyze_sdfg(sdfg, work_map, get_tasklet_work) - for k, v, in work_map.items(): - work_map[k] = (str(sp.simplify(v[0]))) - return { - 'workMap': work_map, - } - # except Exception as e: - # return { - # 'error': { - # 'message': 'Failed to analyze work depth', - # 'details': get_exception_message(e), - # }, - # } - - - -def get_work_depth(sdfg_json): - # final version loads sdfg from json - # loaded = load_sdfg_from_json(sdfg_json) - # if loaded['error'] is not None: - # return loaded['error'] - # sdfg = loaded['sdfg'] - - # for now we load simply load from a file - sdfg = SDFG.from_file(sdfg_json) - - - # try: - work_depth_map = {} - analyze_sdfg(sdfg, work_depth_map, get_tasklet_work_depth) - for k, v, in work_depth_map.items(): - work_depth_map[k] = (str(sp.simplify(v[0])), str(sp.simplify(v[1]))) - return { - 'workDepthMap': work_depth_map, - } - # except Exception as e: - # return { - # 'error': { - # 'message': 'Failed to analyze work depth', - # 'details': get_exception_message(e), - # }, - # } - - - - - -################################################################################ -# Utility functions for running the analysis from the command line ############# -################################################################################ - -class keyvalue(argparse.Action): - - def __call__(self, parser, namespace, values, option_string=None): - setattr(namespace, self.dest, dict()) - for v in values: - k, v = v.split('=') - getattr(namespace, self.dest)[k] = v - - -def main() -> None: - analyze_depth = True - - parser = argparse.ArgumentParser( - 'work_depth_analysis', - usage='python work_depth_analysis.py [-h] filename', - description='Analyze the work/depth of an SDFG.' - ) - - parser.add_argument('filename', type=str, help='The SDFG file to analyze.') - parser.add_argument('--kwargs', nargs='*', help='Define symbols.', action=keyvalue) - - args = parser.parse_args() - - if not os.path.exists(args.filename): - print(args.filename, 'does not exist.') - exit() - - symbols_map = {} - if args.kwargs: - for k, v in args.kwargs.items(): - symbols_map[k] = int(v) - - # TODO: symbols_map maybe not needed - if analyze_depth: - map = get_work_depth(args.filename) - map = map['workDepthMap'] - else: - map = get_work(args.filename) - map = map['workMap'] - - # find uuid of the whole SDFG - sdfg = SDFG.from_file(args.filename) - result = map[get_uuid(sdfg)] - - - print(80*'-') - if isinstance(result, Tuple): - print("Work:\t", result[0]) - print("Depth:\t", result[1]) - else: - print("Work:\t", result) - - print(80*'-') - - - - -if __name__ == '__main__': - main() \ No newline at end of file diff --git a/dace/sdfg/work_depth_analysis/work_depth_tests.py b/dace/sdfg/work_depth_analysis/work_depth_tests.py index e233280143..b6af4daaa7 100644 --- a/dace/sdfg/work_depth_analysis/work_depth_tests.py +++ b/dace/sdfg/work_depth_analysis/work_depth_tests.py @@ -8,7 +8,7 @@ from dace.transformation.dataflow import MapExpansion - +# TODO: add tests for function calls (e.g. reduce) N = dc.symbol('N') From 7f3a997d867800c697b8b9475b6063ad49128b51 Mon Sep 17 00:00:00 2001 From: Cliff Hodel Date: Thu, 20 Jul 2023 09:54:11 +0200 Subject: [PATCH 05/18] code ready for PR --- dace/sdfg/propagation.py | 196 ++++- dace/sdfg/work_depth_analysis/helpers.py | 134 +++- dace/sdfg/work_depth_analysis/work_depth.py | 668 ++++++++---------- .../work_depth_analysis/work_depth_tests.py | 121 ++-- 4 files changed, 634 insertions(+), 485 deletions(-) diff --git a/dace/sdfg/propagation.py b/dace/sdfg/propagation.py index 89ba6928c7..d8f1bee850 100644 --- a/dace/sdfg/propagation.py +++ b/dace/sdfg/propagation.py @@ -10,7 +10,7 @@ import itertools import functools import sympy -from sympy import ceiling +from sympy import ceiling, Symbol from sympy.concrete.summations import Sum import warnings import networkx as nx @@ -564,8 +564,7 @@ def _annotate_loop_ranges(sdfg, unannotated_cycle_states): Annotate each valid for loop construct with its loop variable ranges. :param sdfg: The SDFG in which to look. - :param unannotated_cycle_states: List of states in cycles without valid - for loop ranges. + :param unannotated_cycle_states: List of lists. Each sub-list contains the states of one unannotated cycle. """ # We import here to avoid cyclic imports. @@ -652,7 +651,7 @@ def _annotate_loop_ranges(sdfg, unannotated_cycle_states): res = find_for_loop(sdfg, guard, begin, itervar=itvar) if res is None: # No range detected, mark as unbounded. - unannotated_cycle_states.extend(cycle) + unannotated_cycle_states.append(cycle) else: itervar, rng, _ = res @@ -674,7 +673,192 @@ def _annotate_loop_ranges(sdfg, unannotated_cycle_states): else: # There's no guard state, so this cycle marks all states in it as # dynamically unbounded. - unannotated_cycle_states.extend(cycle) + unannotated_cycle_states.append(cycle) + + +def propagate_states_symbolically(sdfg) -> None: + """ + Idea is like propagate_states, but here we dont have unbounded number of executions. + Instead, we do it symbolically and annotate unbounded loops with symbols "num_exec_{sdfg_id}_{loop_start_state_id}". + + :param sdfg: The SDFG to annotate. + :note: This operates on the SDFG in-place. + """ + + # We import here to avoid cyclic imports. + from dace.sdfg import InterstateEdge + from dace.transformation.helpers import split_interstate_edges + from dace.sdfg.analysis import cfg + + # Reset the state edge annotations (which may have changed due to transformations) + reset_state_annotations(sdfg) + + # Clean up the state machine by separating combined condition and assignment + # edges. + split_interstate_edges(sdfg) + + # To enable branch annotation, we add a temporary exit state that connects + # to all child-less states. With this, we can use the dominance frontier + # to determine a full-merge state for branches. + temp_exit_state = None + for s in sdfg.nodes(): + if sdfg.out_degree(s) == 0: + if temp_exit_state is None: + temp_exit_state = sdfg.add_state('__dace_brannotate_exit') + sdfg.add_edge(s, temp_exit_state, InterstateEdge()) + + dom_frontier = cfg.acyclic_dominance_frontier(sdfg) + + # Find any valid for loop constructs and annotate the loop ranges. Any other + # cycle should be marked as unannotated. + unannotated_cycle_states = [] + _annotate_loop_ranges(sdfg, unannotated_cycle_states) + + # Keep track of states that fully merge a previous conditional split. We do + # this so we can remove the dynamic executions flag for those states. + full_merge_states = set() + + visited_states = set() + + traversal_q = deque() + traversal_q.append((sdfg.start_state, 1, False, [])) + while traversal_q: + (state, proposed_executions, proposed_dynamic, itvar_stack) = traversal_q.pop() + + out_degree = sdfg.out_degree(state) + out_edges = sdfg.out_edges(state) + + # Check if the traversal reached a state that's already been visited + # (ends traversal), or if the number of executions being propagated is + # dynamic unbounded. Otherwise, continue regular traversal. + if state in visited_states: + # This state has already been visited. + if getattr(state, 'is_loop_guard', False): + # If we encounter a loop guard that's already been visited, + # we've finished traversing a loop and can remove that loop's + # iteration variable from the stack. We additively merge the + # number of executions. + state.executions += proposed_executions + else: + # If we have already visited this state, but it is NOT a loop + # guard, this means that we can reach this state via multiple + # different paths. If so, the number of executions for this + # state is given by the maximum number of executions among each + # of the paths reaching it. If the state additionally completely + # merges a previously branched out state tree, we know that the + # number of executions isn't dynamic anymore. + state.executions = sympy.Max(state.executions, proposed_executions).doit() + if state in full_merge_states: + state.dynamic_executions = False + # TODO: do we need this else here or not? + # else: + # state.dynamic_executions = (state.dynamic_executions or proposed_dynamic) + else: + # If the state hasn't been visited yet, we calculate the number of + # executions for the next state(s) and continue propagating. + visited_states.add(state) + if state in full_merge_states: + # If this state fully merges a conditional branch, this turns + # dynamic executions back off. + proposed_dynamic = False + state.executions = proposed_executions + state.dynamic_executions = proposed_dynamic + + if out_degree == 1: + # Continue with the only child state. + if not out_edges[0].data.is_unconditional(): + # If the transition to the child state is based on a + # condition, this state could be an implicit exit state. The + # child state's number of executions is thus only given as + # an upper bound and marked as dynamic. + proposed_dynamic = True + traversal_q.append((out_edges[0].dst, proposed_executions, proposed_dynamic, itvar_stack)) + elif out_degree > 1: + if getattr(state, 'is_loop_guard', False): + itvar = symbolic.symbol(state.itvar) + loop_range = state.ranges[state.itvar] + start = loop_range[0][0] + stop = loop_range[0][1] + stride = loop_range[0][2] + + # Calculate the number of loop executions. + # This resolves ranges based on the order of iteration + # variables pushed on to the stack if we're in a nested + # loop. + loop_executions = ceiling(((stop + 1) - start) / stride) + for outer_itvar_string in reversed(itvar_stack): + outer_range = state.ranges[outer_itvar_string] + outer_start = outer_range[0][0] + outer_stop = outer_range[0][1] + outer_stride = outer_range[0][2] + outer_itvar = symbolic.pystr_to_symbolic(outer_itvar_string) + exec_repl = loop_executions.subs({outer_itvar: (outer_itvar * outer_stride + outer_start)}) + loop_executions = Sum(exec_repl, + (outer_itvar, 0, ceiling((outer_stop - outer_start) / outer_stride))) + loop_executions = loop_executions.doit() + + loop_state = state.condition_edge.dst + end_state = (out_edges[0].dst if out_edges[1].dst == loop_state else out_edges[1].dst) + + traversal_q.append((end_state, state.executions, proposed_dynamic, itvar_stack)) + traversal_q.append((loop_state, loop_executions, proposed_dynamic, itvar_stack + [state.itvar])) + else: + # Conditional split or unannotated loop. + unannotated_loop_edge = None + to_remove = [] + for oedge in out_edges: + for cycle in unannotated_cycle_states: + if oedge.dst in cycle: + # This is an unannotated loop down this branch. + unannotated_loop_edge = oedge + # remove cycle, since it is now annotated with symbol + to_remove.append(cycle) + + for c in to_remove: + unannotated_cycle_states.remove(c) + + if unannotated_loop_edge is not None: + # Traverse as an unbounded loop. + out_edges.remove(unannotated_loop_edge) + + # traverse non-loops states normally + for oedge in out_edges: + traversal_q.append((oedge.dst, state.executions, False, itvar_stack)) + + # Introduce the num_execs symbol and propagate it down the loop. + # These symbols will always be non-negative. + traversal_q.append((unannotated_loop_edge.dst, Symbol(f'num_execs_{sdfg.sdfg_id}_{sdfg.node_id(unannotated_loop_edge.dst)}', nonnegative=True), False, itvar_stack)) + else: + # Traverse as a conditional split. + proposed_executions = state.executions + proposed_dynamic = True + + # Get the dominance frontier for each child state and + # merge them into one common frontier, representing the + # branch's immediate post-dominator. If a state has no + # dominance frontier, add the state itself to the + # frontier. This takes care of the case where a branch + # is fully merged, but one branch contains no states. + common_frontier = set() + for oedge in out_edges: + frontier = dom_frontier[oedge.dst] + if not frontier: + frontier = {oedge.dst} + common_frontier |= frontier + + # Continue traversal for each child. + traversal_q.append((oedge.dst, proposed_executions, proposed_dynamic, itvar_stack)) + + # If the whole branch is not dynamic, and the + # common frontier is exactly one state, we know that + # the branch merges again at that state. + if not state.dynamic_executions and len(common_frontier) == 1: + full_merge_states.add(list(common_frontier)[0]) + + # If we had to create a temporary exit state, we remove it again here. + if temp_exit_state is not None: + sdfg.remove_node(temp_exit_state) + def propagate_states(sdfg) -> None: @@ -759,6 +943,8 @@ def propagate_states(sdfg) -> None: # cycle should be marked as unannotated. unannotated_cycle_states = [] _annotate_loop_ranges(sdfg, unannotated_cycle_states) + # flatten the list + unannotated_cycle_states = [state for cycle in unannotated_cycle_states for state in cycle] # Keep track of states that fully merge a previous conditional split. We do # this so we can remove the dynamic executions flag for those states. diff --git a/dace/sdfg/work_depth_analysis/helpers.py b/dace/sdfg/work_depth_analysis/helpers.py index 28b4741452..b9964db3d1 100644 --- a/dace/sdfg/work_depth_analysis/helpers.py +++ b/dace/sdfg/work_depth_analysis/helpers.py @@ -1,4 +1,7 @@ -from dace import SDFG, SDFGState, nodes, serialize +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +""" Helper functions used by the work depth analysis. """ + +from dace import SDFG, SDFGState, nodes from collections import deque from typing import List, Dict, Set, Tuple, Optional, Union import networkx as nx @@ -20,7 +23,6 @@ def length(self) -> int: UUID_SEPARATOR = '/' - def ids_to_string(sdfg_id, state_id=-1, node_id=-1, edge_id=-1): return (str(sdfg_id) + UUID_SEPARATOR + str(state_id) + UUID_SEPARATOR + str(node_id) + UUID_SEPARATOR + str(edge_id)) @@ -37,13 +39,6 @@ def get_uuid(element, state=None): else: return ids_to_string(-1) - - - - - - - def get_domtree( graph: nx.DiGraph, start_node: str, @@ -80,9 +75,7 @@ def get_domtree( return alldominated, domtree - - -def backedges( +def get_backedges( graph: nx.DiGraph, start: Optional[NodeT], strict: bool = False ) -> Union[Set[EdgeT], Tuple[Set[EdgeT], Set[EdgeT]]]: '''Find all backedges in a directed graph. @@ -227,3 +220,120 @@ def backedges( return backedges +def find_loop_guards_tails_exits(sdfg_nx: nx.DiGraph): + """ + Detects loops in a SDFG. For each loop, it identifies (node, oNode, exit). + We know that there is a backedge from oNode to node that creates the loop and that exit is the exit state of the loop. + + :param sdfg_nx: The networkx representation of a SDFG. + """ + + # preparation phase: compute dominators, backedges etc + for node in sdfg_nx.nodes(): + if sdfg_nx.in_degree(node) == 0: + start = node + break + if start is None: + raise ValueError('No start node could be determined') + + # sdfg can have multiple end nodes --> not good for postDomTree + # --> add a new end node + artificial_end_node = 'artificial_end_node' + sdfg_nx.add_node(artificial_end_node) + for node in sdfg_nx.nodes(): + if sdfg_nx.out_degree(node) == 0 and node != artificial_end_node: + # this is an end node of the sdfg + sdfg_nx.add_edge(node, artificial_end_node) + + # sanity check: + if sdfg_nx.in_degree(artificial_end_node) == 0: + raise ValueError('No end node could be determined in the SDFG') + + # compute dominators and backedges + iDoms = nx.immediate_dominators(sdfg_nx, start) + allDom, domTree = get_domtree(sdfg_nx, start, iDoms) + + reversed_sdfg_nx = sdfg_nx.reverse() + iPostDoms = nx.immediate_dominators(reversed_sdfg_nx, artificial_end_node) + allPostDoms, postDomTree = get_domtree(reversed_sdfg_nx, artificial_end_node, iPostDoms) + + backedges = get_backedges(sdfg_nx, start) + backedgesDstDict = {} + for be in backedges: + if be[1] in backedgesDstDict: + backedgesDstDict[be[1]].add(be) + else: + backedgesDstDict[be[1]] = set([be]) + + + # This list will be filled with triples (node, oNode, exit), one triple for each loop construct in the SDFG. + # There will always be a backedge from oNode to node. Either node or oNode will be the corresponding loop guard, + # depending on whether it is a while-do or a do-while loop. exit will always be the exit state of the loop. + nodes_oNodes_exits = [] + + # iterate over all nodes + for node in sdfg_nx.nodes(): + # Check if any backedge ends in node. + if node in backedgesDstDict: + inc_backedges = backedgesDstDict[node] + + # gather all successors of node that are not reached by backedges + successors = [] + for edge in sdfg_nx.out_edges(node): + if not edge in backedges: + successors.append(edge[1]) + + + # For each incoming backedge, we want to find oNode and exit. There can be multiple backedges, in case + # we have a continue statement in the original code. But we can handle these backedges normally. + for be in inc_backedges: + # since node has an incoming backedge, it is either a loop guard or loop tail + # oNode will exactly be the other thing + oNode = be[0] + exitCandidates = set() + # search for exit candidates: + # a state is a exit candidate if: + # - it is in successor and it does not dominate oNode (else it dominates + # the last loop state, and hence is inside the loop itself) + # - is is a successor of oNode (but not node) + # This handles both cases of while-do and do-while loops + for succ in successors: + if succ != oNode and oNode not in allDom[succ]: + exitCandidates.add(succ) + for succ in sdfg_nx.successors(oNode): + if succ != node: + exitCandidates.add(succ) + + if len(exitCandidates) == 0: + raise ValueError('failed to find any exit nodes') + elif len(exitCandidates) > 1: + # Find the exit candidate that sits highest up in the + # postdominator tree (i.e., has the lowest level). + # That must be the exit node (it must post-dominate) + # everything inside the loop. If there are multiple + # candidates on the lowest level (i.e., disjoint set of + # postdominated nodes), there are multiple exit paths, + # and they all share one level. + cand = exitCandidates.pop() + minSet = set([cand]) + minLevel = nx.get_node_attributes(postDomTree, 'level')[cand] + for cand in exitCandidates: + curr_level = nx.get_node_attributes(postDomTree, 'level')[cand] + if curr_level < minLevel: + # new minimum found + minLevel = curr_level + minSet.clear() + minSet.add(cand) + elif curr_level == minLevel: + # add cand to curr set + minSet.add(cand) + + if len(minSet) > 0: + exitCandidates = minSet + else: + raise ValueError('failed to find exit minSet') + + # now we have a triple (node, oNode, exitCandidates) + nodes_oNodes_exits.append((node, oNode, exitCandidates)) + + return nodes_oNodes_exits diff --git a/dace/sdfg/work_depth_analysis/work_depth.py b/dace/sdfg/work_depth_analysis/work_depth.py index 2defcd17c0..a3c43c9826 100644 --- a/dace/sdfg/work_depth_analysis/work_depth.py +++ b/dace/sdfg/work_depth_analysis/work_depth.py @@ -1,12 +1,15 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +""" Work depth analysis for any input SDFG. Can be used with the DaCe VS Code extension or +from command line as a Python script. """ + import argparse from collections import deque -from dace.sdfg import nodes as nd, propagation, InterstateEdge, utils as sdutil +from dace.sdfg import nodes as nd, propagation, InterstateEdge from dace import SDFG, SDFGState, dtypes from dace.subsets import Range from typing import Tuple, Dict import os import sympy as sp -import networkx as nx from copy import deepcopy from dace.libraries.blas import MatMul, Transpose from dace.libraries.standard import Reduce @@ -14,121 +17,20 @@ import ast import astunparse import warnings -from dace.sdfg.graph import Edge - -from dace.sdfg.work_depth_analysis.helpers import get_uuid, get_domtree, backedges as get_backedges - - -def find_loop_guards_tails_exits(sdfg_nx: nx.DiGraph): - # preparation phase: compute dominators, backedges etc - for node in sdfg_nx.nodes(): - if sdfg_nx.in_degree(node) == 0: - start = node - break - if start is None: - raise ValueError('No start node could be determined') - - # sdfg can have multiple end nodes --> not good for postDomTree - # --> add a new end node - artificial_end_node = 'artificial_end_node' - sdfg_nx.add_node(artificial_end_node) - for node in sdfg_nx.nodes(): - if sdfg_nx.out_degree(node) == 0 and node != artificial_end_node: - # this is an end node of the sdfg - sdfg_nx.add_edge(node, artificial_end_node) - - # sanity check: - if sdfg_nx.in_degree(artificial_end_node) == 0: - raise ValueError('No end node could be determined in the SDFG') - - - - iDoms = nx.immediate_dominators(sdfg_nx, start) - allDom, domTree = get_domtree(sdfg_nx, start, iDoms) - - reversed_sdfg_nx = sdfg_nx.reverse() - iPostDoms = nx.immediate_dominators(reversed_sdfg_nx, artificial_end_node) - allPostDoms, postDomTree = get_domtree(reversed_sdfg_nx, artificial_end_node, iPostDoms) - - backedges = get_backedges(sdfg_nx, start) - backedgesDstDict = {} - for be in backedges: - if be[1] in backedgesDstDict: - backedgesDstDict[be[1]].add(be) - else: - backedgesDstDict[be[1]] = set([be]) - - - nodes_oNodes_exits = [] - - # iterate over all nodes - for node in sdfg_nx.nodes(): - # does any backedge end in node - if node in backedgesDstDict: - inc_backedges = backedgesDstDict[node] - - - # gather all successors of node that are not reached by backedges - successors = [] - for edge in sdfg_nx.out_edges(node): - if not edge in backedges: - successors.append(edge[1]) - - - # if len(inc_backedges) > 1: - # raise ValueError('node has multiple incoming backedges...') - # instead: if multiple incoming backedges, do the below for each backedge - for be in inc_backedges: - - - # since node has an incoming backedge, it is either a loop guard or loop tail - # oNode will exactly be the other thing - oNode = be[0] - exitCandidates = set() - for succ in successors: - if succ != oNode and oNode not in allDom[succ]: - exitCandidates.add(succ) - for succ in sdfg_nx.successors(oNode): - if succ != node: - exitCandidates.add(succ) - - if len(exitCandidates) == 0: - raise ValueError('failed to find any exit nodes') - elif len(exitCandidates) > 1: - # // Find the exit candidate that sits highest up in the - # // postdominator tree (i.e., has the lowest level). - # // That must be the exit node (it must post-dominate) - # // everything inside the loop. If there are multiple - # // candidates on the lowest level (i.e., disjoint set of - # // postdominated nodes), there are multiple exit paths, - # // and they all share one level. - cand = exitCandidates.pop() - minSet = set([cand]) - minLevel = nx.get_node_attributes(postDomTree, 'level')[cand] - for cand in exitCandidates: - curr_level = nx.get_node_attributes(postDomTree, 'level')[cand] - if curr_level < minLevel: - # new minimum found - minLevel = curr_level - minSet.clear() - minSet.add(cand) - elif curr_level == minLevel: - # add cand to curr set - minSet.add(cand) - - if len(minSet) > 0: - exitCandidates = minSet - else: - raise ValueError('failed to find exit minSet') - # now we have a triple (node, oNode, exitCandidates) - nodes_oNodes_exits.append((node, oNode, exitCandidates)) +from dace.sdfg.work_depth_analysis.helpers import get_uuid, find_loop_guards_tails_exits - return nodes_oNodes_exits - def get_array_size_symbols(sdfg): + """ + Returns all symbols that appear isolated in shapes of the SDFG's arrays. + These symbols can then be assumed to be positive. + + :note: This only works if a symbol appears in isolation, i.e. array A[N]. If we have A[N+1], we cannot assume N to be positive. + :param sdfg: The SDFG in which it searches for symbols. + :return: A set containing symbols which we can assume to be positive. + """ symbols = set() for _, _, arr in sdfg.arrays_recursive(): for s in arr.shape: @@ -136,18 +38,29 @@ def get_array_size_symbols(sdfg): symbols.add(s) return symbols -def posify_certain_symbols(expr, syms_to_posify, syms_to_nonnegify): +def posify_certain_symbols(expr, syms_to_posify): + """ + Takes an expression and evaluates it while assuming that certain symbols are positive. + + :param expr: The expression to evaluate. + :param syms_to_posify: List of symbols we assume to be positive. + :note: This is adapted from the Sympy function posify. + """ + expr = sp.sympify(expr) - nonneg = {s: sp.Dummy(s.name, nonnegative=True, **s.assumptions0) - for s in syms_to_nonnegify if s.is_nonnegative is None} - pos = {s: sp.Dummy(s.name, positive=True, **s.assumptions0) + + reps = {s: sp.Dummy(s.name, positive=True, **s.assumptions0) for s in syms_to_posify if s.is_positive is None} - # merge the two dicts into reps - reps = {**nonneg, **pos} expr = expr.subs(reps) return expr.subs({r: s for s, r in reps.items()}) def symeval(val, symbols): + """ + Takes a sympy expression and substitutes its symbols according to a dict { old_symbol: new_symbol}. + + :param val: The expression we are updating. + :param symbols: Dictionary of key value pairs { old_symbol: new_symbol}. + """ first_replacement = { pystr_to_symbolic(k): pystr_to_symbolic('__REPLSYM_' + k) for k in symbols.keys() @@ -158,7 +71,13 @@ def symeval(val, symbols): } return val.subs(first_replacement).subs(second_replacement) -def count_matmul(node, symbols, state): +def evaluate_symbols(base, new): + result = {} + for k, v in new.items(): + result[k] = symeval(v, base) + return result + +def count_work_matmul(node, symbols, state): A_memlet = next(e for e in state.in_edges(node) if e.dst_conn == '_a') B_memlet = next(e for e in state.in_edges(node) if e.dst_conn == '_b') C_memlet = next(e for e in state.out_edges(node) if e.src_conn == '_c') @@ -173,8 +92,7 @@ def count_matmul(node, symbols, state): result *= symeval(A_memlet.data.subset.size()[-1], symbols) return result - -def count_reduce(node, symbols, state): +def count_work_reduce(node, symbols, state): result = 0 if node.wcr is not None: result += count_arithmetic_ops_code(node.wcr) @@ -188,24 +106,53 @@ def count_reduce(node, symbols, state): result = 0 return result +LIBNODES_TO_WORK = { + MatMul: count_work_matmul, + Transpose: lambda *args: 0, + Reduce: count_work_reduce, +} + +def count_depth_matmul(node, symbols, state): + # For now we set it equal to work: see comments in count_depth_reduce just below + return count_work_matmul(node, symbols, state) + +def count_depth_reduce(node, symbols, state): + # depth of reduction is log2 of the work + # TODO: Can we actually assume this? Or is it equal to the work? + # Another thing to consider is that we essetially do NOT count wcr edges as operations for now... + + # return sp.ceiling(sp.log(count_work_reduce(node, symbols, state), 2)) + # set it equal to work for now + return count_work_reduce(node, symbols, state) + + +LIBNODES_TO_DEPTH = { + MatMul: count_depth_matmul, + Transpose: lambda *args: 0, + Reduce: count_depth_reduce, +} + + bigo = sp.Function('bigo') PYFUNC_TO_ARITHMETICS = { 'float': 0, + 'dace.float64': 0, + 'dace.int64': 0, 'math.exp': 1, + 'exp': 1, 'math.tanh': 1, + 'sin': 1, + 'cos': 1, + 'tanh': 1, 'math.sqrt': 1, + 'sqrt': 1, + 'atan2:': 1, 'min': 0, 'max': 0, 'ceiling': 0, 'floor': 0, - # TODO: what about: "cos", "sin", "sqrt", "atan2" (source: npbench's arc_distance), "abs" (mandelbrot1), "exp" (mlp) - # tanh (go_fast) + 'abs': 0 } -LIBNODES_TO_ARITHMETICS = { - MatMul: count_matmul, - Transpose: lambda *args: 0, - Reduce: count_reduce, -} class ArithmeticCounter(ast.NodeVisitor): @@ -227,8 +174,7 @@ def visit_UnaryOp(self, node): def visit_Call(self, node): fname = astunparse.unparse(node.func)[:-1] if fname not in PYFUNC_TO_ARITHMETICS: - # TODO: why do we get this warning? WARNING: Unrecognized python function "dace.float64" (source: npbench's azimint_hist) - print('WARNING: Unrecognized python function "%s"' % fname) + print('WARNING: Unrecognized python function "%s". If this is a type conversion, like "dace.float64", then this is fine.' % fname) return self.generic_visit(node) self.count += PYFUNC_TO_ARITHMETICS[fname] return self.generic_visit(node) @@ -254,12 +200,10 @@ def count_arithmetic_ops_code(code): return ctr.count class DepthCounter(ast.NodeVisitor): - + # so far this is identical to the ArithmeticCounter above. def __init__(self): self.count = 0 - # TODO: if we have a tasklet like _out = 2 * _in + 500 - # will this then have depth of 2? or not because of instruction level parallelism? def visit_BinOp(self, node): if isinstance(node.op, ast.MatMult): raise NotImplementedError('MatMult op count requires shape ' @@ -274,7 +218,7 @@ def visit_UnaryOp(self, node): def visit_Call(self, node): fname = astunparse.unparse(node.func)[:-1] if fname not in PYFUNC_TO_ARITHMETICS: - print('WARNING: Unrecognized python function "%s"' % fname) + print('WARNING: Unrecognized python function "%s". If this is a type conversion, like "dace.float64", then this is fine.' % fname) return self.generic_visit(node) self.count += PYFUNC_TO_ARITHMETICS[fname] return self.generic_visit(node) @@ -289,7 +233,8 @@ def visit_While(self, node): raise NotImplementedError def count_depth_code(code): - ctr = DepthCounter() + # so far this is the same as the work counter, since work = depth for each tasklet, as we can't assume any parallelism + ctr = ArithmeticCounter() if isinstance(code, (tuple, list)): for stmt in code: ctr.visit(stmt) @@ -313,14 +258,14 @@ def tasklet_work(tasklet_node, state): return 1 def tasklet_depth(tasklet_node, state): - # if tasklet_node.code.language == dtypes.Language.CPP: - # for oedge in state.out_edges(tasklet_node): - # return bigo(oedge.data.num_accesses) - + # TODO: how to get depth of CPP tasklets? + # For now we use depth == work: + if tasklet_node.code.language == dtypes.Language.CPP: + for oedge in state.out_edges(tasklet_node): + return bigo(oedge.data.num_accesses) if tasklet_node.code.language == dtypes.Language.Python: return count_depth_code(tasklet_node.code.code) else: - # TODO: improve this # other languages not implemented, count whole tasklet as work of 1 warnings.warn('Depth of tasklets only properly analyzed for Python code. For all other ' 'languages depth = 1 will be counted for each tasklet.') @@ -332,91 +277,65 @@ def get_tasklet_work(node, state): def get_tasklet_work_depth(node, state): return tasklet_work(node, state), tasklet_depth(node, state) -def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], analyze_tasklet, syms_to_nonnegify) -> None: - # print('Analyzing work and depth of SDFG', sdfg.name) - # print('SDFG has', len(sdfg.nodes()), 'states') - # print('Calculating work and depth for all states individually...') +def get_tasklet_avg_par(node, state): + return tasklet_work(node, state), tasklet_depth(node, state) + + +def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], analyze_tasklet, symbols) -> Tuple[sp.Expr, sp.Expr]: + """ + Analyze the work and depth of a given SDFG. + First we determine the work and depth of each state. Then we break loops in the state machine, such that we get a DAG. + Lastly, we compute the path with most work and the path with the most depth in order to get the total work depth. + + :param sdfg: The SDFG to analyze. + :param w_d_map: Dictionary which will save the result. + :param analyze_tasklet: Function used to analyze tasklet nodes. + :param symbols: A dictionary mapping local nested SDFG symbols to global symbols. + :return: A tuple containing the work and depth of the SDFG. + """ + + # First determine the work and depth of each state individually. # Keep track of the work and depth for each state in a dictionary, where work and depth are multiplied by the number # of times the state will be executed. state_depths: Dict[SDFGState, sp.Expr] = {} state_works: Dict[SDFGState, sp.Expr] = {} for state in sdfg.nodes(): - state_work, state_depth = state_work_depth(state, w_d_map, analyze_tasklet, syms_to_nonnegify) - if state.executions == 0:# or state.executions == sp.zoo: - # print('State executions must be statically known exactly or with an upper bound. Offender:', state) - new_symbol = sp.Symbol(f'num_execs_{sdfg.sdfg_id}_{sdfg.node_id(state)}') - state.executions = new_symbol - syms_to_nonnegify |= {new_symbol} - state_works[state] = state_work * state.executions - state_depths[state] = state_depth * state.executions - w_d_map[get_uuid(state)] = (sp.simplify(state_work * state.executions), sp.simplify(state_depth * state.executions)) + state_work, state_depth = state_work_depth(state, w_d_map, analyze_tasklet, symbols) + state_works[state] = sp.simplify(state_work * state.executions) + state_depths[state] = sp.simplify(state_depth * state.executions) + w_d_map[get_uuid(state)] = (state_works[state], state_depths[state]) - # print('Calculating work and depth of the SDFG...') + # Prepare the SDFG for a depth analysis by breaking loops. This removes the edge between the last loop state and + # the guard, and instead places an edge between the last loop state and the exit state. + # This transforms the state machine into a DAG. Hence, we can find the "heaviest" and "deepest" paths in linear time. + # Additionally, construct a dummy exit state and connect every state that has no outgoing edges to it. + # identify all loops in the SDFG nodes_oNodes_exits = find_loop_guards_tails_exits(sdfg._nx) - print(nodes_oNodes_exits) - # Now we need to go over each triple (node, oNode, exits) - # for each triple, we - # - remove edge (oNode, node), i.e. the backward edge - # - for all exits e, add edge (oNode, e). This edge may already exist - + + # Now we need to go over each triple (node, oNode, exits). For each triple, we + # - remove edge (oNode, node), i.e. the backward edge + # - for all exits e, add edge (oNode, e). This edge may already exist for node, oNode, exits in nodes_oNodes_exits: sdfg.remove_edge(sdfg.edges_between(oNode, node)[0]) for e in exits: - # TODO: This will probably fail if len(exits) > 1, but in which cases does that even happen? if len(sdfg.edges_between(oNode, e)) == 0: # no edge there yet sdfg.add_edge(oNode, e, InterstateEdge()) - # Prepare the SDFG for a detph analysis by 'inlining' loops. This removes the edge between the guard and the exit - # state and the edge between the last loop state and the guard, and instead places an edge between the last loop - # state and the exit state. Additionally, construct a dummy exit state and connect every state that has no outgoing - # edges to it. - - - - - + # add a dummy exit to the SDFG, such that each path ends there. dummy_exit = sdfg.add_state('dummy_exit') for state in sdfg.nodes(): - """ - if hasattr(state, 'condition_edge') and hasattr(state, 'is_loop_guard') and state.is_loop_guard: - # This is a loop guard. - loop_begin = state.condition_edge.dst - # Determine loop states through a depth first search from the start of the loop. Everything reached before - # arriving back at the loop guard is part of the loop. - # TODO: This is hacky. Loops should report the loop states directly. This may fail or behave unexpectedly - # for break/return statements inside of loops. - loop_states = set(sdutil.dfs_conditional(sdfg, sources=[loop_begin], condition=lambda _, s: s != state)) - loop_exit = None - exit_edge = None - loop_end = None - end_edge = None - for iedge in sdfg.in_edges(state): - if iedge.src in loop_states: - end_edge = iedge - loop_end = iedge.src - for oedge in sdfg.out_edges(state): - if oedge.dst not in loop_states: - loop_exit = oedge.dst - exit_edge = oedge - - if loop_exit is None or loop_end is None: - raise RuntimeError('Failed to analyze the depth of a loop starting at', state) - - sdfg.remove_edge(exit_edge) - sdfg.remove_edge(end_edge) - sdfg.add_edge(loop_end, loop_exit, InterstateEdge()) - #""" - if len(sdfg.out_edges(state)) == 0 and state != dummy_exit: sdfg.add_edge(state, dummy_exit, InterstateEdge()) - depth_map: Dict[SDFGState, sp.Expr] = {} + # These two dicts save the current length of the "heaviest", resp. "deepest", paths at each state. work_map: Dict[SDFGState, sp.Expr] = {} + depth_map: Dict[SDFGState, sp.Expr] = {} + # The dummy state has 0 work and depth. state_depths[dummy_exit] = sp.sympify(0) state_works[dummy_exit] = sp.sympify(0) @@ -435,47 +354,73 @@ def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], ana n_depth = sp.simplify(depth + state_depths[state]) n_work = sp.simplify(work + state_works[state]) - if state in depth_map: - depth_map[state] = sp.Max(depth_map[state], n_depth) - else: - depth_map[state] = n_depth - - if state in work_map: - work_map[state] = sp.Max(work_map[state], n_work) + # If we are analysing average parallelism, we don't search "heaviest" and "deepest" paths separately, but we want one + # single path with the least average parallelsim (of all paths with more than 0 work). + if analyze_tasklet == get_tasklet_avg_par: + if state in depth_map: # and hence als state in work_map + # if current path has 0 depth, we don't do anything. + if n_depth != 0: + # see if we need to update the work and depth of the current state + # we update if avg parallelism of new incoming path is less than current avg parallelism + old_avg_par = sp.simplify(work_map[state] / depth_map[state]) + new_avg_par = sp.simplify(n_work / n_depth) + + if depth_map[state] == 0 or new_avg_par < old_avg_par: + # old value was divided by zero or new path gives actually worse avg par, then we keep new value + depth_map[state] = n_depth + work_map[state] = n_work + else: + depth_map[state] = n_depth + work_map[state] = n_work else: - work_map[state] = n_work - + # search heaviest and deepest path separately + if state in depth_map: # and consequently also in work_map + depth_map[state] = sp.Max(depth_map[state], n_depth) + work_map[state] = sp.Max(work_map[state], n_work) + else: + depth_map[state] = n_depth + work_map[state] = n_work + out_edges = sdfg.out_edges(state) + # only advance after all incoming edges were visited (meaning that current work depth values of state are final). if any(iedge not in visited for iedge in sdfg.in_edges(state)): pass else: for oedge in out_edges: traversal_q.append((oedge.dst, depth_map[state], work_map[state], oedge)) - max_depth = depth_map[dummy_exit] - max_work = work_map[dummy_exit] - - print('SDFG', sdfg.name, 'processed') - w_d_map[get_uuid(sdfg)] = (sp.simplify(max_work), sp.simplify(max_depth)) - return sp.simplify(max_work), sp.simplify(max_depth) - - -""" -Analyze the work and depth of a scope. -This works by constructing a proxy graph of the scope and then finding the maximum depth path in that graph between -the source and sink. The proxy graph is constructed to remove any multi-edges between nodes and to remove nodes that -do not contribute to the depth. Additionally, nested scopes are summarized into single nodes. All of this is necessary -to reduce the number of possible paths in the graph, as much as possible, since they all have to be brute-force -enumerated to find the maximum depth path. -:note: This is terribly inefficient and should be improved. -:param state: The state in which the scope to analyze is contained. -:param sym_map: A dictionary mapping symbols to their values. -:param entry: The entry node of the scope to analyze. If None, the entire state is analyzed. -:return: A tuple containing the work and depth of the scope. -""" -def scope_work_depth(state: SDFGState, w_d_map: Dict[str, sp.Expr], analyze_tasklet, syms_to_nonnegify, entry: nd.EntryNode = None) -> Tuple[sp.Expr, sp.Expr]: - - # find the work / depth of each node + try: + max_depth = depth_map[dummy_exit] + max_work = work_map[dummy_exit] + except KeyError: + # If we get a KeyError above, this means that the traversal never reached the dummy_exit state. + # This happens if the loops were not properly detected and broken. + raise Exception('Analysis failed, since not all loops got detected. It may help to use more structured loop constructs.') + + sdfg_result = (sp.simplify(max_work), sp.simplify(max_depth)) + w_d_map[get_uuid(sdfg)] = sdfg_result + return sdfg_result + + +def scope_work_depth(state: SDFGState, w_d_map: Dict[str, sp.Expr], analyze_tasklet, symbols, entry: nd.EntryNode = None) -> Tuple[sp.Expr, sp.Expr]: + """ + Analyze the work and depth of a scope. + This works by traversing through the scope analyzing the work and depth of each encountered node. + Depending on what kind of node we encounter, we do the following: + - EntryNode: Recursively analyze work depth of scope. + - Tasklet: use analyze_tasklet to get work depth of tasklet node. + - NestedSDFG: After translating its local symbols to global symbols, we analyze the nested SDFG recursively. + - LibraryNode: Library nodes are analyzed with special functions depending on their type. + Work inside a state can simply be summed up, but for the depth we need to find the longest path. Since dataflow is a DAG, + this can be done in linear time by traversing the graph in topological order. + + :param state: The state in which the scope to analyze is contained. + :param sym_map: A dictionary mapping symbols to their values. + :param entry: The entry node of the scope to analyze. If None, the entire state is analyzed. + :return: A tuple containing the work and depth of the scope. + """ + + # find the work and depth of each node # for maps and nested SDFG, we do it recursively work = sp.sympify(0) max_depth = sp.sympify(0) @@ -483,15 +428,16 @@ def scope_work_depth(state: SDFGState, w_d_map: Dict[str, sp.Expr], analyze_task scope_exit = None if entry is None else state.exit_node(entry) for node in scope_nodes: # add node to map - w_d_map[get_uuid(node, state)] = (sp.sympify(0), sp.sympify(0)) # TODO: do we need this line? + w_d_map[get_uuid(node, state)] = (sp.sympify(0), sp.sympify(0)) if isinstance(node, nd.EntryNode): - # If the scope contains an entry node, we need to recursively analyze the scope of the entry node first. + # If the scope contains an entry node, we need to recursively analyze the sub-scope of the entry node first. # The resulting work/depth are summarized into the entry node - s_work, s_depth = scope_work_depth(state, w_d_map, analyze_tasklet, syms_to_nonnegify, node) + s_work, s_depth = scope_work_depth(state, w_d_map, analyze_tasklet, symbols, node) # add up work for whole state, but also save work for this sub-scope scope in w_d_map work += s_work w_d_map[get_uuid(node, state)] = (s_work, s_depth) elif node == scope_exit: + # don't do anything for exit nodes, everthing handled already in the corresponding entry node. pass elif isinstance(node, nd.Tasklet): # add up work for whole state, but also save work for this node in w_d_map @@ -499,12 +445,25 @@ def scope_work_depth(state: SDFGState, w_d_map: Dict[str, sp.Expr], analyze_task work += t_work w_d_map[get_uuid(node, state)] = (sp.sympify(t_work), sp.sympify(t_depth)) elif isinstance(node, nd.NestedSDFG): + # keep track of nested symbols: "symbols" maps local nested SDFG symbols to global symbols. + # We only want global symbols in our final work depth expressions. + nested_syms = {} + nested_syms.update(symbols) + nested_syms.update(evaluate_symbols(symbols, node.symbol_mapping)) # Nested SDFGs are recursively analyzed first. - nsdfg_work, nsdfg_depth = sdfg_work_depth(node.sdfg, w_d_map, analyze_tasklet, syms_to_nonnegify) + nsdfg_work, nsdfg_depth = sdfg_work_depth(node.sdfg, w_d_map, analyze_tasklet, nested_syms) # add up work for whole state, but also save work for this nested SDFG in w_d_map work += nsdfg_work w_d_map[get_uuid(node, state)] = (nsdfg_work, nsdfg_depth) + elif isinstance(node, nd.LibraryNode): + lib_node_work = LIBNODES_TO_WORK[type(node)](node, symbols, state) + lib_node_depth = -1 # not analyzed + if analyze_tasklet != get_tasklet_work: + # we are analyzing depth + lib_node_depth = LIBNODES_TO_DEPTH[type(node)](node, symbols, state) + w_d_map[get_uuid(node, state)] = (lib_node_work, lib_node_depth) + if entry is not None: # If the scope being analyzed is a map, multiply the work by the number of iterations of the map. @@ -517,12 +476,13 @@ def scope_work_depth(state: SDFGState, w_d_map: Dict[str, sp.Expr], analyze_task print('WARNING: Only Map scopes are supported in work analysis for now. Assuming 1 iteration.') - # TODO: Kinda ugly if condition... - # only do this if we even analyzed depth of tasklets + # Work inside a state can simply be summed up. But now we need to find the depth of a state (i.e. longest path). + # Since dataflow graph is a DAG, this can be done in linear time. max_depth = sp.sympify(0) - if analyze_tasklet == get_tasklet_work_depth: + # only do this if we are analyzing depth + if analyze_tasklet == get_tasklet_work_depth or analyze_tasklet == get_tasklet_avg_par: # Calculate the maximum depth of the scope by finding the 'deepest' path from the source to the sink. This is done by - # a BFS in topological order, where each node propagates its current max depth for all incoming paths + # a traversal in topological order, where each node propagates its current max depth for all incoming paths. traversal_q = deque() visited = set() # find all starting nodes @@ -532,12 +492,10 @@ def scope_work_depth(state: SDFGState, w_d_map: Dict[str, sp.Expr], analyze_task else: for node in scope_nodes: if len(state.in_edges(node)) == 0: - # push this node into the deque + # This node is a start node of the traversal traversal_q.append((node, sp.sympify(0), None)) - - + # this map keeps track of the length of the longest path ending at each state so far seen. depth_map = {} - while traversal_q: node, in_depth, in_edge = traversal_q.popleft() @@ -552,168 +510,89 @@ def scope_work_depth(state: SDFGState, w_d_map: Dict[str, sp.Expr], analyze_task depth_map[node] = n_depth out_edges = state.out_edges(node) - # only advance to next node, if all incoming edges have been visited or the current node is the entry (aka starting node) - # if the current node is the exit of the current scope, we stop, such that we don't leave the current scope + # Only advance to next node, if all incoming edges have been visited or the current node is the entry (aka starting node). + # If the current node is the exit of the scope, we stop, such that we don't leave the scope. if (all(iedge in visited for iedge in state.in_edges(node)) or node == entry) and node != scope_exit: - # if we encounter a nested map, we must not analyze its contents (as they have already been recursively analyzed) - # hence, we continue from the outgoing edges of the corresponding exit + # If we encounter a nested map, we must not analyze its contents (as they have already been recursively analyzed). + # Hence, we continue from the outgoing edges of the corresponding exit. if isinstance(node, nd.EntryNode) and node != entry: - # get the corresponding exit note exit_node = state.exit_node(node) # replace out_edges with the out_edges of the scope exit node out_edges = state.out_edges(exit_node) for oedge in out_edges: traversal_q.append((oedge.dst, depth_map[node], oedge)) if len(out_edges) == 0 or node == scope_exit: - # this is an end node --> update max_depth + # We have reached an end node --> update max_depth max_depth = sp.Max(max_depth, depth_map[node]) - # summarise work / depth of the whole state in the dictionary - w_d_map[get_uuid(state)] = (sp.simplify(work), sp.simplify(max_depth)) - return sp.simplify(work), sp.simplify(max_depth) - -""" -Analyze the work and depth of a state. -:param state: The state to analyze. -:param sym_map: A dictionary mapping symbols to their values. -:return: A tuple containing the work and depth of the state. -""" -def state_work_depth(state: SDFGState, w_d_map: Dict[str, sp.Expr], analyze_tasklet, syms_to_nonnegify) -> None: - work, depth = scope_work_depth(state, w_d_map, analyze_tasklet, syms_to_nonnegify, None) + # summarise work / depth of the whole scope in the dictionary + scope_result = (sp.simplify(work), sp.simplify(max_depth)) + w_d_map[get_uuid(state)] = scope_result + return scope_result + + +def state_work_depth(state: SDFGState, w_d_map: Dict[str, sp.Expr], analyze_tasklet, symbols) -> Tuple[sp.Expr, sp.Expr]: + """ + Analyze the work and depth of a state. + + :param state: The state to analyze. + :param w_d_map: The result will be saved to this map. + :param analyze_tasklet: Function used to analyze tasklet nodes. + :param symbols: A dictionary mapping local nested SDFG symbols to global symbols. + :return: A tuple containing the work and depth of the state. + """ + work, depth = scope_work_depth(state, w_d_map, analyze_tasklet, symbols, None) return work, depth -""" -Analyze the work and depth of an SDFG. -Optionally, a dictionary mapping symbols to their values can be provided to concretize the analysis. -Note that this also significantly speeds up the analysis due to sympy not having to perform the analysis symbolically. -:note: SDFGs should have split interstate edges. This means there should be no interstate edges containing both a - condition and an assignment. -:param sdfg: The SDFG to analyze. -:param sym_map: A dictionary mapping symbols to their values. -:return: A tuple containing the work and depth of the SDFG -""" -# def analyze_sdfg(sdfg: SDFG, w_d_map: Dict[str, str], sym_map: Dict[str, int]) -> Dict[str, Tuple[str, str]]: -def analyze_sdfg(sdfg: SDFG, w_d_map: Dict[str, sp.Expr], analyze_tasklet) -> Dict[str, Tuple[sp.Expr, sp.Expr]]: - # Run state propagation for all SDFGs recursively. This is necessary to determine the number of times each state - # will be executed, or to determine upper bounds for that number (such as in the case of branching) - # print('Propagating states...') - for sd in sdfg.all_sdfgs_recursive(): - propagation.propagate_states(sd) +def analyze_sdfg(sdfg: SDFG, w_d_map: Dict[str, sp.Expr], analyze_tasklet) -> None: + """ + Analyze a given SDFG. We can either analyze work, work and depth or average parallelism. + + :note: SDFGs should have split interstate edges. This means there should be no interstate edges containing both a + condition and an assignment. + :param sdfg: The SDFG to analyze. + :param w_d_map: Dictionary of SDFG elements to (work, depth) tuples. Result will be saved in here. + :param analyze_tasklet: The function used to analyze tasklet nodes. Analyzes either just work, work and depth or average parallelism. + """ # deepcopy such that original sdfg not changed - # sdfg = deepcopy(sdfg) + sdfg = deepcopy(sdfg) - # Check if the SDFG has any dynamically unbounded executions, i.e., if there are any states that have neither a - # statically known number of executions, nor an upper bound on the number of executions. Warn if this is the case. - # print('Checking for dynamically unbounded executions...') + # Run state propagation for all SDFGs recursively. This is necessary to determine the number of times each state + # will be executed, or to determine upper bounds for that number (such as in the case of branching) for sd in sdfg.all_sdfgs_recursive(): - if any([s.executions == 0 and s.dynamic_executions for s in sd.nodes()]): - pass - # print('WARNING: SDFG has dynamic executions. The analysis may fail in unexpected ways or be inaccurate.') + propagation.propagate_states_symbolically(sd) - syms_to_nonnegify = set() # Analyze the work and depth of the SDFG. - # print('Analyzing SDFG...') - sdfg_work_depth(sdfg, w_d_map, analyze_tasklet, syms_to_nonnegify) + symbols = {} + sdfg_work_depth(sdfg, w_d_map, analyze_tasklet, symbols) - # TODO: maybe do this posify more often for performance? + # Note: This posify could be done more often to improve performance. array_symbols = get_array_size_symbols(sdfg) for k, (v_w, v_d) in w_d_map.items(): - v_w = posify_certain_symbols(v_w, array_symbols, syms_to_nonnegify) - v_d = posify_certain_symbols(v_d, array_symbols, syms_to_nonnegify) + # The symeval replaces nested SDFG symbols with their global counterparts. + v_w = posify_certain_symbols(symeval(v_w, symbols), array_symbols) + v_d = posify_certain_symbols(symeval(v_d, symbols), array_symbols) w_d_map[k] = (v_w, v_d) - sdfg.view() - -def get_work(sdfg_json): - # final version loads sdfg from json - # loaded = load_sdfg_from_json(sdfg_json) - # if loaded['error'] is not None: - # return loaded['error'] - # sdfg = loaded['sdfg'] - - # for now we load simply load from a file - sdfg = SDFG.from_file(sdfg_json) - - - - # try: - work_map = {} - analyze_sdfg(sdfg, work_map, get_tasklet_work) - for k, v, in work_map.items(): - work_map[k] = (str(sp.simplify(v[0]))) - return { - 'workMap': work_map, - } - # except Exception as e: - # return { - # 'error': { - # 'message': 'Failed to analyze work depth', - # 'details': get_exception_message(e), - # }, - # } - - - -def get_work_depth(sdfg_json): - # final version loads sdfg from json - # loaded = load_sdfg_from_json(sdfg_json) - # if loaded['error'] is not None: - # return loaded['error'] - # sdfg = loaded['sdfg'] - - # for now we load simply load from a file - sdfg = SDFG.from_file(sdfg_json) - - - # try: - work_depth_map = {} - analyze_sdfg(sdfg, work_depth_map, get_tasklet_work_depth) - for k, v, in work_depth_map.items(): - work_depth_map[k] = (str(sp.simplify(v[0])), str(sp.simplify(v[1]))) - return { - 'workDepthMap': work_depth_map, - } - # except Exception as e: - # return { - # 'error': { - # 'message': 'Failed to analyze work depth', - # 'details': get_exception_message(e), - # }, - # } - - - - - ################################################################################ # Utility functions for running the analysis from the command line ############# ################################################################################ -class keyvalue(argparse.Action): - - def __call__(self, parser, namespace, values, option_string=None): - setattr(namespace, self.dest, dict()) - for v in values: - k, v = v.split('=') - getattr(namespace, self.dest)[k] = v - - def main() -> None: - analyze_depth = True parser = argparse.ArgumentParser( - 'work_depth_analysis', - usage='python work_depth_analysis.py [-h] filename', + 'work_depth', + usage='python work_depth.py [-h] filename --analyze {work,workDepth,avgPar}', description='Analyze the work/depth of an SDFG.' ) parser.add_argument('filename', type=str, help='The SDFG file to analyze.') - parser.add_argument('--kwargs', nargs='*', help='Define symbols.', action=keyvalue) + parser.add_argument('--analyze', choices=['work', 'workDepth', 'avgPar'], default='workDepth', help='Choose what to analyze. Default: workDepth') args = parser.parse_args() @@ -721,34 +600,39 @@ def main() -> None: print(args.filename, 'does not exist.') exit() - symbols_map = {} - if args.kwargs: - for k, v in args.kwargs.items(): - symbols_map[k] = int(v) + if args.analyze == 'workDepth': + analyze_tasklet = get_tasklet_work_depth + elif args.analyze == 'avgPar': + analyze_tasklet = get_tasklet_avg_par + elif args.analyze == 'work': + analyze_tasklet = get_tasklet_work - # TODO: symbols_map maybe not needed - if analyze_depth: - map = get_work_depth(args.filename) - map = map['workDepthMap'] - else: - map = get_work(args.filename) - map = map['workMap'] - - # find uuid of the whole SDFG sdfg = SDFG.from_file(args.filename) - result = map[get_uuid(sdfg)] - + work_depth_map = {} + analyze_sdfg(sdfg, work_depth_map, analyze_tasklet) - print(80*'-') - if isinstance(result, Tuple): - print("Work:\t", result[0]) - print("Depth:\t", result[1]) - else: - print("Work:\t", result) - print(80*'-') + if args.analyze == 'workDepth': + for k, v, in work_depth_map.items(): + work_depth_map[k] = (str(sp.simplify(v[0])), str(sp.simplify(v[1]))) + elif args.analyze == 'work': + for k, v, in work_depth_map.items(): + work_depth_map[k] = str(sp.simplify(v[0])) + elif args.analyze == 'avgPar': + for k, v, in work_depth_map.items(): + work_depth_map[k] = str(sp.simplify(v[0] / v[1]) if str(v[1]) != '0' else 0) # work / depth = avg par + result_whole_sdfg = work_depth_map[get_uuid(sdfg)] + print(80*'-') + if args.analyze == 'workDepth': + print("Work:\t", result_whole_sdfg[0]) + print("Depth:\t", result_whole_sdfg[1]) + elif args.analyze == 'work': + print("Work:\t", result_whole_sdfg) + elif args.analyze == 'avgPar': + print("Average Parallelism:\t", result_whole_sdfg) + print(80*'-') if __name__ == '__main__': diff --git a/dace/sdfg/work_depth_analysis/work_depth_tests.py b/dace/sdfg/work_depth_analysis/work_depth_tests.py index b6af4daaa7..8f692787d9 100644 --- a/dace/sdfg/work_depth_analysis/work_depth_tests.py +++ b/dace/sdfg/work_depth_analysis/work_depth_tests.py @@ -1,5 +1,6 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +""" Contains test cases for the work depth analysis. """ import dace as dc -import numpy as np from dace.sdfg.work_depth_analysis.work_depth import analyze_sdfg, get_tasklet_work_depth from dace.sdfg.work_depth_analysis.helpers import get_uuid import sympy as sp @@ -8,7 +9,7 @@ from dace.transformation.dataflow import MapExpansion -# TODO: add tests for function calls (e.g. reduce) +# TODO: add tests for library nodes (e.g. reduce, matMul) N = dc.symbol('N') @@ -38,7 +39,7 @@ def if_else_sym(x: dc.int64[N], y: dc.int64[N], z: dc.int64[N], sum: dc.int64[1] if x[10] > 50: z[:] = x + y # N work, 1 depth else: - for i in range(K): # K work, K depth + for i in range(K): # K work, K depth sum += x[i] @dc.program @@ -61,7 +62,7 @@ def nested_if_else(x: dc.int64[N], y: dc.int64[N], z: dc.int64[N], sum: dc.int64 if x[10] > 50: if x[9] > 50: z[:] = x + y # N work, 1 depth - z[:] += 2 * x # 2*N work, 2 depth --> total outer if: 3*N work, 3 depth + z[:] += 2 * x # 2*N work, 2 depth --> total outer if: 3*N work, 3 depth else: if y[9] > 50: for i in range(K): @@ -69,9 +70,9 @@ def nested_if_else(x: dc.int64[N], y: dc.int64[N], z: dc.int64[N], sum: dc.int64 else: for j in range(M): sum += x[j] # M work, M depth - z[:] = x + y # N work, depth 1 --> total inner else: M+N work, M+1 depth + z[:] = x + y # N work, depth 1 --> total inner else: M+N work, M+1 depth # --> total outer else: Max(K, M+N) work, Max(K, M+1) depth - # --> total over both branches: Max(K, M+N, 3*N) work, Max(K, M+1, 3) depth + # --> total over both branches: Max(K, M+N, 3*N) work, Max(K, M+1, 3) depth @dc.program def max_of_positive_symbol(x: dc.float64[N]): @@ -98,7 +99,7 @@ def multiple_array_sizes(x: dc.int64[N], y: dc.int64[N], z: dc.int64[N], z3[:] = 2 * x3 + y3 # work 2*K, depth 2 elif x[3] > 0: z[:] = 3 * x + y + 1 # work 3*N, depth 3 - # --> work= Max(3*N, 2*M, 2*K) and depth = 5 + # --> work= Max(3*N, 2*M, 2*K) and depth = 5 @dc.program @@ -143,54 +144,24 @@ def break_while_loop(x:dc.float64[N]): break x += 1 -# @dc.program -# def continue_for_loop2(x:dc.float64[N]): -# i = 0 -# while True: -# i += 1 -# if i % 2 == 0: -# continue -# x += 1 -# if x[0] > 10: -# break - - -tests = [single_map, - single_for_loop, - if_else, - if_else_sym, - nested_sdfg, - nested_maps, - nested_for_loops, - nested_if_else, - max_of_positive_symbol, - multiple_array_sizes, - unbounded_while_do, - unbounded_do_while, - unbounded_nonnegify, - continue_for_loop, - break_for_loop, - break_while_loop] -# tests = [single_map] -results = [(N, 1), - (N, N), - (1000, 100), - (sp.Max(N, K), sp.Max(K,1)), - (2*N, N + 1), - (N*M, 1), - (N*K, N*K), - (sp.Max(K, M+N, 3*N), sp.Max(K, M+1, 3)), - (3*N**2, 3*N), - (sp.Max(3*N, 2*M + 3, 2*K), 5), - (N*sp.Symbol('num_execs_0_2'), sp.Symbol('num_execs_0_2')), - (N*sp.Symbol('num_execs_0_1'), sp.Symbol('num_execs_0_1')), - (sp.Max(N*sp.Symbol('num_execs_0_5'), 2*N*sp.Symbol('num_execs_0_3')), sp.Max(sp.Symbol('num_execs_0_5'), 2*sp.Symbol('num_execs_0_3'))), - (sp.Symbol('num_execs_0_2')*N, sp.Symbol('num_execs_0_2')), - (N**2, N), - (sp.Symbol('num_execs_0_3')*N, sp.Symbol('num_execs_0_3'))] - - +tests_cases = [(single_map, (N, 1)), + (single_for_loop, (N, N)), + (if_else, (1000, 100)), + (if_else_sym, (sp.Max(K, N), sp.Max(1, K))), + (nested_sdfg, (2*N, N + 1)), + (nested_maps, (M*N, 1)), + (nested_for_loops, (K*N, K*N)), + (nested_if_else, (sp.Max(K, 3*N, M + N), sp.Max(3, K, M + 1))), + (max_of_positive_symbol, (3*N**2, 3*N)), + (multiple_array_sizes, (sp.Max(2*K, 3*N, 2*M + 3), 5)), + (unbounded_while_do, (sp.Symbol('num_execs_0_2', nonnegative=True)*N, sp.Symbol('num_execs_0_2', nonnegative=True))), + # TODO: why we get this ugly max(1, num_execs) here?? + (unbounded_do_while, (sp.Max(1,sp.Symbol('num_execs_0_1', nonnegative=True))*N, sp.Max(1,sp.Symbol('num_execs_0_1', nonnegative=True)))), + (unbounded_nonnegify, (2*sp.Symbol('num_execs_0_7', nonnegative=True)*N, 2*sp.Symbol('num_execs_0_7', nonnegative=True))), + (continue_for_loop, (sp.Symbol('num_execs_0_6', nonnegative=True)*N, sp.Symbol('num_execs_0_6', nonnegative=True))), + (break_for_loop, (N**2, N)), + (break_while_loop, (sp.Symbol('num_execs_0_5', nonnegative=True)*N, sp.Symbol('num_execs_0_5', nonnegative=True)))] def test_work_depth(): @@ -198,37 +169,35 @@ def test_work_depth(): failed = 0 exception = 0 failed_tests = [] - for test, correct in zip(tests, results): + for test, correct in tests_cases: w_d_map = {} - sdfg = test.to_sdfg()#simplify=False) + sdfg = test.to_sdfg() if 'nested_sdfg' in test.name: sdfg.apply_transformations(NestSDFG) if 'nested_maps' in test.name: sdfg.apply_transformations(MapExpansion) - # sdfg.view() - # try: - analyze_sdfg(sdfg, w_d_map, get_tasklet_work_depth) - res = w_d_map[get_uuid(sdfg)] - - # check result - if correct == res: - good += 1 - else: - # sdfg.view() + try: + analyze_sdfg(sdfg, w_d_map, get_tasklet_work_depth) + res = w_d_map[get_uuid(sdfg)] + + # check result + if correct == res: + good += 1 + else: + failed += 1 + failed_tests.append(test.name) + print(f'Test {test.name} failed:') + print('correct', correct) + print('result', res) + print() + except Exception as e: + print(e) failed += 1 - failed_tests.append(test.name) - print(f'Test {test.name} failed:') - print('correct', correct) - print('result', res) - print() - # except Exception as e: - # print(e) - # failed += 1 - # exception += 1 + exception += 1 print(100*'-') print(100*'-') - print(f'Ran {len(tests)} tests. {good} succeeded and {failed} failed ' + print(f'Ran {len(tests_cases)} tests. {good} succeeded and {failed} failed ' f'({exception} of those triggered an exception)') print(100*'-') print('failed tests:', failed_tests) From 027d9de01c05465a5f9474d3bcc0417baaa4e46c Mon Sep 17 00:00:00 2001 From: Cliff Hodel Date: Thu, 20 Jul 2023 09:57:19 +0200 Subject: [PATCH 06/18] yapf for formatting --- dace/sdfg/propagation.py | 8 +- dace/sdfg/work_depth_analysis/helpers.py | 50 +++--- dace/sdfg/work_depth_analysis/work_depth.py | 114 +++++++------- .../work_depth_analysis/work_depth_tests.py | 142 +++++++++--------- 4 files changed, 163 insertions(+), 151 deletions(-) diff --git a/dace/sdfg/propagation.py b/dace/sdfg/propagation.py index d8f1bee850..1e2d67ce21 100644 --- a/dace/sdfg/propagation.py +++ b/dace/sdfg/propagation.py @@ -813,7 +813,7 @@ def propagate_states_symbolically(sdfg) -> None: unannotated_loop_edge = oedge # remove cycle, since it is now annotated with symbol to_remove.append(cycle) - + for c in to_remove: unannotated_cycle_states.remove(c) @@ -827,7 +827,10 @@ def propagate_states_symbolically(sdfg) -> None: # Introduce the num_execs symbol and propagate it down the loop. # These symbols will always be non-negative. - traversal_q.append((unannotated_loop_edge.dst, Symbol(f'num_execs_{sdfg.sdfg_id}_{sdfg.node_id(unannotated_loop_edge.dst)}', nonnegative=True), False, itvar_stack)) + traversal_q.append( + (unannotated_loop_edge.dst, + Symbol(f'num_execs_{sdfg.sdfg_id}_{sdfg.node_id(unannotated_loop_edge.dst)}', + nonnegative=True), False, itvar_stack)) else: # Traverse as a conditional split. proposed_executions = state.executions @@ -860,7 +863,6 @@ def propagate_states_symbolically(sdfg) -> None: sdfg.remove_node(temp_exit_state) - def propagate_states(sdfg) -> None: """ Annotate the states of an SDFG with the number of executions. diff --git a/dace/sdfg/work_depth_analysis/helpers.py b/dace/sdfg/work_depth_analysis/helpers.py index b9964db3d1..a80e769f64 100644 --- a/dace/sdfg/work_depth_analysis/helpers.py +++ b/dace/sdfg/work_depth_analysis/helpers.py @@ -9,6 +9,7 @@ NodeT = str EdgeT = Tuple[NodeT, NodeT] + class NodeCycle: nodes: Set[NodeT] = [] @@ -23,30 +24,27 @@ def length(self) -> int: UUID_SEPARATOR = '/' + def ids_to_string(sdfg_id, state_id=-1, node_id=-1, edge_id=-1): - return (str(sdfg_id) + UUID_SEPARATOR + str(state_id) + UUID_SEPARATOR + - str(node_id) + UUID_SEPARATOR + str(edge_id)) + return (str(sdfg_id) + UUID_SEPARATOR + str(state_id) + UUID_SEPARATOR + str(node_id) + UUID_SEPARATOR + + str(edge_id)) + def get_uuid(element, state=None): if isinstance(element, SDFG): return ids_to_string(element.sdfg_id) elif isinstance(element, SDFGState): - return ids_to_string(element.parent.sdfg_id, - element.parent.node_id(element)) + return ids_to_string(element.parent.sdfg_id, element.parent.node_id(element)) elif isinstance(element, nodes.Node): - return ids_to_string(state.parent.sdfg_id, state.parent.node_id(state), - state.node_id(element)) + return ids_to_string(state.parent.sdfg_id, state.parent.node_id(state), state.node_id(element)) else: return ids_to_string(-1) - -def get_domtree( - graph: nx.DiGraph, - start_node: str, - idom: Dict[str, str] = None -): + + +def get_domtree(graph: nx.DiGraph, start_node: str, idom: Dict[str, str] = None): idom = idom or nx.immediate_dominators(graph, start_node) - alldominated = { n: set() for n in graph.nodes } + alldominated = {n: set() for n in graph.nodes} domtree = nx.DiGraph() for node, dom in idom.items(): @@ -75,9 +73,9 @@ def get_domtree( return alldominated, domtree -def get_backedges( - graph: nx.DiGraph, start: Optional[NodeT], strict: bool = False -) -> Union[Set[EdgeT], Tuple[Set[EdgeT], Set[EdgeT]]]: +def get_backedges(graph: nx.DiGraph, + start: Optional[NodeT], + strict: bool = False) -> Union[Set[EdgeT], Tuple[Set[EdgeT], Set[EdgeT]]]: '''Find all backedges in a directed graph. Note: @@ -106,16 +104,13 @@ def get_backedges( start = node break if start is None: - raise ValueError( - 'No start node provided and no start node could ' + - 'be determined automatically' - ) + raise ValueError('No start node provided and no start node could ' + 'be determined automatically') # Gather all cycles in the graph. Cycles are represented as a sequence of # nodes. # O((|V|+|E|)*(C+1)), for C cycles. - all_cycles_nx: List[List[NodeT]] = nx.cycles.simple_cycles(graph) - #all_cycles_nx: List[List[NodeT]] = nx.simple_cycles(graph) + all_cycles_nx: List[List[NodeT]] = nx.cycles.simple_cycles(graph) + #all_cycles_nx: List[List[NodeT]] = nx.simple_cycles(graph) all_cycles: Set[NodeCycle] = set() for cycle in all_cycles_nx: all_cycles.add(NodeCycle(cycle)) @@ -207,7 +202,6 @@ def get_backedges( if eclipsed_candidates: eclipsed_backedges.update(eclipsed_candidates) - # Continue BFS. for neighbour in graph.successors(node): if neighbour not in visited: @@ -235,7 +229,7 @@ def find_loop_guards_tails_exits(sdfg_nx: nx.DiGraph): break if start is None: raise ValueError('No start node could be determined') - + # sdfg can have multiple end nodes --> not good for postDomTree # --> add a new end node artificial_end_node = 'artificial_end_node' @@ -264,7 +258,6 @@ def find_loop_guards_tails_exits(sdfg_nx: nx.DiGraph): backedgesDstDict[be[1]].add(be) else: backedgesDstDict[be[1]] = set([be]) - # This list will be filled with triples (node, oNode, exit), one triple for each loop construct in the SDFG. # There will always be a backedge from oNode to node. Either node or oNode will be the corresponding loop guard, @@ -283,7 +276,6 @@ def find_loop_guards_tails_exits(sdfg_nx: nx.DiGraph): if not edge in backedges: successors.append(edge[1]) - # For each incoming backedge, we want to find oNode and exit. There can be multiple backedges, in case # we have a continue statement in the original code. But we can handle these backedges normally. for be in inc_backedges: @@ -293,7 +285,7 @@ def find_loop_guards_tails_exits(sdfg_nx: nx.DiGraph): exitCandidates = set() # search for exit candidates: # a state is a exit candidate if: - # - it is in successor and it does not dominate oNode (else it dominates + # - it is in successor and it does not dominate oNode (else it dominates # the last loop state, and hence is inside the loop itself) # - is is a successor of oNode (but not node) # This handles both cases of while-do and do-while loops @@ -303,7 +295,7 @@ def find_loop_guards_tails_exits(sdfg_nx: nx.DiGraph): for succ in sdfg_nx.successors(oNode): if succ != node: exitCandidates.add(succ) - + if len(exitCandidates) == 0: raise ValueError('failed to find any exit nodes') elif len(exitCandidates) > 1: @@ -327,7 +319,7 @@ def find_loop_guards_tails_exits(sdfg_nx: nx.DiGraph): elif curr_level == minLevel: # add cand to curr set minSet.add(cand) - + if len(minSet) > 0: exitCandidates = minSet else: diff --git a/dace/sdfg/work_depth_analysis/work_depth.py b/dace/sdfg/work_depth_analysis/work_depth.py index a3c43c9826..ad4f7a842c 100644 --- a/dace/sdfg/work_depth_analysis/work_depth.py +++ b/dace/sdfg/work_depth_analysis/work_depth.py @@ -21,7 +21,6 @@ from dace.sdfg.work_depth_analysis.helpers import get_uuid, find_loop_guards_tails_exits - def get_array_size_symbols(sdfg): """ Returns all symbols that appear isolated in shapes of the SDFG's arrays. @@ -38,6 +37,7 @@ def get_array_size_symbols(sdfg): symbols.add(s) return symbols + def posify_certain_symbols(expr, syms_to_posify): """ Takes an expression and evaluates it while assuming that certain symbols are positive. @@ -48,12 +48,12 @@ def posify_certain_symbols(expr, syms_to_posify): """ expr = sp.sympify(expr) - - reps = {s: sp.Dummy(s.name, positive=True, **s.assumptions0) - for s in syms_to_posify if s.is_positive is None} + + reps = {s: sp.Dummy(s.name, positive=True, **s.assumptions0) for s in syms_to_posify if s.is_positive is None} expr = expr.subs(reps) return expr.subs({r: s for s, r in reps.items()}) + def symeval(val, symbols): """ Takes a sympy expression and substitutes its symbols according to a dict { old_symbol: new_symbol}. @@ -61,22 +61,18 @@ def symeval(val, symbols): :param val: The expression we are updating. :param symbols: Dictionary of key value pairs { old_symbol: new_symbol}. """ - first_replacement = { - pystr_to_symbolic(k): pystr_to_symbolic('__REPLSYM_' + k) - for k in symbols.keys() - } - second_replacement = { - pystr_to_symbolic('__REPLSYM_' + k): v - for k, v in symbols.items() - } + first_replacement = {pystr_to_symbolic(k): pystr_to_symbolic('__REPLSYM_' + k) for k in symbols.keys()} + second_replacement = {pystr_to_symbolic('__REPLSYM_' + k): v for k, v in symbols.items()} return val.subs(first_replacement).subs(second_replacement) + def evaluate_symbols(base, new): result = {} for k, v in new.items(): result[k] = symeval(v, base) return result + def count_work_matmul(node, symbols, state): A_memlet = next(e for e in state.in_edges(node) if e.dst_conn == '_a') B_memlet = next(e for e in state.in_edges(node) if e.dst_conn == '_b') @@ -92,6 +88,7 @@ def count_work_matmul(node, symbols, state): result *= symeval(A_memlet.data.subset.size()[-1], symbols) return result + def count_work_reduce(node, symbols, state): result = 0 if node.wcr is not None: @@ -106,16 +103,19 @@ def count_work_reduce(node, symbols, state): result = 0 return result + LIBNODES_TO_WORK = { MatMul: count_work_matmul, Transpose: lambda *args: 0, Reduce: count_work_reduce, } + def count_depth_matmul(node, symbols, state): # For now we set it equal to work: see comments in count_depth_reduce just below return count_work_matmul(node, symbols, state) + def count_depth_reduce(node, symbols, state): # depth of reduction is log2 of the work # TODO: Can we actually assume this? Or is it equal to the work? @@ -132,7 +132,6 @@ def count_depth_reduce(node, symbols, state): Reduce: count_depth_reduce, } - bigo = sp.Function('bigo') PYFUNC_TO_ARITHMETICS = { 'float': 0, @@ -152,7 +151,7 @@ def count_depth_reduce(node, symbols, state): 'ceiling': 0, 'floor': 0, 'abs': 0 -} +} class ArithmeticCounter(ast.NodeVisitor): @@ -174,7 +173,9 @@ def visit_UnaryOp(self, node): def visit_Call(self, node): fname = astunparse.unparse(node.func)[:-1] if fname not in PYFUNC_TO_ARITHMETICS: - print('WARNING: Unrecognized python function "%s". If this is a type conversion, like "dace.float64", then this is fine.' % fname) + print( + 'WARNING: Unrecognized python function "%s". If this is a type conversion, like "dace.float64", then this is fine.' + % fname) return self.generic_visit(node) self.count += PYFUNC_TO_ARITHMETICS[fname] return self.generic_visit(node) @@ -188,6 +189,7 @@ def visit_For(self, node): def visit_While(self, node): raise NotImplementedError + def count_arithmetic_ops_code(code): ctr = ArithmeticCounter() if isinstance(code, (tuple, list)): @@ -199,6 +201,7 @@ def count_arithmetic_ops_code(code): ctr.visit(code) return ctr.count + class DepthCounter(ast.NodeVisitor): # so far this is identical to the ArithmeticCounter above. def __init__(self): @@ -218,7 +221,9 @@ def visit_UnaryOp(self, node): def visit_Call(self, node): fname = astunparse.unparse(node.func)[:-1] if fname not in PYFUNC_TO_ARITHMETICS: - print('WARNING: Unrecognized python function "%s". If this is a type conversion, like "dace.float64", then this is fine.' % fname) + print( + 'WARNING: Unrecognized python function "%s". If this is a type conversion, like "dace.float64", then this is fine.' + % fname) return self.generic_visit(node) self.count += PYFUNC_TO_ARITHMETICS[fname] return self.generic_visit(node) @@ -232,6 +237,7 @@ def visit_For(self, node): def visit_While(self, node): raise NotImplementedError + def count_depth_code(code): # so far this is the same as the work counter, since work = depth for each tasklet, as we can't assume any parallelism ctr = ArithmeticCounter() @@ -244,11 +250,12 @@ def count_depth_code(code): ctr.visit(code) return ctr.count + def tasklet_work(tasklet_node, state): if tasklet_node.code.language == dtypes.Language.CPP: for oedge in state.out_edges(tasklet_node): return bigo(oedge.data.num_accesses) - + elif tasklet_node.code.language == dtypes.Language.Python: return count_arithmetic_ops_code(tasklet_node.code.code) else: @@ -257,6 +264,7 @@ def tasklet_work(tasklet_node, state): 'languages work = 1 will be counted for each tasklet.') return 1 + def tasklet_depth(tasklet_node, state): # TODO: how to get depth of CPP tasklets? # For now we use depth == work: @@ -271,18 +279,21 @@ def tasklet_depth(tasklet_node, state): 'languages depth = 1 will be counted for each tasklet.') return 1 + def get_tasklet_work(node, state): return tasklet_work(node, state), -1 + def get_tasklet_work_depth(node, state): return tasklet_work(node, state), tasklet_depth(node, state) + def get_tasklet_avg_par(node, state): return tasklet_work(node, state), tasklet_depth(node, state) - -def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], analyze_tasklet, symbols) -> Tuple[sp.Expr, sp.Expr]: +def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], analyze_tasklet, + symbols) -> Tuple[sp.Expr, sp.Expr]: """ Analyze the work and depth of a given SDFG. First we determine the work and depth of each state. Then we break loops in the state machine, such that we get a DAG. @@ -295,7 +306,6 @@ def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], ana :return: A tuple containing the work and depth of the SDFG. """ - # First determine the work and depth of each state individually. # Keep track of the work and depth for each state in a dictionary, where work and depth are multiplied by the number # of times the state will be executed. @@ -307,7 +317,6 @@ def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], ana state_depths[state] = sp.simplify(state_depth * state.executions) w_d_map[get_uuid(state)] = (state_works[state], state_depths[state]) - # Prepare the SDFG for a depth analysis by breaking loops. This removes the edge between the last loop state and # the guard, and instead places an edge between the last loop state and the exit state. # This transforms the state machine into a DAG. Hence, we can find the "heaviest" and "deepest" paths in linear time. @@ -315,8 +324,8 @@ def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], ana # identify all loops in the SDFG nodes_oNodes_exits = find_loop_guards_tails_exits(sdfg._nx) - - # Now we need to go over each triple (node, oNode, exits). For each triple, we + + # Now we need to go over each triple (node, oNode, exits). For each triple, we # - remove edge (oNode, node), i.e. the backward edge # - for all exits e, add edge (oNode, e). This edge may already exist for node, oNode, exits in nodes_oNodes_exits: @@ -357,14 +366,14 @@ def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], ana # If we are analysing average parallelism, we don't search "heaviest" and "deepest" paths separately, but we want one # single path with the least average parallelsim (of all paths with more than 0 work). if analyze_tasklet == get_tasklet_avg_par: - if state in depth_map: # and hence als state in work_map + if state in depth_map: # and hence als state in work_map # if current path has 0 depth, we don't do anything. if n_depth != 0: # see if we need to update the work and depth of the current state # we update if avg parallelism of new incoming path is less than current avg parallelism old_avg_par = sp.simplify(work_map[state] / depth_map[state]) new_avg_par = sp.simplify(n_work / n_depth) - + if depth_map[state] == 0 or new_avg_par < old_avg_par: # old value was divided by zero or new path gives actually worse avg par, then we keep new value depth_map[state] = n_depth @@ -374,17 +383,17 @@ def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], ana work_map[state] = n_work else: # search heaviest and deepest path separately - if state in depth_map: # and consequently also in work_map + if state in depth_map: # and consequently also in work_map depth_map[state] = sp.Max(depth_map[state], n_depth) work_map[state] = sp.Max(work_map[state], n_work) else: depth_map[state] = n_depth work_map[state] = n_work - + out_edges = sdfg.out_edges(state) # only advance after all incoming edges were visited (meaning that current work depth values of state are final). if any(iedge not in visited for iedge in sdfg.in_edges(state)): - pass + pass else: for oedge in out_edges: traversal_q.append((oedge.dst, depth_map[state], work_map[state], oedge)) @@ -395,14 +404,19 @@ def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], ana except KeyError: # If we get a KeyError above, this means that the traversal never reached the dummy_exit state. # This happens if the loops were not properly detected and broken. - raise Exception('Analysis failed, since not all loops got detected. It may help to use more structured loop constructs.') - + raise Exception( + 'Analysis failed, since not all loops got detected. It may help to use more structured loop constructs.') + sdfg_result = (sp.simplify(max_work), sp.simplify(max_depth)) w_d_map[get_uuid(sdfg)] = sdfg_result return sdfg_result -def scope_work_depth(state: SDFGState, w_d_map: Dict[str, sp.Expr], analyze_tasklet, symbols, entry: nd.EntryNode = None) -> Tuple[sp.Expr, sp.Expr]: +def scope_work_depth(state: SDFGState, + w_d_map: Dict[str, sp.Expr], + analyze_tasklet, + symbols, + entry: nd.EntryNode = None) -> Tuple[sp.Expr, sp.Expr]: """ Analyze the work and depth of a scope. This works by traversing through the scope analyzing the work and depth of each encountered node. @@ -435,7 +449,7 @@ def scope_work_depth(state: SDFGState, w_d_map: Dict[str, sp.Expr], analyze_task s_work, s_depth = scope_work_depth(state, w_d_map, analyze_tasklet, symbols, node) # add up work for whole state, but also save work for this sub-scope scope in w_d_map work += s_work - w_d_map[get_uuid(node, state)] = (s_work, s_depth) + w_d_map[get_uuid(node, state)] = (s_work, s_depth) elif node == scope_exit: # don't do anything for exit nodes, everthing handled already in the corresponding entry node. pass @@ -458,13 +472,12 @@ def scope_work_depth(state: SDFGState, w_d_map: Dict[str, sp.Expr], analyze_task w_d_map[get_uuid(node, state)] = (nsdfg_work, nsdfg_depth) elif isinstance(node, nd.LibraryNode): lib_node_work = LIBNODES_TO_WORK[type(node)](node, symbols, state) - lib_node_depth = -1 # not analyzed + lib_node_depth = -1 # not analyzed if analyze_tasklet != get_tasklet_work: # we are analyzing depth lib_node_depth = LIBNODES_TO_DEPTH[type(node)](node, symbols, state) w_d_map[get_uuid(node, state)] = (lib_node_work, lib_node_depth) - - + if entry is not None: # If the scope being analyzed is a map, multiply the work by the number of iterations of the map. if isinstance(entry, nd.MapEntry): @@ -475,7 +488,6 @@ def scope_work_depth(state: SDFGState, w_d_map: Dict[str, sp.Expr], analyze_task else: print('WARNING: Only Map scopes are supported in work analysis for now. Assuming 1 iteration.') - # Work inside a state can simply be summed up. But now we need to find the depth of a state (i.e. longest path). # Since dataflow graph is a DAG, this can be done in linear time. max_depth = sp.sympify(0) @@ -531,7 +543,8 @@ def scope_work_depth(state: SDFGState, w_d_map: Dict[str, sp.Expr], analyze_task return scope_result -def state_work_depth(state: SDFGState, w_d_map: Dict[str, sp.Expr], analyze_tasklet, symbols) -> Tuple[sp.Expr, sp.Expr]: +def state_work_depth(state: SDFGState, w_d_map: Dict[str, sp.Expr], analyze_tasklet, + symbols) -> Tuple[sp.Expr, sp.Expr]: """ Analyze the work and depth of a state. @@ -575,24 +588,24 @@ def analyze_sdfg(sdfg: SDFG, w_d_map: Dict[str, sp.Expr], analyze_tasklet) -> No v_w = posify_certain_symbols(symeval(v_w, symbols), array_symbols) v_d = posify_certain_symbols(symeval(v_d, symbols), array_symbols) w_d_map[k] = (v_w, v_d) - - ################################################################################ # Utility functions for running the analysis from the command line ############# ################################################################################ + def main() -> None: - parser = argparse.ArgumentParser( - 'work_depth', - usage='python work_depth.py [-h] filename --analyze {work,workDepth,avgPar}', - description='Analyze the work/depth of an SDFG.' - ) + parser = argparse.ArgumentParser('work_depth', + usage='python work_depth.py [-h] filename --analyze {work,workDepth,avgPar}', + description='Analyze the work/depth of an SDFG.') parser.add_argument('filename', type=str, help='The SDFG file to analyze.') - parser.add_argument('--analyze', choices=['work', 'workDepth', 'avgPar'], default='workDepth', help='Choose what to analyze. Default: workDepth') + parser.add_argument('--analyze', + choices=['work', 'workDepth', 'avgPar'], + default='workDepth', + help='Choose what to analyze. Default: workDepth') args = parser.parse_args() @@ -600,7 +613,7 @@ def main() -> None: print(args.filename, 'does not exist.') exit() - if args.analyze == 'workDepth': + if args.analyze == 'workDepth': analyze_tasklet = get_tasklet_work_depth elif args.analyze == 'avgPar': analyze_tasklet = get_tasklet_avg_par @@ -611,7 +624,6 @@ def main() -> None: work_depth_map = {} analyze_sdfg(sdfg, work_depth_map, analyze_tasklet) - if args.analyze == 'workDepth': for k, v, in work_depth_map.items(): work_depth_map[k] = (str(sp.simplify(v[0])), str(sp.simplify(v[1]))) @@ -620,11 +632,11 @@ def main() -> None: work_depth_map[k] = str(sp.simplify(v[0])) elif args.analyze == 'avgPar': for k, v, in work_depth_map.items(): - work_depth_map[k] = str(sp.simplify(v[0] / v[1]) if str(v[1]) != '0' else 0) # work / depth = avg par + work_depth_map[k] = str(sp.simplify(v[0] / v[1]) if str(v[1]) != '0' else 0) # work / depth = avg par result_whole_sdfg = work_depth_map[get_uuid(sdfg)] - print(80*'-') + print(80 * '-') if args.analyze == 'workDepth': print("Work:\t", result_whole_sdfg[0]) print("Depth:\t", result_whole_sdfg[1]) @@ -632,8 +644,8 @@ def main() -> None: print("Work:\t", result_whole_sdfg) elif args.analyze == 'avgPar': print("Average Parallelism:\t", result_whole_sdfg) - print(80*'-') + print(80 * '-') if __name__ == '__main__': - main() \ No newline at end of file + main() diff --git a/dace/sdfg/work_depth_analysis/work_depth_tests.py b/dace/sdfg/work_depth_analysis/work_depth_tests.py index 8f692787d9..d0b32d5ab8 100644 --- a/dace/sdfg/work_depth_analysis/work_depth_tests.py +++ b/dace/sdfg/work_depth_analysis/work_depth_tests.py @@ -8,10 +8,8 @@ from dace.transformation.interstate import NestSDFG from dace.transformation.dataflow import MapExpansion - # TODO: add tests for library nodes (e.g. reduce, matMul) - N = dc.symbol('N') M = dc.symbol('M') K = dc.symbol('K') @@ -21,85 +19,91 @@ def single_map(x: dc.float64[N], y: dc.float64[N], z: dc.float64[N]): z[:] = x + y + @dc.program def single_for_loop(x: dc.float64[N], y: dc.float64[N]): for i in range(N): x[i] += y[i] + @dc.program def if_else(x: dc.int64[1000], y: dc.int64[1000], z: dc.int64[1000], sum: dc.int64[1]): if x[10] > 50: - z[:] = x + y # 1000 work, 1 depth + z[:] = x + y # 1000 work, 1 depth else: - for i in range(100): # 100 work, 100 depth + for i in range(100): # 100 work, 100 depth sum += x[i] + @dc.program def if_else_sym(x: dc.int64[N], y: dc.int64[N], z: dc.int64[N], sum: dc.int64[1]): if x[10] > 50: - z[:] = x + y # N work, 1 depth + z[:] = x + y # N work, 1 depth else: - for i in range(K): # K work, K depth + for i in range(K): # K work, K depth sum += x[i] + @dc.program def nested_sdfg(x: dc.float64[N], y: dc.float64[N], z: dc.float64[N]): single_map(x, y, z) single_for_loop(x, y) + @dc.program def nested_maps(x: dc.float64[N, M], y: dc.float64[N, M], z: dc.float64[N, M]): z[:, :] = x + y + @dc.program def nested_for_loops(x: dc.float64[N], y: dc.float64[K]): for i in range(N): for j in range(K): x[i] += y[j] + @dc.program def nested_if_else(x: dc.int64[N], y: dc.int64[N], z: dc.int64[N], sum: dc.int64[1]): if x[10] > 50: if x[9] > 50: - z[:] = x + y # N work, 1 depth - z[:] += 2 * x # 2*N work, 2 depth --> total outer if: 3*N work, 3 depth + z[:] = x + y # N work, 1 depth + z[:] += 2 * x # 2*N work, 2 depth --> total outer if: 3*N work, 3 depth else: if y[9] > 50: - for i in range(K): - sum += x[i] # K work, K depth + for i in range(K): + sum += x[i] # K work, K depth else: - for j in range(M): - sum += x[j] # M work, M depth - z[:] = x + y # N work, depth 1 --> total inner else: M+N work, M+1 depth - # --> total outer else: Max(K, M+N) work, Max(K, M+1) depth - # --> total over both branches: Max(K, M+N, 3*N) work, Max(K, M+1, 3) depth + for j in range(M): + sum += x[j] # M work, M depth + z[:] = x + y # N work, depth 1 --> total inner else: M+N work, M+1 depth + # --> total outer else: Max(K, M+N) work, Max(K, M+1) depth + # --> total over both branches: Max(K, M+N, 3*N) work, Max(K, M+1, 3) depth + @dc.program def max_of_positive_symbol(x: dc.float64[N]): if x[0] > 0: - for i in range(2*N): # work 2*N^2, depth 2*N + for i in range(2 * N): # work 2*N^2, depth 2*N x += 1 else: - for j in range(3*N): # work 3*N^2, depth 3*N + for j in range(3 * N): # work 3*N^2, depth 3*N x += 1 - # total is work 3*N^2, depth 3*N without any max - + # total is work 3*N^2, depth 3*N without any max @dc.program -def multiple_array_sizes(x: dc.int64[N], y: dc.int64[N], z: dc.int64[N], - x2: dc.int64[M], y2: dc.int64[M], z2: dc.int64[M], - x3: dc.int64[K], y3: dc.int64[K], z3: dc.int64[K]): +def multiple_array_sizes(x: dc.int64[N], y: dc.int64[N], z: dc.int64[N], x2: dc.int64[M], y2: dc.int64[M], + z2: dc.int64[M], x3: dc.int64[K], y3: dc.int64[K], z3: dc.int64[K]): if x[0] > 0: - z[:] = 2 * x + y # work 2*N, depth 2 + z[:] = 2 * x + y # work 2*N, depth 2 elif x[1] > 0: - z2[:] = 2 * x2 + y2 # work 2*M + 3, depth 5 + z2[:] = 2 * x2 + y2 # work 2*M + 3, depth 5 z2[0] += 3 + z[1] + z[2] elif x[2] > 0: - z3[:] = 2 * x3 + y3 # work 2*K, depth 2 + z3[:] = 2 * x3 + y3 # work 2*K, depth 2 elif x[3] > 0: - z[:] = 3 * x + y + 1 # work 3*N, depth 3 - # --> work= Max(3*N, 2*M, 2*K) and depth = 5 + z[:] = 3 * x + y + 1 # work 3*N, depth 3 + # --> work= Max(3*N, 2*M, 2*K) and depth = 5 @dc.program @@ -107,61 +111,71 @@ def unbounded_while_do(x: dc.float64[N]): while x[0] < 100: x += 1 + @dc.program def unbounded_do_while(x: dc.float64[N]): while True: x += 1 if x[0] >= 100: break - + @dc.program def unbounded_nonnegify(x: dc.float64[N]): while x[0] < 100: if x[1] < 42: - x += 3*x + x += 3 * x else: x += x + @dc.program -def continue_for_loop(x:dc.float64[N]): +def continue_for_loop(x: dc.float64[N]): for i in range(N): if x[i] > 100: continue x += 1 + @dc.program -def break_for_loop(x:dc.float64[N]): +def break_for_loop(x: dc.float64[N]): for i in range(N): if x[i] > 100: break - x += 1 + x += 1 + @dc.program -def break_while_loop(x:dc.float64[N]): +def break_while_loop(x: dc.float64[N]): while x[0] > 10: if x[1] > 100: break - x += 1 - - -tests_cases = [(single_map, (N, 1)), - (single_for_loop, (N, N)), - (if_else, (1000, 100)), - (if_else_sym, (sp.Max(K, N), sp.Max(1, K))), - (nested_sdfg, (2*N, N + 1)), - (nested_maps, (M*N, 1)), - (nested_for_loops, (K*N, K*N)), - (nested_if_else, (sp.Max(K, 3*N, M + N), sp.Max(3, K, M + 1))), - (max_of_positive_symbol, (3*N**2, 3*N)), - (multiple_array_sizes, (sp.Max(2*K, 3*N, 2*M + 3), 5)), - (unbounded_while_do, (sp.Symbol('num_execs_0_2', nonnegative=True)*N, sp.Symbol('num_execs_0_2', nonnegative=True))), - # TODO: why we get this ugly max(1, num_execs) here?? - (unbounded_do_while, (sp.Max(1,sp.Symbol('num_execs_0_1', nonnegative=True))*N, sp.Max(1,sp.Symbol('num_execs_0_1', nonnegative=True)))), - (unbounded_nonnegify, (2*sp.Symbol('num_execs_0_7', nonnegative=True)*N, 2*sp.Symbol('num_execs_0_7', nonnegative=True))), - (continue_for_loop, (sp.Symbol('num_execs_0_6', nonnegative=True)*N, sp.Symbol('num_execs_0_6', nonnegative=True))), - (break_for_loop, (N**2, N)), - (break_while_loop, (sp.Symbol('num_execs_0_5', nonnegative=True)*N, sp.Symbol('num_execs_0_5', nonnegative=True)))] + x += 1 + + +tests_cases = [ + (single_map, (N, 1)), + (single_for_loop, (N, N)), + (if_else, (1000, 100)), + (if_else_sym, (sp.Max(K, N), sp.Max(1, K))), + (nested_sdfg, (2 * N, N + 1)), + (nested_maps, (M * N, 1)), + (nested_for_loops, (K * N, K * N)), + (nested_if_else, (sp.Max(K, 3 * N, M + N), sp.Max(3, K, M + 1))), + (max_of_positive_symbol, (3 * N**2, 3 * N)), + (multiple_array_sizes, (sp.Max(2 * K, 3 * N, 2 * M + 3), 5)), + (unbounded_while_do, (sp.Symbol('num_execs_0_2', nonnegative=True) * N, sp.Symbol('num_execs_0_2', + nonnegative=True))), + # TODO: why we get this ugly max(1, num_execs) here?? + (unbounded_do_while, (sp.Max(1, sp.Symbol('num_execs_0_1', nonnegative=True)) * N, + sp.Max(1, sp.Symbol('num_execs_0_1', nonnegative=True)))), + (unbounded_nonnegify, (2 * sp.Symbol('num_execs_0_7', nonnegative=True) * N, + 2 * sp.Symbol('num_execs_0_7', nonnegative=True))), + (continue_for_loop, (sp.Symbol('num_execs_0_6', nonnegative=True) * N, sp.Symbol('num_execs_0_6', + nonnegative=True))), + (break_for_loop, (N**2, N)), + (break_while_loop, (sp.Symbol('num_execs_0_5', nonnegative=True) * N, sp.Symbol('num_execs_0_5', nonnegative=True))) +] def test_work_depth(): @@ -188,29 +202,21 @@ def test_work_depth(): failed_tests.append(test.name) print(f'Test {test.name} failed:') print('correct', correct) - print('result', res) + print('result', res) print() except Exception as e: print(e) failed += 1 exception += 1 - - print(100*'-') - print(100*'-') + + print(100 * '-') + print(100 * '-') print(f'Ran {len(tests_cases)} tests. {good} succeeded and {failed} failed ' f'({exception} of those triggered an exception)') - print(100*'-') + print(100 * '-') print('failed tests:', failed_tests) - print(100*'-') - - - - - - - - + print(100 * '-') if __name__ == '__main__': - test_work_depth() \ No newline at end of file + test_work_depth() From 550622fdcfb65bdb24b0238342504129247acc7d Mon Sep 17 00:00:00 2001 From: Cliff Hodel Date: Fri, 21 Jul 2023 10:38:14 +0200 Subject: [PATCH 07/18] put tests into dace/tests/sdfg --- .../sdfg}/work_depth_tests.py | 33 ++++--------------- 1 file changed, 6 insertions(+), 27 deletions(-) rename {dace/sdfg/work_depth_analysis => tests/sdfg}/work_depth_tests.py (86%) diff --git a/dace/sdfg/work_depth_analysis/work_depth_tests.py b/tests/sdfg/work_depth_tests.py similarity index 86% rename from dace/sdfg/work_depth_analysis/work_depth_tests.py rename to tests/sdfg/work_depth_tests.py index d0b32d5ab8..133afe8ae4 100644 --- a/dace/sdfg/work_depth_analysis/work_depth_tests.py +++ b/tests/sdfg/work_depth_tests.py @@ -166,7 +166,7 @@ def break_while_loop(x: dc.float64[N]): (multiple_array_sizes, (sp.Max(2 * K, 3 * N, 2 * M + 3), 5)), (unbounded_while_do, (sp.Symbol('num_execs_0_2', nonnegative=True) * N, sp.Symbol('num_execs_0_2', nonnegative=True))), - # TODO: why we get this ugly max(1, num_execs) here?? + # We get this Max(1, num_execs), since it is a do-while loop, but the num_execs symbol does not capture this. (unbounded_do_while, (sp.Max(1, sp.Symbol('num_execs_0_1', nonnegative=True)) * N, sp.Max(1, sp.Symbol('num_execs_0_1', nonnegative=True)))), (unbounded_nonnegify, (2 * sp.Symbol('num_execs_0_7', nonnegative=True) * N, @@ -190,32 +190,11 @@ def test_work_depth(): sdfg.apply_transformations(NestSDFG) if 'nested_maps' in test.name: sdfg.apply_transformations(MapExpansion) - try: - analyze_sdfg(sdfg, w_d_map, get_tasklet_work_depth) - res = w_d_map[get_uuid(sdfg)] - - # check result - if correct == res: - good += 1 - else: - failed += 1 - failed_tests.append(test.name) - print(f'Test {test.name} failed:') - print('correct', correct) - print('result', res) - print() - except Exception as e: - print(e) - failed += 1 - exception += 1 - - print(100 * '-') - print(100 * '-') - print(f'Ran {len(tests_cases)} tests. {good} succeeded and {failed} failed ' - f'({exception} of those triggered an exception)') - print(100 * '-') - print('failed tests:', failed_tests) - print(100 * '-') + + analyze_sdfg(sdfg, w_d_map, get_tasklet_work_depth) + res = w_d_map[get_uuid(sdfg)] + # check result + assert correct == res if __name__ == '__main__': From 862aaeb2355154820069a976ebb1978fffb74755 Mon Sep 17 00:00:00 2001 From: Cliff Hodel Date: Fri, 21 Jul 2023 10:46:03 +0200 Subject: [PATCH 08/18] fixed import after merge --- dace/sdfg/work_depth_analysis/work_depth.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dace/sdfg/work_depth_analysis/work_depth.py b/dace/sdfg/work_depth_analysis/work_depth.py index ad4f7a842c..e44e3ef3fc 100644 --- a/dace/sdfg/work_depth_analysis/work_depth.py +++ b/dace/sdfg/work_depth_analysis/work_depth.py @@ -11,8 +11,8 @@ import os import sympy as sp from copy import deepcopy -from dace.libraries.blas import MatMul, Transpose -from dace.libraries.standard import Reduce +from dace.libraries.blas import MatMul +from dace.libraries.standard import Reduce, Transpose from dace.symbolic import pystr_to_symbolic import ast import astunparse From 9ad130d04540489576c03ca525560c8e24af5477 Mon Sep 17 00:00:00 2001 From: Cliff Hodel Date: Sat, 22 Jul 2023 11:24:45 +0200 Subject: [PATCH 09/18] merged propgatate_states_symbolically into propagate_states --- dace/sdfg/propagation.py | 231 +++----------------- dace/sdfg/work_depth_analysis/work_depth.py | 2 +- 2 files changed, 36 insertions(+), 197 deletions(-) diff --git a/dace/sdfg/propagation.py b/dace/sdfg/propagation.py index 1e2d67ce21..0fec4812b7 100644 --- a/dace/sdfg/propagation.py +++ b/dace/sdfg/propagation.py @@ -676,194 +676,7 @@ def _annotate_loop_ranges(sdfg, unannotated_cycle_states): unannotated_cycle_states.append(cycle) -def propagate_states_symbolically(sdfg) -> None: - """ - Idea is like propagate_states, but here we dont have unbounded number of executions. - Instead, we do it symbolically and annotate unbounded loops with symbols "num_exec_{sdfg_id}_{loop_start_state_id}". - - :param sdfg: The SDFG to annotate. - :note: This operates on the SDFG in-place. - """ - - # We import here to avoid cyclic imports. - from dace.sdfg import InterstateEdge - from dace.transformation.helpers import split_interstate_edges - from dace.sdfg.analysis import cfg - - # Reset the state edge annotations (which may have changed due to transformations) - reset_state_annotations(sdfg) - - # Clean up the state machine by separating combined condition and assignment - # edges. - split_interstate_edges(sdfg) - - # To enable branch annotation, we add a temporary exit state that connects - # to all child-less states. With this, we can use the dominance frontier - # to determine a full-merge state for branches. - temp_exit_state = None - for s in sdfg.nodes(): - if sdfg.out_degree(s) == 0: - if temp_exit_state is None: - temp_exit_state = sdfg.add_state('__dace_brannotate_exit') - sdfg.add_edge(s, temp_exit_state, InterstateEdge()) - - dom_frontier = cfg.acyclic_dominance_frontier(sdfg) - - # Find any valid for loop constructs and annotate the loop ranges. Any other - # cycle should be marked as unannotated. - unannotated_cycle_states = [] - _annotate_loop_ranges(sdfg, unannotated_cycle_states) - - # Keep track of states that fully merge a previous conditional split. We do - # this so we can remove the dynamic executions flag for those states. - full_merge_states = set() - - visited_states = set() - - traversal_q = deque() - traversal_q.append((sdfg.start_state, 1, False, [])) - while traversal_q: - (state, proposed_executions, proposed_dynamic, itvar_stack) = traversal_q.pop() - - out_degree = sdfg.out_degree(state) - out_edges = sdfg.out_edges(state) - - # Check if the traversal reached a state that's already been visited - # (ends traversal), or if the number of executions being propagated is - # dynamic unbounded. Otherwise, continue regular traversal. - if state in visited_states: - # This state has already been visited. - if getattr(state, 'is_loop_guard', False): - # If we encounter a loop guard that's already been visited, - # we've finished traversing a loop and can remove that loop's - # iteration variable from the stack. We additively merge the - # number of executions. - state.executions += proposed_executions - else: - # If we have already visited this state, but it is NOT a loop - # guard, this means that we can reach this state via multiple - # different paths. If so, the number of executions for this - # state is given by the maximum number of executions among each - # of the paths reaching it. If the state additionally completely - # merges a previously branched out state tree, we know that the - # number of executions isn't dynamic anymore. - state.executions = sympy.Max(state.executions, proposed_executions).doit() - if state in full_merge_states: - state.dynamic_executions = False - # TODO: do we need this else here or not? - # else: - # state.dynamic_executions = (state.dynamic_executions or proposed_dynamic) - else: - # If the state hasn't been visited yet, we calculate the number of - # executions for the next state(s) and continue propagating. - visited_states.add(state) - if state in full_merge_states: - # If this state fully merges a conditional branch, this turns - # dynamic executions back off. - proposed_dynamic = False - state.executions = proposed_executions - state.dynamic_executions = proposed_dynamic - - if out_degree == 1: - # Continue with the only child state. - if not out_edges[0].data.is_unconditional(): - # If the transition to the child state is based on a - # condition, this state could be an implicit exit state. The - # child state's number of executions is thus only given as - # an upper bound and marked as dynamic. - proposed_dynamic = True - traversal_q.append((out_edges[0].dst, proposed_executions, proposed_dynamic, itvar_stack)) - elif out_degree > 1: - if getattr(state, 'is_loop_guard', False): - itvar = symbolic.symbol(state.itvar) - loop_range = state.ranges[state.itvar] - start = loop_range[0][0] - stop = loop_range[0][1] - stride = loop_range[0][2] - - # Calculate the number of loop executions. - # This resolves ranges based on the order of iteration - # variables pushed on to the stack if we're in a nested - # loop. - loop_executions = ceiling(((stop + 1) - start) / stride) - for outer_itvar_string in reversed(itvar_stack): - outer_range = state.ranges[outer_itvar_string] - outer_start = outer_range[0][0] - outer_stop = outer_range[0][1] - outer_stride = outer_range[0][2] - outer_itvar = symbolic.pystr_to_symbolic(outer_itvar_string) - exec_repl = loop_executions.subs({outer_itvar: (outer_itvar * outer_stride + outer_start)}) - loop_executions = Sum(exec_repl, - (outer_itvar, 0, ceiling((outer_stop - outer_start) / outer_stride))) - loop_executions = loop_executions.doit() - - loop_state = state.condition_edge.dst - end_state = (out_edges[0].dst if out_edges[1].dst == loop_state else out_edges[1].dst) - - traversal_q.append((end_state, state.executions, proposed_dynamic, itvar_stack)) - traversal_q.append((loop_state, loop_executions, proposed_dynamic, itvar_stack + [state.itvar])) - else: - # Conditional split or unannotated loop. - unannotated_loop_edge = None - to_remove = [] - for oedge in out_edges: - for cycle in unannotated_cycle_states: - if oedge.dst in cycle: - # This is an unannotated loop down this branch. - unannotated_loop_edge = oedge - # remove cycle, since it is now annotated with symbol - to_remove.append(cycle) - - for c in to_remove: - unannotated_cycle_states.remove(c) - - if unannotated_loop_edge is not None: - # Traverse as an unbounded loop. - out_edges.remove(unannotated_loop_edge) - - # traverse non-loops states normally - for oedge in out_edges: - traversal_q.append((oedge.dst, state.executions, False, itvar_stack)) - - # Introduce the num_execs symbol and propagate it down the loop. - # These symbols will always be non-negative. - traversal_q.append( - (unannotated_loop_edge.dst, - Symbol(f'num_execs_{sdfg.sdfg_id}_{sdfg.node_id(unannotated_loop_edge.dst)}', - nonnegative=True), False, itvar_stack)) - else: - # Traverse as a conditional split. - proposed_executions = state.executions - proposed_dynamic = True - - # Get the dominance frontier for each child state and - # merge them into one common frontier, representing the - # branch's immediate post-dominator. If a state has no - # dominance frontier, add the state itself to the - # frontier. This takes care of the case where a branch - # is fully merged, but one branch contains no states. - common_frontier = set() - for oedge in out_edges: - frontier = dom_frontier[oedge.dst] - if not frontier: - frontier = {oedge.dst} - common_frontier |= frontier - - # Continue traversal for each child. - traversal_q.append((oedge.dst, proposed_executions, proposed_dynamic, itvar_stack)) - - # If the whole branch is not dynamic, and the - # common frontier is exactly one state, we know that - # the branch merges again at that state. - if not state.dynamic_executions and len(common_frontier) == 1: - full_merge_states.add(list(common_frontier)[0]) - - # If we had to create a temporary exit state, we remove it again here. - if temp_exit_state is not None: - sdfg.remove_node(temp_exit_state) - - -def propagate_states(sdfg) -> None: +def propagate_states(sdfg, concretize_dynamic_unbounded=False) -> None: """ Annotate the states of an SDFG with the number of executions. @@ -914,6 +727,9 @@ def propagate_states(sdfg) -> None: once. :param sdfg: The SDFG to annotate. + :param concretize_dynamic_unbounded: If True, we annotate dyncamic unbounded states with symbols of the + form "num_execs_{sdfg_id}_{loop_start_state_id}". Hence, for each + unbounded loop its states will have the same number of symbolic executions. :note: This operates on the SDFG in-place. """ @@ -945,8 +761,9 @@ def propagate_states(sdfg) -> None: # cycle should be marked as unannotated. unannotated_cycle_states = [] _annotate_loop_ranges(sdfg, unannotated_cycle_states) - # flatten the list - unannotated_cycle_states = [state for cycle in unannotated_cycle_states for state in cycle] + if not concretize_dynamic_unbounded: + # Flatten the list. This keeps the old behavior of propagate_states. + unannotated_cycle_states = [state for cycle in unannotated_cycle_states for state in cycle] # Keep track of states that fully merge a previous conditional split. We do # this so we can remove the dynamic executions flag for those states. @@ -988,7 +805,7 @@ def propagate_states(sdfg) -> None: # The only exception to this rule: If the state is in an # unannotated loop, i.e. should be annotated as dynamic # unbounded instead, we do that. - if (state in unannotated_cycle_states): + if (not concretize_dynamic_unbounded) and state in unannotated_cycle_states: state.executions = 0 state.dynamic_executions = True else: @@ -1060,17 +877,39 @@ def propagate_states(sdfg) -> None: else: # Conditional split or unannotated (dynamic unbounded) loop. unannotated_loop_edge = None - for oedge in out_edges: - if oedge.dst in unannotated_cycle_states: - # This is an unannotated loop down this branch. - unannotated_loop_edge = oedge + if concretize_dynamic_unbounded: + to_remove = [] + for oedge in out_edges: + for cycle in unannotated_cycle_states: + if oedge.dst in cycle: + # This is an unannotated loop down this branch. + unannotated_loop_edge = oedge + # remove cycle, since it is now annotated with symbol + to_remove.append(cycle) + + for c in to_remove: + unannotated_cycle_states.remove(c) + else: + for oedge in out_edges: + if oedge.dst in unannotated_cycle_states: + # This is an unannotated loop down this branch. + unannotated_loop_edge = oedge if unannotated_loop_edge is not None: # Traverse as an unbounded loop. out_edges.remove(unannotated_loop_edge) for oedge in out_edges: traversal_q.append((oedge.dst, state.executions, False, itvar_stack)) - traversal_q.append((unannotated_loop_edge.dst, 0, True, itvar_stack)) + if concretize_dynamic_unbounded: + # Here we introduce the num_exec symbol and propagate it down the loop. + # We can always assume these symbols to be non-negative. + traversal_q.append( + (unannotated_loop_edge.dst, + Symbol(f'num_execs_{sdfg.sdfg_id}_{sdfg.node_id(unannotated_loop_edge.dst)}', + nonnegative=True), False, itvar_stack)) + else: + # Propagate dynamic unbounded. + traversal_q.append((unannotated_loop_edge.dst, 0, True, itvar_stack)) else: # Traverse as a conditional split. proposed_executions = state.executions diff --git a/dace/sdfg/work_depth_analysis/work_depth.py b/dace/sdfg/work_depth_analysis/work_depth.py index e44e3ef3fc..fea0ad3453 100644 --- a/dace/sdfg/work_depth_analysis/work_depth.py +++ b/dace/sdfg/work_depth_analysis/work_depth.py @@ -575,7 +575,7 @@ def analyze_sdfg(sdfg: SDFG, w_d_map: Dict[str, sp.Expr], analyze_tasklet) -> No # Run state propagation for all SDFGs recursively. This is necessary to determine the number of times each state # will be executed, or to determine upper bounds for that number (such as in the case of branching) for sd in sdfg.all_sdfgs_recursive(): - propagation.propagate_states_symbolically(sd) + propagation.propagate_states(sd, concretize_dynamic_unbounded=True) # Analyze the work and depth of the SDFG. symbols = {} From f48c4eed815c2c6e55fb6a969ba1b3a2d9a2cfdb Mon Sep 17 00:00:00 2001 From: Cliff Hodel Date: Sat, 22 Jul 2023 11:27:34 +0200 Subject: [PATCH 10/18] fixed format issue in work_depth.py --- dace/sdfg/work_depth_analysis/work_depth.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/dace/sdfg/work_depth_analysis/work_depth.py b/dace/sdfg/work_depth_analysis/work_depth.py index fea0ad3453..f81b7bd75d 100644 --- a/dace/sdfg/work_depth_analysis/work_depth.py +++ b/dace/sdfg/work_depth_analysis/work_depth.py @@ -26,7 +26,8 @@ def get_array_size_symbols(sdfg): Returns all symbols that appear isolated in shapes of the SDFG's arrays. These symbols can then be assumed to be positive. - :note: This only works if a symbol appears in isolation, i.e. array A[N]. If we have A[N+1], we cannot assume N to be positive. + :note: This only works if a symbol appears in isolation, i.e. array A[N]. + If we have A[N+1], we cannot assume N to be positive. :param sdfg: The SDFG in which it searches for symbols. :return: A set containing symbols which we can assume to be positive. """ From 4b3f6a70518f9e13830caf5f75616bfc59242cdd Mon Sep 17 00:00:00 2001 From: Cliff Hodel Date: Sun, 23 Jul 2023 13:24:01 +0200 Subject: [PATCH 11/18] small bugfix --- dace/sdfg/work_depth_analysis/work_depth.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dace/sdfg/work_depth_analysis/work_depth.py b/dace/sdfg/work_depth_analysis/work_depth.py index f81b7bd75d..a05fe10266 100644 --- a/dace/sdfg/work_depth_analysis/work_depth.py +++ b/dace/sdfg/work_depth_analysis/work_depth.py @@ -473,6 +473,7 @@ def scope_work_depth(state: SDFGState, w_d_map[get_uuid(node, state)] = (nsdfg_work, nsdfg_depth) elif isinstance(node, nd.LibraryNode): lib_node_work = LIBNODES_TO_WORK[type(node)](node, symbols, state) + work += lib_node_work lib_node_depth = -1 # not analyzed if analyze_tasklet != get_tasklet_work: # we are analyzing depth From 7b0b2eb67b1a41a47a503a1af7c5c76875a105d9 Mon Sep 17 00:00:00 2001 From: Cliff Hodel Date: Wed, 26 Jul 2023 11:27:19 +0200 Subject: [PATCH 12/18] include wcr edges into analysis, improve LibraryNodes analysis --- dace/sdfg/work_depth_analysis/work_depth.py | 49 +++++++++++++-------- 1 file changed, 31 insertions(+), 18 deletions(-) diff --git a/dace/sdfg/work_depth_analysis/work_depth.py b/dace/sdfg/work_depth_analysis/work_depth.py index a05fe10266..fc1f2089e1 100644 --- a/dace/sdfg/work_depth_analysis/work_depth.py +++ b/dace/sdfg/work_depth_analysis/work_depth.py @@ -89,6 +89,12 @@ def count_work_matmul(node, symbols, state): result *= symeval(A_memlet.data.subset.size()[-1], symbols) return result +def count_depth_matmul(node, symbols, state): + # optimal depth of a matrix multiplication is O(log(size of shared dimension)): + A_memlet = next(e for e in state.in_edges(node) if e.dst_conn == '_a') + size_shared_dimension = symeval(A_memlet.data.subset.size()[-1], symbols) + return bigo(sp.log(size_shared_dimension)) + def count_work_reduce(node, symbols, state): result = 0 @@ -104,29 +110,18 @@ def count_work_reduce(node, symbols, state): result = 0 return result +def count_depth_reduce(node, symbols, state): + # optimal depth of reduction is log of the work + return bigo(sp.log(count_work_reduce(node, symbols, state))) + LIBNODES_TO_WORK = { MatMul: count_work_matmul, Transpose: lambda *args: 0, Reduce: count_work_reduce, + # TODO: include more } - -def count_depth_matmul(node, symbols, state): - # For now we set it equal to work: see comments in count_depth_reduce just below - return count_work_matmul(node, symbols, state) - - -def count_depth_reduce(node, symbols, state): - # depth of reduction is log2 of the work - # TODO: Can we actually assume this? Or is it equal to the work? - # Another thing to consider is that we essetially do NOT count wcr edges as operations for now... - - # return sp.ceiling(sp.log(count_work_reduce(node, symbols, state), 2)) - # set it equal to work for now - return count_work_reduce(node, symbols, state) - - LIBNODES_TO_DEPTH = { MatMul: count_depth_matmul, Transpose: lambda *args: 0, @@ -457,6 +452,10 @@ def scope_work_depth(state: SDFGState, elif isinstance(node, nd.Tasklet): # add up work for whole state, but also save work for this node in w_d_map t_work, t_depth = analyze_tasklet(node, state) + # check if tasklet has any outgoing wcr edges + for e in state.out_edges(node): + if e.data.wcr is not None: + t_work += count_arithmetic_ops_code(e.data.wcr) work += t_work w_d_map[get_uuid(node, state)] = (sp.sympify(t_work), sp.sympify(t_depth)) elif isinstance(node, nd.NestedSDFG): @@ -510,6 +509,7 @@ def scope_work_depth(state: SDFGState, traversal_q.append((node, sp.sympify(0), None)) # this map keeps track of the length of the longest path ending at each state so far seen. depth_map = {} + wcr_depth_map = {} while traversal_q: node, in_depth, in_edge = traversal_q.popleft() @@ -534,11 +534,24 @@ def scope_work_depth(state: SDFGState, # replace out_edges with the out_edges of the scope exit node out_edges = state.out_edges(exit_node) for oedge in out_edges: - traversal_q.append((oedge.dst, depth_map[node], oedge)) + # check for wcr + wcr_depth = 0 + if oedge.data.wcr is not None: + wcr_depth = oedge.data.volume + if get_uuid(node, state) in wcr_depth_map: + # max + wcr_depth_map[get_uuid(node, state)] = sp.Max(wcr_depth_map[get_uuid(node, state)], wcr_depth) + else: + wcr_depth_map[get_uuid(node, state)] = wcr_depth + # We do not need to propagate the wcr_depth to MapExits, since else this will result in depth N + 1 for Maps of range N. + # But we only want N, the code line just above this comment will then take care of that. + traversal_q.append((oedge.dst, depth_map[node] + (wcr_depth if not isinstance(oedge.dst, nd.MapExit) else 0), oedge)) if len(out_edges) == 0 or node == scope_exit: # We have reached an end node --> update max_depth max_depth = sp.Max(max_depth, depth_map[node]) - + + for uuid in wcr_depth_map: + w_d_map[uuid] = (w_d_map[uuid][0], w_d_map[uuid][1] + wcr_depth_map[uuid]) # summarise work / depth of the whole scope in the dictionary scope_result = (sp.simplify(work), sp.simplify(max_depth)) w_d_map[get_uuid(state)] = scope_result From a8efd97452a03f94a14d147f4eef4d5183056209 Mon Sep 17 00:00:00 2001 From: Cliff Hodel Date: Mon, 4 Sep 2023 12:11:56 +0200 Subject: [PATCH 13/18] imporved work depth. wcr now analyses, performance improved, assumptions can be passed --- dace/sdfg/work_depth_analysis/assumptions.py | 282 ++++++++++++++++++ dace/sdfg/work_depth_analysis/helpers.py | 2 + dace/sdfg/work_depth_analysis/work_depth.py | 290 ++++++++++++++----- tests/sdfg/work_depth_tests.py | 102 +++++-- 4 files changed, 582 insertions(+), 94 deletions(-) create mode 100644 dace/sdfg/work_depth_analysis/assumptions.py diff --git a/dace/sdfg/work_depth_analysis/assumptions.py b/dace/sdfg/work_depth_analysis/assumptions.py new file mode 100644 index 0000000000..1f167f15a3 --- /dev/null +++ b/dace/sdfg/work_depth_analysis/assumptions.py @@ -0,0 +1,282 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. + +import sympy as sp +from typing import Tuple, Dict + + +class UnionFind: + """ + Simple, not really optimized UnionFind implementation. + """ + + def __init__(self, elements) -> None: + self.ids = {e : e for e in elements} + + def add_element(self, e): + if e in self.ids: + return False + self.ids.update({e : e}) + return True + + def find(self, e): + prev = e + curr = self.ids[e] + while prev != curr: + prev = curr + curr = self.ids[curr] + # shorten the path + self.ids[e] = curr + return curr + + def union(self, e, f): + if f not in self.ids: + self.add_element(f) + self.ids[self.find(e)] = f + + +class ContradictingAssumptions(Exception): + pass + +class Assumptions: + """ + Summarises the assumptions for a single symbol in three lists: equal, greater, lesser. + """ + + def __init__(self) -> None: + self.greater = [] + self.lesser = [] + self.equal = [] + + def add_greater(self, g): + if isinstance(g, sp.Symbol): + self.greater.append(g) + else: + self.greater = [x for x in self.greater if isinstance(x, sp.Symbol) or x > g] + if len([y for y in self.greater if not isinstance(y, sp.Symbol)]) == 0: + self.greater.append(g) + self.check_consistency() + + def add_lesser(self, l): + if isinstance(l, sp.Symbol): + self.lesser.append(l) + else: + self.lesser = [x for x in self.lesser if isinstance(x, sp.Symbol) or x < l] + if len([y for y in self.lesser if not isinstance(y, sp.Symbol)]) == 0: + self.lesser.append(l) + self.check_consistency() + + def add_equal(self, e): + for x in self.equal: + if not (isinstance(x, sp.Symbol) or isinstance(e, sp.Symbol)) and x != e: + raise ContradictingAssumptions() + self.equal.append(e) + self.check_consistency() + + def check_consistency(self): + if len(self.equal) > 0: + # we know exact value + for e in self.equal: + for g in self.greater: + if (e <= g) == True: + raise ContradictingAssumptions() + for l in self.lesser: + if (e >= l) == True: + raise ContradictingAssumptions() + else: + # check if any greater > any lesser + for g in self.greater: + for l in self.lesser: + if (g > l) == True: + raise ContradictingAssumptions() + return True + + def num_assumptions(self): + # returns the number of individual assumptions for this symbol + return len(self.greater) + len(self.lesser) + len(self.equal) + +def propagate_assumptions(x, y, condensed_assumptions): + """ + Assuming x is equal to y, we propagate the assumptions on x to y. E.g. we have x==y and + x<5. Then, this method adds y<5 to the assumptions. + + :param x: A symbol. + :param y: Another symbol equal to x. + :param condensed_assumptions: Current assumptions over all symbols. + """ + if x == y: + return + assum_x = condensed_assumptions[x] + if y not in condensed_assumptions: + condensed_assumptions[y] = Assumptions() + assum_y = condensed_assumptions[y] + for e in assum_x.equal: + if e is not sp.Symbol(y): + assum_y.add_equal(e) + for g in assum_x.greater: + assum_y.add_greater(g) + for l in assum_x.lesser: + assum_y.add_lesser(l) + assum_y.check_consistency() + +def propagate_assumptions_equal_symbols(condensed_assumptions): + """ + This method handles two things: 1) It generates the substitution dict for all equality assumptions. + And 2) it propagates assumptions too all equal symbols. For each equivalence class, we find a unique + representative using UnionFind. Then, all assumptions get propagates to this symbol using + ``propagate_assumptions``. + + :param condensed_assumptions: Current assumptions over all symbols. + :return: Returns a tuple consisting of 2 substitution dicts. The first one replaces each symbol with + the unique representative of its equivalence class. The second dict replaces each symbol with its numeric + value (if we assume it to be equal some value, e.g. N==5). + """ + # Make one set with unique identifier for each equality class + uf = UnionFind(list(condensed_assumptions)) + for sym in condensed_assumptions: + for other in condensed_assumptions[sym].equal: + if isinstance(other, sp.Symbol): + # we assume sym == other --> union these + uf.union(sym, other.name) + + equality_subs1 = {} + + # For each equivalence class, we now have one unique identifier. + # For each class, we give all the assumptions to this single symbol. + # And we swap each symbol in class for this symbol. + for sym in list(condensed_assumptions): + for other in condensed_assumptions[sym].equal: + if isinstance(other, sp.Symbol): + propagate_assumptions(sym, uf.find(sym), condensed_assumptions) + equality_subs1.update({sym: sp.Symbol(uf.find(sym))}) + + equality_subs2 = {} + # In a second step, each symbol gets replace with its equal number (if present) + # using equality_subs2. + for sym, assum in condensed_assumptions.items(): + for e in assum.equal: + if not isinstance(e, sp.Symbol): + equality_subs2.update({sym: e}) + + # Imagine we have M>N and M==10. We need to deduce N<10 from that. Following code handles that: + for sym, assum in condensed_assumptions.items(): + for g in assum.greater: + if isinstance(g, sp.Symbol): + for e in condensed_assumptions[g.name].equal: + if not isinstance(e, sp.Symbol): + condensed_assumptions[sym].add_greater(e) + assum.greater.remove(g) + for l in assum.lesser: + if isinstance(l, sp.Symbol): + for e in condensed_assumptions[l.name].equal: + if not isinstance(e, sp.Symbol): + condensed_assumptions[sym].add_lesser(e) + assum.lesser.remove(l) + return equality_subs1, equality_subs2 + + +def parse_assumptions(assumptions, array_symbols): + """ + Parses a list of assumptions into substitution dictionaries. Firstly, it gathers all assumptions and + keeps only the strongest ones. Afterwards it constructs two substitution dicts for the equality + assumptions: First dict for symbol==symbol assumptions; second dict for symbol==number assumptions. + The other assumptions get handles by N tuples of substitution dicts (N = max number of concurrent + assumptions for a single symbol). Each tuple is responsible for at most one assumption for each symbol. + First dict in the tuple substitutes the symbol with the assumption; second dict restores the initial symbol. + + :param assumptions: List of assumption strings. + :param array_symbols: List of symbols we assume to be positive, since they are the size of a data container. + :return: Tuple consisting of the 2 dicts responsible for the equality assumptions and the list of size N + reponsible for all other assumptions. + """ + + # TODO: This assumptions system can be improved further, especially the deduction of further assumptions + # from the ones we already have. An example of what is not working currently: + # We have assumptions N>0 N<5 and M>5. + # In the first substitution round we use N>0 and M>5. + # In the second substitution round we use N<5. + # Therefore, Max(M, N) will not be evaluated to M, even though from the input assumptions + # one can clearly deduce M>N. + # This happens since N<5 and M>5 are not in the same substitution round. + # The easiest way to fix this is probably to actually deduce the M>N assumption. + # This guarantees that in some substitution round, we will replace M with N + _p_M, where + # _p_M is some positive symbol. Hence, we would resolve Max(M, N) to N + _p_M, which is M. + + # I suspect there to be many more cases where further assumptions will not be deduced properly. + # But if the user enters assumptions as explicitly as possible, e.g. N<5 M>5 M>N, then everything + # works fine. + + # For each symbol x appearing as a data container size, we can assume x>0. + # TODO (later): Analyze size of shapes more, such that e.g. shape N + 1 --> We can assume N > -1. + # For now we only extract assumptions out of shapes if shape consists of only a single symbol. + for sym in array_symbols: + assumptions.append(f'{sym.name}>0') + + if assumptions is None: + return {}, [({}, {})] + + # Gather assumptions, keeping only the strongest ones for each symbol. + condensed_assumptions: Dict[str, Assumptions] = {} + for a in assumptions: + if '==' in a: + symbol, rhs = a.split('==') + if symbol not in condensed_assumptions: + condensed_assumptions[symbol] = Assumptions() + try: + condensed_assumptions[symbol].add_equal(int(rhs)) + except ValueError: + condensed_assumptions[symbol].add_equal(sp.Symbol(rhs)) + elif '>' in a: + symbol, rhs = a.split('>') + if symbol not in condensed_assumptions: + condensed_assumptions[symbol] = Assumptions() + try: + condensed_assumptions[symbol].add_greater(int(rhs)) + except ValueError: + condensed_assumptions[symbol].add_greater(sp.Symbol(rhs)) + # add the opposite, i.e. for x>y, we add yx + if rhs not in condensed_assumptions: + condensed_assumptions[rhs] = Assumptions() + condensed_assumptions[rhs].add_greater(sp.Symbol(symbol)) + + # Handle equal assumptions. + equality_subs = propagate_assumptions_equal_symbols(condensed_assumptions) + + # How many assumptions does symbol with most assumptions have? + curr_max = -1 + for _, assum in condensed_assumptions.items(): + if assum.num_assumptions() > curr_max: + curr_max = assum.num_assumptions() + + all_subs = [] + for i in range(curr_max): + all_subs.append(({}, {})) + + # Construct all the substitution dicts. In each substitution round we take at most one assumption for each + # symbol. Each round has two dicts: First one swaps in the assumption and second one restores the initial + # symbol. + for sym, assum in condensed_assumptions.items(): + i = 0 + for g in assum.greater: + replacement_symbol = sp.Symbol(f'_p_{sym}', positive=True, integer=True) + all_subs[i][0].update({sp.Symbol(sym): replacement_symbol + g}) + all_subs[i][1].update({replacement_symbol : sp.Symbol(sym) - g}) + i += 1 + for l in assum.lesser: + replacement_symbol = sp.Symbol(f'_n_{sym}', negative=True, integer=True) + all_subs[i][0].update({sp.Symbol(sym): replacement_symbol + l}) + all_subs[i][1].update({replacement_symbol: sp.Symbol(sym) - l}) + i += 1 + + return equality_subs, all_subs \ No newline at end of file diff --git a/dace/sdfg/work_depth_analysis/helpers.py b/dace/sdfg/work_depth_analysis/helpers.py index a80e769f64..e592fd11b5 100644 --- a/dace/sdfg/work_depth_analysis/helpers.py +++ b/dace/sdfg/work_depth_analysis/helpers.py @@ -328,4 +328,6 @@ def find_loop_guards_tails_exits(sdfg_nx: nx.DiGraph): # now we have a triple (node, oNode, exitCandidates) nodes_oNodes_exits.append((node, oNode, exitCandidates)) + # remove artificial end node + sdfg_nx.remove_node(artificial_end_node) return nodes_oNodes_exits diff --git a/dace/sdfg/work_depth_analysis/work_depth.py b/dace/sdfg/work_depth_analysis/work_depth.py index fc1f2089e1..da700bd829 100644 --- a/dace/sdfg/work_depth_analysis/work_depth.py +++ b/dace/sdfg/work_depth_analysis/work_depth.py @@ -5,7 +5,7 @@ import argparse from collections import deque from dace.sdfg import nodes as nd, propagation, InterstateEdge -from dace import SDFG, SDFGState, dtypes +from dace import SDFG, SDFGState, dtypes, int64 from dace.subsets import Range from typing import Tuple, Dict import os @@ -19,6 +19,9 @@ import warnings from dace.sdfg.work_depth_analysis.helpers import get_uuid, find_loop_guards_tails_exits +from dace.sdfg.work_depth_analysis.assumptions import parse_assumptions +from dace.transformation.passes.symbol_ssa import StrictSymbolSSA +from dace.transformation.pass_pipeline import FixedPointPipeline def get_array_size_symbols(sdfg): @@ -39,22 +42,6 @@ def get_array_size_symbols(sdfg): return symbols -def posify_certain_symbols(expr, syms_to_posify): - """ - Takes an expression and evaluates it while assuming that certain symbols are positive. - - :param expr: The expression to evaluate. - :param syms_to_posify: List of symbols we assume to be positive. - :note: This is adapted from the Sympy function posify. - """ - - expr = sp.sympify(expr) - - reps = {s: sp.Dummy(s.name, positive=True, **s.assumptions0) for s in syms_to_posify if s.is_positive is None} - expr = expr.subs(reps) - return expr.subs({r: s for s, r in reps.items()}) - - def symeval(val, symbols): """ Takes a sympy expression and substitutes its symbols according to a dict { old_symbol: new_symbol}. @@ -64,7 +51,7 @@ def symeval(val, symbols): """ first_replacement = {pystr_to_symbolic(k): pystr_to_symbolic('__REPLSYM_' + k) for k in symbols.keys()} second_replacement = {pystr_to_symbolic('__REPLSYM_' + k): v for k, v in symbols.items()} - return val.subs(first_replacement).subs(second_replacement) + return sp.simplify(val.subs(first_replacement).subs(second_replacement)) def evaluate_symbols(base, new): @@ -87,7 +74,7 @@ def count_work_matmul(node, symbols, state): result *= symeval(C_memlet.data.subset.size()[-1], symbols) # K result *= symeval(A_memlet.data.subset.size()[-1], symbols) - return result + return sp.sympify(result) def count_depth_matmul(node, symbols, state): # optimal depth of a matrix multiplication is O(log(size of shared dimension)): @@ -108,7 +95,7 @@ def count_work_reduce(node, symbols, state): result *= in_memlet.data.volume else: result = 0 - return result + return sp.sympify(result) def count_depth_reduce(node, symbols, state): # optimal depth of reduction is log of the work @@ -119,7 +106,6 @@ def count_depth_reduce(node, symbols, state): MatMul: count_work_matmul, Transpose: lambda *args: 0, Reduce: count_work_reduce, - # TODO: include more } LIBNODES_TO_DEPTH = { @@ -249,9 +235,9 @@ def count_depth_code(code): def tasklet_work(tasklet_node, state): if tasklet_node.code.language == dtypes.Language.CPP: + # simplified work analysis for CPP tasklets. for oedge in state.out_edges(tasklet_node): - return bigo(oedge.data.num_accesses) - + return oedge.data.num_accesses elif tasklet_node.code.language == dtypes.Language.Python: return count_arithmetic_ops_code(tasklet_node.code.code) else: @@ -262,11 +248,10 @@ def tasklet_work(tasklet_node, state): def tasklet_depth(tasklet_node, state): - # TODO: how to get depth of CPP tasklets? - # For now we use depth == work: if tasklet_node.code.language == dtypes.Language.CPP: + # For now we simply take depth == work for CPP tasklets. for oedge in state.out_edges(tasklet_node): - return bigo(oedge.data.num_accesses) + return oedge.data.num_accesses if tasklet_node.code.language == dtypes.Language.Python: return count_depth_code(tasklet_node.code.code) else: @@ -277,19 +262,35 @@ def tasklet_depth(tasklet_node, state): def get_tasklet_work(node, state): - return tasklet_work(node, state), -1 + return sp.sympify(tasklet_work(node, state)), sp.sympify(-1) def get_tasklet_work_depth(node, state): - return tasklet_work(node, state), tasklet_depth(node, state) + return sp.sympify(tasklet_work(node, state)), sp.sympify(tasklet_depth(node, state)) def get_tasklet_avg_par(node, state): - return tasklet_work(node, state), tasklet_depth(node, state) + return sp.sympify(tasklet_work(node, state)), sp.sympify(tasklet_depth(node, state)) + +def update_value_map(old, new): + # add new assignments to old + old.update({k: v for k, v in new.items() if k not in old}) + # check for conflicts: + for k, v in new.items(): + if k in old and old[k] != v: + # conflict detected --> forget this mapping completely + old.pop(k) + +def do_initial_subs(w, d, eq, subs1): + """ + Calls subs three times for the give (w)ork and (d)epth values. + """ + return sp.simplify(w.subs(eq[0]).subs(eq[1]).subs(subs1)), sp.simplify(d.subs(eq[0]).subs(eq[1]).subs(subs1)) def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], analyze_tasklet, - symbols) -> Tuple[sp.Expr, sp.Expr]: + symbols: Dict[str, str], detailed_analysis: bool, equality_subs: Tuple[Dict[str, sp.Symbol], Dict[str, sp.Expr]], + subs1: Dict[str, sp.Expr]) -> Tuple[sp.Expr, sp.Expr]: """ Analyze the work and depth of a given SDFG. First we determine the work and depth of each state. Then we break loops in the state machine, such that we get a DAG. @@ -299,6 +300,11 @@ def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], ana :param w_d_map: Dictionary which will save the result. :param analyze_tasklet: Function used to analyze tasklet nodes. :param symbols: A dictionary mapping local nested SDFG symbols to global symbols. + :param detailed_analysis: If True, detailed analysis gets used. For each branch, we keep track of its condition + and work depth values for both branches. If False, the worst-case branch is taken. Discouraged to use on bigger SDFGs, + as computation time sky-rockets, since expression can became HUGE (depending on number of branches etc.). + :param equality_subs: Substitution dict taking care of the equality assumptions. + :param subs1: First substitution dict for greater/lesser assumptions. :return: A tuple containing the work and depth of the SDFG. """ @@ -308,9 +314,13 @@ def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], ana state_depths: Dict[SDFGState, sp.Expr] = {} state_works: Dict[SDFGState, sp.Expr] = {} for state in sdfg.nodes(): - state_work, state_depth = state_work_depth(state, w_d_map, analyze_tasklet, symbols) - state_works[state] = sp.simplify(state_work * state.executions) - state_depths[state] = sp.simplify(state_depth * state.executions) + state_work, state_depth = state_work_depth(state, w_d_map, analyze_tasklet, symbols, detailed_analysis, equality_subs, subs1) + + # Substitutions for state_work and state_depth already performed, but state.executions needs to be subs'd now. + state_work = sp.simplify(state_work * state.executions.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1)) + state_depth = sp.simplify(state_depth * state.executions.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1)) + + state_works[state], state_depths[state] = state_work, state_depth w_d_map[get_uuid(state)] = (state_works[state], state_depths[state]) # Prepare the SDFG for a depth analysis by breaking loops. This removes the edge between the last loop state and @@ -324,12 +334,18 @@ def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], ana # Now we need to go over each triple (node, oNode, exits). For each triple, we # - remove edge (oNode, node), i.e. the backward edge # - for all exits e, add edge (oNode, e). This edge may already exist + # - remove edge from node to exit (if present, i.e. while-do loop) + # - This ensures that every node with > 1 outgoing edge is a branch guard + # - useful for detailed anaylsis. for node, oNode, exits in nodes_oNodes_exits: sdfg.remove_edge(sdfg.edges_between(oNode, node)[0]) for e in exits: if len(sdfg.edges_between(oNode, e)) == 0: # no edge there yet sdfg.add_edge(oNode, e, InterstateEdge()) + if len(sdfg.edges_between(node, e)) > 0: + # edge present --> remove it + sdfg.remove_edge(sdfg.edges_between(node, e)[0]) # add a dummy exit to the SDFG, such that each path ends there. dummy_exit = sdfg.add_state('dummy_exit') @@ -340,6 +356,8 @@ def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], ana # These two dicts save the current length of the "heaviest", resp. "deepest", paths at each state. work_map: Dict[SDFGState, sp.Expr] = {} depth_map: Dict[SDFGState, sp.Expr] = {} + # Keeps track of assignments done on InterstateEdges. + state_value_map: Dict[SDFGState, Dict[sp.Symbol, sp.Symbol]] = {} # The dummy state has 0 work and depth. state_depths[dummy_exit] = sp.sympify(0) state_works[dummy_exit] = sp.sympify(0) @@ -348,40 +366,67 @@ def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], ana # the next state in the BFS if all incoming edges have been visited, to ensure the maximum work / depth expressions # have been calculated. traversal_q = deque() - traversal_q.append((sdfg.start_state, sp.sympify(0), sp.sympify(0), None)) + traversal_q.append((sdfg.start_state, sp.sympify(0), sp.sympify(0), None, [], [], {})) visited = set() + + num_edges = 0 + while traversal_q: - state, depth, work, ie = traversal_q.popleft() + state, depth, work, ie, condition_stack, common_subexpr_stack, value_map = traversal_q.popleft() if ie is not None: visited.add(ie) + + if state in state_value_map: + # update value map: + update_value_map(state_value_map[state], value_map) + else: + state_value_map[state] = value_map - n_depth = sp.simplify(depth + state_depths[state]) - n_work = sp.simplify(work + state_works[state]) + # ignore assignments such as tmp=x[0], as those do not give much information. + value_map = {k: v for k, v in state_value_map[state].items() if '[' not in k and '[' not in v} + n_depth = sp.simplify((depth + state_depths[state]).subs(value_map)) + n_work = sp.simplify((work + state_works[state]).subs(value_map)) # If we are analysing average parallelism, we don't search "heaviest" and "deepest" paths separately, but we want one # single path with the least average parallelsim (of all paths with more than 0 work). if analyze_tasklet == get_tasklet_avg_par: - if state in depth_map: # and hence als state in work_map - # if current path has 0 depth, we don't do anything. + if state in depth_map: # this means we have already visited this state before + cse = common_subexpr_stack.pop() + # if current path has 0 depth (--> 0 work as well), we don't do anything. if n_depth != 0: - # see if we need to update the work and depth of the current state + # check if we need to update the work and depth of the current state # we update if avg parallelism of new incoming path is less than current avg parallelism - old_avg_par = sp.simplify(work_map[state] / depth_map[state]) - new_avg_par = sp.simplify(n_work / n_depth) - - if depth_map[state] == 0 or new_avg_par < old_avg_par: - # old value was divided by zero or new path gives actually worse avg par, then we keep new value - depth_map[state] = n_depth - work_map[state] = n_work + if depth_map[state] == 0: + # old value was divided by zero --> we take new value anyway + depth_map[state] = cse[1] + n_depth + work_map[state] = cse[0] + n_work + else: + old_avg_par = (cse[0] + work_map[state]) / (cse[1] + depth_map[state]) + new_avg_par = (cse[0] + n_work) / (cse[1] + n_depth) + # we take either old work/depth or new work/depth (or both if we cannot determine which one is greater) + depth_map[state] = cse[1] + sp.Piecewise((n_depth, sp.simplify(new_avg_par < old_avg_par)), (depth_map[state], True)) + work_map[state] = cse[0] + sp.Piecewise((n_work, sp.simplify(new_avg_par < old_avg_par)), (work_map[state], True)) else: depth_map[state] = n_depth work_map[state] = n_work else: # search heaviest and deepest path separately if state in depth_map: # and consequently also in work_map - depth_map[state] = sp.Max(depth_map[state], n_depth) - work_map[state] = sp.Max(work_map[state], n_work) + # This cse value would appear in both arguments of the Max. Hence, for performance reasons, + # we pull it out of the Max expression. + # Example: We do cse + Max(a, b) instead of Max(cse + a, cse + b). + # This increases performance drastically, expecially since we avoid nesting Max expressions + # for cases where cse itself contains Max operators. + cse = common_subexpr_stack.pop() + if detailed_analysis: + # This MAX should be covered in the more detailed analysis + cond = condition_stack.pop() + work_map[state] = cse[0] + sp.Piecewise((work_map[state], sp.Not(cond)), (n_work, cond)) + depth_map[state] = cse[1] + sp.Piecewise((depth_map[state], sp.Not(cond)), (n_depth, cond)) + else: + work_map[state] = cse[0] + sp.Max(work_map[state], n_work) + depth_map[state] = cse[1] + sp.Max(depth_map[state], n_depth) else: depth_map[state] = n_depth work_map[state] = n_work @@ -392,7 +437,21 @@ def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], ana pass else: for oedge in out_edges: - traversal_q.append((oedge.dst, depth_map[state], work_map[state], oedge)) + if len(out_edges) > 1: + # It is important to copy these stacks. Else both branches operate on the same stack. + # state is a branch guard --> save condition on stack + new_cond_stack = list(condition_stack) + new_cond_stack.append(oedge.data.condition_sympy()) + # same for common_subexr_stack + new_cse_stack = list(common_subexpr_stack) + new_cse_stack.append((work_map[state], depth_map[state])) + # same for value_map + new_value_map = dict(state_value_map[state]) + new_value_map.update({sp.Symbol(k): sp.Symbol(v) for k, v in oedge.data.assignments.items()}) + traversal_q.append((oedge.dst, 0, 0, oedge, new_cond_stack, new_cse_stack, new_value_map)) + else: + value_map.update(oedge.data.assignments) + traversal_q.append((oedge.dst, depth_map[state], work_map[state], oedge, condition_stack, common_subexpr_stack, value_map)) try: max_depth = depth_map[dummy_exit] @@ -403,7 +462,7 @@ def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], ana raise Exception( 'Analysis failed, since not all loops got detected. It may help to use more structured loop constructs.') - sdfg_result = (sp.simplify(max_work), sp.simplify(max_depth)) + sdfg_result = (max_work, max_depth) w_d_map[get_uuid(sdfg)] = sdfg_result return sdfg_result @@ -412,6 +471,9 @@ def scope_work_depth(state: SDFGState, w_d_map: Dict[str, sp.Expr], analyze_tasklet, symbols, + detailed_analysis, + equality_subs, + subs1, entry: nd.EntryNode = None) -> Tuple[sp.Expr, sp.Expr]: """ Analyze the work and depth of a scope. @@ -425,7 +487,14 @@ def scope_work_depth(state: SDFGState, this can be done in linear time by traversing the graph in topological order. :param state: The state in which the scope to analyze is contained. - :param sym_map: A dictionary mapping symbols to their values. + :param w_d_map: Dictionary saving the final result for each SDFG element. + :param analyze_tasklet: Function used to analyze tasklets. Either analyzes just work, work and depth or average parallelism. + :param symbols: A dictionary mapping local nested SDFG symbols to global symbols. + :param detailed_analysis: If True, detailed analysis gets used. For each branch, we keep track of its condition + and work depth values for both branches. If False, the worst-case branch is taken. Discouraged to use on bigger SDFGs, + as computation time sky-rockets, since expression can became HUGE (depending on number of branches etc.). + :param equality_subs: Substitution dict taking care of the equality assumptions. + :param subs1: First substitution dict for greater/lesser assumptions. :param entry: The entry node of the scope to analyze. If None, the entire state is analyzed. :return: A tuple containing the work and depth of the scope. """ @@ -442,7 +511,8 @@ def scope_work_depth(state: SDFGState, if isinstance(node, nd.EntryNode): # If the scope contains an entry node, we need to recursively analyze the sub-scope of the entry node first. # The resulting work/depth are summarized into the entry node - s_work, s_depth = scope_work_depth(state, w_d_map, analyze_tasklet, symbols, node) + s_work, s_depth = scope_work_depth(state, w_d_map, analyze_tasklet, symbols, detailed_analysis, equality_subs, subs1, node) + s_work, s_depth = do_initial_subs(s_work, s_depth, equality_subs, subs1) # add up work for whole state, but also save work for this sub-scope scope in w_d_map work += s_work w_d_map[get_uuid(node, state)] = (s_work, s_depth) @@ -456,8 +526,9 @@ def scope_work_depth(state: SDFGState, for e in state.out_edges(node): if e.data.wcr is not None: t_work += count_arithmetic_ops_code(e.data.wcr) + t_work, t_depth = do_initial_subs(t_work, t_depth, equality_subs, subs1) work += t_work - w_d_map[get_uuid(node, state)] = (sp.sympify(t_work), sp.sympify(t_depth)) + w_d_map[get_uuid(node, state)] = (t_work, t_depth) elif isinstance(node, nd.NestedSDFG): # keep track of nested symbols: "symbols" maps local nested SDFG symbols to global symbols. # We only want global symbols in our final work depth expressions. @@ -465,18 +536,34 @@ def scope_work_depth(state: SDFGState, nested_syms.update(symbols) nested_syms.update(evaluate_symbols(symbols, node.symbol_mapping)) # Nested SDFGs are recursively analyzed first. - nsdfg_work, nsdfg_depth = sdfg_work_depth(node.sdfg, w_d_map, analyze_tasklet, nested_syms) + nsdfg_work, nsdfg_depth = sdfg_work_depth(node.sdfg, w_d_map, analyze_tasklet, nested_syms, detailed_analysis, equality_subs, subs1) + nsdfg_work, nsdfg_depth = do_initial_subs(nsdfg_work, nsdfg_depth, equality_subs, subs1) # add up work for whole state, but also save work for this nested SDFG in w_d_map work += nsdfg_work w_d_map[get_uuid(node, state)] = (nsdfg_work, nsdfg_depth) elif isinstance(node, nd.LibraryNode): - lib_node_work = LIBNODES_TO_WORK[type(node)](node, symbols, state) - work += lib_node_work - lib_node_depth = -1 # not analyzed + try: + lib_node_work = LIBNODES_TO_WORK[type(node)](node, symbols, state) + except KeyError: + # add a symbol to the top level sdfg, such that the user can define it in the extension + top_level_sdfg = state.parent + # TODO: This symbol should now appear in the VS code extension in the SDFG analysis tab, + # such that the user can define its value. But it doesn't... + # How to achieve this? + top_level_sdfg.add_symbol(f'{node.name}_work', int64) + lib_node_work = sp.Symbol(f'{node.name}_work', positive=True) + lib_node_depth = sp.sympify(-1) # not analyzed if analyze_tasklet != get_tasklet_work: # we are analyzing depth - lib_node_depth = LIBNODES_TO_DEPTH[type(node)](node, symbols, state) + try: + lib_node_depth = LIBNODES_TO_DEPTH[type(node)](node, symbols, state) + except KeyError: + top_level_sdfg = state.parent + top_level_sdfg.add_symbol(f'{node.name}_depth', int64) + lib_node_depth = sp.Symbol(f'{node.name}_depth', positive=True) + lib_node_work, lib_node_depth = do_initial_subs(lib_node_work, lib_node_depth, equality_subs, subs1) + work += lib_node_work w_d_map[get_uuid(node, state)] = (lib_node_work, lib_node_depth) if entry is not None: @@ -485,7 +572,7 @@ def scope_work_depth(state: SDFGState, nmap: nd.Map = entry.map range: Range = nmap.range n_exec = range.num_elements_exact() - work = work * sp.simplify(n_exec) + work = sp.simplify(work * n_exec.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1)) else: print('WARNING: Only Map scopes are supported in work analysis for now. Assuming 1 iteration.') @@ -535,17 +622,30 @@ def scope_work_depth(state: SDFGState, out_edges = state.out_edges(exit_node) for oedge in out_edges: # check for wcr - wcr_depth = 0 + wcr_depth = sp.sympify(0) if oedge.data.wcr is not None: - wcr_depth = oedge.data.volume + # This division gives us the number of writes to each single memory location, which is the depth + # as these need to be sequential (without assumptions on HW etc). + wcr_depth = oedge.data.volume / oedge.data.subset.num_elements() if get_uuid(node, state) in wcr_depth_map: # max wcr_depth_map[get_uuid(node, state)] = sp.Max(wcr_depth_map[get_uuid(node, state)], wcr_depth) else: wcr_depth_map[get_uuid(node, state)] = wcr_depth # We do not need to propagate the wcr_depth to MapExits, since else this will result in depth N + 1 for Maps of range N. - # But we only want N, the code line just above this comment will then take care of that. - traversal_q.append((oedge.dst, depth_map[node] + (wcr_depth if not isinstance(oedge.dst, nd.MapExit) else 0), oedge)) + wcr_depth = wcr_depth if not isinstance(oedge.dst, nd.MapExit) else sp.sympify(0) + + # only append if it's actually new information + # this e.g. helps for huge nested SDFGs with lots of inputs/outputs inside a map scope + append = True + for n, d, _ in traversal_q: + if oedge.dst == n and depth_map[node] + wcr_depth == d: + append = False + break + if append: + traversal_q.append((oedge.dst, depth_map[node] + wcr_depth, oedge)) + else: + visited.add(oedge) if len(out_edges) == 0 or node == scope_exit: # We have reached an end node --> update max_depth max_depth = sp.Max(max_depth, depth_map[node]) @@ -553,13 +653,13 @@ def scope_work_depth(state: SDFGState, for uuid in wcr_depth_map: w_d_map[uuid] = (w_d_map[uuid][0], w_d_map[uuid][1] + wcr_depth_map[uuid]) # summarise work / depth of the whole scope in the dictionary - scope_result = (sp.simplify(work), sp.simplify(max_depth)) + scope_result = (work, max_depth) w_d_map[get_uuid(state)] = scope_result return scope_result def state_work_depth(state: SDFGState, w_d_map: Dict[str, sp.Expr], analyze_tasklet, - symbols) -> Tuple[sp.Expr, sp.Expr]: + symbols, detailed_analysis, equality_subs, subs1) -> Tuple[sp.Expr, sp.Expr]: """ Analyze the work and depth of a state. @@ -567,13 +667,18 @@ def state_work_depth(state: SDFGState, w_d_map: Dict[str, sp.Expr], analyze_task :param w_d_map: The result will be saved to this map. :param analyze_tasklet: Function used to analyze tasklet nodes. :param symbols: A dictionary mapping local nested SDFG symbols to global symbols. + :param detailed_analysis: If True, detailed analysis gets used. For each branch, we keep track of its condition + and work depth values for both branches. If False, the worst-case branch is taken. Discouraged to use on bigger SDFGs, + as computation time sky-rockets, since expression can became HUGE (depending on number of branches etc.). + :param equality_subs: Substitution dict taking care of the equality assumptions. + :param subs1: First substitution dict for greater/lesser assumptions. :return: A tuple containing the work and depth of the state. """ - work, depth = scope_work_depth(state, w_d_map, analyze_tasklet, symbols, None) + work, depth = scope_work_depth(state, w_d_map, analyze_tasklet, symbols, detailed_analysis, equality_subs, subs1, None) return work, depth -def analyze_sdfg(sdfg: SDFG, w_d_map: Dict[str, sp.Expr], analyze_tasklet) -> None: +def analyze_sdfg(sdfg: SDFG, w_d_map: Dict[str, sp.Expr], analyze_tasklet, assumptions: [str], detailed_analysis: bool) -> None: """ Analyze a given SDFG. We can either analyze work, work and depth or average parallelism. @@ -581,12 +686,24 @@ def analyze_sdfg(sdfg: SDFG, w_d_map: Dict[str, sp.Expr], analyze_tasklet) -> No condition and an assignment. :param sdfg: The SDFG to analyze. :param w_d_map: Dictionary of SDFG elements to (work, depth) tuples. Result will be saved in here. - :param analyze_tasklet: The function used to analyze tasklet nodes. Analyzes either just work, work and depth or average parallelism. + :param analyze_tasklet: Function used to analyze tasklet nodes. Analyzes either just work, work and depth or average parallelism. + :param assumptions: List of strings. Each string corresponds to one assumption for some symbol, e.g. 'N>5'. + :param detailed_analysis: If True, detailed analysis gets used. For each branch, we keep track of its condition + and work depth values for both branches. If False, the worst-case branch is taken. Discouraged to use on bigger SDFGs, + as computation time sky-rockets, since expression can became HUGE (depending on number of branches etc.). """ # deepcopy such that original sdfg not changed sdfg = deepcopy(sdfg) + # apply SSA pass + pipeline = FixedPointPipeline([StrictSymbolSSA()]) + pipeline.apply_pass(sdfg, {}) + + array_symbols = get_array_size_symbols(sdfg) + # parse assumptions + equality_subs, all_subs = parse_assumptions(assumptions if assumptions is not None else [], array_symbols) + # Run state propagation for all SDFGs recursively. This is necessary to determine the number of times each state # will be executed, or to determine upper bounds for that number (such as in the case of branching) for sd in sdfg.all_sdfgs_recursive(): @@ -594,17 +711,35 @@ def analyze_sdfg(sdfg: SDFG, w_d_map: Dict[str, sp.Expr], analyze_tasklet) -> No # Analyze the work and depth of the SDFG. symbols = {} - sdfg_work_depth(sdfg, w_d_map, analyze_tasklet, symbols) - - # Note: This posify could be done more often to improve performance. - array_symbols = get_array_size_symbols(sdfg) + sdfg_work_depth(sdfg, w_d_map, analyze_tasklet, symbols, detailed_analysis, equality_subs, all_subs[0][0] if len(all_subs) > 0 else {}) + for k, (v_w, v_d) in w_d_map.items(): # The symeval replaces nested SDFG symbols with their global counterparts. - v_w = posify_certain_symbols(symeval(v_w, symbols), array_symbols) - v_d = posify_certain_symbols(symeval(v_d, symbols), array_symbols) + v_w, v_d = do_subs(v_w, v_d, all_subs) + v_w = symeval(v_w, symbols) + v_d = symeval(v_d, symbols) w_d_map[k] = (v_w, v_d) +def do_subs(work, depth, all_subs): + """ + Handles all substitutions beyond the equality substitutions and the first substitution. + :param work: Some work expression. + :param depth: Some depth expression. + :param all_subs: List of substitution pairs to perform. + :return: Work depth expressions after doing all substitutions. + """ + # first do subs2 of first sub + # then do all the remaining subs + subs2 = all_subs[0][1] if len(all_subs) > 0 else {} + work, depth = sp.simplify(sp.sympify(work).subs(subs2)), sp.simplify(sp.sympify(depth).subs(subs2)) + for i in range(1, len(all_subs)): + subs1, subs2 = all_subs[i] + work, depth = sp.simplify(work.subs(subs1)), sp.simplify(depth.subs(subs1)) + work, depth = sp.simplify(work.subs(subs2)), sp.simplify(depth.subs(subs2)) + return work, depth + + ################################################################################ # Utility functions for running the analysis from the command line ############# ################################################################################ @@ -621,7 +756,10 @@ def main() -> None: choices=['work', 'workDepth', 'avgPar'], default='workDepth', help='Choose what to analyze. Default: workDepth') + parser.add_argument('--assume', nargs='*', help='Collect assumptions about symbols, e.g. x>0 x>y y==5') + parser.add_argument("--detailed", action="store_true", + help="Turns on detailed mode.") args = parser.parse_args() if not os.path.exists(args.filename): @@ -637,7 +775,7 @@ def main() -> None: sdfg = SDFG.from_file(args.filename) work_depth_map = {} - analyze_sdfg(sdfg, work_depth_map, analyze_tasklet) + analyze_sdfg(sdfg, work_depth_map, analyze_tasklet, args.assume, args.detailed) if args.analyze == 'workDepth': for k, v, in work_depth_map.items(): diff --git a/tests/sdfg/work_depth_tests.py b/tests/sdfg/work_depth_tests.py index 133afe8ae4..924397aa1e 100644 --- a/tests/sdfg/work_depth_tests.py +++ b/tests/sdfg/work_depth_tests.py @@ -1,14 +1,18 @@ # Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. """ Contains test cases for the work depth analysis. """ import dace as dc -from dace.sdfg.work_depth_analysis.work_depth import analyze_sdfg, get_tasklet_work_depth +from dace.sdfg.work_depth_analysis.work_depth import analyze_sdfg, get_tasklet_work_depth, parse_assumptions from dace.sdfg.work_depth_analysis.helpers import get_uuid +from dace.sdfg.work_depth_analysis.assumptions import ContradictingAssumptions import sympy as sp from dace.transformation.interstate import NestSDFG from dace.transformation.dataflow import MapExpansion +from pytest import raises + # TODO: add tests for library nodes (e.g. reduce, matMul) +# TODO: add tests for average parallelism N = dc.symbol('N') M = dc.symbol('M') @@ -65,11 +69,11 @@ def nested_for_loops(x: dc.float64[N], y: dc.float64[K]): @dc.program def nested_if_else(x: dc.int64[N], y: dc.int64[N], z: dc.int64[N], sum: dc.int64[1]): if x[10] > 50: - if x[9] > 50: + if x[9] > 40: z[:] = x + y # N work, 1 depth z[:] += 2 * x # 2*N work, 2 depth --> total outer if: 3*N work, 3 depth else: - if y[9] > 50: + if y[9] > 30: for i in range(K): sum += x[i] # K work, K depth else: @@ -152,7 +156,22 @@ def break_while_loop(x: dc.float64[N]): break x += 1 +@dc.program +def sequntial_ifs(x: dc.float64[N + 1], y: dc.float64[M + 1]): # --> cannot assume N, M to be positive + if x[0] > 5: + x[:] += 1 # N+1 work, 1 depth + else: + for i in range(M): # M work, M depth + y[i+1] += y[i] + if M > N: + y[:N+1] += x[:] # N+1 work, 1 depth + else: + x[:M+1] += y[:] # M+1 work, 1 depth + # --> Work: Max(N+1, M) + Max(N+1, M+1) + # Depth: Max(1, M) + 1 + +#(sdfg, (expected_work, expected_depth)) tests_cases = [ (single_map, (N, 1)), (single_for_loop, (N, N)), @@ -164,25 +183,20 @@ def break_while_loop(x: dc.float64[N]): (nested_if_else, (sp.Max(K, 3 * N, M + N), sp.Max(3, K, M + 1))), (max_of_positive_symbol, (3 * N**2, 3 * N)), (multiple_array_sizes, (sp.Max(2 * K, 3 * N, 2 * M + 3), 5)), - (unbounded_while_do, (sp.Symbol('num_execs_0_2', nonnegative=True) * N, sp.Symbol('num_execs_0_2', - nonnegative=True))), + (unbounded_while_do, (sp.Symbol('num_execs_0_2') * N, sp.Symbol('num_execs_0_2'))), # We get this Max(1, num_execs), since it is a do-while loop, but the num_execs symbol does not capture this. - (unbounded_do_while, (sp.Max(1, sp.Symbol('num_execs_0_1', nonnegative=True)) * N, - sp.Max(1, sp.Symbol('num_execs_0_1', nonnegative=True)))), - (unbounded_nonnegify, (2 * sp.Symbol('num_execs_0_7', nonnegative=True) * N, - 2 * sp.Symbol('num_execs_0_7', nonnegative=True))), - (continue_for_loop, (sp.Symbol('num_execs_0_6', nonnegative=True) * N, sp.Symbol('num_execs_0_6', - nonnegative=True))), + (unbounded_do_while, (sp.Max(1, sp.Symbol('num_execs_0_1')) * N, + sp.Max(1, sp.Symbol('num_execs_0_1')))), + (unbounded_nonnegify, (2 * sp.Symbol('num_execs_0_7') * N, + 2 * sp.Symbol('num_execs_0_7'))), + (continue_for_loop, (sp.Symbol('num_execs_0_6') * N, sp.Symbol('num_execs_0_6'))), (break_for_loop, (N**2, N)), - (break_while_loop, (sp.Symbol('num_execs_0_5', nonnegative=True) * N, sp.Symbol('num_execs_0_5', nonnegative=True))) + (break_while_loop, (sp.Symbol('num_execs_0_5') * N, sp.Symbol('num_execs_0_5'))), + (sequntial_ifs, (sp.Max(N+1, M) + sp.Max(N+1, M+1), sp.Max(1, M) + 1)) ] def test_work_depth(): - good = 0 - failed = 0 - exception = 0 - failed_tests = [] for test, correct in tests_cases: w_d_map = {} sdfg = test.to_sdfg() @@ -190,12 +204,64 @@ def test_work_depth(): sdfg.apply_transformations(NestSDFG) if 'nested_maps' in test.name: sdfg.apply_transformations(MapExpansion) - - analyze_sdfg(sdfg, w_d_map, get_tasklet_work_depth) + analyze_sdfg(sdfg, w_d_map, get_tasklet_work_depth, [], False) res = w_d_map[get_uuid(sdfg)] + # substitue each symbol without assumptions. + # We do this since sp.Symbol('N') == Sp.Symbol('N', positive=True) --> False. + reps = {s: sp.Symbol(s.name) for s in (res[0].free_symbols | res[1].free_symbols)} + res = (res[0].subs(reps), res[1].subs(reps)) + reps = {s: sp.Symbol(s.name) for s in (sp.sympify(correct[0]).free_symbols | sp.sympify(correct[1]).free_symbols)} + correct = (sp.sympify(correct[0]).subs(reps), sp.sympify(correct[1]).subs(reps)) # check result assert correct == res +x, y, z, a = sp.symbols('x y z a') + +# (expr, assumptions, result) +assumptions_tests=[ + (sp.Max(x, y), ['x>y'], x), + (sp.Max(x, y, z), ['x>y'], sp.Max(x, z)), + (sp.Max(x, y), ['x==y'], y), + (sp.Max(x, 11) + sp.Max(x, 3), ['x<11'], 11 + sp.Max(x,3)), + (sp.Max(x, 11) + sp.Max(x, 3), ['x<11', 'x>3'], 11 + x), + (sp.Max(x, 11), ['x>5', 'x>3', 'x>11'], x), + (sp.Max(x, 11), ['x==y', 'x>11'], y), + (sp.Max(x, 11) + sp.Max(a, 5), ['a==b', 'b==c', 'c==x', 'a<11', 'c>7'], x + 11), + (sp.Max(x, 11) + sp.Max(a, 5), ['a==b', 'b==c', 'c==x', 'b==7'], 18), + (sp.Max(x, y), ['y>x', 'y==1000'], 1000), + (sp.Max(x, y), ['y0', 'N<5', 'M>5'], M) + +] + +# These assumptions should trigger the ContradictingAssumptions exception. +tests_for_exception = [ + ['x>10', 'x<9'], + ['x==y', 'x>10', 'y<9'], + ['a==b', 'b==c', 'c==d', 'd==e', 'e==f', 'x==y', 'y==z', 'z>b', 'x==5', 'd==100'], + ['x==5', 'x<4'] +] + +def test_assumption_system(): + for expr, assums, res in assumptions_tests: + equality_subs, all_subs = parse_assumptions(assums, set()) + initial_expr = expr + expr = expr.subs(equality_subs[0]) + expr = expr.subs(equality_subs[1]) + for subs1, subs2 in all_subs: + expr = expr.subs(subs1) + expr = expr.subs(subs2) + assert expr == res + + for assums in tests_for_exception: + # check that the Exception gets raised. + with raises(ContradictingAssumptions): + parse_assumptions(assums, set()) + + if __name__ == '__main__': test_work_depth() + test_assumption_system() From 12c2c7333a705e1d3d0ea8b5addb1a67ae1b42d2 Mon Sep 17 00:00:00 2001 From: Cliff Hodel Date: Mon, 4 Sep 2023 12:24:20 +0200 Subject: [PATCH 14/18] formatting with yapf --- dace/sdfg/work_depth_analysis/assumptions.py | 35 ++++++------ dace/sdfg/work_depth_analysis/work_depth.py | 59 ++++++++++++-------- tests/sdfg/work_depth_tests.py | 51 ++++++++--------- 3 files changed, 79 insertions(+), 66 deletions(-) diff --git a/dace/sdfg/work_depth_analysis/assumptions.py b/dace/sdfg/work_depth_analysis/assumptions.py index 1f167f15a3..c7e439cf51 100644 --- a/dace/sdfg/work_depth_analysis/assumptions.py +++ b/dace/sdfg/work_depth_analysis/assumptions.py @@ -8,16 +8,16 @@ class UnionFind: """ Simple, not really optimized UnionFind implementation. """ - + def __init__(self, elements) -> None: - self.ids = {e : e for e in elements} + self.ids = {e: e for e in elements} def add_element(self, e): if e in self.ids: return False - self.ids.update({e : e}) + self.ids.update({e: e}) return True - + def find(self, e): prev = e curr = self.ids[e] @@ -27,16 +27,17 @@ def find(self, e): # shorten the path self.ids[e] = curr return curr - + def union(self, e, f): if f not in self.ids: self.add_element(f) self.ids[self.find(e)] = f - + class ContradictingAssumptions(Exception): pass + class Assumptions: """ Summarises the assumptions for a single symbol in three lists: equal, greater, lesser. @@ -67,7 +68,7 @@ def add_lesser(self, l): def add_equal(self, e): for x in self.equal: - if not (isinstance(x, sp.Symbol) or isinstance(e, sp.Symbol)) and x != e: + if not (isinstance(x, sp.Symbol) or isinstance(e, sp.Symbol)) and x != e: raise ContradictingAssumptions() self.equal.append(e) self.check_consistency() @@ -89,11 +90,12 @@ def check_consistency(self): if (g > l) == True: raise ContradictingAssumptions() return True - + def num_assumptions(self): # returns the number of individual assumptions for this symbol return len(self.greater) + len(self.lesser) + len(self.equal) - + + def propagate_assumptions(x, y, condensed_assumptions): """ Assuming x is equal to y, we propagate the assumptions on x to y. E.g. we have x==y and @@ -118,6 +120,7 @@ def propagate_assumptions(x, y, condensed_assumptions): assum_y.add_lesser(l) assum_y.check_consistency() + def propagate_assumptions_equal_symbols(condensed_assumptions): """ This method handles two things: 1) It generates the substitution dict for all equality assumptions. @@ -139,7 +142,7 @@ def propagate_assumptions_equal_symbols(condensed_assumptions): uf.union(sym, other.name) equality_subs1 = {} - + # For each equivalence class, we now have one unique identifier. # For each class, we give all the assumptions to this single symbol. # And we swap each symbol in class for this symbol. @@ -148,7 +151,7 @@ def propagate_assumptions_equal_symbols(condensed_assumptions): if isinstance(other, sp.Symbol): propagate_assumptions(sym, uf.find(sym), condensed_assumptions) equality_subs1.update({sym: sp.Symbol(uf.find(sym))}) - + equality_subs2 = {} # In a second step, each symbol gets replace with its equal number (if present) # using equality_subs2. @@ -213,7 +216,7 @@ def parse_assumptions(assumptions, array_symbols): if assumptions is None: return {}, [({}, {})] - + # Gather assumptions, keeping only the strongest ones for each symbol. condensed_assumptions: Dict[str, Assumptions] = {} for a in assumptions: @@ -252,13 +255,13 @@ def parse_assumptions(assumptions, array_symbols): # Handle equal assumptions. equality_subs = propagate_assumptions_equal_symbols(condensed_assumptions) - + # How many assumptions does symbol with most assumptions have? curr_max = -1 for _, assum in condensed_assumptions.items(): if assum.num_assumptions() > curr_max: curr_max = assum.num_assumptions() - + all_subs = [] for i in range(curr_max): all_subs.append(({}, {})) @@ -271,7 +274,7 @@ def parse_assumptions(assumptions, array_symbols): for g in assum.greater: replacement_symbol = sp.Symbol(f'_p_{sym}', positive=True, integer=True) all_subs[i][0].update({sp.Symbol(sym): replacement_symbol + g}) - all_subs[i][1].update({replacement_symbol : sp.Symbol(sym) - g}) + all_subs[i][1].update({replacement_symbol: sp.Symbol(sym) - g}) i += 1 for l in assum.lesser: replacement_symbol = sp.Symbol(f'_n_{sym}', negative=True, integer=True) @@ -279,4 +282,4 @@ def parse_assumptions(assumptions, array_symbols): all_subs[i][1].update({replacement_symbol: sp.Symbol(sym) - l}) i += 1 - return equality_subs, all_subs \ No newline at end of file + return equality_subs, all_subs diff --git a/dace/sdfg/work_depth_analysis/work_depth.py b/dace/sdfg/work_depth_analysis/work_depth.py index da700bd829..21e5b937b9 100644 --- a/dace/sdfg/work_depth_analysis/work_depth.py +++ b/dace/sdfg/work_depth_analysis/work_depth.py @@ -76,6 +76,7 @@ def count_work_matmul(node, symbols, state): result *= symeval(A_memlet.data.subset.size()[-1], symbols) return sp.sympify(result) + def count_depth_matmul(node, symbols, state): # optimal depth of a matrix multiplication is O(log(size of shared dimension)): A_memlet = next(e for e in state.in_edges(node) if e.dst_conn == '_a') @@ -97,6 +98,7 @@ def count_work_reduce(node, symbols, state): result = 0 return sp.sympify(result) + def count_depth_reduce(node, symbols, state): # optimal depth of reduction is log of the work return bigo(sp.log(count_work_reduce(node, symbols, state))) @@ -272,6 +274,7 @@ def get_tasklet_work_depth(node, state): def get_tasklet_avg_par(node, state): return sp.sympify(tasklet_work(node, state)), sp.sympify(tasklet_depth(node, state)) + def update_value_map(old, new): # add new assignments to old old.update({k: v for k, v in new.items() if k not in old}) @@ -281,6 +284,7 @@ def update_value_map(old, new): # conflict detected --> forget this mapping completely old.pop(k) + def do_initial_subs(w, d, eq, subs1): """ Calls subs three times for the give (w)ork and (d)epth values. @@ -288,8 +292,8 @@ def do_initial_subs(w, d, eq, subs1): return sp.simplify(w.subs(eq[0]).subs(eq[1]).subs(subs1)), sp.simplify(d.subs(eq[0]).subs(eq[1]).subs(subs1)) -def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], analyze_tasklet, - symbols: Dict[str, str], detailed_analysis: bool, equality_subs: Tuple[Dict[str, sp.Symbol], Dict[str, sp.Expr]], +def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], analyze_tasklet, symbols: Dict[str, str], + detailed_analysis: bool, equality_subs: Tuple[Dict[str, sp.Symbol], Dict[str, sp.Expr]], subs1: Dict[str, sp.Expr]) -> Tuple[sp.Expr, sp.Expr]: """ Analyze the work and depth of a given SDFG. @@ -314,11 +318,14 @@ def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], ana state_depths: Dict[SDFGState, sp.Expr] = {} state_works: Dict[SDFGState, sp.Expr] = {} for state in sdfg.nodes(): - state_work, state_depth = state_work_depth(state, w_d_map, analyze_tasklet, symbols, detailed_analysis, equality_subs, subs1) - + state_work, state_depth = state_work_depth(state, w_d_map, analyze_tasklet, symbols, detailed_analysis, + equality_subs, subs1) + # Substitutions for state_work and state_depth already performed, but state.executions needs to be subs'd now. - state_work = sp.simplify(state_work * state.executions.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1)) - state_depth = sp.simplify(state_depth * state.executions.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1)) + state_work = sp.simplify(state_work * + state.executions.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1)) + state_depth = sp.simplify(state_depth * + state.executions.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1)) state_works[state], state_depths[state] = state_work, state_depth w_d_map[get_uuid(state)] = (state_works[state], state_depths[state]) @@ -376,7 +383,7 @@ def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], ana if ie is not None: visited.add(ie) - + if state in state_value_map: # update value map: update_value_map(state_value_map[state], value_map) @@ -405,8 +412,10 @@ def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], ana old_avg_par = (cse[0] + work_map[state]) / (cse[1] + depth_map[state]) new_avg_par = (cse[0] + n_work) / (cse[1] + n_depth) # we take either old work/depth or new work/depth (or both if we cannot determine which one is greater) - depth_map[state] = cse[1] + sp.Piecewise((n_depth, sp.simplify(new_avg_par < old_avg_par)), (depth_map[state], True)) - work_map[state] = cse[0] + sp.Piecewise((n_work, sp.simplify(new_avg_par < old_avg_par)), (work_map[state], True)) + depth_map[state] = cse[1] + sp.Piecewise((n_depth, sp.simplify(new_avg_par < old_avg_par)), + (depth_map[state], True)) + work_map[state] = cse[0] + sp.Piecewise((n_work, sp.simplify(new_avg_par < old_avg_par)), + (work_map[state], True)) else: depth_map[state] = n_depth work_map[state] = n_work @@ -451,7 +460,8 @@ def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], ana traversal_q.append((oedge.dst, 0, 0, oedge, new_cond_stack, new_cse_stack, new_value_map)) else: value_map.update(oedge.data.assignments) - traversal_q.append((oedge.dst, depth_map[state], work_map[state], oedge, condition_stack, common_subexpr_stack, value_map)) + traversal_q.append((oedge.dst, depth_map[state], work_map[state], oedge, condition_stack, + common_subexpr_stack, value_map)) try: max_depth = depth_map[dummy_exit] @@ -511,7 +521,8 @@ def scope_work_depth(state: SDFGState, if isinstance(node, nd.EntryNode): # If the scope contains an entry node, we need to recursively analyze the sub-scope of the entry node first. # The resulting work/depth are summarized into the entry node - s_work, s_depth = scope_work_depth(state, w_d_map, analyze_tasklet, symbols, detailed_analysis, equality_subs, subs1, node) + s_work, s_depth = scope_work_depth(state, w_d_map, analyze_tasklet, symbols, detailed_analysis, + equality_subs, subs1, node) s_work, s_depth = do_initial_subs(s_work, s_depth, equality_subs, subs1) # add up work for whole state, but also save work for this sub-scope scope in w_d_map work += s_work @@ -536,7 +547,8 @@ def scope_work_depth(state: SDFGState, nested_syms.update(symbols) nested_syms.update(evaluate_symbols(symbols, node.symbol_mapping)) # Nested SDFGs are recursively analyzed first. - nsdfg_work, nsdfg_depth = sdfg_work_depth(node.sdfg, w_d_map, analyze_tasklet, nested_syms, detailed_analysis, equality_subs, subs1) + nsdfg_work, nsdfg_depth = sdfg_work_depth(node.sdfg, w_d_map, analyze_tasklet, nested_syms, + detailed_analysis, equality_subs, subs1) nsdfg_work, nsdfg_depth = do_initial_subs(nsdfg_work, nsdfg_depth, equality_subs, subs1) # add up work for whole state, but also save work for this nested SDFG in w_d_map @@ -629,7 +641,8 @@ def scope_work_depth(state: SDFGState, wcr_depth = oedge.data.volume / oedge.data.subset.num_elements() if get_uuid(node, state) in wcr_depth_map: # max - wcr_depth_map[get_uuid(node, state)] = sp.Max(wcr_depth_map[get_uuid(node, state)], wcr_depth) + wcr_depth_map[get_uuid(node, state)] = sp.Max(wcr_depth_map[get_uuid(node, state)], + wcr_depth) else: wcr_depth_map[get_uuid(node, state)] = wcr_depth # We do not need to propagate the wcr_depth to MapExits, since else this will result in depth N + 1 for Maps of range N. @@ -649,7 +662,7 @@ def scope_work_depth(state: SDFGState, if len(out_edges) == 0 or node == scope_exit: # We have reached an end node --> update max_depth max_depth = sp.Max(max_depth, depth_map[node]) - + for uuid in wcr_depth_map: w_d_map[uuid] = (w_d_map[uuid][0], w_d_map[uuid][1] + wcr_depth_map[uuid]) # summarise work / depth of the whole scope in the dictionary @@ -658,8 +671,8 @@ def scope_work_depth(state: SDFGState, return scope_result -def state_work_depth(state: SDFGState, w_d_map: Dict[str, sp.Expr], analyze_tasklet, - symbols, detailed_analysis, equality_subs, subs1) -> Tuple[sp.Expr, sp.Expr]: +def state_work_depth(state: SDFGState, w_d_map: Dict[str, sp.Expr], analyze_tasklet, symbols, detailed_analysis, + equality_subs, subs1) -> Tuple[sp.Expr, sp.Expr]: """ Analyze the work and depth of a state. @@ -674,11 +687,13 @@ def state_work_depth(state: SDFGState, w_d_map: Dict[str, sp.Expr], analyze_task :param subs1: First substitution dict for greater/lesser assumptions. :return: A tuple containing the work and depth of the state. """ - work, depth = scope_work_depth(state, w_d_map, analyze_tasklet, symbols, detailed_analysis, equality_subs, subs1, None) + work, depth = scope_work_depth(state, w_d_map, analyze_tasklet, symbols, detailed_analysis, equality_subs, subs1, + None) return work, depth -def analyze_sdfg(sdfg: SDFG, w_d_map: Dict[str, sp.Expr], analyze_tasklet, assumptions: [str], detailed_analysis: bool) -> None: +def analyze_sdfg(sdfg: SDFG, w_d_map: Dict[str, sp.Expr], analyze_tasklet, assumptions: [str], + detailed_analysis: bool) -> None: """ Analyze a given SDFG. We can either analyze work, work and depth or average parallelism. @@ -711,8 +726,9 @@ def analyze_sdfg(sdfg: SDFG, w_d_map: Dict[str, sp.Expr], analyze_tasklet, assum # Analyze the work and depth of the SDFG. symbols = {} - sdfg_work_depth(sdfg, w_d_map, analyze_tasklet, symbols, detailed_analysis, equality_subs, all_subs[0][0] if len(all_subs) > 0 else {}) - + sdfg_work_depth(sdfg, w_d_map, analyze_tasklet, symbols, detailed_analysis, equality_subs, + all_subs[0][0] if len(all_subs) > 0 else {}) + for k, (v_w, v_d) in w_d_map.items(): # The symeval replaces nested SDFG symbols with their global counterparts. v_w, v_d = do_subs(v_w, v_d, all_subs) @@ -758,8 +774,7 @@ def main() -> None: help='Choose what to analyze. Default: workDepth') parser.add_argument('--assume', nargs='*', help='Collect assumptions about symbols, e.g. x>0 x>y y==5') - parser.add_argument("--detailed", action="store_true", - help="Turns on detailed mode.") + parser.add_argument("--detailed", action="store_true", help="Turns on detailed mode.") args = parser.parse_args() if not os.path.exists(args.filename): diff --git a/tests/sdfg/work_depth_tests.py b/tests/sdfg/work_depth_tests.py index 924397aa1e..05375007df 100644 --- a/tests/sdfg/work_depth_tests.py +++ b/tests/sdfg/work_depth_tests.py @@ -156,17 +156,18 @@ def break_while_loop(x: dc.float64[N]): break x += 1 + @dc.program -def sequntial_ifs(x: dc.float64[N + 1], y: dc.float64[M + 1]): # --> cannot assume N, M to be positive +def sequntial_ifs(x: dc.float64[N + 1], y: dc.float64[M + 1]): # --> cannot assume N, M to be positive if x[0] > 5: - x[:] += 1 # N+1 work, 1 depth + x[:] += 1 # N+1 work, 1 depth else: for i in range(M): # M work, M depth - y[i+1] += y[i] + y[i + 1] += y[i] if M > N: - y[:N+1] += x[:] # N+1 work, 1 depth + y[:N + 1] += x[:] # N+1 work, 1 depth else: - x[:M+1] += y[:] # M+1 work, 1 depth + x[:M + 1] += y[:] # M+1 work, 1 depth # --> Work: Max(N+1, M) + Max(N+1, M+1) # Depth: Max(1, M) + 1 @@ -185,14 +186,12 @@ def sequntial_ifs(x: dc.float64[N + 1], y: dc.float64[M + 1]): # --> cannot assu (multiple_array_sizes, (sp.Max(2 * K, 3 * N, 2 * M + 3), 5)), (unbounded_while_do, (sp.Symbol('num_execs_0_2') * N, sp.Symbol('num_execs_0_2'))), # We get this Max(1, num_execs), since it is a do-while loop, but the num_execs symbol does not capture this. - (unbounded_do_while, (sp.Max(1, sp.Symbol('num_execs_0_1')) * N, - sp.Max(1, sp.Symbol('num_execs_0_1')))), - (unbounded_nonnegify, (2 * sp.Symbol('num_execs_0_7') * N, - 2 * sp.Symbol('num_execs_0_7'))), + (unbounded_do_while, (sp.Max(1, sp.Symbol('num_execs_0_1')) * N, sp.Max(1, sp.Symbol('num_execs_0_1')))), + (unbounded_nonnegify, (2 * sp.Symbol('num_execs_0_7') * N, 2 * sp.Symbol('num_execs_0_7'))), (continue_for_loop, (sp.Symbol('num_execs_0_6') * N, sp.Symbol('num_execs_0_6'))), (break_for_loop, (N**2, N)), (break_while_loop, (sp.Symbol('num_execs_0_5') * N, sp.Symbol('num_execs_0_5'))), - (sequntial_ifs, (sp.Max(N+1, M) + sp.Max(N+1, M+1), sp.Max(1, M) + 1)) + (sequntial_ifs, (sp.Max(N + 1, M) + sp.Max(N + 1, M + 1), sp.Max(1, M) + 1)) ] @@ -210,7 +209,10 @@ def test_work_depth(): # We do this since sp.Symbol('N') == Sp.Symbol('N', positive=True) --> False. reps = {s: sp.Symbol(s.name) for s in (res[0].free_symbols | res[1].free_symbols)} res = (res[0].subs(reps), res[1].subs(reps)) - reps = {s: sp.Symbol(s.name) for s in (sp.sympify(correct[0]).free_symbols | sp.sympify(correct[1]).free_symbols)} + reps = { + s: sp.Symbol(s.name) + for s in (sp.sympify(correct[0]).free_symbols | sp.sympify(correct[1]).free_symbols) + } correct = (sp.sympify(correct[0]).subs(reps), sp.sympify(correct[1]).subs(reps)) # check result assert correct == res @@ -219,31 +221,24 @@ def test_work_depth(): x, y, z, a = sp.symbols('x y z a') # (expr, assumptions, result) -assumptions_tests=[ - (sp.Max(x, y), ['x>y'], x), - (sp.Max(x, y, z), ['x>y'], sp.Max(x, z)), - (sp.Max(x, y), ['x==y'], y), - (sp.Max(x, 11) + sp.Max(x, 3), ['x<11'], 11 + sp.Max(x,3)), - (sp.Max(x, 11) + sp.Max(x, 3), ['x<11', 'x>3'], 11 + x), - (sp.Max(x, 11), ['x>5', 'x>3', 'x>11'], x), - (sp.Max(x, 11), ['x==y', 'x>11'], y), +assumptions_tests = [ + (sp.Max(x, y), ['x>y'], x), (sp.Max(x, y, z), ['x>y'], sp.Max(x, z)), (sp.Max(x, y), ['x==y'], y), + (sp.Max(x, 11) + sp.Max(x, 3), ['x<11'], 11 + sp.Max(x, 3)), (sp.Max(x, 11) + sp.Max(x, 3), ['x<11', + 'x>3'], 11 + x), + (sp.Max(x, 11), ['x>5', 'x>3', 'x>11'], x), (sp.Max(x, 11), ['x==y', 'x>11'], y), (sp.Max(x, 11) + sp.Max(a, 5), ['a==b', 'b==c', 'c==x', 'a<11', 'c>7'], x + 11), - (sp.Max(x, 11) + sp.Max(a, 5), ['a==b', 'b==c', 'c==x', 'b==7'], 18), - (sp.Max(x, y), ['y>x', 'y==1000'], 1000), + (sp.Max(x, 11) + sp.Max(a, 5), ['a==b', 'b==c', 'c==x', 'b==7'], 18), (sp.Max(x, y), ['y>x', 'y==1000'], 1000), (sp.Max(x, y), ['y0', 'N<5', 'M>5'], M) - ] # These assumptions should trigger the ContradictingAssumptions exception. -tests_for_exception = [ - ['x>10', 'x<9'], - ['x==y', 'x>10', 'y<9'], - ['a==b', 'b==c', 'c==d', 'd==e', 'e==f', 'x==y', 'y==z', 'z>b', 'x==5', 'd==100'], - ['x==5', 'x<4'] -] +tests_for_exception = [['x>10', 'x<9'], ['x==y', 'x>10', 'y<9'], + ['a==b', 'b==c', 'c==d', 'd==e', 'e==f', 'x==y', 'y==z', 'z>b', 'x==5', 'd==100'], + ['x==5', 'x<4']] + def test_assumption_system(): for expr, assums, res in assumptions_tests: From 91583e7ecd48f1d5a4aa787e99daa3f2665ba437 Mon Sep 17 00:00:00 2001 From: Cliff Hodel Date: Mon, 4 Sep 2023 13:43:15 +0200 Subject: [PATCH 15/18] minor changes --- dace/sdfg/work_depth_analysis/work_depth.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/dace/sdfg/work_depth_analysis/work_depth.py b/dace/sdfg/work_depth_analysis/work_depth.py index 21e5b937b9..b05ccc70ae 100644 --- a/dace/sdfg/work_depth_analysis/work_depth.py +++ b/dace/sdfg/work_depth_analysis/work_depth.py @@ -376,8 +376,6 @@ def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], ana traversal_q.append((sdfg.start_state, sp.sympify(0), sp.sympify(0), None, [], [], {})) visited = set() - num_edges = 0 - while traversal_q: state, depth, work, ie, condition_stack, common_subexpr_stack, value_map = traversal_q.popleft() @@ -480,10 +478,10 @@ def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], ana def scope_work_depth(state: SDFGState, w_d_map: Dict[str, sp.Expr], analyze_tasklet, - symbols, - detailed_analysis, - equality_subs, - subs1, + symbols: Dict[str, str], + detailed_analysis: bool, + equality_subs: Tuple[Dict[str, sp.Symbol], Dict[str, sp.Expr]], + subs1: Dict[str, sp.Expr], entry: nd.EntryNode = None) -> Tuple[sp.Expr, sp.Expr]: """ Analyze the work and depth of a scope. From eb5a6f427d47f314e3254f681639cf3f155f77c8 Mon Sep 17 00:00:00 2001 From: Cliff Hodel Date: Mon, 18 Sep 2023 18:23:42 +0200 Subject: [PATCH 16/18] start of op_in analysis --- .../sdfg/work_depth_analysis/op_in_helpers.py | 78 ++ .../operational_intensity.py | 1004 +++++++++++++++++ 2 files changed, 1082 insertions(+) create mode 100644 dace/sdfg/work_depth_analysis/op_in_helpers.py create mode 100644 dace/sdfg/work_depth_analysis/operational_intensity.py diff --git a/dace/sdfg/work_depth_analysis/op_in_helpers.py b/dace/sdfg/work_depth_analysis/op_in_helpers.py new file mode 100644 index 0000000000..c7c17741de --- /dev/null +++ b/dace/sdfg/work_depth_analysis/op_in_helpers.py @@ -0,0 +1,78 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +""" Contains class CacheLineTracker which keeps track of all arrays of an SDFG and their cache line position. +Further, contains class AccessStack which which corresponds to the stack used to compute the stack distance. """ + +from dace.data import Array + +class CacheLineTracker: + + def __init__(self, L) -> None: + self.array_info = {} + self.start_lines = {} + self.next_free_line = 0 + self.L = L + + def add_array(self, name: str, a: Array): + if name not in self.start_lines: + # new array encountered + self.array_info[name] = a + self.start_lines[name] = self.next_free_line + # increase next_free_line + self.next_free_line += (a.total_size * a.dtype.bytes + self.L - 1) // self.L # ceil division + + def cache_line_id(self, name: str, access: [int]): + arr = self.array_info[name] + one_d_index = 0 + for dim in range(len(access)): + i = access[dim] + one_d_index += (i + arr.offset[dim]) * arr.strides[dim] + + # divide by L to get the cache line id + return self.start_lines[name] + (one_d_index * arr.dtype.bytes) // self.L + + +class Node: + + def __init__(self, val: int, n=None) -> None: + self.v = val + self.next = n + + +class AccessStack: + """ A stack of cache line ids. For each memory access, we search the corresponding cache line id + in the stack, report its distance and move it to the top of the stack. If the id was not found, + we report a distance of -1. """ + + def __init__(self) -> None: + self.top = None + + def touch(self, id): + + curr = self.top + prev = None + found = False + distance = 0 + while curr is not None: + # check if we found id + if curr.v == id: + # take curr node out + if prev is not None: + prev.next = curr.next + curr.next = self.top + self.top = curr + + found = True + break + + # iterate further + prev = curr + curr = curr.next + distance += 1 + + if not found: + # we accessed this cache line for the first time ever + self.top = Node(id, self.top) + distance = -1 + + return distance + diff --git a/dace/sdfg/work_depth_analysis/operational_intensity.py b/dace/sdfg/work_depth_analysis/operational_intensity.py new file mode 100644 index 0000000000..af94c7f924 --- /dev/null +++ b/dace/sdfg/work_depth_analysis/operational_intensity.py @@ -0,0 +1,1004 @@ +# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. +""" Analyses the operational intensity of an input SDFG. Can be used as a Python script +or from the VS Code extension. """ + +""" +Plan: +- For each memory access, we need to figure out its cache line and then we compute its stack distance. +- For that we model the actual stack, where we push all the memory acesses (What do we push exactly? +Cache line ids?? check typescript implementation for that information.) +- How do we know which array maps to which cache line? + Idea: for each new array encountered, just assume that it is cache line aligned and starts + at the next free cache line. TODO: check if this is how it usually behaves. Or are arrays + aligned further, like base address % x == 0 for some x bigger than cache line size? +- It is also important that we take data types into account for each array. +- For each mem access we increase the miss counter if stack distance > C(apacity) or it it is a +compulsory miss. Then, in the end we know how many bytes are transferred to cache. It is: + num_misses * L(ine size in bytes) + +- Parameters to our analysis are + - input SDFG + - C(ache capacity) + - L(ine size) +""" + + + + + + + + + + +import argparse +from collections import deque +from dace.sdfg import nodes as nd, propagation, InterstateEdge +from dace import SDFG, SDFGState, dtypes, int64 +from dace.subsets import Range +from typing import Tuple, Dict +import os +import sympy as sp +from copy import deepcopy +from dace.libraries.blas import MatMul +from dace.libraries.standard import Reduce, Transpose +from dace.symbolic import pystr_to_symbolic +import ast +import astunparse +import warnings + +from dace.sdfg.work_depth_analysis.helpers import get_uuid, find_loop_guards_tails_exits +from dace.sdfg.work_depth_analysis.assumptions import parse_assumptions +from dace.transformation.passes.symbol_ssa import StrictSymbolSSA +from dace.transformation.pass_pipeline import FixedPointPipeline + +from dace.data import Array +from dace.sdfg.work_depth_analysis.op_in_helpers import CacheLineTracker, AccessStack + +def get_array_size_symbols(sdfg): + """ + Returns all symbols that appear isolated in shapes of the SDFG's arrays. + These symbols can then be assumed to be positive. + + :note: This only works if a symbol appears in isolation, i.e. array A[N]. + If we have A[N+1], we cannot assume N to be positive. + :param sdfg: The SDFG in which it searches for symbols. + :return: A set containing symbols which we can assume to be positive. + """ + symbols = set() + for _, _, arr in sdfg.arrays_recursive(): + for s in arr.shape: + if isinstance(s, sp.Symbol): + symbols.add(s) + return symbols + + +def symeval(val, symbols): + """ + Takes a sympy expression and substitutes its symbols according to a dict { old_symbol: new_symbol}. + + :param val: The expression we are updating. + :param symbols: Dictionary of key value pairs { old_symbol: new_symbol}. + """ + first_replacement = {pystr_to_symbolic(k): pystr_to_symbolic('__REPLSYM_' + k) for k in symbols.keys()} + second_replacement = {pystr_to_symbolic('__REPLSYM_' + k): v for k, v in symbols.items()} + return sp.simplify(val.subs(first_replacement).subs(second_replacement)) + + +def evaluate_symbols(base, new): + result = {} + for k, v in new.items(): + result[k] = symeval(v, base) + return result + + +def count_work_matmul(node, symbols, state): + A_memlet = next(e for e in state.in_edges(node) if e.dst_conn == '_a') + B_memlet = next(e for e in state.in_edges(node) if e.dst_conn == '_b') + C_memlet = next(e for e in state.out_edges(node) if e.src_conn == '_c') + result = 2 # Multiply, add + # Batch + if len(C_memlet.data.subset) == 3: + result *= symeval(C_memlet.data.subset.size()[0], symbols) + # M*N + result *= symeval(C_memlet.data.subset.size()[-2], symbols) + result *= symeval(C_memlet.data.subset.size()[-1], symbols) + # K + result *= symeval(A_memlet.data.subset.size()[-1], symbols) + return sp.sympify(result) + + +def count_depth_matmul(node, symbols, state): + # optimal depth of a matrix multiplication is O(log(size of shared dimension)): + A_memlet = next(e for e in state.in_edges(node) if e.dst_conn == '_a') + size_shared_dimension = symeval(A_memlet.data.subset.size()[-1], symbols) + return bigo(sp.log(size_shared_dimension)) + + +def count_work_reduce(node, symbols, state): + result = 0 + if node.wcr is not None: + result += count_arithmetic_ops_code(node.wcr) + in_memlet = None + in_edges = state.in_edges(node) + if in_edges is not None and len(in_edges) == 1: + in_memlet = in_edges[0] + if in_memlet is not None and in_memlet.data.volume is not None: + result *= in_memlet.data.volume + else: + result = 0 + return sp.sympify(result) + + +def count_depth_reduce(node, symbols, state): + # optimal depth of reduction is log of the work + return bigo(sp.log(count_work_reduce(node, symbols, state))) + + +LIBNODES_TO_WORK = { + MatMul: count_work_matmul, + Transpose: lambda *args: 0, + Reduce: count_work_reduce, +} + +LIBNODES_TO_DEPTH = { + MatMul: count_depth_matmul, + Transpose: lambda *args: 0, + Reduce: count_depth_reduce, +} + +bigo = sp.Function('bigo') +PYFUNC_TO_ARITHMETICS = { + 'float': 0, + 'dace.float64': 0, + 'dace.int64': 0, + 'math.exp': 1, + 'exp': 1, + 'math.tanh': 1, + 'sin': 1, + 'cos': 1, + 'tanh': 1, + 'math.sqrt': 1, + 'sqrt': 1, + 'atan2:': 1, + 'min': 0, + 'max': 0, + 'ceiling': 0, + 'floor': 0, + 'abs': 0 +} + + +class ArithmeticCounter(ast.NodeVisitor): + + def __init__(self): + self.count = 0 + + def visit_BinOp(self, node): + if isinstance(node.op, ast.MatMult): + raise NotImplementedError('MatMult op count requires shape ' + 'inference') + self.count += 1 + return self.generic_visit(node) + + def visit_UnaryOp(self, node): + self.count += 1 + return self.generic_visit(node) + + def visit_Call(self, node): + fname = astunparse.unparse(node.func)[:-1] + if fname not in PYFUNC_TO_ARITHMETICS: + print( + 'WARNING: Unrecognized python function "%s". If this is a type conversion, like "dace.float64", then this is fine.' + % fname) + return self.generic_visit(node) + self.count += PYFUNC_TO_ARITHMETICS[fname] + return self.generic_visit(node) + + def visit_AugAssign(self, node): + return self.visit_BinOp(node) + + def visit_For(self, node): + raise NotImplementedError + + def visit_While(self, node): + raise NotImplementedError + + +def count_arithmetic_ops_code(code): + ctr = ArithmeticCounter() + if isinstance(code, (tuple, list)): + for stmt in code: + ctr.visit(stmt) + elif isinstance(code, str): + ctr.visit(ast.parse(code)) + else: + ctr.visit(code) + return ctr.count + + +class DepthCounter(ast.NodeVisitor): + # so far this is identical to the ArithmeticCounter above. + def __init__(self): + self.count = 0 + + def visit_BinOp(self, node): + if isinstance(node.op, ast.MatMult): + raise NotImplementedError('MatMult op count requires shape ' + 'inference') + self.count += 1 + return self.generic_visit(node) + + def visit_UnaryOp(self, node): + self.count += 1 + return self.generic_visit(node) + + def visit_Call(self, node): + fname = astunparse.unparse(node.func)[:-1] + if fname not in PYFUNC_TO_ARITHMETICS: + print( + 'WARNING: Unrecognized python function "%s". If this is a type conversion, like "dace.float64", then this is fine.' + % fname) + return self.generic_visit(node) + self.count += PYFUNC_TO_ARITHMETICS[fname] + return self.generic_visit(node) + + def visit_AugAssign(self, node): + return self.visit_BinOp(node) + + def visit_For(self, node): + raise NotImplementedError + + def visit_While(self, node): + raise NotImplementedError + + +def count_depth_code(code): + # so far this is the same as the work counter, since work = depth for each tasklet, as we can't assume any parallelism + ctr = ArithmeticCounter() + if isinstance(code, (tuple, list)): + for stmt in code: + ctr.visit(stmt) + elif isinstance(code, str): + ctr.visit(ast.parse(code)) + else: + ctr.visit(code) + return ctr.count + + +def tasklet_work(tasklet_node, state): + if tasklet_node.code.language == dtypes.Language.CPP: + # simplified work analysis for CPP tasklets. + for oedge in state.out_edges(tasklet_node): + return oedge.data.num_accesses + elif tasklet_node.code.language == dtypes.Language.Python: + return count_arithmetic_ops_code(tasklet_node.code.code) + else: + # other languages not implemented, count whole tasklet as work of 1 + warnings.warn('Work of tasklets only properly analyzed for Python or CPP. For all other ' + 'languages work = 1 will be counted for each tasklet.') + return 1 + + +def tasklet_depth(tasklet_node, state): + if tasklet_node.code.language == dtypes.Language.CPP: + # For now we simply take depth == work for CPP tasklets. + for oedge in state.out_edges(tasklet_node): + return oedge.data.num_accesses + if tasklet_node.code.language == dtypes.Language.Python: + return count_depth_code(tasklet_node.code.code) + else: + # other languages not implemented, count whole tasklet as work of 1 + warnings.warn('Depth of tasklets only properly analyzed for Python code. For all other ' + 'languages depth = 1 will be counted for each tasklet.') + return 1 + + +def get_tasklet_work(node, state): + return sp.sympify(tasklet_work(node, state)), sp.sympify(-1) + + +def get_tasklet_work_depth(node, state): + return sp.sympify(tasklet_work(node, state)), sp.sympify(tasklet_depth(node, state)) + + +def get_tasklet_avg_par(node, state): + return sp.sympify(tasklet_work(node, state)), sp.sympify(tasklet_depth(node, state)) + + +def update_value_map(old, new): + # add new assignments to old + old.update({k: v for k, v in new.items() if k not in old}) + # check for conflicts: + for k, v in new.items(): + if k in old and old[k] != v: + # conflict detected --> forget this mapping completely + old.pop(k) + + +def do_initial_subs(w, d, eq, subs1): + """ + Calls subs three times for the give (w)ork and (d)epth values. + """ + return sp.simplify(w.subs(eq[0]).subs(eq[1]).subs(subs1)), sp.simplify(d.subs(eq[0]).subs(eq[1]).subs(subs1)) + + +def sdfg_op_in_OLD(sdfg: SDFG, op_in_map: Dict[str, Tuple[sp.Expr, sp.Expr]], analyze_tasklet, symbols: Dict[str, str], + detailed_analysis: bool, equality_subs: Tuple[Dict[str, sp.Symbol], Dict[str, sp.Expr]], + subs1: Dict[str, sp.Expr]) -> Tuple[sp.Expr, sp.Expr]: + """ + Analyze the work and depth of a given SDFG. + First we determine the work and depth of each state. Then we break loops in the state machine, such that we get a DAG. + Lastly, we compute the path with most work and the path with the most depth in order to get the total work depth. + + :param sdfg: The SDFG to analyze. + :param op_in_map: Dictionary which will save the result. + :param analyze_tasklet: Function used to analyze tasklet nodes. + :param symbols: A dictionary mapping local nested SDFG symbols to global symbols. + :param detailed_analysis: If True, detailed analysis gets used. For each branch, we keep track of its condition + and work depth values for both branches. If False, the worst-case branch is taken. Discouraged to use on bigger SDFGs, + as computation time sky-rockets, since expression can became HUGE (depending on number of branches etc.). + :param equality_subs: Substitution dict taking care of the equality assumptions. + :param subs1: First substitution dict for greater/lesser assumptions. + :return: A tuple containing the work and depth of the SDFG. + """ + + # First determine the work and depth of each state individually. + # Keep track of the work and depth for each state in a dictionary, where work and depth are multiplied by the number + # of times the state will be executed. + state_depths: Dict[SDFGState, sp.Expr] = {} + state_works: Dict[SDFGState, sp.Expr] = {} + for state in sdfg.nodes(): + state_work, state_depth = state_op_in_OLD(state, op_in_map, analyze_tasklet, symbols, detailed_analysis, + equality_subs, subs1) + + # Substitutions for state_work and state_depth already performed, but state.executions needs to be subs'd now. + state_work = sp.simplify(state_work * + state.executions.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1)) + state_depth = sp.simplify(state_depth * + state.executions.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1)) + + state_works[state], state_depths[state] = state_work, state_depth + op_in_map[get_uuid(state)] = (state_works[state], state_depths[state]) + + # Prepare the SDFG for a depth analysis by breaking loops. This removes the edge between the last loop state and + # the guard, and instead places an edge between the last loop state and the exit state. + # This transforms the state machine into a DAG. Hence, we can find the "heaviest" and "deepest" paths in linear time. + # Additionally, construct a dummy exit state and connect every state that has no outgoing edges to it. + + # identify all loops in the SDFG + nodes_oNodes_exits = find_loop_guards_tails_exits(sdfg._nx) + + # Now we need to go over each triple (node, oNode, exits). For each triple, we + # - remove edge (oNode, node), i.e. the backward edge + # - for all exits e, add edge (oNode, e). This edge may already exist + # - remove edge from node to exit (if present, i.e. while-do loop) + # - This ensures that every node with > 1 outgoing edge is a branch guard + # - useful for detailed anaylsis. + for node, oNode, exits in nodes_oNodes_exits: + sdfg.remove_edge(sdfg.edges_between(oNode, node)[0]) + for e in exits: + if len(sdfg.edges_between(oNode, e)) == 0: + # no edge there yet + sdfg.add_edge(oNode, e, InterstateEdge()) + if len(sdfg.edges_between(node, e)) > 0: + # edge present --> remove it + sdfg.remove_edge(sdfg.edges_between(node, e)[0]) + + # add a dummy exit to the SDFG, such that each path ends there. + dummy_exit = sdfg.add_state('dummy_exit') + for state in sdfg.nodes(): + if len(sdfg.out_edges(state)) == 0 and state != dummy_exit: + sdfg.add_edge(state, dummy_exit, InterstateEdge()) + + # These two dicts save the current length of the "heaviest", resp. "deepest", paths at each state. + work_map: Dict[SDFGState, sp.Expr] = {} + depth_map: Dict[SDFGState, sp.Expr] = {} + # Keeps track of assignments done on InterstateEdges. + state_value_map: Dict[SDFGState, Dict[sp.Symbol, sp.Symbol]] = {} + # The dummy state has 0 work and depth. + state_depths[dummy_exit] = sp.sympify(0) + state_works[dummy_exit] = sp.sympify(0) + + # Perform a BFS traversal of the state machine and calculate the maximum work / depth at each state. Only advance to + # the next state in the BFS if all incoming edges have been visited, to ensure the maximum work / depth expressions + # have been calculated. + traversal_q = deque() + traversal_q.append((sdfg.start_state, sp.sympify(0), sp.sympify(0), None, [], [], {})) + visited = set() + + while traversal_q: + state, depth, work, ie, condition_stack, common_subexpr_stack, value_map = traversal_q.popleft() + + if ie is not None: + visited.add(ie) + + if state in state_value_map: + # update value map: + update_value_map(state_value_map[state], value_map) + else: + state_value_map[state] = value_map + + # ignore assignments such as tmp=x[0], as those do not give much information. + value_map = {k: v for k, v in state_value_map[state].items() if '[' not in k and '[' not in v} + n_depth = sp.simplify((depth + state_depths[state]).subs(value_map)) + n_work = sp.simplify((work + state_works[state]).subs(value_map)) + + # If we are analysing average parallelism, we don't search "heaviest" and "deepest" paths separately, but we want one + # single path with the least average parallelsim (of all paths with more than 0 work). + if analyze_tasklet == get_tasklet_avg_par: + if state in depth_map: # this means we have already visited this state before + cse = common_subexpr_stack.pop() + # if current path has 0 depth (--> 0 work as well), we don't do anything. + if n_depth != 0: + # check if we need to update the work and depth of the current state + # we update if avg parallelism of new incoming path is less than current avg parallelism + if depth_map[state] == 0: + # old value was divided by zero --> we take new value anyway + depth_map[state] = cse[1] + n_depth + work_map[state] = cse[0] + n_work + else: + old_avg_par = (cse[0] + work_map[state]) / (cse[1] + depth_map[state]) + new_avg_par = (cse[0] + n_work) / (cse[1] + n_depth) + # we take either old work/depth or new work/depth (or both if we cannot determine which one is greater) + depth_map[state] = cse[1] + sp.Piecewise((n_depth, sp.simplify(new_avg_par < old_avg_par)), + (depth_map[state], True)) + work_map[state] = cse[0] + sp.Piecewise((n_work, sp.simplify(new_avg_par < old_avg_par)), + (work_map[state], True)) + else: + depth_map[state] = n_depth + work_map[state] = n_work + else: + # search heaviest and deepest path separately + if state in depth_map: # and consequently also in work_map + # This cse value would appear in both arguments of the Max. Hence, for performance reasons, + # we pull it out of the Max expression. + # Example: We do cse + Max(a, b) instead of Max(cse + a, cse + b). + # This increases performance drastically, expecially since we avoid nesting Max expressions + # for cases where cse itself contains Max operators. + cse = common_subexpr_stack.pop() + if detailed_analysis: + # This MAX should be covered in the more detailed analysis + cond = condition_stack.pop() + work_map[state] = cse[0] + sp.Piecewise((work_map[state], sp.Not(cond)), (n_work, cond)) + depth_map[state] = cse[1] + sp.Piecewise((depth_map[state], sp.Not(cond)), (n_depth, cond)) + else: + work_map[state] = cse[0] + sp.Max(work_map[state], n_work) + depth_map[state] = cse[1] + sp.Max(depth_map[state], n_depth) + else: + depth_map[state] = n_depth + work_map[state] = n_work + + out_edges = sdfg.out_edges(state) + # only advance after all incoming edges were visited (meaning that current work depth values of state are final). + if any(iedge not in visited for iedge in sdfg.in_edges(state)): + pass + else: + for oedge in out_edges: + if len(out_edges) > 1: + # It is important to copy these stacks. Else both branches operate on the same stack. + # state is a branch guard --> save condition on stack + new_cond_stack = list(condition_stack) + new_cond_stack.append(oedge.data.condition_sympy()) + # same for common_subexr_stack + new_cse_stack = list(common_subexpr_stack) + new_cse_stack.append((work_map[state], depth_map[state])) + # same for value_map + new_value_map = dict(state_value_map[state]) + new_value_map.update({sp.Symbol(k): sp.Symbol(v) for k, v in oedge.data.assignments.items()}) + traversal_q.append((oedge.dst, 0, 0, oedge, new_cond_stack, new_cse_stack, new_value_map)) + else: + value_map.update(oedge.data.assignments) + traversal_q.append((oedge.dst, depth_map[state], work_map[state], oedge, condition_stack, + common_subexpr_stack, value_map)) + + try: + max_depth = depth_map[dummy_exit] + max_work = work_map[dummy_exit] + except KeyError: + # If we get a KeyError above, this means that the traversal never reached the dummy_exit state. + # This happens if the loops were not properly detected and broken. + raise Exception( + 'Analysis failed, since not all loops got detected. It may help to use more structured loop constructs.') + + sdfg_result = (max_work, max_depth) + op_in_map[get_uuid(sdfg)] = sdfg_result + return sdfg_result + + +def scope_op_in_OLD(state: SDFGState, + op_in_map: Dict[str, sp.Expr], + analyze_tasklet, + symbols: Dict[str, str], + detailed_analysis: bool, + equality_subs: Tuple[Dict[str, sp.Symbol], Dict[str, sp.Expr]], + subs1: Dict[str, sp.Expr], + entry: nd.EntryNode = None) -> Tuple[sp.Expr, sp.Expr]: + """ + Analyze the work and depth of a scope. + This works by traversing through the scope analyzing the work and depth of each encountered node. + Depending on what kind of node we encounter, we do the following: + - EntryNode: Recursively analyze work depth of scope. + - Tasklet: use analyze_tasklet to get work depth of tasklet node. + - NestedSDFG: After translating its local symbols to global symbols, we analyze the nested SDFG recursively. + - LibraryNode: Library nodes are analyzed with special functions depending on their type. + Work inside a state can simply be summed up, but for the depth we need to find the longest path. Since dataflow is a DAG, + this can be done in linear time by traversing the graph in topological order. + + :param state: The state in which the scope to analyze is contained. + :param op_in_map: Dictionary saving the final result for each SDFG element. + :param analyze_tasklet: Function used to analyze tasklets. Either analyzes just work, work and depth or average parallelism. + :param symbols: A dictionary mapping local nested SDFG symbols to global symbols. + :param detailed_analysis: If True, detailed analysis gets used. For each branch, we keep track of its condition + and work depth values for both branches. If False, the worst-case branch is taken. Discouraged to use on bigger SDFGs, + as computation time sky-rockets, since expression can became HUGE (depending on number of branches etc.). + :param equality_subs: Substitution dict taking care of the equality assumptions. + :param subs1: First substitution dict for greater/lesser assumptions. + :param entry: The entry node of the scope to analyze. If None, the entire state is analyzed. + :return: A tuple containing the work and depth of the scope. + """ + + # find the work and depth of each node + # for maps and nested SDFG, we do it recursively + work = sp.sympify(0) + max_depth = sp.sympify(0) + scope_nodes = state.scope_children()[entry] + scope_exit = None if entry is None else state.exit_node(entry) + for node in scope_nodes: + # add node to map + op_in_map[get_uuid(node, state)] = (sp.sympify(0), sp.sympify(0)) + if isinstance(node, nd.EntryNode): + # If the scope contains an entry node, we need to recursively analyze the sub-scope of the entry node first. + # The resulting work/depth are summarized into the entry node + s_work, s_depth = scope_op_in_OLD(state, op_in_map, analyze_tasklet, symbols, detailed_analysis, + equality_subs, subs1, node) + s_work, s_depth = do_initial_subs(s_work, s_depth, equality_subs, subs1) + # add up work for whole state, but also save work for this sub-scope scope in op_in_map + work += s_work + op_in_map[get_uuid(node, state)] = (s_work, s_depth) + elif node == scope_exit: + # don't do anything for exit nodes, everthing handled already in the corresponding entry node. + pass + elif isinstance(node, nd.Tasklet): + # add up work for whole state, but also save work for this node in op_in_map + # t_work, t_depth = analyze_tasklet(node, state) + + # analyze the memory accesses of this tasklet and whether they hit in cache or not + print('tasklet') + t_work, t_depth = sp.sympify(100), sp.sympify(100) + + + # check if tasklet has any outgoing wcr edges + for e in state.out_edges(node): + if e.data.wcr is not None: + t_work += count_arithmetic_ops_code(e.data.wcr) + t_work, t_depth = do_initial_subs(t_work, t_depth, equality_subs, subs1) + work += t_work + op_in_map[get_uuid(node, state)] = (t_work, t_depth) + elif isinstance(node, nd.NestedSDFG): + # keep track of nested symbols: "symbols" maps local nested SDFG symbols to global symbols. + # We only want global symbols in our final work depth expressions. + nested_syms = {} + nested_syms.update(symbols) + nested_syms.update(evaluate_symbols(symbols, node.symbol_mapping)) + # Nested SDFGs are recursively analyzed first. + nsdfg_work, nsdfg_depth = sdfg_op_in_OLD(node.sdfg, op_in_map, analyze_tasklet, nested_syms, + detailed_analysis, equality_subs, subs1) + + nsdfg_work, nsdfg_depth = do_initial_subs(nsdfg_work, nsdfg_depth, equality_subs, subs1) + # add up work for whole state, but also save work for this nested SDFG in op_in_map + work += nsdfg_work + op_in_map[get_uuid(node, state)] = (nsdfg_work, nsdfg_depth) + elif isinstance(node, nd.LibraryNode): + try: + lib_node_work = LIBNODES_TO_WORK[type(node)](node, symbols, state) + except KeyError: + # add a symbol to the top level sdfg, such that the user can define it in the extension + top_level_sdfg = state.parent + # TODO: This symbol should now appear in the VS code extension in the SDFG analysis tab, + # such that the user can define its value. But it doesn't... + # How to achieve this? + top_level_sdfg.add_symbol(f'{node.name}_work', int64) + lib_node_work = sp.Symbol(f'{node.name}_work', positive=True) + lib_node_depth = sp.sympify(-1) # not analyzed + if analyze_tasklet != get_tasklet_work: + # we are analyzing depth + try: + lib_node_depth = LIBNODES_TO_DEPTH[type(node)](node, symbols, state) + except KeyError: + top_level_sdfg = state.parent + top_level_sdfg.add_symbol(f'{node.name}_depth', int64) + lib_node_depth = sp.Symbol(f'{node.name}_depth', positive=True) + lib_node_work, lib_node_depth = do_initial_subs(lib_node_work, lib_node_depth, equality_subs, subs1) + work += lib_node_work + op_in_map[get_uuid(node, state)] = (lib_node_work, lib_node_depth) + + if entry is not None: + # If the scope being analyzed is a map, multiply the work by the number of iterations of the map. + if isinstance(entry, nd.MapEntry): + nmap: nd.Map = entry.map + range: Range = nmap.range + n_exec = range.num_elements_exact() + work = sp.simplify(work * n_exec.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1)) + else: + print('WARNING: Only Map scopes are supported in work analysis for now. Assuming 1 iteration.') + + # Work inside a state can simply be summed up. But now we need to find the depth of a state (i.e. longest path). + # Since dataflow graph is a DAG, this can be done in linear time. + max_depth = sp.sympify(0) + # only do this if we are analyzing depth + if analyze_tasklet == get_tasklet_work_depth or analyze_tasklet == get_tasklet_avg_par: + # Calculate the maximum depth of the scope by finding the 'deepest' path from the source to the sink. This is done by + # a traversal in topological order, where each node propagates its current max depth for all incoming paths. + traversal_q = deque() + visited = set() + # find all starting nodes + if entry: + # the entry is the starting node + traversal_q.append((entry, sp.sympify(0), None)) + else: + for node in scope_nodes: + if len(state.in_edges(node)) == 0: + # This node is a start node of the traversal + traversal_q.append((node, sp.sympify(0), None)) + # this map keeps track of the length of the longest path ending at each state so far seen. + depth_map = {} + wcr_depth_map = {} + while traversal_q: + node, in_depth, in_edge = traversal_q.popleft() + + if in_edge is not None: + visited.add(in_edge) + + n_depth = sp.simplify(in_depth + op_in_map[get_uuid(node, state)][1]) + + if node in depth_map: + depth_map[node] = sp.Max(depth_map[node], n_depth) + else: + depth_map[node] = n_depth + + out_edges = state.out_edges(node) + # Only advance to next node, if all incoming edges have been visited or the current node is the entry (aka starting node). + # If the current node is the exit of the scope, we stop, such that we don't leave the scope. + if (all(iedge in visited for iedge in state.in_edges(node)) or node == entry) and node != scope_exit: + # If we encounter a nested map, we must not analyze its contents (as they have already been recursively analyzed). + # Hence, we continue from the outgoing edges of the corresponding exit. + if isinstance(node, nd.EntryNode) and node != entry: + exit_node = state.exit_node(node) + # replace out_edges with the out_edges of the scope exit node + out_edges = state.out_edges(exit_node) + for oedge in out_edges: + # check for wcr + wcr_depth = sp.sympify(0) + if oedge.data.wcr is not None: + # This division gives us the number of writes to each single memory location, which is the depth + # as these need to be sequential (without assumptions on HW etc). + wcr_depth = oedge.data.volume / oedge.data.subset.num_elements() + if get_uuid(node, state) in wcr_depth_map: + # max + wcr_depth_map[get_uuid(node, state)] = sp.Max(wcr_depth_map[get_uuid(node, state)], + wcr_depth) + else: + wcr_depth_map[get_uuid(node, state)] = wcr_depth + # We do not need to propagate the wcr_depth to MapExits, since else this will result in depth N + 1 for Maps of range N. + wcr_depth = wcr_depth if not isinstance(oedge.dst, nd.MapExit) else sp.sympify(0) + + # only append if it's actually new information + # this e.g. helps for huge nested SDFGs with lots of inputs/outputs inside a map scope + append = True + for n, d, _ in traversal_q: + if oedge.dst == n and depth_map[node] + wcr_depth == d: + append = False + break + if append: + traversal_q.append((oedge.dst, depth_map[node] + wcr_depth, oedge)) + else: + visited.add(oedge) + if len(out_edges) == 0 or node == scope_exit: + # We have reached an end node --> update max_depth + max_depth = sp.Max(max_depth, depth_map[node]) + + for uuid in wcr_depth_map: + op_in_map[uuid] = (op_in_map[uuid][0], op_in_map[uuid][1] + wcr_depth_map[uuid]) + # summarise work / depth of the whole scope in the dictionary + scope_result = (work, max_depth) + op_in_map[get_uuid(state)] = scope_result + return scope_result + + +def state_op_in_OLD(state: SDFGState, op_in_map: Dict[str, sp.Expr], analyze_tasklet, symbols, detailed_analysis, + equality_subs, subs1) -> Tuple[sp.Expr, sp.Expr]: + """ + Analyze the work and depth of a state. + + :param state: The state to analyze. + :param op_in_map: The result will be saved to this map. + :param analyze_tasklet: Function used to analyze tasklet nodes. + :param symbols: A dictionary mapping local nested SDFG symbols to global symbols. + :param detailed_analysis: If True, detailed analysis gets used. For each branch, we keep track of its condition + and work depth values for both branches. If False, the worst-case branch is taken. Discouraged to use on bigger SDFGs, + as computation time sky-rockets, since expression can became HUGE (depending on number of branches etc.). + :param equality_subs: Substitution dict taking care of the equality assumptions. + :param subs1: First substitution dict for greater/lesser assumptions. + :return: A tuple containing the work and depth of the state. + """ + work, depth = scope_op_in_OLD(state, op_in_map, analyze_tasklet, symbols, detailed_analysis, equality_subs, subs1, + None) + return work, depth + + +def analyze_sdfg(sdfg: SDFG, op_in_map: Dict[str, sp.Expr], analyze_tasklet, assumptions: [str], + detailed_analysis: bool) -> None: + """ + Analyze a given SDFG. We can either analyze work, work and depth or average parallelism. + + :note: SDFGs should have split interstate edges. This means there should be no interstate edges containing both a + condition and an assignment. + :param sdfg: The SDFG to analyze. + :param op_in_map: Dictionary of SDFG elements to (work, depth) tuples. Result will be saved in here. + :param analyze_tasklet: Function used to analyze tasklet nodes. Analyzes either just work, work and depth or average parallelism. + :param assumptions: List of strings. Each string corresponds to one assumption for some symbol, e.g. 'N>5'. + :param detailed_analysis: If True, detailed analysis gets used. For each branch, we keep track of its condition + and work depth values for both branches. If False, the worst-case branch is taken. Discouraged to use on bigger SDFGs, + as computation time sky-rockets, since expression can became HUGE (depending on number of branches etc.). + """ + + # deepcopy such that original sdfg not changed + sdfg = deepcopy(sdfg) + + # apply SSA pass + pipeline = FixedPointPipeline([StrictSymbolSSA()]) + pipeline.apply_pass(sdfg, {}) + + array_symbols = get_array_size_symbols(sdfg) + # parse assumptions + equality_subs, all_subs = parse_assumptions(assumptions if assumptions is not None else [], array_symbols) + + # Run state propagation for all SDFGs recursively. This is necessary to determine the number of times each state + # will be executed, or to determine upper bounds for that number (such as in the case of branching) + for sd in sdfg.all_sdfgs_recursive(): + propagation.propagate_states(sd, concretize_dynamic_unbounded=True) + + # Analyze the work and depth of the SDFG. + symbols = {} + sdfg_op_in_OLD(sdfg, op_in_map, analyze_tasklet, symbols, detailed_analysis, equality_subs, + all_subs[0][0] if len(all_subs) > 0 else {}) + + for k, (v_w, v_d) in op_in_map.items(): + # The symeval replaces nested SDFG symbols with their global counterparts. + v_w, v_d = do_subs(v_w, v_d, all_subs) + v_w = symeval(v_w, symbols) + v_d = symeval(v_d, symbols) + op_in_map[k] = (v_w, v_d) + + +def do_subs(work, depth, all_subs): + """ + Handles all substitutions beyond the equality substitutions and the first substitution. + :param work: Some work expression. + :param depth: Some depth expression. + :param all_subs: List of substitution pairs to perform. + :return: Work depth expressions after doing all substitutions. + """ + # first do subs2 of first sub + # then do all the remaining subs + subs2 = all_subs[0][1] if len(all_subs) > 0 else {} + work, depth = sp.simplify(sp.sympify(work).subs(subs2)), sp.simplify(sp.sympify(depth).subs(subs2)) + for i in range(1, len(all_subs)): + subs1, subs2 = all_subs[i] + work, depth = sp.simplify(work.subs(subs1)), sp.simplify(depth.subs(subs1)) + work, depth = sp.simplify(work.subs(subs2)), sp.simplify(depth.subs(subs2)) + return work, depth + + + + + + + + + + + + +def update_mapping(map, mapping): + # update the map params and return False + # if all iterations exhausted, return True + # always increase the last one, if it is exhausted, increase the next one and so forth + map_exhausted = True + for p, range in zip(map.params[::-1], map.range[::-1]): + curr_value = mapping[p] + if curr_value < range[1]: + # update this value and we done + mapping[p] = curr_value + range[2] + map_exhausted = False + break + else: + # set current param to start again and continue + mapping[p] = range[0] + return map_exhausted + + + +def map_op_in(state: SDFGState, op_in_map: Dict[str, sp.Expr], entry, mapping, stack, clt, C): + # we are inside a map --> we need to iterate over the map range and check each memory access. + for p, range in zip(entry.map.params, entry.map.range): + # map each map iteration variable to its start + mapping[p] = range[0] + map_misses = 0 + while True: + # do analysis of map contents + map_misses += scope_op_in(state, op_in_map, mapping, stack, clt, C, entry) + + if update_mapping(entry.map, mapping): + break + return map_misses + + +def scope_op_in(state: SDFGState, op_in_map: Dict[str, sp.Expr], mapping, stack: AccessStack, clt: CacheLineTracker, C, entry=None): + # find the work and depth of each node + # for maps and nested SDFG, we do it recursively + scope_misses = 0 + scope_nodes = state.scope_children()[entry] + for node in scope_nodes: + # add node to map + op_in_map[get_uuid(node, state)] = 0 + if isinstance(node, nd.EntryNode): + # If the scope contains an entry node, we need to recursively analyze the sub-scope of the entry node first. + # The resulting work/depth are summarized into the entry node + map_misses = map_op_in(state, op_in_map, node, mapping, stack, clt, C) + # add up work for whole state, but also save work for this sub-scope scope in op_in_map + op_in_map[get_uuid(node, state)] = map_misses + scope_misses += map_misses + elif isinstance(node, nd.Tasklet): + # add up work for whole state, but also save work for this node in op_in_map + tasklet_misses = 0 + # analyze the memory accesses of this tasklet and whether they hit in cache or not + for e in state.in_edges(node): + if e.data.data in clt.array_info: + line_id = clt.cache_line_id(e.data.data, [x[0].subs(mapping) for x in e.data.subset.ranges]) + dist = stack.touch(line_id) + tasklet_misses += 1 if dist > C or dist == -1 else 0 + for e in state.out_edges(node): + if e.data.data in clt.array_info: + line_id = clt.cache_line_id(e.data.data, [x[0].subs(mapping) for x in e.data.subset.ranges]) + dist = stack.touch(line_id) + tasklet_misses += 1 if dist > C or dist == -1 else 0 + + # TODO: wcr edges. + scope_misses += tasklet_misses + op_in_map[get_uuid(node, state)] = tasklet_misses + elif isinstance(node, nd.NestedSDFG): + pass + # keep track of nested symbols: "symbols" maps local nested SDFG symbols to global symbols. + # We only want global symbols in our final work depth expressions. + # nested_syms = {} + # nested_syms.update(symbols) + # nested_syms.update(evaluate_symbols(symbols, node.symbol_mapping)) + # Nested SDFGs are recursively analyzed first. + nsdfg_misses = sdfg_op_in(node.sdfg, op_in_map, mapping, stack, clt, C) + + # nsdfg_work, nsdfg_depth = do_initial_subs(nsdfg_work, nsdfg_depth, equality_subs, subs1) + # add up work for whole state, but also save work for this nested SDFG in op_in_map + scope_misses += nsdfg_misses + op_in_map[get_uuid(node, state)] = nsdfg_misses + elif isinstance(node, nd.LibraryNode): + pass + # try: + # lib_node_work = LIBNODES_TO_WORK[type(node)](node, symbols, state) + # except KeyError: + # # add a symbol to the top level sdfg, such that the user can define it in the extension + # top_level_sdfg = state.parent + # # TODO: This symbol should now appear in the VS code extension in the SDFG analysis tab, + # # such that the user can define its value. But it doesn't... + # # How to achieve this? + # top_level_sdfg.add_symbol(f'{node.name}_work', int64) + # lib_node_work = sp.Symbol(f'{node.name}_work', positive=True) + # lib_node_depth = sp.sympify(-1) # not analyzed + # if analyze_tasklet != get_tasklet_work: + # # we are analyzing depth + # try: + # lib_node_depth = LIBNODES_TO_DEPTH[type(node)](node, symbols, state) + # except KeyError: + # top_level_sdfg = state.parent + # top_level_sdfg.add_symbol(f'{node.name}_depth', int64) + # lib_node_depth = sp.Symbol(f'{node.name}_depth', positive=True) + # lib_node_work, lib_node_depth = do_initial_subs(lib_node_work, lib_node_depth, equality_subs, subs1) + # work += lib_node_work + # op_in_map[get_uuid(node, state)] = (lib_node_work, lib_node_depth) + op_in_map[get_uuid(state)] = scope_misses + return scope_misses + +def sdfg_op_in(sdfg: SDFG, op_in_map: Dict[str, Tuple[sp.Expr, sp.Expr]], mapping, stack, clt, C): + # traverse this SDFG's states + curr_state = sdfg.start_state + total_misses = 0 + while True: + total_misses += scope_op_in(curr_state, op_in_map, mapping, stack, clt, C) + + if len(sdfg.out_edges(curr_state)) == 0: + # we reached the end state --> stop + break + else: + # take first edge with True condition + found = False + for e in sdfg.out_edges(curr_state): + if e.data.is_unconditional() or e.data.condition_sympy().subs(mapping) == True: + # save e's assignments in mapping and update curr_state + # replace values first with mapping, then update mapping + mapping.update({k: sp.sympify(v).subs(mapping) for k, v in e.data.assignments.items()}) + curr_state = e.dst + found = True + break + if not found: + print('WARNING: state has outgoing edges, but no condition of them can be' + 'evaluated as True and hence the analysis ends. ') + break + # traverse further + op_in_map[get_uuid(sdfg)] = total_misses + +def analyze_sdfg_op_in(sdfg: SDFG, op_in_map: Dict[str, sp.Expr], C, L): + # do some top level stuff + stack = AccessStack() + clt = CacheLineTracker(L) + for _, name, arr in sdfg.arrays_recursive(): + if isinstance(arr, Array): + if name in clt.array_info: + print('WARNING: This array name was already seen!!! Two arrays with the same name in the SDFG.') + clt.add_array(name, arr) + mapping = {} + + sdfg_op_in(sdfg, op_in_map, mapping, stack, clt, C) + + # now we have number of misses --> multiply each by L + for k, v in op_in_map.items(): + op_in_map[k] = v * L + + + + + + +################################################################################ +# Utility functions for running the analysis from the command line ############# +################################################################################ + + +def main() -> None: + + parser = argparse.ArgumentParser('operational_intensity', + usage='python operational_intensity.py [-h] filename', + description='Analyze the operational_intensity of an SDFG.') + + parser.add_argument('filename', type=str, help='The SDFG file to analyze.') + parser.add_argument('C', type=str, help='Cache size in bytes') + parser.add_argument('L', type=str, help='Cache line size in bytes') + + args = parser.parse_args() + + if not os.path.exists(args.filename): + print(args.filename, 'does not exist.') + exit() + + sdfg = SDFG.from_file(args.filename) + op_in_map = {} + analyze_sdfg_op_in(sdfg, op_in_map, int(args.C), int(args.L)) + + for k, v, in op_in_map.items(): + op_in_map[k] = str(sp.simplify(v)) + + result_whole_sdfg = op_in_map[get_uuid(sdfg)] + + print(80 * '-') + print("Bytes transferred:\t", result_whole_sdfg) + print(80 * '-') + + +if __name__ == '__main__': + main() + + + + + + From 55c15987fc3c6be98b05a8e73425c19030e10511 Mon Sep 17 00:00:00 2001 From: Cliff Hodel Date: Mon, 18 Sep 2023 18:43:32 +0200 Subject: [PATCH 17/18] Revert "start of op_in analysis" This reverts commit eb5a6f427d47f314e3254f681639cf3f155f77c8. --- .../sdfg/work_depth_analysis/op_in_helpers.py | 78 -- .../operational_intensity.py | 1004 ----------------- 2 files changed, 1082 deletions(-) delete mode 100644 dace/sdfg/work_depth_analysis/op_in_helpers.py delete mode 100644 dace/sdfg/work_depth_analysis/operational_intensity.py diff --git a/dace/sdfg/work_depth_analysis/op_in_helpers.py b/dace/sdfg/work_depth_analysis/op_in_helpers.py deleted file mode 100644 index c7c17741de..0000000000 --- a/dace/sdfg/work_depth_analysis/op_in_helpers.py +++ /dev/null @@ -1,78 +0,0 @@ -# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. -""" Contains class CacheLineTracker which keeps track of all arrays of an SDFG and their cache line position. -Further, contains class AccessStack which which corresponds to the stack used to compute the stack distance. """ - -from dace.data import Array - -class CacheLineTracker: - - def __init__(self, L) -> None: - self.array_info = {} - self.start_lines = {} - self.next_free_line = 0 - self.L = L - - def add_array(self, name: str, a: Array): - if name not in self.start_lines: - # new array encountered - self.array_info[name] = a - self.start_lines[name] = self.next_free_line - # increase next_free_line - self.next_free_line += (a.total_size * a.dtype.bytes + self.L - 1) // self.L # ceil division - - def cache_line_id(self, name: str, access: [int]): - arr = self.array_info[name] - one_d_index = 0 - for dim in range(len(access)): - i = access[dim] - one_d_index += (i + arr.offset[dim]) * arr.strides[dim] - - # divide by L to get the cache line id - return self.start_lines[name] + (one_d_index * arr.dtype.bytes) // self.L - - -class Node: - - def __init__(self, val: int, n=None) -> None: - self.v = val - self.next = n - - -class AccessStack: - """ A stack of cache line ids. For each memory access, we search the corresponding cache line id - in the stack, report its distance and move it to the top of the stack. If the id was not found, - we report a distance of -1. """ - - def __init__(self) -> None: - self.top = None - - def touch(self, id): - - curr = self.top - prev = None - found = False - distance = 0 - while curr is not None: - # check if we found id - if curr.v == id: - # take curr node out - if prev is not None: - prev.next = curr.next - curr.next = self.top - self.top = curr - - found = True - break - - # iterate further - prev = curr - curr = curr.next - distance += 1 - - if not found: - # we accessed this cache line for the first time ever - self.top = Node(id, self.top) - distance = -1 - - return distance - diff --git a/dace/sdfg/work_depth_analysis/operational_intensity.py b/dace/sdfg/work_depth_analysis/operational_intensity.py deleted file mode 100644 index af94c7f924..0000000000 --- a/dace/sdfg/work_depth_analysis/operational_intensity.py +++ /dev/null @@ -1,1004 +0,0 @@ -# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. -""" Analyses the operational intensity of an input SDFG. Can be used as a Python script -or from the VS Code extension. """ - -""" -Plan: -- For each memory access, we need to figure out its cache line and then we compute its stack distance. -- For that we model the actual stack, where we push all the memory acesses (What do we push exactly? -Cache line ids?? check typescript implementation for that information.) -- How do we know which array maps to which cache line? - Idea: for each new array encountered, just assume that it is cache line aligned and starts - at the next free cache line. TODO: check if this is how it usually behaves. Or are arrays - aligned further, like base address % x == 0 for some x bigger than cache line size? -- It is also important that we take data types into account for each array. -- For each mem access we increase the miss counter if stack distance > C(apacity) or it it is a -compulsory miss. Then, in the end we know how many bytes are transferred to cache. It is: - num_misses * L(ine size in bytes) - -- Parameters to our analysis are - - input SDFG - - C(ache capacity) - - L(ine size) -""" - - - - - - - - - - -import argparse -from collections import deque -from dace.sdfg import nodes as nd, propagation, InterstateEdge -from dace import SDFG, SDFGState, dtypes, int64 -from dace.subsets import Range -from typing import Tuple, Dict -import os -import sympy as sp -from copy import deepcopy -from dace.libraries.blas import MatMul -from dace.libraries.standard import Reduce, Transpose -from dace.symbolic import pystr_to_symbolic -import ast -import astunparse -import warnings - -from dace.sdfg.work_depth_analysis.helpers import get_uuid, find_loop_guards_tails_exits -from dace.sdfg.work_depth_analysis.assumptions import parse_assumptions -from dace.transformation.passes.symbol_ssa import StrictSymbolSSA -from dace.transformation.pass_pipeline import FixedPointPipeline - -from dace.data import Array -from dace.sdfg.work_depth_analysis.op_in_helpers import CacheLineTracker, AccessStack - -def get_array_size_symbols(sdfg): - """ - Returns all symbols that appear isolated in shapes of the SDFG's arrays. - These symbols can then be assumed to be positive. - - :note: This only works if a symbol appears in isolation, i.e. array A[N]. - If we have A[N+1], we cannot assume N to be positive. - :param sdfg: The SDFG in which it searches for symbols. - :return: A set containing symbols which we can assume to be positive. - """ - symbols = set() - for _, _, arr in sdfg.arrays_recursive(): - for s in arr.shape: - if isinstance(s, sp.Symbol): - symbols.add(s) - return symbols - - -def symeval(val, symbols): - """ - Takes a sympy expression and substitutes its symbols according to a dict { old_symbol: new_symbol}. - - :param val: The expression we are updating. - :param symbols: Dictionary of key value pairs { old_symbol: new_symbol}. - """ - first_replacement = {pystr_to_symbolic(k): pystr_to_symbolic('__REPLSYM_' + k) for k in symbols.keys()} - second_replacement = {pystr_to_symbolic('__REPLSYM_' + k): v for k, v in symbols.items()} - return sp.simplify(val.subs(first_replacement).subs(second_replacement)) - - -def evaluate_symbols(base, new): - result = {} - for k, v in new.items(): - result[k] = symeval(v, base) - return result - - -def count_work_matmul(node, symbols, state): - A_memlet = next(e for e in state.in_edges(node) if e.dst_conn == '_a') - B_memlet = next(e for e in state.in_edges(node) if e.dst_conn == '_b') - C_memlet = next(e for e in state.out_edges(node) if e.src_conn == '_c') - result = 2 # Multiply, add - # Batch - if len(C_memlet.data.subset) == 3: - result *= symeval(C_memlet.data.subset.size()[0], symbols) - # M*N - result *= symeval(C_memlet.data.subset.size()[-2], symbols) - result *= symeval(C_memlet.data.subset.size()[-1], symbols) - # K - result *= symeval(A_memlet.data.subset.size()[-1], symbols) - return sp.sympify(result) - - -def count_depth_matmul(node, symbols, state): - # optimal depth of a matrix multiplication is O(log(size of shared dimension)): - A_memlet = next(e for e in state.in_edges(node) if e.dst_conn == '_a') - size_shared_dimension = symeval(A_memlet.data.subset.size()[-1], symbols) - return bigo(sp.log(size_shared_dimension)) - - -def count_work_reduce(node, symbols, state): - result = 0 - if node.wcr is not None: - result += count_arithmetic_ops_code(node.wcr) - in_memlet = None - in_edges = state.in_edges(node) - if in_edges is not None and len(in_edges) == 1: - in_memlet = in_edges[0] - if in_memlet is not None and in_memlet.data.volume is not None: - result *= in_memlet.data.volume - else: - result = 0 - return sp.sympify(result) - - -def count_depth_reduce(node, symbols, state): - # optimal depth of reduction is log of the work - return bigo(sp.log(count_work_reduce(node, symbols, state))) - - -LIBNODES_TO_WORK = { - MatMul: count_work_matmul, - Transpose: lambda *args: 0, - Reduce: count_work_reduce, -} - -LIBNODES_TO_DEPTH = { - MatMul: count_depth_matmul, - Transpose: lambda *args: 0, - Reduce: count_depth_reduce, -} - -bigo = sp.Function('bigo') -PYFUNC_TO_ARITHMETICS = { - 'float': 0, - 'dace.float64': 0, - 'dace.int64': 0, - 'math.exp': 1, - 'exp': 1, - 'math.tanh': 1, - 'sin': 1, - 'cos': 1, - 'tanh': 1, - 'math.sqrt': 1, - 'sqrt': 1, - 'atan2:': 1, - 'min': 0, - 'max': 0, - 'ceiling': 0, - 'floor': 0, - 'abs': 0 -} - - -class ArithmeticCounter(ast.NodeVisitor): - - def __init__(self): - self.count = 0 - - def visit_BinOp(self, node): - if isinstance(node.op, ast.MatMult): - raise NotImplementedError('MatMult op count requires shape ' - 'inference') - self.count += 1 - return self.generic_visit(node) - - def visit_UnaryOp(self, node): - self.count += 1 - return self.generic_visit(node) - - def visit_Call(self, node): - fname = astunparse.unparse(node.func)[:-1] - if fname not in PYFUNC_TO_ARITHMETICS: - print( - 'WARNING: Unrecognized python function "%s". If this is a type conversion, like "dace.float64", then this is fine.' - % fname) - return self.generic_visit(node) - self.count += PYFUNC_TO_ARITHMETICS[fname] - return self.generic_visit(node) - - def visit_AugAssign(self, node): - return self.visit_BinOp(node) - - def visit_For(self, node): - raise NotImplementedError - - def visit_While(self, node): - raise NotImplementedError - - -def count_arithmetic_ops_code(code): - ctr = ArithmeticCounter() - if isinstance(code, (tuple, list)): - for stmt in code: - ctr.visit(stmt) - elif isinstance(code, str): - ctr.visit(ast.parse(code)) - else: - ctr.visit(code) - return ctr.count - - -class DepthCounter(ast.NodeVisitor): - # so far this is identical to the ArithmeticCounter above. - def __init__(self): - self.count = 0 - - def visit_BinOp(self, node): - if isinstance(node.op, ast.MatMult): - raise NotImplementedError('MatMult op count requires shape ' - 'inference') - self.count += 1 - return self.generic_visit(node) - - def visit_UnaryOp(self, node): - self.count += 1 - return self.generic_visit(node) - - def visit_Call(self, node): - fname = astunparse.unparse(node.func)[:-1] - if fname not in PYFUNC_TO_ARITHMETICS: - print( - 'WARNING: Unrecognized python function "%s". If this is a type conversion, like "dace.float64", then this is fine.' - % fname) - return self.generic_visit(node) - self.count += PYFUNC_TO_ARITHMETICS[fname] - return self.generic_visit(node) - - def visit_AugAssign(self, node): - return self.visit_BinOp(node) - - def visit_For(self, node): - raise NotImplementedError - - def visit_While(self, node): - raise NotImplementedError - - -def count_depth_code(code): - # so far this is the same as the work counter, since work = depth for each tasklet, as we can't assume any parallelism - ctr = ArithmeticCounter() - if isinstance(code, (tuple, list)): - for stmt in code: - ctr.visit(stmt) - elif isinstance(code, str): - ctr.visit(ast.parse(code)) - else: - ctr.visit(code) - return ctr.count - - -def tasklet_work(tasklet_node, state): - if tasklet_node.code.language == dtypes.Language.CPP: - # simplified work analysis for CPP tasklets. - for oedge in state.out_edges(tasklet_node): - return oedge.data.num_accesses - elif tasklet_node.code.language == dtypes.Language.Python: - return count_arithmetic_ops_code(tasklet_node.code.code) - else: - # other languages not implemented, count whole tasklet as work of 1 - warnings.warn('Work of tasklets only properly analyzed for Python or CPP. For all other ' - 'languages work = 1 will be counted for each tasklet.') - return 1 - - -def tasklet_depth(tasklet_node, state): - if tasklet_node.code.language == dtypes.Language.CPP: - # For now we simply take depth == work for CPP tasklets. - for oedge in state.out_edges(tasklet_node): - return oedge.data.num_accesses - if tasklet_node.code.language == dtypes.Language.Python: - return count_depth_code(tasklet_node.code.code) - else: - # other languages not implemented, count whole tasklet as work of 1 - warnings.warn('Depth of tasklets only properly analyzed for Python code. For all other ' - 'languages depth = 1 will be counted for each tasklet.') - return 1 - - -def get_tasklet_work(node, state): - return sp.sympify(tasklet_work(node, state)), sp.sympify(-1) - - -def get_tasklet_work_depth(node, state): - return sp.sympify(tasklet_work(node, state)), sp.sympify(tasklet_depth(node, state)) - - -def get_tasklet_avg_par(node, state): - return sp.sympify(tasklet_work(node, state)), sp.sympify(tasklet_depth(node, state)) - - -def update_value_map(old, new): - # add new assignments to old - old.update({k: v for k, v in new.items() if k not in old}) - # check for conflicts: - for k, v in new.items(): - if k in old and old[k] != v: - # conflict detected --> forget this mapping completely - old.pop(k) - - -def do_initial_subs(w, d, eq, subs1): - """ - Calls subs three times for the give (w)ork and (d)epth values. - """ - return sp.simplify(w.subs(eq[0]).subs(eq[1]).subs(subs1)), sp.simplify(d.subs(eq[0]).subs(eq[1]).subs(subs1)) - - -def sdfg_op_in_OLD(sdfg: SDFG, op_in_map: Dict[str, Tuple[sp.Expr, sp.Expr]], analyze_tasklet, symbols: Dict[str, str], - detailed_analysis: bool, equality_subs: Tuple[Dict[str, sp.Symbol], Dict[str, sp.Expr]], - subs1: Dict[str, sp.Expr]) -> Tuple[sp.Expr, sp.Expr]: - """ - Analyze the work and depth of a given SDFG. - First we determine the work and depth of each state. Then we break loops in the state machine, such that we get a DAG. - Lastly, we compute the path with most work and the path with the most depth in order to get the total work depth. - - :param sdfg: The SDFG to analyze. - :param op_in_map: Dictionary which will save the result. - :param analyze_tasklet: Function used to analyze tasklet nodes. - :param symbols: A dictionary mapping local nested SDFG symbols to global symbols. - :param detailed_analysis: If True, detailed analysis gets used. For each branch, we keep track of its condition - and work depth values for both branches. If False, the worst-case branch is taken. Discouraged to use on bigger SDFGs, - as computation time sky-rockets, since expression can became HUGE (depending on number of branches etc.). - :param equality_subs: Substitution dict taking care of the equality assumptions. - :param subs1: First substitution dict for greater/lesser assumptions. - :return: A tuple containing the work and depth of the SDFG. - """ - - # First determine the work and depth of each state individually. - # Keep track of the work and depth for each state in a dictionary, where work and depth are multiplied by the number - # of times the state will be executed. - state_depths: Dict[SDFGState, sp.Expr] = {} - state_works: Dict[SDFGState, sp.Expr] = {} - for state in sdfg.nodes(): - state_work, state_depth = state_op_in_OLD(state, op_in_map, analyze_tasklet, symbols, detailed_analysis, - equality_subs, subs1) - - # Substitutions for state_work and state_depth already performed, but state.executions needs to be subs'd now. - state_work = sp.simplify(state_work * - state.executions.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1)) - state_depth = sp.simplify(state_depth * - state.executions.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1)) - - state_works[state], state_depths[state] = state_work, state_depth - op_in_map[get_uuid(state)] = (state_works[state], state_depths[state]) - - # Prepare the SDFG for a depth analysis by breaking loops. This removes the edge between the last loop state and - # the guard, and instead places an edge between the last loop state and the exit state. - # This transforms the state machine into a DAG. Hence, we can find the "heaviest" and "deepest" paths in linear time. - # Additionally, construct a dummy exit state and connect every state that has no outgoing edges to it. - - # identify all loops in the SDFG - nodes_oNodes_exits = find_loop_guards_tails_exits(sdfg._nx) - - # Now we need to go over each triple (node, oNode, exits). For each triple, we - # - remove edge (oNode, node), i.e. the backward edge - # - for all exits e, add edge (oNode, e). This edge may already exist - # - remove edge from node to exit (if present, i.e. while-do loop) - # - This ensures that every node with > 1 outgoing edge is a branch guard - # - useful for detailed anaylsis. - for node, oNode, exits in nodes_oNodes_exits: - sdfg.remove_edge(sdfg.edges_between(oNode, node)[0]) - for e in exits: - if len(sdfg.edges_between(oNode, e)) == 0: - # no edge there yet - sdfg.add_edge(oNode, e, InterstateEdge()) - if len(sdfg.edges_between(node, e)) > 0: - # edge present --> remove it - sdfg.remove_edge(sdfg.edges_between(node, e)[0]) - - # add a dummy exit to the SDFG, such that each path ends there. - dummy_exit = sdfg.add_state('dummy_exit') - for state in sdfg.nodes(): - if len(sdfg.out_edges(state)) == 0 and state != dummy_exit: - sdfg.add_edge(state, dummy_exit, InterstateEdge()) - - # These two dicts save the current length of the "heaviest", resp. "deepest", paths at each state. - work_map: Dict[SDFGState, sp.Expr] = {} - depth_map: Dict[SDFGState, sp.Expr] = {} - # Keeps track of assignments done on InterstateEdges. - state_value_map: Dict[SDFGState, Dict[sp.Symbol, sp.Symbol]] = {} - # The dummy state has 0 work and depth. - state_depths[dummy_exit] = sp.sympify(0) - state_works[dummy_exit] = sp.sympify(0) - - # Perform a BFS traversal of the state machine and calculate the maximum work / depth at each state. Only advance to - # the next state in the BFS if all incoming edges have been visited, to ensure the maximum work / depth expressions - # have been calculated. - traversal_q = deque() - traversal_q.append((sdfg.start_state, sp.sympify(0), sp.sympify(0), None, [], [], {})) - visited = set() - - while traversal_q: - state, depth, work, ie, condition_stack, common_subexpr_stack, value_map = traversal_q.popleft() - - if ie is not None: - visited.add(ie) - - if state in state_value_map: - # update value map: - update_value_map(state_value_map[state], value_map) - else: - state_value_map[state] = value_map - - # ignore assignments such as tmp=x[0], as those do not give much information. - value_map = {k: v for k, v in state_value_map[state].items() if '[' not in k and '[' not in v} - n_depth = sp.simplify((depth + state_depths[state]).subs(value_map)) - n_work = sp.simplify((work + state_works[state]).subs(value_map)) - - # If we are analysing average parallelism, we don't search "heaviest" and "deepest" paths separately, but we want one - # single path with the least average parallelsim (of all paths with more than 0 work). - if analyze_tasklet == get_tasklet_avg_par: - if state in depth_map: # this means we have already visited this state before - cse = common_subexpr_stack.pop() - # if current path has 0 depth (--> 0 work as well), we don't do anything. - if n_depth != 0: - # check if we need to update the work and depth of the current state - # we update if avg parallelism of new incoming path is less than current avg parallelism - if depth_map[state] == 0: - # old value was divided by zero --> we take new value anyway - depth_map[state] = cse[1] + n_depth - work_map[state] = cse[0] + n_work - else: - old_avg_par = (cse[0] + work_map[state]) / (cse[1] + depth_map[state]) - new_avg_par = (cse[0] + n_work) / (cse[1] + n_depth) - # we take either old work/depth or new work/depth (or both if we cannot determine which one is greater) - depth_map[state] = cse[1] + sp.Piecewise((n_depth, sp.simplify(new_avg_par < old_avg_par)), - (depth_map[state], True)) - work_map[state] = cse[0] + sp.Piecewise((n_work, sp.simplify(new_avg_par < old_avg_par)), - (work_map[state], True)) - else: - depth_map[state] = n_depth - work_map[state] = n_work - else: - # search heaviest and deepest path separately - if state in depth_map: # and consequently also in work_map - # This cse value would appear in both arguments of the Max. Hence, for performance reasons, - # we pull it out of the Max expression. - # Example: We do cse + Max(a, b) instead of Max(cse + a, cse + b). - # This increases performance drastically, expecially since we avoid nesting Max expressions - # for cases where cse itself contains Max operators. - cse = common_subexpr_stack.pop() - if detailed_analysis: - # This MAX should be covered in the more detailed analysis - cond = condition_stack.pop() - work_map[state] = cse[0] + sp.Piecewise((work_map[state], sp.Not(cond)), (n_work, cond)) - depth_map[state] = cse[1] + sp.Piecewise((depth_map[state], sp.Not(cond)), (n_depth, cond)) - else: - work_map[state] = cse[0] + sp.Max(work_map[state], n_work) - depth_map[state] = cse[1] + sp.Max(depth_map[state], n_depth) - else: - depth_map[state] = n_depth - work_map[state] = n_work - - out_edges = sdfg.out_edges(state) - # only advance after all incoming edges were visited (meaning that current work depth values of state are final). - if any(iedge not in visited for iedge in sdfg.in_edges(state)): - pass - else: - for oedge in out_edges: - if len(out_edges) > 1: - # It is important to copy these stacks. Else both branches operate on the same stack. - # state is a branch guard --> save condition on stack - new_cond_stack = list(condition_stack) - new_cond_stack.append(oedge.data.condition_sympy()) - # same for common_subexr_stack - new_cse_stack = list(common_subexpr_stack) - new_cse_stack.append((work_map[state], depth_map[state])) - # same for value_map - new_value_map = dict(state_value_map[state]) - new_value_map.update({sp.Symbol(k): sp.Symbol(v) for k, v in oedge.data.assignments.items()}) - traversal_q.append((oedge.dst, 0, 0, oedge, new_cond_stack, new_cse_stack, new_value_map)) - else: - value_map.update(oedge.data.assignments) - traversal_q.append((oedge.dst, depth_map[state], work_map[state], oedge, condition_stack, - common_subexpr_stack, value_map)) - - try: - max_depth = depth_map[dummy_exit] - max_work = work_map[dummy_exit] - except KeyError: - # If we get a KeyError above, this means that the traversal never reached the dummy_exit state. - # This happens if the loops were not properly detected and broken. - raise Exception( - 'Analysis failed, since not all loops got detected. It may help to use more structured loop constructs.') - - sdfg_result = (max_work, max_depth) - op_in_map[get_uuid(sdfg)] = sdfg_result - return sdfg_result - - -def scope_op_in_OLD(state: SDFGState, - op_in_map: Dict[str, sp.Expr], - analyze_tasklet, - symbols: Dict[str, str], - detailed_analysis: bool, - equality_subs: Tuple[Dict[str, sp.Symbol], Dict[str, sp.Expr]], - subs1: Dict[str, sp.Expr], - entry: nd.EntryNode = None) -> Tuple[sp.Expr, sp.Expr]: - """ - Analyze the work and depth of a scope. - This works by traversing through the scope analyzing the work and depth of each encountered node. - Depending on what kind of node we encounter, we do the following: - - EntryNode: Recursively analyze work depth of scope. - - Tasklet: use analyze_tasklet to get work depth of tasklet node. - - NestedSDFG: After translating its local symbols to global symbols, we analyze the nested SDFG recursively. - - LibraryNode: Library nodes are analyzed with special functions depending on their type. - Work inside a state can simply be summed up, but for the depth we need to find the longest path. Since dataflow is a DAG, - this can be done in linear time by traversing the graph in topological order. - - :param state: The state in which the scope to analyze is contained. - :param op_in_map: Dictionary saving the final result for each SDFG element. - :param analyze_tasklet: Function used to analyze tasklets. Either analyzes just work, work and depth or average parallelism. - :param symbols: A dictionary mapping local nested SDFG symbols to global symbols. - :param detailed_analysis: If True, detailed analysis gets used. For each branch, we keep track of its condition - and work depth values for both branches. If False, the worst-case branch is taken. Discouraged to use on bigger SDFGs, - as computation time sky-rockets, since expression can became HUGE (depending on number of branches etc.). - :param equality_subs: Substitution dict taking care of the equality assumptions. - :param subs1: First substitution dict for greater/lesser assumptions. - :param entry: The entry node of the scope to analyze. If None, the entire state is analyzed. - :return: A tuple containing the work and depth of the scope. - """ - - # find the work and depth of each node - # for maps and nested SDFG, we do it recursively - work = sp.sympify(0) - max_depth = sp.sympify(0) - scope_nodes = state.scope_children()[entry] - scope_exit = None if entry is None else state.exit_node(entry) - for node in scope_nodes: - # add node to map - op_in_map[get_uuid(node, state)] = (sp.sympify(0), sp.sympify(0)) - if isinstance(node, nd.EntryNode): - # If the scope contains an entry node, we need to recursively analyze the sub-scope of the entry node first. - # The resulting work/depth are summarized into the entry node - s_work, s_depth = scope_op_in_OLD(state, op_in_map, analyze_tasklet, symbols, detailed_analysis, - equality_subs, subs1, node) - s_work, s_depth = do_initial_subs(s_work, s_depth, equality_subs, subs1) - # add up work for whole state, but also save work for this sub-scope scope in op_in_map - work += s_work - op_in_map[get_uuid(node, state)] = (s_work, s_depth) - elif node == scope_exit: - # don't do anything for exit nodes, everthing handled already in the corresponding entry node. - pass - elif isinstance(node, nd.Tasklet): - # add up work for whole state, but also save work for this node in op_in_map - # t_work, t_depth = analyze_tasklet(node, state) - - # analyze the memory accesses of this tasklet and whether they hit in cache or not - print('tasklet') - t_work, t_depth = sp.sympify(100), sp.sympify(100) - - - # check if tasklet has any outgoing wcr edges - for e in state.out_edges(node): - if e.data.wcr is not None: - t_work += count_arithmetic_ops_code(e.data.wcr) - t_work, t_depth = do_initial_subs(t_work, t_depth, equality_subs, subs1) - work += t_work - op_in_map[get_uuid(node, state)] = (t_work, t_depth) - elif isinstance(node, nd.NestedSDFG): - # keep track of nested symbols: "symbols" maps local nested SDFG symbols to global symbols. - # We only want global symbols in our final work depth expressions. - nested_syms = {} - nested_syms.update(symbols) - nested_syms.update(evaluate_symbols(symbols, node.symbol_mapping)) - # Nested SDFGs are recursively analyzed first. - nsdfg_work, nsdfg_depth = sdfg_op_in_OLD(node.sdfg, op_in_map, analyze_tasklet, nested_syms, - detailed_analysis, equality_subs, subs1) - - nsdfg_work, nsdfg_depth = do_initial_subs(nsdfg_work, nsdfg_depth, equality_subs, subs1) - # add up work for whole state, but also save work for this nested SDFG in op_in_map - work += nsdfg_work - op_in_map[get_uuid(node, state)] = (nsdfg_work, nsdfg_depth) - elif isinstance(node, nd.LibraryNode): - try: - lib_node_work = LIBNODES_TO_WORK[type(node)](node, symbols, state) - except KeyError: - # add a symbol to the top level sdfg, such that the user can define it in the extension - top_level_sdfg = state.parent - # TODO: This symbol should now appear in the VS code extension in the SDFG analysis tab, - # such that the user can define its value. But it doesn't... - # How to achieve this? - top_level_sdfg.add_symbol(f'{node.name}_work', int64) - lib_node_work = sp.Symbol(f'{node.name}_work', positive=True) - lib_node_depth = sp.sympify(-1) # not analyzed - if analyze_tasklet != get_tasklet_work: - # we are analyzing depth - try: - lib_node_depth = LIBNODES_TO_DEPTH[type(node)](node, symbols, state) - except KeyError: - top_level_sdfg = state.parent - top_level_sdfg.add_symbol(f'{node.name}_depth', int64) - lib_node_depth = sp.Symbol(f'{node.name}_depth', positive=True) - lib_node_work, lib_node_depth = do_initial_subs(lib_node_work, lib_node_depth, equality_subs, subs1) - work += lib_node_work - op_in_map[get_uuid(node, state)] = (lib_node_work, lib_node_depth) - - if entry is not None: - # If the scope being analyzed is a map, multiply the work by the number of iterations of the map. - if isinstance(entry, nd.MapEntry): - nmap: nd.Map = entry.map - range: Range = nmap.range - n_exec = range.num_elements_exact() - work = sp.simplify(work * n_exec.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1)) - else: - print('WARNING: Only Map scopes are supported in work analysis for now. Assuming 1 iteration.') - - # Work inside a state can simply be summed up. But now we need to find the depth of a state (i.e. longest path). - # Since dataflow graph is a DAG, this can be done in linear time. - max_depth = sp.sympify(0) - # only do this if we are analyzing depth - if analyze_tasklet == get_tasklet_work_depth or analyze_tasklet == get_tasklet_avg_par: - # Calculate the maximum depth of the scope by finding the 'deepest' path from the source to the sink. This is done by - # a traversal in topological order, where each node propagates its current max depth for all incoming paths. - traversal_q = deque() - visited = set() - # find all starting nodes - if entry: - # the entry is the starting node - traversal_q.append((entry, sp.sympify(0), None)) - else: - for node in scope_nodes: - if len(state.in_edges(node)) == 0: - # This node is a start node of the traversal - traversal_q.append((node, sp.sympify(0), None)) - # this map keeps track of the length of the longest path ending at each state so far seen. - depth_map = {} - wcr_depth_map = {} - while traversal_q: - node, in_depth, in_edge = traversal_q.popleft() - - if in_edge is not None: - visited.add(in_edge) - - n_depth = sp.simplify(in_depth + op_in_map[get_uuid(node, state)][1]) - - if node in depth_map: - depth_map[node] = sp.Max(depth_map[node], n_depth) - else: - depth_map[node] = n_depth - - out_edges = state.out_edges(node) - # Only advance to next node, if all incoming edges have been visited or the current node is the entry (aka starting node). - # If the current node is the exit of the scope, we stop, such that we don't leave the scope. - if (all(iedge in visited for iedge in state.in_edges(node)) or node == entry) and node != scope_exit: - # If we encounter a nested map, we must not analyze its contents (as they have already been recursively analyzed). - # Hence, we continue from the outgoing edges of the corresponding exit. - if isinstance(node, nd.EntryNode) and node != entry: - exit_node = state.exit_node(node) - # replace out_edges with the out_edges of the scope exit node - out_edges = state.out_edges(exit_node) - for oedge in out_edges: - # check for wcr - wcr_depth = sp.sympify(0) - if oedge.data.wcr is not None: - # This division gives us the number of writes to each single memory location, which is the depth - # as these need to be sequential (without assumptions on HW etc). - wcr_depth = oedge.data.volume / oedge.data.subset.num_elements() - if get_uuid(node, state) in wcr_depth_map: - # max - wcr_depth_map[get_uuid(node, state)] = sp.Max(wcr_depth_map[get_uuid(node, state)], - wcr_depth) - else: - wcr_depth_map[get_uuid(node, state)] = wcr_depth - # We do not need to propagate the wcr_depth to MapExits, since else this will result in depth N + 1 for Maps of range N. - wcr_depth = wcr_depth if not isinstance(oedge.dst, nd.MapExit) else sp.sympify(0) - - # only append if it's actually new information - # this e.g. helps for huge nested SDFGs with lots of inputs/outputs inside a map scope - append = True - for n, d, _ in traversal_q: - if oedge.dst == n and depth_map[node] + wcr_depth == d: - append = False - break - if append: - traversal_q.append((oedge.dst, depth_map[node] + wcr_depth, oedge)) - else: - visited.add(oedge) - if len(out_edges) == 0 or node == scope_exit: - # We have reached an end node --> update max_depth - max_depth = sp.Max(max_depth, depth_map[node]) - - for uuid in wcr_depth_map: - op_in_map[uuid] = (op_in_map[uuid][0], op_in_map[uuid][1] + wcr_depth_map[uuid]) - # summarise work / depth of the whole scope in the dictionary - scope_result = (work, max_depth) - op_in_map[get_uuid(state)] = scope_result - return scope_result - - -def state_op_in_OLD(state: SDFGState, op_in_map: Dict[str, sp.Expr], analyze_tasklet, symbols, detailed_analysis, - equality_subs, subs1) -> Tuple[sp.Expr, sp.Expr]: - """ - Analyze the work and depth of a state. - - :param state: The state to analyze. - :param op_in_map: The result will be saved to this map. - :param analyze_tasklet: Function used to analyze tasklet nodes. - :param symbols: A dictionary mapping local nested SDFG symbols to global symbols. - :param detailed_analysis: If True, detailed analysis gets used. For each branch, we keep track of its condition - and work depth values for both branches. If False, the worst-case branch is taken. Discouraged to use on bigger SDFGs, - as computation time sky-rockets, since expression can became HUGE (depending on number of branches etc.). - :param equality_subs: Substitution dict taking care of the equality assumptions. - :param subs1: First substitution dict for greater/lesser assumptions. - :return: A tuple containing the work and depth of the state. - """ - work, depth = scope_op_in_OLD(state, op_in_map, analyze_tasklet, symbols, detailed_analysis, equality_subs, subs1, - None) - return work, depth - - -def analyze_sdfg(sdfg: SDFG, op_in_map: Dict[str, sp.Expr], analyze_tasklet, assumptions: [str], - detailed_analysis: bool) -> None: - """ - Analyze a given SDFG. We can either analyze work, work and depth or average parallelism. - - :note: SDFGs should have split interstate edges. This means there should be no interstate edges containing both a - condition and an assignment. - :param sdfg: The SDFG to analyze. - :param op_in_map: Dictionary of SDFG elements to (work, depth) tuples. Result will be saved in here. - :param analyze_tasklet: Function used to analyze tasklet nodes. Analyzes either just work, work and depth or average parallelism. - :param assumptions: List of strings. Each string corresponds to one assumption for some symbol, e.g. 'N>5'. - :param detailed_analysis: If True, detailed analysis gets used. For each branch, we keep track of its condition - and work depth values for both branches. If False, the worst-case branch is taken. Discouraged to use on bigger SDFGs, - as computation time sky-rockets, since expression can became HUGE (depending on number of branches etc.). - """ - - # deepcopy such that original sdfg not changed - sdfg = deepcopy(sdfg) - - # apply SSA pass - pipeline = FixedPointPipeline([StrictSymbolSSA()]) - pipeline.apply_pass(sdfg, {}) - - array_symbols = get_array_size_symbols(sdfg) - # parse assumptions - equality_subs, all_subs = parse_assumptions(assumptions if assumptions is not None else [], array_symbols) - - # Run state propagation for all SDFGs recursively. This is necessary to determine the number of times each state - # will be executed, or to determine upper bounds for that number (such as in the case of branching) - for sd in sdfg.all_sdfgs_recursive(): - propagation.propagate_states(sd, concretize_dynamic_unbounded=True) - - # Analyze the work and depth of the SDFG. - symbols = {} - sdfg_op_in_OLD(sdfg, op_in_map, analyze_tasklet, symbols, detailed_analysis, equality_subs, - all_subs[0][0] if len(all_subs) > 0 else {}) - - for k, (v_w, v_d) in op_in_map.items(): - # The symeval replaces nested SDFG symbols with their global counterparts. - v_w, v_d = do_subs(v_w, v_d, all_subs) - v_w = symeval(v_w, symbols) - v_d = symeval(v_d, symbols) - op_in_map[k] = (v_w, v_d) - - -def do_subs(work, depth, all_subs): - """ - Handles all substitutions beyond the equality substitutions and the first substitution. - :param work: Some work expression. - :param depth: Some depth expression. - :param all_subs: List of substitution pairs to perform. - :return: Work depth expressions after doing all substitutions. - """ - # first do subs2 of first sub - # then do all the remaining subs - subs2 = all_subs[0][1] if len(all_subs) > 0 else {} - work, depth = sp.simplify(sp.sympify(work).subs(subs2)), sp.simplify(sp.sympify(depth).subs(subs2)) - for i in range(1, len(all_subs)): - subs1, subs2 = all_subs[i] - work, depth = sp.simplify(work.subs(subs1)), sp.simplify(depth.subs(subs1)) - work, depth = sp.simplify(work.subs(subs2)), sp.simplify(depth.subs(subs2)) - return work, depth - - - - - - - - - - - - -def update_mapping(map, mapping): - # update the map params and return False - # if all iterations exhausted, return True - # always increase the last one, if it is exhausted, increase the next one and so forth - map_exhausted = True - for p, range in zip(map.params[::-1], map.range[::-1]): - curr_value = mapping[p] - if curr_value < range[1]: - # update this value and we done - mapping[p] = curr_value + range[2] - map_exhausted = False - break - else: - # set current param to start again and continue - mapping[p] = range[0] - return map_exhausted - - - -def map_op_in(state: SDFGState, op_in_map: Dict[str, sp.Expr], entry, mapping, stack, clt, C): - # we are inside a map --> we need to iterate over the map range and check each memory access. - for p, range in zip(entry.map.params, entry.map.range): - # map each map iteration variable to its start - mapping[p] = range[0] - map_misses = 0 - while True: - # do analysis of map contents - map_misses += scope_op_in(state, op_in_map, mapping, stack, clt, C, entry) - - if update_mapping(entry.map, mapping): - break - return map_misses - - -def scope_op_in(state: SDFGState, op_in_map: Dict[str, sp.Expr], mapping, stack: AccessStack, clt: CacheLineTracker, C, entry=None): - # find the work and depth of each node - # for maps and nested SDFG, we do it recursively - scope_misses = 0 - scope_nodes = state.scope_children()[entry] - for node in scope_nodes: - # add node to map - op_in_map[get_uuid(node, state)] = 0 - if isinstance(node, nd.EntryNode): - # If the scope contains an entry node, we need to recursively analyze the sub-scope of the entry node first. - # The resulting work/depth are summarized into the entry node - map_misses = map_op_in(state, op_in_map, node, mapping, stack, clt, C) - # add up work for whole state, but also save work for this sub-scope scope in op_in_map - op_in_map[get_uuid(node, state)] = map_misses - scope_misses += map_misses - elif isinstance(node, nd.Tasklet): - # add up work for whole state, but also save work for this node in op_in_map - tasklet_misses = 0 - # analyze the memory accesses of this tasklet and whether they hit in cache or not - for e in state.in_edges(node): - if e.data.data in clt.array_info: - line_id = clt.cache_line_id(e.data.data, [x[0].subs(mapping) for x in e.data.subset.ranges]) - dist = stack.touch(line_id) - tasklet_misses += 1 if dist > C or dist == -1 else 0 - for e in state.out_edges(node): - if e.data.data in clt.array_info: - line_id = clt.cache_line_id(e.data.data, [x[0].subs(mapping) for x in e.data.subset.ranges]) - dist = stack.touch(line_id) - tasklet_misses += 1 if dist > C or dist == -1 else 0 - - # TODO: wcr edges. - scope_misses += tasklet_misses - op_in_map[get_uuid(node, state)] = tasklet_misses - elif isinstance(node, nd.NestedSDFG): - pass - # keep track of nested symbols: "symbols" maps local nested SDFG symbols to global symbols. - # We only want global symbols in our final work depth expressions. - # nested_syms = {} - # nested_syms.update(symbols) - # nested_syms.update(evaluate_symbols(symbols, node.symbol_mapping)) - # Nested SDFGs are recursively analyzed first. - nsdfg_misses = sdfg_op_in(node.sdfg, op_in_map, mapping, stack, clt, C) - - # nsdfg_work, nsdfg_depth = do_initial_subs(nsdfg_work, nsdfg_depth, equality_subs, subs1) - # add up work for whole state, but also save work for this nested SDFG in op_in_map - scope_misses += nsdfg_misses - op_in_map[get_uuid(node, state)] = nsdfg_misses - elif isinstance(node, nd.LibraryNode): - pass - # try: - # lib_node_work = LIBNODES_TO_WORK[type(node)](node, symbols, state) - # except KeyError: - # # add a symbol to the top level sdfg, such that the user can define it in the extension - # top_level_sdfg = state.parent - # # TODO: This symbol should now appear in the VS code extension in the SDFG analysis tab, - # # such that the user can define its value. But it doesn't... - # # How to achieve this? - # top_level_sdfg.add_symbol(f'{node.name}_work', int64) - # lib_node_work = sp.Symbol(f'{node.name}_work', positive=True) - # lib_node_depth = sp.sympify(-1) # not analyzed - # if analyze_tasklet != get_tasklet_work: - # # we are analyzing depth - # try: - # lib_node_depth = LIBNODES_TO_DEPTH[type(node)](node, symbols, state) - # except KeyError: - # top_level_sdfg = state.parent - # top_level_sdfg.add_symbol(f'{node.name}_depth', int64) - # lib_node_depth = sp.Symbol(f'{node.name}_depth', positive=True) - # lib_node_work, lib_node_depth = do_initial_subs(lib_node_work, lib_node_depth, equality_subs, subs1) - # work += lib_node_work - # op_in_map[get_uuid(node, state)] = (lib_node_work, lib_node_depth) - op_in_map[get_uuid(state)] = scope_misses - return scope_misses - -def sdfg_op_in(sdfg: SDFG, op_in_map: Dict[str, Tuple[sp.Expr, sp.Expr]], mapping, stack, clt, C): - # traverse this SDFG's states - curr_state = sdfg.start_state - total_misses = 0 - while True: - total_misses += scope_op_in(curr_state, op_in_map, mapping, stack, clt, C) - - if len(sdfg.out_edges(curr_state)) == 0: - # we reached the end state --> stop - break - else: - # take first edge with True condition - found = False - for e in sdfg.out_edges(curr_state): - if e.data.is_unconditional() or e.data.condition_sympy().subs(mapping) == True: - # save e's assignments in mapping and update curr_state - # replace values first with mapping, then update mapping - mapping.update({k: sp.sympify(v).subs(mapping) for k, v in e.data.assignments.items()}) - curr_state = e.dst - found = True - break - if not found: - print('WARNING: state has outgoing edges, but no condition of them can be' - 'evaluated as True and hence the analysis ends. ') - break - # traverse further - op_in_map[get_uuid(sdfg)] = total_misses - -def analyze_sdfg_op_in(sdfg: SDFG, op_in_map: Dict[str, sp.Expr], C, L): - # do some top level stuff - stack = AccessStack() - clt = CacheLineTracker(L) - for _, name, arr in sdfg.arrays_recursive(): - if isinstance(arr, Array): - if name in clt.array_info: - print('WARNING: This array name was already seen!!! Two arrays with the same name in the SDFG.') - clt.add_array(name, arr) - mapping = {} - - sdfg_op_in(sdfg, op_in_map, mapping, stack, clt, C) - - # now we have number of misses --> multiply each by L - for k, v in op_in_map.items(): - op_in_map[k] = v * L - - - - - - -################################################################################ -# Utility functions for running the analysis from the command line ############# -################################################################################ - - -def main() -> None: - - parser = argparse.ArgumentParser('operational_intensity', - usage='python operational_intensity.py [-h] filename', - description='Analyze the operational_intensity of an SDFG.') - - parser.add_argument('filename', type=str, help='The SDFG file to analyze.') - parser.add_argument('C', type=str, help='Cache size in bytes') - parser.add_argument('L', type=str, help='Cache line size in bytes') - - args = parser.parse_args() - - if not os.path.exists(args.filename): - print(args.filename, 'does not exist.') - exit() - - sdfg = SDFG.from_file(args.filename) - op_in_map = {} - analyze_sdfg_op_in(sdfg, op_in_map, int(args.C), int(args.L)) - - for k, v, in op_in_map.items(): - op_in_map[k] = str(sp.simplify(v)) - - result_whole_sdfg = op_in_map[get_uuid(sdfg)] - - print(80 * '-') - print("Bytes transferred:\t", result_whole_sdfg) - print(80 * '-') - - -if __name__ == '__main__': - main() - - - - - - From ef791242498632ff437f4ee10e940f57cee79dca Mon Sep 17 00:00:00 2001 From: Cliff Hodel Date: Thu, 21 Sep 2023 10:23:20 +0200 Subject: [PATCH 18/18] changes according to comments --- dace/sdfg/work_depth_analysis/assumptions.py | 2 +- dace/sdfg/work_depth_analysis/work_depth.py | 74 ++++++++++++-------- 2 files changed, 45 insertions(+), 31 deletions(-) diff --git a/dace/sdfg/work_depth_analysis/assumptions.py b/dace/sdfg/work_depth_analysis/assumptions.py index c7e439cf51..6e311cde0c 100644 --- a/dace/sdfg/work_depth_analysis/assumptions.py +++ b/dace/sdfg/work_depth_analysis/assumptions.py @@ -1,7 +1,7 @@ # Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved. import sympy as sp -from typing import Tuple, Dict +from typing import Dict class UnionFind: diff --git a/dace/sdfg/work_depth_analysis/work_depth.py b/dace/sdfg/work_depth_analysis/work_depth.py index b05ccc70ae..3549e86a20 100644 --- a/dace/sdfg/work_depth_analysis/work_depth.py +++ b/dace/sdfg/work_depth_analysis/work_depth.py @@ -5,7 +5,7 @@ import argparse from collections import deque from dace.sdfg import nodes as nd, propagation, InterstateEdge -from dace import SDFG, SDFGState, dtypes, int64 +from dace import SDFG, SDFGState, dtypes from dace.subsets import Range from typing import Tuple, Dict import os @@ -251,7 +251,7 @@ def tasklet_work(tasklet_node, state): def tasklet_depth(tasklet_node, state): if tasklet_node.code.language == dtypes.Language.CPP: - # For now we simply take depth == work for CPP tasklets. + # Depth == work for CPP tasklets. for oedge in state.out_edges(tasklet_node): return oedge.data.num_accesses if tasklet_node.code.language == dtypes.Language.Python: @@ -292,9 +292,13 @@ def do_initial_subs(w, d, eq, subs1): return sp.simplify(w.subs(eq[0]).subs(eq[1]).subs(subs1)), sp.simplify(d.subs(eq[0]).subs(eq[1]).subs(subs1)) -def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], analyze_tasklet, symbols: Dict[str, str], - detailed_analysis: bool, equality_subs: Tuple[Dict[str, sp.Symbol], Dict[str, sp.Expr]], - subs1: Dict[str, sp.Expr]) -> Tuple[sp.Expr, sp.Expr]: +def sdfg_work_depth(sdfg: SDFG, + w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], + analyze_tasklet, + symbols: Dict[str, str], + equality_subs: Tuple[Dict[str, sp.Symbol], Dict[str, sp.Expr]], + subs1: Dict[str, sp.Expr], + detailed_analysis: bool = False) -> Tuple[sp.Expr, sp.Expr]: """ Analyze the work and depth of a given SDFG. First we determine the work and depth of each state. Then we break loops in the state machine, such that we get a DAG. @@ -318,8 +322,8 @@ def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], ana state_depths: Dict[SDFGState, sp.Expr] = {} state_works: Dict[SDFGState, sp.Expr] = {} for state in sdfg.nodes(): - state_work, state_depth = state_work_depth(state, w_d_map, analyze_tasklet, symbols, detailed_analysis, - equality_subs, subs1) + state_work, state_depth = state_work_depth(state, w_d_map, analyze_tasklet, symbols, equality_subs, subs1, + detailed_analysis) # Substitutions for state_work and state_depth already performed, but state.executions needs to be subs'd now. state_work = sp.simplify(state_work * @@ -475,14 +479,16 @@ def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], ana return sdfg_result -def scope_work_depth(state: SDFGState, - w_d_map: Dict[str, sp.Expr], - analyze_tasklet, - symbols: Dict[str, str], - detailed_analysis: bool, - equality_subs: Tuple[Dict[str, sp.Symbol], Dict[str, sp.Expr]], - subs1: Dict[str, sp.Expr], - entry: nd.EntryNode = None) -> Tuple[sp.Expr, sp.Expr]: +def scope_work_depth( + state: SDFGState, + w_d_map: Dict[str, sp.Expr], + analyze_tasklet, + symbols: Dict[str, str], + equality_subs: Tuple[Dict[str, sp.Symbol], Dict[str, sp.Expr]], + subs1: Dict[str, sp.Expr], + entry: nd.EntryNode = None, + detailed_analysis: bool = False, +) -> Tuple[sp.Expr, sp.Expr]: """ Analyze the work and depth of a scope. This works by traversing through the scope analyzing the work and depth of each encountered node. @@ -519,8 +525,8 @@ def scope_work_depth(state: SDFGState, if isinstance(node, nd.EntryNode): # If the scope contains an entry node, we need to recursively analyze the sub-scope of the entry node first. # The resulting work/depth are summarized into the entry node - s_work, s_depth = scope_work_depth(state, w_d_map, analyze_tasklet, symbols, detailed_analysis, - equality_subs, subs1, node) + s_work, s_depth = scope_work_depth(state, w_d_map, analyze_tasklet, symbols, equality_subs, subs1, node, + detailed_analysis) s_work, s_depth = do_initial_subs(s_work, s_depth, equality_subs, subs1) # add up work for whole state, but also save work for this sub-scope scope in w_d_map work += s_work @@ -545,8 +551,8 @@ def scope_work_depth(state: SDFGState, nested_syms.update(symbols) nested_syms.update(evaluate_symbols(symbols, node.symbol_mapping)) # Nested SDFGs are recursively analyzed first. - nsdfg_work, nsdfg_depth = sdfg_work_depth(node.sdfg, w_d_map, analyze_tasklet, nested_syms, - detailed_analysis, equality_subs, subs1) + nsdfg_work, nsdfg_depth = sdfg_work_depth(node.sdfg, w_d_map, analyze_tasklet, nested_syms, equality_subs, + subs1, detailed_analysis) nsdfg_work, nsdfg_depth = do_initial_subs(nsdfg_work, nsdfg_depth, equality_subs, subs1) # add up work for whole state, but also save work for this nested SDFG in w_d_map @@ -561,7 +567,7 @@ def scope_work_depth(state: SDFGState, # TODO: This symbol should now appear in the VS code extension in the SDFG analysis tab, # such that the user can define its value. But it doesn't... # How to achieve this? - top_level_sdfg.add_symbol(f'{node.name}_work', int64) + top_level_sdfg.add_symbol(f'{node.name}_work', dtypes.int64) lib_node_work = sp.Symbol(f'{node.name}_work', positive=True) lib_node_depth = sp.sympify(-1) # not analyzed if analyze_tasklet != get_tasklet_work: @@ -570,7 +576,7 @@ def scope_work_depth(state: SDFGState, lib_node_depth = LIBNODES_TO_DEPTH[type(node)](node, symbols, state) except KeyError: top_level_sdfg = state.parent - top_level_sdfg.add_symbol(f'{node.name}_depth', int64) + top_level_sdfg.add_symbol(f'{node.name}_depth', dtypes.int64) lib_node_depth = sp.Symbol(f'{node.name}_depth', positive=True) lib_node_work, lib_node_depth = do_initial_subs(lib_node_work, lib_node_depth, equality_subs, subs1) work += lib_node_work @@ -581,7 +587,7 @@ def scope_work_depth(state: SDFGState, if isinstance(entry, nd.MapEntry): nmap: nd.Map = entry.map range: Range = nmap.range - n_exec = range.num_elements_exact() + n_exec = range.num_elements() work = sp.simplify(work * n_exec.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1)) else: print('WARNING: Only Map scopes are supported in work analysis for now. Assuming 1 iteration.') @@ -669,8 +675,13 @@ def scope_work_depth(state: SDFGState, return scope_result -def state_work_depth(state: SDFGState, w_d_map: Dict[str, sp.Expr], analyze_tasklet, symbols, detailed_analysis, - equality_subs, subs1) -> Tuple[sp.Expr, sp.Expr]: +def state_work_depth(state: SDFGState, + w_d_map: Dict[str, sp.Expr], + analyze_tasklet, + symbols, + equality_subs, + subs1, + detailed_analysis=False) -> Tuple[sp.Expr, sp.Expr]: """ Analyze the work and depth of a state. @@ -685,13 +696,16 @@ def state_work_depth(state: SDFGState, w_d_map: Dict[str, sp.Expr], analyze_task :param subs1: First substitution dict for greater/lesser assumptions. :return: A tuple containing the work and depth of the state. """ - work, depth = scope_work_depth(state, w_d_map, analyze_tasklet, symbols, detailed_analysis, equality_subs, subs1, - None) + work, depth = scope_work_depth(state, w_d_map, analyze_tasklet, symbols, equality_subs, subs1, None, + detailed_analysis) return work, depth -def analyze_sdfg(sdfg: SDFG, w_d_map: Dict[str, sp.Expr], analyze_tasklet, assumptions: [str], - detailed_analysis: bool) -> None: +def analyze_sdfg(sdfg: SDFG, + w_d_map: Dict[str, sp.Expr], + analyze_tasklet, + assumptions: [str], + detailed_analysis: bool = False) -> None: """ Analyze a given SDFG. We can either analyze work, work and depth or average parallelism. @@ -724,8 +738,8 @@ def analyze_sdfg(sdfg: SDFG, w_d_map: Dict[str, sp.Expr], analyze_tasklet, assum # Analyze the work and depth of the SDFG. symbols = {} - sdfg_work_depth(sdfg, w_d_map, analyze_tasklet, symbols, detailed_analysis, equality_subs, - all_subs[0][0] if len(all_subs) > 0 else {}) + sdfg_work_depth(sdfg, w_d_map, analyze_tasklet, symbols, equality_subs, all_subs[0][0] if len(all_subs) > 0 else {}, + detailed_analysis) for k, (v_w, v_d) in w_d_map.items(): # The symeval replaces nested SDFG symbols with their global counterparts.