Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR] Fix reverse wrap pass infra #1775

Merged
merged 1 commit into from
Mar 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
#include "Interfaces/GradientUtils.h"
#include "Interfaces/GradientUtilsReverse.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/FunctionInterfaces.h"

// TODO: this shouldn't depend on specific dialects except Enzyme.
Expand Down
7 changes: 2 additions & 5 deletions enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,7 @@ class MEnzymeLogic {
std::vector<DIFFE_TYPE> constants, MTypeAnalysis &TA,
bool returnUsed, DerivativeMode mode, bool freeMemory,
size_t width, mlir::Type addedType, MFnTypeInfo type_args,
std::vector<bool> volatile_args, void *augmented,
SymbolTableCollection &symbolTable);
std::vector<bool> volatile_args, void *augmented);
void
initializeShadowValues(SmallVector<mlir::Block *> &dominatorToposortBlocks,
MGradientUtilsReverse *gutils);
Expand All @@ -132,8 +131,6 @@ class MEnzymeLogic {
MGradientUtilsReverse *gutils);
void visitChild(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils);
bool visitChildCustom(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils);
void mapInvertArguments(Block *oBB, Block *reverseBB,
MGradientUtilsReverse *gutils);
SmallVector<mlir::Block *> getDominatorToposort(MGradientUtilsReverse *gutils,
Expand All @@ -145,4 +142,4 @@ class MEnzymeLogic {
};

} // Namespace enzyme
} // Namespace mlir
} // Namespace mlir
80 changes: 3 additions & 77 deletions enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,76 +35,6 @@ void handleReturns(Block *oBB, Block *newBB, Block *reverseBB,
}
}

bool MEnzymeLogic::visitChildCustom(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils) {
std::string nameDiffe = "diffe_" + op->getName().getDialectNamespace().str() +
"_" + op->getName().stripDialect().str();
std::string nameStore = "store_" + op->getName().getDialectNamespace().str() +
"_" + op->getName().stripDialect().str();

StringRef srDiffe(nameDiffe);
StringRef srStore(nameStore);

OperationName opNameDiffe(srDiffe, op->getContext());
OperationName opNameStore(srStore, op->getContext());

Operation *symbolDiffe = gutils->symbolTable.lookupNearestSymbolFrom(
op, opNameDiffe.getIdentifier());
Operation *symbolStore = gutils->symbolTable.lookupNearestSymbolFrom(
op, opNameStore.getIdentifier());

if (symbolDiffe != nullptr) {
SmallVector<Value> caches;
if (symbolStore != nullptr) {
Operation *newOp = gutils->getNewFromOriginal(op);

func::FuncOp funcStore = cast<func::FuncOp>(symbolStore);

SmallVector<Type, 2> storeResultTypes;
for (auto x : funcStore.getFunctionType().getResults()) {
storeResultTypes.push_back(x);
}

SmallVector<Value, 2> storeArgs;
for (auto x : newOp->getOperands()) {
storeArgs.push_back(x);
}

OpBuilder storeBuilder(newOp);
func::CallOp storeCI = storeBuilder.create<func::CallOp>(
op->getLoc(), srStore, storeResultTypes, storeArgs);
for (auto x : storeCI.getResults()) {
caches.push_back(gutils->initAndPushCache(x, storeBuilder));
}
}

SmallVector<Value> args;
for (Value opResult : op->getResults()) {
if (!gutils->isConstantValue(opResult)) {
Value invertValue = gutils->invertPointerM(opResult, builder);
args.push_back(invertValue);
}
}
for (Value cache : caches) {
args.push_back(gutils->popCache(cache, builder));
}

SmallVector<Type, 2> resultTypes;
for (auto x : op->getOperands()) {
resultTypes.push_back(x.getType());
}

func::CallOp dCI =
builder.create<func::CallOp>(op->getLoc(), srDiffe, resultTypes, args);
for (int i = 0; i < (int)op->getNumOperands(); i++) {
gutils->setDiffe(op->getOperand(i), dCI.getResult(i), builder);
}

return true;
}
return false;
}

/*
Create reverse mode adjoint for an operation.
*/
Expand Down Expand Up @@ -139,10 +69,7 @@ void MEnzymeLogic::visitChildren(Block *oBB, Block *reverseBB,
auto last = oBB->rend();
for (auto it = first; it != last; ++it) {
Operation *op = &*it;
bool customFound = visitChildCustom(op, revBuilder, gutils);
if (!customFound) {
visitChild(op, revBuilder, gutils);
}
visitChild(op, revBuilder, gutils);
}
}
}
Expand Down Expand Up @@ -257,8 +184,7 @@ FunctionOpInterface MEnzymeLogic::CreateReverseDiff(
FunctionOpInterface fn, DIFFE_TYPE retType,
std::vector<DIFFE_TYPE> constants, MTypeAnalysis &TA, bool returnUsed,
DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType,
MFnTypeInfo type_args, std::vector<bool> volatile_args, void *augmented,
SymbolTableCollection &symbolTable) {
MFnTypeInfo type_args, std::vector<bool> volatile_args, void *augmented) {

if (fn.getFunctionBody().empty()) {
llvm::errs() << fn << "\n";
Expand All @@ -268,7 +194,7 @@ FunctionOpInterface MEnzymeLogic::CreateReverseDiff(
ReturnType returnValue = ReturnType::Args;
MGradientUtilsReverse *gutils = MGradientUtilsReverse::CreateFromClone(
*this, mode, width, fn, TA, type_args, retType, /*diffeReturnArg*/ true,
constants, returnValue, addedType, symbolTable);
constants, returnValue, addedType);

Region &oldRegion = gutils->oldFunc.getFunctionBody();
Region &newRegion = gutils->newFunc.getFunctionBody();
Expand Down
16 changes: 7 additions & 9 deletions enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,11 @@ mlir::enzyme::MGradientUtilsReverse::MGradientUtilsReverse(
const SmallPtrSetImpl<mlir::Value> &activevals_, DIFFE_TYPE ReturnActivity,
ArrayRef<DIFFE_TYPE> ArgDiffeTypes_, IRMapping &originalToNewFn_,
std::map<Operation *, Operation *> &originalToNewFnOps_,
DerivativeMode mode_, unsigned width, SymbolTableCollection &symbolTable_)
DerivativeMode mode_, unsigned width)
: MDiffeGradientUtils(Logic, newFunc_, oldFunc_, TA_, /*MTypeResults*/ {},
invertedPointers_, constantvalues_, activevals_,
ReturnActivity, ArgDiffeTypes_, originalToNewFn_,
originalToNewFnOps_, mode_, width, /*omp*/ false),
symbolTable(symbolTable_) {}
originalToNewFnOps_, mode_, width, /*omp*/ false) {}

Type mlir::enzyme::MGradientUtilsReverse::getIndexCacheType() {
Type indexType = getIndexType();
Expand Down Expand Up @@ -135,8 +134,7 @@ MGradientUtilsReverse *MGradientUtilsReverse::CreateFromClone(
MEnzymeLogic &Logic, DerivativeMode mode_, unsigned width,
FunctionOpInterface todiff, MTypeAnalysis &TA, MFnTypeInfo &oldTypeInfo,
DIFFE_TYPE retType, bool diffeReturnArg, ArrayRef<DIFFE_TYPE> constant_args,
ReturnType returnValue, mlir::Type additionalArg,
SymbolTableCollection &symbolTable_) {
ReturnType returnValue, mlir::Type additionalArg) {
std::string prefix;

switch (mode_) {
Expand Down Expand Up @@ -168,8 +166,8 @@ MGradientUtilsReverse *MGradientUtilsReverse::CreateFromClone(
prefix + todiff.getName(), originalToNew, originalToNewOps,
diffeReturnArg, additionalArg);

return new MGradientUtilsReverse(
Logic, newFunc, todiff, TA, invertedPointers, constant_values,
nonconstant_values, retType, constant_args, originalToNew,
originalToNewOps, mode_, width, symbolTable_);
return new MGradientUtilsReverse(Logic, newFunc, todiff, TA, invertedPointers,
constant_values, nonconstant_values, retType,
constant_args, originalToNew,
originalToNewOps, mode_, width);
}
8 changes: 2 additions & 6 deletions enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,10 @@ class MGradientUtilsReverse : public MDiffeGradientUtils {
ArrayRef<DIFFE_TYPE> ArgDiffeTypes_,
IRMapping &originalToNewFn_,
std::map<Operation *, Operation *> &originalToNewFnOps_,
DerivativeMode mode_, unsigned width,
SymbolTableCollection &symbolTable_);
DerivativeMode mode_, unsigned width);

IRMapping mapReverseModeBlocks;

SymbolTableCollection &symbolTable;

void addToDiffe(mlir::Value oldGradient, mlir::Value addedGradient,
OpBuilder &builder);

Expand Down Expand Up @@ -67,8 +64,7 @@ class MGradientUtilsReverse : public MDiffeGradientUtils {
FunctionOpInterface todiff, MTypeAnalysis &TA,
MFnTypeInfo &oldTypeInfo, DIFFE_TYPE retType,
bool diffeReturnArg, ArrayRef<DIFFE_TYPE> constant_args,
ReturnType returnValue, mlir::Type additionalArg,
SymbolTableCollection &symbolTable_);
ReturnType returnValue, mlir::Type additionalArg);
};

} // namespace enzyme
Expand Down
2 changes: 1 addition & 1 deletion enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ struct DifferentiatePass : public DifferentiatePassBase<DifferentiatePass> {
fn, retType, constants, TA,
/*should return*/ false, mode, freeMemory, width,
/*addedType*/ nullptr, type_args, volatile_args,
/*augmented*/ nullptr, symbolTable);
/*augmented*/ nullptr);
if (!newFunc)
return failure();

Expand Down
21 changes: 15 additions & 6 deletions enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,21 @@ struct DifferentiateWrapperPass
volatile_args.push_back(!(mode == DerivativeMode::ReverseModeCombined));
}

FunctionOpInterface newFunc = Logic.CreateForwardDiff(
fn, retType, constants, TA,
/*should return*/ (retType == DIFFE_TYPE::DUP_ARG), mode, freeMemory,
width,
/*addedType*/ nullptr, type_args, volatile_args,
/*augmented*/ nullptr);
FunctionOpInterface newFunc;
if (mode == DerivativeMode::ForwardMode) {
newFunc = Logic.CreateForwardDiff(
fn, retType, constants, TA,
/*should return*/ (retType == DIFFE_TYPE::DUP_ARG), mode, freeMemory,
width,
/*addedType*/ nullptr, type_args, volatile_args,
/*augmented*/ nullptr);
} else {
newFunc = Logic.CreateReverseDiff(
fn, retType, constants, TA,
/*should return*/ false, mode, freeMemory, width,
/*addedType*/ nullptr, type_args, volatile_args,
/*augmented*/ nullptr);
}
if (!newFunc) {
signalPassFailure();
return;
Expand Down
Loading