Skip to content

Commit

Permalink
Fix MLIR memory bug
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Jan 22, 2024
1 parent 459070a commit 2b880c2
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 15 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/enzyme-mlir.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:
- uses: actions/checkout@v3
with:
repository: 'llvm/llvm-project'
ref: '5ed11e767c0c39a3bc8e035588e7a383849d46a8'
ref: 'bc82cfb38d83f1afeb2c290aa472c2e2e88919cb'
path: 'llvm-project'

- name: Get MLIR commit hash
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ struct GenericOpInterfaceReverse
MGradientUtilsReverse *gutils,
SmallVector<Value> caches) const {
auto linalgOp = cast<linalg::LinalgOp>(op);
assert(linalgOp.hasBufferSemantics() &&
assert(linalgOp.hasPureBufferSemantics() &&
"Linalg op with tensor semantics not yet supported");

linalg::LinalgOp newOp =
Expand Down Expand Up @@ -278,4 +278,4 @@ void mlir::enzyme::registerLinalgDialectAutoDiffInterface(
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
>(context);
});
}
}
28 changes: 16 additions & 12 deletions enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,11 +222,12 @@ void MEnzymeLogic::handlePredecessors(
} else {
SmallVector<Block *> blocks;
SmallVector<APInt> indices;
SmallVector<ValueRange> arguments;
SmallVector<SmallVector<Value>> arguments;
SmallVector<Value> defaultArguments;
Block *defaultBlock;
int i = 1;
for (Block *predecessor : oBB->getPredecessors()) {
Block *defaultBlock = nullptr;
for (auto pair : llvm::enumerate(oBB->getPredecessors())) {
auto predecessor = pair.value();
auto idx = pair.index();
Block *predecessorRevMode =
gutils->mapReverseModeBlocks.lookupOrNull(predecessor);

Expand All @@ -250,10 +251,10 @@ void MEnzymeLogic::handlePredecessors(
}
}
}
if (predecessor != *(oBB->getPredecessors().begin())) {
if (idx != 0) {
blocks.push_back(predecessorRevMode);
indices.push_back(APInt(32, i++));
arguments.push_back(operands);
indices.push_back(APInt(32, idx - 1));
arguments.emplace_back(std::move(operands));
} else {
defaultBlock = predecessorRevMode;
defaultArguments = operands;
Expand All @@ -275,15 +276,19 @@ void MEnzymeLogic::handlePredecessors(
oBB->getPredecessors().end()) {
// If there is only one block we can directly create a branch for
// simplicity sake
revBuilder.create<cf::BranchOp>(loc, defaultBlock, defaultArguments);
auto bop =
revBuilder.create<cf::BranchOp>(loc, defaultBlock, defaultArguments);
} else {
Value cache = gutils->insertInit(gutils->getIndexCacheType());
Value flag =
revBuilder.create<enzyme::PopOp>(loc, gutils->getIndexType(), cache);

revBuilder.create<cf::SwitchOp>(
SmallVector<ValueRange> argumentRanges;
for (const auto &a : arguments)
argumentRanges.emplace_back(a);
auto bop = revBuilder.create<cf::SwitchOp>(
loc, flag, defaultBlock, defaultArguments, ArrayRef<APInt>(indices),
ArrayRef<Block *>(blocks), ArrayRef<ValueRange>(arguments));
ArrayRef<Block *>(blocks), argumentRanges);

Value origin = newBB->addArgument(gutils->getIndexType(), loc);

Expand Down Expand Up @@ -356,7 +361,6 @@ void MEnzymeLogic::differentiate(
Block *oBB = *it;
Block *newBB = gutils->getNewFromOriginal(oBB);
Block *reverseBB = gutils->mapReverseModeBlocks.lookupOrNull(oBB);

mapInvertArguments(oBB, reverseBB, gutils);
handleReturns(oBB, newBB, reverseBB, gutils, parentRegion);
visitChildren(oBB, reverseBB, gutils);
Expand Down Expand Up @@ -401,4 +405,4 @@ FunctionOpInterface MEnzymeLogic::CreateReverseDiff(

delete gutils;
return nf;
}
}

0 comments on commit 2b880c2

Please sign in to comment.