From cf55d27ff4bf20e3704af6f660df6a2d76aa755d Mon Sep 17 00:00:00 2001 From: Jacob Peng Date: Wed, 28 Feb 2024 22:19:59 -0500 Subject: [PATCH] Traverse down call ops when going top down --- .../MLIR/Analysis/ActivityAnnotations.cpp | 67 +++++++++++++++++-- 1 file changed, 63 insertions(+), 4 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Analysis/ActivityAnnotations.cpp b/enzyme/Enzyme/MLIR/Analysis/ActivityAnnotations.cpp index 86c0df694f28..047394896abe 100644 --- a/enzyme/Enzyme/MLIR/Analysis/ActivityAnnotations.cpp +++ b/enzyme/Enzyme/MLIR/Analysis/ActivityAnnotations.cpp @@ -849,16 +849,21 @@ void initializeSparseBackwardActivityAnnotations(FunctionOpInterface func, } } +using OriginsPair = + std::pair; + /// Once having reached a top-level entry point, go top-down and convert the /// relative sources/sinks into concrete active/constant results. /// /// This would ideally be done after lowering to LLVM and during differentiation /// because it loses context sensitivity, but this is faster to prototype with. -void topDownActivityAnalysis(FunctionOpInterface callee, - ArrayRef argActivities, - ArrayRef retActivities) { +void topDownActivityAnalysis( + FunctionOpInterface callee, ArrayRef argActivities, + ArrayRef retActivities, + DenseMap &blockArgOrigins) { using namespace mlir::enzyme; MLIRContext *ctx = callee.getContext(); + callee->setAttr("enzyme.visited", UnitAttr::get(ctx)); auto trueAttr = BoolAttr::get(ctx, true); auto falseAttr = BoolAttr::get(ctx, false); @@ -926,6 +931,45 @@ void topDownActivityAnalysis(FunctionOpInterface callee, op->removeAttr("enzyme.activeop"); op->removeAttr("enzyme.opsrc"); op->removeAttr("enzyme.opsink"); + + if (auto callOp = dyn_cast(op)) { + auto funcOp = cast(callOp.resolveCallable()); + if (!funcOp->hasAttr("enzyme.visited")) { + SmallVector callArgActivities, callResActivities; + for (Value operand : callOp.getArgOperands()) { + if (auto *definingOp = operand.getDefiningOp()) { + bool icv = + definingOp->getAttrOfType("enzyme.icv").getValue(); + callArgActivities.push_back(icv ? Activity::enzyme_const + : Activity::enzyme_out); + } else { + BlockArgument blockArg = cast(operand); + const OriginsPair &originsPair = blockArgOrigins.at(blockArg); + const ForwardOriginsLattice &sources = originsPair.first; + const BackwardOriginsLattice &sinks = originsPair.second; + bool argActive = false; + if (sources.isUnknown() || sinks.isUnknown()) { + argActive = true; + } else if (sources.isUndefined() || sinks.isUndefined()) { + argActive = false; + } else { + argActive = llvm::any_of(sources.getOrigins(), isOriginActive) && + llvm::any_of(sinks.getOrigins(), isOriginActive); + } + callArgActivities.push_back(argActive ? Activity::enzyme_out + : Activity::enzyme_const); + } + } + if (op->getNumResults() != 0) { + bool icv = op->getAttrOfType("enzyme.icv").getValue(); + callResActivities.push_back(icv ? Activity::enzyme_const + : Activity::enzyme_out); + } + + topDownActivityAnalysis(funcOp, callArgActivities, callResActivities, + blockArgOrigins); + } + } }); } } // namespace @@ -938,6 +982,9 @@ void enzyme::runActivityAnnotations( reverseToposortCallgraph(callee, &symbolTable, sorted); raw_ostream &os = llvm::outs(); + // TODO: is there any way of serializing information in a block argument? + DenseMap blockArgOrigins; + StringRef pointerSummaryName = EnzymeDialect::getPointerSummaryAttrName(); for (CallableOpInterface node : sorted) { if (!node.getCallableRegion() || node->hasAttr(pointerSummaryName)) @@ -1165,6 +1212,17 @@ void enzyme::runActivityAnnotations( } }; + // We lose the solver state when going top down and I don't know a better + // way to serialize block argument information. + node.getCallableRegion()->walk([&](Block *block) { + for (BlockArgument blockArg : block->getArguments()) { + OriginsPair blockArgAttributes({blockArg, ValueOriginSet()}, + {blockArg, ValueOriginSet()}); + joinActiveValueState(blockArg, blockArgAttributes); + blockArgOrigins.try_emplace(blockArg, blockArgAttributes); + } + }); + node.getCallableRegion()->walk([&](Operation *op) { if (activityConfig.annotate) annotateActivity(op); @@ -1195,6 +1253,7 @@ void enzyme::runActivityAnnotations( : Activity::enzyme_const); } - topDownActivityAnalysis(callee, argActivities, resActivities); + topDownActivityAnalysis(callee, argActivities, resActivities, + blockArgOrigins); } }