Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MLIR support LogicalResult return in reversemode #1779

Merged
merged 1 commit into from
Mar 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,11 @@ class NoopRevAutoDiffInterface
: public ReverseAutoDiffOpInterface::ExternalModel<
NoopRevAutoDiffInterface<OpTy>, OpTy> {
public:
void createReverseModeAdjoint(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils,
SmallVector<Value> caches) const {}
LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils,
SmallVector<Value> caches) const {
return success();
}

SmallVector<Value> cacheValues(Operation *op,
MGradientUtilsReverse *gutils) const {
Expand All @@ -131,10 +133,11 @@ class ReturnRevAutoDiffInterface
: public ReverseAutoDiffOpInterface::ExternalModel<
ReturnRevAutoDiffInterface<OpTy>, OpTy> {
public:
void createReverseModeAdjoint(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils,
SmallVector<Value> caches) const {
LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils,
SmallVector<Value> caches) const {
returnReverseHandler(op, builder, gutils);
return success();
}

SmallVector<Value> cacheValues(Operation *op,
Expand Down Expand Up @@ -181,9 +184,11 @@ class AutoDiffUsingAllocationRev
: public ReverseAutoDiffOpInterface::ExternalModel<
AutoDiffUsingAllocationRev<OpTy>, OpTy> {
public:
void createReverseModeAdjoint(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils,
SmallVector<Value> caches) const {}
LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils,
SmallVector<Value> caches) const {
return success();
}

SmallVector<Value> cacheValues(Operation *op,
MGradientUtilsReverse *gutils) const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,9 +68,9 @@ template <typename T_>
struct GenericOpInterfaceReverse
: public ReverseAutoDiffOpInterface::ExternalModel<
GenericOpInterfaceReverse<T_>, T_> {
void createReverseModeAdjoint(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils,
SmallVector<Value> caches) const {
LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils,
SmallVector<Value> caches) const {
auto linalgOp = cast<linalg::LinalgOp>(op);
assert(linalgOp.hasPureBufferSemantics() &&
"Linalg op with tensor semantics not yet supported");
Expand Down Expand Up @@ -255,6 +255,7 @@ struct GenericOpInterfaceReverse
cacheBuilder.getArrayAttr(indexingMapsAttr));
adjoint->setAttr(adjoint.getIndexingMapsAttrName(),
builder.getArrayAttr(indexingMapsAttrAdjoint));
return success();
}

SmallVector<Value> cacheValues(Operation *op,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ namespace {
struct LoadOpInterfaceReverse
: public ReverseAutoDiffOpInterface::ExternalModel<LoadOpInterfaceReverse,
memref::LoadOp> {
void createReverseModeAdjoint(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils,
SmallVector<Value> caches) const {
LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils,
SmallVector<Value> caches) const {
auto loadOp = cast<memref::LoadOp>(op);
Value memref = loadOp.getMemref();

Expand All @@ -62,6 +62,7 @@ struct LoadOpInterfaceReverse
ArrayRef<Value>(retrievedArguments));
}
}
return success();
}

SmallVector<Value> cacheValues(Operation *op,
Expand Down Expand Up @@ -96,9 +97,9 @@ struct LoadOpInterfaceReverse
struct StoreOpInterfaceReverse
: public ReverseAutoDiffOpInterface::ExternalModel<StoreOpInterfaceReverse,
memref::StoreOp> {
void createReverseModeAdjoint(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils,
SmallVector<Value> caches) const {
LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils,
SmallVector<Value> caches) const {
auto storeOp = cast<memref::StoreOp>(op);
Value val = storeOp.getValue();
Value memref = storeOp.getMemref();
Expand Down Expand Up @@ -133,6 +134,7 @@ struct StoreOpInterfaceReverse
ArrayRef<Value>(retrievedArguments));
}
}
return success();
}

SmallVector<Value> cacheValues(Operation *op,
Expand Down Expand Up @@ -167,9 +169,11 @@ struct StoreOpInterfaceReverse
struct SubViewOpInterfaceReverse
: public ReverseAutoDiffOpInterface::ExternalModel<
SubViewOpInterfaceReverse, memref::SubViewOp> {
void createReverseModeAdjoint(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils,
SmallVector<Value> caches) const {}
LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils,
SmallVector<Value> caches) const {
return success();
}

SmallVector<Value> cacheValues(Operation *op,
MGradientUtilsReverse *gutils) const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ namespace {
struct ForOpInterfaceReverse
: public ReverseAutoDiffOpInterface::ExternalModel<ForOpInterfaceReverse,
scf::ForOp> {
void createReverseModeAdjoint(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils,
SmallVector<Value> caches) const {
LogicalResult createReverseModeAdjoint(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils,
SmallVector<Value> caches) const {
auto forOp = cast<scf::ForOp>(op);

// Begin Perform d(yielded value[i]) += d(result[i]); d(result[i]) = 0
Expand Down Expand Up @@ -149,6 +149,7 @@ struct ForOpInterfaceReverse
buildFuncReturnOp);
}
}
return success();
}

SmallVector<Value> cacheValues(Operation *op,
Expand Down
2 changes: 1 addition & 1 deletion enzyme/Enzyme/MLIR/Interfaces/AutoDiffOpInterface.td
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def ReverseAutoDiffOpInterface : OpInterface<"ReverseAutoDiffOpInterface"> {
/*desc=*/[{
Emits a reverse-mode adjoint of the given function.
}],
/*retTy=*/"void",
/*retTy=*/"::mlir::LogicalResult",
/*methodName=*/"createReverseModeAdjoint",
/*args=*/(ins "::mlir::OpBuilder &":$builder, "::mlir::enzyme::MGradientUtilsReverse *":$gutils, "SmallVector<Value>":$caches)
>,
Expand Down
19 changes: 9 additions & 10 deletions enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,18 +127,17 @@ class MEnzymeLogic {
handlePredecessors(Block *oBB, Block *newBB, Block *reverseBB,
MGradientUtilsReverse *gutils,
llvm::function_ref<buildReturnFunction> buildReturnOp);
void visitChildren(Block *oBB, Block *reverseBB,
MGradientUtilsReverse *gutils);
void visitChild(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils);
LogicalResult visitChildren(Block *oBB, Block *reverseBB,
MGradientUtilsReverse *gutils);
LogicalResult visitChild(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils);
void mapInvertArguments(Block *oBB, Block *reverseBB,
MGradientUtilsReverse *gutils);
SmallVector<mlir::Block *> getDominatorToposort(MGradientUtilsReverse *gutils,
Region &region);
void differentiate(MGradientUtilsReverse *gutils, Region &oldRegion,
Region &newRegion,
llvm::function_ref<buildReturnFunction> buildFuncRetrunOp,
std::function<std::pair<Value, Value>(Type)> cacheCreator);
LogicalResult
differentiate(MGradientUtilsReverse *gutils, Region &oldRegion,
Region &newRegion,
llvm::function_ref<buildReturnFunction> buildFuncRetrunOp,
std::function<std::pair<Value, Value>(Type)> cacheCreator);
};

} // Namespace enzyme
Expand Down
38 changes: 20 additions & 18 deletions enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,40 +38,35 @@ void handleReturns(Block *oBB, Block *newBB, Block *reverseBB,
/*
Create reverse mode adjoint for an operation.
*/
void MEnzymeLogic::visitChild(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils) {
LogicalResult MEnzymeLogic::visitChild(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils) {
if ((op->getBlock()->getTerminator() != op) &&
llvm::all_of(op->getResults(),
[gutils](Value v) { return gutils->isConstantValue(v); }) &&
gutils->isConstantInstruction(op)) {
return;
return success();
}
if (auto ifaceOp = dyn_cast<ReverseAutoDiffOpInterface>(op)) {
SmallVector<Value> caches = ifaceOp.cacheValues(gutils);
ifaceOp.createReverseModeAdjoint(builder, gutils, caches);
return;
/*
for (int indexResult = 0; indexResult < (int)op->getNumResults();
indexResult++) {
Value result = op->getResult(indexResult);
gutils->clearValue(result, builder);
}
*/
return ifaceOp.createReverseModeAdjoint(builder, gutils, caches);
}
op->emitError() << "could not compute the adjoint for this operation " << *op;
return failure();
}

void MEnzymeLogic::visitChildren(Block *oBB, Block *reverseBB,
MGradientUtilsReverse *gutils) {
LogicalResult MEnzymeLogic::visitChildren(Block *oBB, Block *reverseBB,
MGradientUtilsReverse *gutils) {
OpBuilder revBuilder(reverseBB, reverseBB->end());
bool valid = true;
if (!oBB->empty()) {
auto first = oBB->rbegin();
auto last = oBB->rend();
for (auto it = first; it != last; ++it) {
Operation *op = &*it;
visitChild(op, revBuilder, gutils);
valid &= visitChild(op, revBuilder, gutils).succeeded();
}
}
return success(valid);
}

void MEnzymeLogic::handlePredecessors(
Expand Down Expand Up @@ -161,7 +156,7 @@ void MEnzymeLogic::handlePredecessors(
}
}

void MEnzymeLogic::differentiate(
LogicalResult MEnzymeLogic::differentiate(
MGradientUtilsReverse *gutils, Region &oldRegion, Region &newRegion,
llvm::function_ref<buildReturnFunction> buildFuncReturnOp,
std::function<std::pair<Value, Value>(Type)> cacheCreator) {
Expand All @@ -171,13 +166,15 @@ void MEnzymeLogic::differentiate(

gutils->createReverseModeBlocks(oldRegion, newRegion);

bool valid = true;
for (auto &oBB : oldRegion) {
Block *newBB = gutils->getNewFromOriginal(&oBB);
Block *reverseBB = gutils->mapReverseModeBlocks.lookupOrNull(&oBB);
handleReturns(&oBB, newBB, reverseBB, gutils);
visitChildren(&oBB, reverseBB, gutils);
valid &= visitChildren(&oBB, reverseBB, gutils).succeeded();
handlePredecessors(&oBB, newBB, reverseBB, gutils, buildFuncReturnOp);
}
return success(valid);
}

FunctionOpInterface MEnzymeLogic::CreateReverseDiff(
Expand Down Expand Up @@ -212,7 +209,8 @@ FunctionOpInterface MEnzymeLogic::CreateReverseDiff(

gutils->forceAugmentedReturns();

differentiate(gutils, oldRegion, newRegion, buildFuncReturnOp, nullptr);
auto res =
differentiate(gutils, oldRegion, newRegion, buildFuncReturnOp, nullptr);

auto nf = gutils->newFunc;

Expand All @@ -221,5 +219,9 @@ FunctionOpInterface MEnzymeLogic::CreateReverseDiff(
// llvm::errs() << "nf end\n";

delete gutils;

if (!res.succeeded())
return nullptr;

return nf;
}
6 changes: 4 additions & 2 deletions enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1341,7 +1341,7 @@ static void emitMLIRReverse(raw_ostream &os, Record *pattern, DagInit *tree,
os << " MGradientUtilsReverse *gutils) const "
"{}\n";

os << " void createReverseModeAdjoint(Operation *op0, OpBuilder "
os << " LogicalResult createReverseModeAdjoint(Operation *op0, OpBuilder "
"&builder,\n";
os << " MGradientUtilsReverse *gutils,\n";
os << " SmallVector<Value> caches) const {\n";
Expand Down Expand Up @@ -1995,6 +1995,8 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os,

if (intrinsic == IntrDerivatives || intrinsic == CallDerivatives)
os << " return true;\n }\n";
else if (intrinsic == MLIRDerivatives)
os << " return success();\n }\n";
else
os << " return;\n }\n";
if (intrinsic == MLIRDerivatives)
Expand Down Expand Up @@ -2063,7 +2065,7 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os,
auto origName = "op";
emitMLIRReverse(os, pattern, tree, intrinsic, origName, argOps);
emitReverseCommon(os, pattern, tree, intrinsic, origName, argOps);
os << " return;\n";
os << " return success();\n";
os << " }\n";
os << " };\n";
}
Expand Down
Loading