diff --git a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td index 16a11db1529c..7d74fdafbdbf 100644 --- a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td +++ b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td @@ -111,13 +111,15 @@ def Gradient : Enzyme_Type<"Gradient"> { def SetOp : Enzyme_Op<"set"> { let summary = "Store the current value of the gradient"; - let arguments = (ins AnyType : $gradient, AnyType : $value); + let arguments = (ins Arg:$gradient, AnyType : $value); let results = (outs ); } def GetOp : Enzyme_Op<"get"> { let summary = "Load current value of gradient"; - let arguments = (ins AnyType : $gradient); + let arguments = (ins Arg:$gradient); let results = (outs AnyType); } diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp index 775cc4647146..ade1be7e6406 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp @@ -201,9 +201,9 @@ LogicalResult mlir::enzyme::detail::allocationForwardHandler( return success(); } - -void mlir::enzyme::detail::returnReverseHandler(Operation *op, OpBuilder &builder, - MGradientUtilsReverse *gutils) { +void mlir::enzyme::detail::returnReverseHandler(Operation *op, + OpBuilder &builder, + MGradientUtilsReverse *gutils) { size_t num_out = 0; for (auto act : gutils->RetDiffeTypes) { if (act == DIFFE_TYPE::OUT_DIFF) @@ -216,10 +216,10 @@ void mlir::enzyme::detail::returnReverseHandler(Operation *op, OpBuilder &builde for (auto &&[op, act] : llvm::zip(op->getOperands(), gutils->RetDiffeTypes)) { if (act == DIFFE_TYPE::OUT_DIFF) { if (!gutils->isConstantValue(op)) { - auto d_out = args[args.size() - num_out + idx]; - gutils->addToDiffe(op, d_out, builder); - } - idx++; + auto d_out = args[args.size() - num_out + idx]; + gutils->addToDiffe(op, d_out, builder); + } + idx++; } } } diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h index 74c16d44d2f1..cbad734656b1 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h @@ -113,7 +113,6 @@ class NoopRevAutoDiffInterface : public ReverseAutoDiffOpInterface::ExternalModel< NoopRevAutoDiffInterface, OpTy> { public: - void createReverseModeAdjoint(Operation *op, OpBuilder &builder, MGradientUtilsReverse *gutils, SmallVector caches) const {} @@ -124,8 +123,7 @@ class NoopRevAutoDiffInterface } void createShadowValues(Operation *op, OpBuilder &builder, - MGradientUtilsReverse *gutils) const { - } + MGradientUtilsReverse *gutils) const {} }; template @@ -133,7 +131,6 @@ class ReturnRevAutoDiffInterface : public ReverseAutoDiffOpInterface::ExternalModel< ReturnRevAutoDiffInterface, OpTy> { public: - void createReverseModeAdjoint(Operation *op, OpBuilder &builder, MGradientUtilsReverse *gutils, SmallVector caches) const { @@ -146,8 +143,7 @@ class ReturnRevAutoDiffInterface } void createShadowValues(Operation *op, OpBuilder &builder, - MGradientUtilsReverse *gutils) const { - } + MGradientUtilsReverse *gutils) const {} }; // Implements the forward autodiff interface for operations which are @@ -212,7 +208,8 @@ void registerAutoDiffUsingControlFlowInterface(MLIRContext &context) { template void registerAutoDiffUsingBranchInterface(MLIRContext &context) { OpTy::template attachInterface>(context); - OpTy::template attachInterface>(context); + OpTy::template attachInterface>( + context); } // Registers AutoDiffUsingRegionTerminator for the given op. template diff --git a/enzyme/Enzyme/MLIR/Implementations/LinalgAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/LinalgAutoDiffOpInterfaceImpl.cpp index 385b16fc2e04..db9f4b08d6e1 100644 --- a/enzyme/Enzyme/MLIR/Implementations/LinalgAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/LinalgAutoDiffOpInterfaceImpl.cpp @@ -165,7 +165,8 @@ struct GenericOpInterfaceReverse StringAttr()); int numInputs = inputs.size(); - auto buildFuncReturnOp = [&gutils, numInputs](OpBuilder &builder, Block* oBB) { + auto buildFuncReturnOp = [&gutils, numInputs](OpBuilder &builder, + Block *oBB) { auto loc = oBB->rbegin()->getLoc(); SmallVector retargs; for (auto arg : oBB->getArguments()) { @@ -196,8 +197,8 @@ struct GenericOpInterfaceReverse return std::make_pair(pushCache, popCache); }; - gutils->Logic.differentiate( - gutils, *linalgOp.getBlock()->getParent(), adjoint.getRegion(), buildFuncReturnOp, hook); + gutils->Logic.differentiate(gutils, *linalgOp.getBlock()->getParent(), + adjoint.getRegion(), buildFuncReturnOp, hook); auto newOpYield = cast( cast(newOp).getBodyRegion().front().getTerminator()); diff --git a/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp index 9451ce88b415..cd21c5c548b9 100644 --- a/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp @@ -119,17 +119,18 @@ struct StoreOpInterfaceReverse if (!iface.isMutable()) { if (!gutils->isConstantValue(val)) { - Value loadedGradient = - builder.create(storeOp.getLoc(), memrefGradient, - ArrayRef(retrievedArguments)); + Value loadedGradient = builder.create( + storeOp.getLoc(), memrefGradient, + ArrayRef(retrievedArguments)); gutils->addToDiffe(val, loadedGradient, builder); } - auto zero = cast(gutils->getShadowType(val.getType())).createNullValue(builder, op->getLoc()); + auto zero = + cast(gutils->getShadowType(val.getType())) + .createNullValue(builder, op->getLoc()); builder.create(storeOp.getLoc(), zero, memrefGradient, - ArrayRef(retrievedArguments)); - + ArrayRef(retrievedArguments)); } } } diff --git a/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp index 8c39fc25bf16..72fbe2106a53 100644 --- a/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp @@ -40,102 +40,114 @@ struct ForOpInterfaceReverse SmallVector caches) const { auto forOp = cast(op); - // Begin Perform d(yielded value[i]) += d(result[i]); d(result[i]) = 0 SmallVector resDiffes; for (OpResult v : forOp.getResults()) { - if (!gutils->isConstantValue(v)) { + if (!gutils->isConstantValue(v)) { auto autoDiffType = cast(v.getType()); - if (autoDiffType.isMutable()) { + if (!autoDiffType.isMutable()) { auto prev = gutils->diffe(v, builder); gutils->zeroDiffe(v, builder); resDiffes.push_back(prev); continue; } - } - resDiffes.push_back(nullptr); + } + resDiffes.push_back(nullptr); } + for (auto ® : op->getRegions()) { + auto termIface = + cast(reg.begin()->getTerminator()); - for (auto ® : op->getRegions()) { - auto termIface = cast( - reg.begin()->getTerminator()); - - - SmallVector successors; - termIface.getSuccessorRegions( - SmallVector(termIface->getNumOperands(), Attribute()), - successors); + SmallVector successors; + termIface.getSuccessorRegions( + SmallVector(termIface->getNumOperands(), Attribute()), + successors); for (auto &successor : successors) { - if (!successor.isParent()) continue; - OperandRange operandRange = - termIface.getSuccessorOperands(successor); + if (!successor.isParent()) + continue; + OperandRange operandRange = termIface.getSuccessorOperands(successor); assert(operandRange.size() == resDiffes.size()); - // There is an assumption here that there is only regions that branch to the successor. - // Specifically, otherwise we would need to gutils->addToDiffe select (if came from that result) + + // There is an assumption here that there is only regions that branch to + // the successor. Specifically, otherwise we would need to + // gutils->addToDiffe select (if came from that result) for (auto &&[prev, post] : llvm::zip(operandRange, resDiffes)) { - if (!post) continue; - if (!gutils->isConstantValue(prev)) - gutils->addToDiffe(prev, post, builder); + if (!post) + continue; + if (!gutils->isConstantValue(prev)) + gutils->addToDiffe(prev, post, builder); } } - } + } // End Perform d(yielded value[i]) += d(result[i]); d(result[i]) = 0 auto start = gutils->popCache(caches[0], builder); auto end = gutils->popCache(caches[1], builder); auto step = gutils->popCache(caches[2], builder); - auto repFor = builder.create( - forOp.getLoc(), start, end, step, ArrayRef()); + auto repFor = builder.create(forOp.getLoc(), start, end, step, + ArrayRef()); // erase scf yield repFor.getBody()->begin()->erase(); - for (auto &&[oldReg, newReg] : llvm::zip(op->getRegions(), repFor->getRegions())) { + for (auto &&[oldReg, newReg] : + llvm::zip(op->getRegions(), repFor->getRegions())) { - // This code assumes at most one terminating block for each region (lest the append happen multiple times) - auto buildFuncReturnOp = [&](OpBuilder &builder, Block* oBB) { - - auto loc = oBB->rbegin()->getLoc(); + // This code assumes at most one terminating block for each region (lest + // the append happen multiple times) + auto buildFuncReturnOp = [&](OpBuilder &builder, Block *oBB) { + auto loc = oBB->rbegin()->getLoc(); - auto idx = repFor.getInductionVar(); + auto idx = repFor.getInductionVar(); - auto lhs = builder.create(loc, idx, step); + auto lhs = builder.create(loc, idx, step); - // This needs to know a condition describing which predecessor this will return to, to select the right value - // Here we use the condition i + step >= end to determine the last iteration + // This needs to know a condition describing which predecessor this will + // return to, to select the right value Here we use the condition i + + // step >= end to determine the last iteration - auto condition = builder.create(loc, arith::CmpIPredicate::sge, lhs, end); + auto condition = builder.create( + loc, arith::CmpIPredicate::sge, lhs, end); - for (auto [arg, init_arg] : llvm::zip(oBB->getArguments(), forOp.getInitArgs())) { - if (!gutils->isConstantValue(arg) && !cast(arg.getType()).isMutable()) { - auto diffe = gutils->diffe(arg, builder); - gutils->zeroDiffe(arg, builder); + for (auto [arg, init_arg] : + llvm::zip(oBB->getArguments().slice(1), forOp.getInitArgs())) { + if (!gutils->isConstantValue(arg) && + !cast(arg.getType()).isMutable()) { + auto diffe = gutils->diffe(arg, builder); + gutils->zeroDiffe(arg, builder); - auto zero = cast(diffe.getType()).createNullValue(builder, loc); - auto outside = builder.create(loc, condition, diffe, zero); - auto inside = builder.create(loc, condition, zero, diffe); + auto zero = cast(diffe.getType()) + .createNullValue(builder, loc); + auto outside = + builder.create(loc, condition, diffe, zero); + auto inside = + builder.create(loc, condition, zero, diffe); - // For each predecessor, if we came from that predecessor += the shadow of the arg [after zero'ing] - if (!gutils->isConstantValue(init_arg)) { - gutils->addToDiffe(init_arg, outside, builder); - } + // For each predecessor, if we came from that predecessor += the + // shadow of the arg [after zero'ing] + if (!gutils->isConstantValue(init_arg)) { + gutils->addToDiffe(init_arg, outside, builder); + } - if (!gutils->isConstantValue(arg)) { - gutils->addToDiffe(arg, inside, builder); - } - } + if (!gutils->isConstantValue(arg)) { + gutils->addToDiffe(arg, inside, builder); } - builder.create(loc); - }; - - for (auto &&[oBB, revBB] : llvm::zip(oldReg, newReg)) { - gutils->mapReverseModeBlocks.map(&oBB, &revBB); - gutils->Logic.visitChildren(&oBB, &revBB, gutils); - Block *newBB = gutils->getNewFromOriginal(&oBB); - gutils->Logic.handlePredecessors(&oBB, newBB, &revBB, gutils, buildFuncReturnOp); + } } + builder.create(loc); + }; + + for (auto &&[oBB, revBB] : llvm::zip(oldReg, newReg)) { + gutils->mapReverseModeBlocks.map(&oBB, &revBB); + } + for (auto &&[oBB, revBB] : llvm::zip(oldReg, newReg)) { + gutils->Logic.visitChildren(&oBB, &revBB, gutils); + Block *newBB = gutils->getNewFromOriginal(&oBB); + gutils->Logic.handlePredecessors(&oBB, newBB, &revBB, gutils, + buildFuncReturnOp); + } } } diff --git a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h index 82db673103cd..dc98a5c4c6a6 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h +++ b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h @@ -124,9 +124,10 @@ class MEnzymeLogic { void initializeShadowValues(SmallVector &dominatorToposortBlocks, MGradientUtilsReverse *gutils); - void handlePredecessors(Block *oBB, Block *newBB, Block *reverseBB, - MGradientUtilsReverse *gutils, - llvm::function_ref buildReturnOp); + void + handlePredecessors(Block *oBB, Block *newBB, Block *reverseBB, + MGradientUtilsReverse *gutils, + llvm::function_ref buildReturnOp); void visitChildren(Block *oBB, Block *reverseBB, MGradientUtilsReverse *gutils); void visitChild(Operation *op, OpBuilder &builder, diff --git a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp index 06ccd344d35f..f39766a0a639 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp @@ -21,7 +21,7 @@ using namespace mlir; using namespace mlir::enzyme; void handleReturns(Block *oBB, Block *newBB, Block *reverseBB, - MGradientUtilsReverse *gutils) { + MGradientUtilsReverse *gutils) { if (oBB->getNumSuccessors() == 0) { Operation *returnStatement = newBB->getTerminator(); gutils->erase(returnStatement); @@ -110,7 +110,8 @@ Create reverse mode adjoint for an operation. */ void MEnzymeLogic::visitChild(Operation *op, OpBuilder &builder, MGradientUtilsReverse *gutils) { - if ((op->getBlock()->getTerminator() != op) && llvm::all_of(op->getResults(), + if ((op->getBlock()->getTerminator() != op) && + llvm::all_of(op->getResults(), [gutils](Value v) { return gutils->isConstantValue(v); }) && gutils->isConstantInstruction(op)) { return; @@ -156,73 +157,80 @@ void MEnzymeLogic::handlePredecessors( Location loc = oBB->rbegin()->getLoc(); // TODO remove dependency on CF dialect - Value cache = gutils->insertInit(gutils->getIndexCacheType()); + Value cache = gutils->insertInit(gutils->getIndexCacheType()); - Value flag = - revBuilder.create(loc, gutils->getIndexType(), cache); + Value flag = + revBuilder.create(loc, gutils->getIndexType(), cache); - Block* defaultBlock = nullptr; + Block *defaultBlock = nullptr; - SmallVector blocks; - SmallVector indices; + SmallVector blocks; + SmallVector indices; - OpBuilder newBuilder(newBB, newBB->begin()); + OpBuilder newBuilder(newBB, newBB->begin()); - SmallVector diffes; - for (auto arg : oBB->getArguments()) { - if (!gutils->isConstantValue(arg) && !cast(arg.getType()).isMutable()) { - diffes.push_back(gutils->diffe(arg, revBuilder)); - gutils->zeroDiffe(arg, revBuilder); - continue; - } - diffes.push_back(nullptr); + SmallVector diffes; + for (auto arg : oBB->getArguments()) { + if (!gutils->isConstantValue(arg) && + !cast(arg.getType()).isMutable()) { + diffes.push_back(gutils->diffe(arg, revBuilder)); + gutils->zeroDiffe(arg, revBuilder); + continue; } + diffes.push_back(nullptr); + } + for (auto [idx, pred] : llvm::enumerate(oBB->getPredecessors())) { + auto reversePred = gutils->mapReverseModeBlocks.lookupOrNull(pred); - for (auto [idx, pred] : llvm::enumerate(oBB->getPredecessors())) { - auto reversePred = gutils->mapReverseModeBlocks.lookupOrNull(pred); - - Block *newPred = gutils->getNewFromOriginal(pred); + Block *newPred = gutils->getNewFromOriginal(pred); - OpBuilder predecessorBuilder(newPred->getTerminator()); + OpBuilder predecessorBuilder(newPred->getTerminator()); - Value pred_idx_c = - predecessorBuilder.create(loc, idx - 1, 32); - predecessorBuilder.create(loc, cache, pred_idx_c); + Value pred_idx_c = + predecessorBuilder.create(loc, idx - 1, 32); + predecessorBuilder.create(loc, cache, pred_idx_c); - if (idx == 0) { - defaultBlock = reversePred; + if (idx == 0) { + defaultBlock = reversePred; - } else { - indices.push_back(APInt(32, idx - 1)); - blocks.push_back(reversePred); - } + } else { + indices.push_back(APInt(32, idx - 1)); + blocks.push_back(reversePred); + } - auto term = pred->getTerminator(); - if (auto iface = dyn_cast(term)) { - for (auto &op : term->getOpOperands()) - if (auto blk_idx = iface.getSuccessorBlockArgument(op.getOperandNumber())) - if ((*blk_idx).getOwner() == oBB) { + auto term = pred->getTerminator(); + if (auto iface = dyn_cast(term)) { + for (auto &op : term->getOpOperands()) + if (auto blk_idx = + iface.getSuccessorBlockArgument(op.getOperandNumber())) + if ((*blk_idx).getOwner() == oBB) { auto idx = (*blk_idx).getArgNumber(); if (diffes[idx]) { - - Value rev_idx_c = revBuilder.create(loc, idx - 1, 32); - auto to_prop = revBuilder.create(loc, revBuilder.create(loc, arith::CmpIPredicate::eq, flag, rev_idx_c), - diffes[idx], cast(diffes[idx].getType()).createNullValue(revBuilder, loc)); + Value rev_idx_c = + revBuilder.create(loc, idx - 1, 32); + + auto to_prop = revBuilder.create( + loc, + revBuilder.create( + loc, arith::CmpIPredicate::eq, flag, rev_idx_c), + diffes[idx], + cast(diffes[idx].getType()) + .createNullValue(revBuilder, loc)); gutils->addToDiffe(op.get(), to_prop, revBuilder); } } - } else { - assert(0 && "predecessor did not implement branch op interface"); - } + } else { + assert(0 && "predecessor did not implement branch op interface"); } + } - revBuilder.create( - loc, flag, defaultBlock, ArrayRef(), ArrayRef(indices), - ArrayRef(blocks), SmallVector(indices.size(), ValueRange())); - + revBuilder.create( + loc, flag, defaultBlock, ArrayRef(), ArrayRef(indices), + ArrayRef(blocks), + SmallVector(indices.size(), ValueRange())); } } @@ -265,7 +273,7 @@ FunctionOpInterface MEnzymeLogic::CreateReverseDiff( Region &oldRegion = gutils->oldFunc.getFunctionBody(); Region &newRegion = gutils->newFunc.getFunctionBody(); - auto buildFuncReturnOp = [&](OpBuilder &builder, Block* oBB) { + auto buildFuncReturnOp = [&](OpBuilder &builder, Block *oBB) { SmallVector retargs; for (auto [arg, cv] : llvm::zip(oBB->getArguments(), constants)) { if (cv == DIFFE_TYPE::OUT_DIFF) { diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp index 3fd241c31437..ebae44c9efa1 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp @@ -41,7 +41,8 @@ mlir::enzyme::MGradientUtils::MGradientUtils( originalToNewFnOps(originalToNewFnOps_), blocksNotForAnalysis(), activityAnalyzer(std::make_unique( blocksNotForAnalysis, constantvalues_, activevals_, ReturnActivity)), - TA(TA_), TR(TR_), omp(omp), width(width), ArgDiffeTypes(ArgDiffeTypes_), RetDiffeTypes(1, ReturnActivity) { + TA(TA_), TR(TR_), omp(omp), width(width), ArgDiffeTypes(ArgDiffeTypes_), + RetDiffeTypes(1, ReturnActivity) { /* for (BasicBlock &BB : *oldFunc) { @@ -170,7 +171,8 @@ mlir::Value mlir::enzyme::MGradientUtils::invertPointerM(mlir::Value v, llvm_unreachable("could not invert pointer"); } -mlir::Value mlir::enzyme::MDiffeGradientUtils::getDifferential(mlir::Value oval) { +mlir::Value +mlir::enzyme::MDiffeGradientUtils::getDifferential(mlir::Value oval) { auto found = differentials.lookupOrNull(oval); if (found != nullptr) return found; @@ -179,17 +181,19 @@ mlir::Value mlir::enzyme::MDiffeGradientUtils::getDifferential(mlir::Value oval) OpBuilder builder(oval.getContext()); builder.setInsertionPointToStart(initializationBlock); - auto shadow = builder.create(oval.getLoc(), enzyme::GradientType::get(oval.getContext(), shadowty)); - auto toset = cast(shadowty).createNullValue(builder, oval.getLoc()); + auto shadow = builder.create( + oval.getLoc(), enzyme::GradientType::get(oval.getContext(), shadowty)); + auto toset = cast(shadowty).createNullValue( + builder, oval.getLoc()); builder.create(oval.getLoc(), shadow, toset); differentials.map(oval, shadow); return shadow; } - -void mlir::enzyme::MDiffeGradientUtils::setDiffe(mlir::Value oval, mlir::Value toset, - OpBuilder &BuilderM) { +void mlir::enzyme::MDiffeGradientUtils::setDiffe(mlir::Value oval, + mlir::Value toset, + OpBuilder &BuilderM) { assert(!isConstantValue(oval)); auto iface = oval.getType().cast(); if (!iface.isMutable()) { @@ -200,18 +204,20 @@ void mlir::enzyme::MDiffeGradientUtils::setDiffe(mlir::Value oval, mlir::Value t } } -void mlir::enzyme::MDiffeGradientUtils::zeroDiffe(mlir::Value oval, OpBuilder &BuilderM) { +void mlir::enzyme::MDiffeGradientUtils::zeroDiffe(mlir::Value oval, + OpBuilder &BuilderM) { assert(!isConstantValue(oval)); auto iface = getShadowType(oval.getType()).cast(); assert(!iface.isMutable()); setDiffe(oval, iface.createNullValue(BuilderM, oval.getLoc()), BuilderM); } - -mlir::Value mlir::enzyme::MDiffeGradientUtils::diffe(mlir::Value oval, OpBuilder &BuilderM) { +mlir::Value mlir::enzyme::MDiffeGradientUtils::diffe(mlir::Value oval, + OpBuilder &BuilderM) { auto shadow = getDifferential(oval); - return BuilderM.create(oval.getLoc(), getShadowType(oval.getType()), shadow); + return BuilderM.create(oval.getLoc(), + getShadowType(oval.getType()), shadow); } void mlir::enzyme::MGradientUtils::setDiffe(mlir::Value val, mlir::Value toset, @@ -271,13 +277,15 @@ void mlir::enzyme::MGradientUtils::forceAugmentedReturns() { continue; auto i = val.getArgNumber(); if (mode == DerivativeMode::ForwardMode || - mode == DerivativeMode::ForwardModeSplit || cast(val.getType()).isMutable()) { + mode == DerivativeMode::ForwardModeSplit || + cast(val.getType()).isMutable()) { mlir::Value dval; if (i == blk->getArguments().size() - 1) dval = nblk->addArgument(getShadowType(val.getType()), val.getLoc()); else - dval = nblk->insertArgument(nblk->args_begin() + i + 1, - getShadowType(val.getType()), val.getLoc()); + dval = + nblk->insertArgument(nblk->args_begin() + i + 1, + getShadowType(val.getType()), val.getLoc()); invertedPointers.map(val, dval); } @@ -290,15 +298,16 @@ void mlir::enzyme::MGradientUtils::forceAugmentedReturns() { OpBuilder BuilderZ(getNewFromOriginal(inst)); for (auto res : inst->getResults()) { - if (isConstantValue(res)) continue; + if (isConstantValue(res)) + continue; if (!(mode == DerivativeMode::ForwardMode || - mode == DerivativeMode::ForwardModeSplit || cast(res.getType()).isMutable())) - continue; - mlir::Type antiTy = getShadowType(res.getType()); - auto anti = - BuilderZ.create(res.getLoc(), antiTy); - invertedPointers.map(res, anti); + mode == DerivativeMode::ForwardModeSplit || + cast(res.getType()).isMutable())) + continue; + mlir::Type antiTy = getShadowType(res.getType()); + auto anti = BuilderZ.create(res.getLoc(), antiTy); + invertedPointers.map(res, anti); } }); } diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h index ef01f0542046..32a7fe068d07 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h @@ -88,14 +88,14 @@ class MDiffeGradientUtils : public MGradientUtils { IRMapping differentials; Block *initializationBlock; + public: - mlir::Value getDifferential(mlir::Value origv); void setDiffe(mlir::Value origv, mlir::Value newv, mlir::OpBuilder &builder); void zeroDiffe(mlir::Value origv, mlir::OpBuilder &builder); - + mlir::Value diffe(mlir::Value origv, mlir::OpBuilder &builder); MDiffeGradientUtils(MEnzymeLogic &Logic, FunctionOpInterface newFunc_, @@ -111,8 +111,8 @@ class MDiffeGradientUtils : public MGradientUtils { : MGradientUtils(Logic, newFunc_, oldFunc_, TA, TR, invertedPointers_, constantvalues_, activevals_, ActiveReturn, constant_values, origToNew_, origToNewOps_, mode, width, - omp), - initializationBlock(&*(newFunc.getFunctionBody().begin())) {} + omp), + initializationBlock(&*(newFunc.getFunctionBody().begin())) {} // Technically diffe constructor static MDiffeGradientUtils * diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp b/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp index 0e8c2752f101..20d34247fb3a 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp @@ -40,8 +40,7 @@ mlir::enzyme::MGradientUtilsReverse::MGradientUtilsReverse( invertedPointers_, constantvalues_, activevals_, ReturnActivity, ArgDiffeTypes_, originalToNewFn_, originalToNewFnOps_, mode_, width, /*omp*/ false), - symbolTable(symbolTable_) { -} + symbolTable(symbolTable_) {} Type mlir::enzyme::MGradientUtilsReverse::getIndexCacheType() { Type indexType = getIndexType(); @@ -117,7 +116,7 @@ void mlir::enzyme::MGradientUtilsReverse::addToDiffe(Value oldGradient, Value operandGradient = diffe(oldGradient, builder); auto iface = cast(addedGradient.getType()); auto added = iface.createAddOp(builder, oldGradient.getLoc(), operandGradient, - addedGradient); + addedGradient); setDiffe(oldGradient, added, builder); } diff --git a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp index 963ef28d79a1..de7328dcb186 100644 --- a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp @@ -130,7 +130,7 @@ struct DifferentiatePass : public DifferentiatePassBase { auto *symbolOp = symbolTable.lookupNearestSymbolFrom(CI, CI.getFnAttr()); auto fn = cast(symbolOp); - + auto mode = DerivativeMode::ReverseModeCombined; DIFFE_TYPE retType = mode_from_fn(fn, mode); @@ -138,8 +138,6 @@ struct DifferentiatePass : public DifferentiatePassBase { mlir::Value res = CI.getInputs()[CI.getInputs().size() - 1]; args.push_back(res); - - MTypeAnalysis TA; auto type_args = TA.getAnalyzedTypeInfo(fn); bool freeMemory = true; diff --git a/enzyme/Enzyme/MLIR/Passes/RemoveUnusedEnzymeOps.cpp b/enzyme/Enzyme/MLIR/Passes/RemoveUnusedEnzymeOps.cpp index dff94c63f16c..5ced93e86512 100644 --- a/enzyme/Enzyme/MLIR/Passes/RemoveUnusedEnzymeOps.cpp +++ b/enzyme/Enzyme/MLIR/Passes/RemoveUnusedEnzymeOps.cpp @@ -22,6 +22,7 @@ #include "mlir/Transforms/DialectConversion.h" #include "mlir/Rewrite/PatternApplicator.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/IR/Dominance.h" #include "llvm/Support/raw_ostream.h" @@ -30,10 +31,9 @@ using namespace mlir; using namespace enzyme; namespace { - // Starting at the beginning of blk, is there a path that can execute -// check before end. -bool mayExecuteBefore(Block* blk, Operation* check, Operation *end) { +// check before end. +bool mayExecuteBefore(Block *blk, Operation *check, Operation *end) { auto reg = blk->getParent(); assert(reg->isAncestor(end->getParentRegion())); @@ -57,8 +57,8 @@ bool mayExecuteBefore(Block* blk, Operation* check, Operation *end) { // If we've seen the thing to check with, it may execute before if (op.isAncestor(check)) { - // The sole exception to this is if they are in the same sub region, which is - // known to execute only once. TODO this later + // The sole exception to this is if they are in the same sub region, + // which is known to execute only once. TODO this later /* if (op.isAncestor(end)) { @@ -71,15 +71,16 @@ bool mayExecuteBefore(Block* blk, Operation* check, Operation *end) { return true; } - // Otherwise if we've seen the end op, this path is over as the route we found here - // didn't first find a check. + // Otherwise if we've seen the end op, this path is over as the route we + // found here didn't first find a check. if (op.isAncestor(end)) { seenEnd = true; break; } } - if (seenEnd) continue; + if (seenEnd) + continue; // If we didn't find the end, try all successors for (auto succ : cur->getSuccessors()) { @@ -90,10 +91,10 @@ bool mayExecuteBefore(Block* blk, Operation* check, Operation *end) { return false; } -bool mayExecuteBetween(Operation *start, Operation* check, Operation *end) { +bool mayExecuteBetween(Operation *start, Operation *check, Operation *end) { - for (auto op = start->getNextNode(); op != nullptr; op++) { - // This check op has been found after start in its block + for (auto op = start->getNextNode(); op != nullptr; op = op->getNextNode()) { + // This check op has been found after start in its block if (op->isAncestor(check)) { return true; } @@ -102,7 +103,7 @@ bool mayExecuteBetween(Operation *start, Operation* check, Operation *end) { } } - Block* blk = start->getBlock(); + Block *blk = start->getBlock(); auto reg = blk->getParent(); if (reg->isAncestor(end->getParentRegion())) { @@ -118,45 +119,46 @@ bool mayExecuteBetween(Operation *start, Operation* check, Operation *end) { return mayExecuteBetween(start->getParentOp(), check, end); } -// TODO this isn't necessarily correct. This is because there could be a -// non dominating use bewteen the dominating one and the op, causing +// TODO this isn't necessarily correct. This is because there could be a +// non dominating use bewteen the dominating one and the op, causing // correctness issues when not seen. In interim, be conservative and only // succeed if these have the same parent block, and no other ops in path -template +template T findNearestDominatingOpByUse(Operation *op, Value v) { DominanceInfo dInfo; PostDominanceInfo pdInfo; SmallVector options; + SmallVector conflicts; for (Operation *userSet : v.getUsers()) { if (auto setOp = dyn_cast(userSet)) { options.push_back(setOp); + conflicts.push_back(setOp); + continue; + } + if (auto setOp = dyn_cast(userSet)) { + conflicts.push_back(setOp); + continue; } } - if (options.size() == 1 && dInfo.dominates(options[0], op)) - return options[0]; - llvm::errs() << " scope: " << *op->getParentOp() << "\n"; - llvm::errs() << " want to replace " << *op << "\n"; for (auto opt : options) { if (!dInfo.dominates(opt, op)) continue; bool conflict = false; - llvm::errs() << " trying: " << *opt << "\n"; - for (auto opt2 : options) { - if (opt == opt2) continue; - - llvm::errs() << " conflict check: " << *opt2 << "\n"; + for (auto opt2 : conflicts) { + if (opt == opt2) + continue; + if (opt2 == op) + continue; if (!mayExecuteBetween(opt, opt2, op)) { - llvm::errs() << " + known good since occurs before store\n"; continue; } conflict = true; } if (!conflict) { - llvm::errs() << " - replaced with " << *opt << "\n"; return opt; } } @@ -164,78 +166,137 @@ T findNearestDominatingOpByUse(Operation *op, Value v) { return nullptr; } +struct PopSimplify : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(enzyme::PopOp pop, + PatternRewriter &rewriter) const final { + + auto init = pop.getCache().getDefiningOp(); + if (!init) + return failure(); + + SmallVector pops; + SmallVector pushes; + for (Operation *userSet : init.getResult().getUsers()) { + if (auto push = dyn_cast(userSet)) { + pushes.push_back(push); + continue; + } + if (auto pop = dyn_cast(userSet)) { + pops.push_back(pop); + continue; + } + return failure(); + } + + if (auto push = findNearestDominatingOpByUse( + pop, init)) { + // Do the block check to conservatively avoid multi execute push/pop + if (pop->getBlock() == push->getBlock()) { + rewriter.replaceOp(pop, push.getValue()); + rewriter.eraseOp(push); + return success(); + } + } + + return failure(); + } +}; + +struct GetSimplify : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(enzyme::GetOp get, + PatternRewriter &rewriter) const final { + + auto init = get.getGradient().getDefiningOp(); + if (!init) + return failure(); + + for (Operation *userSet : init.getResult().getUsers()) { + if (isa(userSet)) + continue; + if (isa(userSet)) + continue; + return failure(); + } + + if (auto set = findNearestDominatingOpByUse(get, init)) { + rewriter.replaceOp(get, set.getValue()); + return success(); + } + return failure(); + } +}; + +struct SetSimplify : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(enzyme::SetOp get, + PatternRewriter &rewriter) const final { + + auto init = get.getGradient().getDefiningOp(); + if (!init) + return failure(); + + for (Operation *userSet : init.getResult().getUsers()) { + if (isa(userSet)) + continue; + return failure(); + } + + rewriter.eraseOp(get); + return success(); + } +}; + +struct PushSimplify : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(enzyme::PushOp get, + PatternRewriter &rewriter) const final { + + auto init = get.getCache().getDefiningOp(); + if (!init) + return failure(); + + for (Operation *userSet : init.getResult().getUsers()) { + if (isa(userSet)) + continue; + return failure(); + } + + rewriter.eraseOp(get); + return success(); + } +}; + +struct InitSimplify : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(enzyme::InitOp get, + PatternRewriter &rewriter) const final { + + if (get.use_empty()) { + rewriter.eraseOp(get); + return success(); + } + return failure(); + } +}; + struct RemoveUnusedEnzymeOpsPass : public enzyme::RemoveUnusedEnzymeOpsPassBase { void runOnOperation() override { - SmallVector inits; - getOperation()->walk([&](Operation *op) { - if (auto initOp = dyn_cast(op)) { - inits.push_back(initOp); - } - }); - - for (auto initOp : inits) { - DominanceInfo dInfo; - Value v = initOp; - if (auto type = dyn_cast(initOp.getType())) { - bool replaceable = true; - for (Operation *userSet : v.getUsers()) { - if (isa(userSet)) continue; - if (isa(userSet)) continue; - llvm::errs() << " unknown user of grad: " << *userSet << "\n"; - replaceable = false; - } - if (replaceable) { - // Do replacing - bool allDelete = true; - for (Operation *userGet : make_early_inc_range(v.getUsers())) { - if (auto getOp = dyn_cast(userGet)) { - if (auto setOp = - findNearestDominatingOpByUse(userGet, v)) { - getOp.replaceAllUsesWith(setOp.getValue()); - getOp->erase(); - continue; - } - allDelete = false; - } - } - if (allDelete) { - for (Operation *userGet : make_early_inc_range(v.getUsers())) { - userGet->erase(); - } - initOp->erase(); - } - continue; - } - } else if (auto type = dyn_cast(initOp.getType())) { - bool replaceable = true; - - SmallVector pops; - for (Operation *userSet : v.getUsers()) { - if (isa(userSet)) continue; - if (auto pop = dyn_cast(userSet)) { - pops.push_back(pop); - continue; - } - llvm::errs() << " unknown user of cache: " << *userSet << "\n"; - replaceable = false; - } + RewritePatternSet patterns(&getContext()); + patterns.insert(&getContext()); - if (replaceable) - for (auto pop : pops) { - if (auto push = findNearestDominatingOpByUse(pop, v)) { - pop.replaceAllUsesWith(push.getValue()); - pop->erase(); - push->erase(); - } - } - if (v.use_empty()) { - initOp->erase(); - } - continue; - } - } + GreedyRewriteConfig config; + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), + config); } }; diff --git a/enzyme/Enzyme/MLIR/Passes/SimplifyMath.cpp b/enzyme/Enzyme/MLIR/Passes/SimplifyMath.cpp index 5b93e5ae8242..de04163fea8f 100644 --- a/enzyme/Enzyme/MLIR/Passes/SimplifyMath.cpp +++ b/enzyme/Enzyme/MLIR/Passes/SimplifyMath.cpp @@ -24,7 +24,7 @@ using namespace enzyme; using llvm::errs; namespace { - struct AddSimplify : public OpRewritePattern { +struct AddSimplify : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(arith::AddFOp op, @@ -44,7 +44,7 @@ namespace { } }; - struct SubSimplify : public OpRewritePattern { +struct SubSimplify : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(arith::SubFOp op, diff --git a/enzyme/test/MLIR/Passes/dualpush.mlir b/enzyme/test/MLIR/Passes/dualpush.mlir new file mode 100644 index 000000000000..582feddeabff --- /dev/null +++ b/enzyme/test/MLIR/Passes/dualpush.mlir @@ -0,0 +1,48 @@ +// RUN: %eopt -remove-unnecessary-enzyme-ops %s | FileCheck %s + +// This pop cannot be removed even though we know the first popped value with be -1 +// the other pops will be conditional + +module { + func.func private @diffebbargs(%arg0: f64) { + %c0_i32 = arith.constant 0 : i32 + %c-1_i32 = arith.constant -1 : i32 + %cst = arith.constant 0.000000e+00 : f64 + %3 = "enzyme.init"() : () -> !enzyme.Cache + "enzyme.push"(%3, %c0_i32) : (!enzyme.Cache, i32) -> () + cf.br ^bb1(%arg0 : f64) + ^bb1(%7: f64): // 2 preds: ^bb0, ^bb1 + %8 = arith.cmpf ult, %7, %cst : f64 + "enzyme.push"(%3, %c-1_i32) : (!enzyme.Cache, i32) -> () + cf.cond_br %8, ^bb1(%7 : f64), ^bb4 + ^bb4: // 2 preds: ^bb3, ^bb4 + %18 = "enzyme.pop"(%3) : (!enzyme.Cache) -> i32 + cf.switch %18 : i32, [ + default: ^bb4, + 0: ^bb5 + ] + ^bb5: // pred: ^bb4 + return + } +} + +// CHECK: func.func private @diffebbargs(%arg0: f64) { +// CHECK-NEXT: %c0_i32 = arith.constant 0 : i32 +// CHECK-NEXT: %c-1_i32 = arith.constant -1 : i32 +// CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f64 +// CHECK-NEXT: %0 = "enzyme.init"() : () -> !enzyme.Cache +// CHECK-NEXT: "enzyme.push"(%0, %c0_i32) : (!enzyme.Cache, i32) -> () +// CHECK-NEXT: cf.br ^bb1(%arg0 : f64) +// CHECK-NEXT: ^bb1(%1: f64): // 2 preds: ^bb0, ^bb1 +// CHECK-NEXT: %2 = arith.cmpf ult, %1, %cst : f64 +// CHECK-NEXT: "enzyme.push"(%0, %c-1_i32) : (!enzyme.Cache, i32) -> () +// CHECK-NEXT: cf.cond_br %2, ^bb1(%1 : f64), ^bb2 +// CHECK-NEXT: ^bb2: // 2 preds: ^bb1, ^bb2 +// CHECK-NEXT: %3 = "enzyme.pop"(%0) : (!enzyme.Cache) -> i32 +// CHECK-NEXT: cf.switch %3 : i32, [ +// CHECK-NEXT: default: ^bb2, +// CHECK-NEXT: 0: ^bb3 +// CHECK-NEXT: ] +// CHECK-NEXT: ^bb3: // pred: ^bb2 +// CHECK-NEXT: return +// CHECK-NEXT: } \ No newline at end of file diff --git a/enzyme/test/MLIR/ReverseMode/bbarg-order.mlir b/enzyme/test/MLIR/ReverseMode/bbarg-order.mlir index 1dd48c3b89f7..141ff46aaade 100644 --- a/enzyme/test/MLIR/ReverseMode/bbarg-order.mlir +++ b/enzyme/test/MLIR/ReverseMode/bbarg-order.mlir @@ -1,4 +1,4 @@ -// RUN: %eopt --enzyme --remove-unnecessary-enzyme-ops -canonicalize %s | FileCheck %s +// RUN: %eopt --enzyme -canonicalize --remove-unnecessary-enzyme-ops -canonicalize %s | FileCheck %s module { func.func @bbargs(%x: f64) -> f64 { @@ -25,49 +25,32 @@ module { // CHECK-NEXT: %cst = arith.constant 1.000000e+00 : f64 // CHECK-NEXT: %cst_0 = arith.constant 0.000000e+00 : f64 // CHECK-NEXT: %0 = "enzyme.init"() : () -> !enzyme.Cache -// CHECK-NEXT: %1 = "enzyme.init"() : () -> !enzyme.Gradient -// CHECK-NEXT: "enzyme.set"(%1, %cst_0) : (!enzyme.Gradient, f64) -> () -// CHECK-NEXT: %2 = "enzyme.init"() : () -> !enzyme.Cache -// CHECK-NEXT: %3 = "enzyme.init"() : () -> !enzyme.Gradient -// CHECK-NEXT: "enzyme.set"(%3, %cst_0) : (!enzyme.Gradient, f64) -> () -// CHECK-NEXT: %4 = arith.addf %arg0, %cst : f64 -// CHECK-NEXT: "enzyme.push"(%2, %c0_i32) : (!enzyme.Cache, i32) -> () -// CHECK-NEXT: cf.br ^bb1(%4 : f64) -// CHECK-NEXT: ^bb1(%5: f64): // 2 preds: ^bb0, ^bb1 -// CHECK-NEXT: %6 = arith.cmpf ult, %5, %cst_0 : f64 -// CHECK-NEXT: "enzyme.push"(%2, %c-1_i32) : (!enzyme.Cache, i32) -> () +// CHECK-NEXT: %1 = "enzyme.init"() : () -> !enzyme.Cache +// CHECK-NEXT: %2 = arith.addf %arg0, %cst : f64 +// CHECK-NEXT: "enzyme.push"(%1, %c0_i32) : (!enzyme.Cache, i32) -> () +// CHECK-NEXT: cf.br ^bb1(%2 : f64) +// CHECK-NEXT: ^bb1(%3: f64): // 2 preds: ^bb0, ^bb1 +// CHECK-NEXT: %4 = arith.cmpf ult, %3, %cst_0 : f64 +// CHECK-NEXT: "enzyme.push"(%1, %c-1_i32) : (!enzyme.Cache, i32) -> () // CHECK-NEXT: "enzyme.push"(%0, %c-1_i32) : (!enzyme.Cache, i32) -> () -// CHECK-NEXT: cf.cond_br %6, ^bb1(%5 : f64), ^bb2 +// CHECK-NEXT: cf.cond_br %4, ^bb1(%3 : f64), ^bb2 // CHECK-NEXT: ^bb2: // pred: ^bb1 -// CHECK-NEXT: %7 = arith.addf %arg1, %cst_0 : f64 -// CHECK-NEXT: %8 = "enzyme.pop"(%0) : (!enzyme.Cache) -> i32 -// CHECK-NEXT: %9 = arith.cmpi eq, %8, %c-1_i32 : i32 -// CHECK-NEXT: %10 = arith.select %9, %7, %cst_0 : f64 -// CHECK-NEXT: %11 = "enzyme.get"(%1) : (!enzyme.Gradient) -> f64 -// CHECK-NEXT: %12 = arith.addf %11, %10 : f64 -// CHECK-NEXT: "enzyme.set"(%1, %12) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: %5 = arith.addf %arg1, %cst_0 : f64 +// CHECK-NEXT: %6 = "enzyme.pop"(%0) : (!enzyme.Cache) -> i32 +// CHECK-NEXT: %7 = arith.cmpi eq, %6, %c-1_i32 : i32 +// CHECK-NEXT: %8 = arith.select %7, %5, %cst_0 : f64 +// CHECK-NEXT: %9 = arith.addf %8, %cst_0 : f64 // CHECK-NEXT: cf.br ^bb3 // CHECK-NEXT: ^bb3: // 2 preds: ^bb2, ^bb3 -// CHECK-NEXT: %13 = "enzyme.pop"(%2) : (!enzyme.Cache) -> i32 -// CHECK-NEXT: %14 = "enzyme.get"(%1) : (!enzyme.Gradient) -> f64 -// CHECK-NEXT: "enzyme.set"(%1, %cst_0) : (!enzyme.Gradient, f64) -> () -// CHECK-NEXT: %15 = arith.cmpi eq, %13, %c-1_i32 : i32 -// CHECK-NEXT: %16 = arith.select %15, %14, %cst_0 : f64 -// CHECK-NEXT: %17 = "enzyme.get"(%1) : (!enzyme.Gradient) -> f64 -// CHECK-NEXT: %18 = arith.addf %17, %16 : f64 -// CHECK-NEXT: "enzyme.set"(%1, %18) : (!enzyme.Gradient, f64) -> () -// CHECK-NEXT: %19 = arith.cmpi eq, %13, %c-1_i32 : i32 -// CHECK-NEXT: %20 = arith.select %19, %14, %cst_0 : f64 -// CHECK-NEXT: %21 = "enzyme.get"(%3) : (!enzyme.Gradient) -> f64 -// CHECK-NEXT: %22 = arith.addf %21, %20 : f64 -// CHECK-NEXT: "enzyme.set"(%3, %22) : (!enzyme.Gradient, f64) -> () -// CHECK-NEXT: cf.switch %13 : i32, [ +// CHECK-NEXT: %10 = "enzyme.pop"(%1) : (!enzyme.Cache) -> i32 +// CHECK-NEXT: %11 = arith.cmpi eq, %10, %c-1_i32 : i32 +// CHECK-NEXT: %12 = arith.select %11, %9, %cst_0 : f64 +// CHECK-NEXT: %13 = arith.addf %12, %cst_0 : f64 +// CHECK-NEXT: cf.switch %10 : i32, [ // CHECK-NEXT: default: ^bb3, // CHECK-NEXT: 0: ^bb4 // CHECK-NEXT: ] // CHECK-NEXT: ^bb4: // pred: ^bb3 -// CHECK-NEXT: %23 = "enzyme.get"(%3) : (!enzyme.Gradient) -> f64 -// CHECK-NEXT: "enzyme.set"(%3, %cst_0) : (!enzyme.Gradient, f64) -> () -// CHECK-NEXT: %24 = arith.addf %23, %cst_0 : f64 -// CHECK-NEXT: return %24 : f64 +// CHECK-NEXT: %14 = arith.addf %13, %cst_0 : f64 +// CHECK-NEXT: return %14 : f64 // CHECK-NEXT: } diff --git a/enzyme/test/MLIR/ReverseMode/pow.mlir b/enzyme/test/MLIR/ReverseMode/pow.mlir index e0a74d88145a..9934152def61 100644 --- a/enzyme/test/MLIR/ReverseMode/pow.mlir +++ b/enzyme/test/MLIR/ReverseMode/pow.mlir @@ -1,4 +1,4 @@ -// RUN: %eopt --enzyme --remove-unnecessary-enzyme-ops -canonicalize %s | FileCheck %s +// RUN: %eopt --enzyme -canonicalize --remove-unnecessary-enzyme-ops -enzyme-simplify-math -canonicalize %s | FileCheck %s module { func.func @ppow(%x: f64) -> f64 { @@ -19,29 +19,46 @@ module { } } -// CHECK: func.func private @diffeppow(%[[x:.+]]: f64, %[[dr:.+]]: f64) -> f64 +// CHECK: func.func private @diffeppow(%[[x:.+]]: f64, %[[dr:.+]]: f64) -> f64 { +// CHECK-NEXT: %c10 = arith.constant 10 : index +// CHECK-NEXT: %c1 = arith.constant 1 : index +// CHECK-NEXT: %c0 = arith.constant 0 : index +// CHECK-NEXT: %[[one:.+]] = arith.constant 1.0 +// CHECK-NEXT: %[[zero:.+]] = arith.constant 0.000000e+00 : f64 +// CHECK-NEXT: %[[xshadow:.+]] = "enzyme.init"() : () -> !enzyme.Gradient +// CHECK-NEXT: "enzyme.set"(%[[xshadow]], %[[zero]]) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: %[[itshadow:.+]] = "enzyme.init"() : () -> !enzyme.Gradient +// CHECK-NEXT: "enzyme.set"(%[[itshadow]], %[[zero]]) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: %[[xcache:.+]] = "enzyme.init"() : () -> !enzyme.Cache +// CHECK-NEXT: %[[rcache:.+]] = "enzyme.init"() : () -> !enzyme.Cache +// CHECK-NEXT: %[[rshadow:.+]] = "enzyme.init"() : () -> !enzyme.Gradient +// CHECK-NEXT: "enzyme.set"(%[[rshadow]], %[[zero]]) : (!enzyme.Gradient, f64) -> () -// Make sure the right values are being cached in the primal -// CHECK: %[[one:.+]] = arith.constant 1.0 -// CHECK: scf.for %[[iv:.+]] = %c0 to %c10 step %c1 iter_args(%[[r_it:.+]] = %[[one]]) -// CHECK-NEXT: "enzyme.push"(%[[rcache:.+]], %[[r_it]]) -// CHECK-NEXT: "enzyme.push"(%[[xcache:.+]], %[[x]]) - -// Ensure the right value is yielded in the adjoint -// CHECK: "enzyme.set"(%[[rshadow:.+]], %[[dr]]) -// CHECK: %[[dr:.+]] = "enzyme.get"(%[[rshadow]]) -// CHECK: scf.for %[[iv:.+]] = %[[lb:.+]] to %[[ub:.+]] step %[[step:.+]] iter_args(%[[dr_it:.+]] = %[[dr]]) -// CHECK-NEXT: "enzyme.set"(%[[rshadow:.+]], %[[dr_it]]) -// CHECK-NEXT: %[[dr_it:.+]] = "enzyme.get"(%[[rshadow]]) -// CHECK-NEXT: %[[r_cached:.+]] = "enzyme.pop"(%[[rcache]]) -// CHECK-NEXT: %[[x:.+]] = "enzyme.pop"(%[[xcache]]) -// CHECK-NEXT: %[[dr_next:.+]] = arith.mulf %[[dr_it]], %[[x]] -// CHECK-NEXT: "enzyme.set"(%[[rshadow:.+]], %[[dr_next]]) -// CHECK-NEXT: %[[dx_next:.+]] = arith.mulf %[[dr_it]], %[[r_cached]] -// CHECK-NEXT: %[[dx0:.+]] = "enzyme.get"(%[[xshadow:.+]]) : -// CHECK-NEXT: %[[dx1:.+]] = arith.addf %[[dx0]], %[[dx_next]] -// CHECK-NEXT: "enzyme.set"(%[[xshadow]], %[[dx1]]) -// CHECK-NEXT: %[[dr_next:.+]] = "enzyme.get"(%[[rshadow]]) -// CHECK-NEXT: scf.yield %[[dr_next]] -// CHECK: %[[final:.+]] = "enzyme.get"(%[[xshadow]]) -// CHECK-NEXT: return %[[final]] +// CHECK-NEXT: %{{.+}} = scf.for %[[iv:.+]] = %c0 to %c10 step %c1 iter_args(%[[r_it:.+]] = %[[one]]) -> (f64) { +// CHECK-NEXT: "enzyme.push"(%[[rcache]], %[[r_it]]) : (!enzyme.Cache, f64) -> () +// CHECK-NEXT: "enzyme.push"(%[[xcache]], %[[x]]) : (!enzyme.Cache, f64) -> () +// CHECK-NEXT: %[[fwd:.+]] = arith.mulf %[[r_it]], %[[x]] : f64 +// CHECK-NEXT: scf.yield %[[fwd]] : f64 +// CHECK-NEXT: } +// CHECK-NEXT: "enzyme.set"(%[[rshadow]], %[[dr]]) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: scf.for %[[div:.+]] = %c0 to %c10 step %c1 { +// CHECK-NEXT: %[[dr_it:.+]] = "enzyme.get"(%[[rshadow]]) : (!enzyme.Gradient) -> f64 +// CHECK-NEXT: "enzyme.set"(%[[rshadow]], %[[zero]]) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: %[[r_cached:.+]] = "enzyme.pop"(%[[rcache]]) : (!enzyme.Cache) -> f64 +// CHECK-NEXT: %[[x_cached:.+]] = "enzyme.pop"(%[[xcache]]) : (!enzyme.Cache) -> f64 +// CHECK-NEXT: %[[dr_next:.+]] = arith.mulf %[[dr_it]], %[[x_cached]] +// CHECK-NEXT: %[[previts:.+]] = "enzyme.get"(%[[itshadow]]) : (!enzyme.Gradient) -> f64 +// CHECK-NEXT: %[[postits:.+]] = arith.addf %[[previts]], %[[dr_next]] : f64 +// CHECK-NEXT: "enzyme.set"(%[[itshadow]], %[[postits]]) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: %[[dx_next:.+]] = arith.mulf %[[dr_it]], %[[r_cached]] : f64 +// CHECK-NEXT: %[[dx0:.+]] = "enzyme.get"(%[[xshadow]]) : +// CHECK-NEXT: %[[dx1:.+]] = arith.addf %[[dx0]], %[[dx_next]] +// CHECK-NEXT: "enzyme.set"(%[[xshadow]], %[[dx1]]) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: %[[divp1:.+]] = arith.addi %[[div]], %c1 : index +// CHECK-NEXT: %[[last:.+]] = arith.cmpi sge, %[[divp1]], %c10 : index +// CHECK-NEXT: "enzyme.set"(%[[itshadow]], %[[zero]]) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: %[[sel:.+]] = arith.select %[[last]], %[[zero]], %12 : f64 +// CHECK-NEXT: "enzyme.set"(%[[itshadow]], %[[sel]]) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: } +// CHECK-NEXT: %[[final:.+]] = "enzyme.get"(%[[xshadow]]) +// CHECK-NEXT: return %[[final]] \ No newline at end of file diff --git a/enzyme/test/MLIR/ReverseMode/square.mlir b/enzyme/test/MLIR/ReverseMode/square.mlir index 37d57d426033..4bcae3bb8000 100644 --- a/enzyme/test/MLIR/ReverseMode/square.mlir +++ b/enzyme/test/MLIR/ReverseMode/square.mlir @@ -1,6 +1,6 @@ // RUN: %eopt --enzyme %s | FileCheck %s -// RUN: %eopt --enzyme --remove-unnecessary-enzyme-ops %s | FileCheck %s --check-prefix=REM -// RUN: %eopt --enzyme --remove-unnecessary-enzyme-ops --canonicalize --enzyme-simplify-math --cse %s | FileCheck %s --check-prefix=FIN +// RUN: %eopt --enzyme --canonicalize --remove-unnecessary-enzyme-ops %s | FileCheck %s --check-prefix=REM +// RUN: %eopt --enzyme --canonicalize --remove-unnecessary-enzyme-ops --canonicalize --enzyme-simplify-math --cse %s | FileCheck %s --check-prefix=FIN module { func.func @square(%x: f64) -> f64 { @@ -60,14 +60,9 @@ module { // REM: func.func private @diffesquare(%arg0: f64, %arg1: f64) -> f64 { // REM-NEXT: %[[cst:.+]] = arith.constant 0.000000e+00 : f64 -// REM-NEXT: %[[cst_0:.+]] = arith.constant 0.000000e+00 : f64 -// REM-NEXT: %[[pmu:.+]] = arith.mulf %arg0, %arg0 : f64 -// REM-NEXT: cf.br ^bb1 -// REM-NEXT: ^bb1: // pred: ^bb0 -// REM-NEXT: %[[a1:.+]] = arith.addf %[[cst_0]], %arg1 : f64 -// REM-NEXT: %[[cst_1:.+]] = arith.constant 0.000000e+00 : f64 +// REM-NEXT: %[[a1:.+]] = arith.addf %arg1, %[[cst]] : f64 // REM-NEXT: %[[a2:.+]] = arith.mulf %[[a1]], %arg0 : f64 -// REM-NEXT: %[[a3:.+]] = arith.addf %[[cst]], %[[a2]] : f64 +// REM-NEXT: %[[a3:.+]] = arith.addf %[[a2]], %[[cst]] : f64 // REM-NEXT: %[[a4:.+]] = arith.mulf %[[a1]], %arg0 : f64 // REM-NEXT: %[[a5:.+]] = arith.addf %[[a3]], %[[a4]] : f64 // REM-NEXT: return %[[a5]] : f64