Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AD fix for CorrBijector #281

Merged
merged 13 commits into from
Aug 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 2 additions & 9 deletions ext/BijectorsReverseDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,8 @@ end
)

@grad_from_chainrules _link_chol_lkj(x::TrackedMatrix)
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
@grad_from_chainrules _link_chol_lkj_from_upper(x::TrackedMatrix)
@grad_from_chainrules _link_chol_lkj_from_lower(x::TrackedMatrix)
@grad_from_chainrules _inv_link_chol_lkj(x::TrackedVector)

cholesky_lower(X::TrackedMatrix) = track(cholesky_lower, X)
Expand Down Expand Up @@ -308,13 +310,4 @@ transpose_eager(X::TrackedMatrix) = track(transpose_eager, X)
return y, transpose_eager_pullback
end

if VERSION <= v"1.8.0-DEV.1526"
# HACK: This dispatch does not wrap X in Hermitian before calling cholesky.
# cholesky does not work with AbstractMatrix in julia versions before the compared one,
# and it would error with Hermitian{ReverseDiff.TrackedArray}.
# See commit when the fix was introduced :
# https://github.com/JuliaLang/julia/commit/635449dabee81bba315ab066627a98f856141969
cholesky_factor(X::ReverseDiff.TrackedArray) = cholesky_factor(cholesky(X))
end

end
5 changes: 1 addition & 4 deletions ext/BijectorsZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ if isdefined(Base, :get_extension)
_simplex_inv_bijector,
replace_diag,
jacobian,
_inv_link_chol_lkj,
_link_chol_lkj,
_transform_ordered,
_transform_inverse_ordered,
find_alpha,
Expand Down Expand Up @@ -55,8 +53,6 @@ else
_simplex_inv_bijector,
replace_diag,
jacobian,
_inv_link_chol_lkj,
_link_chol_lkj,
_transform_ordered,
_transform_inverse_ordered,
find_alpha,
Expand Down Expand Up @@ -244,4 +240,5 @@ end
return replace_diag(log, Y)
end
end

end
22 changes: 13 additions & 9 deletions src/bijectors/corr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,9 @@ struct CorrBijector <: Bijector end
with_logabsdet_jacobian(b::CorrBijector, x) = transform(b, x), logabsdetjac(b, x)

function transform(b::CorrBijector, X::AbstractMatrix{<:Real})
w = upper_triangular(parent(cholesky(X).U)) # keep LowerTriangular until here can avoid some computation
w = cholesky_upper(X)
r = _link_chol_lkj(w)
return r + zero(X)
# This dense format itself is required by a test, though I can't get the point.
# https://github.com/TuringLang/Bijectors.jl/blob/b0aaa98f90958a167a0b86c8e8eca9b95502c42d/test/transform.jl#L67
return r
end

function transform(ib::Inverse{CorrBijector}, y::AbstractMatrix{<:Real})
Expand Down Expand Up @@ -127,7 +125,7 @@ struct VecCorrBijector <: Bijector end

with_logabsdet_jacobian(b::VecCorrBijector, x) = transform(b, x), logabsdetjac(b, x)

transform(::VecCorrBijector, X) = _link_chol_lkj(cholesky_factor(X))
transform(::VecCorrBijector, X) = _link_chol_lkj_from_upper(cholesky_upper(X))

function logabsdetjac(b::VecCorrBijector, x)
return -logabsdetjac(inverse(b), b(x))
Expand Down Expand Up @@ -215,7 +213,13 @@ end
# TODO: Implement directly to make use of shared computations.
with_logabsdet_jacobian(b::VecCholeskyBijector, x) = transform(b, x), logabsdetjac(b, x)

transform(::VecCholeskyBijector, X) = _link_chol_lkj(cholesky_factor(X))
function transform(b::VecCholeskyBijector, X)
return if b.mode === :U
_link_chol_lkj_from_upper(cholesky_upper(X))
else # No need to check for === :L, as it is checked in the VecCholeskyBijector constructor.
_link_chol_lkj_from_lower(cholesky_lower(X))
end
end

function logabsdetjac(b::VecCholeskyBijector, x)
return -logabsdetjac(inverse(b), b(x))
Expand All @@ -229,7 +233,7 @@ function transform(b::Inverse{VecCholeskyBijector}, y::AbstractVector{<:Real})
else # No need to check for === :L, as it is checked in the VecCholeskyBijector constructor.
# HACK: Need to make materialize the transposed matrix to avoid numerical instabilities.
# If we don't, the return-type can be both `Matrix` and `Transposed`.
return Cholesky(permutedims(_inv_link_chol_lkj(y), (2, 1)), 'L', 0)
return Cholesky(transpose_eager(_inv_link_chol_lkj(y)), 'L', 0)
end
end

Expand Down Expand Up @@ -299,7 +303,7 @@ function _link_chol_lkj(W::AbstractMatrix)
return z
end

function _link_chol_lkj(W::UpperTriangular)
function _link_chol_lkj_from_upper(W::AbstractMatrix)
K = LinearAlgebra.checksquare(W)
N = ((K - 1) * K) ÷ 2 # {K \choose 2} free parameters

Expand All @@ -321,7 +325,7 @@ function _link_chol_lkj(W::UpperTriangular)
return z
end

_link_chol_lkj(W::LowerTriangular) = _link_chol_lkj(transpose(W))
_link_chol_lkj_from_lower(W::AbstractMatrix) = _link_chol_lkj_from_upper(transpose_eager(W))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one is a bit weird because I just saw that it has it's own chainrule defined while the evaluation itself does not have a specialized implementation..


"""
_inv_link_chol_lkj(y)
Expand Down
12 changes: 6 additions & 6 deletions src/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ function ChainRulesCore.rrule(::typeof(_transform_inverse_ordered), x::AbstractM
return y, _transform_inverse_ordered_adjoint
end

function ChainRulesCore.rrule(::typeof(_link_chol_lkj), W::UpperTriangular)
function ChainRulesCore.rrule(::typeof(_link_chol_lkj_from_upper), W::AbstractMatrix)
K = LinearAlgebra.checksquare(W)
N = ((K - 1) * K) ÷ 2

Expand All @@ -178,7 +178,7 @@ function ChainRulesCore.rrule(::typeof(_link_chol_lkj), W::UpperTriangular)
end
end

function pullback_link_chol_lkj(Δz_thunked)
function pullback_link_chol_lkj_from_upper(Δz_thunked)
Δz = ChainRulesCore.unthunk(Δz_thunked)

ΔW = similar(W)
Expand Down Expand Up @@ -208,10 +208,10 @@ function ChainRulesCore.rrule(::typeof(_link_chol_lkj), W::UpperTriangular)
return ChainRulesCore.NoTangent(), ΔW
end

return z, pullback_link_chol_lkj
return z, pullback_link_chol_lkj_from_upper
end

function ChainRulesCore.rrule(::typeof(_link_chol_lkj), W::LowerTriangular)
function ChainRulesCore.rrule(::typeof(_link_chol_lkj_from_lower), W::AbstractMatrix)
K = LinearAlgebra.checksquare(W)
N = ((K - 1) * K) ÷ 2

Expand All @@ -233,7 +233,7 @@ function ChainRulesCore.rrule(::typeof(_link_chol_lkj), W::LowerTriangular)
end
end

function pullback_link_chol_lkj(Δz_thunked)
function pullback_link_chol_lkj_from_lower(Δz_thunked)
Δz = ChainRulesCore.unthunk(Δz_thunked)

ΔW = similar(W)
Expand Down Expand Up @@ -263,7 +263,7 @@ function ChainRulesCore.rrule(::typeof(_link_chol_lkj), W::LowerTriangular)
return ChainRulesCore.NoTangent(), ΔW
end

return z, pullback_link_chol_lkj
return z, pullback_link_chol_lkj_from_lower
end

function ChainRulesCore.rrule(::typeof(_inv_link_chol_lkj), y::AbstractVector)
Expand Down
9 changes: 3 additions & 6 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,6 @@ upper_triangular(A::AbstractMatrix) = convert(typeof(A), UpperTriangular(A))
pd_from_lower(X) = LowerTriangular(X) * LowerTriangular(X)'
pd_from_upper(X) = UpperTriangular(X)' * UpperTriangular(X)

cholesky_factor(X::AbstractMatrix) = cholesky_factor(cholesky(Hermitian(X)))
cholesky_factor(X::Cholesky) = X.U
cholesky_factor(X::UpperTriangular) = X
cholesky_factor(X::LowerTriangular) = X

# HACK: Allows us to define custom chain rules while we wait for upstream fixes.
transpose_eager(X::AbstractMatrix) = permutedims(X)

Expand All @@ -33,7 +28,8 @@ rather than `LowerTriangular`.
This is a thin wrapper around `cholesky(Hermitian(X)).L`
that returns a `Matrix` rather than `LowerTriangular`.
"""
cholesky_lower(X::AbstractMatrix) = lower_triangular(parent(cholesky(Hermitian(X)).L))
cholesky_lower(X::AbstractMatrix) = lower_triangular(parent(cholesky(Hermitian(X, :L)).L))
cholesky_lower(X::Cholesky) = X.L

"""
cholesky_upper(X)
Expand All @@ -46,6 +42,7 @@ rather than `UpperTriangular`.
that returns a `Matrix` rather than `UpperTriangular`.
"""
cholesky_upper(X::AbstractMatrix) = upper_triangular(parent(cholesky(Hermitian(X)).U))
cholesky_upper(X::Cholesky) = X.U

"""
triu_mask(X::AbstractMatrix, k::Int)
Expand Down
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,5 @@ LazyArrays = "1"
LogExpFunctions = "0.3.1"
ReverseDiff = "1.4.2"
Tracker = "0.2.11"
Zygote = "0.5.4, 0.6"
Zygote = "0.6.63"
julia = "1.3"
4 changes: 2 additions & 2 deletions test/ad/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
# LKJ and LKJCholesky bijector
dist = LKJCholesky(3, 4)
x = rand(dist)
test_rrule(Bijectors._link_chol_lkj, x.U)
test_rrule(Bijectors._link_chol_lkj, x.L)
test_rrule(Bijectors._link_chol_lkj_from_upper, x.U)
test_rrule(Bijectors._link_chol_lkj_from_lower, x.L)

b = bijector(dist)
y = b(x)
Expand Down
35 changes: 35 additions & 0 deletions test/ad/corr.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
@testset "AD for VecCorrBijector" begin
d = 4
dist = LKJ(d, 2.0)
b = bijector(dist)
binv = inverse(b)

x = rand(dist)
y = b(x)

test_ad(y) do x
sum(transform(b, binv(x)))
end

test_ad(y) do y
sum(transform(binv, y))
end
end

@testset "AD for VecCholeskyBijector" begin
d = 4
dist = LKJCholesky(d, 2.0)
b = bijector(dist)
binv = inverse(b)

x = rand(dist)
y = b(x)

test_ad(y) do y
sum(transform(b, binv(y)))
end

test_ad(y) do y
sum(Bijectors.cholesky_upper(transform(binv, y)))
end
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,5 @@ if GROUP == "All" || GROUP == "AD"
include("ad/chainrules.jl")
include("ad/flows.jl")
include("ad/pd.jl")
include("ad/corr.jl")
end
32 changes: 17 additions & 15 deletions test/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -178,17 +178,17 @@ end
@testset "matrix" begin
let
matrix_dists = [Wishart(7, [1 0.5; 0.5 1]), InverseWishart(2, [1 0.5; 0.5 1])]
for dist in matrix_dists
@testset "$dist" for dist in matrix_dists
single_sample_tests(dist)

x = rand(dist)
x = x + x' + 2I
upperinds = [
LinearIndices(size(x))[I] for I in CartesianIndices(size(x)) if I[2] >= I[1]
inds = [
LinearIndices(size(x))[I] for I in CartesianIndices(size(x)) if I[2] <= I[1]
]
logpdf_turing = logpdf_with_trans(dist, x, true)
J = ForwardDiff.jacobian(x -> link(dist, x), x)
J = J[:, upperinds]
J = J[:, inds]
@test logpdf(dist, x) - _logabsdet(J) ≈ logpdf_turing
end
end
Expand All @@ -214,19 +214,21 @@ end
end

@testset "LKJCholesky" begin
dist = LKJCholesky(3, 1)
@testset "uplo: $uplo" for uplo in [:L, :U]
dist = LKJCholesky(3, 1, uplo)
single_sample_tests(dist)

single_sample_tests(dist)

x = rand(dist)
x = rand(dist)

upperinds = [
LinearIndices(size(x))[I] for I in CartesianIndices(size(x)) if I[2] > I[1]
]
J = ForwardDiff.jacobian(x -> link(dist, x), x.U)
J = J[:, upperinds]
logpdf_turing = logpdf_with_trans(dist, x, true)
@test logpdf(dist, x) - _logabsdet(J) ≈ logpdf_turing
inds = [
LinearIndices(size(x))[I] for I in CartesianIndices(size(x)) if
(uplo === :L && I[2] < I[1]) || (uplo === :U && I[2] > I[1])
]
J = ForwardDiff.jacobian(z -> link(dist, Cholesky(z, x.uplo, x.info)), x.UL)
J = J[:, inds]
logpdf_turing = logpdf_with_trans(dist, x, true)
@test logpdf(dist, x) - _logabsdet(J) ≈ logpdf_turing
end
end

################################## Miscelaneous old tests ##################################
Expand Down
Loading