From 03673004eafec7cb242bd33ac3981b58d965ce5c Mon Sep 17 00:00:00 2001 From: William Moses Date: Sun, 11 Aug 2024 15:04:24 -0700 Subject: [PATCH] Improve no return index error (#2040) --- enzyme/Enzyme/AdjointGenerator.h | 45 ++++++++++++++++++++++++-------- 1 file changed, 34 insertions(+), 11 deletions(-) diff --git a/enzyme/Enzyme/AdjointGenerator.h b/enzyme/Enzyme/AdjointGenerator.h index ee263f841a6..0c17d2b14b1 100644 --- a/enzyme/Enzyme/AdjointGenerator.h +++ b/enzyme/Enzyme/AdjointGenerator.h @@ -5658,18 +5658,41 @@ class AdjointGenerator : public llvm::InstVisitor { if (Mode == DerivativeMode::ReverseModeCombined || Mode == DerivativeMode::ReverseModePrimal) { - auto drval = *differetIdx; - newip = (drval < 0) - ? augmentcall - : BuilderZ.CreateExtractValue(augmentcall, - {(unsigned)drval}, - call.getName() + "'ac"); - assert(newip->getType() == placeholder->getType()); - placeholder->replaceAllUsesWith(newip); - if (placeholder == &*BuilderZ.GetInsertPoint()) { - BuilderZ.SetInsertPoint(placeholder->getNextNode()); + if (!differetIdx) { + std::string str; + raw_string_ostream ss(str); + ss << "Did not have return index set when differentiating " + "function\n"; + ss << " call" << call << "\n"; + ss << " augmentcall" << *augmentcall << "\n"; + if (CustomErrorHandler) { + CustomErrorHandler(str.c_str(), wrap(&call), + ErrorType::InternalError, nullptr, nullptr, + nullptr); + } else { + EmitFailure("GetIndexError", call.getDebugLoc(), &call, + ss.str()); + } + placeholder->replaceAllUsesWith( + UndefValue::get(placeholder->getType())); + if (placeholder == &*BuilderZ.GetInsertPoint()) { + BuilderZ.SetInsertPoint(placeholder->getNextNode()); + } + gutils->erase(placeholder); + } else { + auto drval = *differetIdx; + newip = (drval < 0) + ? augmentcall + : BuilderZ.CreateExtractValue(augmentcall, + {(unsigned)drval}, + call.getName() + "'ac"); + assert(newip->getType() == placeholder->getType()); + placeholder->replaceAllUsesWith(newip); + if (placeholder == &*BuilderZ.GetInsertPoint()) { + BuilderZ.SetInsertPoint(placeholder->getNextNode()); + } + gutils->erase(placeholder); } - gutils->erase(placeholder); } else { newip = placeholder; }