Skip to content

Commit

Permalink
Fix for #287 (#288)
Browse files Browse the repository at this point in the history
* added has_constant_bijector and made bijector of product distributions
return the identity whenever possible

* no need to limit ourselves to identity for constant bijectors

* no need to limit ourselves to identity for Product

* bump patch version

* Update test/bijectors/ordered.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* fixed tests

* attempt at fix for ordered MvTDist test

* dispatch on GenericMvTDist instead of TDist

* Update test/bijectors/ordered.jl

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>

* added some tests for MvTDist

* make elementwise acting on identity return identity

* fixed bug in error

---------

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
torfjelde and github-actions[bot] authored Sep 4, 2023
1 parent 9419d4f commit 04b79dd
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 18 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Bijectors"
uuid = "76274a88-744f-5084-9051-94815aaf08c4"
version = "0.13.6"
version = "0.13.7"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Expand Down
3 changes: 2 additions & 1 deletion src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ In the case where `f::ComposedFunction`, the result is
`Base.Fix1(broadcast, f)`.
"""
elementwise(f) = Base.Fix1(broadcast, f)
elementwise(f::typeof(identity)) = identity
# TODO: This is makes dispatching quite a bit easier, but uncertain if this is really
# the way to go.
function elementwise(f::ComposedFunction)
Expand Down Expand Up @@ -91,7 +92,7 @@ function transform(t::Transform, x)
res = with_logabsdet_jacobian(t, x)
if res isa ChangesOfVariables.NoLogAbsDetJacobian
error(
"`transform` not implemented for $(typeof(b)); implement `transform` and/or `with_logabsdet_jacobian`.",
"`transform` not implemented for $(typeof(f)); implement `transform` and/or `with_logabsdet_jacobian`.",
)
end

Expand Down
35 changes: 34 additions & 1 deletion src/transformed_distribution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,43 @@ function bijector(td::TransformedDistribution)
b = bijector(td.dist)
return b === identity ? inverse(td.transform) : b inverse(td.transform)
end

"""
has_constant_bijector(dist_type::Type)
Returns `true` if the distribution type `dist_type` has a constant bijector,
i.e. the return-value of [`bijector`](@ref) does not depend on runtime information.
"""
has_constant_bijector(d::Type) = false
has_constant_bijector(d::Type{<:Normal}) = true
has_constant_bijector(d::Type{<:Distributions.AbstractMvNormal}) = true
has_constant_bijector(d::Type{<:Distributions.AbstractMvLogNormal}) = true
has_constant_bijector(d::Type{<:TDist}) = true
has_constant_bijector(d::Type{<:Distributions.GenericMvTDist}) = true
has_constant_bijector(d::Type{<:PositiveDistribution}) = true
has_constant_bijector(d::Type{<:SimplexDistribution}) = true
has_constant_bijector(d::Type{<:KSOneSided}) = true
function has_constant_bijector(::Type{<:Product{Continuous,D}}) where {D}
return has_constant_bijector(D)
end

# Container distributions.
bijector(d::DiscreteUnivariateDistribution) = identity
bijector(d::DiscreteMultivariateDistribution) = identity
bijector(d::ContinuousUnivariateDistribution) = TruncatedBijector(minimum(d), maximum(d))
bijector(d::Product{Discrete}) = identity
function bijector(d::Product{Continuous})
return TruncatedBijector(_minmax(d.v)...)
D = eltype(d.v)
return if has_constant_bijector(D)
elementwise(bijector(d.v[1]))
else
# FIXME: This is not great. Should use something like
# `Stacked(map(bijector, d.v))` instead.
# TODO: Specialize. F.ex. for FillArrays.jl we can do much better.
TruncatedBijector(_minmax(d.v)...)
end
end

@generated function _minmax(d::AbstractArray{T}) where {T}
try
min, max = minimum(T), maximum(T)
Expand All @@ -63,9 +93,12 @@ end
end
end

# Specialized implementations.
bijector(d::Normal) = identity
bijector(d::Distributions.AbstractMvNormal) = identity
bijector(d::Distributions.AbstractMvLogNormal) = elementwise(log)
bijector(d::TDist) = identity
bijector(d::Distributions.GenericMvTDist) = identity
bijector(d::PositiveDistribution) = elementwise(log)
bijector(d::SimplexDistribution) = SimplexBijector()
bijector(d::KSOneSided) = Logit(zero(eltype(d)), one(eltype(d)))
Expand Down
33 changes: 18 additions & 15 deletions test/bijectors/ordered.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,23 @@ using LinearAlgebra
end

@testset "ordered" begin
d = MvNormal(1:5, Diagonal(6:10))
d_ordered = ordered(d)
@test d_ordered isa Bijectors.TransformedDistribution
@test d_ordered.dist === d
@test d_ordered.transform isa OrderedBijector
y = randn(5)
x = inverse(bijector(d_ordered))(y)
@test issorted(x)
@testset "$d" for d in [
MvNormal(1:5, Diagonal(6:10)),
MvTDist(1, collect(1.0:5), Matrix(I(5))),
product_distribution(fill(Normal(), 5)),
product_distribution(fill(TDist(1), 5)),
]
d_ordered = ordered(d)
@test d_ordered isa Bijectors.TransformedDistribution
@test d_ordered.dist === d
@test d_ordered.transform isa OrderedBijector
y = randn(5)
x = inverse(bijector(d_ordered))(y)
@test issorted(x)
end

d = Product(fill(Normal(), 5))
# currently errors because `bijector(Product(fill(Normal(), 5)))` is not an `Identity`
@test_broken ordered(d) isa Bijectors.TransformedDistribution

# non-Identity bijector is not supported
d = Dirichlet(ones(5))
@test_throws ArgumentError ordered(d)
@testset "non-identity bijector is not supported" begin
d = Dirichlet(ones(5))
@test_throws ArgumentError ordered(d)
end
end
2 changes: 2 additions & 0 deletions test/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ end
MvLogNormal(MvNormal(randn(10), Diagonal(exp.(randn(10))))),
Dirichlet([1000 * one(Float64), eps(Float64)]),
Dirichlet([eps(Float64), 1000 * one(Float64)]),
MvTDist(1, randn(10), Matrix(Diagonal(exp.(randn(10))))),
transformed(MvNormal(randn(10), Diagonal(exp.(randn(10))))),
transformed(MvLogNormal(MvNormal(randn(10), Diagonal(exp.(randn(10)))))),
transformed(reshape(product_distribution(fill(InverseGamma(2, 3), 6)), 2, 3)),
Expand Down Expand Up @@ -200,6 +201,7 @@ end
TuringInverseWishart(v, S),
LKJ(3, 1.0),
reshape(MvNormal(zeros(6), I), 2, 3),
product_distribution(fill(InverseGamma(2, 3), 6)),
]

for dist in matrix_dists
Expand Down

2 comments on commit 04b79dd

@torfjelde
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/90817

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.13.7 -m "<description of version>" 04b79dd46eca8cea2f988348c47bd5e720a2b9a4
git push origin v0.13.7

Please sign in to comment.