From 7509a93084783563938fe294a2f675271210ac94 Mon Sep 17 00:00:00 2001 From: Marcus Chiam Date: Tue, 11 Jul 2023 23:01:45 -0700 Subject: [PATCH] Migrating Flax from returning FrozenDicts to returning regular dicts. More details can be found in this [announcement](https://github.com/google/flax/discussions/3191) PiperOrigin-RevId: 547394685 --- docs/guides/flax_basics.ipynb | 14 +++++++------- docs/guides/flax_basics.md | 14 +++++++------- docs/notebooks/linen_intro.ipynb | 20 ++++++++++---------- docs/notebooks/linen_intro.md | 20 ++++++++++---------- docs/notebooks/state_params.ipynb | 5 ++--- docs/notebooks/state_params.md | 5 ++--- flax/configurations.py | 7 ++++--- tests/core/core_lift_test.py | 1 - tests/linen/linen_attention_test.py | 1 - tests/linen/linen_module_test.py | 7 ------- tests/linen/linen_transforms_test.py | 5 ----- tests/linen/summary_test.py | 7 ------- 12 files changed, 42 insertions(+), 64 deletions(-) diff --git a/docs/guides/flax_basics.ipynb b/docs/guides/flax_basics.ipynb index e9b4df0192..ac1cdad7b8 100644 --- a/docs/guides/flax_basics.ipynb +++ b/docs/guides/flax_basics.ipynb @@ -69,7 +69,7 @@ "import jax\n", "from typing import Any, Callable, Sequence\n", "from jax import lax, random, numpy as jnp\n", - "import flax\n", + "from flax.core import freeze, unfreeze\n", "from flax import linen as nn" ] }, @@ -246,7 +246,7 @@ "W = random.normal(k1, (x_dim, y_dim))\n", "b = random.normal(k2, (y_dim,))\n", "# Store the parameters in a FrozenDict pytree.\n", - "true_params = flax.core.freeze({'params': {'bias': b, 'kernel': W}})\n", + "true_params = freeze({'params': {'bias': b, 'kernel': W}})\n", "\n", "# Generate samples with additional noise.\n", "key_sample, key_noise = random.split(k1)\n", @@ -604,7 +604,7 @@ "params = model.init(key2, x)\n", "y = model.apply(params, x)\n", "\n", - "print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params)))\n", + "print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, unfreeze(params)))\n", "print('output:\\n', y)" ] }, @@ -706,7 +706,7 @@ "params = model.init(key2, x)\n", "y = model.apply(params, x)\n", "\n", - "print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params)))\n", + "print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, unfreeze(params)))\n", "print('output:\\n', y)" ] }, @@ -929,8 +929,8 @@ "for val in [1.0, 2.0, 3.0]:\n", " x = val * jnp.ones((10,5))\n", " y, updated_state = model.apply(variables, x, mutable=['batch_stats'])\n", - " old_state, params = flax.core.pop(variables, 'params')\n", - " variables = flax.core.freeze({'params': params, **updated_state})\n", + " old_state, params = variables.pop('params')\n", + " variables = freeze({'params': params, **updated_state})\n", " print('updated state:\\n', updated_state) # Shows only the mutable part" ] }, @@ -994,7 +994,7 @@ "\n", "x = jnp.ones((10,5))\n", "variables = model.init(random.PRNGKey(0), x)\n", - "state, params = flax.core.pop(variables, 'params')\n", + "state, params = variables.pop('params')\n", "del variables\n", "tx = optax.sgd(learning_rate=0.02)\n", "opt_state = tx.init(params)\n", diff --git a/docs/guides/flax_basics.md b/docs/guides/flax_basics.md index dfaa012052..81682e3140 100644 --- a/docs/guides/flax_basics.md +++ b/docs/guides/flax_basics.md @@ -46,7 +46,7 @@ Here we provide the code needed to set up the environment for our notebook. import jax from typing import Any, Callable, Sequence from jax import lax, random, numpy as jnp -import flax +from flax.core import freeze, unfreeze from flax import linen as nn ``` @@ -132,7 +132,7 @@ k1, k2 = random.split(key) W = random.normal(k1, (x_dim, y_dim)) b = random.normal(k2, (y_dim,)) # Store the parameters in a FrozenDict pytree. -true_params = flax.core.freeze({'params': {'bias': b, 'kernel': W}}) +true_params = freeze({'params': {'bias': b, 'kernel': W}}) # Generate samples with additional noise. key_sample, key_noise = random.split(k1) @@ -306,7 +306,7 @@ model = ExplicitMLP(features=[3,4,5]) params = model.init(key2, x) y = model.apply(params, x) -print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params))) +print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, unfreeze(params))) print('output:\n', y) ``` @@ -362,7 +362,7 @@ model = SimpleMLP(features=[3,4,5]) params = model.init(key2, x) y = model.apply(params, x) -print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params))) +print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, unfreeze(params))) print('output:\n', y) ``` @@ -476,8 +476,8 @@ Here, `updated_state` returns only the state variables that are being mutated by for val in [1.0, 2.0, 3.0]: x = val * jnp.ones((10,5)) y, updated_state = model.apply(variables, x, mutable=['batch_stats']) - old_state, params = flax.core.pop(variables, 'params') - variables = flax.core.freeze({'params': params, **updated_state}) + old_state, params = variables.pop('params') + variables = freeze({'params': params, **updated_state}) print('updated state:\n', updated_state) # Shows only the mutable part ``` @@ -509,7 +509,7 @@ def update_step(tx, apply_fn, x, opt_state, params, state): x = jnp.ones((10,5)) variables = model.init(random.PRNGKey(0), x) -state, params = flax.core.pop(variables, 'params') +state, params = variables.pop('params') del variables tx = optax.sgd(learning_rate=0.02) opt_state = tx.init(params) diff --git a/docs/notebooks/linen_intro.ipynb b/docs/notebooks/linen_intro.ipynb index ab272e83cd..ec16074ae1 100644 --- a/docs/notebooks/linen_intro.ipynb +++ b/docs/notebooks/linen_intro.ipynb @@ -75,7 +75,7 @@ "from typing import Any, Callable, Sequence, Optional\n", "import jax\n", "from jax import lax, random, numpy as jnp\n", - "import flax\n", + "from flax.core import freeze, unfreeze\n", "from flax import linen as nn" ] }, @@ -306,7 +306,7 @@ "init_variables = model.init(key2, x)\n", "y = model.apply(init_variables, x)\n", "\n", - "print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))\n", + "print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))\n", "print('output:\\n', y)" ] }, @@ -370,7 +370,7 @@ "init_variables = model.init(key2, x)\n", "y = model.apply(init_variables, x)\n", "\n", - "print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))\n", + "print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))\n", "print('output:\\n', y)" ] }, @@ -743,7 +743,7 @@ "model = Block(features=3, training=True)\n", "\n", "init_variables = model.init({'params': key2, 'dropout': key3}, x)\n", - "_, init_params = flax.core.pop(init_variables, 'params')\n", + "_, init_params = init_variables.pop('params')\n", "\n", "# When calling `apply` with mutable kinds, returns a pair of output,\n", "# mutated_variables.\n", @@ -752,8 +752,8 @@ "\n", "# Now we reassemble the full variables from the updates (in a real training\n", "# loop, with the updated params from an optimizer).\n", - "updated_variables = flax.core.freeze(dict(params=init_params,\n", - " **mutated_variables))\n", + "updated_variables = freeze(dict(params=init_params,\n", + " **mutated_variables))\n", "\n", "print('updated variables:\\n', updated_variables)\n", "print('initialized variable shapes:\\n',\n", @@ -842,7 +842,7 @@ "init_variables = model.init(key2, x)\n", "y = model.apply(init_variables, x)\n", "\n", - "print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))\n", + "print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))\n", "print('output:\\n', y)" ] }, @@ -913,7 +913,7 @@ "init_variables = model.init(key2, x)\n", "y = model.apply(init_variables, x)\n", "\n", - "print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))\n", + "print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))\n", "print('output:\\n', y)" ] }, @@ -1069,7 +1069,7 @@ " batch_axes=(0,))\n", "\n", "init_variables = model(train=False).init({'params': key2}, x, x)\n", - "print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))\n", + "print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))\n", "\n", "y = model(train=True).apply(init_variables, x, x, rngs={'dropout': key4})\n", "print('output:\\n', y.shape)" @@ -1149,7 +1149,7 @@ "model = SimpleScan(2)\n", "init_variables = model.init(key2, xs)\n", "\n", - "print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))\n", + "print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))\n", "\n", "y = model.apply(init_variables, xs)\n", "print('output:\\n', y)" diff --git a/docs/notebooks/linen_intro.md b/docs/notebooks/linen_intro.md index b3c94187fa..70dbc8337b 100644 --- a/docs/notebooks/linen_intro.md +++ b/docs/notebooks/linen_intro.md @@ -53,7 +53,7 @@ import functools from typing import Any, Callable, Sequence, Optional import jax from jax import lax, random, numpy as jnp -import flax +from flax.core import freeze, unfreeze from flax import linen as nn ``` @@ -154,7 +154,7 @@ model = ExplicitMLP(features=[3,4,5]) init_variables = model.init(key2, x) y = model.apply(init_variables, x) -print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables))) +print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables))) print('output:\n', y) ``` @@ -189,7 +189,7 @@ model = SimpleMLP(features=[3,4,5]) init_variables = model.init(key2, x) y = model.apply(init_variables, x) -print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables))) +print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables))) print('output:\n', y) ``` @@ -357,7 +357,7 @@ x = random.uniform(key1, (3,4,4)) model = Block(features=3, training=True) init_variables = model.init({'params': key2, 'dropout': key3}, x) -_, init_params = flax.core.pop(init_variables, 'params') +_, init_params = init_variables.pop('params') # When calling `apply` with mutable kinds, returns a pair of output, # mutated_variables. @@ -366,8 +366,8 @@ y, mutated_variables = model.apply( # Now we reassemble the full variables from the updates (in a real training # loop, with the updated params from an optimizer). -updated_variables = flax.core.freeze(dict(params=init_params, - **mutated_variables)) +updated_variables = freeze(dict(params=init_params, + **mutated_variables)) print('updated variables:\n', updated_variables) print('initialized variable shapes:\n', @@ -419,7 +419,7 @@ model = MLP(features=[3,4,5]) init_variables = model.init(key2, x) y = model.apply(init_variables, x) -print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables))) +print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables))) print('output:\n', y) ``` @@ -459,7 +459,7 @@ model = RematMLP(features=[3,4,5]) init_variables = model.init(key2, x) y = model.apply(init_variables, x) -print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables))) +print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables))) print('output:\n', y) ``` @@ -587,7 +587,7 @@ model = functools.partial( batch_axes=(0,)) init_variables = model(train=False).init({'params': key2}, x, x) -print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables))) +print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables))) y = model(train=True).apply(init_variables, x, x, rngs={'dropout': key4}) print('output:\n', y.shape) @@ -635,7 +635,7 @@ xs = random.uniform(key1, (1, 5, 2)) model = SimpleScan(2) init_variables = model.init(key2, xs) -print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables))) +print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables))) y = model.apply(init_variables, xs) print('output:\n', y) diff --git a/docs/notebooks/state_params.ipynb b/docs/notebooks/state_params.ipynb index 580b254fac..ab3f8a9f6a 100644 --- a/docs/notebooks/state_params.ipynb +++ b/docs/notebooks/state_params.ipynb @@ -55,7 +55,6 @@ "from jax import random\n", "import optax\n", "\n", - "import flax\n", "from flax import linen as nn\n", "\n", "\n", @@ -153,7 +152,7 @@ "model = BiasAdderWithRunningMean()\n", "variables = model.init(random.PRNGKey(0), dummy_input)\n", "# Split state and params (which are updated by optimizer).\n", - "state, params = flax.core.pop(variables, 'params')\n", + "state, params = variables.pop('params')\n", "del variables # Delete variables to avoid wasting resources\n", "tx = optax.sgd(learning_rate=0.02)\n", "opt_state = tx.init(params)\n", @@ -272,7 +271,7 @@ "model = MLP(hidden_size=10, out_size=1)\n", "variables = model.init(random.PRNGKey(0), dummy_input)\n", "# Split state and params (which are updated by optimizer).\n", - "state, params = flax.core.pop(variables, 'params')\n", + "state, params = variables.pop('params')\n", "del variables # Delete variables to avoid wasting resources\n", "tx = optax.sgd(learning_rate=0.02)\n", "opt_state = tx.init(params)\n", diff --git a/docs/notebooks/state_params.md b/docs/notebooks/state_params.md index b1fb9dce31..9535da13d2 100644 --- a/docs/notebooks/state_params.md +++ b/docs/notebooks/state_params.md @@ -41,7 +41,6 @@ from jax import numpy as jnp from jax import random import optax -import flax from flax import linen as nn @@ -114,7 +113,7 @@ Then we can write the actual training code. model = BiasAdderWithRunningMean() variables = model.init(random.PRNGKey(0), dummy_input) # Split state and params (which are updated by optimizer). -state, params = flax.core.pop(variables, 'params') +state, params = variables.pop('params') del variables # Delete variables to avoid wasting resources tx = optax.sgd(learning_rate=0.02) opt_state = tx.init(params) @@ -203,7 +202,7 @@ Note that we also need to specify that the model state does not have a batch dim model = MLP(hidden_size=10, out_size=1) variables = model.init(random.PRNGKey(0), dummy_input) # Split state and params (which are updated by optimizer). -state, params = flax.core.pop(variables, 'params') +state, params = variables.pop('params') del variables # Delete variables to avoid wasting resources tx = optax.sgd(learning_rate=0.02) opt_state = tx.init(params) diff --git a/flax/configurations.py b/flax/configurations.py index d56e6a8022..a6e0ac2470 100644 --- a/flax/configurations.py +++ b/flax/configurations.py @@ -109,11 +109,12 @@ def temp_flip_flag(var_name: str, var_value: bool): default=False, help=("When adopting outside modules, don't clobber existing names.")) -#TODO(marcuschiam): remove this feature flag once regular dict migration is complete +# TODO(marcuschiam): remove this feature flag once regular dict migration is complete flax_return_frozendict = define_bool_state( name='return_frozendict', - default=True, - help=('Whether to return FrozenDicts when calling init or apply.')) + default=False, + help='Whether to return FrozenDicts when calling init or apply.', +) flax_fix_rng = define_bool_state( name ='fix_rng_separator', diff --git a/tests/core/core_lift_test.py b/tests/core/core_lift_test.py index c07fc84ec6..1794e35478 100644 --- a/tests/core/core_lift_test.py +++ b/tests/core/core_lift_test.py @@ -164,7 +164,6 @@ def false_fn(scope, x): self.assertEqual(vars['state'], {'true_count': 1, 'false_count': 1}) np.testing.assert_allclose(y1, -y2) - @temp_flip_flag('return_frozendict', False) def test_switch(self): def f(scope, x, index): scope.variable('state', 'a_count', lambda: 0) diff --git a/tests/linen/linen_attention_test.py b/tests/linen/linen_attention_test.py index 5ed4ede7c7..e16346c93b 100644 --- a/tests/linen/linen_attention_test.py +++ b/tests/linen/linen_attention_test.py @@ -144,7 +144,6 @@ def test_causal_mask_1d(self): np.testing.assert_allclose(mask_1d, mask_1d_simple,) @parameterized.parameters([((5,), (1,)), ((6, 5), (2,))]) - @temp_flip_flag('return_frozendict', False) def test_decoding(self, spatial_shape, attn_dims): bs = 2 num_heads = 3 diff --git a/tests/linen/linen_module_test.py b/tests/linen/linen_module_test.py index e48122c558..dd14f9e216 100644 --- a/tests/linen/linen_module_test.py +++ b/tests/linen/linen_module_test.py @@ -1098,7 +1098,6 @@ def test(self): A().test() self.assertFalse(setup_called) - @temp_flip_flag('return_frozendict', False) def test_module_pass_as_attr(self): class A(nn.Module): @@ -1129,7 +1128,6 @@ def __call__(self, x): } self.assertTrue(tree_equals(var_shapes, ref_var_shapes)) - @temp_flip_flag('return_frozendict', False) def test_module_pass_in_closure(self): a = nn.Dense(2) @@ -1154,7 +1152,6 @@ def __call__(self, x): self.assertTrue(tree_equals(var_shapes, ref_var_shapes)) self.assertIsNone(a.name) - @temp_flip_flag('return_frozendict', False) def test_toplevel_submodule_adoption(self): class Encoder(nn.Module): @@ -1210,7 +1207,6 @@ def __call__(self, x): } self.assertTrue(tree_equals(var_shapes, ref_var_shapes)) - @temp_flip_flag('return_frozendict', False) def test_toplevel_submodule_adoption_pytree(self): class A(nn.Module): @@ -1254,7 +1250,6 @@ def __call__(self, c, x): lambda x, y: np.testing.assert_allclose(x, y, atol=1e-7), counters, ref_counters))) - @temp_flip_flag('return_frozendict', False) def test_toplevel_submodule_adoption_sharing(self): dense = functools.partial(nn.Dense, use_bias=False) @@ -1305,7 +1300,6 @@ def __call__(self, x): } self.assertTrue(tree_equals(var_shapes, ref_var_shapes)) - @temp_flip_flag('return_frozendict', False) def test_toplevel_named_submodule_adoption(self): dense = functools.partial(nn.Dense, use_bias=False) @@ -1360,7 +1354,6 @@ def __call__(self, x): } self.assertTrue(tree_equals(var_shapes, ref_var_shapes)) - @temp_flip_flag('return_frozendict', False) def test_toplevel_submodule_pytree_adoption_sharing(self): class A(nn.Module): diff --git a/tests/linen/linen_transforms_test.py b/tests/linen/linen_transforms_test.py index ee1a04bde5..6fb7c97b3d 100644 --- a/tests/linen/linen_transforms_test.py +++ b/tests/linen/linen_transforms_test.py @@ -807,7 +807,6 @@ def __call__(self, x): y3 = Ctrafo(a2, b).apply(p2, x) np.testing.assert_allclose(y1, y3, atol=1e-7) - @temp_flip_flag('return_frozendict', False) def test_toplevel_submodule_adoption_pytree_transform(self): class A(nn.Module): @nn.compact @@ -852,7 +851,6 @@ def __call__(self, c, x): cntrs, ref_cntrs) )) - @temp_flip_flag('return_frozendict', False) def test_partially_applied_module_constructor_transform(self): k = random.PRNGKey(0) x = jnp.ones((3,4,4)) @@ -870,7 +868,6 @@ def test_partially_applied_module_constructor_transform(self): } self.assertTrue(tree_equals(init_vars_shapes, ref_var_shapes)) - @temp_flip_flag('return_frozendict', False) def test_partial_module_method(self): k = random.PRNGKey(0) x = jnp.ones((3,4,4)) @@ -1505,7 +1502,6 @@ def false_fn(mdl, x): return nn.cond(pred, true_fn, false_fn, self, x) - @temp_flip_flag('return_frozendict', False) def test_switch(self): class Foo(nn.Module): @nn.compact @@ -1540,7 +1536,6 @@ def c_fn(mdl, x): self.assertEqual(vars['state'], {'a_count': 1, 'b_count': 1, 'c_count': 1}) np.testing.assert_allclose(y1, y3) - @temp_flip_flag('return_frozendict', False) def test_switch_multihead(self): class Foo(nn.Module): def setup(self) -> None: diff --git a/tests/linen/summary_test.py b/tests/linen/summary_test.py index 9f7049f3a9..f20924ad22 100644 --- a/tests/linen/summary_test.py +++ b/tests/linen/summary_test.py @@ -181,7 +181,6 @@ def test_module_summary(self): row.counted_variables, ) - @temp_flip_flag('return_frozendict', False) def test_module_summary_with_depth(self): """ This test creates a Table using `module_summary` set the `depth` argument to `1`, @@ -240,7 +239,6 @@ def test_module_summary_with_depth(self): self.assertEqual(table[3].module_variables, table[3].counted_variables) - @temp_flip_flag('return_frozendict', False) def test_tabulate(self): """ This test creates a string representation of a Module using `Module.tabulate` @@ -323,7 +321,6 @@ def test_tabulate_with_method(self): self.assertIn("(block_method)", module_repr) self.assertIn("(cnn_method)", module_repr) - @temp_flip_flag('return_frozendict', False) def test_tabulate_function(self): """ This test creates a string representation of a Module using `Module.tabulate` @@ -370,7 +367,6 @@ def test_tabulate_function(self): self.assertIn("79.4 KB", lines[-3]) - @temp_flip_flag('return_frozendict', False) def test_lifted_transform(self): class LSTM(nn.Module): features: int @@ -406,7 +402,6 @@ def __call__(self, x): self.assertIn("ScanLSTM/ii", lines[13]) self.assertIn("Dense", lines[13]) - @temp_flip_flag('return_frozendict', False) def test_lifted_transform_no_rename(self): class LSTM(nn.Module): features: int @@ -442,7 +437,6 @@ def __call__(self, x): self.assertIn("ScanLSTMCell_0/ii", lines[13]) self.assertIn("Dense", lines[13]) - @temp_flip_flag('return_frozendict', False) def test_module_reuse(self): class ConvBlock(nn.Module): @nn.compact @@ -524,7 +518,6 @@ def __call__(self): self.assertIn('x: 3.141592', lines[7]) self.assertIn('4.141592', lines[7]) - @temp_flip_flag('return_frozendict', False) def test_partitioned_params(self): class Classifier(nn.Module):