From 07d26b3a3e29c99b30405107216d4751eeb42adb Mon Sep 17 00:00:00 2001 From: "William S. Moses" Date: Sun, 18 Feb 2024 00:12:48 -0500 Subject: [PATCH] fix --- .../MLIR/Implementations/ArithDerivatives.td | 4 ++-- enzyme/Enzyme/MLIR/Implementations/Common.td | 7 +++++++ .../CoreDialectsAutoDiffImplementations.cpp | 17 +++++++++++++++++ .../CoreDialectsAutoDiffImplementations.h | 2 ++ enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 3 ++- 5 files changed, 30 insertions(+), 3 deletions(-) diff --git a/enzyme/Enzyme/MLIR/Implementations/ArithDerivatives.td b/enzyme/Enzyme/MLIR/Implementations/ArithDerivatives.td index 3ed038fa4cb6..3d53793be3af 100644 --- a/enzyme/Enzyme/MLIR/Implementations/ArithDerivatives.td +++ b/enzyme/Enzyme/MLIR/Implementations/ArithDerivatives.td @@ -28,6 +28,6 @@ def : MLIRDerivative<"arith", "DivFOp", (Op $x, $y), [ (CheckedDivF (DiffeRet), $y), (NegF (MulF (CheckedDivF (DiffeRet), $y), (DivF $x, $y))) - ] - // (CheckedDiv (FSub (SelectIfActive $x, (FMul (Shadow $x), $y), (Zero $x)), (SelectIfActive $y, (FMul (Shadow $y), $x), (Zero $y))), (FMul $y, $y)) + ], + (CheckedDivF (SubF (SelectIfActive $x, (MulF (Shadow $x), $y), (ConstantFP<"0","arith", "ConstantOp"> $x)), (SelectIfActive $y, (MulF (Shadow $y), $x), (ConstantFP<"0","arith","ConstantOp"> $y))), (MulF $y, $y)) >; diff --git a/enzyme/Enzyme/MLIR/Implementations/Common.td b/enzyme/Enzyme/MLIR/Implementations/Common.td index 99be53497229..d076b9cda6b4 100644 --- a/enzyme/Enzyme/MLIR/Implementations/Common.td +++ b/enzyme/Enzyme/MLIR/Implementations/Common.td @@ -61,6 +61,9 @@ class DiffeRetIndex indices_> { } def DiffeRet : DiffeRetIndex<[-1]>; +def Shadow : Operation { +} + class Inst : Operation { string name = mnemonic; string dialect = dialect_; @@ -69,6 +72,10 @@ class Inst : Operation { + +} + class ConstantFP : Operation { string value = val; string dialect = dialect_; diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp index 025946ef9651..fdeba19a2379 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp @@ -19,6 +19,23 @@ using namespace mlir; using namespace mlir::enzyme; +mlir::Attribute mlir::enzyme::getConstantAttr(mlir::Type type, + llvm::StringRef value) { + using namespace mlir; + if (auto T = dyn_cast(type)) { + size_t num = 1; + for (auto sz : T.getShape()) + num *= sz; + APFloat apvalue(T.getElementType().cast().getFloatSemantics(), + value); + SmallVector supportedValues(num, apvalue); + return DenseFPElementsAttr::get(type.cast(), supportedValues); + } + auto T = cast(type); + APFloat apvalue(T.getFloatSemantics(), value); + return FloatAttr::get(T, apvalue); +} + void mlir::enzyme::detail::branchingForwardHandler(Operation *inst, OpBuilder &builder, MGradientUtils *gutils) { diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h index a7ec6b179986..cb0be114a2ca 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h @@ -198,5 +198,7 @@ void registerLinalgDialectAutoDiffInterface(DialectRegistry ®istry); void registerMathDialectAutoDiffInterface(DialectRegistry ®istry); void registerCoreDialectAutodiffInterfaces(DialectRegistry ®istry); + +mlir::Attribute getConstantAttr(mlir::Type type, llvm::StringRef value); } // namespace enzyme } // namespace mlir diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index fa11ee4a5ad4..40168a3e5a5e 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -465,7 +465,8 @@ bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os, assert(!isVec); ord = ord1; } - os << ord << ".getType(), getTensorAttr(" << ord << ".getType(), "; + os << ord << ".getType(), mlir::enzyme::getConstantAttr(" << ord + << ".getType(), "; os << "\"" << value->getValue() << "\"))"; } else { if (resultRoot->getNumArgs() != 1)