From 456cf5eddf25fda4d25e499e2f621df5f0ac28d1 Mon Sep 17 00:00:00 2001 From: William Moses Date: Fri, 1 Mar 2024 12:25:11 -0800 Subject: [PATCH] [MLIR] Add read-only reverse mode arg (#1774) --- .../MLIR/Implementations/ArithDerivatives.td | 6 + enzyme/Enzyme/MLIR/Implementations/Common.td | 14 +- enzyme/test/MLIR/ForwardMode/trunc.mlir | 18 + enzyme/test/MLIR/ReverseMode/trunc.mlir | 18 + enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 596 ++++++++++-------- 5 files changed, 370 insertions(+), 282 deletions(-) create mode 100644 enzyme/test/MLIR/ForwardMode/trunc.mlir create mode 100644 enzyme/test/MLIR/ReverseMode/trunc.mlir diff --git a/enzyme/Enzyme/MLIR/Implementations/ArithDerivatives.td b/enzyme/Enzyme/MLIR/Implementations/ArithDerivatives.td index 3d53793be3af..eb0294b4d24d 100644 --- a/enzyme/Enzyme/MLIR/Implementations/ArithDerivatives.td +++ b/enzyme/Enzyme/MLIR/Implementations/ArithDerivatives.td @@ -31,3 +31,9 @@ def : MLIRDerivative<"arith", "DivFOp", (Op $x, $y), ], (CheckedDivF (SubF (SelectIfActive $x, (MulF (Shadow $x), $y), (ConstantFP<"0","arith", "ConstantOp"> $x)), (SelectIfActive $y, (MulF (Shadow $y), $x), (ConstantFP<"0","arith","ConstantOp"> $y))), (MulF $y, $y)) >; + +def ExtF : ArithInst<"ExtFOp">; +def TruncF : ArithInst<"TruncFOp">; + +def : ReadOnlyIdentityOp<"arith", "TruncFOp", [0], (Op $x), [(ExtF (TypeOf $x), (DiffeRet))]>; +def : ReadOnlyIdentityOp<"arith", "ExtFOp", [0], (Op $x), [(TruncF (TypeOf $x), (DiffeRet))]>; diff --git a/enzyme/Enzyme/MLIR/Implementations/Common.td b/enzyme/Enzyme/MLIR/Implementations/Common.td index 3924f4527b00..099e614b8bcd 100644 --- a/enzyme/Enzyme/MLIR/Implementations/Common.td +++ b/enzyme/Enzyme/MLIR/Implementations/Common.td @@ -17,14 +17,21 @@ class ControlFlowOp { string impl = impl_; } -class MemoryIdentityOp ptrargs_, list storedargs_ = []> { + +def Unimplemented { + +} + +class MemoryIdentityOp ptrargs_, list storedargs_ = [], dag patternToMatch=(Unimplemented), list reverse_ = []> { string dialect = dialect_; string opName = opName_; + dag PatternToMatch = patternToMatch; list ptrargs = ptrargs_; list storedargs = storedargs_; + list reverse = reverse_; } -class ReadOnlyIdentityOp ptrargs_> : MemoryIdentityOp; +class ReadOnlyIdentityOp ptrargs_, dag patternToMatch=(Unimplemented), list reverse_ = []> : MemoryIdentityOp; class ReturnOp { string dialect = dialect_; @@ -94,6 +101,9 @@ class ConstantFP : Ope def ResultTypes : GlobalExprgetResultTypes()">; +def TypeOf : Operation { +} + class ArithInst : Inst; class MathInst : Inst; diff --git a/enzyme/test/MLIR/ForwardMode/trunc.mlir b/enzyme/test/MLIR/ForwardMode/trunc.mlir new file mode 100644 index 000000000000..8f3918add4ac --- /dev/null +++ b/enzyme/test/MLIR/ForwardMode/trunc.mlir @@ -0,0 +1,18 @@ +// RUN: %eopt --enzyme %s | FileCheck %s + +module { + func.func @f(%x : f64) -> f32 { + %y = arith.truncf %x : f64 to f32 + return %y : f32 + } + func.func @dsq(%x : f64, %dx : f64) -> f32 { + %r = enzyme.fwddiff @f(%x, %dx) { activity=[#enzyme] } : (f64, f64) -> (f32) + return %r : f32 + } +} + +// CHECK: func.func private @fwddiffef(%[[arg0:.+]]: f64, %[[arg1:.+]]: f64) -> f32 { +// CHECK-NEXT: %[[dy:.+]] = arith.truncf %[[arg1]] : f64 to f32 +// CHECK-NEXT: %[[y:.+]] = arith.truncf %[[arg0]] : f64 to f32 +// CHECK-NEXT: return %[[dy]] : f32 +// CHECK-NEXT: } diff --git a/enzyme/test/MLIR/ReverseMode/trunc.mlir b/enzyme/test/MLIR/ReverseMode/trunc.mlir new file mode 100644 index 000000000000..c078cb0634ae --- /dev/null +++ b/enzyme/test/MLIR/ReverseMode/trunc.mlir @@ -0,0 +1,18 @@ +// RUN: %eopt --enzyme --canonicalize --remove-unnecessary-enzyme-ops --canonicalize --enzyme-simplify-math --cse %s | FileCheck %s --check-prefix=FIN + +module { + func.func @f(%x: f64) -> f32 { + %next = arith.truncf %x : f64 to f32 + return %next : f32 + } + + func.func @dsquare(%x: f64, %dr: f32) -> f64 { + %r = enzyme.autodiff @f(%x, %dr) { activity=[#enzyme] } : (f64, f32) -> f64 + return %r : f64 + } +} + +// FIN: func.func private @diffef(%[[x:.+]]: f64, %[[dx:.+]]: f32) -> f64 { +// FIN-NEXT: %[[res:.+]] = arith.extf %[[dx]] : f32 to f64 +// FIN-NEXT: return %[[res]] : f64 +// FIN-NEXT: } diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index 5f456c8755c6..6eff540b6622 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -354,7 +354,10 @@ bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os, PrintFatalError(pattern->getLoc(), Twine("unknown named operand in typeof") + resultTree->getAsString()); - os << "->getType()"; + if (intrinsic == MLIRDerivatives) + os << ".getType()"; + else + os << "->getType()"; return false; } else if (opName == "VectorSize" || Def->isSubClassOf("VectorSize")) { if (resultRoot->getNumArgs() != 1) @@ -1268,6 +1271,298 @@ static void emitHeaderIncludes(const RecordKeeper &recordKeeper, os << "};\n"; } +static void emitMLIRReverse(raw_ostream &os, Record *pattern, DagInit *tree, + ActionType intrinsic, StringRef origName, + ListInit *argOps) { + + auto opName = pattern->getValueAsString("opName"); + auto dialect = pattern->getValueAsString("dialect"); + os << "struct " << opName << "RevDerivative : \n"; + os << " public " + "ReverseAutoDiffOpInterface::ExternalModel<" + << opName << "RevDerivative, " << dialect << "::" << opName << "> {\n"; + os << " SmallVector cachedArguments(Operation *op,\n"; + os << " MGradientUtilsReverse *gutils) " + "const {\n"; + os << " SmallVector toret(op->getNumOperands(), false);\n"; + StringMap> varNameToCondition; + + std::function)> insert = + [&](DagInit *ptree, ArrayRef prev) { + for (auto treeEn : llvm::enumerate(ptree->getArgs())) { + auto tree = treeEn.value(); + auto name = ptree->getArgNameStr(treeEn.index()); + SmallVector next(prev.begin(), prev.end()); + next.push_back(treeEn.index()); + if (auto dg = dyn_cast(tree)) + insert(dg, next); + + if (name.size()) { + varNameToCondition[name] = std::make_tuple( + "idx == " + std::to_string(treeEn.index()), "", false); + } + } + }; + + insert(tree, {}); + + if (tree->getNameStr().size()) + varNameToCondition[tree->getNameStr()] = + std::make_tuple("ILLEGAL", "ILLEGAL", false); + + os << " for (size_t idx=0; idxgetNumOperands(); idx++) {\n"; + os << " bool used = false;\n"; + printDiffUse(os, " ", argOps, origName, intrinsic, tree, + varNameToCondition); + os << " toret[idx] = used;\n"; + os << " }\n"; + os << " return toret;\n"; + os << " }\n"; + + os << " SmallVector cacheValues(Operation *op,\n"; + os << " MGradientUtilsReverse *gutils) " + "const {\n"; + os << " if (gutils->isConstantInstruction(op) || " + "gutils->isConstantValue(op->getResult(0))) return {};\n"; + os << " auto neededArgs = cachedArguments(op, gutils);\n"; + os << " SmallVector toret;\n"; + os << " OpBuilder builder(gutils->getNewFromOriginal(op));\n"; + os << " for (auto en : llvm::enumerate(neededArgs))\n"; + os << " if (en.value()) {\n"; + os << " Value cache = " + "gutils->initAndPushCache(gutils->getNewFromOriginal(op->" + "getOperand(en.index())), builder);\n"; + os << " toret.push_back(cache);\n"; + os << " }\n"; + os << " return toret;\n"; + os << " }\n"; + os << "\n"; + os << " void createShadowValues(Operation *op, OpBuilder &builder,\n"; + os << " MGradientUtilsReverse *gutils) const " + "{}\n"; + + os << " void createReverseModeAdjoint(Operation *op0, OpBuilder " + "&builder,\n"; + os << " MGradientUtilsReverse *gutils,\n"; + os << " SmallVector caches) const {\n"; + os << " auto op = cast<" << dialect << "::" << opName << ">(op0);\n"; + os << " mlir::Value dif = nullptr;\n"; +} + +static VariableSetting parseVariables(DagInit *tree, ActionType intrinsic, + StringRef origName) { + VariableSetting nameToOrdinal; + std::function)> insert = + [&](DagInit *ptree, ArrayRef prev) { + unsigned i = 0; + for (auto tree : ptree->getArgs()) { + SmallVector next(prev.begin(), prev.end()); + next.push_back(i); + if (auto dg = dyn_cast(tree)) + insert(dg, next); + + if (ptree->getArgNameStr(i).size()) { + std::string op; + if (intrinsic != MLIRDerivatives) + op = (origName + ".getOperand(" + Twine(next[0]) + ")").str(); + else + op = (origName + "->getOperand(" + Twine(next[0]) + ")").str(); + if (prev.size() > 0) { + op = "gutils->extractMeta(Builder2, " + op + + ", ArrayRef({"; + bool first = true; + for (unsigned i = 1; i < next.size(); i++) { + if (!first) + op += ", "; + op += std::to_string(next[i]); + } + op += "}))"; + } + nameToOrdinal.insert(ptree->getArgNameStr(i), op, false); + } + i++; + } + }; + + insert(tree, {}); + + if (tree->getNameStr().size()) + nameToOrdinal.insert(tree->getNameStr(), + (Twine("(&") + origName + ")").str(), false); + return nameToOrdinal; +} + +static void emitReverseCommon(raw_ostream &os, Record *pattern, DagInit *tree, + ActionType intrinsic, StringRef origName, + ListInit *argOps) { + auto nameToOrdinal = parseVariables(tree, intrinsic, origName); + + bool seen = false; + for (auto argOpEn : enumerate(*argOps)) { + size_t argIdx = argOpEn.index(); + if (DagInit *resultRoot = dyn_cast(argOpEn.value())) { + auto opName = resultRoot->getOperator()->getAsString(); + auto Def = cast(resultRoot->getOperator())->getDef(); + if (opName == "InactiveArgSpec" || Def->isSubClassOf("InactiveArgSpec")) { + if (Def->getValueAsBit("asserting")) + os << " assert(gutils->isConstantValue(" << origName << ".getOperand(" + << argIdx << ")));\n"; + continue; + } + } + + os << " "; + if (seen) + os << "} else "; + seen = true; + if (intrinsic == MLIRDerivatives) { + os << "if (!dif && !gutils->isConstantValue(" << origName + << "->getOperand(" << argIdx << "))) {\n"; + } else { + os << "if (!dif && !gutils->isConstantValue(" << origName + << ".getOperand(" << argIdx << "))) {\n"; + } + DagInit *resultTree = cast(argOpEn.value()); + if (hasDiffeRet(resultTree)) { + if (intrinsic == MLIRDerivatives) { + os << " dif = gutils->diffe(" << origName << ", builder);\n"; + os << " gutils->zeroDiffe(" << origName << ", builder);\n"; + } else { + os << " dif = diffe(&" << origName << ", Builder2);\n"; + os << " setDiffe(&" << origName + << ", " + "Constant::getNullValue(gutils->getShadowType(" + << origName + << ".getType())), " + "Builder2);\n"; + } + } + } + if (seen) + os << " }\n"; + + if (intrinsic == MLIRDerivatives) { + os << " SmallVector operands(op->getNumOperands(), nullptr);\n"; + os << " auto neededArgs = cachedArguments(op, gutils);\n"; + os << " size_t count = 0;\n"; + os << " for (auto en : llvm::enumerate(neededArgs))\n"; + os << " if (en.value()) {\n"; + os << " operands[en.index()] = " + "gutils->popCache(caches[count], builder);\n"; + os << " count++;\n"; + os << " }\n"; + } + + std::function, Init *)> revres = + [&](size_t argIdx, ArrayRef idx, Init *ival) { + if (DagInit *resultTree = dyn_cast(ival)) { + auto Def = cast(resultTree->getOperator())->getDef(); + if (Def->isSubClassOf("MultiReturn")) { + unsigned i = 0; + for (auto r : resultTree->getArgs()) { + SmallVector next(idx.begin(), idx.end()); + next.push_back(i); + revres(argIdx, next, r); + i++; + } + return; + } + if (Def->isSubClassOf("InactiveArgSpec")) { + return; + } + const char *curIndent = " "; + os << curIndent << "{\n"; + if (intrinsic == MLIRDerivatives) + os << curIndent << INDENT << "mlir::Value tmp = "; + else + os << curIndent << INDENT << "Value *tmp = "; + bool vectorValued = handle( + Twine(curIndent) + INDENT, "revarg", os, pattern, resultTree, + (intrinsic == MLIRDerivatives) ? "builder" : "Builder2", + nameToOrdinal, /*lookup*/ true, idx, origName, + /*newFromOriginal*/ true, intrinsic); + os << ";\n"; + + if (intrinsic == MLIRDerivatives) { + os << "assert(toadd == nullptr); toadd = tmp;\n"; + } else { + os << curIndent << INDENT + << "Value *out = " + "UndefValue::get(gutils->getShadowType(" + << origName << ".getOperand(" << argIdx << ")->getType()));\n"; + + os << curIndent << INDENT + << "for(unsigned int idx=0, W=gutils->getWidth(); " + "idxgetWidth() == " + "1 ? toadd : gutils->extractMeta(Builder2, toadd, idx)) : " + "nullptr;\n"; + os << curIndent << INDENT << INDENT << "Value *next = tmp;\n"; + if (vectorValued) + os << curIndent << INDENT << INDENT + << "if (gutils->getWidth() > 1) next = " + "gutils->extractMeta(Builder2, next, idx);\n"; + os << curIndent << INDENT << INDENT + << "if (prev) next = Builder2.CreateFAdd(prev, " + "next);\n"; + os << curIndent << INDENT << INDENT + << "out = (gutils->getWidth() > 1) ? " + "Builder2.CreateInsertValue(out, next, idx) : next;\n"; + os << curIndent << INDENT << "}\n"; + os << curIndent << INDENT << "toadd = out;\n"; + } + os << curIndent << "}\n"; + + } else if (ListInit *lst = dyn_cast(ival)) { + unsigned i = 0; + for (auto elem : *lst) { + SmallVector next(idx.begin(), idx.end()); + next.push_back(i); + revres(argIdx, next, elem); + i++; + } + } else + assert(0); + }; + + for (auto argOpEn : enumerate(*argOps)) { + size_t argIdx = argOpEn.index(); + if (DagInit *resultRoot = dyn_cast(argOpEn.value())) { + auto opName = resultRoot->getOperator()->getAsString(); + auto Def = cast(resultRoot->getOperator())->getDef(); + if (opName == "InactiveArgSpec" || Def->isSubClassOf("InactiveArgSpec")) { + continue; + } + } + + const char *curIndent = " "; + if (intrinsic == MLIRDerivatives) + os << curIndent << "if (!gutils->isConstantValue(" << origName + << "->getOperand(" << argIdx << "))) {\n"; + else + os << curIndent << "if (!gutils->isConstantValue(" << origName + << ".getOperand(" << argIdx << "))) {\n"; + initializeNames(Twine(curIndent) + INDENT, os, argOpEn.value(), "local"); + if (intrinsic == MLIRDerivatives) + os << curIndent << INDENT << "mlir::Value toadd = nullptr;\n"; + else + os << curIndent << INDENT << "Value *toadd = nullptr;\n"; + revres(argIdx, {}, argOpEn.value()); + + if (intrinsic == MLIRDerivatives) { + os << curIndent << INDENT << "if (toadd) gutils->addToDiffe(" << origName + << "->getOperand(" << argIdx << "), toadd, builder);\n"; + } else { + os << curIndent << INDENT << "if (toadd) addToDiffe(" << origName + << ".getOperand(" << argIdx << "), toadd"; + os << ", Builder2, " << origName << ".getOperand(" << argIdx + << ")->getType());\n"; + } + os << curIndent << "}\n"; + } +} static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, ActionType intrinsic) { emitSourceFileHeader("Rewriters", os); @@ -1467,45 +1762,7 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, } } - VariableSetting nameToOrdinal; - - std::function)> insert = - [&](DagInit *ptree, ArrayRef prev) { - unsigned i = 0; - for (auto tree : ptree->getArgs()) { - SmallVector next(prev.begin(), prev.end()); - next.push_back(i); - if (auto dg = dyn_cast(tree)) - insert(dg, next); - - if (ptree->getArgNameStr(i).size()) { - std::string op; - if (intrinsic != MLIRDerivatives) - op = (origName + ".getOperand(" + Twine(next[0]) + ")").str(); - else - op = (origName + "->getOperand(" + Twine(next[0]) + ")").str(); - if (prev.size() > 0) { - op = "gutils->extractMeta(Builder2, " + op + - ", ArrayRef({"; - bool first = true; - for (unsigned i = 1; i < next.size(); i++) { - if (!first) - op += ", "; - op += std::to_string(next[i]); - } - op += "}))"; - } - nameToOrdinal.insert(ptree->getArgNameStr(i), op, false); - } - i++; - } - }; - - insert(tree, {}); - - if (tree->getNameStr().size()) - nameToOrdinal.insert(tree->getNameStr(), - (Twine("(&") + origName + ")").str(), false); + VariableSetting nameToOrdinal = parseVariables(tree, intrinsic, origName); if (intrinsic != BinopDerivatives && intrinsic != InstDerivatives && intrinsic != MLIRDerivatives) { @@ -1706,248 +1963,10 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, os << " Value *dif = nullptr;\n"; } else { os << "};\n"; - auto opName = pattern->getValueAsString("opName"); - auto dialect = pattern->getValueAsString("dialect"); - os << "struct " << opName << "RevDerivative : \n"; - os << " public " - "ReverseAutoDiffOpInterface::ExternalModel<" - << opName << "RevDerivative, " << dialect << "::" << opName << "> {\n"; - os << " SmallVector cachedArguments(Operation *op,\n"; - os << " MGradientUtilsReverse *gutils) " - "const {\n"; - os << " SmallVector toret(op->getNumOperands(), false);\n"; - StringMap> varNameToCondition; - - std::function)> insert = - [&](DagInit *ptree, ArrayRef prev) { - for (auto treeEn : llvm::enumerate(ptree->getArgs())) { - auto tree = treeEn.value(); - auto name = ptree->getArgNameStr(treeEn.index()); - SmallVector next(prev.begin(), prev.end()); - next.push_back(treeEn.index()); - if (auto dg = dyn_cast(tree)) - insert(dg, next); - - if (name.size()) { - varNameToCondition[name] = std::make_tuple( - "idx == " + std::to_string(treeEn.index()), "", false); - } - } - }; - - insert(tree, {}); - - if (tree->getNameStr().size()) - varNameToCondition[tree->getNameStr()] = - std::make_tuple("ILLEGAL", "ILLEGAL", false); - - os << " for (size_t idx=0; idxgetNumOperands(); idx++) {\n"; - os << " bool used = false;\n"; - printDiffUse(os, " ", argOps, origName, intrinsic, tree, - varNameToCondition); - os << " toret[idx] = used;\n"; - os << " }\n"; - os << " return toret;\n"; - os << " }\n"; - - os << " SmallVector cacheValues(Operation *op,\n"; - os << " MGradientUtilsReverse *gutils) " - "const {\n"; - os << " if (gutils->isConstantInstruction(op) || " - "gutils->isConstantValue(op->getResult(0))) return {};\n"; - os << " auto neededArgs = cachedArguments(op, gutils);\n"; - os << " SmallVector toret;\n"; - os << " OpBuilder builder(gutils->getNewFromOriginal(op));\n"; - os << " for (auto en : llvm::enumerate(neededArgs))\n"; - os << " if (en.value()) {\n"; - os << " Value cache = " - "gutils->initAndPushCache(gutils->getNewFromOriginal(op->" - "getOperand(en.index())), builder);\n"; - os << " toret.push_back(cache);\n"; - os << " }\n"; - os << " return toret;\n"; - os << " }\n"; - os << "\n"; - os << " void createShadowValues(Operation *op, OpBuilder &builder,\n"; - os << " MGradientUtilsReverse *gutils) const " - "{}\n"; - - os << " void createReverseModeAdjoint(Operation *op0, OpBuilder " - "&builder,\n"; - os << " MGradientUtilsReverse *gutils,\n"; - os << " SmallVector caches) const {\n"; - os << " auto op = cast<" << dialect << "::" << opName << ">(op0);\n"; - os << " mlir::Value dif = nullptr;\n"; - } - // TODO vector - - bool seen = false; - for (auto argOpEn : enumerate(*argOps)) { - size_t argIdx = argOpEn.index(); - if (DagInit *resultRoot = dyn_cast(argOpEn.value())) { - auto opName = resultRoot->getOperator()->getAsString(); - auto Def = cast(resultRoot->getOperator())->getDef(); - if (opName == "InactiveArgSpec" || - Def->isSubClassOf("InactiveArgSpec")) { - if (Def->getValueAsBit("asserting")) - os << " assert(gutils->isConstantValue(" << origName - << ".getOperand(" << argIdx << ")));\n"; - continue; - } - } - - os << " "; - if (seen) - os << "} else "; - seen = true; - if (intrinsic == MLIRDerivatives) { - os << "if (!dif && !gutils->isConstantValue(" << origName - << "->getOperand(" << argIdx << "))) {\n"; - } else { - os << "if (!dif && !gutils->isConstantValue(" << origName - << ".getOperand(" << argIdx << "))) {\n"; - } - DagInit *resultTree = cast(argOpEn.value()); - if (hasDiffeRet(resultTree)) { - if (intrinsic == MLIRDerivatives) { - os << " dif = gutils->diffe(" << origName << ", builder);\n"; - os << " gutils->zeroDiffe(" << origName << ", builder);\n"; - } else { - os << " dif = diffe(&" << origName << ", Builder2);\n"; - os << " setDiffe(&" << origName - << ", " - "Constant::getNullValue(gutils->getShadowType(" - << origName - << ".getType())), " - "Builder2);\n"; - } - } + emitMLIRReverse(os, pattern, tree, intrinsic, origName, argOps); } - if (seen) - os << " }\n"; - - if (intrinsic == MLIRDerivatives) { - os << " SmallVector operands(op->getNumOperands(), nullptr);\n"; - os << " auto neededArgs = cachedArguments(op, gutils);\n"; - os << " size_t count = 0;\n"; - os << " for (auto en : llvm::enumerate(neededArgs))\n"; - os << " if (en.value()) {\n"; - os << " operands[en.index()] = " - "gutils->popCache(caches[count], builder);\n"; - os << " count++;\n"; - os << " }\n"; - } - - std::function, Init *)> revres = - [&](size_t argIdx, ArrayRef idx, Init *ival) { - if (DagInit *resultTree = dyn_cast(ival)) { - auto Def = cast(resultTree->getOperator())->getDef(); - if (Def->isSubClassOf("MultiReturn")) { - unsigned i = 0; - for (auto r : resultTree->getArgs()) { - SmallVector next(idx.begin(), idx.end()); - next.push_back(i); - revres(argIdx, next, r); - i++; - } - return; - } - if (Def->isSubClassOf("InactiveArgSpec")) { - return; - } - const char *curIndent = " "; - os << curIndent << "{\n"; - if (intrinsic == MLIRDerivatives) - os << curIndent << INDENT << "mlir::Value tmp = "; - else - os << curIndent << INDENT << "Value *tmp = "; - bool vectorValued = handle( - Twine(curIndent) + INDENT, "revarg", os, pattern, resultTree, - (intrinsic == MLIRDerivatives) ? "builder" : "Builder2", - nameToOrdinal, /*lookup*/ true, idx, origName, - /*newFromOriginal*/ true, intrinsic); - os << ";\n"; - - if (intrinsic == MLIRDerivatives) { - os << "assert(toadd == nullptr); toadd = tmp;\n"; - } else { - os << curIndent << INDENT - << "Value *out = " - "UndefValue::get(gutils->getShadowType(" - << origName << ".getOperand(" << argIdx << ")->getType()));\n"; - - os << curIndent << INDENT - << "for(unsigned int idx=0, W=gutils->getWidth(); " - "idxgetWidth() == " - "1 ? toadd : gutils->extractMeta(Builder2, toadd, idx)) : " - "nullptr;\n"; - os << curIndent << INDENT << INDENT << "Value *next = tmp;\n"; - if (vectorValued) - os << curIndent << INDENT << INDENT - << "if (gutils->getWidth() > 1) next = " - "gutils->extractMeta(Builder2, next, idx);\n"; - os << curIndent << INDENT << INDENT - << "if (prev) next = Builder2.CreateFAdd(prev, " - "next);\n"; - os << curIndent << INDENT << INDENT - << "out = (gutils->getWidth() > 1) ? " - "Builder2.CreateInsertValue(out, next, idx) : next;\n"; - os << curIndent << INDENT << "}\n"; - os << curIndent << INDENT << "toadd = out;\n"; - } - os << curIndent << "}\n"; - - } else if (ListInit *lst = dyn_cast(ival)) { - unsigned i = 0; - for (auto elem : *lst) { - SmallVector next(idx.begin(), idx.end()); - next.push_back(i); - revres(argIdx, next, elem); - i++; - } - } else - assert(0); - }; - for (auto argOpEn : enumerate(*argOps)) { - size_t argIdx = argOpEn.index(); - if (DagInit *resultRoot = dyn_cast(argOpEn.value())) { - auto opName = resultRoot->getOperator()->getAsString(); - auto Def = cast(resultRoot->getOperator())->getDef(); - if (opName == "InactiveArgSpec" || - Def->isSubClassOf("InactiveArgSpec")) { - continue; - } - } - - const char *curIndent = " "; - if (intrinsic == MLIRDerivatives) - os << curIndent << "if (!gutils->isConstantValue(" << origName - << "->getOperand(" << argIdx << "))) {\n"; - else - os << curIndent << "if (!gutils->isConstantValue(" << origName - << ".getOperand(" << argIdx << "))) {\n"; - initializeNames(Twine(curIndent) + INDENT, os, argOpEn.value(), "local"); - if (intrinsic == MLIRDerivatives) - os << curIndent << INDENT << "mlir::Value toadd = nullptr;\n"; - else - os << curIndent << INDENT << "Value *toadd = nullptr;\n"; - revres(argIdx, {}, argOpEn.value()); - - if (intrinsic == MLIRDerivatives) { - os << curIndent << INDENT << "if (toadd) gutils->addToDiffe(" - << origName << "->getOperand(" << argIdx << "), toadd, builder);\n"; - } else { - os << curIndent << INDENT << "if (toadd) addToDiffe(" << origName - << ".getOperand(" << argIdx << "), toadd"; - os << ", Builder2, " << origName << ".getOperand(" << argIdx - << ")->getType());\n"; - } - os << curIndent << "}\n"; - } + emitReverseCommon(os, pattern, tree, intrinsic, origName, argOps); if (intrinsic != MLIRDerivatives) { os << " auto found = gutils->invertedPointers.find(&(" << origName @@ -2036,6 +2055,18 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, } os << " return true;\n }\n"; os << "};\n"; + + DagInit *tree = pattern->getValueAsDag("PatternToMatch"); + + if (tree->getOperator()->getAsString() != "Unimplemented") { + ListInit *argOps = pattern->getValueAsListInit("reverse"); + auto origName = "op"; + emitMLIRReverse(os, pattern, tree, intrinsic, origName, argOps); + emitReverseCommon(os, pattern, tree, intrinsic, origName, argOps); + os << " return;\n"; + os << " }\n"; + os << " };\n"; + } } const auto &brpatterns = recordKeeper.getAllDerivedDefinitions("BranchOp"); @@ -2081,6 +2112,11 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, for (auto storedarg : pattern->getValueAsListOfInts("storedargs")) os << ", " << storedarg; os << ">(*context);\n"; + DagInit *tree = pattern->getValueAsDag("PatternToMatch"); + if (tree->getOperator()->getAsString() != "Unimplemented") { + os << " " << dialect << "::" << opName << "::attachInterface<" + << opName << "RevDerivative>(*context);\n"; + } } for (Record *pattern : brpatterns) { auto opName = pattern->getValueAsString("opName");