Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vector-version for PDBijector #271

Merged
merged 27 commits into from
Jun 20, 2023
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
14f053d
initial work on PDVecBijector
torfjelde Jun 17, 2023
3396d33
added output_length and output_size to compute output, well, leengths
torfjelde Jun 17, 2023
57988a2
added tests for size of transformed dist using VcCorrBijector
torfjelde Jun 17, 2023
6b65c75
use already constructed transfrormation
torfjelde Jun 17, 2023
367f261
TransformedDistribution should now also have correct variate form
torfjelde Jun 17, 2023
fcee1fe
added proper variateform handling for VecCholeskyBijector too
torfjelde Jun 17, 2023
bc38e64
Apply suggestions from code review
torfjelde Jun 17, 2023
977f39b
added output_size impl for Reshape too
torfjelde Jun 17, 2023
2d27739
added output_size for PDVecBijector annd tests
torfjelde Jun 18, 2023
3194e17
made bijector for PD distributions use PDVecBijcetor
torfjelde Jun 18, 2023
42209be
bump minor version
torfjelde Jun 18, 2023
ee550b2
Update src/bijectors/pd.jl
torfjelde Jun 18, 2023
7867bc6
move utilities from bijectors/corr.jl to utils.jl
torfjelde Jun 18, 2023
1424e2c
fixed Tracker for PD matrices
torfjelde Jun 18, 2023
4beb7a6
Apply suggestions from code review
torfjelde Jun 18, 2023
2885937
fix for matrix AD tests
torfjelde Jun 18, 2023
c92af34
Merge branch 'master' into torfjelde/pd-vec
torfjelde Jun 18, 2023
d6faf97
Merge branch 'master' into torfjelde/pd-vec
torfjelde Jun 19, 2023
5a5ce4a
bumped patch version
torfjelde Jun 19, 2023
0677529
revert patch version
torfjelde Jun 19, 2023
c78c80e
Apply suggestions from code review
torfjelde Jun 20, 2023
70ebe5b
Update src/utils.jl
torfjelde Jun 20, 2023
eb63c2b
removed unnecessary hacks for importing chainrules rule into ReverseDiff
torfjelde Jun 20, 2023
85d188c
markk triu_mask as non-differentiable
torfjelde Jun 20, 2023
35e38b7
shiften some methods around to help with readability
torfjelde Jun 20, 2023
fa2000b
removed redundant wrap_chainrules_output in BijectorsReverseDiffExt
torfjelde Jun 20, 2023
0f04fb5
renamed confusing name in pd tests
torfjelde Jun 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions ext/BijectorsTrackerExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,27 @@ end

# NOTE: Probably doesn't work in complete generality.
wrap_chainrules_output(x) = x
function wrap_chainrules_output(x::ChainRulesCore.Thunk)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this covered by the tests?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No it isn't! But do we want to? 😕
I just started writing a test, and the realized it means I have to move wrap_chainrules__output to Bijectors.jl itself rather than as an extension. But it's only really used for Tracker (it is also used for one part in ReverseDiffjl, but this can be dropped in favor of the macro that has been added to ReverseDiff.jl).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No it isn't!

But does that mean it's not needed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, well, yes 🙃

Also, just realized, the reason why it's probably not there is because this will only be called from Tracker.@grad function which of course should never use @thunk 😬 I'll remove it 👍

return wrap_chainrules_output(ChainRulesCore.unthunk(x))
end
wrap_chainrules_output(x::ChainRulesCore.AbstractZero) = nothing
wrap_chainrules_output(x::Tuple) = map(wrap_chainrules_output, x)

# `update_triu_from_vec`
function Bijectors.update_triu_from_vec(vals::TrackedVector{<:Real}, k::Int, dim::Int)
return track(Bijectors.update_triu_from_vec, vals, k, dim)
end

@grad function Bijectors.update_triu_from_vec(vals::TrackedVector{<:Real}, k::Int, dim::Int)
# HACK: This doesn't support higher order!
y, dy = ChainRulesCore.rrule(Bijectors.update_triu_from_vec, data(vals), k, dim)
return y, (wrap_chainrules_output ∘ Base.tail ∘ dy)
end

Bijectors.upper_triangular(A::TrackedMatrix) = track(Bijectors.upper_triangular, A)
@grad function Bijectors.upper_triangular(A::AbstractMatrix)
Ad = data(A)
return Bijectors.upper_triangular(Ad), Δ -> (Bijectors.upper_triangular(Δ),)
end

end
91 changes: 0 additions & 91 deletions src/bijectors/corr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,97 +89,6 @@ function logabsdetjac(b::CorrBijector, X::AbstractMatrix{<:Real})
return -logabsdetjac(inverse(b), (b(X)))
end

"""
triu_mask(X::AbstractMatrix, k::Int)

Return a mask for elements of `X` above the `k`th diagonal.
"""
function triu_mask(X::AbstractMatrix, k::Int)
# Ensure that we're working with a square matrix.
LinearAlgebra.checksquare(X)

# Using `similar` allows us to respect device of array, etc., e.g. `CuArray`.
m = similar(X, Bool)
return triu(.~m .| m, k)
end

triu_to_vec(X::AbstractMatrix{<:Real}, k::Int) = X[triu_mask(X, k)]

function update_triu_from_vec!(
vals::AbstractVector{<:Real}, k::Int, X::AbstractMatrix{<:Real}
)
# Ensure that we're working with one-based indexing.
# `triu` requires this too.
LinearAlgebra.require_one_based_indexing(X)

# Set the values.
idx = 1
m, n = size(X)
for j in 1:n
for i in 1:min(j - k, m)
X[i, j] = vals[idx]
idx += 1
end
end

return X
end

function update_triu_from_vec(vals::AbstractVector{<:Real}, k::Int, dim::Int)
X = similar(vals, dim, dim)
# TODO: Do we need this?
X .= 0
return update_triu_from_vec!(vals, k, X)
end

function ChainRulesCore.rrule(
::typeof(update_triu_from_vec), x::AbstractVector{<:Real}, k::Int, dim::Int
)
function update_triu_from_vec_pullback(ΔX)
return (
ChainRulesCore.NoTangent(),
triu_to_vec(ChainRulesCore.unthunk(ΔX), k),
ChainRulesCore.NoTangent(),
ChainRulesCore.NoTangent(),
)
end
return update_triu_from_vec(x, k, dim), update_triu_from_vec_pullback
end

# n * (n - 1) / 2 = d
# ⟺ n^2 - n - 2d = 0
# ⟹ n = (1 + sqrt(1 + 8d)) / 2
_triu1_dim_from_length(d) = (1 + isqrt(1 + 8d)) ÷ 2

"""
triu1_to_vec(X::AbstractMatrix{<:Real})

Extracts elements from upper triangle of `X` with offset `1` and returns them as a vector.
"""
triu1_to_vec(X::AbstractMatrix) = triu_to_vec(X, 1)

inverse(::typeof(triu1_to_vec)) = vec_to_triu1

"""
vec_to_triu1(x::AbstractVector{<:Real})

Constructs a matrix from a vector `x` by filling the upper triangle with offset `1`.
"""
function vec_to_triu1(x::AbstractVector)
n = _triu1_dim_from_length(length(x))
X = update_triu_from_vec(x, 1, n)
return upper_triangular(X)
end

inverse(::typeof(vec_to_triu1)) = triu1_to_vec

function vec_to_triu1_row_index(idx)
# Assumes that vector was saved in a column-major order
# and that vector is one-based indexed.
M = _triu1_dim_from_length(idx - 1)
return idx - (M * (M - 1) ÷ 2)
end

"""
VecCorrBijector <: Bijector

Expand Down
45 changes: 45 additions & 0 deletions src/bijectors/pd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,48 @@ end
function with_logabsdet_jacobian(b::PDBijector, X)
return transform(b, X), logabsdetjac(b, X)
end

struct PDVecBijector <: Bijector end

function _triu_dim_from_length(d)
harisorgn marked this conversation as resolved.
Show resolved Hide resolved
# (n^2 + n) / 2 = d
# n² + n - 2d = 0
# n = (-1 + sqrt(1 + 8d)) / 2
return (-1 + isqrt(1 + 8 * d)) ÷ 2
end

"""
vec_to_triu(x::AbstractVector{<:Real})

Constructs a matrix from a vector `x` by filling the upper triangle.
"""
function vec_to_triu(x::AbstractVector)
n = _triu_dim_from_length(length(x))
X = update_triu_from_vec(x, 0, n)
return upper_triangular(X)
end

transform(::PDVecBijector, X::AbstractMatrix{<:Real}) = pd_vec_link(X)
pd_vec_link(X) = triu_to_vec(transpose(pd_link(X)), 0)

function transform(::Inverse{PDVecBijector}, y::AbstractVector{<:Real})
Y = permutedims(vec_to_triu(y))
return transform(inverse(PDBijector()), Y)
end

logabsdetjac(::PDVecBijector, X::AbstractMatrix{<:Real}) = logabsdetjac(PDBijector(), X)

function with_logabsdet_jacobian(b::PDVecBijector, X)
return transform(b, X), logabsdetjac(b, X)
end

function output_size(::PDVecBijector, sz::Tuple{Int,Int})
n = first(sz)
d = (n^2 + n) ÷ 2
return (d,)
end

function output_size(::Inverse{PDVecBijector}, sz::Tuple{Int})
n = _triu_dim_from_length(first(sz))
return (n, n)
end
4 changes: 2 additions & 2 deletions src/transformed_distribution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ bijector(d::BoundedDistribution) = bijector_bounded(d)
const LowerboundedDistribution = Union{Pareto,Levy}
bijector(d::LowerboundedDistribution) = bijector_lowerbounded(d)

bijector(d::PDMatDistribution) = PDBijector()
bijector(d::MatrixBeta) = PDBijector()
bijector(d::PDMatDistribution) = PDVecBijector()
bijector(d::MatrixBeta) = PDVecBijector()

bijector(d::LKJ) = VecCorrBijector()
bijector(d::LKJCholesky) = VecCholeskyBijector(d.uplo)
Expand Down
91 changes: 91 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,94 @@ cholesky_factor(X::AbstractMatrix) = cholesky_factor(cholesky(Hermitian(X)))
cholesky_factor(X::Cholesky) = X.U
cholesky_factor(X::UpperTriangular) = X
cholesky_factor(X::LowerTriangular) = X

"""
triu_mask(X::AbstractMatrix, k::Int)

Return a mask for elements of `X` above the `k`th diagonal.
"""
function triu_mask(X::AbstractMatrix, k::Int)
# Ensure that we're working with a square matrix.
LinearAlgebra.checksquare(X)

# Using `similar` allows us to respect device of array, etc., e.g. `CuArray`.
m = similar(X, Bool)
return triu(.~m .| m, k)
torfjelde marked this conversation as resolved.
Show resolved Hide resolved
end

triu_to_vec(X::AbstractMatrix{<:Real}, k::Int) = X[triu_mask(X, k)]

function update_triu_from_vec!(
vals::AbstractVector{<:Real}, k::Int, X::AbstractMatrix{<:Real}
)
# Ensure that we're working with one-based indexing.
# `triu` requires this too.
LinearAlgebra.require_one_based_indexing(X)

# Set the values.
idx = 1
m, n = size(X)
for j in 1:n
for i in 1:min(j - k, m)
X[i, j] = vals[idx]
idx += 1
end
end

return X
end

function update_triu_from_vec(vals::AbstractVector{<:Real}, k::Int, dim::Int)
X = similar(vals, dim, dim)
# TODO: Do we need this?
fill!(X, 0)
return update_triu_from_vec!(vals, k, X)
end

function ChainRulesCore.rrule(
::typeof(update_triu_from_vec), x::AbstractVector{<:Real}, k::Int, dim::Int
)
function update_triu_from_vec_pullback(ΔX)
return (
ChainRulesCore.NoTangent(),
triu_to_vec(ChainRulesCore.unthunk(ΔX), k),
ChainRulesCore.NoTangent(),
ChainRulesCore.NoTangent(),
)
end
return update_triu_from_vec(x, k, dim), update_triu_from_vec_pullback
end

# n * (n - 1) / 2 = d
# ⟺ n^2 - n - 2d = 0
# ⟹ n = (1 + sqrt(1 + 8d)) / 2
_triu1_dim_from_length(d) = (1 + isqrt(1 + 8d)) ÷ 2

"""
triu1_to_vec(X::AbstractMatrix{<:Real})

Extracts elements from upper triangle of `X` with offset `1` and returns them as a vector.
"""
triu1_to_vec(X::AbstractMatrix) = triu_to_vec(X, 1)

inverse(::typeof(triu1_to_vec)) = vec_to_triu1

"""
vec_to_triu1(x::AbstractVector{<:Real})

Constructs a matrix from a vector `x` by filling the upper triangle with offset `1`.
"""
function vec_to_triu1(x::AbstractVector)
n = _triu1_dim_from_length(length(x))
X = update_triu_from_vec(x, 1, n)
return upper_triangular(X)
end

inverse(::typeof(vec_to_triu1)) = triu1_to_vec

function vec_to_triu1_row_index(idx)
# Assumes that vector was saved in a column-major order
# and that vector is one-based indexed.
M = _triu1_dim_from_length(idx - 1)
return idx - (M * (M - 1) ÷ 2)
end
42 changes: 33 additions & 9 deletions test/bijectors/pd.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,37 @@
using Bijectors, DistributionsAD, LinearAlgebra, Test
using Bijectors: PDBijector
using Bijectors: PDBijector, PDVecBijector

@testset "PDBijector" begin
d = 5
b = PDBijector()
dist = Wishart(d, Matrix{Float64}(I, d, d))
x = rand(dist)
# NOTE: `PDBijector` technically isn't bijective, and so the default `getjacobian`
# used in the ChangesOfVariables.jl tests will fail as the jacobian will have determinant 0.
# Hence, we disable those tests.
test_bijector(b, x; test_not_identity=true, changes_of_variables_test=false)
for d in [2, 5]
b = PDBijector()
dist = Wishart(d, Matrix{Float64}(I, d, d))
x = rand(dist)
# NOTE: `PDBijector` technically isn't bijective, and so the default `getjacobian`
# used in the ChangesOfVariables.jl tests will fail as the jacobian will have determinant 0.
# Hence, we disable those tests.
test_bijector(b, x; test_not_identity=true, changes_of_variables_test=false)
end
end

@testset "PDVecBijector" begin
for d in [2, 5]
b = PDVecBijector()
dist = Wishart(d, Matrix{Float64}(I, d, d))
x = rand(dist)
y = b(x)

# NOTE: `PDBijector` technically isn't bijective, and so the default `getjacobian`
# used in the ChangesOfVariables.jl tests will fail as the jacobian will have determinant 0.
# Hence, we disable those tests.
test_bijector(b, x; test_not_identity=true, changes_of_variables_test=false)

# Check that output sizes are computed correctly.
tdist = transformed(dist, b)
@test length(tdist) == length(y)
@test tdist isa MultivariateDistribution

dist_unconstrained = transformed(MvNormal(zeros(length(tdist)), I), inverse(b))
harisorgn marked this conversation as resolved.
Show resolved Hide resolved
@test size(dist_unconstrained) == size(x)
@test dist_unconstrained isa MatrixDistribution
end
end
5 changes: 1 addition & 4 deletions test/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -183,15 +183,12 @@ end

x = rand(dist)
x = x + x' + 2I
lowerinds = [
LinearIndices(size(x))[I] for I in CartesianIndices(size(x)) if I[1] >= I[2]
]
upperinds = [
LinearIndices(size(x))[I] for I in CartesianIndices(size(x)) if I[2] >= I[1]
]
logpdf_turing = logpdf_with_trans(dist, x, true)
J = ForwardDiff.jacobian(x -> link(dist, x), x)
J = J[lowerinds, upperinds]
J = J[:, upperinds]
@test logpdf(dist, x) - _logabsdet(J) ≈ logpdf_turing
end
end
Expand Down