Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Module Slicing #115

Open
wants to merge 22 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/Module.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@
- reset
- init
- initialized
- slice

1 change: 1 addition & 0 deletions docs/api/module/Module.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@
- reset
- init
- initialized
- slice

1 change: 1 addition & 0 deletions docs/api/nn/Sequential.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@
- reset
- init
- initialized
- slice

2 changes: 1 addition & 1 deletion elegy/model/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,7 +454,7 @@ def format_size(size):

table: tp.List = [["Inputs", format_output(x), "0", "0"]]

for module, base_name, value in summaries:
for module, base_name, value, _ in summaries:
base_name_parts = base_name.split("/")[1:]
module_depth = len(base_name_parts)

Expand Down
67 changes: 62 additions & 5 deletions elegy/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
from elegy.random import RNG
from elegy.utils import EMPTY, Empty, Mode, ModuleOrderError

# imported later because of a circular dependency
# from elegy.module_slicing import slice_module_from_to

__all__ = [
"Module",
"to_module",
Expand Down Expand Up @@ -244,6 +247,7 @@ class Module(metaclass=ModuleMeta):
"reset",
"init",
"initialized",
"slice",
]

def __init__(self, name: tp.Optional[str] = None, dtype: np.dtype = jnp.float32):
Expand Down Expand Up @@ -338,7 +342,7 @@ def __call__(self, *args, **kwargs) -> tp.Any:
else:
outputs = self.call(*args, **kwargs)

add_summary(self, outputs)
add_summary(self, outputs, (args, kwargs))

return outputs

Expand Down Expand Up @@ -578,30 +582,83 @@ def states_bytes(self, include_submodules: bool = True):
)
)

def slice(
self,
start_module: tp.Union["Module", str, None],
end_module: tp.Union[
"Module", str, None, tp.List[tp.Union["Module", str, None]]
],
sample_input: np.ndarray,
) -> "Module":
"""
Creates a new submodule starting from the input of `start_module` to the outputs of `end_module`.

Current limitations:

- all operations between `start_module` and `end_module` must be performed by modules
i.e. `jax.nn.relu()` or `x+1` is not allowed but can be converted by wrapping with `elegy.to_module()`
- only one `start_module` is supported
- all modules between `start_module` and `end_module` must have a single output
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

based on your comment this is not a limitation now?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still is.
It's now possible to have inner modules that have multiple inputs but the result module still must have only one input. Single output limitation also holds.



Example usage:
```
x = jnp.zeros((2, 224, 224, 3))
resnet = elegy.nets.resnet.ResNet18()
submodule = resnet.slice(
start_module=None,
end_module=["/res_net_block_1", "/res_net_block_3", "/res_net_block_5", "/res_net_block_7" ],
sample_input=x,
)
outputs = elegy.Model(submodule).predict(x)
assert outputs[0].shape == (2, 56, 56, 64)
assert outputs[1].shape == (2, 28, 28, 128)
assert outputs[2].shape == (2, 14, 14, 256)
assert outputs[3].shape == (2, 7, 7, 512)
```

Arguments:
start_module: Child module or name of a child module which will be the input module of the resulting module.
If `None`, the first module is used.
end_module: Child module, name of child module, `None` or a list thereof which will be the output module(s) of the resulting module.
If `None`, the last module is used.
sample_input: An array representing a sample input to the parent module.
"""
# importing here because of a circular dependency
from elegy.module_slicing import slice_module_from_to

return slice_module_from_to(self, start_module, end_module, sample_input)


# -------------------------------------------------------------
# hooks
# -------------------------------------------------------------


def add_summary(module_or_name: tp.Union[Module, str], value: np.ndarray) -> None:
def add_summary(
module_or_name: tp.Union[Module, str],
value: np.ndarray,
input_values: tp.Optional[tp.Tuple[tp.Tuple, tp.Dict]] = None,
) -> None:
"""
A hook that lets you define a summary in the current module. Its primary
use is to keep track of certain values as they flow through the network
so [`Model.summary`][elegy.model.model.Model.summary] can show a representation of architecture.
so [`Model.summary`][elegy.model.model.Model.summary] can show a representation of architecture
and to get the graph structure to slice modules.

```python
def call(self, x):
...
y = jax.nn.relu(x)
elegy.add_summary("relu", y)
elegy.add_summary("relu", y, ((x,), {}))
...
```

Arguments:
module_or_name: The name of the summary or alternatively the module that this summary will represent.
If a summary with the same name already exists a unique identifier will be generated.
value: The value for the summary.
input_values: The input arguments for the module, required for slicing.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you elaborate on the structure of the tuple?

"""

if LOCAL.summaries is None:
Expand All @@ -616,7 +673,7 @@ def call(self, x):
else:
module = module_or_name

LOCAL.summaries.append((module, name, value))
LOCAL.summaries.append((module, name, value, input_values))


def add_loss(name: str, value: np.ndarray) -> None:
Expand Down
Loading