Skip to content

Commit

Permalink
Fix Dirichlet logpdf_with_trans to work with a Vector{Real}
Browse files Browse the repository at this point in the history
  • Loading branch information
mhauru committed Aug 29, 2024
1 parent dc6b21f commit a17b734
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/Bijectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ function logpdf_with_trans(d::Distribution, x, transform::Bool)
if ispd(d)
return pd_logpdf_with_trans(d, x, transform)
elseif isdirichlet(d)
l = logpdf(d, x .+ eps(eltype(x)))
l = logpdf(d, x .+ _eps(eltype(x)))
else
l = logpdf(d, x)
end
Expand Down
31 changes: 27 additions & 4 deletions test/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,28 +42,51 @@ function single_sample_tests(dist)

# Check that invlink is inverse of link.
x = rand(dist)
_single_sample_tests_inner(dist, x, ϵ)

# If the sample is a vector of scalars, check that we can run the tests even if the
# vector has the abstract element type Real. Skip type stability tests though.
if x isa Vector{<:Real}
_single_sample_tests_inner(dist, Vector{Real}(x), ϵ, false)
end
end

function _single_sample_tests_inner(dist, x, ϵ, test_type_stability=true)
if dist isa LKJCholesky
x_inv = @inferred Cholesky{Float64,Matrix{Float64}} invlink(
dist, link(dist, copy(x))
)
@test x_inv.UL x.UL atol = 1e-9
else
@test @inferred(invlink(dist, link(dist, copy(x)))) x atol = 1e-9
x_reconstructed = if test_type_stability
@inferred invlink(dist, link(dist, copy(x)))
else
invlink(dist, link(dist, copy(x)))
end
@test x_reconstructed x atol = 1e-9
end

# Check that link is inverse of invlink. Hopefully this just holds given the above...
y = @inferred(link(dist, x))
y = if test_type_stability
@inferred(link(dist, x))
else
link(dist, x)
end
y_reconstructed = if test_type_stability
@inferred(link(dist, invlink(dist, copy(y))))
else
link(dist, invlink(dist, copy(y)))
end
if dist isa Dirichlet
# `logit` and `logistic` are not perfect inverses. This leads to a diversion.
# Example:
# julia> logit(logistic(0.9999999999999998))
# 1.0
# julia> logistic(logit(0.9999999999999998))
# 0.9999999999999998
@test @inferred(link(dist, invlink(dist, copy(y)))) y atol = 0.5
@test y_reconstructed y atol = 0.5
else
@test @inferred(link(dist, invlink(dist, copy(y)))) y atol = 1e-9
@test y_reconstructed y atol = 1e-9
end
if dist isa SimplexDistribution
# This should probably be exact.
Expand Down

0 comments on commit a17b734

Please sign in to comment.