From c22c1f64323205ddbcfccf7f87354cfd3c82a041 Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Wed, 6 Mar 2024 10:44:37 -0500 Subject: [PATCH] MLIR enable vararg activity --- .../MLIR/Passes/PrintActivityAnalysis.cpp | 29 ++++++++++++------- enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 8 +++++ 2 files changed, 27 insertions(+), 10 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Passes/PrintActivityAnalysis.cpp b/enzyme/Enzyme/MLIR/Passes/PrintActivityAnalysis.cpp index ac88d0fc86fa..f86535c1d82d 100644 --- a/enzyme/Enzyme/MLIR/Passes/PrintActivityAnalysis.cpp +++ b/enzyme/Enzyme/MLIR/Passes/PrintActivityAnalysis.cpp @@ -32,6 +32,22 @@ using namespace mlir; namespace { using llvm::errs; +enzyme::Activity getDefaultActivity(Type argType) { + if (argType.isIntOrIndex()) + return enzyme::Activity::enzyme_const; + + if (isa(argType)) + return enzyme::Activity::enzyme_out; + + if (auto T = dyn_cast(argType)) + return getDefaultActivity(T.getElementType()); + + if (isa(argType)) + return enzyme::Activity::enzyme_dup; + + return enzyme::Activity::enzyme_const; +} + struct PrintActivityAnalysisPass : public enzyme::PrintActivityAnalysisPassBase { @@ -43,25 +59,18 @@ struct PrintActivityAnalysisPass MutableArrayRef resActivities) const { for (const auto &[idx, argType] : llvm::enumerate(callee.getArgumentTypes())) { - if (callee.getArgAttr(idx, "enzyme.const") || inactiveArgs || - argType.isIntOrIndex()) + if (callee.getArgAttr(idx, "enzyme.const") || inactiveArgs) argActivities[idx] = enzyme::Activity::enzyme_const; - else if (isa(argType)) - argActivities[idx] = enzyme::Activity::enzyme_out; - else if (isa(argType)) - argActivities[idx] = enzyme::Activity::enzyme_dup; else - argActivities[idx] = enzyme::Activity::enzyme_const; + argActivities[idx] = getDefaultActivity(argType); } for (const auto &[idx, resType] : llvm::enumerate(callee.getResultTypes())) { if (duplicatedRet) resActivities[idx] = (enzyme::Activity::enzyme_dup); - else if (isa(resType)) - resActivities[idx] = (enzyme::Activity::enzyme_out); else - resActivities[idx] = (enzyme::Activity::enzyme_const); + resActivities[idx] = getDefaultActivity(resType); } } diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index 7796486fe5de..a5f7fa768f55 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -2050,9 +2050,17 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, os << " };\n"; os << " bool isArgInactive(mlir::Operation*, size_t idx) const {\n"; for (auto diffarg : diffargs) { + if (diffarg == -1) { + os << " return false;\n"; + break; + } os << " if (idx == " << diffarg << ") return false;\n"; } for (auto diffarg : storedargs) { + if (diffarg == -1) { + os << " return false;\n"; + break; + } os << " if (idx == " << diffarg << ") return false;\n"; } os << " return true;\n }\n";