Skip to content

Commit

Permalink
TypeAnalysis speed up Canonicalize
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Feb 21, 2024
1 parent d4cb8b3 commit ca1e485
Showing 1 changed file with 115 additions and 1 deletion.
116 changes: 115 additions & 1 deletion enzyme/Enzyme/TypeAnalysis/TypeTree.h
Original file line number Diff line number Diff line change
Expand Up @@ -673,7 +673,7 @@ class TypeTree : public std::enable_shared_from_this<TypeTree> {

/// Given that this tree represents something of at most size len,
/// canonicalize this, creating -1's where possible
void CanonicalizeInPlace(size_t len, const llvm::DataLayout &dl) {
void CanonicalizeInPlaceOld(size_t len, const llvm::DataLayout &dl) {
bool canonicalized = true;
for (const auto &pair : mapping) {
assert(pair.first.size() != 0);
Expand Down Expand Up @@ -755,6 +755,120 @@ class TypeTree : public std::enable_shared_from_this<TypeTree> {
}
}
}

void CanonicalizeInPlace(size_t len, const llvm::DataLayout &dl) {
bool canonicalized = true;
for (const auto &pair : mapping) {
assert(pair.first.size() != 0);
if (pair.first[0] != -1) {
canonicalized = false;
break;
}
}
if (canonicalized)
return;

auto Old = *this;
Old.CanonicalizeInPlaceOld(len, dl);

// Map of indices[1:] => ( End => possible Index[0] )
std::map<const std::vector<int>, std::map<ConcreteType, std::set<int>>>
staging;

for (const auto &pair : mapping) {

std::vector<int> next(pair.first.begin() + 1, pair.first.end());
if (pair.first[0] != -1) {
if ((size_t)pair.first[0] >= len) {
llvm::errs() << str() << "\n";
llvm::errs() << " canonicalizing " << len << "\n";
}
assert((size_t)pair.first[0] < len);
}
staging[next][pair.second].insert(pair.first[0]);
}

// TypeTree mappings which did not get combined
std::map<const std::vector<int>, ConcreteType> unCombinedToAdd;

// TypeTree mappings which did get combined into an outer -1
std::map<const std::vector<int>, ConcreteType> combinedToAdd;

for (const auto &pair : staging) {
auto &pnext = pair.first;
for (const auto &pair2 : pair.second) {
auto dt = pair2.first;
const auto &set = pair2.second;

bool legalCombine = false;

// See if we can canonicalize the outermost index into a -1
if (!set.count(-1)) {
size_t chunk = 1;
if (pnext.size() > 0) {
chunk = dl.getPointerSizeInBits() / 8;
} else {
if (auto flt = dt.isFloat()) {
chunk = dl.getTypeSizeInBits(flt) / 8;
} else if (dt == BaseType::Pointer) {
chunk = dl.getPointerSizeInBits() / 8;
}
}

legalCombine = true;
for (size_t i = 0; i < len; i += chunk) {
if (!set.count(i)) {
legalCombine = false;
break;
}
}
}

std::vector<int> next;
next.reserve(pnext.size() + 1);
next.push_back(-1);
for (auto v : pnext)
next.push_back(v);

if (legalCombine) {
combinedToAdd.emplace(next, dt);
} else {
for (auto e : set) {
next[0] = e;
unCombinedToAdd.emplace(next, dt);
}
}
}
}

// If we combined nothing, just return since there are no
// changes.
if (combinedToAdd.size() == 0) {
assert(*this == Old);
assert(this->minIndices == Old.minIndices);
return;
}

// Non-combined ones do not conflict, since they were already in
// a TT which we can assume contained no conflicts.
mapping = std::move(unCombinedToAdd);
minIndices[0] = -1;

// Fusing several terms into a minus one can create a conflict
// if the prior minus one was already in the map
// time, or also generated by fusion.
// E.g. {-1:Anything, [0]:Pointer} on 8 -> create a [-1]:Pointer
// which conflicts
// Alternatively [-1,-1,-1]:Pointer, and generated a [-1,0,-1] fusion
for (const auto &pair : combinedToAdd) {
insert(pair.first, pair.second);
}

assert(*this == Old);
assert(this->minIndices == Old.minIndices);

return;
}

/// Keep only pointers (or anything's) to a repeated value (represented by -1)
TypeTree KeepMinusOne(bool &legal) const {
Expand Down

0 comments on commit ca1e485

Please sign in to comment.