Skip to content

Commit

Permalink
only attribute empty fncs
Browse files Browse the repository at this point in the history
  • Loading branch information
ZuseZ4 committed Oct 25, 2023
1 parent 875c199 commit 17e6835
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 6 deletions.
8 changes: 4 additions & 4 deletions enzyme/Enzyme/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,6 @@ Function *getOrInsertDifferentialFloatMemcpy(Module &M, Type *elementType,
return F;
}

void attribute_copy(BlasInfo blas, llvm::Function *F);
void callMemcpyStridedBlas(llvm::IRBuilder<> &B, llvm::Module &M, BlasInfo blas,
llvm::ArrayRef<llvm::Value *> args,
llvm::Type *copy_retty,
Expand All @@ -650,7 +649,7 @@ void callMemcpyStridedBlas(llvm::IRBuilder<> &B, llvm::Module &M, BlasInfo blas,
FunctionType *FT = FunctionType::get(copy_retty, tys, false);
auto fn = M.getOrInsertFunction(copy_name.str(), FT);
Function *F = cast<Function>(fn.getCallee());
attribute_copy(blas, F);
attributeKnownFunctions(*F);

B.CreateCall(fn, args, bundles);
}
Expand All @@ -666,6 +665,8 @@ void callMemcpyStridedLapack(llvm::IRBuilder<> &B, llvm::Module &M,

auto FT = FunctionType::get(Type::getVoidTy(M.getContext()), tys, false);
auto fn = M.getOrInsertFunction(copy_name.str(), FT);
Function *F = cast<Function>(fn.getCallee());
attributeKnownFunctions(*F);

B.CreateCall(fn, args, bundles);
}
Expand Down Expand Up @@ -859,7 +860,6 @@ void callSPMVDiagUpdate(IRBuilder<> &B, Module &M, BlasInfo blas,
return;
}

void attribute_dot(BlasInfo blas, llvm::Function *F);
llvm::CallInst *
getorInsertInnerProd(llvm::IRBuilder<> &B, llvm::Module &M, BlasInfo blas,
IntegerType *IT, Type *BlasPT, Type *BlasIT, Type *fpTy,
Expand All @@ -884,7 +884,7 @@ getorInsertInnerProd(llvm::IRBuilder<> &B, llvm::Module &M, BlasInfo blas,
FunctionType::get(fpTy, {BlasIT, BlasPT, BlasIT, BlasPT, BlasIT}, false);
Function *FDot =
cast<Function>(M.getOrInsertFunction(dot_name, FDotT).getCallee());
attribute_dot(blas, FDot);
attributeKnownFunctions(*F);

// now add the implementation for the inner_prod call
F->setLinkage(Function::LinkageTypes::InternalLinkage);
Expand Down
8 changes: 6 additions & 2 deletions enzyme/tools/enzyme-tblgen/blasDeclUpdater.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
void emit_attributeBLASCaller(ArrayRef<TGPattern> blasPatterns,
raw_ostream &os) {
os << "void attributeBLAS(BlasInfo blas, llvm::Function *F) { \n";
os << " if (!F->empty())\n";
os << " return;\n";
for (auto &&pattern : blasPatterns) {
auto name = pattern.getName();
os << " if (blas.function == \"" << name << "\") { \n"
Expand All @@ -17,6 +19,8 @@ void emit_attributeBLAS(const TGPattern &pattern, raw_ostream &os) {
auto name = pattern.getName();
bool lv23 = pattern.isBLASLevel2or3();
os << "void attribute_" << name << "(BlasInfo blas, llvm::Function *F) {\n";
os << " if (!F->empty())\n";
os << " return;\n";
os << " const bool byRef = blas.prefix == \"\" || blas.prefix == "
"\"cublas_\";\n";
os << " const bool cblas = blas.prefix == \"cblas_\";\n";
Expand Down Expand Up @@ -191,12 +195,12 @@ void emitBlasDeclUpdater(const RecordKeeper &RK, raw_ostream &os) {
os << " auto changed = false;\n";
os << " auto blasMetaData = extractBLAS(name);\n";
os << " #if LLVM_VERSION_MAJOR >= 16\n";
os << " if (blasMetaData.has_value()) {\n";
os << " if (F.empty() && blasMetaData.has_value()) {\n";
os << " attributeBLAS(blasMetaData.value(), &F);\n";
os << " changed = true;\n";
os << " }\n";
os << " #else\n";
os << " if (blasMetaData.hasValue()) {\n";
os << " if (F.empty() && blasMetaData.hasValue()) {\n";
os << " attributeBLAS(blasMetaData.getValue(), &F);\n";
os << " changed = true;\n";
os << " }\n";
Expand Down

0 comments on commit 17e6835

Please sign in to comment.