Skip to content

Commit

Permalink
test fixes and black
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-g committed Feb 20, 2021
1 parent 330c3e9 commit 14046c2
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 10 deletions.
4 changes: 2 additions & 2 deletions elegy/hooks_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_summaries(self):
elegy.hooks.add_summary(("a", 0, "b"), None, 2.0)
summaries = elegy.hooks.get_summaries()

assert summaries[0] == (("a", 0, "b"), None, 2.0)
assert summaries[0] == (("a", 0, "b"), None, 2.0, None)

def test_no_summaries(self):
assert not elegy.hooks.summaries_active()
Expand Down Expand Up @@ -65,4 +65,4 @@ def f(x):
assert x == 6
assert losses["x_loss"] == 6
assert metrics["x"] == 7
assert summaries[0] == (("a", 0, "b"), jax.nn.relu, 8)
assert summaries[0] == (("a", 0, "b"), jax.nn.relu, 8, None)
9 changes: 7 additions & 2 deletions elegy/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,13 @@ def __call__(self, *args, **kwargs) -> tp.Any:
def call(self, *args, **kwargs):
...

def add_summary(self, name: str, f: tp.Any, value: tp.Any, input_values: tp.Optional[tp.Tuple[tp.Tuple, tp.Dict]] = None):
def add_summary(
self,
name: str,
f: tp.Any,
value: tp.Any,
input_values: tp.Optional[tp.Tuple[tp.Tuple, tp.Dict]] = None,
):
if hooks.summaries_active():
path = get_module_path(self) + (name,)
assert path is not None
Expand Down Expand Up @@ -623,7 +629,6 @@ def slice(
self, start_module, end_module, sample_input
)


def update_parameter(self, name: str, value: tp.Any) -> None:
"""
Update a parameter of the current module.
Expand Down
2 changes: 1 addition & 1 deletion elegy/module_slicing.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class Edge:
def __init__(self, summary: elegy.types.SummaryTableEntry):
self.module = summary.module
# standardize paths with a leading '/'
self.modulename = '/'+summary.path
self.modulename = "/" + summary.path
# convert the output and input arrays in the summary to unique IDs as returned by id()
self.output_ids = jax.tree_leaves(jax.tree_map(id, summary.output_value))
self.input_ids = jax.tree_map(id, summary.input_value)
Expand Down
2 changes: 1 addition & 1 deletion elegy/module_slicing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_basic_slice_by_name(self):
submodule = basicmodule.slice(start, end, x)
submodel = elegy.Model(submodule)
basicmodule.init(rng=elegy.RNGSeq(0), set_defaults=True)(x)
#submodel.summary(x)
# submodel.summary(x)
assert submodel.predict(x).shape == (32, 10)
assert jnp.all(submodel.predict(x) == basicmodule.test_call(x))

Expand Down
4 changes: 4 additions & 0 deletions elegy/module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,11 +214,13 @@ def call(self, x):
("ais/1",),
m.ais[1],
12,
((2.0,), {}),
),
(
(),
m,
13,
((2.0,), {}),
),
]
assert parameters == {
Expand Down Expand Up @@ -256,11 +258,13 @@ def call(self, x):
("a_1",),
m.a_1,
12,
((2.0,), {}),
),
(
(),
m,
13,
((2.0,), {}),
),
]
assert params == {
Expand Down
8 changes: 4 additions & 4 deletions elegy/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,11 +121,11 @@ class Summary(tp.NamedTuple):
input_values: tp.Union[tp.Tuple[tp.Tuple, tp.Dict], None] = None

def tree_flatten(self):
return ((self.value,self.input_values), (self.path, self.module))
return ((self.value, self.input_values), (self.path, self.module))

@classmethod
def tree_unflatten(cls, aux_data, children):
(value,input_values) = children
(value, input_values) = children
path, module = aux_data

return cls(path, module, value, input_values)
Expand Down Expand Up @@ -168,7 +168,7 @@ def totals_entry(

def tree_flatten(self):
return (
(self.output_value,self.input_value),
(self.output_value, self.input_value),
(
self.path,
self.module_type_name,
Expand All @@ -191,7 +191,7 @@ def tree_unflatten(cls, aux_data, children):
non_trainable_params_count,
non_trainable_params_size,
) = aux_data
(output_value,input_value) = children
(output_value, input_value) = children

return cls(
path=path,
Expand Down

0 comments on commit 14046c2

Please sign in to comment.