Skip to content

Commit

Permalink
MLIR improve error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Feb 16, 2024
1 parent bf02614 commit 4ccab29
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 7 deletions.
7 changes: 6 additions & 1 deletion enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ FunctionOpInterface mlir::enzyme::MEnzymeLogic::CreateForwardDiff(
unnecessaryInstructions, gutils, TLI);
*/

bool valid = true;
for (Block &oBB : gutils->oldFunc.getFunctionBody().getBlocks()) {
// Don't create derivatives for code that results in termination
if (guaranteedUnreachable.find(&oBB) != guaranteedUnreachable.end()) {
Expand All @@ -205,7 +206,8 @@ FunctionOpInterface mlir::enzyme::MEnzymeLogic::CreateForwardDiff(
auto last = oBB.empty() ? oBB.end() : std::prev(oBB.end());
for (auto it = first; it != last; ++it) {
// TODO: propagate errors.
(void)gutils->visitChild(&*it);
auto res = gutils->visitChild(&*it);
valid &= res.succeeded();
}

createTerminator(gutils, &oBB, retType, returnValue);
Expand All @@ -232,6 +234,9 @@ FunctionOpInterface mlir::enzyme::MEnzymeLogic::CreateForwardDiff(
auto nf = gutils->newFunc;
delete gutils;

if (!valid)
return nullptr;

// if (PostOpt)
// PPC.optimizeIntermediate(nf);
// if (EnzymePrint) {
Expand Down
3 changes: 2 additions & 1 deletion enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -318,5 +318,6 @@ LogicalResult MGradientUtils::visitChild(Operation *op) {
return iface.createForwardModeTangent(builder, this);
}
}
return op->emitError() << "could not compute the adjoint for this operation";
return op->emitError() << "could not compute the adjoint for this operation "
<< *op;
}
24 changes: 19 additions & 5 deletions enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ struct DifferentiatePass : public DifferentiatePassBase<DifferentiatePass> {
void runOnOperation() override;

template <typename T>
void HandleAutoDiff(SymbolTableCollection &symbolTable, T CI) {
LogicalResult HandleAutoDiff(SymbolTableCollection &symbolTable, T CI) {
std::vector<DIFFE_TYPE> constants;
SmallVector<mlir::Value, 2> args;

Expand Down Expand Up @@ -83,16 +83,20 @@ struct DifferentiatePass : public DifferentiatePassBase<DifferentiatePass> {
/*should return*/ false, mode, freeMemory, width,
/*addedType*/ nullptr, type_args, volatile_args,
/*augmented*/ nullptr);
if (!newFunc)
return failure();

OpBuilder builder(CI);
auto dCI = builder.create<func::CallOp>(CI.getLoc(), newFunc.getName(),
newFunc.getResultTypes(), args);
CI.replaceAllUsesWith(dCI);
CI->erase();
return success();
}

template <typename T>
void HandleAutoDiffReverse(SymbolTableCollection &symbolTable, T CI) {
LogicalResult HandleAutoDiffReverse(SymbolTableCollection &symbolTable,
T CI) {
std::vector<DIFFE_TYPE> constants;
SmallVector<mlir::Value, 2> args;

Expand Down Expand Up @@ -144,12 +148,15 @@ struct DifferentiatePass : public DifferentiatePassBase<DifferentiatePass> {
/*should return*/ false, mode, freeMemory, width,
/*addedType*/ nullptr, type_args, volatile_args,
/*augmented*/ nullptr, symbolTable);
if (!newFunc)
return failure();

OpBuilder builder(CI);
auto dCI = builder.create<func::CallOp>(CI.getLoc(), newFunc.getName(),
newFunc.getResultTypes(), args);
CI.replaceAllUsesWith(dCI);
CI->erase();
return success();
}

void lowerEnzymeCalls(SymbolTableCollection &symbolTable,
Expand All @@ -167,7 +174,11 @@ struct DifferentiatePass : public DifferentiatePassBase<DifferentiatePass> {

for (auto T : toLower) {
if (auto F = dyn_cast<enzyme::ForwardDiffOp>(T)) {
HandleAutoDiff(symbolTable, F);
auto res = HandleAutoDiff(symbolTable, F);
if (!res.succeeded()) {
signalPassFailure();
return;
}
} else {
llvm_unreachable("Illegal type");
}
Expand All @@ -187,7 +198,11 @@ struct DifferentiatePass : public DifferentiatePassBase<DifferentiatePass> {

for (auto T : toLower) {
if (auto F = dyn_cast<enzyme::AutoDiffOp>(T)) {
HandleAutoDiffReverse(symbolTable, F);
auto res = HandleAutoDiffReverse(symbolTable, F);
if (!res.succeeded()) {
signalPassFailure();
return;
}
} else {
llvm_unreachable("Illegal type");
}
Expand All @@ -201,7 +216,6 @@ struct DifferentiatePass : public DifferentiatePassBase<DifferentiatePass> {
namespace mlir {
namespace enzyme {
std::unique_ptr<Pass> createDifferentiatePass() {
new DifferentiatePass();
return std::make_unique<DifferentiatePass>();
}
} // namespace enzyme
Expand Down

0 comments on commit 4ccab29

Please sign in to comment.