Skip to content

Commit

Permalink
Added reverse mode adjoint for linalg.generic (#1134)
Browse files Browse the repository at this point in the history
* Added reverse mode adjoint for linalg.generic

* clang-format

* Only allocate shadow memory in initialization block for enzyme_out arguments

* added simplify-memref-cache pass to enable lowering of linalg.generic adjoint

* Addressing some comments by @ftysne

* Added genericOpInterface to all LinalgStructuredOps

* Added linalg.generic adjoint to all linalg
structured ops.

* Adressing code review

* Added some documentation

* Adressing comments

* Fix errors with migration to MLIR 16.0.5

* Added support for aliasing gradients

* Clang Format

* Fixing bugs with higher dimensional memrefs

* Add passes to legalize AddTo op

* Added Enzyme GenericAdjointOp

* Cleanup

---------

Co-authored-by: Jacob Peng <jacobmpeng@gmail.com>
  • Loading branch information
umatin and pengmai authored Jul 11, 2023
1 parent 2d2fdff commit 3d35e80
Show file tree
Hide file tree
Showing 21 changed files with 945 additions and 37 deletions.
34 changes: 32 additions & 2 deletions enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,27 @@
//
//===----------------------------------------------------------------------===//

//include "mlir/Dialect/Linalg/IR/LinalgBase.td"

#ifndef ENZYME_OPS
#define ENZYME_OPS

include "Dialect.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/ViewLikeInterface.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/IR/EnumAttr.td"

include "mlir/IR/OpBase.td"
include "mlir/IR/FunctionInterfaces.td"
include "mlir/IR/SymbolInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"

include "mlir/IR/AttrTypeBase.td"

include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"


def Activity : I32EnumAttr<"Activity",
"Possible activity states for variables",
[
Expand Down Expand Up @@ -132,4 +137,29 @@ def ShadowedGradient : Enzyme_Type<"ShadowedGradient"> {
let assemblyFormat = "`<` $basetype `>`";
}

def AddToOp : Enzyme_Op<"addTo", [Pure, Terminator, ReturnLike]>,
Arguments<(ins Variadic<AnyType>:$values)> {
let summary = "Linalg add to operation";
let description = [{
TODO
}];
let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>];
let hasCustomAssemblyFormat = 0;
let hasVerifier = 0;
}

def GenericAdjointOp : Enzyme_Op<"genericAdjoint", [AttrSizedOperandSegments]> {
let description = [{ }];

let arguments = (ins Variadic<AnyType>:$inputs,
Variadic<AnyShaped>:$outputs,
AffineMapArrayAttr:$indexing_maps,
ArrayAttr:$iterator_types,
OptionalAttr<StrAttr>:$doc,
OptionalAttr<StrAttr>:$library_call);
let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
let regions = (region AnyRegion:$region);

}

#endif // ENZYME_OPS
1 change: 1 addition & 0 deletions enzyme/Enzyme/MLIR/Dialect/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"

#include "mlir/Dialect/Affine/IR/AffineOps.h"
Expand Down
1 change: 1 addition & 0 deletions enzyme/Enzyme/MLIR/Dialect/Ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"

#include "mlir/IR/DialectImplementation.h"
Expand Down
2 changes: 2 additions & 0 deletions enzyme/Enzyme/MLIR/Implementations/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ add_mlir_library(MLIREnzymeImplementations
ArithAutoDiffOpInterfaceImpl.cpp
LLVMAutoDiffOpInterfaceImpl.cpp
MemRefAutoDiffOpInterfaceImpl.cpp
LinalgAutoDiffOpInterfaceImpl.cpp
BuiltinAutoDiffTypeInterfaceImpl.cpp
SCFAutoDiffOpInterfaceImpl.cpp

Expand All @@ -15,4 +16,5 @@ add_mlir_library(MLIREnzymeImplementations
MLIREnzymeAutoDiffInterface
MLIRIR
MLIRSCFDialect
MLIRLinalgDialect
)
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,6 @@ void registerBuiltinDialectAutoDiffInterface(DialectRegistry &registry);
void registerLLVMDialectAutoDiffInterface(DialectRegistry &registry);
void registerMemRefDialectAutoDiffInterface(DialectRegistry &registry);
void registerSCFDialectAutoDiffInterface(DialectRegistry &registry);
void registerLinalgDialectAutoDiffInterface(DialectRegistry &registry);
} // namespace enzyme
} // namespace mlir
284 changes: 284 additions & 0 deletions enzyme/Enzyme/MLIR/Implementations/LinalgAutoDiffOpInterfaceImpl.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,284 @@
//===- LinalgAutoDiffOpInterfaceImpl.cpp - Interface external model -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains the external model implementation of the automatic
// differentiation op interfaces for the upstream MLIR linalg dialect.
//
//===----------------------------------------------------------------------===//

#include "Implementations/CoreDialectsAutoDiffImplementations.h"
#include "Interfaces/AutoDiffOpInterface.h"
#include "Interfaces/AutoDiffTypeInterface.h"
#include "Interfaces/GradientUtils.h"
#include "Interfaces/GradientUtilsReverse.h"

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"

#include "mlir/IR/DialectRegistry.h"
#include "mlir/Support/LogicalResult.h"
#include <functional>

#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include "mlir/Dialect/Shape/IR/ShapeOpsTypes.h.inc"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/DialectConversion.h"

using namespace mlir;
using namespace mlir::enzyme;

namespace {

Value invertMemref(Value inp, OpBuilder &builder, Location loc) {
MemRefType iType = cast<MemRefType>(inp.getType());
SmallVector<Value> dims;
SmallVector<Value> dimSubOnes;
SmallVector<Value> strides;
Value negOne = builder.create<arith::ConstantIndexOp>(loc, -1);
int shapeDim = iType.getShape().size();
for (int i = 0; i < shapeDim; i++) {
Value dim = builder.create<memref::DimOp>(loc, inp, i);
dims.push_back(dim);
auto dimSubOne = builder.create<arith::AddIOp>(loc, dim, negOne);
dimSubOnes.push_back(dimSubOne);
strides.push_back(negOne);
}
Value view = builder.create<memref::SubViewOp>(
loc, inp, ValueRange(dimSubOnes), ValueRange(dims), ValueRange(strides));
return view;
}

SmallVector<AffineMap> getIndexingMapsArray(enzyme::GenericAdjointOp &op) {
auto attr = op.getIndexingMapsAttr();
SmallVector<AffineMap> indexingMaps;
for (auto map : attr.getValue()) {
indexingMaps.push_back(map.cast<AffineMapAttr>().getValue());
}
return indexingMaps;
}

template <typename T_>
struct GenericOpInterfaceReverse
: public ReverseAutoDiffOpInterface::ExternalModel<
GenericOpInterfaceReverse<T_>, T_> {
void createReverseModeAdjoint(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils,
SmallVector<Value> caches) const {
auto linalgOp = cast<linalg::LinalgOp>(op);
assert(linalgOp.hasBufferSemantics() &&
"Linalg op with tensor semantics not yet supported");

linalg::LinalgOp newOp =
cast<linalg::LinalgOp>(gutils->getNewFromOriginal(linalgOp));

// Replace the op by a linalg.generic op if necessary
// TODO : IRRewriter rewriter(builder.getContext()/*,
// builder.getListener()*/);
ConversionPatternRewriter rewriter(builder.getContext());
auto failiureOrLinalgOp = generalizeNamedOp(rewriter, newOp);
if (!failed(failiureOrLinalgOp)) {
linalg::GenericOp replacement = failiureOrLinalgOp.value();
auto scope = OpBuilder::InsertionGuard(builder);
builder.setInsertionPointAfter(newOp);
builder.insert(replacement);
newOp.erase();
newOp = replacement;
}

auto cacheBuilder = OpBuilder(newOp, builder.getListener());

// Calculate the iteration domain
AffineMap aMap = newOp.getShapesToLoopsMap();
SmallVector<Value> dims;
for (OpOperand *input : newOp.getDpsInputOperands()) {
auto shape = cast<MemRefType>(input->get().getType()).getShape();
for (unsigned i = 0; i < shape.size(); i++) {
auto dimI =
cacheBuilder.create<arith::ConstantIndexOp>(op->getLoc(), i);
auto dim = cacheBuilder.create<memref::DimOp>(op->getLoc(),
input->get(), dimI);
dims.push_back(dim);
}
}
for (OpOperand *output : newOp.getDpsInitOperands()) {
auto shape = cast<MemRefType>(output->get().getType()).getShape();
for (unsigned i = 0; i < shape.size(); i++) {
auto dimI =
cacheBuilder.create<arith::ConstantIndexOp>(op->getLoc(), i);
auto dim = cacheBuilder.create<memref::DimOp>(op->getLoc(),
output->get(), dimI);
dims.push_back(dim);
}
}

SmallVector<Value> iterationDomains;
SmallVector<int64_t> shapes;
for (unsigned int i = 0; i < aMap.getNumResults(); i++) {
AffineMap subMap = aMap.getSubMap({i});
Value domain = cacheBuilder.create<AffineApplyOp>(op->getLoc(), subMap,
ValueRange(dims));
iterationDomains.push_back(domain);
shapes.push_back(ShapedType::kDynamic);
}
//

SmallVector<Value> inputs, outputs;
SmallVector<AffineMap> indexingMaps;
SmallVector<utils::IteratorType> iteratorTypes{
linalgOp.getNumLoops(), utils::IteratorType::parallel};

for (OpOperand *output : linalgOp.getDpsInitOperands()) {
if (!gutils->hasInvertPointer(output->get())) {
continue;
}
indexingMaps.push_back(linalgOp.getMatchingIndexingMap(output));
Value out = gutils->invertPointerM(output->get(), builder);
Value view = invertMemref(out, builder, op->getLoc());
outputs.push_back(view);
}

for (OpOperand *input : linalgOp.getDpsInputOperands()) {
if (!gutils->hasInvertPointer(input->get())) {
continue;
}
indexingMaps.push_back(linalgOp.getMatchingIndexingMap(input));
Value inp = gutils->invertPointerM(input->get(), builder);
Value view = invertMemref(inp, builder, op->getLoc());
inputs.push_back(view);
}

ArrayAttr indexingMapsArrayAttr =
builder.getAffineMapArrayAttr(indexingMaps);
ArrayAttr iteratorTypesArrayAttr =
builder.getArrayAttr(llvm::to_vector(llvm::map_range(
iteratorTypes, [&](utils::IteratorType iter) -> mlir::Attribute {
return linalg::IteratorTypeAttr::get(builder.getContext(), iter);
})));
auto adjoint = builder.create<enzyme::GenericAdjointOp>(
op->getLoc(), TypeRange(), ValueRange(outputs), ValueRange(inputs),
indexingMapsArrayAttr, iteratorTypesArrayAttr, StringAttr(),
StringAttr());

int numInputs = inputs.size();
auto buildFuncReturnOp = [numInputs, indexingMaps, &newOp, &adjoint,
&inputs](OpBuilder &builder, Location loc,
SmallVector<Value> retargs) {
builder.create<enzyme::AddToOp>(
loc, ValueRange{retargs}.take_front(numInputs));
return;
};

Region *newOpRegion = newOp.getBlock()->getParent();
int numInputsNewOp = cast<linalg::GenericOp>(newOp).getInputs().size();
Region *adjointRegion = &adjoint.getRegion();
int numInputsAdjoint = adjoint.getInputs().size();
Location loc = op->getLoc();
int numCaches = 0;
SmallVector<Value> pushCaches;

auto hook = [newOpRegion, adjointRegion, loc, &numCaches = numCaches,
numInputsNewOp, numInputsAdjoint,
&pushCaches = pushCaches](Type t) {
OpBuilder builder(newOpRegion);
Value pushCache = builder.create<enzyme::InitOp>(loc, t);
pushCaches.push_back(pushCache);
newOpRegion->addArgument(t, loc);

Value popCache =
adjointRegion->insertArgument(numInputsAdjoint + numCaches, t, loc);
numCaches++;
return std::make_pair(pushCache, popCache);
};

gutils->Logic.differentiate(
gutils, *linalgOp.getBlock()->getParent(), adjoint.getRegion(),
/*parentRegion=*/false, buildFuncReturnOp, hook);

auto newOpYield = cast<linalg::YieldOp>(
cast<linalg::GenericOp>(newOp).getBodyRegion().front().getTerminator());
for (Value pc : pushCaches) {
newOpYield.getValuesMutable().append(pc);
}

Block *body = &(adjoint.getRegion().front());
auto yieldOp = cast<enzyme::AddToOp>(body->getTerminator());
for (auto opOperand : yieldOp.getOperands()) {
body->addArgument(opOperand.getType(), opOperand.getLoc());
}

OpBuilder builderAdd(yieldOp);

auto newIndexingMaps = newOp.getIndexingMapsArray();
auto indexingMapsAdjoint = getIndexingMapsArray(adjoint);
for (int i = 0; i < numCaches; i++) {
Value cacheArg = body->getArgument(outputs.size() + i);

Type ct = cacheArg.getType();
Type type = MemRefType::get(shapes, ct);
auto alloc = cacheBuilder.create<memref::AllocOp>(
op->getLoc(), type, ValueRange(iterationDomains));
Value cache = gutils->initAndPushCache(alloc, cacheBuilder);
// TODO use higher level API
alloc->setAttr(
alloc.getOperandSegmentSizesAttrName(),
cacheBuilder.getDenseI32ArrayAttr({iterationDomains.size(), 0}));

cast<linalg::GenericOp>(newOp).getOutputsMutable().append(
ValueRange({alloc}));
newIndexingMaps.push_back(AffineMap::getMultiDimIdentityMap(
iterationDomains.size(), cacheBuilder.getContext()));

builderAdd.setInsertionPoint(adjoint);
Value retrievedValue = gutils->popCache(cache, builderAdd);
retrievedValue = invertMemref(retrievedValue, builderAdd, op->getLoc());
adjoint.getInputsMutable().append(ValueRange({retrievedValue}));
indexingMapsAdjoint.insert(
indexingMapsAdjoint.begin() + numInputsAdjoint + i,
AffineMap::getMultiDimIdentityMap(iterationDomains.size(),
builderAdd.getContext()));
}
SmallVector<Attribute> indexingMapsAttr;
SmallVector<Attribute> indexingMapsAttrAdjoint;
for (auto &map : newIndexingMaps) {
indexingMapsAttr.push_back(AffineMapAttr::get(map));
}
for (auto &map : indexingMapsAdjoint) {
indexingMapsAttrAdjoint.push_back(AffineMapAttr::get(map));
}
cast<linalg::GenericOp>(newOp).setIndexingMapsAttr(
cacheBuilder.getArrayAttr(indexingMapsAttr));
adjoint->setAttr(adjoint.getIndexingMapsAttrName(),
builder.getArrayAttr(indexingMapsAttrAdjoint));
}

SmallVector<Value> cacheValues(Operation *op,
MGradientUtilsReverse *gutils) const {
return SmallVector<Value>();
}

void createShadowValues(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils) const {}
};
} // namespace

template <typename... Ts> void attachAllInterfaces(MLIRContext *context) {
(Ts::template attachInterface<GenericOpInterfaceReverse<Ts>>(*context), ...);
}

void mlir::enzyme::registerLinalgDialectAutoDiffInterface(
DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *context, linalg::LinalgDialect *) {
attachAllInterfaces<
#define GET_OP_LIST
#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
>(context);
});
}
Loading

0 comments on commit 3d35e80

Please sign in to comment.