Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Numerical improvements to correlation bijectors #313

Merged
merged 21 commits into from
Jun 5, 2024
Merged

Numerical improvements to correlation bijectors #313

merged 21 commits into from
Jun 5, 2024

Conversation

sethaxen
Copy link
Member

This PR implements the numerical suggestions in #301. It does not make any of the suggested renaming changes, which will be left for a future PR.

src/bijectors/corr.jl Outdated Show resolved Hide resolved
Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Copy link
Member

@torfjelde torfjelde left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will be really nice @sethaxen :)

I've added a few comments. It also seems as if the chain rule is somehow not type stable? 😕

src/bijectors/corr.jl Show resolved Hide resolved
src/bijectors/corr.jl Outdated Show resolved Hide resolved
src/bijectors/corr.jl Outdated Show resolved Hide resolved
src/bijectors/corr.jl Outdated Show resolved Hide resolved
@@ -268,7 +268,6 @@ end
@grad_from_chainrules _link_chol_lkj(x::TrackedMatrix)
@grad_from_chainrules _link_chol_lkj_from_upper(x::TrackedMatrix)
@grad_from_chainrules _link_chol_lkj_from_lower(x::TrackedMatrix)
@grad_from_chainrules _inv_link_chol_lkj(x::TrackedVector)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, is this intentional?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the problem is that neither @grad nor @grad_from_chainrules supports multi-output functions (JuliaDiff/ReverseDiff.jl#221), so we cannot use this macro. At the same time, nothing in the function should not be AD-able by ReverseDiff, so I just removed the rule.

However, we have the same problem with Tracker. Tracker.@grad seems to not support multi-output functions, and I'm still working out how to AD through the primal (I have an idea for a fix).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you have an idea how to get this working for ReverseDiff, let me know. It would be great to use the manual pullback.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nope, my Tracker idea did not work.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nevermind, it works!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lovely:)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, how did you achieve it?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The changes in this commit: a2eac95 . In Tracker, the cotangent of multi-output functions ends up being a TrackedTuple, which doesn't support iteration, so instead use indexing to split the tuple. And also make pd_from_lower and pd_from_upper use the same tricks as cholesky_upper and cholesky_lower.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Uuuuh I didn't know that that was the reason why it was an issue! Dopey:)

src/bijectors/corr.jl Outdated Show resolved Hide resolved
@sethaxen
Copy link
Member Author

sethaxen commented Jun 3, 2024

Remaining errors seem to be ones introduced in #304 and unrelated to this PR

@torfjelde
Copy link
Member

Remaining errors seem to be ones introduced in #304 and unrelated to this PR

Uhmm that's strange o.O Don't understand why this wasn't failing in the original PR. Ooor it might be because it hit the cholesky error and thus didn't run the interface tests on 1.6.. Should be a quick Compat.jl inclusion though; lemme have a check

Copy link
Member

@torfjelde torfjelde left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lovely stuff:) Happy with merging as soon as tests pass (which should be a quick merge with master after #314 )

src/bijectors/corr.jl Show resolved Hide resolved
@@ -268,7 +268,6 @@ end
@grad_from_chainrules _link_chol_lkj(x::TrackedMatrix)
@grad_from_chainrules _link_chol_lkj_from_upper(x::TrackedMatrix)
@grad_from_chainrules _link_chol_lkj_from_lower(x::TrackedMatrix)
@grad_from_chainrules _inv_link_chol_lkj(x::TrackedVector)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lovely:)

@@ -268,7 +268,6 @@ end
@grad_from_chainrules _link_chol_lkj(x::TrackedMatrix)
@grad_from_chainrules _link_chol_lkj_from_upper(x::TrackedMatrix)
@grad_from_chainrules _link_chol_lkj_from_lower(x::TrackedMatrix)
@grad_from_chainrules _inv_link_chol_lkj(x::TrackedVector)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, how did you achieve it?

@yebai
Copy link
Member

yebai commented Jun 3, 2024

It seems #314 does not address the issue on Julia 1.6.

@sethaxen
Copy link
Member Author

sethaxen commented Jun 4, 2024

I don't think #314 was completed before merging. Its CI was still failing with a similar error.

@torfjelde
Copy link
Member

Yes, #314 was indeed not ready for a merge @yebai ; why was it merged?

@yebai
Copy link
Member

yebai commented Jun 4, 2024

Yes, #314 was indeed not ready for a merge @yebai ; why was it merged?

my bad -- it shouldn't have been merged.

@torfjelde
Copy link
Member

@sethaxen I just pushed the fix directly to this branch. Let's see if CI succeeds now:)

@torfjelde
Copy link
Member

Damn, even this doesn't work because eachslice only supports a single value as the dims arg 😕

Really sorry about this @sethaxen ; this bug was hidden behind an unrelated numerical issue that caused this particular test to never be run on 1.6. But the cause of this shouldn't have been merged.

I'll just disable those tests on this PR and then we'll have to fix it in a separate PR.

@sethaxen
Copy link
Member Author

sethaxen commented Jun 5, 2024

Seems like it worked!

EDIT: and, no problem, @torfjelde !

@torfjelde
Copy link
Member

Lovely! Feel free to hit the big green button:)

@sethaxen
Copy link
Member Author

sethaxen commented Jun 5, 2024

Sadly, I am not an "authorized user," and the button is gray.

@torfjelde
Copy link
Member

Want me to do it then? Also happy to let give you authorization given your involvement in Bijectors.jl if you want to:)

@sethaxen
Copy link
Member Author

sethaxen commented Jun 5, 2024

Want me to do it then? Also happy to let give you authorization given your involvement in Bijectors.jl if you want to:)

Sure to both!

@torfjelde torfjelde merged commit fd53666 into master Jun 5, 2024
23 checks passed
@delete-merged-branch delete-merged-branch bot deleted the patch branch June 5, 2024 08:16
@torfjelde
Copy link
Member

Done:) Wonderful stuff @sethaxen ; thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants