Skip to content

Commit

Permalink
Runtime complex ret (#1957)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Jul 1, 2024
1 parent d23b53a commit b53704d
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions enzyme/Enzyme/EnzymeLogic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3183,8 +3183,11 @@ void createTerminator(DiffeGradientUtils *gutils, BasicBlock *oBB,

if (retType != DIFFE_TYPE::CONSTANT) {
auto ret = inst->getOperand(0);
if (!ret->getType()->isFPOrFPVectorTy() &&
TR.getReturnAnalysis().Inner0().isPossiblePointer()) {
Type *rt = ret->getType();
while (auto AT = dyn_cast<ArrayType>(rt))
rt = AT->getElementType();
bool floatLike = rt->isFPOrFPVectorTy();
if (!floatLike && TR.getReturnAnalysis().Inner0().isPossiblePointer()) {
if (gutils->isConstantValue(ret)) {
if (!EnzymeRuntimeActivityCheck &&
TR.query(ret)[{-1}].isPossiblePointer()) {
Expand Down Expand Up @@ -3214,8 +3217,6 @@ void createTerminator(DiffeGradientUtils *gutils, BasicBlock *oBB,
rt = AT->getElementType();
bool floatLike = rt->isFPOrFPVectorTy();

if (auto AT = dyn_cast<ArrayType>(ret->getType()))
floatLike |= AT->getElementType()->isFPOrFPVectorTy();
if (retType == DIFFE_TYPE::CONSTANT) {
toret = gutils->getNewFromOriginal(ret);
} else if (!floatLike &&
Expand All @@ -3237,11 +3238,15 @@ void createTerminator(DiffeGradientUtils *gutils, BasicBlock *oBB,
assert(false && "Invalid return type");
auto ret = inst->getOperand(0);

Type *rt = ret->getType();
while (auto AT = dyn_cast<ArrayType>(rt))
rt = AT->getElementType();
bool floatLike = rt->isFPOrFPVectorTy();

toret =
nBuilder.CreateInsertValue(toret, gutils->getNewFromOriginal(ret), 0);

if (!ret->getType()->isFPOrFPVectorTy() &&
TR.getReturnAnalysis().Inner0().isPossiblePointer()) {
if (!floatLike && TR.getReturnAnalysis().Inner0().isPossiblePointer()) {
toret = nBuilder.CreateInsertValue(
toret,
invertedPtr ? invertedPtr : gutils->invertPointerM(ret, nBuilder), 1);
Expand Down

0 comments on commit b53704d

Please sign in to comment.