Skip to content

Commit

Permalink
Add top-down activity phase
Browse files Browse the repository at this point in the history
  • Loading branch information
pengmai committed Feb 29, 2024
1 parent 0b89b5c commit 7fdfc79
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 31 deletions.
122 changes: 107 additions & 15 deletions enzyme/Enzyme/MLIR/Analysis/ActivityAnnotations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -848,10 +848,91 @@ void initializeSparseBackwardActivityAnnotations(FunctionOpInterface func,
}
}
}

/// Once having reached a top-level entry point, go top-down and convert the
/// relative sources/sinks into concrete active/constant results.
///
/// This would ideally be done after lowering to LLVM and during differentiation
/// because it loses context sensitivity, but this is faster to prototype with.
void topDownActivityAnalysis(FunctionOpInterface callee,
ArrayRef<enzyme::Activity> argActivities,
ArrayRef<enzyme::Activity> retActivities) {
using namespace mlir::enzyme;
MLIRContext *ctx = callee.getContext();
auto trueAttr = BoolAttr::get(ctx, true);
auto falseAttr = BoolAttr::get(ctx, false);

auto isOriginActive = [&](OriginAttr origin) {
if (auto argOriginAttr = dyn_cast<ArgumentOriginAttr>(origin)) {
return llvm::is_contained({Activity::enzyme_dup,
Activity::enzyme_dupnoneed,
Activity::enzyme_out},
argActivities[argOriginAttr.getArgNumber()]);
}
auto retOriginAttr = cast<ReturnOriginAttr>(origin);
return llvm::is_contained({Activity::enzyme_dup, Activity::enzyme_dupnoneed,
Activity::enzyme_out},
retActivities[retOriginAttr.getReturnNumber()]);
};
callee.getFunctionBody().walk([&](Operation *op) {
if (op->getNumResults() == 0) {
// Operations that don't return values are definitionally "constant"
op->setAttr("enzyme.icv", trueAttr);
} else {
// Value activity
if (op->hasAttr("enzyme.constantval")) {
op->setAttr("enzyme.icv", trueAttr);
} else if (op->hasAttr("enzyme.activeval")) {
op->setAttr("enzyme.icv", falseAttr);
} else {
auto valueSource = op->getAttrOfType<ArrayAttr>("enzyme.valsrc");
auto valueSink = op->getAttrOfType<ArrayAttr>("enzyme.valsink");
if (!(valueSource && valueSink)) {
llvm::errs() << "[activity] missing attributes for op: " << *op
<< "\n";
}
assert(valueSource && valueSink && "missing attributes for op");
bool activeSource =
llvm::any_of(valueSource.getAsRange<OriginAttr>(), isOriginActive);
bool activeSink =
llvm::any_of(valueSink.getAsRange<OriginAttr>(), isOriginActive);
bool activeVal = activeSource && activeSink;
op->setAttr("enzyme.icv", BoolAttr::get(ctx, !activeVal));
}
}
op->removeAttr("enzyme.constantval");
op->removeAttr("enzyme.activeval");
op->removeAttr("enzyme.valsrc");
op->removeAttr("enzyme.valsink");

// Instruction activity
if (op->hasAttr("enzyme.constantop")) {
op->setAttr("enzyme.ici", trueAttr);
} else if (op->hasAttr("enzyme.activeop")) {
op->setAttr("enzyme.ici", falseAttr);
} else {
bool activeSource = llvm::any_of(
op->getAttrOfType<ArrayAttr>("enzyme.opsrc").getAsRange<OriginAttr>(),
isOriginActive);
bool activeSink =
llvm::any_of(op->getAttrOfType<ArrayAttr>("enzyme.opsink")
.getAsRange<OriginAttr>(),
isOriginActive);
bool activeOp = activeSource && activeSink;
op->setAttr("enzyme.ici", BoolAttr::get(ctx, !activeOp));
}

op->removeAttr("enzyme.constantop");
op->removeAttr("enzyme.activeop");
op->removeAttr("enzyme.opsrc");
op->removeAttr("enzyme.opsink");
});
}
} // namespace

void enzyme::runActivityAnnotations(
FunctionOpInterface callee, const ActivityPrinterConfig &activityConfig) {
FunctionOpInterface callee, ArrayRef<enzyme::Activity> argActivities,
const ActivityPrinterConfig &activityConfig) {
SymbolTableCollection symbolTable;
SmallVector<CallableOpInterface> sorted;
reverseToposortCallgraph(callee, &symbolTable, sorted);
Expand Down Expand Up @@ -999,14 +1080,12 @@ void enzyme::runActivityAnnotations(
auto joinActiveValueState =
[&](Value value,
std::pair<ForwardOriginsLattice, BackwardOriginsLattice> &out) {
auto *aliasClasses =
solver.getOrCreateState<AliasClassLattice>(value);
if (aliasClasses->isUndefined()) {
// Not a pointer, check the sources and sinks from the sparse state
joinActiveDataState(value, out);
} else {
// Is a pointer, see the origins of whatever it points to
if (isa<LLVM::LLVMPointerType, MemRefType>(value.getType())) {
auto *aliasClasses =
solver.getOrCreateState<AliasClassLattice>(value);
joinActivePointerState(aliasClasses->getAliasClassesObject(), out);
} else {
joinActiveDataState(value, out);
}
};

Expand Down Expand Up @@ -1055,13 +1134,15 @@ void enzyme::runActivityAnnotations(
joinActivePointerState(storedClass->getAliasClassesObject(),
opAttributes);
} else if (auto callOp = dyn_cast<CallOpInterface>(op)) {
// TODO: can we just use the summary?
// If a call op receives or returns any active pointers or data,
// consider it active.
for (Value operand : callOp.getArgOperands())
joinActiveValueState(operand, opAttributes);
for (OpResult result : callOp->getResults())
joinActiveValueState(result, opAttributes);
// TODO: tricky, requires some thought
auto callable = cast<CallableOpInterface>(callOp.resolveCallable());
if (callable->hasAttr(
EnzymeDialect::getDenseActivityAnnotationAttrName())) {
for (Value operand : callOp.getArgOperands())
joinActiveValueState(operand, opAttributes);
}
// We need to
// determine if the body of the function contains active instructions
}

// Default: the op is active iff any of its operands or results are
Expand Down Expand Up @@ -1105,4 +1186,15 @@ void enzyme::runActivityAnnotations(
}
});
}

if (!argActivities.empty() && activityConfig.annotate) {
SmallVector<enzyme::Activity> resActivities;
for (Type resultType : callee.getResultTypes()) {
resActivities.push_back(isa<FloatType, ComplexType>(resultType)
? Activity::enzyme_out
: Activity::enzyme_const);
}

topDownActivityAnalysis(callee, argActivities, resActivities);
}
}
2 changes: 1 addition & 1 deletion enzyme/Enzyme/MLIR/Analysis/ActivityAnnotations.h
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ class ActivityPrinterConfig {
};

void runActivityAnnotations(
FunctionOpInterface callee,
FunctionOpInterface callee, ArrayRef<enzyme::Activity> argActivities = {},
const ActivityPrinterConfig &config = ActivityPrinterConfig());

} // namespace enzyme
Expand Down
28 changes: 13 additions & 15 deletions enzyme/Enzyme/MLIR/Passes/PrintActivityAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,16 +146,16 @@ struct PrintActivityAnalysisPass
auto callee =
cast<FunctionOpInterface>(moduleOp.lookupSymbol(calleeAttr));

SmallVector<enzyme::Activity> argActivities{callee.getNumArguments()},
resultActivities{callee.getNumResults()};
// Populate the argument activities based on either the type or the
// supplied annotation. First argument is the callee
inferArgActivitiesFromEnzymeAutodiff(callee, autodiff_call,
argActivities, resultActivities);

if (relative) {
enzyme::runActivityAnnotations(callee, config);
enzyme::runActivityAnnotations(callee, argActivities, config);
} else {
SmallVector<enzyme::Activity> argActivities{callee.getNumArguments()},
resultActivities{callee.getNumResults()};

// Populate the argument activities based on either the type or the
// supplied annotation. First argument is the callee
inferArgActivitiesFromEnzymeAutodiff(callee, autodiff_call,
argActivities, resultActivities);
enzyme::runDataFlowActivityAnalysis(callee, argActivities,
/*print=*/true, verbose,
annotate);
Expand All @@ -169,15 +169,13 @@ struct PrintActivityAnalysisPass
if (callee.isExternal() || callee.isPrivate())
return;

SmallVector<enzyme::Activity> argActivities{callee.getNumArguments()},
resultActivities{callee.getNumResults()};
initializeArgAndResActivities(callee, argActivities, resultActivities);

if (relative) {
enzyme::runActivityAnnotations(callee, config);
enzyme::runActivityAnnotations(callee, argActivities, config);
} else {

SmallVector<enzyme::Activity> argActivities{callee.getNumArguments()},
resultActivities{callee.getNumResults()};
initializeArgAndResActivities(callee, argActivities,
resultActivities);

enzyme::runDataFlowActivityAnalysis(callee, argActivities,
/*print=*/true, verbose,
annotate);
Expand Down

0 comments on commit 7fdfc79

Please sign in to comment.