Skip to content

Commit

Permalink
Module.slice()
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-g committed Dec 23, 2020
1 parent aa02024 commit 3b3d88a
Show file tree
Hide file tree
Showing 6 changed files with 61 additions and 27 deletions.
1 change: 1 addition & 0 deletions docs/api/Module.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@
- reset
- init
- initialized
- slice

1 change: 1 addition & 0 deletions docs/api/module/Module.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@
- reset
- init
- initialized
- slice

1 change: 1 addition & 0 deletions docs/api/nn/Sequential.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@
- reset
- init
- initialized
- slice

51 changes: 51 additions & 0 deletions elegy/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
from elegy.random import RNG
from elegy.utils import EMPTY, Empty, Mode, ModuleOrderError

# imported later because of a circular dependency
# from elegy.module_slicing import slice_module_from_to

__all__ = [
"Module",
"to_module",
Expand Down Expand Up @@ -244,6 +247,7 @@ class Module(metaclass=ModuleMeta):
"reset",
"init",
"initialized",
"slice",
]

def __init__(self, name: tp.Optional[str] = None, dtype: np.dtype = jnp.float32):
Expand Down Expand Up @@ -572,6 +576,53 @@ def states_bytes(self, include_submodules: bool = True):
)
)

def slice(
self,
start_module: tp.Union["Module", str, None],
end_module: tp.Union[
"Module", str, None, tp.List[tp.Union["Module", str, None]]
],
sample_input: np.ndarray,
) -> "Module":
"""
Creates a new submodule starting from the input of `start_module` to the outputs of `end_module`.
Current limitations:
- all operations between `start_module` and `end_module` must be performed by modules
i.e. `jax.nn.relu()` or `x+1` is not allowed but can be converted by wrapping with `elegy.to_module()`
- only one `start_module` is supported
- all modules between `start_module` and `end_module` must have a single output
Example usage:
```
x = jnp.zeros((2, 224, 224, 3))
resnet = elegy.nets.resnet.ResNet18()
submodule = resnet.slice(
start_module=None,
end_module=["/res_net_block_1", "/res_net_block_3", "/res_net_block_5", "/res_net_block_7" ],
sample_input=x,
)
outputs = elegy.Model(submodule).predict(x)
assert outputs[0].shape == (2, 56, 56, 64)
assert outputs[1].shape == (2, 28, 28, 128)
assert outputs[2].shape == (2, 14, 14, 256)
assert outputs[3].shape == (2, 7, 7, 512)
```
Arguments:
start_module: Child module or name of a child module which will be the input module of the resulting module.
If `None`, the first module is used.
end_module: Child module, name of child module, `None` or a list thereof which will be the output module(s) of the resulting module.
If `None`, the last module is used.
sample_input: An array representing a sample input to the parent module.
"""
# importing here because of a circular dependency
from elegy.module_slicing import slice_module_from_to

return slice_module_from_to(self, start_module, end_module, sample_input)


# -------------------------------------------------------------
# hooks
Expand Down
11 changes: 1 addition & 10 deletions elegy/module_slicing.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,18 @@
import networkx as nx
import elegy
from elegy import Module
from elegy.module import Module
import jax
import itertools
import typing as tp
import numpy as np

__all__ = ["slice_module_from_to"]


def slice_module_from_to(
module: Module,
start_module: tp.Union[Module, str, None],
end_module: tp.Union[Module, str, None, tp.List[tp.Union[Module, str, None]]],
sample_input: np.ndarray,
) -> Module:
"""Creates a new submodule starting from the input of `start_module` to the outputs of `end_module`.
Current limitations:
- only one `start_module` is supported
- all operations between `start_module` and `end_module` must be performed by modules
i.e. `jax.nn.relu()` or `x+1` is not allowed but can be converted by wrapping with `elegy.to_module()`
- all modules between `start_module` and `end_module` must have a single output
"""
assert not isinstance(
start_module, (tp.Tuple, tp.List)
), "Multiple inputs not yet supported"
Expand Down
23 changes: 6 additions & 17 deletions elegy/module_slicing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@ def test_basic_slice_by_ref(self):
x = jnp.zeros((32, 100))
basicmodule = BasicModule0()
basicmodule(x) # trigger creation of weights and submodules
submodule = elegy.module_slicing.slice_module_from_to(
basicmodule, basicmodule.linear0, basicmodule.linear1, x
)
submodule = basicmodule.slice(basicmodule.linear0, basicmodule.linear1, x)
submodel = elegy.Model(submodule)
submodel.summary(x)
assert submodel.predict(x).shape == (32, 10)
Expand All @@ -24,9 +22,7 @@ def test_basic_slice_by_name(self):
for start, end in START_END_COMBOS:
print(start, end)
basicmodule = BasicModule0()
submodule = elegy.module_slicing.slice_module_from_to(
basicmodule, start, end, x
)
submodule = basicmodule.slice(start, end, x)
submodel = elegy.Model(submodule)
submodel.summary(x)
assert submodel.predict(x).shape == (32, 10)
Expand All @@ -35,8 +31,7 @@ def test_basic_slice_by_name(self):
def test_resnet_multi_out(self):
x = jnp.zeros((2, 224, 224, 3))
resnet = elegy.nets.resnet.ResNet18()
submodule = elegy.module_slicing.slice_module_from_to(
resnet,
submodule = resnet.slice(
start_module=None,
end_module=[
"/res_net_block_1",
Expand Down Expand Up @@ -66,9 +61,7 @@ def test_retrain(self):
y = jnp.zeros((32, 10))

basicmodule = BasicModule0()
submodule = elegy.module_slicing.slice_module_from_to(
basicmodule, "linear0", "linear1", x
)
submodule = basicmodule.slice("linear0", "linear1", x)
submodel = elegy.Model(
submodule,
loss=elegy.losses.MeanAbsoluteError(),
Expand All @@ -91,9 +84,7 @@ def test_no_path(self):
basicmodule = BasicModule0()
for start_module in ["linear2", "linear1"]:
try:
submodule = elegy.module_slicing.slice_module_from_to(
basicmodule, start_module, "linear0", x
)
submodule = basicmodule.slice(start_module, "linear0", x)
except RuntimeError as e:
assert e.args[0].startswith(f"No path from /{start_module} to /linear0")
else:
Expand All @@ -106,9 +97,7 @@ def test_multi_input_modules(self):
model = elegy.Model(module)
model.summary(x)

submodule = elegy.module_slicing.slice_module_from_to(
module, None, "/multi_input_module", x
)
submodule = module.slice(None, "/multi_input_module", x)
submodel = elegy.Model(submodule)
submodel.summary(x)
print(submodule.get_parameters())
Expand Down

0 comments on commit 3b3d88a

Please sign in to comment.