Skip to content

Commit

Permalink
formatting with yapf
Browse files Browse the repository at this point in the history
  • Loading branch information
Cliff Hodel committed Sep 4, 2023
1 parent 4091811 commit 12c2c73
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 66 deletions.
35 changes: 19 additions & 16 deletions dace/sdfg/work_depth_analysis/assumptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,16 +8,16 @@ class UnionFind:
"""
Simple, not really optimized UnionFind implementation.
"""

def __init__(self, elements) -> None:
self.ids = {e : e for e in elements}
self.ids = {e: e for e in elements}

def add_element(self, e):
if e in self.ids:
return False
self.ids.update({e : e})
self.ids.update({e: e})
return True

def find(self, e):
prev = e
curr = self.ids[e]
Expand All @@ -27,16 +27,17 @@ def find(self, e):
# shorten the path
self.ids[e] = curr
return curr

def union(self, e, f):
if f not in self.ids:
self.add_element(f)
self.ids[self.find(e)] = f


class ContradictingAssumptions(Exception):
pass


class Assumptions:
"""
Summarises the assumptions for a single symbol in three lists: equal, greater, lesser.
Expand Down Expand Up @@ -67,7 +68,7 @@ def add_lesser(self, l):

def add_equal(self, e):
for x in self.equal:
if not (isinstance(x, sp.Symbol) or isinstance(e, sp.Symbol)) and x != e:
if not (isinstance(x, sp.Symbol) or isinstance(e, sp.Symbol)) and x != e:
raise ContradictingAssumptions()
self.equal.append(e)
self.check_consistency()
Expand All @@ -89,11 +90,12 @@ def check_consistency(self):
if (g > l) == True:
raise ContradictingAssumptions()
return True

def num_assumptions(self):
# returns the number of individual assumptions for this symbol
return len(self.greater) + len(self.lesser) + len(self.equal)



def propagate_assumptions(x, y, condensed_assumptions):
"""
Assuming x is equal to y, we propagate the assumptions on x to y. E.g. we have x==y and
Expand All @@ -118,6 +120,7 @@ def propagate_assumptions(x, y, condensed_assumptions):
assum_y.add_lesser(l)
assum_y.check_consistency()


def propagate_assumptions_equal_symbols(condensed_assumptions):
"""
This method handles two things: 1) It generates the substitution dict for all equality assumptions.
Expand All @@ -139,7 +142,7 @@ def propagate_assumptions_equal_symbols(condensed_assumptions):
uf.union(sym, other.name)

equality_subs1 = {}

# For each equivalence class, we now have one unique identifier.
# For each class, we give all the assumptions to this single symbol.
# And we swap each symbol in class for this symbol.
Expand All @@ -148,7 +151,7 @@ def propagate_assumptions_equal_symbols(condensed_assumptions):
if isinstance(other, sp.Symbol):
propagate_assumptions(sym, uf.find(sym), condensed_assumptions)
equality_subs1.update({sym: sp.Symbol(uf.find(sym))})

equality_subs2 = {}
# In a second step, each symbol gets replace with its equal number (if present)
# using equality_subs2.
Expand Down Expand Up @@ -213,7 +216,7 @@ def parse_assumptions(assumptions, array_symbols):

if assumptions is None:
return {}, [({}, {})]

# Gather assumptions, keeping only the strongest ones for each symbol.
condensed_assumptions: Dict[str, Assumptions] = {}
for a in assumptions:
Expand Down Expand Up @@ -252,13 +255,13 @@ def parse_assumptions(assumptions, array_symbols):

# Handle equal assumptions.
equality_subs = propagate_assumptions_equal_symbols(condensed_assumptions)

# How many assumptions does symbol with most assumptions have?
curr_max = -1
for _, assum in condensed_assumptions.items():
if assum.num_assumptions() > curr_max:
curr_max = assum.num_assumptions()

all_subs = []
for i in range(curr_max):
all_subs.append(({}, {}))
Expand All @@ -271,12 +274,12 @@ def parse_assumptions(assumptions, array_symbols):
for g in assum.greater:
replacement_symbol = sp.Symbol(f'_p_{sym}', positive=True, integer=True)
all_subs[i][0].update({sp.Symbol(sym): replacement_symbol + g})
all_subs[i][1].update({replacement_symbol : sp.Symbol(sym) - g})
all_subs[i][1].update({replacement_symbol: sp.Symbol(sym) - g})
i += 1
for l in assum.lesser:
replacement_symbol = sp.Symbol(f'_n_{sym}', negative=True, integer=True)
all_subs[i][0].update({sp.Symbol(sym): replacement_symbol + l})
all_subs[i][1].update({replacement_symbol: sp.Symbol(sym) - l})
i += 1

return equality_subs, all_subs
return equality_subs, all_subs
59 changes: 37 additions & 22 deletions dace/sdfg/work_depth_analysis/work_depth.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def count_work_matmul(node, symbols, state):
result *= symeval(A_memlet.data.subset.size()[-1], symbols)
return sp.sympify(result)


def count_depth_matmul(node, symbols, state):
# optimal depth of a matrix multiplication is O(log(size of shared dimension)):
A_memlet = next(e for e in state.in_edges(node) if e.dst_conn == '_a')
Expand All @@ -97,6 +98,7 @@ def count_work_reduce(node, symbols, state):
result = 0
return sp.sympify(result)


def count_depth_reduce(node, symbols, state):
# optimal depth of reduction is log of the work
return bigo(sp.log(count_work_reduce(node, symbols, state)))
Expand Down Expand Up @@ -272,6 +274,7 @@ def get_tasklet_work_depth(node, state):
def get_tasklet_avg_par(node, state):
return sp.sympify(tasklet_work(node, state)), sp.sympify(tasklet_depth(node, state))


def update_value_map(old, new):
# add new assignments to old
old.update({k: v for k, v in new.items() if k not in old})
Expand All @@ -281,15 +284,16 @@ def update_value_map(old, new):
# conflict detected --> forget this mapping completely
old.pop(k)


def do_initial_subs(w, d, eq, subs1):
"""
Calls subs three times for the give (w)ork and (d)epth values.
"""
return sp.simplify(w.subs(eq[0]).subs(eq[1]).subs(subs1)), sp.simplify(d.subs(eq[0]).subs(eq[1]).subs(subs1))


def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], analyze_tasklet,
symbols: Dict[str, str], detailed_analysis: bool, equality_subs: Tuple[Dict[str, sp.Symbol], Dict[str, sp.Expr]],
def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], analyze_tasklet, symbols: Dict[str, str],
detailed_analysis: bool, equality_subs: Tuple[Dict[str, sp.Symbol], Dict[str, sp.Expr]],
subs1: Dict[str, sp.Expr]) -> Tuple[sp.Expr, sp.Expr]:
"""
Analyze the work and depth of a given SDFG.
Expand All @@ -314,11 +318,14 @@ def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], ana
state_depths: Dict[SDFGState, sp.Expr] = {}
state_works: Dict[SDFGState, sp.Expr] = {}
for state in sdfg.nodes():
state_work, state_depth = state_work_depth(state, w_d_map, analyze_tasklet, symbols, detailed_analysis, equality_subs, subs1)

state_work, state_depth = state_work_depth(state, w_d_map, analyze_tasklet, symbols, detailed_analysis,
equality_subs, subs1)

# Substitutions for state_work and state_depth already performed, but state.executions needs to be subs'd now.
state_work = sp.simplify(state_work * state.executions.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1))
state_depth = sp.simplify(state_depth * state.executions.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1))
state_work = sp.simplify(state_work *
state.executions.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1))
state_depth = sp.simplify(state_depth *
state.executions.subs(equality_subs[0]).subs(equality_subs[1]).subs(subs1))

state_works[state], state_depths[state] = state_work, state_depth
w_d_map[get_uuid(state)] = (state_works[state], state_depths[state])
Expand Down Expand Up @@ -376,7 +383,7 @@ def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], ana

if ie is not None:
visited.add(ie)

if state in state_value_map:
# update value map:
update_value_map(state_value_map[state], value_map)
Expand Down Expand Up @@ -405,8 +412,10 @@ def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], ana
old_avg_par = (cse[0] + work_map[state]) / (cse[1] + depth_map[state])
new_avg_par = (cse[0] + n_work) / (cse[1] + n_depth)
# we take either old work/depth or new work/depth (or both if we cannot determine which one is greater)
depth_map[state] = cse[1] + sp.Piecewise((n_depth, sp.simplify(new_avg_par < old_avg_par)), (depth_map[state], True))
work_map[state] = cse[0] + sp.Piecewise((n_work, sp.simplify(new_avg_par < old_avg_par)), (work_map[state], True))
depth_map[state] = cse[1] + sp.Piecewise((n_depth, sp.simplify(new_avg_par < old_avg_par)),
(depth_map[state], True))
work_map[state] = cse[0] + sp.Piecewise((n_work, sp.simplify(new_avg_par < old_avg_par)),
(work_map[state], True))
else:
depth_map[state] = n_depth
work_map[state] = n_work
Expand Down Expand Up @@ -451,7 +460,8 @@ def sdfg_work_depth(sdfg: SDFG, w_d_map: Dict[str, Tuple[sp.Expr, sp.Expr]], ana
traversal_q.append((oedge.dst, 0, 0, oedge, new_cond_stack, new_cse_stack, new_value_map))
else:
value_map.update(oedge.data.assignments)
traversal_q.append((oedge.dst, depth_map[state], work_map[state], oedge, condition_stack, common_subexpr_stack, value_map))
traversal_q.append((oedge.dst, depth_map[state], work_map[state], oedge, condition_stack,
common_subexpr_stack, value_map))

try:
max_depth = depth_map[dummy_exit]
Expand Down Expand Up @@ -511,7 +521,8 @@ def scope_work_depth(state: SDFGState,
if isinstance(node, nd.EntryNode):
# If the scope contains an entry node, we need to recursively analyze the sub-scope of the entry node first.
# The resulting work/depth are summarized into the entry node
s_work, s_depth = scope_work_depth(state, w_d_map, analyze_tasklet, symbols, detailed_analysis, equality_subs, subs1, node)
s_work, s_depth = scope_work_depth(state, w_d_map, analyze_tasklet, symbols, detailed_analysis,
equality_subs, subs1, node)
s_work, s_depth = do_initial_subs(s_work, s_depth, equality_subs, subs1)
# add up work for whole state, but also save work for this sub-scope scope in w_d_map
work += s_work
Expand All @@ -536,7 +547,8 @@ def scope_work_depth(state: SDFGState,
nested_syms.update(symbols)
nested_syms.update(evaluate_symbols(symbols, node.symbol_mapping))
# Nested SDFGs are recursively analyzed first.
nsdfg_work, nsdfg_depth = sdfg_work_depth(node.sdfg, w_d_map, analyze_tasklet, nested_syms, detailed_analysis, equality_subs, subs1)
nsdfg_work, nsdfg_depth = sdfg_work_depth(node.sdfg, w_d_map, analyze_tasklet, nested_syms,
detailed_analysis, equality_subs, subs1)

nsdfg_work, nsdfg_depth = do_initial_subs(nsdfg_work, nsdfg_depth, equality_subs, subs1)
# add up work for whole state, but also save work for this nested SDFG in w_d_map
Expand Down Expand Up @@ -629,7 +641,8 @@ def scope_work_depth(state: SDFGState,
wcr_depth = oedge.data.volume / oedge.data.subset.num_elements()
if get_uuid(node, state) in wcr_depth_map:
# max
wcr_depth_map[get_uuid(node, state)] = sp.Max(wcr_depth_map[get_uuid(node, state)], wcr_depth)
wcr_depth_map[get_uuid(node, state)] = sp.Max(wcr_depth_map[get_uuid(node, state)],
wcr_depth)
else:
wcr_depth_map[get_uuid(node, state)] = wcr_depth
# We do not need to propagate the wcr_depth to MapExits, since else this will result in depth N + 1 for Maps of range N.
Expand All @@ -649,7 +662,7 @@ def scope_work_depth(state: SDFGState,
if len(out_edges) == 0 or node == scope_exit:
# We have reached an end node --> update max_depth
max_depth = sp.Max(max_depth, depth_map[node])

for uuid in wcr_depth_map:
w_d_map[uuid] = (w_d_map[uuid][0], w_d_map[uuid][1] + wcr_depth_map[uuid])
# summarise work / depth of the whole scope in the dictionary
Expand All @@ -658,8 +671,8 @@ def scope_work_depth(state: SDFGState,
return scope_result


def state_work_depth(state: SDFGState, w_d_map: Dict[str, sp.Expr], analyze_tasklet,
symbols, detailed_analysis, equality_subs, subs1) -> Tuple[sp.Expr, sp.Expr]:
def state_work_depth(state: SDFGState, w_d_map: Dict[str, sp.Expr], analyze_tasklet, symbols, detailed_analysis,
equality_subs, subs1) -> Tuple[sp.Expr, sp.Expr]:
"""
Analyze the work and depth of a state.
Expand All @@ -674,11 +687,13 @@ def state_work_depth(state: SDFGState, w_d_map: Dict[str, sp.Expr], analyze_task
:param subs1: First substitution dict for greater/lesser assumptions.
:return: A tuple containing the work and depth of the state.
"""
work, depth = scope_work_depth(state, w_d_map, analyze_tasklet, symbols, detailed_analysis, equality_subs, subs1, None)
work, depth = scope_work_depth(state, w_d_map, analyze_tasklet, symbols, detailed_analysis, equality_subs, subs1,
None)
return work, depth


def analyze_sdfg(sdfg: SDFG, w_d_map: Dict[str, sp.Expr], analyze_tasklet, assumptions: [str], detailed_analysis: bool) -> None:
def analyze_sdfg(sdfg: SDFG, w_d_map: Dict[str, sp.Expr], analyze_tasklet, assumptions: [str],
detailed_analysis: bool) -> None:
"""
Analyze a given SDFG. We can either analyze work, work and depth or average parallelism.
Expand Down Expand Up @@ -711,8 +726,9 @@ def analyze_sdfg(sdfg: SDFG, w_d_map: Dict[str, sp.Expr], analyze_tasklet, assum

# Analyze the work and depth of the SDFG.
symbols = {}
sdfg_work_depth(sdfg, w_d_map, analyze_tasklet, symbols, detailed_analysis, equality_subs, all_subs[0][0] if len(all_subs) > 0 else {})

sdfg_work_depth(sdfg, w_d_map, analyze_tasklet, symbols, detailed_analysis, equality_subs,
all_subs[0][0] if len(all_subs) > 0 else {})

for k, (v_w, v_d) in w_d_map.items():
# The symeval replaces nested SDFG symbols with their global counterparts.
v_w, v_d = do_subs(v_w, v_d, all_subs)
Expand Down Expand Up @@ -758,8 +774,7 @@ def main() -> None:
help='Choose what to analyze. Default: workDepth')
parser.add_argument('--assume', nargs='*', help='Collect assumptions about symbols, e.g. x>0 x>y y==5')

parser.add_argument("--detailed", action="store_true",
help="Turns on detailed mode.")
parser.add_argument("--detailed", action="store_true", help="Turns on detailed mode.")
args = parser.parse_args()

if not os.path.exists(args.filename):
Expand Down
Loading

0 comments on commit 12c2c73

Please sign in to comment.