From e622fdeba3d01fb9a5915344139c091adffd56e9 Mon Sep 17 00:00:00 2001 From: alexander-g <3867427+alexander-g@users.noreply.github.com> Date: Mon, 23 Nov 2020 17:20:17 +0100 Subject: [PATCH] experimental module slicing --- elegy/model/model_base.py | 2 +- elegy/module.py | 8 +- elegy/module_slicing.py | 193 +++++++++++++++++++++++++++++++++++ elegy/module_slicing_test.py | 75 ++++++++++++++ elegy/nets/resnet.py | 55 +++++++--- 5 files changed, 313 insertions(+), 20 deletions(-) create mode 100644 elegy/module_slicing.py create mode 100644 elegy/module_slicing_test.py diff --git a/elegy/model/model_base.py b/elegy/model/model_base.py index be1d8f42..8f4ca582 100644 --- a/elegy/model/model_base.py +++ b/elegy/model/model_base.py @@ -436,7 +436,7 @@ def format_size(size): table: tp.List = [["Inputs", format_output(x), "0", "0"]] - for module, base_name, value in summaries: + for module, base_name, value, _ in summaries: base_name_parts = base_name.split("/")[1:] module_depth = len(base_name_parts) diff --git a/elegy/module.py b/elegy/module.py index 8d589a6a..3228232f 100644 --- a/elegy/module.py +++ b/elegy/module.py @@ -331,7 +331,7 @@ def __call__(self, *args, **kwargs) -> tp.Any: else: outputs = self.call(*args, **kwargs) - add_summary(self, outputs) + add_summary(self, outputs, (args, kwargs)) return outputs @@ -577,7 +577,9 @@ def states_bytes(self, include_submodules: bool = True): # ------------------------------------------------------------- -def add_summary(module_or_name: tp.Union[Module, str], value: np.ndarray) -> None: +def add_summary( + module_or_name: tp.Union[Module, str], value: np.ndarray, input_values=None +) -> None: """ A hook that lets you define a summary in the current module. Its primary use is to keep track of certain values as they flow through the network @@ -609,7 +611,7 @@ def call(self, x): else: module = module_or_name - LOCAL.summaries.append((module, name, value)) + LOCAL.summaries.append((module, name, value, input_values)) def add_loss(name: str, value: np.ndarray) -> None: diff --git a/elegy/module_slicing.py b/elegy/module_slicing.py new file mode 100644 index 00000000..7a4ed9d4 --- /dev/null +++ b/elegy/module_slicing.py @@ -0,0 +1,193 @@ +import networkx as nx +import elegy +from elegy import Module +import jax +import itertools +import typing as tp +import numpy as np + +__all__ = ["slice_module_from_to"] + + +def slice_module_from_to( + module: Module, + start_module: tp.Union[Module, str, None], + end_module: tp.Union[Module, str, None, tp.List[tp.Union[Module, str, None]]], + sample_input: np.ndarray, +) -> Module: + """Creates a new submodule starting from the input of 'start_module' to the outputs of 'end_module'. + Current limitations: + - only one input module is supported + - all operations between start_module and end_module must be performed by modules + i.e. jax.nn.relu() or x+1 is not allowed but can be converted by wrapping with elegy.to_module() + - all modules between start_module and end_module must have a single input and a single output + - resulting module is currently not trainable + """ + assert not isinstance( + start_module, (tp.Tuple, tp.List) + ), "Multiple inputs not yet supported" + + # get info about the module structure via summaries + model = elegy.Model(module) + with elegy.hooks_context(summaries=True): + model.predict_fn(sample_input) + summaries = elegy.get_summaries() + + edges = [Edge(summ) for summ in summaries] + start_id = get_input_id(edges, start_module) + if not isinstance(end_module, (tp.Tuple, tp.List)): + end_module = [end_module] + end_ids = [get_output_id(edges, m) for m in end_module] + + graph = construct_graph(edges) + paths = [find_path(graph, start_id, end_id) for end_id in end_ids] + tree = combine_paths(paths) + submodule_call = construct_call(tree) + submodule = elegy.to_module(submodule_call)() + return submodule + + +class Edge: + """A struct to hold edge data""" + + def __init__(self, summary: tp.Tuple[Module, str, np.ndarray, tp.Any]): + self.module = summary[0] + # remove the full module name, leave the leading '/' + self.modulename = ( + summary[1][summary[1].find("/") :] if "/" in summary[1] else "/" + ) + # convert the output and input arrays in the summary to unique IDs as returned by id() + self.output_ids = jax.tree_leaves(jax.tree_map(id, summary[2])) + self.input_ids = jax.tree_map(id, summary[3]) + + +def search_edges( + edges: tp.List[Edge], searchtarget: tp.Union[Module, str, None] +) -> Edge: + """Searches 'edges' for 'searchtarget' which can be a module, name of a module or None""" + if searchtarget is None: + # None means input/output of the full module, which is the last edge + return edges[-1] + elif isinstance(searchtarget, str): + # search by name, with or without leading '/' + if not searchtarget.startswith("/"): + searchtarget = "/" + searchtarget + edges = [e for e in edges if e.modulename == searchtarget] + elif isinstance(searchtarget, Module): + # search by reference + edges = [e for e in edges if e.module == searchtarget] + assert len(edges) > 0, f"Could not find module {searchtarget}" + assert len(edges) < 2, f"Found {len(edges)} modules for {searchtarget}" + return edges[0] + + +def get_input_id(edges: tp.List[Edge], module: tp.Union[Module, str, None]) -> int: + """Searches for module in the list of edges and returns the ID of its input array""" + edge = search_edges(edges, module) + input_ids = jax.tree_leaves(edge.input_ids) + assert len(input_ids) == 1, "Multi-input modules not yet supported" + return input_ids[0] + + +def get_output_id(edges: tp.List[Edge], module: tp.Union[Module, str, None]) -> int: + """Searches for module in the list of edges and returns the ID of its output array""" + edge = search_edges(edges, module) + assert len(edge.output_ids) == 1, "Multi-output modules not yet supported" + return edge.output_ids[0] + + +def merge_args_kwargs(*args, **kwargs) -> tp.List[tp.Tuple[tp.Any, tp.Any]]: + """Merges args and kwargs and their indices to a list of tuples + e.g. merge_args_kwargs(0, 77, a=-2) returns [(0,0), (1,77), ('a',-2)]""" + return list(enumerate(args)) + list(kwargs.items()) + + +def construct_graph(edges: tp.List[Edge]) -> nx.DiGraph: + """Constructs a directed graph with IDs of input/output arrays representing the nodes + and modules (and some more infos) representing the edges""" + G = nx.DiGraph() + for e in edges: + merged_args_kwargs = merge_args_kwargs(*e.input_ids[0], **e.input_ids[1]) + inout_combos = itertools.product(merged_args_kwargs, enumerate(e.output_ids)) + for ((inkey, input_id), (outkey, output_id)) in inout_combos: + depth = e.modulename.count("/") + # it can happen that there are multiple connections between two nodes + # e.g. when a simple parent module has only one child module + # use the one with the lowest depth, i.e. the parent module + if ((input_id, output_id) not in G.edges) or ( + G[input_id][output_id].depth > depth + ): + G.add_edge( + input_id, + output_id, + inkey=inkey, + outkey=outkey, + depth=depth, + **e.__dict__, + ) + return G + + +def find_path(graph: nx.DiGraph, start_node: int, end_node: int) -> nx.DiGraph: + """Returns a new graph with only nodes and edges from start_node to end_node""" + # TODO: catch exceptions + pathnodes = nx.shortest_path(graph, start_node, end_node) + pathgraph = graph.subgraph(pathnodes).copy() + # pathgraph is unordered, need to mark input and output edges + pathgraph[pathnodes[0]][pathnodes[1]]["is_input"] = True + pathgraph[pathnodes[-2]][pathnodes[-1]]["is_output"] = True + return pathgraph + + +def combine_paths(paths: tp.List[nx.DiGraph]) -> nx.DiGraph: + return nx.algorithms.compose_all(paths) + + +def construct_call(tree: nx.DiGraph) -> tp.Callable: + """Returns a new function that represents the __call__ of the new sliced submodule""" + + def visit_edge(edge, x, next_node): + assert edge["inkey"] == 0, "inputs other than 0 not yet implemented" + x = edge["module"](x) + + if isinstance(x, (tuple, list)): + # XXX: what if the whole tuple/list is needed as input later? + x = x[edge["outkey"]] + + outputs = [] + if edge.get("is_output", False): + outputs.append(x) + + if len(tree[next_node]): + # continue traversing the graph if there are more edges + for next_node, next_edge in tree[next_node].items(): + nextx = visit_edge(next_edge, x, next_node) + if not isinstance(nextx, tp.Tuple): + nextx = (nextx,) + outputs.extend(nextx) + # else: no more edges + + outputs = tuple(outputs) + if len(outputs) == 1: + outputs = outputs[0] + return outputs + + def call(x, *args, **kwargs): + input_nodes = [ + nodes[0] + for nodes, edge in tree.edges.items() + if edge.get("is_input", False) + ] + assert len(set(input_nodes)), "multi-inputs not yet supported" + start_node = input_nodes[0] + + x = [ + visit_edge(next_edge, x, next_node) + for next_node, next_edge in tree[start_node].items() + ] + x = tuple(x) + if len(x) == 1: + x = x[0] + return x + + return call diff --git a/elegy/module_slicing_test.py b/elegy/module_slicing_test.py new file mode 100644 index 00000000..ddd33654 --- /dev/null +++ b/elegy/module_slicing_test.py @@ -0,0 +1,75 @@ +import elegy +import elegy.module_slicing +from unittest import TestCase +import jax, jax.numpy as jnp + + +class ModuleSlicingTest(TestCase): + def test_basic_slice_by_ref(self): + x = jnp.zeros((32, 100)) + basicmodule = BasicModule0() + basicmodule(x) # trigger creation of weights and submodules + submodule = elegy.module_slicing.slice_module_from_to( + basicmodule, basicmodule.linear0, basicmodule.linear1, x + ) + submodel = elegy.Model(submodule) + submodel.summary(x) + assert submodel.predict(x).shape == (32, 10) + assert jnp.all(submodel.predict(x) == basicmodule.test_call(x)) + + def test_basic_slice_by_name(self): + x = jnp.zeros((32, 100)) + START_END_COMBOS = [("linear0", "linear1"), (None, "/linear1")] + for start, end in START_END_COMBOS: + print(start, end) + basicmodule = BasicModule0() + submodule = elegy.module_slicing.slice_module_from_to( + basicmodule, start, end, x + ) + submodel = elegy.Model(submodule) + submodel.summary(x) + assert submodel.predict(x).shape == (32, 10) + assert jnp.all(submodel.predict(x) == basicmodule.test_call(x)) + + def test_resnet_multi_out(self): + x = jnp.zeros((2, 224, 224, 3)) + resnet = elegy.nets.resnet.ResNet18() + submodule = elegy.module_slicing.slice_module_from_to( + resnet, + start_module=None, + end_module=[ + "/res_net_block_1", + "/res_net_block_3", + "/res_net_block_5", + "/res_net_block_6", + "/res_net_block_7", + ], + sample_input=x, + ) + submodel = elegy.Model(submodule) + # submodel.summary(x) + outputs = submodel.predict(x) + print(jax.tree_map(jnp.shape, outputs)) + assert len(outputs) == 5 + assert outputs[0].shape == (2, 56, 56, 64) + assert outputs[1].shape == (2, 28, 28, 128) + assert outputs[2].shape == (2, 14, 14, 256) + assert outputs[3].shape == (2, 7, 7, 512) + assert outputs[4].shape == (2, 7, 7, 512) + + print(jax.tree_map(jnp.shape, resnet.get_parameters())) + print(jax.tree_map(jnp.shape, submodel.get_parameters())) + # assert False + + +class BasicModule0(elegy.Module): + def call(self, x): + x = elegy.nn.Linear(25, name="linear0")(x) + x = elegy.nn.Linear(10, name="linear1")(x) + x = elegy.nn.Linear(5, name="linear2")(x) + return x + + def test_call(self, x): + x = self.linear0(x) + x = self.linear1(x) + return x diff --git a/elegy/nets/resnet.py b/elegy/nets/resnet.py index ce1d6940..7879ef3f 100644 --- a/elegy/nets/resnet.py +++ b/elegy/nets/resnet.py @@ -7,44 +7,65 @@ class ResNetBlock(module.Module): """ResNet (identity) block""" - def call(self, x, n_filters, strides=(1, 1)): + def __init__(self, n_filters, strides=(1, 1), *args, **kwargs): + super().__init__(*args, **kwargs) + self.n_filters = n_filters + self.strides = strides + + def call(self, x): x0 = x x = nn.Conv2D( - n_filters, (3, 3), with_bias=False, stride=strides, dtype=self.dtype + self.n_filters, + (3, 3), + with_bias=False, + stride=self.strides, + dtype=self.dtype, )(x) x = nn.BatchNormalization(decay_rate=0.9, eps=1e-5)(x) x = jax.nn.relu(x) - x = nn.Conv2D(n_filters, (3, 3), with_bias=False, dtype=self.dtype)(x) + x = nn.Conv2D(self.n_filters, (3, 3), with_bias=False, dtype=self.dtype)(x) x = nn.BatchNormalization(decay_rate=0.9, eps=1e-5)(x) if x0.shape != x.shape: x0 = nn.Conv2D( - n_filters, (1, 1), with_bias=False, stride=strides, dtype=self.dtype + self.n_filters, + (1, 1), + with_bias=False, + stride=self.strides, + dtype=self.dtype, )(x0) x0 = nn.BatchNormalization(decay_rate=0.9, eps=1e-5)(x0) return jax.nn.relu(x0 + x) -class BottleneckResNetBlock(module.Module): +class BottleneckResNetBlock(ResNetBlock): """ResNet Bottleneck block.""" def call(self, x, n_filters, strides=(1, 1)): x0 = x - x = nn.Conv2D(n_filters, (1, 1), with_bias=False, dtype=self.dtype)(x) + x = nn.Conv2D(self.n_filters, (1, 1), with_bias=False, dtype=self.dtype)(x) x = nn.BatchNormalization(decay_rate=0.9, eps=1e-5)(x) x = jax.nn.relu(x) x = nn.Conv2D( - n_filters, (3, 3), with_bias=False, stride=strides, dtype=self.dtype + self.n_filters, + (3, 3), + with_bias=False, + stride=self.strides, + dtype=self.dtype, )(x) x = nn.BatchNormalization(decay_rate=0.9, eps=1e-5)(x) x = jax.nn.relu(x) - x = nn.Conv2D(n_filters * 4, (1, 1), with_bias=False, dtype=self.dtype)(x) + x = nn.Conv2D(self.n_filters * 4, (1, 1), with_bias=False, dtype=self.dtype)(x) x = nn.BatchNormalization(decay_rate=0.9, eps=1e-5, scale_init=jnp.zeros)(x) if x0.shape != x.shape: x0 = nn.Conv2D( - n_filters * 4, (1, 1), with_bias=False, stride=strides, dtype=self.dtype + self.n_filters * 4, + (1, 1), + with_bias=False, + stride=self.strides, + dtype=self.dtype, )(x0) x0 = nn.BatchNormalization(decay_rate=0.9, eps=1e-5)(x0) return jax.nn.relu(x0 + x) @@ -63,18 +84,20 @@ def call(self, x): 64, (7, 7), stride=(2, 2), padding="SAME", with_bias=False, dtype=self.dtype )(x) x = nn.BatchNormalization(decay_rate=0.9, eps=1e-5)(x) - x = jax.nn.relu(x) + x = module.to_module(jax.nn.relu)()(x) - x = nn.linear.hk.max_pool( - x, window_shape=(1, 3, 3, 1), strides=(1, 2, 2, 1), padding="SAME" - ) + x = module.to_module( + lambda _x: nn.linear.hk.max_pool( + _x, window_shape=(1, 3, 3, 1), strides=(1, 2, 2, 1), padding="SAME" + ) + )()(x) for i, block_size in enumerate(self.stages): for j in range(block_size): strides = (2, 2) if i > 0 and j == 0 else (1, 1) - x = self.block_type(dtype=self.dtype)(x, 64 * 2 ** i, strides=strides) - x = jnp.mean(x, axis=(1, 2)) + x = self.block_type(64 * 2 ** i, strides=strides, dtype=self.dtype)(x) + x = module.to_module(lambda _x: jnp.mean(_x, axis=(1, 2)))()(x) x = nn.Linear(1000, dtype=self.dtype)(x) - x = jnp.asarray(x, jnp.float32) + x = module.to_module(lambda _x: jnp.asarray(_x, jnp.float32))()(x) return x