Skip to content

Commit

Permalink
Do not initialize sparse annotations of non-varied arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
pengmai committed Mar 5, 2024
1 parent b7feaa2 commit a93fa0a
Showing 1 changed file with 27 additions and 5 deletions.
32 changes: 27 additions & 5 deletions enzyme/Enzyme/MLIR/Analysis/ActivityAnnotations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@

using namespace mlir;

static bool isPossiblyActive(Type type) {
return isa<FloatType, ComplexType>(type);
}

template <typename ValueT>
void printSetLattice(const enzyme::SparseSetLattice<ValueT> &setLattice,
raw_ostream &os) {
Expand Down Expand Up @@ -55,6 +59,9 @@ void enzyme::ForwardActivityAnnotationAnalysis::setToEntryState(
assert(lattice->isUndefined());
return;
}
if (!isPossiblyActive(arg.getType())) {
return;
}

auto funcOp = cast<FunctionOpInterface>(arg.getOwner()->getParentOp());
auto origin = ArgumentOriginAttr::get(FlatSymbolRefAttr::get(funcOp),
Expand All @@ -79,10 +86,6 @@ static bool isNoOp(Operation *op) {
LLVM::LifetimeEndOp, LLVM::AssumeOp>(op);
}

static bool isPossiblyActive(Type type) {
return isa<FloatType, ComplexType>(type);
}

void enzyme::ForwardActivityAnnotationAnalysis::visitOperation(
Operation *op, ArrayRef<const ForwardOriginsLattice *> operands,
ArrayRef<ForwardOriginsLattice *> results) {
Expand Down Expand Up @@ -459,6 +462,9 @@ void enzyme::DenseActivityAnnotationAnalysis::visitOperation(
propagateIfChanged(after, changed);
} else if (isa<MemoryEffects::Write>(effect.getEffect())) {
if (std::optional<Value> stored = getStored(op)) {
if (!isPossiblyActive(stored->getType())) {
continue;
}
auto *origins = getOrCreateFor<ForwardOriginsLattice>(op, *stored);
auto *dest = getOrCreateFor<AliasClassLattice>(op, value);
propagateIfChanged(after, after->insert(dest->getAliasClassesObject(),
Expand Down Expand Up @@ -821,6 +827,18 @@ void enzyme::DenseBackwardActivityAnnotationAnalysis::processCopy(
}

namespace {

// TODO: the alias summary attribute is sufficent to get the correct behaviour
// here, but it would be nice if these were not hardcoded.
void annotateHardcoded(FunctionOpInterface func) {
if (func.getName() == "lgamma" || func.getName() == "tanh") {
MLIRContext *ctx = func.getContext();
SmallVector<Attribute> arr = {StringAttr::get(ctx, "<undefined>")};
func->setAttr(enzyme::EnzymeDialect::getAliasSummaryAttrName(),
ArrayAttr::get(ctx, arr));
}
}

/// Starting from callee, compute a reverse (bottom-up) topological sorting of
/// all functions transitively called from callee.
void reverseToposortCallgraph(CallableOpInterface callee,
Expand Down Expand Up @@ -1006,10 +1024,14 @@ void enzyme::runActivityAnnotations(

StringRef pointerSummaryName = EnzymeDialect::getPointerSummaryAttrName();
for (CallableOpInterface node : sorted) {
annotateHardcoded(cast<FunctionOpInterface>(node.getOperation()));

if (!node.getCallableRegion() || node->hasAttr(pointerSummaryName))
continue;
auto funcOp = cast<FunctionOpInterface>(node.getOperation());
os << "[ata] processing function @" << funcOp.getName() << "\n";
if (activityConfig.verbose) {
os << "[ata] processing function @" << funcOp.getName() << "\n";
}
DataFlowConfig dataFlowConfig;
dataFlowConfig.setInterprocedural(false);
DataFlowSolver solver(dataFlowConfig);
Expand Down

0 comments on commit a93fa0a

Please sign in to comment.