Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor gemm #155

Merged
merged 7 commits into from
May 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Tapir"
uuid = "07d77754-e150-4737-8c94-cd238a1fb45b"
authors = ["Will Tebbutt, Hong Ge, and contributors"]
version = "0.2.13"
version = "0.2.14"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
124 changes: 116 additions & 8 deletions src/rrules/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,89 @@ end
# LEVEL 3
#

const MatrixOrView{T} = Union{Matrix{T}, SubArray{T, 2, Matrix{T}}}

@is_primitive(
MinimalCtx,
Tuple{
typeof(BLAS.gemm!),
Char,
Char,
T,
MatrixOrView{T},
MatrixOrView{T},
T,
Matrix{T},
} where {T<:Union{Float32, Float64}},
)

function rrule!!(
::CoDual{typeof(BLAS.gemm!)},
transA::CoDual{Char},
transB::CoDual{Char},
alpha::CoDual{T},
A::CoDual{<:MatrixOrView{T}},
B::CoDual{<:MatrixOrView{T}},
beta::CoDual{T},
C::CoDual{Matrix{T}},
) where {T<:Union{Float32, Float64}}

tA = primal(transA)
tB = primal(transB)
a = primal(alpha)
b = primal(beta)
p_A, dA = viewify(A)
p_B, dB = viewify(B)
p_C, dC = viewify(C)

# In this rule we optimise carefully for the special case a == 1 && b == 0, which
# corresponds to simply multiplying A and B together, and writing the result to C.
# This is an extremely common edge case, so it's important to do well for it.
p_C_copy = copy(p_C)
tmp_ref = Ref{Matrix{T}}()
if (a == 1 && b == 0)
BLAS.gemm!(tA, tB, a, p_A, p_B, b, p_C)
else
tmp = BLAS.gemm(tA, tB, one(T), p_A, p_B)
tmp_ref[] = tmp
BLAS.axpby!(a, tmp, b, p_C)
end

function gemm!_pb!!(::NoRData)

# Compute pullback w.r.t. alpha.
da = (a == 1 && b == 0) ? dot(dC, p_C) : dot(dC, tmp_ref[])

# Restore previous state.
BLAS.copyto!(p_C, p_C_copy)

# Compute pullback w.r.t. beta.
db = BLAS.dot(dC, p_C)

# Increment cotangents.
if tA == 'N'
BLAS.gemm!('N', tB == 'N' ? 'T' : 'N', a, dC, p_B, one(T), dA)
else
BLAS.gemm!(tB == 'N' ? 'N' : 'T', 'T', a, p_B, dC, one(T), dA)
end
if tB == 'N'
BLAS.gemm!(tA == 'N' ? 'T' : 'N', 'N', a, p_A, dC, one(T), dB)
else
BLAS.gemm!('T', tA == 'N' ? 'N' : 'T', a, dC, p_A, one(T), dB)
end
BLAS.scal!(b, dC)

return NoRData(), NoRData(), NoRData(), da, NoRData(), NoRData(), db, NoRData()
end
return C, gemm!_pb!!
end

viewify(A::CoDual{<:Matrix}) = view(primal(A), :, :), view(tangent(A), :, :)
function viewify(A::CoDual{P}) where {P<:SubArray}
p_A = primal(A)
return p_A, P(tangent(A).data.parent, p_A.indices, p_A.offset1, p_A.stride1)
end

for (gemm, elty) in ((:dgemm_, :Float64), (:sgemm_, :Float32))
@eval function rrule!!(
::CoDual{typeof(_foreigncall_)},
Expand Down Expand Up @@ -515,7 +598,39 @@ for (trsm, elty) in ((:dtrsm_, :Float64), (:strsm_, :Float32))
end
end

generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:blas}) = Any[], Any[]
function generate_hand_written_rrule!!_test_cases(rng_ctor, ::Val{:blas})
t_flags = ['N', 'T', 'C']
alphas = [0.0, -0.25]
betas = [0.0, 0.33]

test_cases = vcat(
# gemm!
vec(reduce(
vcat,
vec(map(product(t_flags, t_flags, alphas, betas)) do (tA, tB, a, b)
A = tA == 'N' ? randn(3, 4) : randn(4, 3)
B = tB == 'N' ? randn(4, 5) : randn(5, 4)
As = if tA == 'N'
[randn(3, 4), view(randn(15, 15), 2:4, 3:6)]
else
[randn(4, 3), view(randn(15, 15), 2:5, 3:5)]
end
Bs = if tB == 'N'
[randn(4, 5), view(randn(15, 15), 1:4, 2:6)]
else
[randn(5, 4), view(randn(15, 15), 1:5, 3:6)]
end
C = randn(3, 5)
return map(product(As, Bs)) do (A, B)
(false, :stability, nothing, BLAS.gemm!, tA, tB, a, A, B, b, C)
end
end),
)),
)

memory = Any[]
return test_cases, memory
end

function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:blas})
t_flags = ['N', 'T', 'C']
Expand Down Expand Up @@ -572,13 +687,6 @@ function generate_derived_rrule!!_test_cases(rng_ctor, ::Val{:blas})
# BLAS LEVEL 3
#

# gemm!
vec(map(product(t_flags, t_flags)) do (tA, tB)
A = tA == 'N' ? randn(3, 4) : randn(4, 3)
B = tB == 'N' ? randn(4, 5) : randn(5, 4)
(false, :none, nothing, BLAS.gemm!, tA, tB, randn(), A, B, randn(), randn(3, 5))
end),

# aliased gemm!
vec(map(product(t_flags, t_flags)) do (tA, tB)
A = randn(5, 5)
Expand Down
Loading