Skip to content

Commit

Permalink
Migrating Flax from returning FrozenDicts to returning regular dicts.…
Browse files Browse the repository at this point in the history
… More details can be found in this [announcement](#3191)

PiperOrigin-RevId: 547394685
  • Loading branch information
chiamp authored and Flax Authors committed Jul 19, 2023
1 parent 54a0296 commit 7509a93
Show file tree
Hide file tree
Showing 12 changed files with 42 additions and 64 deletions.
14 changes: 7 additions & 7 deletions docs/guides/flax_basics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
"import jax\n",
"from typing import Any, Callable, Sequence\n",
"from jax import lax, random, numpy as jnp\n",
"import flax\n",
"from flax.core import freeze, unfreeze\n",
"from flax import linen as nn"
]
},
Expand Down Expand Up @@ -246,7 +246,7 @@
"W = random.normal(k1, (x_dim, y_dim))\n",
"b = random.normal(k2, (y_dim,))\n",
"# Store the parameters in a FrozenDict pytree.\n",
"true_params = flax.core.freeze({'params': {'bias': b, 'kernel': W}})\n",
"true_params = freeze({'params': {'bias': b, 'kernel': W}})\n",
"\n",
"# Generate samples with additional noise.\n",
"key_sample, key_noise = random.split(k1)\n",
Expand Down Expand Up @@ -604,7 +604,7 @@
"params = model.init(key2, x)\n",
"y = model.apply(params, x)\n",
"\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params)))\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, unfreeze(params)))\n",
"print('output:\\n', y)"
]
},
Expand Down Expand Up @@ -706,7 +706,7 @@
"params = model.init(key2, x)\n",
"y = model.apply(params, x)\n",
"\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params)))\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, unfreeze(params)))\n",
"print('output:\\n', y)"
]
},
Expand Down Expand Up @@ -929,8 +929,8 @@
"for val in [1.0, 2.0, 3.0]:\n",
" x = val * jnp.ones((10,5))\n",
" y, updated_state = model.apply(variables, x, mutable=['batch_stats'])\n",
" old_state, params = flax.core.pop(variables, 'params')\n",
" variables = flax.core.freeze({'params': params, **updated_state})\n",
" old_state, params = variables.pop('params')\n",
" variables = freeze({'params': params, **updated_state})\n",
" print('updated state:\\n', updated_state) # Shows only the mutable part"
]
},
Expand Down Expand Up @@ -994,7 +994,7 @@
"\n",
"x = jnp.ones((10,5))\n",
"variables = model.init(random.PRNGKey(0), x)\n",
"state, params = flax.core.pop(variables, 'params')\n",
"state, params = variables.pop('params')\n",
"del variables\n",
"tx = optax.sgd(learning_rate=0.02)\n",
"opt_state = tx.init(params)\n",
Expand Down
14 changes: 7 additions & 7 deletions docs/guides/flax_basics.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ Here we provide the code needed to set up the environment for our notebook.
import jax
from typing import Any, Callable, Sequence
from jax import lax, random, numpy as jnp
import flax
from flax.core import freeze, unfreeze
from flax import linen as nn
```

Expand Down Expand Up @@ -132,7 +132,7 @@ k1, k2 = random.split(key)
W = random.normal(k1, (x_dim, y_dim))
b = random.normal(k2, (y_dim,))
# Store the parameters in a FrozenDict pytree.
true_params = flax.core.freeze({'params': {'bias': b, 'kernel': W}})
true_params = freeze({'params': {'bias': b, 'kernel': W}})
# Generate samples with additional noise.
key_sample, key_noise = random.split(k1)
Expand Down Expand Up @@ -306,7 +306,7 @@ model = ExplicitMLP(features=[3,4,5])
params = model.init(key2, x)
y = model.apply(params, x)
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params)))
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, unfreeze(params)))
print('output:\n', y)
```

Expand Down Expand Up @@ -362,7 +362,7 @@ model = SimpleMLP(features=[3,4,5])
params = model.init(key2, x)
y = model.apply(params, x)
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params)))
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, unfreeze(params)))
print('output:\n', y)
```

Expand Down Expand Up @@ -476,8 +476,8 @@ Here, `updated_state` returns only the state variables that are being mutated by
for val in [1.0, 2.0, 3.0]:
x = val * jnp.ones((10,5))
y, updated_state = model.apply(variables, x, mutable=['batch_stats'])
old_state, params = flax.core.pop(variables, 'params')
variables = flax.core.freeze({'params': params, **updated_state})
old_state, params = variables.pop('params')
variables = freeze({'params': params, **updated_state})
print('updated state:\n', updated_state) # Shows only the mutable part
```

Expand Down Expand Up @@ -509,7 +509,7 @@ def update_step(tx, apply_fn, x, opt_state, params, state):
x = jnp.ones((10,5))
variables = model.init(random.PRNGKey(0), x)
state, params = flax.core.pop(variables, 'params')
state, params = variables.pop('params')
del variables
tx = optax.sgd(learning_rate=0.02)
opt_state = tx.init(params)
Expand Down
20 changes: 10 additions & 10 deletions docs/notebooks/linen_intro.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
"from typing import Any, Callable, Sequence, Optional\n",
"import jax\n",
"from jax import lax, random, numpy as jnp\n",
"import flax\n",
"from flax.core import freeze, unfreeze\n",
"from flax import linen as nn"
]
},
Expand Down Expand Up @@ -306,7 +306,7 @@
"init_variables = model.init(key2, x)\n",
"y = model.apply(init_variables, x)\n",
"\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))\n",
"print('output:\\n', y)"
]
},
Expand Down Expand Up @@ -370,7 +370,7 @@
"init_variables = model.init(key2, x)\n",
"y = model.apply(init_variables, x)\n",
"\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))\n",
"print('output:\\n', y)"
]
},
Expand Down Expand Up @@ -743,7 +743,7 @@
"model = Block(features=3, training=True)\n",
"\n",
"init_variables = model.init({'params': key2, 'dropout': key3}, x)\n",
"_, init_params = flax.core.pop(init_variables, 'params')\n",
"_, init_params = init_variables.pop('params')\n",
"\n",
"# When calling `apply` with mutable kinds, returns a pair of output,\n",
"# mutated_variables.\n",
Expand All @@ -752,8 +752,8 @@
"\n",
"# Now we reassemble the full variables from the updates (in a real training\n",
"# loop, with the updated params from an optimizer).\n",
"updated_variables = flax.core.freeze(dict(params=init_params,\n",
" **mutated_variables))\n",
"updated_variables = freeze(dict(params=init_params,\n",
" **mutated_variables))\n",
"\n",
"print('updated variables:\\n', updated_variables)\n",
"print('initialized variable shapes:\\n',\n",
Expand Down Expand Up @@ -842,7 +842,7 @@
"init_variables = model.init(key2, x)\n",
"y = model.apply(init_variables, x)\n",
"\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))\n",
"print('output:\\n', y)"
]
},
Expand Down Expand Up @@ -913,7 +913,7 @@
"init_variables = model.init(key2, x)\n",
"y = model.apply(init_variables, x)\n",
"\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))\n",
"print('output:\\n', y)"
]
},
Expand Down Expand Up @@ -1069,7 +1069,7 @@
" batch_axes=(0,))\n",
"\n",
"init_variables = model(train=False).init({'params': key2}, x, x)\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))\n",
"\n",
"y = model(train=True).apply(init_variables, x, x, rngs={'dropout': key4})\n",
"print('output:\\n', y.shape)"
Expand Down Expand Up @@ -1149,7 +1149,7 @@
"model = SimpleScan(2)\n",
"init_variables = model.init(key2, xs)\n",
"\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))\n",
"\n",
"y = model.apply(init_variables, xs)\n",
"print('output:\\n', y)"
Expand Down
20 changes: 10 additions & 10 deletions docs/notebooks/linen_intro.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ import functools
from typing import Any, Callable, Sequence, Optional
import jax
from jax import lax, random, numpy as jnp
import flax
from flax.core import freeze, unfreeze
from flax import linen as nn
```

Expand Down Expand Up @@ -154,7 +154,7 @@ model = ExplicitMLP(features=[3,4,5])
init_variables = model.init(key2, x)
y = model.apply(init_variables, x)
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))
print('output:\n', y)
```

Expand Down Expand Up @@ -189,7 +189,7 @@ model = SimpleMLP(features=[3,4,5])
init_variables = model.init(key2, x)
y = model.apply(init_variables, x)
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))
print('output:\n', y)
```

Expand Down Expand Up @@ -357,7 +357,7 @@ x = random.uniform(key1, (3,4,4))
model = Block(features=3, training=True)
init_variables = model.init({'params': key2, 'dropout': key3}, x)
_, init_params = flax.core.pop(init_variables, 'params')
_, init_params = init_variables.pop('params')
# When calling `apply` with mutable kinds, returns a pair of output,
# mutated_variables.
Expand All @@ -366,8 +366,8 @@ y, mutated_variables = model.apply(
# Now we reassemble the full variables from the updates (in a real training
# loop, with the updated params from an optimizer).
updated_variables = flax.core.freeze(dict(params=init_params,
**mutated_variables))
updated_variables = freeze(dict(params=init_params,
**mutated_variables))
print('updated variables:\n', updated_variables)
print('initialized variable shapes:\n',
Expand Down Expand Up @@ -419,7 +419,7 @@ model = MLP(features=[3,4,5])
init_variables = model.init(key2, x)
y = model.apply(init_variables, x)
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))
print('output:\n', y)
```

Expand Down Expand Up @@ -459,7 +459,7 @@ model = RematMLP(features=[3,4,5])
init_variables = model.init(key2, x)
y = model.apply(init_variables, x)
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))
print('output:\n', y)
```

Expand Down Expand Up @@ -587,7 +587,7 @@ model = functools.partial(
batch_axes=(0,))
init_variables = model(train=False).init({'params': key2}, x, x)
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))
y = model(train=True).apply(init_variables, x, x, rngs={'dropout': key4})
print('output:\n', y.shape)
Expand Down Expand Up @@ -635,7 +635,7 @@ xs = random.uniform(key1, (1, 5, 2))
model = SimpleScan(2)
init_variables = model.init(key2, xs)
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))
y = model.apply(init_variables, xs)
print('output:\n', y)
Expand Down
5 changes: 2 additions & 3 deletions docs/notebooks/state_params.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@
"from jax import random\n",
"import optax\n",
"\n",
"import flax\n",
"from flax import linen as nn\n",
"\n",
"\n",
Expand Down Expand Up @@ -153,7 +152,7 @@
"model = BiasAdderWithRunningMean()\n",
"variables = model.init(random.PRNGKey(0), dummy_input)\n",
"# Split state and params (which are updated by optimizer).\n",
"state, params = flax.core.pop(variables, 'params')\n",
"state, params = variables.pop('params')\n",
"del variables # Delete variables to avoid wasting resources\n",
"tx = optax.sgd(learning_rate=0.02)\n",
"opt_state = tx.init(params)\n",
Expand Down Expand Up @@ -272,7 +271,7 @@
"model = MLP(hidden_size=10, out_size=1)\n",
"variables = model.init(random.PRNGKey(0), dummy_input)\n",
"# Split state and params (which are updated by optimizer).\n",
"state, params = flax.core.pop(variables, 'params')\n",
"state, params = variables.pop('params')\n",
"del variables # Delete variables to avoid wasting resources\n",
"tx = optax.sgd(learning_rate=0.02)\n",
"opt_state = tx.init(params)\n",
Expand Down
5 changes: 2 additions & 3 deletions docs/notebooks/state_params.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ from jax import numpy as jnp
from jax import random
import optax
import flax
from flax import linen as nn
Expand Down Expand Up @@ -114,7 +113,7 @@ Then we can write the actual training code.
model = BiasAdderWithRunningMean()
variables = model.init(random.PRNGKey(0), dummy_input)
# Split state and params (which are updated by optimizer).
state, params = flax.core.pop(variables, 'params')
state, params = variables.pop('params')
del variables # Delete variables to avoid wasting resources
tx = optax.sgd(learning_rate=0.02)
opt_state = tx.init(params)
Expand Down Expand Up @@ -203,7 +202,7 @@ Note that we also need to specify that the model state does not have a batch dim
model = MLP(hidden_size=10, out_size=1)
variables = model.init(random.PRNGKey(0), dummy_input)
# Split state and params (which are updated by optimizer).
state, params = flax.core.pop(variables, 'params')
state, params = variables.pop('params')
del variables # Delete variables to avoid wasting resources
tx = optax.sgd(learning_rate=0.02)
opt_state = tx.init(params)
Expand Down
7 changes: 4 additions & 3 deletions flax/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,12 @@ def temp_flip_flag(var_name: str, var_value: bool):
default=False,
help=("When adopting outside modules, don't clobber existing names."))

#TODO(marcuschiam): remove this feature flag once regular dict migration is complete
# TODO(marcuschiam): remove this feature flag once regular dict migration is complete
flax_return_frozendict = define_bool_state(
name='return_frozendict',
default=True,
help=('Whether to return FrozenDicts when calling init or apply.'))
default=False,
help='Whether to return FrozenDicts when calling init or apply.',
)

flax_fix_rng = define_bool_state(
name ='fix_rng_separator',
Expand Down
1 change: 0 additions & 1 deletion tests/core/core_lift_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,6 @@ def false_fn(scope, x):
self.assertEqual(vars['state'], {'true_count': 1, 'false_count': 1})
np.testing.assert_allclose(y1, -y2)

@temp_flip_flag('return_frozendict', False)
def test_switch(self):
def f(scope, x, index):
scope.variable('state', 'a_count', lambda: 0)
Expand Down
1 change: 0 additions & 1 deletion tests/linen/linen_attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ def test_causal_mask_1d(self):
np.testing.assert_allclose(mask_1d, mask_1d_simple,)

@parameterized.parameters([((5,), (1,)), ((6, 5), (2,))])
@temp_flip_flag('return_frozendict', False)
def test_decoding(self, spatial_shape, attn_dims):
bs = 2
num_heads = 3
Expand Down
Loading

0 comments on commit 7509a93

Please sign in to comment.