diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index 9f1d3cd65dfc..a3abf6991e95 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -4821,7 +4821,7 @@ class AdjointGenerator : public llvm::InstVisitor { if (call.getAttributes().hasParamAttr(i, attr)) { if (gutils->getWidth() == 1) { structAttrs[args.size()].push_back(call.getParamAttr(i, attr)); - } else if (attr == "enzymejl_returnRoots") { + } else if (attr == std::string("enzymejl_returnRoots")) { structAttrs[args.size()].push_back( Attribute::get(call.getContext(), "enzymejl_returnRoots_v")); } @@ -5155,7 +5155,7 @@ class AdjointGenerator : public llvm::InstVisitor { if (gutils->getWidth() == 1) { structAttrs[pre_args.size()].push_back( call.getParamAttr(i, attr)); - } else if (attr == "enzymejl_returnRoots") { + } else if (attr == std::string("enzymejl_returnRoots")) { structAttrs[pre_args.size()].push_back( Attribute::get(call.getContext(), "enzymejl_returnRoots_v")); } diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 76fc71921654..f930d4e1375b 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -3937,14 +3937,40 @@ Function *EnzymeLogic::CreatePrimalAndGradient( hasTape = false; // res.first.push_back(StructType::get(todiff->getContext(), {})); } else { - llvm::errs() << "expected args: ["; + std::string s; + llvm::raw_string_ostream ss(s); + ss << "Bad function type of custom reverse pass for function " + << key.todiff->getName() << " of type " + << *key.todiff->getFunctionType() << "\n"; + ss << " expected gradient function to have argument types ["; + bool seen = false; for (auto a : res.first) { - llvm::errs() << *a << " "; + if (seen) + ss << ", "; + seen = true; + ss << *a; + } + ss << "]\n"; + ss << " Instead found " << foundcalled->getName() << " of type " + << *foundcalled->getFunctionType() << "\n"; + Value *toshow = key.todiff; + if (context.req) { + toshow = context.req; + ss << " at context: " << *context.req; + } else { + ss << *key.todiff << "\n"; + } + if (CustomErrorHandler) { + CustomErrorHandler(ss.str().c_str(), wrap(toshow), + ErrorType::NoDerivative, nullptr, wrap(key.todiff), + wrap(context.ip)); + } else if (context.req) { + EmitFailure("NoDerivative", context.req->getDebugLoc(), context.req, + ss.str()); + } else { + assert(0 && "bad type for custom gradient"); + llvm_unreachable("bad type for custom gradient"); } - llvm::errs() << "]\n"; - llvm::errs() << *foundcalled << "\n"; - assert(0 && "bad type for custom gradient"); - llvm_unreachable("bad type for custom gradient"); } auto st = dyn_cast(foundcalled->getReturnType());