Skip to content

Commit

Permalink
Fix macos test [blas] (#1518)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Nov 2, 2023
1 parent 04b2aa9 commit ad4a492
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 34 deletions.
60 changes: 30 additions & 30 deletions enzyme/Enzyme/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2552,9 +2552,10 @@ llvm::Value *is_normal(IRBuilder<> &B, llvm::Value *trans, bool byRef,
IntegerType *charTy = IntegerType::get(trans->getContext(), 8);
trans = B.CreateLoad(charTy, trans, "loaded.trans");

auto isN = B.CreateICmpEQ(trans, ConstantInt::get(charTy, 'N'));
auto isn = B.CreateICmpEQ(trans, ConstantInt::get(charTy, 'n'));
// fortran blas
return B.CreateOr(B.CreateICmpEQ(trans, ConstantInt::get(charTy, 'n')),
B.CreateICmpEQ(trans, ConstantInt::get(charTy, 'N')));
return B.CreateOr(isn, isN);
} else {
// we can inspect scalars
return B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 111));
Expand All @@ -2570,33 +2571,32 @@ llvm::Value *is_normal(IRBuilder<> &B, llvm::Value *trans, bool byRef,
llvm::Value *transpose(IRBuilder<> &B, llvm::Value *V, bool cublas) {
llvm::Type *T = V->getType();
if (cublas) {
return B.CreateSelect(
B.CreateICmpEQ(V, ConstantInt::get(T, 1)),
ConstantInt::get(V->getType(), 0),
B.CreateSelect(B.CreateICmpEQ(V, ConstantInt::get(T, 0)),
ConstantInt::get(V->getType(), 1),
ConstantInt::get(V->getType(), 42)));
auto isT1 = B.CreateICmpEQ(V, ConstantInt::get(T, 1));
auto isT0 = B.CreateICmpEQ(V, ConstantInt::get(T, 0));
return B.CreateSelect(isT1, ConstantInt::get(V->getType(), 0),
B.CreateSelect(isT0,
ConstantInt::get(V->getType(), 1),
ConstantInt::get(V->getType(), 42)));
} else if (T->isIntegerTy(8)) {
return B.CreateSelect(
B.CreateICmpEQ(V, ConstantInt::get(V->getType(), 'T')),
ConstantInt::get(V->getType(), 'N'),
B.CreateSelect(
B.CreateICmpEQ(V, ConstantInt::get(V->getType(), 't')),
ConstantInt::get(V->getType(), 'n'),
B.CreateSelect(
B.CreateICmpEQ(V, ConstantInt::get(V->getType(), 'N')),
ConstantInt::get(V->getType(), 'T'),
B.CreateSelect(
B.CreateICmpEQ(V, ConstantInt::get(V->getType(), 'n')),
ConstantInt::get(V->getType(), 't'),
ConstantInt::get(V->getType(), 0)))));
auto isn = B.CreateICmpEQ(V, ConstantInt::get(V->getType(), 'n'));
auto sel1 = B.CreateSelect(isn, ConstantInt::get(V->getType(), 't'),
ConstantInt::get(V->getType(), 0));

auto isN = B.CreateICmpEQ(V, ConstantInt::get(V->getType(), 'N'));
auto sel2 = B.CreateSelect(isN, ConstantInt::get(V->getType(), 'T'), sel1);

auto ist = B.CreateICmpEQ(V, ConstantInt::get(V->getType(), 't'));
auto sel3 = B.CreateSelect(ist, ConstantInt::get(V->getType(), 'n'), sel2);

auto isT = B.CreateICmpEQ(V, ConstantInt::get(V->getType(), 'T'));
return B.CreateSelect(isT, ConstantInt::get(V->getType(), 'N'), sel3);

} else if (T->isIntegerTy(32)) {
return B.CreateSelect(
B.CreateICmpEQ(V, ConstantInt::get(V->getType(), 111)),
ConstantInt::get(V->getType(), 112),
B.CreateSelect(B.CreateICmpEQ(V, ConstantInt::get(V->getType(), 112)),
ConstantInt::get(V->getType(), 111),
ConstantInt::get(V->getType(), 0)));
auto is111 = B.CreateICmpEQ(V, ConstantInt::get(V->getType(), 111));
auto sel1 = B.CreateSelect(
B.CreateICmpEQ(V, ConstantInt::get(V->getType(), 112)),
ConstantInt::get(V->getType(), 111), ConstantInt::get(V->getType(), 0));
return B.CreateSelect(is111, ConstantInt::get(V->getType(), 112), sel1);
} else {
std::string s;
llvm::raw_string_ostream ss(s);
Expand Down Expand Up @@ -2695,9 +2695,9 @@ SmallVector<llvm::Value *, 1> get_blas_row(llvm::IRBuilder<> &B,
if (!byRef) {
cond = B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 111));
} else {
cond = B.CreateOr(
B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 'N')),
B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 'n')));
auto isn = B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 'n'));
auto isN = B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 'N'));
cond = B.CreateOr(isN, isn);
}
} else {
// CUBLAS_OP_N = 0, CUBLAS_OP_T = 1, CUBLAS_OP_C = 2
Expand Down
22 changes: 18 additions & 4 deletions enzyme/tools/enzyme-tblgen/blas-tblgen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1050,15 +1050,29 @@ void rev_call_arg(DagInit *ruleDag, Rule &rule, size_t actArg, size_t pos,

if (Def->isSubClassOf("MagicInst")) {
if (Def->getName() == "Rows") {
os << "({";
for (size_t i = Dag->getNumArgs() - 1;; i--) {
os << "auto brow_" << i << " = ";
rev_call_arg(Dag, rule, actArg, i, os);
os << "; ";
if (i == 0)
break;
}
os << "get_blas_row(Builder2, ";
for (size_t i = 0; i < Dag->getNumArgs(); i++) {
rev_call_arg(Dag, rule, actArg, i, os);
os << "brow_" << i;
os << ", ";
}
os << "byRef, cublas)";
os << "byRef, cublas);})";
return;
}
if (Def->getName() == "Concat") {
os << "({";
for (size_t i = 0; i < Dag->getNumArgs(); i++) {
os << "auto concat_" << i << " = ";
rev_call_arg(Dag, rule, actArg, i, os);
os << "; ";
}
os << "concat_values<";
for (size_t i = 0; i < Dag->getNumArgs(); i++) {
if (i != 0)
Expand All @@ -1069,9 +1083,9 @@ void rev_call_arg(DagInit *ruleDag, Rule &rule, size_t actArg, size_t pos,
for (size_t i = 0; i < Dag->getNumArgs(); i++) {
if (i != 0)
os << ", ";
rev_call_arg(Dag, rule, actArg, i, os);
os << "concat_" << i;
}
os << ")";
os << "); })";
return;
}
if (Def->getName() == "ld") {
Expand Down

0 comments on commit ad4a492

Please sign in to comment.