diff --git a/docs/guides/flax_basics.ipynb b/docs/guides/flax_basics.ipynb index ac1cdad7b8..e9b4df0192 100644 --- a/docs/guides/flax_basics.ipynb +++ b/docs/guides/flax_basics.ipynb @@ -69,7 +69,7 @@ "import jax\n", "from typing import Any, Callable, Sequence\n", "from jax import lax, random, numpy as jnp\n", - "from flax.core import freeze, unfreeze\n", + "import flax\n", "from flax import linen as nn" ] }, @@ -246,7 +246,7 @@ "W = random.normal(k1, (x_dim, y_dim))\n", "b = random.normal(k2, (y_dim,))\n", "# Store the parameters in a FrozenDict pytree.\n", - "true_params = freeze({'params': {'bias': b, 'kernel': W}})\n", + "true_params = flax.core.freeze({'params': {'bias': b, 'kernel': W}})\n", "\n", "# Generate samples with additional noise.\n", "key_sample, key_noise = random.split(k1)\n", @@ -604,7 +604,7 @@ "params = model.init(key2, x)\n", "y = model.apply(params, x)\n", "\n", - "print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, unfreeze(params)))\n", + "print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params)))\n", "print('output:\\n', y)" ] }, @@ -706,7 +706,7 @@ "params = model.init(key2, x)\n", "y = model.apply(params, x)\n", "\n", - "print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, unfreeze(params)))\n", + "print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params)))\n", "print('output:\\n', y)" ] }, @@ -929,8 +929,8 @@ "for val in [1.0, 2.0, 3.0]:\n", " x = val * jnp.ones((10,5))\n", " y, updated_state = model.apply(variables, x, mutable=['batch_stats'])\n", - " old_state, params = variables.pop('params')\n", - " variables = freeze({'params': params, **updated_state})\n", + " old_state, params = flax.core.pop(variables, 'params')\n", + " variables = flax.core.freeze({'params': params, **updated_state})\n", " print('updated state:\\n', updated_state) # Shows only the mutable part" ] }, @@ -994,7 +994,7 @@ "\n", "x = jnp.ones((10,5))\n", "variables = model.init(random.PRNGKey(0), x)\n", - "state, params = variables.pop('params')\n", + "state, params = flax.core.pop(variables, 'params')\n", "del variables\n", "tx = optax.sgd(learning_rate=0.02)\n", "opt_state = tx.init(params)\n", diff --git a/docs/guides/flax_basics.md b/docs/guides/flax_basics.md index 81682e3140..dfaa012052 100644 --- a/docs/guides/flax_basics.md +++ b/docs/guides/flax_basics.md @@ -46,7 +46,7 @@ Here we provide the code needed to set up the environment for our notebook. import jax from typing import Any, Callable, Sequence from jax import lax, random, numpy as jnp -from flax.core import freeze, unfreeze +import flax from flax import linen as nn ``` @@ -132,7 +132,7 @@ k1, k2 = random.split(key) W = random.normal(k1, (x_dim, y_dim)) b = random.normal(k2, (y_dim,)) # Store the parameters in a FrozenDict pytree. -true_params = freeze({'params': {'bias': b, 'kernel': W}}) +true_params = flax.core.freeze({'params': {'bias': b, 'kernel': W}}) # Generate samples with additional noise. key_sample, key_noise = random.split(k1) @@ -306,7 +306,7 @@ model = ExplicitMLP(features=[3,4,5]) params = model.init(key2, x) y = model.apply(params, x) -print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, unfreeze(params))) +print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params))) print('output:\n', y) ``` @@ -362,7 +362,7 @@ model = SimpleMLP(features=[3,4,5]) params = model.init(key2, x) y = model.apply(params, x) -print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, unfreeze(params))) +print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params))) print('output:\n', y) ``` @@ -476,8 +476,8 @@ Here, `updated_state` returns only the state variables that are being mutated by for val in [1.0, 2.0, 3.0]: x = val * jnp.ones((10,5)) y, updated_state = model.apply(variables, x, mutable=['batch_stats']) - old_state, params = variables.pop('params') - variables = freeze({'params': params, **updated_state}) + old_state, params = flax.core.pop(variables, 'params') + variables = flax.core.freeze({'params': params, **updated_state}) print('updated state:\n', updated_state) # Shows only the mutable part ``` @@ -509,7 +509,7 @@ def update_step(tx, apply_fn, x, opt_state, params, state): x = jnp.ones((10,5)) variables = model.init(random.PRNGKey(0), x) -state, params = variables.pop('params') +state, params = flax.core.pop(variables, 'params') del variables tx = optax.sgd(learning_rate=0.02) opt_state = tx.init(params)