Skip to content

Commit

Permalink
Migrating Flax from returning FrozenDicts to returning regular dicts.…
Browse files Browse the repository at this point in the history
… More details can be found in this [announcement](#3191)

PiperOrigin-RevId: 547394685
  • Loading branch information
chiamp authored and Flax Authors committed Jul 19, 2023
1 parent 15d6857 commit 36f0331
Show file tree
Hide file tree
Showing 8 changed files with 996 additions and 1,062 deletions.
2,012 changes: 991 additions & 1,021 deletions docs/guides/flax_basics.ipynb

Large diffs are not rendered by default.

18 changes: 1 addition & 17 deletions docs/guides/flax_basics.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,22 +96,6 @@ The result is what we expect: bias and kernel parameters of the correct size. Un
* Initialization functions are called to generate the initial set of parameters that the model will use. Those are functions that take as arguments `(PRNG Key, shape, dtype)` and return an Array of shape `shape`.
* The init function returns the initialized set of parameters (you can also get the output of the forward pass on the dummy input with the same syntax by using the `init_with_output` method instead of `init`.

+++ {"id": "3yL9mKk7naJn"}

The output shows that the parameters are stored in a `FrozenDict` instance, which helps deal with the functional nature of JAX by preventing any mutation of the underlying dict and making the user aware of it. Read more about it in the [`flax.core.frozen_dict.FrozenDict` API docs](https://flax.readthedocs.io/en/latest/api_reference/flax.core.frozen_dict.html#flax.core.frozen_dict.FrozenDict).

As a consequence, the following doesn't work:

```{code-cell}
:id: HtOFWeiynaJo
:outputId: 689b4230-2a3d-4823-d103-2858e6debc4d
try:
params['new_key'] = jnp.ones((2,2))
except ValueError as e:
print("Error: ", e)
```

+++ {"id": "M1qo9M3_naJo"}

To conduct a forward pass with the model with a given set of parameters (which are never stored with the model), we just use the `apply` method by providing it the parameters to use as well as the input:
Expand Down Expand Up @@ -546,4 +530,4 @@ Flax provides a handy wrapper - `TrainState` - that simplifies the above code. C

### Exporting to Tensorflow's SavedModel with jax2tf

JAX released an experimental converter called [jax2tf](https://github.com/google/jax/tree/main/jax/experimental/jax2tf), which allows converting trained Flax models into Tensorflow's SavedModel format (so it can be used for [TF Hub](https://www.tensorflow.org/hub), [TF.lite](https://www.tensorflow.org/lite), [TF.js](https://www.tensorflow.org/js), or other downstream applications). The repository contains more documentation and has various examples for Flax.
JAX released an experimental converter called [jax2tf](https://github.com/google/jax/tree/main/jax/experimental/jax2tf), which allows converting trained Flax models into Tensorflow's SavedModel format (so it can be used for [TF Hub](https://www.tensorflow.org/hub), [TF.lite](https://www.tensorflow.org/lite), [TF.js](https://www.tensorflow.org/js), or other downstream applications). The repository contains more documentation and has various examples for Flax.
7 changes: 4 additions & 3 deletions flax/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,12 @@ def temp_flip_flag(var_name: str, var_value: bool):
default=False,
help=("When adopting outside modules, don't clobber existing names."))

#TODO(marcuschiam): remove this feature flag once regular dict migration is complete
# TODO(marcuschiam): remove this feature flag once regular dict migration is complete
flax_return_frozendict = define_bool_state(
name='return_frozendict',
default=True,
help=('Whether to return FrozenDicts when calling init or apply.'))
default=False,
help='Whether to return FrozenDicts when calling init or apply.',
)

flax_fix_rng = define_bool_state(
name ='fix_rng_separator',
Expand Down
1 change: 0 additions & 1 deletion tests/core/core_lift_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,6 @@ def false_fn(scope, x):
self.assertEqual(vars['state'], {'true_count': 1, 'false_count': 1})
np.testing.assert_allclose(y1, -y2)

@temp_flip_flag('return_frozendict', False)
def test_switch(self):
def f(scope, x, index):
scope.variable('state', 'a_count', lambda: 0)
Expand Down
1 change: 0 additions & 1 deletion tests/linen/linen_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ def test_causal_mask_1d(self):
np.testing.assert_allclose(mask_1d, mask_1d_simple,)

@parameterized.parameters([((5,), (1,)), ((6, 5), (2,))])
@temp_flip_flag('return_frozendict', False)
def test_decoding(self, spatial_shape, attn_dims):
bs = 2
num_heads = 3
Expand Down
7 changes: 0 additions & 7 deletions tests/linen/linen_module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1098,7 +1098,6 @@ def test(self):
A().test()
self.assertFalse(setup_called)

@temp_flip_flag('return_frozendict', False)
def test_module_pass_as_attr(self):

class A(nn.Module):
Expand Down Expand Up @@ -1129,7 +1128,6 @@ def __call__(self, x):
}
self.assertTrue(tree_equals(var_shapes, ref_var_shapes))

@temp_flip_flag('return_frozendict', False)
def test_module_pass_in_closure(self):
a = nn.Dense(2)

Expand All @@ -1154,7 +1152,6 @@ def __call__(self, x):
self.assertTrue(tree_equals(var_shapes, ref_var_shapes))
self.assertIsNone(a.name)

@temp_flip_flag('return_frozendict', False)
def test_toplevel_submodule_adoption(self):

class Encoder(nn.Module):
Expand Down Expand Up @@ -1210,7 +1207,6 @@ def __call__(self, x):
}
self.assertTrue(tree_equals(var_shapes, ref_var_shapes))

@temp_flip_flag('return_frozendict', False)
def test_toplevel_submodule_adoption_pytree(self):

class A(nn.Module):
Expand Down Expand Up @@ -1254,7 +1250,6 @@ def __call__(self, c, x):
lambda x, y: np.testing.assert_allclose(x, y, atol=1e-7),
counters, ref_counters)))

@temp_flip_flag('return_frozendict', False)
def test_toplevel_submodule_adoption_sharing(self):
dense = functools.partial(nn.Dense, use_bias=False)

Expand Down Expand Up @@ -1305,7 +1300,6 @@ def __call__(self, x):
}
self.assertTrue(tree_equals(var_shapes, ref_var_shapes))

@temp_flip_flag('return_frozendict', False)
def test_toplevel_named_submodule_adoption(self):
dense = functools.partial(nn.Dense, use_bias=False)

Expand Down Expand Up @@ -1360,7 +1354,6 @@ def __call__(self, x):
}
self.assertTrue(tree_equals(var_shapes, ref_var_shapes))

@temp_flip_flag('return_frozendict', False)
def test_toplevel_submodule_pytree_adoption_sharing(self):

class A(nn.Module):
Expand Down
5 changes: 0 additions & 5 deletions tests/linen/linen_transforms_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -807,7 +807,6 @@ def __call__(self, x):
y3 = Ctrafo(a2, b).apply(p2, x)
np.testing.assert_allclose(y1, y3, atol=1e-7)

@temp_flip_flag('return_frozendict', False)
def test_toplevel_submodule_adoption_pytree_transform(self):
class A(nn.Module):
@nn.compact
Expand Down Expand Up @@ -852,7 +851,6 @@ def __call__(self, c, x):
cntrs, ref_cntrs)
))

@temp_flip_flag('return_frozendict', False)
def test_partially_applied_module_constructor_transform(self):
k = random.PRNGKey(0)
x = jnp.ones((3,4,4))
Expand All @@ -870,7 +868,6 @@ def test_partially_applied_module_constructor_transform(self):
}
self.assertTrue(tree_equals(init_vars_shapes, ref_var_shapes))

@temp_flip_flag('return_frozendict', False)
def test_partial_module_method(self):
k = random.PRNGKey(0)
x = jnp.ones((3,4,4))
Expand Down Expand Up @@ -1505,7 +1502,6 @@ def false_fn(mdl, x):

return nn.cond(pred, true_fn, false_fn, self, x)

@temp_flip_flag('return_frozendict', False)
def test_switch(self):
class Foo(nn.Module):
@nn.compact
Expand Down Expand Up @@ -1540,7 +1536,6 @@ def c_fn(mdl, x):
self.assertEqual(vars['state'], {'a_count': 1, 'b_count': 1, 'c_count': 1})
np.testing.assert_allclose(y1, y3)

@temp_flip_flag('return_frozendict', False)
def test_switch_multihead(self):
class Foo(nn.Module):
def setup(self) -> None:
Expand Down
7 changes: 0 additions & 7 deletions tests/linen/summary_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,6 @@ def test_module_summary(self):
row.counted_variables,
)

@temp_flip_flag('return_frozendict', False)
def test_module_summary_with_depth(self):
"""
This test creates a Table using `module_summary` set the `depth` argument to `1`,
Expand Down Expand Up @@ -240,7 +239,6 @@ def test_module_summary_with_depth(self):
self.assertEqual(table[3].module_variables, table[3].counted_variables)


@temp_flip_flag('return_frozendict', False)
def test_tabulate(self):
"""
This test creates a string representation of a Module using `Module.tabulate`
Expand Down Expand Up @@ -323,7 +321,6 @@ def test_tabulate_with_method(self):
self.assertIn("(block_method)", module_repr)
self.assertIn("(cnn_method)", module_repr)

@temp_flip_flag('return_frozendict', False)
def test_tabulate_function(self):
"""
This test creates a string representation of a Module using `Module.tabulate`
Expand Down Expand Up @@ -370,7 +367,6 @@ def test_tabulate_function(self):
self.assertIn("79.4 KB", lines[-3])


@temp_flip_flag('return_frozendict', False)
def test_lifted_transform(self):
class LSTM(nn.Module):
features: int
Expand Down Expand Up @@ -406,7 +402,6 @@ def __call__(self, x):
self.assertIn("ScanLSTM/ii", lines[13])
self.assertIn("Dense", lines[13])

@temp_flip_flag('return_frozendict', False)
def test_lifted_transform_no_rename(self):
class LSTM(nn.Module):
features: int
Expand Down Expand Up @@ -442,7 +437,6 @@ def __call__(self, x):
self.assertIn("ScanLSTMCell_0/ii", lines[13])
self.assertIn("Dense", lines[13])

@temp_flip_flag('return_frozendict', False)
def test_module_reuse(self):
class ConvBlock(nn.Module):
@nn.compact
Expand Down Expand Up @@ -524,7 +518,6 @@ def __call__(self):
self.assertIn('x: 3.141592', lines[7])
self.assertIn('4.141592', lines[7])

@temp_flip_flag('return_frozendict', False)
def test_partitioned_params(self):

class Classifier(nn.Module):
Expand Down

0 comments on commit 36f0331

Please sign in to comment.