Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir] Define generic set/map lattices for dataflow analyses #1769

Merged
merged 1 commit into from
Mar 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions enzyme/Enzyme/MLIR/Analysis/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ add_mlir_library(MLIREnzymeAnalysis
ActivityAnalysis.cpp
DataFlowAliasAnalysis.cpp
DataFlowActivityAnalysis.cpp
DataFlowLattice.cpp

DEPENDS
MLIRAutoDiffTypeInterfaceIncGen
Expand Down
10 changes: 5 additions & 5 deletions enzyme/Enzyme/MLIR/Analysis/DataFlowActivityAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ std::optional<Value> getCopySource(Operation *op) {
/// If the classes are undefined, the callback will not be called at all.
void forEachAliasedAlloc(const AliasClassLattice *ptrAliasClass,
function_ref<void(DistinctAttr)> forEachFn) {
(void)ptrAliasClass->getAliasClassesObject().foreachClass(
(void)ptrAliasClass->getAliasClassesObject().foreachElement(
[&](DistinctAttr alloc, enzyme::AliasClassSet::State state) {
if (state != enzyme::AliasClassSet::State::Undefined)
forEachFn(alloc);
Expand Down Expand Up @@ -636,7 +636,7 @@ class DenseForwardActivityAnalysis
continue;
auto *argAliasClasses = getOrCreateFor<AliasClassLattice>(block, arg);
ChangeResult changed =
argAliasClasses->getAliasClassesObject().foreachClass(
argAliasClasses->getAliasClassesObject().foreachElement(
[lattice](DistinctAttr argAliasClass,
enzyme::AliasClassSet::State state) {
if (state == enzyme::AliasClassSet::State::Undefined)
Expand Down Expand Up @@ -687,7 +687,7 @@ class DenseBackwardActivityAnalysis
}
auto *argAliasClasses = getOrCreateFor<AliasClassLattice>(op, arg);
ChangeResult changed =
argAliasClasses->getAliasClassesObject().foreachClass(
argAliasClasses->getAliasClassesObject().foreachElement(
[before](DistinctAttr argAliasClass,
enzyme::AliasClassSet::State state) {
if (state == enzyme::AliasClassSet::State::Undefined)
Expand All @@ -703,7 +703,7 @@ class DenseBackwardActivityAnalysis
auto *retAliasClasses =
getOrCreateFor<AliasClassLattice>(op, operand);
ChangeResult changed =
retAliasClasses->getAliasClassesObject().foreachClass(
retAliasClasses->getAliasClassesObject().foreachElement(
[before](DistinctAttr retAliasClass,
enzyme::AliasClassSet::State state) {
if (state == enzyme::AliasClassSet::State::Undefined)
Expand Down Expand Up @@ -854,7 +854,7 @@ void printActivityAnalysisResults(const DataFlowSolver &solver,
std::deque<DistinctAttr> frontier;
DenseSet<DistinctAttr> visited;
auto scheduleVisit = [&](const enzyme::AliasClassSet &aliasClasses) {
(void)aliasClasses.foreachClass(
(void)aliasClasses.foreachElement(
[&](DistinctAttr neighbor, enzyme::AliasClassSet::State state) {
assert(neighbor &&
"unhandled undefined/unknown case before visit");
Expand Down
190 changes: 31 additions & 159 deletions enzyme/Enzyme/MLIR/Analysis/DataFlowAliasAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,67 +49,6 @@ static bool isPointerLike(Type type) {
return isa<MemRefType, LLVM::LLVMPointerType>(type);
}

const enzyme::AliasClassSet enzyme::AliasClassSet::undefinedSet =
AliasClassSet(enzyme::AliasClassSet::State::Undefined);
const enzyme::AliasClassSet enzyme::AliasClassSet::unknownSet =
AliasClassSet(enzyme::AliasClassSet::State::Unknown);

ChangeResult enzyme::AliasClassSet::join(const AliasClassSet &other) {
if (isUnknown())
return ChangeResult::NoChange;
if (isUndefined() && other.isUndefined())
return ChangeResult::NoChange;
if (other.isUnknown()) {
state = State::Unknown;
return ChangeResult::Change;
}

ChangeResult result = updateStateToDefined();
return insert(other.aliasClasses) | result;
}

ChangeResult
enzyme::AliasClassSet::insert(const DenseSet<DistinctAttr> &classes) {
if (isUnknown())
return ChangeResult::NoChange;

size_t oldSize = aliasClasses.size();
aliasClasses.insert(classes.begin(), classes.end());
ChangeResult result = aliasClasses.size() == oldSize ? ChangeResult::NoChange
: ChangeResult::Change;
return updateStateToDefined() | result;
}

ChangeResult enzyme::AliasClassSet::markUnknown() {
if (isUnknown())
return ChangeResult::NoChange;

state = State::Unknown;
aliasClasses.clear();
return ChangeResult::Change;
}

bool enzyme::AliasClassSet::isCanonical() const {
return state == State::Defined || aliasClasses.empty();
}

bool enzyme::AliasClassSet::operator==(
const enzyme::AliasClassSet &other) const {
assert(isCanonical() && other.isCanonical());
return state == other.state && llvm::equal(aliasClasses, other.aliasClasses);
}

ChangeResult enzyme::AliasClassSet::foreachClass(
function_ref<ChangeResult(DistinctAttr, State)> callback) const {
if (state != State::Defined)
return callback(nullptr, state);

ChangeResult result = ChangeResult::NoChange;
for (DistinctAttr attr : aliasClasses)
result |= callback(attr, state);
return result;
}

//===----------------------------------------------------------------------===//
// PointsToAnalysis
//===----------------------------------------------------------------------===//
Expand All @@ -122,70 +61,21 @@ static ChangeResult mergeSets(DenseSet<T> &dest, const DenseSet<T> &src) {
}

void enzyme::PointsToSets::print(raw_ostream &os) const {
if (pointsTo.empty()) {
if (map.empty()) {
os << "<empty>\n";
return;
}
for (const auto &[srcClass, destClasses] : pointsTo) {
for (const auto &[srcClass, destClasses] : map) {
os << " " << srcClass << " points to {";
if (destClasses.isUnknown()) {
os << "<unknown>";
} else if (destClasses.isUndefined()) {
os << "<undefined>";
} else {
llvm::interleaveComma(destClasses.getAliasClasses(), os);
llvm::interleaveComma(destClasses.getElements(), os);
}
os << "}\n";
}
// os << "other points to unknown: " << otherPointToUnknown << "\n";
}

/// Union for every variable.
ChangeResult enzyme::PointsToSets::join(const AbstractDenseLattice &lattice) {
const auto &rhs = static_cast<const PointsToSets &>(lattice);
llvm::SmallDenseSet<DistinctAttr> keys;
auto lhsRange = llvm::make_first_range(pointsTo);
auto rhsRange = llvm::make_first_range(rhs.pointsTo);
keys.insert(lhsRange.begin(), lhsRange.end());
keys.insert(rhsRange.begin(), rhsRange.end());

ChangeResult result = ChangeResult::NoChange;
for (DistinctAttr key : keys) {
auto lhsIt = pointsTo.find(key);
auto rhsIt = rhs.pointsTo.find(key);
assert(lhsIt != pointsTo.end() || rhsIt != rhs.pointsTo.end());

// If present in both, join.
if (lhsIt != pointsTo.end() && rhsIt != rhs.pointsTo.end()) {
result |= lhsIt->getSecond().join(rhsIt->getSecond());
continue;
}

// Copy from RHS if available only there.
if (lhsIt == pointsTo.end()) {
pointsTo.try_emplace(rhsIt->getFirst(), rhsIt->getSecond());
result = ChangeResult::Change;
}

// Do nothing if available only in LHS.
}
return result;
}

ChangeResult
enzyme::PointsToSets::joinPotentiallyMissing(DistinctAttr key,
const AliasClassSet &value) {
// Don't store explicitly undefined values in the mapping, keys absent from
// the mapping are treated as implicitly undefined.
if (value.isUndefined())
return ChangeResult::NoChange;

bool inserted;
decltype(pointsTo.begin()) iterator;
std::tie(iterator, inserted) = pointsTo.try_emplace(key, value);
if (!inserted)
return iterator->second.join(value);
return ChangeResult::Change;
}

ChangeResult enzyme::PointsToSets::update(const AliasClassSet &keysToUpdate,
Expand All @@ -198,14 +88,14 @@ ChangeResult enzyme::PointsToSets::update(const AliasClassSet &keysToUpdate,
if (keysToUpdate.isUndefined())
return ChangeResult::NoChange;

return keysToUpdate.foreachClass(
return keysToUpdate.foreachElement(
[&](DistinctAttr dest, AliasClassSet::State state) {
assert(state == AliasClassSet::State::Defined &&
"unknown must have been handled above");
#ifndef NDEBUG
if (replace) {
auto it = pointsTo.find(dest);
if (it != pointsTo.end()) {
auto it = map.find(dest);
if (it != map.end()) {
// Check that we are updating to a state that's >= in the
// lattice.
// TODO: consider a stricter check that we only replace unknown
Expand Down Expand Up @@ -242,17 +132,17 @@ enzyme::PointsToSets::addSetsFrom(const AliasClassSet &destClasses,
if (destClasses.isUndefined())
return ChangeResult::NoChange;

return destClasses.foreachClass(
return destClasses.foreachElement(
[&](DistinctAttr dest, AliasClassSet::State destState) {
assert(destState == AliasClassSet::State::Defined);
return srcClasses.foreachClass(
return srcClasses.foreachElement(
[&](DistinctAttr src, AliasClassSet::State srcState) {
const AliasClassSet *srcClasses = &AliasClassSet::getUndefined();
if (srcState == AliasClassSet::State::Unknown)
srcClasses = &AliasClassSet::getUnknown();
else if (srcState == AliasClassSet::State::Defined) {
auto it = pointsTo.find(src);
if (it != pointsTo.end())
auto it = map.find(src);
if (it != map.end())
srcClasses = &it->getSecond();
}
return joinPotentiallyMissing(dest, *srcClasses);
Expand All @@ -267,16 +157,10 @@ enzyme::PointsToSets::markPointToUnknown(const AliasClassSet &destClasses) {
if (destClasses.isUndefined())
return ChangeResult::NoChange;

return destClasses.foreachClass([&](DistinctAttr dest, AliasClassSet::State) {
return joinPotentiallyMissing(dest, AliasClassSet::getUnknown());
});
}

ChangeResult enzyme::PointsToSets::markAllPointToUnknown() {
ChangeResult result = ChangeResult::NoChange;
for (auto &it : pointsTo)
result |= it.getSecond().join(AliasClassSet::getUnknown());
return result;
return destClasses.foreachElement(
[&](DistinctAttr dest, AliasClassSet::State) {
return joinPotentiallyMissing(dest, AliasClassSet::getUnknown());
});
}

ChangeResult enzyme::PointsToSets::markAllExceptPointToUnknown(
Expand All @@ -285,18 +169,17 @@ ChangeResult enzyme::PointsToSets::markAllExceptPointToUnknown(
return ChangeResult::NoChange;

ChangeResult result = ChangeResult::NoChange;
for (auto &[key, value] : pointsTo) {
if (destClasses.isUnknown() ||
!destClasses.getAliasClasses().contains(key)) {
for (auto &[key, value] : map) {
if (destClasses.isUnknown() || !destClasses.getElements().contains(key)) {
result |= value.markUnknown();
}
}

#ifndef NDEBUG
(void)destClasses.foreachClass(
(void)destClasses.foreachElement(
[&](DistinctAttr dest, AliasClassSet::State state) {
if (state == AliasClassSet::State::Defined)
assert(pointsTo.contains(dest) && "unknown dest cannot be preserved");
assert(map.contains(dest) && "unknown dest cannot be preserved");
return ChangeResult::NoChange;
});
#endif // NDEBUG
Expand Down Expand Up @@ -632,7 +515,7 @@ void enzyme::PointsToPointerAnalysis::visitCallControlFlowTransfer(
// Otherwise, indicate that a pointer that belongs to any of the
// classes captured by this function may be stored into the
// destination class.
changed |= destClasses->getAliasClassesObject().foreachClass(
changed |= destClasses->getAliasClassesObject().foreachElement(
[&](DistinctAttr dest, AliasClassSet::State) {
return after->insert(dest, functionMayCapture);
});
Expand Down Expand Up @@ -694,7 +577,7 @@ void enzyme::PointsToPointerAnalysis::visitCallControlFlowTransfer(
!nonWritableOperandClasses.isUndefined()) {
DenseSet<DistinctAttr> nonOperandClasses =
llvm::set_difference(destClasses->getAliasClasses(),
nonWritableOperandClasses.getAliasClasses());
nonWritableOperandClasses.getElements());
(void)resultWithoutNonWritableOperands.insert(nonOperandClasses);
} else {
(void)resultWithoutNonWritableOperands.join(
Expand Down Expand Up @@ -741,25 +624,14 @@ void enzyme::PointsToPointerAnalysis::setToEntryState(PointsToSets *lattice) {}
// AliasClassLattice
//===----------------------------------------------------------------------===//

void enzyme::AliasClassSet::print(raw_ostream &os) const {
if (isUnknown()) {
os << "<unknown>";
} else if (isUndefined()) {
os << "<undefined>";
} else {
llvm::interleaveComma(aliasClasses, os << "{");
os << "}";
}
}

void enzyme::AliasClassLattice::print(raw_ostream &os) const {
if (aliasClasses.isUnknown()) {
if (elements.isUnknown()) {
os << "Unknown AC";
} else if (aliasClasses.isUndefined()) {
} else if (elements.isUndefined()) {
os << "Undefined AC";
} else {
os << "size: " << aliasClasses.getAliasClasses().size() << ":\n";
for (auto aliasClass : aliasClasses.getAliasClasses()) {
os << "size: " << elements.getElements().size() << ":\n";
for (auto aliasClass : elements.getElements()) {
os << " " << aliasClass << "\n";
}
}
Expand All @@ -774,12 +646,12 @@ enzyme::AliasClassLattice::alias(const AbstractSparseLattice &other) const {
if (getPoint() == rhs->getPoint())
return AliasResult::MustAlias;

if (aliasClasses.isUnknown() || rhs->aliasClasses.isUnknown())
if (elements.isUnknown() || rhs->elements.isUnknown())
return AliasResult::MayAlias;

size_t overlap = llvm::count_if(
aliasClasses.getAliasClasses(), [rhs](DistinctAttr aliasClass) {
return rhs->aliasClasses.getAliasClasses().contains(aliasClass);
size_t overlap =
llvm::count_if(elements.getElements(), [rhs](DistinctAttr aliasClass) {
return rhs->elements.getElements().contains(aliasClass);
});

if (overlap == 0)
Expand All @@ -800,7 +672,7 @@ ChangeResult
enzyme::AliasClassLattice::join(const AbstractSparseLattice &other) {
// Set union of the alias classes
const auto *otherAliasClass = static_cast<const AliasClassLattice *>(&other);
return aliasClasses.join(otherAliasClass->aliasClasses);
return elements.join(otherAliasClass->elements);
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -896,7 +768,7 @@ void enzyme::AliasAnalysis::transfer(
continue;
} else {
propagateIfChanged(result,
result->insert(srcPointsTo.getAliasClasses()));
result->insert(srcPointsTo.getElements()));
}
}
}
Expand Down Expand Up @@ -1041,7 +913,7 @@ void enzyme::AliasAnalysis::visitExternalCall(
// If can read from argument, collect the alias classes that can this
// argument may be pointing to.
const auto *pointsToLattice = getOrCreateFor<PointsToSets>(call, call);
(void)srcClasses->getAliasClassesObject().foreachClass(
(void)srcClasses->getAliasClassesObject().foreachElement(
[&](DistinctAttr srcClass, AliasClassSet::State state) {
// Nothing to do in top/bottom case. In the top case, we have already
// set `operandAliasClasses` to top above.
Expand Down
Loading
Loading