Skip to content

Commit

Permalink
Switch PointsTo set to inheriting from MapOfSetsLattice
Browse files Browse the repository at this point in the history
  • Loading branch information
pengmai committed Feb 27, 2024
1 parent 15334ad commit 53410f7
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 161 deletions.
121 changes: 8 additions & 113 deletions enzyme/Enzyme/MLIR/Analysis/AliasAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,55 +62,6 @@ static ChangeResult mergeSets(DenseSet<T> &dest, const DenseSet<T> &src) {
return dest.size() == oldSize ? ChangeResult::NoChange : ChangeResult::Change;
}

Attribute enzyme::PointsToSets::serialize(MLIRContext *ctx) const {
SmallVector<Attribute> pointsToArray;
auto sortKeys = [&](Attribute a, Attribute b) {
auto distinctA = dyn_cast<DistinctAttr>(a);
auto distinctB = dyn_cast<DistinctAttr>(b);
// If not distinct attributes, sort them arbitrarily.
if (!(distinctA && distinctB))
return &a < &b;

auto pseudoA = dyn_cast_if_present<PseudoAliasClassAttr>(
distinctA.getReferencedAttr());
auto pseudoB = dyn_cast_if_present<PseudoAliasClassAttr>(
distinctB.getReferencedAttr());
auto strA = dyn_cast_if_present<StringAttr>(distinctA.getReferencedAttr());
auto strB = dyn_cast_if_present<StringAttr>(distinctB.getReferencedAttr());
if (pseudoA && pseudoB) {
return std::make_pair(pseudoA.getArgNumber(), pseudoA.getDepth()) <
std::make_pair(pseudoB.getArgNumber(), pseudoB.getDepth());
} else if (strA && strB) {
return strA.strref() < strB.strref();
}
// Order pseudo classes before fresh classes
return pseudoA && !pseudoB;
};

for (const auto &[srcClass, destClasses] : pointsTo) {
SmallVector<Attribute, 2> pair = {srcClass};
SmallVector<Attribute, 5> aliasClasses;
if (destClasses.isUnknown()) {
aliasClasses.push_back(StringAttr::get(ctx, "unknown"));
} else if (destClasses.isUndefined()) {
aliasClasses.push_back(StringAttr::get(ctx, "undefined"));
} else {
for (const DistinctAttr &destClass : destClasses.getElements()) {
aliasClasses.push_back(destClass);
}
llvm::sort(aliasClasses, sortKeys);
}
pair.push_back(ArrayAttr::get(ctx, aliasClasses));
pointsToArray.push_back(ArrayAttr::get(ctx, pair));
}
llvm::sort(pointsToArray, [&](Attribute a, Attribute b) {
auto arrA = cast<ArrayAttr>(a);
auto arrB = cast<ArrayAttr>(b);
return sortKeys(arrA[0], arrB[0]);
});
return ArrayAttr::get(ctx, pointsToArray);
}

// TODO: a bit easier to prototype with a dense map directly, evaluate
// if it'd be better to change the PointsToSets data structure to
// support this
Expand Down Expand Up @@ -143,11 +94,11 @@ deserializePointsTo(ArrayAttr summaryAttr,
}

void enzyme::PointsToSets::print(raw_ostream &os) const {
if (pointsTo.empty()) {
if (map.empty()) {
os << "<empty>\n";
return;
}
for (const auto &[srcClass, destClasses] : pointsTo) {
for (const auto &[srcClass, destClasses] : map) {
os << " " << srcClass << " points to {";
if (destClasses.isUnknown()) {
os << "<unknown>";
Expand All @@ -158,55 +109,6 @@ void enzyme::PointsToSets::print(raw_ostream &os) const {
}
os << "}\n";
}
// os << "other points to unknown: " << otherPointToUnknown << "\n";
}

/// Union for every variable.
ChangeResult enzyme::PointsToSets::join(const AbstractDenseLattice &lattice) {
const auto &rhs = static_cast<const PointsToSets &>(lattice);
llvm::SmallDenseSet<DistinctAttr> keys;
auto lhsRange = llvm::make_first_range(pointsTo);
auto rhsRange = llvm::make_first_range(rhs.pointsTo);
keys.insert(lhsRange.begin(), lhsRange.end());
keys.insert(rhsRange.begin(), rhsRange.end());

ChangeResult result = ChangeResult::NoChange;
for (DistinctAttr key : keys) {
auto lhsIt = pointsTo.find(key);
auto rhsIt = rhs.pointsTo.find(key);
assert(lhsIt != pointsTo.end() || rhsIt != rhs.pointsTo.end());

// If present in both, join.
if (lhsIt != pointsTo.end() && rhsIt != rhs.pointsTo.end()) {
result |= lhsIt->getSecond().join(rhsIt->getSecond());
continue;
}

// Copy from RHS if available only there.
if (lhsIt == pointsTo.end()) {
pointsTo.try_emplace(rhsIt->getFirst(), rhsIt->getSecond());
result = ChangeResult::Change;
}

// Do nothing if available only in LHS.
}
return result;
}

ChangeResult
enzyme::PointsToSets::joinPotentiallyMissing(DistinctAttr key,
const AliasClassSet &value) {
// Don't store explicitly undefined values in the mapping, keys absent from
// the mapping are treated as implicitly undefined.
if (value.isUndefined())
return ChangeResult::NoChange;

bool inserted;
decltype(pointsTo.begin()) iterator;
std::tie(iterator, inserted) = pointsTo.try_emplace(key, value);
if (!inserted)
return iterator->second.join(value);
return ChangeResult::Change;
}

ChangeResult enzyme::PointsToSets::update(const AliasClassSet &keysToUpdate,
Expand All @@ -225,8 +127,8 @@ ChangeResult enzyme::PointsToSets::update(const AliasClassSet &keysToUpdate,
"unknown must have been handled above");
#ifndef NDEBUG
if (replace) {
auto it = pointsTo.find(dest);
if (it != pointsTo.end()) {
auto it = map.find(dest);
if (it != map.end()) {
// Check that we are updating to a state that's >= in the
// lattice.
// TODO: consider a stricter check that we only replace unknown
Expand Down Expand Up @@ -272,8 +174,8 @@ enzyme::PointsToSets::addSetsFrom(const AliasClassSet &destClasses,
if (srcState == AliasClassSet::State::Unknown)
srcClasses = &AliasClassSet::getUnknown();
else if (srcState == AliasClassSet::State::Defined) {
auto it = pointsTo.find(src);
if (it != pointsTo.end())
auto it = map.find(src);
if (it != map.end())
srcClasses = &it->getSecond();
}
return joinPotentiallyMissing(dest, *srcClasses);
Expand All @@ -294,20 +196,13 @@ enzyme::PointsToSets::markPointToUnknown(const AliasClassSet &destClasses) {
});
}

ChangeResult enzyme::PointsToSets::markAllPointToUnknown() {
ChangeResult result = ChangeResult::NoChange;
for (auto &it : pointsTo)
result |= it.getSecond().join(AliasClassSet::getUnknown());
return result;
}

ChangeResult enzyme::PointsToSets::markAllExceptPointToUnknown(
const AliasClassSet &destClasses) {
if (destClasses.isUndefined())
return ChangeResult::NoChange;

ChangeResult result = ChangeResult::NoChange;
for (auto &[key, value] : pointsTo) {
for (auto &[key, value] : map) {
if (destClasses.isUnknown() || !destClasses.getElements().contains(key)) {
result |= value.markUnknown();
}
Expand All @@ -317,7 +212,7 @@ ChangeResult enzyme::PointsToSets::markAllExceptPointToUnknown(
(void)destClasses.foreachElement(
[&](DistinctAttr dest, AliasClassSet::State state) {
if (state == AliasClassSet::State::Defined)
assert(pointsTo.contains(dest) && "unknown dest cannot be preserved");
assert(map.contains(dest) && "unknown dest cannot be preserved");
return ChangeResult::NoChange;
});
#endif // NDEBUG
Expand Down
36 changes: 4 additions & 32 deletions enzyme/Enzyme/MLIR/Analysis/AliasAnalysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,17 +103,12 @@ class OriginalClasses {
// pointers stored/loaded through memory.
//===----------------------------------------------------------------------===//

class PointsToSets : public dataflow::AbstractDenseLattice {
class PointsToSets : public MapOfSetsLattice<DistinctAttr, DistinctAttr> {
public:
using AbstractDenseLattice::AbstractDenseLattice;

// Serialize the points-to information of this state into an attribute.
Attribute serialize(MLIRContext *ctx) const;
using MapOfSetsLattice::MapOfSetsLattice;

void print(raw_ostream &os) const override;

ChangeResult join(const AbstractDenseLattice &lattice) override;

/// Mark the pointer stored in `dest` as possibly pointing to any of `values`,
/// instead of the values it may be currently pointing to.
ChangeResult setPointingToClasses(const AliasClassSet &destClasses,
Expand Down Expand Up @@ -141,18 +136,13 @@ class PointsToSets : public dataflow::AbstractDenseLattice {

/// Mark the entire data structure as "unknown", that is, any pointer may be
/// containing any other pointer. This is the full pessimistic fixpoint.
ChangeResult markAllPointToUnknown();
ChangeResult markAllPointToUnknown() { return markAllUnknown(); }

/// Mark all alias classes except the given ones to point to the "unknown"
/// alias set.
ChangeResult markAllExceptPointToUnknown(const AliasClassSet &destClasses);

const AliasClassSet &getPointsTo(DistinctAttr id) const {
auto it = pointsTo.find(id);
if (it == pointsTo.end())
return AliasClassSet::getUndefined();
return it->getSecond();
}
const AliasClassSet &getPointsTo(DistinctAttr id) const { return lookup(id); }

private:
/// Update all alias classes in `keysToUpdate` to additionally point to alias
Expand All @@ -169,24 +159,6 @@ class PointsToSets : public dataflow::AbstractDenseLattice {
/// in the lattice, not only the replacements described above.
ChangeResult update(const AliasClassSet &keysToUpdate,
const AliasClassSet &values, bool replace);

ChangeResult joinPotentiallyMissing(DistinctAttr key,
const AliasClassSet &value);

/// Indicates that alias classes not listed as keys in `pointsTo` point to
/// unknown alias set (when true) or an empty alias set (when false).
// TODO: consider also differentiating between pointing to known-empty vs.
// not-yet-computed.
// bool otherPointToUnknown = false;

// missing from map always beings "undefined", "unknown"s are stored
// explicitly.

/// Maps an identifier of an alias set to the set of alias sets its value may
/// belong to. When an identifier is not present in this map, it is considered
/// to point to either the unknown set or nothing, based on the value of
/// `otherPointToUnknown`.
DenseMap<DistinctAttr, AliasClassSet> pointsTo;
};

//===----------------------------------------------------------------------===//
Expand Down
36 changes: 20 additions & 16 deletions enzyme/Enzyme/MLIR/Analysis/Lattice.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,16 @@
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
// If using this code in an academic setting, please cite the following:
// @incollection{enzymeNeurips,
// title = {Instead of Rewriting Foreign Code for Machine Learning,
// Automatically Synthesize Fast Gradients},
// author = {Moses, William S. and Churavy, Valentin},
// booktitle = {Advances in Neural Information Processing Systems 33},
// year = {2020},
// note = {To appear in},
// @inproceedings{NEURIPS2020_9332c513,
// author = {Moses, William and Churavy, Valentin},
// booktitle = {Advances in Neural Information Processing Systems},
// editor = {H. Larochelle and M. Ranzato and R. Hadsell and M. F. Balcan and H.
// Lin}, pages = {12472--12485}, publisher = {Curran Associates, Inc.}, title =
// {Instead of Rewriting Foreign Code for Machine Learning, Automatically
// Synthesize Fast Gradients}, url =
// {https://proceedings.neurips.cc/paper/2020/file/9332c513ef44b682e9347822c2e457ac-Paper.pdf},
// volume = {33},
// year = {2020}
// }
//
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -275,7 +278,6 @@ class MapOfSetsLattice : public dataflow::AbstractDenseLattice {
return result;
}

// TODO(jacob): switch over the alias class lattices to using these
/// Map all keys to all values.
ChangeResult insert(const SetLattice<KeyT> &keysToUpdate,
const SetLattice<ElementT> &values) {
Expand All @@ -300,6 +302,14 @@ class MapOfSetsLattice : public dataflow::AbstractDenseLattice {
return result;
}

const SetLattice<ElementT> &lookup(KeyT key) const {
auto it = map.find(key);
if (it == map.end())
return SetLattice<ElementT>::getUndefined();
return it->getSecond();
}

protected:
ChangeResult joinPotentiallyMissing(KeyT key,
const SetLattice<ElementT> &value) {
// Don't store explicitly undefined values in the mapping, keys absent from
Expand All @@ -315,14 +325,8 @@ class MapOfSetsLattice : public dataflow::AbstractDenseLattice {
return ChangeResult::Change;
}

const SetLattice<ElementT> &lookup(KeyT key) const {
auto it = map.find(key);
if (it == map.end())
return SetLattice<ElementT>::getUndefined();
return it->getSecond();
}

protected:
/// Maps a key to a set of values. When a key is not present in this map, it
/// is considered to map to an uninitialized set.
DenseMap<KeyT, SetLattice<ElementT>> map;

private:
Expand Down

0 comments on commit 53410f7

Please sign in to comment.