diff --git a/enzyme/Enzyme/MLIR/Analysis/ActivityAnnotations.cpp b/enzyme/Enzyme/MLIR/Analysis/ActivityAnnotations.cpp index d0fb9e339f46..ad45d890893a 100644 --- a/enzyme/Enzyme/MLIR/Analysis/ActivityAnnotations.cpp +++ b/enzyme/Enzyme/MLIR/Analysis/ActivityAnnotations.cpp @@ -850,7 +850,8 @@ void initializeSparseBackwardActivityAnnotations(FunctionOpInterface func, } } // namespace -void enzyme::runActivityAnnotations(FunctionOpInterface callee) { +void enzyme::runActivityAnnotations( + FunctionOpInterface callee, const ActivityPrinterConfig &activityConfig) { SymbolTableCollection symbolTable; SmallVector sorted; reverseToposortCallgraph(callee, &symbolTable, sorted); @@ -862,9 +863,9 @@ void enzyme::runActivityAnnotations(FunctionOpInterface callee) { continue; auto funcOp = cast(node.getOperation()); os << "[ata] processing function @" << funcOp.getName() << "\n"; - DataFlowConfig config; - config.setInterprocedural(false); - DataFlowSolver solver(config); + DataFlowConfig dataFlowConfig; + dataFlowConfig.setInterprocedural(false); + DataFlowSolver solver(dataFlowConfig); SymbolTableCollection symbolTable; solver.load(); @@ -910,12 +911,7 @@ void enzyme::runActivityAnnotations(FunctionOpInterface callee) { } } - // for (BlockArgument arg : node.getCallableRegion()->getArguments()) { - // auto *backwardState = - // solver.getOrCreateState(arg); - // os << "[debug] backward state for arg " << arg.getArgNumber() << ": " - // << *backwardState << "\n"; - // } + // Sparse alias annotations SmallVector aliasAttributes(returnAliasClasses.size()); llvm::transform(returnAliasClasses, aliasAttributes.begin(), [&](enzyme::AliasClassLattice lattice) { @@ -924,24 +920,29 @@ void enzyme::runActivityAnnotations(FunctionOpInterface callee) { node->setAttr(EnzymeDialect::getAliasSummaryAttrName(), ArrayAttr::get(node.getContext(), aliasAttributes)); + // Points-to-pointer annotations node->setAttr(pointerSummaryName, p2sets.serialize(node.getContext())); - os << "[ata] p2p summary:\n"; - if (node->getAttrOfType(pointerSummaryName).size() == 0) { - os << " \n"; - } - for (ArrayAttr pair : node->getAttrOfType(pointerSummaryName) - .getAsRange()) { - os << " " << pair[0] << " -> " << pair[1] << "\n"; + if (activityConfig.verbose) { + os << "[ata] p2p summary:\n"; + if (node->getAttrOfType(pointerSummaryName).size() == 0) { + os << " \n"; + } + for (ArrayAttr pair : node->getAttrOfType(pointerSummaryName) + .getAsRange()) { + os << " " << pair[0] << " -> " << pair[1] << "\n"; + } } node->setAttr(EnzymeDialect::getDenseActivityAnnotationAttrName(), forwardOriginsMap.serialize(node.getContext())); - os << "[ata] forward value origins:\n"; - for (ArrayAttr pair : - node->getAttrOfType( - EnzymeDialect::getDenseActivityAnnotationAttrName()) - .getAsRange()) { - os << " " << pair[0] << " originates from " << pair[1] << "\n"; + if (activityConfig.verbose) { + os << "[ata] forward value origins:\n"; + for (ArrayAttr pair : + node->getAttrOfType( + EnzymeDialect::getDenseActivityAnnotationAttrName()) + .getAsRange()) { + os << " " << pair[0] << " originates from " << pair[1] << "\n"; + } } auto *backwardOriginsMap = @@ -949,10 +950,12 @@ void enzyme::runActivityAnnotations(FunctionOpInterface callee) { &node.getCallableRegion()->front().front()); Attribute backwardOrigins = backwardOriginsMap->serialize(node.getContext()); - os << "[ata] backward value origins:\n"; - for (ArrayAttr pair : - cast(backwardOrigins).getAsRange()) { - os << " " << pair[0] << " goes to " << pair[1] << "\n"; + if (activityConfig.verbose) { + os << "[ata] backward value origins:\n"; + for (ArrayAttr pair : + cast(backwardOrigins).getAsRange()) { + os << " " << pair[0] << " goes to " << pair[1] << "\n"; + } } // Serialize return origins @@ -967,39 +970,138 @@ void enzyme::runActivityAnnotations(FunctionOpInterface callee) { node->setAttr( EnzymeDialect::getSparseActivityAnnotationAttrName(), ArrayAttr::get(node.getContext(), serializedReturnOperandOrigins)); - os << "[ata] return origins: " - << node->getAttr(EnzymeDialect::getSparseActivityAnnotationAttrName()) - << "\n"; + if (activityConfig.verbose) { + os << "[ata] return origins: " + << node->getAttr(EnzymeDialect::getSparseActivityAnnotationAttrName()) + << "\n"; + } - node.getCallableRegion()->walk([&](Operation *op) { - if (op->hasAttr("tag")) { - for (OpResult result : op->getResults()) { + auto joinActiveDataState = + [&](Value value, + std::pair &out) { + auto *sources = solver.getOrCreateState(value); + auto *sinks = solver.getOrCreateState(value); + (void)out.first.join(*sources); + (void)out.second.meet(*sinks); + }; + + auto joinActivePointerState = + [&](const AliasClassSet &aliasClasses, + std::pair &out) { + traversePointsToSets( + aliasClasses, p2sets, [&](DistinctAttr aliasClass) { + (void)out.first.merge(forwardOriginsMap.getOrigins(aliasClass)); + (void)out.second.merge( + backwardOriginsMap->getOrigins(aliasClass)); + }); + }; + + auto joinActiveValueState = + [&](Value value, + std::pair &out) { auto *aliasClasses = - solver.getOrCreateState(result); + solver.getOrCreateState(value); if (aliasClasses->isUndefined()) { // Not a pointer, check the sources and sinks from the sparse state - auto *sources = - solver.getOrCreateState(result); - auto *sinks = - solver.getOrCreateState(result); - os << op->getAttr("tag") << "(#" << result.getResultNumber() - << ")\n" - << " sources: " << sources->serialize(ctx) << "\n" - << " sinks: " << sinks->serialize(ctx) << "\n"; + joinActiveDataState(value, out); } else { // Is a pointer, see the origins of whatever it points to - ForwardOriginsLattice sources(result, ValueOriginSet()); - BackwardOriginsLattice sinks(result, ValueOriginSet()); - traversePointsToSets( - aliasClasses->getAliasClassesObject(), p2sets, - [&](DistinctAttr aliasClass) { - (void)sources.merge(forwardOriginsMap.getOrigins(aliasClass)); - (void)sinks.merge(backwardOriginsMap->getOrigins(aliasClass)); - }); + joinActivePointerState(aliasClasses->getAliasClassesObject(), out); + } + }; + + auto annotateActivity = [&](Operation *op) { + assert(op->getNumResults() < 2 && op->getNumRegions() == 0 && + "annotation only supports the LLVM dialect"); + auto unitAttr = UnitAttr::get(ctx); + // Check activity of values + for (OpResult result : op->getResults()) { + std::pair + activityAttributes({result, ValueOriginSet()}, + {result, ValueOriginSet()}); + joinActiveValueState(result, activityAttributes); + const auto &sources = activityAttributes.first; + const auto &sinks = activityAttributes.second; + // Possible states: if either source or sink is undefined or empty, the + // value is always constant. + if (sources.isUnknown() || sinks.isUnknown()) { + // Always active + op->setAttr("enzyme.activeval", unitAttr); + } else if (sources.isUndefined() || sinks.isUndefined()) { + // Always constant + op->setAttr("enzyme.constantval", unitAttr); + } else { + // Conditionally active depending on the activity of sources and sinks + op->setAttr("enzyme.valsrc", sources.serialize(ctx)); + op->setAttr("enzyme.valsink", sinks.serialize(ctx)); + } + } + // Check activity of operation + StringRef opSourceAttrName = "enzyme.opsrc"; + StringRef opSinkAttrName = "enzyme.opsink"; + std::pair opAttributes( + {nullptr, ValueOriginSet()}, {nullptr, ValueOriginSet()}); + if (isPure(op)) { + // A pure operation can only propagate data via its results + std::pair opAttributes( + {nullptr, ValueOriginSet()}, {nullptr, ValueOriginSet()}); + for (OpResult result : op->getResults()) { + joinActiveDataState(result, opAttributes); + } + } else { + // We need a special case because stores of active pointers don't fit + // the definition but are active instructions + if (auto storeOp = dyn_cast(op)) { + auto *storedClass = + solver.getOrCreateState(storeOp.getValue()); + joinActivePointerState(storedClass->getAliasClassesObject(), + opAttributes); + } else if (auto callOp = dyn_cast(op)) { + // TODO: can we just use the summary? + // If a call op receives or returns any active pointers or data, + // consider it active. + for (Value operand : callOp.getArgOperands()) + joinActiveValueState(operand, opAttributes); + for (OpResult result : callOp->getResults()) + joinActiveValueState(result, opAttributes); + } + + // Default: the op is active iff any of its operands or results are + // active data. + for (Value operand : op->getOperands()) + joinActiveDataState(operand, opAttributes); + for (OpResult result : op->getResults()) + joinActiveDataState(result, opAttributes); + } + + const auto &opSources = opAttributes.first; + const auto &opSinks = opAttributes.second; + if (opSources.isUnknown() || opSinks.isUnknown()) { + op->setAttr("enzyme.activeop", unitAttr); + } else if (opSources.isUndefined() || opSinks.isUndefined()) { + op->setAttr("enzyme.constantop", unitAttr); + } else { + op->setAttr(opSourceAttrName, opAttributes.first.serialize(ctx)); + op->setAttr(opSinkAttrName, opAttributes.second.serialize(ctx)); + } + }; + + node.getCallableRegion()->walk([&](Operation *op) { + if (activityConfig.annotate) + annotateActivity(op); + if (activityConfig.verbose) { + if (op->hasAttr("tag")) { + for (OpResult result : op->getResults()) { + std::pair + activityAttributes({result, ValueOriginSet()}, + {result, ValueOriginSet()}); + joinActiveValueState(result, activityAttributes); os << op->getAttr("tag") << "(#" << result.getResultNumber() << ")\n" - << " sources: " << sources.serialize(ctx) << "\n" - << " sinks: " << sinks.serialize(ctx) << "\n"; + << " sources: " << activityAttributes.first.serialize(ctx) + << "\n" + << " sinks: " << activityAttributes.second.serialize(ctx) + << "\n"; } } } diff --git a/enzyme/Enzyme/MLIR/Analysis/ActivityAnnotations.h b/enzyme/Enzyme/MLIR/Analysis/ActivityAnnotations.h index 933137a96f02..cbd65e2ee449 100644 --- a/enzyme/Enzyme/MLIR/Analysis/ActivityAnnotations.h +++ b/enzyme/Enzyme/MLIR/Analysis/ActivityAnnotations.h @@ -211,7 +211,22 @@ class DenseBackwardActivityAnnotationAnalysis const BackwardOriginsMap &after, BackwardOriginsMap *before); }; -void runActivityAnnotations(FunctionOpInterface callee); +class ActivityPrinterConfig { +public: + ActivityPrinterConfig() = default; + + /// Output extra information for debugging + bool verbose = false; + /// Annotate the IR with activity information for every operation. Currently + /// only supports the LLVM dialect. + bool annotate = false; + /// Infer the starting argument state from an __enzyme_autodiff call. + bool inferFromAutodiff = false; +}; + +void runActivityAnnotations( + FunctionOpInterface callee, + const ActivityPrinterConfig &config = ActivityPrinterConfig()); } // namespace enzyme } // namespace mlir diff --git a/enzyme/Enzyme/MLIR/Analysis/Lattice.h b/enzyme/Enzyme/MLIR/Analysis/Lattice.h index 9c8b6b187295..25fd37ebf57b 100644 --- a/enzyme/Enzyme/MLIR/Analysis/Lattice.h +++ b/enzyme/Enzyme/MLIR/Analysis/Lattice.h @@ -197,7 +197,7 @@ class SparseSetLattice : public dataflow::AbstractSparseLattice { SparseSetLattice(Value value, SetLattice &&elements) : dataflow::AbstractSparseLattice(value), elements(std::move(elements)) {} - Attribute serialize(MLIRContext *ctx) { return serializeSetNaive(ctx); } + Attribute serialize(MLIRContext *ctx) const { return serializeSetNaive(ctx); } ChangeResult merge(const SetLattice &other) { return elements.join(other); @@ -219,7 +219,7 @@ class SparseSetLattice : public dataflow::AbstractSparseLattice { SetLattice elements; private: - Attribute serializeSetNaive(MLIRContext *ctx) { + Attribute serializeSetNaive(MLIRContext *ctx) const { if (elements.isUndefined()) return StringAttr::get(ctx, ""); if (elements.isUnknown()) @@ -235,6 +235,9 @@ class SparseSetLattice : public dataflow::AbstractSparseLattice { //===----------------------------------------------------------------------===// // MapOfSetsLattice +// +// A lattice for use in dense analyses that maps keys (usually static memory +// locations) to sets of values. //===----------------------------------------------------------------------===// template diff --git a/enzyme/Enzyme/MLIR/Passes/Passes.td b/enzyme/Enzyme/MLIR/Passes/Passes.td index a59789f01cbc..503eff26ac2c 100644 --- a/enzyme/Enzyme/MLIR/Passes/Passes.td +++ b/enzyme/Enzyme/MLIR/Passes/Passes.td @@ -105,11 +105,11 @@ def PrintActivityAnalysisPass : Pass<"print-activity-analysis"> { /*description=*/"Annotate every operation and value with its activity" >, Option< - /*C++ variable name=*/"useAnnotations", - /*CLI argument=*/"use-annotations", + /*C++ variable name=*/"relative", + /*CLI argument=*/"relative", /*type=*/"bool", /*default=*/"false", - /*description=*/"Use bottom-up activity annotations" + /*description=*/"Use relative bottom-up activity analysis" >, Option< /*C++ variable name=*/"inactiveArgs", diff --git a/enzyme/Enzyme/MLIR/Passes/PrintActivityAnalysis.cpp b/enzyme/Enzyme/MLIR/Passes/PrintActivityAnalysis.cpp index a8f792c5eb31..daccc8d13d7f 100644 --- a/enzyme/Enzyme/MLIR/Passes/PrintActivityAnalysis.cpp +++ b/enzyme/Enzyme/MLIR/Passes/PrintActivityAnalysis.cpp @@ -29,7 +29,6 @@ using namespace mlir; namespace { -using llvm::errs; struct PrintActivityAnalysisPass : public enzyme::PrintActivityAnalysisPassBase { @@ -118,6 +117,11 @@ struct PrintActivityAnalysisPass } void runOnOperation() override { + enzyme::ActivityPrinterConfig config; + config.annotate = annotate; + config.inferFromAutodiff = false; + config.verbose = verbose; + auto moduleOp = cast(getOperation()); if (annotate) { @@ -142,8 +146,8 @@ struct PrintActivityAnalysisPass auto callee = cast(moduleOp.lookupSymbol(calleeAttr)); - if (useAnnotations) { - enzyme::runActivityAnnotations(callee); + if (relative) { + enzyme::runActivityAnnotations(callee, config); } else { SmallVector argActivities{callee.getNumArguments()}, resultActivities{callee.getNumResults()}; @@ -161,12 +165,12 @@ struct PrintActivityAnalysisPass } if (funcsToAnalyze.empty()) { - moduleOp.walk([this](FunctionOpInterface callee) { + moduleOp.walk([this, &config](FunctionOpInterface callee) { if (callee.isExternal() || callee.isPrivate()) return; - if (useAnnotations) { - enzyme::runActivityAnnotations(callee); + if (relative) { + enzyme::runActivityAnnotations(callee, config); } else { SmallVector argActivities{callee.getNumArguments()}, diff --git a/enzyme/test/MLIR/ActivityAnalysis/Summaries/basic.mlir b/enzyme/test/MLIR/ActivityAnalysis/Summaries/basic.mlir index 809e50e77e46..82587036cc8e 100644 --- a/enzyme/test/MLIR/ActivityAnalysis/Summaries/basic.mlir +++ b/enzyme/test/MLIR/ActivityAnalysis/Summaries/basic.mlir @@ -1,4 +1,4 @@ -// RUN: %eopt --print-activity-analysis='use-annotations' --split-input-file %s | FileCheck %s +// RUN: %eopt --print-activity-analysis='relative verbose' --split-input-file %s | FileCheck %s // CHECK-LABEL: processing function @sparse_callee // CHECK: "fadd"(#0) diff --git a/enzyme/test/MLIR/ActivityAnalysis/Summaries/bude.mlir b/enzyme/test/MLIR/ActivityAnalysis/Summaries/bude.mlir index 45c5a31d7141..32b8c879957b 100644 --- a/enzyme/test/MLIR/ActivityAnalysis/Summaries/bude.mlir +++ b/enzyme/test/MLIR/ActivityAnalysis/Summaries/bude.mlir @@ -1,4 +1,4 @@ -// RUN: %eopt --print-activity-analysis='use-annotations' --split-input-file %s | FileCheck %s +// RUN: %eopt --print-activity-analysis='relative verbose' --split-input-file %s | FileCheck %s #alias_scope_domain = #llvm.alias_scope_domain, description = "wrap"> #loop_unroll = #llvm.loop_unroll diff --git a/enzyme/test/MLIR/AliasAnalysis/Summaries/leaf_nodes.mlir b/enzyme/test/MLIR/AliasAnalysis/Summaries/leaf_nodes.mlir index 13372df22036..fb82d0c3c21d 100644 --- a/enzyme/test/MLIR/AliasAnalysis/Summaries/leaf_nodes.mlir +++ b/enzyme/test/MLIR/AliasAnalysis/Summaries/leaf_nodes.mlir @@ -1,4 +1,4 @@ -// RUN: %eopt --print-activity-analysis='use-annotations' --split-input-file %s | FileCheck %s +// RUN: %eopt --print-activity-analysis='relative verbose' --split-input-file %s | FileCheck %s #alias_scope_domain = #llvm.alias_scope_domain, description = "mat_mult"> #alias_scope_domain1 = #llvm.alias_scope_domain, description = "mat_mult"> diff --git a/enzyme/test/MLIR/AliasAnalysis/Summaries/parent_nodes.mlir b/enzyme/test/MLIR/AliasAnalysis/Summaries/parent_nodes.mlir index 52f5b7f9fd49..a952fa7cb04a 100644 --- a/enzyme/test/MLIR/AliasAnalysis/Summaries/parent_nodes.mlir +++ b/enzyme/test/MLIR/AliasAnalysis/Summaries/parent_nodes.mlir @@ -1,4 +1,4 @@ -// RUN: %eopt --print-activity-analysis='use-annotations' --split-input-file %s | FileCheck %s +// RUN: %eopt --print-activity-analysis='relative verbose' --split-input-file %s | FileCheck %s #alias_scope_domain = #llvm.alias_scope_domain, description = "mat_mult"> #alias_scope_domain1 = #llvm.alias_scope_domain, description = "mat_mult">