From 17c1190e9bcc1e12e92d4bedf12f78d7a5fa5dba Mon Sep 17 00:00:00 2001 From: Jacob Peng Date: Thu, 22 Feb 2024 17:44:15 -0500 Subject: [PATCH] Define generic lattices for dataflow analyses --- enzyme/Enzyme/MLIR/Analysis/AliasAnalysis.cpp | 190 ++------- enzyme/Enzyme/MLIR/Analysis/AliasAnalysis.h | 142 +------ enzyme/Enzyme/MLIR/Analysis/CMakeLists.txt | 1 + .../Analysis/DataFlowActivityAnalysis.cpp | 10 +- enzyme/Enzyme/MLIR/Analysis/Lattice.cpp | 44 +++ enzyme/Enzyme/MLIR/Analysis/Lattice.h | 364 ++++++++++++++++++ .../Enzyme/MLIR/Passes/PrintAliasAnalysis.cpp | 2 +- 7 files changed, 457 insertions(+), 296 deletions(-) create mode 100644 enzyme/Enzyme/MLIR/Analysis/Lattice.cpp create mode 100644 enzyme/Enzyme/MLIR/Analysis/Lattice.h diff --git a/enzyme/Enzyme/MLIR/Analysis/AliasAnalysis.cpp b/enzyme/Enzyme/MLIR/Analysis/AliasAnalysis.cpp index 9af864e592d6..e10dfd68c843 100644 --- a/enzyme/Enzyme/MLIR/Analysis/AliasAnalysis.cpp +++ b/enzyme/Enzyme/MLIR/Analysis/AliasAnalysis.cpp @@ -49,67 +49,6 @@ static bool isPointerLike(Type type) { return isa(type); } -const enzyme::AliasClassSet enzyme::AliasClassSet::undefinedSet = - AliasClassSet(enzyme::AliasClassSet::State::Undefined); -const enzyme::AliasClassSet enzyme::AliasClassSet::unknownSet = - AliasClassSet(enzyme::AliasClassSet::State::Unknown); - -ChangeResult enzyme::AliasClassSet::join(const AliasClassSet &other) { - if (isUnknown()) - return ChangeResult::NoChange; - if (isUndefined() && other.isUndefined()) - return ChangeResult::NoChange; - if (other.isUnknown()) { - state = State::Unknown; - return ChangeResult::Change; - } - - ChangeResult result = updateStateToDefined(); - return insert(other.aliasClasses) | result; -} - -ChangeResult -enzyme::AliasClassSet::insert(const DenseSet &classes) { - if (isUnknown()) - return ChangeResult::NoChange; - - size_t oldSize = aliasClasses.size(); - aliasClasses.insert(classes.begin(), classes.end()); - ChangeResult result = aliasClasses.size() == oldSize ? ChangeResult::NoChange - : ChangeResult::Change; - return updateStateToDefined() | result; -} - -ChangeResult enzyme::AliasClassSet::markUnknown() { - if (isUnknown()) - return ChangeResult::NoChange; - - state = State::Unknown; - aliasClasses.clear(); - return ChangeResult::Change; -} - -bool enzyme::AliasClassSet::isCanonical() const { - return state == State::Defined || aliasClasses.empty(); -} - -bool enzyme::AliasClassSet::operator==( - const enzyme::AliasClassSet &other) const { - assert(isCanonical() && other.isCanonical()); - return state == other.state && llvm::equal(aliasClasses, other.aliasClasses); -} - -ChangeResult enzyme::AliasClassSet::foreachClass( - function_ref callback) const { - if (state != State::Defined) - return callback(nullptr, state); - - ChangeResult result = ChangeResult::NoChange; - for (DistinctAttr attr : aliasClasses) - result |= callback(attr, state); - return result; -} - //===----------------------------------------------------------------------===// // PointsToAnalysis //===----------------------------------------------------------------------===// @@ -122,70 +61,21 @@ static ChangeResult mergeSets(DenseSet &dest, const DenseSet &src) { } void enzyme::PointsToSets::print(raw_ostream &os) const { - if (pointsTo.empty()) { + if (map.empty()) { os << "\n"; return; } - for (const auto &[srcClass, destClasses] : pointsTo) { + for (const auto &[srcClass, destClasses] : map) { os << " " << srcClass << " points to {"; if (destClasses.isUnknown()) { os << ""; } else if (destClasses.isUndefined()) { os << ""; } else { - llvm::interleaveComma(destClasses.getAliasClasses(), os); + llvm::interleaveComma(destClasses.getElements(), os); } os << "}\n"; } - // os << "other points to unknown: " << otherPointToUnknown << "\n"; -} - -/// Union for every variable. -ChangeResult enzyme::PointsToSets::join(const AbstractDenseLattice &lattice) { - const auto &rhs = static_cast(lattice); - llvm::SmallDenseSet keys; - auto lhsRange = llvm::make_first_range(pointsTo); - auto rhsRange = llvm::make_first_range(rhs.pointsTo); - keys.insert(lhsRange.begin(), lhsRange.end()); - keys.insert(rhsRange.begin(), rhsRange.end()); - - ChangeResult result = ChangeResult::NoChange; - for (DistinctAttr key : keys) { - auto lhsIt = pointsTo.find(key); - auto rhsIt = rhs.pointsTo.find(key); - assert(lhsIt != pointsTo.end() || rhsIt != rhs.pointsTo.end()); - - // If present in both, join. - if (lhsIt != pointsTo.end() && rhsIt != rhs.pointsTo.end()) { - result |= lhsIt->getSecond().join(rhsIt->getSecond()); - continue; - } - - // Copy from RHS if available only there. - if (lhsIt == pointsTo.end()) { - pointsTo.try_emplace(rhsIt->getFirst(), rhsIt->getSecond()); - result = ChangeResult::Change; - } - - // Do nothing if available only in LHS. - } - return result; -} - -ChangeResult -enzyme::PointsToSets::joinPotentiallyMissing(DistinctAttr key, - const AliasClassSet &value) { - // Don't store explicitly undefined values in the mapping, keys absent from - // the mapping are treated as implicitly undefined. - if (value.isUndefined()) - return ChangeResult::NoChange; - - bool inserted; - decltype(pointsTo.begin()) iterator; - std::tie(iterator, inserted) = pointsTo.try_emplace(key, value); - if (!inserted) - return iterator->second.join(value); - return ChangeResult::Change; } ChangeResult enzyme::PointsToSets::update(const AliasClassSet &keysToUpdate, @@ -198,14 +88,14 @@ ChangeResult enzyme::PointsToSets::update(const AliasClassSet &keysToUpdate, if (keysToUpdate.isUndefined()) return ChangeResult::NoChange; - return keysToUpdate.foreachClass( + return keysToUpdate.foreachElement( [&](DistinctAttr dest, AliasClassSet::State state) { assert(state == AliasClassSet::State::Defined && "unknown must have been handled above"); #ifndef NDEBUG if (replace) { - auto it = pointsTo.find(dest); - if (it != pointsTo.end()) { + auto it = map.find(dest); + if (it != map.end()) { // Check that we are updating to a state that's >= in the // lattice. // TODO: consider a stricter check that we only replace unknown @@ -242,17 +132,17 @@ enzyme::PointsToSets::addSetsFrom(const AliasClassSet &destClasses, if (destClasses.isUndefined()) return ChangeResult::NoChange; - return destClasses.foreachClass( + return destClasses.foreachElement( [&](DistinctAttr dest, AliasClassSet::State destState) { assert(destState == AliasClassSet::State::Defined); - return srcClasses.foreachClass( + return srcClasses.foreachElement( [&](DistinctAttr src, AliasClassSet::State srcState) { const AliasClassSet *srcClasses = &AliasClassSet::getUndefined(); if (srcState == AliasClassSet::State::Unknown) srcClasses = &AliasClassSet::getUnknown(); else if (srcState == AliasClassSet::State::Defined) { - auto it = pointsTo.find(src); - if (it != pointsTo.end()) + auto it = map.find(src); + if (it != map.end()) srcClasses = &it->getSecond(); } return joinPotentiallyMissing(dest, *srcClasses); @@ -267,16 +157,10 @@ enzyme::PointsToSets::markPointToUnknown(const AliasClassSet &destClasses) { if (destClasses.isUndefined()) return ChangeResult::NoChange; - return destClasses.foreachClass([&](DistinctAttr dest, AliasClassSet::State) { - return joinPotentiallyMissing(dest, AliasClassSet::getUnknown()); - }); -} - -ChangeResult enzyme::PointsToSets::markAllPointToUnknown() { - ChangeResult result = ChangeResult::NoChange; - for (auto &it : pointsTo) - result |= it.getSecond().join(AliasClassSet::getUnknown()); - return result; + return destClasses.foreachElement( + [&](DistinctAttr dest, AliasClassSet::State) { + return joinPotentiallyMissing(dest, AliasClassSet::getUnknown()); + }); } ChangeResult enzyme::PointsToSets::markAllExceptPointToUnknown( @@ -285,18 +169,17 @@ ChangeResult enzyme::PointsToSets::markAllExceptPointToUnknown( return ChangeResult::NoChange; ChangeResult result = ChangeResult::NoChange; - for (auto &[key, value] : pointsTo) { - if (destClasses.isUnknown() || - !destClasses.getAliasClasses().contains(key)) { + for (auto &[key, value] : map) { + if (destClasses.isUnknown() || !destClasses.getElements().contains(key)) { result |= value.markUnknown(); } } #ifndef NDEBUG - (void)destClasses.foreachClass( + (void)destClasses.foreachElement( [&](DistinctAttr dest, AliasClassSet::State state) { if (state == AliasClassSet::State::Defined) - assert(pointsTo.contains(dest) && "unknown dest cannot be preserved"); + assert(map.contains(dest) && "unknown dest cannot be preserved"); return ChangeResult::NoChange; }); #endif // NDEBUG @@ -632,7 +515,7 @@ void enzyme::PointsToPointerAnalysis::visitCallControlFlowTransfer( // Otherwise, indicate that a pointer that belongs to any of the // classes captured by this function may be stored into the // destination class. - changed |= destClasses->getAliasClassesObject().foreachClass( + changed |= destClasses->getAliasClassesObject().foreachElement( [&](DistinctAttr dest, AliasClassSet::State) { return after->insert(dest, functionMayCapture); }); @@ -694,7 +577,7 @@ void enzyme::PointsToPointerAnalysis::visitCallControlFlowTransfer( !nonWritableOperandClasses.isUndefined()) { DenseSet nonOperandClasses = llvm::set_difference(destClasses->getAliasClasses(), - nonWritableOperandClasses.getAliasClasses()); + nonWritableOperandClasses.getElements()); (void)resultWithoutNonWritableOperands.insert(nonOperandClasses); } else { (void)resultWithoutNonWritableOperands.join( @@ -741,25 +624,14 @@ void enzyme::PointsToPointerAnalysis::setToEntryState(PointsToSets *lattice) {} // AliasClassLattice //===----------------------------------------------------------------------===// -void enzyme::AliasClassSet::print(raw_ostream &os) const { - if (isUnknown()) { - os << ""; - } else if (isUndefined()) { - os << ""; - } else { - llvm::interleaveComma(aliasClasses, os << "{"); - os << "}"; - } -} - void enzyme::AliasClassLattice::print(raw_ostream &os) const { - if (aliasClasses.isUnknown()) { + if (elements.isUnknown()) { os << "Unknown AC"; - } else if (aliasClasses.isUndefined()) { + } else if (elements.isUndefined()) { os << "Undefined AC"; } else { - os << "size: " << aliasClasses.getAliasClasses().size() << ":\n"; - for (auto aliasClass : aliasClasses.getAliasClasses()) { + os << "size: " << elements.getElements().size() << ":\n"; + for (auto aliasClass : elements.getElements()) { os << " " << aliasClass << "\n"; } } @@ -774,12 +646,12 @@ enzyme::AliasClassLattice::alias(const AbstractSparseLattice &other) const { if (getPoint() == rhs->getPoint()) return AliasResult::MustAlias; - if (aliasClasses.isUnknown() || rhs->aliasClasses.isUnknown()) + if (elements.isUnknown() || rhs->elements.isUnknown()) return AliasResult::MayAlias; - size_t overlap = llvm::count_if( - aliasClasses.getAliasClasses(), [rhs](DistinctAttr aliasClass) { - return rhs->aliasClasses.getAliasClasses().contains(aliasClass); + size_t overlap = + llvm::count_if(elements.getElements(), [rhs](DistinctAttr aliasClass) { + return rhs->elements.getElements().contains(aliasClass); }); if (overlap == 0) @@ -800,7 +672,7 @@ ChangeResult enzyme::AliasClassLattice::join(const AbstractSparseLattice &other) { // Set union of the alias classes const auto *otherAliasClass = static_cast(&other); - return aliasClasses.join(otherAliasClass->aliasClasses); + return elements.join(otherAliasClass->elements); } //===----------------------------------------------------------------------===// @@ -896,7 +768,7 @@ void enzyme::AliasAnalysis::transfer( continue; } else { propagateIfChanged(result, - result->insert(srcPointsTo.getAliasClasses())); + result->insert(srcPointsTo.getElements())); } } } @@ -1041,7 +913,7 @@ void enzyme::AliasAnalysis::visitExternalCall( // If can read from argument, collect the alias classes that can this // argument may be pointing to. const auto *pointsToLattice = getOrCreateFor(call, call); - (void)srcClasses->getAliasClassesObject().foreachClass( + (void)srcClasses->getAliasClassesObject().foreachElement( [&](DistinctAttr srcClass, AliasClassSet::State state) { // Nothing to do in top/bottom case. In the top case, we have already // set `operandAliasClasses` to top above. diff --git a/enzyme/Enzyme/MLIR/Analysis/AliasAnalysis.h b/enzyme/Enzyme/MLIR/Analysis/AliasAnalysis.h index 57297f1242a7..eeb70bc19a87 100644 --- a/enzyme/Enzyme/MLIR/Analysis/AliasAnalysis.h +++ b/enzyme/Enzyme/MLIR/Analysis/AliasAnalysis.h @@ -26,6 +26,8 @@ #ifndef ENZYME_MLIR_ANALYSIS_DATAFLOW_ALIASANALYSIS_H #define ENZYME_MLIR_ANALYSIS_DATAFLOW_ALIASANALYSIS_H +#include "Lattice.h" + #include "mlir/Analysis/AliasAnalysis.h" #include "mlir/Analysis/DataFlow/DenseAnalysis.h" #include "mlir/Analysis/DataFlow/SparseAnalysis.h" @@ -42,85 +44,7 @@ namespace enzyme { /// marked as "unknown", which is a conservative pessimistic state, or as /// "undefined", which is a "not-yet-analyzed" initial state. Undefined state is /// different from an empty alias set. -class AliasClassSet { -public: - enum class State { - Undefined, ///< Has not been analyzed yet (lattice bottom). - Defined, ///< Has specific alias classes. - Unknown ///< Analyzed and may point to any class (lattice top). - }; - - AliasClassSet() : state(State::Undefined) {} - - AliasClassSet(DistinctAttr single) : state(State::Defined) { - aliasClasses.insert(single); - } - - // TODO(zinenko): deprecate this and use a visitor instead. - DenseSet &getAliasClasses() { - assert(state == State::Defined); - return aliasClasses; - } - const DenseSet &getAliasClasses() const { - return const_cast(this)->getAliasClasses(); - } - - bool isUnknown() const { return state == State::Unknown; } - bool isUndefined() const { return state == State::Undefined; } - - ChangeResult join(const AliasClassSet &other); - ChangeResult insert(const DenseSet &classes); - ChangeResult markUnknown(); - - /// Returns true if this set is in the canonical form, i.e. either the state - /// is `State::Defined` or the explicit list of classes is empty, but not - /// both. - bool isCanonical() const; - - /// Returns an instance of AliasClassSet known not to alias with anything. - /// This is different from "undefined" and "unknown". The instance is *not* a - /// classical singleton. - static const AliasClassSet &getEmpty() { - static const AliasClassSet empty(State::Defined); - return empty; - } - - /// Returns an instance of AliasClassSet in "undefined" state, i.e. without a - /// set of alias classes. This is different from empty alias set, which - /// indicates that the value is known not to alias with any alias class. The - /// instance is *not* a classical singleton, there are other ways of obtaining - /// it. - static const AliasClassSet &getUndefined() { return undefinedSet; } - - /// Returns an instance of AliasClassSet for the "unknown" class. The instance - /// is *not* a classical singleton, there are other ways of obtaining an - /// "unknown" alias set. - static const AliasClassSet &getUnknown() { return unknownSet; } - - bool operator==(const AliasClassSet &other) const; - - void print(llvm::raw_ostream &os) const; - - ChangeResult - foreachClass(function_ref callback) const; - -private: - explicit AliasClassSet(State state) : state(state) {} - - ChangeResult updateStateToDefined() { - assert(state != State::Unknown && "cannot go back from unknown state"); - ChangeResult result = state == State::Undefined ? ChangeResult::Change - : ChangeResult::NoChange; - state = State::Defined; - return result; - } - - const static AliasClassSet unknownSet; - const static AliasClassSet undefinedSet; - - DenseSet aliasClasses; - State state; -}; +using AliasClassSet = SetLattice; //===----------------------------------------------------------------------===// // OriginalClasses @@ -179,14 +103,12 @@ class OriginalClasses { // pointers stored/loaded through memory. //===----------------------------------------------------------------------===// -class PointsToSets : public dataflow::AbstractDenseLattice { +class PointsToSets : public MapOfSetsLattice { public: - using AbstractDenseLattice::AbstractDenseLattice; + using MapOfSetsLattice::MapOfSetsLattice; void print(raw_ostream &os) const override; - ChangeResult join(const AbstractDenseLattice &lattice) override; - /// Mark the pointer stored in `dest` as possibly pointing to any of `values`, /// instead of the values it may be currently pointing to. ChangeResult setPointingToClasses(const AliasClassSet &destClasses, @@ -214,18 +136,13 @@ class PointsToSets : public dataflow::AbstractDenseLattice { /// Mark the entire data structure as "unknown", that is, any pointer may be /// containing any other pointer. This is the full pessimistic fixpoint. - ChangeResult markAllPointToUnknown(); + ChangeResult markAllPointToUnknown() { return markAllUnknown(); } /// Mark all alias classes except the given ones to point to the "unknown" /// alias set. ChangeResult markAllExceptPointToUnknown(const AliasClassSet &destClasses); - const AliasClassSet &getPointsTo(DistinctAttr id) const { - auto it = pointsTo.find(id); - if (it == pointsTo.end()) - return AliasClassSet::getUndefined(); - return it->getSecond(); - } + const AliasClassSet &getPointsTo(DistinctAttr id) const { return lookup(id); } private: /// Update all alias classes in `keysToUpdate` to additionally point to alias @@ -242,24 +159,6 @@ class PointsToSets : public dataflow::AbstractDenseLattice { /// in the lattice, not only the replacements described above. ChangeResult update(const AliasClassSet &keysToUpdate, const AliasClassSet &values, bool replace); - - ChangeResult joinPotentiallyMissing(DistinctAttr key, - const AliasClassSet &value); - - /// Indicates that alias classes not listed as keys in `pointsTo` point to - /// unknown alias set (when true) or an empty alias set (when false). - // TODO: consider also differentiating between pointing to known-empty vs. - // not-yet-computed. - // bool otherPointToUnknown = false; - - // missing from map always beings "undefined", "unknown"s are stored - // explicitly. - - /// Maps an identifier of an alias set to the set of alias sets its value may - /// belong to. When an identifier is not present in this map, it is considered - /// to point to either the unknown set or nothing, based on the value of - /// `otherPointToUnknown`. - DenseMap pointsTo; }; //===----------------------------------------------------------------------===// @@ -298,12 +197,9 @@ class PointsToPointerAnalysis // AliasClassLattice //===----------------------------------------------------------------------===// -class AliasClassLattice : public dataflow::AbstractSparseLattice { +class AliasClassLattice : public SparseSetLattice { public: - using AbstractSparseLattice::AbstractSparseLattice; - AliasClassLattice(Value value, AliasClassSet &&classes) - : dataflow::AbstractSparseLattice(value), - aliasClasses(std::move(classes)) {} + using SparseSetLattice::SparseSetLattice; void print(raw_ostream &os) const override; @@ -311,31 +207,15 @@ class AliasClassLattice : public dataflow::AbstractSparseLattice { ChangeResult join(const AbstractSparseLattice &other) override; - ChangeResult insert(const DenseSet &classes) { - return aliasClasses.insert(classes); - } - static AliasClassLattice single(Value point, DistinctAttr value) { return AliasClassLattice(point, AliasClassSet(value)); } - ChangeResult markUnknown() { return aliasClasses.markUnknown(); } - - // ChangeResult reset() { return aliasClasses.reset(); } - - /// We don't know anything about the aliasing of this value. - bool isUnknown() const { return aliasClasses.isUnknown(); } - - bool isUndefined() const { return aliasClasses.isUndefined(); } - const DenseSet &getAliasClasses() const { - return aliasClasses.getAliasClasses(); + return elements.getElements(); } - const AliasClassSet &getAliasClassesObject() const { return aliasClasses; } - -private: - AliasClassSet aliasClasses; + const AliasClassSet &getAliasClassesObject() const { return elements; } }; //===----------------------------------------------------------------------===// diff --git a/enzyme/Enzyme/MLIR/Analysis/CMakeLists.txt b/enzyme/Enzyme/MLIR/Analysis/CMakeLists.txt index 6e53ef50ce4a..7d92d9b357f3 100644 --- a/enzyme/Enzyme/MLIR/Analysis/CMakeLists.txt +++ b/enzyme/Enzyme/MLIR/Analysis/CMakeLists.txt @@ -2,6 +2,7 @@ add_mlir_library(MLIREnzymeAnalysis ActivityAnalysis.cpp AliasAnalysis.cpp DataFlowActivityAnalysis.cpp + Lattice.cpp DEPENDS MLIRAutoDiffTypeInterfaceIncGen diff --git a/enzyme/Enzyme/MLIR/Analysis/DataFlowActivityAnalysis.cpp b/enzyme/Enzyme/MLIR/Analysis/DataFlowActivityAnalysis.cpp index 9c540064a49a..2048a02918bd 100644 --- a/enzyme/Enzyme/MLIR/Analysis/DataFlowActivityAnalysis.cpp +++ b/enzyme/Enzyme/MLIR/Analysis/DataFlowActivityAnalysis.cpp @@ -490,7 +490,7 @@ std::optional getCopySource(Operation *op) { /// If the classes are undefined, the callback will not be called at all. void forEachAliasedAlloc(const AliasClassLattice *ptrAliasClass, function_ref forEachFn) { - (void)ptrAliasClass->getAliasClassesObject().foreachClass( + (void)ptrAliasClass->getAliasClassesObject().foreachElement( [&](DistinctAttr alloc, enzyme::AliasClassSet::State state) { if (state != enzyme::AliasClassSet::State::Undefined) forEachFn(alloc); @@ -636,7 +636,7 @@ class DenseForwardActivityAnalysis continue; auto *argAliasClasses = getOrCreateFor(block, arg); ChangeResult changed = - argAliasClasses->getAliasClassesObject().foreachClass( + argAliasClasses->getAliasClassesObject().foreachElement( [lattice](DistinctAttr argAliasClass, enzyme::AliasClassSet::State state) { if (state == enzyme::AliasClassSet::State::Undefined) @@ -687,7 +687,7 @@ class DenseBackwardActivityAnalysis } auto *argAliasClasses = getOrCreateFor(op, arg); ChangeResult changed = - argAliasClasses->getAliasClassesObject().foreachClass( + argAliasClasses->getAliasClassesObject().foreachElement( [before](DistinctAttr argAliasClass, enzyme::AliasClassSet::State state) { if (state == enzyme::AliasClassSet::State::Undefined) @@ -703,7 +703,7 @@ class DenseBackwardActivityAnalysis auto *retAliasClasses = getOrCreateFor(op, operand); ChangeResult changed = - retAliasClasses->getAliasClassesObject().foreachClass( + retAliasClasses->getAliasClassesObject().foreachElement( [before](DistinctAttr retAliasClass, enzyme::AliasClassSet::State state) { if (state == enzyme::AliasClassSet::State::Undefined) @@ -854,7 +854,7 @@ void printActivityAnalysisResults(const DataFlowSolver &solver, std::deque frontier; DenseSet visited; auto scheduleVisit = [&](const enzyme::AliasClassSet &aliasClasses) { - (void)aliasClasses.foreachClass( + (void)aliasClasses.foreachElement( [&](DistinctAttr neighbor, enzyme::AliasClassSet::State state) { assert(neighbor && "unhandled undefined/unknown case before visit"); diff --git a/enzyme/Enzyme/MLIR/Analysis/Lattice.cpp b/enzyme/Enzyme/MLIR/Analysis/Lattice.cpp new file mode 100644 index 000000000000..4034c810d492 --- /dev/null +++ b/enzyme/Enzyme/MLIR/Analysis/Lattice.cpp @@ -0,0 +1,44 @@ +//===- Lattice.h - Implementation of common dataflow lattices -------------===// +// +// Enzyme Project +// +// Part of the Enzyme Project, under the Apache License v2.0 with LLVM +// Exceptions. See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// If using this code in an academic setting, please cite the following: +// @inproceedings{NEURIPS2020_9332c513, +// author = {Moses, William and Churavy, Valentin}, +// booktitle = {Advances in Neural Information Processing Systems}, +// editor = {H. Larochelle and M. Ranzato and R. Hadsell and M. F. Balcan and H. +// Lin}, pages = {12472--12485}, publisher = {Curran Associates, Inc.}, title = +// {Instead of Rewriting Foreign Code for Machine Learning, Automatically +// Synthesize Fast Gradients}, url = +// {https://proceedings.neurips.cc/paper/2020/file/9332c513ef44b682e9347822c2e457ac-Paper.pdf}, +// volume = {33}, +// year = {2020} +// } +// +//===----------------------------------------------------------------------===// +// +// This file contains the implementation of reusable lattices in dataflow +// analyses. +// +//===----------------------------------------------------------------------===// + +#include + +using namespace mlir; + +bool enzyme::sortAttributes(Attribute a, Attribute b) { + auto distinctA = dyn_cast(a); + auto distinctB = dyn_cast(b); + if (distinctA && distinctB) { + auto strA = dyn_cast_if_present(distinctA.getReferencedAttr()); + auto strB = dyn_cast_if_present(distinctB.getReferencedAttr()); + if (strA && strB) + return strA.strref() < strB.strref(); + } + // If there's no string to compare, sort them arbitrarily. + return &a < &b; +} diff --git a/enzyme/Enzyme/MLIR/Analysis/Lattice.h b/enzyme/Enzyme/MLIR/Analysis/Lattice.h new file mode 100644 index 000000000000..a10d0b248972 --- /dev/null +++ b/enzyme/Enzyme/MLIR/Analysis/Lattice.h @@ -0,0 +1,364 @@ +//===- Lattice.h - Declaration of common dataflow lattices ----------------===// +// +// Enzyme Project +// +// Part of the Enzyme Project, under the Apache License v2.0 with LLVM +// Exceptions. See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// If using this code in an academic setting, please cite the following: +// @inproceedings{NEURIPS2020_9332c513, +// author = {Moses, William and Churavy, Valentin}, +// booktitle = {Advances in Neural Information Processing Systems}, +// editor = {H. Larochelle and M. Ranzato and R. Hadsell and M. F. Balcan and H. +// Lin}, pages = {12472--12485}, publisher = {Curran Associates, Inc.}, title = +// {Instead of Rewriting Foreign Code for Machine Learning, Automatically +// Synthesize Fast Gradients}, url = +// {https://proceedings.neurips.cc/paper/2020/file/9332c513ef44b682e9347822c2e457ac-Paper.pdf}, +// volume = {33}, +// year = {2020} +// } +// +//===----------------------------------------------------------------------===// +// +// This file contains the declaration of reusable lattices in dataflow analyses. +// +//===----------------------------------------------------------------------===// + +#ifndef ENZYME_MLIR_ANALYSIS_DATAFLOW_LATTICE_H +#define ENZYME_MLIR_ANALYSIS_DATAFLOW_LATTICE_H + +#include "mlir/Analysis/DataFlow/DenseAnalysis.h" +#include "mlir/Analysis/DataFlow/SparseAnalysis.h" +#include "mlir/Analysis/DataFlowFramework.h" + +namespace mlir { +namespace enzyme { + +//===----------------------------------------------------------------------===// +// SetLattice +// +// A data structure representing a set of elements. It may be undefined, meaning +// the analysis has no information about it, or unknown, meaning the analysis +// has conservatively assumed it could contain anything. +//===----------------------------------------------------------------------===// + +template class SetLattice { +public: + enum class State { + Undefined, ///< Has not been analyzed yet (lattice bottom). + Defined, ///< Has specific elements. + Unknown ///< Analyzed and may contain anything (lattice top). + }; + + SetLattice() : state(State::Undefined) {} + + SetLattice(ValueT single) : state(State::Defined) { elements.insert(single); } + + // TODO(zinenko): deprecate this and use a visitor instead. + DenseSet &getElements() { + assert(state == State::Defined); + return elements; + } + + const DenseSet &getElements() const { + return const_cast *>(this)->getElements(); + } + + bool isUnknown() const { return state == State::Unknown; } + bool isUndefined() const { return state == State::Undefined; } + + ChangeResult join(const SetLattice &other) { + if (isUnknown()) + return ChangeResult::NoChange; + if (isUndefined() && other.isUndefined()) + return ChangeResult::NoChange; + if (other.isUnknown()) { + state = State::Unknown; + return ChangeResult::Change; + } + + ChangeResult result = updateStateToDefined(); + return insert(other.elements) | result; + } + + ChangeResult insert(const DenseSet &newElements) { + if (isUnknown()) + return ChangeResult::NoChange; + + size_t oldSize = elements.size(); + elements.insert(newElements.begin(), newElements.end()); + ChangeResult result = elements.size() == oldSize ? ChangeResult::NoChange + : ChangeResult::Change; + return updateStateToDefined() | result; + } + + ChangeResult markUnknown() { + if (isUnknown()) + return ChangeResult::NoChange; + + state = State::Unknown; + elements.clear(); + return ChangeResult::Change; + } + + /// Returns true if this set is in the canonical form, i.e. either the state + /// is `State::Defined` or the explicit list of classes is empty, but not + /// both. + bool isCanonical() const { + return state == State::Defined || elements.empty(); + } + + /// Returns an instance of SetLattice known not to have any elements. + /// This is different from "undefined" and "unknown". The instance is *not* a + /// classical singleton. + static const SetLattice &getEmpty() { + static const SetLattice empty(State::Defined); + return empty; + } + + /// Returns an instance of SetLattice in "undefined" state, i.e. without a set + /// of elements. This is different from empty set, which indicates that the + /// set is known not to contain any elements. The instance is *not* a + /// classical singleton, there are other ways of obtaining it. + static const SetLattice &getUndefined() { return undefinedSet; } + + /// Returns an instance of SetLattice for the "unknown" class. The instance + /// is *not* a classical singleton, there are other ways of obtaining an + /// "unknown" alias set. + static const SetLattice &getUnknown() { return unknownSet; } + + bool operator==(const SetLattice &other) const { + assert(isCanonical() && other.isCanonical()); + return state == other.state && llvm::equal(elements, other.elements); + } + + void print(llvm::raw_ostream &os) const { + if (isUnknown()) { + os << ""; + } else if (isUndefined()) { + os << ""; + } else { + llvm::interleaveComma(elements, os << "{"); + os << "}"; + } + } + + ChangeResult + foreachElement(function_ref callback) const { + if (state != State::Defined) + return callback(nullptr, state); + + ChangeResult result = ChangeResult::NoChange; + for (ValueT element : elements) + result |= callback(element, state); + return result; + } + +private: + explicit SetLattice(State state) : state(state) {} + + ChangeResult updateStateToDefined() { + assert(state != State::Unknown && "cannot go back from unknown state"); + ChangeResult result = state == State::Undefined ? ChangeResult::Change + : ChangeResult::NoChange; + state = State::Defined; + return result; + } + + const static SetLattice unknownSet; + const static SetLattice undefinedSet; + + DenseSet elements; + State state; +}; + +template +const SetLattice SetLattice::unknownSet = + SetLattice(SetLattice::State::Unknown); + +template +const SetLattice SetLattice::undefinedSet = + SetLattice(SetLattice::State::Undefined); + +/// Used when serializing to ensure a consistent order. +bool sortAttributes(Attribute a, Attribute b); + +//===----------------------------------------------------------------------===// +// SparseSetLattice +// +// An abstract lattice for sparse analyses that wraps a set lattice. +//===----------------------------------------------------------------------===// + +template +class SparseSetLattice : public dataflow::AbstractSparseLattice { +public: + using AbstractSparseLattice::AbstractSparseLattice; + SparseSetLattice(Value value, SetLattice &&elements) + : dataflow::AbstractSparseLattice(value), elements(std::move(elements)) {} + + Attribute serialize(MLIRContext *ctx) const { return serializeSetNaive(ctx); } + + ChangeResult merge(const SetLattice &other) { + return elements.join(other); + } + + ChangeResult insert(const DenseSet &newElements) { + return elements.insert(newElements); + } + + ChangeResult markUnknown() { return elements.markUnknown(); } + + bool isUnknown() const { return elements.isUnknown(); } + + bool isUndefined() const { return elements.isUndefined(); } + + const DenseSet &getElements() const { return elements.getElements(); } + +protected: + SetLattice elements; + +private: + Attribute serializeSetNaive(MLIRContext *ctx) const { + if (elements.isUndefined()) + return StringAttr::get(ctx, ""); + if (elements.isUnknown()) + return StringAttr::get(ctx, ""); + SmallVector elementsVec; + for (Attribute element : elements.getElements()) { + elementsVec.push_back(element); + } + llvm::sort(elementsVec, sortAttributes); + return ArrayAttr::get(ctx, elementsVec); + } +}; + +//===----------------------------------------------------------------------===// +// MapOfSetsLattice +//===----------------------------------------------------------------------===// + +template +class MapOfSetsLattice : public dataflow::AbstractDenseLattice { +public: + using AbstractDenseLattice::AbstractDenseLattice; + + Attribute serialize(MLIRContext *ctx) const { + return serializeMapOfSetsNaive(ctx); + } + + ChangeResult join(const AbstractDenseLattice &other) { + const auto &rhs = + static_cast &>(other); + llvm::SmallDenseSet keys; + auto lhsRange = llvm::make_first_range(map); + auto rhsRange = llvm::make_first_range(rhs.map); + keys.insert(lhsRange.begin(), lhsRange.end()); + keys.insert(rhsRange.begin(), rhsRange.end()); + + ChangeResult result = ChangeResult::NoChange; + for (DistinctAttr key : keys) { + auto lhsIt = map.find(key); + auto rhsIt = rhs.map.find(key); + assert(lhsIt != map.end() || rhsIt != rhs.map.end()); + + // If present in both, join. + if (lhsIt != map.end() && rhsIt != rhs.map.end()) { + result |= lhsIt->getSecond().join(rhsIt->getSecond()); + continue; + } + + // Copy from RHS if available only there. + if (lhsIt == map.end()) { + map.try_emplace(rhsIt->getFirst(), rhsIt->getSecond()); + result = ChangeResult::Change; + } + + // Do nothing if available only in LHS. + } + return result; + } + + /// Map all keys to all values. + ChangeResult insert(const SetLattice &keysToUpdate, + const SetLattice &values) { + if (keysToUpdate.isUnknown()) + return markAllUnknown(); + + if (keysToUpdate.isUndefined()) + return ChangeResult::NoChange; + + return keysToUpdate.foreachElement( + [&](DistinctAttr key, typename SetLattice::State state) { + assert(state == SetLattice::State::Defined && + "unknown must have been handled above"); + return joinPotentiallyMissing(key, values); + }); + } + + ChangeResult markAllUnknown() { + ChangeResult result = ChangeResult::NoChange; + for (auto &it : map) + result |= it.getSecond().join(SetLattice::getUnknown()); + return result; + } + + const SetLattice &lookup(KeyT key) const { + auto it = map.find(key); + if (it == map.end()) + return SetLattice::getUndefined(); + return it->getSecond(); + } + +protected: + ChangeResult joinPotentiallyMissing(KeyT key, + const SetLattice &value) { + // Don't store explicitly undefined values in the mapping, keys absent from + // the mapping are treated as implicitly undefined. + if (value.isUndefined()) + return ChangeResult::NoChange; + + bool inserted; + decltype(map.begin()) iterator; + std::tie(iterator, inserted) = map.try_emplace(key, value); + if (!inserted) + return iterator->second.join(value); + return ChangeResult::Change; + } + + /// Maps a key to a set of values. When a key is not present in this map, it + /// is considered to map to an uninitialized set. + DenseMap> map; + +private: + Attribute serializeMapOfSetsNaive(MLIRContext *ctx) const { + SmallVector pointsToArray; + + for (const auto &[srcClass, destClasses] : map) { + SmallVector pair = {srcClass}; + SmallVector aliasClasses; + if (destClasses.isUnknown()) { + aliasClasses.push_back(StringAttr::get(ctx, "unknown")); + } else if (destClasses.isUndefined()) { + aliasClasses.push_back(StringAttr::get(ctx, "undefined")); + } else { + for (const Attribute &destClass : destClasses.getElements()) { + aliasClasses.push_back(destClass); + } + llvm::sort(aliasClasses, sortAttributes); + } + pair.push_back(ArrayAttr::get(ctx, aliasClasses)); + pointsToArray.push_back(ArrayAttr::get(ctx, pair)); + } + llvm::sort(pointsToArray, [&](Attribute a, Attribute b) { + auto arrA = cast(a); + auto arrB = cast(b); + return sortAttributes(arrA[0], arrB[0]); + }); + return ArrayAttr::get(ctx, pointsToArray); + } +}; + +} // namespace enzyme +} // namespace mlir + +#endif // ENZYME_MLIR_ANALYSIS_DATAFLOW_LATTICE_H diff --git a/enzyme/Enzyme/MLIR/Passes/PrintAliasAnalysis.cpp b/enzyme/Enzyme/MLIR/Passes/PrintAliasAnalysis.cpp index c2680ae6db72..0cfc2eda33da 100644 --- a/enzyme/Enzyme/MLIR/Passes/PrintAliasAnalysis.cpp +++ b/enzyme/Enzyme/MLIR/Passes/PrintAliasAnalysis.cpp @@ -91,7 +91,7 @@ struct PrintAliasAnalysisPass continue; // TODO(zinenko): this has been overriding the argument... // Use an array attr instead (will break syntactic tests). - (void)state->getAliasClassesObject().foreachClass( + (void)state->getAliasClassesObject().foreachElement( [&](DistinctAttr aliasClass, enzyme::AliasClassSet::State state) { if (state == enzyme::AliasClassSet::State::Undefined) funcOp.setArgAttr(