Skip to content

Commit

Permalink
Fix potrf memset (#1968)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Jul 6, 2024
1 parent 28bc30c commit f80c238
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 15 deletions.
7 changes: 3 additions & 4 deletions enzyme/Enzyme/BlasDerivatives.td
Original file line number Diff line number Diff line change
Expand Up @@ -514,10 +514,10 @@ def spmv : CallBlasPattern<(Op $layout, $uplo, $n, $alpha, $ap, $x, $incx, $beta
// d(A^T) −= dB B2^T
def trtrs : CallBlasPattern<(Op $layout, $uplo, $trans, $diag, $n, $nrhs, $a, $lda, $b, $ldb, $info),
["b"],
[cblas_layout, uplo, trans, diag, len, len, mld<["n", "n"]>, vinc<["n"]>, len],
[cblas_layout, uplo, trans, diag, len, len, mld<["uplo", "n", "n"]>, vinc<["n"]>, len],
[
/* a */ (AssertingInactiveArg),
/* b */ (AssertingInactiveArg),
/* b */ (BlasCall<"trtrs"> $layout, $uplo, $diag, $n, $nrhs, $a, $lda, (Shadow $b), Alloca<1>),
]
>;

Expand Down Expand Up @@ -604,8 +604,7 @@ def potrf: CallBlasPattern<(Op $layout, $uplo, $n, $A, $lda, $info),
[cblas_layout, uplo, len, mld<["uplo", "n", "n"]>, len],
[
/* A */
(Seq<["tri", "triangular", "n"], [], 1>
(BlasCall<"lascl"> $layout, (flip_uplo $uplo), ConstantInt<0>, ConstantInt<0>, Constant<"1.0">, Constant<"0.0">, $n, $n, use<"tri">, $n, Alloca<1>),
(Seq<["tri", "zerotriangular", "n"], [], 1>
(BlasCall<"lacpy"> $layout, $uplo, $n, $n, (Shadow $A), use<"tri">, $n),

(BlasCall<"trmm"> $layout, (uplo_to_side $uplo), $uplo, Char<"T">, Char<"N">, $n, $n, Constant<"1.0">, $A, (ld $A, Char<"N">, $lda, $n, $n), use<"tri">, $n),
Expand Down
14 changes: 6 additions & 8 deletions enzyme/test/Integration/ReverseMode/blas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1088,19 +1088,18 @@ static void potrfTests() {
inDerivative = true;

assert(foundCalls.size() >= 2);
assert(foundCalls[1].type == CallType::LASCL);
assert(foundCalls[1].type == CallType::LACPY);
double *tri = (double *)foundCalls[1].pout_arg1;
inputs[3] = BlasInfo(tri, layout, N, N, N);
cblas_dlascl(layout, flip_uplo(uplo), 0, 0, 1.0, 0.0, N, N, tri, N, 0);

cblas_dlacpy(layout, uplo, N, N, dA, lda, tri, N);

cblas_dtrmm(layout, uplo_to_side(uplo), uplo, 'T', 'N', N, N, 1.0,
A, lda, tri, N);

assert(foundCalls.size() >= 5);
assert(foundCalls[4].type == CallType::COPY);
double *tmp = (double *)foundCalls[4].pout_arg1;
assert(foundCalls[3].type == CallType::COPY);
double *tmp = (double *)foundCalls[3].pout_arg1;
inputs[4] = BlasInfo(tmp, N, 1);

cblas_dcopy(N, tri, N + 1, tmp, 1);
Expand Down Expand Up @@ -1175,19 +1174,18 @@ static void potrfTests() {
cblas_dscal(1, 0.0, dA, lda);

assert(foundCalls.size() >= 2);
assert(foundCalls[4].type == CallType::LASCL);
assert(foundCalls[4].type == CallType::LACPY);
double *tri = (double *)foundCalls[4].pout_arg1;
inputs[3] = BlasInfo(tri, (char)layout, N, N, N);
cblas_dlascl(layout, flip_uplo(uplo), 0, 0, 1.0, 0.0, N, N, tri, N, 0);

cblas_dlacpy(layout, uplo, N, N, dA, lda, tri, N);

cblas_dtrmm(layout, uplo_to_side(uplo), uplo, 'T', 'N', N, N, 1.0,
cacheA, N, tri, N);

assert(foundCalls.size() >= 5);
assert(foundCalls[7].type == CallType::COPY);
double *tmp = (double *)foundCalls[7].pout_arg1;
assert(foundCalls[6].type == CallType::COPY);
double *tmp = (double *)foundCalls[6].pout_arg1;
inputs[4] = BlasInfo(tmp, N, 1);

cblas_dcopy(N, tri, N + 1, tmp, 1);
Expand Down
12 changes: 9 additions & 3 deletions enzyme/tools/enzyme-tblgen/blas-tblgen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1458,7 +1458,8 @@ void emit_tmp_creation(Record *Def, raw_ostream &os, StringRef builder) {
assert(args.size() >= 2);
auto action = args[1];
assert(action == "product" || action == "is_normal" ||
action == "triangular" || action == "vector");
action == "triangular" || action == "vector" ||
action == "zerotriangular");
if (action == "product") {
const auto matName = args[0];
const auto dim1 = "arg_" + args[2];
Expand Down Expand Up @@ -1489,7 +1490,7 @@ void emit_tmp_creation(Record *Def, raw_ostream &os, StringRef builder) {
os << " Value *len1 = load_if_ref(" << builder << ", intType," << dim1
<< ", byRef);\n";
os << " Value *size_" << vecName << " = len1;\n";
} else if (action == "triangular") {
} else if (action == "triangular" || action == "zerotriangular") {
assert(args.size() == 3);
const auto vecName = args[0];
const auto dim1 = "arg_" + args[2];
Expand All @@ -1502,8 +1503,13 @@ void emit_tmp_creation(Record *Def, raw_ostream &os, StringRef builder) {
}
const auto matName = args[0];
const auto allocName = "mat_" + matName;
if (action == "zerotriangular")
os << " Instruction * zero = nullptr;\n";
os << " Value * true_" << allocName << " = CreateAllocation(" << builder
<< ", fpType, size_" << matName << ", \"" << allocName << "\");\n"
<< ", fpType, size_" << matName << ", \"" << allocName << "\", nullptr";
if (action == "zerotriangular")
os << ", &zero";
os << ");\n"
<< " Value * " << allocName << " = true_" << allocName << ";\n"
<< " if (type_vec_like->isIntegerTy()) {\n"
<< " " << allocName << " = " << builder << ".CreatePtrToInt("
Expand Down

0 comments on commit f80c238

Please sign in to comment.