-
Notifications
You must be signed in to change notification settings - Fork 32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Module Slicing #115
base: master
Are you sure you want to change the base?
Module Slicing #115
Conversation
62d3e3d
to
83d2c0c
Compare
Codecov Report
@@ Coverage Diff @@
## master #115 +/- ##
==========================================
+ Coverage 85.04% 85.86% +0.81%
==========================================
Files 131 133 +2
Lines 6963 7300 +337
==========================================
+ Hits 5922 6268 +346
+ Misses 1041 1032 -9
Continue to review full report at Codecov.
|
The resulting module can now be retrained. |
Hey @alexander-g ! This is amazing 😃 Can you share a bit about what strategy you are using to construct the new model? |
cc @charlielito |
I collect information about the main module and its submodules from the ^1: luckily this seems to work well in JAX, even when you do a |
|
I'll try to test this out so I can give a better opinion about toe API. In previous version we actually had something in the spirit of |
64bdaa4
to
aa02024
Compare
ec22e53
to
3b3d88a
Compare
One ugly detail: I import the Apart from that, I think this is usable and can be merged. I'd like to use it in #126 |
I ended moving all hooks like # module.py
module_slice = None
...
# modole_slice.py
import sys
from . import module
current_module = sys.modules[__name__]
module.module_slice = current_module
... Not sure this pattern is strictly better. |
- all operations between `start_module` and `end_module` must be performed by modules | ||
i.e. `jax.nn.relu()` or `x+1` is not allowed but can be converted by wrapping with `elegy.to_module()` | ||
- only one `start_module` is supported | ||
- all modules between `start_module` and `end_module` must have a single output |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
based on your comment this is not a limitation now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still is.
It's now possible to have inner modules that have multiple inputs but the result module still must have only one input. Single output limitation also holds.
elegy/module.py
Outdated
... | ||
``` | ||
|
||
Arguments: | ||
module_or_name: The name of the summary or alternatively the module that this summary will represent. | ||
If a summary with the same name already exists a unique identifier will be generated. | ||
value: The value for the summary. | ||
input_values: The input arguments for the module, required for slicing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you elaborate on the structure of the tuple?
14046c2
to
e9ae2c2
Compare
e9ae2c2
to
2c9cf1c
Compare
Updated to 0.6.0 @cgarciae What is your opinion on this PR? Shall we continue or do you have other ideas on how to do transfer learning |
My current opinion about this feature is this. Pros
Cons
The only alternative I can think right now is to have an optional flag (e.g. def call(...):
if self.return_multiple:
return dict(layer1=out1, layer2=out2, ...)
else:
return out This solution is framework agnostic and easy to implement but is not automatic and requires effort from the author of the Module. |
This is exactly what I want to avoid. This requires too much ahead-thinking, which is difficult to do especially if the module was written by another author. Say, I want to inspect an arbitrary inner layer of a ResNet, then I'd need to rewrite the
Could you explain what you mean with this or give an example? What should the author take care of? It is indeed brittle but rather because of the inner graph logic which is not yet fully "battle-tested". |
I think we can add this feature as a show-case of what can be done even if only certain Modules support it.
If I remember correctly this works with the |
I was told that in Pytorch you do this by taking a slice from |
Sounds limited.
I've started experimenting with a new method based on |
This seems super useful, can you show how it could look? |
Experimental support for slicing modules. e.g:
This currently requires the additional package
networkx
. This could be removed with some more work if you don't want to introduce another dependency.Limitations:
i.e.
jax.nn.relu()
orx+1
is not allowed but can be converted by wrapping withelegy.to_module()
start_module
andend_module
must have a single input and a single output.get_parameters()
does not return any weights. Need some hints how to fix that.