From f80c2386bcaf8223f4b8b9c92c68d950fccbbd8a Mon Sep 17 00:00:00 2001 From: William Moses Date: Sat, 6 Jul 2024 15:35:39 -0400 Subject: [PATCH] Fix potrf memset (#1968) --- enzyme/Enzyme/BlasDerivatives.td | 7 +++---- enzyme/test/Integration/ReverseMode/blas.cpp | 14 ++++++-------- enzyme/tools/enzyme-tblgen/blas-tblgen.cpp | 12 +++++++++--- 3 files changed, 18 insertions(+), 15 deletions(-) diff --git a/enzyme/Enzyme/BlasDerivatives.td b/enzyme/Enzyme/BlasDerivatives.td index a5472c3e9ef9..704a6f227428 100644 --- a/enzyme/Enzyme/BlasDerivatives.td +++ b/enzyme/Enzyme/BlasDerivatives.td @@ -514,10 +514,10 @@ def spmv : CallBlasPattern<(Op $layout, $uplo, $n, $alpha, $ap, $x, $incx, $beta // d(A^T) −= dB B2^T def trtrs : CallBlasPattern<(Op $layout, $uplo, $trans, $diag, $n, $nrhs, $a, $lda, $b, $ldb, $info), ["b"], - [cblas_layout, uplo, trans, diag, len, len, mld<["n", "n"]>, vinc<["n"]>, len], + [cblas_layout, uplo, trans, diag, len, len, mld<["uplo", "n", "n"]>, vinc<["n"]>, len], [ /* a */ (AssertingInactiveArg), - /* b */ (AssertingInactiveArg), + /* b */ (BlasCall<"trtrs"> $layout, $uplo, $diag, $n, $nrhs, $a, $lda, (Shadow $b), Alloca<1>), ] >; @@ -604,8 +604,7 @@ def potrf: CallBlasPattern<(Op $layout, $uplo, $n, $A, $lda, $info), [cblas_layout, uplo, len, mld<["uplo", "n", "n"]>, len], [ /* A */ - (Seq<["tri", "triangular", "n"], [], 1> - (BlasCall<"lascl"> $layout, (flip_uplo $uplo), ConstantInt<0>, ConstantInt<0>, Constant<"1.0">, Constant<"0.0">, $n, $n, use<"tri">, $n, Alloca<1>), + (Seq<["tri", "zerotriangular", "n"], [], 1> (BlasCall<"lacpy"> $layout, $uplo, $n, $n, (Shadow $A), use<"tri">, $n), (BlasCall<"trmm"> $layout, (uplo_to_side $uplo), $uplo, Char<"T">, Char<"N">, $n, $n, Constant<"1.0">, $A, (ld $A, Char<"N">, $lda, $n, $n), use<"tri">, $n), diff --git a/enzyme/test/Integration/ReverseMode/blas.cpp b/enzyme/test/Integration/ReverseMode/blas.cpp index a35883ff8ef9..2fdbbe611425 100644 --- a/enzyme/test/Integration/ReverseMode/blas.cpp +++ b/enzyme/test/Integration/ReverseMode/blas.cpp @@ -1088,10 +1088,9 @@ static void potrfTests() { inDerivative = true; assert(foundCalls.size() >= 2); - assert(foundCalls[1].type == CallType::LASCL); + assert(foundCalls[1].type == CallType::LACPY); double *tri = (double *)foundCalls[1].pout_arg1; inputs[3] = BlasInfo(tri, layout, N, N, N); - cblas_dlascl(layout, flip_uplo(uplo), 0, 0, 1.0, 0.0, N, N, tri, N, 0); cblas_dlacpy(layout, uplo, N, N, dA, lda, tri, N); @@ -1099,8 +1098,8 @@ static void potrfTests() { A, lda, tri, N); assert(foundCalls.size() >= 5); - assert(foundCalls[4].type == CallType::COPY); - double *tmp = (double *)foundCalls[4].pout_arg1; + assert(foundCalls[3].type == CallType::COPY); + double *tmp = (double *)foundCalls[3].pout_arg1; inputs[4] = BlasInfo(tmp, N, 1); cblas_dcopy(N, tri, N + 1, tmp, 1); @@ -1175,10 +1174,9 @@ static void potrfTests() { cblas_dscal(1, 0.0, dA, lda); assert(foundCalls.size() >= 2); - assert(foundCalls[4].type == CallType::LASCL); + assert(foundCalls[4].type == CallType::LACPY); double *tri = (double *)foundCalls[4].pout_arg1; inputs[3] = BlasInfo(tri, (char)layout, N, N, N); - cblas_dlascl(layout, flip_uplo(uplo), 0, 0, 1.0, 0.0, N, N, tri, N, 0); cblas_dlacpy(layout, uplo, N, N, dA, lda, tri, N); @@ -1186,8 +1184,8 @@ static void potrfTests() { cacheA, N, tri, N); assert(foundCalls.size() >= 5); - assert(foundCalls[7].type == CallType::COPY); - double *tmp = (double *)foundCalls[7].pout_arg1; + assert(foundCalls[6].type == CallType::COPY); + double *tmp = (double *)foundCalls[6].pout_arg1; inputs[4] = BlasInfo(tmp, N, 1); cblas_dcopy(N, tri, N + 1, tmp, 1); diff --git a/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp b/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp index 5437ed16b8ea..fbda8b08fca6 100644 --- a/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp +++ b/enzyme/tools/enzyme-tblgen/blas-tblgen.cpp @@ -1458,7 +1458,8 @@ void emit_tmp_creation(Record *Def, raw_ostream &os, StringRef builder) { assert(args.size() >= 2); auto action = args[1]; assert(action == "product" || action == "is_normal" || - action == "triangular" || action == "vector"); + action == "triangular" || action == "vector" || + action == "zerotriangular"); if (action == "product") { const auto matName = args[0]; const auto dim1 = "arg_" + args[2]; @@ -1489,7 +1490,7 @@ void emit_tmp_creation(Record *Def, raw_ostream &os, StringRef builder) { os << " Value *len1 = load_if_ref(" << builder << ", intType," << dim1 << ", byRef);\n"; os << " Value *size_" << vecName << " = len1;\n"; - } else if (action == "triangular") { + } else if (action == "triangular" || action == "zerotriangular") { assert(args.size() == 3); const auto vecName = args[0]; const auto dim1 = "arg_" + args[2]; @@ -1502,8 +1503,13 @@ void emit_tmp_creation(Record *Def, raw_ostream &os, StringRef builder) { } const auto matName = args[0]; const auto allocName = "mat_" + matName; + if (action == "zerotriangular") + os << " Instruction * zero = nullptr;\n"; os << " Value * true_" << allocName << " = CreateAllocation(" << builder - << ", fpType, size_" << matName << ", \"" << allocName << "\");\n" + << ", fpType, size_" << matName << ", \"" << allocName << "\", nullptr"; + if (action == "zerotriangular") + os << ", &zero"; + os << ");\n" << " Value * " << allocName << " = true_" << allocName << ";\n" << " if (type_vec_like->isIntegerTy()) {\n" << " " << allocName << " = " << builder << ".CreatePtrToInt("