Skip to content

Commit

Permalink
TypeTree speedup shiftindicies (#1744)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Feb 20, 2024
1 parent d04f103 commit 1ea2e0b
Showing 1 changed file with 152 additions and 21 deletions.
173 changes: 152 additions & 21 deletions enzyme/Enzyme/TypeAnalysis/TypeTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -724,17 +724,51 @@ class TypeTree : public std::enable_shared_from_this<TypeTree> {
}

/// Replace mappings in the range in [offset, offset+maxSize] with those in
// [addOffset, addOffset + maxSize]. In other worse, select all mappings in
// [addOffset, addOffset + maxSize]. In other words, select all mappings in
// [offset, offset+maxSize] then add `addOffset`
TypeTree ShiftIndices(const llvm::DataLayout &dl, const int offset,
const int maxSize, size_t addOffset = 0) const {

// If we have no terms 1+ layer deep return the current result as a shift
// won't change anything. This also makes the latercode simpler as it
// can assume at least a first index exists.
if (minIndices.size() == 0)
return *this;

// If we have no size in return, simply return an empty type tree. Again
// this simplifies later code which can assume that a minus one expantion
// will always result in an added variable (which would not be the case
// on a size == 0).
if (maxSize == 0)
return TypeTree();

TypeTree Result;

// The normal orIn / insert methods do collision checking, which is slow
// (and presently O(n)). This is because an expansion of a -1 which could
// conflict with a fixed value. Consider calling this
// ShiftIndicies(offset=0, maxSize=2, addOffset=0, tt={[-1]:Integer,
// [1]:Anything}) the -1 would expand to [0]:Int, [1]:Int, which would need
// to be merged with [1]:Anything
//
// The only possible values which can cause a conflict are minus -1's.
// As a result, we start with a fast insertion (aka without check) of
// non-expanded values, since they just do a literal shift which needs no
// extra checking, besides bounds checks.
//
// Since we're doing things manually, we also need to manually preserve TT
// invariants. Specifically, TT limits all values to have offsets <
// MAX_OFFSET, unless it is the smallest offset at that depth. (e.g. so we
// can still hava typetree {[123456]:Int}, even if limit is 100).
//
// First compute the minimum 0th index to be kept.
Result.minIndices.resize(minIndices.size(), INT_MAX);

for (const auto &pair : mapping) {
if (pair.first.size() == 0) {
if (pair.second == BaseType::Pointer ||
pair.second == BaseType::Anything) {
Result.insert(pair.first, pair.second);
Result.mapping.emplace(pair.first, pair.second);
continue;
}

Expand All @@ -743,55 +777,152 @@ class TypeTree : public std::enable_shared_from_this<TypeTree> {
llvm_unreachable("ShiftIndices called on a nonpointer/anything");
}

std::vector<int> next(pair.first);
int next0 = pair.first[0];

if (next0 == -1) {
if (maxSize == -1) {
// Max size does not clip the next index

// If we have a follow up offset add, we lose the -1 since we only
// represent [0, inf) with -1 not the [addOffset, inf) required here
if (addOffset != 0) {
next0 = addOffset;
}

} else {
// We're going to insert addOffset + 0...maxSize so the new minIndex
// is addOffset
Result.minIndices[0] = addOffset;
for (size_t i = 1, sz = pair.first.size(); i < sz; i++)
if (pair.first[i] < Result.minIndices[i])
Result.minIndices[i] = pair.first[i];
continue;
}
} else {
// Too small for range
if (next0 < offset) {
continue;
}
next0 -= offset;

if (maxSize != -1) {
if (next0 >= maxSize)
continue;
}

next0 += addOffset;
}
if (next0 < Result.minIndices[0])
Result.minIndices[0] = next0;
for (size_t i = 1, sz = pair.first.size(); i < sz; i++)
if (pair.first[i] < Result.minIndices[i])
Result.minIndices[i] = pair.first[i];
}

// Max depth of actual inserted values
size_t maxInsertedDepth = 0;

// Insert all
for (const auto &pair : mapping) {
if (pair.first.size() == 0)
continue;

if (next[0] == -1) {
int next0 = pair.first[0];

if (next0 == -1) {
if (maxSize == -1) {
// Max size does not clip the next index

// If we have a follow up offset add, we lose the -1 since we only
// represent [0, inf) with -1 not the [addOffset, inf) required here
if (addOffset != 0) {
next[0] = addOffset;
next0 = addOffset;
}

} else {
// This needs to become 0...maxSize as seen below
// This needs to become 0...maxSize handled separately as it is the
// only insertion that could have collisions
continue;
}
} else {
// Too small for range
if (next[0] < offset) {
if (next0 < offset) {
continue;
}
next[0] -= offset;
next0 -= offset;

if (maxSize != -1) {
if (next[0] >= maxSize)
if (next0 >= maxSize)
continue;
}

next[0] += addOffset;
next0 += addOffset;
}

size_t chunk = 1;
auto op = operator[]({pair.first[0]});
if (auto flt = op.isFloat()) {
chunk = dl.getTypeSizeInBits(flt) / 8;
} else if (op == BaseType::Pointer) {
chunk = dl.getPointerSizeInBits() / 8;
// If after moving this would not merit being kept for being a min index
// or being within the max type offset, skip it.
if (next0 > MaxTypeOffset) {
bool minIndex = next0 == Result.minIndices[0];
if (!minIndex)
for (size_t i = 1; i < pair.first.size(); i++) {
if (pair.first[i] == Result.minIndices[i]) {
minIndex = true;
break;
}
}
if (!minIndex)
continue;
}

if (next[0] == -1 && maxSize != -1) {
std::vector<int> next(pair.first);
next[0] = next0;
Result.mapping.emplace(next, pair.second);
if (next.size() > maxInsertedDepth)
maxInsertedDepth = next.size();
}

// Insert and expand the minus one, if needed
if (maxSize != -1)
for (const auto &pair : mapping) {
if (pair.first.size() == 0)
continue;
if (pair.first[0] != -1)
continue;

size_t chunk = 1;
std::vector<int> next(pair.first);
auto op = operator[]({next[0]});
if (auto flt = op.isFloat()) {
chunk = dl.getTypeSizeInBits(flt) / 8;
} else if (op == BaseType::Pointer) {
chunk = dl.getPointerSizeInBits() / 8;
}
auto offincr = (chunk - offset % chunk) % chunk;
bool inserted = false;
for (int i = offincr; i < maxSize; i += chunk) {
next[0] = i + addOffset;
Result.orIn(next, pair.second);
ConcreteType prev(pair.second);
// We can use faster checks here, since we know there can be no
// -1's that we would conflict with, only conflicts from previous
// fixed value insertions.
auto found = Result.mapping.find(next);
if (found != Result.mapping.end()) {
// orIn returns if changed, update the value in the map if so
// with the new value.
if (prev.orIn(found->second, /*pointerIntSame*/ false))
found->second = prev;
} else {
Result.mapping.emplace(next, pair.second);
}
inserted = true;
}
} else {
Result.orIn(next, pair.second);
if (inserted && next.size() > maxInsertedDepth)
maxInsertedDepth = next.size();
}
}

// Resize minIndices down if we dropped any higher-depth indices for being
// out of scope.
Result.minIndices.resize(maxInsertedDepth);
return Result;
}

Expand Down

0 comments on commit 1ea2e0b

Please sign in to comment.