diff --git a/enzyme/Enzyme/BlasDerivatives.td b/enzyme/Enzyme/BlasDerivatives.td index 602d95cd9354..a0fb421db150 100644 --- a/enzyme/Enzyme/BlasDerivatives.td +++ b/enzyme/Enzyme/BlasDerivatives.td @@ -86,35 +86,20 @@ class IntMatchers _before, list after = _after; } -def trans_to_side : IntMatchers< - "charType", "charType", - ["'n'", "'N'", "'T'", "'t'"], - ["'l'", "'L'", "'U'", "'u'"] - >; - def side_to_trans : IntMatchers< "charType", "charType", ["'l'", "'L'", "'R'", "'r'"], ["'n'", "'N'", "'T'", "'t'"] >; -def is_upper : IntMatchers< - "charType", "Builder2.getInt1Ty()", - ["'u'", "'U'", "'L'", "'l'"], - ["true", "true", "false", "false"] - >; - def is_diag_int : IntMatchers< "charType", "intType", ["'u'", "'U'", "'N'", "'n'"], ["1", "1", "0", "0"] >; -def is_left : IntMatchers< - "charType", "Builder2.getInt1Ty()", - ["'l'", "'L'", "'R'", "'r'"], - ["true", "true", "false", "false"] - >; +def is_left : MagicInst; +def is_lower : MagicInst; def First : MagicInst; def Lookup : MagicInst; @@ -141,6 +126,12 @@ class input { string name = _name; } +// only applicable to triangular matricies, like a regular use of $A +// except that the non-set terms are assured to be set to 0 +class zero_cached { + string name = _name; +} + class Constant { string value = _value; } @@ -155,11 +146,12 @@ class transpose { string name = _name; } -class Seq _args = [], list _vars = []> { +class Seq _args = [], list _vars = [], bit _start = 1> { list args = _args; list vars = _vars; - + bit start = _start; } + class For { string idx = idx_; bit offset = offset_; @@ -189,20 +181,20 @@ class DiagUpdateSPMV { def scal : CallBlasPattern<(Op $n, $alpha, $x, $incx), ["x"],[len, fp, vinc<["n"]>], [ - // dot must proceed scal, because scal modifies (Shadow $x) + // dot must preceed scal, because scal modifies (Shadow $x) (BlasCall<"dot"> $n, $x, (Shadow $x)), (BlasCall<"scal"> $n, $alpha, (Shadow $x)) ], - (Seq<[], []> (BlasCall<"scal"> $n, $alpha, (Shadow $x)), (BlasCall<"axpy"> $n, (Shadow $alpha), $x, (Shadow $x))) + (Seq<[], [], 1> (BlasCall<"scal"> $n, $alpha, (Shadow $x)), (BlasCall<"axpy"> $n, (Shadow $alpha), $x, (Shadow $x))) >; -// def lacpy : CallBlasPattern<(Op $layout, $m, $n, $A, $lda, $B, $ldb), -// ["B"],[cblas_layout, len, len, mld<["m", "n"]>, mld<["m", "n"]>], -// [ -// (AssertingInactiveArg), // from -// (AssertingInactiveArg), // to -// ] -// >; +def lacpy : CallBlasPattern<(Op $layout, $uplo, $m, $n, $A, $lda, $B, $ldb), + ["B"],[cblas_layout, uplo, len, len, mld<["m", "n"]>, mld<["m", "n"]>], + [ + (AssertingInactiveArg), // from + (AssertingInactiveArg), // to + ] + >; def lascl : CallBlasPattern<(Op $layout, $type, $kl, $ku, $cfrom, $cto, $m, $n, $A, $lda, $info), ["A"],[cblas_layout, uplo, len, len, fp, fp, len, len, mld<["m", "n"]>, len], @@ -222,7 +214,7 @@ def axpy : CallBlasPattern<(Op $n, $alpha, $x, $incx, $y, $incy), (BlasCall<"axpy"> $n, $alpha, (Shadow $y), (Shadow $x)), (InactiveArg) // y = alpha*x + y, so nothing to do here ], - (Seq<[], []> (BlasCall<"axpy"> $n, $alpha, (Shadow $x), (Shadow $y)), (BlasCall<"axpy"> $n, (Shadow $alpha), $x, (Shadow $y))) + (Seq<[], [], 1> (BlasCall<"axpy"> $n, $alpha, (Shadow $x), (Shadow $y)), (BlasCall<"axpy"> $n, (Shadow $alpha), $x, (Shadow $y))) >; // x * y @@ -250,7 +242,7 @@ def copy : CallBlasPattern<(Op $n, $x, $incx, $y, $incy), (InactiveArg),// copy moves x into y, so x is never modified. (BlasCall<"axpy"> $n, Constant<"1.0">, (Shadow $y), (Shadow $x)) ], - (Seq<[], ["beta1"]> + (Seq<[], ["beta1"], 1> (BlasCall<"copy"> (FirstUse<"beta1"> $n, $n), (Shadow $x), (Shadow $y)), (FirstUse<"beta1"> (BlasCall<"scal"> $n, Constant<"0">, (Shadow $y))) ) @@ -277,7 +269,7 @@ def asum : CallBlasPattern<(Op $n, $x, $incx), def gemv : CallBlasPattern<(Op $layout, $transa, $m, $n, $alpha, $A, $lda, $x, $incx, $beta, $y, $incy), ["y"], [cblas_layout, trans, len, len, fp, mld<["m", "n"]>, vinc<["transa", "n", "m"]>, fp, vinc<["transa", "m", "n"]>], [ - /* alpha */ (Seq<["Ax", "is_normal", "transa", "m", "n"], []> + /* alpha */ (Seq<["Ax", "is_normal", "transa", "m", "n"], [], 1> (BlasCall<"gemv"> $layout, $transa, $m, $n, Constant<"1.0">, $A, (ld $A, Char<"N">, $lda, $m, $m), $x, Constant<"0.0">, use<"Ax">, ConstantInt<1>), (BlasCall<"dot"> (Rows $transa, $m, $n), (Shadow $y), use<"Ax">, ConstantInt<1>)), @@ -291,7 +283,7 @@ def gemv : CallBlasPattern<(Op $layout, $transa, $m, $n, $alpha, $A, $lda, $x, $ /* beta */ (BlasCall<"dot"> (Rows $transa, $m, $n), (Shadow $y), input<"y">), /* y */ (BlasCall<"scal"> (Rows $transa, $m, $n), $beta, (Shadow $y)) ], - (Seq<[], ["beta1"]> + (Seq<[], ["beta1"], 1> (BlasCall<"axpy"> (Rows $transa, $m, $n), (Shadow $beta), $y, (Shadow $y)), (BlasCall<"gemv"> $layout, $transa, $m, $n, $alpha, $A, (ld $A, Char<"N">, $lda, $m, $m), (Shadow $x), (FirstUse<"beta1"> $beta, Constant<"1">), (Shadow $y)), (BlasCall<"gemv"> $layout, $transa, $m, $n, $alpha, (Shadow $A), $x, (FirstUse<"beta1"> $beta, Constant<"1">), (Shadow $y)), @@ -303,17 +295,17 @@ def gemv : CallBlasPattern<(Op $layout, $transa, $m, $n, $alpha, $A, $lda, $x, $ // x = Ax // currently assumes for vector dimensions that transa = 'N' and gets dimensions wrong otherwise def trmv : CallBlasPattern<(Op $layout, $uplo, $trans, $diag, $n, $A, $lda, $x, $incx), - ["x"], [cblas_layout, uplo, trans, diag, len, mld<["diag", "n", "n"]>, vinc<["n"]>], + ["x"], [cblas_layout, uplo, trans, diag, len, mld<["uplo", "n", "n"]>, vinc<["n"]>], [ - /* A */ (For<"i", 1> (ISelect (is_upper $uplo), $n, (Sub $n, (is_diag_int $diag))), + /* A */ (For<"i", 1> (ISelect (is_lower $uplo), (Sub $n, (is_diag_int $diag)), $n), (BlasCall<"axpy"> - (ISelect (is_upper $uplo), - (Sub $i, (is_diag_int $diag)), - (Add (Sub (Sub $n, (is_diag_int $diag)), $i), ConstantInt<1>) + (ISelect (is_lower $uplo), + (Add (Sub (Sub $n, (is_diag_int $diag)), $i), ConstantInt<1>), + (Sub $i, (is_diag_int $diag)) ), (LoadLookup $layout, (Rows $trans, input<"x">, (Shadow $x)), (Sub $i, ConstantInt<1>)), - (Lookup $layout, (Rows $trans, (Shadow $x), input<"x">), (ISelect (is_upper $uplo), ConstantInt<0>, (Sub (Add $i, (is_diag_int $diag)), ConstantInt<1>))), - (First (Lookup $layout, (Shadow $A), (ISelect (is_upper $uplo), ConstantInt<0>, (Sub (Add $i, (is_diag_int $diag)), ConstantInt<1>)), (Sub $i, ConstantInt<1>))), + (Lookup $layout, (Rows $trans, (Shadow $x), input<"x">), (ISelect (is_lower $uplo), (Sub (Add $i, (is_diag_int $diag)), ConstantInt<1>), ConstantInt<0>)), + (First (Lookup $layout, (Shadow $A), (ISelect (is_lower $uplo), (Sub (Add $i, (is_diag_int $diag)), ConstantInt<1>), ConstantInt<0>), (Sub $i, ConstantInt<1>))), ConstantInt<1> ) ), @@ -344,7 +336,7 @@ def gemm : CallBlasPattern<(Op $layout, $transa, $transb, $m, $n, $k, $alpha, $A [cblas_layout, trans, trans, len, len, len, fp, mld<["transa", "m", "k"]>, mld<["transb", "k", "n"]>, fp, mld<["m", "n"]>], [ - /* alpha */ (Seq<["AB", "product", "m", "n"], []> + /* alpha */ (Seq<["AB", "product", "m", "n"], [], 1> (BlasCall<"gemm"> $layout, $transa, $transb, $m, $n, $k, Constant<"1.0">, $A, (ld $A, $transa, $lda, $k, $m), $B, (ld $B, $transb, $ldb, $k, $n), Constant<"0.0">, use<"AB">, $m),// TODO: check if last arg should be $m or $n (FrobInnerProd<""> $m, $n, (Shadow $C), use<"AB">)), /* A */ (BlasCall<"gemm"> $layout, (Rows $transa, @@ -367,7 +359,7 @@ def gemm : CallBlasPattern<(Op $layout, $transa, $transb, $m, $n, $k, $alpha, $A /* beta */ (FrobInnerProd<""> $m, $n, (Shadow $C), input<"C">), /* C */ (BlasCall<"lascl"> $layout, Char<"G">, ConstantInt<0>, ConstantInt<0>, Constant<"1.0">, $beta, $m, $n, (Shadow $C), Alloca<1>) ], - (Seq<[], ["beta1"]> + (Seq<[], ["beta1"], 1> (BlasCall<"axpy"> (AssertingInactiveArg), (Shadow $beta), $C, (Shadow $C)), (BlasCall<"gemm"> $layout, $transa, $transb, $m, $n, $k, $alpha, $A, (ld $A, $transa, $lda, $k, $m), (Shadow $B), (FirstUse<"beta1"> $beta, Constant<"1">), (Shadow $C)), (BlasCall<"gemm"> $layout, $transa, $transb, $m, $n, $k, $alpha, (Shadow $A), $B, (ld $B, $transb, $ldb, $n, $k), (FirstUse<"beta1"> $beta, Constant<"1">), (Shadow $C)), @@ -382,33 +374,33 @@ def trmm : CallBlasPattern<(Op $layout, $side, $uplo, $transa, $diag, $m, $n, $a [cblas_layout, side, uplo, trans, diag, len, len, fp, mld<["side", "m", "n"]>, mld<["m","n"]>], [ /*alpha*/ (AssertingInactiveArg), - /* A */ (For<"i", 1> (Sub (ISelect (is_left $side), $m, $n), (ISelect (is_upper $uplo), ConstantInt<0>, (is_diag_int $diag))), + /* A */ (For<"i", 1> (Sub (ISelect (is_left $side), $m, $n), (ISelect (is_lower $uplo), (is_diag_int $diag), ConstantInt<0>)), (BlasCall<"gemv"> $layout, (side_to_trans $side), (ISelect (is_left $side), (Concat - (ISelect (is_upper $uplo), - (Sub $i, (is_diag_int $diag)), - (Sub $m, (Sub (Add $i, (is_diag_int $diag)), ConstantInt<1>))), + (ISelect (is_lower $uplo), + (Sub $m, (Sub (Add $i, (is_diag_int $diag)), ConstantInt<1>)), + (Sub $i, (is_diag_int $diag))), $n), (Concat $m, - (ISelect (is_upper $uplo), - (Sub $i, (is_diag_int $diag)), - (Sub $n, (Sub (Add $i, (is_diag_int $diag)), ConstantInt<1>))) + (ISelect (is_lower $uplo), + (Sub $n, (Sub (Add $i, (is_diag_int $diag)), ConstantInt<1>)), + (Sub $i, (is_diag_int $diag))) ) ), $alpha, (Lookup $layout, (ISelect (BXor (is_left $side), (Not (Rows $transa))), (Shadow $B), (Concat input<"B">, $m)), - (ISelect (BAnd (is_left $side), (Not (is_upper $uplo))), (Sub (Add $i, (is_diag_int $diag)), ConstantInt<1>), ConstantInt<0>), - (ISelect (BAnd (Not (is_left $side)), (Not (is_upper $uplo))), (Sub (Add $i, (is_diag_int $diag)), ConstantInt<1>), ConstantInt<0>)), + (ISelect (BAnd (is_left $side), (is_lower $uplo)), (Sub (Add $i, (is_diag_int $diag)), ConstantInt<1>), ConstantInt<0>), + (ISelect (BAnd (Not (is_left $side)), (is_lower $uplo)), (Sub (Add $i, (is_diag_int $diag)), ConstantInt<1>), ConstantInt<0>)), (First (Lookup $layout, (ISelect (BXor (is_left $side), (Not (Rows $transa))), (Concat input<"B">, $m), (Shadow $B)), (ISelect (is_left $side), (Sub $i, ConstantInt<1>), ConstantInt<0>), (ISelect (is_left $side), ConstantInt<0>, (Sub $i, ConstantInt<1>)))), (ISelect (is_left $side), (ISelect (BXor (is_left $side), (Not (Rows $transa))), $m, $ldb), ConstantInt<1>), Constant<"1">, (First (Lookup $layout, (Shadow $A), - (ISelect (is_upper $uplo), ConstantInt<0>, (Sub (Add $i, (is_diag_int $diag)), ConstantInt<1>)), + (ISelect (is_lower $uplo), (Sub (Add $i, (is_diag_int $diag)), ConstantInt<1>), ConstantInt<0>), (Sub $i, ConstantInt<1>))), ConstantInt<1> ) @@ -445,11 +437,11 @@ def syrk : CallBlasPattern<(Op $layout, $uplo, $trans, $n, $k, $alpha, $A, $lda, [cblas_layout, uplo, trans, len, len, fp, mld<["trans", "n", "k"]>, fp, mld<["n", "n"]>], [ - /* alpha */ (AssertingInactiveArg), /*(Seq<["AB", "product", "m", "n"], []> + /* alpha */ (AssertingInactiveArg), /*(Seq<["AB", "product", "m", "n"], [], true> (BlasCall<"gemm"> $layout, $transa, $transb, $m, $n, $k, Constant<"1.0">, $A, (ld $A, $transa, $lda, $k, $m), $B, (ld $B, $transb, $ldb, $k, $n), Constant<"0.0">, use<"AB">, $m),// TODO: check if last arg should be $m or $n (FrobInnerProd<""> $m, $n, (Shadow $C), use<"AB">)),*/ /* A */ - (Seq<[], []> + (Seq<[], [], 1> (BlasCall<"symm"> $layout, (Rows $trans, Char<"l">, Char<"r">), @@ -489,7 +481,7 @@ def syrk : CallBlasPattern<(Op $layout, $uplo, $trans, $n, $k, $alpha, $A, $lda, /* beta */ (AssertingInactiveArg), /* C */ (BlasCall<"lascl"> $layout, $uplo, ConstantInt<0>, ConstantInt<0>, Constant<"1.0">, $beta, $n, $n, (Shadow $C), Alloca<1>) ], - (Seq<[], ["beta1"]> + (Seq<[], ["beta1"], 1> (BlasCall<"axpy"> (AssertingInactiveArg), (Shadow $beta), $C, (Shadow $C)), (BlasCall<"syr2k"> $layout, $uplo, $trans, $n, $k, $alpha, $A, $lda, (Shadow $A), (FirstUse<"beta1"> $beta, Constant<"1">), (Shadow $C)), (BlasCall<"syrk"> $layout, $uplo, $trans, $n, $k, (Shadow $alpha), $A, $lda, (FirstUse<"beta1"> $beta, Constant<"1">), (Shadow $C)), @@ -501,10 +493,10 @@ def spmv : CallBlasPattern<(Op $layout, $uplo, $n, $alpha, $ap, $x, $incx, $beta ["y"], [cblas_layout, uplo, len, fp, ap<["n"]>, vinc<["n"]>, fp, vinc<["n"]>], [ - /* alpha */ (Seq<["y0", "triangular", "n"], []> + /* alpha */ (Seq<["y0", "triangular", "n"], [], 1> (BlasCall<"spmv"> $layout, $uplo, $n, Constant<"1.0">, $ap, $x, Constant<"0.0">, use<"y0">, ConstantInt<1>), (BlasCall<"dot"> $n, (Shadow $y), use<"y0">, ConstantInt<1>)), - /* ap */ (Seq<[], []> + /* ap */ (Seq<[], [], 1> (BlasCall<"spr2"> $layout, $uplo, $n, $alpha, $x, (Shadow $y), (Shadow $ap)), (DiagUpdateSPMV<""> $uplo, $n, $alpha, $x, (Shadow $y), (Shadow $ap))), /* x */ (BlasCall<"spmv"> $layout, $uplo, $n, $alpha, $ap, (Shadow $y), Constant<"1.0">, (Shadow $x)), @@ -564,8 +556,133 @@ def spr2 : CallBlasPattern<(Op $layout, $uplo, $n, $alpha, $x, $incx, $y, $incy, // // Lv 3 // // -// def : CallBlasPattern<(Op $layout, $side, $uplo, $transa, $diag, $m, $n, $alpha, $a, $lda, $b, $ldb), -// ["trsm"], -// [cblas_layout, side, uplo, trans, diag, len, len, fp, vld, vld], -// [] -// >; + +// Solve op( A )*X = alpha*B, or X*op( A ) = alpha*B, +def trsm: CallBlasPattern<(Op $layout, $side, $uplo, $transa, $diag, $m, $n, $alpha, $A, $lda, $B, $ldb), + ["B"], + [cblas_layout, side, uplo, trans, diag, len, len, fp, mld<["side", "m", "n"]>, mld<["m","n"]>], + [ + /* alpha */ (AssertingInactiveArg), + /* A */ (AssertingInactiveArg), + /* B */ (AssertingInactiveArg), + ] + >; + +def uplo_to_normal : IntMatchers< + "charType", "charType", + ["'l'", "'L'", "'U'", "'u'"], + ["'n'", "'N'", "'T'", "'t'"] + >; +def uplo_to_trans : IntMatchers< + "charType", "charType", + ["'l'", "'L'", "'U'", "'u'"], + ["'t'", "'T'", "'N'", "'n'"] + >; +def flip_uplo : IntMatchers< + "charType", "charType", + ["'l'", "'L'", "'U'", "'u'"], + ["'u'", "'U'", "'L'", "'l'"] + >; +def uplo_to_side : IntMatchers< + "charType", "charType", + ["'l'", "'L'", "'U'", "'u'"], + ["'L'", "'L'", "'R'", "'R'"] + >; +def uplo_to_rside : IntMatchers< + "charType", "charType", + ["'l'", "'L'", "'U'", "'u'"], + ["'R'", "'R'", "'L'", "'L'"] + >; + + +def potrf: CallBlasPattern<(Op $layout, $uplo, $n, $A, $lda, $info), + ["A"], + [cblas_layout, uplo, len, mld<["uplo", "n", "n"]>, len], + [ + /* A */ + (Seq<["tri", "triangular", "n"], [], 0> + (BlasCall<"lascl"> $layout, (flip_uplo $uplo), ConstantInt<0>, ConstantInt<0>, Constant<"1.0">, Constant<"0.0">, $n, $n, use<"tri">, $n, Alloca<1>), + (BlasCall<"lacpy"> $layout, $uplo, $n, $n, (Shadow $A), use<"tri">, $n), + + (BlasCall<"trmm"> $layout, (uplo_to_side $uplo), $uplo, Char<"T">, Char<"N">, $n, $n, Constant<"1.0">, $A, (ld $A, Char<"N">, $lda, $n, $n), use<"tri">, $n), + + (Seq<["tmp", "vector", "n"], [], 1> + + // Zero out flipped side again + (BlasCall<"copy"> $n, use<"tri">, (Add $n, ConstantInt<1>), use<"tmp">, ConstantInt<1>), + (BlasCall<"scal"> $n, Constant<"0.5">, use<"tmp">, ConstantInt<1>), + + (BlasCall<"lascl"> $layout, (flip_uplo $uplo), ConstantInt<0>, ConstantInt<0>, Constant<"1.0">, Constant<"0.0">, $n, $n, use<"tri">, $n, Alloca<1>), + (BlasCall<"copy"> $n, use<"tmp">, ConstantInt<1>, use<"tri">, (Add $n, ConstantInt<1>)) + ), + + (BlasCall<"trsm"> $layout, (uplo_to_rside $uplo), $uplo, Char<"N">, Char<"N">, $n, $n, Constant<"1.0">, $A, (ld $A, Char<"N">, $lda, $n, $n), use<"tri">, $n), + + (BlasCall<"trsm"> $layout, (uplo_to_side $uplo), $uplo, Char<"T">, Char<"N">, $n, $n, Constant<"1.0">, $A, (ld $A, Char<"N">, $lda, $n, $n), use<"tri">, $n), + (For<"i", 0> (Sub $n, ConstantInt<1>), + (BlasCall<"axpy"> + (Sub (Sub $n, ConstantInt<1>), $i), + Constant<"1.0">, + (First + (Lookup $layout, + (Concat use<"tri">, $n), + $i, + (Add $i, ConstantInt<1>) + ) + ), + (First + (Lookup $layout, + (Concat ConstantInt<0>, $n), + ConstantInt<0>, + ConstantInt<1> + ) + ), + (First + (Lookup $layout, + (Concat use<"tri">, $n), + (Add $i, ConstantInt<1>), + $i + ) + ), + (First + (Lookup $layout, + (Concat ConstantInt<0>, $n), + ConstantInt<1>, + ConstantInt<0> + ) + ) + ) + ), + (BlasCall<"lacpy"> $layout, $uplo, $n, $n, use<"tri">, $n, (Shadow $A)) + ) + ], + (Seq<["tri", "triangular", "n"], [], 0> + (BlasCall<"lacpy"> $layout, (flip_uplo $uplo), $n, $n, (Shadow $A), use<"tri">, $n), + + // Zero the strictly other side of shadow [and not diagonal]. We need to save the diagonal and restore since otherwise it's overwritten + (BlasCall<"lascl"> $layout, (flip_uplo $uplo), ConstantInt<0>, ConstantInt<0>, Constant<"1.0">, Constant<"0.0">, $n, $n, (Shadow $A), Alloca<1>), + (BlasCall<"copy"> $n, use<"tri">, (Add $lda, ConstantInt<1>), (First (Shadow $A)), (Add $lda, ConstantInt<1>)), + + // Actual Math + (BlasCall<"trsm"> $layout, Char<"L">, $uplo, (uplo_to_normal $uplo), Char<"N">, $n, $n, Constant<"1.0">, $A, $lda, (Shadow $A)), + (BlasCall<"trsm"> $layout, Char<"R">, $uplo, (uplo_to_trans $uplo), Char<"N">, $n, $n, Constant<"1.0">, $A, $lda, (Shadow $A)), + (BlasCall<"scal"> $n, Constant<"0.5">, (First (Shadow $A)), (Add $lda, ConstantInt<1>)), + + (Seq<["tmp", "vector", "n"], [], 1> + + // Zero out flipped side again + (BlasCall<"copy"> $n, (First (Shadow $A)), (Add $lda, ConstantInt<1>), use<"tmp">, ConstantInt<1>), + + (BlasCall<"lascl"> $layout, (flip_uplo $uplo), ConstantInt<0>, ConstantInt<0>, Constant<"1.0">, Constant<"0.0">, $n, $n, (Shadow $A), Alloca<1>), + (BlasCall<"copy"> $n, use<"tmp">, ConstantInt<1>, (First (Shadow $A)), (Add $lda, ConstantInt<1>)), + + // More math + (BlasCall<"trmm"> $layout, (uplo_to_side $uplo), $uplo, Char<"N">, Char<"N">, $n, $n, Constant<"1.0">, $A, $lda, (Shadow $A)), + + // Restore the flipped side, but saving our computed diagonal + (BlasCall<"copy"> $n, (First (Shadow $A)), (Add $lda, ConstantInt<1>), use<"tmp">, ConstantInt<1>), + (BlasCall<"lacpy"> $layout, (flip_uplo $uplo), $n, $n, use<"tri">, $n, (Shadow $A)), + (BlasCall<"copy"> $n, use<"tmp">, ConstantInt<1>, (First (Shadow $A)), (Add $lda, ConstantInt<1>)) + ) + ) + >; diff --git a/enzyme/Enzyme/GradientUtils.cpp b/enzyme/Enzyme/GradientUtils.cpp index 0217fe99be0d..10ec661f6d30 100644 --- a/enzyme/Enzyme/GradientUtils.cpp +++ b/enzyme/Enzyme/GradientUtils.cpp @@ -8226,7 +8226,6 @@ void GradientUtils::eraseFictiousPHIs() { for (auto pair : phis) { auto pp = pair.first; if (pp->getNumUses() != 0) { - assert(0); if (CustomErrorHandler) { std::string str; raw_string_ostream ss(str); diff --git a/enzyme/Enzyme/Utils.cpp b/enzyme/Enzyme/Utils.cpp index 66ea9e337e20..a61348e830cb 100644 --- a/enzyme/Enzyme/Utils.cpp +++ b/enzyme/Enzyme/Utils.cpp @@ -659,8 +659,10 @@ void callMemcpyStridedBlas(llvm::IRBuilder<> &B, llvm::Module &M, BlasInfo blas, llvm::ArrayRef args, llvm::Type *copy_retty, llvm::ArrayRef bundles) { - auto copy_name = - std::string(blas.prefix) + blas.floatType + "copy" + blas.suffix; + const bool cublasv2 = + blas.prefix == "cublas" && StringRef(blas.suffix).contains("v2"); + auto copy_name = std::string(blas.prefix) + blas.floatType + "copy" + + (cublasv2 ? "" : blas.suffix); SmallVector tys; for (auto arg : args) @@ -793,9 +795,7 @@ void callSPMVDiagUpdate(IRBuilder<> &B, Module &M, BlasInfo blas, cast(blasalpha->getType())->getAddressSpace())); alpha = B1.CreateLoad(fpTy, VP); } - Value *is_u = is_uper(B1, blasuplo, byRef); - // Value *k = B1.CreateSelect(is_u, ConstantInt::get(IT, 0), - // ConstantInt::get(IT, 1), "k"); + Value *is_l = is_lower(B1, blasuplo, byRef, /*cublas*/ false); B1.CreateCondBr(B1.CreateICmpEQ(n, ConstantInt::get(IT, 0)), end, init); IRBuilder<> B2(init); @@ -811,7 +811,7 @@ void callSPMVDiagUpdate(IRBuilder<> &B, Module &M, BlasInfo blas, blasdAP, PointerType::get( fpTy, cast(blasdAP->getType())->getAddressSpace())); - B2.CreateCondBr(is_u, uper_code, lower_code); + B2.CreateCondBr(is_l, lower_code, uper_code); IRBuilder<> B3(uper_code); B3.setFastMathFlags(getFast()); @@ -2654,9 +2654,11 @@ std::optional extractBLAS(llvm::StringRef in) llvm::Optional extractBLAS(llvm::StringRef in) #endif { - const char *extractable[] = {"dot", "scal", "axpy", "gemv", "gemm", "spmv", - "syrk", "nrm2", "trmm", "trmv", "symm"}; - const char *floatType[] = {"s", "d"}; // c, z + const char *extractable[] = {"dot", "scal", "axpy", "gemv", "gemm", + "spmv", "syrk", "nrm2", "trmm", "trmv", + "symm", "potrf", "copy", "spmv", "syr2k", + "potrs", "getrf", "getrs", "trtrs", "getri"}; + const char *floatType[] = {"s", "d", "c", "z"}; const char *prefixes[] = {"" /*Fortran*/, "cblas_"}; const char *suffixes[] = {"", "_", "64_", "_64_"}; for (auto t : floatType) { @@ -2674,8 +2676,8 @@ llvm::Optional extractBLAS(llvm::StringRef in) } } // c interface to cublas - const char *cuCFloatType[] = {"S", "D"}; // c, z - const char *cuFFloatType[] = {"s", "d"}; // c, z + const char *cuCFloatType[] = {"S", "D", "C", "Z"}; + const char *cuFFloatType[] = {"s", "d", "c", "z"}; const char *cuCPrefixes[] = {"cublas"}; const char *cuSuffixes[] = {"", "_v2", "_64", "_v2_64"}; for (auto t : llvm::enumerate(cuCFloatType)) { @@ -2737,13 +2739,13 @@ llvm::FastMathFlags getFast() { void addValueToCache(llvm::Value *arg, bool cache_arg, llvm::Type *ty, llvm::SmallVectorImpl &cacheValues, llvm::IRBuilder<> &BuilderZ, const Twine &name) { + if (!cache_arg) + return; if (!arg->getType()->isPointerTy()) { assert(arg->getType() == ty); cacheValues.push_back(arg); return; } - if (!cache_arg) - return; #if LLVM_VERSION_MAJOR < 17 auto PT = cast(arg->getType()); #if LLVM_VERSION_MAJOR <= 14 @@ -2796,38 +2798,47 @@ llvm::Value *to_blas_fp_callconv(IRBuilder<> &B, llvm::Value *V, bool byRef, return allocV; } -llvm::Value *select_vec_dims(IRBuilder<> &B, llvm::Value *trans, - llvm::Value *dim1, llvm::Value *dim2, bool byRef, - bool cublas) { - auto norm = is_normal(B, trans, byRef, cublas); - Value *width = B.CreateSelect(norm, dim1, dim2); - - return width; -} - -Value *is_uper(IRBuilder<> &B, Value *trans, bool byRef) { - IntegerType *charTy; +Value *is_lower(IRBuilder<> &B, Value *uplo, bool byRef, bool cublas) { + if (cublas) { + Value *isNormal = nullptr; + isNormal = B.CreateICmpEQ( + uplo, ConstantInt::get(uplo->getType(), + /*cublasFillMode_t::CUBLAS_FILL_MODE_LOWER*/ 0)); + return isNormal; + } + if (auto CI = dyn_cast(uplo)) { + if (CI->getValue() == 'L' || CI->getValue() == 'l') + return ConstantInt::getTrue(B.getContext()); + if (CI->getValue() == 'U' || CI->getValue() == 'u') + return ConstantInt::getFalse(B.getContext()); + } if (byRef) { // can't inspect opaque ptr, so assume 8 (Julia) - charTy = IntegerType::get(trans->getContext(), 8); - trans = B.CreateLoad(charTy, trans, "loaded.trans"); + IntegerType *charTy = IntegerType::get(uplo->getContext(), 8); + uplo = B.CreateLoad(charTy, uplo, "loaded.trans"); + + auto isL = B.CreateICmpEQ(uplo, ConstantInt::get(uplo->getType(), 'L')); + auto isl = B.CreateICmpEQ(uplo, ConstantInt::get(uplo->getType(), 'l')); + // fortran blas + return B.CreateOr(isl, isL); } else { // we can inspect scalars - unsigned int len = trans->getType()->getScalarSizeInBits(); - charTy = IntegerType::get(trans->getContext(), len); + auto capi = B.CreateICmpEQ(uplo, ConstantInt::get(uplo->getType(), 122)); + // TODO we really should just return capi, but for sake of consistency, + // we will accept either here. + auto isL = B.CreateICmpEQ(uplo, ConstantInt::get(uplo->getType(), 'L')); + auto isl = B.CreateICmpEQ(uplo, ConstantInt::get(uplo->getType(), 'l')); + return B.CreateOr(capi, B.CreateOr(isl, isL)); } - - Value *isUper = - B.CreateOr(B.CreateICmpEQ(trans, ConstantInt::get(charTy, 'u')), - B.CreateICmpEQ(trans, ConstantInt::get(charTy, 'U'))); - return isUper; } llvm::Value *is_normal(IRBuilder<> &B, llvm::Value *trans, bool byRef, bool cublas) { if (cublas) { Value *isNormal = nullptr; - isNormal = B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 0)); + isNormal = B.CreateICmpEQ( + trans, ConstantInt::get(trans->getType(), + /*cublasOperation_t::CUBLAS_OP_N*/ 0)); return isNormal; } // Explicitly support 'N' always, since we use in the rule infra @@ -2841,13 +2852,56 @@ llvm::Value *is_normal(IRBuilder<> &B, llvm::Value *trans, bool byRef, IntegerType *charTy = IntegerType::get(trans->getContext(), 8); trans = B.CreateLoad(charTy, trans, "loaded.trans"); - auto isN = B.CreateICmpEQ(trans, ConstantInt::get(charTy, 'N')); - auto isn = B.CreateICmpEQ(trans, ConstantInt::get(charTy, 'n')); + auto isN = B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 'N')); + auto isn = B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 'n')); // fortran blas return B.CreateOr(isn, isN); } else { + // TODO we really should just return capi, but for sake of consistency, + // we will accept either here. // we can inspect scalars - return B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 111)); + auto capi = B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 111)); + auto isN = B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 'N')); + auto isn = B.CreateICmpEQ(trans, ConstantInt::get(trans->getType(), 'n')); + // fortran blas + return B.CreateOr(capi, B.CreateOr(isn, isN)); + } +} + +llvm::Value *is_left(IRBuilder<> &B, llvm::Value *side, bool byRef, + bool cublas) { + if (cublas) { + Value *isNormal = nullptr; + isNormal = B.CreateICmpEQ( + side, ConstantInt::get(side->getType(), + /*cublasSideMode_t::CUBLAS_SIDE_LEFT*/ 0)); + return isNormal; + } + // Explicitly support 'L'/'R' always, since we use in the rule infra + if (auto CI = dyn_cast(side)) { + if (CI->getValue() == 'L' || CI->getValue() == 'l') + return ConstantInt::getTrue(B.getContext()); + if (CI->getValue() == 'R' || CI->getValue() == 'r') + return ConstantInt::getFalse(B.getContext()); + } + if (byRef) { + // can't inspect opaque ptr, so assume 8 (Julia) + IntegerType *charTy = IntegerType::get(side->getContext(), 8); + side = B.CreateLoad(charTy, side, "loaded.side"); + + auto isL = B.CreateICmpEQ(side, ConstantInt::get(side->getType(), 'L')); + auto isl = B.CreateICmpEQ(side, ConstantInt::get(side->getType(), 'l')); + // fortran blas + return B.CreateOr(isl, isL); + } else { + // TODO we really should just return capi, but for sake of consistency, + // we will accept either here. + // we can inspect scalars + auto capi = B.CreateICmpEQ(side, ConstantInt::get(side->getType(), 141)); + auto isL = B.CreateICmpEQ(side, ConstantInt::get(side->getType(), 'L')); + auto isl = B.CreateICmpEQ(side, ConstantInt::get(side->getType(), 'l')); + // fortran blas + return B.CreateOr(capi, B.CreateOr(isl, isL)); } } diff --git a/enzyme/Enzyme/Utils.h b/enzyme/Enzyme/Utils.h index 8a088827c1f8..2f98ebbdb00a 100644 --- a/enzyme/Enzyme/Utils.h +++ b/enzyme/Enzyme/Utils.h @@ -1848,10 +1848,11 @@ static inline llvm::SmallVector concat_values(T &&...t) { llvm::Value *is_normal(llvm::IRBuilder<> &B, llvm::Value *trans, bool byRef, bool cublas); -llvm::Value *is_uper(llvm::IRBuilder<> &B, llvm::Value *trans, bool byRef); -llvm::Value *select_vec_dims(llvm::IRBuilder<> &B, llvm::Value *trans, - llvm::Value *dim1, llvm::Value *dim2, bool byRef, - bool cublas); +llvm::Value *is_left(llvm::IRBuilder<> &B, llvm::Value *side, bool byRef, + bool cublas); +llvm::Value *is_lower(llvm::IRBuilder<> &B, llvm::Value *uplo, bool byRef, + bool cublas); + // first one assume V is an Integer llvm::Value *transpose(llvm::IRBuilder<> &B, llvm::Value *V, bool cublas); // secon one assume V is an Integer or a ptr to an int (depends on byRef) diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemv_c_loop.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemv_c_loop.ll index e8b47fcf5310..8f6ffe0db61d 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemv_c_loop.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemv_c_loop.ll @@ -56,17 +56,16 @@ entry: ; CHECK-NEXT: %malloccall = tail call noalias nonnull i8* @malloc(i32 %mallocsize) ; CHECK-NEXT: %cache.A = bitcast i8* %malloccall to double* ; CHECK-NEXT: call void @cblas_dlacpy(i32 101, i8 0, i32 %N, i32 %N, double* %K, i32 %N, double* %cache.A, i32 %N) -; CHECK-NEXT: %1 = select i1 true, i32 %N, i32 %N -; CHECK-NEXT: %mallocsize1 = mul nuw nsw i32 %1, 8 +; CHECK-NEXT: %mallocsize1 = mul nuw nsw i32 %N, 8 ; CHECK-NEXT: %malloccall2 = tail call noalias nonnull i8* @malloc(i32 %mallocsize1) -; CHECK-NEXT: %2 = load i8**, i8*** %malloccall2_cache, align 8, !dereferenceable !6, !invariant.group !2 -; CHECK-NEXT: %3 = getelementptr inbounds i8*, i8** %2, i64 %iv -; CHECK-NEXT: store i8* %malloccall2, i8** %3, align 8, !invariant.group !7 -; CHECK-NEXT: %4 = load i8**, i8*** %malloccall_cache, align 8, !dereferenceable !6, !invariant.group !5 -; CHECK-NEXT: %5 = getelementptr inbounds i8*, i8** %4, i64 %iv -; CHECK-NEXT: store i8* %malloccall, i8** %5, align 8, !invariant.group !8 +; CHECK-NEXT: %[[i2:.+]] = load i8**, i8*** %malloccall2_cache, align 8, !dereferenceable !6, !invariant.group !2 +; CHECK-NEXT: %[[i3:.+]] = getelementptr inbounds i8*, i8** %[[i2]], i64 %iv +; CHECK-NEXT: store i8* %malloccall2, i8** %[[i3]], align 8, !invariant.group !7 +; CHECK-NEXT: %[[i4:.+]] = load i8**, i8*** %malloccall_cache, align 8, !dereferenceable !6, !invariant.group !5 +; CHECK-NEXT: %[[i5:.+]] = getelementptr inbounds i8*, i8** %[[i4]], i64 %iv +; CHECK-NEXT: store i8* %malloccall, i8** %[[i5]], align 8, !invariant.group !8 ; CHECK-NEXT: %cache.x = bitcast i8* %malloccall2 to double* -; CHECK-NEXT: call void @cblas_dcopy(i32 %1, double* %x0, i32 1, double* %cache.x, i32 1) +; CHECK-NEXT: call void @cblas_dcopy(i32 %N, double* %x0, i32 1, double* %cache.x, i32 1) ; CHECK-NEXT: tail call void @cblas_dgemv(i32 noundef 101, i32 noundef 111, i32 noundef %N, i32 noundef %N, double noundef 1.000000e-03, double* noundef %K, i32 noundef %N, double* noundef %x0, i32 noundef 1, double noundef 1.000000e+00, double* noundef %v0, i32 noundef 1) ; CHECK-NEXT: %exitcond.not = icmp eq i64 %iv.next, 5000 ; CHECK-NEXT: br i1 %exitcond.not, label %for.cond.cleanup, label %for.body diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemv_c_loop2.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemv_c_loop2.ll index 57e2d4153baf..7e1a594d3dab 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemv_c_loop2.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemv_c_loop2.ll @@ -56,17 +56,16 @@ entry: ; CHECK-NEXT: %malloccall = tail call noalias nonnull i8* @malloc(i32 %mallocsize) ; CHECK-NEXT: %cache.A = bitcast i8* %malloccall to double* ; CHECK-NEXT: call void @cblas_dlacpy(i32 101, i8 0, i32 %N, i32 %N, double* %K, i32 %N, double* %cache.A, i32 %N) -; CHECK-NEXT: %1 = select i1 true, i32 %N, i32 %N -; CHECK-NEXT: %mallocsize1 = mul nuw nsw i32 %1, 8 +; CHECK-NEXT: %mallocsize1 = mul nuw nsw i32 %N, 8 ; CHECK-NEXT: %malloccall2 = tail call noalias nonnull i8* @malloc(i32 %mallocsize1) -; CHECK-NEXT: %2 = load i8**, i8*** %malloccall2_cache, align 8, !dereferenceable !6, !invariant.group !2 -; CHECK-NEXT: %3 = getelementptr inbounds i8*, i8** %2, i64 %iv -; CHECK-NEXT: store i8* %malloccall2, i8** %3, align 8, !invariant.group !7 -; CHECK-NEXT: %4 = load i8**, i8*** %malloccall_cache, align 8, !dereferenceable !6, !invariant.group !5 -; CHECK-NEXT: %5 = getelementptr inbounds i8*, i8** %4, i64 %iv -; CHECK-NEXT: store i8* %malloccall, i8** %5, align 8, !invariant.group !8 +; CHECK-NEXT: %[[i2:.+]] = load i8**, i8*** %malloccall2_cache, align 8, !dereferenceable !6, !invariant.group !2 +; CHECK-NEXT: %[[i3:.+]] = getelementptr inbounds i8*, i8** %[[i2]], i64 %iv +; CHECK-NEXT: store i8* %malloccall2, i8** %[[i3]], align 8, !invariant.group !7 +; CHECK-NEXT: %[[i4:.+]] = load i8**, i8*** %malloccall_cache, align 8, !dereferenceable !6, !invariant.group !5 +; CHECK-NEXT: %[[i5:.+]] = getelementptr inbounds i8*, i8** %[[i4]], i64 %iv +; CHECK-NEXT: store i8* %malloccall, i8** %[[i5]], align 8, !invariant.group !8 ; CHECK-NEXT: %cache.x = bitcast i8* %malloccall2 to double* -; CHECK-NEXT: call void @cblas_dcopy(i32 %1, double* %x0, i32 1, double* %cache.x, i32 1) +; CHECK-NEXT: call void @cblas_dcopy(i32 %N, double* %x0, i32 1, double* %cache.x, i32 1) ; CHECK-NEXT: tail call void @cblas_dgemv(i32 noundef 101, i32 noundef 111, i32 noundef %N, i32 noundef %N, double noundef 1.000000e-03, double* noundef %K, i32 noundef %N, double* noundef %x0, i32 noundef 1, double noundef 1.000000e+00, double* noundef %v0, i32 noundef 1) ; CHECK-NEXT: %exitcond.not = icmp eq i64 %iv.next, 5000 ; CHECK-NEXT: br i1 %exitcond.not, label %for.cond.cleanup, label %for.body diff --git a/enzyme/test/Enzyme/ReverseMode/blas/gemv_c_loop3_matcopy.ll b/enzyme/test/Enzyme/ReverseMode/blas/gemv_c_loop3_matcopy.ll index a84be7d7d0b5..012861b841aa 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/gemv_c_loop3_matcopy.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/gemv_c_loop3_matcopy.ll @@ -62,11 +62,10 @@ entry: ; CHECK-NEXT: br i1 %7, label %__enzyme_memcpy_double_mat_32.exit, label %init.idx.i ; CHECK: __enzyme_memcpy_double_mat_32.exit: ; preds = %entry, %init.end.i -; CHECK-NEXT: %8 = select i1 true, i32 %N, i32 %N -; CHECK-NEXT: %mallocsize22 = mul nuw nsw i32 %8, 8 +; CHECK-NEXT: %mallocsize22 = mul nuw nsw i32 %N, 8 ; CHECK-NEXT: %malloccall23 = tail call noalias nonnull i8* @malloc(i32 %mallocsize22) ; CHECK-NEXT: %cache.x24 = bitcast i8* %malloccall23 to double* -; CHECK-NEXT: call void @cblas_dcopy(i32 %8, double* %x0, i32 1, double* %cache.x24, i32 1) +; CHECK-NEXT: call void @cblas_dcopy(i32 %N, double* %x0, i32 1, double* %cache.x24, i32 1) ; CHECK-NEXT: tail call void @cblas_dgemv(i32 noundef 101, i32 noundef 111, i32 noundef %N, i32 noundef %N, double noundef 1.000000e-03, double* noundef %K, i32 noundef %N, double* noundef %x0, i32 noundef 1, double noundef 1.000000e+00, double* noundef %v0, i32 noundef 1) ; CHECK-NEXT: %[[i11:.+]] = mul i32 %N, %N ; CHECK-NEXT: %mallocsize11 = mul nuw nsw i32 %[[i11]], 8 @@ -100,11 +99,10 @@ entry: ; CHECK-NEXT: br i1 %[[i18:.+]], label %__enzyme_memcpy_double_mat_32.exit38, label %init.idx.i29 ; CHECK: __enzyme_memcpy_double_mat_32.exit38: ; preds = %__enzyme_memcpy_double_mat_32.exit, %init.end.i37 -; CHECK-NEXT: %[[i19:.+]] = select i1 true, i32 %N, i32 %N -; CHECK-NEXT: %mallocsize14 = mul nuw nsw i32 %[[i19]], 8 +; CHECK-NEXT: %mallocsize14 = mul nuw nsw i32 %N, 8 ; CHECK-NEXT: %malloccall15 = tail call noalias nonnull i8* @malloc(i32 %mallocsize14) ; CHECK-NEXT: %cache.x16 = bitcast i8* %malloccall15 to double* -; CHECK-NEXT: call void @cblas_dcopy(i32 %[[i19]], double* %x0, i32 1, double* %cache.x16, i32 1) +; CHECK-NEXT: call void @cblas_dcopy(i32 %N, double* %x0, i32 1, double* %cache.x16, i32 1) ; CHECK-NEXT: tail call void @cblas_dgemv(i32 noundef 101, i32 noundef 111, i32 noundef %N, i32 noundef %N, double noundef 1.000000e-03, double* noundef %K, i32 noundef %N, double* noundef %x0, i32 noundef 1, double noundef 1.000000e+00, double* noundef %v0, i32 noundef 1) ; CHECK-NEXT: %[[i22:.+]] = mul i32 %N, %N ; CHECK-NEXT: %mallocsize3 = mul nuw nsw i32 %[[i22]], 8 @@ -138,11 +136,10 @@ entry: ; CHECK-NEXT: br i1 %[[i29]], label %__enzyme_memcpy_double_mat_32.exit50, label %init.idx.i41 ; CHECK: __enzyme_memcpy_double_mat_32.exit50: ; preds = %__enzyme_memcpy_double_mat_32.exit38, %init.end.i49 -; CHECK-NEXT: %[[i30:.+]] = select i1 true, i32 %N, i32 %N -; CHECK-NEXT: %mallocsize6 = mul nuw nsw i32 %[[i30]], 8 +; CHECK-NEXT: %mallocsize6 = mul nuw nsw i32 %N, 8 ; CHECK-NEXT: %malloccall7 = tail call noalias nonnull i8* @malloc(i32 %mallocsize6) ; CHECK-NEXT: %cache.x8 = bitcast i8* %malloccall7 to double* -; CHECK-NEXT: call void @cblas_dcopy(i32 %[[i30]], double* %x0, i32 1, double* %cache.x8, i32 1) +; CHECK-NEXT: call void @cblas_dcopy(i32 %N, double* %x0, i32 1, double* %cache.x8, i32 1) ; CHECK-NEXT: tail call void @cblas_dgemv(i32 noundef 101, i32 noundef 111, i32 noundef %N, i32 noundef %N, double noundef 1.000000e-03, double* noundef %K, i32 noundef %N, double* noundef %x0, i32 noundef 1, double noundef 1.000000e+00, double* noundef %v0, i32 noundef 1) ; CHECK-NEXT: %[[i33:.+]] = mul i32 %N, %N ; CHECK-NEXT: %mallocsize = mul nuw nsw i32 %[[i33]], 8 @@ -176,11 +173,10 @@ entry: ; CHECK-NEXT: br i1 %[[i40]], label %__enzyme_memcpy_double_mat_32.exit62, label %init.idx.i53 ; CHECK: __enzyme_memcpy_double_mat_32.exit62: ; preds = %__enzyme_memcpy_double_mat_32.exit50, %init.end.i61 -; CHECK-NEXT: %[[i41:.+]] = select i1 true, i32 %N, i32 %N -; CHECK-NEXT: %mallocsize1 = mul nuw nsw i32 %[[i41]], 8 +; CHECK-NEXT: %mallocsize1 = mul nuw nsw i32 %N, 8 ; CHECK-NEXT: %malloccall2 = tail call noalias nonnull i8* @malloc(i32 %mallocsize1) ; CHECK-NEXT: %cache.x = bitcast i8* %malloccall2 to double* -; CHECK-NEXT: call void @cblas_dcopy(i32 %[[i41]], double* %x0, i32 1, double* %cache.x, i32 1) +; CHECK-NEXT: call void @cblas_dcopy(i32 %N, double* %x0, i32 1, double* %cache.x, i32 1) ; CHECK-NEXT: tail call void @cblas_dgemv(i32 noundef 101, i32 noundef 111, i32 noundef %N, i32 noundef %N, double noundef 1.000000e-03, double* noundef %K, i32 noundef %N, double* noundef %x0, i32 noundef 1, double noundef 1.000000e+00, double* noundef %v0, i32 noundef 1) ; CHECK-NEXT: tail call void @cblas_dgemv(i32 noundef 101, i32 noundef 111, i32 noundef %N, i32 noundef %N, double noundef 1.000000e-03, double* noundef %K, i32 noundef %N, double* noundef %x0, i32 noundef 1, double noundef 1.000000e+00, double* noundef %v0, i32 noundef 1) ; CHECK-NEXT: br label %invertentry diff --git a/enzyme/test/Enzyme/ReverseMode/blas/spmv_f_c_lacpy.ll b/enzyme/test/Enzyme/ReverseMode/blas/spmv_f_c_lacpy.ll index ea99f479c8d8..d870496ece07 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/spmv_f_c_lacpy.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/spmv_f_c_lacpy.ll @@ -92,10 +92,8 @@ entry: ; CHECK-NEXT: %[[i6:.+]] = bitcast i8* %n_p to i64* ; CHECK-NEXT: %[[i7:.+]] = load i64, i64* %[[i6]] -; CHECK-NEXT: %[[i8:.+]] = add i64 %[[i7]], 1 -; CHECK-NEXT: %square_mat_size_y0 = mul i64 %[[i7]], %[[i8]] -; CHECK-NEXT: %size_y0 = udiv i64 %square_mat_size_y0, 2 -; CHECK-NEXT: %mallocsize4 = mul nuw nsw i64 %size_y0, 8 +; CHECK-NEXT: %[[i8:.+]] = mul i64 %[[i7]], %[[i7]] +; CHECK-NEXT: %mallocsize4 = mul nuw nsw i64 %[[i8]], 8 ; CHECK-NEXT: %malloccall5 = tail call noalias nonnull i8* @malloc(i64 %mallocsize4) ; CHECK-NEXT: %[[mat_y0:.+]] = bitcast i8* %malloccall5 to double* ; CHECK-NEXT: %[[i9:.+]] = bitcast double* %[[mat_y0]] to i8* @@ -124,8 +122,8 @@ entry: ; CHECK-NEXT: %[[i21:.+]] = bitcast i8* %alpha to double* ; CHECK-NEXT: %[[i22:.+]] = load double, double* %[[i21]] ; CHECK-NEXT: %loaded.trans.i = load i8, i8* %uplo -; CHECK-DAG: %[[i0:.+]] = icmp eq i8 %loaded.trans.i, 85 -; CHECK-DAG: %[[i1:.+]] = icmp eq i8 %loaded.trans.i, 117 +; CHECK-DAG: %[[i0:.+]] = icmp eq i8 %loaded.trans.i, 76 +; CHECK-DAG: %[[i1:.+]] = icmp eq i8 %loaded.trans.i, 108 ; CHECK-NEXT: %[[i25:.+]] = or i1 %[[i1]], %[[i0]] ; CHECK-NEXT: %[[i26:.+]] = icmp eq i64 %[[i17]], 0 ; CHECK-NEXT: br i1 %[[i26]], label %__enzyme_spmv_diagd_64_.exit, label %init.i @@ -134,7 +132,7 @@ entry: ; CHECK-NEXT: %[[i27:.+]] = bitcast i8* %X to double* ; CHECK-NEXT: %[[i28:.+]] = bitcast i8* %incx_p to double* ; CHECK-NEXT: %[[i29:.+]] = bitcast i8* %incy_p to double* -; CHECK-NEXT: br i1 %[[i25]], label %uper.i, label %lower.i +; CHECK-NEXT: br i1 %[[i25]], label %lower.i, label %uper.i ; CHECK: uper.i: ; preds = %uper.i, %init.i ; CHECK-NEXT: %iteration.i = phi i64 [ 0, %init.i ], [ %iter.next.i, %uper.i ] @@ -206,8 +204,8 @@ entry: ; CHECK-NEXT: %7 = bitcast i8* %blasalpha to double* ; CHECK-NEXT: %8 = load double, double* %7 ; CHECK-NEXT: %loaded.trans = load i8, i8* %blasuplo -; CHECK-DAG: %[[i9:.+]] = icmp eq i8 %loaded.trans, 85 -; CHECK-DAG: %[[i10:.+]] = icmp eq i8 %loaded.trans, 117 +; CHECK-DAG: %[[i9:.+]] = icmp eq i8 %loaded.trans, 76 +; CHECK-DAG: %[[i10:.+]] = icmp eq i8 %loaded.trans, 108 ; CHECK-NEXT: %11 = or i1 %[[i10]], %[[i9]] ; CHECK-NEXT: %12 = icmp eq i64 %2, 0 ; CHECK-NEXT: br i1 %12, label %for.end, label %init @@ -216,7 +214,7 @@ entry: ; CHECK-NEXT: %13 = bitcast i8* %blasx to double* ; CHECK-NEXT: %14 = bitcast i8* %blasdy to double* ; CHECK-NEXT: %15 = bitcast i8* %blasdAP to double* -; CHECK-NEXT: br i1 %11, label %uper, label %lower +; CHECK-NEXT: br i1 %11, label %lower, label %uper ; CHECK: uper: ; preds = %uper, %init ; CHECK-NEXT: %iteration = phi i64 [ 0, %init ], [ %iter.next, %uper ] diff --git a/enzyme/test/Enzyme/ReverseMode/blas/syrk_f.ll b/enzyme/test/Enzyme/ReverseMode/blas/syrk_f.ll index 73c3d16fa964..91f424768727 100644 --- a/enzyme/test/Enzyme/ReverseMode/blas/syrk_f.ll +++ b/enzyme/test/Enzyme/ReverseMode/blas/syrk_f.ll @@ -104,124 +104,112 @@ entry: ; CHECK-NEXT: store i8 114, i8* %byref.constant.char.r, align 1 ; CHECK-NEXT: store i8 108, i8* %byref.constant.char.l, align 1 ; CHECK-NEXT: %ld.row.trans = load i8, i8* %trans, align 1 -; CHECK-NEXT: %1 = icmp eq i8 %ld.row.trans, 110 -; CHECK-NEXT: %2 = icmp eq i8 %ld.row.trans, 78 -; CHECK-NEXT: %3 = or i1 %2, %1 -; CHECK-NEXT: %4 = select i1 %3, i8* %byref.constant.char.l, i8* %byref.constant.char.r +; CHECK-NEXT: %[[i1:.+]] = icmp eq i8 %ld.row.trans, 110 +; CHECK-NEXT: %[[i2:.+]] = icmp eq i8 %ld.row.trans, 78 +; CHECK-NEXT: %[[i3:.+]] = or i1 %[[i2]], %[[i1]] +; CHECK-NEXT: %[[i4:.+]] = select i1 %[[i3]], i8* %byref.constant.char.l, i8* %byref.constant.char.r ; CHECK-NEXT: %ld.row.trans1 = load i8, i8* %trans, align 1 -; CHECK-NEXT: %5 = icmp eq i8 %ld.row.trans1, 110 -; CHECK-NEXT: %6 = icmp eq i8 %ld.row.trans1, 78 -; CHECK-NEXT: %7 = or i1 %6, %5 -; CHECK-NEXT: %8 = select i1 %7, i8* %n_p, i8* %k_p -; CHECK-NEXT: %9 = select i1 %7, i8* %k_p, i8* %n_p +; CHECK-NEXT: %[[i5:.+]] = icmp eq i8 %ld.row.trans1, 110 +; CHECK-NEXT: %[[i6:.+]] = icmp eq i8 %ld.row.trans1, 78 +; CHECK-NEXT: %[[i7:.+]] = or i1 %[[i6]], %[[i5]] +; CHECK-NEXT: %[[i8:.+]] = select i1 %[[i7:.+]], i8* %n_p, i8* %k_p +; CHECK-NEXT: %[[i9:.+]] = select i1 %[[i7:.+]], i8* %k_p, i8* %n_p ; CHECK-NEXT: store double 1.000000e+00, double* %byref.constant.fp.1, align 8 ; CHECK-NEXT: %fpcast.constant.fp.1 = bitcast double* %byref.constant.fp.1 to i8* -; CHECK-NEXT: call void @dsymm_64_(i8* %4, i8* %uplo, i8* %8, i8* %9, i8* %alpha_p, i8* %"C'", i8* %ldc_p, i8* %A, i8* %lda_p, i8* %fpcast.constant.fp.1, i8* %"A'", i8* %lda_p) -; CHECK-NEXT: %10 = bitcast i8* %n_p to i64* -; CHECK-NEXT: %11 = load i64, i64* %10, align 4 -; CHECK-NEXT: %12 = icmp eq i64 %11, 0 -; CHECK-NEXT: br i1 %12, label %invertentry_end, label %invertentry_loop +; CHECK-NEXT: call void @dsymm_64_(i8* %[[i4]], i8* %uplo, i8* %[[i8]], i8* %[[i9]], i8* %alpha_p, i8* %"C'", i8* %ldc_p, i8* %A, i8* %lda_p, i8* %fpcast.constant.fp.1, i8* %"A'", i8* %lda_p) +; CHECK-NEXT: %[[i10:.+]] = bitcast i8* %n_p to i64* +; CHECK-NEXT: %[[i11:.+]] = load i64, i64* %[[i10]], align 4 +; CHECK-NEXT: %[[i12:.+]] = icmp eq i64 %[[i11]], 0 +; CHECK-NEXT: br i1 %[[i12]], label %invertentry_end, label %invertentry_loop ; CHECK: invertentry_loop: ; preds = %invertentry_loop, %invertentry -; CHECK-NEXT: %13 = phi i64 [ 0, %invertentry ], [ %14, %invertentry_loop ] -; CHECK-NEXT: %14 = add nuw nsw i64 %13, 1 -; CHECK-NEXT: store i64 %13, i64* %byref.for.i, align 4 +; CHECK-NEXT: %[[i13:.+]] = phi i64 [ 0, %invertentry ], [ %[[i14:.+]], %invertentry_loop ] +; CHECK-NEXT: %[[i14]] = add nuw nsw i64 %[[i13]], 1 +; CHECK-NEXT: store i64 %[[i13]], i64* %byref.for.i, align 4 ; CHECK-NEXT: %intcast.for.i = bitcast i64* %byref.for.i to i8* -; CHECK-NEXT: %15 = bitcast i8* %"C'" to double* -; CHECK-NEXT: %16 = bitcast i8* %ldc_p to i64* -; CHECK-NEXT: %17 = load i64, i64* %16, align 4 -; CHECK-NEXT: %18 = load i8, i8* %uplo, align 1 -; CHECK-NEXT: %19 = icmp eq i8 %18, 101 -; CHECK-NEXT: %20 = select i1 %19, i64 %17, i64 1 -; CHECK-NEXT: %21 = bitcast i8* %intcast.for.i to i64* -; CHECK-NEXT: %22 = load i64, i64* %21, align 4 -; CHECK-NEXT: %23 = mul i64 %22, %20 -; CHECK-NEXT: %24 = select i1 %19, i64 1, i64 %17 -; CHECK-NEXT: %25 = bitcast i8* %intcast.for.i to i64* -; CHECK-NEXT: %26 = load i64, i64* %25, align 4 -; CHECK-NEXT: %27 = mul i64 %26, %24 -; CHECK-NEXT: %28 = add i64 %23, %27 -; CHECK-NEXT: %29 = getelementptr double, double* %15, i64 %28 -; CHECK-NEXT: %30 = load double, double* %29, align 8 -; CHECK-NEXT: %31 = bitcast i8* %alpha_p to double* -; CHECK-NEXT: %32 = load double, double* %31, align 8 -; CHECK-NEXT: %33 = fmul fast double %32, %30 -; CHECK-NEXT: store double %33, double* %byref.FMul, align 8 +; CHECK-NEXT: %[[i16:.+]] = bitcast i8* %ldc_p to i64* +; CHECK-NEXT: %[[i17:.+]] = load i64, i64* %[[i16]], align 4 +; CHECK-NEXT: %[[i21:.+]] = bitcast i8* %intcast.for.i to i64* +; CHECK-NEXT: %[[i22:.+]] = load i64, i64* %[[i21]], align 4 +; CHECK-NEXT: %[[i23:.+]] = mul i64 %[[i22]], 1 +; CHECK-NEXT: %[[i25:.+]] = bitcast i8* %intcast.for.i to i64* +; CHECK-NEXT: %[[i26:.+]] = load i64, i64* %[[i25]], align 4 +; CHECK-NEXT: %[[i27:.+]] = mul i64 %[[i26]], %[[i17]] +; CHECK-NEXT: %[[i28:.+]] = add i64 %[[i23]], %[[i27]] +; CHECK-NEXT: %[[i15:.+]] = bitcast i8* %"C'" to double* +; CHECK-NEXT: %[[i29:.+]] = getelementptr double, double* %[[i15]], i64 %[[i28]] +; CHECK-NEXT: %[[i30:.+]] = load double, double* %[[i29]], align 8 +; CHECK-NEXT: %[[i31:.+]] = bitcast i8* %alpha_p to double* +; CHECK-NEXT: %[[i32:.+]] = load double, double* %[[i31]], align 8 +; CHECK-NEXT: %[[i33:.+]] = fmul fast double %[[i32]], %[[i30]] +; CHECK-NEXT: store double %[[i33]], double* %byref.FMul, align 8 ; CHECK-NEXT: store i64 0, i64* %byref.constant.int.0, align 4 ; CHECK-NEXT: %intcast.constant.int.0 = bitcast i64* %byref.constant.int.0 to i8* ; CHECK-NEXT: %ld.row.trans2 = load i8, i8* %trans, align 1 -; CHECK-NEXT: %34 = icmp eq i8 %ld.row.trans2, 110 -; CHECK-NEXT: %35 = icmp eq i8 %ld.row.trans2, 78 -; CHECK-NEXT: %36 = or i1 %35, %34 -; CHECK-NEXT: %37 = select i1 %36, i8* %intcast.for.i, i8* %intcast.constant.int.0 +; CHECK-NEXT: %[[i34:.+]] = icmp eq i8 %ld.row.trans2, 110 +; CHECK-NEXT: %[[i35:.+]] = icmp eq i8 %ld.row.trans2, 78 +; CHECK-NEXT: %[[i36:.+]] = or i1 %[[i35]], %[[i34]] +; CHECK-NEXT: %[[i37:.+]] = select i1 %[[i36]], i8* %intcast.for.i, i8* %intcast.constant.int.0 ; CHECK-NEXT: store i64 0, i64* %byref.constant.int.03, align 4 ; CHECK-NEXT: %intcast.constant.int.04 = bitcast i64* %byref.constant.int.03 to i8* ; CHECK-NEXT: %ld.row.trans5 = load i8, i8* %trans, align 1 -; CHECK-NEXT: %38 = icmp eq i8 %ld.row.trans5, 110 -; CHECK-NEXT: %39 = icmp eq i8 %ld.row.trans5, 78 -; CHECK-NEXT: %40 = or i1 %39, %38 -; CHECK-NEXT: %41 = select i1 %40, i8* %intcast.constant.int.04, i8* %intcast.for.i -; CHECK-NEXT: %42 = bitcast i8* %A to double* -; CHECK-NEXT: %43 = bitcast i8* %lda_p to i64* -; CHECK-NEXT: %44 = load i64, i64* %43, align 4 -; CHECK-NEXT: %45 = load i8, i8* %uplo, align 1 -; CHECK-NEXT: %46 = icmp eq i8 %45, 101 -; CHECK-NEXT: %47 = select i1 %46, i64 %44, i64 1 -; CHECK-NEXT: %48 = bitcast i8* %37 to i64* -; CHECK-NEXT: %49 = load i64, i64* %48, align 4 -; CHECK-NEXT: %50 = mul i64 %49, %47 -; CHECK-NEXT: %51 = select i1 %46, i64 1, i64 %44 -; CHECK-NEXT: %52 = bitcast i8* %41 to i64* -; CHECK-NEXT: %53 = load i64, i64* %52, align 4 -; CHECK-NEXT: %54 = mul i64 %53, %51 -; CHECK-NEXT: %55 = add i64 %50, %54 -; CHECK-NEXT: %56 = getelementptr double, double* %42, i64 %55 +; CHECK-NEXT: %[[i38:.+]] = icmp eq i8 %ld.row.trans5, 110 +; CHECK-NEXT: %[[i39:.+]] = icmp eq i8 %ld.row.trans5, 78 +; CHECK-NEXT: %[[i40:.+]] = or i1 %[[i39]], %[[i38]] +; CHECK-NEXT: %[[i41:.+]] = select i1 %[[i40]], i8* %intcast.constant.int.04, i8* %intcast.for.i +; CHECK-NEXT: %[[i43:.+]] = bitcast i8* %lda_p to i64* +; CHECK-NEXT: %[[i44:.+]] = load i64, i64* %[[i43]], align 4 +; CHECK-NEXT: %[[i48:.+]] = bitcast i8* %[[i37]] to i64* +; CHECK-NEXT: %[[i49:.+]] = load i64, i64* %[[i48]], align 4 +; CHECK-NEXT: %[[i50:.+]] = mul i64 %[[i49]], 1 +; CHECK-NEXT: %[[i52:.+]] = bitcast i8* %[[i41]] to i64* +; CHECK-NEXT: %[[i53:.+]] = load i64, i64* %[[i52]], align 4 +; CHECK-NEXT: %[[i54:.+]] = mul i64 %[[i53]], %[[i44]] +; CHECK-NEXT: %[[i55:.+]] = add i64 %[[i50]], %[[i54]] +; CHECK-NEXT: %[[i42:.+]] = bitcast i8* %A to double* +; CHECK-NEXT: %[[i56:.+]] = getelementptr double, double* %[[i42]], i64 %[[i55]] ; CHECK-NEXT: store i64 1, i64* %byref.constant.int.1, align 4 ; CHECK-NEXT: %intcast.constant.int.1 = bitcast i64* %byref.constant.int.1 to i8* ; CHECK-NEXT: %ld.row.trans6 = load i8, i8* %trans, align 1 -; CHECK-NEXT: %57 = icmp eq i8 %ld.row.trans6, 110 -; CHECK-NEXT: %58 = icmp eq i8 %ld.row.trans6, 78 -; CHECK-NEXT: %59 = or i1 %58, %57 -; CHECK-NEXT: %60 = select i1 %59, i8* %lda_p, i8* %intcast.constant.int.1 +; CHECK-NEXT: %[[i57:.+]] = icmp eq i8 %ld.row.trans6, 110 +; CHECK-NEXT: %[[i58:.+]] = icmp eq i8 %ld.row.trans6, 78 +; CHECK-NEXT: %[[i59:.+]] = or i1 %[[i58]], %[[i57]] +; CHECK-NEXT: %[[i60:.+]] = select i1 %[[i59]], i8* %lda_p, i8* %intcast.constant.int.1 ; CHECK-NEXT: store i64 0, i64* %byref.constant.int.07, align 4 ; CHECK-NEXT: %intcast.constant.int.08 = bitcast i64* %byref.constant.int.07 to i8* ; CHECK-NEXT: %ld.row.trans9 = load i8, i8* %trans, align 1 -; CHECK-NEXT: %61 = icmp eq i8 %ld.row.trans9, 110 -; CHECK-NEXT: %62 = icmp eq i8 %ld.row.trans9, 78 -; CHECK-NEXT: %63 = or i1 %62, %61 -; CHECK-NEXT: %64 = select i1 %63, i8* %intcast.for.i, i8* %intcast.constant.int.08 +; CHECK-NEXT: %[[i61:.+]] = icmp eq i8 %ld.row.trans9, 110 +; CHECK-NEXT: %[[i62:.+]] = icmp eq i8 %ld.row.trans9, 78 +; CHECK-NEXT: %[[i63:.+]] = or i1 %[[i62]], %[[i61]] +; CHECK-NEXT: %[[i64:.+]] = select i1 %[[i63]], i8* %intcast.for.i, i8* %intcast.constant.int.08 ; CHECK-NEXT: store i64 0, i64* %byref.constant.int.010, align 4 ; CHECK-NEXT: %intcast.constant.int.011 = bitcast i64* %byref.constant.int.010 to i8* ; CHECK-NEXT: %ld.row.trans12 = load i8, i8* %trans, align 1 -; CHECK-NEXT: %65 = icmp eq i8 %ld.row.trans12, 110 -; CHECK-NEXT: %66 = icmp eq i8 %ld.row.trans12, 78 -; CHECK-NEXT: %67 = or i1 %66, %65 -; CHECK-NEXT: %68 = select i1 %67, i8* %intcast.constant.int.011, i8* %intcast.for.i -; CHECK-NEXT: %69 = bitcast i8* %"A'" to double* -; CHECK-NEXT: %70 = bitcast i8* %lda_p to i64* -; CHECK-NEXT: %71 = load i64, i64* %70, align 4 -; CHECK-NEXT: %72 = load i8, i8* %uplo, align 1 -; CHECK-NEXT: %73 = icmp eq i8 %72, 101 -; CHECK-NEXT: %74 = select i1 %73, i64 %71, i64 1 -; CHECK-NEXT: %75 = bitcast i8* %64 to i64* -; CHECK-NEXT: %76 = load i64, i64* %75, align 4 -; CHECK-NEXT: %77 = mul i64 %76, %74 -; CHECK-NEXT: %78 = select i1 %73, i64 1, i64 %71 -; CHECK-NEXT: %79 = bitcast i8* %68 to i64* -; CHECK-NEXT: %80 = load i64, i64* %79, align 4 -; CHECK-NEXT: %81 = mul i64 %80, %78 -; CHECK-NEXT: %82 = add i64 %77, %81 -; CHECK-NEXT: %83 = getelementptr double, double* %69, i64 %82 +; CHECK-NEXT: %[[i65:.+]] = icmp eq i8 %ld.row.trans12, 110 +; CHECK-NEXT: %[[i66:.+]] = icmp eq i8 %ld.row.trans12, 78 +; CHECK-NEXT: %[[i67:.+]] = or i1 %[[i66]], %[[i65]] +; CHECK-NEXT: %[[i68:.+]] = select i1 %[[i67]], i8* %intcast.constant.int.011, i8* %intcast.for.i +; CHECK-NEXT: %[[i70:.+]] = bitcast i8* %lda_p to i64* +; CHECK-NEXT: %[[i71:.+]] = load i64, i64* %[[i70]], align 4 +; CHECK-NEXT: %[[i75:.+]] = bitcast i8* %[[i64]] to i64* +; CHECK-NEXT: %[[i76:.+]] = load i64, i64* %[[i75]], align 4 +; CHECK-NEXT: %[[i77:.+]] = mul i64 %[[i76]], 1 +; CHECK-NEXT: %[[i79:.+]] = bitcast i8* %[[i68]] to i64* +; CHECK-NEXT: %[[i80:.+]] = load i64, i64* %[[i79]], align 4 +; CHECK-NEXT: %[[i81:.+]] = mul i64 %[[i80]], %[[i71]] +; CHECK-NEXT: %[[i82:.+]] = add i64 %[[i77]], %[[i81]] +; CHECK-NEXT: %[[i69:.+]] = bitcast i8* %"A'" to double* +; CHECK-NEXT: %[[i83:.+]] = getelementptr double, double* %[[i69]], i64 %[[i82]] ; CHECK-NEXT: store i64 1, i64* %byref.constant.int.113, align 4 ; CHECK-NEXT: %intcast.constant.int.114 = bitcast i64* %byref.constant.int.113 to i8* ; CHECK-NEXT: %ld.row.trans15 = load i8, i8* %trans, align 1 -; CHECK-NEXT: %84 = icmp eq i8 %ld.row.trans15, 110 -; CHECK-NEXT: %85 = icmp eq i8 %ld.row.trans15, 78 -; CHECK-NEXT: %86 = or i1 %85, %84 -; CHECK-NEXT: %87 = select i1 %86, i8* %lda_p, i8* %intcast.constant.int.114 -; CHECK-NEXT: call void @daxpy_64_(i8* %k_p, double* %byref.FMul, double* %56, i8* %60, double* %83, i8* %87) -; CHECK-NEXT: %88 = icmp eq i64 %11, %14 -; CHECK-NEXT: br i1 %88, label %invertentry_end, label %invertentry_loop +; CHECK-NEXT: %[[i84:.+]] = icmp eq i8 %ld.row.trans15, 110 +; CHECK-NEXT: %[[i85:.+]] = icmp eq i8 %ld.row.trans15, 78 +; CHECK-NEXT: %[[i86:.+]] = or i1 %[[i85]], %[[i84]] +; CHECK-NEXT: %[[i87:.+]] = select i1 %[[i86]], i8* %lda_p, i8* %intcast.constant.int.114 +; CHECK-NEXT: call void @daxpy_64_(i8* %k_p, double* %byref.FMul, double* %[[i56]], i8* %[[i60]], double* %[[i83]], i8* %[[i87]]) +; CHECK-NEXT: %[[i88:.+]] = icmp eq i64 %[[i11]], %[[i14]] +; CHECK-NEXT: br i1 %[[i88]], label %invertentry_end, label %invertentry_loop ; CHECK: invertentry_end: ; preds = %invertentry_loop, %invertentry ; CHECK-NEXT: store i64 0, i64* %byref.constant.int.016, align 4 diff --git a/enzyme/test/Integration/ForwardMode/blas.cpp b/enzyme/test/Integration/ForwardMode/blas.cpp index 952d067c02fb..3a0467862010 100644 --- a/enzyme/test/Integration/ForwardMode/blas.cpp +++ b/enzyme/test/Integration/ForwardMode/blas.cpp @@ -64,6 +64,11 @@ void my_dsyrk(char layout, char uplo, char trans, C, ldc); } +void my_potrf(char layout, char uplo, int N, double *__restrict__ A, int lda) { + int info; + cblas_dpotrf(layout, uplo, N, A, lda, &info); +} + static void dotTests() { { std::string Test = "DOT active both "; @@ -560,7 +565,6 @@ static void gemmTests() { static void syrkTests() { // N means normal matrix, T means transposed - // TODO: row major is presently an exepcted failure. We should re-enable. for (char layout : {CblasColMajor, CblasRowMajor}) { for (auto transA : @@ -693,6 +697,104 @@ static void syrkTests() { } } +static void potrfTests() { + int N = 17; + // N means normal matrix, T means transposed + for (char layout : {CblasColMajor, CblasRowMajor}) { + for (auto uplo : {'U', 'u', 'L', 'l'}) + + { + BlasInfo inputs[6] = { + /*A*/ BlasInfo(A, layout, N, N, lda), + /*B*/ BlasInfo(), + /*C*/ BlasInfo(), + BlasInfo(), + BlasInfo(), + BlasInfo(), + }; + { + + std::string Test = "POTRF active A "; + init(); + + my_potrf(layout, uplo, N, A, lda); + + assert(calls.size() == 1); + assert(calls[0].inDerivative == false); + assert(calls[0].type == CallType::POTRF); + assert(calls[0].pout_arg1 == A); + assert(calls[0].pin_arg1 == UNUSED_POINTER); + assert(calls[0].pin_arg2 == UNUSED_POINTER); + assert(calls[0].farg1 == UNUSED_DOUBLE); + assert(calls[0].farg2 == UNUSED_DOUBLE); + assert(calls[0].layout == layout); + assert(calls[0].targ1 == UNUSED_TRANS); + assert(calls[0].targ2 == UNUSED_TRANS); + assert(calls[0].iarg1 == N); + assert(calls[0].iarg2 == UNUSED_INT); + assert(calls[0].iarg3 == UNUSED_INT); + assert(calls[0].iarg4 == lda); + assert(calls[0].iarg5 == UNUSED_INT); + assert(calls[0].iarg6 == UNUSED_INT); + assert(calls[0].side == UNUSED_TRANS); + assert(calls[0].uplo == uplo); + assert(calls[0].diag == UNUSED_TRANS); + + // Check memory of primal on own. + checkMemoryTrace(inputs, "Primal " + Test, calls); + + init(); + __enzyme_fwddiff( + (void *)my_potrf, enzyme_const, layout, enzyme_const, uplo, + enzyme_const, N, enzyme_dup, A, dA, enzyme_const, lda); + foundCalls = calls; + init(); + + my_potrf(layout, uplo, N, A, lda); + + assert(foundCalls.size() >= 2); + assert(foundCalls[1].type == CallType::LACPY); + double* tri = (double*)foundCalls[1].pout_arg1; + inputs[3] = BlasInfo(tri, layout, N, N, N); + cblas_dlacpy(layout, flip_uplo(uplo), N, N, dA, lda, tri, N); + + cblas_dlascl(layout, flip_uplo(uplo), 0, 0, 1.0, 0.0, N, N, dA, lda, 0); + cblas_dcopy(N, tri, lda+1, dA, lda+1); + + cblas_dtrsm(layout, 'L', uplo, uplo_to_normal(uplo), 'N', N, N, 1.0, A, lda, dA, lda); + cblas_dtrsm(layout, 'R', uplo, uplo_to_trans(uplo), 'N', N, N, 1.0, A, lda, dA, lda); + cblas_dscal(N, 0.5, dA, lda+1); + + assert(foundCalls.size() >= 9); + assert(foundCalls[7].type == CallType::COPY); + double* tmp = (double*)foundCalls[7].pout_arg1; + inputs[4] = BlasInfo(tmp, N, 1); + + cblas_dcopy(N, dA, lda+1, tmp, 1); + cblas_dlascl(layout, flip_uplo(uplo), 0, 0, 1.0, 0.0, N, N, dA, lda, 0); + cblas_dcopy(N, tmp, 1, dA, lda+1); + cblas_dtrmm(layout, uplo_to_side(uplo), uplo, 'N', 'N', N, N, 1.0, A, lda, dA, lda); + + cblas_dcopy(N, dA, lda+1, tmp, 1); + cblas_dlacpy(layout, flip_uplo(uplo), N, N, tri, N, dA, lda); + cblas_dcopy(N, tmp, 1, dA, lda+1); + + checkTest(Test); + + SkipVecIncCheck = true; + // Check memory of primal of expected derivative + checkMemoryTrace(inputs, "Expected " + Test, calls); + + // Check memory of primal of our derivative (if equal above, it + // should be the same). + checkMemoryTrace(inputs, "Found " + Test, foundCalls); + SkipVecIncCheck = false; + } + + } + } +} + int main() { dotTests(); @@ -703,4 +805,6 @@ int main() { gemmTests(); syrkTests(); + + potrfTests(); } diff --git a/enzyme/test/Integration/ReverseMode/blas.cpp b/enzyme/test/Integration/ReverseMode/blas.cpp index 0a8410f4d4e6..a85a7cdbed36 100644 --- a/enzyme/test/Integration/ReverseMode/blas.cpp +++ b/enzyme/test/Integration/ReverseMode/blas.cpp @@ -78,6 +78,12 @@ void my_dsyrk(char layout, char uplo, char trans, inDerivative = true; } +void my_potrf(char layout, char uplo, int N, double *__restrict__ A, int lda) { + int info; + cblas_dpotrf(layout, uplo, N, A, lda, &info); + inDerivative = true; +} + static void dotTests() { std::string Test = "DOT active both "; @@ -1012,7 +1018,128 @@ static void syrkTests() { free(dC); } -int main() { +static void potrfTests() { + int N = 17; + // N means normal matrix, T means transposed + for (char layout : {CblasColMajor, CblasRowMajor}) { + for (auto uplo : {'U', 'u', 'L', 'l'}) + + { + BlasInfo inputs[6] = { + /*A*/ BlasInfo(A, layout, N, N, lda), + /*B*/ BlasInfo(), + /*C*/ BlasInfo(), + BlasInfo(), + BlasInfo(), + BlasInfo(), + }; + { + + std::string Test = "POTRF active A "; + init(); + + my_potrf(layout, uplo, N, A, lda); + + assert(calls.size() == 1); + assert(calls[0].inDerivative == false); + assert(calls[0].type == CallType::POTRF); + assert(calls[0].pout_arg1 == A); + assert(calls[0].pin_arg1 == UNUSED_POINTER); + assert(calls[0].pin_arg2 == UNUSED_POINTER); + assert(calls[0].farg1 == UNUSED_DOUBLE); + assert(calls[0].farg2 == UNUSED_DOUBLE); + assert(calls[0].layout == layout); + assert(calls[0].targ1 == UNUSED_TRANS); + assert(calls[0].targ2 == UNUSED_TRANS); + assert(calls[0].iarg1 == N); + assert(calls[0].iarg2 == UNUSED_INT); + assert(calls[0].iarg3 == UNUSED_INT); + assert(calls[0].iarg4 == lda); + assert(calls[0].iarg5 == UNUSED_INT); + assert(calls[0].iarg6 == UNUSED_INT); + assert(calls[0].side == UNUSED_TRANS); + assert(calls[0].uplo == uplo); + assert(calls[0].diag == UNUSED_TRANS); + + // Check memory of primal on own. + checkMemoryTrace(inputs, "Primal " + Test, calls); + + init(); + __enzyme_autodiff((void *)my_potrf, enzyme_const, layout, enzyme_const, + uplo, enzyme_const, N, enzyme_dup, A, dA, + enzyme_const, lda); + foundCalls = calls; + init(); + + my_potrf(layout, uplo, N, A, lda); + + inDerivative = true; + + assert(foundCalls.size() >= 2); + assert(foundCalls[1].type == CallType::LASCL); + double *tri = (double *)foundCalls[1].pout_arg1; + inputs[3] = BlasInfo(tri, layout, N, N, N); + cblas_dlascl(layout, flip_uplo(uplo), 0, 0, 1.0, 0.0, N, N, tri, N, 0); + + cblas_dlacpy(layout, uplo, N, N, dA, lda, tri, N); + + cblas_dtrmm(layout, uplo_to_side(uplo), uplo, 'T', 'N', N, N, 1.0, + A, lda, tri, N); + + assert(foundCalls.size() >= 5); + assert(foundCalls[4].type == CallType::COPY); + double *tmp = (double *)foundCalls[4].pout_arg1; + inputs[4] = BlasInfo(tmp, N, 1); + + cblas_dcopy(N, tri, N + 1, tmp, 1); + cblas_dscal(N, 0.5, tmp, 1); + cblas_dlascl(layout, flip_uplo(uplo), 0, 0, 1.0, 0.0, N, N, tri, N, 0); + cblas_dcopy(N, tmp, 1, tri, N + 1); + + cblas_dtrsm(layout, uplo_to_rside(uplo), uplo, 'N', 'N', N, N, 1.0, + A, lda, tri, N); + cblas_dtrsm(layout, uplo_to_side(uplo), uplo, 'T', 'N', N, N, 1.0, + A, lda, tri, N); +#define triv(r, c) \ + tri[(r) * (layout == CblasRowMajor ? N : 1) + \ + (c) * (layout == CblasRowMajor ? 1 : N)] + + int upperinc = (&triv(0, 1) - &triv(0,0)); + int lowerinc = (&triv(1, 0) - &triv(0,0)); + if (layout == CblasColMajor) { + assert(upperinc == N); + assert(lowerinc == 1); + } else { + assert(upperinc == 1); + assert(lowerinc == N); + } + for (int i=0; i blasPatterns, raw_ostream &os) { << " fpType = Type::getDoubleTy(call.getContext()); \n" << " } else if (blas.floatType == \"s\" || blas.floatType == \"S\"){\n" << " fpType = Type::getFloatTy(call.getContext()); \n" + << " } else if (blas.floatType == \"c\" || blas.floatType == \"C\"){\n" + << " fpType = " + "llvm::VectorType::get(Type::getFloatTy(call.getContext()), 2, false); " + " \n" + << " } else if (blas.floatType == \"z\" || blas.floatType == \"Z\"){\n" + << " fpType = " + "llvm::VectorType::get(Type::getDoubleTy(call.getContext()), 2, " + "false); \n" << " } else { \n" << " assert(false && \"Unreachable\"); \n" << " } \n"; @@ -325,13 +333,16 @@ void emit_helper(const TGPattern &pattern, raw_ostream &os) { os << "// Next ones shall only be called in the cblas case,\n" << "// they have incorrect meaning otherwise\n" << " const int pos_" << name << " = 0;\n" - << " const auto orig_" << name << " = call.getArgOperand(pos_" << name - << ");\n" - << " auto arg_" << name << " = gutils->getNewFromOriginal(orig_" << name - << ");\n" - << " const auto type_" << name << " = arg_" << name << "->getType();\n" + << " Value *const orig_" << name << " = cblas ? call.getArgOperand(pos_" + << name << ") : nullptr;\n" + << " Value * arg_" << name + << " = cblas ? gutils->getNewFromOriginal(orig_" << name + << ") : nullptr;\n" + << " const auto type_" << name << " = cblas ? arg_" << name + << "->getType() : nullptr;\n" << " const bool overwritten_" << name - << " = (cacheMode ? overwritten_args[pos_" << name << "] : false);\n\n"; + << " = ((cacheMode && cblas) ? overwritten_args[pos_" << name + << "] : false);\n\n"; } auto actArgs = pattern.getActiveArgs(); @@ -344,8 +355,11 @@ void emit_helper(const TGPattern &pattern, raw_ostream &os) { << " auto arg_" << name << " = gutils->getNewFromOriginal(orig_" << name << ");\n" << " const auto type_" << name << " = arg_" << name << "->getType();\n" - << " const bool overwritten_" << name - << " = (cacheMode ? overwritten_args[pos_" << name << "] : false);\n"; + << " const bool overwritten_" << name; + // if (pattern.getMutableArgs().count(i)) + // os << " = (cacheMode ? true : false);\n"; + // else + os << " = (cacheMode ? overwritten_args[pos_" << name << "] : false);\n"; if (std::count(actArgs.begin(), actArgs.end(), i)) { os << " bool active_" << name << " = !gutils->isConstantValue(orig_" << name << ");\n" @@ -453,7 +467,7 @@ void emit_helper(const TGPattern &pattern, raw_ostream &os) { for (auto name : enumerate(nameVec)) { assert(argTypeMap.count(name.index()) == 1); auto ty = argTypeMap.lookup(name.index()); - if (ty == ArgType::trans) { + if (ty == ArgType::trans || ty == ArgType::side || ty == ArgType::uplo) { os << " Type *cublasEnumType = nullptr;\n"; os << " if (cublas) cublasEnumType = type_" << name.value() << ";\n"; break; @@ -533,7 +547,7 @@ void emit_scalar_types(const TGPattern &pattern, raw_ostream &os) { for (auto name : enumerate(nameVec)) { assert(argTypeMap.count(name.index()) == 1); auto ty = argTypeMap.lookup(name.index()); - if (ty == ArgType::trans) { + if (ty == ArgType::trans || ty == ArgType::side || ty == ArgType::uplo) { hasTrans = true; break; } @@ -544,6 +558,8 @@ void emit_scalar_types(const TGPattern &pattern, raw_ostream &os) { << " Value *valueG = nullptr;\n" << " Value *valuer = nullptr;\n" << " Value *valuel = nullptr;\n" + << " Value *valueR = nullptr;\n" + << " Value *valueL = nullptr;\n" << " if (cublas) {\n" << " valueN = ConstantInt::get(cublasEnumType, " "cublasOperation_t::CUBLAS_OP_N);\n" @@ -553,6 +569,10 @@ void emit_scalar_types(const TGPattern &pattern, raw_ostream &os) { "cublasSideMode_t::CUBLAS_SIDE_LEFT);\n" << " valuer = ConstantInt::get(cublasEnumType, " "cublasSideMode_t::CUBLAS_SIDE_RIGHT);\n" + << " valueL = ConstantInt::get(cublasEnumType, " + "cublasSideMode_t::CUBLAS_SIDE_LEFT);\n" + << " valueR = ConstantInt::get(cublasEnumType, " + "cublasSideMode_t::CUBLAS_SIDE_RIGHT);\n" << " // TODO lascl not available in cublas, nor op G\n" << " valueG = ConstantInt::get(cublasEnumType, " "'G');\n" @@ -562,6 +582,8 @@ void emit_scalar_types(const TGPattern &pattern, raw_ostream &os) { << " valueG = ConstantInt::get(charType, 'G');\n" << " valuer = ConstantInt::get(charType, 'r');\n" << " valuel = ConstantInt::get(charType, 'l');\n" + << " valueR = ConstantInt::get(charType, 'R');\n" + << " valueL = ConstantInt::get(charType, 'L');\n" << " }\n\n"; } } @@ -918,6 +940,24 @@ void rev_call_arg(bool forward, DagInit *ruleDag, const TGPattern &pattern, << ", cache_" << matName << ", byRef, cublas)}"; return; } + if (Def->getName() == "is_left") { + if (Dag->getNumArgs() != 1) + PrintFatalError(pattern.getLoc(), "only 1-arg ld operands supported"); + const auto sideName = Dag->getArgNameStr(0); + os << "{to_blas_callconv(Builder2, is_left(Builder2, arg_" << sideName + << ", byRef, cublas), byRef, cublas, julia_decl_type, " + "allocationBuilder, \"isleft\")}"; + return; + } + if (Def->getName() == "is_lower") { + if (Dag->getNumArgs() != 1) + PrintFatalError(pattern.getLoc(), "only 1-arg ld operands supported"); + const auto uploName = Dag->getArgNameStr(0); + os << "{to_blas_callconv(Builder2, is_lower(Builder2, arg_" << uploName + << ", byRef, cublas), byRef, cublas, julia_decl_type, " + "allocationBuilder, \"isleft\")}"; + return; + } } else if (Def->getName() == "Shadow" || Def->isSubClassOf("Shadow")) { if (Dag->getNumArgs() != 1) PrintFatalError(pattern.getLoc(), "only single op shadow supported"); @@ -1095,8 +1135,12 @@ void rev_call_arg(bool forward, DagInit *ruleDag, const TGPattern &pattern, os << " auto derivcall_" << dfnc_name << " = gutils->oldFunc->getParent()->getOrInsertFunction(\n" - << " blas.prefix + blas.floatType + \"" << dfnc_name - << "\" + blas.suffix, FT" << dfnc_name << ");\n"; + << " blas.prefix + blas.floatType + \"" << dfnc_name; + + if (dfnc_name == "copy") + os << "\" + cublasv2 ? \"\" : blas.suffix, FT" << dfnc_name << ");\n"; + else + os << "\" + blas.suffix, FT" << dfnc_name << ");\n"; os << " if (auto F = dyn_cast(derivcall_" << dfnc_name << ".getCallee()))\n" @@ -1153,29 +1197,24 @@ void rev_call_arg(bool forward, DagInit *ruleDag, const TGPattern &pattern, } os << " Value *ptr = larg_1[0];\n"; - os << " if (ptr->getType()->isIntegerTy()) ptr = " - "Builder2.CreateIntToPtr(ptr, PointerType::getUnqual(fpType));\n"; - - os << "#if LLVM_VERSION_MAJOR < 17\n"; - os << "#if LLVM_VERSION_MAJOR >= 15\n"; - os << " if (ptr->getContext().supportsTypedPointers()) {\n"; - os << "#endif\n"; - os << " if (fpType != ptr->getType()->getPointerElementType()) {\n"; - os << " ptr = Builder2.CreatePointerCast(ptr, " - "PointerType::get(fpType, " - "cast(ptr->getType())->getAddressSpace()));\n"; - os << " }\n"; - os << "#if LLVM_VERSION_MAJOR >= 15\n"; - os << " }\n"; - os << "#endif\n"; - os << "#endif\n"; os << " Value *ld_lookup = load_if_ref(Builder2, intType, larg_1[1], " "byRef);\n"; + + auto SDI = dyn_cast(Dag->getArg(1)); + auto SDI2 = + SDI ? dyn_cast(SDI->getOperator())->getDef() : nullptr; + auto SDI3 = (SDI2 && SDI2->getName() == "Concat") + ? dyn_cast(SDI->getArg(0)) + : nullptr; + bool constint = SDI3 && SDI3->getDef()->isSubClassOf("ConstantInt"); + if (Dag->getNumArgs() == 4) { - os << " Value *layoutptr = load_if_ref(Builder2, charType, larg_0[0], " - "byRef);\n"; - os << " Value* is_row_maj = Builder2.CreateICmpEQ(layoutptr, " - "ConstantInt::get(layoutptr->getType(), 101));\n"; + os << " Value *layoutptr = cblas ? load_if_ref(Builder2, charType, " + "larg_0[0], " + "byRef) : nullptr;\n"; + os << " Value* is_row_maj = cblas ? Builder2.CreateICmpEQ(layoutptr, " + "ConstantInt::get(layoutptr->getType(), 101)) : " + "Builder2.getFalse();\n"; os << " Value* offset = Builder2.CreateMul(load_if_ref(Builder2, " "intType, larg_2[0], byRef), CreateSelect(Builder2, is_row_maj, " "ld_lookup, ConstantInt::get(intType, 1)));\n"; @@ -1183,11 +1222,36 @@ void rev_call_arg(bool forward, DagInit *ruleDag, const TGPattern &pattern, "Builder2.CreateMul(load_if_ref(Builder2, " "intType, larg_3[0], byRef), CreateSelect(Builder2, is_row_maj, " "ConstantInt::get(intType, 1), ld_lookup)));\n"; - } else { - os << " Value* offset = Builder2.CreateMul(load_if_ref(Builder2, " - "intType, larg_2[0], byRef), ld_lookup);\n"; + if (constint) + os << " ptr = to_blas_callconv(Builder2, offset, byRef, cublas, " + "nullptr, " + "allocationBuilder, \"offset\");\n"; + } + + if (!constint) { + os << " if (ptr->getType()->isIntegerTy()) ptr = " + "Builder2.CreateIntToPtr(ptr, PointerType::getUnqual(fpType));\n"; + + os << "#if LLVM_VERSION_MAJOR < 17\n"; + os << "#if LLVM_VERSION_MAJOR >= 15\n"; + os << " if (ptr->getContext().supportsTypedPointers()) {\n"; + os << "#endif\n"; + os << " if (fpType != ptr->getType()->getPointerElementType()) {\n"; + os << " ptr = Builder2.CreatePointerCast(ptr, " + "PointerType::get(fpType, " + "cast(ptr->getType())->getAddressSpace()));\n"; + os << " }\n"; + os << "#if LLVM_VERSION_MAJOR >= 15\n"; + os << " }\n"; + os << "#endif\n"; + os << "#endif\n"; + if (Dag->getNumArgs() == 4) { + } else { + os << " Value* offset = Builder2.CreateMul(load_if_ref(Builder2, " + "intType, larg_2[0], byRef), ld_lookup);\n"; + } + os << " ptr = Builder2.CreateGEP(fpType, ptr, offset);\n"; } - os << " ptr = Builder2.CreateGEP(fpType, ptr, offset);\n"; if (Def->getName() == "LoadLookup") { os << " if (!byRefFloat) ptr = Builder2.CreateLoad(fpType, ptr);\n"; os << " SmallVector vals = { ptr };\n"; @@ -1254,10 +1318,16 @@ void rev_call_arg(bool forward, DagInit *ruleDag, const TGPattern &pattern, } else if (val == "l") { os << "{to_blas_callconv(Builder2, valuel, byRef, cublas, nullptr, " "allocationBuilder, \"constant.char.l\")}"; + } else if (val == "R") { + os << "{to_blas_callconv(Builder2, valueR, byRef, cublas, nullptr, " + "allocationBuilder, \"constant.char.R\")}"; + } else if (val == "L") { + os << "{to_blas_callconv(Builder2, valueL, byRef, cublas, nullptr, " + "allocationBuilder, \"constant.char.L\")}"; // C is not supported yet //} else if (val == "C") { } else { - errs() << "unknown char: " << val << "\n"; + errs() << "unknown char: '" << val << "'\n"; PrintFatalError(Def->getLoc(), "unknown char"); } } else if (Def->isSubClassOf("Alloca")) { @@ -1284,6 +1354,7 @@ void rev_call_arg(bool forward, DagInit *ruleDag, const TGPattern &pattern, } else { auto name = ruleDag->getArgNameStr(pos); if (name == "") { + llvm::errs() << "ruleDag: " << *ruleDag << "\n"; PrintFatalError(pattern.getLoc(), "arg has no name!" + std::to_string(pos)); assert(name != ""); @@ -1365,13 +1436,13 @@ void rev_call_args(bool forward, Twine argName, const TGPattern &pattern, } os << " if (byRef) {\n"; int n = 0; - if (func == "gemv" || func == "lascl") + if (func == "gemv" || func == "lascl" || func == "potrs") n = 1; if (func == "gemm" || func == "syrk" || func == "syr2k") n = 2; if (func == "trmv") n = 3; - if (func == "trmm") + if (func == "trmm" || func == "trsm") n = 4; for (int i = 0; i < n; i++) os << " " << argName @@ -1403,7 +1474,7 @@ void emit_tmp_creation(Record *Def, raw_ostream &os, StringRef builder) { assert(args.size() >= 2); auto action = args[1]; assert(action == "product" || action == "is_normal" || - action == "triangular"); + action == "triangular" || action == "vector"); if (action == "product") { const auto matName = args[0]; const auto dim1 = "arg_" + args[2]; @@ -1427,6 +1498,13 @@ void emit_tmp_creation(Record *Def, raw_ostream &os, StringRef builder) { os << " Value *size_" << vecName << " = " << builder << ".CreateSelect(is_normal(" << builder << ", " << trans << ", byRef, cublas), len1, len2);\n"; + } else if (action == "vector") { + assert(args.size() == 3); + const auto vecName = args[0]; + const auto dim1 = "arg_" + args[2]; + os << " Value *len1 = load_if_ref(" << builder << ", intType," << dim1 + << ", byRef);\n"; + os << " Value *size_" << vecName << " = len1;\n"; } else if (action == "triangular") { assert(args.size() == 3); const auto vecName = args[0]; @@ -1436,13 +1514,7 @@ void emit_tmp_creation(Record *Def, raw_ostream &os, StringRef builder) { // Size has to be (at least) // ( ( n*( n + 1 ) )/2 ) os << " Value *size_" << vecName << " = " << builder - << ".CreateMul(len, " << builder - << ".CreateAdd(len, " - "ConstantInt::get(intType, 1)), \"square_mat_size_" - << vecName << "\");\n" - << " size_" << vecName << " = " << builder << ".CreateUDiv(size_" - << vecName << ", ConstantInt::get(intType, 2), \"size_" << vecName - << "\");\n"; + << ".CreateMul(len, len);\n"; } const auto matName = args[0]; const auto allocName = "mat_" + matName; @@ -1563,8 +1635,12 @@ void emit_dag(bool forward, Twine resultVarName, DagInit *ruleDag, os << " auto derivcall_" << dfnc_name << " = gutils->oldFunc->getParent()->getOrInsertFunction(\n" - << " blas.prefix + blas.floatType + \"" << dfnc_name - << "\" + blas.suffix, FT" << dfnc_name << ");\n"; + << " blas.prefix + blas.floatType + \"" << dfnc_name; + + if (dfnc_name == "copy") + os << "\" + (cublasv2 ? \"\" : blas.suffix), FT" << dfnc_name << ");\n"; + else + os << "\" + blas.suffix, FT" << dfnc_name << ");\n"; os << " if (auto F = dyn_cast(derivcall_" << dfnc_name << ".getCallee()))\n" @@ -1916,6 +1992,16 @@ void emit_fwd_rewrite_rules(const TGPattern &pattern, raw_ostream &os) { } } + auto duals = pattern.getDuals(); + const auto Def = cast(duals->getOperator())->getDef(); + + if (Def->isSubClassOf("Seq")) { + if (!Def->getValueAsBit("start")) { + os << "Builder2.SetInsertPoint(gutils->getNewFromOriginal(&call)->" + "getNextNode());\n"; + } + } + os << " Value *dres = applyChainRule(\n" << " call.getType(), Builder2,\n" << " [&]("; @@ -1929,7 +2015,7 @@ void emit_fwd_rewrite_rules(const TGPattern &pattern, raw_ostream &os) { << " Value *dres = nullptr;\n"; StringMap vars; - emit_dag(/*forward*/ true, "dres", pattern.getDuals(), "args", os, "", + emit_dag(/*forward*/ true, "dres", duals, "args", os, "", /*actArg*/ -1, pattern, /*runtimeChecked*/ false, vars); os << " if (!dres && !call.getType()->isVoidTy()) dres = " @@ -1971,6 +2057,25 @@ void emit_rev_rewrite_rules(const StringMap &patternMap, os << " /* rev-rewrite */ \n" << " if (Mode == DerivativeMode::ReverseModeCombined ||\n" << " Mode == DerivativeMode::ReverseModeGradient) {\n" + << " if (blas.floatType == \"c\" || blas.floatType == \"C\" || " + "blas.floatType == \"z\" || blas.floatType == \"Z\") {\n" + << " std::string s;\n" + << " llvm::raw_string_ostream ss(s);\n" + << " ss << \"" << pattern.getName() << "\" << \"\\n\";\n" + << " ss << call.getDebugLoc() << \"\\n\";\n" + << " ss << \"Complex inputs not yet supported in reverse mode for " + "BLAS calls\" << " + "\"\\n\";\n" + << " if (CustomErrorHandler) {\n" + << " CustomErrorHandler(ss.str().c_str(), wrap(&call), " + "ErrorType::NoDerivative,\n" + << " gutils, nullptr, wrap(&Builder2));\n" + << " } else {\n" + << " EmitFailure(\"Unsupported Mode\", call.getDebugLoc(), &call, " + "ss.str());\n" + << " }\n" + << " }\n" + << " Value *alloc = nullptr;\n" << " if (byRef && !cublas) {\n" << " alloc = allocationBuilder.CreateAlloca(fpType, nullptr, " diff --git a/enzyme/tools/enzyme-tblgen/blasDiffUseUpdater.h b/enzyme/tools/enzyme-tblgen/blasDiffUseUpdater.h index c9893d5184f8..aa6f884eda56 100644 --- a/enzyme/tools/enzyme-tblgen/blasDiffUseUpdater.h +++ b/enzyme/tools/enzyme-tblgen/blasDiffUseUpdater.h @@ -55,8 +55,13 @@ void emit_BLASDiffUse(TGPattern &pattern, llvm::raw_ostream &os) { << ";\n"; os << " auto arg_" << name << " = CI->getArgOperand(pos_" << name << ");\n"; - os << " const bool overwritten_" << name - << " = (cacheMode ? (overwritten_args_ptr ? (*overwritten_args_ptr)[pos_" + os << " const bool overwritten_" << name; + + // if (pattern.getMutableArgs().count(argPos)) + // os << " = (cacheMode ? true : false);\n\n"; + // else + os << " = (cacheMode ? (overwritten_args_ptr ? " + "(*overwritten_args_ptr)[pos_" << name << "] : true ) : false);\n\n"; } diff --git a/enzyme/tools/enzyme-tblgen/blasTAUpdater.h b/enzyme/tools/enzyme-tblgen/blasTAUpdater.h index 331f0b2eafe7..e8a9ae6f9e6b 100644 --- a/enzyme/tools/enzyme-tblgen/blasTAUpdater.h +++ b/enzyme/tools/enzyme-tblgen/blasTAUpdater.h @@ -14,9 +14,11 @@ void emit_BLASTypes(raw_ostream &os) { os << "TypeTree ttFloat;\n" << "llvm::Type *floatType; \n" - << "if (blas.floatType == \"s\" || blas.floatType == \"S\") {\n" + << "if (blas.floatType == \"s\" || blas.floatType == \"S\" || " + "blas.floatType == \"c\" || blas.floatType == \"C\") {\n" << " floatType = Type::getFloatTy(call.getContext());\n" - << "} else if (blas.floatType == \"d\" || blas.floatType == \"D\"){\n" + << "} else if (blas.floatType == \"d\" || blas.floatType == \"D\" || " + "blas.floatType == \"z\" || blas.floatType == \"Z\") {\n" << " floatType = Type::getDoubleTy(call.getContext());\n" << "} else {\n" << " llvm_unreachable(\"unknown float type of blas\");\n" diff --git a/enzyme/tools/enzyme-tblgen/caching.cpp b/enzyme/tools/enzyme-tblgen/caching.cpp index 6d0cbeedd06e..a8c1cdd2fd29 100644 --- a/enzyme/tools/enzyme-tblgen/caching.cpp +++ b/enzyme/tools/enzyme-tblgen/caching.cpp @@ -192,8 +192,11 @@ void emit_vec_like_copy(const TGPattern &pattern, raw_ostream &os) { << " Value *arg_malloc_size;\n"; if (dimensions.size() == 3) { - os -<< " malloc_size = select_vec_dims(BuilderZ, arg_" << nameVec[dimensions[0]] << ", arg_" << nameVec[dimensions[1]] << ", arg_" << nameVec[dimensions[2]] << ", byRef, cublas);\n"; + auto startty = pattern.getTypeOfArg(nameVec[dimensions[0]]); + assert(startty == ArgType::trans); +os +<< " auto norm = is_normal(BuilderZ, arg_" << nameVec[dimensions[0]] << ", byRef, cublas);\n" +<< " malloc_size = CreateSelect(BuilderZ, norm, arg_" << nameVec[dimensions[1]] << ", arg_" << nameVec[dimensions[2]] << ");\n"; } else { os << " malloc_size = arg_" << nameVec[dimensions[0]] << ";\n"; @@ -201,11 +204,14 @@ void emit_vec_like_copy(const TGPattern &pattern, raw_ostream &os) { os << " arg_malloc_size = malloc_size;\n" << " malloc_size = load_if_ref(BuilderZ, intType, malloc_size, byRef);\n" -<< " auto malins = CreateAllocation(BuilderZ, fpType, malloc_size, \"cache." << vecName << "\");\n" +<< " Instruction *SubZero = nullptr;\n" +<< " auto malins = CreateAllocation(BuilderZ, fpType, malloc_size, \"cache." << vecName << "\", /*caller*/nullptr"; + if (pattern.getName() == "potrf") os << ", &SubZero"; + os << ");\n" << " ValueType valueTypes[] = {" << valueTypes << "};\n" << " valueTypes[" << argIdx << "] = ValueType::Primal;\n" << " if (byRef) valueTypes[" << argIdx+1 << "] = ValueType::Primal;\n"; - for (auto len_pos : pattern.getRelatedLengthArgs(argIdx) ) { + for (auto len_pos : pattern.getRelatedLengthArgs(argIdx, /*hideuplo*/true) ) { os << " if (byRef) valueTypes[" << len_pos << "] = ValueType::Primal;\n"; } os << " if (cublas) {\n" @@ -249,11 +255,27 @@ os << " if (cublas) {\n" << " auto charTy = IntegerType::get(intType->getContext(), 8);\n" << " Value *M, *N;\n"; + std::string uplostr = " Value *uplo = llvm::ConstantInt::get(charTy, 0);\n" // garbage data, just should not match U or L + " uplo = to_blas_callconv(BuilderZ, uplo, byRef, cublas, nullptr, allocationBuilder, \"copy.garbage\");\n"; if (dimensions.size() == 3) { + auto startty = pattern.getTypeOfArg(nameVec[dimensions[0]]); + if (startty == ArgType::trans) { os << " Value *normal = is_normal(BuilderZ, arg_" << nameVec[dimensions[0]] << ", byRef, cublas);\n" << " M = BuilderZ.CreateSelect(normal, " << dim1 << ", " << dim2 << ");\n" << " N = BuilderZ.CreateSelect(normal, " << dim2 << ", " << dim1 << ");\n"; + } else if (startty == ArgType::uplo) { +os << " M = " << dim1 << ";\n" +<< " N = " << dim2 << ";\n"; +uplostr = " Value *uplo = arg_" + nameVec[dimensions[0]] + ";\n"; + } else if (startty == ArgType::side) { +os +<< " Value *normal = is_left(BuilderZ, arg_" << nameVec[dimensions[0]] << ", byRef, cublas);\n" +<< " M = BuilderZ.CreateSelect(normal, " << dim1 << ", " << dim2 << ");\n" +<< " N = BuilderZ.CreateSelect(normal, " << dim2 << ", " << dim1 << ");\n"; + } else { + assert(0 &&" unknown startty"); + } } else { os << " M = " << dim1 << ";\n" @@ -264,16 +286,18 @@ os << " if (cublas) {\n" << " auto *len1 = load_if_ref(BuilderZ, intType, M, byRef);\n" << " auto *len2 = load_if_ref(BuilderZ, intType, N, byRef);\n" << " auto *matSize = BuilderZ.CreateMul(len1, len2);\n" -<< " auto malins = CreateAllocation(BuilderZ, fpType, matSize, \"cache." << matName << "\");\n" +<< " Instruction *SubZero = nullptr;\n" +<< " auto malins = CreateAllocation(BuilderZ, fpType, matSize, \"cache." << matName << "\", /*caller*/nullptr"; + if (pattern.getName() == "potrf") os << ", &SubZero"; + os << ");\n" << " SmallVector valueTypes = {" << valueTypes << "};\n" <<" valueTypes[" << argIdx << "] = ValueType::Primal;\n" << " if (byRef) valueTypes[" << argIdx+1 << "] = ValueType::Primal;\n"; - for (auto len_pos : dimensions ) { + for (auto len_pos : pattern.getRelatedLengthArgs(argIdx, /*hideuplo*/true) ) { os << " if (byRef) valueTypes[" << len_pos << "] = ValueType::Primal;\n"; } os << " if (EnzymeLapackCopy) {\n" -<< " Value *uplo = llvm::ConstantInt::get(charTy, 0);\n" // garbage data, just should not match U or L -<< " uplo = to_blas_callconv(BuilderZ, uplo, byRef, cublas, nullptr, allocationBuilder, \"copy.garbage\");\n" +<< uplostr << " SmallVector args = {uplo, M, N, arg_" << matName << ", arg_" << ldName << ", malins, M};\n" << " if (!byRef) {\n" << " args.insert(args.begin(), arg_layout); valueTypes.insert(valueTypes.begin(), ValueType::Primal); }\n" diff --git a/enzyme/tools/enzyme-tblgen/datastructures.cpp b/enzyme/tools/enzyme-tblgen/datastructures.cpp index ebf0e367c806..f68ce257953d 100644 --- a/enzyme/tools/enzyme-tblgen/datastructures.cpp +++ b/enzyme/tools/enzyme-tblgen/datastructures.cpp @@ -393,7 +393,7 @@ void fillRelatedLenghts( assert(argTypes.lookup(lengths[1]) == ArgType::len); } else { assert(argTypes.lookup(lengths[0]) == ArgType::trans || - argTypes.lookup(lengths[0]) == ArgType::diag || + argTypes.lookup(lengths[0]) == ArgType::uplo || argTypes.lookup(lengths[0]) == ArgType::side); assert(argTypes.lookup(lengths[1]) == ArgType::len); assert(argTypes.lookup(lengths[2]) == ArgType::len); @@ -454,7 +454,8 @@ TGPattern::TGPattern(Record *r) fillArgUserMap(rules, args, posActArgs, argUsers); } -SmallVector TGPattern::getRelatedLengthArgs(size_t arg) const { +SmallVector TGPattern::getRelatedLengthArgs(size_t arg, + bool hideuplo) const { // other args are unrelated to length args assert(argTypes.lookup(arg) == ArgType::vincData || argTypes.lookup(arg) == ArgType::mldData || @@ -465,9 +466,10 @@ SmallVector TGPattern::getRelatedLengthArgs(size_t arg) const { if (related.size() == 3) { auto argTy = argTypes.lookup(related[0]); - assert(argTy == ArgType::trans || argTy == ArgType::diag || + assert(argTy == ArgType::trans || argTy == ArgType::uplo || argTy == ArgType::side); - (void)argTy; + if (hideuplo && argTy == ArgType::uplo) + related.erase(related.begin()); } return related; diff --git a/enzyme/tools/enzyme-tblgen/datastructures.h b/enzyme/tools/enzyme-tblgen/datastructures.h index 885686a39241..11638b76da44 100644 --- a/enzyme/tools/enzyme-tblgen/datastructures.h +++ b/enzyme/tools/enzyme-tblgen/datastructures.h @@ -123,7 +123,8 @@ class TGPattern { public: TGPattern(Record *r); - SmallVector getRelatedLengthArgs(size_t arg) const; + SmallVector getRelatedLengthArgs(size_t arg, + bool hideuplo = false) const; bool isBLASLevel2or3() const; const DenseMap> &getArgUsers() const; StringRef getName() const;