Skip to content

Commit

Permalink
Use binary operators in tablegen (#1293)
Browse files Browse the repository at this point in the history
* fix

* binops

* fix fwd vec tests

* wip reverse

* remaining tests fixed

* fix blas in rebase

* With fdiv

* reformat

* fix

* fix tests

* fix tests

* apply suggestion
  • Loading branch information
wsmoses authored Jun 26, 2023
1 parent c21f77b commit 50b44bf
Show file tree
Hide file tree
Showing 168 changed files with 1,657 additions and 1,728 deletions.
14 changes: 14 additions & 0 deletions enzyme/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,19 @@ gentbl(
":enzyme-tblgen",
],
)
gentbl(
name = "gen-binop-derivatives",
tbl_outs = [(
"-gen-binop-derivatives",
"BinopDerivatives.inc",
)],
tblgen = ":enzyme-tblgen",
td_file = "Enzyme/InstructionDerivatives.td",
td_srcs = ["Enzyme/BlasDerivatives.td"],
deps = [
":enzyme-tblgen",
],
)
gentbl(
name = "blas-derivatives",
tbl_outs = [(
Expand Down Expand Up @@ -120,6 +133,7 @@ cc_library(
"@llvm-project//llvm:IRReader",
":call-derivatives",
":intr-derivatives",
":binop-derivatives",
":blas-derivatives",
":blas-attributor",
":blas-typeanalysis",
Expand Down
323 changes: 61 additions & 262 deletions enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -2121,83 +2121,18 @@ class AdjointGenerator
return;
}

switch (Mode) {
case DerivativeMode::ReverseModeGradient:
case DerivativeMode::ReverseModeCombined:
if (gutils->isConstantInstruction(&BO))
return;
createBinaryOperatorAdjoint(BO);
break;
case DerivativeMode::ForwardMode:
case DerivativeMode::ForwardModeSplit:
createBinaryOperatorDual(BO);
break;
case DerivativeMode::ReverseModePrimal:
return;
}
}

void createBinaryOperatorAdjoint(llvm::BinaryOperator &BO) {
using namespace llvm;

IRBuilder<> Builder2(&BO);
getReverseBuilder(Builder2);

Value *orig_op0 = BO.getOperand(0);
Value *orig_op1 = BO.getOperand(1);
bool constantval0 = gutils->isConstantValue(orig_op0);
bool constantval1 = gutils->isConstantValue(orig_op1);

Value *dif0 = nullptr;
Value *dif1 = nullptr;
Value *idiff = diffe(&BO, Builder2);

Type *addingType = BO.getType();

switch (BO.getOpcode()) {
case Instruction::FMul: {
if (!constantval0) {
Value *op0 = lookup(gutils->getNewFromOriginal(orig_op1), Builder2);
auto rule = [&](Value *idiff) {
return checkedMul(Builder2, idiff, op0,
"m0diffe" + orig_op0->getName());
};
dif0 = applyChainRule(orig_op0->getType(), Builder2, rule, idiff);
}
if (!constantval1) {
auto rule = [&](Value *idiff) {
return checkedMul(
Builder2, idiff,
lookup(gutils->getNewFromOriginal(orig_op0), Builder2),
"m1diffe" + orig_op1->getName());
};
dif1 = applyChainRule(orig_op1->getType(), Builder2, rule, idiff);
}
break;
}
case Instruction::FAdd: {
if (!constantval0)
dif0 = idiff;
if (!constantval1)
dif1 = idiff;
break;
}
case Instruction::FSub: {
if (!constantval0)
dif0 = idiff;
if (!constantval1) {
auto rule = [&](Value *idiff) { return Builder2.CreateFNeg(idiff); };
dif1 = applyChainRule(orig_op1->getType(), Builder2, rule, idiff);
}
break;
}
case Instruction::FDiv: {
if (BO.getOpcode() == llvm::Instruction::FDiv &&
(Mode == DerivativeMode::ReverseModeGradient ||
Mode == DerivativeMode::ReverseModeCombined) &&
!gutils->isConstantValue(&BO)) {
using namespace llvm;
// Required loopy phi = [in, BO, BO, ..., BO]
// 1) phi is only used in this B0
// 2) BO dominates all latches
// 3) phi == B0 whenever not coming from preheader [implies 2]
// 4) [optional but done for ease] one exit to make it easier to
// calculation the product at that point
Value *orig_op0 = BO.getOperand(0);
if (auto P0 = dyn_cast<PHINode>(orig_op0)) {
LoopContext lc;
SmallVector<Instruction *, 4> activeUses;
Expand All @@ -2221,6 +2156,20 @@ class AdjointGenerator
}
}
if (allIncoming && lc.exitBlocks.size() == 1) {

IRBuilder<> Builder2(&BO);
getReverseBuilder(Builder2);

Value *orig_op1 = BO.getOperand(1);
bool constantval0 = gutils->isConstantValue(orig_op0);
bool constantval1 = gutils->isConstantValue(orig_op1);

Value *dif0 = nullptr;
Value *dif1 = nullptr;
Value *idiff = diffe(&BO, Builder2);

Type *addingType = BO.getType();

if (!constantval1) {
IRBuilder<> EB(*lc.exitBlocks.begin());
getReverseBuilder(EB, /*original=*/false);
Expand Down Expand Up @@ -2277,62 +2226,51 @@ class AdjointGenerator
}
}
}
if (!constantval0) {
Value *op1 = lookup(gutils->getNewFromOriginal(orig_op1), Builder2);
auto rule = [&](Value *idiff) {
return Builder2.CreateFDiv(idiff, op1,
"d0diffe" + orig_op0->getName());
};
dif0 = applyChainRule(orig_op0->getType(), Builder2, rule, idiff);
}
if (!constantval1) {
Value *lop1 = lookup(gutils->getNewFromOriginal(orig_op1), Builder2);
Value *lastdiv = lookup(gutils->getNewFromOriginal(&BO), Builder2);

auto rule = [&](Value *idiff) {
auto res = Builder2.CreateFNeg(
Builder2.CreateFMul(lastdiv, Builder2.CreateFDiv(idiff, lop1)));
if (EnzymeStrongZero) {
res = CreateSelect(
Builder2,
Builder2.CreateFCmpOEQ(
idiff, Constant::getNullValue(idiff->getType())),
idiff, res);
}
return res;
};
dif1 = applyChainRule(orig_op1->getType(), Builder2, rule, idiff);
}
break;
}
case Instruction::FRem: {
if (!constantval0) {
dif0 = idiff;
}
if (!constantval1) {
auto M = gutils->newFunc->getParent();
Value *lop0 = lookup(gutils->getNewFromOriginal(orig_op0), Builder2);
Value *lop1 = lookup(gutils->getNewFromOriginal(orig_op1), Builder2);
Value *div = Builder2.CreateFDiv(lop0, lop1);

Type *tys[] = {div->getType()};
Value *args[] = {div};
args[0] = Builder2.CreateCall(
Intrinsic::getDeclaration(M, Intrinsic::fabs, tys), args);
args[0] = Builder2.CreateCall(
Intrinsic::getDeclaration(M, Intrinsic::floor, tys), args);
Value *args2[] = {args[0], div};
args[0] = Builder2.CreateCall(
Intrinsic::getDeclaration(M, Intrinsic::copysign, tys), args2);
args[0] = Builder2.CreateFNeg(args[0]);

auto rule = [&](Value *idiff) {
return checkedMul(Builder2, idiff, args[0]);
};
dif1 = applyChainRule(orig_op1->getType(), Builder2, rule, idiff);

{
using namespace llvm;
switch (BO.getOpcode()) {
#include "BinopDerivatives.inc"
default:
break;
}
}

switch (Mode) {
case DerivativeMode::ReverseModeGradient:
case DerivativeMode::ReverseModeCombined:
if (gutils->isConstantInstruction(&BO))
return;
createBinaryOperatorAdjoint(BO);
break;
case DerivativeMode::ForwardMode:
case DerivativeMode::ForwardModeSplit:
createBinaryOperatorDual(BO);
break;
case DerivativeMode::ReverseModePrimal:
return;
}
}

void createBinaryOperatorAdjoint(llvm::BinaryOperator &BO) {
using namespace llvm;

IRBuilder<> Builder2(&BO);
getReverseBuilder(Builder2);

Value *orig_op0 = BO.getOperand(0);
Value *orig_op1 = BO.getOperand(1);
bool constantval0 = gutils->isConstantValue(orig_op0);
bool constantval1 = gutils->isConstantValue(orig_op1);

Value *dif0 = nullptr;
Value *dif1 = nullptr;
Value *idiff = diffe(&BO, Builder2);

Type *addingType = BO.getType();

switch (BO.getOpcode()) {
case Instruction::LShr: {
if (!constantval0) {
if (auto ci = dyn_cast<ConstantInt>(orig_op1)) {
Expand Down Expand Up @@ -2617,145 +2555,6 @@ class AdjointGenerator
constantval1 ? nullptr : diffe(orig_op1, Builder2)};

switch (BO.getOpcode()) {
case Instruction::FMul: {
if (!constantval0 && !constantval1) {
auto rule = [&](Value *dif0, Value *dif1) {
Value *idiff0 =
checkedMul(Builder2, dif0, gutils->getNewFromOriginal(orig_op1));
Value *idiff1 =
checkedMul(Builder2, dif1, gutils->getNewFromOriginal(orig_op0));
return Builder2.CreateFAdd(idiff0, idiff1);
};
Value *diff =
applyChainRule(BO.getType(), Builder2, rule, dif[0], dif[1]);
setDiffe(&BO, diff, Builder2);
} else if (!constantval0) {
auto rule = [&](Value *dif0) {
return checkedMul(Builder2, dif0,
gutils->getNewFromOriginal(orig_op1));
};
Value *idiff0 = applyChainRule(BO.getType(), Builder2, rule, dif[0]);
setDiffe(&BO, idiff0, Builder2);
} else if (!constantval1) {
auto rule = [&](Value *dif1) {
return checkedMul(Builder2, dif1,
gutils->getNewFromOriginal(orig_op0));
};
Value *idiff1 = applyChainRule(BO.getType(), Builder2, rule, dif[1]);
setDiffe(&BO, idiff1, Builder2);
}
break;
}
case Instruction::FAdd: {
if (!constantval0 && !constantval1) {
auto rule = [&](Value *dif0, Value *dif1) {
return Builder2.CreateFAdd(dif0, dif1);
};
Value *diff =
applyChainRule(BO.getType(), Builder2, rule, dif[0], dif[1]);
setDiffe(&BO, diff, Builder2);
} else if (!constantval0) {
setDiffe(&BO, dif[0], Builder2);
} else if (!constantval1) {
setDiffe(&BO, dif[1], Builder2);
}
break;
}
case Instruction::FSub: {
if (!constantval0 && !constantval1) {
auto rule = [&](Value *dif0, Value *dif1) {
return Builder2.CreateFAdd(dif0, Builder2.CreateFNeg(dif1));
};
Value *diff =
applyChainRule(BO.getType(), Builder2, rule, dif[0], dif[1]);
setDiffe(&BO, diff, Builder2);
} else if (!constantval0) {
setDiffe(&BO, dif[0], Builder2);
} else if (!constantval1) {
auto rule = [&](Value *dif1) { return Builder2.CreateFNeg(dif1); };
Value *diff = applyChainRule(BO.getType(), Builder2, rule, dif[1]);
setDiffe(&BO, diff, Builder2);
}
break;
}
case Instruction::FDiv: {
Value *idiff3 = nullptr;
if (!constantval0 && !constantval1) {
auto rule = [&](Value *dif0, Value *dif1) {
Value *idiff1 =
checkedMul(Builder2, dif0, gutils->getNewFromOriginal(orig_op1));
Value *idiff2 =
checkedMul(Builder2, dif1, gutils->getNewFromOriginal(orig_op0));
return Builder2.CreateFSub(idiff1, idiff2);
};
idiff3 = applyChainRule(BO.getType(), Builder2, rule, dif[0], dif[1]);
} else if (!constantval0) {
auto rule = [&](Value *dif0) {
return checkedMul(Builder2, dif0,
gutils->getNewFromOriginal(orig_op1));
};
idiff3 = applyChainRule(BO.getType(), Builder2, rule, dif[0]);
} else if (!constantval1) {
auto rule = [&](Value *dif1) {
return checkedMul(
Builder2, dif1,
Builder2.CreateFNeg(gutils->getNewFromOriginal(orig_op0)));
};
idiff3 = applyChainRule(BO.getType(), Builder2, rule, dif[1]);
}

Value *idiff4 = Builder2.CreateFMul(gutils->getNewFromOriginal(orig_op1),
gutils->getNewFromOriginal(orig_op1));

auto rule = [&](Value *idiff3) {
return Builder2.CreateFDiv(idiff3, idiff4);
};

Value *idiff5 = applyChainRule(BO.getType(), Builder2, rule, idiff3);
setDiffe(&BO, idiff5, Builder2);

break;
}
case Instruction::FRem: {
Value *round = nullptr;
if (!constantval1) {
auto M = gutils->newFunc->getParent();
Value *lop0 = gutils->getNewFromOriginal(orig_op0);
Value *lop1 = gutils->getNewFromOriginal(orig_op1);
Value *div = Builder2.CreateFDiv(lop0, lop1);

Type *tys[] = {div->getType()};
Value *args[] = {div};
args[0] = Builder2.CreateCall(
Intrinsic::getDeclaration(M, Intrinsic::fabs, tys), args);
args[0] = Builder2.CreateCall(
Intrinsic::getDeclaration(M, Intrinsic::floor, tys), args);
Value *args2[] = {args[0], div};
args[0] = Builder2.CreateCall(
Intrinsic::getDeclaration(M, Intrinsic::copysign, tys), args2);
round = Builder2.CreateFNeg(args[0]);
}

if (!constantval0 && !constantval1) {
auto rule = [&](Value *dif0, Value *dif1) {
return Builder2.CreateFAdd(dif0, checkedMul(Builder2, dif1, round));
};
setDiffe(
&BO,
applyChainRule(orig_op1->getType(), Builder2, rule, dif[0], dif[1]),
Builder2);
} else if (!constantval0) {
setDiffe(&BO, dif[0], Builder2);
} else if (!constantval1) {
auto rule = [&](Value *dif1) {
return checkedMul(Builder2, dif1, round);
};
setDiffe(&BO,
applyChainRule(orig_op1->getType(), Builder2, rule, dif[1]),
Builder2);
}
break;
}
case Instruction::And: {
// If & against 0b10000000000 and a float the result is 0
auto &dl = gutils->oldFunc->getParent()->getDataLayout();
Expand Down
Loading

0 comments on commit 50b44bf

Please sign in to comment.