From beb41cc04cbb564349b6b1578c1dbe2514781190 Mon Sep 17 00:00:00 2001 From: WT Date: Wed, 10 Nov 2021 17:01:11 +0000 Subject: [PATCH 1/7] Make literal_getproperty type stable --- src/lib/array.jl | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/src/lib/array.jl b/src/lib/array.jl index 7734ad5ca..cd40006e7 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -741,20 +741,24 @@ end return ((uplo=nothing, info=nothing, factors=nothing),) end end -@adjoint function literal_getproperty(C::Cholesky, ::Val{:U}) +@adjoint function literal_getproperty( + C::Cholesky{T, <:StridedMatrix{T}} where {T<:Real}, ::Val{:U} +) return literal_getproperty(C, Val(:U)), function(Δ) - Δ_factors = C.uplo == 'U' ? UpperTriangular(Δ) : LowerTriangular(copy(Δ')) + Δ_factors = C.uplo == 'U' ? triu!(collect(Δ)) : tril!(collect(Δ')) return ((uplo=nothing, info=nothing, factors=Δ_factors),) end end -@adjoint function literal_getproperty(C::Cholesky, ::Val{:L}) +@adjoint function literal_getproperty( + C::Cholesky{T, <:StridedMatrix{T}} where {T<:Real}, ::Val{:L} +) return literal_getproperty(C, Val(:L)), function(Δ) - Δ_factors = C.uplo == 'L' ? LowerTriangular(Δ) : UpperTriangular(copy(Δ')) + Δ_factors = C.uplo == 'L' ? tril!(collect(Δ)) : triu(collect(Δ')) return ((uplo=nothing, info=nothing, factors=Δ_factors),) end end -@adjoint function logdet(C::Cholesky) +@adjoint function logdet(C::Cholesky{T, <:StridedMatrix{T}} where {T<:Real}) return logdet(C), function(Δ) return ((uplo=nothing, info=nothing, factors=Diagonal(2 .* Δ ./ diag(C.factors))),) end From 8aefe018d69dffd42a4febe55d1329384409084e Mon Sep 17 00:00:00 2001 From: WT Date: Wed, 10 Nov 2021 17:03:01 +0000 Subject: [PATCH 2/7] Revert logdet change --- src/lib/array.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lib/array.jl b/src/lib/array.jl index cd40006e7..8f258bfa7 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -758,7 +758,7 @@ end end end -@adjoint function logdet(C::Cholesky{T, <:StridedMatrix{T}} where {T<:Real}) +@adjoint function logdet(C::Cholesky) return logdet(C), function(Δ) return ((uplo=nothing, info=nothing, factors=Diagonal(2 .* Δ ./ diag(C.factors))),) end From a2df037aab7229df55079ed49826feca639a9700 Mon Sep 17 00:00:00 2001 From: WT Date: Wed, 10 Nov 2021 17:03:28 +0000 Subject: [PATCH 3/7] Check inference --- test/gradcheck.jl | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/test/gradcheck.jl b/test/gradcheck.jl index ef958da48..50455f989 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -801,6 +801,17 @@ end @test gradtest(A->logdet(cholesky(A' * A + I)), A) @test gradtest(B->cholesky(Symmetric(B)).U, A * A' + I) @test gradtest(B->logdet(cholesky(Symmetric(B))), A * A' + I) + + @testset "inference" begin + out, pb = _pullback(Context(), C -> C.U, cholesky(Symmetric(A'A + I, :U))) + @inferred pb(out) + out, pb = _pullback(Context(), C -> C.U, cholesky(Symmetric(A'A + I, :L))) + @inferred pb(out) + out, pb = _pullback(Context(), C -> C.L, cholesky(Symmetric(A'A + I, :U))) + @inferred pb(out) + out, pb = _pullback(Context(), C -> C.L, cholesky(Symmetric(A'A + I, :L))) + @inferred pb(out) + end end @testset "cholesky - scalar" begin rng = MersenneTwister(123456) From 2845f095a33b90bb1aa95a44f14fa300c21a00bd Mon Sep 17 00:00:00 2001 From: WT Date: Wed, 10 Nov 2021 17:03:41 +0000 Subject: [PATCH 4/7] Import _pullback / Context into tests --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index 17ebb3997..870d87b32 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,5 @@ using Zygote, Test -using Zygote: gradient, ZygoteRuleConfig +using Zygote: gradient, ZygoteRuleConfig, _pullback, Context using CUDA using CUDA: has_cuda From 0548ad5e5427cb83b670163ddac0780c6d081e13 Mon Sep 17 00:00:00 2001 From: WT Date: Wed, 10 Nov 2021 17:04:03 +0000 Subject: [PATCH 5/7] Bump patch --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index a8ea16a25..717c707ab 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Zygote" uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.6.30" +version = "0.6.31" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" From e244f44beb1f1326b734dc0be2fe9d91bbf7df22 Mon Sep 17 00:00:00 2001 From: WT Date: Wed, 10 Nov 2021 17:09:30 +0000 Subject: [PATCH 6/7] Fix typo --- src/lib/array.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lib/array.jl b/src/lib/array.jl index 8f258bfa7..659b8d89a 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -753,7 +753,7 @@ end C::Cholesky{T, <:StridedMatrix{T}} where {T<:Real}, ::Val{:L} ) return literal_getproperty(C, Val(:L)), function(Δ) - Δ_factors = C.uplo == 'L' ? tril!(collect(Δ)) : triu(collect(Δ')) + Δ_factors = C.uplo == 'L' ? tril!(collect(Δ)) : triu!(collect(Δ')) return ((uplo=nothing, info=nothing, factors=Δ_factors),) end end From 23c728b16d64cd578314d74dd2b5cf8b9f15a3bc Mon Sep 17 00:00:00 2001 From: WT Date: Thu, 11 Nov 2021 09:28:13 +0000 Subject: [PATCH 7/7] _pullback -> pullback --- test/gradcheck.jl | 8 ++++---- test/runtests.jl | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/test/gradcheck.jl b/test/gradcheck.jl index 50455f989..fa2f4dddb 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -803,13 +803,13 @@ end @test gradtest(B->logdet(cholesky(Symmetric(B))), A * A' + I) @testset "inference" begin - out, pb = _pullback(Context(), C -> C.U, cholesky(Symmetric(A'A + I, :U))) + out, pb = pullback(C -> C.U, cholesky(Symmetric(A'A + I, :U))) @inferred pb(out) - out, pb = _pullback(Context(), C -> C.U, cholesky(Symmetric(A'A + I, :L))) + out, pb = pullback(C -> C.U, cholesky(Symmetric(A'A + I, :L))) @inferred pb(out) - out, pb = _pullback(Context(), C -> C.L, cholesky(Symmetric(A'A + I, :U))) + out, pb = pullback(C -> C.L, cholesky(Symmetric(A'A + I, :U))) @inferred pb(out) - out, pb = _pullback(Context(), C -> C.L, cholesky(Symmetric(A'A + I, :L))) + out, pb = pullback(C -> C.L, cholesky(Symmetric(A'A + I, :L))) @inferred pb(out) end end diff --git a/test/runtests.jl b/test/runtests.jl index 870d87b32..17ebb3997 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,5 @@ using Zygote, Test -using Zygote: gradient, ZygoteRuleConfig, _pullback, Context +using Zygote: gradient, ZygoteRuleConfig using CUDA using CUDA: has_cuda