From ad4a492a524ed6fbe9d47ebeb933c9b0f2ab992f Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 2 Nov 2023 10:32:15 -0500 Subject: [PATCH] Fix macos test [blas] (#1518) --- enzyme/Enzyme/Utils.cpp | 60 +++++++++++----------- enzyme/tools/enzyme-tblgen/blas-tblgen.cpp | 22 ++++++-- 2 files changed, 48 insertions(+), 34 deletions(-) diff --git a/enzyme/Enzyme/Utils.cpp b/enzyme/Enzyme/Utils.cpp index 00a497f05d4f..44ee737cee1c 100644 --- a/enzyme/Enzyme/Utils.cpp +++ b/enzyme/Enzyme/Utils.cpp @@ -2552,9 +2552,10 @@ llvm::Value *is_normal(IRBuilder<> &B, llvm::Value *trans, bool byRef, IntegerType *charTy = IntegerType::get(trans->getContext(), 8); trans = B.CreateLoad(charTy, trans, "loaded.trans"); + auto isN = B.CreateICmpEQ(trans, ConstantInt::get(charTy, 'N')); + auto isn = B.CreateICmpEQ(trans, ConstantInt::get(charTy, 'n')); // fortran blas - return B.CreateOr(B.CreateICmpEQ(trans, ConstantInt::get(charTy, 'n')), - B.CreateICmpEQ(trans, ConstantInt::get(charTy, 'N'))); + return B.CreateOr(isn, isN); } else { // we can inspect scalars return B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 111)); @@ -2570,33 +2571,32 @@ llvm::Value *is_normal(IRBuilder<> &B, llvm::Value *trans, bool byRef, llvm::Value *transpose(IRBuilder<> &B, llvm::Value *V, bool cublas) { llvm::Type *T = V->getType(); if (cublas) { - return B.CreateSelect( - B.CreateICmpEQ(V, ConstantInt::get(T, 1)), - ConstantInt::get(V->getType(), 0), - B.CreateSelect(B.CreateICmpEQ(V, ConstantInt::get(T, 0)), - ConstantInt::get(V->getType(), 1), - ConstantInt::get(V->getType(), 42))); + auto isT1 = B.CreateICmpEQ(V, ConstantInt::get(T, 1)); + auto isT0 = B.CreateICmpEQ(V, ConstantInt::get(T, 0)); + return B.CreateSelect(isT1, ConstantInt::get(V->getType(), 0), + B.CreateSelect(isT0, + ConstantInt::get(V->getType(), 1), + ConstantInt::get(V->getType(), 42))); } else if (T->isIntegerTy(8)) { - return B.CreateSelect( - B.CreateICmpEQ(V, ConstantInt::get(V->getType(), 'T')), - ConstantInt::get(V->getType(), 'N'), - B.CreateSelect( - B.CreateICmpEQ(V, ConstantInt::get(V->getType(), 't')), - ConstantInt::get(V->getType(), 'n'), - B.CreateSelect( - B.CreateICmpEQ(V, ConstantInt::get(V->getType(), 'N')), - ConstantInt::get(V->getType(), 'T'), - B.CreateSelect( - B.CreateICmpEQ(V, ConstantInt::get(V->getType(), 'n')), - ConstantInt::get(V->getType(), 't'), - ConstantInt::get(V->getType(), 0))))); + auto isn = B.CreateICmpEQ(V, ConstantInt::get(V->getType(), 'n')); + auto sel1 = B.CreateSelect(isn, ConstantInt::get(V->getType(), 't'), + ConstantInt::get(V->getType(), 0)); + + auto isN = B.CreateICmpEQ(V, ConstantInt::get(V->getType(), 'N')); + auto sel2 = B.CreateSelect(isN, ConstantInt::get(V->getType(), 'T'), sel1); + + auto ist = B.CreateICmpEQ(V, ConstantInt::get(V->getType(), 't')); + auto sel3 = B.CreateSelect(ist, ConstantInt::get(V->getType(), 'n'), sel2); + + auto isT = B.CreateICmpEQ(V, ConstantInt::get(V->getType(), 'T')); + return B.CreateSelect(isT, ConstantInt::get(V->getType(), 'N'), sel3); + } else if (T->isIntegerTy(32)) { - return B.CreateSelect( - B.CreateICmpEQ(V, ConstantInt::get(V->getType(), 111)), - ConstantInt::get(V->getType(), 112), - B.CreateSelect(B.CreateICmpEQ(V, ConstantInt::get(V->getType(), 112)), - ConstantInt::get(V->getType(), 111), - ConstantInt::get(V->getType(), 0))); + auto is111 = B.CreateICmpEQ(V, ConstantInt::get(V->getType(), 111)); + auto sel1 = B.CreateSelect( + B.CreateICmpEQ(V, ConstantInt::get(V->getType(), 112)), + ConstantInt::get(V->getType(), 111), ConstantInt::get(V->getType(), 0)); + return B.CreateSelect(is111, ConstantInt::get(V->getType(), 112), sel1); } else { std::string s; llvm::raw_string_ostream ss(s); @@ -2695,9 +2695,9 @@ SmallVector get_blas_row(llvm::IRBuilder<> &B, if (!byRef) { cond = B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 111)); } else { - cond = B.CreateOr( - B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 'N')), - B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 'n'))); + auto isn = B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 'n')); + auto isN = B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 'N')); + cond = B.CreateOr(isN, isn); } } else { // CUBLAS_OP_N = 0, CUBLAS_OP_T = 1, CUBLAS_OP_C = 2 diff --git a/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp b/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp index 66a46e56098f..3cdbfb989b1d 100644 --- a/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp @@ -1050,15 +1050,29 @@ void rev_call_arg(DagInit *ruleDag, Rule &rule, size_t actArg, size_t pos, if (Def->isSubClassOf("MagicInst")) { if (Def->getName() == "Rows") { + os << "({"; + for (size_t i = Dag->getNumArgs() - 1;; i--) { + os << "auto brow_" << i << " = "; + rev_call_arg(Dag, rule, actArg, i, os); + os << "; "; + if (i == 0) + break; + } os << "get_blas_row(Builder2, "; for (size_t i = 0; i < Dag->getNumArgs(); i++) { - rev_call_arg(Dag, rule, actArg, i, os); + os << "brow_" << i; os << ", "; } - os << "byRef, cublas)"; + os << "byRef, cublas);})"; return; } if (Def->getName() == "Concat") { + os << "({"; + for (size_t i = 0; i < Dag->getNumArgs(); i++) { + os << "auto concat_" << i << " = "; + rev_call_arg(Dag, rule, actArg, i, os); + os << "; "; + } os << "concat_values<"; for (size_t i = 0; i < Dag->getNumArgs(); i++) { if (i != 0) @@ -1069,9 +1083,9 @@ void rev_call_arg(DagInit *ruleDag, Rule &rule, size_t actArg, size_t pos, for (size_t i = 0; i < Dag->getNumArgs(); i++) { if (i != 0) os << ", "; - rev_call_arg(Dag, rule, actArg, i, os); + os << "concat_" << i; } - os << ")"; + os << "); })"; return; } if (Def->getName() == "ld") {