Skip to content

Commit

Permalink
Enable get return diffe type to be used in more scenarios (#1300)
Browse files Browse the repository at this point in the history
* Enable get return diffe type to be used in more scenarios

* Fix subretused of gc allocation
  • Loading branch information
wsmoses authored Jun 26, 2023
1 parent 366d3c0 commit bccd9ff
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -9618,7 +9618,7 @@ class AdjointGenerator
funcName == "ijl_alloc_array_3d" || funcName == "ijl_array_copy" ||
funcName == "julia.gc_alloc_obj" || funcName == "jl_gc_alloc_typed" ||
funcName == "ijl_gc_alloc_typed") {
if (unnecessaryValues.count(&call)) {
if (!subretused) {
eraseIfUnused(call, /*erase*/ true, /*check*/ false);
return;
}
Expand Down
4 changes: 2 additions & 2 deletions enzyme/Enzyme/CApi.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -397,8 +397,8 @@ EnzymeGradientUtilsGetReturnDiffeType(GradientUtils *G, LLVMValueRef oval,
uint8_t *needsShadow) {
bool needsPrimalB;
bool needsShadowB;
auto res = (CDIFFE_TYPE)(G->getReturnDiffeType(cast<CallInst>(unwrap(oval)),
&needsPrimalB, &needsShadowB));
auto res = (CDIFFE_TYPE)(G->getReturnDiffeType(unwrap(oval), &needsPrimalB,
&needsShadowB));
if (needsPrimal)
*needsPrimal = needsPrimalB;
if (needsShadow)
Expand Down
2 changes: 1 addition & 1 deletion enzyme/Enzyme/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4339,7 +4339,7 @@ GradientUtils *GradientUtils::CreateFromClone(
return res;
}

DIFFE_TYPE GradientUtils::getReturnDiffeType(llvm::CallInst *orig,
DIFFE_TYPE GradientUtils::getReturnDiffeType(llvm::Value *orig,
bool *primalReturnUsedP,
bool *shadowReturnUsedP) const {
bool shadowReturnUsed = false;
Expand Down
2 changes: 1 addition & 1 deletion enzyme/Enzyme/GradientUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ class GradientUtils : public CacheUtility {
public:
DIFFE_TYPE getDiffeType(llvm::Value *v, bool foreignFunction) const;

DIFFE_TYPE getReturnDiffeType(llvm::CallInst *orig, bool *primalReturnUsedP,
DIFFE_TYPE getReturnDiffeType(llvm::Value *orig, bool *primalReturnUsedP,
bool *shadowReturnUsedP) const;

static GradientUtils *
Expand Down

0 comments on commit bccd9ff

Please sign in to comment.