Skip to content

Commit

Permalink
attribute blas fnc in PreserveNVVM (#1505)
Browse files Browse the repository at this point in the history
* attribute blas fnc in PreserveNVVM

* extra Blas header to include instead of inc files

* Revert "extra Blas header to include instead of inc files"

This reverts commit d57adb3.

* attribute blas fncs which we generate too

* attribute functions earlier

* only attribute empty fncs
  • Loading branch information
ZuseZ4 authored Oct 25, 2023
1 parent 31f1380 commit 1a7d5d5
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 8 deletions.
18 changes: 16 additions & 2 deletions enzyme/Enzyme/Enzyme.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,14 +117,16 @@ llvm::cl::opt<bool> EnzymeOMPOpt("enzyme-omp-opt", cl::init(false), cl::Hidden,
#define addAttribute addAttributeAtIndex
#define getAttribute getAttributeAtIndex
#endif
void attributeKnownFunctions(llvm::Function &F) {
bool attributeKnownFunctions(llvm::Function &F) {
bool changed = false;
if (F.getName().contains("__enzyme_float") ||
F.getName().contains("__enzyme_double") ||
F.getName().contains("__enzyme_integer") ||
F.getName().contains("__enzyme_pointer") ||
F.getName().contains("__enzyme_todense") ||
F.getName().contains("__enzyme_iter") ||
F.getName().contains("__enzyme_virtualreverse")) {
changed = true;
#if LLVM_VERSION_MAJOR >= 16
F.setOnlyReadsMemory();
F.setOnlyWritesMemory();
Expand All @@ -140,6 +142,7 @@ void attributeKnownFunctions(llvm::Function &F) {
}
}
if (F.getName() == "memcmp") {
changed = true;
#if LLVM_VERSION_MAJOR >= 16
F.setOnlyAccessesArgMemory();
F.setOnlyReadsMemory();
Expand All @@ -159,13 +162,15 @@ void attributeKnownFunctions(llvm::Function &F) {
}
}

attributeTablegen(F);
changed |= attributeTablegen(F);

if (F.getName() ==
"_ZNSt7__cxx1112basic_stringIcSt11char_traitsIcESaIcEE9_M_createERmm") {
changed = true;
F.addFnAttr(Attribute::NoFree);
}
if (F.getName() == "MPI_Irecv" || F.getName() == "PMPI_Irecv") {
changed = true;
#if LLVM_VERSION_MAJOR >= 16
F.setOnlyAccessesInaccessibleMemOrArgMem();
#else
Expand All @@ -184,6 +189,7 @@ void attributeKnownFunctions(llvm::Function &F) {
F.addParamAttr(6, Attribute::WriteOnly);
}
if (F.getName() == "MPI_Isend" || F.getName() == "PMPI_Isend") {
changed = true;
#if LLVM_VERSION_MAJOR >= 16
F.setOnlyAccessesInaccessibleMemOrArgMem();
#else
Expand All @@ -203,6 +209,7 @@ void attributeKnownFunctions(llvm::Function &F) {
}
if (F.getName() == "MPI_Comm_rank" || F.getName() == "PMPI_Comm_rank" ||
F.getName() == "MPI_Comm_size" || F.getName() == "PMPI_Comm_size") {
changed = true;
#if LLVM_VERSION_MAJOR >= 16
F.setOnlyAccessesInaccessibleMemOrArgMem();
#else
Expand All @@ -224,6 +231,7 @@ void attributeKnownFunctions(llvm::Function &F) {
}
}
if (F.getName() == "MPI_Wait" || F.getName() == "PMPI_Wait") {
changed = true;
F.addFnAttr(Attribute::NoUnwind);
F.addFnAttr(Attribute::NoRecurse);
F.addFnAttr(Attribute::WillReturn);
Expand All @@ -234,6 +242,7 @@ void attributeKnownFunctions(llvm::Function &F) {
F.addParamAttr(1, Attribute::NoCapture);
}
if (F.getName() == "MPI_Waitall" || F.getName() == "PMPI_Waitall") {
changed = true;
F.addFnAttr(Attribute::NoUnwind);
F.addFnAttr(Attribute::NoRecurse);
F.addFnAttr(Attribute::WillReturn);
Expand Down Expand Up @@ -269,6 +278,7 @@ void attributeKnownFunctions(llvm::Function &F) {
}
if (auto GV = dyn_cast<GlobalVariable>(C)) {
if (GV->getName() == "ompi_mpi_cxx_bool") {
changed = true;
CI->addAttribute(
AttributeList::FunctionIndex,
Attribute::get(CI->getContext(), "enzyme_inactive"));
Expand All @@ -282,6 +292,7 @@ void attributeKnownFunctions(llvm::Function &F) {

if (F.getName() == "omp_get_max_threads" ||
F.getName() == "omp_get_thread_num") {
changed = true;
#if LLVM_VERSION_MAJOR >= 16
F.setOnlyAccessesInaccessibleMemory();
F.setOnlyReadsMemory();
Expand All @@ -292,6 +303,7 @@ void attributeKnownFunctions(llvm::Function &F) {
}
if (F.getName() == "frexp" || F.getName() == "frexpf" ||
F.getName() == "frexpl") {
changed = true;
#if LLVM_VERSION_MAJOR >= 16
F.setOnlyAccessesArgMemory();
#else
Expand All @@ -301,13 +313,15 @@ void attributeKnownFunctions(llvm::Function &F) {
}
if (F.getName() == "__fd_sincos_1" || F.getName() == "__fd_cos_1" ||
F.getName() == "__mth_i_ipowi") {
changed = true;
#if LLVM_VERSION_MAJOR >= 16
F.setOnlyReadsMemory();
F.setOnlyWritesMemory();
#else
F.addFnAttr(Attribute::ReadNone);
#endif
}
return changed;
}

namespace {
Expand Down
6 changes: 6 additions & 0 deletions enzyme/Enzyme/PreserveNVVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,12 @@ bool preserveLinkage(bool Begin, Function &F) {

bool preserveNVVM(bool Begin, Function &F) {
bool changed = false;

auto name = getFuncName(&F);
if (Begin) {
changed |= attributeKnownFunctions(F);
}

StringMap<std::pair<std::string, std::string>> Implements;
for (std::string T : {"", "f"}) {
// sincos, sinpi, cospi, sincospi, cyl_bessel_i1
Expand Down
5 changes: 5 additions & 0 deletions enzyme/Enzyme/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -648,6 +648,8 @@ void callMemcpyStridedBlas(llvm::IRBuilder<> &B, llvm::Module &M, BlasInfo blas,

FunctionType *FT = FunctionType::get(copy_retty, tys, false);
auto fn = M.getOrInsertFunction(copy_name.str(), FT);
Function *F = cast<Function>(fn.getCallee());
attributeKnownFunctions(*F);

B.CreateCall(fn, args, bundles);
}
Expand All @@ -663,6 +665,8 @@ void callMemcpyStridedLapack(llvm::IRBuilder<> &B, llvm::Module &M,

auto FT = FunctionType::get(Type::getVoidTy(M.getContext()), tys, false);
auto fn = M.getOrInsertFunction(copy_name.str(), FT);
Function *F = cast<Function>(fn.getCallee());
attributeKnownFunctions(*F);

B.CreateCall(fn, args, bundles);
}
Expand Down Expand Up @@ -880,6 +884,7 @@ getorInsertInnerProd(llvm::IRBuilder<> &B, llvm::Module &M, BlasInfo blas,
FunctionType::get(fpTy, {BlasIT, BlasPT, BlasIT, BlasPT, BlasIT}, false);
Function *FDot =
cast<Function>(M.getOrInsertFunction(dot_name, FDotT).getCallee());
attributeKnownFunctions(*F);

// now add the implementation for the inner_prod call
F->setLinkage(Function::LinkageTypes::InternalLinkage);
Expand Down
2 changes: 1 addition & 1 deletion enzyme/Enzyme/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -1496,7 +1496,7 @@ static inline bool isNoCapture(const llvm::CallInst *call, size_t idx) {
return false;
}

void attributeKnownFunctions(llvm::Function &F);
bool attributeKnownFunctions(llvm::Function &F);

llvm::Constant *getUndefinedValueForType(llvm::Type *T, bool forceZero = false);

Expand Down
21 changes: 16 additions & 5 deletions enzyme/tools/enzyme-tblgen/blasDeclUpdater.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
void emit_attributeBLASCaller(ArrayRef<TGPattern> blasPatterns,
raw_ostream &os) {
os << "void attributeBLAS(BlasInfo blas, llvm::Function *F) { \n";
os << " if (!F->empty())\n";
os << " return;\n";
for (auto &&pattern : blasPatterns) {
auto name = pattern.getName();
os << " if (blas.function == \"" << name << "\") { \n"
Expand All @@ -17,6 +19,8 @@ void emit_attributeBLAS(const TGPattern &pattern, raw_ostream &os) {
auto name = pattern.getName();
bool lv23 = pattern.isBLASLevel2or3();
os << "void attribute_" << name << "(BlasInfo blas, llvm::Function *F) {\n";
os << " if (!F->empty())\n";
os << " return;\n";
os << " const bool byRef = blas.prefix == \"\" || blas.prefix == "
"\"cublas_\";\n";
os << " const bool cblas = blas.prefix == \"cblas_\";\n";
Expand Down Expand Up @@ -186,15 +190,20 @@ void emitBlasDeclUpdater(const RecordKeeper &RK, raw_ostream &os) {
}
emit_attributeBLASCaller(newBlasPatterns, os);

os << "void attributeTablegen(llvm::Function &F) {\n";
os << "bool attributeTablegen(llvm::Function &F) {\n";
os << " auto name = getFuncName(&F);\n";
os << " auto changed = false;\n";
os << " auto blasMetaData = extractBLAS(name);\n";
os << " #if LLVM_VERSION_MAJOR >= 16\n";
os << " if (blasMetaData.has_value())\n";
os << " if (F.empty() && blasMetaData.has_value()) {\n";
os << " attributeBLAS(blasMetaData.value(), &F);\n";
os << " changed = true;\n";
os << " }\n";
os << " #else\n";
os << " if (blasMetaData.hasValue())\n";
os << " if (F.empty() && blasMetaData.hasValue()) {\n";
os << " attributeBLAS(blasMetaData.getValue(), &F);\n";
os << " changed = true;\n";
os << " }\n";
os << " #endif\n";
{
const auto &patterns = RK.getAllDerivedDefinitions("CallPattern");
Expand All @@ -209,7 +218,9 @@ void emitBlasDeclUpdater(const RecordKeeper &RK, raw_ostream &os) {
prev = true;
}
os << ") && F.getFunctionType()->getNumParams() == " << tree->getNumArgs()
<< " ){\n";
<< " ){\n"
<< " changed = true;\n";

for (auto attr : *pattern->getValueAsListInit("FnAttrs")) {
auto attrDef = cast<DefInit>(attr)->getDef();
auto attrName = attrDef->getValueAsString("name");
Expand Down Expand Up @@ -251,6 +262,6 @@ void emitBlasDeclUpdater(const RecordKeeper &RK, raw_ostream &os) {
os << " }\n";
}
}

os << " return changed;\n";
os << "}\n";
}

0 comments on commit 1a7d5d5

Please sign in to comment.