diff --git a/src/Nabla.jl b/src/Nabla.jl index 6c9dca12..cdba98c7 100644 --- a/src/Nabla.jl +++ b/src/Nabla.jl @@ -58,5 +58,6 @@ module Nabla include("sensitivities/linalg/triangular.jl") include("sensitivities/linalg/factorization/cholesky.jl") include("sensitivities/linalg/factorization/svd.jl") + include("sensitivities/linalg/factorization/qr.jl") end # module Nabla diff --git a/src/sensitivities/linalg/factorization/qr.jl b/src/sensitivities/linalg/factorization/qr.jl new file mode 100644 index 00000000..c53e4042 --- /dev/null +++ b/src/sensitivities/linalg/factorization/qr.jl @@ -0,0 +1,54 @@ +import LinearAlgebra: qr +import Base: getproperty + +const QRLike = Union{QR, LinearAlgebra.QRCompactWY} + +@explicit_intercepts qr Tuple{AbstractMatrix{<:Real}} + +function ∇( + ::typeof(qr), + ::Type{Arg{1}}, + p, + Y::QRLike, + Ȳ::NamedTuple{(:Q,:R)}, + A::AbstractMatrix, +) + Q, R = Y + Q̄, R̄ = Ȳ + triu!(R̄) + M = R*R̄' + M .-= Q̄'Q + return (Q̄ + Q*Symmetric(M, :L)) / R' +end + +@explicit_intercepts getproperty Tuple{QRLike, Symbol} [true, false] + +function ∇(::typeof(getproperty), ::Type{Arg{1}}, p, y, ȳ, F::QRLike, x::Symbol) + if x === :Q + return (Q=reshape(ȳ, size(F.Q)), R=zeroslike(F.R)) + elseif x === :R + return (Q=zeroslike(F.Q), R=reshape(ȳ, size(F.R))) + else + throw(ArgumentError("unrecognized property $x; expected Q or R")) + end +end + +function ∇( + x̄::NamedTuple{(:Q,:R)}, + ::typeof(getproperty), + ::Type{Arg{1}}, + p, y, ȳ, + F::QRLike, + x::Symbol, +) + x̄_update = ∇(getproperty, Arg{1}, p, y, ȳ, F, x) + if x === :Q + return (Q=update!(x̄.Q, x̄_update.Q), R=x̄.R) + elseif x === :R + return (Q=x̄.Q, R=update!(x̄.R, x̄_update.R)) + end +end + +Base.iterate(qr::Branch{<:QRLike}) = (qr.Q, Val(:R)) +Base.iterate(qr::Branch{<:QRLike}, ::Val{:R}) = (qr.R, Val(:done)) +Base.iterate(qr::Branch{<:QRLike}, ::Val{:done}) = nothing diff --git a/test/runtests.jl b/test/runtests.jl index fe596fdd..304c797a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -43,6 +43,7 @@ end @testset "Factorisations" begin include("sensitivities/linalg/factorization/cholesky.jl") include("sensitivities/linalg/factorization/svd.jl") + include("sensitivities/linalg/factorization/qr.jl") end end end diff --git a/test/sensitivities/linalg/factorization/qr.jl b/test/sensitivities/linalg/factorization/qr.jl new file mode 100644 index 00000000..44913bb8 --- /dev/null +++ b/test/sensitivities/linalg/factorization/qr.jl @@ -0,0 +1,37 @@ +@testset "QR" begin + @testset "Comparison with finite differencing" begin + rng = MersenneTwister(123456) + n = 5 + A = randn(rng, n, n) + VA = randn(rng, n, n) + @test check_errs(X->qr(X).Q, randn(rng, n, n), A, VA) + @test check_errs(X->qr(X).R, randn(rng, n, n), A, VA) + end + + @testset "Branch consistency" begin + X_ = Matrix(1.0I, 5, 3) + X = Leaf(Tape(), X_) + F = qr(X) + @test F isa Branch{<:LinearAlgebra.QRCompactWY} + @test getfield(F, :f) == qr + @test unbox(F.Q) ≈ Matrix(1.0I, 5, 5) + @test unbox(F.R) ≈ Matrix(1.0I, 3, 3) + # Destructuring via iteration + Q, R = F + @test Q isa Branch{<:LinearAlgebra.QRCompactWYQ} + @test R isa Branch{<:Matrix} + end + + @testset "Tape updating" begin + t = Tape() + X_ = Matrix(1.0I, 4, 4) + X = Leaf(t, X_) + Q, R = qr(X) + Y = Q*R + Z = Q*Y*R + rt = ∇(Z, X_) + @test rt[2] isa NamedTuple{(:Q,:R)} + @test rt[2].Q ≈ Matrix(2.0I, 4, 4) + @test rt[2].R ≈ Matrix(2.0I, 4, 4) + end +end