From 87fc569b026451d498908f1b1c665272ff2d5989 Mon Sep 17 00:00:00 2001 From: William Moses Date: Wed, 6 Mar 2024 11:23:46 -0600 Subject: [PATCH] MLIR enable vararg activity (#1781) --- .../CoreDialectsAutoDiffImplementations.cpp | 3 +- .../MLIR/Passes/PrintActivityAnalysis.cpp | 29 ++++++++++++------- enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 8 +++++ 3 files changed, 29 insertions(+), 11 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp index ade1be7e6406..4129be80c1d5 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp @@ -141,7 +141,8 @@ LogicalResult mlir::enzyme::detail::memoryIdentityForwardHandler( } else { if (gutils->isConstantValue(operand.get())) { - if (contains(storedVals, operand.getOperandNumber())) { + if (contains(storedVals, operand.getOperandNumber()) || + contains(storedVals, -1)) { if (auto iface = dyn_cast(operand.get().getType())) { if (!iface.isMutable()) { diff --git a/enzyme/Enzyme/MLIR/Passes/PrintActivityAnalysis.cpp b/enzyme/Enzyme/MLIR/Passes/PrintActivityAnalysis.cpp index ac88d0fc86fa..f86535c1d82d 100644 --- a/enzyme/Enzyme/MLIR/Passes/PrintActivityAnalysis.cpp +++ b/enzyme/Enzyme/MLIR/Passes/PrintActivityAnalysis.cpp @@ -32,6 +32,22 @@ using namespace mlir; namespace { using llvm::errs; +enzyme::Activity getDefaultActivity(Type argType) { + if (argType.isIntOrIndex()) + return enzyme::Activity::enzyme_const; + + if (isa(argType)) + return enzyme::Activity::enzyme_out; + + if (auto T = dyn_cast(argType)) + return getDefaultActivity(T.getElementType()); + + if (isa(argType)) + return enzyme::Activity::enzyme_dup; + + return enzyme::Activity::enzyme_const; +} + struct PrintActivityAnalysisPass : public enzyme::PrintActivityAnalysisPassBase { @@ -43,25 +59,18 @@ struct PrintActivityAnalysisPass MutableArrayRef resActivities) const { for (const auto &[idx, argType] : llvm::enumerate(callee.getArgumentTypes())) { - if (callee.getArgAttr(idx, "enzyme.const") || inactiveArgs || - argType.isIntOrIndex()) + if (callee.getArgAttr(idx, "enzyme.const") || inactiveArgs) argActivities[idx] = enzyme::Activity::enzyme_const; - else if (isa(argType)) - argActivities[idx] = enzyme::Activity::enzyme_out; - else if (isa(argType)) - argActivities[idx] = enzyme::Activity::enzyme_dup; else - argActivities[idx] = enzyme::Activity::enzyme_const; + argActivities[idx] = getDefaultActivity(argType); } for (const auto &[idx, resType] : llvm::enumerate(callee.getResultTypes())) { if (duplicatedRet) resActivities[idx] = (enzyme::Activity::enzyme_dup); - else if (isa(resType)) - resActivities[idx] = (enzyme::Activity::enzyme_out); else - resActivities[idx] = (enzyme::Activity::enzyme_const); + resActivities[idx] = getDefaultActivity(resType); } } diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index 7796486fe5de..a5f7fa768f55 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -2050,9 +2050,17 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, os << " };\n"; os << " bool isArgInactive(mlir::Operation*, size_t idx) const {\n"; for (auto diffarg : diffargs) { + if (diffarg == -1) { + os << " return false;\n"; + break; + } os << " if (idx == " << diffarg << ") return false;\n"; } for (auto diffarg : storedargs) { + if (diffarg == -1) { + os << " return false;\n"; + break; + } os << " if (idx == " << diffarg << ") return false;\n"; } os << " return true;\n }\n";