Skip to content

Commit

Permalink
AD fix for CorrBijector (#281)
Browse files Browse the repository at this point in the history
* removed redundant imports to BijectorsZygoteExt

* use cholesky_upper and cholesky_lower instead of cholesky_factor, etc.

* added tests for CorrVecBijector

* name testset correctly

* use cholesky_lower and cholesky_upper instead of cholesky_factor

* removed now-redundant cholesky_factor

* Fix obsolete function references in tests.  (#282)

* Update chainrules.jl

* Update corr.jl

* Revert changes to transform.

* removed type-piracy that has been addressed upstream and bumped Zygote
version in test

* use :L for Hermitian in `cholesky_lower`

* fixed ForwardDiff tests for LKJCholesky

* fixed tests for matrix dists and added tests for both values of uplo
in LKJCholesky tests

* another attempt at fixing Julia 1.6 tests

---------

Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com>
  • Loading branch information
torfjelde and yebai authored Aug 12, 2023
1 parent e87a2aa commit 52ee210
Show file tree
Hide file tree
Showing 10 changed files with 81 additions and 52 deletions.
11 changes: 2 additions & 9 deletions ext/BijectorsReverseDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,8 @@ end
)

@grad_from_chainrules _link_chol_lkj(x::TrackedMatrix)
@grad_from_chainrules _link_chol_lkj_from_upper(x::TrackedMatrix)
@grad_from_chainrules _link_chol_lkj_from_lower(x::TrackedMatrix)
@grad_from_chainrules _inv_link_chol_lkj(x::TrackedVector)

cholesky_lower(X::TrackedMatrix) = track(cholesky_lower, X)
Expand Down Expand Up @@ -308,13 +310,4 @@ transpose_eager(X::TrackedMatrix) = track(transpose_eager, X)
return y, transpose_eager_pullback
end

if VERSION <= v"1.8.0-DEV.1526"
# HACK: This dispatch does not wrap X in Hermitian before calling cholesky.
# cholesky does not work with AbstractMatrix in julia versions before the compared one,
# and it would error with Hermitian{ReverseDiff.TrackedArray}.
# See commit when the fix was introduced :
# https://github.com/JuliaLang/julia/commit/635449dabee81bba315ab066627a98f856141969
cholesky_factor(X::ReverseDiff.TrackedArray) = cholesky_factor(cholesky(X))
end

end
5 changes: 1 addition & 4 deletions ext/BijectorsZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ if isdefined(Base, :get_extension)
_simplex_inv_bijector,
replace_diag,
jacobian,
_inv_link_chol_lkj,
_link_chol_lkj,
_transform_ordered,
_transform_inverse_ordered,
find_alpha,
Expand Down Expand Up @@ -55,8 +53,6 @@ else
_simplex_inv_bijector,
replace_diag,
jacobian,
_inv_link_chol_lkj,
_link_chol_lkj,
_transform_ordered,
_transform_inverse_ordered,
find_alpha,
Expand Down Expand Up @@ -244,4 +240,5 @@ end
return replace_diag(log, Y)
end
end

end
22 changes: 13 additions & 9 deletions src/bijectors/corr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,9 @@ struct CorrBijector <: Bijector end
with_logabsdet_jacobian(b::CorrBijector, x) = transform(b, x), logabsdetjac(b, x)

function transform(b::CorrBijector, X::AbstractMatrix{<:Real})
w = upper_triangular(parent(cholesky(X).U)) # keep LowerTriangular until here can avoid some computation
w = cholesky_upper(X)
r = _link_chol_lkj(w)
return r + zero(X)
# This dense format itself is required by a test, though I can't get the point.
# https://github.com/TuringLang/Bijectors.jl/blob/b0aaa98f90958a167a0b86c8e8eca9b95502c42d/test/transform.jl#L67
return r
end

function transform(ib::Inverse{CorrBijector}, y::AbstractMatrix{<:Real})
Expand Down Expand Up @@ -127,7 +125,7 @@ struct VecCorrBijector <: Bijector end

with_logabsdet_jacobian(b::VecCorrBijector, x) = transform(b, x), logabsdetjac(b, x)

transform(::VecCorrBijector, X) = _link_chol_lkj(cholesky_factor(X))
transform(::VecCorrBijector, X) = _link_chol_lkj_from_upper(cholesky_upper(X))

function logabsdetjac(b::VecCorrBijector, x)
return -logabsdetjac(inverse(b), b(x))
Expand Down Expand Up @@ -215,7 +213,13 @@ end
# TODO: Implement directly to make use of shared computations.
with_logabsdet_jacobian(b::VecCholeskyBijector, x) = transform(b, x), logabsdetjac(b, x)

transform(::VecCholeskyBijector, X) = _link_chol_lkj(cholesky_factor(X))
function transform(b::VecCholeskyBijector, X)
return if b.mode === :U
_link_chol_lkj_from_upper(cholesky_upper(X))
else # No need to check for === :L, as it is checked in the VecCholeskyBijector constructor.
_link_chol_lkj_from_lower(cholesky_lower(X))
end
end

function logabsdetjac(b::VecCholeskyBijector, x)
return -logabsdetjac(inverse(b), b(x))
Expand All @@ -229,7 +233,7 @@ function transform(b::Inverse{VecCholeskyBijector}, y::AbstractVector{<:Real})
else # No need to check for === :L, as it is checked in the VecCholeskyBijector constructor.
# HACK: Need to make materialize the transposed matrix to avoid numerical instabilities.
# If we don't, the return-type can be both `Matrix` and `Transposed`.
return Cholesky(permutedims(_inv_link_chol_lkj(y), (2, 1)), 'L', 0)
return Cholesky(transpose_eager(_inv_link_chol_lkj(y)), 'L', 0)
end
end

Expand Down Expand Up @@ -299,7 +303,7 @@ function _link_chol_lkj(W::AbstractMatrix)
return z
end

function _link_chol_lkj(W::UpperTriangular)
function _link_chol_lkj_from_upper(W::AbstractMatrix)
K = LinearAlgebra.checksquare(W)
N = ((K - 1) * K) ÷ 2 # {K \choose 2} free parameters

Expand All @@ -321,7 +325,7 @@ function _link_chol_lkj(W::UpperTriangular)
return z
end

_link_chol_lkj(W::LowerTriangular) = _link_chol_lkj(transpose(W))
_link_chol_lkj_from_lower(W::AbstractMatrix) = _link_chol_lkj_from_upper(transpose_eager(W))

"""
_inv_link_chol_lkj(y)
Expand Down
12 changes: 6 additions & 6 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ function ChainRulesCore.rrule(::typeof(_transform_inverse_ordered), x::AbstractM
return y, _transform_inverse_ordered_adjoint
end

function ChainRulesCore.rrule(::typeof(_link_chol_lkj), W::UpperTriangular)
function ChainRulesCore.rrule(::typeof(_link_chol_lkj_from_upper), W::AbstractMatrix)
K = LinearAlgebra.checksquare(W)
N = ((K - 1) * K) ÷ 2

Expand All @@ -178,7 +178,7 @@ function ChainRulesCore.rrule(::typeof(_link_chol_lkj), W::UpperTriangular)
end
end

function pullback_link_chol_lkj(Δz_thunked)
function pullback_link_chol_lkj_from_upper(Δz_thunked)
Δz = ChainRulesCore.unthunk(Δz_thunked)

ΔW = similar(W)
Expand Down Expand Up @@ -208,10 +208,10 @@ function ChainRulesCore.rrule(::typeof(_link_chol_lkj), W::UpperTriangular)
return ChainRulesCore.NoTangent(), ΔW
end

return z, pullback_link_chol_lkj
return z, pullback_link_chol_lkj_from_upper
end

function ChainRulesCore.rrule(::typeof(_link_chol_lkj), W::LowerTriangular)
function ChainRulesCore.rrule(::typeof(_link_chol_lkj_from_lower), W::AbstractMatrix)
K = LinearAlgebra.checksquare(W)
N = ((K - 1) * K) ÷ 2

Expand All @@ -233,7 +233,7 @@ function ChainRulesCore.rrule(::typeof(_link_chol_lkj), W::LowerTriangular)
end
end

function pullback_link_chol_lkj(Δz_thunked)
function pullback_link_chol_lkj_from_lower(Δz_thunked)
Δz = ChainRulesCore.unthunk(Δz_thunked)

ΔW = similar(W)
Expand Down Expand Up @@ -263,7 +263,7 @@ function ChainRulesCore.rrule(::typeof(_link_chol_lkj), W::LowerTriangular)
return ChainRulesCore.NoTangent(), ΔW
end

return z, pullback_link_chol_lkj
return z, pullback_link_chol_lkj_from_lower
end

function ChainRulesCore.rrule(::typeof(_inv_link_chol_lkj), y::AbstractVector)
Expand Down
9 changes: 3 additions & 6 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,6 @@ upper_triangular(A::AbstractMatrix) = convert(typeof(A), UpperTriangular(A))
pd_from_lower(X) = LowerTriangular(X) * LowerTriangular(X)'
pd_from_upper(X) = UpperTriangular(X)' * UpperTriangular(X)

cholesky_factor(X::AbstractMatrix) = cholesky_factor(cholesky(Hermitian(X)))
cholesky_factor(X::Cholesky) = X.U
cholesky_factor(X::UpperTriangular) = X
cholesky_factor(X::LowerTriangular) = X

# HACK: Allows us to define custom chain rules while we wait for upstream fixes.
transpose_eager(X::AbstractMatrix) = permutedims(X)

Expand All @@ -33,7 +28,8 @@ rather than `LowerTriangular`.
This is a thin wrapper around `cholesky(Hermitian(X)).L`
that returns a `Matrix` rather than `LowerTriangular`.
"""
cholesky_lower(X::AbstractMatrix) = lower_triangular(parent(cholesky(Hermitian(X)).L))
cholesky_lower(X::AbstractMatrix) = lower_triangular(parent(cholesky(Hermitian(X, :L)).L))
cholesky_lower(X::Cholesky) = X.L

"""
cholesky_upper(X)
Expand All @@ -46,6 +42,7 @@ rather than `UpperTriangular`.
that returns a `Matrix` rather than `UpperTriangular`.
"""
cholesky_upper(X::AbstractMatrix) = upper_triangular(parent(cholesky(Hermitian(X)).U))
cholesky_upper(X::Cholesky) = X.U

"""
triu_mask(X::AbstractMatrix, k::Int)
Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,5 @@ LazyArrays = "1"
LogExpFunctions = "0.3.1"
ReverseDiff = "1.4.2"
Tracker = "0.2.11"
Zygote = "0.5.4, 0.6"
Zygote = "0.6.63"
julia = "1.3"
4 changes: 2 additions & 2 deletions test/ad/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
# LKJ and LKJCholesky bijector
dist = LKJCholesky(3, 4)
x = rand(dist)
test_rrule(Bijectors._link_chol_lkj, x.U)
test_rrule(Bijectors._link_chol_lkj, x.L)
test_rrule(Bijectors._link_chol_lkj_from_upper, x.U)
test_rrule(Bijectors._link_chol_lkj_from_lower, x.L)

b = bijector(dist)
y = b(x)
Expand Down
35 changes: 35 additions & 0 deletions test/ad/corr.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
@testset "AD for VecCorrBijector" begin
d = 4
dist = LKJ(d, 2.0)
b = bijector(dist)
binv = inverse(b)

x = rand(dist)
y = b(x)

test_ad(y) do x
sum(transform(b, binv(x)))
end

test_ad(y) do y
sum(transform(binv, y))
end
end

@testset "AD for VecCholeskyBijector" begin
d = 4
dist = LKJCholesky(d, 2.0)
b = bijector(dist)
binv = inverse(b)

x = rand(dist)
y = b(x)

test_ad(y) do y
sum(transform(b, binv(y)))
end

test_ad(y) do y
sum(Bijectors.cholesky_upper(transform(binv, y)))
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,5 @@ if GROUP == "All" || GROUP == "AD"
include("ad/chainrules.jl")
include("ad/flows.jl")
include("ad/pd.jl")
include("ad/corr.jl")
end
32 changes: 17 additions & 15 deletions test/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -178,17 +178,17 @@ end
@testset "matrix" begin
let
matrix_dists = [Wishart(7, [1 0.5; 0.5 1]), InverseWishart(2, [1 0.5; 0.5 1])]
for dist in matrix_dists
@testset "$dist" for dist in matrix_dists
single_sample_tests(dist)

x = rand(dist)
x = x + x' + 2I
upperinds = [
LinearIndices(size(x))[I] for I in CartesianIndices(size(x)) if I[2] >= I[1]
inds = [
LinearIndices(size(x))[I] for I in CartesianIndices(size(x)) if I[2] <= I[1]
]
logpdf_turing = logpdf_with_trans(dist, x, true)
J = ForwardDiff.jacobian(x -> link(dist, x), x)
J = J[:, upperinds]
J = J[:, inds]
@test logpdf(dist, x) - _logabsdet(J) logpdf_turing
end
end
Expand All @@ -214,19 +214,21 @@ end
end

@testset "LKJCholesky" begin
dist = LKJCholesky(3, 1)
@testset "uplo: $uplo" for uplo in [:L, :U]
dist = LKJCholesky(3, 1, uplo)
single_sample_tests(dist)

single_sample_tests(dist)

x = rand(dist)
x = rand(dist)

upperinds = [
LinearIndices(size(x))[I] for I in CartesianIndices(size(x)) if I[2] > I[1]
]
J = ForwardDiff.jacobian(x -> link(dist, x), x.U)
J = J[:, upperinds]
logpdf_turing = logpdf_with_trans(dist, x, true)
@test logpdf(dist, x) - _logabsdet(J) logpdf_turing
inds = [
LinearIndices(size(x))[I] for I in CartesianIndices(size(x)) if
(uplo === :L && I[2] < I[1]) || (uplo === :U && I[2] > I[1])
]
J = ForwardDiff.jacobian(z -> link(dist, Cholesky(z, x.uplo, x.info)), x.UL)
J = J[:, inds]
logpdf_turing = logpdf_with_trans(dist, x, true)
@test logpdf(dist, x) - _logabsdet(J) logpdf_turing
end
end

################################## Miscelaneous old tests ##################################
Expand Down

0 comments on commit 52ee210

Please sign in to comment.