diff --git a/docs/guides/flax_basics.ipynb b/docs/guides/flax_basics.ipynb index ac1cdad7b8..e9b4df0192 100644 --- a/docs/guides/flax_basics.ipynb +++ b/docs/guides/flax_basics.ipynb @@ -69,7 +69,7 @@ "import jax\n", "from typing import Any, Callable, Sequence\n", "from jax import lax, random, numpy as jnp\n", - "from flax.core import freeze, unfreeze\n", + "import flax\n", "from flax import linen as nn" ] }, @@ -246,7 +246,7 @@ "W = random.normal(k1, (x_dim, y_dim))\n", "b = random.normal(k2, (y_dim,))\n", "# Store the parameters in a FrozenDict pytree.\n", - "true_params = freeze({'params': {'bias': b, 'kernel': W}})\n", + "true_params = flax.core.freeze({'params': {'bias': b, 'kernel': W}})\n", "\n", "# Generate samples with additional noise.\n", "key_sample, key_noise = random.split(k1)\n", @@ -604,7 +604,7 @@ "params = model.init(key2, x)\n", "y = model.apply(params, x)\n", "\n", - "print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, unfreeze(params)))\n", + "print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params)))\n", "print('output:\\n', y)" ] }, @@ -706,7 +706,7 @@ "params = model.init(key2, x)\n", "y = model.apply(params, x)\n", "\n", - "print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, unfreeze(params)))\n", + "print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params)))\n", "print('output:\\n', y)" ] }, @@ -929,8 +929,8 @@ "for val in [1.0, 2.0, 3.0]:\n", " x = val * jnp.ones((10,5))\n", " y, updated_state = model.apply(variables, x, mutable=['batch_stats'])\n", - " old_state, params = variables.pop('params')\n", - " variables = freeze({'params': params, **updated_state})\n", + " old_state, params = flax.core.pop(variables, 'params')\n", + " variables = flax.core.freeze({'params': params, **updated_state})\n", " print('updated state:\\n', updated_state) # Shows only the mutable part" ] }, @@ -994,7 +994,7 @@ "\n", "x = jnp.ones((10,5))\n", "variables = model.init(random.PRNGKey(0), x)\n", - "state, params = variables.pop('params')\n", + "state, params = flax.core.pop(variables, 'params')\n", "del variables\n", "tx = optax.sgd(learning_rate=0.02)\n", "opt_state = tx.init(params)\n", diff --git a/docs/guides/flax_basics.md b/docs/guides/flax_basics.md index 81682e3140..dfaa012052 100644 --- a/docs/guides/flax_basics.md +++ b/docs/guides/flax_basics.md @@ -46,7 +46,7 @@ Here we provide the code needed to set up the environment for our notebook. import jax from typing import Any, Callable, Sequence from jax import lax, random, numpy as jnp -from flax.core import freeze, unfreeze +import flax from flax import linen as nn ``` @@ -132,7 +132,7 @@ k1, k2 = random.split(key) W = random.normal(k1, (x_dim, y_dim)) b = random.normal(k2, (y_dim,)) # Store the parameters in a FrozenDict pytree. -true_params = freeze({'params': {'bias': b, 'kernel': W}}) +true_params = flax.core.freeze({'params': {'bias': b, 'kernel': W}}) # Generate samples with additional noise. key_sample, key_noise = random.split(k1) @@ -306,7 +306,7 @@ model = ExplicitMLP(features=[3,4,5]) params = model.init(key2, x) y = model.apply(params, x) -print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, unfreeze(params))) +print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params))) print('output:\n', y) ``` @@ -362,7 +362,7 @@ model = SimpleMLP(features=[3,4,5]) params = model.init(key2, x) y = model.apply(params, x) -print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, unfreeze(params))) +print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params))) print('output:\n', y) ``` @@ -476,8 +476,8 @@ Here, `updated_state` returns only the state variables that are being mutated by for val in [1.0, 2.0, 3.0]: x = val * jnp.ones((10,5)) y, updated_state = model.apply(variables, x, mutable=['batch_stats']) - old_state, params = variables.pop('params') - variables = freeze({'params': params, **updated_state}) + old_state, params = flax.core.pop(variables, 'params') + variables = flax.core.freeze({'params': params, **updated_state}) print('updated state:\n', updated_state) # Shows only the mutable part ``` @@ -509,7 +509,7 @@ def update_step(tx, apply_fn, x, opt_state, params, state): x = jnp.ones((10,5)) variables = model.init(random.PRNGKey(0), x) -state, params = variables.pop('params') +state, params = flax.core.pop(variables, 'params') del variables tx = optax.sgd(learning_rate=0.02) opt_state = tx.init(params) diff --git a/docs/notebooks/linen_intro.ipynb b/docs/notebooks/linen_intro.ipynb index ec16074ae1..ab272e83cd 100644 --- a/docs/notebooks/linen_intro.ipynb +++ b/docs/notebooks/linen_intro.ipynb @@ -75,7 +75,7 @@ "from typing import Any, Callable, Sequence, Optional\n", "import jax\n", "from jax import lax, random, numpy as jnp\n", - "from flax.core import freeze, unfreeze\n", + "import flax\n", "from flax import linen as nn" ] }, @@ -306,7 +306,7 @@ "init_variables = model.init(key2, x)\n", "y = model.apply(init_variables, x)\n", "\n", - "print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))\n", + "print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))\n", "print('output:\\n', y)" ] }, @@ -370,7 +370,7 @@ "init_variables = model.init(key2, x)\n", "y = model.apply(init_variables, x)\n", "\n", - "print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))\n", + "print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))\n", "print('output:\\n', y)" ] }, @@ -743,7 +743,7 @@ "model = Block(features=3, training=True)\n", "\n", "init_variables = model.init({'params': key2, 'dropout': key3}, x)\n", - "_, init_params = init_variables.pop('params')\n", + "_, init_params = flax.core.pop(init_variables, 'params')\n", "\n", "# When calling `apply` with mutable kinds, returns a pair of output,\n", "# mutated_variables.\n", @@ -752,8 +752,8 @@ "\n", "# Now we reassemble the full variables from the updates (in a real training\n", "# loop, with the updated params from an optimizer).\n", - "updated_variables = freeze(dict(params=init_params,\n", - " **mutated_variables))\n", + "updated_variables = flax.core.freeze(dict(params=init_params,\n", + " **mutated_variables))\n", "\n", "print('updated variables:\\n', updated_variables)\n", "print('initialized variable shapes:\\n',\n", @@ -842,7 +842,7 @@ "init_variables = model.init(key2, x)\n", "y = model.apply(init_variables, x)\n", "\n", - "print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))\n", + "print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))\n", "print('output:\\n', y)" ] }, @@ -913,7 +913,7 @@ "init_variables = model.init(key2, x)\n", "y = model.apply(init_variables, x)\n", "\n", - "print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))\n", + "print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))\n", "print('output:\\n', y)" ] }, @@ -1069,7 +1069,7 @@ " batch_axes=(0,))\n", "\n", "init_variables = model(train=False).init({'params': key2}, x, x)\n", - "print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))\n", + "print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))\n", "\n", "y = model(train=True).apply(init_variables, x, x, rngs={'dropout': key4})\n", "print('output:\\n', y.shape)" @@ -1149,7 +1149,7 @@ "model = SimpleScan(2)\n", "init_variables = model.init(key2, xs)\n", "\n", - "print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))\n", + "print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))\n", "\n", "y = model.apply(init_variables, xs)\n", "print('output:\\n', y)" diff --git a/docs/notebooks/linen_intro.md b/docs/notebooks/linen_intro.md index 70dbc8337b..b3c94187fa 100644 --- a/docs/notebooks/linen_intro.md +++ b/docs/notebooks/linen_intro.md @@ -53,7 +53,7 @@ import functools from typing import Any, Callable, Sequence, Optional import jax from jax import lax, random, numpy as jnp -from flax.core import freeze, unfreeze +import flax from flax import linen as nn ``` @@ -154,7 +154,7 @@ model = ExplicitMLP(features=[3,4,5]) init_variables = model.init(key2, x) y = model.apply(init_variables, x) -print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables))) +print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables))) print('output:\n', y) ``` @@ -189,7 +189,7 @@ model = SimpleMLP(features=[3,4,5]) init_variables = model.init(key2, x) y = model.apply(init_variables, x) -print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables))) +print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables))) print('output:\n', y) ``` @@ -357,7 +357,7 @@ x = random.uniform(key1, (3,4,4)) model = Block(features=3, training=True) init_variables = model.init({'params': key2, 'dropout': key3}, x) -_, init_params = init_variables.pop('params') +_, init_params = flax.core.pop(init_variables, 'params') # When calling `apply` with mutable kinds, returns a pair of output, # mutated_variables. @@ -366,8 +366,8 @@ y, mutated_variables = model.apply( # Now we reassemble the full variables from the updates (in a real training # loop, with the updated params from an optimizer). -updated_variables = freeze(dict(params=init_params, - **mutated_variables)) +updated_variables = flax.core.freeze(dict(params=init_params, + **mutated_variables)) print('updated variables:\n', updated_variables) print('initialized variable shapes:\n', @@ -419,7 +419,7 @@ model = MLP(features=[3,4,5]) init_variables = model.init(key2, x) y = model.apply(init_variables, x) -print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables))) +print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables))) print('output:\n', y) ``` @@ -459,7 +459,7 @@ model = RematMLP(features=[3,4,5]) init_variables = model.init(key2, x) y = model.apply(init_variables, x) -print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables))) +print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables))) print('output:\n', y) ``` @@ -587,7 +587,7 @@ model = functools.partial( batch_axes=(0,)) init_variables = model(train=False).init({'params': key2}, x, x) -print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables))) +print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables))) y = model(train=True).apply(init_variables, x, x, rngs={'dropout': key4}) print('output:\n', y.shape) @@ -635,7 +635,7 @@ xs = random.uniform(key1, (1, 5, 2)) model = SimpleScan(2) init_variables = model.init(key2, xs) -print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables))) +print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables))) y = model.apply(init_variables, xs) print('output:\n', y) diff --git a/docs/notebooks/state_params.ipynb b/docs/notebooks/state_params.ipynb index ab3f8a9f6a..580b254fac 100644 --- a/docs/notebooks/state_params.ipynb +++ b/docs/notebooks/state_params.ipynb @@ -55,6 +55,7 @@ "from jax import random\n", "import optax\n", "\n", + "import flax\n", "from flax import linen as nn\n", "\n", "\n", @@ -152,7 +153,7 @@ "model = BiasAdderWithRunningMean()\n", "variables = model.init(random.PRNGKey(0), dummy_input)\n", "# Split state and params (which are updated by optimizer).\n", - "state, params = variables.pop('params')\n", + "state, params = flax.core.pop(variables, 'params')\n", "del variables # Delete variables to avoid wasting resources\n", "tx = optax.sgd(learning_rate=0.02)\n", "opt_state = tx.init(params)\n", @@ -271,7 +272,7 @@ "model = MLP(hidden_size=10, out_size=1)\n", "variables = model.init(random.PRNGKey(0), dummy_input)\n", "# Split state and params (which are updated by optimizer).\n", - "state, params = variables.pop('params')\n", + "state, params = flax.core.pop(variables, 'params')\n", "del variables # Delete variables to avoid wasting resources\n", "tx = optax.sgd(learning_rate=0.02)\n", "opt_state = tx.init(params)\n", diff --git a/docs/notebooks/state_params.md b/docs/notebooks/state_params.md index 9535da13d2..b1fb9dce31 100644 --- a/docs/notebooks/state_params.md +++ b/docs/notebooks/state_params.md @@ -41,6 +41,7 @@ from jax import numpy as jnp from jax import random import optax +import flax from flax import linen as nn @@ -113,7 +114,7 @@ Then we can write the actual training code. model = BiasAdderWithRunningMean() variables = model.init(random.PRNGKey(0), dummy_input) # Split state and params (which are updated by optimizer). -state, params = variables.pop('params') +state, params = flax.core.pop(variables, 'params') del variables # Delete variables to avoid wasting resources tx = optax.sgd(learning_rate=0.02) opt_state = tx.init(params) @@ -202,7 +203,7 @@ Note that we also need to specify that the model state does not have a batch dim model = MLP(hidden_size=10, out_size=1) variables = model.init(random.PRNGKey(0), dummy_input) # Split state and params (which are updated by optimizer). -state, params = variables.pop('params') +state, params = flax.core.pop(variables, 'params') del variables # Delete variables to avoid wasting resources tx = optax.sgd(learning_rate=0.02) opt_state = tx.init(params)