Skip to content

Commit

Permalink
MLIR enable vararg activity
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Mar 6, 2024
1 parent 9706635 commit c18863f
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,8 @@ LogicalResult mlir::enzyme::detail::memoryIdentityForwardHandler(
} else {
if (gutils->isConstantValue(operand.get())) {

if (contains(storedVals, operand.getOperandNumber())) {
if (contains(storedVals, operand.getOperandNumber()) ||
contains(storedVals, -1)) {
if (auto iface =
dyn_cast<AutoDiffTypeInterface>(operand.get().getType())) {
if (!iface.isMutable()) {
Expand Down
29 changes: 19 additions & 10 deletions enzyme/Enzyme/MLIR/Passes/PrintActivityAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,22 @@ using namespace mlir;
namespace {
using llvm::errs;

enzyme::Activity getDefaultActivity(Type argType) {
if (argType.isIntOrIndex())
return enzyme::Activity::enzyme_const;

if (isa<FloatType, ComplexType>(argType))
return enzyme::Activity::enzyme_out;

if (auto T = dyn_cast<TensorType>(argType))
return getDefaultActivity(T.getElementType());

if (isa<LLVM::LLVMPointerType, MemRefType>(argType))
return enzyme::Activity::enzyme_dup;

return enzyme::Activity::enzyme_const;
}

struct PrintActivityAnalysisPass
: public enzyme::PrintActivityAnalysisPassBase<PrintActivityAnalysisPass> {

Expand All @@ -43,25 +59,18 @@ struct PrintActivityAnalysisPass
MutableArrayRef<enzyme::Activity> resActivities) const {
for (const auto &[idx, argType] :
llvm::enumerate(callee.getArgumentTypes())) {
if (callee.getArgAttr(idx, "enzyme.const") || inactiveArgs ||
argType.isIntOrIndex())
if (callee.getArgAttr(idx, "enzyme.const") || inactiveArgs)
argActivities[idx] = enzyme::Activity::enzyme_const;
else if (isa<FloatType, ComplexType>(argType))
argActivities[idx] = enzyme::Activity::enzyme_out;
else if (isa<LLVM::LLVMPointerType, MemRefType>(argType))
argActivities[idx] = enzyme::Activity::enzyme_dup;
else
argActivities[idx] = enzyme::Activity::enzyme_const;
argActivities[idx] = getDefaultActivity(argType);
}

for (const auto &[idx, resType] :
llvm::enumerate(callee.getResultTypes())) {
if (duplicatedRet)
resActivities[idx] = (enzyme::Activity::enzyme_dup);
else if (isa<FloatType>(resType))
resActivities[idx] = (enzyme::Activity::enzyme_out);
else
resActivities[idx] = (enzyme::Activity::enzyme_const);
resActivities[idx] = getDefaultActivity(resType);
}
}

Expand Down
8 changes: 8 additions & 0 deletions enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2050,9 +2050,17 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os,
os << " };\n";
os << " bool isArgInactive(mlir::Operation*, size_t idx) const {\n";
for (auto diffarg : diffargs) {
if (diffarg == -1) {
os << " return false;\n";
break;
}
os << " if (idx == " << diffarg << ") return false;\n";
}
for (auto diffarg : storedargs) {
if (diffarg == -1) {
os << " return false;\n";
break;
}
os << " if (idx == " << diffarg << ") return false;\n";
}
os << " return true;\n }\n";
Expand Down

0 comments on commit c18863f

Please sign in to comment.