Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Feb 18, 2024
1 parent 9e3c10a commit 07d26b3
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 3 deletions.
4 changes: 2 additions & 2 deletions enzyme/Enzyme/MLIR/Implementations/ArithDerivatives.td
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,6 @@ def : MLIRDerivative<"arith", "DivFOp", (Op $x, $y),
[
(CheckedDivF (DiffeRet), $y),
(NegF (MulF (CheckedDivF (DiffeRet), $y), (DivF $x, $y)))
]
// (CheckedDiv (FSub (SelectIfActive $x, (FMul (Shadow $x), $y), (Zero $x)), (SelectIfActive $y, (FMul (Shadow $y), $x), (Zero $y))), (FMul $y, $y))
],
(CheckedDivF (SubF (SelectIfActive $x, (MulF (Shadow $x), $y), (ConstantFP<"0","arith", "ConstantOp"> $x)), (SelectIfActive $y, (MulF (Shadow $y), $x), (ConstantFP<"0","arith","ConstantOp"> $y))), (MulF $y, $y))
>;
7 changes: 7 additions & 0 deletions enzyme/Enzyme/MLIR/Implementations/Common.td
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ class DiffeRetIndex<list<int> indices_> {
}
def DiffeRet : DiffeRetIndex<[-1]>;

def Shadow : Operation</*primal*/0, /*shadow*/1> {
}

class Inst<string mnemonic, string dialect_> : Operation</*primal*/1, /*shadow*/0> {
string name = mnemonic;
string dialect = dialect_;
Expand All @@ -69,6 +72,10 @@ class Inst<string mnemonic, string dialect_> : Operation</*primal*/1, /*shadow*
def Op {
}

def SelectIfActive : Operation</*primal*/0, /*shadow*/0, /*custom*/1> {

}

class ConstantFP<string val, string dialect_, string op_> : Operation</*primal*/0, /*shadow*/0> {
string value = val;
string dialect = dialect_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,23 @@
using namespace mlir;
using namespace mlir::enzyme;

mlir::Attribute mlir::enzyme::getConstantAttr(mlir::Type type,
llvm::StringRef value) {
using namespace mlir;
if (auto T = dyn_cast<TensorType>(type)) {
size_t num = 1;
for (auto sz : T.getShape())
num *= sz;
APFloat apvalue(T.getElementType().cast<FloatType>().getFloatSemantics(),
value);
SmallVector<APFloat> supportedValues(num, apvalue);
return DenseFPElementsAttr::get(type.cast<ShapedType>(), supportedValues);
}
auto T = cast<FloatType>(type);
APFloat apvalue(T.getFloatSemantics(), value);
return FloatAttr::get(T, apvalue);
}

void mlir::enzyme::detail::branchingForwardHandler(Operation *inst,
OpBuilder &builder,
MGradientUtils *gutils) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,5 +198,7 @@ void registerLinalgDialectAutoDiffInterface(DialectRegistry &registry);
void registerMathDialectAutoDiffInterface(DialectRegistry &registry);

void registerCoreDialectAutodiffInterfaces(DialectRegistry &registry);

mlir::Attribute getConstantAttr(mlir::Type type, llvm::StringRef value);
} // namespace enzyme
} // namespace mlir
3 changes: 2 additions & 1 deletion enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -465,7 +465,8 @@ bool handle(const Twine &curIndent, const Twine &argPattern, raw_ostream &os,
assert(!isVec);
ord = ord1;
}
os << ord << ".getType(), getTensorAttr(" << ord << ".getType(), ";
os << ord << ".getType(), mlir::enzyme::getConstantAttr(" << ord
<< ".getType(), ";
os << "\"" << value->getValue() << "\"))";
} else {
if (resultRoot->getNumArgs() != 1)
Expand Down

0 comments on commit 07d26b3

Please sign in to comment.