Skip to content

Commit

Permalink
[MLIR] Fix reverse wrap pass infra (#1775)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Mar 2, 2024
1 parent 456cf5e commit 0b62188
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 105 deletions.
1 change: 0 additions & 1 deletion enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
#include "Interfaces/GradientUtils.h"
#include "Interfaces/GradientUtilsReverse.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/FunctionInterfaces.h"

// TODO: this shouldn't depend on specific dialects except Enzyme.
Expand Down
7 changes: 2 additions & 5 deletions enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,7 @@ class MEnzymeLogic {
std::vector<DIFFE_TYPE> constants, MTypeAnalysis &TA,
bool returnUsed, DerivativeMode mode, bool freeMemory,
size_t width, mlir::Type addedType, MFnTypeInfo type_args,
std::vector<bool> volatile_args, void *augmented,
SymbolTableCollection &symbolTable);
std::vector<bool> volatile_args, void *augmented);
void
initializeShadowValues(SmallVector<mlir::Block *> &dominatorToposortBlocks,
MGradientUtilsReverse *gutils);
Expand All @@ -132,8 +131,6 @@ class MEnzymeLogic {
MGradientUtilsReverse *gutils);
void visitChild(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils);
bool visitChildCustom(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils);
void mapInvertArguments(Block *oBB, Block *reverseBB,
MGradientUtilsReverse *gutils);
SmallVector<mlir::Block *> getDominatorToposort(MGradientUtilsReverse *gutils,
Expand All @@ -145,4 +142,4 @@ class MEnzymeLogic {
};

} // Namespace enzyme
} // Namespace mlir
} // Namespace mlir
80 changes: 3 additions & 77 deletions enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,76 +35,6 @@ void handleReturns(Block *oBB, Block *newBB, Block *reverseBB,
}
}

bool MEnzymeLogic::visitChildCustom(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils) {
std::string nameDiffe = "diffe_" + op->getName().getDialectNamespace().str() +
"_" + op->getName().stripDialect().str();
std::string nameStore = "store_" + op->getName().getDialectNamespace().str() +
"_" + op->getName().stripDialect().str();

StringRef srDiffe(nameDiffe);
StringRef srStore(nameStore);

OperationName opNameDiffe(srDiffe, op->getContext());
OperationName opNameStore(srStore, op->getContext());

Operation *symbolDiffe = gutils->symbolTable.lookupNearestSymbolFrom(
op, opNameDiffe.getIdentifier());
Operation *symbolStore = gutils->symbolTable.lookupNearestSymbolFrom(
op, opNameStore.getIdentifier());

if (symbolDiffe != nullptr) {
SmallVector<Value> caches;
if (symbolStore != nullptr) {
Operation *newOp = gutils->getNewFromOriginal(op);

func::FuncOp funcStore = cast<func::FuncOp>(symbolStore);

SmallVector<Type, 2> storeResultTypes;
for (auto x : funcStore.getFunctionType().getResults()) {
storeResultTypes.push_back(x);
}

SmallVector<Value, 2> storeArgs;
for (auto x : newOp->getOperands()) {
storeArgs.push_back(x);
}

OpBuilder storeBuilder(newOp);
func::CallOp storeCI = storeBuilder.create<func::CallOp>(
op->getLoc(), srStore, storeResultTypes, storeArgs);
for (auto x : storeCI.getResults()) {
caches.push_back(gutils->initAndPushCache(x, storeBuilder));
}
}

SmallVector<Value> args;
for (Value opResult : op->getResults()) {
if (!gutils->isConstantValue(opResult)) {
Value invertValue = gutils->invertPointerM(opResult, builder);
args.push_back(invertValue);
}
}
for (Value cache : caches) {
args.push_back(gutils->popCache(cache, builder));
}

SmallVector<Type, 2> resultTypes;
for (auto x : op->getOperands()) {
resultTypes.push_back(x.getType());
}

func::CallOp dCI =
builder.create<func::CallOp>(op->getLoc(), srDiffe, resultTypes, args);
for (int i = 0; i < (int)op->getNumOperands(); i++) {
gutils->setDiffe(op->getOperand(i), dCI.getResult(i), builder);
}

return true;
}
return false;
}

/*
Create reverse mode adjoint for an operation.
*/
Expand Down Expand Up @@ -139,10 +69,7 @@ void MEnzymeLogic::visitChildren(Block *oBB, Block *reverseBB,
auto last = oBB->rend();
for (auto it = first; it != last; ++it) {
Operation *op = &*it;
bool customFound = visitChildCustom(op, revBuilder, gutils);
if (!customFound) {
visitChild(op, revBuilder, gutils);
}
visitChild(op, revBuilder, gutils);
}
}
}
Expand Down Expand Up @@ -257,8 +184,7 @@ FunctionOpInterface MEnzymeLogic::CreateReverseDiff(
FunctionOpInterface fn, DIFFE_TYPE retType,
std::vector<DIFFE_TYPE> constants, MTypeAnalysis &TA, bool returnUsed,
DerivativeMode mode, bool freeMemory, size_t width, mlir::Type addedType,
MFnTypeInfo type_args, std::vector<bool> volatile_args, void *augmented,
SymbolTableCollection &symbolTable) {
MFnTypeInfo type_args, std::vector<bool> volatile_args, void *augmented) {

if (fn.getFunctionBody().empty()) {
llvm::errs() << fn << "\n";
Expand All @@ -268,7 +194,7 @@ FunctionOpInterface MEnzymeLogic::CreateReverseDiff(
ReturnType returnValue = ReturnType::Args;
MGradientUtilsReverse *gutils = MGradientUtilsReverse::CreateFromClone(
*this, mode, width, fn, TA, type_args, retType, /*diffeReturnArg*/ true,
constants, returnValue, addedType, symbolTable);
constants, returnValue, addedType);

Region &oldRegion = gutils->oldFunc.getFunctionBody();
Region &newRegion = gutils->newFunc.getFunctionBody();
Expand Down
16 changes: 7 additions & 9 deletions enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,11 @@ mlir::enzyme::MGradientUtilsReverse::MGradientUtilsReverse(
const SmallPtrSetImpl<mlir::Value> &activevals_, DIFFE_TYPE ReturnActivity,
ArrayRef<DIFFE_TYPE> ArgDiffeTypes_, IRMapping &originalToNewFn_,
std::map<Operation *, Operation *> &originalToNewFnOps_,
DerivativeMode mode_, unsigned width, SymbolTableCollection &symbolTable_)
DerivativeMode mode_, unsigned width)
: MDiffeGradientUtils(Logic, newFunc_, oldFunc_, TA_, /*MTypeResults*/ {},
invertedPointers_, constantvalues_, activevals_,
ReturnActivity, ArgDiffeTypes_, originalToNewFn_,
originalToNewFnOps_, mode_, width, /*omp*/ false),
symbolTable(symbolTable_) {}
originalToNewFnOps_, mode_, width, /*omp*/ false) {}

Type mlir::enzyme::MGradientUtilsReverse::getIndexCacheType() {
Type indexType = getIndexType();
Expand Down Expand Up @@ -135,8 +134,7 @@ MGradientUtilsReverse *MGradientUtilsReverse::CreateFromClone(
MEnzymeLogic &Logic, DerivativeMode mode_, unsigned width,
FunctionOpInterface todiff, MTypeAnalysis &TA, MFnTypeInfo &oldTypeInfo,
DIFFE_TYPE retType, bool diffeReturnArg, ArrayRef<DIFFE_TYPE> constant_args,
ReturnType returnValue, mlir::Type additionalArg,
SymbolTableCollection &symbolTable_) {
ReturnType returnValue, mlir::Type additionalArg) {
std::string prefix;

switch (mode_) {
Expand Down Expand Up @@ -168,8 +166,8 @@ MGradientUtilsReverse *MGradientUtilsReverse::CreateFromClone(
prefix + todiff.getName(), originalToNew, originalToNewOps,
diffeReturnArg, additionalArg);

return new MGradientUtilsReverse(
Logic, newFunc, todiff, TA, invertedPointers, constant_values,
nonconstant_values, retType, constant_args, originalToNew,
originalToNewOps, mode_, width, symbolTable_);
return new MGradientUtilsReverse(Logic, newFunc, todiff, TA, invertedPointers,
constant_values, nonconstant_values, retType,
constant_args, originalToNew,
originalToNewOps, mode_, width);
}
8 changes: 2 additions & 6 deletions enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,10 @@ class MGradientUtilsReverse : public MDiffeGradientUtils {
ArrayRef<DIFFE_TYPE> ArgDiffeTypes_,
IRMapping &originalToNewFn_,
std::map<Operation *, Operation *> &originalToNewFnOps_,
DerivativeMode mode_, unsigned width,
SymbolTableCollection &symbolTable_);
DerivativeMode mode_, unsigned width);

IRMapping mapReverseModeBlocks;

SymbolTableCollection &symbolTable;

void addToDiffe(mlir::Value oldGradient, mlir::Value addedGradient,
OpBuilder &builder);

Expand Down Expand Up @@ -67,8 +64,7 @@ class MGradientUtilsReverse : public MDiffeGradientUtils {
FunctionOpInterface todiff, MTypeAnalysis &TA,
MFnTypeInfo &oldTypeInfo, DIFFE_TYPE retType,
bool diffeReturnArg, ArrayRef<DIFFE_TYPE> constant_args,
ReturnType returnValue, mlir::Type additionalArg,
SymbolTableCollection &symbolTable_);
ReturnType returnValue, mlir::Type additionalArg);
};

} // namespace enzyme
Expand Down
2 changes: 1 addition & 1 deletion enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ struct DifferentiatePass : public DifferentiatePassBase<DifferentiatePass> {
fn, retType, constants, TA,
/*should return*/ false, mode, freeMemory, width,
/*addedType*/ nullptr, type_args, volatile_args,
/*augmented*/ nullptr, symbolTable);
/*augmented*/ nullptr);
if (!newFunc)
return failure();

Expand Down
21 changes: 15 additions & 6 deletions enzyme/Enzyme/MLIR/Passes/EnzymeWrapPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,21 @@ struct DifferentiateWrapperPass
volatile_args.push_back(!(mode == DerivativeMode::ReverseModeCombined));
}

FunctionOpInterface newFunc = Logic.CreateForwardDiff(
fn, retType, constants, TA,
/*should return*/ (retType == DIFFE_TYPE::DUP_ARG), mode, freeMemory,
width,
/*addedType*/ nullptr, type_args, volatile_args,
/*augmented*/ nullptr);
FunctionOpInterface newFunc;
if (mode == DerivativeMode::ForwardMode) {
newFunc = Logic.CreateForwardDiff(
fn, retType, constants, TA,
/*should return*/ (retType == DIFFE_TYPE::DUP_ARG), mode, freeMemory,
width,
/*addedType*/ nullptr, type_args, volatile_args,
/*augmented*/ nullptr);
} else {
newFunc = Logic.CreateReverseDiff(
fn, retType, constants, TA,
/*should return*/ false, mode, freeMemory, width,
/*addedType*/ nullptr, type_args, volatile_args,
/*augmented*/ nullptr);
}
if (!newFunc) {
signalPassFailure();
return;
Expand Down

0 comments on commit 0b62188

Please sign in to comment.