Skip to content

Commit

Permalink
Define generic lattices for dataflow analyses (#1769)
Browse files Browse the repository at this point in the history
Add LLVM_DUMP_METHOD so that this method does not get stripped for debug builds

Co-authored-by: Tim Gymnich <tgymnich@icloud.com>
  • Loading branch information
pengmai and tgymnich authored Mar 8, 2024
1 parent 0a129ae commit 651b42f
Show file tree
Hide file tree
Showing 7 changed files with 461 additions and 296 deletions.
1 change: 1 addition & 0 deletions enzyme/Enzyme/MLIR/Analysis/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ add_mlir_library(MLIREnzymeAnalysis
ActivityAnalysis.cpp
DataFlowAliasAnalysis.cpp
DataFlowActivityAnalysis.cpp
DataFlowLattice.cpp

DEPENDS
MLIRAutoDiffTypeInterfaceIncGen
Expand Down
10 changes: 5 additions & 5 deletions enzyme/Enzyme/MLIR/Analysis/DataFlowActivityAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ std::optional<Value> getCopySource(Operation *op) {
/// If the classes are undefined, the callback will not be called at all.
void forEachAliasedAlloc(const AliasClassLattice *ptrAliasClass,
function_ref<void(DistinctAttr)> forEachFn) {
(void)ptrAliasClass->getAliasClassesObject().foreachClass(
(void)ptrAliasClass->getAliasClassesObject().foreachElement(
[&](DistinctAttr alloc, enzyme::AliasClassSet::State state) {
if (state != enzyme::AliasClassSet::State::Undefined)
forEachFn(alloc);
Expand Down Expand Up @@ -636,7 +636,7 @@ class DenseForwardActivityAnalysis
continue;
auto *argAliasClasses = getOrCreateFor<AliasClassLattice>(block, arg);
ChangeResult changed =
argAliasClasses->getAliasClassesObject().foreachClass(
argAliasClasses->getAliasClassesObject().foreachElement(
[lattice](DistinctAttr argAliasClass,
enzyme::AliasClassSet::State state) {
if (state == enzyme::AliasClassSet::State::Undefined)
Expand Down Expand Up @@ -687,7 +687,7 @@ class DenseBackwardActivityAnalysis
}
auto *argAliasClasses = getOrCreateFor<AliasClassLattice>(op, arg);
ChangeResult changed =
argAliasClasses->getAliasClassesObject().foreachClass(
argAliasClasses->getAliasClassesObject().foreachElement(
[before](DistinctAttr argAliasClass,
enzyme::AliasClassSet::State state) {
if (state == enzyme::AliasClassSet::State::Undefined)
Expand All @@ -703,7 +703,7 @@ class DenseBackwardActivityAnalysis
auto *retAliasClasses =
getOrCreateFor<AliasClassLattice>(op, operand);
ChangeResult changed =
retAliasClasses->getAliasClassesObject().foreachClass(
retAliasClasses->getAliasClassesObject().foreachElement(
[before](DistinctAttr retAliasClass,
enzyme::AliasClassSet::State state) {
if (state == enzyme::AliasClassSet::State::Undefined)
Expand Down Expand Up @@ -854,7 +854,7 @@ void printActivityAnalysisResults(const DataFlowSolver &solver,
std::deque<DistinctAttr> frontier;
DenseSet<DistinctAttr> visited;
auto scheduleVisit = [&](const enzyme::AliasClassSet &aliasClasses) {
(void)aliasClasses.foreachClass(
(void)aliasClasses.foreachElement(
[&](DistinctAttr neighbor, enzyme::AliasClassSet::State state) {
assert(neighbor &&
"unhandled undefined/unknown case before visit");
Expand Down
190 changes: 31 additions & 159 deletions enzyme/Enzyme/MLIR/Analysis/DataFlowAliasAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,67 +49,6 @@ static bool isPointerLike(Type type) {
return isa<MemRefType, LLVM::LLVMPointerType>(type);
}

const enzyme::AliasClassSet enzyme::AliasClassSet::undefinedSet =
AliasClassSet(enzyme::AliasClassSet::State::Undefined);
const enzyme::AliasClassSet enzyme::AliasClassSet::unknownSet =
AliasClassSet(enzyme::AliasClassSet::State::Unknown);

ChangeResult enzyme::AliasClassSet::join(const AliasClassSet &other) {
if (isUnknown())
return ChangeResult::NoChange;
if (isUndefined() && other.isUndefined())
return ChangeResult::NoChange;
if (other.isUnknown()) {
state = State::Unknown;
return ChangeResult::Change;
}

ChangeResult result = updateStateToDefined();
return insert(other.aliasClasses) | result;
}

ChangeResult
enzyme::AliasClassSet::insert(const DenseSet<DistinctAttr> &classes) {
if (isUnknown())
return ChangeResult::NoChange;

size_t oldSize = aliasClasses.size();
aliasClasses.insert(classes.begin(), classes.end());
ChangeResult result = aliasClasses.size() == oldSize ? ChangeResult::NoChange
: ChangeResult::Change;
return updateStateToDefined() | result;
}

ChangeResult enzyme::AliasClassSet::markUnknown() {
if (isUnknown())
return ChangeResult::NoChange;

state = State::Unknown;
aliasClasses.clear();
return ChangeResult::Change;
}

bool enzyme::AliasClassSet::isCanonical() const {
return state == State::Defined || aliasClasses.empty();
}

bool enzyme::AliasClassSet::operator==(
const enzyme::AliasClassSet &other) const {
assert(isCanonical() && other.isCanonical());
return state == other.state && llvm::equal(aliasClasses, other.aliasClasses);
}

ChangeResult enzyme::AliasClassSet::foreachClass(
function_ref<ChangeResult(DistinctAttr, State)> callback) const {
if (state != State::Defined)
return callback(nullptr, state);

ChangeResult result = ChangeResult::NoChange;
for (DistinctAttr attr : aliasClasses)
result |= callback(attr, state);
return result;
}

//===----------------------------------------------------------------------===//
// PointsToAnalysis
//===----------------------------------------------------------------------===//
Expand All @@ -122,70 +61,21 @@ static ChangeResult mergeSets(DenseSet<T> &dest, const DenseSet<T> &src) {
}

void enzyme::PointsToSets::print(raw_ostream &os) const {
if (pointsTo.empty()) {
if (map.empty()) {
os << "<empty>\n";
return;
}
for (const auto &[srcClass, destClasses] : pointsTo) {
for (const auto &[srcClass, destClasses] : map) {
os << " " << srcClass << " points to {";
if (destClasses.isUnknown()) {
os << "<unknown>";
} else if (destClasses.isUndefined()) {
os << "<undefined>";
} else {
llvm::interleaveComma(destClasses.getAliasClasses(), os);
llvm::interleaveComma(destClasses.getElements(), os);
}
os << "}\n";
}
// os << "other points to unknown: " << otherPointToUnknown << "\n";
}

/// Union for every variable.
ChangeResult enzyme::PointsToSets::join(const AbstractDenseLattice &lattice) {
const auto &rhs = static_cast<const PointsToSets &>(lattice);
llvm::SmallDenseSet<DistinctAttr> keys;
auto lhsRange = llvm::make_first_range(pointsTo);
auto rhsRange = llvm::make_first_range(rhs.pointsTo);
keys.insert(lhsRange.begin(), lhsRange.end());
keys.insert(rhsRange.begin(), rhsRange.end());

ChangeResult result = ChangeResult::NoChange;
for (DistinctAttr key : keys) {
auto lhsIt = pointsTo.find(key);
auto rhsIt = rhs.pointsTo.find(key);
assert(lhsIt != pointsTo.end() || rhsIt != rhs.pointsTo.end());

// If present in both, join.
if (lhsIt != pointsTo.end() && rhsIt != rhs.pointsTo.end()) {
result |= lhsIt->getSecond().join(rhsIt->getSecond());
continue;
}

// Copy from RHS if available only there.
if (lhsIt == pointsTo.end()) {
pointsTo.try_emplace(rhsIt->getFirst(), rhsIt->getSecond());
result = ChangeResult::Change;
}

// Do nothing if available only in LHS.
}
return result;
}

ChangeResult
enzyme::PointsToSets::joinPotentiallyMissing(DistinctAttr key,
const AliasClassSet &value) {
// Don't store explicitly undefined values in the mapping, keys absent from
// the mapping are treated as implicitly undefined.
if (value.isUndefined())
return ChangeResult::NoChange;

bool inserted;
decltype(pointsTo.begin()) iterator;
std::tie(iterator, inserted) = pointsTo.try_emplace(key, value);
if (!inserted)
return iterator->second.join(value);
return ChangeResult::Change;
}

ChangeResult enzyme::PointsToSets::update(const AliasClassSet &keysToUpdate,
Expand All @@ -198,14 +88,14 @@ ChangeResult enzyme::PointsToSets::update(const AliasClassSet &keysToUpdate,
if (keysToUpdate.isUndefined())
return ChangeResult::NoChange;

return keysToUpdate.foreachClass(
return keysToUpdate.foreachElement(
[&](DistinctAttr dest, AliasClassSet::State state) {
assert(state == AliasClassSet::State::Defined &&
"unknown must have been handled above");
#ifndef NDEBUG
if (replace) {
auto it = pointsTo.find(dest);
if (it != pointsTo.end()) {
auto it = map.find(dest);
if (it != map.end()) {
// Check that we are updating to a state that's >= in the
// lattice.
// TODO: consider a stricter check that we only replace unknown
Expand Down Expand Up @@ -242,17 +132,17 @@ enzyme::PointsToSets::addSetsFrom(const AliasClassSet &destClasses,
if (destClasses.isUndefined())
return ChangeResult::NoChange;

return destClasses.foreachClass(
return destClasses.foreachElement(
[&](DistinctAttr dest, AliasClassSet::State destState) {
assert(destState == AliasClassSet::State::Defined);
return srcClasses.foreachClass(
return srcClasses.foreachElement(
[&](DistinctAttr src, AliasClassSet::State srcState) {
const AliasClassSet *srcClasses = &AliasClassSet::getUndefined();
if (srcState == AliasClassSet::State::Unknown)
srcClasses = &AliasClassSet::getUnknown();
else if (srcState == AliasClassSet::State::Defined) {
auto it = pointsTo.find(src);
if (it != pointsTo.end())
auto it = map.find(src);
if (it != map.end())
srcClasses = &it->getSecond();
}
return joinPotentiallyMissing(dest, *srcClasses);
Expand All @@ -267,16 +157,10 @@ enzyme::PointsToSets::markPointToUnknown(const AliasClassSet &destClasses) {
if (destClasses.isUndefined())
return ChangeResult::NoChange;

return destClasses.foreachClass([&](DistinctAttr dest, AliasClassSet::State) {
return joinPotentiallyMissing(dest, AliasClassSet::getUnknown());
});
}

ChangeResult enzyme::PointsToSets::markAllPointToUnknown() {
ChangeResult result = ChangeResult::NoChange;
for (auto &it : pointsTo)
result |= it.getSecond().join(AliasClassSet::getUnknown());
return result;
return destClasses.foreachElement(
[&](DistinctAttr dest, AliasClassSet::State) {
return joinPotentiallyMissing(dest, AliasClassSet::getUnknown());
});
}

ChangeResult enzyme::PointsToSets::markAllExceptPointToUnknown(
Expand All @@ -285,18 +169,17 @@ ChangeResult enzyme::PointsToSets::markAllExceptPointToUnknown(
return ChangeResult::NoChange;

ChangeResult result = ChangeResult::NoChange;
for (auto &[key, value] : pointsTo) {
if (destClasses.isUnknown() ||
!destClasses.getAliasClasses().contains(key)) {
for (auto &[key, value] : map) {
if (destClasses.isUnknown() || !destClasses.getElements().contains(key)) {
result |= value.markUnknown();
}
}

#ifndef NDEBUG
(void)destClasses.foreachClass(
(void)destClasses.foreachElement(
[&](DistinctAttr dest, AliasClassSet::State state) {
if (state == AliasClassSet::State::Defined)
assert(pointsTo.contains(dest) && "unknown dest cannot be preserved");
assert(map.contains(dest) && "unknown dest cannot be preserved");
return ChangeResult::NoChange;
});
#endif // NDEBUG
Expand Down Expand Up @@ -632,7 +515,7 @@ void enzyme::PointsToPointerAnalysis::visitCallControlFlowTransfer(
// Otherwise, indicate that a pointer that belongs to any of the
// classes captured by this function may be stored into the
// destination class.
changed |= destClasses->getAliasClassesObject().foreachClass(
changed |= destClasses->getAliasClassesObject().foreachElement(
[&](DistinctAttr dest, AliasClassSet::State) {
return after->insert(dest, functionMayCapture);
});
Expand Down Expand Up @@ -694,7 +577,7 @@ void enzyme::PointsToPointerAnalysis::visitCallControlFlowTransfer(
!nonWritableOperandClasses.isUndefined()) {
DenseSet<DistinctAttr> nonOperandClasses =
llvm::set_difference(destClasses->getAliasClasses(),
nonWritableOperandClasses.getAliasClasses());
nonWritableOperandClasses.getElements());
(void)resultWithoutNonWritableOperands.insert(nonOperandClasses);
} else {
(void)resultWithoutNonWritableOperands.join(
Expand Down Expand Up @@ -741,25 +624,14 @@ void enzyme::PointsToPointerAnalysis::setToEntryState(PointsToSets *lattice) {}
// AliasClassLattice
//===----------------------------------------------------------------------===//

void enzyme::AliasClassSet::print(raw_ostream &os) const {
if (isUnknown()) {
os << "<unknown>";
} else if (isUndefined()) {
os << "<undefined>";
} else {
llvm::interleaveComma(aliasClasses, os << "{");
os << "}";
}
}

void enzyme::AliasClassLattice::print(raw_ostream &os) const {
if (aliasClasses.isUnknown()) {
if (elements.isUnknown()) {
os << "Unknown AC";
} else if (aliasClasses.isUndefined()) {
} else if (elements.isUndefined()) {
os << "Undefined AC";
} else {
os << "size: " << aliasClasses.getAliasClasses().size() << ":\n";
for (auto aliasClass : aliasClasses.getAliasClasses()) {
os << "size: " << elements.getElements().size() << ":\n";
for (auto aliasClass : elements.getElements()) {
os << " " << aliasClass << "\n";
}
}
Expand All @@ -774,12 +646,12 @@ enzyme::AliasClassLattice::alias(const AbstractSparseLattice &other) const {
if (getPoint() == rhs->getPoint())
return AliasResult::MustAlias;

if (aliasClasses.isUnknown() || rhs->aliasClasses.isUnknown())
if (elements.isUnknown() || rhs->elements.isUnknown())
return AliasResult::MayAlias;

size_t overlap = llvm::count_if(
aliasClasses.getAliasClasses(), [rhs](DistinctAttr aliasClass) {
return rhs->aliasClasses.getAliasClasses().contains(aliasClass);
size_t overlap =
llvm::count_if(elements.getElements(), [rhs](DistinctAttr aliasClass) {
return rhs->elements.getElements().contains(aliasClass);
});

if (overlap == 0)
Expand All @@ -800,7 +672,7 @@ ChangeResult
enzyme::AliasClassLattice::join(const AbstractSparseLattice &other) {
// Set union of the alias classes
const auto *otherAliasClass = static_cast<const AliasClassLattice *>(&other);
return aliasClasses.join(otherAliasClass->aliasClasses);
return elements.join(otherAliasClass->elements);
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -896,7 +768,7 @@ void enzyme::AliasAnalysis::transfer(
continue;
} else {
propagateIfChanged(result,
result->insert(srcPointsTo.getAliasClasses()));
result->insert(srcPointsTo.getElements()));
}
}
}
Expand Down Expand Up @@ -1041,7 +913,7 @@ void enzyme::AliasAnalysis::visitExternalCall(
// If can read from argument, collect the alias classes that can this
// argument may be pointing to.
const auto *pointsToLattice = getOrCreateFor<PointsToSets>(call, call);
(void)srcClasses->getAliasClassesObject().foreachClass(
(void)srcClasses->getAliasClassesObject().foreachElement(
[&](DistinctAttr srcClass, AliasClassSet::State state) {
// Nothing to do in top/bottom case. In the top case, we have already
// set `operandAliasClasses` to top above.
Expand Down
Loading

0 comments on commit 651b42f

Please sign in to comment.