diff --git a/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp b/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp index 1415559b3f6b..30fbfc8ef345 100644 --- a/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp +++ b/enzyme/Enzyme/MLIR/Analysis/ActivityAnalysis.cpp @@ -248,7 +248,7 @@ static Operation *getFunctionFromCall(CallOpInterface iface) { return SymbolTable::lookupNearestSymbolFrom(iface.getOperation(), symbol); } -constexpr bool EnzymePrintActivity = false; +constexpr bool EnzymePrintActivity = true; /// Is the use of value val as an argument of call CI known to be inactive /// This tool can only be used when in DOWN mode diff --git a/enzyme/Enzyme/MLIR/Passes/PrintActivityAnalysis.cpp b/enzyme/Enzyme/MLIR/Passes/PrintActivityAnalysis.cpp index ac88d0fc86fa..daeb2a0dcba9 100644 --- a/enzyme/Enzyme/MLIR/Passes/PrintActivityAnalysis.cpp +++ b/enzyme/Enzyme/MLIR/Passes/PrintActivityAnalysis.cpp @@ -32,6 +32,22 @@ using namespace mlir; namespace { using llvm::errs; +enzyme::Activity getDefaultActivity(Type argType) { + if (argType.isIntOrIndex()) + return enzyme::Activity::enzyme_const; + + if (isa(argType)) + return enzyme::Activity::enzyme_out; + + if (auto T = dyn_cast(argType)) + return getDefaultActivity(T.getElementType()); + + if (isa(argType)) + return enzyme::Activity::enzyme_dup; + + return enzyme::Activity::enzyme_const; +} + struct PrintActivityAnalysisPass : public enzyme::PrintActivityAnalysisPassBase { @@ -43,25 +59,18 @@ struct PrintActivityAnalysisPass MutableArrayRef resActivities) const { for (const auto &[idx, argType] : llvm::enumerate(callee.getArgumentTypes())) { - if (callee.getArgAttr(idx, "enzyme.const") || inactiveArgs || - argType.isIntOrIndex()) - argActivities[idx] = enzyme::Activity::enzyme_const; - else if (isa(argType)) - argActivities[idx] = enzyme::Activity::enzyme_out; - else if (isa(argType)) - argActivities[idx] = enzyme::Activity::enzyme_dup; - else + if (callee.getArgAttr(idx, "enzyme.const") || inactiveArgs) argActivities[idx] = enzyme::Activity::enzyme_const; + else + argActivities[idx] = getDefaultActivity(argType); } for (const auto &[idx, resType] : llvm::enumerate(callee.getResultTypes())) { if (duplicatedRet) resActivities[idx] = (enzyme::Activity::enzyme_dup); - else if (isa(resType)) - resActivities[idx] = (enzyme::Activity::enzyme_out); else - resActivities[idx] = (enzyme::Activity::enzyme_const); + resActivities[idx] = getDefaultActivity(resType); } } diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index 7796486fe5de..a5f7fa768f55 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -2050,9 +2050,17 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, os << " };\n"; os << " bool isArgInactive(mlir::Operation*, size_t idx) const {\n"; for (auto diffarg : diffargs) { + if (diffarg == -1) { + os << " return false;\n"; + break; + } os << " if (idx == " << diffarg << ") return false;\n"; } for (auto diffarg : storedargs) { + if (diffarg == -1) { + os << " return false;\n"; + break; + } os << " if (idx == " << diffarg << ") return false;\n"; } os << " return true;\n }\n";