Skip to content

Commit

Permalink
Cleanup and Fixup MLIR reverse mode (#1771)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Mar 1, 2024
1 parent b97aa9d commit b96c443
Show file tree
Hide file tree
Showing 33 changed files with 1,101 additions and 908 deletions.
18 changes: 18 additions & 0 deletions enzyme/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,23 @@ gentbl(
],
)

gentbl(
name = "func-derivatives",
tbl_outs = [(
"-gen-mlir-derivatives",
"Enzyme/MLIR/Implementations/FuncDerivatives.inc",
)],
tblgen = ":enzyme-tblgen",
td_file = "Enzyme/MLIR/Implementations/FuncDerivatives.td",
td_srcs = [
"Enzyme/MLIR/Implementations/FuncDerivatives.td",
"Enzyme/MLIR/Implementations/Common.td",
],
deps = [
":enzyme-tblgen",
],
)

cc_library(
name = "EnzymeMLIR",
srcs = glob([
Expand Down Expand Up @@ -582,6 +599,7 @@ cc_library(
":arith-derivatives",
":cf-derivatives",
":llvm-derivatives",
":func-derivatives",
":math-derivatives",
":memref-derivatives",
":nvvm-derivatives",
Expand Down
42 changes: 14 additions & 28 deletions enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,6 @@ def PopOp : Enzyme_Op<"pop"> {
let results = (outs AnyType:$output);
}

def ClearOp : Enzyme_Op<"clear"> {
let summary = "Remove top element from ShadowedGradient";
let arguments = (ins AnyType : $cache);
let results = (outs );
}

def InitOp : Enzyme_Op<"init"> {
let summary = "Creat enzyme.gradient and enzyme.cache";
let arguments = (ins );
Expand All @@ -105,36 +99,28 @@ def Cache : Enzyme_Type<"Cache"> {
let assemblyFormat = "`<` $type `>`";
}

def SetOp : Enzyme_Op<"set"> {
let summary = "Write to gradient";
let arguments = (ins AnyType : $gradient, AnyType : $value);
let results = (outs );
}

def GetOp : Enzyme_Op<"get"> {
let summary = "Load value of gradient";
let arguments = (ins AnyType : $gradient);
let results = (outs AnyType);
}

def Gradient : Enzyme_Type<"Gradient"> {
let summary = "Stores gradient if it cant be stroed in a value.";
let summary = "Mutable storage for accumulating gradients";
let description = [{
"Cache for reverse pass"
Mutable storage for accumulating derivatives of immutable types (e.g. adding all the partial derivatives from users of a float64)
}];
let parameters = (ins "Type":$basetype);
let mnemonic = "Gradient";
let assemblyFormat = "`<` $basetype `>`";
}

def ShadowedGradient : Enzyme_Type<"ShadowedGradient"> {
let summary = "Stores gradients which need to be initialized with shadow values from the forward pass.";
let description = [{
"Cache for reverse pass"
}];
let parameters = (ins "Type":$basetype);
let mnemonic = "ShadowedGradient";
let assemblyFormat = "`<` $basetype `>`";
def SetOp : Enzyme_Op<"set"> {
let summary = "Store the current value of the gradient";
let arguments = (ins Arg<AnyType, "the reference to store to",
[MemWrite]>:$gradient, AnyType : $value);
let results = (outs );
}

def GetOp : Enzyme_Op<"get"> {
let summary = "Load current value of gradient";
let arguments = (ins Arg<AnyType, "the reference to load from",
[MemRead]>:$gradient);
let results = (outs AnyType);
}

def AddToOp : Enzyme_Op<"addTo", [Pure, Terminator, ReturnLike]>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ class FloatTypeInterface
return self;
}

bool requiresShadow(Type self) const { return false; }
bool isMutable(Type self) const { return false; }
LogicalResult zeroInPlace(Type self, OpBuilder &builder, Location loc,
Value val) const {
return failure();
Expand Down Expand Up @@ -77,7 +77,7 @@ class TensorTypeInterface
return self;
}

bool requiresShadow(Type self) const { return false; }
bool isMutable(Type self) const { return false; }
LogicalResult zeroInPlace(Type self, OpBuilder &builder, Location loc,
Value val) const {
return failure();
Expand Down Expand Up @@ -105,7 +105,7 @@ class IntegerTypeInterface
return self;
}

bool requiresShadow(Type self) const { return false; }
bool isMutable(Type self) const { return false; }
LogicalResult zeroInPlace(Type self, OpBuilder &builder, Location loc,
Value val) const {
return failure();
Expand Down
1 change: 1 addition & 0 deletions enzyme/Enzyme/MLIR/Implementations/CFDerivatives.td
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
include "Common.td"

def : BranchOp<"cf", "CondBranchOp">;
def : BranchOp<"cf", "BranchOp">;
def : BranchOp<"cf", "SwitchOp">;
7 changes: 7 additions & 0 deletions enzyme/Enzyme/MLIR/Implementations/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,18 @@ set(LLVM_TARGET_DEFINITIONS MathDerivatives.td)
enzyme_tablegen(MathDerivatives.inc -gen-mlir-derivatives)
add_public_tablegen_target(MathDerivativesIncGen)

set(LLVM_TARGET_DEFINITIONS FuncDerivatives.td)
enzyme_tablegen(FuncDerivatives.inc -gen-mlir-derivatives)
add_public_tablegen_target(FuncDerivativesIncGen)

add_mlir_library(MLIREnzymeImplementations
AffineAutoDiffOpInterfaceImpl.cpp
ArithAutoDiffOpInterfaceImpl.cpp
CoreDialectsAutoDiffImplementations.cpp
LLVMAutoDiffOpInterfaceImpl.cpp
NVVMAutoDiffOpInterfaceImpl.cpp
MemRefAutoDiffOpInterfaceImpl.cpp
FuncAutoDiffOpInterfaceImpl.cpp
LinalgAutoDiffOpInterfaceImpl.cpp
BuiltinAutoDiffTypeInterfaceImpl.cpp
SCFAutoDiffOpInterfaceImpl.cpp
Expand All @@ -48,6 +53,7 @@ add_mlir_library(MLIREnzymeImplementations
AffineDerivativesIncGen
ArithDerivativesIncGen
LLVMDerivativesIncGen
FuncDerivativesIncGen
NVVMDerivativesIncGen
SCFDerivativesIncGen
CFDerivativesIncGen
Expand All @@ -56,6 +62,7 @@ add_mlir_library(MLIREnzymeImplementations

LINK_LIBS PUBLIC
MLIRArithDialect
MLIRFuncDialect
MLIRLLVMDialect
MLIRMemRefDialect
MLIREnzymeAutoDiffInterface
Expand Down
5 changes: 5 additions & 0 deletions enzyme/Enzyme/MLIR/Implementations/Common.td
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ class MemoryIdentityOp<string dialect_, string opName_, list<int> ptrargs_, list

class ReadOnlyIdentityOp<string dialect_, string opName_, list<int> ptrargs_> : MemoryIdentityOp<dialect_, opName_, ptrargs_>;

class ReturnOp<string dialect_, string opName_> {
string dialect = dialect_;
string opName = opName_;
}

class BranchOp<string dialect_, string opName_> {
string dialect = dialect_;
string opName = opName_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "Interfaces/AutoDiffOpInterface.h"
#include "Interfaces/AutoDiffTypeInterface.h"
#include "Interfaces/GradientUtils.h"
#include "Interfaces/GradientUtilsReverse.h"

using namespace mlir;
using namespace mlir::enzyme;
Expand Down Expand Up @@ -143,8 +144,7 @@ LogicalResult mlir::enzyme::detail::memoryIdentityForwardHandler(
if (contains(storedVals, operand.getOperandNumber())) {
if (auto iface =
dyn_cast<AutoDiffTypeInterface>(operand.get().getType())) {
if (!iface.requiresShadow()) {
// TODO only do if mutable
if (!iface.isMutable()) {
Type retTy = iface.getShadowType();
auto toret = retTy.cast<AutoDiffTypeInterface>().createNullValue(
builder, operand.get().getLoc());
Expand Down Expand Up @@ -201,6 +201,29 @@ LogicalResult mlir::enzyme::detail::allocationForwardHandler(
return success();
}

void mlir::enzyme::detail::returnReverseHandler(Operation *op,
OpBuilder &builder,
MGradientUtilsReverse *gutils) {
size_t num_out = 0;
for (auto act : gutils->RetDiffeTypes) {
if (act == DIFFE_TYPE::OUT_DIFF)
num_out++;
}

size_t idx = 0;
auto args = gutils->newFunc->getRegions().begin()->begin()->getArguments();

for (auto &&[op, act] : llvm::zip(op->getOperands(), gutils->RetDiffeTypes)) {
if (act == DIFFE_TYPE::OUT_DIFF) {
if (!gutils->isConstantValue(op)) {
auto d_out = args[args.size() - num_out + idx];
gutils->addToDiffe(op, d_out, builder);
}
idx++;
}
}
}

void mlir::enzyme::detail::regionTerminatorForwardHandler(
Operation *origTerminator, OpBuilder &builder, MGradientUtils *gutils) {
auto parentOp = origTerminator->getParentOp();
Expand Down Expand Up @@ -401,4 +424,5 @@ void mlir::enzyme::registerCoreDialectAutodiffInterfaces(
enzyme::registerSCFDialectAutoDiffInterface(registry);
enzyme::registerCFDialectAutoDiffInterface(registry);
enzyme::registerLinalgDialectAutoDiffInterface(registry);
enzyme::registerFuncDialectAutoDiffInterface(registry);
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ void branchingForwardHandler(Operation *op, OpBuilder &builder,
void regionTerminatorForwardHandler(Operation *op, OpBuilder &builder,
MGradientUtils *gutils);

// Implements reverse-mode differentiation of return operations.
void returnReverseHandler(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils);

// Implements forward-mode differentiation of read-only (including read-none)
// operations which do not perform computation
LogicalResult memoryIdentityForwardHandler(Operation *op, OpBuilder &builder,
Expand Down Expand Up @@ -104,6 +108,44 @@ class AutoDiffUsingRegionTerminator
}
};

template <typename OpTy>
class NoopRevAutoDiffInterface
: public ReverseAutoDiffOpInterface::ExternalModel<
NoopRevAutoDiffInterface<OpTy>, OpTy> {
public:
void createReverseModeAdjoint(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils,
SmallVector<Value> caches) const {}

SmallVector<Value> cacheValues(Operation *op,
MGradientUtilsReverse *gutils) const {
return SmallVector<Value>();
}

void createShadowValues(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils) const {}
};

template <typename OpTy>
class ReturnRevAutoDiffInterface
: public ReverseAutoDiffOpInterface::ExternalModel<
ReturnRevAutoDiffInterface<OpTy>, OpTy> {
public:
void createReverseModeAdjoint(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils,
SmallVector<Value> caches) const {
returnReverseHandler(op, builder, gutils);
}

SmallVector<Value> cacheValues(Operation *op,
MGradientUtilsReverse *gutils) const {
return SmallVector<Value>();
}

void createShadowValues(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils) const {}
};

// Implements the forward autodiff interface for operations which are
// read only and identity like (aka not computing sin of mem read).
template <typename OpTy, int... storedvals>
Expand Down Expand Up @@ -166,12 +208,24 @@ void registerAutoDiffUsingControlFlowInterface(MLIRContext &context) {
template <typename OpTy>
void registerAutoDiffUsingBranchInterface(MLIRContext &context) {
OpTy::template attachInterface<detail::AutoDiffUsingBranch<OpTy>>(context);
OpTy::template attachInterface<detail::NoopRevAutoDiffInterface<OpTy>>(
context);
}
// Registers AutoDiffUsingRegionTerminator for the given op.
template <typename OpTy>
void registerAutoDiffUsingRegionTerminatorInterface(MLIRContext &context) {
OpTy::template attachInterface<detail::AutoDiffUsingRegionTerminator<OpTy>>(
context);
OpTy::template attachInterface<detail::NoopRevAutoDiffInterface<OpTy>>(
context);
}
// Registers registerAutoDiffUsingReturnInterface for the given op.
template <typename OpTy>
void registerAutoDiffUsingReturnInterface(MLIRContext &context) {
OpTy::template attachInterface<detail::AutoDiffUsingRegionTerminator<OpTy>>(
context);
OpTy::template attachInterface<detail::ReturnRevAutoDiffInterface<OpTy>>(
context);
}
// Registers AutoDiffUsingMemoryIdentity for the given op.
template <typename OpTy, int... storedvals>
Expand Down Expand Up @@ -199,6 +253,7 @@ void registerSCFDialectAutoDiffInterface(DialectRegistry &registry);
void registerCFDialectAutoDiffInterface(DialectRegistry &registry);
void registerLinalgDialectAutoDiffInterface(DialectRegistry &registry);
void registerMathDialectAutoDiffInterface(DialectRegistry &registry);
void registerFuncDialectAutoDiffInterface(DialectRegistry &registry);

void registerCoreDialectAutodiffInterfaces(DialectRegistry &registry);

Expand Down
37 changes: 37 additions & 0 deletions enzyme/Enzyme/MLIR/Implementations/FuncAutoDiffOpInterfaceImpl.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
//===- FuncAutoDiffOpInterfaceImpl.cpp - Interface external model --------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains the external model implementation of the automatic
// differentiation op interfaces for the upstream MLIR arithmetic dialect.
//
//===----------------------------------------------------------------------===//

#include "Implementations/CoreDialectsAutoDiffImplementations.h"
#include "Interfaces/AutoDiffOpInterface.h"
#include "Interfaces/GradientUtils.h"
#include "Interfaces/GradientUtilsReverse.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/DialectRegistry.h"
#include "mlir/Support/LogicalResult.h"

#include "Dialect/Ops.h"
#include "mlir/IR/TypeSupport.h"

using namespace mlir;
using namespace mlir::enzyme;

namespace {
#include "Implementations/FuncDerivatives.inc"
} // namespace

void mlir::enzyme::registerFuncDialectAutoDiffInterface(
DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *context, func::FuncDialect *) {
registerInterfaces(context);
});
}
3 changes: 3 additions & 0 deletions enzyme/Enzyme/MLIR/Implementations/FuncDerivatives.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
include "Common.td"

def : ReturnOp<"func", "ReturnOp">;
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class PointerTypeInterface
return self;
}

bool requiresShadow(Type self) const { return true; }
bool isMutable(Type self) const { return true; }

LogicalResult zeroInPlace(Type self, OpBuilder &builder, Location loc,
Value val) const {
Expand Down
Loading

0 comments on commit b96c443

Please sign in to comment.