From 53410f7d9cac8fbed6bf6c284c870a84c62724a6 Mon Sep 17 00:00:00 2001 From: Jacob Peng Date: Mon, 26 Feb 2024 23:39:50 -0500 Subject: [PATCH] Switch PointsTo set to inheriting from MapOfSetsLattice --- enzyme/Enzyme/MLIR/Analysis/AliasAnalysis.cpp | 121 ++---------------- enzyme/Enzyme/MLIR/Analysis/AliasAnalysis.h | 36 +----- enzyme/Enzyme/MLIR/Analysis/Lattice.h | 36 +++--- 3 files changed, 32 insertions(+), 161 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Analysis/AliasAnalysis.cpp b/enzyme/Enzyme/MLIR/Analysis/AliasAnalysis.cpp index 4d2731caadf3..660e2df11140 100644 --- a/enzyme/Enzyme/MLIR/Analysis/AliasAnalysis.cpp +++ b/enzyme/Enzyme/MLIR/Analysis/AliasAnalysis.cpp @@ -62,55 +62,6 @@ static ChangeResult mergeSets(DenseSet &dest, const DenseSet &src) { return dest.size() == oldSize ? ChangeResult::NoChange : ChangeResult::Change; } -Attribute enzyme::PointsToSets::serialize(MLIRContext *ctx) const { - SmallVector pointsToArray; - auto sortKeys = [&](Attribute a, Attribute b) { - auto distinctA = dyn_cast(a); - auto distinctB = dyn_cast(b); - // If not distinct attributes, sort them arbitrarily. - if (!(distinctA && distinctB)) - return &a < &b; - - auto pseudoA = dyn_cast_if_present( - distinctA.getReferencedAttr()); - auto pseudoB = dyn_cast_if_present( - distinctB.getReferencedAttr()); - auto strA = dyn_cast_if_present(distinctA.getReferencedAttr()); - auto strB = dyn_cast_if_present(distinctB.getReferencedAttr()); - if (pseudoA && pseudoB) { - return std::make_pair(pseudoA.getArgNumber(), pseudoA.getDepth()) < - std::make_pair(pseudoB.getArgNumber(), pseudoB.getDepth()); - } else if (strA && strB) { - return strA.strref() < strB.strref(); - } - // Order pseudo classes before fresh classes - return pseudoA && !pseudoB; - }; - - for (const auto &[srcClass, destClasses] : pointsTo) { - SmallVector pair = {srcClass}; - SmallVector aliasClasses; - if (destClasses.isUnknown()) { - aliasClasses.push_back(StringAttr::get(ctx, "unknown")); - } else if (destClasses.isUndefined()) { - aliasClasses.push_back(StringAttr::get(ctx, "undefined")); - } else { - for (const DistinctAttr &destClass : destClasses.getElements()) { - aliasClasses.push_back(destClass); - } - llvm::sort(aliasClasses, sortKeys); - } - pair.push_back(ArrayAttr::get(ctx, aliasClasses)); - pointsToArray.push_back(ArrayAttr::get(ctx, pair)); - } - llvm::sort(pointsToArray, [&](Attribute a, Attribute b) { - auto arrA = cast(a); - auto arrB = cast(b); - return sortKeys(arrA[0], arrB[0]); - }); - return ArrayAttr::get(ctx, pointsToArray); -} - // TODO: a bit easier to prototype with a dense map directly, evaluate // if it'd be better to change the PointsToSets data structure to // support this @@ -143,11 +94,11 @@ deserializePointsTo(ArrayAttr summaryAttr, } void enzyme::PointsToSets::print(raw_ostream &os) const { - if (pointsTo.empty()) { + if (map.empty()) { os << "\n"; return; } - for (const auto &[srcClass, destClasses] : pointsTo) { + for (const auto &[srcClass, destClasses] : map) { os << " " << srcClass << " points to {"; if (destClasses.isUnknown()) { os << ""; @@ -158,55 +109,6 @@ void enzyme::PointsToSets::print(raw_ostream &os) const { } os << "}\n"; } - // os << "other points to unknown: " << otherPointToUnknown << "\n"; -} - -/// Union for every variable. -ChangeResult enzyme::PointsToSets::join(const AbstractDenseLattice &lattice) { - const auto &rhs = static_cast(lattice); - llvm::SmallDenseSet keys; - auto lhsRange = llvm::make_first_range(pointsTo); - auto rhsRange = llvm::make_first_range(rhs.pointsTo); - keys.insert(lhsRange.begin(), lhsRange.end()); - keys.insert(rhsRange.begin(), rhsRange.end()); - - ChangeResult result = ChangeResult::NoChange; - for (DistinctAttr key : keys) { - auto lhsIt = pointsTo.find(key); - auto rhsIt = rhs.pointsTo.find(key); - assert(lhsIt != pointsTo.end() || rhsIt != rhs.pointsTo.end()); - - // If present in both, join. - if (lhsIt != pointsTo.end() && rhsIt != rhs.pointsTo.end()) { - result |= lhsIt->getSecond().join(rhsIt->getSecond()); - continue; - } - - // Copy from RHS if available only there. - if (lhsIt == pointsTo.end()) { - pointsTo.try_emplace(rhsIt->getFirst(), rhsIt->getSecond()); - result = ChangeResult::Change; - } - - // Do nothing if available only in LHS. - } - return result; -} - -ChangeResult -enzyme::PointsToSets::joinPotentiallyMissing(DistinctAttr key, - const AliasClassSet &value) { - // Don't store explicitly undefined values in the mapping, keys absent from - // the mapping are treated as implicitly undefined. - if (value.isUndefined()) - return ChangeResult::NoChange; - - bool inserted; - decltype(pointsTo.begin()) iterator; - std::tie(iterator, inserted) = pointsTo.try_emplace(key, value); - if (!inserted) - return iterator->second.join(value); - return ChangeResult::Change; } ChangeResult enzyme::PointsToSets::update(const AliasClassSet &keysToUpdate, @@ -225,8 +127,8 @@ ChangeResult enzyme::PointsToSets::update(const AliasClassSet &keysToUpdate, "unknown must have been handled above"); #ifndef NDEBUG if (replace) { - auto it = pointsTo.find(dest); - if (it != pointsTo.end()) { + auto it = map.find(dest); + if (it != map.end()) { // Check that we are updating to a state that's >= in the // lattice. // TODO: consider a stricter check that we only replace unknown @@ -272,8 +174,8 @@ enzyme::PointsToSets::addSetsFrom(const AliasClassSet &destClasses, if (srcState == AliasClassSet::State::Unknown) srcClasses = &AliasClassSet::getUnknown(); else if (srcState == AliasClassSet::State::Defined) { - auto it = pointsTo.find(src); - if (it != pointsTo.end()) + auto it = map.find(src); + if (it != map.end()) srcClasses = &it->getSecond(); } return joinPotentiallyMissing(dest, *srcClasses); @@ -294,20 +196,13 @@ enzyme::PointsToSets::markPointToUnknown(const AliasClassSet &destClasses) { }); } -ChangeResult enzyme::PointsToSets::markAllPointToUnknown() { - ChangeResult result = ChangeResult::NoChange; - for (auto &it : pointsTo) - result |= it.getSecond().join(AliasClassSet::getUnknown()); - return result; -} - ChangeResult enzyme::PointsToSets::markAllExceptPointToUnknown( const AliasClassSet &destClasses) { if (destClasses.isUndefined()) return ChangeResult::NoChange; ChangeResult result = ChangeResult::NoChange; - for (auto &[key, value] : pointsTo) { + for (auto &[key, value] : map) { if (destClasses.isUnknown() || !destClasses.getElements().contains(key)) { result |= value.markUnknown(); } @@ -317,7 +212,7 @@ ChangeResult enzyme::PointsToSets::markAllExceptPointToUnknown( (void)destClasses.foreachElement( [&](DistinctAttr dest, AliasClassSet::State state) { if (state == AliasClassSet::State::Defined) - assert(pointsTo.contains(dest) && "unknown dest cannot be preserved"); + assert(map.contains(dest) && "unknown dest cannot be preserved"); return ChangeResult::NoChange; }); #endif // NDEBUG diff --git a/enzyme/Enzyme/MLIR/Analysis/AliasAnalysis.h b/enzyme/Enzyme/MLIR/Analysis/AliasAnalysis.h index 95180ce6f662..dd7110a7b21a 100644 --- a/enzyme/Enzyme/MLIR/Analysis/AliasAnalysis.h +++ b/enzyme/Enzyme/MLIR/Analysis/AliasAnalysis.h @@ -103,17 +103,12 @@ class OriginalClasses { // pointers stored/loaded through memory. //===----------------------------------------------------------------------===// -class PointsToSets : public dataflow::AbstractDenseLattice { +class PointsToSets : public MapOfSetsLattice { public: - using AbstractDenseLattice::AbstractDenseLattice; - - // Serialize the points-to information of this state into an attribute. - Attribute serialize(MLIRContext *ctx) const; + using MapOfSetsLattice::MapOfSetsLattice; void print(raw_ostream &os) const override; - ChangeResult join(const AbstractDenseLattice &lattice) override; - /// Mark the pointer stored in `dest` as possibly pointing to any of `values`, /// instead of the values it may be currently pointing to. ChangeResult setPointingToClasses(const AliasClassSet &destClasses, @@ -141,18 +136,13 @@ class PointsToSets : public dataflow::AbstractDenseLattice { /// Mark the entire data structure as "unknown", that is, any pointer may be /// containing any other pointer. This is the full pessimistic fixpoint. - ChangeResult markAllPointToUnknown(); + ChangeResult markAllPointToUnknown() { return markAllUnknown(); } /// Mark all alias classes except the given ones to point to the "unknown" /// alias set. ChangeResult markAllExceptPointToUnknown(const AliasClassSet &destClasses); - const AliasClassSet &getPointsTo(DistinctAttr id) const { - auto it = pointsTo.find(id); - if (it == pointsTo.end()) - return AliasClassSet::getUndefined(); - return it->getSecond(); - } + const AliasClassSet &getPointsTo(DistinctAttr id) const { return lookup(id); } private: /// Update all alias classes in `keysToUpdate` to additionally point to alias @@ -169,24 +159,6 @@ class PointsToSets : public dataflow::AbstractDenseLattice { /// in the lattice, not only the replacements described above. ChangeResult update(const AliasClassSet &keysToUpdate, const AliasClassSet &values, bool replace); - - ChangeResult joinPotentiallyMissing(DistinctAttr key, - const AliasClassSet &value); - - /// Indicates that alias classes not listed as keys in `pointsTo` point to - /// unknown alias set (when true) or an empty alias set (when false). - // TODO: consider also differentiating between pointing to known-empty vs. - // not-yet-computed. - // bool otherPointToUnknown = false; - - // missing from map always beings "undefined", "unknown"s are stored - // explicitly. - - /// Maps an identifier of an alias set to the set of alias sets its value may - /// belong to. When an identifier is not present in this map, it is considered - /// to point to either the unknown set or nothing, based on the value of - /// `otherPointToUnknown`. - DenseMap pointsTo; }; //===----------------------------------------------------------------------===// diff --git a/enzyme/Enzyme/MLIR/Analysis/Lattice.h b/enzyme/Enzyme/MLIR/Analysis/Lattice.h index eda117881eeb..9c8b6b187295 100644 --- a/enzyme/Enzyme/MLIR/Analysis/Lattice.h +++ b/enzyme/Enzyme/MLIR/Analysis/Lattice.h @@ -7,13 +7,16 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // // If using this code in an academic setting, please cite the following: -// @incollection{enzymeNeurips, -// title = {Instead of Rewriting Foreign Code for Machine Learning, -// Automatically Synthesize Fast Gradients}, -// author = {Moses, William S. and Churavy, Valentin}, -// booktitle = {Advances in Neural Information Processing Systems 33}, -// year = {2020}, -// note = {To appear in}, +// @inproceedings{NEURIPS2020_9332c513, +// author = {Moses, William and Churavy, Valentin}, +// booktitle = {Advances in Neural Information Processing Systems}, +// editor = {H. Larochelle and M. Ranzato and R. Hadsell and M. F. Balcan and H. +// Lin}, pages = {12472--12485}, publisher = {Curran Associates, Inc.}, title = +// {Instead of Rewriting Foreign Code for Machine Learning, Automatically +// Synthesize Fast Gradients}, url = +// {https://proceedings.neurips.cc/paper/2020/file/9332c513ef44b682e9347822c2e457ac-Paper.pdf}, +// volume = {33}, +// year = {2020} // } // //===----------------------------------------------------------------------===// @@ -275,7 +278,6 @@ class MapOfSetsLattice : public dataflow::AbstractDenseLattice { return result; } - // TODO(jacob): switch over the alias class lattices to using these /// Map all keys to all values. ChangeResult insert(const SetLattice &keysToUpdate, const SetLattice &values) { @@ -300,6 +302,14 @@ class MapOfSetsLattice : public dataflow::AbstractDenseLattice { return result; } + const SetLattice &lookup(KeyT key) const { + auto it = map.find(key); + if (it == map.end()) + return SetLattice::getUndefined(); + return it->getSecond(); + } + +protected: ChangeResult joinPotentiallyMissing(KeyT key, const SetLattice &value) { // Don't store explicitly undefined values in the mapping, keys absent from @@ -315,14 +325,8 @@ class MapOfSetsLattice : public dataflow::AbstractDenseLattice { return ChangeResult::Change; } - const SetLattice &lookup(KeyT key) const { - auto it = map.find(key); - if (it == map.end()) - return SetLattice::getUndefined(); - return it->getSecond(); - } - -protected: + /// Maps a key to a set of values. When a key is not present in this map, it + /// is considered to map to an uninitialized set. DenseMap> map; private: