Skip to content

Commit

Permalink
Simplify no derivative error handling and add runtime error support (#…
Browse files Browse the repository at this point in the history
…1971)

* Simplify no derivative error handling

* runtime errors
  • Loading branch information
wsmoses authored Jul 8, 2024
1 parent cd39401 commit e24ec43
Show file tree
Hide file tree
Showing 6 changed files with 164 additions and 327 deletions.
192 changes: 31 additions & 161 deletions enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -330,12 +330,7 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
ss << "cannot handle unknown instruction\n" << inst;
IRBuilder<> Builder2(&inst);
getForwardBuilder(Builder2);
if (CustomErrorHandler) {
CustomErrorHandler(ss.str().c_str(), wrap(&inst), ErrorType::NoDerivative,
gutils, nullptr, wrap(&Builder2));
} else {
EmitFailure("NoDerivative", inst.getDebugLoc(), &inst, ss.str());
}
EmitNoDerivativeError(ss.str(), inst, gutils, Builder2);
if (!gutils->isConstantValue(&inst)) {
if (Mode == DerivativeMode::ForwardMode ||
Mode == DerivativeMode::ForwardModeError ||
Expand Down Expand Up @@ -466,14 +461,7 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
EmitWarning("CannotDeduceType", I, ss.str());
goto known;
}
if (CustomErrorHandler) {
CustomErrorHandler(str.c_str(), wrap(&I), ErrorType::NoType,
TR.analyzer, nullptr, wrap(&BuilderZ));
} else {
ss << "\n";
TR.dump(ss);
EmitFailure("CannotDeduceType", I.getDebugLoc(), &I, ss.str());
}
EmitNoTypeError(str, I, gutils, BuilderZ);
known:;
}

Expand Down Expand Up @@ -911,12 +899,7 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
llvm::raw_string_ostream ss(s);
ss << *I.getParent()->getParent() << "\n" << I << "\n";
ss << " Active atomic inst not yet handled";
if (CustomErrorHandler) {
CustomErrorHandler(ss.str().c_str(), wrap(&I), ErrorType::NoDerivative,
gutils, nullptr, wrap(&BuilderZ));
} else {
EmitFailure("NoDerivative", I.getDebugLoc(), &I, ss.str());
}
EmitNoDerivativeError(ss.str(), I, gutils, BuilderZ);
if (!gutils->isConstantValue(&I)) {
if (Mode == DerivativeMode::ForwardMode ||
Mode == DerivativeMode::ForwardModeError ||
Expand Down Expand Up @@ -1048,14 +1031,7 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
EmitWarning("CannotDeduceType", I, ss.str());
goto known;
}
if (CustomErrorHandler) {
CustomErrorHandler(str.c_str(), wrap(&I), ErrorType::NoType,
TR.analyzer, nullptr, wrap(&BuilderZ));
} else {
ss << "\n";
TR.dump(ss);
EmitFailure("CannotDeduceType", I.getDebugLoc(), &I, ss.str());
}
EmitNoTypeError(str, I, gutils, BuilderZ);
return;
known:;
}
Expand All @@ -1075,12 +1051,7 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
raw_string_ostream ss(str);
ss << "Cannot deduce single type of store " << I << vd.str()
<< " size: " << storeSize;
if (CustomErrorHandler) {
CustomErrorHandler(str.c_str(), wrap(&I), ErrorType::NoType,
TR.analyzer, nullptr, wrap(&BuilderZ));
} else {
EmitFailure("CannotDeduceType", I.getDebugLoc(), &I, ss.str());
}
EmitNoTypeError(str, I, gutils, BuilderZ);
return;
}
}
Expand Down Expand Up @@ -1153,12 +1124,7 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
raw_string_ostream ss(str);
ss << "Cannot deduce type of store " << I << vd.str()
<< " size: " << storeSize;
if (CustomErrorHandler) {
CustomErrorHandler(str.c_str(), wrap(&I), ErrorType::NoType,
TR.analyzer, nullptr, wrap(&BuilderZ));
} else {
EmitFailure("CannotDeduceType", I.getDebugLoc(), &I, ss.str());
}
EmitNoTypeError(str, I, gutils, BuilderZ);
break;
}

Expand Down Expand Up @@ -1462,16 +1428,7 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
std::string str;
raw_string_ostream ss(str);
ss << "Cannot deduce adding type (cast) of " << I;
if (CustomErrorHandler) {
CustomErrorHandler(ss.str().c_str(), wrap(&I), ErrorType::NoType,
TR.analyzer, nullptr, wrap(&Builder2));
return;
} else {
ss << "\n";
TR.dump(ss);
EmitFailure("CannotDeduceType", I.getDebugLoc(), &I, ss.str());
return;
}
EmitNoTypeError(str, I, gutils, Builder2);
}
assert(FT);

Expand All @@ -1489,15 +1446,7 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
llvm::raw_string_ostream ss(s);
ss << *I.getParent()->getParent() << "\n";
ss << "cannot handle above cast " << I << "\n";
if (CustomErrorHandler) {
CustomErrorHandler(ss.str().c_str(), wrap(&I),
ErrorType::NoDerivative, gutils, nullptr,
wrap(&Builder2));
} else {
ss << "\n";
TR.dump(ss);
EmitFailure("CannotHandleCast", I.getDebugLoc(), &I, ss.str());
}
EmitNoDerivativeError(ss.str(), I, gutils, Builder2);
return (llvm::Value *)UndefValue::get(op0->getType());
}
};
Expand Down Expand Up @@ -2029,13 +1978,7 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
raw_string_ostream ss(str);
ss << "Cannot deduce type of insertvalue ins " << IVI
<< " size: " << size0 << " TT: " << TT.str();
if (CustomErrorHandler) {
CustomErrorHandler(str.c_str(), wrap(&IVI), ErrorType::NoType,
TR.analyzer, nullptr, wrap(&Builder2));
} else {
EmitFailure("CannotDeduceType", IVI.getDebugLoc(), &IVI,
ss.str());
}
EmitNoTypeError(str, IVI, gutils, Builder2);
}
}

Expand Down Expand Up @@ -2114,13 +2057,7 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
ss << "Cannot deduce type of insertvalue agg " << IVI
<< " start: " << start << " size: " << size1
<< " TT: " << TT.str();
if (CustomErrorHandler) {
CustomErrorHandler(str.c_str(), wrap(&IVI), ErrorType::NoType,
TR.analyzer, nullptr, wrap(&Builder2));
} else {
EmitFailure("CannotDeduceType", IVI.getDebugLoc(), &IVI,
ss.str());
}
EmitNoTypeError(str, IVI, gutils, Builder2);
}
}

Expand Down Expand Up @@ -2622,12 +2559,7 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
<< " type: " << TR.query(&I).str() << "\n";
}
ss << "cannot handle unknown binary operator: " << BO << "\n";
if (CustomErrorHandler) {
CustomErrorHandler(ss.str().c_str(), wrap(&BO), ErrorType::NoDerivative,
gutils, nullptr, wrap(&Builder2));
} else {
EmitFailure("NoDerivative", BO.getDebugLoc(), &BO, ss.str());
}
EmitNoDerivativeError(ss.str(), BO, gutils, Builder2);
}

done:;
Expand Down Expand Up @@ -2863,18 +2795,11 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
<< " type: " << TR.query(&I).str() << "\n";
}
ss << "cannot handle unknown binary operator: " << BO << "\n";
if (CustomErrorHandler) {
auto rval = unwrap(CustomErrorHandler(ss.str().c_str(), wrap(&BO),
ErrorType::NoDerivative, gutils,
nullptr, wrap(&Builder2)));
if (!rval)
rval = Constant::getNullValue(gutils->getShadowType(BO.getType()));
if (!gutils->isConstantValue(&BO))
setDiffe(&BO, rval, Builder2);
} else {
EmitFailure("NoDerivative", BO.getDebugLoc(), &BO, ss.str());
return;
}
auto rval = EmitNoDerivativeError(ss.str(), BO, gutils, Builder2);
if (!rval)
rval = Constant::getNullValue(gutils->getShadowType(BO.getType()));
if (!gutils->isConstantValue(&BO))
setDiffe(&BO, rval, Builder2);
break;
}
}
Expand Down Expand Up @@ -2924,12 +2849,7 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
ss << "couldn't handle non constant inst in memset to "
"propagate differential to\n"
<< MS;
if (CustomErrorHandler) {
CustomErrorHandler(ss.str().c_str(), wrap(&MS), ErrorType::NoDerivative,
gutils, nullptr, wrap(&BuilderZ));
} else {
EmitFailure("NoDerivative", MS.getDebugLoc(), &MS, ss.str());
}
EmitNoDerivativeError(ss.str(), MS, gutils, BuilderZ);
}

if (Mode == DerivativeMode::ForwardMode ||
Expand Down Expand Up @@ -3150,14 +3070,7 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
std::string str;
raw_string_ostream ss(str);
ss << "Cannot deduce type of memset " << MS;
if (CustomErrorHandler) {
CustomErrorHandler(str.c_str(), wrap(&MS), ErrorType::NoType,
TR.analyzer, nullptr, wrap(&BuilderZ));
} else {
ss << "\n";
TR.dump(ss);
EmitFailure("CannotDeduceType", MS.getDebugLoc(), &MS, ss.str());
}
EmitNoTypeError(str, MS, gutils, BuilderZ);
return;
}
known:;
Expand Down Expand Up @@ -3458,15 +3371,7 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
std::string str;
raw_string_ostream ss(str);
ss << "Cannot deduce type of copy " << MTI;
if (CustomErrorHandler) {
CustomErrorHandler(str.c_str(), wrap(&MTI), ErrorType::NoType,
TR.analyzer, nullptr, wrap(&BuilderZ));
} else {
ss << "\n";
ss << *gutils->oldFunc << "\n";
TR.dump(ss);
EmitFailure("CannotDeduceType", MTI.getDebugLoc(), &MTI, ss.str());
}
EmitNoTypeError(str, MTI, gutils, BuilderZ);
vd = TypeTree(BaseType::Integer).Only(0, &MTI);
} else {
vd = TypeTree(BaseType::Pointer).Only(0, &MTI);
Expand Down Expand Up @@ -3810,17 +3715,10 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
ss << *gutils->oldFunc << "\n";
ss << *gutils->newFunc << "\n";
ss << "cannot handle (augmented) unknown intrinsic\n" << I;
if (CustomErrorHandler) {
IRBuilder<> BuilderZ(&I);
getForwardBuilder(BuilderZ);
CustomErrorHandler(ss.str().c_str(), wrap(&I),
ErrorType::NoDerivative, gutils, nullptr,
wrap(&BuilderZ));
return false;
} else {
EmitFailure("NoDerivative", I.getDebugLoc(), &I, ss.str());
return false;
}
IRBuilder<> BuilderZ(&I);
getForwardBuilder(BuilderZ);
EmitNoDerivativeError(ss.str(), I, gutils, BuilderZ);
return false;
}
return false;
}
Expand Down Expand Up @@ -3951,15 +3849,8 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
ss << "cannot handle (reverse) unknown intrinsic\n"
<< Intrinsic::getName(ID) << "\n"
<< I;
if (CustomErrorHandler) {
CustomErrorHandler(ss.str().c_str(), wrap(&I),
ErrorType::NoDerivative, gutils, nullptr,
wrap(&Builder2));
return false;
} else {
EmitFailure("NoDerivative", I.getDebugLoc(), &I, ss.str());
return false;
}
EmitNoDerivativeError(ss.str(), I, gutils, Builder2);
return false;
}
return false;
}
Expand Down Expand Up @@ -4031,13 +3922,7 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
ss << "cannot handle (forward) unknown intrinsic\n"
<< Intrinsic::getName(ID) << "\n"
<< I;
if (CustomErrorHandler) {
CustomErrorHandler(ss.str().c_str(), wrap(&I),
ErrorType::NoDerivative, gutils, nullptr,
wrap(&Builder2));
} else {
EmitFailure("NoDerivative", I.getDebugLoc(), &I, ss.str());
}
EmitNoDerivativeError(ss.str(), I, gutils, Builder2);
if (!gutils->isConstantValue(&I))
setDiffe(&I,
Constant::getNullValue(gutils->getShadowType(I.getType())),
Expand Down Expand Up @@ -5353,14 +5238,7 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
raw_string_ostream ss(str);
ss << "cannot find shadow for " << *callval
<< " for use as function in " << call;
if (CustomErrorHandler) {
CustomErrorHandler(ss.str().c_str(), wrap(&call),
ErrorType::NoDerivative, gutils, nullptr,
wrap(&BuilderZ));
} else {
EmitFailure("NoDerivative", call.getDebugLoc(), &call, ss.str());
return;
}
EmitNoDerivativeError(ss.str(), call, gutils, BuilderZ);
}
newcalled = gutils->invertPointerM(callval, BuilderZ);

Expand Down Expand Up @@ -5829,20 +5707,12 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
ss << "in Mode: " << to_string(Mode) << "\n";
ss << " orig: " << call << " callval: " << *callval << "\n";
ss << " constant function being called, but active call instruction\n";
if (CustomErrorHandler) {
auto val = unwrap(CustomErrorHandler(ss.str().c_str(), wrap(&call),
ErrorType::NoDerivative, gutils,
nullptr, wrap(&Builder2)));
if (val)
newcalled = val;
else
newcalled =
UndefValue::get(gutils->getShadowType(callval->getType()));
} else {
EmitFailure("NoDerivative", call.getDebugLoc(), &call, ss.str());
auto val = EmitNoDerivativeError(ss.str(), call, gutils, Builder2);
if (val)
newcalled = val;
else
newcalled =
UndefValue::get(gutils->getShadowType(callval->getType()));
}
} else {
newcalled = lookup(gutils->invertPointerM(callval, Builder2), Builder2);
}
Expand Down
24 changes: 4 additions & 20 deletions enzyme/Enzyme/CallDerivatives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1164,18 +1164,10 @@ void AdjointGenerator::handleMPI(llvm::CallInst &call, llvm::Function *called,
if (!isSum) {
std::string s;
llvm::raw_string_ostream ss(s);
ss << *gutils->oldFunc << "\n";
ss << *gutils->newFunc << "\n";
ss << " call: " << call << "\n";
ss << " unhandled mpi_reduce op: " << *orig_op << "\n";
if (CustomErrorHandler) {
CustomErrorHandler(ss.str().c_str(), wrap(&call),
ErrorType::NoDerivative, gutils, nullptr,
wrap(&BuilderZ));
} else {
EmitFailure("NoDerivative", call.getDebugLoc(), &call, ss.str());
return;
}
EmitNoDerivativeError(ss.str(), call, gutils, BuilderZ);
return;
}

Value *shadow_recvbuf = gutils->invertPointerM(orig_recvbuf, Builder2);
Expand Down Expand Up @@ -1413,18 +1405,10 @@ void AdjointGenerator::handleMPI(llvm::CallInst &call, llvm::Function *called,
if (!isSum) {
std::string s;
llvm::raw_string_ostream ss(s);
ss << *gutils->oldFunc << "\n";
ss << *gutils->newFunc << "\n";
ss << " call: " << call << "\n";
ss << " unhandled mpi_allreduce op: " << *orig_op << "\n";
if (CustomErrorHandler) {
CustomErrorHandler(ss.str().c_str(), wrap(&call),
ErrorType::NoDerivative, gutils, nullptr,
wrap(&BuilderZ));
} else {
EmitFailure("NoDerivative", call.getDebugLoc(), &call, ss.str());
return;
}
EmitNoDerivativeError(ss.str(), call, gutils, BuilderZ);
return;
}

Value *shadow_recvbuf = gutils->invertPointerM(orig_recvbuf, Builder2);
Expand Down
Loading

0 comments on commit e24ec43

Please sign in to comment.