Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add functions to truncate and expand fp values #1615

Merged
merged 3 commits into from
Jan 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 53 additions & 11 deletions enzyme/Enzyme/Enzyme.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1314,14 +1314,14 @@ class EnzymeBase {
return type_args;
}

bool HandleTruncate(CallInst *CI) {
bool HandleTruncateFunc(CallInst *CI) {
IRBuilder<> Builder(CI);
Function *F = parseFunctionParameter(CI);
if (!F)
return false;
if (CI->arg_size() != 3) {
EmitFailure("TooManyArgs", CI->getDebugLoc(), CI,
"Had incorrect number of args to __enzyme_truncate", *CI,
"Had incorrect number of args to __enzyme_truncate_func", *CI,
" - expected 3");
return false;
}
Expand All @@ -1330,7 +1330,7 @@ class EnzymeBase {
auto Cto = cast<ConstantInt>(CI->getArgOperand(2));
assert(Cto);
RequestContext context(CI, &Builder);
llvm::Value *res = Logic.CreateTruncate(
llvm::Value *res = Logic.CreateTruncateFunc(
context, F, (unsigned)Cfrom->getValue().getZExtValue(),
(unsigned)Cto->getValue().getZExtValue());
if (!res)
Expand All @@ -1341,6 +1341,28 @@ class EnzymeBase {
return true;
}

bool HandleTruncateValue(CallInst *CI, bool isTruncate) {
IRBuilder<> Builder(CI);
if (CI->arg_size() != 3) {
EmitFailure("TooManyArgs", CI->getDebugLoc(), CI,
"Had incorrect number of args to __enzyme_truncate_value",
*CI, " - expected 3");
return false;
}
auto Cfrom = cast<ConstantInt>(CI->getArgOperand(1));
assert(Cfrom);
auto Cto = cast<ConstantInt>(CI->getArgOperand(2));
assert(Cto);
auto Addr = CI->getArgOperand(0);
RequestContext context(CI, &Builder);
bool res = Logic.CreateTruncateValue(
context, Addr, (unsigned)Cfrom->getValue().getZExtValue(),
(unsigned)Cto->getValue().getZExtValue(), isTruncate);
if (!res)
return false;
return true;
}

bool HandleBatch(CallInst *CI) {
unsigned width = 1;
unsigned truei = 0;
Expand Down Expand Up @@ -2088,7 +2110,9 @@ class EnzymeBase {
MapVector<CallInst *, DerivativeMode> toVirtual;
MapVector<CallInst *, DerivativeMode> toSize;
SmallVector<CallInst *, 4> toBatch;
SmallVector<CallInst *, 4> toTruncate;
SmallVector<CallInst *, 4> toTruncateFunc;
SmallVector<CallInst *, 4> toTruncateValue;
SmallVector<CallInst *, 4> toExpandValue;
MapVector<CallInst *, ProbProgMode> toProbProg;
SetVector<CallInst *> InactiveCalls;
SetVector<CallInst *> IterCalls;
Expand Down Expand Up @@ -2398,7 +2422,9 @@ class EnzymeBase {
bool virtualCall = false;
bool sizeOnly = false;
bool batch = false;
bool truncate = false;
bool truncateFunc = false;
bool truncateValue = false;
bool expandValue = false;
bool probProg = false;
DerivativeMode derivativeMode;
ProbProgMode probProgMode;
Expand Down Expand Up @@ -2428,9 +2454,15 @@ class EnzymeBase {
} else if (Fn->getName().contains("__enzyme_batch")) {
enableEnzyme = true;
batch = true;
} else if (Fn->getName().contains("__enzyme_truncate")) {
} else if (Fn->getName().contains("__enzyme_truncate_func")) {
enableEnzyme = true;
truncate = true;
truncateFunc = true;
} else if (Fn->getName().contains("__enzyme_truncate_value")) {
enableEnzyme = true;
truncateValue = true;
} else if (Fn->getName().contains("__enzyme_expand_value")) {
enableEnzyme = true;
expandValue = true;
} else if (Fn->getName().contains("__enzyme_likelihood")) {
enableEnzyme = true;
probProgMode = ProbProgMode::Likelihood;
Expand Down Expand Up @@ -2488,8 +2520,12 @@ class EnzymeBase {
toSize[CI] = derivativeMode;
else if (batch)
toBatch.push_back(CI);
else if (truncate)
toTruncate.push_back(CI);
else if (truncateFunc)
toTruncateFunc.push_back(CI);
else if (truncateValue)
toTruncateValue.push_back(CI);
else if (expandValue)
toExpandValue.push_back(CI);
else if (probProg) {
toProbProg[CI] = probProgMode;
} else
Expand Down Expand Up @@ -2583,8 +2619,14 @@ class EnzymeBase {
for (auto call : toBatch) {
HandleBatch(call);
}
for (auto call : toTruncate) {
HandleTruncate(call);
for (auto call : toTruncateFunc) {
HandleTruncateFunc(call);
}
for (auto call : toTruncateValue) {
HandleTruncateValue(call, true);
}
for (auto call : toExpandValue) {
HandleTruncateValue(call, false);
}

for (auto &&[call, mode] : toProbProg) {
Expand Down
134 changes: 98 additions & 36 deletions enzyme/Enzyme/EnzymeLogic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4816,11 +4816,53 @@ Function *EnzymeLogic::CreateForwardDiff(
return nf;
}

static Type *getTypeForWidth(LLVMContext &ctx, unsigned width) {
switch (width) {
default:
return llvm::Type::getIntNTy(ctx, width);
case 64:
return llvm::Type::getDoubleTy(ctx);
case 32:
return llvm::Type::getFloatTy(ctx);
case 16:
return llvm::Type::getHalfTy(ctx);
}
}

static Value *floatTruncate(IRBuilderBase &B, Value *v, Value *tmpBlock,
unsigned fromwidth, unsigned towidth) {
Type *fromTy = getTypeForWidth(B.getContext(), fromwidth);
Type *toTy = getTypeForWidth(B.getContext(), towidth);
if (!tmpBlock)
tmpBlock = B.CreateAlloca(fromTy);
B.CreateStore(
v, B.CreatePointerCast(tmpBlock, PointerType::getUnqual(v->getType())));
return B.CreateLoad(
toTy, B.CreatePointerCast(tmpBlock, PointerType::getUnqual(toTy)));
}

static Value *floatExpand(IRBuilderBase &B, Value *v, Value *tmpBlock,
unsigned fromwidth, unsigned towidth) {
Type *fromTy = getTypeForWidth(B.getContext(), fromwidth);
if (!tmpBlock)
tmpBlock = B.CreateAlloca(fromTy);
auto c0 =
Constant::getNullValue(llvm::Type::getIntNTy(B.getContext(), fromwidth));
B.CreateStore(
c0, B.CreatePointerCast(tmpBlock, PointerType::getUnqual(c0->getType())));
B.CreateStore(
v, B.CreatePointerCast(tmpBlock, PointerType::getUnqual(v->getType())));
return B.CreateLoad(
fromTy, B.CreatePointerCast(tmpBlock, PointerType::getUnqual(fromTy)));
}

class TruncateGenerator : public llvm::InstVisitor<TruncateGenerator> {
private:
ValueToValueMapTy &originalToNewFn;
unsigned fromwidth;
unsigned towidth;
Type *fromType;
Type *toType;
Function *oldFunc;
Function *newFunc;
AllocaInst *tmpBlock;
Expand All @@ -4833,7 +4875,11 @@ class TruncateGenerator : public llvm::InstVisitor<TruncateGenerator> {
: originalToNewFn(originalToNewFn), fromwidth(fromwidth),
towidth(towidth), oldFunc(oldFunc), newFunc(newFunc), Logic(Logic) {
IRBuilder<> B(&newFunc->getEntryBlock().front());
tmpBlock = B.CreateAlloca(getTypeForWidth(fromwidth));

fromType = getTypeForWidth(B.getContext(), fromwidth);
toType = getTypeForWidth(B.getContext(), towidth);

tmpBlock = B.CreateAlloca(fromType);
}

void visitInstruction(llvm::Instruction &inst) {
Expand All @@ -4851,42 +4897,16 @@ class TruncateGenerator : public llvm::InstVisitor<TruncateGenerator> {
todo(inst);
}

Type *getTypeForWidth(unsigned width) {
switch (width) {
default:
return llvm::Type::getIntNTy(oldFunc->getContext(), width);
case 64:
return llvm::Type::getDoubleTy(oldFunc->getContext());
case 32:
return llvm::Type::getFloatTy(oldFunc->getContext());
case 16:
return llvm::Type::getHalfTy(oldFunc->getContext());
}
}
Type *getFromType() { return fromType; }

Type *getFromType() { return getTypeForWidth(fromwidth); }

Type *getToType() { return getTypeForWidth(towidth); }
Type *getToType() { return toType; }

Value *truncate(IRBuilder<> &B, Value *v) {
Type *nextType = getTypeForWidth(towidth);
B.CreateStore(
v, B.CreatePointerCast(tmpBlock, PointerType::getUnqual(v->getType())));
return B.CreateLoad(
nextType,
B.CreatePointerCast(tmpBlock, PointerType::getUnqual(nextType)));
return floatTruncate(B, v, tmpBlock, fromwidth, towidth);
}

Value *expand(IRBuilder<> &B, Value *v) {
Type *origT = getFromType();
auto c0 = Constant::getNullValue(
llvm::Type::getIntNTy(oldFunc->getContext(), fromwidth));
B.CreateStore(c0, B.CreatePointerCast(
tmpBlock, PointerType::getUnqual(c0->getType())));
B.CreateStore(
v, B.CreatePointerCast(tmpBlock, PointerType::getUnqual(v->getType())));
return B.CreateLoad(
origT, B.CreatePointerCast(tmpBlock, PointerType::getUnqual(origT)));
return floatExpand(B, v, tmpBlock, fromwidth, towidth);
}

void todo(llvm::Instruction &I) {
Expand Down Expand Up @@ -5183,7 +5203,7 @@ class TruncateGenerator : public llvm::InstVisitor<TruncateGenerator> {

Value *GetShadow(RequestContext &ctx, Value *v) {
if (auto F = dyn_cast<Function>(v))
return Logic.CreateTruncate(ctx, F, fromwidth, towidth);
return Logic.CreateTruncateFunc(ctx, F, fromwidth, towidth);
llvm::errs() << " unknown get truncated func: " << *v << "\n";
llvm_unreachable("unknown get truncated func");
return v;
Expand All @@ -5206,10 +5226,52 @@ class TruncateGenerator : public llvm::InstVisitor<TruncateGenerator> {
}
};

llvm::Function *EnzymeLogic::CreateTruncate(RequestContext context,
llvm::Function *totrunc,
unsigned fromwidth,
unsigned towidth) {
bool EnzymeLogic::CreateTruncateValue(RequestContext context, Value *v,
unsigned fromwidth, unsigned towidth,
bool isTruncate) {
assert(context.req && context.ip);

if (fromwidth == towidth) {
context.req->eraseFromParent();
return true;
}

if (fromwidth < towidth) {
std::string s;
llvm::raw_string_ostream ss(s);
ss << "Cannot truncate into a large width\n";
wsmoses marked this conversation as resolved.
Show resolved Hide resolved
if (context.req) {
ss << " at context: " << *context.req;
EmitFailure("NoTruncate", context.req->getDebugLoc(), context.req,
ss.str());
return false;
}
llvm_unreachable("failed to truncate value");
}

IRBuilderBase &B = *context.ip;
Type *fromTy = getTypeForWidth(B.getContext(), fromwidth);
Type *toTy = getTypeForWidth(B.getContext(), towidth);

Value *converted = nullptr;
if (isTruncate)
converted =
floatExpand(B, B.CreateFPTrunc(v, toTy), nullptr, fromwidth, towidth);
else
converted =
B.CreateFPExt(floatTruncate(B, v, nullptr, fromwidth, towidth), fromTy);
assert(converted);

context.req->replaceAllUsesWith(converted);
context.req->eraseFromParent();

return true;
}

llvm::Function *EnzymeLogic::CreateTruncateFunc(RequestContext context,
llvm::Function *totrunc,
unsigned fromwidth,
unsigned towidth) {
if (fromwidth == towidth)
return totrunc;

Expand Down
9 changes: 6 additions & 3 deletions enzyme/Enzyme/EnzymeLogic.h
Original file line number Diff line number Diff line change
Expand Up @@ -512,9 +512,12 @@ class EnzymeLogic {

using TruncateCacheKey = std::tuple<llvm::Function *, unsigned, unsigned>;
std::map<TruncateCacheKey, llvm::Function *> TruncateCachedFunctions;
llvm::Function *CreateTruncate(RequestContext context,
llvm::Function *tobatch, unsigned fromwidth,
unsigned towidth);
llvm::Function *CreateTruncateFunc(RequestContext context,
llvm::Function *tobatch,
unsigned fromwidth, unsigned towidth);
bool CreateTruncateValue(RequestContext context, llvm::Value *addr,
unsigned fromwidth, unsigned towidth,
bool isTruncate);

/// Create a traced version of a function
/// \p context the instruction which requested this trace (or null).
Expand Down
4 changes: 2 additions & 2 deletions enzyme/test/Enzyme/Truncate/cmp.ll
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ define i1 @f(double %x, double %y) {
ret i1 %res
}

declare i1 (double, double)* @__enzyme_truncate(...)
declare i1 (double, double)* @__enzyme_truncate_func(...)

define i1 @tester(double %x, double %y) {
entry:
%ptr = call i1 (double, double)* (...) @__enzyme_truncate(i1 (double, double)* @f, i64 64, i64 32)
%ptr = call i1 (double, double)* (...) @__enzyme_truncate_func(i1 (double, double)* @f, i64 64, i64 32)
%res = call i1 %ptr(double %x, double %y)
ret i1 %res
}
Expand Down
4 changes: 2 additions & 2 deletions enzyme/test/Enzyme/Truncate/intrinsic.ll
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ define double @f(double %x, double %y) {
ret double %res
}

declare double (double, double)* @__enzyme_truncate(...)
declare double (double, double)* @__enzyme_truncate_func(...)

define double @tester(double %x, double %y) {
entry:
%ptr = call double (double, double)* (...) @__enzyme_truncate(double (double, double)* @f, i64 64, i64 32)
%ptr = call double (double, double)* (...) @__enzyme_truncate_func(double (double, double)* @f, i64 64, i64 32)
%res = call double %ptr(double %x, double %y)
ret double %res
}
Expand Down
4 changes: 2 additions & 2 deletions enzyme/test/Enzyme/Truncate/select.ll
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ define double @f(double %x, double %y, i1 %cond) {
ret double %res
}

declare double (double, double, i1)* @__enzyme_truncate(...)
declare double (double, double, i1)* @__enzyme_truncate_func(...)

define double @tester(double %x, double %y, i1 %cond) {
entry:
%ptr = call double (double, double, i1)* (...) @__enzyme_truncate(double (double, double, i1)* @f, i64 64, i64 32)
%ptr = call double (double, double, i1)* (...) @__enzyme_truncate_func(double (double, double, i1)* @f, i64 64, i64 32)
%res = call double %ptr(double %x, double %y, i1 %cond)
ret double %res
}
Expand Down
4 changes: 2 additions & 2 deletions enzyme/test/Enzyme/Truncate/simple.ll
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@ define void @f(double* %x) {
ret void
}

declare void (double*)* @__enzyme_truncate(...)
declare void (double*)* @__enzyme_truncate_func(...)

define void @tester(double* %data) {
entry:
%ptr = call void (double*)* (...) @__enzyme_truncate(void (double*)* @f, i64 64, i64 32)
%ptr = call void (double*)* (...) @__enzyme_truncate_func(void (double*)* @f, i64 64, i64 32)
call void %ptr(double* %data)
ret void
}
Expand Down
Loading
Loading