From 36a6b41cbfee5e8b143dce2ffd932bc49924f5e3 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 5 Jun 2023 22:15:33 +0200 Subject: [PATCH 01/22] Remove unused proj field --- src/bijectors/simplex.jl | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/bijectors/simplex.jl b/src/bijectors/simplex.jl index 1fbc28d2..c4529a1a 100644 --- a/src/bijectors/simplex.jl +++ b/src/bijectors/simplex.jl @@ -1,8 +1,7 @@ #################### # Simplex bijector # #################### -struct SimplexBijector{T} <: Bijector end -SimplexBijector() = SimplexBijector{true}() +struct SimplexBijector <: Bijector end with_logabsdet_jacobian(b::SimplexBijector, x) = transform(b, x), logabsdetjac(b, x) From fa503cb2586835d7049a1124f287d9d577f3ef26 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 5 Jun 2023 22:16:40 +0200 Subject: [PATCH 02/22] Update simplex bijector calls --- src/bijectors/simplex.jl | 54 ++++++++++++++++------------------------ 1 file changed, 22 insertions(+), 32 deletions(-) diff --git a/src/bijectors/simplex.jl b/src/bijectors/simplex.jl index c4529a1a..eee3af2e 100644 --- a/src/bijectors/simplex.jl +++ b/src/bijectors/simplex.jl @@ -8,10 +8,16 @@ with_logabsdet_jacobian(b::SimplexBijector, x) = transform(b, x), logabsdetjac(b transform(b::SimplexBijector, x) = _simplex_bijector(x, b) transform!(b::SimplexBijector, y, x) = _simplex_bijector!(y, x, b) -_simplex_bijector(x::AbstractArray, b::SimplexBijector) = _simplex_bijector!(similar(x), x, b) +function _simplex_bijector(x::AbstractArray, b::SimplexBijector) + sz = size(x) + K = size(x, 1) + y = similar(x, Base.setindex(sz, K - 1, 1)) + _simplex_bijector!(y, x, b) + return y +end # Vector implementation. -function _simplex_bijector!(y, x::AbstractVector, ::SimplexBijector{proj}) where {proj} +function _simplex_bijector!(y, x::AbstractVector, ::SimplexBijector) K = length(x) @assert K > 1 "x needs to be of length greater than 1" T = eltype(x) @@ -26,18 +32,11 @@ function _simplex_bijector!(y, x::AbstractVector, ::SimplexBijector{proj}) where z = (x[k] + ϵ)*(one(T) - 2ϵ)/((one(T) + ϵ) - sum_tmp) y[k] = LogExpFunctions.logit(z) + log(T(K - k)) end - @inbounds sum_tmp += x[K - 1] - @inbounds if proj - y[K] = zero(T) - else - y[K] = one(T) - sum_tmp - x[K] - end - return y end # Matrix implementation. -function _simplex_bijector!(Y, X::AbstractMatrix, ::SimplexBijector{proj}) where {proj} +function _simplex_bijector!(Y, X::AbstractMatrix, ::SimplexBijector) K, N = size(X, 1), size(X, 2) @assert K > 1 "x needs to be of length greater than 1" T = eltype(X) @@ -51,12 +50,6 @@ function _simplex_bijector!(Y, X::AbstractMatrix, ::SimplexBijector{proj}) where z = (X[k, n] + ϵ)*(one(T) - 2ϵ)/((one(T) + ϵ) - sum_tmp) Y[k, n] = LogExpFunctions.logit(z) + log(T(K - k)) end - sum_tmp += X[K-1, n] - if proj - Y[K, n] = zero(T) - else - Y[K, n] = one(T) - sum_tmp - X[K, n] - end end return Y @@ -72,10 +65,16 @@ function transform!( return _simplex_inv_bijector!(x, y, ib.orig) end -_simplex_inv_bijector(y, b) = _simplex_inv_bijector!(similar(y), y, b) +function _simplex_inv_bijector(y, b) + sz = size(y) + K = sz[1] + 1 + x = similar(y, Base.setindex(sz, K, 1)) + _simplex_inv_bijector!(x, y, b) + return x +end -function _simplex_inv_bijector!(x, y::AbstractVector, b::SimplexBijector{proj}) where {proj} - K = length(y) +function _simplex_inv_bijector!(x, y::AbstractVector, b::SimplexBijector) + K = length(y) + 1 @assert K > 1 "x needs to be of length greater than 1" T = eltype(y) ϵ = _eps(T) @@ -88,17 +87,12 @@ function _simplex_inv_bijector!(x, y::AbstractVector, b::SimplexBijector{proj}) x[k] = _clamp(((one(T) + ϵ) - sum_tmp) / (one(T) - 2ϵ) * z - ϵ, 0, 1) end @inbounds sum_tmp += x[K - 1] - @inbounds if proj - x[K] = _clamp(one(T) - sum_tmp, 0, 1) - else - x[K] = _clamp(one(T) - sum_tmp - y[K], 0, 1) - end - + x[K] = _clamp(one(T) - sum_tmp, 0, 1) return x end -function _simplex_inv_bijector!(X, Y::AbstractMatrix, b::SimplexBijector{proj}) where {proj} - K, N = size(Y, 1), size(Y, 2) +function _simplex_inv_bijector!(X, Y::AbstractMatrix, b::SimplexBijector) + K, N = size(Y, 1) + 1, size(Y, 2) @assert K > 1 "x needs to be of length greater than 1" T = eltype(Y) ϵ = _eps(T) @@ -111,11 +105,7 @@ function _simplex_inv_bijector!(X, Y::AbstractMatrix, b::SimplexBijector{proj}) X[k, n] = _clamp(((one(T) + ϵ) - sum_tmp) / (one(T) - 2ϵ) * z - ϵ, 0, 1) end sum_tmp += X[K - 1, n] - if proj - X[K, n] = _clamp(one(T) - sum_tmp, 0, 1) - else - X[K, n] = _clamp(one(T) - sum_tmp - Y[K, n], 0, 1) - end + X[K, n] = _clamp(one(T) - sum_tmp, 0, 1) end return X From 4abaae9eedb9ffa69acba9d045fa67f9926a174f Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 5 Jun 2023 22:17:15 +0200 Subject: [PATCH 03/22] Update simplex jacobian calls --- src/bijectors/simplex.jl | 46 ++++++++++++---------------------------- 1 file changed, 13 insertions(+), 33 deletions(-) diff --git a/src/bijectors/simplex.jl b/src/bijectors/simplex.jl index eee3af2e..c1921b99 100644 --- a/src/bijectors/simplex.jl +++ b/src/bijectors/simplex.jl @@ -200,13 +200,10 @@ function simplex_logabsdetjac_gradient(x::AbstractMatrix) return g end -function simplex_link_jacobian( - x::AbstractVector{T}, - ::Val{proj}=Val(true), -) where {T<:Real, proj} +function simplex_link_jacobian(x::AbstractVector{T}) where {T<:Real} K = length(x) @assert K > 1 "x needs to be of length greater than 1" - dydxt = similar(x, length(x), length(x)) + dydxt = similar(x, K, K - 1) @inbounds dydxt .= 0 ϵ = _eps(T) sum_tmp = zero(T) @@ -223,16 +220,10 @@ function simplex_link_jacobian( dydxt[i,k] = (1/z + 1/(1-z)) * (x[k] + ϵ)*(one(T) - 2ϵ)/((one(T) + ϵ) - sum_tmp)^2 end end - @inbounds sum_tmp += x[K - 1] - @inbounds if !proj - @simd for i in 1:K - dydxt[i,K] = -1 - end - end - return UpperTriangular(dydxt)' + return dydxt' end -function jacobian(b::SimplexBijector{proj}, x::AbstractVector{T}) where {proj, T} - return simplex_link_jacobian(x, Val(proj)) +function jacobian(b::SimplexBijector, x::AbstractVector{T}) where {T} + return simplex_link_jacobian(x) end #= @@ -301,13 +292,10 @@ function add_simplex_link_adjoint!( end =# -function simplex_invlink_jacobian( - y::AbstractVector{T}, - ::Val{proj}=Val(true), -) where {T<:Real, proj} - K = length(y) +function simplex_invlink_jacobian(y::AbstractVector{T}) where {T<:Real} + K = length(y) + 1 @assert K > 1 "x needs to be of length greater than 1" - dxdy = similar(y, length(y), length(y)) + dxdy = similar(y, K, K - 1) @inbounds dxdy .= 0 ϵ = _eps(T) @@ -333,16 +321,8 @@ function simplex_invlink_jacobian( end end @inbounds sum_tmp += clamped_x - @inbounds if proj - unclamped_x = one(T) - sum_tmp - clamped_x = _clamp(unclamped_x, 0, 1) - else - unclamped_x = one(T) - sum_tmp - y[K] - clamped_x = _clamp(unclamped_x, 0, 1) - if unclamped_x == clamped_x - dxdy[K,K] = -1 - end - end + unclamped_x = one(T) - sum_tmp + clamped_x = _clamp(unclamped_x, 0, 1) @inbounds if unclamped_x == clamped_x for i in 1:K-1 @simd for j in i:K-1 @@ -350,11 +330,11 @@ function simplex_invlink_jacobian( end end end - return LowerTriangular(dxdy) + return dxdy end # jacobian -function jacobian(ib::Inverse{<:SimplexBijector{proj}}, y::AbstractVector{T}) where {proj, T} - return simplex_invlink_jacobian(y, Val(proj)) +function jacobian(ib::Inverse{<:SimplexBijector}, y::AbstractVector{T}) where {T} + return simplex_invlink_jacobian(y) end #= From 579c808f580844f339353d1b6a3ee657eb52c287 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 5 Jun 2023 22:17:37 +0200 Subject: [PATCH 04/22] Remove proj type entry --- test/interface.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/interface.jl b/test/interface.jl index e1f8d0e4..55b73e83 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -139,7 +139,7 @@ end # verify against AD # similar to what we do in test/transform.jl for Dirichlet if dist isa Dirichlet - b = Bijectors.SimplexBijector{false}() + b = Bijectors.SimplexBijector() # HACK(torfjelde): Calling `rand(dist)` will sometimes lead to `[0.999..., 0.0]` # which in turn will lead to differences between `ForwardDiff.jacobian` # and `logabsdetjac` due to how we handle the boundary values in `SimplexBijector`. From 31102701358af7b1baf12dd2b3aeb87d44772cd4 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 5 Jun 2023 22:17:56 +0200 Subject: [PATCH 05/22] Compute logdetjac from square part of jacobian --- test/interface.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/interface.jl b/test/interface.jl index 55b73e83..52922590 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -147,8 +147,8 @@ end x = any(rand(dist) .> 0.9999) ? [0.0, 1.0][sortperm(rand(dist))] : rand(dist) y = b(x) @test b(param(x)) isa TrackedArray - @test log(abs(det(ForwardDiff.jacobian(b, x)))) ≈ logabsdetjac(b, x) - @test log(abs(det(ForwardDiff.jacobian(inverse(b), y)))) ≈ logabsdetjac(inverse(b), y) + @test log(abs(det(ForwardDiff.jacobian(b, x)[1:end,1:end-1]))) ≈ logabsdetjac(b, x) + @test log(abs(det(ForwardDiff.jacobian(inverse(b), y)[1:end-1,1:end]))) ≈ logabsdetjac(inverse(b), y) else b = bijector(dist) x = rand(dist) From f21d3286ac419242c39be4d112751c0625553597 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 5 Jun 2023 22:18:29 +0200 Subject: [PATCH 06/22] Increment minor version number --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 98f1908f..08b29211 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Bijectors" uuid = "76274a88-744f-5084-9051-94815aaf08c4" -version = "0.12.5" +version = "0.13.0" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" From 7e2927afaa47cd94b84a659496d6217e2534fcf6 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 6 Jun 2023 22:14:53 +0200 Subject: [PATCH 07/22] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/interface.jl | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/test/interface.jl b/test/interface.jl index 79b61ebf..11f6e092 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -168,10 +168,11 @@ end end y = b(x) @test b(param(x)) isa TrackedArray - @test log(abs(det(ForwardDiff.jacobian(b, x)[1:end,1:end-1]))) ≈ + @test log(abs(det(ForwardDiff.jacobian(b, x)[1:end, 1:(end - 1)]))) ≈ logabsdetjac(b, x) - @test log(abs(det(ForwardDiff.jacobian(inverse(b), y)[1:end-1,1:end]))) ≈ - logabsdetjac(inverse(b), y) + @test log( + abs(det(ForwardDiff.jacobian(inverse(b), y)[1:(end - 1), 1:end])) + ) ≈ logabsdetjac(inverse(b), y) else b = bijector(dist) x = rand(dist) From ba21df56fa4aff7120caddf5b38c797fa5964565 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 6 Jun 2023 23:32:52 +0200 Subject: [PATCH 08/22] Apply suggestions from code review Co-authored-by: David Widmann --- src/bijectors/simplex.jl | 6 ++---- test/interface.jl | 6 ++---- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/src/bijectors/simplex.jl b/src/bijectors/simplex.jl index fd841eca..e8acc5a6 100644 --- a/src/bijectors/simplex.jl +++ b/src/bijectors/simplex.jl @@ -203,8 +203,7 @@ end function simplex_link_jacobian(x::AbstractVector{T}) where {T<:Real} K = length(x) @assert K > 1 "x needs to be of length greater than 1" - dydxt = similar(x, K, K - 1) - @inbounds dydxt .= 0 + dydxt = fill!(similar(x, K, K - 1), 0) ϵ = _eps(T) sum_tmp = zero(T) @@ -297,8 +296,7 @@ end function simplex_invlink_jacobian(y::AbstractVector{T}) where {T<:Real} K = length(y) + 1 @assert K > 1 "x needs to be of length greater than 1" - dxdy = similar(y, K, K - 1) - @inbounds dxdy .= 0 + dxdy = fill!(similar(y, K, K - 1), 0) ϵ = _eps(T) @inbounds z = LogExpFunctions.logistic(y[1] - log(T(K - 1))) diff --git a/test/interface.jl b/test/interface.jl index 11f6e092..28264691 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -168,11 +168,9 @@ end end y = b(x) @test b(param(x)) isa TrackedArray - @test log(abs(det(ForwardDiff.jacobian(b, x)[1:end, 1:(end - 1)]))) ≈ + @test logabsdet(ForwardDiff.jacobian(b, x)[:, 1:(end - 1)])[1] ≈ logabsdetjac(b, x) - @test log( - abs(det(ForwardDiff.jacobian(inverse(b), y)[1:(end - 1), 1:end])) - ) ≈ logabsdetjac(inverse(b), y) + @test logabsdet(ForwardDiff.jacobian(inverse(b), y)[1:(end - 1), :])[1] ≈ logabsdetjac(inverse(b), y) else b = bijector(dist) x = rand(dist) From 3ce84bb0bdb420204212aef9f2a4c8efe496ed1f Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 6 Jun 2023 23:39:45 +0200 Subject: [PATCH 09/22] Update test/interface.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/interface.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/interface.jl b/test/interface.jl index 28264691..d7061419 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -170,7 +170,8 @@ end @test b(param(x)) isa TrackedArray @test logabsdet(ForwardDiff.jacobian(b, x)[:, 1:(end - 1)])[1] ≈ logabsdetjac(b, x) - @test logabsdet(ForwardDiff.jacobian(inverse(b), y)[1:(end - 1), :])[1] ≈ logabsdetjac(inverse(b), y) + @test logabsdet(ForwardDiff.jacobian(inverse(b), y)[1:(end - 1), :])[1] ≈ + logabsdetjac(inverse(b), y) else b = bijector(dist) x = rand(dist) From 776e4af315769661b87616f9dbedb73072619c4b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 17 Jun 2023 19:25:52 +0100 Subject: [PATCH 10/22] fixed link and invlink for SimplexBijector --- src/Bijectors.jl | 23 ++++------------------- 1 file changed, 4 insertions(+), 19 deletions(-) diff --git a/src/Bijectors.jl b/src/Bijectors.jl index 46b31fb3..7f776b82 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -206,26 +206,11 @@ isdirichlet(::Distribution) = false # ∑xᵢ = 1 # ########### -function link(d::Dirichlet, x::AbstractVecOrMat{<:Real}, ::Val{proj}=Val(true)) where {proj} - return SimplexBijector{proj}()(x) -end - -function link_jacobian( - d::Dirichlet, x::AbstractVector{<:Real}, ::Val{proj}=Val(true) -) where {proj} - return jacobian(SimplexBijector{proj}(), x) -end +link(d::Dirichlet, x::AbstractVecOrMat{<:Real}) = SimplexBijector()(x) +link_jacobian(d::Dirichlet, x::AbstractVector{<:Real}) = jacobian(SimplexBijector(), x) -function invlink( - d::Dirichlet, y::AbstractVecOrMat{<:Real}, ::Val{proj}=Val(true) -) where {proj} - return inverse(SimplexBijector{proj}())(y) -end -function invlink_jacobian( - d::Dirichlet, y::AbstractVector{<:Real}, ::Val{proj}=Val(true) -) where {proj} - return jacobian(inverse(SimplexBijector{proj}()), y) -end +invlink(d::Dirichlet, y::AbstractVecOrMat{<:Real}) = inverse(SimplexBijector())(y) +invlink_jacobian(d::Dirichlet, y::AbstractVector{<:Real}) = jacobian(inverse(SimplexBijector()), y) ## Matrix From 921f818d880a33da0de4764fc9ce4b60364794ee Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 17 Jun 2023 21:24:26 +0100 Subject: [PATCH 11/22] Update src/Bijectors.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/Bijectors.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/Bijectors.jl b/src/Bijectors.jl index 7f776b82..e2c4bd0c 100644 --- a/src/Bijectors.jl +++ b/src/Bijectors.jl @@ -210,7 +210,9 @@ link(d::Dirichlet, x::AbstractVecOrMat{<:Real}) = SimplexBijector()(x) link_jacobian(d::Dirichlet, x::AbstractVector{<:Real}) = jacobian(SimplexBijector(), x) invlink(d::Dirichlet, y::AbstractVecOrMat{<:Real}) = inverse(SimplexBijector())(y) -invlink_jacobian(d::Dirichlet, y::AbstractVector{<:Real}) = jacobian(inverse(SimplexBijector()), y) +function invlink_jacobian(d::Dirichlet, y::AbstractVector{<:Real}) + return jacobian(inverse(SimplexBijector()), y) +end ## Matrix From d934cfa66033b8a8db5f1d91ba414fcd7def5b8f Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 17 Jun 2023 21:28:47 +0100 Subject: [PATCH 12/22] super-hacky fix to size issue of TransformedDistribution --- src/transformed_distribution.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformed_distribution.jl b/src/transformed_distribution.jl index eccbb64c..4d296556 100644 --- a/src/transformed_distribution.jl +++ b/src/transformed_distribution.jl @@ -101,8 +101,8 @@ end ############################## # size -Base.length(td::Transformed) = length(td.dist) -Base.size(td::Transformed) = size(td.dist) +Base.length(td::Transformed) = length(td.transform(rand(td.dist))) +Base.size(td::Transformed) = size(td.transform(rand(td.dist))) function logpdf(td::UnivariateTransformed, y::Real) x, logjac = with_logabsdet_jacobian(inverse(td.transform), y) From 8f39b0df2ad930faa84628558fd817cc8b5e6b30 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Sat, 17 Jun 2023 21:29:22 +0100 Subject: [PATCH 13/22] added fixme comment --- src/transformed_distribution.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/transformed_distribution.jl b/src/transformed_distribution.jl index 4d296556..dd6cfdc2 100644 --- a/src/transformed_distribution.jl +++ b/src/transformed_distribution.jl @@ -100,6 +100,7 @@ end # Distributions.jl interface # ############################## +# FIXME: Do this properly yah fool # size Base.length(td::Transformed) = length(td.transform(rand(td.dist))) Base.size(td::Transformed) = size(td.transform(rand(td.dist))) From 8efd243e0c0429df897c5d13904e56ccdf80a70d Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 19 Jun 2023 01:48:19 +0100 Subject: [PATCH 14/22] removed redundant constructor for Stacked --- src/bijectors/stacked.jl | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/bijectors/stacked.jl b/src/bijectors/stacked.jl index 73abf51c..6febe140 100644 --- a/src/bijectors/stacked.jl +++ b/src/bijectors/stacked.jl @@ -85,9 +85,6 @@ end end end -# Avoid mixing tuples and arrays. -Stacked(bs::Tuple, ranges::AbstractArray) = Stacked(collect(bs), ranges) - Functors.@functor Stacked (bs,) function Base.show(io::IO, b::Stacked) From 9c164332bd0e71b451e51f13110da372ab2c2da5 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 19 Jun 2023 01:52:55 +0100 Subject: [PATCH 15/22] added implementation of output_size for SimplexBijector --- src/bijectors/simplex.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/bijectors/simplex.jl b/src/bijectors/simplex.jl index e8acc5a6..97536061 100644 --- a/src/bijectors/simplex.jl +++ b/src/bijectors/simplex.jl @@ -3,6 +3,12 @@ #################### struct SimplexBijector <: Bijector end +output_size(::SimplexBijector, sz::Tuple{Int}) = (first(sz) - 1,) +output_size(::Inverse{SimplexBijector}, sz::Tuple{Int}) = (first(sz) + 1,) + +output_size(::SimplexBijector, sz::Tuple{Int,Int}) = Base.setindex(sz, first(sz) - 1, 1) +output_size(::Inverse{SimplexBijector}, sz::Tuple{Int,Int}) = Base.setindex(sz, first(sz) + 1, 1) + with_logabsdet_jacobian(b::SimplexBijector, x) = transform(b, x), logabsdetjac(b, x) transform(b::SimplexBijector, x) = _simplex_bijector(x, b) From 78f015ec56484e5c7b7d4c3ea7a5a7497b7616ce Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 19 Jun 2023 02:16:19 +0100 Subject: [PATCH 16/22] Update src/bijectors/simplex.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/bijectors/simplex.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/bijectors/simplex.jl b/src/bijectors/simplex.jl index 97536061..f237c307 100644 --- a/src/bijectors/simplex.jl +++ b/src/bijectors/simplex.jl @@ -7,7 +7,9 @@ output_size(::SimplexBijector, sz::Tuple{Int}) = (first(sz) - 1,) output_size(::Inverse{SimplexBijector}, sz::Tuple{Int}) = (first(sz) + 1,) output_size(::SimplexBijector, sz::Tuple{Int,Int}) = Base.setindex(sz, first(sz) - 1, 1) -output_size(::Inverse{SimplexBijector}, sz::Tuple{Int,Int}) = Base.setindex(sz, first(sz) + 1, 1) +function output_size(::Inverse{SimplexBijector}, sz::Tuple{Int,Int}) + return Base.setindex(sz, first(sz) + 1, 1) +end with_logabsdet_jacobian(b::SimplexBijector, x) = transform(b, x), logabsdetjac(b, x) From 144df86206cfacdd70c27a329e83415b6f37965b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 19 Jun 2023 02:22:01 +0100 Subject: [PATCH 17/22] fixed tests --- test/interface.jl | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/test/interface.jl b/test/interface.jl index 22960858..207adec1 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -421,35 +421,37 @@ end b = SimplexBijector() ib = inverse(b) - x = ib(randn(10)) + d_x = 10 + x = ib(randn(d_x - 1)) y = b(x) @test Bijectors.jacobian(b, x) ≈ ForwardDiff.jacobian(b, x) @test Bijectors.jacobian(ib, y) ≈ ForwardDiff.jacobian(ib, y) # Just some additional computation so we also ensure the pullbacks are the same - weights = randn(10) + weights_x = randn(d_x) + weights_y = randn(d_x - 1) # Tracker.jl x_tracked = Tracker.param(x) - z = sum(weights .* b(x_tracked)) + z = sum(weights_y .* b(x_tracked)) Tracker.back!(z) Δ_tracker = Tracker.grad(x_tracked) # ForwardDiff.jl - Δ_forwarddiff = ForwardDiff.gradient(z -> sum(weights .* b(z)), x) + Δ_forwarddiff = ForwardDiff.gradient(z -> sum(weights_y .* b(z)), x) # Compare @test Δ_forwarddiff ≈ Δ_tracker # Tracker.jl y_tracked = Tracker.param(y) - z = sum(weights .* ib(y_tracked)) + z = sum(weights_x .* ib(y_tracked)) Tracker.back!(z) Δ_tracker = Tracker.grad(y_tracked) # ForwardDiff.jl - Δ_forwarddiff = ForwardDiff.gradient(z -> sum(weights .* ib(z)), y) + Δ_forwarddiff = ForwardDiff.gradient(z -> sum(weights_x .* ib(z)), y) @test Δ_forwarddiff ≈ Δ_tracker end From a8e6e21934f42624cd0e06d5e5e51f7988b49324 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 19 Jun 2023 03:06:44 +0100 Subject: [PATCH 18/22] removed more references to old SimplexBijector code --- ext/BijectorsDistributionsADExt.jl | 24 ++++++++++++------------ test/transform.jl | 22 +++++++--------------- 2 files changed, 19 insertions(+), 27 deletions(-) diff --git a/ext/BijectorsDistributionsADExt.jl b/ext/BijectorsDistributionsADExt.jl index c34284a1..f4e41046 100644 --- a/ext/BijectorsDistributionsADExt.jl +++ b/ext/BijectorsDistributionsADExt.jl @@ -79,26 +79,26 @@ Bijectors.isdirichlet(::VectorOfMultivariate{Continuous,<:TuringDirichlet}) = tr Bijectors.isdirichlet(::TuringDirichlet) = true function Bijectors.link( - d::TuringDirichlet, x::AbstractVecOrMat{<:Real}, ::Val{proj}=Val(true) -) where {proj} - return Bijectors.SimplexBijector{proj}()(x) + d::TuringDirichlet, x::AbstractVecOrMat{<:Real} +) + return Bijectors.SimplexBijector()(x) end function Bijectors.link_jacobian( - d::TuringDirichlet, x::AbstractVector{<:Real}, ::Val{proj}=Val(true) -) where {proj} - return jacobian(Bijectors.SimplexBijector{proj}(), x) + d::TuringDirichlet, x::AbstractVector{<:Real} +) + return jacobian(Bijectors.SimplexBijector(), x) end function Bijectors.invlink( - d::TuringDirichlet, y::AbstractVecOrMat{<:Real}, ::Val{proj}=Val(true) -) where {proj} - return inverse(Bijectors.SimplexBijector{proj}())(y) + d::TuringDirichlet, y::AbstractVecOrMat{<:Real} +) + return inverse(Bijectors.SimplexBijector())(y) end function Bijectors.invlink_jacobian( - d::TuringDirichlet, y::AbstractVector{<:Real}, ::Val{proj}=Val(true) -) where {proj} - return jacobian(inverse(Bijectors.SimplexBijector{proj}()), y) + d::TuringDirichlet, y::AbstractVector{<:Real} +) + return jacobian(inverse(Bijectors.SimplexBijector()), y) end Bijectors.ispd(::TuringWishart) = true diff --git a/test/transform.jl b/test/transform.jl index 00e2ea79..31e40548 100644 --- a/test/transform.jl +++ b/test/transform.jl @@ -146,7 +146,7 @@ end end logpdf_turing = logpdf_with_trans(dist, x, true) - J = ForwardDiff.jacobian(x -> link(dist, x, Val(false)), x) + J = ForwardDiff.jacobian(x -> link(dist, x), x) @test logpdf(dist, x .+ ϵ) - _logabsdet(J) ≈ logpdf_turing # Issue #12 @@ -272,25 +272,17 @@ end x = rand(dist) y = @inferred(link(dist, x)) - f1 = x -> link(dist, x, Val(true)) - f2 = x -> link(dist, x, Val(false)) - g1 = y -> invlink(dist, y, Val(true)) - g2 = y -> invlink(dist, y, Val(false)) + f1 = x -> link(dist, x) + g1 = y -> invlink(dist, y) @test @aeq ForwardDiff.jacobian(f1, x) @inferred( - Bijectors.simplex_link_jacobian(x, Val(true)) - ) - @test @aeq ForwardDiff.jacobian(f2, x) @inferred( - Bijectors.simplex_link_jacobian(x, Val(false)) + Bijectors.simplex_link_jacobian(x) ) @test @aeq ForwardDiff.jacobian(g1, y) @inferred( - Bijectors.simplex_invlink_jacobian(y, Val(true)) - ) - @test @aeq ForwardDiff.jacobian(g2, y) @inferred( - Bijectors.simplex_invlink_jacobian(y, Val(false)) + Bijectors.simplex_invlink_jacobian(y) ) - @test @aeq Bijectors.simplex_link_jacobian(x, Val(false)) * - Bijectors.simplex_invlink_jacobian(y, Val(false)) I + @test @aeq Bijectors.simplex_link_jacobian(x) * + Bijectors.simplex_invlink_jacobian(y) end for i in 1:4 test_link_and_invlink() From 852c82633d8aade00c59fa278c794378c64bd576 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 19 Jun 2023 11:05:44 +0100 Subject: [PATCH 19/22] fixed more dirichlet tests --- test/transform.jl | 32 +++++++++++++++++++------------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/test/transform.jl b/test/transform.jl index 31e40548..079885a1 100644 --- a/test/transform.jl +++ b/test/transform.jl @@ -11,9 +11,12 @@ _logabsdet(x::AbstractArray) = logabsdet(x)[1] _logabsdet(x::Real) = log(abs(x)) # Generate a (vector / matrix of) random number(s). -_rand_real(::Real) = randn() -_rand_real(x) = randn(size(x)) -_rand_real(x, e) = (y = randn(size(x)); y[end] = e; y) +_rand_real(dist, ::Real) = randn() +function _rand_real(dist, x) + b = bijector(dist) + sz = Bijectors.output_size(b, size(x)) + return randn(sz) +end # Standard tests for all distributions involving a single-sample. function single_sample_tests(dist, jacobian) @@ -68,13 +71,13 @@ function single_sample_tests(dist) # Check that invlink maps back to the apppropriate constrained domain. @test all( isfinite, - logpdf.(Ref(dist), [invlink(dist, _rand_real(x, 0)) .+ ϵ for _ in 1:100]), + logpdf.(Ref(dist), [invlink(dist, _rand_real(dist, x)) .+ ϵ for _ in 1:100]), ) else # This should probably be exact. @test logpdf(dist, x) == logpdf_with_trans(dist, x, false) @test all( - isfinite, logpdf.(Ref(dist), [invlink(dist, _rand_real(y)) for _ in 1:100]) + isfinite, logpdf.(Ref(dist), [invlink(dist, _rand_real(dist, y)) for _ in 1:100]) ) end end @@ -144,14 +147,17 @@ end else rand(dist) end - - logpdf_turing = logpdf_with_trans(dist, x, true) - J = ForwardDiff.jacobian(x -> link(dist, x), x) - @test logpdf(dist, x .+ ϵ) - _logabsdet(J) ≈ logpdf_turing + # `Dirichlet` is no longer mapping between spaces of the same dimensionality, + # so the block below no longer works. + if !(dist isa Dirichlet) + logpdf_turing = logpdf_with_trans(dist, x, true) + J = ForwardDiff.jacobian(x -> link(dist, x), x) + @test logpdf(dist, x .+ ϵ) - _logabsdet(J) ≈ logpdf_turing + end # Issue #12 stepsize = 1e10 - dim = length(dist) + dim = Bijectors.output_length(bijector(dist), length(dist)) x = [ logpdf_with_trans( dist, @@ -237,8 +243,8 @@ end # julia> logpdf_with_trans(Dirichlet([1., 1., 1.]), [-1., -2., -3.], true, true) # -3.006450206744678 d = Dirichlet([1.0, 1.0, 1.0]) -r = [-1000.0, -1000.0, 0.0] -r2 = [-1.0, -2.0, 0.0] +r = [-1000.0, -1000.0] +r2 = [-1.0, -2.0] # test vector invlink dist = Dirichlet(ones(5)) @@ -282,7 +288,7 @@ end Bijectors.simplex_invlink_jacobian(y) ) @test @aeq Bijectors.simplex_link_jacobian(x) * - Bijectors.simplex_invlink_jacobian(y) + Bijectors.simplex_invlink_jacobian(y) I end for i in 1:4 test_link_and_invlink() From 97af4417f2ea83738f02bbae47d3f5f101602633 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 19 Jun 2023 11:06:13 +0100 Subject: [PATCH 20/22] formatting --- test/transform.jl | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/test/transform.jl b/test/transform.jl index 079885a1..62cee973 100644 --- a/test/transform.jl +++ b/test/transform.jl @@ -77,7 +77,8 @@ function single_sample_tests(dist) # This should probably be exact. @test logpdf(dist, x) == logpdf_with_trans(dist, x, false) @test all( - isfinite, logpdf.(Ref(dist), [invlink(dist, _rand_real(dist, y)) for _ in 1:100]) + isfinite, + logpdf.(Ref(dist), [invlink(dist, _rand_real(dist, y)) for _ in 1:100]), ) end end @@ -281,9 +282,7 @@ end f1 = x -> link(dist, x) g1 = y -> invlink(dist, y) - @test @aeq ForwardDiff.jacobian(f1, x) @inferred( - Bijectors.simplex_link_jacobian(x) - ) + @test @aeq ForwardDiff.jacobian(f1, x) @inferred(Bijectors.simplex_link_jacobian(x)) @test @aeq ForwardDiff.jacobian(g1, y) @inferred( Bijectors.simplex_invlink_jacobian(y) ) From 1f8a0f159704496ccc666328ab66b13c81baea9b Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 19 Jun 2023 11:16:26 +0100 Subject: [PATCH 21/22] possilby fixed weird formatting complaints --- ext/BijectorsDistributionsADExt.jl | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) diff --git a/ext/BijectorsDistributionsADExt.jl b/ext/BijectorsDistributionsADExt.jl index f4e41046..8a553285 100644 --- a/ext/BijectorsDistributionsADExt.jl +++ b/ext/BijectorsDistributionsADExt.jl @@ -78,26 +78,18 @@ Bijectors.isdirichlet(::VectorOfMultivariate{Continuous,<:Dirichlet}) = true Bijectors.isdirichlet(::VectorOfMultivariate{Continuous,<:TuringDirichlet}) = true Bijectors.isdirichlet(::TuringDirichlet) = true -function Bijectors.link( - d::TuringDirichlet, x::AbstractVecOrMat{<:Real} -) +function Bijectors.link(d::TuringDirichlet, x::AbstractVecOrMat{<:Real}) return Bijectors.SimplexBijector()(x) end -function Bijectors.link_jacobian( - d::TuringDirichlet, x::AbstractVector{<:Real} -) +function Bijectors.link_jacobian(d::TuringDirichlet, x::AbstractVector{<:Real}) return jacobian(Bijectors.SimplexBijector(), x) end -function Bijectors.invlink( - d::TuringDirichlet, y::AbstractVecOrMat{<:Real} -) +function Bijectors.invlink(d::TuringDirichlet, y::AbstractVecOrMat{<:Real}) return inverse(Bijectors.SimplexBijector())(y) end -function Bijectors.invlink_jacobian( - d::TuringDirichlet, y::AbstractVector{<:Real} -) +function Bijectors.invlink_jacobian(d::TuringDirichlet, y::AbstractVector{<:Real}) return jacobian(inverse(Bijectors.SimplexBijector()), y) end From 5d394bb3ff03216d4945d918dac0ddae333d08d7 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 19 Jun 2023 15:42:50 +0100 Subject: [PATCH 22/22] Apply suggestions from code review Co-authored-by: David Widmann --- src/bijectors/simplex.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/bijectors/simplex.jl b/src/bijectors/simplex.jl index f237c307..c5bb4ef6 100644 --- a/src/bijectors/simplex.jl +++ b/src/bijectors/simplex.jl @@ -6,9 +6,9 @@ struct SimplexBijector <: Bijector end output_size(::SimplexBijector, sz::Tuple{Int}) = (first(sz) - 1,) output_size(::Inverse{SimplexBijector}, sz::Tuple{Int}) = (first(sz) + 1,) -output_size(::SimplexBijector, sz::Tuple{Int,Int}) = Base.setindex(sz, first(sz) - 1, 1) +output_size(::SimplexBijector, sz::Tuple{Int,Int}) = (first(sz) - 1, last(sz)) function output_size(::Inverse{SimplexBijector}, sz::Tuple{Int,Int}) - return Base.setindex(sz, first(sz) + 1, 1) + return (first(sz) + 1, last(sz)) end with_logabsdet_jacobian(b::SimplexBijector, x) = transform(b, x), logabsdetjac(b, x)