diff --git a/enzyme/Enzyme/MLIR/Analysis/ActivityAnnotations.cpp b/enzyme/Enzyme/MLIR/Analysis/ActivityAnnotations.cpp index 1bf1442c9854..2ba2af8e378c 100644 --- a/enzyme/Enzyme/MLIR/Analysis/ActivityAnnotations.cpp +++ b/enzyme/Enzyme/MLIR/Analysis/ActivityAnnotations.cpp @@ -18,6 +18,10 @@ using namespace mlir; +static bool isPossiblyActive(Type type) { + return isa(type); +} + template void printSetLattice(const enzyme::SparseSetLattice &setLattice, raw_ostream &os) { @@ -55,6 +59,9 @@ void enzyme::ForwardActivityAnnotationAnalysis::setToEntryState( assert(lattice->isUndefined()); return; } + if (!isPossiblyActive(arg.getType())) { + return; + } auto funcOp = cast(arg.getOwner()->getParentOp()); auto origin = ArgumentOriginAttr::get(FlatSymbolRefAttr::get(funcOp), @@ -79,10 +86,6 @@ static bool isNoOp(Operation *op) { LLVM::LifetimeEndOp, LLVM::AssumeOp>(op); } -static bool isPossiblyActive(Type type) { - return isa(type); -} - void enzyme::ForwardActivityAnnotationAnalysis::visitOperation( Operation *op, ArrayRef operands, ArrayRef results) { @@ -459,6 +462,9 @@ void enzyme::DenseActivityAnnotationAnalysis::visitOperation( propagateIfChanged(after, changed); } else if (isa(effect.getEffect())) { if (std::optional stored = getStored(op)) { + if (!isPossiblyActive(stored->getType())) { + continue; + } auto *origins = getOrCreateFor(op, *stored); auto *dest = getOrCreateFor(op, value); propagateIfChanged(after, after->insert(dest->getAliasClassesObject(), @@ -821,6 +827,18 @@ void enzyme::DenseBackwardActivityAnnotationAnalysis::processCopy( } namespace { + +// TODO: the alias summary attribute is sufficent to get the correct behaviour +// here, but it would be nice if these were not hardcoded. +void annotateHardcoded(FunctionOpInterface func) { + if (func.getName() == "lgamma" || func.getName() == "tanh") { + MLIRContext *ctx = func.getContext(); + SmallVector arr = {StringAttr::get(ctx, "")}; + func->setAttr(enzyme::EnzymeDialect::getAliasSummaryAttrName(), + ArrayAttr::get(ctx, arr)); + } +} + /// Starting from callee, compute a reverse (bottom-up) topological sorting of /// all functions transitively called from callee. void reverseToposortCallgraph(CallableOpInterface callee, @@ -1006,10 +1024,14 @@ void enzyme::runActivityAnnotations( StringRef pointerSummaryName = EnzymeDialect::getPointerSummaryAttrName(); for (CallableOpInterface node : sorted) { + annotateHardcoded(cast(node.getOperation())); + if (!node.getCallableRegion() || node->hasAttr(pointerSummaryName)) continue; auto funcOp = cast(node.getOperation()); - os << "[ata] processing function @" << funcOp.getName() << "\n"; + if (activityConfig.verbose) { + os << "[ata] processing function @" << funcOp.getName() << "\n"; + } DataFlowConfig dataFlowConfig; dataFlowConfig.setInterprocedural(false); DataFlowSolver solver(dataFlowConfig);