From 3b3d88afb3446208a45c55a04b6de0fb936a6399 Mon Sep 17 00:00:00 2001 From: alexander-g <3867427+alexander-g@users.noreply.github.com> Date: Wed, 23 Dec 2020 17:30:41 +0100 Subject: [PATCH] Module.slice() --- docs/api/Module.md | 1 + docs/api/module/Module.md | 1 + docs/api/nn/Sequential.md | 1 + elegy/module.py | 51 ++++++++++++++++++++++++++++++++++++ elegy/module_slicing.py | 11 +------- elegy/module_slicing_test.py | 23 +++++----------- 6 files changed, 61 insertions(+), 27 deletions(-) diff --git a/docs/api/Module.md b/docs/api/Module.md index bb0cdb27..0304f149 100644 --- a/docs/api/Module.md +++ b/docs/api/Module.md @@ -13,4 +13,5 @@ - reset - init - initialized + - slice \ No newline at end of file diff --git a/docs/api/module/Module.md b/docs/api/module/Module.md index 4fce8e01..05526de1 100644 --- a/docs/api/module/Module.md +++ b/docs/api/module/Module.md @@ -13,4 +13,5 @@ - reset - init - initialized + - slice \ No newline at end of file diff --git a/docs/api/nn/Sequential.md b/docs/api/nn/Sequential.md index 660c84cd..c0ae3fe0 100644 --- a/docs/api/nn/Sequential.md +++ b/docs/api/nn/Sequential.md @@ -13,4 +13,5 @@ - reset - init - initialized + - slice \ No newline at end of file diff --git a/elegy/module.py b/elegy/module.py index 84f6edff..652eebb7 100644 --- a/elegy/module.py +++ b/elegy/module.py @@ -17,6 +17,9 @@ from elegy.random import RNG from elegy.utils import EMPTY, Empty, Mode, ModuleOrderError +# imported later because of a circular dependency +# from elegy.module_slicing import slice_module_from_to + __all__ = [ "Module", "to_module", @@ -244,6 +247,7 @@ class Module(metaclass=ModuleMeta): "reset", "init", "initialized", + "slice", ] def __init__(self, name: tp.Optional[str] = None, dtype: np.dtype = jnp.float32): @@ -572,6 +576,53 @@ def states_bytes(self, include_submodules: bool = True): ) ) + def slice( + self, + start_module: tp.Union["Module", str, None], + end_module: tp.Union[ + "Module", str, None, tp.List[tp.Union["Module", str, None]] + ], + sample_input: np.ndarray, + ) -> "Module": + """ + Creates a new submodule starting from the input of `start_module` to the outputs of `end_module`. + + Current limitations: + + - all operations between `start_module` and `end_module` must be performed by modules + i.e. `jax.nn.relu()` or `x+1` is not allowed but can be converted by wrapping with `elegy.to_module()` + - only one `start_module` is supported + - all modules between `start_module` and `end_module` must have a single output + + + Example usage: + ``` + x = jnp.zeros((2, 224, 224, 3)) + resnet = elegy.nets.resnet.ResNet18() + submodule = resnet.slice( + start_module=None, + end_module=["/res_net_block_1", "/res_net_block_3", "/res_net_block_5", "/res_net_block_7" ], + sample_input=x, + ) + outputs = elegy.Model(submodule).predict(x) + assert outputs[0].shape == (2, 56, 56, 64) + assert outputs[1].shape == (2, 28, 28, 128) + assert outputs[2].shape == (2, 14, 14, 256) + assert outputs[3].shape == (2, 7, 7, 512) + ``` + + Arguments: + start_module: Child module or name of a child module which will be the input module of the resulting module. + If `None`, the first module is used. + end_module: Child module, name of child module, `None` or a list thereof which will be the output module(s) of the resulting module. + If `None`, the last module is used. + sample_input: An array representing a sample input to the parent module. + """ + # importing here because of a circular dependency + from elegy.module_slicing import slice_module_from_to + + return slice_module_from_to(self, start_module, end_module, sample_input) + # ------------------------------------------------------------- # hooks diff --git a/elegy/module_slicing.py b/elegy/module_slicing.py index a3fd7519..9c4b8b93 100644 --- a/elegy/module_slicing.py +++ b/elegy/module_slicing.py @@ -1,13 +1,11 @@ import networkx as nx import elegy -from elegy import Module +from elegy.module import Module import jax import itertools import typing as tp import numpy as np -__all__ = ["slice_module_from_to"] - def slice_module_from_to( module: Module, @@ -15,13 +13,6 @@ def slice_module_from_to( end_module: tp.Union[Module, str, None, tp.List[tp.Union[Module, str, None]]], sample_input: np.ndarray, ) -> Module: - """Creates a new submodule starting from the input of `start_module` to the outputs of `end_module`. - Current limitations: - - only one `start_module` is supported - - all operations between `start_module` and `end_module` must be performed by modules - i.e. `jax.nn.relu()` or `x+1` is not allowed but can be converted by wrapping with `elegy.to_module()` - - all modules between `start_module` and `end_module` must have a single output - """ assert not isinstance( start_module, (tp.Tuple, tp.List) ), "Multiple inputs not yet supported" diff --git a/elegy/module_slicing_test.py b/elegy/module_slicing_test.py index ba58ee97..3eadfd96 100644 --- a/elegy/module_slicing_test.py +++ b/elegy/module_slicing_test.py @@ -10,9 +10,7 @@ def test_basic_slice_by_ref(self): x = jnp.zeros((32, 100)) basicmodule = BasicModule0() basicmodule(x) # trigger creation of weights and submodules - submodule = elegy.module_slicing.slice_module_from_to( - basicmodule, basicmodule.linear0, basicmodule.linear1, x - ) + submodule = basicmodule.slice(basicmodule.linear0, basicmodule.linear1, x) submodel = elegy.Model(submodule) submodel.summary(x) assert submodel.predict(x).shape == (32, 10) @@ -24,9 +22,7 @@ def test_basic_slice_by_name(self): for start, end in START_END_COMBOS: print(start, end) basicmodule = BasicModule0() - submodule = elegy.module_slicing.slice_module_from_to( - basicmodule, start, end, x - ) + submodule = basicmodule.slice(start, end, x) submodel = elegy.Model(submodule) submodel.summary(x) assert submodel.predict(x).shape == (32, 10) @@ -35,8 +31,7 @@ def test_basic_slice_by_name(self): def test_resnet_multi_out(self): x = jnp.zeros((2, 224, 224, 3)) resnet = elegy.nets.resnet.ResNet18() - submodule = elegy.module_slicing.slice_module_from_to( - resnet, + submodule = resnet.slice( start_module=None, end_module=[ "/res_net_block_1", @@ -66,9 +61,7 @@ def test_retrain(self): y = jnp.zeros((32, 10)) basicmodule = BasicModule0() - submodule = elegy.module_slicing.slice_module_from_to( - basicmodule, "linear0", "linear1", x - ) + submodule = basicmodule.slice("linear0", "linear1", x) submodel = elegy.Model( submodule, loss=elegy.losses.MeanAbsoluteError(), @@ -91,9 +84,7 @@ def test_no_path(self): basicmodule = BasicModule0() for start_module in ["linear2", "linear1"]: try: - submodule = elegy.module_slicing.slice_module_from_to( - basicmodule, start_module, "linear0", x - ) + submodule = basicmodule.slice(start_module, "linear0", x) except RuntimeError as e: assert e.args[0].startswith(f"No path from /{start_module} to /linear0") else: @@ -106,9 +97,7 @@ def test_multi_input_modules(self): model = elegy.Model(module) model.summary(x) - submodule = elegy.module_slicing.slice_module_from_to( - module, None, "/multi_input_module", x - ) + submodule = module.slice(None, "/multi_input_module", x) submodel = elegy.Model(submodule) submodel.summary(x) print(submodule.get_parameters())