Skip to content

Commit

Permalink
MLIR support LogicalResult return in reversemode
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Mar 6, 2024
1 parent 0b62188 commit 6ebd9ee
Show file tree
Hide file tree
Showing 8 changed files with 48 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,11 @@ class NoopRevAutoDiffInterface
: public ReverseAutoDiffOpInterface::ExternalModel<
NoopRevAutoDiffInterface<OpTy>, OpTy> {
public:
void createReverseModeAdjoint(Operation *op, OpBuilder &builder,
LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils,
SmallVector<Value> caches) const {}
SmallVector<Value> caches) const {
return success();
}

SmallVector<Value> cacheValues(Operation *op,
MGradientUtilsReverse *gutils) const {
Expand All @@ -131,10 +133,11 @@ class ReturnRevAutoDiffInterface
: public ReverseAutoDiffOpInterface::ExternalModel<
ReturnRevAutoDiffInterface<OpTy>, OpTy> {
public:
void createReverseModeAdjoint(Operation *op, OpBuilder &builder,
LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils,
SmallVector<Value> caches) const {
returnReverseHandler(op, builder, gutils);
return success();
}

SmallVector<Value> cacheValues(Operation *op,
Expand Down Expand Up @@ -181,9 +184,11 @@ class AutoDiffUsingAllocationRev
: public ReverseAutoDiffOpInterface::ExternalModel<
AutoDiffUsingAllocationRev<OpTy>, OpTy> {
public:
void createReverseModeAdjoint(Operation *op, OpBuilder &builder,
LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils,
SmallVector<Value> caches) const {}
SmallVector<Value> caches) const {
return success();
}

SmallVector<Value> cacheValues(Operation *op,
MGradientUtilsReverse *gutils) const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ template <typename T_>
struct GenericOpInterfaceReverse
: public ReverseAutoDiffOpInterface::ExternalModel<
GenericOpInterfaceReverse<T_>, T_> {
void createReverseModeAdjoint(Operation *op, OpBuilder &builder,
LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils,
SmallVector<Value> caches) const {
auto linalgOp = cast<linalg::LinalgOp>(op);
Expand Down Expand Up @@ -255,6 +255,7 @@ struct GenericOpInterfaceReverse
cacheBuilder.getArrayAttr(indexingMapsAttr));
adjoint->setAttr(adjoint.getIndexingMapsAttrName(),
builder.getArrayAttr(indexingMapsAttrAdjoint));
return success();
}

SmallVector<Value> cacheValues(Operation *op,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ namespace {
struct LoadOpInterfaceReverse
: public ReverseAutoDiffOpInterface::ExternalModel<LoadOpInterfaceReverse,
memref::LoadOp> {
void createReverseModeAdjoint(Operation *op, OpBuilder &builder,
LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils,
SmallVector<Value> caches) const {
auto loadOp = cast<memref::LoadOp>(op);
Expand Down Expand Up @@ -62,6 +62,7 @@ struct LoadOpInterfaceReverse
ArrayRef<Value>(retrievedArguments));
}
}
return success();
}

SmallVector<Value> cacheValues(Operation *op,
Expand Down Expand Up @@ -96,7 +97,7 @@ struct LoadOpInterfaceReverse
struct StoreOpInterfaceReverse
: public ReverseAutoDiffOpInterface::ExternalModel<StoreOpInterfaceReverse,
memref::StoreOp> {
void createReverseModeAdjoint(Operation *op, OpBuilder &builder,
LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils,
SmallVector<Value> caches) const {
auto storeOp = cast<memref::StoreOp>(op);
Expand Down Expand Up @@ -133,6 +134,7 @@ struct StoreOpInterfaceReverse
ArrayRef<Value>(retrievedArguments));
}
}
return success();
}

SmallVector<Value> cacheValues(Operation *op,
Expand Down Expand Up @@ -167,9 +169,11 @@ struct StoreOpInterfaceReverse
struct SubViewOpInterfaceReverse
: public ReverseAutoDiffOpInterface::ExternalModel<
SubViewOpInterfaceReverse, memref::SubViewOp> {
void createReverseModeAdjoint(Operation *op, OpBuilder &builder,
LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils,
SmallVector<Value> caches) const {}
SmallVector<Value> caches) const {
return success();
}

SmallVector<Value> cacheValues(Operation *op,
MGradientUtilsReverse *gutils) const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ namespace {
struct ForOpInterfaceReverse
: public ReverseAutoDiffOpInterface::ExternalModel<ForOpInterfaceReverse,
scf::ForOp> {
void createReverseModeAdjoint(Operation *op, OpBuilder &builder,
LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils,
SmallVector<Value> caches) const {
auto forOp = cast<scf::ForOp>(op);
Expand Down Expand Up @@ -149,6 +149,7 @@ struct ForOpInterfaceReverse
buildFuncReturnOp);
}
}
return success();
}

SmallVector<Value> cacheValues(Operation *op,
Expand Down
2 changes: 1 addition & 1 deletion enzyme/Enzyme/MLIR/Interfaces/AutoDiffOpInterface.td
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def ReverseAutoDiffOpInterface : OpInterface<"ReverseAutoDiffOpInterface"> {
/*desc=*/[{
Emits a reverse-mode adjoint of the given function.
}],
/*retTy=*/"void",
/*retTy=*/"::mlir::LogicalResult",
/*methodName=*/"createReverseModeAdjoint",
/*args=*/(ins "::mlir::OpBuilder &":$builder, "::mlir::enzyme::MGradientUtilsReverse *":$gutils, "SmallVector<Value>":$caches)
>,
Expand Down
8 changes: 3 additions & 5 deletions enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,15 +127,13 @@ class MEnzymeLogic {
handlePredecessors(Block *oBB, Block *newBB, Block *reverseBB,
MGradientUtilsReverse *gutils,
llvm::function_ref<buildReturnFunction> buildReturnOp);
void visitChildren(Block *oBB, Block *reverseBB,
LogicalResult visitChildren(Block *oBB, Block *reverseBB,
MGradientUtilsReverse *gutils);
void visitChild(Operation *op, OpBuilder &builder,
LogicalResult visitChild(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils);
void mapInvertArguments(Block *oBB, Block *reverseBB,
MGradientUtilsReverse *gutils);
SmallVector<mlir::Block *> getDominatorToposort(MGradientUtilsReverse *gutils,
Region &region);
void differentiate(MGradientUtilsReverse *gutils, Region &oldRegion,
LogicalResult differentiate(MGradientUtilsReverse *gutils, Region &oldRegion,
Region &newRegion,
llvm::function_ref<buildReturnFunction> buildFuncRetrunOp,
std::function<std::pair<Value, Value>(Type)> cacheCreator);
Expand Down
34 changes: 18 additions & 16 deletions enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,40 +38,35 @@ void handleReturns(Block *oBB, Block *newBB, Block *reverseBB,
/*
Create reverse mode adjoint for an operation.
*/
void MEnzymeLogic::visitChild(Operation *op, OpBuilder &builder,
LogicalResult MEnzymeLogic::visitChild(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils) {
if ((op->getBlock()->getTerminator() != op) &&
llvm::all_of(op->getResults(),
[gutils](Value v) { return gutils->isConstantValue(v); }) &&
gutils->isConstantInstruction(op)) {
return;
return success();
}
if (auto ifaceOp = dyn_cast<ReverseAutoDiffOpInterface>(op)) {
SmallVector<Value> caches = ifaceOp.cacheValues(gutils);
ifaceOp.createReverseModeAdjoint(builder, gutils, caches);
return;
/*
for (int indexResult = 0; indexResult < (int)op->getNumResults();
indexResult++) {
Value result = op->getResult(indexResult);
gutils->clearValue(result, builder);
}
*/
return ifaceOp.createReverseModeAdjoint(builder, gutils, caches);
}
op->emitError() << "could not compute the adjoint for this operation " << *op;
return failure();
}

void MEnzymeLogic::visitChildren(Block *oBB, Block *reverseBB,
LogicalResult MEnzymeLogic::visitChildren(Block *oBB, Block *reverseBB,
MGradientUtilsReverse *gutils) {
OpBuilder revBuilder(reverseBB, reverseBB->end());
bool valid = true;
if (!oBB->empty()) {
auto first = oBB->rbegin();
auto last = oBB->rend();
for (auto it = first; it != last; ++it) {
Operation *op = &*it;
visitChild(op, revBuilder, gutils);
valid &= visitChild(op, revBuilder, gutils).succeeded();
}
}
return success(valid);
}

void MEnzymeLogic::handlePredecessors(
Expand Down Expand Up @@ -161,7 +156,7 @@ void MEnzymeLogic::handlePredecessors(
}
}

void MEnzymeLogic::differentiate(
LogicalResult MEnzymeLogic::differentiate(
MGradientUtilsReverse *gutils, Region &oldRegion, Region &newRegion,
llvm::function_ref<buildReturnFunction> buildFuncReturnOp,
std::function<std::pair<Value, Value>(Type)> cacheCreator) {
Expand All @@ -171,13 +166,15 @@ void MEnzymeLogic::differentiate(

gutils->createReverseModeBlocks(oldRegion, newRegion);

bool valid = true;
for (auto &oBB : oldRegion) {
Block *newBB = gutils->getNewFromOriginal(&oBB);
Block *reverseBB = gutils->mapReverseModeBlocks.lookupOrNull(&oBB);
handleReturns(&oBB, newBB, reverseBB, gutils);
visitChildren(&oBB, reverseBB, gutils);
valid &= visitChildren(&oBB, reverseBB, gutils).succeeded();
handlePredecessors(&oBB, newBB, reverseBB, gutils, buildFuncReturnOp);
}
return success(valid);
}

FunctionOpInterface MEnzymeLogic::CreateReverseDiff(
Expand Down Expand Up @@ -212,14 +209,19 @@ FunctionOpInterface MEnzymeLogic::CreateReverseDiff(

gutils->forceAugmentedReturns();

differentiate(gutils, oldRegion, newRegion, buildFuncReturnOp, nullptr);
auto res = differentiate(gutils, oldRegion, newRegion, buildFuncReturnOp, nullptr);

auto nf = gutils->newFunc;


// llvm::errs() << "nf\n";
// nf.dump();
// llvm::errs() << "nf end\n";

delete gutils;

if (!res.succeeded())
return nullptr;

return nf;
}
6 changes: 4 additions & 2 deletions enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1341,7 +1341,7 @@ static void emitMLIRReverse(raw_ostream &os, Record *pattern, DagInit *tree,
os << " MGradientUtilsReverse *gutils) const "
"{}\n";

os << " void createReverseModeAdjoint(Operation *op0, OpBuilder "
os << " LogicalResult createReverseModeAdjoint(Operation *op0, OpBuilder "
"&builder,\n";
os << " MGradientUtilsReverse *gutils,\n";
os << " SmallVector<Value> caches) const {\n";
Expand Down Expand Up @@ -1995,6 +1995,8 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os,

if (intrinsic == IntrDerivatives || intrinsic == CallDerivatives)
os << " return true;\n }\n";
else if (intrinsic == MLIRDerivatives)
os << " return success();\n }\n";
else
os << " return;\n }\n";
if (intrinsic == MLIRDerivatives)
Expand Down Expand Up @@ -2063,7 +2065,7 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os,
auto origName = "op";
emitMLIRReverse(os, pattern, tree, intrinsic, origName, argOps);
emitReverseCommon(os, pattern, tree, intrinsic, origName, argOps);
os << " return;\n";
os << " return success();\n";
os << " }\n";
os << " };\n";
}
Expand Down

0 comments on commit 6ebd9ee

Please sign in to comment.