Skip to content

Commit

Permalink
Runtime Activity: move to parameter (#2076)
Browse files Browse the repository at this point in the history
* Runtime Activity: move to parameter

* fmt
  • Loading branch information
wsmoses authored Sep 12, 2024
1 parent a0f88da commit 8aa216e
Show file tree
Hide file tree
Showing 19 changed files with 288 additions and 225 deletions.
57 changes: 31 additions & 26 deletions enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
}
} else {
Value *newip = gutils->invertPointerM(&I, BuilderZ);
if (EnzymeRuntimeActivityCheck && vd[{-1}].isFloat()) {
if (gutils->runtimeActivity && vd[{-1}].isFloat()) {
// TODO handle mask
assert(!mask);

Expand Down Expand Up @@ -997,7 +997,7 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
}

Value *diff = nullptr;
if (!EnzymeRuntimeActivityCheck && constantval) {
if (!gutils->runtimeActivity && constantval) {
if (dt.isPossiblePointer() && vd[{-1, -1}] != BaseType::Integer) {
if (!isa<UndefValue>(orig_val) &&
!isa<ConstantPointerNull>(orig_val)) {
Expand Down Expand Up @@ -1227,7 +1227,7 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
Value *valueop = nullptr;

if (constantval) {
if (!EnzymeRuntimeActivityCheck) {
if (!gutils->runtimeActivity) {
if (dt.isPossiblePointer() && vd[{-1, -1}] != BaseType::Integer) {
if (!isa<UndefValue>(orig_val) &&
!isa<ConstantPointerNull>(orig_val)) {
Expand Down Expand Up @@ -3290,7 +3290,7 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
bool errorIfNoType = true;
if ((Mode == DerivativeMode::ForwardMode ||
Mode == DerivativeMode::ForwardModeError) &&
(!gutils->isConstantValue(orig_src) && !EnzymeRuntimeActivityCheck)) {
(!gutils->isConstantValue(orig_src) && !gutils->runtimeActivity)) {
errorIfNoType = false;
}

Expand Down Expand Up @@ -3432,7 +3432,7 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
Legal = true;
}
if (!gutils->isConstantValue(orig_src) &&
!EnzymeRuntimeActivityCheck) {
!gutils->runtimeActivity) {
Legal = true;
}
}
Expand Down Expand Up @@ -3513,7 +3513,7 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
Type::getInt8Ty(ddst->getContext()), ddst, start);
}
CallInst *call;
// TODO add EnzymeRuntimeActivity (correctness)
// TODO add gutils->runtimeActivity (correctness)
if (dt.isFloat() && gutils->isConstantValue(orig_src)) {
call = BuilderZ.CreateMemSet(
ddst, ConstantInt::get(Type::getInt8Ty(ddst->getContext()), 0),
Expand Down Expand Up @@ -4105,7 +4105,7 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
subretType, argsInverted, TR.analyzer->interprocedural,
/*return is used*/ false,
/*shadowReturnUsed*/ false, nextTypeInfo, overwritten_args, false,
gutils->getWidth(),
gutils->runtimeActivity, gutils->getWidth(),
/*AtomicAdd*/ true,
/*OpenMP*/ true);
if (Mode == DerivativeMode::ReverseModePrimal) {
Expand Down Expand Up @@ -4315,21 +4315,23 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {

newcalled = gutils->Logic.CreatePrimalAndGradient(
RequestContext(&call, &Builder2),
(ReverseCacheKey){.todiff = cast<Function>(called),
.retType = subretType,
.constant_args = argsInverted,
.overwritten_args = overwritten_args,
.returnUsed = false,
.shadowReturnUsed = false,
.mode = DerivativeMode::ReverseModeGradient,
.width = gutils->getWidth(),
.freeMemory = true,
.AtomicAdd = true,
.additionalType =
tape ? PointerType::getUnqual(tape->getType())
: nullptr,
.forceAnonymousTape = false,
.typeInfo = nextTypeInfo},
(ReverseCacheKey){
.todiff = cast<Function>(called),
.retType = subretType,
.constant_args = argsInverted,
.overwritten_args = overwritten_args,
.returnUsed = false,
.shadowReturnUsed = false,
.mode = DerivativeMode::ReverseModeGradient,
.width = gutils->getWidth(),
.freeMemory = true,
.AtomicAdd = true,
.additionalType =
tape ? PointerType::getUnqual(tape->getType()) : nullptr,
.forceAnonymousTape = false,
.typeInfo = nextTypeInfo,
.runtimeActivity = gutils->runtimeActivity,
},
TR.analyzer->interprocedural, subdata,
/*omp*/ true);

Expand Down Expand Up @@ -4825,8 +4827,9 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
RequestContext(&call, &BuilderZ), cast<Function>(called),
subretType, argsInverted, TR.analyzer->interprocedural,
/*returnValue*/ subretused, Mode,
((DiffeGradientUtils *)gutils)->FreeMemory, gutils->getWidth(),
tape ? tape->getType() : nullptr, nextTypeInfo, overwritten_args,
((DiffeGradientUtils *)gutils)->FreeMemory, gutils->runtimeActivity,
gutils->getWidth(), tape ? tape->getType() : nullptr, nextTypeInfo,
overwritten_args,
/*augmented*/ subdata);
FT = cast<Function>(newcalled)->getFunctionType();
} else {
Expand Down Expand Up @@ -5214,7 +5217,8 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
RequestContext(&call, &BuilderZ), cast<Function>(called),
subretType, argsInverted, TR.analyzer->interprocedural,
/*return is used*/ subretused, shadowReturnUsed, nextTypeInfo,
overwritten_args, false, gutils->getWidth(), gutils->AtomicAdd);
overwritten_args, false, gutils->runtimeActivity,
gutils->getWidth(), gutils->AtomicAdd);
if (Mode == DerivativeMode::ReverseModePrimal) {
assert(augmentedReturn);
auto subaugmentations =
Expand Down Expand Up @@ -5645,7 +5649,8 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
.AtomicAdd = gutils->AtomicAdd,
.additionalType = tape ? tape->getType() : nullptr,
.forceAnonymousTape = false,
.typeInfo = nextTypeInfo},
.typeInfo = nextTypeInfo,
.runtimeActivity = gutils->runtimeActivity},
TR.analyzer->interprocedural, subdata);
if (!newcalled)
return;
Expand Down
59 changes: 33 additions & 26 deletions enzyme/Enzyme/CApi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,10 @@ void EnzymeRegisterDiffUseCallHandler(char *Name,
};
}

uint8_t EnzymeGradientUtilsGetRuntimeActivity(GradientUtils *gutils) {
return gutils->runtimeActivity;
}

uint64_t EnzymeGradientUtilsGetWidth(GradientUtils *gutils) {
return gutils->getWidth();
}
Expand Down Expand Up @@ -586,9 +590,10 @@ LLVMValueRef EnzymeCreateForwardDiff(
EnzymeLogicRef Logic, LLVMValueRef request_req, LLVMBuilderRef request_ip,
LLVMValueRef todiff, CDIFFE_TYPE retType, CDIFFE_TYPE *constant_args,
size_t constant_args_size, EnzymeTypeAnalysisRef TA, uint8_t returnValue,
CDerivativeMode mode, uint8_t freeMemory, unsigned width,
LLVMTypeRef additionalArg, CFnTypeInfo typeInfo, uint8_t *_overwritten_args,
size_t overwritten_args_size, EnzymeAugmentedReturnPtr augmented) {
CDerivativeMode mode, uint8_t freeMemory, uint8_t runtimeActivity,
unsigned width, LLVMTypeRef additionalArg, CFnTypeInfo typeInfo,
uint8_t *_overwritten_args, size_t overwritten_args_size,
EnzymeAugmentedReturnPtr augmented) {
SmallVector<DIFFE_TYPE, 4> nconstant_args((DIFFE_TYPE *)constant_args,
(DIFFE_TYPE *)constant_args +
constant_args_size);
Expand All @@ -601,16 +606,18 @@ LLVMValueRef EnzymeCreateForwardDiff(
RequestContext(cast_or_null<Instruction>(unwrap(request_req)),
unwrap(request_ip)),
cast<Function>(unwrap(todiff)), (DIFFE_TYPE)retType, nconstant_args,
eunwrap(TA), returnValue, (DerivativeMode)mode, freeMemory, width,
unwrap(additionalArg), eunwrap(typeInfo, cast<Function>(unwrap(todiff))),
overwritten_args, eunwrap(augmented)));
eunwrap(TA), returnValue, (DerivativeMode)mode, freeMemory,
runtimeActivity, width, unwrap(additionalArg),
eunwrap(typeInfo, cast<Function>(unwrap(todiff))), overwritten_args,
eunwrap(augmented)));
}
LLVMValueRef EnzymeCreatePrimalAndGradient(
EnzymeLogicRef Logic, LLVMValueRef request_req, LLVMBuilderRef request_ip,
LLVMValueRef todiff, CDIFFE_TYPE retType, CDIFFE_TYPE *constant_args,
size_t constant_args_size, EnzymeTypeAnalysisRef TA, uint8_t returnValue,
uint8_t dretUsed, CDerivativeMode mode, unsigned width, uint8_t freeMemory,
LLVMTypeRef additionalArg, uint8_t forceAnonymousTape, CFnTypeInfo typeInfo,
uint8_t dretUsed, CDerivativeMode mode, uint8_t runtimeActivity,
unsigned width, uint8_t freeMemory, LLVMTypeRef additionalArg,
uint8_t forceAnonymousTape, CFnTypeInfo typeInfo,
uint8_t *_overwritten_args, size_t overwritten_args_size,
EnzymeAugmentedReturnPtr augmented, uint8_t AtomicAdd) {
std::vector<DIFFE_TYPE> nconstant_args((DIFFE_TYPE *)constant_args,
Expand All @@ -624,30 +631,30 @@ LLVMValueRef EnzymeCreatePrimalAndGradient(
return wrap(eunwrap(Logic).CreatePrimalAndGradient(
RequestContext(cast_or_null<Instruction>(unwrap(request_req)),
unwrap(request_ip)),
(ReverseCacheKey){
.todiff = cast<Function>(unwrap(todiff)),
.retType = (DIFFE_TYPE)retType,
.constant_args = nconstant_args,
.overwritten_args = overwritten_args,
.returnUsed = (bool)returnValue,
.shadowReturnUsed = (bool)dretUsed,
.mode = (DerivativeMode)mode,
.width = width,
.freeMemory = (bool)freeMemory,
.AtomicAdd = (bool)AtomicAdd,
.additionalType = unwrap(additionalArg),
.forceAnonymousTape = (bool)forceAnonymousTape,
.typeInfo = eunwrap(typeInfo, cast<Function>(unwrap(todiff))),
},
(ReverseCacheKey){.todiff = cast<Function>(unwrap(todiff)),
.retType = (DIFFE_TYPE)retType,
.constant_args = nconstant_args,
.overwritten_args = overwritten_args,
.returnUsed = (bool)returnValue,
.shadowReturnUsed = (bool)dretUsed,
.mode = (DerivativeMode)mode,
.width = width,
.freeMemory = (bool)freeMemory,
.AtomicAdd = (bool)AtomicAdd,
.additionalType = unwrap(additionalArg),
.forceAnonymousTape = (bool)forceAnonymousTape,
.typeInfo =
eunwrap(typeInfo, cast<Function>(unwrap(todiff))),
.runtimeActivity = (bool)runtimeActivity},
eunwrap(TA), eunwrap(augmented)));
}
EnzymeAugmentedReturnPtr EnzymeCreateAugmentedPrimal(
EnzymeLogicRef Logic, LLVMValueRef request_req, LLVMBuilderRef request_ip,
LLVMValueRef todiff, CDIFFE_TYPE retType, CDIFFE_TYPE *constant_args,
size_t constant_args_size, EnzymeTypeAnalysisRef TA, uint8_t returnUsed,
uint8_t shadowReturnUsed, CFnTypeInfo typeInfo, uint8_t *_overwritten_args,
size_t overwritten_args_size, uint8_t forceAnonymousTape, unsigned width,
uint8_t AtomicAdd) {
size_t overwritten_args_size, uint8_t forceAnonymousTape,
uint8_t runtimeActivity, unsigned width, uint8_t AtomicAdd) {

SmallVector<DIFFE_TYPE, 4> nconstant_args((DIFFE_TYPE *)constant_args,
(DIFFE_TYPE *)constant_args +
Expand All @@ -663,7 +670,7 @@ EnzymeAugmentedReturnPtr EnzymeCreateAugmentedPrimal(
cast<Function>(unwrap(todiff)), (DIFFE_TYPE)retType, nconstant_args,
eunwrap(TA), returnUsed, shadowReturnUsed,
eunwrap(typeInfo, cast<Function>(unwrap(todiff))), overwritten_args,
forceAnonymousTape, width, AtomicAdd));
forceAnonymousTape, runtimeActivity, width, AtomicAdd));
}

LLVMValueRef EnzymeCreateBatch(EnzymeLogicRef Logic, LLVMValueRef request_req,
Expand Down
7 changes: 4 additions & 3 deletions enzyme/Enzyme/CApi.h
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,10 @@ LLVMValueRef EnzymeCreateForwardDiff(
EnzymeLogicRef Logic, LLVMValueRef request_req, LLVMBuilderRef request_ip,
LLVMValueRef todiff, CDIFFE_TYPE retType, CDIFFE_TYPE *constant_args,
size_t constant_args_size, EnzymeTypeAnalysisRef TA, uint8_t returnValue,
CDerivativeMode mode, uint8_t freeMemory, unsigned width,
LLVMTypeRef additionalArg, CFnTypeInfo typeInfo, uint8_t *_overwritten_args,
size_t overwritten_args_size, EnzymeAugmentedReturnPtr augmented);
CDerivativeMode mode, uint8_t freeMemory, uint8_t runtimeActivity,
unsigned width, LLVMTypeRef additionalArg, CFnTypeInfo typeInfo,
uint8_t *_overwritten_args, size_t overwritten_args_size,
EnzymeAugmentedReturnPtr augmented);

#ifdef __cplusplus
}
Expand Down
20 changes: 10 additions & 10 deletions enzyme/Enzyme/DiffeGradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,11 @@ DiffeGradientUtils::DiffeGradientUtils(
const SmallPtrSetImpl<Value *> &returnvals_, DIFFE_TYPE ActiveReturn,
bool shadowReturnUsed, ArrayRef<DIFFE_TYPE> constant_values,
llvm::ValueMap<const llvm::Value *, AssertingReplacingVH> &origToNew_,
DerivativeMode mode, unsigned width, bool omp)
DerivativeMode mode, bool runtimeActivity, unsigned width, bool omp)
: GradientUtils(Logic, newFunc_, oldFunc_, TLI, TA, TR, invertedPointers_,
constantvalues_, returnvals_, ActiveReturn,
shadowReturnUsed, constant_values, origToNew_, mode, width,
omp) {
shadowReturnUsed, constant_values, origToNew_, mode,
runtimeActivity, width, omp) {
if (oldFunc_->empty())
return;
assert(reverseBlocks.size() == 0);
Expand All @@ -84,11 +84,11 @@ DiffeGradientUtils::DiffeGradientUtils(
}

DiffeGradientUtils *DiffeGradientUtils::CreateFromClone(
EnzymeLogic &Logic, DerivativeMode mode, unsigned width, Function *todiff,
TargetLibraryInfo &TLI, TypeAnalysis &TA, FnTypeInfo &oldTypeInfo,
DIFFE_TYPE retType, bool shadowReturn, bool diffeReturnArg,
ArrayRef<DIFFE_TYPE> constant_args, ReturnType returnValue,
Type *additionalArg, bool omp) {
EnzymeLogic &Logic, DerivativeMode mode, bool runtimeActivity,
unsigned width, Function *todiff, TargetLibraryInfo &TLI, TypeAnalysis &TA,
FnTypeInfo &oldTypeInfo, DIFFE_TYPE retType, bool shadowReturn,
bool diffeReturnArg, ArrayRef<DIFFE_TYPE> constant_args,
ReturnType returnValue, Type *additionalArg, bool omp) {
Function *oldFunc = todiff;
assert(mode == DerivativeMode::ReverseModeGradient ||
mode == DerivativeMode::ReverseModeCombined ||
Expand Down Expand Up @@ -162,7 +162,7 @@ DiffeGradientUtils *DiffeGradientUtils::CreateFromClone(
auto res = new DiffeGradientUtils(
Logic, newFunc, oldFunc, TLI, TA, TR, invertedPointers, constant_values,
nonconstant_values, retType, shadowReturn, constant_args, originalToNew,
mode, width, omp);
mode, runtimeActivity, width, omp);

return res;
}
Expand Down Expand Up @@ -1175,7 +1175,7 @@ void DiffeGradientUtils::addToInvertedPtrDiffe(
// are distinct statically as they are allocas/mallocs, if not compare
// the pointers and conditionally execute.
if ((!isa<AllocaInst>(basePtr) && !isAllocationCall(basePtr, TLI)) &&
EnzymeRuntimeActivityCheck && !merge) {
runtimeActivity && !merge) {
Value *shadow = Builder2.CreateICmpNE(
lookupM(getNewFromOriginal(origptr), Builder2),
lookupM(invertPointerM(origptr, Builder2), Builder2));
Expand Down
9 changes: 5 additions & 4 deletions enzyme/Enzyme/DiffeGradientUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,17 +69,18 @@ class DiffeGradientUtils final : public GradientUtils {
DIFFE_TYPE ActiveReturn, bool shadowReturnUsed,
llvm::ArrayRef<DIFFE_TYPE> constant_values,
llvm::ValueMap<const llvm::Value *, AssertingReplacingVH> &origToNew_,
DerivativeMode mode, unsigned width, bool omp);
DerivativeMode mode, bool runtimeActivity, unsigned width, bool omp);

public:
/// Whether to free memory in reverse pass or split forward.
bool FreeMemory;
llvm::ValueMap<const llvm::Value *, llvm::TrackingVH<llvm::AllocaInst>>
differentials;
static DiffeGradientUtils *
CreateFromClone(EnzymeLogic &Logic, DerivativeMode mode, unsigned width,
llvm::Function *todiff, llvm::TargetLibraryInfo &TLI,
TypeAnalysis &TA, FnTypeInfo &oldTypeInfo, DIFFE_TYPE retType,
CreateFromClone(EnzymeLogic &Logic, DerivativeMode mode, bool runtimeActivity,
unsigned width, llvm::Function *todiff,
llvm::TargetLibraryInfo &TLI, TypeAnalysis &TA,
FnTypeInfo &oldTypeInfo, DIFFE_TYPE retType,
bool shadowReturnArg, bool diffeReturnArg,
llvm::ArrayRef<DIFFE_TYPE> constant_args,
ReturnType returnValue, llvm::Type *additionalArg, bool omp);
Expand Down
2 changes: 1 addition & 1 deletion enzyme/Enzyme/DifferentialUseAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ bool DifferentialUseAnalysis::is_use_directly_needed_in_reverse(

if (!shadow)
if (auto LI = dyn_cast<LoadInst>(user)) {
if (EnzymeRuntimeActivityCheck) {
if (gutils->runtimeActivity) {
auto vd = TR.query(const_cast<llvm::Instruction *>(user));
if (!vd.isKnown()) {
auto ET = LI->getType();
Expand Down
Loading

0 comments on commit 8aa216e

Please sign in to comment.