Skip to content

Commit

Permalink
Sparse alias tests, renamed attributes and moved them to Enzyme dialect
Browse files Browse the repository at this point in the history
  • Loading branch information
pengmai committed Feb 27, 2024
1 parent 01a458e commit 15334ad
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 68 deletions.
98 changes: 59 additions & 39 deletions enzyme/Enzyme/MLIR/Analysis/ActivityAnnotations.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "ActivityAnnotations.h"
#include "AliasAnalysis.h"
#include "Dialect/Dialect.h"
#include "Dialect/Ops.h"
#include "Lattice.h"

Expand All @@ -17,10 +18,6 @@

using namespace mlir;

static StringRef getActivityAnnotationAttrName() { return "activedeps"; }
static StringRef getPointerSummaryAttrName() { return "p2psummary"; }
static StringRef getReturnOriginsAttrName() { return "returnorigins"; }

template <typename ValueT>
void printSetLattice(const enzyme::SparseSetLattice<ValueT> &setLattice,
raw_ostream &os) {
Expand Down Expand Up @@ -184,8 +181,8 @@ void enzyme::ForwardActivityAnnotationAnalysis::visitExternalCall(

if (auto callee = SymbolTable::lookupNearestSymbolFrom<FunctionOpInterface>(
call, symbol.getLeafReference())) {
if (auto returnOriginsAttr =
callee->getAttrOfType<ArrayAttr>(getReturnOriginsAttrName())) {
if (auto returnOriginsAttr = callee->getAttrOfType<ArrayAttr>(
EnzymeDialect::getSparseActivityAnnotationAttrName())) {
SmallVector<ValueOriginSet> returnOrigins;
deserializeReturnOrigins(returnOriginsAttr, returnOrigins);
return processCallToSummarizedFunc(call, returnOrigins, operands,
Expand Down Expand Up @@ -286,8 +283,8 @@ void enzyme::BackwardActivityAnnotationAnalysis::visitExternalCall(

if (auto callee = SymbolTable::lookupNearestSymbolFrom<FunctionOpInterface>(
call, symbol.getLeafReference())) {
if (auto returnOriginsAttr =
callee->getAttrOfType<ArrayAttr>(getReturnOriginsAttrName())) {
if (auto returnOriginsAttr = callee->getAttrOfType<ArrayAttr>(
EnzymeDialect::getSparseActivityAnnotationAttrName())) {
SmallVector<ValueOriginSet> returnOrigins;
deserializeReturnOrigins(returnOriginsAttr, returnOrigins);
return processCallToSummarizedFunc(call, returnOrigins, operands,
Expand Down Expand Up @@ -513,7 +510,7 @@ void enzyme::DenseActivityAnnotationAnalysis::visitCallControlFlowTransfer(
if (auto callee = SymbolTable::lookupNearestSymbolFrom<FunctionOpInterface>(
call, symbol.getLeafReference())) {
if (auto summaryAttr = callee->getAttrOfType<ArrayAttr>(
getActivityAnnotationAttrName())) {
EnzymeDialect::getDenseActivityAnnotationAttrName())) {
DenseMap<DistinctAttr, ValueOriginSet> summary;
deserializePointsTo(summaryAttr, summary);
return processCallToSummarizedFunc(call, summary, before, after);
Expand Down Expand Up @@ -634,7 +631,7 @@ void enzyme::DenseBackwardActivityAnnotationAnalysis::
if (auto callee = SymbolTable::lookupNearestSymbolFrom<FunctionOpInterface>(
call, symbol.getLeafReference())) {
if (auto summaryAttr = callee->getAttrOfType<ArrayAttr>(
getActivityAnnotationAttrName())) {
EnzymeDialect::getDenseActivityAnnotationAttrName())) {
DenseMap<DistinctAttr, ValueOriginSet> summary;
deserializePointsTo(summaryAttr, summary);
return processCallToSummarizedFunc(call, summary, after, before);
Expand Down Expand Up @@ -859,8 +856,9 @@ void enzyme::runActivityAnnotations(FunctionOpInterface callee) {
reverseToposortCallgraph(callee, &symbolTable, sorted);
raw_ostream &os = llvm::outs();

StringRef pointerSummaryName = EnzymeDialect::getPointerSummaryAttrName();
for (CallableOpInterface node : sorted) {
if (!node.getCallableRegion() || node->hasAttr(getPointerSummaryAttrName()))
if (!node.getCallableRegion() || node->hasAttr(pointerSummaryName))
continue;
auto funcOp = cast<FunctionOpInterface>(node.getOperation());
os << "[ata] processing function @" << funcOp.getName() << "\n";
Expand All @@ -887,7 +885,7 @@ void enzyme::runActivityAnnotations(FunctionOpInterface callee) {

// Create the overall summary by joining sets at all return sites.
enzyme::PointsToSets p2sets(nullptr);
enzyme::ForwardOriginsMap voMap(nullptr);
enzyme::ForwardOriginsMap forwardOriginsMap(nullptr);
size_t numResults = node.getResultTypes().size();
SmallVector<enzyme::ForwardOriginsLattice> returnOperandOrigins(
numResults, ForwardOriginsLattice(nullptr));
Expand All @@ -900,7 +898,7 @@ void enzyme::runActivityAnnotations(FunctionOpInterface callee) {
auto *returnOrigins =
solver.lookupState<enzyme::ForwardOriginsMap>(&op);
if (returnOrigins)
(void)voMap.join(*returnOrigins);
(void)forwardOriginsMap.join(*returnOrigins);

for (OpOperand &operand : op.getOpOperands()) {
(void)returnAliasClasses[operand.getOperandNumber()].join(
Expand All @@ -918,37 +916,39 @@ void enzyme::runActivityAnnotations(FunctionOpInterface callee) {
// os << "[debug] backward state for arg " << arg.getArgNumber() << ": "
// << *backwardState << "\n";
// }
SmallVector<Attribute> aliasAttributes(returnAliasClasses.size());
llvm::transform(returnAliasClasses, aliasAttributes.begin(),
[&](enzyme::AliasClassLattice lattice) {
return lattice.serialize(node.getContext());
});
node->setAttr(EnzymeDialect::getAliasSummaryAttrName(),
ArrayAttr::get(node.getContext(), aliasAttributes));

for (auto lattice : returnAliasClasses) {
os << "[debug] return alias class: " << lattice << "\n";
}

node->setAttr(getPointerSummaryAttrName(),
p2sets.serialize(node.getContext()));
node->setAttr(pointerSummaryName, p2sets.serialize(node.getContext()));
os << "[ata] p2p summary:\n";
if (node->getAttrOfType<ArrayAttr>(getPointerSummaryAttrName()).size() ==
0) {
if (node->getAttrOfType<ArrayAttr>(pointerSummaryName).size() == 0) {
os << " <empty>\n";
}
for (ArrayAttr pair :
node->getAttrOfType<ArrayAttr>(getPointerSummaryAttrName())
.getAsRange<ArrayAttr>()) {
for (ArrayAttr pair : node->getAttrOfType<ArrayAttr>(pointerSummaryName)
.getAsRange<ArrayAttr>()) {
os << " " << pair[0] << " -> " << pair[1] << "\n";
}

node->setAttr(getActivityAnnotationAttrName(),
voMap.serialize(node.getContext()));
node->setAttr(EnzymeDialect::getDenseActivityAnnotationAttrName(),
forwardOriginsMap.serialize(node.getContext()));
os << "[ata] forward value origins:\n";
for (ArrayAttr pair :
node->getAttrOfType<ArrayAttr>(getActivityAnnotationAttrName())
node->getAttrOfType<ArrayAttr>(
EnzymeDialect::getDenseActivityAnnotationAttrName())
.getAsRange<ArrayAttr>()) {
os << " " << pair[0] << " originates from " << pair[1] << "\n";
}

auto *backwardOriginMap =
auto *backwardOriginsMap =
solver.getOrCreateState<enzyme::BackwardOriginsMap>(
&node.getCallableRegion()->front().front());
Attribute backwardOrigins = backwardOriginMap->serialize(node.getContext());
Attribute backwardOrigins =
backwardOriginsMap->serialize(node.getContext());
os << "[ata] backward value origins:\n";
for (ArrayAttr pair :
cast<ArrayAttr>(backwardOrigins).getAsRange<ArrayAttr>()) {
Expand All @@ -965,22 +965,42 @@ void enzyme::runActivityAnnotations(FunctionOpInterface callee) {
return lattice.serialize(ctx);
});
node->setAttr(
getReturnOriginsAttrName(),
EnzymeDialect::getSparseActivityAnnotationAttrName(),
ArrayAttr::get(node.getContext(), serializedReturnOperandOrigins));
os << "[ata] return origins: " << node->getAttr(getReturnOriginsAttrName())
os << "[ata] return origins: "
<< node->getAttr(EnzymeDialect::getSparseActivityAnnotationAttrName())
<< "\n";

node.getCallableRegion()->walk([&](Operation *op) {
if (op->hasAttr("tag")) {
for (OpResult result : op->getResults()) {
auto *sources =
solver.getOrCreateState<enzyme::ForwardOriginsLattice>(result);
auto *sinks =
solver.getOrCreateState<enzyme::BackwardOriginsLattice>(result);
os << op->getAttr("tag") << "(#" << result.getResultNumber()
<< ") sources:\n"
<< *sources << "sinks:\n"
<< *sinks << "\n";
auto *aliasClasses =
solver.getOrCreateState<enzyme::AliasClassLattice>(result);
if (aliasClasses->isUndefined()) {
// Not a pointer, check the sources and sinks from the sparse state
auto *sources =
solver.getOrCreateState<enzyme::ForwardOriginsLattice>(result);
auto *sinks =
solver.getOrCreateState<enzyme::BackwardOriginsLattice>(result);
os << op->getAttr("tag") << "(#" << result.getResultNumber()
<< ")\n"
<< " sources: " << sources->serialize(ctx) << "\n"
<< " sinks: " << sinks->serialize(ctx) << "\n";
} else {
// Is a pointer, see the origins of whatever it points to
ForwardOriginsLattice sources(result, ValueOriginSet());
BackwardOriginsLattice sinks(result, ValueOriginSet());
traversePointsToSets(
aliasClasses->getAliasClassesObject(), p2sets,
[&](DistinctAttr aliasClass) {
(void)sources.merge(forwardOriginsMap.getOrigins(aliasClass));
(void)sinks.merge(backwardOriginsMap->getOrigins(aliasClass));
});
os << op->getAttr("tag") << "(#" << result.getResultNumber()
<< ")\n"
<< " sources: " << sources.serialize(ctx) << "\n"
<< " sinks: " << sinks.serialize(ctx) << "\n";
}
}
}
});
Expand Down
62 changes: 61 additions & 1 deletion enzyme/Enzyme/MLIR/Analysis/AliasAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
//
//===----------------------------------------------------------------------===//
#include "AliasAnalysis.h"
#include "Dialect/Dialect.h"
#include "Dialect/Ops.h"

#include "mlir/Analysis/AliasAnalysis.h"
Expand Down Expand Up @@ -629,7 +630,8 @@ void enzyme::PointsToPointerAnalysis::visitCallControlFlowTransfer(
// into pointers that are non-arguments.
if (auto callee = SymbolTable::lookupNearestSymbolFrom<FunctionOpInterface>(
call, symbol.getLeafReference())) {
if (auto summaryAttr = callee->getAttrOfType<ArrayAttr>("p2psummary")) {
if (auto summaryAttr = callee->getAttrOfType<ArrayAttr>(
EnzymeDialect::getPointerSummaryAttrName())) {
DenseMap<DistinctAttr, AliasClassSet> summary;
deserializePointsTo(summaryAttr, summary);
return processCallToSummarizedFunc(call, summary, after);
Expand Down Expand Up @@ -1127,6 +1129,28 @@ void enzyme::AliasAnalysis::visitOperation(
}
}

static void
deserializeAliasSummary(ArrayAttr summary,
SmallVectorImpl<enzyme::AliasClassSet> &out) {
for (Attribute element : summary) {
if (auto strAttr = dyn_cast<StringAttr>(element)) {
if (strAttr.getValue() == "<unknown>") {
out.push_back(enzyme::AliasClassSet::getUnknown());
} else {
assert(strAttr.getValue() == "<undefined>");
out.push_back(enzyme::AliasClassSet::getUndefined());
}
} else {
enzyme::AliasClassSet aliasClasses;
for (DistinctAttr aliasClass :
cast<ArrayAttr>(element).getAsRange<DistinctAttr>()) {
(void)aliasClasses.insert({aliasClass});
}
out.push_back(aliasClasses);
}
}
}

void enzyme::AliasAnalysis::visitExternalCall(
CallOpInterface call, ArrayRef<const AliasClassLattice *> operands,
ArrayRef<AliasClassLattice *> results) {
Expand Down Expand Up @@ -1155,6 +1179,42 @@ void enzyme::AliasAnalysis::visitExternalCall(
if (!callee)
return markResultsUnknown();

if (auto aliasSummaryAttr = callee->getAttrOfType<ArrayAttr>(
EnzymeDialect::getAliasSummaryAttrName())) {
// The summary tells us what operands may alias with what results
SmallVector<AliasClassSet> aliasSummary;
deserializeAliasSummary(aliasSummaryAttr, aliasSummary);

assert(results.size() == aliasSummary.size());
for (auto &&[i, resultSummary] : llvm::enumerate(aliasSummary)) {
// Would be nice to zip over results and aliasSummary, but requires
// capture of structured binding (may require a newer clang version)
AliasClassLattice *result = results[i];
ChangeResult changed = ChangeResult::NoChange;
if (resultSummary.isUndefined())
continue;
if (resultSummary.isUnknown())
changed |= result->markUnknown();
else {
changed |= resultSummary.foreachElement(
[&](DistinctAttr aliasClass, AliasClassSet::State state) {
assert(state == AliasClassSet::State::Defined);
if (auto pseudoClass = dyn_cast<PseudoAliasClassAttr>(
aliasClass.getReferencedAttr())) {
assert(
pseudoClass.getDepth() == 0 &&
"sparse alias summaries for depth > 0 not yet implemented");
return result->join(*operands[pseudoClass.getArgNumber()]);
} else {
return result->insert({aliasClass});
}
});
}
propagateIfChanged(result, changed);
}
return;
}

// Collect alias classes that can be read through the arguments.
std::optional<LLVM::ModRefInfo> argModRef = getFunctionArgModRef(callee);
std::optional<LLVM::ModRefInfo> otherModRef = getFunctionOtherModRef(callee);
Expand Down
8 changes: 8 additions & 0 deletions enzyme/Enzyme/MLIR/Dialect/Dialect.td
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,14 @@ def Enzyme_Dialect : Dialect {
let cppNamespace = "::mlir::enzyme";
let useDefaultAttributePrinterParser = 1;
let useDefaultTypePrinterParser = 1;

let extraClassDeclaration = [{
/// Names of analysis summary attributes
static StringRef getPointerSummaryAttrName() { return "enzyme.p2p"; }
static StringRef getAliasSummaryAttrName() { return "enzyme.alias"; }
static StringRef getDenseActivityAnnotationAttrName() { return "enzyme.denseactive"; }
static StringRef getSparseActivityAnnotationAttrName() { return "enzyme.sparseactive"; }
}];
}

//===----------------------------------------------------------------------===//
Expand Down
Loading

0 comments on commit 15334ad

Please sign in to comment.