Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
ZuseZ4 committed Oct 16, 2023
1 parent aea349c commit ee7dc8d
Show file tree
Hide file tree
Showing 3 changed files with 208 additions and 5 deletions.
8 changes: 7 additions & 1 deletion enzyme/Enzyme/BlasDerivatives.td
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,11 @@ class use<string _name> {

class FrobInnerProd<string _tmp>;

// when use-lapack=1 is set and we are not on a gpu,
// calls out to lascl, otherwise we will loop over the rows and call scal.
// If the matrix is continuous, we use a single scal as minor optimization.
class ScaleMatrix<string _tmp>;

class DiagUpdateSPMV<string _tmp>;

// General note: If return is scalar, return it. If return is vec, update it.
Expand Down Expand Up @@ -245,7 +250,8 @@ def gemm : CallBlasPattern<(Op $layout, $transa, $transb, $m, $n, $k, $alpha, $A
(Concat adj<"C">, $A, (ld $A, $transa, $lda, $m, $k))),
Constant<"1.0">, adj<"B">),
/* beta */ (FrobInnerProd<""> $m, $n, adj<"C">, input<"C">),
/* C */ (b<"lascl"> $layout, Char<"G">, ConstantInt<0>, ConstantInt<0>, Constant<"1.0">, $beta, $m, $n, adj<"C">, Alloca<1>)
/* C */ (ScaleMatrix<""> $m, $n, $beta, adj<"C">)
///* C */ (b<"lascl"> $layout, Char<"G">, ConstantInt<0>, ConstantInt<0>, Constant<"1.0">, $beta, $m, $n, adj<"C">, Alloca<1>)
]
>;

Expand Down
180 changes: 180 additions & 0 deletions enzyme/Enzyme/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,186 @@ void callMemcpyStridedLapack(llvm::IRBuilder<> &B, llvm::Module &M,
B.CreateCall(fn, args, bundles);
}

void callScaleMatrix(llvm::IRBuilder<> &B, llvm::Module &M, BlasInfo blas, llvm::Value *valueG, llvm::Value *cublas_handle,
ArrayRef<Value *> args,
llvm::ArrayRef<llvm::OperandBundleDef> bundles, bool byRef, bool cublas) {
// if we are on a cpu and can use lapack, use lascl.
// Otherwise, use scal per row.
// If the matrix is continuous, use a single large scal.
// If nrows or ncols = 0, return.

Value *m = args[0];
Value *n = args[1];
Value *alpha = args[2];
Value *matrix = args[3];
Value *ldmatrix = args[4];

if (!cublas && EnzymeLapackCopy) {
// (b<"lascl"> Char<"G">, ConstantInt<0>, ConstantInt<0>, Constant<"1.0">, $beta, $m, $n, adj<"C">)
std::string lascl_name = blas.floatType + "lascl" + blas.suffix;

SmallVector<Type *, 1> tys;
std::vector<Value *> args;
Value *LapackInfoInt = B.createAlloca(blas.intType, nullptr, "LapackInfoInt");

args.push_back(valueG);
// The lower bandwidth of A.
args.push_back(ConstantInt::get(blas.intType, 0));
// The upper bandwidth of A.
args.push_back(ConstantInt::get(blas.intType, 0));
args.push_back(ConstantFP::get(blas.fpType, 1.0));
args.push_back(alpha);
args.push_back(m);
args.push_back(n);
args.push_back(matrix);
args.push_back(ldmatrix);
args.push_back(LapackInfoInt);

for (auto arg : args)
tys.push_back(arg->getType());

auto FT = FunctionType::get(Type::getVoidTy(M.getContext()), tys, false);
auto fn = M.getOrInsertFunction(lascl_name, FT);
if (auto F = dyn_cast<Function>(fn.getCallee()))
{
attribute_lascl(blas, F);
}
B.CreateCall(fn, args, bundles);
return;
}

if (cublas)
assert(cublas_handle);

std::string fnc_name =
"__enzyme_scale_matrix_" + blas.prefix + blas.floatType + blas.suffix;
FunctionType *scalMatFT = nullptr;

if (!F->empty()) {
B.CreateCall(F, args, bundles);
return;
}
BasicBlock *entry = BasicBlock::Create(M.getContext(), "entry", F);
BasicBlock *check = BasicBlock::Create(M.getContext(), "init", F);
BasicBlock *large_scal = BasicBlock::Create(M.getContext(), "uper", F);
BasicBlock *multi_scal = BasicBlock::Create(M.getContext(), "lower", F);
BasicBlock *end = BasicBlock::Create(M.getContext(), "end", F);

// void scaleMat(m, n, alpha, A, lda)
// void cuscaleMat(handle, m, n, alpha, A, lda)
// void scal(n, alpha, x, incx=1)
auto firstarg = F->arg_begin();
if (cublas) {
firstarg->setName("handle");
firstarg++;
}
auto blasm = firstarg;
blasm->setName("m");
auto blasn = blasm + 1;
blasn->setName("blasn");
auto blasalpha = blasn + 1;
blasalpha->setName("blasalpha");
auto blasA = blasalpha + 1;
blasA->setName("blasA");
auto blaslda = blasA + 1;
blaslda->setName("blaslda");

llvm::Value *byRefAlpha = nullptr;
if (cublas) {
if (blasalpha->getType()->isPointerTy()) {
byRefAlpha = blasalpha;
} else {
byRefAlpha = entry.CreateAlloca(BlasFPT, nullptr, "alpha");
entry.CreateStore(blasalpha, byRefAlpha);
}
}



{
// if (m == 0 || n == 0) return;
IRBuilder<> B(entry);
Value *isZero =
B.CreateOr(B.CreateICmpEQ(m, ConstantInt::get(m->getType(), 0)),
B.CreateICmpEQ(n, ConstantInt::get(n->getType(), 0)));
B.CreateCondBr(isZero, end, check);
}

{
// check if matrix is continuous
IRBuilder<> B2(check);
Value *isCont = B2.CreateICmpEQ(blaslda, blasm);
B2.CreateCondBr(isCont, large_scal, multi_scal);
}

{
// if the matrix is continuous, use a single large scale
IRBuilder<> B3(large_scal);
B3.setFastMathFlags(getFast());
llvm::Value *mat_len = B3.CreateMul(blasm, blasn);
llvm::ArrayRef<llvm::Value *> args;
if (cublas) {
args.push_back(cublas_handle);
}
args.push_back(mat_len);
if (cublas | byRef) {
args.push_back(byRefAlpha);
} else {
args.push_back(blasalpha);
}
args.push_back(blasA);
args.push_back(blasOne);

B3.CreateCall(F, args, bundles);
B3.CreateBr(end);
}

{
// if the matrix is not continuous, use a scale per row
IRBuilder<> B4(multi_scal);
B4.setFastMathFlags(getFast());
PHINode *Aidx = B4.CreatePHI(IT, 2, "Aidx");
PHINode *iter = B4.CreatePHI(IT, 2, "iteration");
PHINode *sum = B4.CreatePHI(fpTy, 2, "sum");
Aidx->addIncoming(ConstantInt::get(IT, 0), init);
iter->addIncoming(ConstantInt::get(IT, 0), init);
sum->addIncoming(ConstantFP::get(fpTy, 0.0), init);

Value *Ai = B4.CreateInBoundsGEP(fpTy, Afloat, Aidx, "A.i");
Value *AiScal = B4.CreatePointerCast(Ai, BlasPT);
llvm::ArrayRef<llvm::Value *> args;
if (cublas) {
args.push_back(cublas_handle);
}
args.push_back(n);
if (cublas | byRef) {
args.push_back(byRefAlpha);
} else {
args.push_back(blasalpha);
}
args.push_back(blasAiScal);
args.push_back(blasOne);
Value *newScal = B4.CreateCall(FScal, args, bundles);

Value *Anext = B4.CreateNUWAdd(Aidx, lda, "Aidx.next");
Value *iternext = B4.CreateAdd(iter, ConstantInt::get(IT, 1), "iter.next");
Value *sumnext = B4.CreateFAdd(sum, newDot);

iter->addIncoming(iternext, body);
Aidx->addIncoming(Anext, body);
sum->addIncoming(sumnext, body);

B4.CreateCondBr(B4.CreateICmpEQ(iter, m), end, body);
}

{
IRBuilder<> B5(end);
B5.CreateRetVoid();
}

B.CreateCall(F, args, bundles);
}

void callSPMVDiagUpdate(IRBuilder<> &B, Module &M, BlasInfo blas,
IntegerType *IT, Type *BlasCT, Type *BlasFPT,
Type *BlasPT, Type *BlasIT, Type *fpTy,
Expand Down
25 changes: 21 additions & 4 deletions enzyme/tools/enzyme-tblgen/blas-tblgen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1240,6 +1240,7 @@ void rev_call_args(StringRef argName, Rule &rule, size_t actArg,
// Distinguish later trough byRef if it is cblas (thus has layout)
os << " if (cblas) " << argName << ".push_back(arg_layout);\n";
}
// handle exist only under the cublas ABI, but there for all fncs.
os << " if (cublas) " << argName << ".push_back(arg_handle);\n";

for (size_t pos = fncHasLayout ? 1 : 0; pos < numArgs; pos++) {
Expand Down Expand Up @@ -1304,8 +1305,6 @@ void emit_fret_call(StringRef dfnc_name, StringRef argName, StringRef name,
os << "}\n";
}

// todo: update rt_active_<X> to use actual dag requirements,
// possibly by or-ing them
void emit_runtime_condition(DagInit *ruleDag, StringRef name, StringRef tab,
StringRef B, bool isFP, raw_ostream &os) {
os << tab << "BasicBlock *nextBlock_" << name << " = nullptr;\n"
Expand Down Expand Up @@ -1518,9 +1517,11 @@ void emit_rev_rewrite_rules(const StringMap<TGPattern> &patternMap,
os << " if (!cublas) {\n";
emit_fret_call(dfnc_name, "ArrayRef<Value *>(args1)", name, "Builder2",
os);
// TODO: think again about this cublas float ret part
os << " } else {\n";
} else {
os << " SmallVector<Type*, 1> tys; for (auto arg : args1) "
os << " SmallVector<Type*, 1> tys;\n"
<< "for (auto arg : args1) "
"tys.push_back(arg->getType());\n";
std::string dfnc_ret_ty = get_blas_ret_ty(dfnc_name);

Expand Down Expand Up @@ -1553,15 +1554,31 @@ void emit_rev_rewrite_rules(const StringMap<TGPattern> &patternMap,
} else if (Def->isSubClassOf("MagicInst") && Def->getName() == "noop") {
} else if (Def->isSubClassOf("MagicInst") && Def->getName() == "inactive") {
os << " assert(!active_" << name << ");\n";
} else if (Def->isSubClassOf("ScaleMatrix")) {
// /* C */ (ScaleMatrix<""> $m, $n, $beta, adj<"C">, $ldc)
assert(ty == ArgType::mldData);
os << " // ScaleMatrix\n";
emit_if_rule_condition(ruleDag, name, " ", os);
emit_runtime_condition(ruleDag, name, " ", "Builder2", true, os);
rev_call_args("args1", rule, actArg, os, -1, "");
os << " const auto Defs = gutils->getInvertedBundles(&call, {"
<< valueTypes << "}, Builder2, /* lookup */ true);\n";
// Now that we have the defs, we can create the call
os << "callScaleMatrix(Builder2, *gutils->oldFunc->getParent(), blas, "
"intType, blasCharType, blasFPType, type_vec_like, type_n, fpType, "
"ArrayRef<Value *>(args1), "
"Defs, byRef, cublas);\n";
emit_runtime_continue(ruleDag, name, " ", "Builder2", true, os);
os << " }\n";
} else if (Def->isSubClassOf("DiagUpdateSPMV")) {
assert(ty == ArgType::ap);
os << " // DiagUpdateSPMV\n";
emit_if_rule_condition(ruleDag, name, " ", os);
emit_runtime_condition(ruleDag, name, " ", "Builder2", true, os);
rev_call_args("args1", rule, actArg, os, -1, "");
os << " const auto Defs = gutils->getInvertedBundles(&call, {"
<< valueTypes << "}, Builder2, /* lookup */ true);\n";
// Now that we have the defs, we can create the call
assert(ty == ArgType::ap);
os << "callSPMVDiagUpdate(Builder2, *gutils->oldFunc->getParent(), blas, "
"intType, blasCharType, blasFPType, type_vec_like, type_n, fpType, "
"ArrayRef<Value *>(args1), "
Expand Down

0 comments on commit ee7dc8d

Please sign in to comment.