Skip to content

Commit

Permalink
additional cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Feb 29, 2024
1 parent b9ab98f commit 01d8acf
Show file tree
Hide file tree
Showing 3 changed files with 1 addition and 67 deletions.
2 changes: 1 addition & 1 deletion enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ bool MEnzymeLogic::visitChildCustom(Operation *op, OpBuilder &builder,
func::CallOp dCI =
builder.create<func::CallOp>(op->getLoc(), srDiffe, resultTypes, args);
for (int i = 0; i < (int)op->getNumOperands(); i++) {
gutils->mapInvertPointer(op->getOperand(i), dCI.getResult(i), builder);
gutils->setDiffe(op->getOperand(i), dCI.getResult(i), builder);
}

return true;
Expand Down
59 changes: 0 additions & 59 deletions enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,6 @@ mlir::enzyme::MGradientUtilsReverse::MGradientUtilsReverse(
ReturnActivity, ArgDiffeTypes_, originalToNewFn_,
originalToNewFnOps_, mode_, width, /*omp*/ false),
symbolTable(symbolTable_) {

initInitializationBlock(invertedPointers_, ArgDiffeTypes_);
}

// for(auto x : v.getUsers()){x->dump();} DEBUG

bool MGradientUtilsReverse::onlyUsedInParentBlock(Value v) {
return !v.isUsedOutsideOfBlock(v.getParentBlock());
}

Type mlir::enzyme::MGradientUtilsReverse::getIndexCacheType() {
Expand Down Expand Up @@ -129,57 +121,6 @@ void mlir::enzyme::MGradientUtilsReverse::addToDiffe(Value oldGradient,
setDiffe(oldGradient, added, builder);
}

void mlir::enzyme::MGradientUtilsReverse::mapInvertPointer(
mlir::Value v, mlir::Value invertValue, OpBuilder &builder) {
assert(0);
/*
if (!invertedPointersGlobal.contains(v)) {
Value g = insertInitGradient(v, builder);
invertedPointersGlobal.map(v, g);
}
Value gradient = invertedPointersGlobal.lookupOrNull(v);
builder.create<enzyme::SetOp>(v.getLoc(), gradient, invertValue);
*/
}

void MGradientUtilsReverse::initInitializationBlock(
IRMapping invertedPointers_, ArrayRef<DIFFE_TYPE> argDiffeTypes) {

OpBuilder initializationBuilder(
&*(this->newFunc.getFunctionBody().begin()),
this->newFunc.getFunctionBody().begin()->begin());

/*
for (const auto &[val, diffe_type] : llvm::zip(
this->oldFunc.getFunctionBody().getArguments(), argDiffeTypes)) {
if (diffe_type != DIFFE_TYPE::OUT_DIFF) {
continue;
}
auto iface = dyn_cast<AutoDiffTypeInterface>(val.getType());
if (!iface) {
llvm_unreachable(
"Type does not have an associated AutoDiffTypeInterface");
}
Value zero = iface.createNullValue(initializationBuilder, val.getLoc());
mapInvertPointer(val, zero, initializationBuilder);
}
for (auto const &x : invertedPointers_.getValueMap()) {
if (auto iface = dyn_cast<AutoDiffTypeInterface>(x.first.getType())) {
if (!iface.isMutable()) {
mapShadowValue(x.first, x.second,
initializationBuilder); // This may create an unnecessary
// ShadowedGradient which could
// be avoidable TODO
} else {
mapInvertPointer(x.first, x.second, initializationBuilder);
}
} else {
llvm_unreachable("TODO not implemented");
}
}
*/
}

void MGradientUtilsReverse::createReverseModeBlocks(Region &oldFunc,
Region &newFunc) {
for (auto it = oldFunc.getBlocks().rbegin(); it != oldFunc.getBlocks().rend();
Expand Down
7 changes: 0 additions & 7 deletions enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@ class MGradientUtilsReverse : public MDiffeGradientUtils {

void addToDiffe(mlir::Value oldGradient, mlir::Value addedGradient,
OpBuilder &builder);
void mapInvertPointer(mlir::Value v, mlir::Value invertValue,
OpBuilder &builder);

Type getIndexType();
Value insertInit(Type t);
Expand All @@ -58,11 +56,6 @@ class MGradientUtilsReverse : public MDiffeGradientUtils {
Type getIndexCacheType();
Value initAndPushCache(Value v, OpBuilder &builder);

void initInitializationBlock(IRMapping invertedPointers_,
ArrayRef<DIFFE_TYPE> argDiffeTypes);

bool onlyUsedInParentBlock(Value v);

Operation *cloneWithNewOperands(OpBuilder &B, Operation *op);

Value popCache(Value cache, OpBuilder &builder);
Expand Down

0 comments on commit 01d8acf

Please sign in to comment.