-
Notifications
You must be signed in to change notification settings - Fork 35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Ad extension [WIP] #85
Conversation
AD extension package
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #85 +/- ##
==========================================
+ Coverage 82.05% 84.64% +2.59%
==========================================
Files 27 31 +4
Lines 2753 3322 +569
==========================================
+ Hits 2259 2812 +553
- Misses 494 510 +16 ☔ View full report in Codecov by Sentry. |
* Add untested svdsolve rrule * Fix typos * Add svdsolve rrule to extension folder * Delete src/adrules/svdsolve.jl --------- Co-authored-by: Jutho <Jutho@users.noreply.github.com>
I think this is now mostly ready, up to some cleanup and streamlining of the interface. Maybe @lkdvos wants to review? |
There is one more significant TODO: The eigenvalue approach for solving the linear problem/Sylvester problem in both the |
I haven't looked into this in detail, but this sounds like it could also be solved with appropriate calls to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is definitely a lot of nice work, looks great!
I think I mainly have some minor typos and small nitpicking things, but maybe as a more general comment:
I am not a huge fan of the dummy (; alg_rrule=nothing)
in the keyword arguments of the forward passes. Conceptually, I quite like that the method definition does not need to know anything about the AD that may or not happen, and it feels a bit unnatural that if I were to decide to implement eg a MinRes
linsolver, I need to remember to add the keyword argument.
I think, once we have some benchmarks, it should be possible to have decent default rrule algorithm defaults based on the forward algorithm, and then any expert user who still wants to play around with the different implementations, or experiment with new ones could do the (little bit of) extra work of doing something along the lines of hook_pullback(f, args...; kwargs..., alg_rrule=my_alg)
, which we can even hide in a macro: @hook_pullback f(args...; kwargs..., alg_rrule=myalg)
or @alg_rrule f(args...; kwargs..., alg_rrule=myalg)
or something similar.
This being said, this could also just be me, and if it does not bother you as much, it might not be worth it to change it.
Finally,
if n == 0 | ||
∂f = ZeroTangent() | ||
return ∂self, ∂f, ∂x₀, ∂howmany, ∂which, ∂alg | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks like a type instability (I assume this is what you were referring to this week?)
Do you think we can avoid this by:
- inserting
pullback_eigsolve(ΔX::Tuple{AbstractZero, AbstractZero, Any}) = [...]
(with∂f = ZeroTangent()
) - throwing a warning and explicitly computing the zero pullback
I am honestly not sure if the second case ever happens, as I think this implies that the dependence on the eigenvalues and eigenvectors is exactly zero, which sounds incredibly implausible with floating point accuracy. This would both mean that the regular (most common) case is now type-stable, and that the case where n = 0
gets handled properly when both inputs are AbstractZero
(which afaik e.g. Zygote would never even generate either)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes I will get rid of this.
function compute_eigsolve_pullback_data(Δvals, Δvecs, vals, vecs, info, which, fᴴ, T, | ||
alg_primal::Arnoldi, alg_rrule::Arnoldi) | ||
n = length(Δvecs) | ||
G = zeros(T, n, n) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
G = zeros(T, n, n) | |
G = Matrix{T}(undef, n, n) # eigenvector overlap matrix |
eigsolve(W₀, n, :LR, alg_rrule) do w | ||
x, y, z = w |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
eigsolve(W₀, n, :LR, alg_rrule) do w | |
x, y, z = w | |
eigsolve(W₀, n, :LR, alg_rrule) do (x, y, z) |
Co-authored-by: Lukas <37111893+lkdvos@users.noreply.github.com>
Ok, I think this is mostly ready. Maybe I can add a few more tests to improve coverage (e.g. print warnings and test for them). The TODO for not having to go via complex values in case of Arnoldi About the interface; I did not really study the comment of @lkdvos above. What do you dislike about the keyword argument |
I think my main argument against the I would much rather have a different implementation that keeps the primal computations clean. One such way is to simply add a wrapper function with the I am a bit more in favour of something like: vals, vecs, info = eigsolve(A, x, num, which, alg) # can infer default AD algorithm
# option 1:
vals, vecs, info = hook_pullback(eigsolve, A, x, num, which, alg; alg_rrule) # expert mode -- specifies rrule algorithm
# option 2:
@alg_rrule vals, vecs, info = eigsolve(A, x, num, which, alg; alg_rrule) # looks like current implementation, but expands to option 1 This being said, in principle this is mostly a conceptual issue, and this does not really change all that much. I like keeping AD separated, but this is obviously strictly necessary. Concerning the test coverage, it might be a good idea to add a sparse matrix to the set of tests. This seems like a good candidate for something that checks if our assumptions about what we can/cannot do with |
Ok I see; I agree that it ads a keyword that is irrelevant to the forward computation. As that is mostly a "burden" on the developer side, I don't mind too much. If there would be some centralised infrastructure to do the hook or macro solution, I would use it, but for now, I think there is less overhead in simply adding those keywords rather than developing it within the scope of this package. From the user side, I think the current approach is fine right? If they don't need AD, they don't need to care about this keyword, and if they do, it is easy to try out the different choices without any significant change to the code. |
yes, I think so |
Ok, I think this is more or less ready. I started modifying the tests by using The other remaining questions are interface, i.e. the keyword argument |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I quickly read through the remaining changes, I left a small comment about the _realview
/_imagview
implementations.
Considering the interface, I am ok with keeping the alg_rrule
kwarg in its current state, it seems like this might be the most elegant solution for now anyways.
For the logging, I would also rather have the warnings unaffected by the verbosity
keyword, as this is typically not the behaviour you would want. Can we just turn of logging in the case where the Jacobian is constructed and expected to not respect gauge invariance?
The easiest solution is something like this:
using Logging
Logging.with_logger(NullLogger()) do
# all logs that originate here get discarded
end
A more advanced solution is:
struct IgnoreWarningLogger{L}
parent::L
end
Logging.min_enabled_level(logger::IgnoreWarningLogger) = min_enabled_level(logger.parent)
Logging.catch_exceptions(logger::IgnoreWarningLogger) = catch_exceptions(logger.parent)
function Logging.shouldlog(logger::IgnoreWarningLogger, level, _module, group, id)
id == (##insert warning log id here ##) && return false
return shouldlog(logger.parent, level, _module, group, id)
end
function Logging.handle_message(logger::IgnoreWarningLogger,
args...; kwargs...)
handle_message(logger.parent, args...; kwargs...)
end
function _realview(v::AbstractVector{Complex{T}}) where {T} | ||
v_real = reinterpret(T, v) | ||
return view(v_real, axes(v_real, 1)[begin:2:end]) | ||
end | ||
|
||
function _imagview(v::AbstractVector{Complex{T}}) where {T} | ||
v_real = reinterpret(T, v) | ||
return view(v_real, axes(v_real, 1)[(begin + 1):2:end]) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we might want to explicitly restrict these definitions to BlasFloats, considering it relies heavily on how the data is stored? I am not sure how the semantics of reinterpret
are defined for eg Complex{BigFloat}
, and I guess that KrylovKit currently would not work with these types anyways, but it might be safer to just explicitly restrict and if necessary extend afterwards?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think its up to reinterpret to handle the more complicated types correctly; the semantics of reinterpret
are used correctly. I think reinterpret
is only defined for isbits types, and will thus error for anything else.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There might be user defined Number
subtypes for which this is correct, which we would exclude with the restriction (not that we currently support this in the rest of KrylovKit, but you never know).
Thanks for the logging solution; that is indeed useful. However, another reason why I decided to hide the warning behind the Another solution could thus be to reduce it to |
🥳 |
No description provided.