Skip to content

Commit

Permalink
BLAS: fix blas erasure (#1948)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Jun 29, 2024
1 parent cc545bc commit 015a772
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 5 deletions.
1 change: 1 addition & 0 deletions enzyme/Enzyme/GradientUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8226,6 +8226,7 @@ void GradientUtils::eraseFictiousPHIs() {
for (auto pair : phis) {
auto pp = pair.first;
if (pp->getNumUses() != 0) {
assert(0);
if (CustomErrorHandler) {
std::string str;
raw_string_ostream ss(str);
Expand Down
15 changes: 10 additions & 5 deletions enzyme/tools/enzyme-tblgen/blas-tblgen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -255,17 +255,22 @@ void emit_free_and_ending(const TGPattern &pattern, raw_ostream &os) {
"Constant::getNullValue(call.getType()));\n"
<< " }\n";

os << " bool shouldErase = true;\n";
os << " if (gutils->knownRecomputeHeuristic.find(&call) !=\n"
<< " gutils->knownRecomputeHeuristic.end()) {\n"
<< " if (!gutils->knownRecomputeHeuristic[&call]) {\n"
<< " auto cv = gutils->cacheForReverse(BuilderZ, newCall,\n"
<< " getIndex(&call, CacheType::Self, BuilderZ));\n"
<< " shouldErase = false;\n"
<< " }\n"
<< " } else if (Mode == DerivativeMode::ReverseModeGradient) { \n"
<< " eraseIfUnused(call, /*erase*/ true, /*check*/ false); \n"
<< " } else { \n"
<< " eraseIfUnused(call); \n"
<< " } \n"
<< " }\n"
<< " if (shouldErase) {\n"
<< " if (Mode == DerivativeMode::ReverseModeGradient) { \n"
<< " eraseIfUnused(call, /*erase*/ true, /*check*/ false); \n"
<< " } else { \n"
<< " eraseIfUnused(call); \n"
<< " } \n"
<< " }\n"
<< " return true;\n"
<< "#ifdef __clang__\n"
<< "#pragma clang diagnostic pop\n"
Expand Down

0 comments on commit 015a772

Please sign in to comment.