Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tests on generic structs #343

Open
gdalle opened this issue Jul 1, 2024 · 10 comments
Open

Tests on generic structs #343

gdalle opened this issue Jul 1, 2024 · 10 comments
Labels
downstream Related to downstream compatibility test Related to the testing subpackage

Comments

@gdalle
Copy link
Owner

gdalle commented Jul 1, 2024

What kind of structs should we add to enable deep learning applications?

@gdalle gdalle added downstream Related to downstream compatibility test Related to the testing subpackage labels Jul 1, 2024
@adrhill
Copy link
Collaborator

adrhill commented Jul 1, 2024

The answer will be different for Flux and Lux, the former most likely requiring support for Functor.jl.
It's worth noting that parameters in DL are somewhat of a mess: some are trainable arrays, some are trainable scalar values and other are non-trainable parameters, e.g. moving statistics in BatchNorm layers.

@avik-pal
Copy link

avik-pal commented Jul 2, 2024

It's worth noting that parameters in DL are somewhat of a mess: some are trainable arrays, some are trainable scalar values and other are non-trainable parameters, e.g. moving statistics in BatchNorm layers.

This is what Lux and its derivative frameworks were designed to fix. LuxCore.jl is basically the interface specification that Lux models need to abide by. Lux simply specifies bindings to LuxLib/NNlib and some utilities users need for deep learning abiding by those specifications. To summarize the interface specification, you have 4 parts:

  1. model: This is an immutable structure of the problem (more specifically, neural network architecture). Values Never Participate in Gradient Calculation
  2. ps: Parameters. All of these are trainable. They may contain scalars but generally don't. When we want to compute the gradients, this is what we are typically computing the gradients for.
  3. st: States (can think of them as non-trainable parameters). These can be anything like arrays, Vals, Numbers, etc. Values Never Participate in Gradient Calculation
  4. data: You might need to compute the gradients for this if users want it. But typically, this is not required.

The answer will be different for Flux and Lux, the former most likely requiring support for Functor.jl

Lux uses a restrictive definition of fmap (for operations on parameters) for type-stability (https://lux.csail.mit.edu/stable/api/Lux/utilities#Lux.recursive_map) but fmap is the general solution for both.

@willtebbutt
Copy link
Contributor

I don't have a strong view on exactly what structs you should test, but I do know of several things that you will need to make decisions about, based on my experience helping @oxinabox design the tangent type system in ChainRules, and my experience with Tapir.jl.

Firstly, there are a couple of edge cases that you'll probably want to actively ensure that people avoid in order to reduce the number of tests you have to write:

  1. (mutable) structs which may have undefined fields. I have to support it in Tapir.jl, because you occasionally hit structs used internally which don't have all fields defined, but since (I'm assuming) this interface is only concerned with user-facing structs, you can just insist that people don't try to take gradients etc w.r.t. them. The problem with them is that they're a pain to test and handle properly -- it's much easier to enforce that all fields are always defined.
  2. Self-referential stuff. What do you want to do about structs which reference themselves? Moreover, what should their tangent type be?

Additionally, you'll need to consider whether to follow ChainRules' v1 approach and be flexible regarding what type is used to represent the tangent of a given struct of a given type, or whether to go down the route that Tapir.jl and Enzyme.jl take of insisting on there being a unique tangent type for each primal type. If you choose the former, you massively blow up the interface surface that you'll have to test. Moreover, you run the risk of different AD backends giving different answers and them both technically being "correct".

Personally, I would encourage you to take an opinionated view, and insist upon unique tangent types. I doubt it will matter too much what types you pick, but my experience is that being restrictive makes your life much easier.

I hope the above is helpful. I'm very excited to see what we wind up with here!

(Also, I'm on holiday at the minute, so I probably won't be super responsive to this thread until next week. Apologies in advance!)

@gdalle
Copy link
Owner Author

gdalle commented Jul 15, 2024

Thanks for your advice!
Before you go, I'd love it if you could take a look at my multi-argument/activity proposal in #311 (comment), see if there are any obvious things we cannot do with it?

@gdalle gdalle mentioned this issue Jul 19, 2024
5 tasks
@willtebbutt
Copy link
Contributor

willtebbutt commented Oct 8, 2024

@gdalle did you ever think any more about this?

The release of 1.11 has prompted me to restart this discussion because Arrays are now generic structs. This is relevant because (once I've finished upgrading Mooncake for 1.11) if you take the gradient of a function w.r.t. an Array, you should no longer expect to get an Array back by default (you'll get a Mooncake.MutableTangent, or something like that). I'm assuming that this wasn't a problem before because DI's tests all use Arrays(?).

While non-array like things are the correct thing to use internally in Mooncake, they're probably not what we want to be presenting to users. I'm keen to write some convenience functionality on my end to provide translations (for some types), but before doing that I would like to know what you would like in DI.

For example, I'm reasonably sure we would agree that an acceptable type for the gradient of a function w.r.t.

  1. a Vector{Float64} is another Vector{Float64},
  2. an Array{Float64, N} is another Array{Float64, N},
  3. a Float64 another Float64,
    but what about more generic types? e.g. Diagonal, component arrays, etc. I can't see this formalised anywhere, so it would be good to agree on it. What happens if we have complicated element types in a given array?

Maybe a useful exercise would be to define for some specific types what the type of the result ought to be, and to clearly state which set of types DI has strong opinions on, and which it does not yet have strong opinions on.

@gdalle
Copy link
Owner Author

gdalle commented Oct 8, 2024

The goal of DI is to be as unopinionated as possible, so I probably won't be taking sides here. Think of DI as a fancy argument-passer, which returns whatever the backends return.

There have been endless discussions on the meaning of derivatives when you're on a manifold, and this meaning differs between backends. From what I understand, ChainRules tries to preserve structure while Enzyme takes a more cartesian approach, so there is no universally right answer. If I try to unify return types for structured objects, I will definitely make a lot of people unhappy, and probably trash performance in the process.

There are also differences on how every backend handles some fields in a struct. Some backends error on integers (ChainRules?), others just ignore them as inactive values (Enzyme?), others differentiate them fine (FiniteDiff?). Some backends even ignore numbers to differentiate only arrays (Tracker?).

Similarly, some backends accept arbitrary tangent types, while other backends (Enzyme and Mooncake) are stricter. For the stricter ones, I implement automatic conversion, but not automatic structure adaptation. In other words, if convert(correct_tangent_type, tangent) fails, you're on your own.

DI is thoroughly tested with the standard Array type, but the test suite is implemented with isapprox, so there is no requirement to return the same type as the reference we compare against. You can even pass your own isapprox function for structured outputs, if your test scenario has specific semantics (e.g. you want to ignore a subset of fields in the struct).

TLDR: Everything is in place to differentiate non-Arrays, but the semantics are up to the backend to decide.

@willtebbutt
Copy link
Contributor

Fair enough. In that case I'll ignore this issue until the upgrades are done, and figure out how to make everything work on the DI end when we get to it :)

@gdalle
Copy link
Owner Author

gdalle commented Oct 8, 2024

Does your new tangent type behave like an array? Can one index it, sum it, etc.?

@willtebbutt
Copy link
Contributor

willtebbutt commented Oct 8, 2024

It almost certainly won't by default.

edit: I say "almost" because I'm not 100% sure what the best choice is from Mooncake's perspective yet.

@gdalle
Copy link
Owner Author

gdalle commented Oct 8, 2024

Let's discuss it in compintell/Mooncake.jl#286?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
downstream Related to downstream compatibility test Related to the testing subpackage
Projects
None yet
Development

No branches or pull requests

4 participants