Skip to content

Commit

Permalink
fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Feb 29, 2024
1 parent 005f4fc commit cef6a48
Show file tree
Hide file tree
Showing 18 changed files with 465 additions and 333 deletions.
6 changes: 4 additions & 2 deletions enzyme/Enzyme/MLIR/Dialect/EnzymeOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -111,13 +111,15 @@ def Gradient : Enzyme_Type<"Gradient"> {

def SetOp : Enzyme_Op<"set"> {
let summary = "Store the current value of the gradient";
let arguments = (ins AnyType : $gradient, AnyType : $value);
let arguments = (ins Arg<AnyType, "the reference to store to",
[MemWrite]>:$gradient, AnyType : $value);
let results = (outs );
}

def GetOp : Enzyme_Op<"get"> {
let summary = "Load current value of gradient";
let arguments = (ins AnyType : $gradient);
let arguments = (ins Arg<AnyType, "the reference to load from",
[MemRead]>:$gradient);
let results = (outs AnyType);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,9 +201,9 @@ LogicalResult mlir::enzyme::detail::allocationForwardHandler(
return success();
}


void mlir::enzyme::detail::returnReverseHandler(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils) {
void mlir::enzyme::detail::returnReverseHandler(Operation *op,
OpBuilder &builder,
MGradientUtilsReverse *gutils) {
size_t num_out = 0;
for (auto act : gutils->RetDiffeTypes) {
if (act == DIFFE_TYPE::OUT_DIFF)
Expand All @@ -216,10 +216,10 @@ void mlir::enzyme::detail::returnReverseHandler(Operation *op, OpBuilder &builde
for (auto &&[op, act] : llvm::zip(op->getOperands(), gutils->RetDiffeTypes)) {
if (act == DIFFE_TYPE::OUT_DIFF) {
if (!gutils->isConstantValue(op)) {
auto d_out = args[args.size() - num_out + idx];
gutils->addToDiffe(op, d_out, builder);
}
idx++;
auto d_out = args[args.size() - num_out + idx];
gutils->addToDiffe(op, d_out, builder);
}
idx++;
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,6 @@ class NoopRevAutoDiffInterface
: public ReverseAutoDiffOpInterface::ExternalModel<
NoopRevAutoDiffInterface<OpTy>, OpTy> {
public:

void createReverseModeAdjoint(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils,
SmallVector<Value> caches) const {}
Expand All @@ -124,16 +123,14 @@ class NoopRevAutoDiffInterface
}

void createShadowValues(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils) const {
}
MGradientUtilsReverse *gutils) const {}
};

template <typename OpTy>
class ReturnRevAutoDiffInterface
: public ReverseAutoDiffOpInterface::ExternalModel<
ReturnRevAutoDiffInterface<OpTy>, OpTy> {
public:

void createReverseModeAdjoint(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils,
SmallVector<Value> caches) const {
Expand All @@ -146,8 +143,7 @@ class ReturnRevAutoDiffInterface
}

void createShadowValues(Operation *op, OpBuilder &builder,
MGradientUtilsReverse *gutils) const {
}
MGradientUtilsReverse *gutils) const {}
};

// Implements the forward autodiff interface for operations which are
Expand Down Expand Up @@ -212,7 +208,8 @@ void registerAutoDiffUsingControlFlowInterface(MLIRContext &context) {
template <typename OpTy>
void registerAutoDiffUsingBranchInterface(MLIRContext &context) {
OpTy::template attachInterface<detail::AutoDiffUsingBranch<OpTy>>(context);
OpTy::template attachInterface<detail::NoopRevAutoDiffInterface<OpTy>>(context);
OpTy::template attachInterface<detail::NoopRevAutoDiffInterface<OpTy>>(
context);
}
// Registers AutoDiffUsingRegionTerminator for the given op.
template <typename OpTy>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,8 @@ struct GenericOpInterfaceReverse
StringAttr());

int numInputs = inputs.size();
auto buildFuncReturnOp = [&gutils, numInputs](OpBuilder &builder, Block* oBB) {
auto buildFuncReturnOp = [&gutils, numInputs](OpBuilder &builder,
Block *oBB) {
auto loc = oBB->rbegin()->getLoc();
SmallVector<Value> retargs;
for (auto arg : oBB->getArguments()) {
Expand Down Expand Up @@ -196,8 +197,8 @@ struct GenericOpInterfaceReverse
return std::make_pair(pushCache, popCache);
};

gutils->Logic.differentiate(
gutils, *linalgOp.getBlock()->getParent(), adjoint.getRegion(), buildFuncReturnOp, hook);
gutils->Logic.differentiate(gutils, *linalgOp.getBlock()->getParent(),
adjoint.getRegion(), buildFuncReturnOp, hook);

auto newOpYield = cast<linalg::YieldOp>(
cast<linalg::GenericOp>(newOp).getBodyRegion().front().getTerminator());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,17 +119,18 @@ struct StoreOpInterfaceReverse

if (!iface.isMutable()) {
if (!gutils->isConstantValue(val)) {
Value loadedGradient =
builder.create<memref::LoadOp>(storeOp.getLoc(), memrefGradient,
ArrayRef<Value>(retrievedArguments));
Value loadedGradient = builder.create<memref::LoadOp>(
storeOp.getLoc(), memrefGradient,
ArrayRef<Value>(retrievedArguments));
gutils->addToDiffe(val, loadedGradient, builder);
}

auto zero = cast<AutoDiffTypeInterface>(gutils->getShadowType(val.getType())).createNullValue(builder, op->getLoc());
auto zero =
cast<AutoDiffTypeInterface>(gutils->getShadowType(val.getType()))
.createNullValue(builder, op->getLoc());

builder.create<memref::StoreOp>(storeOp.getLoc(), zero, memrefGradient,
ArrayRef<Value>(retrievedArguments));

ArrayRef<Value>(retrievedArguments));
}
}
}
Expand Down
128 changes: 70 additions & 58 deletions enzyme/Enzyme/MLIR/Implementations/SCFAutoDiffOpInterfaceImpl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,102 +40,114 @@ struct ForOpInterfaceReverse
SmallVector<Value> caches) const {
auto forOp = cast<scf::ForOp>(op);


// Begin Perform d(yielded value[i]) += d(result[i]); d(result[i]) = 0
SmallVector<Value, 1> resDiffes;
for (OpResult v : forOp.getResults()) {
if (!gutils->isConstantValue(v)) {
if (!gutils->isConstantValue(v)) {
auto autoDiffType = cast<AutoDiffTypeInterface>(v.getType());
if (autoDiffType.isMutable()) {
if (!autoDiffType.isMutable()) {
auto prev = gutils->diffe(v, builder);
gutils->zeroDiffe(v, builder);
resDiffes.push_back(prev);
continue;
}
}
resDiffes.push_back(nullptr);
}
resDiffes.push_back(nullptr);
}

for (auto &reg : op->getRegions()) {
auto termIface =
cast<RegionBranchTerminatorOpInterface>(reg.begin()->getTerminator());

for (auto &reg : op->getRegions()) {
auto termIface = cast<RegionBranchTerminatorOpInterface>(
reg.begin()->getTerminator());


SmallVector<RegionSuccessor> successors;
termIface.getSuccessorRegions(
SmallVector<Attribute>(termIface->getNumOperands(), Attribute()),
successors);
SmallVector<RegionSuccessor> successors;
termIface.getSuccessorRegions(
SmallVector<Attribute>(termIface->getNumOperands(), Attribute()),
successors);

for (auto &successor : successors) {
if (!successor.isParent()) continue;
OperandRange operandRange =
termIface.getSuccessorOperands(successor);
if (!successor.isParent())
continue;
OperandRange operandRange = termIface.getSuccessorOperands(successor);
assert(operandRange.size() == resDiffes.size());
// There is an assumption here that there is only regions that branch to the successor.
// Specifically, otherwise we would need to gutils->addToDiffe select (if came from that result)

// There is an assumption here that there is only regions that branch to
// the successor. Specifically, otherwise we would need to
// gutils->addToDiffe select (if came from that result)
for (auto &&[prev, post] : llvm::zip(operandRange, resDiffes)) {
if (!post) continue;
if (!gutils->isConstantValue(prev))
gutils->addToDiffe(prev, post, builder);
if (!post)
continue;
if (!gutils->isConstantValue(prev))
gutils->addToDiffe(prev, post, builder);
}
}
}
}
// End Perform d(yielded value[i]) += d(result[i]); d(result[i]) = 0

auto start = gutils->popCache(caches[0], builder);
auto end = gutils->popCache(caches[1], builder);
auto step = gutils->popCache(caches[2], builder);

auto repFor = builder.create<scf::ForOp>(
forOp.getLoc(), start, end, step, ArrayRef<Value>());
auto repFor = builder.create<scf::ForOp>(forOp.getLoc(), start, end, step,
ArrayRef<Value>());
// erase scf yield
repFor.getBody()->begin()->erase();

for (auto &&[oldReg, newReg] : llvm::zip(op->getRegions(), repFor->getRegions())) {
for (auto &&[oldReg, newReg] :
llvm::zip(op->getRegions(), repFor->getRegions())) {

// This code assumes at most one terminating block for each region (lest the append happen multiple times)
auto buildFuncReturnOp = [&](OpBuilder &builder, Block* oBB) {
auto loc = oBB->rbegin()->getLoc();
// This code assumes at most one terminating block for each region (lest
// the append happen multiple times)
auto buildFuncReturnOp = [&](OpBuilder &builder, Block *oBB) {
auto loc = oBB->rbegin()->getLoc();

auto idx = repFor.getInductionVar();
auto idx = repFor.getInductionVar();

auto lhs = builder.create<arith::AddIOp>(loc, idx, step);
auto lhs = builder.create<arith::AddIOp>(loc, idx, step);

// This needs to know a condition describing which predecessor this will return to, to select the right value
// Here we use the condition i + step >= end to determine the last iteration
// This needs to know a condition describing which predecessor this will
// return to, to select the right value Here we use the condition i +
// step >= end to determine the last iteration

auto condition = builder.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge, lhs, end);
auto condition = builder.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sge, lhs, end);

for (auto [arg, init_arg] : llvm::zip(oBB->getArguments(), forOp.getInitArgs())) {
if (!gutils->isConstantValue(arg) && !cast<AutoDiffTypeInterface>(arg.getType()).isMutable()) {
auto diffe = gutils->diffe(arg, builder);
gutils->zeroDiffe(arg, builder);
for (auto [arg, init_arg] :
llvm::zip(oBB->getArguments().slice(1), forOp.getInitArgs())) {
if (!gutils->isConstantValue(arg) &&
!cast<AutoDiffTypeInterface>(arg.getType()).isMutable()) {
auto diffe = gutils->diffe(arg, builder);
gutils->zeroDiffe(arg, builder);

auto zero = cast<AutoDiffTypeInterface>(diffe.getType()).createNullValue(builder, loc);
auto outside = builder.create<arith::SelectOp>(loc, condition, diffe, zero);
auto inside = builder.create<arith::SelectOp>(loc, condition, zero, diffe);
auto zero = cast<AutoDiffTypeInterface>(diffe.getType())
.createNullValue(builder, loc);
auto outside =
builder.create<arith::SelectOp>(loc, condition, diffe, zero);
auto inside =
builder.create<arith::SelectOp>(loc, condition, zero, diffe);

// For each predecessor, if we came from that predecessor += the shadow of the arg [after zero'ing]
if (!gutils->isConstantValue(init_arg)) {
gutils->addToDiffe(init_arg, outside, builder);
}
// For each predecessor, if we came from that predecessor += the
// shadow of the arg [after zero'ing]
if (!gutils->isConstantValue(init_arg)) {
gutils->addToDiffe(init_arg, outside, builder);
}

if (!gutils->isConstantValue(arg)) {
gutils->addToDiffe(arg, inside, builder);
}
}
if (!gutils->isConstantValue(arg)) {
gutils->addToDiffe(arg, inside, builder);
}
builder.create<scf::YieldOp>(loc);
};

for (auto &&[oBB, revBB] : llvm::zip(oldReg, newReg)) {
gutils->mapReverseModeBlocks.map(&oBB, &revBB);
gutils->Logic.visitChildren(&oBB, &revBB, gutils);
Block *newBB = gutils->getNewFromOriginal(&oBB);
gutils->Logic.handlePredecessors(&oBB, newBB, &revBB, gutils, buildFuncReturnOp);
}
}
builder.create<scf::YieldOp>(loc);
};

for (auto &&[oBB, revBB] : llvm::zip(oldReg, newReg)) {
gutils->mapReverseModeBlocks.map(&oBB, &revBB);
}
for (auto &&[oBB, revBB] : llvm::zip(oldReg, newReg)) {
gutils->Logic.visitChildren(&oBB, &revBB, gutils);
Block *newBB = gutils->getNewFromOriginal(&oBB);
gutils->Logic.handlePredecessors(&oBB, newBB, &revBB, gutils,
buildFuncReturnOp);
}
}
}

Expand Down
7 changes: 4 additions & 3 deletions enzyme/Enzyme/MLIR/Interfaces/EnzymeLogic.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,10 @@ class MEnzymeLogic {
void
initializeShadowValues(SmallVector<mlir::Block *> &dominatorToposortBlocks,
MGradientUtilsReverse *gutils);
void handlePredecessors(Block *oBB, Block *newBB, Block *reverseBB,
MGradientUtilsReverse *gutils,
llvm::function_ref<buildReturnFunction> buildReturnOp);
void
handlePredecessors(Block *oBB, Block *newBB, Block *reverseBB,
MGradientUtilsReverse *gutils,
llvm::function_ref<buildReturnFunction> buildReturnOp);
void visitChildren(Block *oBB, Block *reverseBB,
MGradientUtilsReverse *gutils);
void visitChild(Operation *op, OpBuilder &builder,
Expand Down
Loading

0 comments on commit cef6a48

Please sign in to comment.