From e622fdeba3d01fb9a5915344139c091adffd56e9 Mon Sep 17 00:00:00 2001 From: alexander-g <3867427+alexander-g@users.noreply.github.com> Date: Mon, 23 Nov 2020 17:20:17 +0100 Subject: [PATCH 01/16] experimental module slicing --- elegy/model/model_base.py | 2 +- elegy/module.py | 8 +- elegy/module_slicing.py | 193 +++++++++++++++++++++++++++++++++++ elegy/module_slicing_test.py | 75 ++++++++++++++ elegy/nets/resnet.py | 55 +++++++--- 5 files changed, 313 insertions(+), 20 deletions(-) create mode 100644 elegy/module_slicing.py create mode 100644 elegy/module_slicing_test.py diff --git a/elegy/model/model_base.py b/elegy/model/model_base.py index be1d8f42..8f4ca582 100644 --- a/elegy/model/model_base.py +++ b/elegy/model/model_base.py @@ -436,7 +436,7 @@ def format_size(size): table: tp.List = [["Inputs", format_output(x), "0", "0"]] - for module, base_name, value in summaries: + for module, base_name, value, _ in summaries: base_name_parts = base_name.split("/")[1:] module_depth = len(base_name_parts) diff --git a/elegy/module.py b/elegy/module.py index 8d589a6a..3228232f 100644 --- a/elegy/module.py +++ b/elegy/module.py @@ -331,7 +331,7 @@ def __call__(self, *args, **kwargs) -> tp.Any: else: outputs = self.call(*args, **kwargs) - add_summary(self, outputs) + add_summary(self, outputs, (args, kwargs)) return outputs @@ -577,7 +577,9 @@ def states_bytes(self, include_submodules: bool = True): # ------------------------------------------------------------- -def add_summary(module_or_name: tp.Union[Module, str], value: np.ndarray) -> None: +def add_summary( + module_or_name: tp.Union[Module, str], value: np.ndarray, input_values=None +) -> None: """ A hook that lets you define a summary in the current module. Its primary use is to keep track of certain values as they flow through the network @@ -609,7 +611,7 @@ def call(self, x): else: module = module_or_name - LOCAL.summaries.append((module, name, value)) + LOCAL.summaries.append((module, name, value, input_values)) def add_loss(name: str, value: np.ndarray) -> None: diff --git a/elegy/module_slicing.py b/elegy/module_slicing.py new file mode 100644 index 00000000..7a4ed9d4 --- /dev/null +++ b/elegy/module_slicing.py @@ -0,0 +1,193 @@ +import networkx as nx +import elegy +from elegy import Module +import jax +import itertools +import typing as tp +import numpy as np + +__all__ = ["slice_module_from_to"] + + +def slice_module_from_to( + module: Module, + start_module: tp.Union[Module, str, None], + end_module: tp.Union[Module, str, None, tp.List[tp.Union[Module, str, None]]], + sample_input: np.ndarray, +) -> Module: + """Creates a new submodule starting from the input of 'start_module' to the outputs of 'end_module'. + Current limitations: + - only one input module is supported + - all operations between start_module and end_module must be performed by modules + i.e. jax.nn.relu() or x+1 is not allowed but can be converted by wrapping with elegy.to_module() + - all modules between start_module and end_module must have a single input and a single output + - resulting module is currently not trainable + """ + assert not isinstance( + start_module, (tp.Tuple, tp.List) + ), "Multiple inputs not yet supported" + + # get info about the module structure via summaries + model = elegy.Model(module) + with elegy.hooks_context(summaries=True): + model.predict_fn(sample_input) + summaries = elegy.get_summaries() + + edges = [Edge(summ) for summ in summaries] + start_id = get_input_id(edges, start_module) + if not isinstance(end_module, (tp.Tuple, tp.List)): + end_module = [end_module] + end_ids = [get_output_id(edges, m) for m in end_module] + + graph = construct_graph(edges) + paths = [find_path(graph, start_id, end_id) for end_id in end_ids] + tree = combine_paths(paths) + submodule_call = construct_call(tree) + submodule = elegy.to_module(submodule_call)() + return submodule + + +class Edge: + """A struct to hold edge data""" + + def __init__(self, summary: tp.Tuple[Module, str, np.ndarray, tp.Any]): + self.module = summary[0] + # remove the full module name, leave the leading '/' + self.modulename = ( + summary[1][summary[1].find("/") :] if "/" in summary[1] else "/" + ) + # convert the output and input arrays in the summary to unique IDs as returned by id() + self.output_ids = jax.tree_leaves(jax.tree_map(id, summary[2])) + self.input_ids = jax.tree_map(id, summary[3]) + + +def search_edges( + edges: tp.List[Edge], searchtarget: tp.Union[Module, str, None] +) -> Edge: + """Searches 'edges' for 'searchtarget' which can be a module, name of a module or None""" + if searchtarget is None: + # None means input/output of the full module, which is the last edge + return edges[-1] + elif isinstance(searchtarget, str): + # search by name, with or without leading '/' + if not searchtarget.startswith("/"): + searchtarget = "/" + searchtarget + edges = [e for e in edges if e.modulename == searchtarget] + elif isinstance(searchtarget, Module): + # search by reference + edges = [e for e in edges if e.module == searchtarget] + assert len(edges) > 0, f"Could not find module {searchtarget}" + assert len(edges) < 2, f"Found {len(edges)} modules for {searchtarget}" + return edges[0] + + +def get_input_id(edges: tp.List[Edge], module: tp.Union[Module, str, None]) -> int: + """Searches for module in the list of edges and returns the ID of its input array""" + edge = search_edges(edges, module) + input_ids = jax.tree_leaves(edge.input_ids) + assert len(input_ids) == 1, "Multi-input modules not yet supported" + return input_ids[0] + + +def get_output_id(edges: tp.List[Edge], module: tp.Union[Module, str, None]) -> int: + """Searches for module in the list of edges and returns the ID of its output array""" + edge = search_edges(edges, module) + assert len(edge.output_ids) == 1, "Multi-output modules not yet supported" + return edge.output_ids[0] + + +def merge_args_kwargs(*args, **kwargs) -> tp.List[tp.Tuple[tp.Any, tp.Any]]: + """Merges args and kwargs and their indices to a list of tuples + e.g. merge_args_kwargs(0, 77, a=-2) returns [(0,0), (1,77), ('a',-2)]""" + return list(enumerate(args)) + list(kwargs.items()) + + +def construct_graph(edges: tp.List[Edge]) -> nx.DiGraph: + """Constructs a directed graph with IDs of input/output arrays representing the nodes + and modules (and some more infos) representing the edges""" + G = nx.DiGraph() + for e in edges: + merged_args_kwargs = merge_args_kwargs(*e.input_ids[0], **e.input_ids[1]) + inout_combos = itertools.product(merged_args_kwargs, enumerate(e.output_ids)) + for ((inkey, input_id), (outkey, output_id)) in inout_combos: + depth = e.modulename.count("/") + # it can happen that there are multiple connections between two nodes + # e.g. when a simple parent module has only one child module + # use the one with the lowest depth, i.e. the parent module + if ((input_id, output_id) not in G.edges) or ( + G[input_id][output_id].depth > depth + ): + G.add_edge( + input_id, + output_id, + inkey=inkey, + outkey=outkey, + depth=depth, + **e.__dict__, + ) + return G + + +def find_path(graph: nx.DiGraph, start_node: int, end_node: int) -> nx.DiGraph: + """Returns a new graph with only nodes and edges from start_node to end_node""" + # TODO: catch exceptions + pathnodes = nx.shortest_path(graph, start_node, end_node) + pathgraph = graph.subgraph(pathnodes).copy() + # pathgraph is unordered, need to mark input and output edges + pathgraph[pathnodes[0]][pathnodes[1]]["is_input"] = True + pathgraph[pathnodes[-2]][pathnodes[-1]]["is_output"] = True + return pathgraph + + +def combine_paths(paths: tp.List[nx.DiGraph]) -> nx.DiGraph: + return nx.algorithms.compose_all(paths) + + +def construct_call(tree: nx.DiGraph) -> tp.Callable: + """Returns a new function that represents the __call__ of the new sliced submodule""" + + def visit_edge(edge, x, next_node): + assert edge["inkey"] == 0, "inputs other than 0 not yet implemented" + x = edge["module"](x) + + if isinstance(x, (tuple, list)): + # XXX: what if the whole tuple/list is needed as input later? + x = x[edge["outkey"]] + + outputs = [] + if edge.get("is_output", False): + outputs.append(x) + + if len(tree[next_node]): + # continue traversing the graph if there are more edges + for next_node, next_edge in tree[next_node].items(): + nextx = visit_edge(next_edge, x, next_node) + if not isinstance(nextx, tp.Tuple): + nextx = (nextx,) + outputs.extend(nextx) + # else: no more edges + + outputs = tuple(outputs) + if len(outputs) == 1: + outputs = outputs[0] + return outputs + + def call(x, *args, **kwargs): + input_nodes = [ + nodes[0] + for nodes, edge in tree.edges.items() + if edge.get("is_input", False) + ] + assert len(set(input_nodes)), "multi-inputs not yet supported" + start_node = input_nodes[0] + + x = [ + visit_edge(next_edge, x, next_node) + for next_node, next_edge in tree[start_node].items() + ] + x = tuple(x) + if len(x) == 1: + x = x[0] + return x + + return call diff --git a/elegy/module_slicing_test.py b/elegy/module_slicing_test.py new file mode 100644 index 00000000..ddd33654 --- /dev/null +++ b/elegy/module_slicing_test.py @@ -0,0 +1,75 @@ +import elegy +import elegy.module_slicing +from unittest import TestCase +import jax, jax.numpy as jnp + + +class ModuleSlicingTest(TestCase): + def test_basic_slice_by_ref(self): + x = jnp.zeros((32, 100)) + basicmodule = BasicModule0() + basicmodule(x) # trigger creation of weights and submodules + submodule = elegy.module_slicing.slice_module_from_to( + basicmodule, basicmodule.linear0, basicmodule.linear1, x + ) + submodel = elegy.Model(submodule) + submodel.summary(x) + assert submodel.predict(x).shape == (32, 10) + assert jnp.all(submodel.predict(x) == basicmodule.test_call(x)) + + def test_basic_slice_by_name(self): + x = jnp.zeros((32, 100)) + START_END_COMBOS = [("linear0", "linear1"), (None, "/linear1")] + for start, end in START_END_COMBOS: + print(start, end) + basicmodule = BasicModule0() + submodule = elegy.module_slicing.slice_module_from_to( + basicmodule, start, end, x + ) + submodel = elegy.Model(submodule) + submodel.summary(x) + assert submodel.predict(x).shape == (32, 10) + assert jnp.all(submodel.predict(x) == basicmodule.test_call(x)) + + def test_resnet_multi_out(self): + x = jnp.zeros((2, 224, 224, 3)) + resnet = elegy.nets.resnet.ResNet18() + submodule = elegy.module_slicing.slice_module_from_to( + resnet, + start_module=None, + end_module=[ + "/res_net_block_1", + "/res_net_block_3", + "/res_net_block_5", + "/res_net_block_6", + "/res_net_block_7", + ], + sample_input=x, + ) + submodel = elegy.Model(submodule) + # submodel.summary(x) + outputs = submodel.predict(x) + print(jax.tree_map(jnp.shape, outputs)) + assert len(outputs) == 5 + assert outputs[0].shape == (2, 56, 56, 64) + assert outputs[1].shape == (2, 28, 28, 128) + assert outputs[2].shape == (2, 14, 14, 256) + assert outputs[3].shape == (2, 7, 7, 512) + assert outputs[4].shape == (2, 7, 7, 512) + + print(jax.tree_map(jnp.shape, resnet.get_parameters())) + print(jax.tree_map(jnp.shape, submodel.get_parameters())) + # assert False + + +class BasicModule0(elegy.Module): + def call(self, x): + x = elegy.nn.Linear(25, name="linear0")(x) + x = elegy.nn.Linear(10, name="linear1")(x) + x = elegy.nn.Linear(5, name="linear2")(x) + return x + + def test_call(self, x): + x = self.linear0(x) + x = self.linear1(x) + return x diff --git a/elegy/nets/resnet.py b/elegy/nets/resnet.py index ce1d6940..7879ef3f 100644 --- a/elegy/nets/resnet.py +++ b/elegy/nets/resnet.py @@ -7,44 +7,65 @@ class ResNetBlock(module.Module): """ResNet (identity) block""" - def call(self, x, n_filters, strides=(1, 1)): + def __init__(self, n_filters, strides=(1, 1), *args, **kwargs): + super().__init__(*args, **kwargs) + self.n_filters = n_filters + self.strides = strides + + def call(self, x): x0 = x x = nn.Conv2D( - n_filters, (3, 3), with_bias=False, stride=strides, dtype=self.dtype + self.n_filters, + (3, 3), + with_bias=False, + stride=self.strides, + dtype=self.dtype, )(x) x = nn.BatchNormalization(decay_rate=0.9, eps=1e-5)(x) x = jax.nn.relu(x) - x = nn.Conv2D(n_filters, (3, 3), with_bias=False, dtype=self.dtype)(x) + x = nn.Conv2D(self.n_filters, (3, 3), with_bias=False, dtype=self.dtype)(x) x = nn.BatchNormalization(decay_rate=0.9, eps=1e-5)(x) if x0.shape != x.shape: x0 = nn.Conv2D( - n_filters, (1, 1), with_bias=False, stride=strides, dtype=self.dtype + self.n_filters, + (1, 1), + with_bias=False, + stride=self.strides, + dtype=self.dtype, )(x0) x0 = nn.BatchNormalization(decay_rate=0.9, eps=1e-5)(x0) return jax.nn.relu(x0 + x) -class BottleneckResNetBlock(module.Module): +class BottleneckResNetBlock(ResNetBlock): """ResNet Bottleneck block.""" def call(self, x, n_filters, strides=(1, 1)): x0 = x - x = nn.Conv2D(n_filters, (1, 1), with_bias=False, dtype=self.dtype)(x) + x = nn.Conv2D(self.n_filters, (1, 1), with_bias=False, dtype=self.dtype)(x) x = nn.BatchNormalization(decay_rate=0.9, eps=1e-5)(x) x = jax.nn.relu(x) x = nn.Conv2D( - n_filters, (3, 3), with_bias=False, stride=strides, dtype=self.dtype + self.n_filters, + (3, 3), + with_bias=False, + stride=self.strides, + dtype=self.dtype, )(x) x = nn.BatchNormalization(decay_rate=0.9, eps=1e-5)(x) x = jax.nn.relu(x) - x = nn.Conv2D(n_filters * 4, (1, 1), with_bias=False, dtype=self.dtype)(x) + x = nn.Conv2D(self.n_filters * 4, (1, 1), with_bias=False, dtype=self.dtype)(x) x = nn.BatchNormalization(decay_rate=0.9, eps=1e-5, scale_init=jnp.zeros)(x) if x0.shape != x.shape: x0 = nn.Conv2D( - n_filters * 4, (1, 1), with_bias=False, stride=strides, dtype=self.dtype + self.n_filters * 4, + (1, 1), + with_bias=False, + stride=self.strides, + dtype=self.dtype, )(x0) x0 = nn.BatchNormalization(decay_rate=0.9, eps=1e-5)(x0) return jax.nn.relu(x0 + x) @@ -63,18 +84,20 @@ def call(self, x): 64, (7, 7), stride=(2, 2), padding="SAME", with_bias=False, dtype=self.dtype )(x) x = nn.BatchNormalization(decay_rate=0.9, eps=1e-5)(x) - x = jax.nn.relu(x) + x = module.to_module(jax.nn.relu)()(x) - x = nn.linear.hk.max_pool( - x, window_shape=(1, 3, 3, 1), strides=(1, 2, 2, 1), padding="SAME" - ) + x = module.to_module( + lambda _x: nn.linear.hk.max_pool( + _x, window_shape=(1, 3, 3, 1), strides=(1, 2, 2, 1), padding="SAME" + ) + )()(x) for i, block_size in enumerate(self.stages): for j in range(block_size): strides = (2, 2) if i > 0 and j == 0 else (1, 1) - x = self.block_type(dtype=self.dtype)(x, 64 * 2 ** i, strides=strides) - x = jnp.mean(x, axis=(1, 2)) + x = self.block_type(64 * 2 ** i, strides=strides, dtype=self.dtype)(x) + x = module.to_module(lambda _x: jnp.mean(_x, axis=(1, 2)))()(x) x = nn.Linear(1000, dtype=self.dtype)(x) - x = jnp.asarray(x, jnp.float32) + x = module.to_module(lambda _x: jnp.asarray(_x, jnp.float32))()(x) return x From 83d2c0c3369392f0f1d8ea5bc0055d41a66e1db6 Mon Sep 17 00:00:00 2001 From: alexander-g <3867427+alexander-g@users.noreply.github.com> Date: Wed, 25 Nov 2020 08:34:53 +0100 Subject: [PATCH 02/16] added networkx dependency --- poetry.lock | 36 ++++++++++++++++++++++++++++++++++-- pyproject.toml | 1 + 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/poetry.lock b/poetry.lock index 1c375cf9..c296c1fd 100644 --- a/poetry.lock +++ b/poetry.lock @@ -289,7 +289,7 @@ python-versions = ">=2.7,!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*" name = "decorator" version = "4.4.2" description = "Decorators for Humans" -category = "dev" +category = "main" optional = false python-versions = ">=2.6, !=3.0.*, !=3.1.*" @@ -1056,6 +1056,30 @@ traitlets = ">=4.1" [package.extras] test = ["pytest", "pytest-cov", "testpath"] +[[package]] +name = "networkx" +version = "2.5" +description = "Python package for creating and manipulating graphs and networks" +category = "main" +optional = false +python-versions = ">=3.6" + +[package.dependencies] +decorator = ">=4.3.0" + +[package.extras] +all = ["numpy", "scipy", "pandas", "matplotlib", "pygraphviz", "pydot", "pyyaml", "lxml", "pytest"] +gdal = ["gdal"] +lxml = ["lxml"] +matplotlib = ["matplotlib"] +numpy = ["numpy"] +pandas = ["pandas"] +pydot = ["pydot"] +pygraphviz = ["pygraphviz"] +pytest = ["pytest"] +pyyaml = ["pyyaml"] +scipy = ["scipy"] + [[package]] name = "nltk" version = "3.5" @@ -1925,7 +1949,7 @@ testing = ["jaraco.itertools", "func-timeout"] [metadata] lock-version = "1.1" python-versions = "^3.6.1" -content-hash = "76722876470254aec7faca8ce53381def647fb96d894408d785e858d9aa34609" +content-hash = "24c692af727d8acc8161a3a99abc13bf1ff48401705ac2bfcaa93de5884031c4" [metadata.files] absl-py = [ @@ -2526,6 +2550,10 @@ nbformat = [ {file = "nbformat-5.0.7-py3-none-any.whl", hash = "sha256:ea55c9b817855e2dfcd3f66d74857342612a60b1f09653440f4a5845e6e3523f"}, {file = "nbformat-5.0.7.tar.gz", hash = "sha256:54d4d6354835a936bad7e8182dcd003ca3dc0cedfee5a306090e04854343b340"}, ] +networkx = [ + {file = "networkx-2.5-py3-none-any.whl", hash = "sha256:8c5812e9f798d37c50570d15c4a69d5710a18d77bafc903ee9c5fba7454c616c"}, + {file = "networkx-2.5.tar.gz", hash = "sha256:7978955423fbc9639c10498878be59caf99b44dc304c2286162fd24b458c1602"}, +] nltk = [ {file = "nltk-3.5.zip", hash = "sha256:845365449cd8c5f9731f7cb9f8bd6fd0767553b9d53af9eb1b3abf7700936b35"}, ] @@ -2822,6 +2850,8 @@ pyyaml = [ {file = "PyYAML-5.3.1-cp37-cp37m-win_amd64.whl", hash = "sha256:73f099454b799e05e5ab51423c7bcf361c58d3206fa7b0d555426b1f4d9a3eaf"}, {file = "PyYAML-5.3.1-cp38-cp38-win32.whl", hash = "sha256:06a0d7ba600ce0b2d2fe2e78453a470b5a6e000a985dd4a4e54e436cc36b0e97"}, {file = "PyYAML-5.3.1-cp38-cp38-win_amd64.whl", hash = "sha256:95f71d2af0ff4227885f7a6605c37fd53d3a106fcab511b8860ecca9fcf400ee"}, + {file = "PyYAML-5.3.1-cp39-cp39-win32.whl", hash = "sha256:ad9c67312c84def58f3c04504727ca879cb0013b2517c85a9a253f0cb6380c0a"}, + {file = "PyYAML-5.3.1-cp39-cp39-win_amd64.whl", hash = "sha256:6034f55dab5fea9e53f436aa68fa3ace2634918e8b5994d82f3621c04ff5ed2e"}, {file = "PyYAML-5.3.1.tar.gz", hash = "sha256:b8eac752c5e14d3eca0e6dd9199cd627518cb5ec06add0de9d32baeee6fe645d"}, ] pyzmq = [ @@ -2967,6 +2997,8 @@ tables = [ {file = "tables-3.6.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:eed1e030bb077476d585697e37f2b8e37db4157ff93b485b43f374254cff8698"}, {file = "tables-3.6.1-cp38-cp38-manylinux1_i686.whl", hash = "sha256:7acbf0e2fb7132a40f441ebb53b53c97cee05fb88ce743afdd97c681d1d377d7"}, {file = "tables-3.6.1-cp38-cp38-manylinux1_x86_64.whl", hash = "sha256:94d7ccac04277089e3bb466bf5c8f7038dd53bb8f19ea9679b7fea62c5c3ae8f"}, + {file = "tables-3.6.1-cp39-cp39-manylinux2010_i686.whl", hash = "sha256:da9e1ee83c01ed4d1382c7b186d77b4c0ef80b340a48d11a66346e30342c5929"}, + {file = "tables-3.6.1-cp39-cp39-manylinux2010_x86_64.whl", hash = "sha256:dedb959c00ac9e84562a69e80fa858d7aa06d91f96c6cb8cccbbbaf7a879436b"}, {file = "tables-3.6.1.tar.gz", hash = "sha256:49a972b8a7c27a8a173aeb05f67acb45fe608b64cd8e9fa667c0962a60b71b49"}, ] tabulate = [ diff --git a/pyproject.toml b/pyproject.toml index 1bd78b39..31cd4380 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ pyyaml = "^5.3.1" pytest-cov = "^2.10.0" dm-haiku = "^0.0.2" optax = "^0.0.1" +networkx = "^2.5" [tool.poetry.dev-dependencies] pytest = "^5.2" From 0ffc36f4ac5e9c374226efbc7c0c3cfc5a0f346d Mon Sep 17 00:00:00 2001 From: alexander-g <3867427+alexander-g@users.noreply.github.com> Date: Wed, 25 Nov 2020 10:53:29 +0100 Subject: [PATCH 03/16] sliced module is now retrainable --- elegy/module_slicing.py | 63 ++++++++++++++++++++---------------- elegy/module_slicing_test.py | 29 ++++++++++++++++- 2 files changed, 63 insertions(+), 29 deletions(-) diff --git a/elegy/module_slicing.py b/elegy/module_slicing.py index 7a4ed9d4..5d03a621 100644 --- a/elegy/module_slicing.py +++ b/elegy/module_slicing.py @@ -42,8 +42,7 @@ def slice_module_from_to( graph = construct_graph(edges) paths = [find_path(graph, start_id, end_id) for end_id in end_ids] tree = combine_paths(paths) - submodule_call = construct_call(tree) - submodule = elegy.to_module(submodule_call)() + submodule = SlicedModule(tree) return submodule @@ -143,10 +142,38 @@ def combine_paths(paths: tp.List[nx.DiGraph]) -> nx.DiGraph: return nx.algorithms.compose_all(paths) -def construct_call(tree: nx.DiGraph) -> tp.Callable: - """Returns a new function that represents the __call__ of the new sliced submodule""" +class SlicedModule(elegy.Module): + def __init__(self, tree: nx.DiGraph): + super().__init__() + #adding the all modules as attributes so that they get recognized by .get_parameters() + for edge in tree.edges.values(): + attrname = edge['modulename'][1:].replace('/', '_') + setattr(self, attrname, edge['module']) + + assert not hasattr(self, '_tree'), 'Modules with the name "_tree" are prohibited' + self._tree = tree - def visit_edge(edge, x, next_node): + def call(self, x): + input_nodes = [ + nodes[0] + for nodes, edge in self._tree.edges.items() + if edge.get("is_input", False) + ] + #should not happen + assert len(set(input_nodes))>0, "could not find any input nodes" + assert len(set(input_nodes))<2, "multi-inputs not yet supported" + start_node = input_nodes[0] + + x = [ + self.visit_edge(next_edge, x, next_node) + for next_node, next_edge in self._tree[start_node].items() + ] + x = tuple(x) + if len(x) == 1: + x = x[0] + return x + + def visit_edge(self, edge, x, next_node): assert edge["inkey"] == 0, "inputs other than 0 not yet implemented" x = edge["module"](x) @@ -158,10 +185,10 @@ def visit_edge(edge, x, next_node): if edge.get("is_output", False): outputs.append(x) - if len(tree[next_node]): + if len(self._tree[next_node]): # continue traversing the graph if there are more edges - for next_node, next_edge in tree[next_node].items(): - nextx = visit_edge(next_edge, x, next_node) + for next_node, next_edge in self._tree[next_node].items(): + nextx = self.visit_edge(next_edge, x, next_node) if not isinstance(nextx, tp.Tuple): nextx = (nextx,) outputs.extend(nextx) @@ -171,23 +198,3 @@ def visit_edge(edge, x, next_node): if len(outputs) == 1: outputs = outputs[0] return outputs - - def call(x, *args, **kwargs): - input_nodes = [ - nodes[0] - for nodes, edge in tree.edges.items() - if edge.get("is_input", False) - ] - assert len(set(input_nodes)), "multi-inputs not yet supported" - start_node = input_nodes[0] - - x = [ - visit_edge(next_edge, x, next_node) - for next_node, next_edge in tree[start_node].items() - ] - x = tuple(x) - if len(x) == 1: - x = x[0] - return x - - return call diff --git a/elegy/module_slicing_test.py b/elegy/module_slicing_test.py index ddd33654..72a2c5b0 100644 --- a/elegy/module_slicing_test.py +++ b/elegy/module_slicing_test.py @@ -2,6 +2,7 @@ import elegy.module_slicing from unittest import TestCase import jax, jax.numpy as jnp +import optax class ModuleSlicingTest(TestCase): @@ -59,7 +60,33 @@ def test_resnet_multi_out(self): print(jax.tree_map(jnp.shape, resnet.get_parameters())) print(jax.tree_map(jnp.shape, submodel.get_parameters())) - # assert False + + + def test_retrain(self): + x = jnp.ones((32, 100)) + y = jnp.zeros((32, 10)) + + basicmodule = BasicModule0() + submodule = elegy.module_slicing.slice_module_from_to( + basicmodule, "linear0", "linear1", x + ) + submodel = elegy.Model(submodule, loss=elegy.losses.MeanAbsoluteError(), optimizer=optax.adamw(1e-3),) + y0 = submodel.predict(x) + y1 = basicmodule.test_call(x) + + submodel.fit(x,y, epochs=3, verbose=2) + + y2 = submodel.predict(x) + y3 = basicmodule.test_call(x) + + assert jnp.all(y2 == y3) + #output after training should be closer to zero because targets are zero + assert jnp.abs(y2.mean()) < jnp.abs(y0.mean()) + + + + + class BasicModule0(elegy.Module): From c4ec83d96662c3e2a440544a108030306d4e604f Mon Sep 17 00:00:00 2001 From: alexander-g <3867427+alexander-g@users.noreply.github.com> Date: Wed, 25 Nov 2020 11:35:52 +0100 Subject: [PATCH 04/16] exception handling --- elegy/module_slicing.py | 36 +++++++++++++++++++++++---------- elegy/module_slicing_test.py | 39 +++++++++++++++++++++++++++--------- 2 files changed, 55 insertions(+), 20 deletions(-) diff --git a/elegy/module_slicing.py b/elegy/module_slicing.py index 5d03a621..ff5f60c5 100644 --- a/elegy/module_slicing.py +++ b/elegy/module_slicing.py @@ -129,8 +129,20 @@ def construct_graph(edges: tp.List[Edge]) -> nx.DiGraph: def find_path(graph: nx.DiGraph, start_node: int, end_node: int) -> nx.DiGraph: """Returns a new graph with only nodes and edges from start_node to end_node""" - # TODO: catch exceptions - pathnodes = nx.shortest_path(graph, start_node, end_node) + + startname = list(graph[start_node].values())[0]["modulename"] + endname = list(graph.reverse()[end_node].values())[0]["modulename"] + + try: + pathnodes = nx.shortest_path(graph, start_node, end_node) + except nx.NetworkXNoPath: + raise RuntimeError( + f"No path from {startname} to {endname}. Make sure all operations inbetween are performed by modules." + ) from None + if len(pathnodes) < 2: + raise RuntimeError( + f"No operations between the input of {startname} and the output of {endname}." + ) from None pathgraph = graph.subgraph(pathnodes).copy() # pathgraph is unordered, need to mark input and output edges pathgraph[pathnodes[0]][pathnodes[1]]["is_input"] = True @@ -145,12 +157,14 @@ def combine_paths(paths: tp.List[nx.DiGraph]) -> nx.DiGraph: class SlicedModule(elegy.Module): def __init__(self, tree: nx.DiGraph): super().__init__() - #adding the all modules as attributes so that they get recognized by .get_parameters() + # adding the all modules as attributes so that they get recognized by .get_parameters() for edge in tree.edges.values(): - attrname = edge['modulename'][1:].replace('/', '_') - setattr(self, attrname, edge['module']) - - assert not hasattr(self, '_tree'), 'Modules with the name "_tree" are prohibited' + attrname = edge["modulename"][1:].replace("/", "_") + setattr(self, attrname, edge["module"]) + + assert not hasattr( + self, "_tree" + ), 'Modules with the name "_tree" are prohibited' self._tree = tree def call(self, x): @@ -159,9 +173,9 @@ def call(self, x): for nodes, edge in self._tree.edges.items() if edge.get("is_input", False) ] - #should not happen - assert len(set(input_nodes))>0, "could not find any input nodes" - assert len(set(input_nodes))<2, "multi-inputs not yet supported" + # should not happen + assert len(set(input_nodes)) > 0, "could not find any input nodes" + assert len(set(input_nodes)) < 2, "multi-inputs not yet supported" start_node = input_nodes[0] x = [ @@ -172,7 +186,7 @@ def call(self, x): if len(x) == 1: x = x[0] return x - + def visit_edge(self, edge, x, next_node): assert edge["inkey"] == 0, "inputs other than 0 not yet implemented" x = edge["module"](x) diff --git a/elegy/module_slicing_test.py b/elegy/module_slicing_test.py index 72a2c5b0..4eea2b24 100644 --- a/elegy/module_slicing_test.py +++ b/elegy/module_slicing_test.py @@ -60,7 +60,6 @@ def test_resnet_multi_out(self): print(jax.tree_map(jnp.shape, resnet.get_parameters())) print(jax.tree_map(jnp.shape, submodel.get_parameters())) - def test_retrain(self): x = jnp.ones((32, 100)) @@ -68,25 +67,47 @@ def test_retrain(self): basicmodule = BasicModule0() submodule = elegy.module_slicing.slice_module_from_to( - basicmodule, "linear0", "linear1", x - ) - submodel = elegy.Model(submodule, loss=elegy.losses.MeanAbsoluteError(), optimizer=optax.adamw(1e-3),) + basicmodule, "linear0", "linear0", x + ) + submodel = elegy.Model( + submodule, + loss=elegy.losses.MeanAbsoluteError(), + optimizer=optax.adamw(1e-3), + ) y0 = submodel.predict(x) y1 = basicmodule.test_call(x) - submodel.fit(x,y, epochs=3, verbose=2) + submodel.fit(x, y, epochs=3, verbose=2) y2 = submodel.predict(x) y3 = basicmodule.test_call(x) assert jnp.all(y2 == y3) - #output after training should be closer to zero because targets are zero + # output after training should be closer to zero because targets are zero assert jnp.abs(y2.mean()) < jnp.abs(y0.mean()) + def test_no_path(self): + x = jnp.ones((32, 100)) + basicmodule = BasicModule0() + try: + submodule = elegy.module_slicing.slice_module_from_to( + basicmodule, "linear2", "linear0", x + ) + except RuntimeError as e: + assert e.args[0].startswith("No path from /linear2 to /linear0") + else: + assert False, "No error or wrong error raised" - - - + try: + submodule = elegy.module_slicing.slice_module_from_to( + basicmodule, "linear1", "linear0", x + ) + except RuntimeError as e: + assert e.args[0].startswith( + "No operations between the input of /linear1 and the output of /linear0" + ) + else: + assert False, "No error or wrong error raised" class BasicModule0(elegy.Module): From c5256146db39790c6d5485ca8cd09cc509fa0a00 Mon Sep 17 00:00:00 2001 From: alexander-g <3867427+alexander-g@users.noreply.github.com> Date: Wed, 25 Nov 2020 12:54:15 +0100 Subject: [PATCH 05/16] refactoring --- elegy/module_slicing.py | 45 ++++++++++++++++-------------------- elegy/module_slicing_test.py | 2 +- elegy/nets/resnet.py | 14 +++++------ 3 files changed, 28 insertions(+), 33 deletions(-) diff --git a/elegy/module_slicing.py b/elegy/module_slicing.py index ff5f60c5..be17197e 100644 --- a/elegy/module_slicing.py +++ b/elegy/module_slicing.py @@ -164,10 +164,10 @@ def __init__(self, tree: nx.DiGraph): assert not hasattr( self, "_tree" - ), 'Modules with the name "_tree" are prohibited' + ), 'Modules with the name "_tree" are prohibited' # can this happen? self._tree = tree - def call(self, x): + def call(self, x: tp.Any) -> tp.Union[tp.Any, tp.Tuple[tp.Any]]: input_nodes = [ nodes[0] for nodes, edge in self._tree.edges.items() @@ -178,37 +178,32 @@ def call(self, x): assert len(set(input_nodes)) < 2, "multi-inputs not yet supported" start_node = input_nodes[0] - x = [ - self.visit_edge(next_edge, x, next_node) - for next_node, next_edge in self._tree[start_node].items() - ] - x = tuple(x) - if len(x) == 1: - x = x[0] - return x + outputs = self.visit_node(start_node, x) + + outputs = tuple(outputs) + if len(outputs) == 1: + outputs = outputs[0] + return outputs - def visit_edge(self, edge, x, next_node): + def visit_edge(self, edge: tp.Dict, x: tp.Any) -> tp.Any: + """Performs the operation to get from node A to node B which the parameter "edge" connects""" assert edge["inkey"] == 0, "inputs other than 0 not yet implemented" + x = edge["module"](x) if isinstance(x, (tuple, list)): # XXX: what if the whole tuple/list is needed as input later? x = x[edge["outkey"]] + return x + + def visit_node(self, node: int, x: tp.Any) -> tp.List[tp.Any]: + """Recursively visits all nodes starting from the parameter "node" and collects outputs.""" outputs = [] - if edge.get("is_output", False): - outputs.append(x) - - if len(self._tree[next_node]): - # continue traversing the graph if there are more edges - for next_node, next_edge in self._tree[next_node].items(): - nextx = self.visit_edge(next_edge, x, next_node) - if not isinstance(nextx, tp.Tuple): - nextx = (nextx,) - outputs.extend(nextx) - # else: no more edges + for nextnode, edge in self._tree[node].items(): + y = self.visit_edge(edge, x) + if edge.get("is_output", False): + outputs.append(y) + outputs.extend(self.visit_node(nextnode, y)) - outputs = tuple(outputs) - if len(outputs) == 1: - outputs = outputs[0] return outputs diff --git a/elegy/module_slicing_test.py b/elegy/module_slicing_test.py index 4eea2b24..073ca446 100644 --- a/elegy/module_slicing_test.py +++ b/elegy/module_slicing_test.py @@ -67,7 +67,7 @@ def test_retrain(self): basicmodule = BasicModule0() submodule = elegy.module_slicing.slice_module_from_to( - basicmodule, "linear0", "linear0", x + basicmodule, "linear0", "linear1", x ) submodel = elegy.Model( submodule, diff --git a/elegy/nets/resnet.py b/elegy/nets/resnet.py index 7879ef3f..2b274d52 100644 --- a/elegy/nets/resnet.py +++ b/elegy/nets/resnet.py @@ -86,18 +86,18 @@ def call(self, x): x = nn.BatchNormalization(decay_rate=0.9, eps=1e-5)(x) x = module.to_module(jax.nn.relu)()(x) - x = module.to_module( - lambda _x: nn.linear.hk.max_pool( - _x, window_shape=(1, 3, 3, 1), strides=(1, 2, 2, 1), padding="SAME" - ) - )()(x) + x = nn.MaxPool(window_shape=(1, 3, 3, 1), strides=(1, 2, 2, 1), padding="SAME")( + x + ) for i, block_size in enumerate(self.stages): for j in range(block_size): strides = (2, 2) if i > 0 and j == 0 else (1, 1) x = self.block_type(64 * 2 ** i, strides=strides, dtype=self.dtype)(x) - x = module.to_module(lambda _x: jnp.mean(_x, axis=(1, 2)))()(x) + GAP = lambda x: jnp.mean(x, axis=(1, 2)) + x = module.to_module(GAP)(name="global_average_pooling")(x) x = nn.Linear(1000, dtype=self.dtype)(x) - x = module.to_module(lambda _x: jnp.asarray(_x, jnp.float32))()(x) + to_float32 = lambda x: jnp.asarray(x, jnp.float32) + x = module.to_module(to_float32)(name="to_float32")(x) return x From 2f18a9a2c4a15cb2e2bcecde9f85335e6122d773 Mon Sep 17 00:00:00 2001 From: alexander-g <3867427+alexander-g@users.noreply.github.com> Date: Wed, 2 Dec 2020 17:34:09 +0100 Subject: [PATCH 06/16] resnet50 fix --- elegy/module_slicing.py | 1 - elegy/nets/resnet.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/elegy/module_slicing.py b/elegy/module_slicing.py index be17197e..08bad27a 100644 --- a/elegy/module_slicing.py +++ b/elegy/module_slicing.py @@ -21,7 +21,6 @@ def slice_module_from_to( - all operations between start_module and end_module must be performed by modules i.e. jax.nn.relu() or x+1 is not allowed but can be converted by wrapping with elegy.to_module() - all modules between start_module and end_module must have a single input and a single output - - resulting module is currently not trainable """ assert not isinstance( start_module, (tp.Tuple, tp.List) diff --git a/elegy/nets/resnet.py b/elegy/nets/resnet.py index 2b274d52..57401d7e 100644 --- a/elegy/nets/resnet.py +++ b/elegy/nets/resnet.py @@ -42,7 +42,7 @@ def call(self, x): class BottleneckResNetBlock(ResNetBlock): """ResNet Bottleneck block.""" - def call(self, x, n_filters, strides=(1, 1)): + def call(self, x): x0 = x x = nn.Conv2D(self.n_filters, (1, 1), with_bias=False, dtype=self.dtype)(x) x = nn.BatchNormalization(decay_rate=0.9, eps=1e-5)(x) From a75c7e34d9e30d1127b0c6870fe9d54bc4643435 Mon Sep 17 00:00:00 2001 From: alexander-g <3867427+alexander-g@users.noreply.github.com> Date: Fri, 4 Dec 2020 10:41:34 +0100 Subject: [PATCH 07/16] Slicing with multi-input modules between start_module and end_module now possible --- elegy/module_slicing.py | 161 +++++++++++++++++++++++++++++------ elegy/module_slicing_test.py | 115 +++++++++++++++++++++---- 2 files changed, 231 insertions(+), 45 deletions(-) diff --git a/elegy/module_slicing.py b/elegy/module_slicing.py index 08bad27a..f9eafe9b 100644 --- a/elegy/module_slicing.py +++ b/elegy/module_slicing.py @@ -15,12 +15,12 @@ def slice_module_from_to( end_module: tp.Union[Module, str, None, tp.List[tp.Union[Module, str, None]]], sample_input: np.ndarray, ) -> Module: - """Creates a new submodule starting from the input of 'start_module' to the outputs of 'end_module'. + """Creates a new submodule starting from the input of `start_module` to the outputs of `end_module`. Current limitations: - - only one input module is supported - - all operations between start_module and end_module must be performed by modules - i.e. jax.nn.relu() or x+1 is not allowed but can be converted by wrapping with elegy.to_module() - - all modules between start_module and end_module must have a single input and a single output + - only one `start_module` is supported + - all operations between `start_module` and `end_module` must be performed by modules + i.e. `jax.nn.relu()` or `x+1` is not allowed but can be converted by wrapping with `elegy.to_module()` + - all modules between `start_module` and `end_module` must have a single output """ assert not isinstance( start_module, (tp.Tuple, tp.List) @@ -39,8 +39,8 @@ def slice_module_from_to( end_ids = [get_output_id(edges, m) for m in end_module] graph = construct_graph(edges) - paths = [find_path(graph, start_id, end_id) for end_id in end_ids] - tree = combine_paths(paths) + dag_paths = [find_dag_path(graph, start_id, end_id) for end_id in end_ids] + tree = combine_paths(dag_paths) #not really a tree submodule = SlicedModule(tree) return submodule @@ -99,6 +99,17 @@ def merge_args_kwargs(*args, **kwargs) -> tp.List[tp.Tuple[tp.Any, tp.Any]]: e.g. merge_args_kwargs(0, 77, a=-2) returns [(0,0), (1,77), ('a',-2)]""" return list(enumerate(args)) + list(kwargs.items()) +def split_merged_args_kwargs(args_kwargs: tp.List[tp.Tuple[tp.Any, tp.Any]]) -> tp.Tuple[tp.Tuple, tp.Dict]: + '''Reverse operation of merge_args_kwargs(). + e.g. split_merged_args_kwargs([(0,0), (1,77), ('a':-2)]) -> (0,77), {'a':-2}''' + args,kwargs = list(), dict() + for key,value in args_kwargs: + if isinstance(key, int): + args.append(value) + else: + kwargs[key]=value + return tuple(args), kwargs + def construct_graph(edges: tp.List[Edge]) -> nx.DiGraph: """Constructs a directed graph with IDs of input/output arrays representing the nodes @@ -126,27 +137,97 @@ def construct_graph(edges: tp.List[Edge]) -> nx.DiGraph: return G -def find_path(graph: nx.DiGraph, start_node: int, end_node: int) -> nx.DiGraph: - """Returns a new graph with only nodes and edges from start_node to end_node""" +def are_paths_computationally_equivalent(path0: nx.DiGraph, path1: nx.DiGraph) -> bool: + '''Checks two paths for computaional equivalence i.e. whether or not they differ only in depth of modules. + E.g if node B is computed by a module composed of several submodules with subnodes B0 and B1 + then paths A->B->C and A->B0->B1->B->C are computationally equivalent. + On the other hand, this does not apply to branches A->B->C vs A->D->C. + Importantly, the edge["inkey"] attributes must be the same: + A->C != A->B->C if C if computed by a dual-input module (e.g. C = A+B)''' + #traverse both paths and check if nodes path0 are in path1 or vice versa + #get nodes from both paths, make sure they are ordered + #skip the first one assuming both have the same source node + nodes0 = list(nx.dfs_postorder_nodes(path0))[::-1][1:] + nodes1 = list(nx.dfs_postorder_nodes(path1))[::-1][1:] + while len(nodes0) and len(nodes1): + #currently traversed nodes from both paths + n0, n1 = nodes0[0], nodes1[0] + + if n0 in nodes1: + #current node of path0 is in path1, still need to check 'inkey' + inkey0 = path0.get_edge_data(*list(path0.in_edges(n0))[0])['inkey'] + inkey1 = path1.get_edge_data(*list(path1.in_edges(n0))[0])['inkey'] + if inkey0 == inkey1: + #all ok, continue traversing paths + nodes1 = nodes1[nodes1.index(n0)+1:] + nodes0 = nodes0[1:] + continue + else: + #inkey is not the same, must be a multi-input module -> reject + return False + elif n1 in nodes0: + #current node of path1 is in path0, still need to check 'inkey' + inkey0 = path0.get_edge_data(*list(path0.in_edges(n1))[0])['inkey'] + inkey1 = path1.get_edge_data(*list(path1.in_edges(n1))[0])['inkey'] + if inkey0 == inkey1: + #all ok, continue traversing paths + nodes0 = nodes0[nodes0.index(n1)+1:] + nodes1 = nodes1[1:] + continue + else: + #inkey is not the same, must be a multi-input module -> reject + return False + else: + #neither path contains the current node of the other path -> reject + return False + if len(nodes0)>0 or len(nodes1)>0: + #should not happen because our paths have the same first and last nodes + return False + #traversed both paths until the end + return True + + +def filter_computationally_equivalent_paths(paths: tp.List[nx.DiGraph]) -> tp.List[nx.DiGraph]: + '''Removes paths with deep modules if there are paths with equivalent, shallow modules. + E.g: remove A->B0->B1->B->C in favor of A->B->C''' + filtered = set() #contains indices of paths to be removed + for i,j in itertools.combinations(range(len(paths)), 2): + if i in filtered or j in filtered: + continue + if are_paths_computationally_equivalent(paths[i], paths[j]): + #keep the shorter path + if len(paths[i]) > len(paths[j]): + filtered.add(i) + else: + filtered.add(j) + paths = [paths[i] for i in range(len(paths)) if i not in filtered] + return paths + + +def find_dag_path(graph: nx.DiGraph, start_node: int, end_node: int) -> nx.DiGraph: + """Returns a new (possibly multi-path) graph with only nodes and edges from start_node to end_node""" startname = list(graph[start_node].values())[0]["modulename"] endname = list(graph.reverse()[end_node].values())[0]["modulename"] try: - pathnodes = nx.shortest_path(graph, start_node, end_node) + edge_paths = list(nx.all_simple_edge_paths(graph, start_node, end_node)) #list of lists of tuples + if len(edge_paths)==0: + raise nx.NetworkXNoPath except nx.NetworkXNoPath: raise RuntimeError( f"No path from {startname} to {endname}. Make sure all operations inbetween are performed by modules." ) from None - if len(pathnodes) < 2: - raise RuntimeError( - f"No operations between the input of {startname} and the output of {endname}." - ) from None - pathgraph = graph.subgraph(pathnodes).copy() - # pathgraph is unordered, need to mark input and output edges - pathgraph[pathnodes[0]][pathnodes[1]]["is_input"] = True - pathgraph[pathnodes[-2]][pathnodes[-1]]["is_output"] = True - return pathgraph + + graph_paths = [nx.edge_subgraph(graph, path) for path in edge_paths] #list of nx.DiGraphs + graph_paths = filter_computationally_equivalent_paths(graph_paths) + dag_graph = nx.algorithms.compose_all(graph_paths) + #dag_graph is unordered, need to mark input and output edges + for _,_, edgedata in dag_graph.out_edges(start_node, data=True): + edgedata['is_input'] = True + for _,_, edgedata in dag_graph.in_edges(end_node, data=True): + edgedata['is_output'] = True + return dag_graph def combine_paths(paths: tp.List[nx.DiGraph]) -> nx.DiGraph: @@ -172,23 +253,39 @@ def call(self, x: tp.Any) -> tp.Union[tp.Any, tp.Tuple[tp.Any]]: for nodes, edge in self._tree.edges.items() if edge.get("is_input", False) ] + # should not happen assert len(set(input_nodes)) > 0, "could not find any input nodes" assert len(set(input_nodes)) < 2, "multi-inputs not yet supported" start_node = input_nodes[0] - outputs = self.visit_node(start_node, x) + outputs = self.visit_node(start_node, x, deferred_call_args=dict()) outputs = tuple(outputs) if len(outputs) == 1: outputs = outputs[0] return outputs - def visit_edge(self, edge: tp.Dict, x: tp.Any) -> tp.Any: + def visit_edge(self, edge: tp.Dict, x: tp.Any, deferred_call_args: tp.Dict) -> tp.Any: """Performs the operation to get from node A to node B which the parameter "edge" connects""" - assert edge["inkey"] == 0, "inputs other than 0 not yet implemented" - - x = edge["module"](x) + n_inputs = len(jax.tree_leaves(edge['input_ids'])) + if n_inputs==1: + #a single-input module, simply call it with the input + x = edge["module"](x) + else: + #multi-input module + #check if all the inputs are ready + call_args = deferred_call_args.get(edge['modulename'], dict()) + call_args[edge['inkey']] = x + if len(call_args) == n_inputs: + #all inputs are ready, call module + args, kwargs = split_merged_args_kwargs(call_args.items()) + x = edge['module'](*args, **kwargs) + del deferred_call_args[edge['modulename']] + else: + #still missing some inputs, continue traversing the graph + deferred_call_args[edge['modulename']] = call_args + return DeferredCall if isinstance(x, (tuple, list)): # XXX: what if the whole tuple/list is needed as input later? @@ -196,13 +293,23 @@ def visit_edge(self, edge: tp.Dict, x: tp.Any) -> tp.Any: return x - def visit_node(self, node: int, x: tp.Any) -> tp.List[tp.Any]: + def visit_node(self, node: int, x: tp.Any, deferred_call_args: tp.Dict) -> tp.List[tp.Any]: """Recursively visits all nodes starting from the parameter "node" and collects outputs.""" outputs = [] for nextnode, edge in self._tree[node].items(): - y = self.visit_edge(edge, x) + y = self.visit_edge(edge, x, deferred_call_args) + if y==DeferredCall: + #visited edge module is missing some inputs, will come back here later + continue if edge.get("is_output", False): outputs.append(y) - outputs.extend(self.visit_node(nextnode, y)) + outputs.extend(self.visit_node(nextnode, y, deferred_call_args)) return outputs + + +class DeferredCall: + '''Dummy class that indicates that a call has to be deferred''' + ... + + diff --git a/elegy/module_slicing_test.py b/elegy/module_slicing_test.py index 073ca446..dd4bf6d3 100644 --- a/elegy/module_slicing_test.py +++ b/elegy/module_slicing_test.py @@ -89,25 +89,88 @@ def test_retrain(self): def test_no_path(self): x = jnp.ones((32, 100)) basicmodule = BasicModule0() - try: - submodule = elegy.module_slicing.slice_module_from_to( - basicmodule, "linear2", "linear0", x - ) - except RuntimeError as e: - assert e.args[0].startswith("No path from /linear2 to /linear0") - else: - assert False, "No error or wrong error raised" + for start_module in ['linear2', 'linear1']: + try: + submodule = elegy.module_slicing.slice_module_from_to( + basicmodule, start_module, "linear0", x + ) + except RuntimeError as e: + assert e.args[0].startswith(f"No path from /{start_module} to /linear0") + else: + assert False, "No error or wrong error raised" + + def test_multi_input_modules(self): + x = jnp.ones((32, 100)) + + module = ContainsMultiInputModule() + model = elegy.Model(module) + model.summary(x) + + submodule = elegy.module_slicing.slice_module_from_to(module, None, '/multi_input_module', x) + submodel = elegy.Model(submodule) + submodel.summary(x) + print(submodule.get_parameters()) + + y = submodel.predict(x) + print(y.shape) + assert(y.shape==(32,25)) + assert(jnp.allclose(y, module.test_call(x) )) + + + def test_computationally_equivalent_paths(self): + import networkx as nx + G = nx.DiGraph() + G.add_edge(0,1, inkey=0) + G.add_edge(1,2, inkey=0) + G.add_edge(0,2, inkey=0) #0->2 is equivalent to the path 0->1->2 + G.add_edge(2,3, inkey=0) + G.add_edge(3,4, inkey=0) + + g0 = G.edge_subgraph([(0,1), (1,2), (2,3)]).copy() + g1 = G.edge_subgraph([(0,2), (2,3)]).copy() + + apce = elegy.module_slicing.are_paths_computationally_equivalent + fcep = elegy.module_slicing.filter_computationally_equivalent_paths + + assert apce(g0,g1) + assert apce(g1,g0) + filtered_paths = fcep([g0,g1]) + assert len(filtered_paths) == 1 + assert filtered_paths[0] == g1 + + G = nx.DiGraph() + G.add_edge(0,1, inkey=0) + G.add_edge(1,2, inkey=0) + G.add_edge(0,2, inkey=1) #not equivalent, multi-input module + G.add_edge(2,3, inkey=0) + G.add_edge(3,4, inkey=0) + + g0 = G.edge_subgraph([(0,1), (1,2), (2,3)]).copy() + g1 = G.edge_subgraph([(0,2), (2,3)]).copy() + g2 = G.edge_subgraph([(0,2), (2,3), (3,4)]).copy() + + apce = elegy.module_slicing.are_paths_computationally_equivalent + assert not apce(g0,g1) + assert not apce(g1,g0) + assert not apce(g1,g2) + filtered_paths = fcep([g0,g1,g2]) + assert len(filtered_paths) == 3 + assert g0 in filtered_paths and g1 in filtered_paths and g2 in filtered_paths + + + + def test_split_merge_args_kwargs(self): + args_kwargs = elegy.module_slicing.merge_args_kwargs(0,101,-2,a=65,b=77) + assert len(args_kwargs)==5 + for x in [(0,0), (1,101), (2,-2), ('a',65), ('b',77)]: + assert x in args_kwargs + + args,kwargs = elegy.module_slicing.split_merged_args_kwargs(args_kwargs) + assert args==(0,101,-2) + assert len(kwargs)==2 + assert kwargs['a']==65 and kwargs['b']==77 + - try: - submodule = elegy.module_slicing.slice_module_from_to( - basicmodule, "linear1", "linear0", x - ) - except RuntimeError as e: - assert e.args[0].startswith( - "No operations between the input of /linear1 and the output of /linear0" - ) - else: - assert False, "No error or wrong error raised" class BasicModule0(elegy.Module): @@ -121,3 +184,19 @@ def test_call(self, x): x = self.linear0(x) x = self.linear1(x) return x + +class MultiInputModule(elegy.Module): + def call(self, x0, x1): + return x0[...,:25]+x1[...,:25] + +class ContainsMultiInputModule(elegy.Module): + def call(self, x): + x0 = elegy.nn.Linear(25, name='linear0')(x) + x = MultiInputModule(name='multi_input_module')(x,x0) + x = elegy.nn.Linear(10)(x) + return x + + def test_call(self, x): + x0 = self.linear0(x) + x = self.multi_input_module(x, x0) + return x \ No newline at end of file From 41fa9169cb816465528cdf74b2ddc5498012aa1d Mon Sep 17 00:00:00 2001 From: alexander-g <3867427+alexander-g@users.noreply.github.com> Date: Fri, 4 Dec 2020 10:58:40 +0100 Subject: [PATCH 08/16] Docs for add_summary --- elegy/module.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/elegy/module.py b/elegy/module.py index 3228232f..05f40122 100644 --- a/elegy/module.py +++ b/elegy/module.py @@ -578,18 +578,19 @@ def states_bytes(self, include_submodules: bool = True): def add_summary( - module_or_name: tp.Union[Module, str], value: np.ndarray, input_values=None + module_or_name: tp.Union[Module, str], value: np.ndarray, input_values:tp.Optional[tp.Tuple[tp.Tuple, tp.Dict]]=None ) -> None: """ A hook that lets you define a summary in the current module. Its primary use is to keep track of certain values as they flow through the network - so [`Model.summary`][elegy.model.model.Model.summary] can show a representation of architecture. + so [`Model.summary`][elegy.model.model.Model.summary] can show a representation of architecture + and to get the graph structure to slice modules. ```python def call(self, x): ... y = jax.nn.relu(x) - elegy.add_summary("relu", y) + elegy.add_summary("relu", y, ((x,), {})) ... ``` @@ -597,6 +598,7 @@ def call(self, x): module_or_name: The name of the summary or alternatively the module that this summary will represent. If a summary with the same name already exists a unique identifier will be generated. value: The value for the summary. + input_values: The input arguments for the module, required for slicing. """ if LOCAL.summaries is None: From 15943e3b0da93b3dd08957121219d9ba815d4e9f Mon Sep 17 00:00:00 2001 From: alexander-g <3867427+alexander-g@users.noreply.github.com> Date: Fri, 4 Dec 2020 11:01:07 +0100 Subject: [PATCH 09/16] black --- elegy/module.py | 4 +- elegy/module_slicing.py | 149 +++++++++++++++++++---------------- elegy/module_slicing_test.py | 90 ++++++++++----------- 3 files changed, 128 insertions(+), 115 deletions(-) diff --git a/elegy/module.py b/elegy/module.py index 05f40122..384fd4e0 100644 --- a/elegy/module.py +++ b/elegy/module.py @@ -578,7 +578,9 @@ def states_bytes(self, include_submodules: bool = True): def add_summary( - module_or_name: tp.Union[Module, str], value: np.ndarray, input_values:tp.Optional[tp.Tuple[tp.Tuple, tp.Dict]]=None + module_or_name: tp.Union[Module, str], + value: np.ndarray, + input_values: tp.Optional[tp.Tuple[tp.Tuple, tp.Dict]] = None, ) -> None: """ A hook that lets you define a summary in the current module. Its primary diff --git a/elegy/module_slicing.py b/elegy/module_slicing.py index f9eafe9b..a3fd7519 100644 --- a/elegy/module_slicing.py +++ b/elegy/module_slicing.py @@ -40,7 +40,7 @@ def slice_module_from_to( graph = construct_graph(edges) dag_paths = [find_dag_path(graph, start_id, end_id) for end_id in end_ids] - tree = combine_paths(dag_paths) #not really a tree + tree = combine_paths(dag_paths) # not really a tree submodule = SlicedModule(tree) return submodule @@ -99,17 +99,20 @@ def merge_args_kwargs(*args, **kwargs) -> tp.List[tp.Tuple[tp.Any, tp.Any]]: e.g. merge_args_kwargs(0, 77, a=-2) returns [(0,0), (1,77), ('a',-2)]""" return list(enumerate(args)) + list(kwargs.items()) -def split_merged_args_kwargs(args_kwargs: tp.List[tp.Tuple[tp.Any, tp.Any]]) -> tp.Tuple[tp.Tuple, tp.Dict]: - '''Reverse operation of merge_args_kwargs(). - e.g. split_merged_args_kwargs([(0,0), (1,77), ('a':-2)]) -> (0,77), {'a':-2}''' - args,kwargs = list(), dict() - for key,value in args_kwargs: + +def split_merged_args_kwargs( + args_kwargs: tp.List[tp.Tuple[tp.Any, tp.Any]] +) -> tp.Tuple[tp.Tuple, tp.Dict]: + """Reverse operation of merge_args_kwargs(). + e.g. split_merged_args_kwargs([(0,0), (1,77), ('a':-2)]) -> (0,77), {'a':-2}""" + args, kwargs = list(), dict() + for key, value in args_kwargs: if isinstance(key, int): args.append(value) else: - kwargs[key]=value + kwargs[key] = value return tuple(args), kwargs - + def construct_graph(edges: tp.List[Edge]) -> nx.DiGraph: """Constructs a directed graph with IDs of input/output arrays representing the nodes @@ -137,66 +140,67 @@ def construct_graph(edges: tp.List[Edge]) -> nx.DiGraph: return G - def are_paths_computationally_equivalent(path0: nx.DiGraph, path1: nx.DiGraph) -> bool: - '''Checks two paths for computaional equivalence i.e. whether or not they differ only in depth of modules. - E.g if node B is computed by a module composed of several submodules with subnodes B0 and B1 - then paths A->B->C and A->B0->B1->B->C are computationally equivalent. - On the other hand, this does not apply to branches A->B->C vs A->D->C. - Importantly, the edge["inkey"] attributes must be the same: - A->C != A->B->C if C if computed by a dual-input module (e.g. C = A+B)''' - #traverse both paths and check if nodes path0 are in path1 or vice versa - #get nodes from both paths, make sure they are ordered - #skip the first one assuming both have the same source node + """Checks two paths for computaional equivalence i.e. whether or not they differ only in depth of modules. + E.g if node B is computed by a module composed of several submodules with subnodes B0 and B1 + then paths A->B->C and A->B0->B1->B->C are computationally equivalent. + On the other hand, this does not apply to branches A->B->C vs A->D->C. + Importantly, the edge["inkey"] attributes must be the same: + A->C != A->B->C if C if computed by a dual-input module (e.g. C = A+B)""" + # traverse both paths and check if nodes path0 are in path1 or vice versa + # get nodes from both paths, make sure they are ordered + # skip the first one assuming both have the same source node nodes0 = list(nx.dfs_postorder_nodes(path0))[::-1][1:] nodes1 = list(nx.dfs_postorder_nodes(path1))[::-1][1:] while len(nodes0) and len(nodes1): - #currently traversed nodes from both paths + # currently traversed nodes from both paths n0, n1 = nodes0[0], nodes1[0] if n0 in nodes1: - #current node of path0 is in path1, still need to check 'inkey' - inkey0 = path0.get_edge_data(*list(path0.in_edges(n0))[0])['inkey'] - inkey1 = path1.get_edge_data(*list(path1.in_edges(n0))[0])['inkey'] + # current node of path0 is in path1, still need to check 'inkey' + inkey0 = path0.get_edge_data(*list(path0.in_edges(n0))[0])["inkey"] + inkey1 = path1.get_edge_data(*list(path1.in_edges(n0))[0])["inkey"] if inkey0 == inkey1: - #all ok, continue traversing paths - nodes1 = nodes1[nodes1.index(n0)+1:] + # all ok, continue traversing paths + nodes1 = nodes1[nodes1.index(n0) + 1 :] nodes0 = nodes0[1:] continue else: - #inkey is not the same, must be a multi-input module -> reject + # inkey is not the same, must be a multi-input module -> reject return False elif n1 in nodes0: - #current node of path1 is in path0, still need to check 'inkey' - inkey0 = path0.get_edge_data(*list(path0.in_edges(n1))[0])['inkey'] - inkey1 = path1.get_edge_data(*list(path1.in_edges(n1))[0])['inkey'] + # current node of path1 is in path0, still need to check 'inkey' + inkey0 = path0.get_edge_data(*list(path0.in_edges(n1))[0])["inkey"] + inkey1 = path1.get_edge_data(*list(path1.in_edges(n1))[0])["inkey"] if inkey0 == inkey1: - #all ok, continue traversing paths - nodes0 = nodes0[nodes0.index(n1)+1:] + # all ok, continue traversing paths + nodes0 = nodes0[nodes0.index(n1) + 1 :] nodes1 = nodes1[1:] continue else: - #inkey is not the same, must be a multi-input module -> reject + # inkey is not the same, must be a multi-input module -> reject return False else: - #neither path contains the current node of the other path -> reject + # neither path contains the current node of the other path -> reject return False - if len(nodes0)>0 or len(nodes1)>0: - #should not happen because our paths have the same first and last nodes + if len(nodes0) > 0 or len(nodes1) > 0: + # should not happen because our paths have the same first and last nodes return False - #traversed both paths until the end + # traversed both paths until the end return True -def filter_computationally_equivalent_paths(paths: tp.List[nx.DiGraph]) -> tp.List[nx.DiGraph]: - '''Removes paths with deep modules if there are paths with equivalent, shallow modules. - E.g: remove A->B0->B1->B->C in favor of A->B->C''' - filtered = set() #contains indices of paths to be removed - for i,j in itertools.combinations(range(len(paths)), 2): +def filter_computationally_equivalent_paths( + paths: tp.List[nx.DiGraph], +) -> tp.List[nx.DiGraph]: + """Removes paths with deep modules if there are paths with equivalent, shallow modules. + E.g: remove A->B0->B1->B->C in favor of A->B->C""" + filtered = set() # contains indices of paths to be removed + for i, j in itertools.combinations(range(len(paths)), 2): if i in filtered or j in filtered: continue if are_paths_computationally_equivalent(paths[i], paths[j]): - #keep the shorter path + # keep the shorter path if len(paths[i]) > len(paths[j]): filtered.add(i) else: @@ -211,22 +215,26 @@ def find_dag_path(graph: nx.DiGraph, start_node: int, end_node: int) -> nx.DiGra endname = list(graph.reverse()[end_node].values())[0]["modulename"] try: - edge_paths = list(nx.all_simple_edge_paths(graph, start_node, end_node)) #list of lists of tuples - if len(edge_paths)==0: + edge_paths = list( + nx.all_simple_edge_paths(graph, start_node, end_node) + ) # list of lists of tuples + if len(edge_paths) == 0: raise nx.NetworkXNoPath except nx.NetworkXNoPath: raise RuntimeError( f"No path from {startname} to {endname}. Make sure all operations inbetween are performed by modules." ) from None - - graph_paths = [nx.edge_subgraph(graph, path) for path in edge_paths] #list of nx.DiGraphs + + graph_paths = [ + nx.edge_subgraph(graph, path) for path in edge_paths + ] # list of nx.DiGraphs graph_paths = filter_computationally_equivalent_paths(graph_paths) dag_graph = nx.algorithms.compose_all(graph_paths) - #dag_graph is unordered, need to mark input and output edges - for _,_, edgedata in dag_graph.out_edges(start_node, data=True): - edgedata['is_input'] = True - for _,_, edgedata in dag_graph.in_edges(end_node, data=True): - edgedata['is_output'] = True + # dag_graph is unordered, need to mark input and output edges + for _, _, edgedata in dag_graph.out_edges(start_node, data=True): + edgedata["is_input"] = True + for _, _, edgedata in dag_graph.in_edges(end_node, data=True): + edgedata["is_output"] = True return dag_graph @@ -266,25 +274,27 @@ def call(self, x: tp.Any) -> tp.Union[tp.Any, tp.Tuple[tp.Any]]: outputs = outputs[0] return outputs - def visit_edge(self, edge: tp.Dict, x: tp.Any, deferred_call_args: tp.Dict) -> tp.Any: + def visit_edge( + self, edge: tp.Dict, x: tp.Any, deferred_call_args: tp.Dict + ) -> tp.Any: """Performs the operation to get from node A to node B which the parameter "edge" connects""" - n_inputs = len(jax.tree_leaves(edge['input_ids'])) - if n_inputs==1: - #a single-input module, simply call it with the input + n_inputs = len(jax.tree_leaves(edge["input_ids"])) + if n_inputs == 1: + # a single-input module, simply call it with the input x = edge["module"](x) else: - #multi-input module - #check if all the inputs are ready - call_args = deferred_call_args.get(edge['modulename'], dict()) - call_args[edge['inkey']] = x + # multi-input module + # check if all the inputs are ready + call_args = deferred_call_args.get(edge["modulename"], dict()) + call_args[edge["inkey"]] = x if len(call_args) == n_inputs: - #all inputs are ready, call module + # all inputs are ready, call module args, kwargs = split_merged_args_kwargs(call_args.items()) - x = edge['module'](*args, **kwargs) - del deferred_call_args[edge['modulename']] + x = edge["module"](*args, **kwargs) + del deferred_call_args[edge["modulename"]] else: - #still missing some inputs, continue traversing the graph - deferred_call_args[edge['modulename']] = call_args + # still missing some inputs, continue traversing the graph + deferred_call_args[edge["modulename"]] = call_args return DeferredCall if isinstance(x, (tuple, list)): @@ -293,13 +303,15 @@ def visit_edge(self, edge: tp.Dict, x: tp.Any, deferred_call_args: tp.Dict) -> t return x - def visit_node(self, node: int, x: tp.Any, deferred_call_args: tp.Dict) -> tp.List[tp.Any]: + def visit_node( + self, node: int, x: tp.Any, deferred_call_args: tp.Dict + ) -> tp.List[tp.Any]: """Recursively visits all nodes starting from the parameter "node" and collects outputs.""" outputs = [] for nextnode, edge in self._tree[node].items(): y = self.visit_edge(edge, x, deferred_call_args) - if y==DeferredCall: - #visited edge module is missing some inputs, will come back here later + if y == DeferredCall: + # visited edge module is missing some inputs, will come back here later continue if edge.get("is_output", False): outputs.append(y) @@ -309,7 +321,6 @@ def visit_node(self, node: int, x: tp.Any, deferred_call_args: tp.Dict) -> tp.Li class DeferredCall: - '''Dummy class that indicates that a call has to be deferred''' - ... - + """Dummy class that indicates that a call has to be deferred""" + ... diff --git a/elegy/module_slicing_test.py b/elegy/module_slicing_test.py index dd4bf6d3..ba58ee97 100644 --- a/elegy/module_slicing_test.py +++ b/elegy/module_slicing_test.py @@ -89,7 +89,7 @@ def test_retrain(self): def test_no_path(self): x = jnp.ones((32, 100)) basicmodule = BasicModule0() - for start_module in ['linear2', 'linear1']: + for start_module in ["linear2", "linear1"]: try: submodule = elegy.module_slicing.slice_module_from_to( basicmodule, start_module, "linear0", x @@ -98,7 +98,7 @@ def test_no_path(self): assert e.args[0].startswith(f"No path from /{start_module} to /linear0") else: assert False, "No error or wrong error raised" - + def test_multi_input_modules(self): x = jnp.ones((32, 100)) @@ -106,71 +106,69 @@ def test_multi_input_modules(self): model = elegy.Model(module) model.summary(x) - submodule = elegy.module_slicing.slice_module_from_to(module, None, '/multi_input_module', x) - submodel = elegy.Model(submodule) + submodule = elegy.module_slicing.slice_module_from_to( + module, None, "/multi_input_module", x + ) + submodel = elegy.Model(submodule) submodel.summary(x) print(submodule.get_parameters()) y = submodel.predict(x) print(y.shape) - assert(y.shape==(32,25)) - assert(jnp.allclose(y, module.test_call(x) )) + assert y.shape == (32, 25) + assert jnp.allclose(y, module.test_call(x)) - def test_computationally_equivalent_paths(self): import networkx as nx + G = nx.DiGraph() - G.add_edge(0,1, inkey=0) - G.add_edge(1,2, inkey=0) - G.add_edge(0,2, inkey=0) #0->2 is equivalent to the path 0->1->2 - G.add_edge(2,3, inkey=0) - G.add_edge(3,4, inkey=0) + G.add_edge(0, 1, inkey=0) + G.add_edge(1, 2, inkey=0) + G.add_edge(0, 2, inkey=0) # 0->2 is equivalent to the path 0->1->2 + G.add_edge(2, 3, inkey=0) + G.add_edge(3, 4, inkey=0) - g0 = G.edge_subgraph([(0,1), (1,2), (2,3)]).copy() - g1 = G.edge_subgraph([(0,2), (2,3)]).copy() + g0 = G.edge_subgraph([(0, 1), (1, 2), (2, 3)]).copy() + g1 = G.edge_subgraph([(0, 2), (2, 3)]).copy() apce = elegy.module_slicing.are_paths_computationally_equivalent fcep = elegy.module_slicing.filter_computationally_equivalent_paths - assert apce(g0,g1) - assert apce(g1,g0) - filtered_paths = fcep([g0,g1]) + assert apce(g0, g1) + assert apce(g1, g0) + filtered_paths = fcep([g0, g1]) assert len(filtered_paths) == 1 assert filtered_paths[0] == g1 G = nx.DiGraph() - G.add_edge(0,1, inkey=0) - G.add_edge(1,2, inkey=0) - G.add_edge(0,2, inkey=1) #not equivalent, multi-input module - G.add_edge(2,3, inkey=0) - G.add_edge(3,4, inkey=0) + G.add_edge(0, 1, inkey=0) + G.add_edge(1, 2, inkey=0) + G.add_edge(0, 2, inkey=1) # not equivalent, multi-input module + G.add_edge(2, 3, inkey=0) + G.add_edge(3, 4, inkey=0) - g0 = G.edge_subgraph([(0,1), (1,2), (2,3)]).copy() - g1 = G.edge_subgraph([(0,2), (2,3)]).copy() - g2 = G.edge_subgraph([(0,2), (2,3), (3,4)]).copy() + g0 = G.edge_subgraph([(0, 1), (1, 2), (2, 3)]).copy() + g1 = G.edge_subgraph([(0, 2), (2, 3)]).copy() + g2 = G.edge_subgraph([(0, 2), (2, 3), (3, 4)]).copy() apce = elegy.module_slicing.are_paths_computationally_equivalent - assert not apce(g0,g1) - assert not apce(g1,g0) - assert not apce(g1,g2) - filtered_paths = fcep([g0,g1,g2]) + assert not apce(g0, g1) + assert not apce(g1, g0) + assert not apce(g1, g2) + filtered_paths = fcep([g0, g1, g2]) assert len(filtered_paths) == 3 assert g0 in filtered_paths and g1 in filtered_paths and g2 in filtered_paths - - def test_split_merge_args_kwargs(self): - args_kwargs = elegy.module_slicing.merge_args_kwargs(0,101,-2,a=65,b=77) - assert len(args_kwargs)==5 - for x in [(0,0), (1,101), (2,-2), ('a',65), ('b',77)]: + args_kwargs = elegy.module_slicing.merge_args_kwargs(0, 101, -2, a=65, b=77) + assert len(args_kwargs) == 5 + for x in [(0, 0), (1, 101), (2, -2), ("a", 65), ("b", 77)]: assert x in args_kwargs - - args,kwargs = elegy.module_slicing.split_merged_args_kwargs(args_kwargs) - assert args==(0,101,-2) - assert len(kwargs)==2 - assert kwargs['a']==65 and kwargs['b']==77 - + args, kwargs = elegy.module_slicing.split_merged_args_kwargs(args_kwargs) + assert args == (0, 101, -2) + assert len(kwargs) == 2 + assert kwargs["a"] == 65 and kwargs["b"] == 77 class BasicModule0(elegy.Module): @@ -185,18 +183,20 @@ def test_call(self, x): x = self.linear1(x) return x + class MultiInputModule(elegy.Module): def call(self, x0, x1): - return x0[...,:25]+x1[...,:25] + return x0[..., :25] + x1[..., :25] + class ContainsMultiInputModule(elegy.Module): def call(self, x): - x0 = elegy.nn.Linear(25, name='linear0')(x) - x = MultiInputModule(name='multi_input_module')(x,x0) + x0 = elegy.nn.Linear(25, name="linear0")(x) + x = MultiInputModule(name="multi_input_module")(x, x0) x = elegy.nn.Linear(10)(x) return x - + def test_call(self, x): x0 = self.linear0(x) x = self.multi_input_module(x, x0) - return x \ No newline at end of file + return x From aa020243deb49cae6ba841bd7855d03fb65ed18f Mon Sep 17 00:00:00 2001 From: alexander-g <3867427+alexander-g@users.noreply.github.com> Date: Wed, 23 Dec 2020 16:13:29 +0100 Subject: [PATCH 10/16] fixing poetry --- poetry.lock | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/poetry.lock b/poetry.lock index 487bd039..015c4e57 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1981,7 +1981,7 @@ testing = ["jaraco.itertools", "func-timeout"] [metadata] lock-version = "1.1" python-versions = "^3.6.1" -content-hash = "bbc2c87ac11ebfaf59a88dfd6492af93fb52b567fcd2742e5ed97f8ce7c04f9e" +content-hash = "749fa52b414dfcf88cda6336f12473dad382593ea8c8d0358417e35fcafd1440" [metadata.files] absl-py = [ From 3b3d88afb3446208a45c55a04b6de0fb936a6399 Mon Sep 17 00:00:00 2001 From: alexander-g <3867427+alexander-g@users.noreply.github.com> Date: Wed, 23 Dec 2020 17:30:41 +0100 Subject: [PATCH 11/16] Module.slice() --- docs/api/Module.md | 1 + docs/api/module/Module.md | 1 + docs/api/nn/Sequential.md | 1 + elegy/module.py | 51 ++++++++++++++++++++++++++++++++++++ elegy/module_slicing.py | 11 +------- elegy/module_slicing_test.py | 23 +++++----------- 6 files changed, 61 insertions(+), 27 deletions(-) diff --git a/docs/api/Module.md b/docs/api/Module.md index bb0cdb27..0304f149 100644 --- a/docs/api/Module.md +++ b/docs/api/Module.md @@ -13,4 +13,5 @@ - reset - init - initialized + - slice \ No newline at end of file diff --git a/docs/api/module/Module.md b/docs/api/module/Module.md index 4fce8e01..05526de1 100644 --- a/docs/api/module/Module.md +++ b/docs/api/module/Module.md @@ -13,4 +13,5 @@ - reset - init - initialized + - slice \ No newline at end of file diff --git a/docs/api/nn/Sequential.md b/docs/api/nn/Sequential.md index 660c84cd..c0ae3fe0 100644 --- a/docs/api/nn/Sequential.md +++ b/docs/api/nn/Sequential.md @@ -13,4 +13,5 @@ - reset - init - initialized + - slice \ No newline at end of file diff --git a/elegy/module.py b/elegy/module.py index 84f6edff..652eebb7 100644 --- a/elegy/module.py +++ b/elegy/module.py @@ -17,6 +17,9 @@ from elegy.random import RNG from elegy.utils import EMPTY, Empty, Mode, ModuleOrderError +# imported later because of a circular dependency +# from elegy.module_slicing import slice_module_from_to + __all__ = [ "Module", "to_module", @@ -244,6 +247,7 @@ class Module(metaclass=ModuleMeta): "reset", "init", "initialized", + "slice", ] def __init__(self, name: tp.Optional[str] = None, dtype: np.dtype = jnp.float32): @@ -572,6 +576,53 @@ def states_bytes(self, include_submodules: bool = True): ) ) + def slice( + self, + start_module: tp.Union["Module", str, None], + end_module: tp.Union[ + "Module", str, None, tp.List[tp.Union["Module", str, None]] + ], + sample_input: np.ndarray, + ) -> "Module": + """ + Creates a new submodule starting from the input of `start_module` to the outputs of `end_module`. + + Current limitations: + + - all operations between `start_module` and `end_module` must be performed by modules + i.e. `jax.nn.relu()` or `x+1` is not allowed but can be converted by wrapping with `elegy.to_module()` + - only one `start_module` is supported + - all modules between `start_module` and `end_module` must have a single output + + + Example usage: + ``` + x = jnp.zeros((2, 224, 224, 3)) + resnet = elegy.nets.resnet.ResNet18() + submodule = resnet.slice( + start_module=None, + end_module=["/res_net_block_1", "/res_net_block_3", "/res_net_block_5", "/res_net_block_7" ], + sample_input=x, + ) + outputs = elegy.Model(submodule).predict(x) + assert outputs[0].shape == (2, 56, 56, 64) + assert outputs[1].shape == (2, 28, 28, 128) + assert outputs[2].shape == (2, 14, 14, 256) + assert outputs[3].shape == (2, 7, 7, 512) + ``` + + Arguments: + start_module: Child module or name of a child module which will be the input module of the resulting module. + If `None`, the first module is used. + end_module: Child module, name of child module, `None` or a list thereof which will be the output module(s) of the resulting module. + If `None`, the last module is used. + sample_input: An array representing a sample input to the parent module. + """ + # importing here because of a circular dependency + from elegy.module_slicing import slice_module_from_to + + return slice_module_from_to(self, start_module, end_module, sample_input) + # ------------------------------------------------------------- # hooks diff --git a/elegy/module_slicing.py b/elegy/module_slicing.py index a3fd7519..9c4b8b93 100644 --- a/elegy/module_slicing.py +++ b/elegy/module_slicing.py @@ -1,13 +1,11 @@ import networkx as nx import elegy -from elegy import Module +from elegy.module import Module import jax import itertools import typing as tp import numpy as np -__all__ = ["slice_module_from_to"] - def slice_module_from_to( module: Module, @@ -15,13 +13,6 @@ def slice_module_from_to( end_module: tp.Union[Module, str, None, tp.List[tp.Union[Module, str, None]]], sample_input: np.ndarray, ) -> Module: - """Creates a new submodule starting from the input of `start_module` to the outputs of `end_module`. - Current limitations: - - only one `start_module` is supported - - all operations between `start_module` and `end_module` must be performed by modules - i.e. `jax.nn.relu()` or `x+1` is not allowed but can be converted by wrapping with `elegy.to_module()` - - all modules between `start_module` and `end_module` must have a single output - """ assert not isinstance( start_module, (tp.Tuple, tp.List) ), "Multiple inputs not yet supported" diff --git a/elegy/module_slicing_test.py b/elegy/module_slicing_test.py index ba58ee97..3eadfd96 100644 --- a/elegy/module_slicing_test.py +++ b/elegy/module_slicing_test.py @@ -10,9 +10,7 @@ def test_basic_slice_by_ref(self): x = jnp.zeros((32, 100)) basicmodule = BasicModule0() basicmodule(x) # trigger creation of weights and submodules - submodule = elegy.module_slicing.slice_module_from_to( - basicmodule, basicmodule.linear0, basicmodule.linear1, x - ) + submodule = basicmodule.slice(basicmodule.linear0, basicmodule.linear1, x) submodel = elegy.Model(submodule) submodel.summary(x) assert submodel.predict(x).shape == (32, 10) @@ -24,9 +22,7 @@ def test_basic_slice_by_name(self): for start, end in START_END_COMBOS: print(start, end) basicmodule = BasicModule0() - submodule = elegy.module_slicing.slice_module_from_to( - basicmodule, start, end, x - ) + submodule = basicmodule.slice(start, end, x) submodel = elegy.Model(submodule) submodel.summary(x) assert submodel.predict(x).shape == (32, 10) @@ -35,8 +31,7 @@ def test_basic_slice_by_name(self): def test_resnet_multi_out(self): x = jnp.zeros((2, 224, 224, 3)) resnet = elegy.nets.resnet.ResNet18() - submodule = elegy.module_slicing.slice_module_from_to( - resnet, + submodule = resnet.slice( start_module=None, end_module=[ "/res_net_block_1", @@ -66,9 +61,7 @@ def test_retrain(self): y = jnp.zeros((32, 10)) basicmodule = BasicModule0() - submodule = elegy.module_slicing.slice_module_from_to( - basicmodule, "linear0", "linear1", x - ) + submodule = basicmodule.slice("linear0", "linear1", x) submodel = elegy.Model( submodule, loss=elegy.losses.MeanAbsoluteError(), @@ -91,9 +84,7 @@ def test_no_path(self): basicmodule = BasicModule0() for start_module in ["linear2", "linear1"]: try: - submodule = elegy.module_slicing.slice_module_from_to( - basicmodule, start_module, "linear0", x - ) + submodule = basicmodule.slice(start_module, "linear0", x) except RuntimeError as e: assert e.args[0].startswith(f"No path from /{start_module} to /linear0") else: @@ -106,9 +97,7 @@ def test_multi_input_modules(self): model = elegy.Model(module) model.summary(x) - submodule = elegy.module_slicing.slice_module_from_to( - module, None, "/multi_input_module", x - ) + submodule = module.slice(None, "/multi_input_module", x) submodel = elegy.Model(submodule) submodel.summary(x) print(submodule.get_parameters()) From e77a2074fcaf85c8f3a4f685d379be1cd2e897a4 Mon Sep 17 00:00:00 2001 From: alexander-g <3867427+alexander-g@users.noreply.github.com> Date: Sat, 26 Dec 2020 07:07:46 +0100 Subject: [PATCH 12/16] circular dependency fix --- elegy/module.py | 15 ++++++++------- elegy/module_slicing.py | 5 +++++ 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/elegy/module.py b/elegy/module.py index 57a6667c..527d9f3f 100644 --- a/elegy/module.py +++ b/elegy/module.py @@ -17,8 +17,10 @@ from elegy.random import RNG from elegy.utils import EMPTY, Empty, Mode, ModuleOrderError -# imported later because of a circular dependency -# from elegy.module_slicing import slice_module_from_to +# placeholder for module module_slicing.py +# injected from inside the module because of a circular dependency +module_slicing = None + __all__ = [ "Module", @@ -624,10 +626,9 @@ def slice( If `None`, the last module is used. sample_input: An array representing a sample input to the parent module. """ - # importing here because of a circular dependency - from elegy.module_slicing import slice_module_from_to - - return slice_module_from_to(self, start_module, end_module, sample_input) + return module_slicing.slice_module_from_to( + self, start_module, end_module, sample_input + ) # ------------------------------------------------------------- @@ -658,7 +659,7 @@ def call(self, x): module_or_name: The name of the summary or alternatively the module that this summary will represent. If a summary with the same name already exists a unique identifier will be generated. value: The value for the summary. - input_values: The input arguments for the module, required for slicing. + input_values: Input arguments (args, kwargs) as used to call the module (required for slicing). """ if LOCAL.summaries is None: diff --git a/elegy/module_slicing.py b/elegy/module_slicing.py index 9c4b8b93..d8a2cb01 100644 --- a/elegy/module_slicing.py +++ b/elegy/module_slicing.py @@ -6,6 +6,11 @@ import typing as tp import numpy as np +import sys +from . import module + +module.module_slicing = sys.modules[__name__] + def slice_module_from_to( module: Module, From 9a59c93c5f6493ef9ddc453d3b1c5ec3b445424d Mon Sep 17 00:00:00 2001 From: alexander-g <3867427+alexander-g@users.noreply.github.com> Date: Sat, 2 Jan 2021 08:31:22 +0100 Subject: [PATCH 13/16] can now specify inputs as an output target for Module.slice() --- elegy/__init__.py | 1 + elegy/module_slicing.py | 35 ++++++++++++++++++++++++++++++++--- elegy/module_slicing_test.py | 11 +++++++++++ 3 files changed, 44 insertions(+), 3 deletions(-) diff --git a/elegy/__init__.py b/elegy/__init__.py index f0497247..4c0c3929 100644 --- a/elegy/__init__.py +++ b/elegy/__init__.py @@ -41,6 +41,7 @@ training_context, value_and_grad, ) +from . import module_slicing __all__ = [ "Loss", diff --git a/elegy/module_slicing.py b/elegy/module_slicing.py index d8a2cb01..63df32b6 100644 --- a/elegy/module_slicing.py +++ b/elegy/module_slicing.py @@ -29,10 +29,17 @@ def slice_module_from_to( summaries = elegy.get_summaries() edges = [Edge(summ) for summ in summaries] + if start_module in ["/input", "input"]: + start_module = None start_id = get_input_id(edges, start_module) if not isinstance(end_module, (tp.Tuple, tp.List)): end_module = [end_module] - end_ids = [get_output_id(edges, m) for m in end_module] + end_ids = [ + get_output_id(edges, m) + if m not in ["/input", "input"] + else get_input_id(edges, None) + for m in end_module + ] graph = construct_graph(edges) dag_paths = [find_dag_path(graph, start_id, end_id) for end_id in end_ids] @@ -133,6 +140,22 @@ def construct_graph(edges: tp.List[Edge]) -> nx.DiGraph: depth=depth, **e.__dict__, ) + + # adding dummy edges from inputs to inputs + e = edges[-1] # edge representing the full module + merged_args_kwargs = merge_args_kwargs(*e.input_ids[0], **e.input_ids[1]) + for key, node_id in merged_args_kwargs: + G.add_edge( + node_id, + node_id, + inkey=key, + outkey=key, + depth=0, + module=lambda x: x, + modulename="Inputs", + input_ids=[node_id], + output_ids=[node_id], + ) return G @@ -215,7 +238,11 @@ def find_dag_path(graph: nx.DiGraph, start_node: int, end_node: int) -> nx.DiGra nx.all_simple_edge_paths(graph, start_node, end_node) ) # list of lists of tuples if len(edge_paths) == 0: - raise nx.NetworkXNoPath + if start_node == end_node and (start_node, end_node) in graph.edges: + # input -> input + edge_paths = [[(start_node, end_node)]] + else: + raise nx.NetworkXNoPath except nx.NetworkXNoPath: raise RuntimeError( f"No path from {startname} to {endname}. Make sure all operations inbetween are performed by modules." @@ -311,7 +338,9 @@ def visit_node( continue if edge.get("is_output", False): outputs.append(y) - outputs.extend(self.visit_node(nextnode, y, deferred_call_args)) + if node != nextnode: + outputs.extend(self.visit_node(nextnode, y, deferred_call_args)) + # else: input -> input return outputs diff --git a/elegy/module_slicing_test.py b/elegy/module_slicing_test.py index 3eadfd96..19ae98f8 100644 --- a/elegy/module_slicing_test.py +++ b/elegy/module_slicing_test.py @@ -28,6 +28,17 @@ def test_basic_slice_by_name(self): assert submodel.predict(x).shape == (32, 10) assert jnp.all(submodel.predict(x) == basicmodule.test_call(x)) + def test_slice_return_input(self): + x = jnp.zeros((32, 100)) + basicmodule = BasicModule0() + submodule = basicmodule.slice("input", ["/linear1", "input"], x) + submodel = elegy.Model(submodule) + submodel.summary(x) + ypred = submodel.predict(x) + assert jnp.all(ypred[1] == x) + assert ypred[0].shape == (32, 10) + assert jnp.all(ypred[0] == basicmodule.test_call(x)) + def test_resnet_multi_out(self): x = jnp.zeros((2, 224, 224, 3)) resnet = elegy.nets.resnet.ResNet18() From c89df43b1b2d152c7fb3e7cc40d72d5119bcd7e0 Mon Sep 17 00:00:00 2001 From: alexander-g <3867427+alexander-g@users.noreply.github.com> Date: Fri, 15 Jan 2021 13:38:00 +0100 Subject: [PATCH 14/16] slicing deferred call bugfix --- elegy/module_slicing.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/elegy/module_slicing.py b/elegy/module_slicing.py index 63df32b6..80aca870 100644 --- a/elegy/module_slicing.py +++ b/elegy/module_slicing.py @@ -333,7 +333,7 @@ def visit_node( outputs = [] for nextnode, edge in self._tree[node].items(): y = self.visit_edge(edge, x, deferred_call_args) - if y == DeferredCall: + if y is DeferredCall: # visited edge module is missing some inputs, will come back here later continue if edge.get("is_output", False): From d9083981111977daff39a4122977385a5509d288 Mon Sep 17 00:00:00 2001 From: alexander-g <3867427+alexander-g@users.noreply.github.com> Date: Sat, 20 Feb 2021 16:50:16 +0100 Subject: [PATCH 15/16] update to 0.6.0 --- elegy/hooks.py | 4 ++-- elegy/model/model.py | 4 +++- elegy/model/model_core.py | 6 ++++++ elegy/module.py | 6 +++--- elegy/module_slicing.py | 20 ++++++++------------ elegy/module_slicing_test.py | 24 +++++++++++------------- elegy/types.py | 19 ++++++++++++++----- 7 files changed, 47 insertions(+), 36 deletions(-) diff --git a/elegy/hooks.py b/elegy/hooks.py index 0838abea..ed469403 100644 --- a/elegy/hooks.py +++ b/elegy/hooks.py @@ -105,6 +105,7 @@ def add_summary( path: types.Path, module: tp.Any, value: tp.Any, + input_values: tp.Optional[tp.Tuple[tp.Tuple, tp.Dict]] = None, ) -> None: """ A hook that lets you define a summary in the current module. Its primary @@ -127,8 +128,7 @@ def call(self, x): if not summaries_active(): return - - LOCAL.summaries.append(types.Summary(path, module, value)) + LOCAL.summaries.append(types.Summary(path, module, value, input_values)) def get_losses() -> types.Logs: diff --git a/elegy/model/model.py b/elegy/model/model.py index ea20dab1..3b5b76a2 100644 --- a/elegy/model/model.py +++ b/elegy/model/model.py @@ -223,7 +223,7 @@ def summary_step( entries: tp.List[types.SummaryTableEntry] = [] - for path, module, value in summaries: + for path, module, value, input_values in summaries: module_params, module_states = self.api_module.get_summary_params( path=path, @@ -239,7 +239,9 @@ def summary_step( module_type_name=( module.__class__.__name__ if is_generalizable(module) else "" ), + module=module, output_value=value, + input_value=input_values, trainable_params_count=( utils.parameters_count(module_params) if module_params is not None diff --git a/elegy/model/model_core.py b/elegy/model/model_core.py index edecc703..9e9f4fb0 100644 --- a/elegy/model/model_core.py +++ b/elegy/model/model_core.py @@ -431,6 +431,7 @@ def summary( depth: int = 2, tablefmt: str = "fancy_grid", return_repr: bool = False, + return_raw_entries: bool = False, **tablulate_kwargs, ) -> tp.Optional[str]: """ @@ -468,6 +469,9 @@ def summary( total_entry = entries[-1] entries = entries[:-1] + if return_raw_entries: + return entries + depth_groups: tp.Dict[str, tp.List[types.SummaryTableEntry]] = toolz.groupby( lambda entry: "/".join(entry.path.split("/")[:depth]), entries ) @@ -480,7 +484,9 @@ def get_grouped_entry( return types.SummaryTableEntry( path=entry.path, module_type_name=entry.module_type_name, + module=entry.module, output_value=entry.output_value, + input_value=entry.input_value, trainable_params_count=sum( entry_.trainable_params_count for entry_ in group ), diff --git a/elegy/module.py b/elegy/module.py index 7855c2ef..62bd9c3f 100644 --- a/elegy/module.py +++ b/elegy/module.py @@ -374,7 +374,7 @@ def __call__(self, *args, **kwargs) -> tp.Any: if hooks.summaries_active(): path = get_module_path(self) assert path is not None - hooks.add_summary(path, self, outputs) + hooks.add_summary(path, self, outputs, (args, kwargs)) return outputs @@ -382,11 +382,11 @@ def __call__(self, *args, **kwargs) -> tp.Any: def call(self, *args, **kwargs): ... - def add_summary(self, name: str, f: tp.Any, value: tp.Any): + def add_summary(self, name: str, f: tp.Any, value: tp.Any, input_values: tp.Optional[tp.Tuple[tp.Tuple, tp.Dict]] = None): if hooks.summaries_active(): path = get_module_path(self) + (name,) assert path is not None - hooks.add_summary(path, f, value) + hooks.add_summary(path, f, value, input_values) def init( self, diff --git a/elegy/module_slicing.py b/elegy/module_slicing.py index 80aca870..08288520 100644 --- a/elegy/module_slicing.py +++ b/elegy/module_slicing.py @@ -23,10 +23,8 @@ def slice_module_from_to( ), "Multiple inputs not yet supported" # get info about the module structure via summaries - model = elegy.Model(module) - with elegy.hooks_context(summaries=True): - model.predict_fn(sample_input) - summaries = elegy.get_summaries() + model = elegy.Model(module, run_eagerly=True) + summaries = model.summary(sample_input, return_raw_entries=True) edges = [Edge(summ) for summ in summaries] if start_module in ["/input", "input"]: @@ -51,15 +49,13 @@ def slice_module_from_to( class Edge: """A struct to hold edge data""" - def __init__(self, summary: tp.Tuple[Module, str, np.ndarray, tp.Any]): - self.module = summary[0] - # remove the full module name, leave the leading '/' - self.modulename = ( - summary[1][summary[1].find("/") :] if "/" in summary[1] else "/" - ) + def __init__(self, summary: elegy.types.SummaryTableEntry): + self.module = summary.module + # standardize paths with a leading '/' + self.modulename = '/'+summary.path # convert the output and input arrays in the summary to unique IDs as returned by id() - self.output_ids = jax.tree_leaves(jax.tree_map(id, summary[2])) - self.input_ids = jax.tree_map(id, summary[3]) + self.output_ids = jax.tree_leaves(jax.tree_map(id, summary.output_value)) + self.input_ids = jax.tree_map(id, summary.input_value) def search_edges( diff --git a/elegy/module_slicing_test.py b/elegy/module_slicing_test.py index 19ae98f8..29571731 100644 --- a/elegy/module_slicing_test.py +++ b/elegy/module_slicing_test.py @@ -9,7 +9,7 @@ class ModuleSlicingTest(TestCase): def test_basic_slice_by_ref(self): x = jnp.zeros((32, 100)) basicmodule = BasicModule0() - basicmodule(x) # trigger creation of weights and submodules + basicmodule.init(rng=elegy.RNGSeq(0), set_defaults=True)(x) submodule = basicmodule.slice(basicmodule.linear0, basicmodule.linear1, x) submodel = elegy.Model(submodule) submodel.summary(x) @@ -24,13 +24,15 @@ def test_basic_slice_by_name(self): basicmodule = BasicModule0() submodule = basicmodule.slice(start, end, x) submodel = elegy.Model(submodule) - submodel.summary(x) + basicmodule.init(rng=elegy.RNGSeq(0), set_defaults=True)(x) + #submodel.summary(x) assert submodel.predict(x).shape == (32, 10) assert jnp.all(submodel.predict(x) == basicmodule.test_call(x)) def test_slice_return_input(self): x = jnp.zeros((32, 100)) basicmodule = BasicModule0() + basicmodule.init(rng=elegy.RNGSeq(0), set_defaults=True)(x) submodule = basicmodule.slice("input", ["/linear1", "input"], x) submodel = elegy.Model(submodule) submodel.summary(x) @@ -42,6 +44,7 @@ def test_slice_return_input(self): def test_resnet_multi_out(self): x = jnp.zeros((2, 224, 224, 3)) resnet = elegy.nets.resnet.ResNet18() + resnet.init(rng=elegy.RNGSeq(0), set_defaults=True)(x) submodule = resnet.slice( start_module=None, end_module=[ @@ -64,14 +67,12 @@ def test_resnet_multi_out(self): assert outputs[3].shape == (2, 7, 7, 512) assert outputs[4].shape == (2, 7, 7, 512) - print(jax.tree_map(jnp.shape, resnet.get_parameters())) - print(jax.tree_map(jnp.shape, submodel.get_parameters())) - def test_retrain(self): x = jnp.ones((32, 100)) y = jnp.zeros((32, 10)) basicmodule = BasicModule0() + basicmodule.init(rng=elegy.RNGSeq(0), set_defaults=True)(x) submodule = basicmodule.slice("linear0", "linear1", x) submodel = elegy.Model( submodule, @@ -84,9 +85,6 @@ def test_retrain(self): submodel.fit(x, y, epochs=3, verbose=2) y2 = submodel.predict(x) - y3 = basicmodule.test_call(x) - - assert jnp.all(y2 == y3) # output after training should be closer to zero because targets are zero assert jnp.abs(y2.mean()) < jnp.abs(y0.mean()) @@ -105,13 +103,13 @@ def test_multi_input_modules(self): x = jnp.ones((32, 100)) module = ContainsMultiInputModule() + module.init(rng=elegy.RNGSeq(0), set_defaults=True)(x) model = elegy.Model(module) model.summary(x) submodule = module.slice(None, "/multi_input_module", x) submodel = elegy.Model(submodule) submodel.summary(x) - print(submodule.get_parameters()) y = submodel.predict(x) print(y.shape) @@ -179,8 +177,8 @@ def call(self, x): return x def test_call(self, x): - x = self.linear0(x) - x = self.linear1(x) + x = self.linear0.call_with_defaults()(x) + x = self.linear1.call_with_defaults()(x) return x @@ -197,6 +195,6 @@ def call(self, x): return x def test_call(self, x): - x0 = self.linear0(x) - x = self.multi_input_module(x, x0) + x0 = self.linear0.call_with_defaults()(x) + x = self.multi_input_module.call_with_defaults()(x, x0) return x diff --git a/elegy/types.py b/elegy/types.py index db533b8c..05633549 100644 --- a/elegy/types.py +++ b/elegy/types.py @@ -107,16 +107,17 @@ class Summary(tp.NamedTuple): path: Path module: tp.Optional[SummaryModule] value: SummaryValue + input_values: tp.Union[tp.Tuple[tp.Tuple, tp.Dict], None] = None def tree_flatten(self): - return ((self.value,), (self.path, self.module)) + return ((self.value,self.input_values), (self.path, self.module)) @classmethod def tree_unflatten(cls, aux_data, children): - (value,) = children + (value,input_values) = children path, module = aux_data - return cls(path, module, value) + return cls(path, module, value, input_values) Summaries = tp.List[Summary] @@ -126,7 +127,9 @@ def tree_unflatten(cls, aux_data, children): class SummaryTableEntry(tp.NamedTuple): path: str module_type_name: str + module: tp.Any output_value: Pytree + input_value: Pytree trainable_params_count: int trainable_params_size: int non_trainable_params_count: int @@ -143,7 +146,9 @@ def totals_entry( return cls( path="", module_type_name="", + module=None, output_value=None, + input_value=None, trainable_params_count=trainable_params_count, trainable_params_size=trainable_params_size, non_trainable_params_count=non_trainable_params_count, @@ -152,10 +157,11 @@ def totals_entry( def tree_flatten(self): return ( - (self.output_value,), + (self.output_value,self.input_value), ( self.path, self.module_type_name, + self.module, self.trainable_params_count, self.trainable_params_size, self.non_trainable_params_count, @@ -168,17 +174,20 @@ def tree_unflatten(cls, aux_data, children): ( path, module_type_name, + module, trainable_params_count, trainable_params_size, non_trainable_params_count, non_trainable_params_size, ) = aux_data - (output_value,) = children + (output_value,input_value) = children return cls( path=path, module_type_name=module_type_name, + module=module, output_value=output_value, + input_value=input_value, trainable_params_count=trainable_params_count, trainable_params_size=trainable_params_size, non_trainable_params_count=non_trainable_params_count, From 2c9cf1c624a0b5df37181e846178798df4831a93 Mon Sep 17 00:00:00 2001 From: alexander-g <3867427+alexander-g@users.noreply.github.com> Date: Sat, 20 Feb 2021 17:09:27 +0100 Subject: [PATCH 16/16] test fixes and black --- elegy/hooks_test.py | 4 ++-- elegy/module.py | 9 +++++++-- elegy/module_slicing.py | 2 +- elegy/module_slicing_test.py | 2 +- elegy/module_test.py | 4 ++++ elegy/types.py | 8 ++++---- 6 files changed, 19 insertions(+), 10 deletions(-) diff --git a/elegy/hooks_test.py b/elegy/hooks_test.py index f517617d..54952731 100644 --- a/elegy/hooks_test.py +++ b/elegy/hooks_test.py @@ -32,7 +32,7 @@ def test_summaries(self): elegy.hooks.add_summary(("a", 0, "b"), None, 2.0) summaries = elegy.hooks.get_summaries() - assert summaries[0] == (("a", 0, "b"), None, 2.0) + assert summaries[0] == (("a", 0, "b"), None, 2.0, None) def test_no_summaries(self): assert not elegy.hooks.summaries_active() @@ -65,4 +65,4 @@ def f(x): assert x == 6 assert losses["x_loss"] == 6 assert metrics["x"] == 7 - assert summaries[0] == (("a", 0, "b"), jax.nn.relu, 8) + assert summaries[0] == (("a", 0, "b"), jax.nn.relu, 8, None) diff --git a/elegy/module.py b/elegy/module.py index 4b482f77..92cae6dd 100644 --- a/elegy/module.py +++ b/elegy/module.py @@ -382,7 +382,13 @@ def __call__(self, *args, **kwargs) -> tp.Any: def call(self, *args, **kwargs): ... - def add_summary(self, name: str, f: tp.Any, value: tp.Any, input_values: tp.Optional[tp.Tuple[tp.Tuple, tp.Dict]] = None): + def add_summary( + self, + name: str, + f: tp.Any, + value: tp.Any, + input_values: tp.Optional[tp.Tuple[tp.Tuple, tp.Dict]] = None, + ): if hooks.summaries_active(): path = get_module_path(self) + (name,) assert path is not None @@ -623,7 +629,6 @@ def slice( self, start_module, end_module, sample_input ) - def update_parameter(self, name: str, value: tp.Any) -> None: """ Update a parameter of the current module. diff --git a/elegy/module_slicing.py b/elegy/module_slicing.py index 08288520..8f163b2b 100644 --- a/elegy/module_slicing.py +++ b/elegy/module_slicing.py @@ -52,7 +52,7 @@ class Edge: def __init__(self, summary: elegy.types.SummaryTableEntry): self.module = summary.module # standardize paths with a leading '/' - self.modulename = '/'+summary.path + self.modulename = "/" + summary.path # convert the output and input arrays in the summary to unique IDs as returned by id() self.output_ids = jax.tree_leaves(jax.tree_map(id, summary.output_value)) self.input_ids = jax.tree_map(id, summary.input_value) diff --git a/elegy/module_slicing_test.py b/elegy/module_slicing_test.py index 29571731..d6ad9af4 100644 --- a/elegy/module_slicing_test.py +++ b/elegy/module_slicing_test.py @@ -25,7 +25,7 @@ def test_basic_slice_by_name(self): submodule = basicmodule.slice(start, end, x) submodel = elegy.Model(submodule) basicmodule.init(rng=elegy.RNGSeq(0), set_defaults=True)(x) - #submodel.summary(x) + # submodel.summary(x) assert submodel.predict(x).shape == (32, 10) assert jnp.all(submodel.predict(x) == basicmodule.test_call(x)) diff --git a/elegy/module_test.py b/elegy/module_test.py index 16b371cf..59c9d23d 100644 --- a/elegy/module_test.py +++ b/elegy/module_test.py @@ -214,11 +214,13 @@ def call(self, x): ("ais/1",), m.ais[1], 12, + ((2.0,), {}), ), ( (), m, 13, + ((2.0,), {}), ), ] assert parameters == { @@ -256,11 +258,13 @@ def call(self, x): ("a_1",), m.a_1, 12, + ((2.0,), {}), ), ( (), m, 13, + ((2.0,), {}), ), ] assert params == { diff --git a/elegy/types.py b/elegy/types.py index 862ed98a..de2b12db 100644 --- a/elegy/types.py +++ b/elegy/types.py @@ -121,11 +121,11 @@ class Summary(tp.NamedTuple): input_values: tp.Union[tp.Tuple[tp.Tuple, tp.Dict], None] = None def tree_flatten(self): - return ((self.value,self.input_values), (self.path, self.module)) + return ((self.value, self.input_values), (self.path, self.module)) @classmethod def tree_unflatten(cls, aux_data, children): - (value,input_values) = children + (value, input_values) = children path, module = aux_data return cls(path, module, value, input_values) @@ -168,7 +168,7 @@ def totals_entry( def tree_flatten(self): return ( - (self.output_value,self.input_value), + (self.output_value, self.input_value), ( self.path, self.module_type_name, @@ -191,7 +191,7 @@ def tree_unflatten(cls, aux_data, children): non_trainable_params_count, non_trainable_params_size, ) = aux_data - (output_value,input_value) = children + (output_value, input_value) = children return cls( path=path,