diff --git a/Project.toml b/Project.toml index a8ea16a25..717c707ab 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Zygote" uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.6.30" +version = "0.6.31" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" diff --git a/src/lib/array.jl b/src/lib/array.jl index 7734ad5ca..659b8d89a 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -741,15 +741,19 @@ end return ((uplo=nothing, info=nothing, factors=nothing),) end end -@adjoint function literal_getproperty(C::Cholesky, ::Val{:U}) +@adjoint function literal_getproperty( + C::Cholesky{T, <:StridedMatrix{T}} where {T<:Real}, ::Val{:U} +) return literal_getproperty(C, Val(:U)), function(Δ) - Δ_factors = C.uplo == 'U' ? UpperTriangular(Δ) : LowerTriangular(copy(Δ')) + Δ_factors = C.uplo == 'U' ? triu!(collect(Δ)) : tril!(collect(Δ')) return ((uplo=nothing, info=nothing, factors=Δ_factors),) end end -@adjoint function literal_getproperty(C::Cholesky, ::Val{:L}) +@adjoint function literal_getproperty( + C::Cholesky{T, <:StridedMatrix{T}} where {T<:Real}, ::Val{:L} +) return literal_getproperty(C, Val(:L)), function(Δ) - Δ_factors = C.uplo == 'L' ? LowerTriangular(Δ) : UpperTriangular(copy(Δ')) + Δ_factors = C.uplo == 'L' ? tril!(collect(Δ)) : triu!(collect(Δ')) return ((uplo=nothing, info=nothing, factors=Δ_factors),) end end diff --git a/test/gradcheck.jl b/test/gradcheck.jl index ef958da48..fa2f4dddb 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -801,6 +801,17 @@ end @test gradtest(A->logdet(cholesky(A' * A + I)), A) @test gradtest(B->cholesky(Symmetric(B)).U, A * A' + I) @test gradtest(B->logdet(cholesky(Symmetric(B))), A * A' + I) + + @testset "inference" begin + out, pb = pullback(C -> C.U, cholesky(Symmetric(A'A + I, :U))) + @inferred pb(out) + out, pb = pullback(C -> C.U, cholesky(Symmetric(A'A + I, :L))) + @inferred pb(out) + out, pb = pullback(C -> C.L, cholesky(Symmetric(A'A + I, :U))) + @inferred pb(out) + out, pb = pullback(C -> C.L, cholesky(Symmetric(A'A + I, :L))) + @inferred pb(out) + end end @testset "cholesky - scalar" begin rng = MersenneTwister(123456)