Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements to work depth analysis #1363

Merged
merged 30 commits into from
Sep 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
34df15c
initial push of work_depth analysis script
Jul 4, 2023
e000247
adding tests to work_depth analysis
Jul 4, 2023
58e1906
rename work depth analysis
Jul 4, 2023
0b3fdea
todos added
Jul 11, 2023
05f8af3
Merge branch 'spcl:master' into master
hodelcl Jul 11, 2023
50cdd98
Merge branch 'spcl:master' into master
hodelcl Jul 18, 2023
7f3a997
code ready for PR
Jul 20, 2023
027d9de
yapf for formatting
Jul 20, 2023
cd68a09
Merge branch 'master' of github.com:hodelcl/dace
Jul 20, 2023
550622f
put tests into dace/tests/sdfg
Jul 21, 2023
9efad2a
Merge branch 'master' of github.com:hodelcl/dace
Jul 21, 2023
862aaeb
fixed import after merge
Jul 21, 2023
9ad130d
merged propgatate_states_symbolically into propagate_states
Jul 22, 2023
f48c4ee
fixed format issue in work_depth.py
Jul 22, 2023
e3f667e
Merge branch 'spcl:master' into master
hodelcl Jul 23, 2023
4b3f6a7
small bugfix
Jul 23, 2023
7b0b2eb
include wcr edges into analysis, improve LibraryNodes analysis
Jul 26, 2023
ca8f8ef
Merge branch 'master' into master
phschaad Aug 11, 2023
a8efd97
imporved work depth. wcr now analyses, performance improved, assumpti…
Sep 4, 2023
e5a2687
Merge branch 'spcl:master' into master
hodelcl Sep 4, 2023
4091811
Merge branch 'master' of github.com:hodelcl/dace
Sep 4, 2023
12c2c73
formatting with yapf
Sep 4, 2023
91583e7
minor changes
Sep 4, 2023
2a84e5a
Merge branch 'master' into master
phschaad Sep 5, 2023
7431e25
Merge branch 'master' into master
phschaad Sep 18, 2023
eb5a6f4
start of op_in analysis
Sep 18, 2023
3cee571
Merge branch 'master' of github.com:hodelcl/dace
Sep 18, 2023
55c1598
Revert "start of op_in analysis"
Sep 18, 2023
ef79124
changes according to comments
Sep 21, 2023
f29ddb0
Merge branch 'master' into master
phschaad Sep 25, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
285 changes: 285 additions & 0 deletions dace/sdfg/work_depth_analysis/assumptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,285 @@
# Copyright 2019-2023 ETH Zurich and the DaCe authors. All rights reserved.

import sympy as sp
from typing import Dict


class UnionFind:
"""
Simple, not really optimized UnionFind implementation.
"""

def __init__(self, elements) -> None:
self.ids = {e: e for e in elements}

def add_element(self, e):
if e in self.ids:
return False
self.ids.update({e: e})
return True

def find(self, e):
prev = e
curr = self.ids[e]
while prev != curr:
prev = curr
curr = self.ids[curr]
# shorten the path
self.ids[e] = curr
return curr

def union(self, e, f):
if f not in self.ids:
self.add_element(f)
self.ids[self.find(e)] = f


class ContradictingAssumptions(Exception):
pass


class Assumptions:
"""
Summarises the assumptions for a single symbol in three lists: equal, greater, lesser.
"""

def __init__(self) -> None:
self.greater = []
self.lesser = []
self.equal = []

def add_greater(self, g):
if isinstance(g, sp.Symbol):
self.greater.append(g)
else:
self.greater = [x for x in self.greater if isinstance(x, sp.Symbol) or x > g]
if len([y for y in self.greater if not isinstance(y, sp.Symbol)]) == 0:
self.greater.append(g)
self.check_consistency()

def add_lesser(self, l):
if isinstance(l, sp.Symbol):
self.lesser.append(l)
else:
self.lesser = [x for x in self.lesser if isinstance(x, sp.Symbol) or x < l]
if len([y for y in self.lesser if not isinstance(y, sp.Symbol)]) == 0:
self.lesser.append(l)
self.check_consistency()

def add_equal(self, e):
for x in self.equal:
if not (isinstance(x, sp.Symbol) or isinstance(e, sp.Symbol)) and x != e:
raise ContradictingAssumptions()
self.equal.append(e)
self.check_consistency()

def check_consistency(self):
if len(self.equal) > 0:
# we know exact value
for e in self.equal:
for g in self.greater:
if (e <= g) == True:
raise ContradictingAssumptions()
for l in self.lesser:
if (e >= l) == True:
raise ContradictingAssumptions()
else:
# check if any greater > any lesser
for g in self.greater:
for l in self.lesser:
if (g > l) == True:
raise ContradictingAssumptions()
return True

def num_assumptions(self):
# returns the number of individual assumptions for this symbol
return len(self.greater) + len(self.lesser) + len(self.equal)


def propagate_assumptions(x, y, condensed_assumptions):
"""
Assuming x is equal to y, we propagate the assumptions on x to y. E.g. we have x==y and
x<5. Then, this method adds y<5 to the assumptions.

:param x: A symbol.
:param y: Another symbol equal to x.
:param condensed_assumptions: Current assumptions over all symbols.
"""
if x == y:
return
assum_x = condensed_assumptions[x]
if y not in condensed_assumptions:
condensed_assumptions[y] = Assumptions()
assum_y = condensed_assumptions[y]
for e in assum_x.equal:
if e is not sp.Symbol(y):
assum_y.add_equal(e)
for g in assum_x.greater:
assum_y.add_greater(g)
for l in assum_x.lesser:
assum_y.add_lesser(l)
assum_y.check_consistency()


def propagate_assumptions_equal_symbols(condensed_assumptions):
"""
This method handles two things: 1) It generates the substitution dict for all equality assumptions.
And 2) it propagates assumptions too all equal symbols. For each equivalence class, we find a unique
representative using UnionFind. Then, all assumptions get propagates to this symbol using
``propagate_assumptions``.

:param condensed_assumptions: Current assumptions over all symbols.
:return: Returns a tuple consisting of 2 substitution dicts. The first one replaces each symbol with
the unique representative of its equivalence class. The second dict replaces each symbol with its numeric
value (if we assume it to be equal some value, e.g. N==5).
"""
# Make one set with unique identifier for each equality class
uf = UnionFind(list(condensed_assumptions))
for sym in condensed_assumptions:
for other in condensed_assumptions[sym].equal:
if isinstance(other, sp.Symbol):
# we assume sym == other --> union these
uf.union(sym, other.name)

equality_subs1 = {}

# For each equivalence class, we now have one unique identifier.
# For each class, we give all the assumptions to this single symbol.
# And we swap each symbol in class for this symbol.
for sym in list(condensed_assumptions):
for other in condensed_assumptions[sym].equal:
if isinstance(other, sp.Symbol):
propagate_assumptions(sym, uf.find(sym), condensed_assumptions)
equality_subs1.update({sym: sp.Symbol(uf.find(sym))})

equality_subs2 = {}
# In a second step, each symbol gets replace with its equal number (if present)
# using equality_subs2.
for sym, assum in condensed_assumptions.items():
for e in assum.equal:
if not isinstance(e, sp.Symbol):
equality_subs2.update({sym: e})

# Imagine we have M>N and M==10. We need to deduce N<10 from that. Following code handles that:
for sym, assum in condensed_assumptions.items():
for g in assum.greater:
if isinstance(g, sp.Symbol):
for e in condensed_assumptions[g.name].equal:
if not isinstance(e, sp.Symbol):
condensed_assumptions[sym].add_greater(e)
assum.greater.remove(g)
for l in assum.lesser:
if isinstance(l, sp.Symbol):
for e in condensed_assumptions[l.name].equal:
if not isinstance(e, sp.Symbol):
condensed_assumptions[sym].add_lesser(e)
assum.lesser.remove(l)
return equality_subs1, equality_subs2


def parse_assumptions(assumptions, array_symbols):
"""
Parses a list of assumptions into substitution dictionaries. Firstly, it gathers all assumptions and
keeps only the strongest ones. Afterwards it constructs two substitution dicts for the equality
assumptions: First dict for symbol==symbol assumptions; second dict for symbol==number assumptions.
The other assumptions get handles by N tuples of substitution dicts (N = max number of concurrent
assumptions for a single symbol). Each tuple is responsible for at most one assumption for each symbol.
First dict in the tuple substitutes the symbol with the assumption; second dict restores the initial symbol.

:param assumptions: List of assumption strings.
:param array_symbols: List of symbols we assume to be positive, since they are the size of a data container.
:return: Tuple consisting of the 2 dicts responsible for the equality assumptions and the list of size N
reponsible for all other assumptions.
"""

# TODO: This assumptions system can be improved further, especially the deduction of further assumptions
# from the ones we already have. An example of what is not working currently:
# We have assumptions N>0 N<5 and M>5.
# In the first substitution round we use N>0 and M>5.
# In the second substitution round we use N<5.
# Therefore, Max(M, N) will not be evaluated to M, even though from the input assumptions
# one can clearly deduce M>N.
# This happens since N<5 and M>5 are not in the same substitution round.
# The easiest way to fix this is probably to actually deduce the M>N assumption.
# This guarantees that in some substitution round, we will replace M with N + _p_M, where
# _p_M is some positive symbol. Hence, we would resolve Max(M, N) to N + _p_M, which is M.

# I suspect there to be many more cases where further assumptions will not be deduced properly.
# But if the user enters assumptions as explicitly as possible, e.g. N<5 M>5 M>N, then everything
# works fine.

# For each symbol x appearing as a data container size, we can assume x>0.
# TODO (later): Analyze size of shapes more, such that e.g. shape N + 1 --> We can assume N > -1.
# For now we only extract assumptions out of shapes if shape consists of only a single symbol.
for sym in array_symbols:
assumptions.append(f'{sym.name}>0')

if assumptions is None:
return {}, [({}, {})]

# Gather assumptions, keeping only the strongest ones for each symbol.
condensed_assumptions: Dict[str, Assumptions] = {}
for a in assumptions:
if '==' in a:
symbol, rhs = a.split('==')
if symbol not in condensed_assumptions:
condensed_assumptions[symbol] = Assumptions()
try:
condensed_assumptions[symbol].add_equal(int(rhs))
except ValueError:
condensed_assumptions[symbol].add_equal(sp.Symbol(rhs))
elif '>' in a:
symbol, rhs = a.split('>')
if symbol not in condensed_assumptions:
condensed_assumptions[symbol] = Assumptions()
try:
condensed_assumptions[symbol].add_greater(int(rhs))
except ValueError:
condensed_assumptions[symbol].add_greater(sp.Symbol(rhs))
# add the opposite, i.e. for x>y, we add y<x
if rhs not in condensed_assumptions:
condensed_assumptions[rhs] = Assumptions()
condensed_assumptions[rhs].add_lesser(sp.Symbol(symbol))
elif '<' in a:
symbol, rhs = a.split('<')
if symbol not in condensed_assumptions:
condensed_assumptions[symbol] = Assumptions()
try:
condensed_assumptions[symbol].add_lesser(int(rhs))
except ValueError:
condensed_assumptions[symbol].add_lesser(sp.Symbol(rhs))
# add the opposite, i.e. for x<y, we add y>x
if rhs not in condensed_assumptions:
condensed_assumptions[rhs] = Assumptions()
condensed_assumptions[rhs].add_greater(sp.Symbol(symbol))

# Handle equal assumptions.
equality_subs = propagate_assumptions_equal_symbols(condensed_assumptions)

# How many assumptions does symbol with most assumptions have?
curr_max = -1
for _, assum in condensed_assumptions.items():
if assum.num_assumptions() > curr_max:
curr_max = assum.num_assumptions()

all_subs = []
for i in range(curr_max):
all_subs.append(({}, {}))

# Construct all the substitution dicts. In each substitution round we take at most one assumption for each
# symbol. Each round has two dicts: First one swaps in the assumption and second one restores the initial
# symbol.
for sym, assum in condensed_assumptions.items():
i = 0
for g in assum.greater:
replacement_symbol = sp.Symbol(f'_p_{sym}', positive=True, integer=True)
all_subs[i][0].update({sp.Symbol(sym): replacement_symbol + g})
all_subs[i][1].update({replacement_symbol: sp.Symbol(sym) - g})
i += 1
for l in assum.lesser:
replacement_symbol = sp.Symbol(f'_n_{sym}', negative=True, integer=True)
all_subs[i][0].update({sp.Symbol(sym): replacement_symbol + l})
all_subs[i][1].update({replacement_symbol: sp.Symbol(sym) - l})
i += 1

return equality_subs, all_subs
2 changes: 2 additions & 0 deletions dace/sdfg/work_depth_analysis/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,4 +328,6 @@ def find_loop_guards_tails_exits(sdfg_nx: nx.DiGraph):
# now we have a triple (node, oNode, exitCandidates)
nodes_oNodes_exits.append((node, oNode, exitCandidates))

# remove artificial end node
sdfg_nx.remove_node(artificial_end_node)
return nodes_oNodes_exits
Loading
Loading