-
Notifications
You must be signed in to change notification settings - Fork 32
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Module Slicing #115
Open
alexander-g
wants to merge
22
commits into
poets-ai:master
Choose a base branch
from
alexander-g:slicing_modules
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Module Slicing #115
Changes from 14 commits
Commits
Show all changes
22 commits
Select commit
Hold shift + click to select a range
e622fde
experimental module slicing
alexander-g 83d2c0c
added networkx dependency
alexander-g 0ffc36f
sliced module is now retrainable
alexander-g c4ec83d
exception handling
alexander-g c525614
refactoring
alexander-g 2f18a9a
resnet50 fix
alexander-g a75c7e3
Slicing with multi-input modules between start_module and end_module …
alexander-g 41fa916
Docs for add_summary
alexander-g 15943e3
black
alexander-g 250e783
Merge branch 'master' into slicing_modules
alexander-g f57d41b
Merge branch 'master' into _slicing
alexander-g aa02024
fixing poetry
alexander-g 3b3d88a
Module.slice()
alexander-g 06cfd4c
Merge branch 'master' into slicing_modules
cgarciae e77a207
circular dependency fix
alexander-g 9a59c93
can now specify inputs as an output target for Module.slice()
alexander-g c89df43
slicing deferred call bugfix
alexander-g c2808a4
Merge branch 'master' into slicing_modules
alexander-g f5d669d
Merge branch 'master' into slicing_modules
alexander-g d908398
update to 0.6.0
alexander-g 330c3e9
Merge branch 'master' into slicing_modules
alexander-g 2c9cf1c
test fixes and black
alexander-g File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,4 +13,5 @@ | |
- reset | ||
- init | ||
- initialized | ||
- slice | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,4 +13,5 @@ | |
- reset | ||
- init | ||
- initialized | ||
- slice | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,4 +13,5 @@ | |
- reset | ||
- init | ||
- initialized | ||
- slice | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,9 @@ | |
from elegy.random import RNG | ||
from elegy.utils import EMPTY, Empty, Mode, ModuleOrderError | ||
|
||
# imported later because of a circular dependency | ||
# from elegy.module_slicing import slice_module_from_to | ||
|
||
__all__ = [ | ||
"Module", | ||
"to_module", | ||
|
@@ -244,6 +247,7 @@ class Module(metaclass=ModuleMeta): | |
"reset", | ||
"init", | ||
"initialized", | ||
"slice", | ||
] | ||
|
||
def __init__(self, name: tp.Optional[str] = None, dtype: np.dtype = jnp.float32): | ||
|
@@ -338,7 +342,7 @@ def __call__(self, *args, **kwargs) -> tp.Any: | |
else: | ||
outputs = self.call(*args, **kwargs) | ||
|
||
add_summary(self, outputs) | ||
add_summary(self, outputs, (args, kwargs)) | ||
|
||
return outputs | ||
|
||
|
@@ -578,30 +582,83 @@ def states_bytes(self, include_submodules: bool = True): | |
) | ||
) | ||
|
||
def slice( | ||
self, | ||
start_module: tp.Union["Module", str, None], | ||
end_module: tp.Union[ | ||
"Module", str, None, tp.List[tp.Union["Module", str, None]] | ||
], | ||
sample_input: np.ndarray, | ||
) -> "Module": | ||
""" | ||
Creates a new submodule starting from the input of `start_module` to the outputs of `end_module`. | ||
|
||
Current limitations: | ||
|
||
- all operations between `start_module` and `end_module` must be performed by modules | ||
i.e. `jax.nn.relu()` or `x+1` is not allowed but can be converted by wrapping with `elegy.to_module()` | ||
- only one `start_module` is supported | ||
- all modules between `start_module` and `end_module` must have a single output | ||
|
||
|
||
Example usage: | ||
``` | ||
x = jnp.zeros((2, 224, 224, 3)) | ||
resnet = elegy.nets.resnet.ResNet18() | ||
submodule = resnet.slice( | ||
start_module=None, | ||
end_module=["/res_net_block_1", "/res_net_block_3", "/res_net_block_5", "/res_net_block_7" ], | ||
sample_input=x, | ||
) | ||
outputs = elegy.Model(submodule).predict(x) | ||
assert outputs[0].shape == (2, 56, 56, 64) | ||
assert outputs[1].shape == (2, 28, 28, 128) | ||
assert outputs[2].shape == (2, 14, 14, 256) | ||
assert outputs[3].shape == (2, 7, 7, 512) | ||
``` | ||
|
||
Arguments: | ||
start_module: Child module or name of a child module which will be the input module of the resulting module. | ||
If `None`, the first module is used. | ||
end_module: Child module, name of child module, `None` or a list thereof which will be the output module(s) of the resulting module. | ||
If `None`, the last module is used. | ||
sample_input: An array representing a sample input to the parent module. | ||
""" | ||
# importing here because of a circular dependency | ||
from elegy.module_slicing import slice_module_from_to | ||
|
||
return slice_module_from_to(self, start_module, end_module, sample_input) | ||
|
||
|
||
# ------------------------------------------------------------- | ||
# hooks | ||
# ------------------------------------------------------------- | ||
|
||
|
||
def add_summary(module_or_name: tp.Union[Module, str], value: np.ndarray) -> None: | ||
def add_summary( | ||
module_or_name: tp.Union[Module, str], | ||
value: np.ndarray, | ||
input_values: tp.Optional[tp.Tuple[tp.Tuple, tp.Dict]] = None, | ||
) -> None: | ||
""" | ||
A hook that lets you define a summary in the current module. Its primary | ||
use is to keep track of certain values as they flow through the network | ||
so [`Model.summary`][elegy.model.model.Model.summary] can show a representation of architecture. | ||
so [`Model.summary`][elegy.model.model.Model.summary] can show a representation of architecture | ||
and to get the graph structure to slice modules. | ||
|
||
```python | ||
def call(self, x): | ||
... | ||
y = jax.nn.relu(x) | ||
elegy.add_summary("relu", y) | ||
elegy.add_summary("relu", y, ((x,), {})) | ||
... | ||
``` | ||
|
||
Arguments: | ||
module_or_name: The name of the summary or alternatively the module that this summary will represent. | ||
If a summary with the same name already exists a unique identifier will be generated. | ||
value: The value for the summary. | ||
input_values: The input arguments for the module, required for slicing. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you elaborate on the structure of the tuple? |
||
""" | ||
|
||
if LOCAL.summaries is None: | ||
|
@@ -616,7 +673,7 @@ def call(self, x): | |
else: | ||
module = module_or_name | ||
|
||
LOCAL.summaries.append((module, name, value)) | ||
LOCAL.summaries.append((module, name, value, input_values)) | ||
|
||
|
||
def add_loss(name: str, value: np.ndarray) -> None: | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
based on your comment this is not a limitation now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still is.
It's now possible to have inner modules that have multiple inputs but the result module still must have only one input. Single output limitation also holds.