Skip to content

Commit

Permalink
Fully functional cuda integration tests
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Oct 16, 2023
1 parent 5974276 commit 37fa7a0
Show file tree
Hide file tree
Showing 8 changed files with 168 additions and 72 deletions.
10 changes: 8 additions & 2 deletions enzyme/test/Integration/ReverseMode/cublas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,20 @@ static void dotTests() {
enzyme_dup, A, dA, enzyme_const, incA, enzyme_dup, B, dB,
enzyme_const, incB);
foundCalls = calls;

auto stack_ret = (double*)foundCalls[1].pin_arg2;
inputs[4] = BlasInfo(stack_ret, 1, 1);

init();

my_ddot(handle, N, A, incA, B, incB);

calls[0].pout_arg1 = (double*)foundCalls[0].pout_arg1;

inDerivative = true;

cublasDaxpy(handle, N, 1.0, B, incB, dA, incA);
cublasDaxpy(handle, N, 1.0, A, incA, dB, incB);
cublasDaxpy(handle, N, stack_ret, B, incB, dA, incA);
cublasDaxpy(handle, N, stack_ret, A, incA, dB, incB);

checkTest(Test);

Expand Down
24 changes: 19 additions & 5 deletions enzyme/test/Integration/blasinfra.h
Original file line number Diff line number Diff line change
Expand Up @@ -466,8 +466,13 @@ void printcall(BlasCall rcall) {
printty(rcall.handle);
printf(", N=");
printty(rcall.iarg1);
printf(", alpha=");
printty(rcall.farg1);
if (rcall.abi != ABIType::CUBLAS) {
printf(", alpha=");
printty(rcall.farg1);
} else {
printf(", alphap=");
printty(rcall.pin_arg2);
}
printf(", X=");
printty(rcall.pin_arg1);
printf(", incx=");
Expand All @@ -493,6 +498,9 @@ void printcall(BlasCall rcall) {
printty(rcall.pin_arg2);
printf(", incy=");
printty(rcall.iarg5);
if (rcall.abi == ABIType::CUBLAS)
printf(", result=");
printty(rcall.pout_arg1);
printf(")");
return;
case CallType::GEMV:
Expand Down Expand Up @@ -879,17 +887,17 @@ __attribute__((noinline)) cublasStatus_t cublasDdot(cublasHandle_t *handle,
return cublasStatus_t::CUBLAS_STATUS_SUCCESS;
}
__attribute__((noinline)) cublasStatus_t cublasDaxpy(cublasHandle_t *handle,
int N, double alpha,
int N, double *alpha,
double *X, int incx,
double *Y, int incy) {
BlasCall call = {ABIType::CUBLAS,handle,inDerivative,
CallType::AXPY,
Y,
X,
UNUSED_POINTER,
alpha,
UNUSED_DOUBLE,
CUBLAS_LAYOUT,
UNUSED_DOUBLE,
CUBLAS_LAYOUT,
UNUSED_TRANS,
UNUSED_TRANS,
N,
Expand Down Expand Up @@ -1144,12 +1152,18 @@ void checkMemory(BlasCall rcall, BlasInfo inputs[6], std::string test,

auto alpha = rcall.farg1;

auto cualpha = pointer_to_index(rcall.pin_arg2, inputs);

auto N = rcall.iarg1;
auto incX = rcall.iarg4;
auto incY = rcall.iarg5;

checkVector(X, "X", /*len=*/N, /*inc=*/incX, test, rcall, trace);
checkVector(Y, "Y", /*len=*/N, /*inc=*/incY, test, rcall, trace);

if (rcall.abi == ABIType::CUBLAS) {
checkVector(cualpha, "alpha", /*len=*/1, /*inc=*/1, test, rcall, trace);
}
return;
}
case CallType::DOT: {
Expand Down
115 changes: 62 additions & 53 deletions enzyme/tools/enzyme-tblgen/blas-tblgen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,10 @@ using namespace llvm;

// TODO: add this to .td file and generate it based on that
std::string get_blas_ret_ty(StringRef dfnc_name) {
if (dfnc_name == "dot" || dfnc_name == "asum" || dfnc_name == "nrm2" ||
dfnc_name == "iamax" || dfnc_name == "iamin" ||
dfnc_name == "inner_prod") {
if (has_active_return(dfnc_name))
return "fpType";
}
return "Builder2.getVoidTy()";
else
return "Builder2.getVoidTy()";
}

bool hasDiffeRet(Init *resultTree) {
Expand All @@ -50,6 +48,12 @@ bool hasDiffeRet(Init *resultTree) {
return true;
}
}
if (DefInit *DefArg = dyn_cast<DefInit>(resultTree)) {
auto Def = DefArg->getDef();
if (Def->isSubClassOf("DiffeRetIndex")) {
return true;
}
}
return false;
}

Expand Down Expand Up @@ -211,9 +215,11 @@ void emit_helper(const TGPattern &pattern, raw_ostream &os) {
bool lv23 = pattern.isBLASLevel2or3();
const auto mutArgSet = pattern.getMutableArgs();

os << " const bool byRef = blas.prefix == \"\" || blas.prefix == \"cublas_\";\n";
os << " const bool byRef = blas.prefix == \"\" || blas.prefix == "
"\"cublas_\";\n";
os << " const bool cblas = blas.prefix == \"cblas_\";\n";
os << " const bool cublas = blas.prefix == \"cublas_\" || blas.prefix == \"cublas\";\n";
os << " const bool cublas = blas.prefix == \"cublas_\" || blas.prefix == "
"\"cublas\";\n";
os << " Value *cacheval = nullptr;\n\n";
// lv 2 or 3 functions have an extra arg under the cblas_ abi
os << " const int offset = (";
Expand Down Expand Up @@ -367,11 +373,11 @@ void emit_helper(const TGPattern &pattern, raw_ostream &os) {
"(Type*) Type::getInt8PtrTy(call.getContext()) : "
"(Type*) Type::getInt8Ty(call.getContext());\n";

os << " Type *cublasEnumType = nullptr;\n";
for (auto name : enumerate(nameVec)) {
assert(argTypeMap.count(name.index()) == 1);
auto ty = argTypeMap.lookup(name.index());
if (ty == ArgType::trans) {
os << " Type *cublasEnumType = nullptr;\n";
os << " if (cublas) cublasEnumType = type_" << name.value() << ";\n";
break;
}
Expand Down Expand Up @@ -443,22 +449,34 @@ void emit_scalar_types(const TGPattern &pattern, raw_ostream &os) {
<< " if (julia_decl)\n"
<< " julia_decl_type = intType;\n";

os << " Value *valueN = nullptr;\n"
<< " Value *valueT = nullptr;\n"
<< " Value *valueG = nullptr;\n"
<< " if (cublas) {\n"
<< " valueN = ConstantInt::get(cublasEnumType, "
"cublasOperation_t::CUBLAS_OP_N);\n"
<< " valueT = ConstantInt::get(cublasEnumType, "
"cublasOperation_t::CUBLAS_OP_T);\n"
<< " // TODO lascl not available in cublas, nor op G\n"
<< " valueG = ConstantInt::get(cublasEnumType, "
"'G');\n"
<< " } else {\n"
<< " valueN = ConstantInt::get(charType, 'N');\n"
<< " valueT = ConstantInt::get(charType, 'T');\n"
<< " valueG = ConstantInt::get(charType, 'G');\n"
<< " }\n\n";
auto argTypeMap = pattern.getArgTypeMap();
bool hasTrans = false;
for (auto name : enumerate(nameVec)) {
assert(argTypeMap.count(name.index()) == 1);
auto ty = argTypeMap.lookup(name.index());
if (ty == ArgType::trans) {
hasTrans = true;
break;
}
}
if (hasTrans) {
os << " Value *valueN = nullptr;\n"
<< " Value *valueT = nullptr;\n"
<< " Value *valueG = nullptr;\n"
<< " if (cublas) {\n"
<< " valueN = ConstantInt::get(cublasEnumType, "
"cublasOperation_t::CUBLAS_OP_N);\n"
<< " valueT = ConstantInt::get(cublasEnumType, "
"cublasOperation_t::CUBLAS_OP_T);\n"
<< " // TODO lascl not available in cublas, nor op G\n"
<< " valueG = ConstantInt::get(cublasEnumType, "
"'G');\n"
<< " } else {\n"
<< " valueN = ConstantInt::get(charType, 'N');\n"
<< " valueT = ConstantInt::get(charType, 'T');\n"
<< " valueG = ConstantInt::get(charType, 'G');\n"
<< " }\n\n";
}
}

void extract_scalar(StringRef name, StringRef elemTy, raw_ostream &os) {
Expand Down Expand Up @@ -1278,8 +1296,7 @@ void emit_fret_call(StringRef dfnc_name, StringRef argName, StringRef name,
os << " if (byRef) {\n"
<< " ((DiffeGradientUtils *)gutils)"
<< "->addToInvertedPtrDiffe(&call, nullptr, fpType, 0,"
<< "(blas.is64 ? 8 : 4), orig_" << name << ", cubcall, "
<< bb << ");\n"
<< "(blas.is64 ? 8 : 4), orig_" << name << ", cubcall, " << bb << ");\n"
<< " } else {\n"
<< " addToDiffe(orig_" << name << ", cubcall, " << bb
<< ", fpType);\n"
Expand Down Expand Up @@ -1359,36 +1376,21 @@ void emit_rev_rewrite_rules(const StringMap<TGPattern> &patternMap,
// and we should emit the code for handling it.
bool hasDiffeRetVal = false;
for (auto derivOp : rules) {
DagInit *resultRoot = derivOp.getRuleDag(); // correct
for (size_t pos = 0; pos < resultRoot->getNumArgs(); pos++) {
Init *arg = resultRoot->getArg(pos);
if (DefInit *DefArg = dyn_cast<DefInit>(arg)) {
auto Def = DefArg->getDef();
if (Def->isSubClassOf("DiffeRetIndex")) {
hasDiffeRetVal = true;
}
}
}
auto opName = resultRoot->getOperator()->getAsString();
auto Def = cast<DefInit>(resultRoot->getOperator())->getDef();
if (opName == "DiffeRetIndex" || Def->isSubClassOf("DiffeRetIndex")) {
hasDiffeRetVal = true;
}
for (auto arg : resultRoot->getArgs()) {
hasDiffeRetVal |= hasDiffeRet(arg);
}
hasDiffeRetVal |= hasDiffeRet(derivOp.getRuleDag());
}

os << " /* rev-rewrite */ \n"
<< " if (Mode == DerivativeMode::ReverseModeCombined ||\n"
<< " Mode == DerivativeMode::ReverseModeGradient) {\n"
<< " Value *alloc = nullptr;\n"
<< " if (byRef) {\n"
<< " if (byRef && !cublas) {\n"
<< " alloc = allocationBuilder.CreateAlloca(fpType, nullptr, "
"\"ret\");\n"
<< " }\n\n";

if (hasDiffeRetVal) {
os << " Value *dif = diffe(&call, Builder2);\n";
os << " Value *dif = cublas ? gutils->invertPointerM(call.getArgOperand("
<< typeMap.size() << " + offset), Builder2) : diffe(&call, Builder2);\n";
}

// We only emit one derivcall per blass call type.
Expand Down Expand Up @@ -1469,7 +1471,7 @@ void emit_rev_rewrite_rules(const StringMap<TGPattern> &patternMap,

if (hasDiffeRetVal) {
os << ((first) ? "" : ", ") << "Value *dif) {\n"
<< " if (byRef) {\n"
<< " if (byRef && !cublas) {\n"
<< " Builder2.CreateStore(dif, alloc);\n"
<< " dif = alloc;\n"
<< " }\n";
Expand Down Expand Up @@ -1543,7 +1545,7 @@ void emit_rev_rewrite_rules(const StringMap<TGPattern> &patternMap,
os << " Builder2.CreateCall(derivcall_" << dfnc_name
<< ", args1, Defs);\n";
}
if (ty == ArgType::fp)
if (ty == ArgType::fp)
os << " }\n";
emit_runtime_continue(ruleDag, name, " ", "Builder2",
(ty == ArgType::fp), os);
Expand Down Expand Up @@ -1659,6 +1661,11 @@ void emit_rev_rewrite_rules(const StringMap<TGPattern> &patternMap,
PrintFatalError("Unhandled blas-rev case!");
}
}
if (hasDiffeRetVal) {
os << " if (cublas)\n";
os << " Builder2.CreateStore(Constant::getNullValue(fpType), dif);\n";
}

os << " },\n"
<< " ";

Expand All @@ -1674,11 +1681,13 @@ void emit_rev_rewrite_rules(const StringMap<TGPattern> &patternMap,
first = false;
}
if (hasDiffeRetVal) {
os << ((first) ? "" : ", ") << "dif);\n"
<< " setDiffe(\n"
<< " &call,\n"
<< " Constant::getNullValue(gutils->getShadowType(call.getType())),\n"
<< " Builder2);\n";
os << ((first) ? "" : ", ") << "dif);\n";
os << " if (!cublas)\n"
<< " setDiffe(\n"
<< " &call,\n"
<< " "
"Constant::getNullValue(gutils->getShadowType(call.getType())),\n"
<< " Builder2);\n";
} else {
os << " );\n";
}
Expand Down
39 changes: 31 additions & 8 deletions enzyme/tools/enzyme-tblgen/blasDeclUpdater.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@ void emit_attributeBLAS(const TGPattern &pattern, raw_ostream &os) {
auto name = pattern.getName();
bool lv23 = pattern.isBLASLevel2or3();
os << "void attribute_" << name << "(BlasInfo blas, llvm::Function *F) {\n";
os << " const bool byRef = blas.prefix == \"\" || blas.prefix == \"cublas_\";\n";
os << " const bool byRef = blas.prefix == \"\" || blas.prefix == "
"\"cublas_\";\n";
os << " const bool cblas = blas.prefix == \"cblas_\";\n";
os << " const bool cublas = blas.prefix == \"cublas_\" || blas.prefix == \"cublas\";\n";
os << " const bool cublas = blas.prefix == \"cublas_\" || blas.prefix == "
"\"cublas\";\n";
os << "#if LLVM_VERSION_MAJOR >= 16\n"
<< " F->setOnlyAccessesArgMemory();\n"
<< "#else\n"
Expand All @@ -43,12 +45,20 @@ void emit_attributeBLAS(const TGPattern &pattern, raw_ostream &os) {
DenseSet<size_t> mutableArgs = pattern.getMutableArgs();

if (mutableArgs.size() == 0) {
// under cublas, these functions have an extra write-only return ptr
// argument
if (has_active_return(name)) {
os << " if (!cublas) {\n";
}
os << "#if LLVM_VERSION_MAJOR >= 16\n";
os << " F->setOnlyReadsMemory();\n";
os << "#else\n";
os << " F->removeFnAttr(llvm::Attribute::ReadNone);\n";
os << " F->addFnAttr(llvm::Attribute::ReadOnly);\n";
os << "#endif\n";
if (has_active_return(name)) {
os << " }\n";
}
}

os << " const int offset = (";
Expand Down Expand Up @@ -97,11 +107,11 @@ void emit_attributeBLAS(const TGPattern &pattern, raw_ostream &os) {
if (is_char_arg(typeOfArg) || typeOfArg == ArgType::len ||
typeOfArg == ArgType::vincInc || typeOfArg == ArgType::fp ||
typeOfArg == ArgType::mldLD) {
os << " F->removeParamAttr(" << i << " + offset"
os << " F->removeParamAttr(" << i << " + offset"
<< ", llvm::Attribute::ReadNone);\n"
<< " F->addParamAttr(" << i << " + offset"
<< " F->addParamAttr(" << i << " + offset"
<< ", llvm::Attribute::ReadOnly);\n"
<< " F->addParamAttr(" << i << " + offset"
<< " F->addParamAttr(" << i << " + offset"
<< ", llvm::Attribute::NoCapture);\n";
}
}
Expand Down Expand Up @@ -134,14 +144,27 @@ void emit_attributeBLAS(const TGPattern &pattern, raw_ostream &os) {
<< ", llvm::Attribute::get(F->getContext(), \"enzyme_NoCapture\"));\n";
if (mutableArgs.count(argPos) == 0) {
// Only emit ReadOnly if the arg isn't mutable
os << " F->addParamAttr(" << i << " + offset"
os << " F->addParamAttr(" << i << " + offset"
<< ", llvm::Attribute::get(F->getContext(), "
"\"enzyme_ReadOnly\"));\n";
}
}
}
os << " }\n"
<< "}\n";
os << " }\n";

if (has_active_return(name)) {
// under cublas, these functions have an extra return ptr argument
size_t ptrRetArg = argTypeMap.size();
os << " if (cublas) {\n"
<< " F->removeParamAttr(" << ptrRetArg << " + offset"
<< ", llvm::Attribute::ReadNone);\n"
<< " F->addParamAttr(" << ptrRetArg << " + offset"
<< ", llvm::Attribute::WriteOnly);\n"
<< " F->addParamAttr(" << ptrRetArg << " + offset"
<< ", llvm::Attribute::NoCapture);\n"
<< " }\n";
}
os << "}\n";
}

void emitBlasDeclUpdater(const RecordKeeper &RK, raw_ostream &os) {
Expand Down
Loading

0 comments on commit 37fa7a0

Please sign in to comment.