diff --git a/enzyme/Enzyme/BlasDerivatives.td b/enzyme/Enzyme/BlasDerivatives.td index 8b7285ede3f..3bf5b42bc36 100644 --- a/enzyme/Enzyme/BlasDerivatives.td +++ b/enzyme/Enzyme/BlasDerivatives.td @@ -336,7 +336,6 @@ def ger : CallBlasPattern<(Op $layout, $m, $n, $alpha, $x, $incx, $y, $incy, $A, // } // C := alpha*op( A )*op( B ) + beta*C -// FWD: dC = dalpha A B + alpha dA B + alpha A dB + dbeta C + beta dC def gemm : CallBlasPattern<(Op $layout, $transa, $transb, $m, $n, $k, $alpha, $A, $lda, $B, $ldb, $beta, $C, $ldc), ["C"], [cblas_layout, trans, trans, len, len, len, fp, mld<["transa", "m", "k"]>, mld<["transb", "k", "n"]>, fp, mld<["m", "n"]>], @@ -365,6 +364,7 @@ def gemm : CallBlasPattern<(Op $layout, $transa, $transb, $m, $n, $k, $alpha, $A /* beta */ (FrobInnerProd<""> $m, $n, (Shadow $C), input<"C">), /* C */ (BlasCall<"lascl"> $layout, Char<"G">, ConstantInt<0>, ConstantInt<0>, Constant<"1.0">, $beta, $m, $n, (Shadow $C), Alloca<1>) ], + // FWD: dC = dalpha A B + alpha dA B + alpha A dB + dbeta C + beta dC (Seq<[], ["beta1"], 1> (BlasCall<"axpy"> (AssertingInactiveArg), (Shadow $beta), $C, (Shadow $C)), (BlasCall<"gemm"> $layout, $transa, $transb, $m, $n, $k, $alpha, $A, (ld $A, $transa, $lda, $k, $m), (Shadow $B), (FirstUse<"beta1"> $beta, Constant<"1">), (Shadow $C)), @@ -411,7 +411,7 @@ def trmm : CallBlasPattern<(Op $layout, $side, $uplo, $transa, $diag, $m, $n, $a ConstantInt<1> ) ), - /* B */ (BlasCall<"trmm"> $layout, $side, $uplo, transpose<"transa">, $diag, $m, $n, $alpha, $A, (ld $A, $side, $lda, $m, $n), (Shadow $B)) + /* B */ (BlasCall<"trmm"> $layout, $side, $uplo, transpose<"transa">, $diag, $m, $n, $alpha, $A, (ld $A, (side_to_trans $side), $lda, $n, $m), (Shadow $B)) ] >; @@ -420,11 +420,56 @@ def symm: CallBlasPattern<(Op $layout, $side, $uplo, $m, $n, $alpha, $A, $lda, $ [cblas_layout, side, uplo, len, len, fp, mld<["side", "m", "n"]>, mld<["m", "n"]>, fp, mld<["m", "n"]>], [ /*alpha*/ (AssertingInactiveArg), - /*A*/ (AssertingInactiveArg), - /*B*/ (AssertingInactiveArg), + /*A*/ (Seq<["tmp", "is_left", "side", "m", "n"], [], 1> + (BlasCall<"copy"> + (ISelect (is_left $side), $m, $n), + (First (Shadow $A)), + (Add $lda, ConstantInt<1>), + use<"tmp">, + ConstantInt<1> + ), + (BlasCall<"syr2k"> + $layout, + $uplo, + (side_to_trans $side), + (ISelect (is_left $side), $m, $n), + (ISelect (is_left $side), $n, $m), + $alpha, + $B, (ld $B, Char<"N">, $ldb, $m, $m), + (Shadow $C), + Constant<"1.0">, + (Shadow $A) + ), + (BlasCall<"axpy"> + (ISelect (is_left $side), $m, $n), + Constant<"-1">, + (First (Shadow $A)), + (Add $lda, ConstantInt<1>), + use<"tmp">, + ConstantInt<1> + ), + (BlasCall<"axpy"> + (ISelect (is_left $side), $m, $n), + Constant<"0.5">, + use<"tmp">, + ConstantInt<1>, + (First (Shadow $A)), + (Add $lda, ConstantInt<1>) + ) + ), + /*B*/ (BlasCall<"symm"> $layout, $side, $uplo, $m, $n, $alpha, $A, (ld $A, (side_to_trans $side), $lda, $n, $m), (Shadow $C), Constant<"1">, (Shadow $B)), /*beta*/ (AssertingInactiveArg), - /*C*/ (AssertingInactiveArg), - ]>; + /*C*/ (BlasCall<"lascl"> $layout, Char<"G">, ConstantInt<0>, ConstantInt<0>, Constant<"1.0">, $beta, $m, $n, (Shadow $C), Alloca<1>) + ], + // FWD: dC = dalpha A B + alpha dA B + alpha A dB + dbeta C + beta dC + (Seq<[], ["beta1"], 1> + (BlasCall<"axpy"> (AssertingInactiveArg), (Shadow $beta), $C, (Shadow $C)), + (BlasCall<"symm"> $layout, $side, $uplo, $m, $n, $alpha, $A, (ld $A, (side_to_trans $side), $lda, $n, $m), (Shadow $B), (FirstUse<"beta1"> $beta, Constant<"1">), (Shadow $C)), + (BlasCall<"symm"> $layout, $side, $uplo, $m, $n, $alpha, (Shadow $A), $B, (ld $B, Char<"N">, $ldb, $m, $m), (FirstUse<"beta1"> $beta, Constant<"1">), (Shadow $C)), + (BlasCall<"symm"> $layout, $side, $uplo, $m, $n, (Shadow $alpha), $A, (ld $A, (side_to_trans $side), $lda, $n, $m), $B, (ld $B, Char<"N">, $ldb, $m, $m), (FirstUse<"beta1"> $beta, Constant<"1">), (Shadow $C)), + (FirstUse<"beta1"> (BlasCall<"lascl"> $layout, Char<"G">, ConstantInt<0>, ConstantInt<0>, Constant<"1.0">, $beta, $m, $n, (Shadow $C), Alloca<1>)) + ) + >; def syr2k : CallBlasPattern<(Op $layout, $uplo, $trans, $n, $k, $alpha, $A, $lda, $B, $ldb, $beta, $C, $ldc), ["C"], @@ -444,7 +489,7 @@ def syrk : CallBlasPattern<(Op $layout, $uplo, $trans, $n, $k, $alpha, $A, $lda, [ /* alpha */ (AssertingInactiveArg), /*(Seq<["AB", "product", "m", "n"], [], true> - (BlasCall<"gemm"> $layout, $transa, $transb, $m, $n, $k, Constant<"1.0">, $A, (ld $A, $transa, $lda, $k, $m), $B, (ld $B, $transb, $ldb, $k, $n), Constant<"0.0">, use<"AB">, $m),// TODO: check if last arg should be $m or $n + (BlasCall<"gemm"> $layout, $transa, $transb, $m, $n, $k, Constant<"1.0">, $A, (ld $A, $transa, $lda, $n, $k), $B, (ld $B, $transb, $ldb, $k, $n), Constant<"0.0">, use<"AB">, $m),// TODO: check if last arg should be $m or $n (FrobInnerProd<""> $m, $n, (Shadow $C), use<"AB">)),*/ /* A */ (Seq<[], [], 1> diff --git a/enzyme/test/Integration/ReverseMode/blas.cpp b/enzyme/test/Integration/ReverseMode/blas.cpp index e4f9dbc339f..88b19c5af36 100644 --- a/enzyme/test/Integration/ReverseMode/blas.cpp +++ b/enzyme/test/Integration/ReverseMode/blas.cpp @@ -68,6 +68,14 @@ void my_dtrmm(char layout, char side, char uplo, inDerivative = true; } +void ow_dtrmm(char layout, char side, char uplo, + char trans, char diag, int M, int N, + double alpha, double * A, int lda, + double * B, int ldb) { + cblas_dtrmm(layout, side, uplo, trans, diag, M, N, alpha, A, lda, B, ldb); + inDerivative = true; +} + void my_dsyrk(char layout, char uplo, char trans, int N, int K, double alpha, double *__restrict__ A, int lda, double beta, @@ -111,6 +119,24 @@ void ow_trtrs(char layout, char uplo, char trans, char diag, int N, int Nrhs, inDerivative = true; } +void my_symm(char layout, char side, char uplo, + int M, int N, double alpha, + double * __restrict__ A, int lda, double * __restrict__ B, + int ldb, double beta, double * __restrict__ C, + int ldc) { + cblas_dsymm(layout, side, uplo, M, N, alpha, A, lda, B, ldb, beta, C, ldc); + inDerivative = true; +} + +void ow_symm(char layout, char side, char uplo, + int M, int N, double alpha, + double * A, int lda, double * B, + int ldb, double beta, double * C, + int ldc) { + cblas_dsymm(layout, side, uplo, M, N, alpha, A, lda, B, ldb, beta, C, ldc); + inDerivative = true; +} + static void dotTests() { std::string Test = "DOT active both "; @@ -896,6 +922,157 @@ static void trmmTests() { checkMemoryTrace(inputs, "Found " + Test, foundCalls); } + + { + + bool trans = !is_normal(transA); + std::string Test = "TRMM overwrite active A, B "; + BlasInfo inputs[6] = { + /*A*/ BlasInfo(A, layout, (side == 'L' || side == 'l') ? M : N, (side == 'L' || side == 'l') ? M : N, lda), + /*B*/ BlasInfo(B, layout, M, N, incB), + BlasInfo(), + BlasInfo(), + BlasInfo(), + BlasInfo() + }; + init(); + + ow_dtrmm(layout, side, uplo, (char)transA, diag, M, N, alpha, A, lda, B, incB); + + // Check memory of primal on own. + checkMemoryTrace(inputs, "Primal " + Test, calls); + + init(); + __enzyme_autodiff((void*) ow_dtrmm, + enzyme_const, layout, + enzyme_const, side, + enzyme_const, uplo, + enzyme_const, transA, + enzyme_const, diag, + enzyme_const, M, + enzyme_const, N, + enzyme_const, alpha, + enzyme_dup, A, dA, + enzyme_const, lda, + enzyme_dup, B, dB, + enzyme_const, incB); + foundCalls = calls; + init(); + + + double* cacheA = (double*)foundCalls[0].pout_arg1; + + cblas_dlacpy(layout, '\0', is_left(side) ? M : N, is_left(side) ? M : N, + A, + lda, cacheA, is_left(side) ? M : N); + inputs[5] = BlasInfo(cacheA, layout, is_left(side) ? M : N, is_left(side) ? M : N, is_left(side) ? M : N); + + double* cacheB = (double*)foundCalls[1].pout_arg1; + + cblas_dlacpy(layout, '\0', M, N, + B, + incB, cacheB, M); + inputs[4] = BlasInfo(cacheB, layout, M, N, M); + + ow_dtrmm(layout, side, uplo, (char)transA, diag, M, N, alpha, A, lda, B, incB); + + assert(foundCalls.size() >= 2); + assert(foundCalls[0].type == CallType::LACPY); + inDerivative = true; + + auto d = (diag == 'n' || diag == 'N') ? 0 : 1; + + #define B0(r,c) cacheB[(r-1)*(layout == CblasRowMajor ? M : 1) + (c-1)*(layout == CblasRowMajor ? 1 : M) ] + #define Ba(r,c) dB[(r-1)*(layout == CblasRowMajor ? incB : 1) + (c-1)*(layout == CblasRowMajor ? 1 : incB) ] + #define Aa(r,c) dA[(r-1)*(layout == CblasRowMajor ? lda : 1) + (c-1)*(layout == CblasRowMajor ? 1 : lda) ] + + auto ldb = incB; + + char toTrans; + if (side == 'l') + toTrans = 'n'; + else if (side == 'L') + toTrans = 'N'; + else if (side == 'r') + toTrans = 't'; + else if (side == 'R') + toTrans = 'T'; + + if (side == 'l' || side == 'L') { + if (is_normal(transA)) { + // BLAS operation + // B = alpha*A*B0 + // RMD operation + // Aa += alpha*Ba*B0' + if(uplo == 'u' || uplo == 'U') { + // A is upper triangular + for (int i=1; i<=M; i++) + cblas_dgemv(layout, toTrans,i-d,N, alpha,dB,incB,&B0(i, 1),M,1.0,&Aa(1, i),1); + } else { + // A is lower triangular + for (int i=1; i<=M-d; i++) + cblas_dgemv(layout, toTrans,M-i+1-d,N,alpha,&Ba(i+d,1),ldb,&B0(i,1),M,1.0, &Aa(i+d,i),1); + } + } else { + // BLAS operation + // B = alpha*A'*B0 + // RMD operation + // Aa += alpha*B*Ba' + if(uplo == 'u' || uplo == 'U') { + // A is upper triangular + for (int i=1; i<=M; i++) + cblas_dgemv(layout, toTrans,i-d,N, alpha,&B0(1,1),M,&Ba(i,1),ldb,1.0,&Aa(1,i),1); + } else { + // A is lower triangular + for (int i=1; i<=M-d; i++) + cblas_dgemv(layout, toTrans,M-i+1-d,N,alpha,&B0(i+d,1),M,&Ba(i,1),ldb,1.0, &Aa(i+d,i),1); + } + } + } else { + if (is_normal(transA)) { + // BLAS operation + // B = alpha*B0*A + // RMD operation + // Aa += alpha*B0'*Ba + if(uplo == 'u' || uplo == 'U') { + // A is upper triangular + for (int i=1; i<=N; i++) + cblas_dgemv(layout, toTrans,M,i-d,alpha,&B0(1,1),M,&Ba(1,i),1, 1.0,&Aa(1,i),1); + } else { + // A is lower triangular + for (int i=1; i<=N-d; i++) + cblas_dgemv(layout, toTrans,M,N-i+1-d,alpha,&B0(1,i+d),M,&Ba(1,i),1, 1.0, & + Aa(i+d,i),1); + } + } else { + // BLAS operation + // B = alpha*B0*A' + // RMD operation + // Aa += alpha*Ba'*B0 + if(uplo == 'u' || uplo == 'U') { + // A is upper triangular + for (int i=1; i<=N; i++) + cblas_dgemv(layout, toTrans,M,i-d,alpha,&Ba(1,1),ldb,&B0(1,i),1, 1.0,&Aa(1,i),1); + } else { + // A is lower triangular + for (int i=1; i<=N-d; i++) + cblas_dgemv(layout, toTrans,M,N-i+1-d,alpha,&Ba(1,i+d),ldb,&B0(1,i),1, 1.0, &Aa(i+d,i),1); + } + } + } + + cblas_dtrmm(layout, side, uplo, (char)transpose(transA), diag, M, N, alpha, cacheA, is_left(side) ? M : N, dB, incB); + + checkTest(Test); + + // Check memory of primal of expected derivative + checkMemoryTrace(inputs, "Expected " + Test, calls); + + // Check memory of primal of our derivative (if equal above, it + // should be the same). + checkMemoryTrace(inputs, "Found " + Test, foundCalls); + + } } @@ -1624,6 +1801,202 @@ static void trtrsTests() { } } +static void symmTests() { + int N = 17; + int M = 9; + // N means normal matrix, T means transposed + for (char layout : {CblasColMajor, CblasRowMajor}) { + for (auto uplo : {'U', 'u', 'L', 'l'}) + for (auto side : {'L', 'l', 'R', 'r'}) { + BlasInfo inputs[6] = { + /*A*/ BlasInfo(A, layout, is_left(side) ? M : N, is_left(side) ? M : N, lda), + /*B*/ BlasInfo(B, layout, M, N, incB), + /*C*/ BlasInfo(C, layout, M, N, incC), + BlasInfo(), + BlasInfo(), + BlasInfo(), + }; + { + + std::string Test = "SYMM active A, B, C"; + init(); + + my_symm(layout, side, uplo, M, N, alpha, A, lda, B, incB, beta, C, incC); + + assert(calls.size() == 1); + assert(calls[0].inDerivative == false); + assert(calls[0].type == CallType::SYMM); + assert(calls[0].pout_arg1 == C); + assert(calls[0].pin_arg1 == A); + assert(calls[0].pin_arg2 == B); + assert(calls[0].farg1 == alpha); + assert(calls[0].farg2 == beta); + assert(calls[0].layout == layout); + assert(calls[0].targ1 == UNUSED_TRANS); + assert(calls[0].targ2 == UNUSED_TRANS); + assert(calls[0].iarg1 == M); + assert(calls[0].iarg2 == N); + assert(calls[0].iarg3 == UNUSED_INT); + assert(calls[0].iarg4 == lda); + assert(calls[0].iarg5 == incB); + assert(calls[0].iarg6 == incC); + assert(calls[0].side == side); + assert(calls[0].uplo == uplo); + assert(calls[0].diag == UNUSED_TRANS); + + // Check memory of primal on own. + checkMemoryTrace(inputs, "Primal " + Test, calls); + + init(); + __enzyme_autodiff((void *)my_symm, + enzyme_const, layout, + enzyme_const, side, + enzyme_const, uplo, + enzyme_const, M, + enzyme_const, N, + enzyme_const, alpha, + enzyme_dup, A, dA, + enzyme_const, lda, + enzyme_dup, B, dB, + enzyme_const, incB, + enzyme_const, beta, + enzyme_dup, C, dC, + enzyme_const, incC); + foundCalls = calls; + init(); + + my_symm(layout, side, uplo, M, N, alpha, A, lda, B, incB, beta, C, incC); + + inDerivative = true; + + + assert(foundCalls[1].type == CallType::COPY); + double *tmp = (double *)foundCalls[1].pout_arg1; + cblas_dcopy(is_left(side) ? M : N, dA, lda+1, tmp, 1); + inputs[3] = BlasInfo(tmp, is_left(side) ? M : N, 1); + + // ssyr2k(uplo, 'n', m, n, alpha,B,ldb,Ca,ldc, 1.0,Aa,lda) + // ssyr2k(uplo,'t', n,m, alpha,B,ldb,Ca,ldc, 1.0,Aa,lda) + cblas_dsyr2k(layout, + uplo, + side_to_trans(side), + is_left(side) ? M : N, + is_left(side) ? N : M, + alpha, + B, + incB, + dC, + incC, + 1.0, + dA, + lda); + + cblas_daxpy(is_left(side) ? M : N, -1, dA, lda+1, tmp, 1); + cblas_daxpy(is_left(side) ? M : N, 0.5, tmp, 1, dA, lda+1); + + cblas_dsymm(layout, side, uplo, M, N, alpha, A, lda, dC, incC, 1.0, dB, incB); + + cblas_dlascl(layout, 'G', 0, 0, 1.0, beta, M, N, dC, incC, 0 ); + + checkTest(Test); + + SkipVecIncCheck = true; + // Check memory of primal of expected derivative + checkMemoryTrace(inputs, "Expected " + Test, calls); + + // Check memory of primal of our derivative (if equal above, it + // should be the same). + checkMemoryTrace(inputs, "Found " + Test, foundCalls); + SkipVecIncCheck = false; + } + { + + std::string Test = "SYMM overwriten active A, B, C"; + init(); + + ow_symm(layout, side, uplo, M, N, alpha, A, lda, B, incB, beta, C, incC); + + // Check memory of primal on own. + checkMemoryTrace(inputs, "Primal " + Test, calls); + + init(); + __enzyme_autodiff((void *)ow_symm, + enzyme_const, layout, + enzyme_const, side, + enzyme_const, uplo, + enzyme_const, M, + enzyme_const, N, + enzyme_const, alpha, + enzyme_dup, A, dA, + enzyme_const, lda, + enzyme_dup, B, dB, + enzyme_const, incB, + enzyme_const, beta, + enzyme_dup, C, dC, + enzyme_const, incC); + foundCalls = calls; + init(); + + double *cacheA = (double *)foundCalls[0].pout_arg1; + inputs[4] = BlasInfo(cacheA, layout, is_left(side) ? M : N, is_left(side) ? M : N, is_left(side) ? M : N); + assert(inputs[4].ty == ValueType::Matrix); + cblas_dlacpy(layout, '\0', is_left(side) ? M : N, is_left(side) ? M : N, A, lda, cacheA, is_left(side) ? M : N); + + double *cacheB = (double *)foundCalls[1].pout_arg1; + inputs[5] = BlasInfo(cacheB, layout, M, N, M); + assert(inputs[5].ty == ValueType::Matrix); + cblas_dlacpy(layout, '\0', M, N, B, incB, cacheB, M); + + ow_symm(layout, side, uplo, M, N, alpha, A, lda, B, incB, beta, C, incC); + + inDerivative = true; + + //cblas_dscal(1, 0.0, dA, lda); + + + //assert(foundCalls[1].type == CallType::COPY); + double *tmp = (double *)foundCalls[3].pout_arg1; + cblas_dcopy(is_left(side) ? M : N, dA, lda+1, tmp, 1); + inputs[3] = BlasInfo(tmp, is_left(side) ? M : N, 1); + + // ssyr2k(uplo, 'n', m, n, alpha,B,ldb,Ca,ldc, 1.0,Aa,lda) + // ssyr2k(uplo,'t', n,m, alpha,B,ldb,Ca,ldc, 1.0,Aa,lda) + cblas_dsyr2k(layout, + uplo, + side_to_trans(side), + is_left(side) ? M : N, + is_left(side) ? N : M, + alpha, + cacheB, + M, + dC, + incC, + 1.0, + dA, + lda); + + cblas_daxpy(is_left(side) ? M : N, -1, dA, lda+1, tmp, 1); + cblas_daxpy(is_left(side) ? M : N, 0.5, tmp, 1, dA, lda+1); + + cblas_dsymm(layout, side, uplo, M, N, alpha, cacheA, is_left(side) ? M : N, dC, incC, 1.0, dB, incB); + + cblas_dlascl(layout, 'G', 0, 0, 1.0, beta, M, N, dC, incC, 0 ); + + checkTest(Test); + + SkipVecIncCheck = true; + // Check memory of primal of expected derivative + checkMemoryTrace(inputs, "Expected " + Test, calls); + + // Check memory of primal of our derivative (if equal above, it + // should be the same). + checkMemoryTrace(inputs, "Found " + Test, foundCalls); + SkipVecIncCheck = false; + } + } + } +} + int main() { dotTests(); @@ -1644,4 +2017,6 @@ int main() { potrsTests(); trtrsTests(); + + symmTests(); } diff --git a/enzyme/test/Integration/blasinfra.h b/enzyme/test/Integration/blasinfra.h index b716f250673..b16cf53e054 100644 --- a/enzyme/test/Integration/blasinfra.h +++ b/enzyme/test/Integration/blasinfra.h @@ -74,6 +74,38 @@ enum class CBLAS_TRANSPOSE : char { CblasConjTrans = 113 }; +bool is_left(char c) { + switch (c) { + case 'L': + return true; + case 'l': + return true; + case 'R': + return false; + case 'r': + return false; + default: + printf("Illegal isleft of '%c' %d\n", c, c); + exit(1); + } +} + +char side_to_trans(char c) { + switch (c) { + case 'L': + return 'N'; + case 'l': + return 'n'; + case 'R': + return 'T'; + case 'r': + return 't'; + default: + printf("Illegal side_to_trans of '%c' %d\n", c, c); + exit(1); + } +} + bool is_normal(char c) { switch (c) { case 'N': @@ -1102,6 +1134,8 @@ void printcall(BlasCall rcall) { printty(rcall.pin_arg2); printf(", ldb="); printty(rcall.iarg5); + printf(", beta="); + printty(rcall.farg2); printf(", C="); printty(rcall.pout_arg1); printf(", ldc="); @@ -2845,7 +2879,7 @@ void checkMemory(BlasCall rcall, BlasInfo inputs[6], std::string test, auto left = side_char == 'L' || side_char == 'l'; checkMatrix(C, "C", layout, /*rows=*/M, - /*cols=*/N, /*ld=*/ldb, test, rcall, trace); + /*cols=*/N, /*ld=*/ldc, test, rcall, trace); checkMatrix(B, "B", layout, /*rows=*/M, /*cols=*/N, /*ld=*/ldb, test, rcall, trace); diff --git a/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp b/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp index da2938c6979..2528aff1eac 100644 --- a/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp @@ -1473,7 +1473,7 @@ void emit_tmp_creation(Record *Def, raw_ostream &os, StringRef builder) { auto action = args[1]; assert(action == "product" || action == "is_normal" || action == "triangular" || action == "vector" || - action == "zerotriangular"); + action == "zerotriangular" || action == "is_left"); if (action == "product") { const auto matName = args[0]; const auto dim1 = "arg_" + args[2]; @@ -1497,6 +1497,19 @@ void emit_tmp_creation(Record *Def, raw_ostream &os, StringRef builder) { os << " Value *size_" << vecName << " = " << builder << ".CreateSelect(is_normal(" << builder << ", " << trans << ", byRef, cublas), len1, len2);\n"; + } else if (action == "is_left") { + assert(args.size() == 5); + const auto vecName = args[0]; + const auto trans = "arg_" + args[2]; + const auto dim1 = "arg_" + args[3]; + const auto dim2 = "arg_" + args[4]; + os << " Value *len1 = load_if_ref(" << builder << ", intType," << dim1 + << ", byRef);\n" + << " Value *len2 = load_if_ref(" << builder << ", intType," << dim2 + << ", byRef);\n"; + os << " Value *size_" << vecName << " = " << builder + << ".CreateSelect(is_left(" << builder << ", " << trans + << ", byRef, cublas), len1, len2);\n"; } else if (action == "vector") { assert(args.size() == 3); const auto vecName = args[0]; diff --git a/enzyme/tools/enzyme-tblgen/caching.cpp b/enzyme/tools/enzyme-tblgen/caching.cpp index affd25685b6..f1753634adf 100644 --- a/enzyme/tools/enzyme-tblgen/caching.cpp +++ b/enzyme/tools/enzyme-tblgen/caching.cpp @@ -272,8 +272,7 @@ uplostr = " Value *uplo = arg_" + nameVec[dimensions[0]] + ";\n"; } else if (startty == ArgType::side) { os << " Value *normal = is_left(BuilderZ, arg_" << nameVec[dimensions[0]] << ", byRef, cublas);\n" -<< " M = BuilderZ.CreateSelect(normal, " << dim1 << ", " << dim2 << ");\n" -<< " N = BuilderZ.CreateSelect(normal, " << dim2 << ", " << dim1 << ");\n"; +<< " M = N = BuilderZ.CreateSelect(normal, " << dim1 << ", " << dim2 << ");\n"; } else { assert(0 &&" unknown startty"); }