From 0b55e253c9f9543ac5bb789df4a81875d76e69fb Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Thu, 8 Aug 2024 19:15:08 +0100 Subject: [PATCH] symm rrule implementation (#219) * Fix typo * symm rrule implementation * Bump patch * Use BLAS for multithreading * Optimise for common edge case * Improve error message further --- Project.toml | 2 +- src/interpreter/s2s_reverse_mode_ad.jl | 6 ++ src/rrules/blas.jl | 106 ++++++++++++++++++++++++- 3 files changed, 112 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 0136553b..2c9c3f7e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Tapir" uuid = "07d77754-e150-4737-8c94-cd238a1fb45b" authors = ["Will Tebbutt, Hong Ge, and contributors"] -version = "0.2.31" +version = "0.2.32" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/interpreter/s2s_reverse_mode_ad.jl b/src/interpreter/s2s_reverse_mode_ad.jl index 0f615974..9d975f40 100644 --- a/src/interpreter/s2s_reverse_mode_ad.jl +++ b/src/interpreter/s2s_reverse_mode_ad.jl @@ -1254,6 +1254,12 @@ function (rule::LazyDerivedRule{T, Trule})(args::Vararg{Any, N}) where {N, T, Tr rule.rule = derived_rule else @warn "Unable to put rule in rule field. Rule should error." + println("MethodInstance is") + display(rule.mi) + println() + println("with signature") + display(rule.mi.specTypes) + println() println("derived_rule is of type") display(typeof(derived_rule)) println() diff --git a/src/rrules/blas.jl b/src/rrules/blas.jl index 066db692..b047b6d3 100644 --- a/src/rrules/blas.jl +++ b/src/rrules/blas.jl @@ -388,6 +388,93 @@ for (gemm, elty) in ((:dgemm_, :Float64), (:sgemm_, :Float32)) end end +@is_primitive( + MinimalCtx, + Tuple{ + typeof(BLAS.symm!), + Char, + Char, + T, + MatrixOrView{T}, + MatrixOrView{T}, + T, + Matrix{T}, + } where {T<:Union{Float32, Float64}}, +) + +function rrule!!( + ::CoDual{typeof(BLAS.symm!)}, + side::CoDual{Char}, + uplo::CoDual{Char}, + alpha::CoDual{T}, + A_dA::CoDual{<:MatrixOrView{T}}, + B_dB::CoDual{<:MatrixOrView{T}}, + beta::CoDual{T}, + C_dC::CoDual{Matrix{T}}, +) where {T<:Union{Float32, Float64}} + + # Extract primals. + s = primal(side) + ul = primal(uplo) + α = primal(alpha) + β = primal(beta) + A, dA = viewify(A_dA) + B, dB = viewify(B_dB) + C, dC = viewify(C_dC) + + # In this rule we optimise carefully for the special case a == 1 && b == 0, which + # corresponds to simply multiplying symm(A) and B together, and writing the result to C. + # This is an extremely common edge case, so it's important to do well for it. + C_copy = copy(C) + tmp_ref = Ref{Matrix{T}}() + if (α == 1 && β == 0) + BLAS.symm!(s, ul, α, A, B, β, C) + else + tmp = BLAS.symm(s, ul, one(T), A, B) + tmp_ref[] = tmp + BLAS.axpby!(α, tmp, β, C) + end + + function symm!_adjoint(::NoRData) + + if (α == 1 && β == 0) + dα = dot(dC, C) + BLAS.copyto!(C, C_copy) + else + # Reset C. + BLAS.copyto!(C, C_copy) + + # gradient w.r.t. α. Safe to write into memory for copy of C. + BLAS.symm!(s, ul, one(T), A, B, zero(T), C_copy) + dα = dot(dC, C_copy) + end + + # gradient w.r.t. A. + dA_tmp = s == 'L' ? dC * B' : B' * dC + if ul == 'L' + dA .+= α .* LowerTriangular(dA_tmp) + dA .+= α .* UpperTriangular(dA_tmp)' + else + dA .+= α .* LowerTriangular(dA_tmp)' + dA .+= α .* UpperTriangular(dA_tmp) + end + @inbounds for n in diagind(dA) + dA[n] -= α * dA_tmp[n] + end + + # gradient w.r.t. B. + BLAS.symm!(s, ul, α, A, dC, one(T), dB) + + # gradient w.r.t. beta. + dβ = dot(dC, C) + + # gradient w.r.t. C. + BLAS.scal!(β, dC) + + return NoRData(), NoRData(), NoRData(), dα, NoRData(), NoRData(), dβ, NoRData() + end + return C_dC, symm!_adjoint +end for (syrk, elty) in ((:dsyrk_, :Float64), (:ssyrk_, :Float32)) @eval function rrule!!( @@ -600,7 +687,7 @@ end function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:blas}) t_flags = ['N', 'T', 'C'] - alphas = [0.0, -0.25] + alphas = [1.0, -0.25] betas = [0.0, 0.33] test_cases = vcat( @@ -626,6 +713,23 @@ function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:blas}) end end), )), + + # symm! + vec(reduce( + vcat, + vec(map(product(['L', 'R'], ['L', 'U'], alphas, betas)) do (side, uplo, α, β) + nA = side == 'L' ? 5 : 7 + A = randn(nA, nA) + vA = view(randn(15, 15), 1:nA, 1:nA) + B = randn(5, 7) + vB = view(randn(15, 15), 1:5, 1:7) + C = randn(5, 7) + return Any[ + (false, :stability, nothing, BLAS.symm!, side, uplo, α, A, B, β, C), + (false, :stability, nothing, BLAS.symm!, side, uplo, α, vA, vB, β, C), + ] + end) + )), ) memory = Any[]