-
Notifications
You must be signed in to change notification settings - Fork 33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Vector-version for PDBijector
#271
Merged
Merged
Changes from 22 commits
Commits
Show all changes
27 commits
Select commit
Hold shift + click to select a range
14f053d
initial work on PDVecBijector
torfjelde 3396d33
added output_length and output_size to compute output, well, leengths
torfjelde 57988a2
added tests for size of transformed dist using VcCorrBijector
torfjelde 6b65c75
use already constructed transfrormation
torfjelde 367f261
TransformedDistribution should now also have correct variate form
torfjelde fcee1fe
added proper variateform handling for VecCholeskyBijector too
torfjelde bc38e64
Apply suggestions from code review
torfjelde 977f39b
added output_size impl for Reshape too
torfjelde 2d27739
added output_size for PDVecBijector annd tests
torfjelde 3194e17
made bijector for PD distributions use PDVecBijcetor
torfjelde 42209be
bump minor version
torfjelde ee550b2
Update src/bijectors/pd.jl
torfjelde 7867bc6
move utilities from bijectors/corr.jl to utils.jl
torfjelde 1424e2c
fixed Tracker for PD matrices
torfjelde 4beb7a6
Apply suggestions from code review
torfjelde 2885937
fix for matrix AD tests
torfjelde c92af34
Merge branch 'master' into torfjelde/pd-vec
torfjelde d6faf97
Merge branch 'master' into torfjelde/pd-vec
torfjelde 5a5ce4a
bumped patch version
torfjelde 0677529
revert patch version
torfjelde c78c80e
Apply suggestions from code review
torfjelde 70ebe5b
Update src/utils.jl
torfjelde eb63c2b
removed unnecessary hacks for importing chainrules rule into ReverseDiff
torfjelde 85d188c
markk triu_mask as non-differentiable
torfjelde 35e38b7
shiften some methods around to help with readability
torfjelde fa2000b
removed redundant wrap_chainrules_output in BijectorsReverseDiffExt
torfjelde 0f04fb5
renamed confusing name in pd tests
torfjelde File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,37 @@ | ||
using Bijectors, DistributionsAD, LinearAlgebra, Test | ||
using Bijectors: PDBijector | ||
using Bijectors: PDBijector, PDVecBijector | ||
|
||
@testset "PDBijector" begin | ||
d = 5 | ||
b = PDBijector() | ||
dist = Wishart(d, Matrix{Float64}(I, d, d)) | ||
x = rand(dist) | ||
# NOTE: `PDBijector` technically isn't bijective, and so the default `getjacobian` | ||
# used in the ChangesOfVariables.jl tests will fail as the jacobian will have determinant 0. | ||
# Hence, we disable those tests. | ||
test_bijector(b, x; test_not_identity=true, changes_of_variables_test=false) | ||
for d in [2, 5] | ||
b = PDBijector() | ||
dist = Wishart(d, Matrix{Float64}(I, d, d)) | ||
x = rand(dist) | ||
# NOTE: `PDBijector` technically isn't bijective, and so the default `getjacobian` | ||
# used in the ChangesOfVariables.jl tests will fail as the jacobian will have determinant 0. | ||
# Hence, we disable those tests. | ||
test_bijector(b, x; test_not_identity=true, changes_of_variables_test=false) | ||
end | ||
end | ||
|
||
@testset "PDVecBijector" begin | ||
for d in [2, 5] | ||
b = PDVecBijector() | ||
dist = Wishart(d, Matrix{Float64}(I, d, d)) | ||
x = rand(dist) | ||
y = b(x) | ||
|
||
# NOTE: `PDBijector` technically isn't bijective, and so the default `getjacobian` | ||
# used in the ChangesOfVariables.jl tests will fail as the jacobian will have determinant 0. | ||
# Hence, we disable those tests. | ||
test_bijector(b, x; test_not_identity=true, changes_of_variables_test=false) | ||
|
||
# Check that output sizes are computed correctly. | ||
tdist = transformed(dist, b) | ||
@test length(tdist) == length(y) | ||
@test tdist isa MultivariateDistribution | ||
|
||
dist_unconstrained = transformed(MvNormal(zeros(length(tdist)), I), inverse(b)) | ||
harisorgn marked this conversation as resolved.
Show resolved
Hide resolved
|
||
@test size(dist_unconstrained) == size(x) | ||
@test dist_unconstrained isa MatrixDistribution | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this covered by the tests?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No it isn't! But do we want to? 😕
I just started writing a test, and the realized it means I have to move
wrap_chainrules__output
to Bijectors.jl itself rather than as an extension. But it's only really used for Tracker (it is also used for one part in ReverseDiffjl, but this can be dropped in favor of the macro that has been added to ReverseDiff.jl).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But does that mean it's not needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, well, yes 🙃
Also, just realized, the reason why it's probably not there is because this will only be called from
Tracker.@grad function
which of course should never use@thunk
😬 I'll remove it 👍