Skip to content

Commit

Permalink
Migrating Flax from returning FrozenDicts to returning regular dicts.…
Browse files Browse the repository at this point in the history
… More details can be found in this [announcement](#3191)

PiperOrigin-RevId: 547394685
  • Loading branch information
chiamp authored and Flax Authors committed Jul 19, 2023
1 parent 290e50f commit c7980b2
Show file tree
Hide file tree
Showing 8 changed files with 1,005 additions and 1,024 deletions.
1,987 changes: 994 additions & 993 deletions docs/guides/flax_basics.ipynb

Large diffs are not rendered by default.

14 changes: 7 additions & 7 deletions docs/guides/flax_basics.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ Here we provide the code needed to set up the environment for our notebook.
import jax
from typing import Any, Callable, Sequence
from jax import lax, random, numpy as jnp
from flax.core import freeze, unfreeze
import flax
from flax import linen as nn
```

Expand Down Expand Up @@ -132,7 +132,7 @@ k1, k2 = random.split(key)
W = random.normal(k1, (x_dim, y_dim))
b = random.normal(k2, (y_dim,))
# Store the parameters in a FrozenDict pytree.
true_params = freeze({'params': {'bias': b, 'kernel': W}})
true_params = flax.core.freeze({'params': {'bias': b, 'kernel': W}})
# Generate samples with additional noise.
key_sample, key_noise = random.split(k1)
Expand Down Expand Up @@ -306,7 +306,7 @@ model = ExplicitMLP(features=[3,4,5])
params = model.init(key2, x)
y = model.apply(params, x)
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, unfreeze(params)))
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params)))
print('output:\n', y)
```

Expand Down Expand Up @@ -362,7 +362,7 @@ model = SimpleMLP(features=[3,4,5])
params = model.init(key2, x)
y = model.apply(params, x)
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, unfreeze(params)))
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params)))
print('output:\n', y)
```

Expand Down Expand Up @@ -476,8 +476,8 @@ Here, `updated_state` returns only the state variables that are being mutated by
for val in [1.0, 2.0, 3.0]:
x = val * jnp.ones((10,5))
y, updated_state = model.apply(variables, x, mutable=['batch_stats'])
old_state, params = variables.pop('params')
variables = freeze({'params': params, **updated_state})
old_state, params = flax.core.pop(variables, 'params')
variables = flax.core.freeze({'params': params, **updated_state})
print('updated state:\n', updated_state) # Shows only the mutable part
```

Expand Down Expand Up @@ -509,7 +509,7 @@ def update_step(tx, apply_fn, x, opt_state, params, state):
x = jnp.ones((10,5))
variables = model.init(random.PRNGKey(0), x)
state, params = variables.pop('params')
state, params = flax.core.pop(variables, 'params')
del variables
tx = optax.sgd(learning_rate=0.02)
opt_state = tx.init(params)
Expand Down
7 changes: 4 additions & 3 deletions flax/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,12 @@ def temp_flip_flag(var_name: str, var_value: bool):
default=False,
help=("When adopting outside modules, don't clobber existing names."))

#TODO(marcuschiam): remove this feature flag once regular dict migration is complete
# TODO(marcuschiam): remove this feature flag once regular dict migration is complete
flax_return_frozendict = define_bool_state(
name='return_frozendict',
default=True,
help=('Whether to return FrozenDicts when calling init or apply.'))
default=False,
help='Whether to return FrozenDicts when calling init or apply.',
)

flax_fix_rng = define_bool_state(
name ='fix_rng_separator',
Expand Down
1 change: 0 additions & 1 deletion tests/core/core_lift_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,6 @@ def false_fn(scope, x):
self.assertEqual(vars['state'], {'true_count': 1, 'false_count': 1})
np.testing.assert_allclose(y1, -y2)

@temp_flip_flag('return_frozendict', False)
def test_switch(self):
def f(scope, x, index):
scope.variable('state', 'a_count', lambda: 0)
Expand Down
1 change: 0 additions & 1 deletion tests/linen/linen_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ def test_causal_mask_1d(self):
np.testing.assert_allclose(mask_1d, mask_1d_simple,)

@parameterized.parameters([((5,), (1,)), ((6, 5), (2,))])
@temp_flip_flag('return_frozendict', False)
def test_decoding(self, spatial_shape, attn_dims):
bs = 2
num_heads = 3
Expand Down
7 changes: 0 additions & 7 deletions tests/linen/linen_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1098,7 +1098,6 @@ def test(self):
A().test()
self.assertFalse(setup_called)

@temp_flip_flag('return_frozendict', False)
def test_module_pass_as_attr(self):

class A(nn.Module):
Expand Down Expand Up @@ -1129,7 +1128,6 @@ def __call__(self, x):
}
self.assertTrue(tree_equals(var_shapes, ref_var_shapes))

@temp_flip_flag('return_frozendict', False)
def test_module_pass_in_closure(self):
a = nn.Dense(2)

Expand All @@ -1154,7 +1152,6 @@ def __call__(self, x):
self.assertTrue(tree_equals(var_shapes, ref_var_shapes))
self.assertIsNone(a.name)

@temp_flip_flag('return_frozendict', False)
def test_toplevel_submodule_adoption(self):

class Encoder(nn.Module):
Expand Down Expand Up @@ -1210,7 +1207,6 @@ def __call__(self, x):
}
self.assertTrue(tree_equals(var_shapes, ref_var_shapes))

@temp_flip_flag('return_frozendict', False)
def test_toplevel_submodule_adoption_pytree(self):

class A(nn.Module):
Expand Down Expand Up @@ -1254,7 +1250,6 @@ def __call__(self, c, x):
lambda x, y: np.testing.assert_allclose(x, y, atol=1e-7),
counters, ref_counters)))

@temp_flip_flag('return_frozendict', False)
def test_toplevel_submodule_adoption_sharing(self):
dense = functools.partial(nn.Dense, use_bias=False)

Expand Down Expand Up @@ -1305,7 +1300,6 @@ def __call__(self, x):
}
self.assertTrue(tree_equals(var_shapes, ref_var_shapes))

@temp_flip_flag('return_frozendict', False)
def test_toplevel_named_submodule_adoption(self):
dense = functools.partial(nn.Dense, use_bias=False)

Expand Down Expand Up @@ -1360,7 +1354,6 @@ def __call__(self, x):
}
self.assertTrue(tree_equals(var_shapes, ref_var_shapes))

@temp_flip_flag('return_frozendict', False)
def test_toplevel_submodule_pytree_adoption_sharing(self):

class A(nn.Module):
Expand Down
5 changes: 0 additions & 5 deletions tests/linen/linen_transforms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,7 +807,6 @@ def __call__(self, x):
y3 = Ctrafo(a2, b).apply(p2, x)
np.testing.assert_allclose(y1, y3, atol=1e-7)

@temp_flip_flag('return_frozendict', False)
def test_toplevel_submodule_adoption_pytree_transform(self):
class A(nn.Module):
@nn.compact
Expand Down Expand Up @@ -852,7 +851,6 @@ def __call__(self, c, x):
cntrs, ref_cntrs)
))

@temp_flip_flag('return_frozendict', False)
def test_partially_applied_module_constructor_transform(self):
k = random.PRNGKey(0)
x = jnp.ones((3,4,4))
Expand All @@ -870,7 +868,6 @@ def test_partially_applied_module_constructor_transform(self):
}
self.assertTrue(tree_equals(init_vars_shapes, ref_var_shapes))

@temp_flip_flag('return_frozendict', False)
def test_partial_module_method(self):
k = random.PRNGKey(0)
x = jnp.ones((3,4,4))
Expand Down Expand Up @@ -1505,7 +1502,6 @@ def false_fn(mdl, x):

return nn.cond(pred, true_fn, false_fn, self, x)

@temp_flip_flag('return_frozendict', False)
def test_switch(self):
class Foo(nn.Module):
@nn.compact
Expand Down Expand Up @@ -1540,7 +1536,6 @@ def c_fn(mdl, x):
self.assertEqual(vars['state'], {'a_count': 1, 'b_count': 1, 'c_count': 1})
np.testing.assert_allclose(y1, y3)

@temp_flip_flag('return_frozendict', False)
def test_switch_multihead(self):
class Foo(nn.Module):
def setup(self) -> None:
Expand Down
7 changes: 0 additions & 7 deletions tests/linen/summary_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,6 @@ def test_module_summary(self):
row.counted_variables,
)

@temp_flip_flag('return_frozendict', False)
def test_module_summary_with_depth(self):
"""
This test creates a Table using `module_summary` set the `depth` argument to `1`,
Expand Down Expand Up @@ -240,7 +239,6 @@ def test_module_summary_with_depth(self):
self.assertEqual(table[3].module_variables, table[3].counted_variables)


@temp_flip_flag('return_frozendict', False)
def test_tabulate(self):
"""
This test creates a string representation of a Module using `Module.tabulate`
Expand Down Expand Up @@ -323,7 +321,6 @@ def test_tabulate_with_method(self):
self.assertIn("(block_method)", module_repr)
self.assertIn("(cnn_method)", module_repr)

@temp_flip_flag('return_frozendict', False)
def test_tabulate_function(self):
"""
This test creates a string representation of a Module using `Module.tabulate`
Expand Down Expand Up @@ -370,7 +367,6 @@ def test_tabulate_function(self):
self.assertIn("79.4 KB", lines[-3])


@temp_flip_flag('return_frozendict', False)
def test_lifted_transform(self):
class LSTM(nn.Module):
features: int
Expand Down Expand Up @@ -406,7 +402,6 @@ def __call__(self, x):
self.assertIn("ScanLSTM/ii", lines[13])
self.assertIn("Dense", lines[13])

@temp_flip_flag('return_frozendict', False)
def test_lifted_transform_no_rename(self):
class LSTM(nn.Module):
features: int
Expand Down Expand Up @@ -442,7 +437,6 @@ def __call__(self, x):
self.assertIn("ScanLSTMCell_0/ii", lines[13])
self.assertIn("Dense", lines[13])

@temp_flip_flag('return_frozendict', False)
def test_module_reuse(self):
class ConvBlock(nn.Module):
@nn.compact
Expand Down Expand Up @@ -524,7 +518,6 @@ def __call__(self):
self.assertIn('x: 3.141592', lines[7])
self.assertIn('4.141592', lines[7])

@temp_flip_flag('return_frozendict', False)
def test_partitioned_params(self):

class Classifier(nn.Module):
Expand Down

0 comments on commit c7980b2

Please sign in to comment.