Skip to content

Commit

Permalink
C++ error message for incorrect custom gradient type (#1764)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Feb 27, 2024
1 parent 1588442 commit 3e2de5d
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 8 deletions.
4 changes: 2 additions & 2 deletions enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -4821,7 +4821,7 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
if (call.getAttributes().hasParamAttr(i, attr)) {
if (gutils->getWidth() == 1) {
structAttrs[args.size()].push_back(call.getParamAttr(i, attr));
} else if (attr == "enzymejl_returnRoots") {
} else if (attr == std::string("enzymejl_returnRoots")) {
structAttrs[args.size()].push_back(
Attribute::get(call.getContext(), "enzymejl_returnRoots_v"));
}
Expand Down Expand Up @@ -5155,7 +5155,7 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
if (gutils->getWidth() == 1) {
structAttrs[pre_args.size()].push_back(
call.getParamAttr(i, attr));
} else if (attr == "enzymejl_returnRoots") {
} else if (attr == std::string("enzymejl_returnRoots")) {
structAttrs[pre_args.size()].push_back(
Attribute::get(call.getContext(), "enzymejl_returnRoots_v"));
}
Expand Down
38 changes: 32 additions & 6 deletions enzyme/Enzyme/EnzymeLogic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3937,14 +3937,40 @@ Function *EnzymeLogic::CreatePrimalAndGradient(
hasTape = false;
// res.first.push_back(StructType::get(todiff->getContext(), {}));
} else {
llvm::errs() << "expected args: [";
std::string s;
llvm::raw_string_ostream ss(s);
ss << "Bad function type of custom reverse pass for function "
<< key.todiff->getName() << " of type "
<< *key.todiff->getFunctionType() << "\n";
ss << " expected gradient function to have argument types [";
bool seen = false;
for (auto a : res.first) {
llvm::errs() << *a << " ";
if (seen)
ss << ", ";
seen = true;
ss << *a;
}
ss << "]\n";
ss << " Instead found " << foundcalled->getName() << " of type "
<< *foundcalled->getFunctionType() << "\n";
Value *toshow = key.todiff;
if (context.req) {
toshow = context.req;
ss << " at context: " << *context.req;
} else {
ss << *key.todiff << "\n";
}
if (CustomErrorHandler) {
CustomErrorHandler(ss.str().c_str(), wrap(toshow),
ErrorType::NoDerivative, nullptr, wrap(key.todiff),
wrap(context.ip));
} else if (context.req) {
EmitFailure("NoDerivative", context.req->getDebugLoc(), context.req,
ss.str());
} else {
assert(0 && "bad type for custom gradient");
llvm_unreachable("bad type for custom gradient");
}
llvm::errs() << "]\n";
llvm::errs() << *foundcalled << "\n";
assert(0 && "bad type for custom gradient");
llvm_unreachable("bad type for custom gradient");
}

auto st = dyn_cast<StructType>(foundcalled->getReturnType());
Expand Down

0 comments on commit 3e2de5d

Please sign in to comment.