diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index ba61beddc531..757f411fef89 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -10117,14 +10117,20 @@ class AdjointGenerator dchoice = diffe(&call, Builder2); } - auto gradient_setter = cast( - cast( - call.getMetadata("enzyme_gradient_setter")->getOperand(0).get()) - ->getValue()); +#if LLVM_VERSION_MAJOR >= 10 + if (call.hasMetadata("enzyme_gradient_setter")) { +#else + if (call.getMetadata("enzyme_gradient_setter")) { +#endif + auto gradient_setter = cast( + cast( + call.getMetadata("enzyme_gradient_setter")->getOperand(0).get()) + ->getValue()); - TraceUtils::InsertChoiceGradient( - Builder2, gradient_setter->getFunctionType(), gradient_setter, - daddress, dchoice, dtrace); + TraceUtils::InsertChoiceGradient( + Builder2, gradient_setter->getFunctionType(), gradient_setter, + daddress, dchoice, dtrace); + } return; } diff --git a/enzyme/Enzyme/CApi.cpp b/enzyme/Enzyme/CApi.cpp index 5226bf7a8c87..a9df932db252 100644 --- a/enzyme/Enzyme/CApi.cpp +++ b/enzyme/Enzyme/CApi.cpp @@ -202,18 +202,16 @@ EnzymeTraceInterfaceRef FindEnzymeStaticTraceInterface(LLVMModuleRef M) { } EnzymeTraceInterfaceRef CreateEnzymeStaticTraceInterface( - LLVMContextRef C, LLVMValueRef sampleFunction, - LLVMValueRef getTraceFunction, LLVMValueRef getChoiceFunction, - LLVMValueRef insertCallFunction, LLVMValueRef insertChoiceFunction, - LLVMValueRef insertArgumentFunction, LLVMValueRef insertReturnFunction, - LLVMValueRef insertFunctionFunction, + LLVMContextRef C, LLVMValueRef getTraceFunction, + LLVMValueRef getChoiceFunction, LLVMValueRef insertCallFunction, + LLVMValueRef insertChoiceFunction, LLVMValueRef insertArgumentFunction, + LLVMValueRef insertReturnFunction, LLVMValueRef insertFunctionFunction, LLVMValueRef insertChoiceGradientFunction, LLVMValueRef insertArgumentGradientFunction, LLVMValueRef newTraceFunction, LLVMValueRef freeTraceFunction, LLVMValueRef hasCallFunction, LLVMValueRef hasChoiceFunction) { return (EnzymeTraceInterfaceRef)(new StaticTraceInterface( - *unwrap(C), cast(unwrap(sampleFunction)), - cast(unwrap(getTraceFunction)), + *unwrap(C), cast(unwrap(getTraceFunction)), cast(unwrap(getChoiceFunction)), cast(unwrap(insertCallFunction)), cast(unwrap(insertChoiceFunction)), @@ -622,13 +620,11 @@ EnzymeAugmentedReturnPtr EnzymeCreateAugmentedPrimal( forceAnonymousTape, width, AtomicAdd)); } -LLVMValueRef CreateTrace(EnzymeLogicRef Logic, LLVMValueRef totrace, - LLVMValueRef *generative_functions, - size_t generative_functions_size, - const char *active_random_variables[], - size_t active_random_variables_size, - CProbProgMode mode, uint8_t autodiff, - EnzymeTraceInterfaceRef interface) { +LLVMValueRef CreateTrace( + EnzymeLogicRef Logic, LLVMValueRef totrace, LLVMValueRef sample_function, + LLVMValueRef *generative_functions, size_t generative_functions_size, + const char *active_random_variables[], size_t active_random_variables_size, + CProbProgMode mode, uint8_t autodiff, EnzymeTraceInterfaceRef interface) { SmallPtrSet GenerativeFunctions; for (size_t i = 0; i < generative_functions_size; i++) { @@ -641,9 +637,9 @@ LLVMValueRef CreateTrace(EnzymeLogicRef Logic, LLVMValueRef totrace, } return wrap(eunwrap(Logic).CreateTrace( - cast(unwrap(totrace)), GenerativeFunctions, - ActiveRandomVariables, (ProbProgMode)mode, (bool)autodiff, - eunwrap(interface))); + cast(unwrap(totrace)), cast(unwrap(sample_function)), + GenerativeFunctions, ActiveRandomVariables, (ProbProgMode)mode, + (bool)autodiff, eunwrap(interface))); } LLVMValueRef diff --git a/enzyme/Enzyme/CApi.h b/enzyme/Enzyme/CApi.h index c6ce34434d36..e23c79e8cfe0 100644 --- a/enzyme/Enzyme/CApi.h +++ b/enzyme/Enzyme/CApi.h @@ -158,13 +158,11 @@ EnzymeAugmentedReturnPtr EnzymeCreateAugmentedPrimal( size_t uncacheable_args_size, uint8_t forceAnonymousTape, unsigned width, uint8_t AtomicAdd); -LLVMValueRef CreateTrace(EnzymeLogicRef Logic, LLVMValueRef totrace, - LLVMValueRef *generative_functions, - size_t generative_functions_size, - const char *active_random_variables[], - size_t active_random_variables_size, - CProbProgMode mode, uint8_t autodiff, - EnzymeTraceInterfaceRef interface); +LLVMValueRef CreateTrace( + EnzymeLogicRef Logic, LLVMValueRef totrace, LLVMValueRef sample_function, + LLVMValueRef *generative_functions, size_t generative_functions_size, + const char *active_random_variables[], size_t active_random_variables_size, + CProbProgMode mode, uint8_t autodiff, EnzymeTraceInterfaceRef interface); typedef uint8_t (*CustomRuleType)(int /*direction*/, CTypeTreeRef /*return*/, CTypeTreeRef * /*args*/, @@ -180,11 +178,10 @@ void FreeTypeAnalysis(EnzymeTypeAnalysisRef); EnzymeTraceInterfaceRef FindEnzymeStaticTraceInterface(LLVMModuleRef M); EnzymeTraceInterfaceRef CreateEnzymeStaticTraceInterface( - LLVMContextRef C, LLVMValueRef sampleFunction, - LLVMValueRef getTraceFunction, LLVMValueRef getChoiceFunction, - LLVMValueRef insertCallFunction, LLVMValueRef insertChoiceFunction, - LLVMValueRef insertArgumentFunction, LLVMValueRef insertReturnFunction, - LLVMValueRef insertFunctionFunction, + LLVMContextRef C, LLVMValueRef getTraceFunction, + LLVMValueRef getChoiceFunction, LLVMValueRef insertCallFunction, + LLVMValueRef insertChoiceFunction, LLVMValueRef insertArgumentFunction, + LLVMValueRef insertReturnFunction, LLVMValueRef insertFunctionFunction, LLVMValueRef insertChoiceGradientFunction, LLVMValueRef insertArgumentGradientFunction, LLVMValueRef newTraceFunction, LLVMValueRef freeTraceFunction, LLVMValueRef hasCallFunction, diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index 7aa58aa84e7c..634d350fed81 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -1796,15 +1796,28 @@ class EnzymeBase { // Interface bool has_dynamic_interface = dynamic_interface != nullptr; - TraceInterface *interface; + bool needs_interface = + mode == ProbProgMode::Trace || mode == ProbProgMode::Condition; + TraceInterface *interface = nullptr; if (has_dynamic_interface) { interface = new DynamicTraceInterface(dynamic_interface, CI->getFunction()); - } else { + } else if (needs_interface) { interface = new StaticTraceInterface(F->getParent()); } - bool autodiff = dtrace; + // Find sample function + Function *sampleFunction = nullptr; + for (auto &&interface_func : F->getParent()->functions()) { + if (interface_func.getName().contains("__enzyme_sample")) { + assert(interface_func.getFunctionType()->getNumParams() >= 3); + sampleFunction = &interface_func; + } + } + + assert(sampleFunction); + + bool autodiff = dtrace || dlikelihood; IRBuilder<> AllocaBuilder(CI->getParent()->getFirstNonPHI()); if (!likelihood) { @@ -1812,8 +1825,8 @@ class EnzymeBase { nullptr, "likelihood"); Builder.CreateStore(ConstantFP::getNullValue(Builder.getDoubleTy()), likelihood); - args.push_back(likelihood); } + args.push_back(likelihood); if (autodiff && !dlikelihood) { dlikelihood = AllocaBuilder.CreateAlloca(AllocaBuilder.getDoubleTy(), @@ -1836,15 +1849,17 @@ class EnzymeBase { constants.push_back(DIFFE_TYPE::CONSTANT); } - args.push_back(trace); - dargs.push_back(trace); - constants.push_back(DIFFE_TYPE::CONSTANT); + if (mode == ProbProgMode::Trace || mode == ProbProgMode::Condition) { + args.push_back(trace); + dargs.push_back(trace); + constants.push_back(DIFFE_TYPE::CONSTANT); + } // Determine generative functions SmallPtrSet generativeFunctions; SetVector> workList; - workList.insert(interface->getSampleFunction()); - generativeFunctions.insert(interface->getSampleFunction()); + workList.insert(sampleFunction); + generativeFunctions.insert(sampleFunction); while (!workList.empty()) { auto todo = *workList.begin(); @@ -1871,9 +1886,9 @@ class EnzymeBase { #endif } - auto newFunc = - Logic.CreateTrace(F, generativeFunctions, opt->ActiveRandomVariables, - mode, autodiff, interface); + auto newFunc = Logic.CreateTrace(F, sampleFunction, generativeFunctions, + opt->ActiveRandomVariables, mode, autodiff, + interface); if (!autodiff) { auto call = CallInst::Create(newFunc->getFunctionType(), newFunc, args); @@ -2324,6 +2339,10 @@ class EnzymeBase { } else if (Fn->getName().contains("__enzyme_batch")) { enableEnzyme = true; batch = true; + } else if (Fn->getName().contains("__enzyme_likelihood")) { + enableEnzyme = true; + probProgMode = ProbProgMode::Likelihood; + probProg = true; } else if (Fn->getName().contains("__enzyme_trace")) { enableEnzyme = true; probProgMode = ProbProgMode::Trace; @@ -2645,8 +2664,8 @@ class EnzymeBase { for (auto &&Inst : BB) { if (auto CI = dyn_cast(&Inst)) { Function *enzyme_sample = CI->getCalledFunction(); - if (enzyme_sample && enzyme_sample->getName().contains( - TraceInterface::sampleFunctionName)) { + if (enzyme_sample && + enzyme_sample->getName().contains("__enzyme_sample")) { if (CI->getNumOperands() < 3) { EmitFailure( "IllegalNumberOfArguments", CI->getDebugLoc(), CI, diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index d73fd5a02c9a..e7990ffa2185 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -4988,7 +4988,7 @@ llvm::Function *EnzymeLogic::CreateBatch(Function *tobatch, unsigned width, }; Function * -EnzymeLogic::CreateTrace(Function *totrace, +EnzymeLogic::CreateTrace(Function *totrace, Function *sampleFunction, SmallPtrSetImpl &GenerativeFunctions, StringSet<> &ActiveRandomVariables, ProbProgMode mode, bool autodiff, TraceInterface *interface) { @@ -4998,8 +4998,8 @@ EnzymeLogic::CreateTrace(Function *totrace, } ValueToValueMapTy originalToNewFn; - TraceUtils *tutils = - TraceUtils::FromClone(mode, interface, totrace, originalToNewFn); + TraceUtils *tutils = TraceUtils::FromClone(mode, sampleFunction, interface, + totrace, originalToNewFn); TraceGenerator *tracer = new TraceGenerator(*this, tutils, autodiff, originalToNewFn, GenerativeFunctions, ActiveRandomVariables); diff --git a/enzyme/Enzyme/EnzymeLogic.h b/enzyme/Enzyme/EnzymeLogic.h index e447b9673f41..67b05898e797 100644 --- a/enzyme/Enzyme/EnzymeLogic.h +++ b/enzyme/Enzyme/EnzymeLogic.h @@ -469,7 +469,7 @@ class EnzymeLogic { BATCH_TYPE ret_type); llvm::Function * - CreateTrace(llvm::Function *totrace, + CreateTrace(llvm::Function *totrace, llvm::Function *sampleFunction, llvm::SmallPtrSetImpl &GenerativeFunctions, llvm::StringSet<> &ActiveRandomVariables, ProbProgMode mode, bool autodiff, TraceInterface *interface); diff --git a/enzyme/Enzyme/TraceGenerator.cpp b/enzyme/Enzyme/TraceGenerator.cpp index 7c5506d4936b..15b182b63223 100644 --- a/enzyme/Enzyme/TraceGenerator.cpp +++ b/enzyme/Enzyme/TraceGenerator.cpp @@ -59,6 +59,9 @@ TraceGenerator::TraceGenerator( }; void TraceGenerator::visitFunction(Function &F) { + if (mode == ProbProgMode::Likelihood) + return; + auto fn = tutils->newFunc; auto entry = fn->getEntryBlock().getFirstNonPHIOrDbgOrLifetime(); @@ -124,6 +127,7 @@ void TraceGenerator::handleSampleCall(CallInst &call, CallInst *new_call) { Function *samplefn = GetFunctionFromValue(new_call->getArgOperand(0)); Function *likelihoodfn = GetFunctionFromValue(new_call->getArgOperand(1)); + Value *address = new_call->getArgOperand(2); IRBuilder<> Builder(new_call); @@ -136,15 +140,21 @@ void TraceGenerator::handleSampleCall(CallInst &call, CallInst *new_call) { OutlineBuilder.CreateRet(choice); }; - std::string mode_str = - mode == ProbProgMode::Condition ? "condition" : "sample"; + const char *mode_str; + switch (mode) { + case ProbProgMode::Likelihood: + case ProbProgMode::Trace: + mode_str = "sample"; + break; + case ProbProgMode::Condition: + mode_str = "condition"; + break; + } auto sample_call = tutils->CreateOutlinedFunction( - Builder, OutlinedSample, - tutils->getTraceInterface()->insertChoiceTy()->getParamType(2), Args, - false, mode_str + "_" + samplefn->getName()); + Builder, OutlinedSample, samplefn->getReturnType(), Args, false, + Twine(mode_str) + "_" + samplefn->getName()); - Value *address = Args[0]; StringRef const_address; bool is_address_const = getConstantStringInfo(address, const_address); bool is_random_var_active = @@ -166,7 +176,8 @@ void TraceGenerator::handleSampleCall(CallInst &call, CallInst *new_call) { sample_call->addAttribute(AttributeList::FunctionIndex, activity_attribute); #endif - if (autodiff) { + if (autodiff && + (mode == ProbProgMode::Trace || mode == ProbProgMode::Condition)) { auto gradient_setter = ValueAsMetadata::get(tutils->interface->insertChoiceGradient(Builder)); auto gradient_setter_node = @@ -195,35 +206,37 @@ void TraceGenerator::handleSampleCall(CallInst &call, CallInst *new_call) { // create outlined trace function - Value *trace_args[] = {new_call->getArgOperand(2), score, sample_call}; + if (mode == ProbProgMode::Trace || mode == ProbProgMode::Condition) { + Value *trace_args[] = {address, score, sample_call}; - auto OutlinedTrace = [](IRBuilder<> &OutlineBuilder, - TraceUtils *OutlineTutils, - ArrayRef Arguments) { - OutlineTutils->InsertChoice(OutlineBuilder, Arguments[0], Arguments[1], - Arguments[2]); - OutlineBuilder.CreateRetVoid(); - }; + auto OutlinedTrace = [](IRBuilder<> &OutlineBuilder, + TraceUtils *OutlineTutils, + ArrayRef Arguments) { + OutlineTutils->InsertChoice(OutlineBuilder, Arguments[0], Arguments[1], + Arguments[2]); + OutlineBuilder.CreateRetVoid(); + }; - auto trace_call = tutils->CreateOutlinedFunction( - Builder, OutlinedTrace, Builder.getVoidTy(), trace_args, false, - "outline_insert_choice"); + auto trace_call = tutils->CreateOutlinedFunction( + Builder, OutlinedTrace, Builder.getVoidTy(), trace_args, false, + "outline_insert_choice"); #if LLVM_VERSION_MAJOR >= 14 - trace_call->addAttributeAtIndex( - AttributeList::FunctionIndex, - Attribute::get(call.getContext(), "enzyme_inactive")); - trace_call->addAttributeAtIndex( - AttributeList::FunctionIndex, - Attribute::get(call.getContext(), "enzyme_notypeanalysis")); + trace_call->addAttributeAtIndex( + AttributeList::FunctionIndex, + Attribute::get(call.getContext(), "enzyme_inactive")); + trace_call->addAttributeAtIndex( + AttributeList::FunctionIndex, + Attribute::get(call.getContext(), "enzyme_notypeanalysis")); #else - trace_call->addAttribute( - AttributeList::FunctionIndex, - Attribute::get(call.getContext(), "enzyme_inactive")); - trace_call->addAttribute( - AttributeList::FunctionIndex, - Attribute::get(call.getContext(), "enzyme_notypeanalysis")); + trace_call->addAttribute( + AttributeList::FunctionIndex, + Attribute::get(call.getContext(), "enzyme_inactive")); + trace_call->addAttribute( + AttributeList::FunctionIndex, + Attribute::get(call.getContext(), "enzyme_notypeanalysis")); #endif + } sample_call->takeName(new_call); new_call->replaceAllUsesWith(sample_call); @@ -232,36 +245,49 @@ void TraceGenerator::handleSampleCall(CallInst &call, CallInst *new_call) { void TraceGenerator::handleArbitraryCall(CallInst &call, CallInst *new_call) { IRBuilder<> Builder(new_call); - auto str = call.getName() + "." + call.getCalledFunction()->getName(); - auto address = Builder.CreateGlobalStringPtr(str.str()); SmallVector args; for (auto it = new_call->arg_begin(); it != new_call->arg_end(); it++) { args.push_back(*it); } - args.push_back(tutils->getLikelihood()); - Function *called = getFunctionFromCall(&call); assert(called); - Function *samplefn = - Logic.CreateTrace(called, generativeFunctions, activeRandomVariables, - mode, autodiff, tutils->interface); + Function *samplefn = Logic.CreateTrace( + called, tutils->sampleFunction, generativeFunctions, + activeRandomVariables, mode, autodiff, tutils->interface); - auto trace = tutils->CreateTrace(Builder); - - Instruction *tracecall; + Instruction *replacement; switch (mode) { + case ProbProgMode::Likelihood: { + SmallVector args_and_likelihood = SmallVector(args); + args_and_likelihood.push_back(tutils->getLikelihood()); + replacement = + Builder.CreateCall(samplefn->getFunctionType(), samplefn, + args_and_likelihood, "eval." + called->getName()); + break; + } case ProbProgMode::Trace: { + auto trace = tutils->CreateTrace(Builder); + auto address = Builder.CreateGlobalStringPtr( + (call.getName() + "." + called->getName()).str()); + SmallVector args_and_trace = SmallVector(args); + args_and_trace.push_back(tutils->getLikelihood()); args_and_trace.push_back(trace); - tracecall = + replacement = Builder.CreateCall(samplefn->getFunctionType(), samplefn, args_and_trace, "trace." + called->getName()); + + tutils->InsertCall(Builder, address, trace); break; } case ProbProgMode::Condition: { + auto trace = tutils->CreateTrace(Builder); + auto address = Builder.CreateGlobalStringPtr( + (call.getName() + "." + called->getName()).str()); + Instruction *hasCall = tutils->HasCall(Builder, address, "has.call." + call.getName()); Instruction *ThenTerm, *ElseTerm; @@ -277,6 +303,7 @@ void TraceGenerator::handleArbitraryCall(CallInst &call, CallInst *new_call) { SmallVector args_and_cond = SmallVector(args); auto observations = tutils->GetTrace(Builder, address, called->getName() + ".subtrace"); + args_and_cond.push_back(tutils->getLikelihood()); args_and_cond.push_back(observations); args_and_cond.push_back(trace); ThenTracecall = @@ -291,6 +318,7 @@ void TraceGenerator::handleArbitraryCall(CallInst &call, CallInst *new_call) { SmallVector args_and_null = SmallVector(args); auto observations = ConstantPointerNull::get(cast( tutils->getTraceInterface()->newTraceTy()->getReturnType())); + args_and_null.push_back(tutils->getLikelihood()); args_and_null.push_back(observations); args_and_null.push_back(trace); ElseTracecall = @@ -303,14 +331,15 @@ void TraceGenerator::handleArbitraryCall(CallInst &call, CallInst *new_call) { 2, call.getName()); phi->addIncoming(ThenTracecall, ThenTerm->getParent()); phi->addIncoming(ElseTracecall, ElseTerm->getParent()); - tracecall = phi; + replacement = phi; + + tutils->InsertCall(Builder, address, trace); + break; } } - tutils->InsertCall(Builder, address, trace); - - tracecall->takeName(new_call); - new_call->replaceAllUsesWith(tracecall); + replacement->takeName(new_call); + new_call->replaceAllUsesWith(replacement); new_call->eraseFromParent(); } @@ -321,8 +350,7 @@ void TraceGenerator::visitCallInst(CallInst &call) { CallInst *new_call = dyn_cast(originalToNewFn[&call]); - if (call.getCalledFunction() == - tutils->getTraceInterface()->getSampleFunction()) { + if (tutils->isSampleCall(&call)) { handleSampleCall(call, new_call); } else { handleArbitraryCall(call, new_call); diff --git a/enzyme/Enzyme/TraceInterface.cpp b/enzyme/Enzyme/TraceInterface.cpp index 02c98de5acb2..4d4f65eeb1d9 100644 --- a/enzyme/Enzyme/TraceInterface.cpp +++ b/enzyme/Enzyme/TraceInterface.cpp @@ -187,9 +187,6 @@ StaticTraceInterface::StaticTraceInterface(Module *M) } else if (F.getName().contains("__enzyme_has_choice")) { assert(F.getFunctionType() == hasChoiceTy()); hasChoiceFunction = &F; - } else if (F.getName().contains(sampleFunctionName)) { - assert(F.getFunctionType()->getNumParams() >= 3); - sampleFunction = &F; } } @@ -209,7 +206,6 @@ StaticTraceInterface::StaticTraceInterface(Module *M) assert(hasCallFunction); assert(hasChoiceFunction); - assert(sampleFunction); newTraceFunction->addFnAttr("enzyme_notypeanalysis"); freeTraceFunction->addFnAttr("enzyme_notypeanalysis"); @@ -224,7 +220,6 @@ StaticTraceInterface::StaticTraceInterface(Module *M) insertArgumentGradientFunction->addFnAttr("enzyme_notypeanalysis"); hasCallFunction->addFnAttr("enzyme_notypeanalysis"); hasChoiceFunction->addFnAttr("enzyme_notypeanalysis"); - sampleFunction->addFnAttr("enzyme_notypeanalysis"); newTraceFunction->addFnAttr("enzyme_inactive"); freeTraceFunction->addFnAttr("enzyme_inactive"); @@ -239,7 +234,6 @@ StaticTraceInterface::StaticTraceInterface(Module *M) insertArgumentGradientFunction->addFnAttr("enzyme_inactive"); hasCallFunction->addFnAttr("enzyme_inactive"); hasChoiceFunction->addFnAttr("enzyme_inactive"); - sampleFunction->addFnAttr("enzyme_inactive"); newTraceFunction->addFnAttr(Attribute::NoFree); getTraceFunction->addFnAttr(Attribute::NoFree); @@ -253,20 +247,18 @@ StaticTraceInterface::StaticTraceInterface(Module *M) insertArgumentGradientFunction->addFnAttr(Attribute::NoFree); hasCallFunction->addFnAttr(Attribute::NoFree); hasChoiceFunction->addFnAttr(Attribute::NoFree); - sampleFunction->addFnAttr(Attribute::NoFree); } StaticTraceInterface::StaticTraceInterface( - LLVMContext &C, Function *sampleFunction, Function *getTraceFunction, - Function *getChoiceFunction, Function *insertCallFunction, - Function *insertChoiceFunction, Function *insertArgumentFunction, - Function *insertReturnFunction, Function *insertFunctionFunction, - Function *insertChoiceGradientFunction, + LLVMContext &C, Function *getTraceFunction, Function *getChoiceFunction, + Function *insertCallFunction, Function *insertChoiceFunction, + Function *insertArgumentFunction, Function *insertReturnFunction, + Function *insertFunctionFunction, Function *insertChoiceGradientFunction, Function *insertArgumentGradientFunction, Function *newTraceFunction, Function *freeTraceFunction, Function *hasCallFunction, Function *hasChoiceFunction) - : TraceInterface(C), sampleFunction(sampleFunction), - getTraceFunction(getTraceFunction), getChoiceFunction(getChoiceFunction), + : TraceInterface(C), getTraceFunction(getTraceFunction), + getChoiceFunction(getChoiceFunction), insertCallFunction(insertCallFunction), insertChoiceFunction(insertChoiceFunction), insertArgumentFunction(insertArgumentFunction), @@ -277,9 +269,6 @@ StaticTraceInterface::StaticTraceInterface( newTraceFunction(newTraceFunction), freeTraceFunction(freeTraceFunction), hasCallFunction(hasCallFunction), hasChoiceFunction(hasChoiceFunction){}; -// implemented by enzyme -Function *StaticTraceInterface::getSampleFunction() { return sampleFunction; } - // user implemented Value *StaticTraceInterface::getTrace(IRBuilder<> &Builder) { return getTraceFunction; @@ -324,14 +313,6 @@ Value *StaticTraceInterface::hasChoice(IRBuilder<> &Builder) { DynamicTraceInterface::DynamicTraceInterface(Value *dynamicInterface, Function *F) : TraceInterface(F->getContext()) { - for (auto &&interface_func : F->getParent()->functions()) { - if (interface_func.getName().contains(TraceInterface::sampleFunctionName)) { - assert(interface_func.getFunctionType()->getNumParams() >= 3); - sampleFunction = &interface_func; - } - } - - assert(sampleFunction); assert(dynamicInterface); auto &M = *F->getParent(); @@ -382,7 +363,6 @@ DynamicTraceInterface::DynamicTraceInterface(Value *dynamicInterface, assert(hasCallFunction); assert(hasChoiceFunction); - assert(sampleFunction); } Function *DynamicTraceInterface::MaterializeInterfaceFunction( @@ -418,8 +398,6 @@ Function *DynamicTraceInterface::MaterializeInterfaceFunction( return F; } -Function *DynamicTraceInterface::getSampleFunction() { return sampleFunction; } - // user implemented Value *DynamicTraceInterface::getTrace(IRBuilder<> &Builder) { return getTraceFunction; diff --git a/enzyme/Enzyme/TraceInterface.h b/enzyme/Enzyme/TraceInterface.h index d26a008bb809..eeca7d1abc35 100644 --- a/enzyme/Enzyme/TraceInterface.h +++ b/enzyme/Enzyme/TraceInterface.h @@ -42,10 +42,6 @@ class TraceInterface { virtual ~TraceInterface() = default; public: - // implemented by enzyme - virtual llvm::Function *getSampleFunction() = 0; - static constexpr const char sampleFunctionName[] = "__enzyme_sample"; - // user implemented virtual llvm::Value *getTrace(llvm::IRBuilder<> &Builder) = 0; virtual llvm::Value *getChoice(llvm::IRBuilder<> &Builder) = 0; @@ -103,8 +99,6 @@ class TraceInterface { class StaticTraceInterface final : public TraceInterface { private: - llvm::Function *sampleFunction = nullptr; - // user implemented llvm::Function *getTraceFunction = nullptr; llvm::Function *getChoiceFunction = nullptr; llvm::Function *insertCallFunction = nullptr; @@ -122,24 +116,23 @@ class StaticTraceInterface final : public TraceInterface { public: StaticTraceInterface(llvm::Module *M); - StaticTraceInterface( - llvm::LLVMContext &C, llvm::Function *sampleFunction, - llvm::Function *getTraceFunction, llvm::Function *getChoiceFunction, - llvm::Function *insertCallFunction, llvm::Function *insertChoiceFunction, - llvm::Function *insertArgumentFunction, - llvm::Function *insertReturnFunction, - llvm::Function *insertFunctionFunction, - llvm::Function *insertChoiceGradientFunction, - llvm::Function *insertArgumentGradientFunction, - llvm::Function *newTraceFunction, llvm::Function *freeTraceFunction, - llvm::Function *hasCallFunction, llvm::Function *hasChoiceFunction); + StaticTraceInterface(llvm::LLVMContext &C, llvm::Function *getTraceFunction, + llvm::Function *getChoiceFunction, + llvm::Function *insertCallFunction, + llvm::Function *insertChoiceFunction, + llvm::Function *insertArgumentFunction, + llvm::Function *insertReturnFunction, + llvm::Function *insertFunctionFunction, + llvm::Function *insertChoiceGradientFunction, + llvm::Function *insertArgumentGradientFunction, + llvm::Function *newTraceFunction, + llvm::Function *freeTraceFunction, + llvm::Function *hasCallFunction, + llvm::Function *hasChoiceFunction); ~StaticTraceInterface() = default; public: - // implemented by enzyme - llvm::Function *getSampleFunction(); - // user implemented llvm::Value *getTrace(llvm::IRBuilder<> &Builder); llvm::Value *getChoice(llvm::IRBuilder<> &Builder); @@ -157,9 +150,6 @@ class StaticTraceInterface final : public TraceInterface { }; class DynamicTraceInterface final : public TraceInterface { -private: - llvm::Function *sampleFunction = nullptr; - private: llvm::Function *getTraceFunction; llvm::Function *getChoiceFunction; @@ -188,9 +178,6 @@ class DynamicTraceInterface final : public TraceInterface { const llvm::Twine &Name = ""); public: - // implemented by enzyme - llvm::Function *getSampleFunction(); - // user implemented llvm::Value *getTrace(llvm::IRBuilder<> &Builder); llvm::Value *getChoice(llvm::IRBuilder<> &Builder); diff --git a/enzyme/Enzyme/TraceUtils.cpp b/enzyme/Enzyme/TraceUtils.cpp index 0bf1c43dda89..e09eb8cdef4f 100644 --- a/enzyme/Enzyme/TraceUtils.cpp +++ b/enzyme/Enzyme/TraceUtils.cpp @@ -44,22 +44,19 @@ using namespace llvm; -TraceUtils::TraceUtils(ProbProgMode mode, Function *newFunc, Argument *trace, +TraceUtils::TraceUtils(ProbProgMode mode, Function *sampleFunction, + Function *newFunc, Argument *trace, Argument *observations, Argument *likelihood, TraceInterface *interface) : trace(trace), observations(observations), likelihood(likelihood), - interface(interface), mode(mode), newFunc(newFunc){}; + sampleFunction(sampleFunction), interface(interface), mode(mode), + newFunc(newFunc){}; -TraceUtils *TraceUtils::FromClone(ProbProgMode mode, TraceInterface *interface, - Function *oldFunc, +TraceUtils *TraceUtils::FromClone(ProbProgMode mode, Function *sampleFunction, + TraceInterface *interface, Function *oldFunc, ValueToValueMapTy &originalToNewFn) { - assert(interface); - auto &Context = oldFunc->getContext(); - FunctionType *orig_FTy = oldFunc->getFunctionType(); - Type *traceType = TraceInterface::getTraceTy(Context)->getReturnType(); - SmallVector params; for (unsigned i = 0; i < orig_FTy->getNumParams(); ++i) { @@ -69,20 +66,34 @@ TraceUtils *TraceUtils::FromClone(ProbProgMode mode, TraceInterface *interface, Type *likelihood_acc_type = PointerType::getDoublePtrTy(Context); params.push_back(likelihood_acc_type); - if (mode == ProbProgMode::Condition) - params.push_back(traceType); + if (mode == ProbProgMode::Trace || mode == ProbProgMode::Condition) { + Type *traceType = interface->getTraceTy()->getReturnType(); + + if (mode == ProbProgMode::Condition) + params.push_back(traceType); - params.push_back(traceType); + params.push_back(traceType); + } Type *RetTy = oldFunc->getReturnType(); FunctionType *FTy = FunctionType::get(RetTy, params, oldFunc->isVarArg()); - std::string mode_str = - (mode == ProbProgMode::Condition ? "condition" : "trace"); + const char *mode_str; + switch (mode) { + case ProbProgMode::Likelihood: + mode_str = "likelihood"; + break; + case ProbProgMode::Trace: + mode_str = "trace"; + break; + case ProbProgMode::Condition: + mode_str = "condition"; + break; + } Function *newFunc = Function::Create( FTy, Function::LinkageTypes::InternalLinkage, - mode_str + "_" + oldFunc->getName(), oldFunc->getParent()); + Twine(mode_str) + "_" + oldFunc->getName(), oldFunc->getParent()); auto DestArg = newFunc->arg_begin(); auto SrcArg = oldFunc->arg_begin(); @@ -111,11 +122,14 @@ TraceUtils *TraceUtils::FromClone(ProbProgMode mode, TraceInterface *interface, Argument *observations = nullptr; Argument *likelihood = nullptr; - auto arg = newFunc->arg_end() - 1; + auto arg = newFunc->arg_end(); - trace = arg; - arg->setName("trace"); - arg->addAttr(Attribute::get(Context, TraceParameterAttribute)); + if (mode == ProbProgMode::Trace || mode == ProbProgMode::Condition) { + arg -= 1; + trace = arg; + arg->setName("trace"); + arg->addAttr(Attribute::get(Context, TraceParameterAttribute)); + } if (mode == ProbProgMode::Condition) { arg -= 1; @@ -129,8 +143,8 @@ TraceUtils *TraceUtils::FromClone(ProbProgMode mode, TraceInterface *interface, arg->setName("likelihood"); arg->addAttr(Attribute::get(Context, LikelihoodParameterAttribute)); - return new TraceUtils(mode, newFunc, trace, observations, likelihood, - interface); + return new TraceUtils(mode, sampleFunction, newFunc, trace, observations, + likelihood, interface); }; TraceUtils::~TraceUtils() = default; @@ -390,6 +404,7 @@ Instruction *TraceUtils::SampleOrCondition(IRBuilder<> &Builder, auto parent_fn = Builder.GetInsertBlock()->getParent(); switch (mode) { + case ProbProgMode::Likelihood: case ProbProgMode::Trace: { auto sample_call = Builder.CreateCall(sample_fn->getFunctionType(), sample_fn, sample_args); @@ -449,8 +464,11 @@ CallInst *TraceUtils::CreateOutlinedFunction( Tys.push_back(observations->getType()); } - Vals.push_back(trace); - Tys.push_back(trace->getType()); + if (mode == ProbProgMode::Trace || mode == ProbProgMode::Condition) { + Vals.push_back(trace); + Tys.push_back(trace->getType()); + } + FunctionType *FTy = FunctionType::get(RetTy, Tys, false); Function *F = Function::Create(FTy, Function::LinkageTypes::InternalLinkage, Name, M); @@ -472,12 +490,19 @@ CallInst *TraceUtils::CreateOutlinedFunction( if (mode == ProbProgMode::Condition) observations_arg = idx++; - Argument *trace_arg = idx++; + Argument *trace_arg = nullptr; + if (mode == ProbProgMode::Trace || mode == ProbProgMode::Condition) + trace_arg = idx++; - TraceUtils OutlineTutils = TraceUtils(mode, F, trace_arg, observations_arg, - likelihood_arg, interface); + TraceUtils OutlineTutils = + TraceUtils(mode, sampleFunction, F, trace_arg, observations_arg, + likelihood_arg, interface); IRBuilder<> OutlineBuilder(Entry); Outlined(OutlineBuilder, &OutlineTutils, Rets); return Builder.CreateCall(FTy, F, Vals); } + +bool TraceUtils::isSampleCall(CallInst *call) { + return call->getCalledFunction() == sampleFunction; +} diff --git a/enzyme/Enzyme/TraceUtils.h b/enzyme/Enzyme/TraceUtils.h index 8e1788932802..8791ad1060c3 100644 --- a/enzyme/Enzyme/TraceUtils.h +++ b/enzyme/Enzyme/TraceUtils.h @@ -44,13 +44,14 @@ class TraceUtils { private: llvm::Value *trace; - llvm::Value *observations = nullptr; - llvm::Value *likelihood = nullptr; + llvm::Value *observations; + llvm::Value *likelihood; public: TraceInterface *interface; ProbProgMode mode; llvm::Function *newFunc; + llvm::Function *sampleFunction; constexpr static const char TraceParameterAttribute[] = "enzyme_trace"; constexpr static const char ObservationsParameterAttribute[] = @@ -59,13 +60,14 @@ class TraceUtils { "enzyme_likelihood"; public: - TraceUtils(ProbProgMode mode, llvm::Function *newFunc, llvm::Argument *trace, + TraceUtils(ProbProgMode mode, llvm::Function *sampleFunction, + llvm::Function *newFunc, llvm::Argument *trace, llvm::Argument *observations, llvm::Argument *likelihood, TraceInterface *interface); static TraceUtils * - FromClone(ProbProgMode mode, TraceInterface *interface, - llvm::Function *oldFunc, + FromClone(ProbProgMode mode, llvm::Function *sampleFunction, + TraceInterface *interface, llvm::Function *oldFunc, llvm::ValueMap &originalToNewFn); @@ -141,5 +143,7 @@ class TraceUtils { Outlined, llvm::Type *RetTy, llvm::ArrayRef Arguments, bool needsLikelihood = true, const llvm::Twine &Name = ""); + + bool isSampleCall(llvm::CallInst *call); }; #endif /* TraceUtils_h */ diff --git a/enzyme/Enzyme/Utils.h b/enzyme/Enzyme/Utils.h index 5b4dcaa413f6..4a8aa43f90fe 100644 --- a/enzyme/Enzyme/Utils.h +++ b/enzyme/Enzyme/Utils.h @@ -321,8 +321,9 @@ enum class DerivativeMode { }; enum class ProbProgMode { - Trace = 0, - Condition = 1, + Likelihood = 0, + Trace = 1, + Condition = 2, }; /// Classification of value as an original program diff --git a/enzyme/test/Enzyme/ProbProg/condition-dynamic.ll b/enzyme/test/Enzyme/ProbProg/condition-dynamic.ll index 71a47661a977..e92e42bbd2f4 100644 --- a/enzyme/test/Enzyme/ProbProg/condition-dynamic.ll +++ b/enzyme/test/Enzyme/ProbProg/condition-dynamic.ll @@ -139,12 +139,12 @@ entry: ; CHECK-NEXT: %new_trace.i = load i8* ()*, i8* ()** @new_trace_ptr ; CHECK-NEXT: %17 = call i8* %new_trace.i() ; CHECK-NEXT: %has_call.i = load i1 (i8*, i8*)*, i1 (i8*, i8*)** @has_call_ptr -; CHECK-NEXT: %18 = call i1 %has_call.i(i8* %observations, i8* getelementptr inbounds ([21 x i8], [21 x i8]* @2, i32 0, i32 0)) +; CHECK-NEXT: %18 = call i1 %has_call.i(i8* %observations, i8* getelementptr inbounds ([21 x i8], [21 x i8]* @6, i32 0, i32 0)) ; CHECK-NEXT: br i1 %18, label %condition.call2.with.trace, label %condition.call2.without.trace ; CHECK: condition.call2.with.trace: ; preds = %condition_normal.2.exit ; CHECK-NEXT: %get_trace.i = load i8* (i8*, i8*)*, i8* (i8*, i8*)** @get_trace_ptr -; CHECK-NEXT: %19 = call i8* %get_trace.i(i8* %observations, i8* getelementptr inbounds ([21 x i8], [21 x i8]* @2, i32 0, i32 0)) +; CHECK-NEXT: %19 = call i8* %get_trace.i(i8* %observations, i8* getelementptr inbounds ([21 x i8], [21 x i8]* @6, i32 0, i32 0)) ; CHECK-NEXT: %condition.calculate_loss = call double @condition_calculate_loss(double %5, double %12, double* %data, i32 %n, double* %likelihood, i8* %19, i8* %17) ; CHECK-NEXT: br label %entry.cntd @@ -155,7 +155,7 @@ entry: ; CHECK: entry.cntd: ; preds = %condition.call2.without.trace, %condition.call2.with.trace ; CHECK-NEXT: %call2 = phi double [ %condition.calculate_loss, %condition.call2.with.trace ], [ %trace.calculate_loss, %condition.call2.without.trace ] ; CHECK-NEXT: %insert_call.i = load void (i8*, i8*, i8*)*, void (i8*, i8*, i8*)** @insert_call_ptr -; CHECK-NEXT: call void %insert_call.i(i8* %trace, i8* getelementptr inbounds ([21 x i8], [21 x i8]* @2, i32 0, i32 0), i8* %17) +; CHECK-NEXT: call void %insert_call.i(i8* %trace, i8* getelementptr inbounds ([21 x i8], [21 x i8]* @6, i32 0, i32 0), i8* %17) ; CHECK-NEXT: %20 = bitcast double %call2 to i64 ; CHECK-NEXT: %21 = inttoptr i64 %20 to i8* ; CHECK-NEXT: %insert_return.i = load void (i8*, i8*, i64)*, void (i8*, i8*, i64)** @insert_return_ptr @@ -170,15 +170,15 @@ entry: ; CHECK-NEXT: call void %insert_function.i(i8* %trace, i8* bitcast (double (double, double, double*, i32, double*, i8*, i8*)* @condition_calculate_loss to i8*)) ; CHECK-NEXT: %0 = bitcast double %m to i64 ; CHECK-NEXT: %1 = inttoptr i64 %0 to i8* -; CHECK-NEXT: call void @insert_argument(i8* %trace, i8* nocapture readonly getelementptr inbounds ([2 x i8], [2 x i8]* @3, i32 0, i32 0), i8* %1, i64 8) +; CHECK-NEXT: call void @insert_argument(i8* %trace, i8* nocapture readonly getelementptr inbounds ([2 x i8], [2 x i8]* @2, i32 0, i32 0), i8* %1, i64 8) ; CHECK-NEXT: %2 = bitcast double %b to i64 ; CHECK-NEXT: %3 = inttoptr i64 %2 to i8* -; CHECK-NEXT: call void @insert_argument(i8* %trace, i8* nocapture readonly getelementptr inbounds ([2 x i8], [2 x i8]* @4, i32 0, i32 0), i8* %3, i64 8) +; CHECK-NEXT: call void @insert_argument(i8* %trace, i8* nocapture readonly getelementptr inbounds ([2 x i8], [2 x i8]* @3, i32 0, i32 0), i8* %3, i64 8) ; CHECK-NEXT: %4 = bitcast double* %data to i8* -; CHECK-NEXT: call void @insert_argument(i8* %trace, i8* nocapture readonly getelementptr inbounds ([5 x i8], [5 x i8]* @5, i32 0, i32 0), i8* %4, i64 0) +; CHECK-NEXT: call void @insert_argument(i8* %trace, i8* nocapture readonly getelementptr inbounds ([5 x i8], [5 x i8]* @4, i32 0, i32 0), i8* %4, i64 0) ; CHECK-NEXT: %5 = zext i32 %n to i64 ; CHECK-NEXT: %6 = inttoptr i64 %5 to i8* -; CHECK-NEXT: call void @insert_argument(i8* %trace, i8* nocapture readonly getelementptr inbounds ([2 x i8], [2 x i8]* @6, i32 0, i32 0), i8* %6, i64 4) +; CHECK-NEXT: call void @insert_argument(i8* %trace, i8* nocapture readonly getelementptr inbounds ([2 x i8], [2 x i8]* @5, i32 0, i32 0), i8* %6, i64 4) ; CHECK-NEXT: %cmp19 = icmp sgt i32 %n, 0 ; CHECK-NEXT: br i1 %cmp19, label %for.body.preheader, label %for.cond.cleanup diff --git a/enzyme/test/Enzyme/ProbProg/condition-static.ll b/enzyme/test/Enzyme/ProbProg/condition-static.ll index c0bb60ee6591..d3292575eefe 100644 --- a/enzyme/test/Enzyme/ProbProg/condition-static.ll +++ b/enzyme/test/Enzyme/ProbProg/condition-static.ll @@ -143,11 +143,11 @@ entry: ; CHECK-NEXT: %16 = inttoptr i64 %15 to i8* ; CHECK-NEXT: call void @__enzyme_insert_choice(i8* %trace, i8* nocapture readonly getelementptr inbounds ([2 x i8], [2 x i8]* @.str.2, i64 0, i64 0), double %likelihood.call1, i8* %16, i64 8) ; CHECK-NEXT: %trace2 = call i8* @__enzyme_newtrace() -; CHECK-NEXT: %has.call.call2 = call i1 @__enzyme_has_call(i8* %observations, i8* nocapture readonly getelementptr inbounds ([21 x i8], [21 x i8]* @2, i32 0, i32 0)) +; CHECK-NEXT: %has.call.call2 = call i1 @__enzyme_has_call(i8* %observations, i8* nocapture readonly getelementptr inbounds ([21 x i8], [21 x i8]* @6, i32 0, i32 0)) ; CHECK-NEXT: br i1 %has.call.call2, label %condition.call2.with.trace, label %condition.call2.without.trace ; CHECK: condition.call2.with.trace: ; preds = %condition_normal.2.exit -; CHECK-NEXT: %calculate_loss.subtrace = call i8* @__enzyme_get_trace(i8* %observations, i8* nocapture readonly getelementptr inbounds ([21 x i8], [21 x i8]* @2, i32 0, i32 0)) +; CHECK-NEXT: %calculate_loss.subtrace = call i8* @__enzyme_get_trace(i8* %observations, i8* nocapture readonly getelementptr inbounds ([21 x i8], [21 x i8]* @6, i32 0, i32 0)) ; CHECK-NEXT: %condition.calculate_loss = call double @condition_calculate_loss(double %5, double %12, double* %data, i32 %n, double* %likelihood, i8* %calculate_loss.subtrace, i8* %trace2) ; CHECK-NEXT: br label %entry.cntd @@ -157,7 +157,7 @@ entry: ; CHECK: entry.cntd: ; preds = %condition.call2.without.trace, %condition.call2.with.trace ; CHECK-NEXT: %call2 = phi double [ %condition.calculate_loss, %condition.call2.with.trace ], [ %trace.calculate_loss, %condition.call2.without.trace ] -; CHECK-NEXT: call void @__enzyme_insert_call(i8* %trace, i8* nocapture readonly getelementptr inbounds ([21 x i8], [21 x i8]* @2, i32 0, i32 0), i8* %trace2) +; CHECK-NEXT: call void @__enzyme_insert_call(i8* %trace, i8* nocapture readonly getelementptr inbounds ([21 x i8], [21 x i8]* @6, i32 0, i32 0), i8* %trace2) ; CHECK-NEXT: %17 = bitcast double %call2 to i64 ; CHECK-NEXT: %18 = inttoptr i64 %17 to i8* ; CHECK-NEXT: call void @__enzyme_insert_return(i8* %trace, i8* %18, i64 8) @@ -170,15 +170,15 @@ entry: ; CHECK-NEXT: call void @__enzyme_insert_function(i8* %trace, i8* bitcast (double (double, double, double*, i32, double*, i8*, i8*)* @condition_calculate_loss to i8*)) ; CHECK-NEXT: %0 = bitcast double %m to i64 ; CHECK-NEXT: %1 = inttoptr i64 %0 to i8* -; CHECK-NEXT: call void @__enzyme_insert_argument(i8* %trace, i8* nocapture readonly getelementptr inbounds ([2 x i8], [2 x i8]* @3, i32 0, i32 0), i8* %1, i64 8) +; CHECK-NEXT: call void @__enzyme_insert_argument(i8* %trace, i8* nocapture readonly getelementptr inbounds ([2 x i8], [2 x i8]* @2, i32 0, i32 0), i8* %1, i64 8) ; CHECK-NEXT: %2 = bitcast double %b to i64 ; CHECK-NEXT: %3 = inttoptr i64 %2 to i8* -; CHECK-NEXT: call void @__enzyme_insert_argument(i8* %trace, i8* nocapture readonly getelementptr inbounds ([2 x i8], [2 x i8]* @4, i32 0, i32 0), i8* %3, i64 8) +; CHECK-NEXT: call void @__enzyme_insert_argument(i8* %trace, i8* nocapture readonly getelementptr inbounds ([2 x i8], [2 x i8]* @3, i32 0, i32 0), i8* %3, i64 8) ; CHECK-NEXT: %4 = bitcast double* %data to i8* -; CHECK-NEXT: call void @__enzyme_insert_argument(i8* %trace, i8* nocapture readonly getelementptr inbounds ([5 x i8], [5 x i8]* @5, i32 0, i32 0), i8* %4, i64 0) +; CHECK-NEXT: call void @__enzyme_insert_argument(i8* %trace, i8* nocapture readonly getelementptr inbounds ([5 x i8], [5 x i8]* @4, i32 0, i32 0), i8* %4, i64 0) ; CHECK-NEXT: %5 = zext i32 %n to i64 ; CHECK-NEXT: %6 = inttoptr i64 %5 to i8* -; CHECK-NEXT: call void @__enzyme_insert_argument(i8* %trace, i8* nocapture readonly getelementptr inbounds ([2 x i8], [2 x i8]* @6, i32 0, i32 0), i8* %6, i64 4) +; CHECK-NEXT: call void @__enzyme_insert_argument(i8* %trace, i8* nocapture readonly getelementptr inbounds ([2 x i8], [2 x i8]* @5, i32 0, i32 0), i8* %6, i64 4) ; CHECK-NEXT: %cmp19 = icmp sgt i32 %n, 0 ; CHECK-NEXT: br i1 %cmp19, label %for.body.preheader, label %for.cond.cleanup diff --git a/enzyme/test/Enzyme/ProbProg/likelihood-only.ll b/enzyme/test/Enzyme/ProbProg/likelihood-only.ll new file mode 100644 index 000000000000..e24c94e1b927 --- /dev/null +++ b/enzyme/test/Enzyme/ProbProg/likelihood-only.ll @@ -0,0 +1,75 @@ +; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -instsimplify -S | FileCheck %s; fi +; RUN: %opt < %s %newLoadEnzyme -passes="enzyme,function(mem2reg,instsimplify)" -enzyme-preopt=false -S | FileCheck %s + +declare double @normal(double, double) +declare double @exp(double) +declare double @log(double) + +define double @normal_logpdf(double %mean, double %var, double %x) { + %i = fdiv double 1.000000e+00, %var + %i3 = fmul double %i, 0x40040D931FF62705 + %i4 = fdiv double %mean, %var + %i5 = fsub double %x, %i4 + %i6 = fmul double %i5, %i5 + %i7 = fmul double %i6, -5.000000e-01 + %i8 = tail call double @exp(double %i7) + %i9 = fmul double %i3, %i8 + %i10 = tail call double @log(double %i9) + ret double %i10 +} + +@.str = private constant [3 x i8] c"mu\00" +@.str.1 = private constant [2 x i8] c"x\00" + +@enzyme_duplikelihood = global i32 0 + +declare double @__enzyme_sample(double (double, double)*, double (double, double, double)*, i8*, double, double) +declare void @__enzyme_likelihood(void ()*, i32, double*, double*) + +define void @test() { +entry: + %mu = call double @__enzyme_sample(double (double, double)* @normal, double (double, double, double)* @normal_logpdf, i8* getelementptr inbounds ([3 x i8], [3 x i8]* @.str, i64 0, i64 0), double 0.0, double 1.0) + %x = call double @__enzyme_sample(double (double, double)* @normal, double (double, double, double)* @normal_logpdf, i8* getelementptr inbounds ([2 x i8], [2 x i8]* @.str.1, i64 0, i64 0), double %mu, double 1.0) + ret void +} + +define double @generate() { +entry: + %0 = load i32, i32* @enzyme_duplikelihood + %likelihood = alloca double + %dlikelihood = alloca double + store double 1.0, double* %dlikelihood + tail call void @__enzyme_likelihood(void ()* @test, i32 %0, double* %likelihood, double* %dlikelihood) + %res = load double, double* %likelihood + ret double %res +} + +; CHECK: define internal void @diffelikelihood_test(double* "enzyme_likelihood" %likelihood, double* %"likelihood'") +; CHECK-NEXT: entry: +; CHECK-NEXT: %0 = call double @normal(double 0.000000e+00, double 1.000000e+00) +; CHECK-NEXT: %likelihood.mu = call fast double @augmented_normal_logpdf.2(double 0.000000e+00, double 1.000000e+00, double %0) +; CHECK-NEXT: %log_prob_sum = load double, double* %likelihood +; CHECK-NEXT: %1 = fadd double %log_prob_sum, %likelihood.mu +; CHECK-NEXT: store double %1, double* %likelihood +; CHECK-NEXT: %2 = call double @normal(double %0, double 1.000000e+00) +; CHECK-NEXT: %likelihood.x = call fast double @augmented_normal_logpdf(double %0, double 1.000000e+00, double %2) +; CHECK-NEXT: %log_prob_sum1 = load double, double* %likelihood +; CHECK-NEXT: %3 = fadd double %log_prob_sum1, %likelihood.x +; CHECK-NEXT: store double %3, double* %likelihood +; CHECK-NEXT: br label %invertentry + +; CHECK: invertentry: ; preds = %entry +; CHECK-NEXT: %4 = load double, double* %"likelihood'" +; CHECK-NEXT: store double 0.000000e+00, double* %"likelihood'" +; CHECK-NEXT: %5 = load double, double* %"likelihood'" +; CHECK-NEXT: %6 = fadd fast double %5, %4 +; CHECK-NEXT: store double %6, double* %"likelihood'" +; CHECK-NEXT: %7 = call { double, double } @diffenormal_logpdf(double %0, double 1.000000e+00, double %2, double %4) +; CHECK-NEXT: %8 = load double, double* %"likelihood'" +; CHECK-NEXT: store double 0.000000e+00, double* %"likelihood'" +; CHECK-NEXT: %9 = load double, double* %"likelihood'" +; CHECK-NEXT: %10 = fadd fast double %9, %8 +; CHECK-NEXT: store double %10, double* %"likelihood'" +; CHECK-NEXT: %11 = call { double } @diffenormal_logpdf.3(double 0.000000e+00, double 1.000000e+00, double %0, double %8) +; CHECK-NEXT: ret void +; CHECK-NEXT: } \ No newline at end of file