From b96c4439ee0be92cea74d35d3dd33ac53bf9ca5d Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 29 Feb 2024 20:49:35 -0800 Subject: [PATCH] Cleanup and Fixup MLIR reverse mode (#1771) --- enzyme/BUILD | 18 + enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td | 42 +- .../BuiltinAutoDiffTypeInterfaceImpl.cpp | 6 +- .../MLIR/Implementations/CFDerivatives.td | 1 + .../MLIR/Implementations/CMakeLists.txt | 7 + enzyme/Enzyme/MLIR/Implementations/Common.td | 5 + .../CoreDialectsAutoDiffImplementations.cpp | 28 +- .../CoreDialectsAutoDiffImplementations.h | 55 +++ .../FuncAutoDiffOpInterfaceImpl.cpp | 37 ++ .../MLIR/Implementations/FuncDerivatives.td | 3 + .../LLVMAutoDiffOpInterfaceImpl.cpp | 2 +- .../LinalgAutoDiffOpInterfaceImpl.cpp | 18 +- .../MemRefAutoDiffOpInterfaceImpl.cpp | 59 +-- .../SCFAutoDiffOpInterfaceImpl.cpp | 138 +++++-- .../MLIR/Interfaces/AutoDiffTypeInterface.td | 4 +- enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h | 15 +- .../MLIR/Interfaces/EnzymeLogicReverse.cpp | 320 +++++---------- .../Enzyme/MLIR/Interfaces/GradientUtils.cpp | 142 +++---- enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h | 25 +- .../MLIR/Interfaces/GradientUtilsReverse.cpp | 233 +---------- .../MLIR/Interfaces/GradientUtilsReverse.h | 37 +- enzyme/Enzyme/MLIR/Passes/CMakeLists.txt | 2 +- enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp | 29 +- enzyme/Enzyme/MLIR/Passes/Passes.h | 2 +- enzyme/Enzyme/MLIR/Passes/Passes.td | 6 +- .../MLIR/Passes/RemoveUnusedEnzymeOps.cpp | 372 ++++++++++++------ .../MLIR/Passes/ShadowedGradientToCache.cpp | 64 --- enzyme/Enzyme/MLIR/Passes/SimplifyMath.cpp | 88 +++++ enzyme/test/MLIR/Passes/dualpush.mlir | 48 +++ enzyme/test/MLIR/ReverseMode/bbarg-order.mlir | 49 ++- enzyme/test/MLIR/ReverseMode/pow.mlir | 69 ++-- enzyme/test/MLIR/ReverseMode/square.mlir | 75 ++++ enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp | 10 +- 33 files changed, 1101 insertions(+), 908 deletions(-) create mode 100644 enzyme/Enzyme/MLIR/Implementations/FuncAutoDiffOpInterfaceImpl.cpp create mode 100644 enzyme/Enzyme/MLIR/Implementations/FuncDerivatives.td delete mode 100644 enzyme/Enzyme/MLIR/Passes/ShadowedGradientToCache.cpp create mode 100644 enzyme/Enzyme/MLIR/Passes/SimplifyMath.cpp create mode 100644 enzyme/test/MLIR/Passes/dualpush.mlir create mode 100644 enzyme/test/MLIR/ReverseMode/square.mlir diff --git a/enzyme/BUILD b/enzyme/BUILD index dc36cb4ad0c5..bfb41ed3f9ff 100644 --- a/enzyme/BUILD +++ b/enzyme/BUILD @@ -547,6 +547,23 @@ gentbl( ], ) +gentbl( + name = "func-derivatives", + tbl_outs = [( + "-gen-mlir-derivatives", + "Enzyme/MLIR/Implementations/FuncDerivatives.inc", + )], + tblgen = ":enzyme-tblgen", + td_file = "Enzyme/MLIR/Implementations/FuncDerivatives.td", + td_srcs = [ + "Enzyme/MLIR/Implementations/FuncDerivatives.td", + "Enzyme/MLIR/Implementations/Common.td", + ], + deps = [ + ":enzyme-tblgen", + ], +) + cc_library( name = "EnzymeMLIR", srcs = glob([ @@ -582,6 +599,7 @@ cc_library( ":arith-derivatives", ":cf-derivatives", ":llvm-derivatives", + ":func-derivatives", ":math-derivatives", ":memref-derivatives", ":nvvm-derivatives", diff --git a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td index de51913400a6..7d74fdafbdbf 100644 --- a/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td +++ b/enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td @@ -83,12 +83,6 @@ def PopOp : Enzyme_Op<"pop"> { let results = (outs AnyType:$output); } -def ClearOp : Enzyme_Op<"clear"> { - let summary = "Remove top element from ShadowedGradient"; - let arguments = (ins AnyType : $cache); - let results = (outs ); -} - def InitOp : Enzyme_Op<"init"> { let summary = "Creat enzyme.gradient and enzyme.cache"; let arguments = (ins ); @@ -105,36 +99,28 @@ def Cache : Enzyme_Type<"Cache"> { let assemblyFormat = "`<` $type `>`"; } -def SetOp : Enzyme_Op<"set"> { - let summary = "Write to gradient"; - let arguments = (ins AnyType : $gradient, AnyType : $value); - let results = (outs ); -} - -def GetOp : Enzyme_Op<"get"> { - let summary = "Load value of gradient"; - let arguments = (ins AnyType : $gradient); - let results = (outs AnyType); -} - def Gradient : Enzyme_Type<"Gradient"> { - let summary = "Stores gradient if it cant be stroed in a value."; + let summary = "Mutable storage for accumulating gradients"; let description = [{ - "Cache for reverse pass" + Mutable storage for accumulating derivatives of immutable types (e.g. adding all the partial derivatives from users of a float64) }]; let parameters = (ins "Type":$basetype); let mnemonic = "Gradient"; let assemblyFormat = "`<` $basetype `>`"; } -def ShadowedGradient : Enzyme_Type<"ShadowedGradient"> { - let summary = "Stores gradients which need to be initialized with shadow values from the forward pass."; - let description = [{ - "Cache for reverse pass" - }]; - let parameters = (ins "Type":$basetype); - let mnemonic = "ShadowedGradient"; - let assemblyFormat = "`<` $basetype `>`"; +def SetOp : Enzyme_Op<"set"> { + let summary = "Store the current value of the gradient"; + let arguments = (ins Arg:$gradient, AnyType : $value); + let results = (outs ); +} + +def GetOp : Enzyme_Op<"get"> { + let summary = "Load current value of gradient"; + let arguments = (ins Arg:$gradient); + let results = (outs AnyType); } def AddToOp : Enzyme_Op<"addTo", [Pure, Terminator, ReturnLike]>, diff --git a/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp index a15b3283e10d..058f49e87ac9 100644 --- a/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/BuiltinAutoDiffTypeInterfaceImpl.cpp @@ -44,7 +44,7 @@ class FloatTypeInterface return self; } - bool requiresShadow(Type self) const { return false; } + bool isMutable(Type self) const { return false; } LogicalResult zeroInPlace(Type self, OpBuilder &builder, Location loc, Value val) const { return failure(); @@ -77,7 +77,7 @@ class TensorTypeInterface return self; } - bool requiresShadow(Type self) const { return false; } + bool isMutable(Type self) const { return false; } LogicalResult zeroInPlace(Type self, OpBuilder &builder, Location loc, Value val) const { return failure(); @@ -105,7 +105,7 @@ class IntegerTypeInterface return self; } - bool requiresShadow(Type self) const { return false; } + bool isMutable(Type self) const { return false; } LogicalResult zeroInPlace(Type self, OpBuilder &builder, Location loc, Value val) const { return failure(); diff --git a/enzyme/Enzyme/MLIR/Implementations/CFDerivatives.td b/enzyme/Enzyme/MLIR/Implementations/CFDerivatives.td index 8b4f41696a9d..0b522e72ccf2 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CFDerivatives.td +++ b/enzyme/Enzyme/MLIR/Implementations/CFDerivatives.td @@ -1,4 +1,5 @@ include "Common.td" +def : BranchOp<"cf", "CondBranchOp">; def : BranchOp<"cf", "BranchOp">; def : BranchOp<"cf", "SwitchOp">; diff --git a/enzyme/Enzyme/MLIR/Implementations/CMakeLists.txt b/enzyme/Enzyme/MLIR/Implementations/CMakeLists.txt index ab156fa1b36f..3508e71d9adc 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CMakeLists.txt +++ b/enzyme/Enzyme/MLIR/Implementations/CMakeLists.txt @@ -30,6 +30,10 @@ set(LLVM_TARGET_DEFINITIONS MathDerivatives.td) enzyme_tablegen(MathDerivatives.inc -gen-mlir-derivatives) add_public_tablegen_target(MathDerivativesIncGen) +set(LLVM_TARGET_DEFINITIONS FuncDerivatives.td) +enzyme_tablegen(FuncDerivatives.inc -gen-mlir-derivatives) +add_public_tablegen_target(FuncDerivativesIncGen) + add_mlir_library(MLIREnzymeImplementations AffineAutoDiffOpInterfaceImpl.cpp ArithAutoDiffOpInterfaceImpl.cpp @@ -37,6 +41,7 @@ add_mlir_library(MLIREnzymeImplementations LLVMAutoDiffOpInterfaceImpl.cpp NVVMAutoDiffOpInterfaceImpl.cpp MemRefAutoDiffOpInterfaceImpl.cpp + FuncAutoDiffOpInterfaceImpl.cpp LinalgAutoDiffOpInterfaceImpl.cpp BuiltinAutoDiffTypeInterfaceImpl.cpp SCFAutoDiffOpInterfaceImpl.cpp @@ -48,6 +53,7 @@ add_mlir_library(MLIREnzymeImplementations AffineDerivativesIncGen ArithDerivativesIncGen LLVMDerivativesIncGen + FuncDerivativesIncGen NVVMDerivativesIncGen SCFDerivativesIncGen CFDerivativesIncGen @@ -56,6 +62,7 @@ add_mlir_library(MLIREnzymeImplementations LINK_LIBS PUBLIC MLIRArithDialect + MLIRFuncDialect MLIRLLVMDialect MLIRMemRefDialect MLIREnzymeAutoDiffInterface diff --git a/enzyme/Enzyme/MLIR/Implementations/Common.td b/enzyme/Enzyme/MLIR/Implementations/Common.td index 33d9f12c2f37..3924f4527b00 100644 --- a/enzyme/Enzyme/MLIR/Implementations/Common.td +++ b/enzyme/Enzyme/MLIR/Implementations/Common.td @@ -26,6 +26,11 @@ class MemoryIdentityOp ptrargs_, list class ReadOnlyIdentityOp ptrargs_> : MemoryIdentityOp; +class ReturnOp { + string dialect = dialect_; + string opName = opName_; +} + class BranchOp { string dialect = dialect_; string opName = opName_; diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp index f9b2d61b0dce..ade1be7e6406 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.cpp @@ -15,6 +15,7 @@ #include "Interfaces/AutoDiffOpInterface.h" #include "Interfaces/AutoDiffTypeInterface.h" #include "Interfaces/GradientUtils.h" +#include "Interfaces/GradientUtilsReverse.h" using namespace mlir; using namespace mlir::enzyme; @@ -143,8 +144,7 @@ LogicalResult mlir::enzyme::detail::memoryIdentityForwardHandler( if (contains(storedVals, operand.getOperandNumber())) { if (auto iface = dyn_cast(operand.get().getType())) { - if (!iface.requiresShadow()) { - // TODO only do if mutable + if (!iface.isMutable()) { Type retTy = iface.getShadowType(); auto toret = retTy.cast().createNullValue( builder, operand.get().getLoc()); @@ -201,6 +201,29 @@ LogicalResult mlir::enzyme::detail::allocationForwardHandler( return success(); } +void mlir::enzyme::detail::returnReverseHandler(Operation *op, + OpBuilder &builder, + MGradientUtilsReverse *gutils) { + size_t num_out = 0; + for (auto act : gutils->RetDiffeTypes) { + if (act == DIFFE_TYPE::OUT_DIFF) + num_out++; + } + + size_t idx = 0; + auto args = gutils->newFunc->getRegions().begin()->begin()->getArguments(); + + for (auto &&[op, act] : llvm::zip(op->getOperands(), gutils->RetDiffeTypes)) { + if (act == DIFFE_TYPE::OUT_DIFF) { + if (!gutils->isConstantValue(op)) { + auto d_out = args[args.size() - num_out + idx]; + gutils->addToDiffe(op, d_out, builder); + } + idx++; + } + } +} + void mlir::enzyme::detail::regionTerminatorForwardHandler( Operation *origTerminator, OpBuilder &builder, MGradientUtils *gutils) { auto parentOp = origTerminator->getParentOp(); @@ -401,4 +424,5 @@ void mlir::enzyme::registerCoreDialectAutodiffInterfaces( enzyme::registerSCFDialectAutoDiffInterface(registry); enzyme::registerCFDialectAutoDiffInterface(registry); enzyme::registerLinalgDialectAutoDiffInterface(registry); + enzyme::registerFuncDialectAutoDiffInterface(registry); } diff --git a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h index 7ee0be2adb70..cbad734656b1 100644 --- a/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h +++ b/enzyme/Enzyme/MLIR/Implementations/CoreDialectsAutoDiffImplementations.h @@ -52,6 +52,10 @@ void branchingForwardHandler(Operation *op, OpBuilder &builder, void regionTerminatorForwardHandler(Operation *op, OpBuilder &builder, MGradientUtils *gutils); +// Implements reverse-mode differentiation of return operations. +void returnReverseHandler(Operation *op, OpBuilder &builder, + MGradientUtilsReverse *gutils); + // Implements forward-mode differentiation of read-only (including read-none) // operations which do not perform computation LogicalResult memoryIdentityForwardHandler(Operation *op, OpBuilder &builder, @@ -104,6 +108,44 @@ class AutoDiffUsingRegionTerminator } }; +template +class NoopRevAutoDiffInterface + : public ReverseAutoDiffOpInterface::ExternalModel< + NoopRevAutoDiffInterface, OpTy> { +public: + void createReverseModeAdjoint(Operation *op, OpBuilder &builder, + MGradientUtilsReverse *gutils, + SmallVector caches) const {} + + SmallVector cacheValues(Operation *op, + MGradientUtilsReverse *gutils) const { + return SmallVector(); + } + + void createShadowValues(Operation *op, OpBuilder &builder, + MGradientUtilsReverse *gutils) const {} +}; + +template +class ReturnRevAutoDiffInterface + : public ReverseAutoDiffOpInterface::ExternalModel< + ReturnRevAutoDiffInterface, OpTy> { +public: + void createReverseModeAdjoint(Operation *op, OpBuilder &builder, + MGradientUtilsReverse *gutils, + SmallVector caches) const { + returnReverseHandler(op, builder, gutils); + } + + SmallVector cacheValues(Operation *op, + MGradientUtilsReverse *gutils) const { + return SmallVector(); + } + + void createShadowValues(Operation *op, OpBuilder &builder, + MGradientUtilsReverse *gutils) const {} +}; + // Implements the forward autodiff interface for operations which are // read only and identity like (aka not computing sin of mem read). template @@ -166,12 +208,24 @@ void registerAutoDiffUsingControlFlowInterface(MLIRContext &context) { template void registerAutoDiffUsingBranchInterface(MLIRContext &context) { OpTy::template attachInterface>(context); + OpTy::template attachInterface>( + context); } // Registers AutoDiffUsingRegionTerminator for the given op. template void registerAutoDiffUsingRegionTerminatorInterface(MLIRContext &context) { OpTy::template attachInterface>( context); + OpTy::template attachInterface>( + context); +} +// Registers registerAutoDiffUsingReturnInterface for the given op. +template +void registerAutoDiffUsingReturnInterface(MLIRContext &context) { + OpTy::template attachInterface>( + context); + OpTy::template attachInterface>( + context); } // Registers AutoDiffUsingMemoryIdentity for the given op. template @@ -199,6 +253,7 @@ void registerSCFDialectAutoDiffInterface(DialectRegistry ®istry); void registerCFDialectAutoDiffInterface(DialectRegistry ®istry); void registerLinalgDialectAutoDiffInterface(DialectRegistry ®istry); void registerMathDialectAutoDiffInterface(DialectRegistry ®istry); +void registerFuncDialectAutoDiffInterface(DialectRegistry ®istry); void registerCoreDialectAutodiffInterfaces(DialectRegistry ®istry); diff --git a/enzyme/Enzyme/MLIR/Implementations/FuncAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/FuncAutoDiffOpInterfaceImpl.cpp new file mode 100644 index 000000000000..dddce795adfd --- /dev/null +++ b/enzyme/Enzyme/MLIR/Implementations/FuncAutoDiffOpInterfaceImpl.cpp @@ -0,0 +1,37 @@ +//===- FuncAutoDiffOpInterfaceImpl.cpp - Interface external model --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains the external model implementation of the automatic +// differentiation op interfaces for the upstream MLIR arithmetic dialect. +// +//===----------------------------------------------------------------------===// + +#include "Implementations/CoreDialectsAutoDiffImplementations.h" +#include "Interfaces/AutoDiffOpInterface.h" +#include "Interfaces/GradientUtils.h" +#include "Interfaces/GradientUtilsReverse.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/DialectRegistry.h" +#include "mlir/Support/LogicalResult.h" + +#include "Dialect/Ops.h" +#include "mlir/IR/TypeSupport.h" + +using namespace mlir; +using namespace mlir::enzyme; + +namespace { +#include "Implementations/FuncDerivatives.inc" +} // namespace + +void mlir::enzyme::registerFuncDialectAutoDiffInterface( + DialectRegistry ®istry) { + registry.addExtension(+[](MLIRContext *context, func::FuncDialect *) { + registerInterfaces(context); + }); +} diff --git a/enzyme/Enzyme/MLIR/Implementations/FuncDerivatives.td b/enzyme/Enzyme/MLIR/Implementations/FuncDerivatives.td new file mode 100644 index 000000000000..005246887fdf --- /dev/null +++ b/enzyme/Enzyme/MLIR/Implementations/FuncDerivatives.td @@ -0,0 +1,3 @@ +include "Common.td" + +def : ReturnOp<"func", "ReturnOp">; \ No newline at end of file diff --git a/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp index 472cd0a43d61..264278e97ef9 100644 --- a/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/LLVMAutoDiffOpInterfaceImpl.cpp @@ -55,7 +55,7 @@ class PointerTypeInterface return self; } - bool requiresShadow(Type self) const { return true; } + bool isMutable(Type self) const { return true; } LogicalResult zeroInPlace(Type self, OpBuilder &builder, Location loc, Value val) const { diff --git a/enzyme/Enzyme/MLIR/Implementations/LinalgAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/LinalgAutoDiffOpInterfaceImpl.cpp index 8ecb851af90b..db9f4b08d6e1 100644 --- a/enzyme/Enzyme/MLIR/Implementations/LinalgAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/LinalgAutoDiffOpInterfaceImpl.cpp @@ -133,7 +133,7 @@ struct GenericOpInterfaceReverse linalgOp.getNumLoops(), utils::IteratorType::parallel}; for (OpOperand &output : linalgOp.getDpsInitsMutable()) { - if (!gutils->hasInvertPointer(output.get())) { + if (gutils->isConstantValue(output.get())) { continue; } indexingMaps.push_back(linalgOp.getMatchingIndexingMap(&output)); @@ -143,7 +143,7 @@ struct GenericOpInterfaceReverse } for (OpOperand *input : linalgOp.getDpsInputOperands()) { - if (!gutils->hasInvertPointer(input->get())) { + if (gutils->isConstantValue(input->get())) { continue; } indexingMaps.push_back(linalgOp.getMatchingIndexingMap(input)); @@ -165,8 +165,13 @@ struct GenericOpInterfaceReverse StringAttr()); int numInputs = inputs.size(); - auto buildFuncReturnOp = [numInputs](OpBuilder &builder, Location loc, - SmallVector retargs) { + auto buildFuncReturnOp = [&gutils, numInputs](OpBuilder &builder, + Block *oBB) { + auto loc = oBB->rbegin()->getLoc(); + SmallVector retargs; + for (auto arg : oBB->getArguments()) { + retargs.push_back(gutils->invertPointerM(arg, builder)); + } builder.create( loc, ValueRange{retargs}.take_front(numInputs)); return; @@ -192,9 +197,8 @@ struct GenericOpInterfaceReverse return std::make_pair(pushCache, popCache); }; - gutils->Logic.differentiate( - gutils, *linalgOp.getBlock()->getParent(), adjoint.getRegion(), - /*parentRegion=*/false, buildFuncReturnOp, hook); + gutils->Logic.differentiate(gutils, *linalgOp.getBlock()->getParent(), + adjoint.getRegion(), buildFuncReturnOp, hook); auto newOpYield = cast( cast(newOp).getBodyRegion().front().getTerminator()); diff --git a/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp index 1b811caf1b81..cd21c5c548b9 100644 --- a/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/MemRefAutoDiffOpInterfaceImpl.cpp @@ -41,8 +41,8 @@ struct LoadOpInterfaceReverse Value memref = loadOp.getMemref(); if (auto iface = dyn_cast(loadOp.getType())) { - if (gutils->hasInvertPointer(loadOp) && - gutils->hasInvertPointer(memref)) { + if (!gutils->isConstantValue(loadOp) && + !gutils->isConstantValue(memref)) { Value gradient = gutils->invertPointerM(loadOp, builder); Value memrefGradient = gutils->invertPointerM(memref, builder); @@ -70,8 +70,8 @@ struct LoadOpInterfaceReverse Value memref = loadOp.getMemref(); ValueRange indices = loadOp.getIndices(); if (auto iface = dyn_cast(loadOp.getType())) { - if (gutils->hasInvertPointer(loadOp) && - gutils->hasInvertPointer(memref)) { + if (!gutils->isConstantValue(loadOp) && + !gutils->isConstantValue(memref)) { OpBuilder cacheBuilder(gutils->getNewFromOriginal(op)); SmallVector caches; for (Value v : indices) { @@ -104,28 +104,33 @@ struct StoreOpInterfaceReverse Value memref = storeOp.getMemref(); // ValueRange indices = storeOp.getIndices(); - if (auto iface = dyn_cast(val.getType())) { - if (gutils->hasInvertPointer(memref)) { - OpBuilder cacheBuilder(gutils->getNewFromOriginal(op)); + auto iface = cast(val.getType()); - Value memrefGradient = gutils->invertPointerM(memref, builder); + if (!gutils->isConstantValue(memref)) { + OpBuilder cacheBuilder(gutils->getNewFromOriginal(op)); - SmallVector retrievedArguments; - for (Value cache : caches) { - Value retrievedValue = gutils->popCache(cache, builder); - retrievedArguments.push_back(retrievedValue); - } + Value memrefGradient = gutils->invertPointerM(memref, builder); - Value loadedGradient = - builder.create(storeOp.getLoc(), memrefGradient, - ArrayRef(retrievedArguments)); - Value addedGradient = loadedGradient; - if (gutils->hasInvertPointer(val)) { - Value gradient = gutils->invertPointerM(val, builder); - addedGradient = iface.createAddOp(builder, storeOp.getLoc(), gradient, - loadedGradient); + SmallVector retrievedArguments; + for (Value cache : caches) { + Value retrievedValue = gutils->popCache(cache, builder); + retrievedArguments.push_back(retrievedValue); + } + + if (!iface.isMutable()) { + if (!gutils->isConstantValue(val)) { + Value loadedGradient = builder.create( + storeOp.getLoc(), memrefGradient, + ArrayRef(retrievedArguments)); + gutils->addToDiffe(val, loadedGradient, builder); } - gutils->mapInvertPointer(val, addedGradient, builder); + + auto zero = + cast(gutils->getShadowType(val.getType())) + .createNullValue(builder, op->getLoc()); + + builder.create(storeOp.getLoc(), zero, memrefGradient, + ArrayRef(retrievedArguments)); } } } @@ -137,7 +142,7 @@ struct StoreOpInterfaceReverse ValueRange indices = storeOp.getIndices(); Value val = storeOp.getValue(); if (auto iface = dyn_cast(val.getType())) { - if (gutils->hasInvertPointer(memref)) { + if (!gutils->isConstantValue(memref)) { OpBuilder cacheBuilder(gutils->getNewFromOriginal(op)); SmallVector caches; for (Value v : indices) { @@ -175,13 +180,13 @@ struct SubViewOpInterfaceReverse MGradientUtilsReverse *gutils) const { auto subviewOp = cast(op); auto newSubviewOp = cast(gutils->getNewFromOriginal(op)); - if (gutils->hasInvertPointer(subviewOp.getSource())) { + if (!gutils->isConstantValue(subviewOp.getSource())) { Value shadow = builder.create( op->getLoc(), newSubviewOp.getType(), gutils->invertPointerM(subviewOp.getSource(), builder), newSubviewOp.getMixedOffsets(), newSubviewOp.getMixedSizes(), newSubviewOp.getMixedStrides()); - gutils->mapShadowValue(subviewOp, shadow, builder); + gutils->setDiffe(subviewOp, shadow, builder); } } }; @@ -205,13 +210,13 @@ class MemRefTypeInterface return self; } - bool requiresShadow(Type self) const { return true; } + bool isMutable(Type self) const { return true; } LogicalResult zeroInPlace(Type self, OpBuilder &builder, Location loc, Value val) const { auto MT = cast(self); if (auto iface = dyn_cast(MT.getElementType())) { - if (!iface.requiresShadow()) { + if (!iface.isMutable()) { Value zero = iface.createNullValue(builder, loc); builder.create(loc, zero, val); } diff --git a/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp b/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp index bab415ff247f..72fbe2106a53 100644 --- a/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp +++ b/enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp @@ -40,45 +40,113 @@ struct ForOpInterfaceReverse SmallVector caches) const { auto forOp = cast(op); - SmallVector nArgs; - for (Value v : forOp.getResults()) { - if (auto iface = dyn_cast(v.getType())) { - if (gutils->hasInvertPointer(v)) { - nArgs.push_back(gutils->invertPointerM(v, builder)); - } else { - nArgs.push_back(iface.createNullValue(builder, v.getLoc())); + // Begin Perform d(yielded value[i]) += d(result[i]); d(result[i]) = 0 + SmallVector resDiffes; + for (OpResult v : forOp.getResults()) { + if (!gutils->isConstantValue(v)) { + auto autoDiffType = cast(v.getType()); + if (!autoDiffType.isMutable()) { + auto prev = gutils->diffe(v, builder); + gutils->zeroDiffe(v, builder); + resDiffes.push_back(prev); + continue; } } + resDiffes.push_back(nullptr); } - auto repFor = builder.create( - forOp.getLoc(), gutils->popCache(caches[0], builder), - gutils->popCache(caches[1], builder), - gutils->popCache(caches[2], builder), nArgs); // TODO - repFor.getRegion().begin()->erase(); - - auto buildFuncReturnOp = [](OpBuilder &builder, Location loc, - SmallVector retargs) { - builder.create(loc, retargs); - return; - }; - - gutils->Logic.differentiate(gutils, forOp.getRegion(), repFor.getRegion(), - /*parentRegion=*/false, buildFuncReturnOp, - nullptr); - - // Insert the index which is carried by the scf for op. - Type indexType = IndexType::get(builder.getContext()); - repFor.getRegion().insertArgument((unsigned)0, indexType, forOp.getLoc()); - - for (const auto &[iterOperand, adjResult] : - llvm::zip(forOp.getInitArgs(), repFor.getResults())) { - if (gutils->hasInvertPointer(iterOperand)) { - auto autoDiffType = cast(iterOperand.getType()); - Value before = gutils->invertPointerM(iterOperand, builder); - Value after = autoDiffType.createAddOp(builder, forOp.getLoc(), before, - adjResult); - gutils->mapInvertPointer(iterOperand, after, builder); + for (auto ® : op->getRegions()) { + auto termIface = + cast(reg.begin()->getTerminator()); + + SmallVector successors; + termIface.getSuccessorRegions( + SmallVector(termIface->getNumOperands(), Attribute()), + successors); + + for (auto &successor : successors) { + if (!successor.isParent()) + continue; + OperandRange operandRange = termIface.getSuccessorOperands(successor); + assert(operandRange.size() == resDiffes.size()); + + // There is an assumption here that there is only regions that branch to + // the successor. Specifically, otherwise we would need to + // gutils->addToDiffe select (if came from that result) + for (auto &&[prev, post] : llvm::zip(operandRange, resDiffes)) { + if (!post) + continue; + if (!gutils->isConstantValue(prev)) + gutils->addToDiffe(prev, post, builder); + } + } + } + // End Perform d(yielded value[i]) += d(result[i]); d(result[i]) = 0 + + auto start = gutils->popCache(caches[0], builder); + auto end = gutils->popCache(caches[1], builder); + auto step = gutils->popCache(caches[2], builder); + + auto repFor = builder.create(forOp.getLoc(), start, end, step, + ArrayRef()); + // erase scf yield + repFor.getBody()->begin()->erase(); + + for (auto &&[oldReg, newReg] : + llvm::zip(op->getRegions(), repFor->getRegions())) { + + // This code assumes at most one terminating block for each region (lest + // the append happen multiple times) + auto buildFuncReturnOp = [&](OpBuilder &builder, Block *oBB) { + auto loc = oBB->rbegin()->getLoc(); + + auto idx = repFor.getInductionVar(); + + auto lhs = builder.create(loc, idx, step); + + // This needs to know a condition describing which predecessor this will + // return to, to select the right value Here we use the condition i + + // step >= end to determine the last iteration + + auto condition = builder.create( + loc, arith::CmpIPredicate::sge, lhs, end); + + for (auto [arg, init_arg] : + llvm::zip(oBB->getArguments().slice(1), forOp.getInitArgs())) { + if (!gutils->isConstantValue(arg) && + !cast(arg.getType()).isMutable()) { + auto diffe = gutils->diffe(arg, builder); + gutils->zeroDiffe(arg, builder); + + auto zero = cast(diffe.getType()) + .createNullValue(builder, loc); + auto outside = + builder.create(loc, condition, diffe, zero); + auto inside = + builder.create(loc, condition, zero, diffe); + + // For each predecessor, if we came from that predecessor += the + // shadow of the arg [after zero'ing] + if (!gutils->isConstantValue(init_arg)) { + gutils->addToDiffe(init_arg, outside, builder); + } + + if (!gutils->isConstantValue(arg)) { + gutils->addToDiffe(arg, inside, builder); + } + } + } + builder.create(loc); + }; + + for (auto &&[oBB, revBB] : llvm::zip(oldReg, newReg)) { + gutils->mapReverseModeBlocks.map(&oBB, &revBB); + } + for (auto &&[oBB, revBB] : llvm::zip(oldReg, newReg)) { + gutils->Logic.visitChildren(&oBB, &revBB, gutils); + Block *newBB = gutils->getNewFromOriginal(&oBB); + gutils->Logic.handlePredecessors(&oBB, newBB, &revBB, gutils, + buildFuncReturnOp); } } } diff --git a/enzyme/Enzyme/MLIR/Interfaces/AutoDiffTypeInterface.td b/enzyme/Enzyme/MLIR/Interfaces/AutoDiffTypeInterface.td index 956f616751dc..0f6d38b4ffa4 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/AutoDiffTypeInterface.td +++ b/enzyme/Enzyme/MLIR/Interfaces/AutoDiffTypeInterface.td @@ -61,10 +61,10 @@ def AutoDiffTypeInterface : TypeInterface<"AutoDiffTypeInterface"> { >, InterfaceMethod< /*desc=*/[{ - Returns if the Type needs to be cleared. + Returns whether the type is mutable in place or not. }], /*retTy=*/"bool", - /*methodName=*/"requiresShadow", + /*methodName=*/"isMutable", /*args=*/(ins ) > ]; diff --git a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h index 223f604eaabb..dc98a5c4c6a6 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h +++ b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h @@ -10,8 +10,7 @@ namespace mlir { namespace enzyme { -typedef void(buildReturnFunction)(OpBuilder &, Location, - SmallVector); +typedef void(buildReturnFunction)(OpBuilder &, mlir::Block *); class MGradientUtilsReverse; @@ -125,24 +124,22 @@ class MEnzymeLogic { void initializeShadowValues(SmallVector &dominatorToposortBlocks, MGradientUtilsReverse *gutils); - void handlePredecessors(Block *oBB, Block *newBB, Block *reverseBB, - MGradientUtilsReverse *gutils, - llvm::function_ref buildReturnOp, - bool parentRegion); + void + handlePredecessors(Block *oBB, Block *newBB, Block *reverseBB, + MGradientUtilsReverse *gutils, + llvm::function_ref buildReturnOp); void visitChildren(Block *oBB, Block *reverseBB, MGradientUtilsReverse *gutils); void visitChild(Operation *op, OpBuilder &builder, MGradientUtilsReverse *gutils); bool visitChildCustom(Operation *op, OpBuilder &builder, MGradientUtilsReverse *gutils); - void handleReturns(Block *oBB, Block *newBB, Block *reverseBB, - MGradientUtilsReverse *gutils, bool parentRegion); void mapInvertArguments(Block *oBB, Block *reverseBB, MGradientUtilsReverse *gutils); SmallVector getDominatorToposort(MGradientUtilsReverse *gutils, Region ®ion); void differentiate(MGradientUtilsReverse *gutils, Region &oldRegion, - Region &newRegion, bool parentRegion, + Region &newRegion, llvm::function_ref buildFuncRetrunOp, std::function(Type)> cacheCreator); }; diff --git a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp index 39de939bc27a..f39766a0a639 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/EnzymeLogicReverse.cpp @@ -11,8 +11,6 @@ #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/Dominance.h" -#include "llvm/ADT/BreadthFirstIterator.h" #include "EnzymeLogic.h" #include "Interfaces/GradientUtils.h" @@ -22,71 +20,18 @@ using namespace mlir; using namespace mlir::enzyme; -SmallVector -MEnzymeLogic::getDominatorToposort(MGradientUtilsReverse *gutils, - Region ®ion) { - SmallVector dominatorToposortBlocks; - if (region.hasOneBlock()) { - dominatorToposortBlocks.push_back(&*(region.begin())); - } else { - auto dInfo = mlir::detail::DominanceInfoBase(nullptr); - llvm::DominatorTreeBase &dt = - dInfo.getDomTree(&(gutils->oldFunc.getFunctionBody())); - auto root = dt.getNode(&*(region.begin())); +void handleReturns(Block *oBB, Block *newBB, Block *reverseBB, + MGradientUtilsReverse *gutils) { + if (oBB->getNumSuccessors() == 0) { + Operation *returnStatement = newBB->getTerminator(); + gutils->erase(returnStatement); - for (llvm::DomTreeNodeBase *node : llvm::breadth_first(root)) { - dominatorToposortBlocks.push_back(node->getBlock()); - } - } - return dominatorToposortBlocks; -} + OpBuilder forwardToBackwardBuilder(newBB, newBB->end()); -void MEnzymeLogic::mapInvertArguments(Block *oBB, Block *reverseBB, - MGradientUtilsReverse *gutils) { - OpBuilder builder(reverseBB, reverseBB->begin()); - for (int i = 0; i < (int)gutils->mapBlockArguments[oBB].size(); i++) { - auto x = gutils->mapBlockArguments[oBB][i]; - if (auto iface = x.second.getType().dyn_cast()) { - Value added = reverseBB->getArgument(i); - if (gutils->hasInvertPointer(x.second)) { - added = iface.createAddOp(builder, x.second.getLoc(), added, - gutils->invertPointerM(x.second, builder)); - } - gutils->mapInvertPointer(x.second, added, builder); - } - } -} + Operation *newBranchOp = forwardToBackwardBuilder.create( + oBB->getTerminator()->getLoc(), reverseBB); -void MEnzymeLogic::handleReturns(Block *oBB, Block *newBB, Block *reverseBB, - MGradientUtilsReverse *gutils, - bool parentRegion) { - if (oBB->getNumSuccessors() == 0) { - if (parentRegion) { - Operation *returnStatement = newBB->getTerminator(); - gutils->erase(returnStatement); - - OpBuilder forwardToBackwardBuilder(newBB, newBB->end()); - gutils->mapInvertPointer( - oBB->getTerminator()->getOperand(0), - gutils->newFunc.getArgument(gutils->newFunc.getNumArguments() - 1), - forwardToBackwardBuilder); // TODO handle multiple return values - Operation *newBranchOp = forwardToBackwardBuilder.create( - oBB->getTerminator()->getLoc(), reverseBB); - - gutils->originalToNewFnOps[oBB->getTerminator()] = newBranchOp; - } else { - Operation *terminator = oBB->getTerminator(); - OpBuilder builder(reverseBB, reverseBB->begin()); - - int i = 0; - for (OpOperand &operand : terminator->getOpOperands()) { - Value val = operand.get(); - if (auto iface = val.getType().dyn_cast()) { - gutils->mapInvertPointer(val, reverseBB->getArgument(i), builder); - i++; - } - } - } + gutils->originalToNewFnOps[oBB->getTerminator()] = newBranchOp; } } @@ -135,7 +80,7 @@ bool MEnzymeLogic::visitChildCustom(Operation *op, OpBuilder &builder, SmallVector args; for (Value opResult : op->getResults()) { - if (gutils->hasInvertPointer(opResult)) { + if (!gutils->isConstantValue(opResult)) { Value invertValue = gutils->invertPointerM(opResult, builder); args.push_back(invertValue); } @@ -152,7 +97,7 @@ bool MEnzymeLogic::visitChildCustom(Operation *op, OpBuilder &builder, func::CallOp dCI = builder.create(op->getLoc(), srDiffe, resultTypes, args); for (int i = 0; i < (int)op->getNumOperands(); i++) { - gutils->mapInvertPointer(op->getOperand(i), dCI.getResult(i), builder); + gutils->setDiffe(op->getOperand(i), dCI.getResult(i), builder); } return true; @@ -165,7 +110,8 @@ Create reverse mode adjoint for an operation. */ void MEnzymeLogic::visitChild(Operation *op, OpBuilder &builder, MGradientUtilsReverse *gutils) { - if (llvm::all_of(op->getResults(), + if ((op->getBlock()->getTerminator() != op) && + llvm::all_of(op->getResults(), [gutils](Value v) { return gutils->isConstantValue(v); }) && gutils->isConstantInstruction(op)) { return; @@ -173,12 +119,14 @@ void MEnzymeLogic::visitChild(Operation *op, OpBuilder &builder, if (auto ifaceOp = dyn_cast(op)) { SmallVector caches = ifaceOp.cacheValues(gutils); ifaceOp.createReverseModeAdjoint(builder, gutils, caches); - + return; + /* for (int indexResult = 0; indexResult < (int)op->getNumResults(); indexResult++) { Value result = op->getResult(indexResult); gutils->clearValue(result, builder); } + */ } op->emitError() << "could not compute the adjoint for this operation " << *op; } @@ -201,176 +149,107 @@ void MEnzymeLogic::visitChildren(Block *oBB, Block *reverseBB, void MEnzymeLogic::handlePredecessors( Block *oBB, Block *newBB, Block *reverseBB, MGradientUtilsReverse *gutils, - llvm::function_ref buildReturnOp, bool parentRegion) { + llvm::function_ref buildReturnOp) { OpBuilder revBuilder(reverseBB, reverseBB->end()); if (oBB->hasNoPredecessors()) { - SmallVector retargs; - // We need different handling on the top level due to - // the presence of duplicated args since we don't yet have activity analysis - if (parentRegion) { - assert(gutils->ArgDiffeTypes.size() == - gutils->oldFunc.getNumArguments() && - "Mismatch of activity array size vs # original function args"); - for (const auto &[diffeType, oldArg] : - llvm::zip(gutils->ArgDiffeTypes, oBB->getArguments())) { - if (diffeType == DIFFE_TYPE::OUT_DIFF) { - retargs.push_back(gutils->invertPointerM(oldArg, revBuilder)); - } - } - } else { - for (auto arg : oBB->getArguments()) { - if (gutils->hasInvertPointer(arg)) { - retargs.push_back(gutils->invertPointerM(arg, revBuilder)); - } - } - } - buildReturnOp(revBuilder, oBB->rbegin()->getLoc(), retargs); + buildReturnOp(revBuilder, oBB); } else { + Location loc = oBB->rbegin()->getLoc(); + // TODO remove dependency on CF dialect + + Value cache = gutils->insertInit(gutils->getIndexCacheType()); + + Value flag = + revBuilder.create(loc, gutils->getIndexType(), cache); + + Block *defaultBlock = nullptr; + SmallVector blocks; SmallVector indices; - SmallVector> arguments; - SmallVector defaultArguments; - Block *defaultBlock = nullptr; - for (auto pair : llvm::enumerate(oBB->getPredecessors())) { - auto predecessor = pair.value(); - auto idx = pair.index(); - Block *predecessorRevMode = - gutils->mapReverseModeBlocks.lookupOrNull(predecessor); - - SmallVector operands; - auto argumentsIt = gutils->mapBlockArguments.find(predecessor); - if (argumentsIt != gutils->mapBlockArguments.end()) { - for (auto operandOld : argumentsIt->second) { - if (oBB == operandOld.first.getParentBlock() && - gutils->hasInvertPointer(operandOld.first)) { - operands.push_back( - gutils->invertPointerM(operandOld.first, revBuilder)); - } else { - if (auto iface = operandOld.first.getType() - .dyn_cast()) { - Value nullValue = - iface.createNullValue(revBuilder, oBB->rbegin()->getLoc()); - operands.push_back(nullValue); - } else { - llvm_unreachable("no canonial null value found"); - } - } - } - } - if (idx != 0) { - blocks.push_back(predecessorRevMode); - indices.push_back(APInt(32, idx - 1)); - arguments.emplace_back(std::move(operands)); - } else { - defaultBlock = predecessorRevMode; - defaultArguments = operands; + + OpBuilder newBuilder(newBB, newBB->begin()); + + SmallVector diffes; + for (auto arg : oBB->getArguments()) { + if (!gutils->isConstantValue(arg) && + !cast(arg.getType()).isMutable()) { + diffes.push_back(gutils->diffe(arg, revBuilder)); + gutils->zeroDiffe(arg, revBuilder); + continue; } + diffes.push_back(nullptr); } - // Clear invert pointers of all arguments with gradient - for (auto argument : oBB->getArguments()) { - if (gutils->hasInvertPointer(argument)) { - auto iface = argument.getType().cast(); - Value nullValue = iface.createNullValue(revBuilder, argument.getLoc()); - gutils->mapInvertPointer(argument, nullValue, revBuilder); + for (auto [idx, pred] : llvm::enumerate(oBB->getPredecessors())) { + auto reversePred = gutils->mapReverseModeBlocks.lookupOrNull(pred); + + Block *newPred = gutils->getNewFromOriginal(pred); + + OpBuilder predecessorBuilder(newPred->getTerminator()); + + Value pred_idx_c = + predecessorBuilder.create(loc, idx - 1, 32); + predecessorBuilder.create(loc, cache, pred_idx_c); + + if (idx == 0) { + defaultBlock = reversePred; + + } else { + indices.push_back(APInt(32, idx - 1)); + blocks.push_back(reversePred); } - } - Location loc = oBB->rbegin()->getLoc(); - // Remove Dependency to CF dialect - if (std::next(oBB->getPredecessors().begin()) == - oBB->getPredecessors().end()) { - // If there is only one block we can directly create a branch for - // simplicity sake - revBuilder.create(loc, defaultBlock, defaultArguments); - } else { - Value cache = gutils->insertInit(gutils->getIndexCacheType()); - Value flag = - revBuilder.create(loc, gutils->getIndexType(), cache); - - SmallVector argumentRanges; - for (const auto &a : arguments) - argumentRanges.emplace_back(a); - revBuilder.create( - loc, flag, defaultBlock, defaultArguments, ArrayRef(indices), - ArrayRef(blocks), argumentRanges); - - Value origin = newBB->addArgument(gutils->getIndexType(), loc); - - OpBuilder newBuilder(newBB, newBB->begin()); - newBuilder.create(loc, cache, origin); - - int j = 0; - for (Block *predecessor : oBB->getPredecessors()) { - Block *newPredecessor = gutils->getNewFromOriginal(predecessor); - - OpBuilder predecessorBuilder(newPredecessor, - std::prev(newPredecessor->end())); - Value indicator = - predecessorBuilder.create(loc, j++, 32); - - Operation *terminator = newPredecessor->getTerminator(); - if (auto binst = dyn_cast(terminator)) { - for (unsigned i = 0; i < terminator->getNumSuccessors(); i++) { - if (terminator->getSuccessor(i) == newBB) { - SuccessorOperands sOps = binst.getSuccessorOperands(i); - sOps.append(indicator); + auto term = pred->getTerminator(); + if (auto iface = dyn_cast(term)) { + for (auto &op : term->getOpOperands()) + if (auto blk_idx = + iface.getSuccessorBlockArgument(op.getOperandNumber())) + if ((*blk_idx).getOwner() == oBB) { + auto idx = (*blk_idx).getArgNumber(); + if (diffes[idx]) { + + Value rev_idx_c = + revBuilder.create(loc, idx - 1, 32); + + auto to_prop = revBuilder.create( + loc, + revBuilder.create( + loc, arith::CmpIPredicate::eq, flag, rev_idx_c), + diffes[idx], + cast(diffes[idx].getType()) + .createNullValue(revBuilder, loc)); + + gutils->addToDiffe(op.get(), to_prop, revBuilder); + } } - } - } else { - llvm_unreachable("invalid terminator"); - } + } else { + assert(0 && "predecessor did not implement branch op interface"); } } - } -} -void MEnzymeLogic::initializeShadowValues( - SmallVector &dominatorToposortBlocks, - MGradientUtilsReverse *gutils) { - for (auto it = dominatorToposortBlocks.begin(); - it != dominatorToposortBlocks.end(); ++it) { - Block *oBB = *it; - - if (!oBB->empty()) { - for (auto it = oBB->begin(); it != oBB->end(); ++it) { - Operation *op = &*it; - Operation *newOp = gutils->getNewFromOriginal(op); - - if (auto ifaceOp = dyn_cast(op)) { - OpBuilder builder(newOp); - ifaceOp.createShadowValues(builder, gutils); - } - } - } + revBuilder.create( + loc, flag, defaultBlock, ArrayRef(), ArrayRef(indices), + ArrayRef(blocks), + SmallVector(indices.size(), ValueRange())); } } void MEnzymeLogic::differentiate( MGradientUtilsReverse *gutils, Region &oldRegion, Region &newRegion, - bool parentRegion, llvm::function_ref buildFuncReturnOp, std::function(Type)> cacheCreator) { gutils->registerCacheCreatorHook(cacheCreator); auto scope = llvm::make_scope_exit( [&]() { gutils->deregisterCacheCreatorHook(cacheCreator); }); - gutils->createReverseModeBlocks(oldRegion, newRegion, parentRegion); - - SmallVector dominatorToposortBlocks = - getDominatorToposort(gutils, oldRegion); - initializeShadowValues(dominatorToposortBlocks, gutils); - - for (auto it = dominatorToposortBlocks.rbegin(); - it != dominatorToposortBlocks.rend(); ++it) { - Block *oBB = *it; - Block *newBB = gutils->getNewFromOriginal(oBB); - Block *reverseBB = gutils->mapReverseModeBlocks.lookupOrNull(oBB); - mapInvertArguments(oBB, reverseBB, gutils); - handleReturns(oBB, newBB, reverseBB, gutils, parentRegion); - visitChildren(oBB, reverseBB, gutils); - handlePredecessors(oBB, newBB, reverseBB, gutils, buildFuncReturnOp, - parentRegion); + gutils->createReverseModeBlocks(oldRegion, newRegion); + + for (auto &oBB : oldRegion) { + Block *newBB = gutils->getNewFromOriginal(&oBB); + Block *reverseBB = gutils->mapReverseModeBlocks.lookupOrNull(&oBB); + handleReturns(&oBB, newBB, reverseBB, gutils); + visitChildren(&oBB, reverseBB, gutils); + handlePredecessors(&oBB, newBB, reverseBB, gutils, buildFuncReturnOp); } } @@ -394,13 +273,20 @@ FunctionOpInterface MEnzymeLogic::CreateReverseDiff( Region &oldRegion = gutils->oldFunc.getFunctionBody(); Region &newRegion = gutils->newFunc.getFunctionBody(); - auto buildFuncReturnOp = [](OpBuilder &builder, Location loc, - SmallVector retargs) { - builder.create(loc, retargs); + auto buildFuncReturnOp = [&](OpBuilder &builder, Block *oBB) { + SmallVector retargs; + for (auto [arg, cv] : llvm::zip(oBB->getArguments(), constants)) { + if (cv == DIFFE_TYPE::OUT_DIFF) { + retargs.push_back(gutils->diffe(arg, builder)); + } + } + builder.create(oBB->rbegin()->getLoc(), retargs); return; }; - differentiate(gutils, oldRegion, newRegion, true, buildFuncReturnOp, nullptr); + gutils->forceAugmentedReturns(); + + differentiate(gutils, oldRegion, newRegion, buildFuncReturnOp, nullptr); auto nf = gutils->newFunc; diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp index 4598394b9e44..ebae44c9efa1 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.cpp @@ -41,7 +41,8 @@ mlir::enzyme::MGradientUtils::MGradientUtils( originalToNewFnOps(originalToNewFnOps_), blocksNotForAnalysis(), activityAnalyzer(std::make_unique( blocksNotForAnalysis, constantvalues_, activevals_, ReturnActivity)), - TA(TA_), TR(TR_), omp(omp), width(width), ArgDiffeTypes(ArgDiffeTypes_) { + TA(TA_), TR(TR_), omp(omp), width(width), ArgDiffeTypes(ArgDiffeTypes_), + RetDiffeTypes(1, ReturnActivity) { /* for (BasicBlock &BB : *oldFunc) { @@ -170,6 +171,55 @@ mlir::Value mlir::enzyme::MGradientUtils::invertPointerM(mlir::Value v, llvm_unreachable("could not invert pointer"); } +mlir::Value +mlir::enzyme::MDiffeGradientUtils::getDifferential(mlir::Value oval) { + auto found = differentials.lookupOrNull(oval); + if (found != nullptr) + return found; + + auto shadowty = getShadowType(oval.getType()); + OpBuilder builder(oval.getContext()); + builder.setInsertionPointToStart(initializationBlock); + + auto shadow = builder.create( + oval.getLoc(), enzyme::GradientType::get(oval.getContext(), shadowty)); + auto toset = cast(shadowty).createNullValue( + builder, oval.getLoc()); + builder.create(oval.getLoc(), shadow, toset); + + differentials.map(oval, shadow); + return shadow; +} + +void mlir::enzyme::MDiffeGradientUtils::setDiffe(mlir::Value oval, + mlir::Value toset, + OpBuilder &BuilderM) { + assert(!isConstantValue(oval)); + auto iface = oval.getType().cast(); + if (!iface.isMutable()) { + auto shadow = getDifferential(oval); + BuilderM.create(oval.getLoc(), shadow, toset); + } else { + MGradientUtils::setDiffe(oval, toset, BuilderM); + } +} + +void mlir::enzyme::MDiffeGradientUtils::zeroDiffe(mlir::Value oval, + OpBuilder &BuilderM) { + assert(!isConstantValue(oval)); + auto iface = getShadowType(oval.getType()).cast(); + assert(!iface.isMutable()); + setDiffe(oval, iface.createNullValue(BuilderM, oval.getLoc()), BuilderM); +} + +mlir::Value mlir::enzyme::MDiffeGradientUtils::diffe(mlir::Value oval, + OpBuilder &BuilderM) { + + auto shadow = getDifferential(oval); + return BuilderM.create(oval.getLoc(), + getShadowType(oval.getType()), shadow); +} + void mlir::enzyme::MGradientUtils::setDiffe(mlir::Value val, mlir::Value toset, OpBuilder &BuilderM) { /* @@ -226,81 +276,39 @@ void mlir::enzyme::MGradientUtils::forceAugmentedReturns() { if (isConstantValue(val)) continue; auto i = val.getArgNumber(); - mlir::Value dval; - if (i == blk->getArguments().size() - 1) - dval = nblk->addArgument(getShadowType(val.getType()), val.getLoc()); - else - dval = nblk->insertArgument(nblk->args_begin() + i + 1, - getShadowType(val.getType()), val.getLoc()); - - invertedPointers.map(val, dval); + if (mode == DerivativeMode::ForwardMode || + mode == DerivativeMode::ForwardModeSplit || + cast(val.getType()).isMutable()) { + mlir::Value dval; + if (i == blk->getArguments().size() - 1) + dval = nblk->addArgument(getShadowType(val.getType()), val.getLoc()); + else + dval = + nblk->insertArgument(nblk->args_begin() + i + 1, + getShadowType(val.getType()), val.getLoc()); + + invertedPointers.map(val, dval); + } } }); oldFunc.walk([&](Operation *inst) { if (inst == oldFunc) return; - if (mode == DerivativeMode::ForwardMode || - mode == DerivativeMode::ForwardModeSplit) { - OpBuilder BuilderZ(getNewFromOriginal(inst)); - for (auto res : inst->getResults()) { - if (!isConstantValue(res)) { - mlir::Type antiTy = getShadowType(res.getType()); - auto anti = - BuilderZ.create(res.getLoc(), antiTy); - invertedPointers.map(res, anti); - } - } - return; - } - /* - - if (inst->getType()->isFPOrFPVectorTy()) - continue; //! op->getType()->isPointerTy() && - //! !op->getType()->isIntegerTy()) { - - if (!TR.query(inst)[{-1}].isPossiblePointer()) - continue; - - if (isa(inst)) { - IRBuilder<> BuilderZ(inst); - getForwardBuilder(BuilderZ); - Type *antiTy = getShadowType(inst->getType()); - PHINode *anti = - BuilderZ.CreatePHI(antiTy, 1, inst->getName() + "'il_phi"); - invertedPointers.insert(std::make_pair( - (const Value *)inst, InvertedPointerVH(this, anti))); - continue; - } - - if (!isa(inst)) { - continue; - } - - if (isa(inst)) { - continue; - } - - if (isConstantValue(inst)) { - continue; - } - - CallInst *op = cast(inst); - Function *called = op->getCalledFunction(); - IRBuilder<> BuilderZ(inst); - getForwardBuilder(BuilderZ); - Type *antiTy = getShadowType(inst->getType()); - - PHINode *anti = - BuilderZ.CreatePHI(antiTy, 1, op->getName() + "'ip_phi"); - invertedPointers.insert( - std::make_pair((const Value *)inst, InvertedPointerVH(this, anti))); + OpBuilder BuilderZ(getNewFromOriginal(inst)); + for (auto res : inst->getResults()) { + if (isConstantValue(res)) + continue; - if (called && isAllocationFunction(called->getName(), TLI)) { - anti->setName(op->getName() + "'mi"); + if (!(mode == DerivativeMode::ForwardMode || + mode == DerivativeMode::ForwardModeSplit || + cast(res.getType()).isMutable())) + continue; + mlir::Type antiTy = getShadowType(res.getType()); + auto anti = BuilderZ.create(res.getLoc(), antiTy); + invertedPointers.map(res, anti); } - */ }); } diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h index 7ef99a1014fd..32a7fe068d07 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtils.h @@ -38,6 +38,7 @@ class MGradientUtils { bool omp; unsigned width; + SmallVector RetDiffeTypes; ArrayRef ArgDiffeTypes; mlir::Value getNewFromOriginal(const mlir::Value originst) const; @@ -68,16 +69,35 @@ class MGradientUtils { bool isConstantInstruction(mlir::Operation *v) const; bool isConstantValue(mlir::Value v) const; mlir::Value invertPointerM(mlir::Value v, OpBuilder &Builder2); - void setDiffe(mlir::Value val, mlir::Value toset, OpBuilder &BuilderM); void forceAugmentedReturns(); Operation *cloneWithNewOperands(OpBuilder &B, Operation *op); LogicalResult visitChild(Operation *op); + + void setDiffe(mlir::Value origv, mlir::Value newv, mlir::OpBuilder &builder); + + mlir::Type getShadowType(mlir::Type T) { + auto iface = cast(T); + return iface.getShadowType(width); + } }; class MDiffeGradientUtils : public MGradientUtils { +protected: + IRMapping differentials; + + Block *initializationBlock; + public: + mlir::Value getDifferential(mlir::Value origv); + + void setDiffe(mlir::Value origv, mlir::Value newv, mlir::OpBuilder &builder); + + void zeroDiffe(mlir::Value origv, mlir::OpBuilder &builder); + + mlir::Value diffe(mlir::Value origv, mlir::OpBuilder &builder); + MDiffeGradientUtils(MEnzymeLogic &Logic, FunctionOpInterface newFunc_, FunctionOpInterface oldFunc_, MTypeAnalysis &TA, MTypeResults TR, IRMapping &invertedPointers_, @@ -91,7 +111,8 @@ class MDiffeGradientUtils : public MGradientUtils { : MGradientUtils(Logic, newFunc_, oldFunc_, TA, TR, invertedPointers_, constantvalues_, activevals_, ActiveReturn, constant_values, origToNew_, origToNewOps_, mode, width, - omp) {} + omp), + initializationBlock(&*(newFunc.getFunctionBody().begin())) {} // Technically diffe constructor static MDiffeGradientUtils * diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp b/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp index 65d0349f1476..20d34247fb3a 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.cpp @@ -40,16 +40,7 @@ mlir::enzyme::MGradientUtilsReverse::MGradientUtilsReverse( invertedPointers_, constantvalues_, activevals_, ReturnActivity, ArgDiffeTypes_, originalToNewFn_, originalToNewFnOps_, mode_, width, /*omp*/ false), - symbolTable(symbolTable_) { - - initInitializationBlock(invertedPointers_, ArgDiffeTypes_); -} - -// for(auto x : v.getUsers()){x->dump();} DEBUG - -bool MGradientUtilsReverse::onlyUsedInParentBlock(Value v) { - return !v.isUsedOutsideOfBlock(v.getParentBlock()); -} + symbolTable(symbolTable_) {} Type mlir::enzyme::MGradientUtilsReverse::getIndexCacheType() { Type indexType = getIndexType(); @@ -109,34 +100,6 @@ Value MGradientUtilsReverse::popCache(Value cache, OpBuilder &builder) { cache); } -// Gradient -Type mlir::enzyme::MGradientUtilsReverse::getGradientType(Value v) { - Type valueType = v.getType(); - return GradientType::get(v.getContext(), valueType); -} - -Value mlir::enzyme::MGradientUtilsReverse::insertInitGradient( - mlir::Value v, OpBuilder &builder) { - Type gradientType = getGradientType(v); - OpBuilder initBuilder(initializationBlock, initializationBlock->begin()); - Value gradient = initBuilder.create(v.getLoc(), gradientType); - return gradient; -} - -// Shadow Gradient -Type mlir::enzyme::MGradientUtilsReverse::getShadowedGradientType(Value v) { - Type valueType = v.getType(); - return ShadowedGradientType::get(v.getContext(), valueType); -} - -Value mlir::enzyme::MGradientUtilsReverse::insertInitShadowedGradient( - mlir::Value v, OpBuilder &builder) { - Type gradientType = getShadowedGradientType(v); - OpBuilder initBuilder(initializationBlock, initializationBlock->begin()); - Value gradient = initBuilder.create(v.getLoc(), gradientType); - return gradient; -} - Operation * mlir::enzyme::MGradientUtilsReverse::cloneWithNewOperands(OpBuilder &B, Operation *op) { @@ -146,205 +109,25 @@ mlir::enzyme::MGradientUtilsReverse::cloneWithNewOperands(OpBuilder &B, return B.clone(*op, map); } -bool mlir::enzyme::MGradientUtilsReverse::requiresShadow(Type t) { - if (auto iface = dyn_cast(t)) { - return iface.requiresShadow(); - } - return false; -} - void mlir::enzyme::MGradientUtilsReverse::addToDiffe(Value oldGradient, Value addedGradient, OpBuilder &builder) { - // TODO - Value gradient = addedGradient; - if (hasInvertPointer(oldGradient)) { - Value operandGradient = invertPointerM(oldGradient, builder); - auto iface = cast(addedGradient.getType()); - gradient = iface.createAddOp(builder, oldGradient.getLoc(), operandGradient, + assert(!isConstantValue(oldGradient)); + Value operandGradient = diffe(oldGradient, builder); + auto iface = cast(addedGradient.getType()); + auto added = iface.createAddOp(builder, oldGradient.getLoc(), operandGradient, addedGradient); - } - mapInvertPointer(oldGradient, gradient, builder); -} - -Value mlir::enzyme::MGradientUtilsReverse::diffe(Value v, OpBuilder &builder) { - return invertPointerM(v, builder); -} - -/* -The value v must have an invert pointer -*/ -Value mlir::enzyme::MGradientUtilsReverse::invertPointerM(Value v, - OpBuilder &builder) { - if (invertedPointersGlobal.contains(v)) { - Value gradient = invertedPointersGlobal.lookupOrNull(v); - Type type = gradient.getType(); - - if (GradientType gType = dyn_cast(type)) { - Value ret = builder.create(v.getLoc(), gType.getBasetype(), - gradient); - return ret; - } else { - llvm_unreachable("found invalid type"); - } - } else if (invertedPointersShadow.contains(v)) { - Value gradient = invertedPointersShadow.lookupOrNull(v); - Type type = gradient.getType(); - - if (ShadowedGradientType gType = - dyn_cast(type)) { - Value ret = builder.create(v.getLoc(), gType.getBasetype(), - gradient); - return ret; - } else { - llvm_unreachable("found invalid type"); - } - } - - llvm::errs() << " could not invert pointer v " << v << "\n"; - llvm_unreachable("could not invert pointer"); -} - -void mlir::enzyme::MGradientUtilsReverse::mapInvertPointer( - mlir::Value v, mlir::Value invertValue, OpBuilder &builder) { - if (!invertedPointersGlobal.contains(v)) { - Value g = insertInitGradient(v, builder); - invertedPointersGlobal.map(v, g); - } - Value gradient = invertedPointersGlobal.lookupOrNull(v); - builder.create(v.getLoc(), gradient, invertValue); -} - -Value mlir::enzyme::MGradientUtilsReverse::getShadowValue(mlir::Value v) { - return shadowValues.lookupOrNull(v); -} - -void mlir::enzyme::MGradientUtilsReverse::mapShadowValue(mlir::Value v, - mlir::Value shadow, - OpBuilder &builder) { - assert(!invertedPointersShadow.contains( - v)); // Shadow Values must only be mapped exactly once - - Value cache = insertInitShadowedGradient(v, builder); - invertedPointersShadow.map(v, cache); - - builder.create(v.getLoc(), cache, shadow); - - shadowValues.map(v, shadow); -} - -void mlir::enzyme::MGradientUtilsReverse::clearValue(mlir::Value v, - OpBuilder &builder) { - if (invertedPointersGlobal.contains(v)) { - if (!onlyUsedInParentBlock(v)) { // TODO is this necessary? - Value gradient = invertedPointersGlobal.lookupOrNull(v); - Type type = cast(gradient.getType()).getBasetype(); - if (auto iface = dyn_cast(type)) { - Value zero = iface.createNullValue(builder, v.getLoc()); - builder.create(v.getLoc(), gradient, zero); - } else { - llvm_unreachable( - "Type does not have an associated AutoDiffTypeInterface"); - } - } - } else if (invertedPointersShadow.contains(v)) { - Value gradient = invertedPointersShadow.lookupOrNull(v); - builder.create(v.getLoc(), gradient); - } -} - -bool mlir::enzyme::MGradientUtilsReverse::hasInvertPointer(mlir::Value v) { - return (invertedPointersGlobal.contains(v)) || - (invertedPointersShadow.contains(v)); -} - -void MGradientUtilsReverse::initInitializationBlock( - IRMapping invertedPointers_, ArrayRef argDiffeTypes) { - initializationBlock = &*(this->newFunc.getFunctionBody().begin()); - - OpBuilder initializationBuilder( - &*(this->newFunc.getFunctionBody().begin()), - this->newFunc.getFunctionBody().begin()->begin()); - - for (const auto &[val, diffe_type] : llvm::zip( - this->oldFunc.getFunctionBody().getArguments(), argDiffeTypes)) { - if (diffe_type != DIFFE_TYPE::OUT_DIFF) { - continue; - } - auto iface = dyn_cast(val.getType()); - if (!iface) { - llvm_unreachable( - "Type does not have an associated AutoDiffTypeInterface"); - } - Value zero = iface.createNullValue(initializationBuilder, val.getLoc()); - mapInvertPointer(val, zero, initializationBuilder); - } - for (auto const &x : invertedPointers_.getValueMap()) { - if (auto iface = dyn_cast(x.first.getType())) { - if (iface.requiresShadow()) { - mapShadowValue(x.first, x.second, - initializationBuilder); // This may create an unnecessary - // ShadowedGradient which could - // be avoidable TODO - } else { - mapInvertPointer(x.first, x.second, initializationBuilder); - } - } else { - llvm_unreachable("TODO not implemented"); - } - } + setDiffe(oldGradient, added, builder); } void MGradientUtilsReverse::createReverseModeBlocks(Region &oldFunc, - Region &newFunc, - bool isParentRegion) { + Region &newFunc) { for (auto it = oldFunc.getBlocks().rbegin(); it != oldFunc.getBlocks().rend(); ++it) { Block *block = &*it; Block *reverseBlock = new Block(); - - SmallVector> - reverseModeArguments; // Argument, Assigned value (2. is technically not - // necessary but simplifies code a lot) - - // Add reverse mode Arguments to Block - Operation *term = block->getTerminator(); - mlir::BranchOpInterface brOp = dyn_cast(term); - bool returnLike = term->hasTrait(); - if (brOp) { - for (int i = 0; i < (int)term->getNumSuccessors(); i++) { - SuccessorOperands sOps = brOp.getSuccessorOperands(i); - Block *successorBlock = term->getSuccessor(i); - - assert(successorBlock->getNumArguments() == sOps.size()); - for (int j = 0; j < (int)sOps.size(); j++) { - // Check if the argument needs a gradient - if (auto iface = successorBlock->getArgument(j) - .getType() - .dyn_cast()) { - reverseModeArguments.push_back(std::pair( - successorBlock->getArgument(j), sOps[j])); - } - } - } - for (auto it : reverseModeArguments) { - reverseBlock->addArgument(it.second.getType(), it.second.getLoc()); - } - - mapBlockArguments[block] = reverseModeArguments; - } else if (returnLike) { - if (!isParentRegion) { - for (OpOperand &operand : term->getOpOperands()) { - Value val = operand.get(); - if (auto iface = val.getType().dyn_cast()) { - reverseBlock->addArgument(val.getType(), val.getLoc()); - } - } - } - } - - mapReverseModeBlocks.map(block, reverseBlock); newFunc.getBlocks().insert(newFunc.end(), reverseBlock); + mapReverseModeBlocks.map(block, reverseBlock); } } diff --git a/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.h b/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.h index 3d5d4d147269..d3b2e818391f 100644 --- a/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.h +++ b/enzyme/Enzyme/MLIR/Interfaces/GradientUtilsReverse.h @@ -34,32 +34,13 @@ class MGradientUtilsReverse : public MDiffeGradientUtils { DerivativeMode mode_, unsigned width, SymbolTableCollection &symbolTable_); - IRMapping invertedPointersGlobal; - IRMapping invertedPointersShadow; - IRMapping shadowValues; - Block *initializationBlock; - IRMapping mapReverseModeBlocks; - DenseMap>> mapBlockArguments; SymbolTableCollection &symbolTable; - bool hasInvertPointer(mlir::Value v); - mlir::Value invertPointerM(mlir::Value v, OpBuilder &builder); - mlir::Value diffe(mlir::Value v, OpBuilder &builder); - void addToDiffe(mlir::Value oldGradient, mlir::Value addedGradient, OpBuilder &builder); - void mapInvertPointer(mlir::Value v, mlir::Value invertValue, - OpBuilder &builder); - - mlir::Value getShadowValue(mlir::Value v); - void mapShadowValue(mlir::Value v, mlir::Value invertValue, - OpBuilder &builder); - void clearValue(mlir::Value v, OpBuilder &builder); - - void setDiffe(mlir::Value val, mlir::Value toset, OpBuilder &BuilderM); Type getIndexType(); Value insertInit(Type t); @@ -75,27 +56,11 @@ class MGradientUtilsReverse : public MDiffeGradientUtils { Type getIndexCacheType(); Value initAndPushCache(Value v, OpBuilder &builder); - // Gradient - Type getGradientType(Value t); - Value insertInitGradient(mlir::Value v, OpBuilder &builder); - - // ShadowedGradient - Type getShadowedGradientType(Value t); - Value insertInitShadowedGradient(mlir::Value v, OpBuilder &builder); - - bool requiresShadow(Type t); - - void initInitializationBlock(IRMapping invertedPointers_, - ArrayRef argDiffeTypes); - - bool onlyUsedInParentBlock(Value v); - Operation *cloneWithNewOperands(OpBuilder &B, Operation *op); Value popCache(Value cache, OpBuilder &builder); - void createReverseModeBlocks(Region &oldFunc, Region &newFunc, - bool isParentRegion = false); + void createReverseModeBlocks(Region &oldFunc, Region &newFunc); static MGradientUtilsReverse * CreateFromClone(MEnzymeLogic &Logic, DerivativeMode mode_, unsigned width, diff --git a/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt b/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt index 8a18dc33d195..90e9506d2bb1 100644 --- a/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt +++ b/enzyme/Enzyme/MLIR/Passes/CMakeLists.txt @@ -9,7 +9,7 @@ add_mlir_dialect_library(MLIREnzymeTransforms PrintActivityAnalysis.cpp PrintAliasAnalysis.cpp EnzymeToMemRef.cpp - ShadowedGradientToCache.cpp + SimplifyMath.cpp AddToOpToIndexAndLoad.cpp AddToOpToSplit.cpp RemoveUnusedEnzymeOps.cpp diff --git a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp index e43db26b21a2..de7328dcb186 100644 --- a/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp +++ b/enzyme/Enzyme/MLIR/Passes/EnzymeMLIRPass.cpp @@ -31,6 +31,17 @@ struct DifferentiatePass : public DifferentiatePassBase { void runOnOperation() override; + static DIFFE_TYPE mode_from_fn(FunctionOpInterface fn, DerivativeMode mode) { + DIFFE_TYPE retType = DIFFE_TYPE::CONSTANT; + if (fn.getNumResults() != 0) { + if (mode == DerivativeMode::ReverseModeCombined) + retType = DIFFE_TYPE::OUT_DIFF; + else + retType = DIFFE_TYPE::DUP_ARG; + } + return retType; + } + template LogicalResult HandleAutoDiff(SymbolTableCollection &symbolTable, T CI) { std::vector constants; @@ -60,12 +71,11 @@ struct DifferentiatePass : public DifferentiatePassBase { auto *symbolOp = symbolTable.lookupNearestSymbolFrom(CI, CI.getFnAttr()); auto fn = cast(symbolOp); - DIFFE_TYPE retType = - fn.getNumResults() == 0 ? DIFFE_TYPE::CONSTANT : DIFFE_TYPE::DUP_ARG; + auto mode = DerivativeMode::ForwardMode; + DIFFE_TYPE retType = mode_from_fn(fn, mode); MTypeAnalysis TA; auto type_args = TA.getAnalyzedTypeInfo(fn); - auto mode = DerivativeMode::ForwardMode; bool freeMemory = true; size_t width = 1; @@ -118,19 +128,18 @@ struct DifferentiatePass : public DifferentiatePassBase { truei++; } - // Add the return gradient - mlir::Value res = CI.getInputs()[CI.getInputs().size() - 1]; - args.push_back(res); - auto *symbolOp = symbolTable.lookupNearestSymbolFrom(CI, CI.getFnAttr()); auto fn = cast(symbolOp); - DIFFE_TYPE retType = - fn.getNumResults() == 0 ? DIFFE_TYPE::CONSTANT : DIFFE_TYPE::DUP_ARG; + auto mode = DerivativeMode::ReverseModeCombined; + DIFFE_TYPE retType = mode_from_fn(fn, mode); + + // Add the return gradient + mlir::Value res = CI.getInputs()[CI.getInputs().size() - 1]; + args.push_back(res); MTypeAnalysis TA; auto type_args = TA.getAnalyzedTypeInfo(fn); - auto mode = DerivativeMode::ReverseModeGradient; bool freeMemory = true; size_t width = 1; diff --git a/enzyme/Enzyme/MLIR/Passes/Passes.h b/enzyme/Enzyme/MLIR/Passes/Passes.h index 80c88373090c..c4eff488c332 100644 --- a/enzyme/Enzyme/MLIR/Passes/Passes.h +++ b/enzyme/Enzyme/MLIR/Passes/Passes.h @@ -30,7 +30,7 @@ std::unique_ptr createPrintAliasAnalysisPass(); std::unique_ptr createEnzymeToMemRefPass(); -std::unique_ptr createShadowedGradientToCachePass(); +std::unique_ptr createMathematicSimplificationPass(); std::unique_ptr createAddToOpToIndexAndLoadPass(); diff --git a/enzyme/Enzyme/MLIR/Passes/Passes.td b/enzyme/Enzyme/MLIR/Passes/Passes.td index d2a04fc37412..c1617eaf4af0 100644 --- a/enzyme/Enzyme/MLIR/Passes/Passes.td +++ b/enzyme/Enzyme/MLIR/Passes/Passes.td @@ -141,9 +141,9 @@ def EnzymeOpsToMemRefPass : Pass<"convert-enzyme-to-memref"> { let constructor = "mlir::enzyme::createEnzymeToMemRefPass()"; } -def ShadowedGradientToCachePass : Pass<"convert-enzyme-shadowed-gradient-to-cache"> { - let summary = "Convert Enzyme Shadowed Gradient to Cache Ops"; - let constructor = "mlir::enzyme::createShadowedGradientToCachePass()"; +def MathematicSimplificationPass : Pass<"enzyme-simplify-math"> { + let summary = "Simplify basic mathematical operations"; + let constructor = "mlir::enzyme::createMathematicSimplificationPass()"; } def AddToOpToIndexAndLoadPass : Pass<"add-to-op-to-index-and-load"> { diff --git a/enzyme/Enzyme/MLIR/Passes/RemoveUnusedEnzymeOps.cpp b/enzyme/Enzyme/MLIR/Passes/RemoveUnusedEnzymeOps.cpp index e93334164dd0..5ced93e86512 100644 --- a/enzyme/Enzyme/MLIR/Passes/RemoveUnusedEnzymeOps.cpp +++ b/enzyme/Enzyme/MLIR/Passes/RemoveUnusedEnzymeOps.cpp @@ -22,168 +22,284 @@ #include "mlir/Transforms/DialectConversion.h" #include "mlir/Rewrite/PatternApplicator.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "mlir/IR/Dominance.h" #include "llvm/Support/raw_ostream.h" using namespace mlir; using namespace enzyme; -using llvm::errs; namespace { -// TODO: Expand to region branches?? -bool reachable(Operation *a, Operation *b) { - Block *aBlock = a->getBlock(); - Block *bBlock = b->getBlock(); - if (aBlock == bBlock) { - if (a->isBeforeInBlock(b)) { - return true; - } - } +// Starting at the beginning of blk, is there a path that can execute +// check before end. +bool mayExecuteBefore(Block *blk, Operation *check, Operation *end) { + auto reg = blk->getParent(); + assert(reg->isAncestor(end->getParentRegion())); + DenseSet visitedBlocks; + SmallVector blocksToVisit; + for (auto succ : blk->getSuccessors()) { + blocksToVisit.push_back(succ); + } - blocksToVisit.push_back(aBlock); while (!blocksToVisit.empty()) { - Block *processedBlock = blocksToVisit[blocksToVisit.size() - 1]; - blocksToVisit.pop_back(); + Block *cur = blocksToVisit.pop_back_val(); + + if (visitedBlocks.contains(cur)) + continue; + + visitedBlocks.insert(cur); + + bool seenEnd = false; + for (auto &op : *cur) { + + // If we've seen the thing to check with, it may execute before + if (op.isAncestor(check)) { + // The sole exception to this is if they are in the same sub region, + // which is known to execute only once. TODO this later + /* + if (op.isAncestor(end)) { + + for (auto reg2 : op.getRegions()) { + + } + } + */ - for (Block *successor : processedBlock->getSuccessors()) { - if (!visitedBlocks.contains(successor)) { - visitedBlocks.insert(successor); - blocksToVisit.push_back(successor); + return true; + } - if (successor == bBlock) - return true; + // Otherwise if we've seen the end op, this path is over as the route we + // found here didn't first find a check. + if (op.isAncestor(end)) { + seenEnd = true; + break; } } + + if (seenEnd) + continue; + + // If we didn't find the end, try all successors + for (auto succ : cur->getSuccessors()) { + blocksToVisit.push_back(succ); + } } + return false; } -template -Operation *findNearestDominatingOpByUse(Operation *op, Value v) { +bool mayExecuteBetween(Operation *start, Operation *check, Operation *end) { + + for (auto op = start->getNextNode(); op != nullptr; op = op->getNextNode()) { + // This check op has been found after start in its block + if (op->isAncestor(check)) { + return true; + } + if (op->isAncestor(end)) { + return false; + } + } + + Block *blk = start->getBlock(); + + auto reg = blk->getParent(); + if (reg->isAncestor(end->getParentRegion())) { + return mayExecuteBefore(blk, check, end); + } + + // If the check is in the parent op, but the end is not, assume + // we may execute that parent op part before going to any later ops + if (reg->isAncestor(check->getParentRegion())) { + return true; + } + + return mayExecuteBetween(start->getParentOp(), check, end); +} + +// TODO this isn't necessarily correct. This is because there could be a +// non dominating use bewteen the dominating one and the op, causing +// correctness issues when not seen. In interim, be conservative and only +// succeed if these have the same parent block, and no other ops in path +template +T findNearestDominatingOpByUse(Operation *op, Value v) { DominanceInfo dInfo; + PostDominanceInfo pdInfo; - Operation *closestSetOp = nullptr; + SmallVector options; + SmallVector conflicts; for (Operation *userSet : v.getUsers()) { if (auto setOp = dyn_cast(userSet)) { - if (dInfo.dominates(userSet, op)) { - if (closestSetOp == nullptr) { - closestSetOp = userSet; - } else if (dInfo.dominates(closestSetOp, userSet)) { - closestSetOp = userSet; - } + options.push_back(setOp); + conflicts.push_back(setOp); + continue; + } + if (auto setOp = dyn_cast(userSet)) { + conflicts.push_back(setOp); + continue; + } + } + + for (auto opt : options) { + if (!dInfo.dominates(opt, op)) + continue; + bool conflict = false; + for (auto opt2 : conflicts) { + if (opt == opt2) + continue; + if (opt2 == op) + continue; + + if (!mayExecuteBetween(opt, opt2, op)) { + continue; } + + conflict = true; + } + if (!conflict) { + return opt; } } - return closestSetOp; + + return nullptr; } +struct PopSimplify : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(enzyme::PopOp pop, + PatternRewriter &rewriter) const final { + + auto init = pop.getCache().getDefiningOp(); + if (!init) + return failure(); + + SmallVector pops; + SmallVector pushes; + for (Operation *userSet : init.getResult().getUsers()) { + if (auto push = dyn_cast(userSet)) { + pushes.push_back(push); + continue; + } + if (auto pop = dyn_cast(userSet)) { + pops.push_back(pop); + continue; + } + return failure(); + } + + if (auto push = findNearestDominatingOpByUse( + pop, init)) { + // Do the block check to conservatively avoid multi execute push/pop + if (pop->getBlock() == push->getBlock()) { + rewriter.replaceOp(pop, push.getValue()); + rewriter.eraseOp(push); + return success(); + } + } + + return failure(); + } +}; + +struct GetSimplify : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(enzyme::GetOp get, + PatternRewriter &rewriter) const final { + + auto init = get.getGradient().getDefiningOp(); + if (!init) + return failure(); + + for (Operation *userSet : init.getResult().getUsers()) { + if (isa(userSet)) + continue; + if (isa(userSet)) + continue; + return failure(); + } + + if (auto set = findNearestDominatingOpByUse(get, init)) { + rewriter.replaceOp(get, set.getValue()); + return success(); + } + return failure(); + } +}; + +struct SetSimplify : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(enzyme::SetOp get, + PatternRewriter &rewriter) const final { + + auto init = get.getGradient().getDefiningOp(); + if (!init) + return failure(); + + for (Operation *userSet : init.getResult().getUsers()) { + if (isa(userSet)) + continue; + return failure(); + } + + rewriter.eraseOp(get); + return success(); + } +}; + +struct PushSimplify : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(enzyme::PushOp get, + PatternRewriter &rewriter) const final { + + auto init = get.getCache().getDefiningOp(); + if (!init) + return failure(); + + for (Operation *userSet : init.getResult().getUsers()) { + if (isa(userSet)) + continue; + return failure(); + } + + rewriter.eraseOp(get); + return success(); + } +}; + +struct InitSimplify : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(enzyme::InitOp get, + PatternRewriter &rewriter) const final { + + if (get.use_empty()) { + rewriter.eraseOp(get); + return success(); + } + return failure(); + } +}; + struct RemoveUnusedEnzymeOpsPass : public enzyme::RemoveUnusedEnzymeOpsPassBase { void runOnOperation() override { - getOperation()->walk([&](Operation *op) { - DominanceInfo dInfo; - if (auto initOp = dyn_cast(op)) { - Value v = initOp; - if (auto type = dyn_cast(initOp.getType())) { - bool replaceable = true; - for (Operation *userSet : v.getUsers()) { - if (auto setOp = dyn_cast(userSet)) { - for (Operation *userGet : v.getUsers()) { - if (auto getOp = dyn_cast(userGet)) { - // We can safely delete an enzyme.gradient op if each pair of - // enzyme.set and enzyme.get ops are either not reachable or - // are reachable and do not exist inside a loop - bool relatedButNotInLoop = - dInfo.dominates(userSet, userGet) && - !reachable(getOp, setOp); - bool unrelated = !reachable(setOp, getOp); - if (!(relatedButNotInLoop || unrelated)) { - replaceable = false; - } - } - } - } - } - if (replaceable) { - // Do replacing - for (Operation *userGet : v.getUsers()) { - if (auto getOp = dyn_cast(userGet)) { - Operation *closestSetOp = - findNearestDominatingOpByUse(userGet, v); - auto setOp = dyn_cast(closestSetOp); - getOp.replaceAllUsesWith(setOp.getValue()); - } - } - for (Operation *userGet : v.getUsers()) { - userGet->erase(); - } - op->erase(); - } - } else if (auto type = dyn_cast(initOp.getType())) { - bool replaceable = true; - for (Operation *userPush : v.getUsers()) { - if (auto pushOp = dyn_cast(userPush)) { - // There should only be exactly one push per pop - if (reachable(userPush, userPush)) { - replaceable = false; - } - int numAssociatedPops = 0; - for (Operation *user : v.getUsers()) { - if (auto popOp = dyn_cast(user)) { - if (reachable(userPush, user)) { - // Pops always need to be dominated by the push - if (dInfo.dominates(userPush, user)) { - numAssociatedPops++; - } else { - replaceable = false; - } - } - } - if (auto getOp = dyn_cast(user)) { - if (reachable(userPush, user)) { - // Gets always need to be dominated by the push - if (!dInfo.dominates(userPush, user)) { - replaceable = false; - } - } - } - } - // There should only be one pop per push - if (numAssociatedPops > 1) { - replaceable = false; - } - } - } - if (replaceable) { - // Do replacing - for (Operation *user : v.getUsers()) { - if (auto popOp = dyn_cast(user)) { - Operation *closestPushOp = - findNearestDominatingOpByUse(user, v); - auto pushOp = dyn_cast(closestPushOp); - popOp.replaceAllUsesWith(pushOp.getValue()); - } - if (auto getOp = dyn_cast(user)) { - Operation *closestPushOp = - findNearestDominatingOpByUse(user, v); - auto pushOp = dyn_cast(closestPushOp); - getOp.replaceAllUsesWith(pushOp.getValue()); - } - } - for (Operation *user : v.getUsers()) { - user->erase(); - } - op->erase(); - } - } - } - }); - }; + RewritePatternSet patterns(&getContext()); + patterns.insert(&getContext()); + + GreedyRewriteConfig config; + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), + config); + } }; + } // end anonymous namespace namespace mlir { diff --git a/enzyme/Enzyme/MLIR/Passes/ShadowedGradientToCache.cpp b/enzyme/Enzyme/MLIR/Passes/ShadowedGradientToCache.cpp deleted file mode 100644 index d4e374d6c68f..000000000000 --- a/enzyme/Enzyme/MLIR/Passes/ShadowedGradientToCache.cpp +++ /dev/null @@ -1,64 +0,0 @@ -//===- ShadowedGradientToCache.cpp - Lower Shadowed Gradient ops -//------------------ // -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements a pass to lower custom ops generated by the Enzyme AD -// procedure to the MemRef dialect. -//===----------------------------------------------------------------------===// - -#include "Dialect/Ops.h" -#include "PassDetails.h" -#include "Passes/Passes.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Transforms/DialectConversion.h" - -#include "llvm/Support/raw_ostream.h" - -using namespace mlir; -using namespace enzyme; -using llvm::errs; -namespace { -struct ShadowedGradientToCachePass - : public enzyme::ShadowedGradientToCachePassBase< - ShadowedGradientToCachePass> { - void runOnOperation() override { - getOperation()->walk([&](Operation *op) { - if (auto initOp = dyn_cast(op)) { - if (auto type = - dyn_cast(initOp.getType())) { - Type cacheType = CacheType::get(op->getContext(), type.getBasetype()); - - OpBuilder builder(op); - Value buffer = builder.create(op->getLoc(), cacheType); - - initOp.replaceAllUsesWith(buffer); - initOp->erase(); - } - } - if (auto clearOp = dyn_cast(op)) { - if (auto type = - dyn_cast(clearOp.getCache().getType())) { - OpBuilder builder(op); - builder.create(op->getLoc(), type.getType(), - clearOp.getCache()); - - clearOp->erase(); - } - } - }); - }; -}; -} // end anonymous namespace - -namespace mlir { -namespace enzyme { -std::unique_ptr createShadowedGradientToCachePass() { - return std::make_unique(); -} -} // namespace enzyme -} // namespace mlir diff --git a/enzyme/Enzyme/MLIR/Passes/SimplifyMath.cpp b/enzyme/Enzyme/MLIR/Passes/SimplifyMath.cpp new file mode 100644 index 000000000000..de04163fea8f --- /dev/null +++ b/enzyme/Enzyme/MLIR/Passes/SimplifyMath.cpp @@ -0,0 +1,88 @@ +//===- SimpliyMath.cpp - Simplify Mathematical operations ------------------ // +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements a pass to lower custom ops generated by the Enzyme AD +// procedure to the MemRef dialect. +//===----------------------------------------------------------------------===// + +#include "Dialect/Ops.h" +#include "PassDetails.h" +#include "Passes/Passes.h" +#include "mlir/Dialect/Arith/IR/Arith.h" + +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" + +using namespace mlir; +using namespace enzyme; +using llvm::errs; +namespace { + +struct AddSimplify : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(arith::AddFOp op, + PatternRewriter &rewriter) const final { + + if (matchPattern(op.getLhs(), m_AnyZeroFloat())) { + rewriter.replaceOp(op, op.getRhs()); + return success(); + } + + if (matchPattern(op.getRhs(), m_AnyZeroFloat())) { + rewriter.replaceOp(op, op.getLhs()); + return success(); + } + + return failure(); + } +}; + +struct SubSimplify : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(arith::SubFOp op, + PatternRewriter &rewriter) const final { + + if (matchPattern(op.getRhs(), m_AnyZeroFloat())) { + rewriter.replaceOp(op, op.getLhs()); + return success(); + } + + if (matchPattern(op.getLhs(), m_AnyZeroFloat())) { + rewriter.replaceOpWithNewOp(op, op.getRhs()); + return success(); + } + + return failure(); + } +}; + +struct MathematicSimplification + : public enzyme::MathematicSimplificationPassBase< + MathematicSimplification> { + void runOnOperation() override { + + RewritePatternSet patterns(&getContext()); + patterns.insert(&getContext()); + + GreedyRewriteConfig config; + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), + config); + }; +}; +} // end anonymous namespace + +namespace mlir { +namespace enzyme { +std::unique_ptr createMathematicSimplificationPass() { + return std::make_unique(); +} +} // namespace enzyme +} // namespace mlir diff --git a/enzyme/test/MLIR/Passes/dualpush.mlir b/enzyme/test/MLIR/Passes/dualpush.mlir new file mode 100644 index 000000000000..582feddeabff --- /dev/null +++ b/enzyme/test/MLIR/Passes/dualpush.mlir @@ -0,0 +1,48 @@ +// RUN: %eopt -remove-unnecessary-enzyme-ops %s | FileCheck %s + +// This pop cannot be removed even though we know the first popped value with be -1 +// the other pops will be conditional + +module { + func.func private @diffebbargs(%arg0: f64) { + %c0_i32 = arith.constant 0 : i32 + %c-1_i32 = arith.constant -1 : i32 + %cst = arith.constant 0.000000e+00 : f64 + %3 = "enzyme.init"() : () -> !enzyme.Cache + "enzyme.push"(%3, %c0_i32) : (!enzyme.Cache, i32) -> () + cf.br ^bb1(%arg0 : f64) + ^bb1(%7: f64): // 2 preds: ^bb0, ^bb1 + %8 = arith.cmpf ult, %7, %cst : f64 + "enzyme.push"(%3, %c-1_i32) : (!enzyme.Cache, i32) -> () + cf.cond_br %8, ^bb1(%7 : f64), ^bb4 + ^bb4: // 2 preds: ^bb3, ^bb4 + %18 = "enzyme.pop"(%3) : (!enzyme.Cache) -> i32 + cf.switch %18 : i32, [ + default: ^bb4, + 0: ^bb5 + ] + ^bb5: // pred: ^bb4 + return + } +} + +// CHECK: func.func private @diffebbargs(%arg0: f64) { +// CHECK-NEXT: %c0_i32 = arith.constant 0 : i32 +// CHECK-NEXT: %c-1_i32 = arith.constant -1 : i32 +// CHECK-NEXT: %cst = arith.constant 0.000000e+00 : f64 +// CHECK-NEXT: %0 = "enzyme.init"() : () -> !enzyme.Cache +// CHECK-NEXT: "enzyme.push"(%0, %c0_i32) : (!enzyme.Cache, i32) -> () +// CHECK-NEXT: cf.br ^bb1(%arg0 : f64) +// CHECK-NEXT: ^bb1(%1: f64): // 2 preds: ^bb0, ^bb1 +// CHECK-NEXT: %2 = arith.cmpf ult, %1, %cst : f64 +// CHECK-NEXT: "enzyme.push"(%0, %c-1_i32) : (!enzyme.Cache, i32) -> () +// CHECK-NEXT: cf.cond_br %2, ^bb1(%1 : f64), ^bb2 +// CHECK-NEXT: ^bb2: // 2 preds: ^bb1, ^bb2 +// CHECK-NEXT: %3 = "enzyme.pop"(%0) : (!enzyme.Cache) -> i32 +// CHECK-NEXT: cf.switch %3 : i32, [ +// CHECK-NEXT: default: ^bb2, +// CHECK-NEXT: 0: ^bb3 +// CHECK-NEXT: ] +// CHECK-NEXT: ^bb3: // pred: ^bb2 +// CHECK-NEXT: return +// CHECK-NEXT: } \ No newline at end of file diff --git a/enzyme/test/MLIR/ReverseMode/bbarg-order.mlir b/enzyme/test/MLIR/ReverseMode/bbarg-order.mlir index e5bb39eea040..141ff46aaade 100644 --- a/enzyme/test/MLIR/ReverseMode/bbarg-order.mlir +++ b/enzyme/test/MLIR/ReverseMode/bbarg-order.mlir @@ -1,4 +1,4 @@ -// RUN: %eopt --enzyme %s | FileCheck %s +// RUN: %eopt --enzyme -canonicalize --remove-unnecessary-enzyme-ops -canonicalize %s | FileCheck %s module { func.func @bbargs(%x: f64) -> f64 { @@ -19,15 +19,38 @@ module { } } -// CHECK: func.func @diff(%[[arg0:.+]]: f64, %[[arg1:.+]]: f64) -> f64 { -// CHECK-NEXT: %[[i0:.+]] = call @diffebbargs(%[[arg0]], %[[arg1]]) : (f64, f64) -> f64 -// CHECK-NEXT: return %[[i0:.+]] -// CHECK: func.func private @diffebbargs(%[[arg0:.+]]: f64, %[[arg1:.+]]: f64) -> f64 { - -// There should be exactly one block with two f64 args, and their values should be accumulated -// in the shadow. -// CHECK: ^[[BBMULTI:.+]](%[[fst:.+]]: f64, %[[snd:.+]]: f64): -// CHECK-NEXT: "enzyme.set"(%[[shadow:.+]], %[[fst]]) -// CHECK-NEXT: %[[before:.+]] = "enzyme.get"(%[[shadow]]) -// CHECK-NEXT: %[[after:.+]] = arith.addf %[[snd]], %[[before]] -// CHECK-NEXT: "enzyme.set"(%[[shadow]], %[[after]]) +// CHECK: func.func private @diffebbargs(%arg0: f64, %arg1: f64) -> f64 { +// CHECK-NEXT: %c-1_i32 = arith.constant -1 : i32 +// CHECK-NEXT: %c0_i32 = arith.constant 0 : i32 +// CHECK-NEXT: %cst = arith.constant 1.000000e+00 : f64 +// CHECK-NEXT: %cst_0 = arith.constant 0.000000e+00 : f64 +// CHECK-NEXT: %0 = "enzyme.init"() : () -> !enzyme.Cache +// CHECK-NEXT: %1 = "enzyme.init"() : () -> !enzyme.Cache +// CHECK-NEXT: %2 = arith.addf %arg0, %cst : f64 +// CHECK-NEXT: "enzyme.push"(%1, %c0_i32) : (!enzyme.Cache, i32) -> () +// CHECK-NEXT: cf.br ^bb1(%2 : f64) +// CHECK-NEXT: ^bb1(%3: f64): // 2 preds: ^bb0, ^bb1 +// CHECK-NEXT: %4 = arith.cmpf ult, %3, %cst_0 : f64 +// CHECK-NEXT: "enzyme.push"(%1, %c-1_i32) : (!enzyme.Cache, i32) -> () +// CHECK-NEXT: "enzyme.push"(%0, %c-1_i32) : (!enzyme.Cache, i32) -> () +// CHECK-NEXT: cf.cond_br %4, ^bb1(%3 : f64), ^bb2 +// CHECK-NEXT: ^bb2: // pred: ^bb1 +// CHECK-NEXT: %5 = arith.addf %arg1, %cst_0 : f64 +// CHECK-NEXT: %6 = "enzyme.pop"(%0) : (!enzyme.Cache) -> i32 +// CHECK-NEXT: %7 = arith.cmpi eq, %6, %c-1_i32 : i32 +// CHECK-NEXT: %8 = arith.select %7, %5, %cst_0 : f64 +// CHECK-NEXT: %9 = arith.addf %8, %cst_0 : f64 +// CHECK-NEXT: cf.br ^bb3 +// CHECK-NEXT: ^bb3: // 2 preds: ^bb2, ^bb3 +// CHECK-NEXT: %10 = "enzyme.pop"(%1) : (!enzyme.Cache) -> i32 +// CHECK-NEXT: %11 = arith.cmpi eq, %10, %c-1_i32 : i32 +// CHECK-NEXT: %12 = arith.select %11, %9, %cst_0 : f64 +// CHECK-NEXT: %13 = arith.addf %12, %cst_0 : f64 +// CHECK-NEXT: cf.switch %10 : i32, [ +// CHECK-NEXT: default: ^bb3, +// CHECK-NEXT: 0: ^bb4 +// CHECK-NEXT: ] +// CHECK-NEXT: ^bb4: // pred: ^bb3 +// CHECK-NEXT: %14 = arith.addf %13, %cst_0 : f64 +// CHECK-NEXT: return %14 : f64 +// CHECK-NEXT: } diff --git a/enzyme/test/MLIR/ReverseMode/pow.mlir b/enzyme/test/MLIR/ReverseMode/pow.mlir index 5c5596ec389e..9934152def61 100644 --- a/enzyme/test/MLIR/ReverseMode/pow.mlir +++ b/enzyme/test/MLIR/ReverseMode/pow.mlir @@ -1,4 +1,4 @@ -// RUN: %eopt --enzyme %s | FileCheck %s +// RUN: %eopt --enzyme -canonicalize --remove-unnecessary-enzyme-ops -enzyme-simplify-math -canonicalize %s | FileCheck %s module { func.func @ppow(%x: f64) -> f64 { @@ -19,29 +19,46 @@ module { } } -// CHECK: func.func private @diffeppow(%[[x:.+]]: f64, %[[dr:.+]]: f64) -> f64 +// CHECK: func.func private @diffeppow(%[[x:.+]]: f64, %[[dr:.+]]: f64) -> f64 { +// CHECK-NEXT: %c10 = arith.constant 10 : index +// CHECK-NEXT: %c1 = arith.constant 1 : index +// CHECK-NEXT: %c0 = arith.constant 0 : index +// CHECK-NEXT: %[[one:.+]] = arith.constant 1.0 +// CHECK-NEXT: %[[zero:.+]] = arith.constant 0.000000e+00 : f64 +// CHECK-NEXT: %[[xshadow:.+]] = "enzyme.init"() : () -> !enzyme.Gradient +// CHECK-NEXT: "enzyme.set"(%[[xshadow]], %[[zero]]) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: %[[itshadow:.+]] = "enzyme.init"() : () -> !enzyme.Gradient +// CHECK-NEXT: "enzyme.set"(%[[itshadow]], %[[zero]]) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: %[[xcache:.+]] = "enzyme.init"() : () -> !enzyme.Cache +// CHECK-NEXT: %[[rcache:.+]] = "enzyme.init"() : () -> !enzyme.Cache +// CHECK-NEXT: %[[rshadow:.+]] = "enzyme.init"() : () -> !enzyme.Gradient +// CHECK-NEXT: "enzyme.set"(%[[rshadow]], %[[zero]]) : (!enzyme.Gradient, f64) -> () -// Make sure the right values are being cached in the primal -// CHECK: %[[one:.+]] = arith.constant 1.0 -// CHECK: scf.for %[[iv:.+]] = %c0 to %c10 step %c1 iter_args(%[[r_it:.+]] = %[[one]]) -// CHECK-NEXT: "enzyme.push"(%[[rcache:.+]], %[[r_it]]) -// CHECK-NEXT: "enzyme.push"(%[[xcache:.+]], %[[x]]) - -// Ensure the right value is yielded in the adjoint -// CHECK: "enzyme.set"(%[[rshadow:.+]], %[[dr]]) -// CHECK: %[[dr:.+]] = "enzyme.get"(%[[rshadow]]) -// CHECK: scf.for %[[iv:.+]] = %[[lb:.+]] to %[[ub:.+]] step %[[step:.+]] iter_args(%[[dr_it:.+]] = %[[dr]]) -// CHECK-NEXT: "enzyme.set"(%[[rshadow:.+]], %[[dr_it]]) -// CHECK-NEXT: %[[dr_it:.+]] = "enzyme.get"(%[[rshadow]]) -// CHECK-NEXT: %[[r_cached:.+]] = "enzyme.pop"(%[[rcache]]) -// CHECK-NEXT: %[[x:.+]] = "enzyme.pop"(%[[xcache]]) -// CHECK-NEXT: %[[dr_next:.+]] = arith.mulf %[[dr_it]], %[[x]] -// CHECK-NEXT: "enzyme.set"(%[[rshadow:.+]], %[[dr_next]]) -// CHECK-NEXT: %[[dx_next:.+]] = arith.mulf %[[dr_it]], %[[r_cached]] -// CHECK-NEXT: %[[dx0:.+]] = "enzyme.get"(%[[xshadow:.+]]) : -// CHECK-NEXT: %[[dx1:.+]] = arith.addf %[[dx0]], %[[dx_next]] -// CHECK-NEXT: "enzyme.set"(%[[xshadow]], %[[dx1]]) -// CHECK-NEXT: %[[dr_next:.+]] = "enzyme.get"(%[[rshadow]]) -// CHECK-NEXT: scf.yield %[[dr_next]] -// CHECK: %[[final:.+]] = "enzyme.get"(%[[xshadow]]) -// CHECK-NEXT: return %[[final]] +// CHECK-NEXT: %{{.+}} = scf.for %[[iv:.+]] = %c0 to %c10 step %c1 iter_args(%[[r_it:.+]] = %[[one]]) -> (f64) { +// CHECK-NEXT: "enzyme.push"(%[[rcache]], %[[r_it]]) : (!enzyme.Cache, f64) -> () +// CHECK-NEXT: "enzyme.push"(%[[xcache]], %[[x]]) : (!enzyme.Cache, f64) -> () +// CHECK-NEXT: %[[fwd:.+]] = arith.mulf %[[r_it]], %[[x]] : f64 +// CHECK-NEXT: scf.yield %[[fwd]] : f64 +// CHECK-NEXT: } +// CHECK-NEXT: "enzyme.set"(%[[rshadow]], %[[dr]]) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: scf.for %[[div:.+]] = %c0 to %c10 step %c1 { +// CHECK-NEXT: %[[dr_it:.+]] = "enzyme.get"(%[[rshadow]]) : (!enzyme.Gradient) -> f64 +// CHECK-NEXT: "enzyme.set"(%[[rshadow]], %[[zero]]) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: %[[r_cached:.+]] = "enzyme.pop"(%[[rcache]]) : (!enzyme.Cache) -> f64 +// CHECK-NEXT: %[[x_cached:.+]] = "enzyme.pop"(%[[xcache]]) : (!enzyme.Cache) -> f64 +// CHECK-NEXT: %[[dr_next:.+]] = arith.mulf %[[dr_it]], %[[x_cached]] +// CHECK-NEXT: %[[previts:.+]] = "enzyme.get"(%[[itshadow]]) : (!enzyme.Gradient) -> f64 +// CHECK-NEXT: %[[postits:.+]] = arith.addf %[[previts]], %[[dr_next]] : f64 +// CHECK-NEXT: "enzyme.set"(%[[itshadow]], %[[postits]]) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: %[[dx_next:.+]] = arith.mulf %[[dr_it]], %[[r_cached]] : f64 +// CHECK-NEXT: %[[dx0:.+]] = "enzyme.get"(%[[xshadow]]) : +// CHECK-NEXT: %[[dx1:.+]] = arith.addf %[[dx0]], %[[dx_next]] +// CHECK-NEXT: "enzyme.set"(%[[xshadow]], %[[dx1]]) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: %[[divp1:.+]] = arith.addi %[[div]], %c1 : index +// CHECK-NEXT: %[[last:.+]] = arith.cmpi sge, %[[divp1]], %c10 : index +// CHECK-NEXT: "enzyme.set"(%[[itshadow]], %[[zero]]) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: %[[sel:.+]] = arith.select %[[last]], %[[zero]], %12 : f64 +// CHECK-NEXT: "enzyme.set"(%[[itshadow]], %[[sel]]) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: } +// CHECK-NEXT: %[[final:.+]] = "enzyme.get"(%[[xshadow]]) +// CHECK-NEXT: return %[[final]] \ No newline at end of file diff --git a/enzyme/test/MLIR/ReverseMode/square.mlir b/enzyme/test/MLIR/ReverseMode/square.mlir new file mode 100644 index 000000000000..4bcae3bb8000 --- /dev/null +++ b/enzyme/test/MLIR/ReverseMode/square.mlir @@ -0,0 +1,75 @@ +// RUN: %eopt --enzyme %s | FileCheck %s +// RUN: %eopt --enzyme --canonicalize --remove-unnecessary-enzyme-ops %s | FileCheck %s --check-prefix=REM +// RUN: %eopt --enzyme --canonicalize --remove-unnecessary-enzyme-ops --canonicalize --enzyme-simplify-math --cse %s | FileCheck %s --check-prefix=FIN + +module { + func.func @square(%x: f64) -> f64 { + %next = arith.mulf %x, %x : f64 + return %next : f64 + } + + func.func @dsquare(%x: f64, %dr: f64) -> f64 { + %r = enzyme.autodiff @square(%x, %dr) { activity=[#enzyme] } : (f64, f64) -> f64 + return %r : f64 + } +} + + +// CHECK: func.func @dsquare(%arg0: f64, %arg1: f64) -> f64 { +// CHECK-NEXT: %0 = call @diffesquare(%arg0, %arg1) : (f64, f64) -> f64 +// CHECK-NEXT: return %0 : f64 +// CHECK-NEXT: } + +// CHECK: func.func private @diffesquare(%arg0: f64, %arg1: f64) -> f64 { +// CHECK-NEXT: %[[dx:.+]] = "enzyme.init"() : () -> !enzyme.Gradient +// CHECK-NEXT: %[[c0:.+]] = arith.constant 0.000000e+00 : f64 +// CHECK-NEXT: "enzyme.set"(%[[dx]], %[[c0]]) : (!enzyme.Gradient, f64) -> () + +// CHECK-NEXT: %[[lhscache:.+]] = "enzyme.init"() : () -> !enzyme.Cache +// CHECK-NEXT: %[[rhscache:.+]] = "enzyme.init"() : () -> !enzyme.Cache + +// CHECK-NEXT: %[[dy:.+]] = "enzyme.init"() : () -> !enzyme.Gradient +// CHECK-NEXT: %[[c1:.+]] = arith.constant 0.000000e+00 : f64 +// CHECK-NEXT: "enzyme.set"(%[[dy]], %[[c1]]) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: "enzyme.push"(%[[rhscache]], %arg0) : (!enzyme.Cache, f64) -> () +// CHECK-NEXT: "enzyme.push"(%[[lhscache]], %arg0) : (!enzyme.Cache, f64) -> () +// CHECK-NEXT: %[[mul:.+]] = arith.mulf %arg0, %arg0 : f64 +// CHECK-NEXT: cf.br ^bb1 + +// CHECK: ^bb1: // pred: ^bb0 +// CHECK-NEXT: %[[prevdret0:.+]] = "enzyme.get"(%[[dy]]) : (!enzyme.Gradient) -> f64 +// CHECK-NEXT: %[[postdret0:.+]] = arith.addf %[[prevdret0]], %arg1 : f64 +// CHECK-NEXT: "enzyme.set"(%[[dy]], %[[postdret0]]) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: %[[prevdret:.+]] = "enzyme.get"(%[[dy]]) : (!enzyme.Gradient) -> f64 +// CHECK-NEXT: %[[c2:.+]] = arith.constant 0.000000e+00 : f64 +// CHECK-NEXT: "enzyme.set"(%[[dy]], %[[c2]]) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: %[[postlhs:.+]] = "enzyme.pop"(%[[rhscache]]) : (!enzyme.Cache) -> f64 +// CHECK-NEXT: %[[postrhs:.+]] = "enzyme.pop"(%[[lhscache]]) : (!enzyme.Cache) -> f64 +// CHECK-NEXT: %[[dlhs:.+]] = arith.mulf %[[prevdret]], %[[postrhs]] : f64 +// CHECK-NEXT: %[[prevdx1:.+]] = "enzyme.get"(%[[dx]]) : (!enzyme.Gradient) -> f64 +// CHECK-NEXT: %[[postdx1:.+]] = arith.addf %[[prevdx1]], %[[dlhs]] : f64 +// CHECK-NEXT: "enzyme.set"(%[[dx]], %[[postdx1]]) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: %[[drhs:.+]] = arith.mulf %[[prevdret]], %[[postlhs]] : f64 +// CHECK-NEXT: %[[prevdx2:.+]] = "enzyme.get"(%[[dx]]) : (!enzyme.Gradient) -> f64 +// CHECK-NEXT: %[[postdx2:.+]] = arith.addf %[[prevdx2]], %[[drhs]] : f64 +// CHECK-NEXT: "enzyme.set"(%[[dx]], %[[postdx2]]) : (!enzyme.Gradient, f64) -> () +// CHECK-NEXT: %[[res:.+]] = "enzyme.get"(%[[dx]]) : (!enzyme.Gradient) -> f64 +// CHECK-NEXT: return %[[res]] : f64 +// CHECK-NEXT: } + + +// REM: func.func private @diffesquare(%arg0: f64, %arg1: f64) -> f64 { +// REM-NEXT: %[[cst:.+]] = arith.constant 0.000000e+00 : f64 +// REM-NEXT: %[[a1:.+]] = arith.addf %arg1, %[[cst]] : f64 +// REM-NEXT: %[[a2:.+]] = arith.mulf %[[a1]], %arg0 : f64 +// REM-NEXT: %[[a3:.+]] = arith.addf %[[a2]], %[[cst]] : f64 +// REM-NEXT: %[[a4:.+]] = arith.mulf %[[a1]], %arg0 : f64 +// REM-NEXT: %[[a5:.+]] = arith.addf %[[a3]], %[[a4]] : f64 +// REM-NEXT: return %[[a5]] : f64 +// REM-NEXT: } + +// FIN: func.func private @diffesquare(%arg0: f64, %arg1: f64) -> f64 { +// FIN-NEXT: %0 = arith.mulf %arg1, %arg0 : f64 +// FIN-NEXT: %1 = arith.addf %0, %0 : f64 +// FIN-NEXT: return %1 : f64 +// FIN-NEXT: } \ No newline at end of file diff --git a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp index 143c85ea684e..5f456c8755c6 100644 --- a/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/enzyme-tblgen.cpp @@ -1811,7 +1811,7 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, if (hasDiffeRet(resultTree)) { if (intrinsic == MLIRDerivatives) { os << " dif = gutils->diffe(" << origName << ", builder);\n"; - os << " gutils->clearValue(" << origName << ", builder);\n"; + os << " gutils->zeroDiffe(" << origName << ", builder);\n"; } else { os << " dif = diffe(&" << origName << ", Builder2);\n"; os << " setDiffe(&" << origName @@ -2040,6 +2040,8 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, const auto &brpatterns = recordKeeper.getAllDerivedDefinitions("BranchOp"); + const auto &retpatterns = recordKeeper.getAllDerivedDefinitions("ReturnOp"); + const auto ®tpatterns = recordKeeper.getAllDerivedDefinitions("RegionTerminatorOp"); @@ -2092,6 +2094,12 @@ static void emitDerivatives(const RecordKeeper &recordKeeper, raw_ostream &os, os << " registerAutoDiffUsingRegionTerminatorInterface<" << dialect << "::" << opName << ">(*context);\n"; } + for (Record *pattern : retpatterns) { + auto opName = pattern->getValueAsString("opName"); + auto dialect = pattern->getValueAsString("dialect"); + os << " registerAutoDiffUsingReturnInterface<" << dialect + << "::" << opName << ">(*context);\n"; + } for (Record *pattern : allocpatterns) { auto opName = pattern->getValueAsString("opName"); auto dialect = pattern->getValueAsString("dialect");