Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Module Slicing #115

Open
wants to merge 22 commits into
base: master
Choose a base branch
from
Open

Conversation

alexander-g
Copy link
Contributor

Experimental support for slicing modules. e.g:

x = jnp.zeros((2, 224, 224, 3))
resnet = elegy.nets.resnet.ResNet18()
submodule = elegy.module_slicing.slice_module_from_to(
            resnet,
            start_module=None,
            end_module=["/res_net_block_1", "/res_net_block_3", "/res_net_block_5", "/res_net_block_7" ],
            sample_input=x,
        )
outputs = elegy.Model(submodule).predict(x)
assert outputs[0].shape == (2, 56, 56, 64)
assert outputs[1].shape == (2, 28, 28, 128)
assert outputs[2].shape == (2, 14, 14, 256)
assert outputs[3].shape == (2, 7, 7, 512)

This currently requires the additional package networkx. This could be removed with some more work if you don't want to introduce another dependency.

Limitations:

  • All operations between start_module and end_module must be performed by modules
    i.e. jax.nn.relu() or x+1 is not allowed but can be converted by wrapping with elegy.to_module()
  • Only one input module is supported
  • All modules between start_module and end_module must have a single input and a single output
  • Resulting module is currently not trainable: .get_parameters() does not return any weights. Need some hints how to fix that.

@codecov-io
Copy link

codecov-io commented Nov 25, 2020

Codecov Report

Merging #115 (2c9cf1c) into master (87e18c1) will increase coverage by 0.81%.
The diff coverage is 98.56%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master     #115      +/-   ##
==========================================
+ Coverage   85.04%   85.86%   +0.81%     
==========================================
  Files         131      133       +2     
  Lines        6963     7300     +337     
==========================================
+ Hits         5922     6268     +346     
+ Misses       1041     1032       -9     
Impacted Files Coverage Δ
elegy/module_test.py 99.30% <ø> (ø)
elegy/module_slicing.py 97.76% <97.76%> (ø)
elegy/module_slicing_test.py 99.33% <99.33%> (ø)
elegy/hooks.py 85.61% <100.00%> (ø)
elegy/hooks_test.py 100.00% <100.00%> (ø)
elegy/model/model.py 96.41% <100.00%> (ø)
elegy/model/model_core.py 90.04% <100.00%> (+0.09%) ⬆️
elegy/module.py 95.52% <100.00%> (+0.26%) ⬆️
elegy/types.py 94.35% <100.00%> (+0.08%) ⬆️
... and 3 more

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 87e18c1...2c9cf1c. Read the comment docs.

@alexander-g
Copy link
Contributor Author

The resulting module can now be retrained.

@alexander-g alexander-g marked this pull request as ready for review November 25, 2020 11:59
@cgarciae
Copy link
Collaborator

Hey @alexander-g ! This is amazing 😃 Can you share a bit about what strategy you are using to construct the new model?
I think this is a key feature to enable easier Transfer Learning and thus making Elegy more appealing for real use-cases.

@cgarciae
Copy link
Collaborator

cc @charlielito

@alexander-g
Copy link
Contributor Author

I collect information about the main module and its submodules from the summaries feature. Then I construct a directed graph with the modules representing the edges and the inputs/outputs of the modules representing the nodes. If the outputs of module A are the same as the inputs of module B (as returned by id() ^1) then they are connected.
Then I simply search for the shortest path between start_module and the end_module and finally go along this path and execute the corresponding edges/modules.

^1: luckily this seems to work well in JAX, even when you do a x+=1, the value returned by id() changes, whereas in numpy it stays the same.

elegy/module.py Outdated Show resolved Hide resolved
@alexander-g
Copy link
Contributor Author

  • Modules with multiple inputs between start_module and end_module should now work too (e.g. skip connections in ResNet or U-Net). However this is getting more complex than I initially thought. I try to add as many comments as possible and cover everything with test cases but I sometimes get confused by this code myself.
  • What would be a good API for this functionality? Maybe add this as a method: Module.slice()?

@cgarciae
Copy link
Collaborator

cgarciae commented Dec 14, 2020

I'll try to test this out so I can give a better opinion about toe API. In previous version we actually had something in the spirit of slice but it was only for the parameters dictionary, Module.slice seems more intuitive IMO.

@alexander-g
Copy link
Contributor Author

  • Added the Module.slice() method

One ugly detail: I import the module_slicing python module inside the function in module.py:line 622 because of a circular dependency. Was not able to fix that.
Moreover, because Module is the parent class of many other Modules like Conv, BatchNorm etc. mkdocs wants to add this method to the docs although it doesn't really make sense. Can I prevent that? Or should I add them anyway?

Apart from that, I think this is usable and can be merged. I'd like to use it in #126

@cgarciae
Copy link
Collaborator

One ugly detail: I import the module_slicing python module inside the function in module.py:line 622 because of a circular dependency. Was not able to fix that.

I ended moving all hooks like add_loss, add_summary, etc from elegy.hooks to module because of this. An alternative strategy is to set a dummy reference of module_slice on module and patch it on creation:

# module.py
module_slice = None
...

# modole_slice.py
import sys
from . import module

current_module = sys.modules[__name__]
module.module_slice = current_module
...

Not sure this pattern is strictly better.

- all operations between `start_module` and `end_module` must be performed by modules
i.e. `jax.nn.relu()` or `x+1` is not allowed but can be converted by wrapping with `elegy.to_module()`
- only one `start_module` is supported
- all modules between `start_module` and `end_module` must have a single output
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

based on your comment this is not a limitation now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still is.
It's now possible to have inner modules that have multiple inputs but the result module still must have only one input. Single output limitation also holds.

elegy/module.py Outdated
...
```

Arguments:
module_or_name: The name of the summary or alternatively the module that this summary will represent.
If a summary with the same name already exists a unique identifier will be generated.
value: The value for the summary.
input_values: The input arguments for the module, required for slicing.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you elaborate on the structure of the tuple?

@alexander-g alexander-g mentioned this pull request Jan 15, 2021
7 tasks
@alexander-g
Copy link
Contributor Author

Requesting review/testing @cgarciae
This is usable and #126 depends on it.

@alexander-g
Copy link
Contributor Author

Updated to 0.6.0

@cgarciae What is your opinion on this PR? Shall we continue or do you have other ideas on how to do transfer learning

@cgarciae
Copy link
Collaborator

cgarciae commented Feb 20, 2021

My current opinion about this feature is this.

Pros

  • When it works its really simple
  • You can potentially use any intermediate layer

Cons

  • Can be brittle / silently fail if the author of the Module is not careful.
  • Currently only properly supported in Elegy Modules.

The only alternative I can think right now is to have an optional flag (e.g. return_multiple) that returns a dictionary with multiple outs that you would like to expose for the user e.g:

def call(...):
    if self.return_multiple:
        return dict(layer1=out1, layer2=out2, ...)
    else:
        return out

This solution is framework agnostic and easy to implement but is not automatic and requires effort from the author of the Module.

@alexander-g
Copy link
Contributor Author

The only alternative I can think right now is to have an optional flag (e.g. return_multiple) that returns a dictionary with multiple outs that you would like to expose for the user e.g:

This is exactly what I want to avoid. This requires too much ahead-thinking, which is difficult to do especially if the module was written by another author. Say, I want to inspect an arbitrary inner layer of a ResNet, then I'd need to rewrite the call() of this Module

Can be brittle / silently fail if the author of the Module is not careful.

Could you explain what you mean with this or give an example? What should the author take care of? It is indeed brittle but rather because of the inner graph logic which is not yet fully "battle-tested".

@cgarciae
Copy link
Collaborator

cgarciae commented Feb 20, 2021

This is exactly what I want to avoid. This requires too much ahead-thinking, which is difficult to do especially if the module was written by another author. Say, I want to inspect an arbitrary inner layer of a ResNet, then I'd need to rewrite the call() of this Module

I think we can add this feature as a show-case of what can be done even if only certain Modules support it.

Could you explain what you mean with this or give an example? What should the author take care of? It is indeed brittle but rather because of the inner graph logic which is not yet fully "battle-tested".

If I remember correctly this works with the hooks.add_summary feature and only works properly if all intermediate operations are properly added to the summaries, so a causal call to e.g. relu will not be registered and might give an incorrect slice.

@cgarciae
Copy link
Collaborator

I was told that in Pytorch you do this by taking a slice from Sequential. I like this API because its simple, and easy to implement since Sequential already has all the machinery to do this.

@cgarciae cgarciae closed this Feb 22, 2021
@cgarciae cgarciae reopened this Feb 22, 2021
@alexander-g
Copy link
Contributor Author

I was told that in Pytorch you do this by taking a slice from Sequential.

Sounds limited.

torchvision creates an additional class IntermediateLayerGetter to extract intermediate feature maps from a ResNet. I don't like this approach at all, it's not simple for the user.

I've started experimenting with a new method based on jax.make_jaxpr. This would allow using arbitrary JAX code like jax.nn.relu(x) or x+1, but start and end targets still need to be Elegy modules.

@cgarciae
Copy link
Collaborator

cgarciae commented Feb 23, 2021

I've started experimenting with a new method based on jax.make_jaxpr. This would allow using arbitrary JAX code like jax.nn.relu(x) or x+1, but start and end targets still need to be Elegy modules.

This seems super useful, can you show how it could look?
My main worry with something so raw is that the names might be difficult to find (I am just guessing).

@alexander-g alexander-g mentioned this pull request Feb 28, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants