Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Oct 16, 2023
1 parent 37fa7a0 commit 6b607d4
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 6 deletions.
6 changes: 3 additions & 3 deletions enzyme/Enzyme/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -640,7 +640,7 @@ void callMemcpyStridedBlas(llvm::IRBuilder<> &B, llvm::Module &M, BlasInfo blas,
llvm::ArrayRef<llvm::Value *> args,
llvm::Type *copy_retty,
llvm::ArrayRef<llvm::OperandBundleDef> bundles) {
std::string copy_name = blas.prefix + blas.floatType + "copy" + blas.suffix;
auto copy_name = Twine(blas.prefix) + blas.floatType + "copy" + blas.suffix;

SmallVector<Type *, 1> tys;
for (auto arg : args)
Expand All @@ -655,7 +655,7 @@ void callMemcpyStridedBlas(llvm::IRBuilder<> &B, llvm::Module &M, BlasInfo blas,
void callMemcpyStridedLapack(llvm::IRBuilder<> &B, llvm::Module &M,
BlasInfo blas, llvm::ArrayRef<llvm::Value *> args,
llvm::ArrayRef<llvm::OperandBundleDef> bundles) {
std::string copy_name = blas.prefix + blas.floatType + "lacpy" + blas.suffix;
auto copy_name = Twine(blas.prefix) + blas.floatType + "lacpy" + blas.suffix;

SmallVector<Type *, 1> tys;
for (auto arg : args)
Expand All @@ -674,7 +674,7 @@ void callSPMVDiagUpdate(IRBuilder<> &B, Module &M, BlasInfo blas,
ArrayRef<OperandBundleDef> bundles, bool byRef,
bool julia_decl) {
// add spmv diag update call if not already present
std::string fnc_name = "__enzyme_spmv_diag" + blas.floatType + blas.suffix;
auto fnc_name = Twine("__enzyme_spmv_diag") + blas.floatType + blas.suffix;

// spmvDiagHelper(uplo, n, alpha, x, incx, ya, incy, APa)
auto FDiagUpdateT = FunctionType::get(
Expand Down
4 changes: 2 additions & 2 deletions enzyme/test/Integration/blasinfra.h
Original file line number Diff line number Diff line change
Expand Up @@ -1022,6 +1022,7 @@ BlasInfo pointer_to_index(void *v, BlasInfo inputs[6]) {
for (int i = 3; i < 6; i++)
if (inputs[i].ptr == v)
return inputs[i];
printty(v);
assert(0 && " illegal pointer to invert");
}

Expand Down Expand Up @@ -1152,8 +1153,6 @@ void checkMemory(BlasCall rcall, BlasInfo inputs[6], std::string test,

auto alpha = rcall.farg1;

auto cualpha = pointer_to_index(rcall.pin_arg2, inputs);

auto N = rcall.iarg1;
auto incX = rcall.iarg4;
auto incY = rcall.iarg5;
Expand All @@ -1162,6 +1161,7 @@ void checkMemory(BlasCall rcall, BlasInfo inputs[6], std::string test,
checkVector(Y, "Y", /*len=*/N, /*inc=*/incY, test, rcall, trace);

if (rcall.abi == ABIType::CUBLAS) {
auto cualpha = pointer_to_index(rcall.pin_arg2, inputs);
checkVector(cualpha, "alpha", /*len=*/1, /*inc=*/1, test, rcall, trace);
}
return;
Expand Down
1 change: 0 additions & 1 deletion enzyme/tools/enzyme-tblgen/blasDiffUseUpdater.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ void emit_BLASDiffUse(TGPattern &pattern, llvm::raw_ostream &os) {
for (size_t argPos = (lv23 ? 1 : 0); argPos < typeMap.size(); argPos++) {
auto users = argUsers.lookup(argPos);
auto name = nameVec[argPos];
size_t i = (lv23 ? argPos - 1 : argPos);
os << " if (val == arg_" << name << " && need_" << name << " && !cache_"
<< name << ")\n"
<< " return true;\n";
Expand Down

0 comments on commit 6b607d4

Please sign in to comment.