Skip to content

Commit

Permalink
Potrs (#1965)
Browse files Browse the repository at this point in the history
* potrs infra

* rev potrs checks

* fix

* fix
  • Loading branch information
wsmoses authored Jul 6, 2024
1 parent 803ee2e commit 273a3a7
Show file tree
Hide file tree
Showing 6 changed files with 325 additions and 12 deletions.
55 changes: 54 additions & 1 deletion enzyme/Enzyme/BlasDerivatives.td
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,7 @@ def potrf: CallBlasPattern<(Op $layout, $uplo, $n, $A, $lda, $info),
[cblas_layout, uplo, len, mld<["uplo", "n", "n"]>, len],
[
/* A */
(Seq<["tri", "triangular", "n"], [], 0>
(Seq<["tri", "triangular", "n"], [], 1>
(BlasCall<"lascl"> $layout, (flip_uplo $uplo), ConstantInt<0>, ConstantInt<0>, Constant<"1.0">, Constant<"0.0">, $n, $n, use<"tri">, $n, Alloca<1>),
(BlasCall<"lacpy"> $layout, $uplo, $n, $n, (Shadow $A), use<"tri">, $n),

Expand Down Expand Up @@ -688,3 +688,56 @@ def potrf: CallBlasPattern<(Op $layout, $uplo, $n, $A, $lda, $info),
)
)
>;


def potrs: CallBlasPattern<(Op $layout, $uplo, $n, $nrhs, $A, $lda, $B, $ldb, $info),
["B"],
[cblas_layout, uplo, len, len, mld<["uplo", "n", "n"]>, mld<["n", "nrhs"]>, len],
[
(Seq<["tri", "triangular", "n"], [], 1>

(BlasCall<"syr2k"> $layout, Char<"U">, Char<"N">, $n, $nrhs, Constant<"1.0">, input<"B">, $n, (Shadow $B), Constant<"0.0">, use<"tri">, $n),
(CopyLowerToUpper<""> $layout, Char<"U">, (Concat use<"tri">, $n), $n),

(BlasCall<"trsm"> $layout, (uplo_to_rside $uplo), $uplo, Char<"T">, Char<"N">, $n, $n, Constant<"1.0">, $A, (ld $A, Char<"N">, $lda, $n, $n), use<"tri">, $n),

(BlasCall<"potrs"> $layout, $uplo, $n, $n, $A, (ld $A, Char<"N">, $lda, $n, $n), use<"tri">, $n, Alloca<1>),

(For<"i", 0> $n,
(BlasCall<"axpy">
(Sub $n, $i),
Constant<"-1.0">,
(First
(Lookup $layout,
(Concat use<"tri">, $n),
$i,
$i
)
),
(First
(Lookup $layout,
(Concat ConstantInt<0>, $n),
(ISelect (is_lower $uplo), ConstantInt<1>, ConstantInt<0>),
(ISelect (is_lower $uplo), ConstantInt<0>, ConstantInt<1>)
)
),
(First
(Lookup $layout,
(Shadow $A),
$i,
$i
)
),
(First
(Lookup $layout,
(Concat ConstantInt<0>, $lda),
(ISelect (is_lower $uplo), ConstantInt<1>, ConstantInt<0>),
(ISelect (is_lower $uplo), ConstantInt<0>, ConstantInt<1>)
)
)
)
)
),
(BlasCall<"potrs"> $layout, $uplo, $n, $nrhs, $A, (ld $A, Char<"N">, $lda, $n, $n), (Shadow $B), Alloca<1>)
]
>;
8 changes: 4 additions & 4 deletions enzyme/Enzyme/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2836,10 +2836,10 @@ std::optional<BlasInfo> extractBLAS(llvm::StringRef in)
llvm::Optional<BlasInfo> extractBLAS(llvm::StringRef in)
#endif
{
const char *extractable[] = {"dot", "scal", "axpy", "gemv", "gemm",
"spmv", "syrk", "nrm2", "trmm", "trmv",
"symm", "potrf", "copy", "spmv", "syr2k",
"potrs", "getrf", "getrs", "trtrs", "getri"};
const char *extractable[] = {
"dot", "scal", "axpy", "gemv", "gemm", "spmv", "syrk",
"nrm2", "trmm", "trmv", "symm", "potrf", "potrs", "copy",
"spmv", "syr2k", "potrs", "getrf", "getrs", "trtrs", "getri"};
const char *floatType[] = {"s", "d", "c", "z"};
const char *prefixes[] = {"" /*Fortran*/, "cblas_"};
const char *suffixes[] = {"", "_", "64_", "_64_"};
Expand Down
163 changes: 163 additions & 0 deletions enzyme/test/Integration/ReverseMode/blas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,12 @@ void ow_potrf(char layout, char uplo, int N, double *__restrict__ A, int lda) {
inDerivative = true;
}

void my_potrs(char layout, char uplo, int N, int Nrhs, double *__restrict__ A, int lda, double *__restrict__ B, int ldb) {
int info;
cblas_dpotrs(layout, uplo, N, Nrhs, A, lda, B, ldb, &info);
inDerivative = true;
}

static void dotTests() {

std::string Test = "DOT active both ";
Expand Down Expand Up @@ -1232,6 +1238,161 @@ static void potrfTests() {
}
}

static void potrsTests() {
int N = 17;
int Nrhs = M;
// N means normal matrix, T means transposed
for (char layout : {CblasColMajor, CblasRowMajor}) {
for (auto uplo : {'U', 'u', 'L', 'l'})

{
BlasInfo inputs[6] = {
/*A*/ BlasInfo(A, layout, N, N, lda),
/*B*/ BlasInfo(B, layout, N, Nrhs, incB),
/*C*/ BlasInfo(),
BlasInfo(),
BlasInfo(),
BlasInfo(),
};
{

std::string Test = "POTRS active A, B";
init();

my_potrs(layout, uplo, N, Nrhs, A, lda, B, incB);

assert(calls.size() == 1);
assert(calls[0].inDerivative == false);
assert(calls[0].type == CallType::POTRS);
assert(calls[0].pout_arg1 == B);
assert(calls[0].pin_arg1 == A);
assert(calls[0].pin_arg2 == UNUSED_POINTER);
assert(calls[0].farg1 == UNUSED_DOUBLE);
assert(calls[0].farg2 == UNUSED_DOUBLE);
assert(calls[0].layout == layout);
assert(calls[0].targ1 == UNUSED_TRANS);
assert(calls[0].targ2 == UNUSED_TRANS);
assert(calls[0].iarg1 == N);
assert(calls[0].iarg2 == Nrhs);
assert(calls[0].iarg3 == UNUSED_INT);
assert(calls[0].iarg4 == lda);
assert(calls[0].iarg5 == incB);
assert(calls[0].iarg6 == UNUSED_INT);
assert(calls[0].side == UNUSED_TRANS);
assert(calls[0].uplo == uplo);
assert(calls[0].diag == UNUSED_TRANS);

// Check memory of primal on own.
checkMemoryTrace(inputs, "Primal " + Test, calls);

init();
__enzyme_autodiff((void *)my_potrs, enzyme_const, layout, enzyme_const,
uplo, enzyme_const, N, enzyme_const, Nrhs, enzyme_dup, A, dA,
enzyme_const, lda, enzyme_dup, B, dB, enzyme_const, incB);
foundCalls = calls;
init();

assert(foundCalls[0].type == CallType::LACPY);
double *inpB = (double *)foundCalls[0].pout_arg1;
inputs[3] = BlasInfo(inpB, layout, N, Nrhs, N);
cblas_dlacpy(layout, '\0', N, Nrhs, B, incB, inpB, N);

my_potrs(layout, uplo, N, Nrhs, A, lda, B, incB);

inDerivative = true;

assert(foundCalls[2].type == CallType::SYR2K);
double *tri = (double *)foundCalls[2].pout_arg1;
inputs[4] = BlasInfo(tri, layout, N, N, N);
cblas_dsyr2k(layout, 'U', 'N', N, Nrhs, 1.0, inpB, N, dB, incB, 0.0,
tri, N);

#define triv(r, c) \
tri[(r) * (layout == CblasRowMajor ? N : 1) + \
(c) * (layout == CblasRowMajor ? 1 : N)]

bool is_lower = uplo == 'L' || uplo == 'l';
int upperinc = (&triv(0, 1) - &triv(0,0));
int lowerinc = (&triv(1, 0) - &triv(0,0));
if (layout == CblasColMajor) {
assert(upperinc == N);
assert(lowerinc == 1);
} else {
assert(upperinc == 1);
assert(lowerinc == N);
}
for (int i = 0; i < N - 1; i++) {
cblas_dcopy(N - i - 1, &triv(i, i + 1), upperinc, &triv(i + 1, i),
lowerinc);
}

cblas_dtrsm(layout, uplo_to_rside(uplo), uplo, 'T', 'N', N, N, 1.0, A,
lda, tri, N);

cblas_dpotrs(layout, uplo, N, N, A, lda, tri, N, nullptr);

#define Av(r, c) \
dA[(r) * (layout == CblasRowMajor ? lda : 1) + \
(c) * (layout == CblasRowMajor ? 1 : lda)]

int Aupperinc = (&Av(0, 1) - &Av(0,0));
int Alowerinc = (&Av(1, 0) - &Av(0,0));
if (layout == CblasColMajor) {
assert(Aupperinc == lda);
assert(Alowerinc == 1);
} else {
assert(Aupperinc == 1);
assert(Alowerinc == lda);
}

for (int i = 0; i < N; i++) {
cblas_daxpy(N - i, -1.0, &triv(i, i), is_lower ? lowerinc : upperinc,
&Av(i, i), is_lower ? Alowerinc : Aupperinc);
}

cblas_dpotrs(layout, uplo, N, Nrhs, A, lda, dB, incB, nullptr);

checkTest(Test);

SkipVecIncCheck = true;
// Check memory of primal of expected derivative
checkMemoryTrace(inputs, "Expected " + Test, calls);

// Check memory of primal of our derivative (if equal above, it
// should be the same).
checkMemoryTrace(inputs, "Found " + Test, foundCalls);
SkipVecIncCheck = false;
}
{

std::string Test = "POTRS active B";

init();
__enzyme_autodiff((void *)my_potrs, enzyme_const, layout, enzyme_const,
uplo, enzyme_const, N, enzyme_const, Nrhs, enzyme_const, A,
enzyme_const, lda, enzyme_dup, B, dB, enzyme_const, incB);
foundCalls = calls;
init();

my_potrs(layout, uplo, N, Nrhs, A, lda, B, incB);

inDerivative = true;


cblas_dpotrs(layout, uplo, N, Nrhs, A, lda, dB, incB, nullptr);

checkTest(Test);
// Check memory of primal of expected derivative
checkMemoryTrace(inputs, "Expected " + Test, calls);

// Check memory of primal of our derivative (if equal above, it
// should be the same).
checkMemoryTrace(inputs, "Found " + Test, foundCalls);
}
}
}
}

int main() {
dotTests();

Expand All @@ -1248,4 +1409,6 @@ int main() {
syrkTests();

potrfTests();

potrsTests();
}
2 changes: 1 addition & 1 deletion enzyme/test/Integration/ReverseMode/cublas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ static void gemmTests() {

// TODO we are currently faking support here, this needs to be actually implemented
double c10 = 1.0;
cublasDlascl(handle, (cublasOperation_t)'G', 0, 0, &c10, &beta, M, N,
cublasDlascl(handle, (cublasOperation_t)2, 0, 0, &c10, &beta, M, N,
dC, incC, 0);

checkTest(Test);
Expand Down
84 changes: 82 additions & 2 deletions enzyme/test/Integration/blasinfra.h
Original file line number Diff line number Diff line change
Expand Up @@ -316,6 +316,7 @@ enum class CallType {
SYMM,
NRM2,
POTRF,
POTRS,
TRSM,
};

Expand Down Expand Up @@ -445,6 +446,9 @@ void printty(CallType v) {
case CallType::POTRF:
printf("POTRF");
return;
case CallType::POTRS:
printf("POTRS");
return;
case CallType::TRSM:
printf("TRSM");
return;
Expand Down Expand Up @@ -957,6 +961,29 @@ void printcall(BlasCall rcall) {
printty(rcall.iarg4);
printf(")");
return;
case CallType::POTRS:
printf("POTRS(abi=");
printty(rcall.abi);
printf(", handle=");
printty(rcall.handle);
printf(", layout=");
printty(rcall.layout);
printf(", uplo=");
printty(rcall.uplo);
printf(", N=");
printty(rcall.iarg1);
printf(", Nrhs=");
printty(rcall.iarg2);
printf(", A=");
printty(rcall.pin_arg1);
printf(", lda=");
printty(rcall.iarg4);
printf(", B=");
printty(rcall.pout_arg1);
printf(", ldb=");
printty(rcall.iarg5);
printf(")");
return;
case CallType::SYRK:
printf("SYRK(abi=");
printty(rcall.abi);
Expand Down Expand Up @@ -1857,6 +1884,37 @@ __attribute__((noinline)) void cblas_dpotrf(char layout, char uplo,
calls.push_back(call);
}

// The factorization has the form
// A = U**T * U, if UPLO = 'U', or
// A = L * L**T, if UPLO = 'L',
__attribute__((noinline)) void cblas_dpotrs(char layout, char uplo,
int N, int Nrhs, double *A, int lda,
double *B, int ldb, int* info) {
BlasCall call = {ABIType::CBLAS,
UNUSED_HANDLE,
inDerivative,
CallType::POTRS,
B,
A,
UNUSED_POINTER,
UNUSED_DOUBLE,
UNUSED_DOUBLE,
layout,
UNUSED_TRANS,
UNUSED_TRANS,
N,
Nrhs,
UNUSED_INT,
lda,
ldb,
UNUSED_INT,
UNUSED_INT,
UNUSED_TRANS,
uplo,
UNUSED_TRANS};
calls.push_back(call);
}

// Solve op( A )*X = alpha*B, or X*op( A ) = alpha*B
__attribute__((noinline)) void cblas_dtrsm(char layout, char side, char uplo,
char trans, char diag, int M, int N,
Expand Down Expand Up @@ -2286,8 +2344,11 @@ void checkMemory(BlasCall rcall, BlasInfo inputs[6], std::string test,
auto lda = rcall.iarg4;

// = 'G': A is a full matrix.
assert(type == 'G' || type == 'L' || type == 'l' || type == 'U' ||
type == 'u');
if (rcall.abi == ABIType::CUBLAS || rcall.abi == ABIType::CUBLASv2)
assert(type == (char)2);
else
assert(type == 'G' || type == 'L' || type == 'l' || type == 'U' ||
type == 'u');

// A is an m-by-n matrix
checkMatrix(A, "A", layout, /*rows=*/M, /*cols=*/N, /*ld=*/lda, test, rcall,
Expand Down Expand Up @@ -2599,6 +2660,25 @@ void checkMemory(BlasCall rcall, BlasInfo inputs[6], std::string test,
/*cols=*/N, /*ld=*/lda, test, rcall, trace);
return;
}
case CallType::POTRS: {
auto B = pointer_to_index(rcall.pout_arg1, inputs);
auto A = pointer_to_index(rcall.pin_arg1, inputs);

auto lda = rcall.iarg4;
auto ldb = rcall.iarg5;
auto layout = rcall.layout;
auto N = rcall.iarg1;
auto Nrhs = rcall.iarg2;

auto uplo_char = rcall.uplo;

checkMatrix(A, "A", layout, /*rows=*/N,
/*cols=*/N, /*ld=*/lda, test, rcall, trace);

checkMatrix(B, "B", layout, /*rows=*/N,
/*cols=*/Nrhs, /*ld=*/ldb, test, rcall, trace);
return;
}
case CallType::SYR2K: {
// C := alpha*A*B**T + alpha*B*A**T + beta*C or C := alpha*A**T*B + alpha*B**T*A + beta*C
auto C = pointer_to_index(rcall.pout_arg1, inputs);
Expand Down
Loading

0 comments on commit 273a3a7

Please sign in to comment.