Skip to content

Commit

Permalink
wip fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Oct 16, 2023
1 parent 0ee7418 commit 5974276
Show file tree
Hide file tree
Showing 7 changed files with 108 additions and 102 deletions.
102 changes: 45 additions & 57 deletions enzyme/Enzyme/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -637,10 +637,10 @@ Function *getOrInsertDifferentialFloatMemcpy(Module &M, Type *elementType,
}

void callMemcpyStridedBlas(llvm::IRBuilder<> &B, llvm::Module &M, BlasInfo blas,
llvm::ArrayRef<llvm::Value *> args, llvm::Type *copy_retty,
llvm::ArrayRef<llvm::Value *> args,
llvm::Type *copy_retty,
llvm::ArrayRef<llvm::OperandBundleDef> bundles) {
std::string copy_name =
blas.prefix + blas.floatType + "copy" + blas.suffix;
std::string copy_name = blas.prefix + blas.floatType + "copy" + blas.suffix;

SmallVector<Type *, 1> tys;
for (auto arg : args)
Expand All @@ -655,8 +655,7 @@ void callMemcpyStridedBlas(llvm::IRBuilder<> &B, llvm::Module &M, BlasInfo blas,
void callMemcpyStridedLapack(llvm::IRBuilder<> &B, llvm::Module &M,
BlasInfo blas, llvm::ArrayRef<llvm::Value *> args,
llvm::ArrayRef<llvm::OperandBundleDef> bundles) {
std::string copy_name =
blas.prefix + blas.floatType + "lacpy" + blas.suffix;
std::string copy_name = blas.prefix + blas.floatType + "lacpy" + blas.suffix;

SmallVector<Type *, 1> tys;
for (auto arg : args)
Expand All @@ -675,8 +674,7 @@ void callSPMVDiagUpdate(IRBuilder<> &B, Module &M, BlasInfo blas,
ArrayRef<OperandBundleDef> bundles, bool byRef,
bool julia_decl) {
// add spmv diag update call if not already present
std::string fnc_name =
"__enzyme_spmv_diag" + blas.floatType + blas.suffix;
std::string fnc_name = "__enzyme_spmv_diag" + blas.floatType + blas.suffix;

// spmvDiagHelper(uplo, n, alpha, x, incx, ya, incy, APa)
auto FDiagUpdateT = FunctionType::get(
Expand Down Expand Up @@ -867,8 +865,7 @@ getorInsertInnerProd(llvm::IRBuilder<> &B, llvm::Module &M, BlasInfo blas,
assert(fpTy->isFloatingPointTy());

// add inner_prod call if not already present
std::string prod_name =
"__enzyme_inner_prod" + blas.floatType + blas.suffix;
std::string prod_name = "__enzyme_inner_prod" + blas.floatType + blas.suffix;
auto FInnerProdT =
FunctionType::get(fpTy, {BlasIT, BlasIT, BlasPT, BlasIT, BlasPT}, false);
Function *F =
Expand All @@ -878,8 +875,7 @@ getorInsertInnerProd(llvm::IRBuilder<> &B, llvm::Module &M, BlasInfo blas,
return B.CreateCall(F, args, bundles);

// add dot call if not already present
std::string dot_name =
blas.prefix + blas.floatType + "dot" + blas.suffix;
std::string dot_name = blas.prefix + blas.floatType + "dot" + blas.suffix;
auto FDotT =
FunctionType::get(fpTy, {BlasIT, BlasPT, BlasIT, BlasPT, BlasIT}, false);
Function *FDot =
Expand Down Expand Up @@ -929,12 +925,13 @@ getorInsertInnerProd(llvm::IRBuilder<> &B, llvm::Module &M, BlasInfo blas,

{
IRBuilder<> B1(entry);
Value *blasOne = to_blas_callconv(B1, ConstantInt::get(IT, 1), byRef, cublas, IT,
B1, "constant.one");
Value *blasOne = to_blas_callconv(B1, ConstantInt::get(IT, 1), byRef,
cublas, IT, B1, "constant.one");
Value *m = load_if_ref(B1, IT, blasm, byRef);
Value *n = load_if_ref(B1, IT, blasn, byRef);
Value *size = B1.CreateNUWMul(m, n, "mat.size");
Value *blasSize = to_blas_callconv(B1, size, byRef, cublas, IT, B1, "mat.size");
Value *blasSize =
to_blas_callconv(B1, size, byRef, cublas, IT, B1, "mat.size");
B1.CreateCondBr(B1.CreateICmpEQ(size, ConstantInt::get(IT, 0)), end, init);

IRBuilder<> B2(init);
Expand Down Expand Up @@ -2357,59 +2354,47 @@ std::optional<BlasInfo> extractBLAS(llvm::StringRef in)
llvm::Optional<BlasInfo> extractBLAS(llvm::StringRef in)
#endif
{
const char* extractable[] = {"dot", "scal", "axpy", "gemv", "gemm", "spmv"};
const char* floatType[] = {"s", "d"}; // c, z
const char* prefixes[] = {"" /*Fortran*/, "cblas_"};
const char* suffixes[] = {"", "_", "64_", "_64_"};
const char *extractable[] = {"dot", "scal", "axpy", "gemv", "gemm", "spmv"};
const char *floatType[] = {"s", "d"}; // c, z
const char *prefixes[] = {"" /*Fortran*/, "cblas_"};
const char *suffixes[] = {"", "_", "64_", "_64_"};
for (auto t : floatType) {
for (auto f : extractable) {
for (auto p : prefixes) {
for (auto s : suffixes) {
if (in == (Twine(p) + t + f + s).str()) {
bool is64 = llvm::StringRef(s).contains("64");
return BlasInfo{
t,
p,
s,
f,
is64,
t, p, s, f, is64,
};
}
}
}
}
}
// c interface to cublas
const char* cuCFloatType[] = {"S", "D"}; // c, z
const char* cuFFloatType[] = {"s", "d"}; // c, z
const char* cuCPrefixes[] = {"cublas"};
const char *cuCFloatType[] = {"S", "D"}; // c, z
const char *cuFFloatType[] = {"s", "d"}; // c, z
const char *cuCPrefixes[] = {"cublas"};
for (auto t : llvm::enumerate(cuCFloatType)) {
for (auto f : extractable) {
for (auto p : cuCPrefixes) {
if (in == (Twine(p) + t.value() + f).str()) {
return BlasInfo{
cuFFloatType[t.index()],
p,
"",
f,
false,
t.value(), p, "", f, false,
};
}
}
}
}
// Fortran interface to cublas
const char* cuFPrefixes[] = {"cublas_"};
const char *cuFPrefixes[] = {"cublas_"};
for (auto t : cuFFloatType) {
for (auto f : extractable) {
for (auto p : cuFPrefixes) {
if (in == (Twine(p) + t + f).str()) {
return BlasInfo{
t,
p,
"",
f,
false,
t, p, "", f, false,
};
}
}
Expand Down Expand Up @@ -2473,8 +2458,8 @@ void addValueToCache(llvm::Value *arg, bool cache_arg, llvm::Type *ty,

// julia_decl null means not julia decl, otherwise it is the integer type needed
// to cast to
llvm::Value *to_blas_callconv(IRBuilder<> &B, llvm::Value *V, bool byRef, bool cublas,
IntegerType *julia_decl,
llvm::Value *to_blas_callconv(IRBuilder<> &B, llvm::Value *V, bool byRef,
bool cublas, IntegerType *julia_decl,
IRBuilder<> &entryBuilder,
llvm::Twine const &name) {
if (!byRef)
Expand Down Expand Up @@ -2507,7 +2492,8 @@ llvm::Value *to_blas_fp_callconv(IRBuilder<> &B, llvm::Value *V, bool byRef,
}

llvm::Value *select_vec_dims(IRBuilder<> &B, llvm::Value *trans,
llvm::Value *dim1, llvm::Value *dim2, bool byRef, bool cublas) {
llvm::Value *dim1, llvm::Value *dim2, bool byRef,
bool cublas) {
Value *width = B.CreateSelect(is_normal(B, trans, byRef, cublas), dim1, dim2);

return width;
Expand All @@ -2533,7 +2519,8 @@ Value *is_uper(IRBuilder<> &B, Value *trans, bool byRef) {
return isUper;
}

llvm::Value *is_normal(IRBuilder<> &B, llvm::Value *trans, bool byRef, bool cublas) {
llvm::Value *is_normal(IRBuilder<> &B, llvm::Value *trans, bool byRef,
bool cublas) {
if (cublas) {
Value *isNormal = nullptr;
isNormal = B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 0));
Expand Down Expand Up @@ -2568,12 +2555,12 @@ llvm::Value *transpose(IRBuilder<> &B, llvm::Value *V, bool cublas) {
llvm::Type *T = V->getType();
Value *out;
if (cublas) {
out = B.CreateSelect(
B.CreateICmpEQ(V, ConstantInt::get(T, 1)),
ConstantInt::get(V->getType(), 0),
B.CreateSelect(B.CreateICmpEQ(V, ConstantInt::get(T, 0)),
ConstantInt::get(V->getType(), 1),
ConstantInt::get(V->getType(), 42)));
out =
B.CreateSelect(B.CreateICmpEQ(V, ConstantInt::get(T, 1)),
ConstantInt::get(V->getType(), 0),
B.CreateSelect(B.CreateICmpEQ(V, ConstantInt::get(T, 0)),
ConstantInt::get(V->getType(), 1),
ConstantInt::get(V->getType(), 42)));
return out;
}
if (T->isIntegerTy(8)) {
Expand Down Expand Up @@ -2620,20 +2607,21 @@ llvm::Value *transpose(IRBuilder<> &B, llvm::Value *V, bool cublas) {
llvm::Value *get_cached_mat_width(llvm::IRBuilder<> &B,
llvm::ArrayRef<llvm::Value *> trans,
llvm::Value *arg_ld, llvm::Value *dim1,
llvm::Value *dim2, bool cacheMat,
bool byRef, bool cublas) {
llvm::Value *dim2, bool cacheMat, bool byRef,
bool cublas) {
if (!cacheMat)
return arg_ld;

assert(trans.size() == 1);

llvm::Value *width = CreateSelect(B, is_normal(B, trans[0], byRef, cublas), dim1, dim2);
llvm::Value *width =
CreateSelect(B, is_normal(B, trans[0], byRef, cublas), dim1, dim2);

return width;
}

llvm::Value *transpose(llvm::IRBuilder<> &B, llvm::Value *V, bool byRef, bool cublas,
llvm::IntegerType *julia_decl,
llvm::Value *transpose(llvm::IRBuilder<> &B, llvm::Value *V, bool byRef,
bool cublas, llvm::IntegerType *julia_decl,
llvm::IRBuilder<> &entryBuilder,
const llvm::Twine &name) {

Expand Down Expand Up @@ -2674,12 +2662,12 @@ SmallVector<llvm::Value *, 1> get_blas_row(llvm::IRBuilder<> &B,
Value *cond = nullptr;
if (!cublas) {
cond = B.CreateOr(
B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 'N')),
B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 'n')));
B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 'N')),
B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 'n')));
} else {
// CUBLAS_OP_N = 0, CUBLAS_OP_T = 1, CUBLAS_OP_C = 2
// TODO: verify
cond = B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 0));
// CUBLAS_OP_N = 0, CUBLAS_OP_T = 1, CUBLAS_OP_C = 2
// TODO: verify
cond = B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 0));
}
assert(row.size() == col.size());
SmallVector<Value *, 1> toreturn;
Expand Down
22 changes: 13 additions & 9 deletions enzyme/Enzyme/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -625,7 +625,8 @@ llvm::Function *getOrInsertDifferentialFloatMemcpy(

/// Create function for type that performs memcpy with a stride using blas copy
void callMemcpyStridedBlas(llvm::IRBuilder<> &B, llvm::Module &M, BlasInfo blas,
llvm::ArrayRef<llvm::Value *> args, llvm::Type *cublas_retty,
llvm::ArrayRef<llvm::Value *> args,
llvm::Type *cublas_retty,
llvm::ArrayRef<llvm::OperandBundleDef> bundles);

/// Create function for type that performs memcpy using lapack copy
Expand Down Expand Up @@ -1628,8 +1629,8 @@ llvm::Value *load_if_ref(llvm::IRBuilder<> &B, llvm::IntegerType *intType,

// julia_decl null means not julia decl, otherwise it is the integer type needed
// to cast to
llvm::Value *to_blas_callconv(llvm::IRBuilder<> &B, llvm::Value *V, bool byRef, bool cublas,
llvm::IntegerType *julia_decl,
llvm::Value *to_blas_callconv(llvm::IRBuilder<> &B, llvm::Value *V, bool byRef,
bool cublas, llvm::IntegerType *julia_decl,
llvm::IRBuilder<> &entryBuilder,
llvm::Twine const & = "");
llvm::Value *to_blas_fp_callconv(llvm::IRBuilder<> &B, llvm::Value *V,
Expand All @@ -1640,8 +1641,8 @@ llvm::Value *to_blas_fp_callconv(llvm::IRBuilder<> &B, llvm::Value *V,
llvm::Value *get_cached_mat_width(llvm::IRBuilder<> &B,
llvm::ArrayRef<llvm::Value *> trans,
llvm::Value *arg_ld, llvm::Value *dim_1,
llvm::Value *dim_2, bool cacheMat,
bool byRef, bool cublas);
llvm::Value *dim_2, bool cacheMat, bool byRef,
bool cublas);

template <typename T>
static inline void append(llvm::SmallVectorImpl<T> &vec) {}
Expand All @@ -1658,15 +1659,18 @@ static inline llvm::SmallVector<llvm::Value *, 1> concat_values(T &&...t) {
return res;
}

llvm::Value *is_normal(llvm::IRBuilder<> &B, llvm::Value *trans, bool byRef, bool cublas);
llvm::Value *is_normal(llvm::IRBuilder<> &B, llvm::Value *trans, bool byRef,
bool cublas);
llvm::Value *is_uper(llvm::IRBuilder<> &B, llvm::Value *trans, bool byRef);
llvm::Value *select_vec_dims(llvm::IRBuilder<> &B, llvm::Value *trans,
llvm::Value *dim1, llvm::Value *dim2, bool byRef, bool cublas);
llvm::Value *dim1, llvm::Value *dim2, bool byRef,
bool cublas);
// first one assume V is an Integer
llvm::Value *transpose(llvm::IRBuilder<> &B, llvm::Value *V, bool cublas);
// secon one assume V is an Integer or a ptr to an int (depends on byRef)
llvm::Value *transpose(llvm::IRBuilder<> &B, llvm::Value *V, bool byRef, bool cublas,
llvm::IntegerType *IT, llvm::IRBuilder<> &entryBuilder,
llvm::Value *transpose(llvm::IRBuilder<> &B, llvm::Value *V, bool byRef,
bool cublas, llvm::IntegerType *IT,
llvm::IRBuilder<> &entryBuilder,
const llvm::Twine &name);
llvm::SmallVector<llvm::Value *, 1>
get_blas_row(llvm::IRBuilder<> &B, llvm::ArrayRef<llvm::Value *> trans,
Expand Down
5 changes: 2 additions & 3 deletions enzyme/test/Integration/ReverseMode/cublas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,6 @@ static void gemvTests() {
{

bool trans = !is_normal(transA);
printf("trans: %s\n", trans ? "true" : "false");
auto handle = DEFAULT_CUBLAS_HANDLE;
std::string Test = "GEMV active A, C ";
BlasInfo inputs[6] = {/*A*/ BlasInfo(A, CUBLAS_LAYOUT, M, N, lda),
Expand Down Expand Up @@ -328,8 +327,8 @@ static void gemmTests() {
transB_bool ? dC : A, transB_bool ? incC : lda,
transB_bool ? A : dC, transB_bool ? lda : incC, 1.0, dB, incB);

// not supported yet by cublas @wsmoses
// cublasDlascl(handle, 'G', 0, 0, 1.0, beta, M, N, dC, incC /*, extra
// TODO we are currently faking support here, this needs to be actually implemented
cublasDlascl(handle, (cublasOperation_t)'G', 0, 0, 1.0, beta, M, N, dC, incC /*, extra
// 0*/ );

checkTest(Test);
Expand Down
16 changes: 10 additions & 6 deletions enzyme/test/Integration/blasinfra.h
Original file line number Diff line number Diff line change
Expand Up @@ -906,9 +906,8 @@ cublasDgemv(cublasHandle_t *handle, cublasOperation_t trans, int M, int N,
double alpha, double *A, int lda, double *X, int incx, double beta,
double *Y, int incy) {
BlasCall call = {ABIType::CUBLAS,handle,
inDerivative, CallType::GEMV, Y, A, X, alpha, beta,
CUBLAS_LAYOUT,
(char)trans, UNUSED_TRANS, M, N, UNUSED_INT, lda, incx, incy};
inDerivative, CallType::GEMV, Y, A, X, alpha, beta, CUBLAS_LAYOUT,
(char)trans, UNUSED_TRANS, M, N, UNUSED_INT, lda, incx, incy};
calls.push_back(call);
return cublasStatus_t::CUBLAS_STATUS_SUCCESS;
}
Expand Down Expand Up @@ -1032,7 +1031,9 @@ void checkVector(BlasInfo info, std::string vecname, int length, int increment,
printf("Error in test %s, invalid memory\n", test.c_str());
printTrace(trace);
printcall(rcall);
printf(" Input %s length must be ", vecname.c_str());
printf(" Input %s (", vecname.c_str());
printty(info.ptr);
printf(") length must be ");
printty(info.vec_length);
printf(" found ");
printty(length);
Expand All @@ -1043,7 +1044,9 @@ void checkVector(BlasInfo info, std::string vecname, int length, int increment,
printf("Error in test %s, invalid memory\n", test.c_str());
printTrace(trace);
printcall(rcall);
printf(" Input %s increment must be ", vecname.c_str());
printf(" Input %s (", vecname.c_str());
printty(info.ptr);
printf(") increment must be ");
printty(info.vec_increment);
printf(" found ");
printty(increment);
Expand Down Expand Up @@ -1167,9 +1170,10 @@ void checkMemory(BlasCall rcall, BlasInfo inputs[6], std::string test,
auto A = pointer_to_index(rcall.pin_arg1, inputs);
auto X = pointer_to_index(rcall.pin_arg2, inputs);


auto layout = rcall.layout;
auto trans_char = rcall.targ1;
auto trans = !(trans_char == 'N' || trans_char == 'n');
auto trans = !is_normal(trans_char);
auto M = rcall.iarg1;
auto N = rcall.iarg2;
auto alpha = rcall.farg1;
Expand Down
6 changes: 3 additions & 3 deletions enzyme/tools/enzyme-tblgen/blas-tblgen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@ void emit_handleBLAS(ArrayRef<TGPattern> blasPatterns, raw_ostream &os) {
<< " bool result = true; \n"
<< " if (!gutils->isConstantInstruction(&call)) { \n"
<< " Type *fpType; \n"
<< " if (blas.floatType == \"d\") { \n"
<< " if (blas.floatType == \"d\" || blas.floatType == \"D\") { \n"
<< " fpType = Type::getDoubleTy(call.getContext()); \n"
<< " } else if (blas.floatType == \"s\") { \n"
<< " } else if (blas.floatType == \"s\" || blas.floatType == \"S\"){\n"
<< " fpType = Type::getFloatTy(call.getContext()); \n"
<< " } else { \n"
<< " assert(false && \"Unreachable\"); \n"
Expand Down Expand Up @@ -453,7 +453,7 @@ void emit_scalar_types(const TGPattern &pattern, raw_ostream &os) {
"cublasOperation_t::CUBLAS_OP_T);\n"
<< " // TODO lascl not available in cublas, nor op G\n"
<< " valueG = ConstantInt::get(cublasEnumType, "
"cublasOperation_t::CUBLAS_OP_N);\n"
"'G');\n"
<< " } else {\n"
<< " valueN = ConstantInt::get(charType, 'N');\n"
<< " valueT = ConstantInt::get(charType, 'T');\n"
Expand Down
Loading

0 comments on commit 5974276

Please sign in to comment.