Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ProbProg] Likelihood only mode #1303

Merged
merged 3 commits into from
Jun 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -10117,14 +10117,20 @@ class AdjointGenerator
dchoice = diffe(&call, Builder2);
}

auto gradient_setter = cast<Function>(
cast<ValueAsMetadata>(
call.getMetadata("enzyme_gradient_setter")->getOperand(0).get())
->getValue());
#if LLVM_VERSION_MAJOR >= 10
if (call.hasMetadata("enzyme_gradient_setter")) {
#else
if (call.getMetadata("enzyme_gradient_setter")) {
#endif
auto gradient_setter = cast<Function>(
cast<ValueAsMetadata>(
call.getMetadata("enzyme_gradient_setter")->getOperand(0).get())
->getValue());

TraceUtils::InsertChoiceGradient(
Builder2, gradient_setter->getFunctionType(), gradient_setter,
daddress, dchoice, dtrace);
TraceUtils::InsertChoiceGradient(
Builder2, gradient_setter->getFunctionType(), gradient_setter,
daddress, dchoice, dtrace);
}

return;
}
Expand Down
30 changes: 13 additions & 17 deletions enzyme/Enzyme/CApi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -202,18 +202,16 @@ EnzymeTraceInterfaceRef FindEnzymeStaticTraceInterface(LLVMModuleRef M) {
}

EnzymeTraceInterfaceRef CreateEnzymeStaticTraceInterface(
LLVMContextRef C, LLVMValueRef sampleFunction,
LLVMValueRef getTraceFunction, LLVMValueRef getChoiceFunction,
LLVMValueRef insertCallFunction, LLVMValueRef insertChoiceFunction,
LLVMValueRef insertArgumentFunction, LLVMValueRef insertReturnFunction,
LLVMValueRef insertFunctionFunction,
LLVMContextRef C, LLVMValueRef getTraceFunction,
LLVMValueRef getChoiceFunction, LLVMValueRef insertCallFunction,
LLVMValueRef insertChoiceFunction, LLVMValueRef insertArgumentFunction,
LLVMValueRef insertReturnFunction, LLVMValueRef insertFunctionFunction,
LLVMValueRef insertChoiceGradientFunction,
LLVMValueRef insertArgumentGradientFunction, LLVMValueRef newTraceFunction,
LLVMValueRef freeTraceFunction, LLVMValueRef hasCallFunction,
LLVMValueRef hasChoiceFunction) {
return (EnzymeTraceInterfaceRef)(new StaticTraceInterface(
*unwrap(C), cast<Function>(unwrap(sampleFunction)),
cast<Function>(unwrap(getTraceFunction)),
*unwrap(C), cast<Function>(unwrap(getTraceFunction)),
cast<Function>(unwrap(getChoiceFunction)),
cast<Function>(unwrap(insertCallFunction)),
cast<Function>(unwrap(insertChoiceFunction)),
Expand Down Expand Up @@ -622,13 +620,11 @@ EnzymeAugmentedReturnPtr EnzymeCreateAugmentedPrimal(
forceAnonymousTape, width, AtomicAdd));
}

LLVMValueRef CreateTrace(EnzymeLogicRef Logic, LLVMValueRef totrace,
LLVMValueRef *generative_functions,
size_t generative_functions_size,
const char *active_random_variables[],
size_t active_random_variables_size,
CProbProgMode mode, uint8_t autodiff,
EnzymeTraceInterfaceRef interface) {
LLVMValueRef CreateTrace(
EnzymeLogicRef Logic, LLVMValueRef totrace, LLVMValueRef sample_function,
LLVMValueRef *generative_functions, size_t generative_functions_size,
const char *active_random_variables[], size_t active_random_variables_size,
CProbProgMode mode, uint8_t autodiff, EnzymeTraceInterfaceRef interface) {

SmallPtrSet<Function *, 4> GenerativeFunctions;
for (size_t i = 0; i < generative_functions_size; i++) {
Expand All @@ -641,9 +637,9 @@ LLVMValueRef CreateTrace(EnzymeLogicRef Logic, LLVMValueRef totrace,
}

return wrap(eunwrap(Logic).CreateTrace(
cast<Function>(unwrap(totrace)), GenerativeFunctions,
ActiveRandomVariables, (ProbProgMode)mode, (bool)autodiff,
eunwrap(interface)));
cast<Function>(unwrap(totrace)), cast<Function>(unwrap(sample_function)),
GenerativeFunctions, ActiveRandomVariables, (ProbProgMode)mode,
(bool)autodiff, eunwrap(interface)));
}

LLVMValueRef
Expand Down
21 changes: 9 additions & 12 deletions enzyme/Enzyme/CApi.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,13 +158,11 @@ EnzymeAugmentedReturnPtr EnzymeCreateAugmentedPrimal(
size_t uncacheable_args_size, uint8_t forceAnonymousTape, unsigned width,
uint8_t AtomicAdd);

LLVMValueRef CreateTrace(EnzymeLogicRef Logic, LLVMValueRef totrace,
LLVMValueRef *generative_functions,
size_t generative_functions_size,
const char *active_random_variables[],
size_t active_random_variables_size,
CProbProgMode mode, uint8_t autodiff,
EnzymeTraceInterfaceRef interface);
LLVMValueRef CreateTrace(
EnzymeLogicRef Logic, LLVMValueRef totrace, LLVMValueRef sample_function,
LLVMValueRef *generative_functions, size_t generative_functions_size,
const char *active_random_variables[], size_t active_random_variables_size,
CProbProgMode mode, uint8_t autodiff, EnzymeTraceInterfaceRef interface);

typedef uint8_t (*CustomRuleType)(int /*direction*/, CTypeTreeRef /*return*/,
CTypeTreeRef * /*args*/,
Expand All @@ -180,11 +178,10 @@ void FreeTypeAnalysis(EnzymeTypeAnalysisRef);

EnzymeTraceInterfaceRef FindEnzymeStaticTraceInterface(LLVMModuleRef M);
EnzymeTraceInterfaceRef CreateEnzymeStaticTraceInterface(
LLVMContextRef C, LLVMValueRef sampleFunction,
LLVMValueRef getTraceFunction, LLVMValueRef getChoiceFunction,
LLVMValueRef insertCallFunction, LLVMValueRef insertChoiceFunction,
LLVMValueRef insertArgumentFunction, LLVMValueRef insertReturnFunction,
LLVMValueRef insertFunctionFunction,
LLVMContextRef C, LLVMValueRef getTraceFunction,
LLVMValueRef getChoiceFunction, LLVMValueRef insertCallFunction,
LLVMValueRef insertChoiceFunction, LLVMValueRef insertArgumentFunction,
LLVMValueRef insertReturnFunction, LLVMValueRef insertFunctionFunction,
LLVMValueRef insertChoiceGradientFunction,
LLVMValueRef insertArgumentGradientFunction, LLVMValueRef newTraceFunction,
LLVMValueRef freeTraceFunction, LLVMValueRef hasCallFunction,
Expand Down
47 changes: 33 additions & 14 deletions enzyme/Enzyme/Enzyme.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1796,24 +1796,37 @@ class EnzymeBase {

// Interface
bool has_dynamic_interface = dynamic_interface != nullptr;
TraceInterface *interface;
bool needs_interface =
mode == ProbProgMode::Trace || mode == ProbProgMode::Condition;
TraceInterface *interface = nullptr;
if (has_dynamic_interface) {
interface =
new DynamicTraceInterface(dynamic_interface, CI->getFunction());
} else {
} else if (needs_interface) {
interface = new StaticTraceInterface(F->getParent());
}

bool autodiff = dtrace;
// Find sample function
Function *sampleFunction = nullptr;
for (auto &&interface_func : F->getParent()->functions()) {
if (interface_func.getName().contains("__enzyme_sample")) {
assert(interface_func.getFunctionType()->getNumParams() >= 3);
sampleFunction = &interface_func;
}
}

assert(sampleFunction);

bool autodiff = dtrace || dlikelihood;
IRBuilder<> AllocaBuilder(CI->getParent()->getFirstNonPHI());

if (!likelihood) {
likelihood = AllocaBuilder.CreateAlloca(AllocaBuilder.getDoubleTy(),
nullptr, "likelihood");
Builder.CreateStore(ConstantFP::getNullValue(Builder.getDoubleTy()),
likelihood);
args.push_back(likelihood);
}
args.push_back(likelihood);

if (autodiff && !dlikelihood) {
dlikelihood = AllocaBuilder.CreateAlloca(AllocaBuilder.getDoubleTy(),
Expand All @@ -1836,15 +1849,17 @@ class EnzymeBase {
constants.push_back(DIFFE_TYPE::CONSTANT);
}

args.push_back(trace);
dargs.push_back(trace);
constants.push_back(DIFFE_TYPE::CONSTANT);
if (mode == ProbProgMode::Trace || mode == ProbProgMode::Condition) {
args.push_back(trace);
dargs.push_back(trace);
constants.push_back(DIFFE_TYPE::CONSTANT);
}

// Determine generative functions
SmallPtrSet<Function *, 4> generativeFunctions;
SetVector<Function *, std::deque<Function *>> workList;
workList.insert(interface->getSampleFunction());
generativeFunctions.insert(interface->getSampleFunction());
workList.insert(sampleFunction);
generativeFunctions.insert(sampleFunction);

while (!workList.empty()) {
auto todo = *workList.begin();
Expand All @@ -1871,9 +1886,9 @@ class EnzymeBase {
#endif
}

auto newFunc =
Logic.CreateTrace(F, generativeFunctions, opt->ActiveRandomVariables,
mode, autodiff, interface);
auto newFunc = Logic.CreateTrace(F, sampleFunction, generativeFunctions,
opt->ActiveRandomVariables, mode, autodiff,
interface);

if (!autodiff) {
auto call = CallInst::Create(newFunc->getFunctionType(), newFunc, args);
Expand Down Expand Up @@ -2324,6 +2339,10 @@ class EnzymeBase {
} else if (Fn->getName().contains("__enzyme_batch")) {
enableEnzyme = true;
batch = true;
} else if (Fn->getName().contains("__enzyme_likelihood")) {
enableEnzyme = true;
probProgMode = ProbProgMode::Likelihood;
probProg = true;
} else if (Fn->getName().contains("__enzyme_trace")) {
enableEnzyme = true;
probProgMode = ProbProgMode::Trace;
Expand Down Expand Up @@ -2645,8 +2664,8 @@ class EnzymeBase {
for (auto &&Inst : BB) {
if (auto CI = dyn_cast<CallInst>(&Inst)) {
Function *enzyme_sample = CI->getCalledFunction();
if (enzyme_sample && enzyme_sample->getName().contains(
TraceInterface::sampleFunctionName)) {
if (enzyme_sample &&
enzyme_sample->getName().contains("__enzyme_sample")) {
if (CI->getNumOperands() < 3) {
EmitFailure(
"IllegalNumberOfArguments", CI->getDebugLoc(), CI,
Expand Down
6 changes: 3 additions & 3 deletions enzyme/Enzyme/EnzymeLogic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4988,7 +4988,7 @@ llvm::Function *EnzymeLogic::CreateBatch(Function *tobatch, unsigned width,
};

Function *
EnzymeLogic::CreateTrace(Function *totrace,
EnzymeLogic::CreateTrace(Function *totrace, Function *sampleFunction,
SmallPtrSetImpl<Function *> &GenerativeFunctions,
StringSet<> &ActiveRandomVariables, ProbProgMode mode,
bool autodiff, TraceInterface *interface) {
Expand All @@ -4998,8 +4998,8 @@ EnzymeLogic::CreateTrace(Function *totrace,
}

ValueToValueMapTy originalToNewFn;
TraceUtils *tutils =
TraceUtils::FromClone(mode, interface, totrace, originalToNewFn);
TraceUtils *tutils = TraceUtils::FromClone(mode, sampleFunction, interface,
totrace, originalToNewFn);
TraceGenerator *tracer =
new TraceGenerator(*this, tutils, autodiff, originalToNewFn,
GenerativeFunctions, ActiveRandomVariables);
Expand Down
2 changes: 1 addition & 1 deletion enzyme/Enzyme/EnzymeLogic.h
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ class EnzymeLogic {
BATCH_TYPE ret_type);

llvm::Function *
CreateTrace(llvm::Function *totrace,
CreateTrace(llvm::Function *totrace, llvm::Function *sampleFunction,
llvm::SmallPtrSetImpl<llvm::Function *> &GenerativeFunctions,
llvm::StringSet<> &ActiveRandomVariables, ProbProgMode mode,
bool autodiff, TraceInterface *interface);
Expand Down
Loading
Loading