Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Module Slicing #115

Open
wants to merge 22 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion elegy/model/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,7 @@ def format_size(size):

table: tp.List = [["Inputs", format_output(x), "0", "0"]]

for module, base_name, value in summaries:
for module, base_name, value, _ in summaries:
base_name_parts = base_name.split("/")[1:]
module_depth = len(base_name_parts)

Expand Down
8 changes: 5 additions & 3 deletions elegy/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def __call__(self, *args, **kwargs) -> tp.Any:
else:
outputs = self.call(*args, **kwargs)

add_summary(self, outputs)
add_summary(self, outputs, (args, kwargs))

return outputs

Expand Down Expand Up @@ -577,7 +577,9 @@ def states_bytes(self, include_submodules: bool = True):
# -------------------------------------------------------------


def add_summary(module_or_name: tp.Union[Module, str], value: np.ndarray) -> None:
def add_summary(
module_or_name: tp.Union[Module, str], value: np.ndarray, input_values=None
alexander-g marked this conversation as resolved.
Show resolved Hide resolved
) -> None:
"""
A hook that lets you define a summary in the current module. Its primary
use is to keep track of certain values as they flow through the network
Expand Down Expand Up @@ -609,7 +611,7 @@ def call(self, x):
else:
module = module_or_name

LOCAL.summaries.append((module, name, value))
LOCAL.summaries.append((module, name, value, input_values))


def add_loss(name: str, value: np.ndarray) -> None:
Expand Down
209 changes: 209 additions & 0 deletions elegy/module_slicing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
import networkx as nx
import elegy
from elegy import Module
import jax
import itertools
import typing as tp
import numpy as np

__all__ = ["slice_module_from_to"]


def slice_module_from_to(
module: Module,
start_module: tp.Union[Module, str, None],
end_module: tp.Union[Module, str, None, tp.List[tp.Union[Module, str, None]]],
sample_input: np.ndarray,
) -> Module:
"""Creates a new submodule starting from the input of 'start_module' to the outputs of 'end_module'.
Current limitations:
- only one input module is supported
- all operations between start_module and end_module must be performed by modules
i.e. jax.nn.relu() or x+1 is not allowed but can be converted by wrapping with elegy.to_module()
- all modules between start_module and end_module must have a single input and a single output
- resulting module is currently not trainable
"""
assert not isinstance(
start_module, (tp.Tuple, tp.List)
), "Multiple inputs not yet supported"

# get info about the module structure via summaries
model = elegy.Model(module)
with elegy.hooks_context(summaries=True):
model.predict_fn(sample_input)
summaries = elegy.get_summaries()

edges = [Edge(summ) for summ in summaries]
start_id = get_input_id(edges, start_module)
if not isinstance(end_module, (tp.Tuple, tp.List)):
end_module = [end_module]
end_ids = [get_output_id(edges, m) for m in end_module]

graph = construct_graph(edges)
paths = [find_path(graph, start_id, end_id) for end_id in end_ids]
tree = combine_paths(paths)
submodule = SlicedModule(tree)
return submodule


class Edge:
"""A struct to hold edge data"""

def __init__(self, summary: tp.Tuple[Module, str, np.ndarray, tp.Any]):
self.module = summary[0]
# remove the full module name, leave the leading '/'
self.modulename = (
summary[1][summary[1].find("/") :] if "/" in summary[1] else "/"
)
# convert the output and input arrays in the summary to unique IDs as returned by id()
self.output_ids = jax.tree_leaves(jax.tree_map(id, summary[2]))
self.input_ids = jax.tree_map(id, summary[3])


def search_edges(
edges: tp.List[Edge], searchtarget: tp.Union[Module, str, None]
) -> Edge:
"""Searches 'edges' for 'searchtarget' which can be a module, name of a module or None"""
if searchtarget is None:
# None means input/output of the full module, which is the last edge
return edges[-1]
elif isinstance(searchtarget, str):
# search by name, with or without leading '/'
if not searchtarget.startswith("/"):
searchtarget = "/" + searchtarget
edges = [e for e in edges if e.modulename == searchtarget]
elif isinstance(searchtarget, Module):
# search by reference
edges = [e for e in edges if e.module == searchtarget]
assert len(edges) > 0, f"Could not find module {searchtarget}"
assert len(edges) < 2, f"Found {len(edges)} modules for {searchtarget}"
return edges[0]


def get_input_id(edges: tp.List[Edge], module: tp.Union[Module, str, None]) -> int:
"""Searches for module in the list of edges and returns the ID of its input array"""
edge = search_edges(edges, module)
input_ids = jax.tree_leaves(edge.input_ids)
assert len(input_ids) == 1, "Multi-input modules not yet supported"
return input_ids[0]


def get_output_id(edges: tp.List[Edge], module: tp.Union[Module, str, None]) -> int:
"""Searches for module in the list of edges and returns the ID of its output array"""
edge = search_edges(edges, module)
assert len(edge.output_ids) == 1, "Multi-output modules not yet supported"
return edge.output_ids[0]


def merge_args_kwargs(*args, **kwargs) -> tp.List[tp.Tuple[tp.Any, tp.Any]]:
"""Merges args and kwargs and their indices to a list of tuples
e.g. merge_args_kwargs(0, 77, a=-2) returns [(0,0), (1,77), ('a',-2)]"""
return list(enumerate(args)) + list(kwargs.items())


def construct_graph(edges: tp.List[Edge]) -> nx.DiGraph:
"""Constructs a directed graph with IDs of input/output arrays representing the nodes
and modules (and some more infos) representing the edges"""
G = nx.DiGraph()
for e in edges:
merged_args_kwargs = merge_args_kwargs(*e.input_ids[0], **e.input_ids[1])
inout_combos = itertools.product(merged_args_kwargs, enumerate(e.output_ids))
for ((inkey, input_id), (outkey, output_id)) in inout_combos:
depth = e.modulename.count("/")
# it can happen that there are multiple connections between two nodes
# e.g. when a simple parent module has only one child module
# use the one with the lowest depth, i.e. the parent module
if ((input_id, output_id) not in G.edges) or (
G[input_id][output_id].depth > depth
):
G.add_edge(
input_id,
output_id,
inkey=inkey,
outkey=outkey,
depth=depth,
**e.__dict__,
)
return G


def find_path(graph: nx.DiGraph, start_node: int, end_node: int) -> nx.DiGraph:
"""Returns a new graph with only nodes and edges from start_node to end_node"""

startname = list(graph[start_node].values())[0]["modulename"]
endname = list(graph.reverse()[end_node].values())[0]["modulename"]

try:
pathnodes = nx.shortest_path(graph, start_node, end_node)
except nx.NetworkXNoPath:
raise RuntimeError(
f"No path from {startname} to {endname}. Make sure all operations inbetween are performed by modules."
) from None
if len(pathnodes) < 2:
raise RuntimeError(
f"No operations between the input of {startname} and the output of {endname}."
) from None
pathgraph = graph.subgraph(pathnodes).copy()
# pathgraph is unordered, need to mark input and output edges
pathgraph[pathnodes[0]][pathnodes[1]]["is_input"] = True
pathgraph[pathnodes[-2]][pathnodes[-1]]["is_output"] = True
return pathgraph


def combine_paths(paths: tp.List[nx.DiGraph]) -> nx.DiGraph:
return nx.algorithms.compose_all(paths)


class SlicedModule(elegy.Module):
def __init__(self, tree: nx.DiGraph):
super().__init__()
# adding the all modules as attributes so that they get recognized by .get_parameters()
for edge in tree.edges.values():
attrname = edge["modulename"][1:].replace("/", "_")
setattr(self, attrname, edge["module"])

assert not hasattr(
self, "_tree"
), 'Modules with the name "_tree" are prohibited' # can this happen?
self._tree = tree

def call(self, x: tp.Any) -> tp.Union[tp.Any, tp.Tuple[tp.Any]]:
input_nodes = [
nodes[0]
for nodes, edge in self._tree.edges.items()
if edge.get("is_input", False)
]
# should not happen
assert len(set(input_nodes)) > 0, "could not find any input nodes"
assert len(set(input_nodes)) < 2, "multi-inputs not yet supported"
start_node = input_nodes[0]

outputs = self.visit_node(start_node, x)

outputs = tuple(outputs)
if len(outputs) == 1:
outputs = outputs[0]
return outputs

def visit_edge(self, edge: tp.Dict, x: tp.Any) -> tp.Any:
"""Performs the operation to get from node A to node B which the parameter "edge" connects"""
assert edge["inkey"] == 0, "inputs other than 0 not yet implemented"

x = edge["module"](x)

if isinstance(x, (tuple, list)):
# XXX: what if the whole tuple/list is needed as input later?
x = x[edge["outkey"]]

return x

def visit_node(self, node: int, x: tp.Any) -> tp.List[tp.Any]:
"""Recursively visits all nodes starting from the parameter "node" and collects outputs."""
outputs = []
for nextnode, edge in self._tree[node].items():
y = self.visit_edge(edge, x)
if edge.get("is_output", False):
outputs.append(y)
outputs.extend(self.visit_node(nextnode, y))

return outputs
123 changes: 123 additions & 0 deletions elegy/module_slicing_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import elegy
import elegy.module_slicing
from unittest import TestCase
import jax, jax.numpy as jnp
import optax


class ModuleSlicingTest(TestCase):
def test_basic_slice_by_ref(self):
x = jnp.zeros((32, 100))
basicmodule = BasicModule0()
basicmodule(x) # trigger creation of weights and submodules
submodule = elegy.module_slicing.slice_module_from_to(
basicmodule, basicmodule.linear0, basicmodule.linear1, x
)
submodel = elegy.Model(submodule)
submodel.summary(x)
assert submodel.predict(x).shape == (32, 10)
assert jnp.all(submodel.predict(x) == basicmodule.test_call(x))

def test_basic_slice_by_name(self):
x = jnp.zeros((32, 100))
START_END_COMBOS = [("linear0", "linear1"), (None, "/linear1")]
for start, end in START_END_COMBOS:
print(start, end)
basicmodule = BasicModule0()
submodule = elegy.module_slicing.slice_module_from_to(
basicmodule, start, end, x
)
submodel = elegy.Model(submodule)
submodel.summary(x)
assert submodel.predict(x).shape == (32, 10)
assert jnp.all(submodel.predict(x) == basicmodule.test_call(x))

def test_resnet_multi_out(self):
x = jnp.zeros((2, 224, 224, 3))
resnet = elegy.nets.resnet.ResNet18()
submodule = elegy.module_slicing.slice_module_from_to(
resnet,
start_module=None,
end_module=[
"/res_net_block_1",
"/res_net_block_3",
"/res_net_block_5",
"/res_net_block_6",
"/res_net_block_7",
],
sample_input=x,
)
submodel = elegy.Model(submodule)
# submodel.summary(x)
outputs = submodel.predict(x)
print(jax.tree_map(jnp.shape, outputs))
assert len(outputs) == 5
assert outputs[0].shape == (2, 56, 56, 64)
assert outputs[1].shape == (2, 28, 28, 128)
assert outputs[2].shape == (2, 14, 14, 256)
assert outputs[3].shape == (2, 7, 7, 512)
assert outputs[4].shape == (2, 7, 7, 512)

print(jax.tree_map(jnp.shape, resnet.get_parameters()))
print(jax.tree_map(jnp.shape, submodel.get_parameters()))

def test_retrain(self):
x = jnp.ones((32, 100))
y = jnp.zeros((32, 10))

basicmodule = BasicModule0()
submodule = elegy.module_slicing.slice_module_from_to(
basicmodule, "linear0", "linear1", x
)
submodel = elegy.Model(
submodule,
loss=elegy.losses.MeanAbsoluteError(),
optimizer=optax.adamw(1e-3),
)
y0 = submodel.predict(x)
y1 = basicmodule.test_call(x)

submodel.fit(x, y, epochs=3, verbose=2)

y2 = submodel.predict(x)
y3 = basicmodule.test_call(x)

assert jnp.all(y2 == y3)
# output after training should be closer to zero because targets are zero
assert jnp.abs(y2.mean()) < jnp.abs(y0.mean())

def test_no_path(self):
x = jnp.ones((32, 100))
basicmodule = BasicModule0()
try:
submodule = elegy.module_slicing.slice_module_from_to(
basicmodule, "linear2", "linear0", x
)
except RuntimeError as e:
assert e.args[0].startswith("No path from /linear2 to /linear0")
else:
assert False, "No error or wrong error raised"

try:
submodule = elegy.module_slicing.slice_module_from_to(
basicmodule, "linear1", "linear0", x
)
except RuntimeError as e:
assert e.args[0].startswith(
"No operations between the input of /linear1 and the output of /linear0"
)
else:
assert False, "No error or wrong error raised"


class BasicModule0(elegy.Module):
def call(self, x):
x = elegy.nn.Linear(25, name="linear0")(x)
x = elegy.nn.Linear(10, name="linear1")(x)
x = elegy.nn.Linear(5, name="linear2")(x)
return x

def test_call(self, x):
x = self.linear0(x)
x = self.linear1(x)
return x
Loading