Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AD fix for CorrBijector #281

Merged
merged 13 commits into from
Aug 12, 2023
Merged

Conversation

torfjelde
Copy link
Member

This is a sibling-PR of #280, making use of the functionality introduced there to ensure that CorrBijector and it's siblings are also working as intended.

Comment on lines 244 to 246
# TODO: Remove these as soon as https://github.com/FluxML/Zygote.jl/pull/1444 is merged.
@adjoint LinearAlgebra.parent(x::LinearAlgebra.UpperTriangular) = parent(x), ȳ -> (LinearAlgebra.UpperTriangular(ȳ),)
@adjoint LinearAlgebra.parent(x::LinearAlgebra.LowerTriangular) = parent(x), ȳ -> (LinearAlgebra.LowerTriangular(ȳ),)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added this here to check if tests are succeeding; will remove as soon as the mentioned PR goes through.

@@ -321,7 +325,7 @@ function _link_chol_lkj(W::UpperTriangular)
return z
end

_link_chol_lkj(W::LowerTriangular) = _link_chol_lkj(transpose(W))
_link_chol_lkj_from_lower(W::AbstractMatrix) = _link_chol_lkj_from_upper(transpose_eager(W))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one is a bit weird because I just saw that it has it's own chainrule defined while the evaluation itself does not have a specialized implementation..

@yebai
Copy link
Member

yebai commented Aug 10, 2023

There is a reproducible Cholesky factorization failure error. I couldn't find any obvious sources causing it.

* Update chainrules.jl

* Update corr.jl

* Revert changes to transform.
@sunxd3
Copy link

sunxd3 commented Aug 11, 2023

Also when I try to run

@testset "LKJCholesky" begin
dist = LKJCholesky(3, 1)
single_sample_tests(dist)
x = rand(dist)
upperinds = [
LinearIndices(size(x))[I] for I in CartesianIndices(size(x)) if I[2] > I[1]
]
J = ForwardDiff.jacobian(x -> link(dist, x), x.U)
J = J[:, upperinds]
logpdf_turing = logpdf_with_trans(dist, x, true)
@test logpdf(dist, x) - _logabsdet(J) logpdf_turing
end

encounter

julia> dist = LKJCholesky(3, 1)
LKJCholesky{Float64}(
d: 3
η: 1.0
uplo: L
)


julia> x = rand(dist)
Cholesky{Float64, Matrix{Float64}}
L factor:
3×3 LowerTriangular{Float64, Matrix{Float64}}:
  1.0                  
  0.198156  0.980171    
 -0.797754  0.257329  0.545316


julia> J = ForwardDiff.jacobian(x -> link(dist, x), x.U)
ERROR: PosDefException: matrix is not positive definite; Cholesky factorization failed.
Stacktrace:
  [1] checkpositivedefinite
    @ /path/to/julia/stdlib/v1.9/LinearAlgebra/src/factorization.jl:18 [inlined]
  [2] #cholesky!#152
    @ /path/to/julia/stdlib/v1.9/LinearAlgebra/src/cholesky.jl:268 [inlined]
  [3] cholesky! (repeats 2 times)
    @ /path/to/julia/stdlib/v1.9/LinearAlgebra/src/cholesky.jl:266 [inlined]
  [4] #cholesky#162
    @ /path/to/julia/stdlib/v1.9/LinearAlgebra/src/cholesky.jl:400 [inlined]
  [5] cholesky (repeats 2 times)
    @ /path/to/julia/stdlib/v1.9/LinearAlgebra/src/cholesky.jl:400 [inlined]
  [6] cholesky_lower(X::UpperTriangular{ForwardDiff.Dual{ForwardDiff.Tag{var"#9#10", Float64}, Float64, 9}, Matrix{ForwardDiff.Dual{ForwardDiff.Tag{var"#9#10", Float64}, Float64, 9}}})
    @ Bijectors /path/to/Bijectors/src/utils.jl:31
  [7] transform(b::Bijectors.VecCholeskyBijector, X::UpperTriangular{ForwardDiff.Dual{ForwardDiff.Tag{var"#9#10", Float64}, Float64, 9}, Matrix{ForwardDiff.Dual{ForwardDiff.Tag{var"#9#10", Float64}, Float64, 9}}})
    @ Bijectors /path/to/Bijectors/src/bijectors/corr.jl:220
  [8] Transform
    @ /path/to/Bijectors/src/interface.jl:80 [inlined]
  [9] link(d::LKJCholesky{Float64}, x::UpperTriangular{ForwardDiff.Dual{ForwardDiff.Tag{var"#9#10", Float64}, Float64, 9}, Matrix{ForwardDiff.Dual{ForwardDiff.Tag{var"#9#10", Float64}, Float64, 9}}})
    @ Bijectors /path/to/Bijectors/src/Bijectors.jl:128
 [10] (::var"#9#10")(x::UpperTriangular{ForwardDiff.Dual{ForwardDiff.Tag{var"#9#10", Float64}, Float64, 9}, Matrix{ForwardDiff.Dual{ForwardDiff.Tag{var"#9#10", Float64}, Float64, 9}}})
    @ Main ./REPL[3]:1
 [11] vector_mode_dual_eval!
    @ /path/to/ForwardDiff/src/apiutils.jl:24 [inlined]
 [12] vector_mode_jacobian(f::var"#9#10", x::UpperTriangular{Float64, Matrix{Float64}}, cfg::ForwardDiff.JacobianConfig{ForwardDiff.Tag{var"#9#10", Float64}, Float64, 9, UpperTriangular{ForwardDiff.Dual{ForwardDiff.Tag{var"#9#10", Float64}, Float64, 9}, Matrix{ForwardDiff.Dual{ForwardDiff.Tag{var"#9#10", Float64}, Float64, 9}}}})
    @ ForwardDiff /path/to/ForwardDiff/src/jacobian.jl:125
 [13] jacobian(f::Function, x::UpperTriangular{Float64, Matrix{Float64}}, cfg::ForwardDiff.JacobianConfig{ForwardDiff.Tag{var"#9#10", Float64}, Float64, 9, UpperTriangular{ForwardDiff.Dual{ForwardDiff.Tag{var"#9#10", Float64}, Float64, 9}, Matrix{ForwardDiff.Dual{ForwardDiff.Tag{var"#9#10", Float64}, Float64, 9}}}}, ::Val{true})
    @ ForwardDiff /path/to/ForwardDiff/src/jacobian.jl:21
 [14] jacobian(f::Function, x::UpperTriangular{Float64, Matrix{Float64}}, cfg::ForwardDiff.JacobianConfig{ForwardDiff.Tag{var"#9#10", Float64}, Float64, 9, UpperTriangular{ForwardDiff.Dual{ForwardDiff.Tag{var"#9#10", Float64}, Float64, 9}, Matrix{ForwardDiff.Dual{ForwardDiff.Tag{var"#9#10", Float64}, Float64, 9}}}})
    @ ForwardDiff /path/to/ForwardDiff/src/jacobian.jl:19
 [15] jacobian(f::Function, x::UpperTriangular{Float64, Matrix{Float64}})
    @ ForwardDiff /path/to/ForwardDiff/src/jacobian.jl:19
 [16] top-level scope
    @ REPL[3]:1

@sunxd3
Copy link

sunxd3 commented Aug 11, 2023

Maybe related, LKJCholesky default uplo='L'; Hermitian default uplo='U'

@torfjelde
Copy link
Member Author

In latest version cholesky takes AbstractMatrix.

But we only call cholesky with Hermitian, which is defined, no?

Maybe related, LKJCholesky default uplo='L'; Hermitian default uplo='U'

The issue is two-fold. First, we're giving the transformation in ForwardDiff.jacobian(...) a upper-triangular but it expects a lower-triangular (it ends up calling cholesky_lower).
But even if we fix this, i.e. pass x.U instead, lit just returns 0s:

julia> using Bijectors

julia> dist = LKJCholesky(3, 1)
LKJCholesky{Float64}(
d: 3
η: 1.0
uplo: L
)


julia> x = rand(dist)
LinearAlgebra.Cholesky{Float64, Matrix{Float64}}
L factor:
3×3 LinearAlgebra.LowerTriangular{Float64, Matrix{Float64}}:
  1.0                   
 -0.107859   0.994166    
 -0.38706   -0.246285  0.888554

julia> link(dist, x)
3-element Vector{Float64}:
 -0.10828016710654258
 -0.40833723913774234
 -0.2737439074161078

julia> link(dist, x.L)
3-element Vector{Float64}:
 0.0
 0.0
 0.0

This was caused by this line:

cholesky_lower(X::AbstractMatrix) = lower_triangular(parent(cholesky(Hermitian(X)).L))

Here we ended up calling Hermitian(X, :U) on a X which is lower-triangular; I've now made this call Hermitian(X, :L). Now we have transpose(cholesky_lower(x.L)) == cholesky_upper(x.U), as expected.

@torfjelde
Copy link
Member Author

Alrighty; tests are at least passing locally for me now. Let's see if Julia 1.6 also works.

@torfjelde
Copy link
Member Author

Looks like we're ready to go:)

@torfjelde torfjelde merged commit 52ee210 into torfjelde/pd-fix Aug 12, 2023
21 checks passed
@delete-merged-branch delete-merged-branch bot deleted the torfjelde/corr-ad-update branch August 12, 2023 10:47
torfjelde added a commit that referenced this pull request Aug 12, 2023
* added cholesky_lower and cholesky_triangular

* updated PD to use new cholesky_lower and cholesky_upper

* simplified imports in BijectorsReverseDiffExtx

* added ChainRules as a dep since we need the chain rules for cholesky, etc.

* forgot to update Project.toml in previous commit

* added explicit implementation of with_logabsdet_jacobian for PDBijector

* Update src/utils.jl

* added ProjectTo in rrules for cholesky_lower and cholesky_upper to be proper

* added ProjectTo for cholesky_upper too

* added transpose_eager as a alias for permutedims to allow definition
of AD rules without type piracy

* allow usage of ForwardDiff gradient as ground-truth

* added AD tests for PDVecBijector

* added AD tests for PDVecBijector to runtests and commented out all
other tests for the sake of reproducing ReverseDiff bug

* forgot to remove type-piracy def of ReverseDiff rule for permutedims

* use ReverseDiff.@Grad instead of ReverseDiff.@grad_from_chainrules

* only define cholesky_lower and cholesky_upper rules for ReverseDiff, remove rules ChainRules defs

* formatting

* parameterise gradient test for PD bijector properly instead of using
ForwardDiff as per suggestion of @devmotion

* reversed chagne to test_ad

* reactivate tests

* updated doocstrings

* improved PDVecBijector AD tests a bit

* AD fix for CorrBijector (#281)

* removed redundant imports to BijectorsZygoteExt

* use cholesky_upper and cholesky_lower instead of cholesky_factor, etc.

* added tests for CorrVecBijector

* name testset correctly

* use cholesky_lower and cholesky_upper instead of cholesky_factor

* removed now-redundant cholesky_factor

* Fix obsolete function references in tests.  (#282)

* Update chainrules.jl

* Update corr.jl

* Revert changes to transform.

* removed type-piracy that has been addressed upstream and bumped Zygote
version in test

* use :L for Hermitian in `cholesky_lower`

* fixed ForwardDiff tests for LKJCholesky

* fixed tests for matrix dists and added tests for both values of uplo
in LKJCholesky tests

* another attempt at fixing Julia 1.6 tests

---------

Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com>

---------

Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com>
@yebai
Copy link
Member

yebai commented Aug 12, 2023

Thanks @sunxd3 and @torfjelde!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants