Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ad extension [WIP] #85

Merged
merged 23 commits into from
Jun 11, 2024
Merged

Ad extension [WIP] #85

merged 23 commits into from
Jun 11, 2024

Conversation

Jutho
Copy link
Owner

@Jutho Jutho commented May 13, 2024

No description provided.

Copy link

codecov bot commented May 13, 2024

Codecov Report

Attention: Patch coverage is 91.84290% with 54 lines in your changes missing coverage. Please review.

Project coverage is 84.64%. Comparing base (da91706) to head (f32f4f0).
Report is 1 commits behind head on master.

Current head f32f4f0 differs from pull request most recent head b60adc5

Please upload reports for the commit b60adc5 to get more accurate results.

Files Patch % Lines
ext/KrylovKitChainRulesCoreExt/eigsolve.jl 92.16% 21 Missing ⚠️
ext/KrylovKitChainRulesCoreExt/svdsolve.jl 92.62% 16 Missing ⚠️
ext/KrylovKitChainRulesCoreExt/linsolve.jl 89.79% 5 Missing ⚠️
src/eigsolve/arnoldi.jl 82.75% 5 Missing ⚠️
src/linsolve/linsolve.jl 58.33% 5 Missing ⚠️
src/eigsolve/eigsolve.jl 85.71% 1 Missing ⚠️
src/factorizations/lanczos.jl 91.66% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master      #85      +/-   ##
==========================================
+ Coverage   82.05%   84.64%   +2.59%     
==========================================
  Files          27       31       +4     
  Lines        2753     3322     +569     
==========================================
+ Hits         2259     2812     +553     
- Misses        494      510      +16     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

pbrehmer and others added 3 commits May 13, 2024 21:15
* Add untested svdsolve rrule

* Fix typos

* Add svdsolve rrule to extension folder

* Delete src/adrules/svdsolve.jl

---------

Co-authored-by: Jutho <Jutho@users.noreply.github.com>
@Jutho Jutho mentioned this pull request May 25, 2024
@Jutho
Copy link
Owner Author

Jutho commented May 25, 2024

I think this is now mostly ready, up to some cleanup and streamlining of the interface. Maybe @lkdvos wants to review?

@Jutho
Copy link
Owner Author

Jutho commented May 25, 2024

There is one more significant TODO:

The eigenvalue approach for solving the linear problem/Sylvester problem in both the rrule of eigsolve and svdsolve is nonhermitian, which means that the results are always obtained in complex arithmetic, even when the forward calculation can be completely real (namely for a real symmetric eigenvalue problem or a real singular value problem). This so far does not cause problems, as apparently the imaginary parts of the computed quantities is exactly zero and therefore it is implicitly converted back to real vectors. However, this might break with custom types, so it would be better to explicitly restrict to real arithmetic by using schursolve and a custom routine to extract the real eigenvectors associated with the results coming out of schursolve.

@lkdvos
Copy link
Collaborator

lkdvos commented May 25, 2024

There is one more significant TODO:

The eigenvalue approach for solving the linear problem/Sylvester problem in both the rrule of eigsolve and svdsolve is nonhermitian, which means that the results are always obtained in complex arithmetic, even when the forward calculation can be completely real (namely for a real symmetric eigenvalue problem or a real singular value problem). This so far does not cause problems, as apparently the imaginary parts of the computed quantities is exactly zero and therefore it is implicitly converted back to real vectors. However, this might break with custom types, so it would be better to explicitly restrict to real arithmetic by using schursolve and a custom routine to extract the real eigenvectors associated with the results coming out of schursolve.

I haven't looked into this in detail, but this sounds like it could also be solved with appropriate calls to ProjectTo, which should work even for custom types as there it gives a hook to add the correct projection methods.

Copy link
Collaborator

@lkdvos lkdvos left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is definitely a lot of nice work, looks great!

I think I mainly have some minor typos and small nitpicking things, but maybe as a more general comment:
I am not a huge fan of the dummy (; alg_rrule=nothing) in the keyword arguments of the forward passes. Conceptually, I quite like that the method definition does not need to know anything about the AD that may or not happen, and it feels a bit unnatural that if I were to decide to implement eg a MinRes linsolver, I need to remember to add the keyword argument.

I think, once we have some benchmarks, it should be possible to have decent default rrule algorithm defaults based on the forward algorithm, and then any expert user who still wants to play around with the different implementations, or experiment with new ones could do the (little bit of) extra work of doing something along the lines of hook_pullback(f, args...; kwargs..., alg_rrule=my_alg), which we can even hide in a macro: @hook_pullback f(args...; kwargs..., alg_rrule=myalg) or @alg_rrule f(args...; kwargs..., alg_rrule=myalg) or something similar.
This being said, this could also just be me, and if it does not bother you as much, it might not be worth it to change it.

Finally,

ext/KrylovKitChainRulesCoreExt/eigsolve.jl Outdated Show resolved Hide resolved
Comment on lines 42 to 45
if n == 0
∂f = ZeroTangent()
return ∂self, ∂f, ∂x₀, ∂howmany, ∂which, ∂alg
end
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like a type instability (I assume this is what you were referring to this week?)
Do you think we can avoid this by:

  1. inserting pullback_eigsolve(ΔX::Tuple{AbstractZero, AbstractZero, Any}) = [...] (with ∂f = ZeroTangent())
  2. throwing a warning and explicitly computing the zero pullback

I am honestly not sure if the second case ever happens, as I think this implies that the dependence on the eigenvalues and eigenvectors is exactly zero, which sounds incredibly implausible with floating point accuracy. This would both mean that the regular (most common) case is now type-stable, and that the case where n = 0 gets handled properly when both inputs are AbstractZero (which afaik e.g. Zygote would never even generate either)

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes I will get rid of this.

ext/KrylovKitChainRulesCoreExt/eigsolve.jl Outdated Show resolved Hide resolved
ext/KrylovKitChainRulesCoreExt/eigsolve.jl Outdated Show resolved Hide resolved
function compute_eigsolve_pullback_data(Δvals, Δvecs, vals, vecs, info, which, fᴴ, T,
alg_primal::Arnoldi, alg_rrule::Arnoldi)
n = length(Δvecs)
G = zeros(T, n, n)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
G = zeros(T, n, n)
G = Matrix{T}(undef, n, n) # eigenvector overlap matrix

Comment on lines 205 to 206
eigsolve(W₀, n, :LR, alg_rrule) do w
x, y, z = w
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
eigsolve(W₀, n, :LR, alg_rrule) do w
x, y, z = w
eigsolve(W₀, n, :LR, alg_rrule) do (x, y, z)

ext/KrylovKitChainRulesCoreExt/svdsolve.jl Outdated Show resolved Hide resolved
ext/KrylovKitChainRulesCoreExt/svdsolve.jl Show resolved Hide resolved
ext/KrylovKitChainRulesCoreExt/utilities.jl Show resolved Hide resolved
test/ad.jl Show resolved Hide resolved
@Jutho
Copy link
Owner Author

Jutho commented May 30, 2024

Ok, I think this is mostly ready. Maybe I can add a few more tests to improve coverage (e.g. print warnings and test for them). The TODO for not having to go via complex values in case of Arnoldi rrule for svdsolve or hermitian eigsolve is also still open.

About the interface; I did not really study the comment of @lkdvos above. What do you dislike about the keyword argument ; alg_rrule = ... in the methods? Whether the actual values to this keyword need to have the values they currently have (recycling existing structures) or some new values is certainly open for debate. And a sensible default definitely needs to be in place after benchmarking. This also reminds me, documentation about this needs to be added.

@lkdvos
Copy link
Collaborator

lkdvos commented May 30, 2024

I think my main argument against the alg_rrule kwarg is that it "pollutes" the method definition with information that is only relevant to AD. In other words, if I were to for example write an implementation of minres, I now need to remember to add that keyword argument, even though this has nothing to do with the primal computation.

I would much rather have a different implementation that keeps the primal computations clean. One such way is to simply add a wrapper function with the alg_rrule kwarg added, which can then correctly distribute the args and kwargs to their relevant places.

I am a bit more in favour of something like:

vals, vecs, info = eigsolve(A, x, num, which, alg) # can infer default AD algorithm

# option 1:
vals, vecs, info = hook_pullback(eigsolve, A, x, num, which, alg; alg_rrule) # expert mode -- specifies rrule algorithm

# option 2:
@alg_rrule vals, vecs, info = eigsolve(A, x, num, which, alg; alg_rrule) # looks like current implementation, but expands to option 1

This being said, in principle this is mostly a conceptual issue, and this does not really change all that much. I like keeping AD separated, but this is obviously strictly necessary.


Concerning the test coverage, it might be a good idea to add a sparse matrix to the set of tests. This seems like a good candidate for something that checks if our assumptions about what we can/cannot do with AbstractMatrix types are fair

@Jutho
Copy link
Owner Author

Jutho commented May 30, 2024

Ok I see; I agree that it ads a keyword that is irrelevant to the forward computation. As that is mostly a "burden" on the developer side, I don't mind too much. If there would be some centralised infrastructure to do the hook or macro solution, I would use it, but for now, I think there is less overhead in simply adding those keywords rather than developing it within the scope of this package.

From the user side, I think the current approach is fine right? If they don't need AD, they don't need to care about this keyword, and if they do, it is easy to try out the different choices without any significant change to the code.

@lkdvos
Copy link
Collaborator

lkdvos commented May 30, 2024

yes, I think so

@Jutho
Copy link
Owner Author

Jutho commented Jun 10, 2024

Ok, I think this is more or less ready. I started modifying the tests by using @test_logs to test verbosity, but only in certain parts of the tests so far (because it conflicts with @constinferred and because I don't know how to test for a variable number of @info outputs). Making the tests more uniform can happen later.

The other remaining questions are interface, i.e. the keyword argument rrule_alg, and also the fact that the warning for the gauge-dependence of the adjoint is only printed for verbosity >= 1 (which then also implies that other @info statements are printed). Hence, there is currently no way to have the warning (which I think is important and relevant) without having additional @info output. However, if the warning is on by default (i.e. also for verbosity==0), then the tests would produce plenty of those, as computing the full Jacobian with respect to all variables in the eigenvectors is not a gauge-invariant operation.

Copy link
Collaborator

@lkdvos lkdvos left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I quickly read through the remaining changes, I left a small comment about the _realview/_imagview implementations.
Considering the interface, I am ok with keeping the alg_rrule kwarg in its current state, it seems like this might be the most elegant solution for now anyways.
For the logging, I would also rather have the warnings unaffected by the verbosity keyword, as this is typically not the behaviour you would want. Can we just turn of logging in the case where the Jacobian is constructed and expected to not respect gauge invariance?
The easiest solution is something like this:

using Logging
Logging.with_logger(NullLogger()) do
# all logs that originate here get discarded
end

A more advanced solution is:

struct IgnoreWarningLogger{L}
   parent::L
end
Logging.min_enabled_level(logger::IgnoreWarningLogger) = min_enabled_level(logger.parent)
Logging.catch_exceptions(logger::IgnoreWarningLogger) = catch_exceptions(logger.parent)
function Logging.shouldlog(logger::IgnoreWarningLogger, level, _module, group, id)
   id == (##insert warning log id here ##) && return false
   return shouldlog(logger.parent, level, _module, group, id)
end
function Logging.handle_message(logger::IgnoreWarningLogger,
                                args...; kwargs...)
    handle_message(logger.parent, args...; kwargs...)
end

Comment on lines +56 to +64
function _realview(v::AbstractVector{Complex{T}}) where {T}
v_real = reinterpret(T, v)
return view(v_real, axes(v_real, 1)[begin:2:end])
end

function _imagview(v::AbstractVector{Complex{T}}) where {T}
v_real = reinterpret(T, v)
return view(v_real, axes(v_real, 1)[(begin + 1):2:end])
end
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we might want to explicitly restrict these definitions to BlasFloats, considering it relies heavily on how the data is stored? I am not sure how the semantics of reinterpret are defined for eg Complex{BigFloat}, and I guess that KrylovKit currently would not work with these types anyways, but it might be safer to just explicitly restrict and if necessary extend afterwards?

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think its up to reinterpret to handle the more complicated types correctly; the semantics of reinterpret are used correctly. I think reinterpret is only defined for isbits types, and will thus error for anything else.

Copy link
Owner Author

@Jutho Jutho Jun 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There might be user defined Number subtypes for which this is correct, which we would exclude with the restriction (not that we currently support this in the rest of KrylovKit, but you never know).

@Jutho
Copy link
Owner Author

Jutho commented Jun 11, 2024

Thanks for the logging solution; that is indeed useful. However, another reason why I decided to hide the warning behind the verbosity >= 1 flag is that I also want to make it easy to switch it off for users that know what they are doing.

Another solution could thus be to reduce it to verbosity >= 0, so that the gauge dependence warning is on by default, but can be switched off by setting verbosity=-1. At some point we should also switch to using the levels of the Logging system, but I do not want to invest the time in this right now.

@Jutho Jutho merged commit 27662a3 into master Jun 11, 2024
@lkdvos
Copy link
Collaborator

lkdvos commented Jun 11, 2024

🥳

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants