Skip to content

Commit

Permalink
Traverse down call ops when going top down
Browse files Browse the repository at this point in the history
  • Loading branch information
pengmai committed Feb 29, 2024
1 parent 7fdfc79 commit cf55d27
Showing 1 changed file with 63 additions and 4 deletions.
67 changes: 63 additions & 4 deletions enzyme/Enzyme/MLIR/Analysis/ActivityAnnotations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -849,16 +849,21 @@ void initializeSparseBackwardActivityAnnotations(FunctionOpInterface func,
}
}

using OriginsPair =
std::pair<enzyme::ForwardOriginsLattice, enzyme::BackwardOriginsLattice>;

/// Once having reached a top-level entry point, go top-down and convert the
/// relative sources/sinks into concrete active/constant results.
///
/// This would ideally be done after lowering to LLVM and during differentiation
/// because it loses context sensitivity, but this is faster to prototype with.
void topDownActivityAnalysis(FunctionOpInterface callee,
ArrayRef<enzyme::Activity> argActivities,
ArrayRef<enzyme::Activity> retActivities) {
void topDownActivityAnalysis(
FunctionOpInterface callee, ArrayRef<enzyme::Activity> argActivities,
ArrayRef<enzyme::Activity> retActivities,
DenseMap<BlockArgument, OriginsPair> &blockArgOrigins) {
using namespace mlir::enzyme;
MLIRContext *ctx = callee.getContext();
callee->setAttr("enzyme.visited", UnitAttr::get(ctx));
auto trueAttr = BoolAttr::get(ctx, true);
auto falseAttr = BoolAttr::get(ctx, false);

Expand Down Expand Up @@ -926,6 +931,45 @@ void topDownActivityAnalysis(FunctionOpInterface callee,
op->removeAttr("enzyme.activeop");
op->removeAttr("enzyme.opsrc");
op->removeAttr("enzyme.opsink");

if (auto callOp = dyn_cast<CallOpInterface>(op)) {
auto funcOp = cast<FunctionOpInterface>(callOp.resolveCallable());
if (!funcOp->hasAttr("enzyme.visited")) {
SmallVector<Activity> callArgActivities, callResActivities;
for (Value operand : callOp.getArgOperands()) {
if (auto *definingOp = operand.getDefiningOp()) {
bool icv =
definingOp->getAttrOfType<BoolAttr>("enzyme.icv").getValue();
callArgActivities.push_back(icv ? Activity::enzyme_const
: Activity::enzyme_out);
} else {
BlockArgument blockArg = cast<BlockArgument>(operand);
const OriginsPair &originsPair = blockArgOrigins.at(blockArg);
const ForwardOriginsLattice &sources = originsPair.first;
const BackwardOriginsLattice &sinks = originsPair.second;
bool argActive = false;
if (sources.isUnknown() || sinks.isUnknown()) {
argActive = true;
} else if (sources.isUndefined() || sinks.isUndefined()) {
argActive = false;
} else {
argActive = llvm::any_of(sources.getOrigins(), isOriginActive) &&
llvm::any_of(sinks.getOrigins(), isOriginActive);
}
callArgActivities.push_back(argActive ? Activity::enzyme_out
: Activity::enzyme_const);
}
}
if (op->getNumResults() != 0) {
bool icv = op->getAttrOfType<BoolAttr>("enzyme.icv").getValue();
callResActivities.push_back(icv ? Activity::enzyme_const
: Activity::enzyme_out);
}

topDownActivityAnalysis(funcOp, callArgActivities, callResActivities,
blockArgOrigins);
}
}
});
}
} // namespace
Expand All @@ -938,6 +982,9 @@ void enzyme::runActivityAnnotations(
reverseToposortCallgraph(callee, &symbolTable, sorted);
raw_ostream &os = llvm::outs();

// TODO: is there any way of serializing information in a block argument?
DenseMap<BlockArgument, OriginsPair> blockArgOrigins;

StringRef pointerSummaryName = EnzymeDialect::getPointerSummaryAttrName();
for (CallableOpInterface node : sorted) {
if (!node.getCallableRegion() || node->hasAttr(pointerSummaryName))
Expand Down Expand Up @@ -1165,6 +1212,17 @@ void enzyme::runActivityAnnotations(
}
};

// We lose the solver state when going top down and I don't know a better
// way to serialize block argument information.
node.getCallableRegion()->walk([&](Block *block) {
for (BlockArgument blockArg : block->getArguments()) {
OriginsPair blockArgAttributes({blockArg, ValueOriginSet()},
{blockArg, ValueOriginSet()});
joinActiveValueState(blockArg, blockArgAttributes);
blockArgOrigins.try_emplace(blockArg, blockArgAttributes);
}
});

node.getCallableRegion()->walk([&](Operation *op) {
if (activityConfig.annotate)
annotateActivity(op);
Expand Down Expand Up @@ -1195,6 +1253,7 @@ void enzyme::runActivityAnnotations(
: Activity::enzyme_const);
}

topDownActivityAnalysis(callee, argActivities, resActivities);
topDownActivityAnalysis(callee, argActivities, resActivities,
blockArgOrigins);
}
}

0 comments on commit cf55d27

Please sign in to comment.