diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 59f1d1b08e6..f75cef60165 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -5142,7 +5142,7 @@ class TruncateGenerator : public llvm::InstVisitor { void visitExtractValueInst(llvm::ExtractValueInst &EEI) { return; } void visitInsertValueInst(llvm::InsertValueInst &EEI) { return; } CallInst *createMPFRCall(llvm::IRBuilder<> &B, llvm::Instruction &I, - llvm::Type *RetTy, SmallVectorImpl &Args) { + llvm::Type *RetTy, SmallVectorImpl &ArgsIn) { std::string Name; if (auto BO = dyn_cast(&I)) { Name = "binop_" + std::string(BO->getOpcodeName()); @@ -5164,8 +5164,11 @@ class TruncateGenerator : public llvm::InstVisitor { } std::string MangledName = - std::string("__enzyme_mpfr_") + truncation.mangleString() + "_" + Name; + std::string("__enzyme_mpfr_") + truncation.mangleFrom() + "_" + Name; auto F = newFunc->getParent()->getFunction(MangledName); + SmallVector Args(ArgsIn.begin(), ArgsIn.end()); + Args.push_back(B.getInt64(truncation.getTo().exponentWidth)); + Args.push_back(B.getInt64(truncation.getTo().significandWidth)); if (!F) { SmallVector ArgTypes; for (auto Arg : Args) @@ -5467,7 +5470,7 @@ llvm::Function *EnzymeLogic::CreateTruncateFunc(RequestContext context, FunctionType *FTy = FunctionType::get(NewTy, params, totrunc->isVarArg()); std::string truncName = std::string("__enzyme_done_truncate_") + truncateModeStr(mode) + - "_func_" + truncation.mangleString() + "_" + totrunc->getName().str(); + "_func_" + truncation.mangleTruncation() + "_" + totrunc->getName().str(); Function *NewF = Function::Create(FTy, totrunc->getLinkage(), truncName, totrunc->getParent()); diff --git a/enzyme/Enzyme/EnzymeLogic.h b/enzyme/Enzyme/EnzymeLogic.h index 190b4dabc60..ce329c429c0 100644 --- a/enzyme/Enzyme/EnzymeLogic.h +++ b/enzyme/Enzyme/EnzymeLogic.h @@ -297,6 +297,7 @@ enum TruncateMode { TruncMemMode, TruncOpMode, TruncOpFullModuleMode }; case TruncOpFullModuleMode: return "op_full_module"; } + llvm_unreachable("Invalid truncation mode"); } struct FloatRepresentation { @@ -366,6 +367,7 @@ struct FloatTruncation { llvm::report_fatal_error( "Float truncation `from` and `to` type must not be the same."); } + FloatRepresentation getTo() { return to;} unsigned getFromTypeWidth() { return from.getTypeWidth(); } unsigned getToTypeWidth() { return to.getTypeWidth(); } llvm::Type *getFromType(llvm::LLVMContext &ctx) { @@ -388,9 +390,12 @@ struct FloatTruncation { bool operator<(const FloatTruncation &other) const { return std::tuple(from, to) < std::tuple(other.from, other.to); } - std::string mangleString() const { + std::string mangleTruncation() const { return from.to_string() + "to" + to.to_string(); } + std::string mangleFrom() const { + return from.to_string(); + } }; class EnzymeLogic {