Skip to content

Commit

Permalink
Implement annotating ops with relative sources and sinks
Browse files Browse the repository at this point in the history
  • Loading branch information
pengmai committed Feb 27, 2024
1 parent 53410f7 commit 8059822
Show file tree
Hide file tree
Showing 9 changed files with 192 additions and 68 deletions.
206 changes: 154 additions & 52 deletions enzyme/Enzyme/MLIR/Analysis/ActivityAnnotations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -850,7 +850,8 @@ void initializeSparseBackwardActivityAnnotations(FunctionOpInterface func,
}
} // namespace

void enzyme::runActivityAnnotations(FunctionOpInterface callee) {
void enzyme::runActivityAnnotations(
FunctionOpInterface callee, const ActivityPrinterConfig &activityConfig) {
SymbolTableCollection symbolTable;
SmallVector<CallableOpInterface> sorted;
reverseToposortCallgraph(callee, &symbolTable, sorted);
Expand All @@ -862,9 +863,9 @@ void enzyme::runActivityAnnotations(FunctionOpInterface callee) {
continue;
auto funcOp = cast<FunctionOpInterface>(node.getOperation());
os << "[ata] processing function @" << funcOp.getName() << "\n";
DataFlowConfig config;
config.setInterprocedural(false);
DataFlowSolver solver(config);
DataFlowConfig dataFlowConfig;
dataFlowConfig.setInterprocedural(false);
DataFlowSolver solver(dataFlowConfig);
SymbolTableCollection symbolTable;

solver.load<dataflow::SparseConstantPropagation>();
Expand Down Expand Up @@ -910,12 +911,7 @@ void enzyme::runActivityAnnotations(FunctionOpInterface callee) {
}
}

// for (BlockArgument arg : node.getCallableRegion()->getArguments()) {
// auto *backwardState =
// solver.getOrCreateState<enzyme::BackwardOriginsLattice>(arg);
// os << "[debug] backward state for arg " << arg.getArgNumber() << ": "
// << *backwardState << "\n";
// }
// Sparse alias annotations
SmallVector<Attribute> aliasAttributes(returnAliasClasses.size());
llvm::transform(returnAliasClasses, aliasAttributes.begin(),
[&](enzyme::AliasClassLattice lattice) {
Expand All @@ -924,35 +920,42 @@ void enzyme::runActivityAnnotations(FunctionOpInterface callee) {
node->setAttr(EnzymeDialect::getAliasSummaryAttrName(),
ArrayAttr::get(node.getContext(), aliasAttributes));

// Points-to-pointer annotations
node->setAttr(pointerSummaryName, p2sets.serialize(node.getContext()));
os << "[ata] p2p summary:\n";
if (node->getAttrOfType<ArrayAttr>(pointerSummaryName).size() == 0) {
os << " <empty>\n";
}
for (ArrayAttr pair : node->getAttrOfType<ArrayAttr>(pointerSummaryName)
.getAsRange<ArrayAttr>()) {
os << " " << pair[0] << " -> " << pair[1] << "\n";
if (activityConfig.verbose) {
os << "[ata] p2p summary:\n";
if (node->getAttrOfType<ArrayAttr>(pointerSummaryName).size() == 0) {
os << " <empty>\n";
}
for (ArrayAttr pair : node->getAttrOfType<ArrayAttr>(pointerSummaryName)
.getAsRange<ArrayAttr>()) {
os << " " << pair[0] << " -> " << pair[1] << "\n";
}
}

node->setAttr(EnzymeDialect::getDenseActivityAnnotationAttrName(),
forwardOriginsMap.serialize(node.getContext()));
os << "[ata] forward value origins:\n";
for (ArrayAttr pair :
node->getAttrOfType<ArrayAttr>(
EnzymeDialect::getDenseActivityAnnotationAttrName())
.getAsRange<ArrayAttr>()) {
os << " " << pair[0] << " originates from " << pair[1] << "\n";
if (activityConfig.verbose) {
os << "[ata] forward value origins:\n";
for (ArrayAttr pair :
node->getAttrOfType<ArrayAttr>(
EnzymeDialect::getDenseActivityAnnotationAttrName())
.getAsRange<ArrayAttr>()) {
os << " " << pair[0] << " originates from " << pair[1] << "\n";
}
}

auto *backwardOriginsMap =
solver.getOrCreateState<enzyme::BackwardOriginsMap>(
&node.getCallableRegion()->front().front());
Attribute backwardOrigins =
backwardOriginsMap->serialize(node.getContext());
os << "[ata] backward value origins:\n";
for (ArrayAttr pair :
cast<ArrayAttr>(backwardOrigins).getAsRange<ArrayAttr>()) {
os << " " << pair[0] << " goes to " << pair[1] << "\n";
if (activityConfig.verbose) {
os << "[ata] backward value origins:\n";
for (ArrayAttr pair :
cast<ArrayAttr>(backwardOrigins).getAsRange<ArrayAttr>()) {
os << " " << pair[0] << " goes to " << pair[1] << "\n";
}
}

// Serialize return origins
Expand All @@ -967,39 +970,138 @@ void enzyme::runActivityAnnotations(FunctionOpInterface callee) {
node->setAttr(
EnzymeDialect::getSparseActivityAnnotationAttrName(),
ArrayAttr::get(node.getContext(), serializedReturnOperandOrigins));
os << "[ata] return origins: "
<< node->getAttr(EnzymeDialect::getSparseActivityAnnotationAttrName())
<< "\n";
if (activityConfig.verbose) {
os << "[ata] return origins: "
<< node->getAttr(EnzymeDialect::getSparseActivityAnnotationAttrName())
<< "\n";
}

node.getCallableRegion()->walk([&](Operation *op) {
if (op->hasAttr("tag")) {
for (OpResult result : op->getResults()) {
auto joinActiveDataState =
[&](Value value,
std::pair<ForwardOriginsLattice, BackwardOriginsLattice> &out) {
auto *sources = solver.getOrCreateState<ForwardOriginsLattice>(value);
auto *sinks = solver.getOrCreateState<BackwardOriginsLattice>(value);
(void)out.first.join(*sources);
(void)out.second.meet(*sinks);
};

auto joinActivePointerState =
[&](const AliasClassSet &aliasClasses,
std::pair<ForwardOriginsLattice, BackwardOriginsLattice> &out) {
traversePointsToSets(
aliasClasses, p2sets, [&](DistinctAttr aliasClass) {
(void)out.first.merge(forwardOriginsMap.getOrigins(aliasClass));
(void)out.second.merge(
backwardOriginsMap->getOrigins(aliasClass));
});
};

auto joinActiveValueState =
[&](Value value,
std::pair<ForwardOriginsLattice, BackwardOriginsLattice> &out) {
auto *aliasClasses =
solver.getOrCreateState<enzyme::AliasClassLattice>(result);
solver.getOrCreateState<AliasClassLattice>(value);
if (aliasClasses->isUndefined()) {
// Not a pointer, check the sources and sinks from the sparse state
auto *sources =
solver.getOrCreateState<enzyme::ForwardOriginsLattice>(result);
auto *sinks =
solver.getOrCreateState<enzyme::BackwardOriginsLattice>(result);
os << op->getAttr("tag") << "(#" << result.getResultNumber()
<< ")\n"
<< " sources: " << sources->serialize(ctx) << "\n"
<< " sinks: " << sinks->serialize(ctx) << "\n";
joinActiveDataState(value, out);
} else {
// Is a pointer, see the origins of whatever it points to
ForwardOriginsLattice sources(result, ValueOriginSet());
BackwardOriginsLattice sinks(result, ValueOriginSet());
traversePointsToSets(
aliasClasses->getAliasClassesObject(), p2sets,
[&](DistinctAttr aliasClass) {
(void)sources.merge(forwardOriginsMap.getOrigins(aliasClass));
(void)sinks.merge(backwardOriginsMap->getOrigins(aliasClass));
});
joinActivePointerState(aliasClasses->getAliasClassesObject(), out);
}
};

auto annotateActivity = [&](Operation *op) {
assert(op->getNumResults() < 2 && op->getNumRegions() == 0 &&
"annotation only supports the LLVM dialect");
auto unitAttr = UnitAttr::get(ctx);
// Check activity of values
for (OpResult result : op->getResults()) {
std::pair<ForwardOriginsLattice, BackwardOriginsLattice>
activityAttributes({result, ValueOriginSet()},
{result, ValueOriginSet()});
joinActiveValueState(result, activityAttributes);
const auto &sources = activityAttributes.first;
const auto &sinks = activityAttributes.second;
// Possible states: if either source or sink is undefined or empty, the
// value is always constant.
if (sources.isUnknown() || sinks.isUnknown()) {
// Always active
op->setAttr("enzyme.activeval", unitAttr);
} else if (sources.isUndefined() || sinks.isUndefined()) {
// Always constant
op->setAttr("enzyme.constantval", unitAttr);
} else {
// Conditionally active depending on the activity of sources and sinks
op->setAttr("enzyme.valsrc", sources.serialize(ctx));
op->setAttr("enzyme.valsink", sinks.serialize(ctx));
}
}
// Check activity of operation
StringRef opSourceAttrName = "enzyme.opsrc";
StringRef opSinkAttrName = "enzyme.opsink";
std::pair<ForwardOriginsLattice, BackwardOriginsLattice> opAttributes(
{nullptr, ValueOriginSet()}, {nullptr, ValueOriginSet()});
if (isPure(op)) {
// A pure operation can only propagate data via its results
std::pair<ForwardOriginsLattice, BackwardOriginsLattice> opAttributes(
{nullptr, ValueOriginSet()}, {nullptr, ValueOriginSet()});
for (OpResult result : op->getResults()) {
joinActiveDataState(result, opAttributes);
}
} else {
// We need a special case because stores of active pointers don't fit
// the definition but are active instructions
if (auto storeOp = dyn_cast<LLVM::StoreOp>(op)) {
auto *storedClass =
solver.getOrCreateState<AliasClassLattice>(storeOp.getValue());
joinActivePointerState(storedClass->getAliasClassesObject(),
opAttributes);
} else if (auto callOp = dyn_cast<CallOpInterface>(op)) {
// TODO: can we just use the summary?
// If a call op receives or returns any active pointers or data,
// consider it active.
for (Value operand : callOp.getArgOperands())
joinActiveValueState(operand, opAttributes);
for (OpResult result : callOp->getResults())
joinActiveValueState(result, opAttributes);
}

// Default: the op is active iff any of its operands or results are
// active data.
for (Value operand : op->getOperands())
joinActiveDataState(operand, opAttributes);
for (OpResult result : op->getResults())
joinActiveDataState(result, opAttributes);
}

const auto &opSources = opAttributes.first;
const auto &opSinks = opAttributes.second;
if (opSources.isUnknown() || opSinks.isUnknown()) {
op->setAttr("enzyme.activeop", unitAttr);
} else if (opSources.isUndefined() || opSinks.isUndefined()) {
op->setAttr("enzyme.constantop", unitAttr);
} else {
op->setAttr(opSourceAttrName, opAttributes.first.serialize(ctx));
op->setAttr(opSinkAttrName, opAttributes.second.serialize(ctx));
}
};

node.getCallableRegion()->walk([&](Operation *op) {
if (activityConfig.annotate)
annotateActivity(op);
if (activityConfig.verbose) {
if (op->hasAttr("tag")) {
for (OpResult result : op->getResults()) {
std::pair<ForwardOriginsLattice, BackwardOriginsLattice>
activityAttributes({result, ValueOriginSet()},
{result, ValueOriginSet()});
joinActiveValueState(result, activityAttributes);
os << op->getAttr("tag") << "(#" << result.getResultNumber()
<< ")\n"
<< " sources: " << sources.serialize(ctx) << "\n"
<< " sinks: " << sinks.serialize(ctx) << "\n";
<< " sources: " << activityAttributes.first.serialize(ctx)
<< "\n"
<< " sinks: " << activityAttributes.second.serialize(ctx)
<< "\n";
}
}
}
Expand Down
17 changes: 16 additions & 1 deletion enzyme/Enzyme/MLIR/Analysis/ActivityAnnotations.h
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,22 @@ class DenseBackwardActivityAnnotationAnalysis
const BackwardOriginsMap &after, BackwardOriginsMap *before);
};

void runActivityAnnotations(FunctionOpInterface callee);
class ActivityPrinterConfig {
public:
ActivityPrinterConfig() = default;

/// Output extra information for debugging
bool verbose = false;
/// Annotate the IR with activity information for every operation. Currently
/// only supports the LLVM dialect.
bool annotate = false;
/// Infer the starting argument state from an __enzyme_autodiff call.
bool inferFromAutodiff = false;
};

void runActivityAnnotations(
FunctionOpInterface callee,
const ActivityPrinterConfig &config = ActivityPrinterConfig());

} // namespace enzyme
} // namespace mlir
Expand Down
7 changes: 5 additions & 2 deletions enzyme/Enzyme/MLIR/Analysis/Lattice.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ class SparseSetLattice : public dataflow::AbstractSparseLattice {
SparseSetLattice(Value value, SetLattice<ValueT> &&elements)
: dataflow::AbstractSparseLattice(value), elements(std::move(elements)) {}

Attribute serialize(MLIRContext *ctx) { return serializeSetNaive(ctx); }
Attribute serialize(MLIRContext *ctx) const { return serializeSetNaive(ctx); }

ChangeResult merge(const SetLattice<ValueT> &other) {
return elements.join(other);
Expand All @@ -219,7 +219,7 @@ class SparseSetLattice : public dataflow::AbstractSparseLattice {
SetLattice<ValueT> elements;

private:
Attribute serializeSetNaive(MLIRContext *ctx) {
Attribute serializeSetNaive(MLIRContext *ctx) const {
if (elements.isUndefined())
return StringAttr::get(ctx, "<undefined>");
if (elements.isUnknown())
Expand All @@ -235,6 +235,9 @@ class SparseSetLattice : public dataflow::AbstractSparseLattice {

//===----------------------------------------------------------------------===//
// MapOfSetsLattice
//
// A lattice for use in dense analyses that maps keys (usually static memory
// locations) to sets of values.
//===----------------------------------------------------------------------===//

template <typename KeyT, typename ElementT>
Expand Down
6 changes: 3 additions & 3 deletions enzyme/Enzyme/MLIR/Passes/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,11 @@ def PrintActivityAnalysisPass : Pass<"print-activity-analysis"> {
/*description=*/"Annotate every operation and value with its activity"
>,
Option<
/*C++ variable name=*/"useAnnotations",
/*CLI argument=*/"use-annotations",
/*C++ variable name=*/"relative",
/*CLI argument=*/"relative",
/*type=*/"bool",
/*default=*/"false",
/*description=*/"Use bottom-up activity annotations"
/*description=*/"Use relative bottom-up activity analysis"
>,
Option<
/*C++ variable name=*/"inactiveArgs",
Expand Down
16 changes: 10 additions & 6 deletions enzyme/Enzyme/MLIR/Passes/PrintActivityAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
using namespace mlir;

namespace {
using llvm::errs;

struct PrintActivityAnalysisPass
: public enzyme::PrintActivityAnalysisPassBase<PrintActivityAnalysisPass> {
Expand Down Expand Up @@ -118,6 +117,11 @@ struct PrintActivityAnalysisPass
}

void runOnOperation() override {
enzyme::ActivityPrinterConfig config;
config.annotate = annotate;
config.inferFromAutodiff = false;
config.verbose = verbose;

auto moduleOp = cast<ModuleOp>(getOperation());

if (annotate) {
Expand All @@ -142,8 +146,8 @@ struct PrintActivityAnalysisPass
auto callee =
cast<FunctionOpInterface>(moduleOp.lookupSymbol(calleeAttr));

if (useAnnotations) {
enzyme::runActivityAnnotations(callee);
if (relative) {
enzyme::runActivityAnnotations(callee, config);
} else {
SmallVector<enzyme::Activity> argActivities{callee.getNumArguments()},
resultActivities{callee.getNumResults()};
Expand All @@ -161,12 +165,12 @@ struct PrintActivityAnalysisPass
}

if (funcsToAnalyze.empty()) {
moduleOp.walk([this](FunctionOpInterface callee) {
moduleOp.walk([this, &config](FunctionOpInterface callee) {
if (callee.isExternal() || callee.isPrivate())
return;

if (useAnnotations) {
enzyme::runActivityAnnotations(callee);
if (relative) {
enzyme::runActivityAnnotations(callee, config);
} else {

SmallVector<enzyme::Activity> argActivities{callee.getNumArguments()},
Expand Down
2 changes: 1 addition & 1 deletion enzyme/test/MLIR/ActivityAnalysis/Summaries/basic.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: %eopt --print-activity-analysis='use-annotations' --split-input-file %s | FileCheck %s
// RUN: %eopt --print-activity-analysis='relative verbose' --split-input-file %s | FileCheck %s

// CHECK-LABEL: processing function @sparse_callee
// CHECK: "fadd"(#0)
Expand Down
2 changes: 1 addition & 1 deletion enzyme/test/MLIR/ActivityAnalysis/Summaries/bude.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: %eopt --print-activity-analysis='use-annotations' --split-input-file %s | FileCheck %s
// RUN: %eopt --print-activity-analysis='relative verbose' --split-input-file %s | FileCheck %s

#alias_scope_domain = #llvm.alias_scope_domain<id = distinct[0]<>, description = "wrap">
#loop_unroll = #llvm.loop_unroll<disable = true>
Expand Down
2 changes: 1 addition & 1 deletion enzyme/test/MLIR/AliasAnalysis/Summaries/leaf_nodes.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: %eopt --print-activity-analysis='use-annotations' --split-input-file %s | FileCheck %s
// RUN: %eopt --print-activity-analysis='relative verbose' --split-input-file %s | FileCheck %s

#alias_scope_domain = #llvm.alias_scope_domain<id = distinct[0]<>, description = "mat_mult">
#alias_scope_domain1 = #llvm.alias_scope_domain<id = distinct[1]<>, description = "mat_mult">
Expand Down
2 changes: 1 addition & 1 deletion enzyme/test/MLIR/AliasAnalysis/Summaries/parent_nodes.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// RUN: %eopt --print-activity-analysis='use-annotations' --split-input-file %s | FileCheck %s
// RUN: %eopt --print-activity-analysis='relative verbose' --split-input-file %s | FileCheck %s

#alias_scope_domain = #llvm.alias_scope_domain<id = distinct[0]<>, description = "mat_mult">
#alias_scope_domain1 = #llvm.alias_scope_domain<id = distinct[1]<>, description = "mat_mult">
Expand Down

0 comments on commit 8059822

Please sign in to comment.