diff --git a/dace/sdfg/work_depth_analysis/assumptions.py b/dace/sdfg/work_depth_analysis/assumptions.py index 1f167f15a3..c7e439cf51 100644 --- a/dace/sdfg/work_depth_analysis/assumptions.py +++ b/dace/sdfg/work_depth_analysis/assumptions.py @@ -8,16 +8,16 @@ class UnionFind: """ Simple, not really optimized UnionFind implementation. """ - + def __init__(self, elements) -> None: - self.ids = {e : e for e in elements} + self.ids = {e: e for e in elements} def add_element(self, e): if e in self.ids: return False - self.ids.update({e : e}) + self.ids.update({e: e}) return True - + def find(self, e): prev = e curr = self.ids[e] @@ -27,16 +27,17 @@ def find(self, e): # shorten the path self.ids[e] = curr return curr - + def union(self, e, f): if f not in self.ids: self.add_element(f) self.ids[self.find(e)] = f - + class ContradictingAssumptions(Exception): pass + class Assumptions: """ Summarises the assumptions for a single symbol in three lists: equal, greater, lesser. @@ -67,7 +68,7 @@ def add_lesser(self, l): def add_equal(self, e): for x in self.equal: - if not (isinstance(x, sp.Symbol) or isinstance(e, sp.Symbol)) and x != e: + if not (isinstance(x, sp.Symbol) or isinstance(e, sp.Symbol)) and x != e: raise ContradictingAssumptions() self.equal.append(e) self.check_consistency() @@ -89,11 +90,12 @@ def check_consistency(self): if (g > l) == True: raise ContradictingAssumptions() return True - + def num_assumptions(self): # returns the number of individual assumptions for this symbol return len(self.greater) + len(self.lesser) + len(self.equal) - + + def propagate_assumptions(x, y, condensed_assumptions): """ Assuming x is equal to y, we propagate the assumptions on x to y. E.g. we have x==y and @@ -118,6 +120,7 @@ def propagate_assumptions(x, y, condensed_assumptions): assum_y.add_lesser(l) assum_y.check_consistency() + def propagate_assumptions_equal_symbols(condensed_assumptions): """ This method handles two things: 1) It generates the substitution dict for all equality assumptions. @@ -139,7 +142,7 @@ def propagate_assumptions_equal_symbols(condensed_assumptions): uf.union(sym, other.name) equality_subs1 = {} - + # For each equivalence class, we now have one unique identifier. # For each class, we give all the assumptions to this single symbol. # And we swap each symbol in class for this symbol. @@ -148,7 +151,7 @@ def propagate_assumptions_equal_symbols(condensed_assumptions): if isinstance(other, sp.Symbol): propagate_assumptions(sym, uf.find(sym), condensed_assumptions) equality_subs1.update({sym: sp.Symbol(uf.find(sym))}) - + equality_subs2 = {} # In a second step, each symbol gets replace with its equal number (if present) # using equality_subs2. @@ -213,7 +216,7 @@ def parse_assumptions(assumptions, array_symbols): if assumptions is None: return {}, [({}, {})] - + # Gather assumptions, keeping only the strongest ones for each symbol. condensed_assumptions: Dict[str, Assumptions] = {} for a in assumptions: @@ -252,13 +255,13 @@ def parse_assumptions(assumptions, array_symbols): # Handle equal assumptions. equality_subs = propagate_assumptions_equal_symbols(condensed_assumptions) - + # How many assumptions does symbol with most assumptions have? curr_max = -1 for _, assum in condensed_assumptions.items(): if assum.num_assumptions() > curr_max: curr_max = assum.num_assumptions() - + all_subs = [] for i in range(curr_max): all_subs.append(({}, {})) @@ -271,7 +274,7 @@ def parse_assumptions(assumptions, array_symbols): for g in assum.greater: replacement_symbol = sp.Symbol(f'_p_{sym}', positive=True, integer=True) all_subs[i][0].update({sp.Symbol(sym): replacement_symbol + g}) - all_subs[i][1].update({replacement_symbol : sp.Symbol(sym) - g}) + all_subs[i][1].update({replacement_symbol: sp.Symbol(sym) - g}) i += 1 for l in assum.lesser: replacement_symbol = sp.Symbol(f'_n_{sym}', negative=True, integer=True) @@ -279,4 +282,4 @@ def parse_assumptions(assumptions, array_symbols): all_subs[i][1].update({replacement_symbol: sp.Symbol(sym) - l}) i += 1 - return equality_subs, all_subs \ No newline at end of file + return equality_subs, all_subs diff --git a/dace/sdfg/work_depth_analysis/work_depth.py b/dace/sdfg/work_depth_analysis/work_depth.py index da700bd829..21e5b937b9 100644 --- a/dace/sdfg/work_depth_analysis/work_depth.py +++ b/dace/sdfg/work_depth_analysis/work_depth.py @@ -76,6 +76,7 @@ def count_work_matmul(node, symbols, state): result *= symeval(A_memlet.data.subset.size()[-1], symbols) return sp.sympify(result) + def count_depth_matmul(node, symbols, state): # optimal depth of a matrix multiplication is O(log(size of shared dimension)): A_memlet = next(e for e in state.in_edges(node) if e.dst_conn == '_a') @@ -97,6 +98,7 @@ def count_work_reduce(node, symbols, state): result = 0 return sp.sympify(result) + def count_depth_reduce(node, symbols, state): # optimal depth of reduction is log of the work return bigo(sp.log(count_work_reduce(node, symbols, state))) @@ -272,6 +274,7 @@ def get_tasklet_work_depth(node, state): def get_tasklet_avg_par(node, state): return sp.sympify(tasklet_work(node, state)), sp.sympify(tasklet_depth(node, state)) + def update_value_map(old, new): # add new assignments to old old.update({k: v for k, v in new.items() if k not in old}) @@ -281,6 +284,7 @@ def update_value_map(old, new): # conflict detected --> forget this mapping completely old.pop(k) + def do_initial_subs(w, d, eq, subs1): """ Calls subs three times for the give (w)ork and (d)epth values. @@ -288,8 +292,8 @@ def do_initial_subs(w, d, eq, subs1): return sp.simplify(w.subs(eq[0]).subs(eq[1]).subs(subs1)), sp.simplify(d.subs(eq[0]).subs(eq[1]).subs(subs1)) -def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], analyze_tasklet, - symbols: Dict[str, str], detailed_analysis: bool, equality_subs: Tuple[Dict[str, sp.Symbol], Dict[str, sp.Expr]], +def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], analyze_tasklet, symbols: Dict[str, str], + detailed_analysis: bool, equality_subs: Tuple[Dict[str, sp.Symbol], Dict[str, sp.Expr]], subs1: Dict[str, sp.Expr]) -> Tuple[sp.Expr, sp.Expr]: """ Analyze the work and depth of a given SDFG. @@ -314,11 +318,14 @@ def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], ana state_depths: Dict[SDFGState, sp.Expr] = {} state_works: Dict[SDFGState, sp.Expr] = {} for state in sdfg.nodes(): - state_work, state_depth = state_work_depth(state, w_d_map, analyze_tasklet, symbols, detailed_analysis, equality_subs, subs1) - + state_work, state_depth = state_work_depth(state, w_d_map, analyze_tasklet, symbols, detailed_analysis, + equality_subs, subs1) + # Substitutions for state_work and state_depth already performed, but state.executions needs to be subs'd now. - state_work = sp.simplify(state_work * state.executions.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1)) - state_depth = sp.simplify(state_depth * state.executions.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1)) + state_work = sp.simplify(state_work * + state.executions.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1)) + state_depth = sp.simplify(state_depth * + state.executions.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1)) state_works[state], state_depths[state] = state_work, state_depth w_d_map[get_uuid(state)] = (state_works[state], state_depths[state]) @@ -376,7 +383,7 @@ def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], ana if ie is not None: visited.add(ie) - + if state in state_value_map: # update value map: update_value_map(state_value_map[state], value_map) @@ -405,8 +412,10 @@ def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], ana old_avg_par = (cse[0] + work_map[state]) / (cse[1] + depth_map[state]) new_avg_par = (cse[0] + n_work) / (cse[1] + n_depth) # we take either old work/depth or new work/depth (or both if we cannot determine which one is greater) - depth_map[state] = cse[1] + sp.Piecewise((n_depth, sp.simplify(new_avg_par < old_avg_par)), (depth_map[state], True)) - work_map[state] = cse[0] + sp.Piecewise((n_work, sp.simplify(new_avg_par < old_avg_par)), (work_map[state], True)) + depth_map[state] = cse[1] + sp.Piecewise((n_depth, sp.simplify(new_avg_par < old_avg_par)), + (depth_map[state], True)) + work_map[state] = cse[0] + sp.Piecewise((n_work, sp.simplify(new_avg_par < old_avg_par)), + (work_map[state], True)) else: depth_map[state] = n_depth work_map[state] = n_work @@ -451,7 +460,8 @@ def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], ana traversal_q.append((oedge.dst, 0, 0, oedge, new_cond_stack, new_cse_stack, new_value_map)) else: value_map.update(oedge.data.assignments) - traversal_q.append((oedge.dst, depth_map[state], work_map[state], oedge, condition_stack, common_subexpr_stack, value_map)) + traversal_q.append((oedge.dst, depth_map[state], work_map[state], oedge, condition_stack, + common_subexpr_stack, value_map)) try: max_depth = depth_map[dummy_exit] @@ -511,7 +521,8 @@ def scope_work_depth(state: SDFGState, if isinstance(node, nd.EntryNode): # If the scope contains an entry node, we need to recursively analyze the sub-scope of the entry node first. # The resulting work/depth are summarized into the entry node - s_work, s_depth = scope_work_depth(state, w_d_map, analyze_tasklet, symbols, detailed_analysis, equality_subs, subs1, node) + s_work, s_depth = scope_work_depth(state, w_d_map, analyze_tasklet, symbols, detailed_analysis, + equality_subs, subs1, node) s_work, s_depth = do_initial_subs(s_work, s_depth, equality_subs, subs1) # add up work for whole state, but also save work for this sub-scope scope in w_d_map work += s_work @@ -536,7 +547,8 @@ def scope_work_depth(state: SDFGState, nested_syms.update(symbols) nested_syms.update(evaluate_symbols(symbols, node.symbol_mapping)) # Nested SDFGs are recursively analyzed first. - nsdfg_work, nsdfg_depth = sdfg_work_depth(node.sdfg, w_d_map, analyze_tasklet, nested_syms, detailed_analysis, equality_subs, subs1) + nsdfg_work, nsdfg_depth = sdfg_work_depth(node.sdfg, w_d_map, analyze_tasklet, nested_syms, + detailed_analysis, equality_subs, subs1) nsdfg_work, nsdfg_depth = do_initial_subs(nsdfg_work, nsdfg_depth, equality_subs, subs1) # add up work for whole state, but also save work for this nested SDFG in w_d_map @@ -629,7 +641,8 @@ def scope_work_depth(state: SDFGState, wcr_depth = oedge.data.volume / oedge.data.subset.num_elements() if get_uuid(node, state) in wcr_depth_map: # max - wcr_depth_map[get_uuid(node, state)] = sp.Max(wcr_depth_map[get_uuid(node, state)], wcr_depth) + wcr_depth_map[get_uuid(node, state)] = sp.Max(wcr_depth_map[get_uuid(node, state)], + wcr_depth) else: wcr_depth_map[get_uuid(node, state)] = wcr_depth # We do not need to propagate the wcr_depth to MapExits, since else this will result in depth N + 1 for Maps of range N. @@ -649,7 +662,7 @@ def scope_work_depth(state: SDFGState, if len(out_edges) == 0 or node == scope_exit: # We have reached an end node --> update max_depth max_depth = sp.Max(max_depth, depth_map[node]) - + for uuid in wcr_depth_map: w_d_map[uuid] = (w_d_map[uuid][0], w_d_map[uuid][1] + wcr_depth_map[uuid]) # summarise work / depth of the whole scope in the dictionary @@ -658,8 +671,8 @@ def scope_work_depth(state: SDFGState, return scope_result -def state_work_depth(state: SDFGState, w_d_map: Dict[str, sp.Expr], analyze_tasklet, - symbols, detailed_analysis, equality_subs, subs1) -> Tuple[sp.Expr, sp.Expr]: +def state_work_depth(state: SDFGState, w_d_map: Dict[str, sp.Expr], analyze_tasklet, symbols, detailed_analysis, + equality_subs, subs1) -> Tuple[sp.Expr, sp.Expr]: """ Analyze the work and depth of a state. @@ -674,11 +687,13 @@ def state_work_depth(state: SDFGState, w_d_map: Dict[str, sp.Expr], analyze_task :param subs1: First substitution dict for greater/lesser assumptions. :return: A tuple containing the work and depth of the state. """ - work, depth = scope_work_depth(state, w_d_map, analyze_tasklet, symbols, detailed_analysis, equality_subs, subs1, None) + work, depth = scope_work_depth(state, w_d_map, analyze_tasklet, symbols, detailed_analysis, equality_subs, subs1, + None) return work, depth -def analyze_sdfg(sdfg: SDFG, w_d_map: Dict[str, sp.Expr], analyze_tasklet, assumptions: [str], detailed_analysis: bool) -> None: +def analyze_sdfg(sdfg: SDFG, w_d_map: Dict[str, sp.Expr], analyze_tasklet, assumptions: [str], + detailed_analysis: bool) -> None: """ Analyze a given SDFG. We can either analyze work, work and depth or average parallelism. @@ -711,8 +726,9 @@ def analyze_sdfg(sdfg: SDFG, w_d_map: Dict[str, sp.Expr], analyze_tasklet, assum # Analyze the work and depth of the SDFG. symbols = {} - sdfg_work_depth(sdfg, w_d_map, analyze_tasklet, symbols, detailed_analysis, equality_subs, all_subs[0][0] if len(all_subs) > 0 else {}) - + sdfg_work_depth(sdfg, w_d_map, analyze_tasklet, symbols, detailed_analysis, equality_subs, + all_subs[0][0] if len(all_subs) > 0 else {}) + for k, (v_w, v_d) in w_d_map.items(): # The symeval replaces nested SDFG symbols with their global counterparts. v_w, v_d = do_subs(v_w, v_d, all_subs) @@ -758,8 +774,7 @@ def main() -> None: help='Choose what to analyze. Default: workDepth') parser.add_argument('--assume', nargs='*', help='Collect assumptions about symbols, e.g. x>0 x>y y==5') - parser.add_argument("--detailed", action="store_true", - help="Turns on detailed mode.") + parser.add_argument("--detailed", action="store_true", help="Turns on detailed mode.") args = parser.parse_args() if not os.path.exists(args.filename): diff --git a/tests/sdfg/work_depth_tests.py b/tests/sdfg/work_depth_tests.py index 924397aa1e..05375007df 100644 --- a/tests/sdfg/work_depth_tests.py +++ b/tests/sdfg/work_depth_tests.py @@ -156,17 +156,18 @@ def break_while_loop(x: dc.float64[N]): break x += 1 + @dc.program -def sequntial_ifs(x: dc.float64[N + 1], y: dc.float64[M + 1]): # --> cannot assume N, M to be positive +def sequntial_ifs(x: dc.float64[N + 1], y: dc.float64[M + 1]): # --> cannot assume N, M to be positive if x[0] > 5: - x[:] += 1 # N+1 work, 1 depth + x[:] += 1 # N+1 work, 1 depth else: for i in range(M): # M work, M depth - y[i+1] += y[i] + y[i + 1] += y[i] if M > N: - y[:N+1] += x[:] # N+1 work, 1 depth + y[:N + 1] += x[:] # N+1 work, 1 depth else: - x[:M+1] += y[:] # M+1 work, 1 depth + x[:M + 1] += y[:] # M+1 work, 1 depth # --> Work: Max(N+1, M) + Max(N+1, M+1) # Depth: Max(1, M) + 1 @@ -185,14 +186,12 @@ def sequntial_ifs(x: dc.float64[N + 1], y: dc.float64[M + 1]): # --> cannot assu (multiple_array_sizes, (sp.Max(2 * K, 3 * N, 2 * M + 3), 5)), (unbounded_while_do, (sp.Symbol('num_execs_0_2') * N, sp.Symbol('num_execs_0_2'))), # We get this Max(1, num_execs), since it is a do-while loop, but the num_execs symbol does not capture this. - (unbounded_do_while, (sp.Max(1, sp.Symbol('num_execs_0_1')) * N, - sp.Max(1, sp.Symbol('num_execs_0_1')))), - (unbounded_nonnegify, (2 * sp.Symbol('num_execs_0_7') * N, - 2 * sp.Symbol('num_execs_0_7'))), + (unbounded_do_while, (sp.Max(1, sp.Symbol('num_execs_0_1')) * N, sp.Max(1, sp.Symbol('num_execs_0_1')))), + (unbounded_nonnegify, (2 * sp.Symbol('num_execs_0_7') * N, 2 * sp.Symbol('num_execs_0_7'))), (continue_for_loop, (sp.Symbol('num_execs_0_6') * N, sp.Symbol('num_execs_0_6'))), (break_for_loop, (N**2, N)), (break_while_loop, (sp.Symbol('num_execs_0_5') * N, sp.Symbol('num_execs_0_5'))), - (sequntial_ifs, (sp.Max(N+1, M) + sp.Max(N+1, M+1), sp.Max(1, M) + 1)) + (sequntial_ifs, (sp.Max(N + 1, M) + sp.Max(N + 1, M + 1), sp.Max(1, M) + 1)) ] @@ -210,7 +209,10 @@ def test_work_depth(): # We do this since sp.Symbol('N') == Sp.Symbol('N', positive=True) --> False. reps = {s: sp.Symbol(s.name) for s in (res[0].free_symbols | res[1].free_symbols)} res = (res[0].subs(reps), res[1].subs(reps)) - reps = {s: sp.Symbol(s.name) for s in (sp.sympify(correct[0]).free_symbols | sp.sympify(correct[1]).free_symbols)} + reps = { + s: sp.Symbol(s.name) + for s in (sp.sympify(correct[0]).free_symbols | sp.sympify(correct[1]).free_symbols) + } correct = (sp.sympify(correct[0]).subs(reps), sp.sympify(correct[1]).subs(reps)) # check result assert correct == res @@ -219,31 +221,24 @@ def test_work_depth(): x, y, z, a = sp.symbols('x y z a') # (expr, assumptions, result) -assumptions_tests=[ - (sp.Max(x, y), ['x>y'], x), - (sp.Max(x, y, z), ['x>y'], sp.Max(x, z)), - (sp.Max(x, y), ['x==y'], y), - (sp.Max(x, 11) + sp.Max(x, 3), ['x<11'], 11 + sp.Max(x,3)), - (sp.Max(x, 11) + sp.Max(x, 3), ['x<11', 'x>3'], 11 + x), - (sp.Max(x, 11), ['x>5', 'x>3', 'x>11'], x), - (sp.Max(x, 11), ['x==y', 'x>11'], y), +assumptions_tests = [ + (sp.Max(x, y), ['x>y'], x), (sp.Max(x, y, z), ['x>y'], sp.Max(x, z)), (sp.Max(x, y), ['x==y'], y), + (sp.Max(x, 11) + sp.Max(x, 3), ['x<11'], 11 + sp.Max(x, 3)), (sp.Max(x, 11) + sp.Max(x, 3), ['x<11', + 'x>3'], 11 + x), + (sp.Max(x, 11), ['x>5', 'x>3', 'x>11'], x), (sp.Max(x, 11), ['x==y', 'x>11'], y), (sp.Max(x, 11) + sp.Max(a, 5), ['a==b', 'b==c', 'c==x', 'a<11', 'c>7'], x + 11), - (sp.Max(x, 11) + sp.Max(a, 5), ['a==b', 'b==c', 'c==x', 'b==7'], 18), - (sp.Max(x, y), ['y>x', 'y==1000'], 1000), + (sp.Max(x, 11) + sp.Max(a, 5), ['a==b', 'b==c', 'c==x', 'b==7'], 18), (sp.Max(x, y), ['y>x', 'y==1000'], 1000), (sp.Max(x, y), ['y0', 'N<5', 'M>5'], M) - ] # These assumptions should trigger the ContradictingAssumptions exception. -tests_for_exception = [ - ['x>10', 'x<9'], - ['x==y', 'x>10', 'y<9'], - ['a==b', 'b==c', 'c==d', 'd==e', 'e==f', 'x==y', 'y==z', 'z>b', 'x==5', 'd==100'], - ['x==5', 'x<4'] -] +tests_for_exception = [['x>10', 'x<9'], ['x==y', 'x>10', 'y<9'], + ['a==b', 'b==c', 'c==d', 'd==e', 'e==f', 'x==y', 'y==z', 'z>b', 'x==5', 'd==100'], + ['x==5', 'x<4']] + def test_assumption_system(): for expr, assums, res in assumptions_tests: