Skip to content

Commit

Permalink
Fix cblas
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Oct 24, 2023
1 parent fb4c017 commit 0d244a1
Showing 1 changed file with 38 additions and 15 deletions.
53 changes: 38 additions & 15 deletions enzyme/Enzyme/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2494,7 +2494,8 @@ llvm::Value *to_blas_fp_callconv(IRBuilder<> &B, llvm::Value *V, bool byRef,
llvm::Value *select_vec_dims(IRBuilder<> &B, llvm::Value *trans,
llvm::Value *dim1, llvm::Value *dim2, bool byRef,
bool cublas) {
Value *width = B.CreateSelect(is_normal(B, trans, byRef, cublas), dim1, dim2);
auto norm = is_normal(B, trans, byRef, cublas);
Value *width = B.CreateSelect(norm, dim1, dim2);

return width;
}
Expand Down Expand Up @@ -2526,23 +2527,24 @@ llvm::Value *is_normal(IRBuilder<> &B, llvm::Value *trans, bool byRef,
isNormal = B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 0));
return isNormal;
}
IntegerType *charTy;
// Explicitly support 'N' always, since we use in the rule infra
if (auto CI = dyn_cast<ConstantInt>(trans)) {
if (CI->getValue() == 'N' || CI->getValue() == 'n')
return ConstantInt::getTrue(
B.getContext()); //(Type::getInt1Ty(B.getContext()), true);
}
if (byRef) {
// can't inspect opaque ptr, so assume 8 (Julia)
charTy = IntegerType::get(trans->getContext(), 8);
IntegerType *charTy = IntegerType::get(trans->getContext(), 8);
trans = B.CreateLoad(charTy, trans, "loaded.trans");

// fortran blas
return B.CreateOr(B.CreateICmpEQ(trans, ConstantInt::get(charTy, 'n')),
B.CreateICmpEQ(trans, ConstantInt::get(charTy, 'N')));
} else {
// we can inspect scalars
unsigned int len = trans->getType()->getScalarSizeInBits();
charTy = IntegerType::get(trans->getContext(), len);
return B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 111));
}

Value *trueVal = ConstantInt::getTrue(trans->getContext());

Value *isNormal =
B.CreateOr(B.CreateICmpEQ(trans, ConstantInt::get(charTy, 'n')),
B.CreateICmpEQ(trans, ConstantInt::get(charTy, 'N')));
return isNormal;
}

// Ok. Here we are.
Expand Down Expand Up @@ -2622,6 +2624,22 @@ llvm::Value *transpose(llvm::IRBuilder<> &B, llvm::Value *V, bool byRef,
llvm::IRBuilder<> &entryBuilder,
const llvm::Twine &name) {

if (!byRef) {
// Explicitly support 'N' always, since we use in the rule infra
if (auto CI = dyn_cast<ConstantInt>(V)) {
if (CI->getValue() == 'N')
return ConstantInt::get(CI->getType(), 'T');
if (CI->getValue() == 'n')
return ConstantInt::get(CI->getType(), 't');
}

// cblas
return B.CreateSelect(
B.CreateICmpEQ(V, ConstantInt::get(V->getType(), 111)),
ConstantInt::get(V->getType(), 112),
ConstantInt::get(V->getType(), 111));
}

if (byRef) {
auto charType = IntegerType::get(V->getContext(), 8);
V = B.CreateLoad(charType, V, "ld." + name);
Expand Down Expand Up @@ -2658,9 +2676,14 @@ SmallVector<llvm::Value *, 1> get_blas_row(llvm::IRBuilder<> &B,

Value *cond = nullptr;
if (!cublas) {
cond = B.CreateOr(
B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 'N')),
B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 'n')));

if (!byRef) {
cond = B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 111));
} else {
cond = B.CreateOr(
B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 'N')),
B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 'n')));
}
} else {
// CUBLAS_OP_N = 0, CUBLAS_OP_T = 1, CUBLAS_OP_C = 2
// TODO: verify
Expand Down

0 comments on commit 0d244a1

Please sign in to comment.