Skip to content

Commit

Permalink
Add symm (#2037)
Browse files Browse the repository at this point in the history
* Add symm

* wip

* Reverse symm

* more testing

* fix
  • Loading branch information
wsmoses authored Aug 11, 2024
1 parent 5ab5470 commit 636db9f
Show file tree
Hide file tree
Showing 5 changed files with 477 additions and 11 deletions.
59 changes: 52 additions & 7 deletions enzyme/Enzyme/BlasDerivatives.td
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,6 @@ def ger : CallBlasPattern<(Op $layout, $m, $n, $alpha, $x, $incx, $y, $incy, $A,
// }

// C := alpha*op( A )*op( B ) + beta*C
// FWD: dC = dalpha A B + alpha dA B + alpha A dB + dbeta C + beta dC
def gemm : CallBlasPattern<(Op $layout, $transa, $transb, $m, $n, $k, $alpha, $A, $lda, $B, $ldb, $beta, $C, $ldc),
["C"],
[cblas_layout, trans, trans, len, len, len, fp, mld<["transa", "m", "k"]>, mld<["transb", "k", "n"]>, fp, mld<["m", "n"]>],
Expand Down Expand Up @@ -365,6 +364,7 @@ def gemm : CallBlasPattern<(Op $layout, $transa, $transb, $m, $n, $k, $alpha, $A
/* beta */ (FrobInnerProd<""> $m, $n, (Shadow $C), input<"C">),
/* C */ (BlasCall<"lascl"> $layout, Char<"G">, ConstantInt<0>, ConstantInt<0>, Constant<"1.0">, $beta, $m, $n, (Shadow $C), Alloca<1>)
],
// FWD: dC = dalpha A B + alpha dA B + alpha A dB + dbeta C + beta dC
(Seq<[], ["beta1"], 1>
(BlasCall<"axpy"> (AssertingInactiveArg), (Shadow $beta), $C, (Shadow $C)),
(BlasCall<"gemm"> $layout, $transa, $transb, $m, $n, $k, $alpha, $A, (ld $A, $transa, $lda, $k, $m), (Shadow $B), (FirstUse<"beta1"> $beta, Constant<"1">), (Shadow $C)),
Expand Down Expand Up @@ -411,7 +411,7 @@ def trmm : CallBlasPattern<(Op $layout, $side, $uplo, $transa, $diag, $m, $n, $a
ConstantInt<1>
)
),
/* B */ (BlasCall<"trmm"> $layout, $side, $uplo, transpose<"transa">, $diag, $m, $n, $alpha, $A, (ld $A, $side, $lda, $m, $n), (Shadow $B))
/* B */ (BlasCall<"trmm"> $layout, $side, $uplo, transpose<"transa">, $diag, $m, $n, $alpha, $A, (ld $A, (side_to_trans $side), $lda, $n, $m), (Shadow $B))
]
>;

Expand All @@ -420,11 +420,56 @@ def symm: CallBlasPattern<(Op $layout, $side, $uplo, $m, $n, $alpha, $A, $lda, $
[cblas_layout, side, uplo, len, len, fp, mld<["side", "m", "n"]>, mld<["m", "n"]>, fp, mld<["m", "n"]>],
[
/*alpha*/ (AssertingInactiveArg),
/*A*/ (AssertingInactiveArg),
/*B*/ (AssertingInactiveArg),
/*A*/ (Seq<["tmp", "is_left", "side", "m", "n"], [], 1>
(BlasCall<"copy">
(ISelect (is_left $side), $m, $n),
(First (Shadow $A)),
(Add $lda, ConstantInt<1>),
use<"tmp">,
ConstantInt<1>
),
(BlasCall<"syr2k">
$layout,
$uplo,
(side_to_trans $side),
(ISelect (is_left $side), $m, $n),
(ISelect (is_left $side), $n, $m),
$alpha,
$B, (ld $B, Char<"N">, $ldb, $m, $m),
(Shadow $C),
Constant<"1.0">,
(Shadow $A)
),
(BlasCall<"axpy">
(ISelect (is_left $side), $m, $n),
Constant<"-1">,
(First (Shadow $A)),
(Add $lda, ConstantInt<1>),
use<"tmp">,
ConstantInt<1>
),
(BlasCall<"axpy">
(ISelect (is_left $side), $m, $n),
Constant<"0.5">,
use<"tmp">,
ConstantInt<1>,
(First (Shadow $A)),
(Add $lda, ConstantInt<1>)
)
),
/*B*/ (BlasCall<"symm"> $layout, $side, $uplo, $m, $n, $alpha, $A, (ld $A, (side_to_trans $side), $lda, $n, $m), (Shadow $C), Constant<"1">, (Shadow $B)),
/*beta*/ (AssertingInactiveArg),
/*C*/ (AssertingInactiveArg),
]>;
/*C*/ (BlasCall<"lascl"> $layout, Char<"G">, ConstantInt<0>, ConstantInt<0>, Constant<"1.0">, $beta, $m, $n, (Shadow $C), Alloca<1>)
],
// FWD: dC = dalpha A B + alpha dA B + alpha A dB + dbeta C + beta dC
(Seq<[], ["beta1"], 1>
(BlasCall<"axpy"> (AssertingInactiveArg), (Shadow $beta), $C, (Shadow $C)),
(BlasCall<"symm"> $layout, $side, $uplo, $m, $n, $alpha, $A, (ld $A, (side_to_trans $side), $lda, $n, $m), (Shadow $B), (FirstUse<"beta1"> $beta, Constant<"1">), (Shadow $C)),
(BlasCall<"symm"> $layout, $side, $uplo, $m, $n, $alpha, (Shadow $A), $B, (ld $B, Char<"N">, $ldb, $m, $m), (FirstUse<"beta1"> $beta, Constant<"1">), (Shadow $C)),
(BlasCall<"symm"> $layout, $side, $uplo, $m, $n, (Shadow $alpha), $A, (ld $A, (side_to_trans $side), $lda, $n, $m), $B, (ld $B, Char<"N">, $ldb, $m, $m), (FirstUse<"beta1"> $beta, Constant<"1">), (Shadow $C)),
(FirstUse<"beta1"> (BlasCall<"lascl"> $layout, Char<"G">, ConstantInt<0>, ConstantInt<0>, Constant<"1.0">, $beta, $m, $n, (Shadow $C), Alloca<1>))
)
>;

def syr2k : CallBlasPattern<(Op $layout, $uplo, $trans, $n, $k, $alpha, $A, $lda, $B, $ldb, $beta, $C, $ldc),
["C"],
Expand All @@ -444,7 +489,7 @@ def syrk : CallBlasPattern<(Op $layout, $uplo, $trans, $n, $k, $alpha, $A, $lda,
[

/* alpha */ (AssertingInactiveArg), /*(Seq<["AB", "product", "m", "n"], [], true>
(BlasCall<"gemm"> $layout, $transa, $transb, $m, $n, $k, Constant<"1.0">, $A, (ld $A, $transa, $lda, $k, $m), $B, (ld $B, $transb, $ldb, $k, $n), Constant<"0.0">, use<"AB">, $m),// TODO: check if last arg should be $m or $n
(BlasCall<"gemm"> $layout, $transa, $transb, $m, $n, $k, Constant<"1.0">, $A, (ld $A, $transa, $lda, $n, $k), $B, (ld $B, $transb, $ldb, $k, $n), Constant<"0.0">, use<"AB">, $m),// TODO: check if last arg should be $m or $n
(FrobInnerProd<""> $m, $n, (Shadow $C), use<"AB">)),*/
/* A */
(Seq<[], [], 1>
Expand Down
Loading

0 comments on commit 636db9f

Please sign in to comment.