diff --git a/ext/BijectorsReverseDiffExt.jl b/ext/BijectorsReverseDiffExt.jl index ef0cfebd..7a57f3a2 100644 --- a/ext/BijectorsReverseDiffExt.jl +++ b/ext/BijectorsReverseDiffExt.jl @@ -250,32 +250,14 @@ end end # `OrderedBijector` -function _transform_ordered(y::Union{TrackedVector,TrackedMatrix}) - return track(_transform_ordered, y) -end -@grad function _transform_ordered(y::AbstractVecOrMat) - x, dx = ChainRulesCore.rrule(_transform_ordered, value(y)) - return x, (wrap_chainrules_output ∘ Base.tail ∘ dx) -end - -function _transform_inverse_ordered(x::Union{TrackedVector,TrackedMatrix}) - return track(_transform_inverse_ordered, x) -end -@grad function _transform_inverse_ordered(x::AbstractVecOrMat) - y, dy = ChainRulesCore.rrule(_transform_inverse_ordered, value(x)) - return y, (wrap_chainrules_output ∘ Base.tail ∘ dy) -end +@grad_from_chainrules _transform_ordered(y::Union{TrackedVector,TrackedMatrix}) +@grad_from_chainrules _transform_inverse_ordered(x::Union{TrackedVector,TrackedMatrix}) @grad_from_chainrules update_triu_from_vec(vals::TrackedVector{<:Real}, k::Int, dim::Int) @grad_from_chainrules _link_chol_lkj(x::TrackedMatrix) @grad_from_chainrules _inv_link_chol_lkj(x::TrackedVector) -# NOTE: Probably doesn't work in complete generality. -wrap_chainrules_output(x) = x -wrap_chainrules_output(x::ChainRulesCore.AbstractZero) = nothing -wrap_chainrules_output(x::Tuple) = map(wrap_chainrules_output, x) - if VERSION <= v"1.8.0-DEV.1526" # HACK: This dispatch does not wrap X in Hermitian before calling cholesky. # cholesky does not work with AbstractMatrix in julia versions before the compared one, diff --git a/ext/BijectorsTrackerExt.jl b/ext/BijectorsTrackerExt.jl index bc12c46d..b44cf3a3 100644 --- a/ext/BijectorsTrackerExt.jl +++ b/ext/BijectorsTrackerExt.jl @@ -532,4 +532,21 @@ wrap_chainrules_output(x) = x wrap_chainrules_output(x::ChainRulesCore.AbstractZero) = nothing wrap_chainrules_output(x::Tuple) = map(wrap_chainrules_output, x) +# `update_triu_from_vec` +function Bijectors.update_triu_from_vec(vals::TrackedVector{<:Real}, k::Int, dim::Int) + return track(Bijectors.update_triu_from_vec, vals, k, dim) +end + +@grad function Bijectors.update_triu_from_vec(vals::TrackedVector{<:Real}, k::Int, dim::Int) + # HACK: This doesn't support higher order! + y, dy = ChainRulesCore.rrule(Bijectors.update_triu_from_vec, data(vals), k, dim) + return y, (wrap_chainrules_output ∘ Base.tail ∘ dy) +end + +Bijectors.upper_triangular(A::TrackedMatrix) = track(Bijectors.upper_triangular, A) +@grad function Bijectors.upper_triangular(A::AbstractMatrix) + Ad = data(A) + return Bijectors.upper_triangular(Ad), Δ -> (Bijectors.upper_triangular(Δ),) +end + end diff --git a/src/bijectors/corr.jl b/src/bijectors/corr.jl index a4ed4740..0ec6297f 100644 --- a/src/bijectors/corr.jl +++ b/src/bijectors/corr.jl @@ -89,97 +89,6 @@ function logabsdetjac(b::CorrBijector, X::AbstractMatrix{<:Real}) return -logabsdetjac(inverse(b), (b(X))) end -""" - triu_mask(X::AbstractMatrix, k::Int) - -Return a mask for elements of `X` above the `k`th diagonal. -""" -function triu_mask(X::AbstractMatrix, k::Int) - # Ensure that we're working with a square matrix. - LinearAlgebra.checksquare(X) - - # Using `similar` allows us to respect device of array, etc., e.g. `CuArray`. - m = similar(X, Bool) - return triu(.~m .| m, k) -end - -triu_to_vec(X::AbstractMatrix{<:Real}, k::Int) = X[triu_mask(X, k)] - -function update_triu_from_vec!( - vals::AbstractVector{<:Real}, k::Int, X::AbstractMatrix{<:Real} -) - # Ensure that we're working with one-based indexing. - # `triu` requires this too. - LinearAlgebra.require_one_based_indexing(X) - - # Set the values. - idx = 1 - m, n = size(X) - for j in 1:n - for i in 1:min(j - k, m) - X[i, j] = vals[idx] - idx += 1 - end - end - - return X -end - -function update_triu_from_vec(vals::AbstractVector{<:Real}, k::Int, dim::Int) - X = similar(vals, dim, dim) - # TODO: Do we need this? - X .= 0 - return update_triu_from_vec!(vals, k, X) -end - -function ChainRulesCore.rrule( - ::typeof(update_triu_from_vec), x::AbstractVector{<:Real}, k::Int, dim::Int -) - function update_triu_from_vec_pullback(ΔX) - return ( - ChainRulesCore.NoTangent(), - triu_to_vec(ChainRulesCore.unthunk(ΔX), k), - ChainRulesCore.NoTangent(), - ChainRulesCore.NoTangent(), - ) - end - return update_triu_from_vec(x, k, dim), update_triu_from_vec_pullback -end - -# n * (n - 1) / 2 = d -# ⟺ n^2 - n - 2d = 0 -# ⟹ n = (1 + sqrt(1 + 8d)) / 2 -_triu1_dim_from_length(d) = (1 + isqrt(1 + 8d)) ÷ 2 - -""" - triu1_to_vec(X::AbstractMatrix{<:Real}) - -Extracts elements from upper triangle of `X` with offset `1` and returns them as a vector. -""" -triu1_to_vec(X::AbstractMatrix) = triu_to_vec(X, 1) - -inverse(::typeof(triu1_to_vec)) = vec_to_triu1 - -""" - vec_to_triu1(x::AbstractVector{<:Real}) - -Constructs a matrix from a vector `x` by filling the upper triangle with offset `1`. -""" -function vec_to_triu1(x::AbstractVector) - n = _triu1_dim_from_length(length(x)) - X = update_triu_from_vec(x, 1, n) - return upper_triangular(X) -end - -inverse(::typeof(vec_to_triu1)) = triu1_to_vec - -function vec_to_triu1_row_index(idx) - # Assumes that vector was saved in a column-major order - # and that vector is one-based indexed. - M = _triu1_dim_from_length(idx - 1) - return idx - (M * (M - 1) ÷ 2) -end - """ VecCorrBijector <: Bijector diff --git a/src/bijectors/pd.jl b/src/bijectors/pd.jl index 6e74523e..4839b83e 100644 --- a/src/bijectors/pd.jl +++ b/src/bijectors/pd.jl @@ -40,3 +40,30 @@ end function with_logabsdet_jacobian(b::PDBijector, X) return transform(b, X), logabsdetjac(b, X) end + +struct PDVecBijector <: Bijector end + +transform(::PDVecBijector, X::AbstractMatrix{<:Real}) = pd_vec_link(X) +pd_vec_link(X) = triu_to_vec(transpose(pd_link(X))) + +function transform(::Inverse{PDVecBijector}, y::AbstractVector{<:Real}) + Y = permutedims(vec_to_triu(y)) + return transform(inverse(PDBijector()), Y) +end + +logabsdetjac(::PDVecBijector, X::AbstractMatrix{<:Real}) = logabsdetjac(PDBijector(), X) + +function with_logabsdet_jacobian(b::PDVecBijector, X) + return transform(b, X), logabsdetjac(b, X) +end + +function output_size(::PDVecBijector, sz::Tuple{Int,Int}) + n = first(sz) + d = (n^2 + n) ÷ 2 + return (d,) +end + +function output_size(::Inverse{PDVecBijector}, sz::Tuple{Int}) + n = _triu_dim_from_length(first(sz)) + return (n, n) +end diff --git a/src/transformed_distribution.jl b/src/transformed_distribution.jl index 04c3a559..9c14fbf5 100644 --- a/src/transformed_distribution.jl +++ b/src/transformed_distribution.jl @@ -82,8 +82,8 @@ bijector(d::BoundedDistribution) = bijector_bounded(d) const LowerboundedDistribution = Union{Pareto,Levy} bijector(d::LowerboundedDistribution) = bijector_lowerbounded(d) -bijector(d::PDMatDistribution) = PDBijector() -bijector(d::MatrixBeta) = PDBijector() +bijector(d::PDMatDistribution) = PDVecBijector() +bijector(d::MatrixBeta) = PDVecBijector() bijector(d::LKJ) = VecCorrBijector() bijector(d::LKJCholesky) = VecCholeskyBijector(d.uplo) diff --git a/src/utils.jl b/src/utils.jl index 439a2d0c..f511cf7b 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -18,3 +18,123 @@ cholesky_factor(X::AbstractMatrix) = cholesky_factor(cholesky(Hermitian(X))) cholesky_factor(X::Cholesky) = X.U cholesky_factor(X::UpperTriangular) = X cholesky_factor(X::LowerTriangular) = X + +""" + triu_mask(X::AbstractMatrix, k::Int) + +Return a mask for elements of `X` above the `k`th diagonal. +""" +function triu_mask(X::AbstractMatrix, k::Int) + # Ensure that we're working with a square matrix. + LinearAlgebra.checksquare(X) + + # Using `similar` allows us to respect device of array, etc., e.g. `CuArray`. + m = similar(X, Bool) + return triu!(fill!(m, true), k) +end + +ChainRulesCore.@non_differentiable triu_mask(X::AbstractMatrix, k::Int) + +_triu_to_vec(X::AbstractMatrix{<:Real}, k::Int) = X[triu_mask(X, k)] + +function update_triu_from_vec!( + vals::AbstractVector{<:Real}, k::Int, X::AbstractMatrix{<:Real} +) + # Ensure that we're working with one-based indexing. + # `triu` requires this too. + LinearAlgebra.require_one_based_indexing(X) + + # Set the values. + idx = 1 + m, n = size(X) + for j in 1:n + for i in 1:min(j - k, m) + X[i, j] = vals[idx] + idx += 1 + end + end + + return X +end + +function update_triu_from_vec(vals::AbstractVector{<:Real}, k::Int, dim::Int) + X = similar(vals, dim, dim) + # TODO: Do we need this? + fill!(X, 0) + return update_triu_from_vec!(vals, k, X) +end + +function ChainRulesCore.rrule( + ::typeof(update_triu_from_vec), x::AbstractVector{<:Real}, k::Int, dim::Int +) + function update_triu_from_vec_pullback(ΔX) + return ( + ChainRulesCore.NoTangent(), + _triu_to_vec(ChainRulesCore.unthunk(ΔX), k), + ChainRulesCore.NoTangent(), + ChainRulesCore.NoTangent(), + ) + end + return update_triu_from_vec(x, k, dim), update_triu_from_vec_pullback +end + +# n * (n - 1) / 2 = d +# ⟺ n^2 - n - 2d = 0 +# ⟹ n = (1 + sqrt(1 + 8d)) / 2 +_triu1_dim_from_length(d) = (1 + isqrt(1 + 8d)) ÷ 2 + +""" + triu1_to_vec(X::AbstractMatrix{<:Real}) + +Extracts elements from upper triangle of `X` with offset `1` and returns them as a vector. +""" +triu1_to_vec(X::AbstractMatrix) = _triu_to_vec(X, 1) + +inverse(::typeof(triu1_to_vec)) = vec_to_triu1 + +""" + vec_to_triu1(x::AbstractVector{<:Real}) + +Constructs a matrix from a vector `x` by filling the upper triangle with offset `1`. +""" +function vec_to_triu1(x::AbstractVector) + n = _triu1_dim_from_length(length(x)) + X = update_triu_from_vec(x, 1, n) + return upper_triangular(X) +end + +inverse(::typeof(vec_to_triu1)) = triu1_to_vec + +function vec_to_triu1_row_index(idx) + # Assumes that vector was saved in a column-major order + # and that vector is one-based indexed. + M = _triu1_dim_from_length(idx - 1) + return idx - (M * (M - 1) ÷ 2) +end + +# Triangular matrix with diagonals. + +# (n^2 + n) / 2 = d +# ⟺ n² + n - 2d = 0 +# ⟺ n = (-1 + sqrt(1 + 8d)) / 2 +_triu_dim_from_length(d) = (-1 + isqrt(1 + 8 * d)) ÷ 2 + +""" + triu_to_vec(X::AbstractMatrix{<:Real}) + +Extracts elements from upper triangle of `X` and returns them as a vector. +""" +triu_to_vec(X::AbstractMatrix) = _triu_to_vec(X, 0) + +""" + vec_to_triu(x::AbstractVector{<:Real}) + +Constructs a matrix from a vector `x` by filling the upper triangle. +""" +function vec_to_triu(x::AbstractVector) + n = _triu_dim_from_length(length(x)) + X = update_triu_from_vec(x, 0, n) + return upper_triangular(X) +end + +inverse(::typeof(vec_to_triu)) = triu_to_vec diff --git a/test/bijectors/pd.jl b/test/bijectors/pd.jl index f429c5c2..2375d27a 100644 --- a/test/bijectors/pd.jl +++ b/test/bijectors/pd.jl @@ -1,13 +1,37 @@ using Bijectors, DistributionsAD, LinearAlgebra, Test -using Bijectors: PDBijector +using Bijectors: PDBijector, PDVecBijector @testset "PDBijector" begin - d = 5 - b = PDBijector() - dist = Wishart(d, Matrix{Float64}(I, d, d)) - x = rand(dist) - # NOTE: `PDBijector` technically isn't bijective, and so the default `getjacobian` - # used in the ChangesOfVariables.jl tests will fail as the jacobian will have determinant 0. - # Hence, we disable those tests. - test_bijector(b, x; test_not_identity=true, changes_of_variables_test=false) + for d in [2, 5] + b = PDBijector() + dist = Wishart(d, Matrix{Float64}(I, d, d)) + x = rand(dist) + # NOTE: `PDBijector` technically isn't bijective, and so the default `getjacobian` + # used in the ChangesOfVariables.jl tests will fail as the jacobian will have determinant 0. + # Hence, we disable those tests. + test_bijector(b, x; test_not_identity=true, changes_of_variables_test=false) + end +end + +@testset "PDVecBijector" begin + for d in [2, 5] + b = PDVecBijector() + dist = Wishart(d, Matrix{Float64}(I, d, d)) + x = rand(dist) + y = b(x) + + # NOTE: `PDBijector` technically isn't bijective, and so the default `getjacobian` + # used in the ChangesOfVariables.jl tests will fail as the jacobian will have determinant 0. + # Hence, we disable those tests. + test_bijector(b, x; test_not_identity=true, changes_of_variables_test=false) + + # Check that output sizes are computed correctly. + tdist = transformed(dist, b) + @test length(tdist) == length(y) + @test tdist isa MultivariateDistribution + + dist_transformed = transformed(MvNormal(zeros(length(tdist)), I), inverse(b)) + @test size(dist_transformed) == size(x) + @test dist_transformed isa MatrixDistribution + end end diff --git a/test/transform.jl b/test/transform.jl index 62cee973..a123aec5 100644 --- a/test/transform.jl +++ b/test/transform.jl @@ -183,15 +183,12 @@ end x = rand(dist) x = x + x' + 2I - lowerinds = [ - LinearIndices(size(x))[I] for I in CartesianIndices(size(x)) if I[1] >= I[2] - ] upperinds = [ LinearIndices(size(x))[I] for I in CartesianIndices(size(x)) if I[2] >= I[1] ] logpdf_turing = logpdf_with_trans(dist, x, true) J = ForwardDiff.jacobian(x -> link(dist, x), x) - J = J[lowerinds, upperinds] + J = J[:, upperinds] @test logpdf(dist, x) - _logabsdet(J) ≈ logpdf_turing end end