Skip to content

Commit

Permalink
Merge pull request #3214 from chiamp:fix_flax_basics
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 549425315
  • Loading branch information
Flax Authors committed Jul 19, 2023
2 parents 290e50f + 93a38ff commit 54a0296
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 38 deletions.
14 changes: 7 additions & 7 deletions docs/guides/flax_basics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
"import jax\n",
"from typing import Any, Callable, Sequence\n",
"from jax import lax, random, numpy as jnp\n",
"from flax.core import freeze, unfreeze\n",
"import flax\n",
"from flax import linen as nn"
]
},
Expand Down Expand Up @@ -246,7 +246,7 @@
"W = random.normal(k1, (x_dim, y_dim))\n",
"b = random.normal(k2, (y_dim,))\n",
"# Store the parameters in a FrozenDict pytree.\n",
"true_params = freeze({'params': {'bias': b, 'kernel': W}})\n",
"true_params = flax.core.freeze({'params': {'bias': b, 'kernel': W}})\n",
"\n",
"# Generate samples with additional noise.\n",
"key_sample, key_noise = random.split(k1)\n",
Expand Down Expand Up @@ -604,7 +604,7 @@
"params = model.init(key2, x)\n",
"y = model.apply(params, x)\n",
"\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, unfreeze(params)))\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params)))\n",
"print('output:\\n', y)"
]
},
Expand Down Expand Up @@ -706,7 +706,7 @@
"params = model.init(key2, x)\n",
"y = model.apply(params, x)\n",
"\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, unfreeze(params)))\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params)))\n",
"print('output:\\n', y)"
]
},
Expand Down Expand Up @@ -929,8 +929,8 @@
"for val in [1.0, 2.0, 3.0]:\n",
" x = val * jnp.ones((10,5))\n",
" y, updated_state = model.apply(variables, x, mutable=['batch_stats'])\n",
" old_state, params = variables.pop('params')\n",
" variables = freeze({'params': params, **updated_state})\n",
" old_state, params = flax.core.pop(variables, 'params')\n",
" variables = flax.core.freeze({'params': params, **updated_state})\n",
" print('updated state:\\n', updated_state) # Shows only the mutable part"
]
},
Expand Down Expand Up @@ -994,7 +994,7 @@
"\n",
"x = jnp.ones((10,5))\n",
"variables = model.init(random.PRNGKey(0), x)\n",
"state, params = variables.pop('params')\n",
"state, params = flax.core.pop(variables, 'params')\n",
"del variables\n",
"tx = optax.sgd(learning_rate=0.02)\n",
"opt_state = tx.init(params)\n",
Expand Down
14 changes: 7 additions & 7 deletions docs/guides/flax_basics.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ Here we provide the code needed to set up the environment for our notebook.
import jax
from typing import Any, Callable, Sequence
from jax import lax, random, numpy as jnp
from flax.core import freeze, unfreeze
import flax
from flax import linen as nn
```

Expand Down Expand Up @@ -132,7 +132,7 @@ k1, k2 = random.split(key)
W = random.normal(k1, (x_dim, y_dim))
b = random.normal(k2, (y_dim,))
# Store the parameters in a FrozenDict pytree.
true_params = freeze({'params': {'bias': b, 'kernel': W}})
true_params = flax.core.freeze({'params': {'bias': b, 'kernel': W}})
# Generate samples with additional noise.
key_sample, key_noise = random.split(k1)
Expand Down Expand Up @@ -306,7 +306,7 @@ model = ExplicitMLP(features=[3,4,5])
params = model.init(key2, x)
y = model.apply(params, x)
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, unfreeze(params)))
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params)))
print('output:\n', y)
```

Expand Down Expand Up @@ -362,7 +362,7 @@ model = SimpleMLP(features=[3,4,5])
params = model.init(key2, x)
y = model.apply(params, x)
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, unfreeze(params)))
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params)))
print('output:\n', y)
```

Expand Down Expand Up @@ -476,8 +476,8 @@ Here, `updated_state` returns only the state variables that are being mutated by
for val in [1.0, 2.0, 3.0]:
x = val * jnp.ones((10,5))
y, updated_state = model.apply(variables, x, mutable=['batch_stats'])
old_state, params = variables.pop('params')
variables = freeze({'params': params, **updated_state})
old_state, params = flax.core.pop(variables, 'params')
variables = flax.core.freeze({'params': params, **updated_state})
print('updated state:\n', updated_state) # Shows only the mutable part
```

Expand Down Expand Up @@ -509,7 +509,7 @@ def update_step(tx, apply_fn, x, opt_state, params, state):
x = jnp.ones((10,5))
variables = model.init(random.PRNGKey(0), x)
state, params = variables.pop('params')
state, params = flax.core.pop(variables, 'params')
del variables
tx = optax.sgd(learning_rate=0.02)
opt_state = tx.init(params)
Expand Down
20 changes: 10 additions & 10 deletions docs/notebooks/linen_intro.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
"from typing import Any, Callable, Sequence, Optional\n",
"import jax\n",
"from jax import lax, random, numpy as jnp\n",
"from flax.core import freeze, unfreeze\n",
"import flax\n",
"from flax import linen as nn"
]
},
Expand Down Expand Up @@ -306,7 +306,7 @@
"init_variables = model.init(key2, x)\n",
"y = model.apply(init_variables, x)\n",
"\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))\n",
"print('output:\\n', y)"
]
},
Expand Down Expand Up @@ -370,7 +370,7 @@
"init_variables = model.init(key2, x)\n",
"y = model.apply(init_variables, x)\n",
"\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))\n",
"print('output:\\n', y)"
]
},
Expand Down Expand Up @@ -743,7 +743,7 @@
"model = Block(features=3, training=True)\n",
"\n",
"init_variables = model.init({'params': key2, 'dropout': key3}, x)\n",
"_, init_params = init_variables.pop('params')\n",
"_, init_params = flax.core.pop(init_variables, 'params')\n",
"\n",
"# When calling `apply` with mutable kinds, returns a pair of output,\n",
"# mutated_variables.\n",
Expand All @@ -752,8 +752,8 @@
"\n",
"# Now we reassemble the full variables from the updates (in a real training\n",
"# loop, with the updated params from an optimizer).\n",
"updated_variables = freeze(dict(params=init_params,\n",
" **mutated_variables))\n",
"updated_variables = flax.core.freeze(dict(params=init_params,\n",
" **mutated_variables))\n",
"\n",
"print('updated variables:\\n', updated_variables)\n",
"print('initialized variable shapes:\\n',\n",
Expand Down Expand Up @@ -842,7 +842,7 @@
"init_variables = model.init(key2, x)\n",
"y = model.apply(init_variables, x)\n",
"\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))\n",
"print('output:\\n', y)"
]
},
Expand Down Expand Up @@ -913,7 +913,7 @@
"init_variables = model.init(key2, x)\n",
"y = model.apply(init_variables, x)\n",
"\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))\n",
"print('output:\\n', y)"
]
},
Expand Down Expand Up @@ -1069,7 +1069,7 @@
" batch_axes=(0,))\n",
"\n",
"init_variables = model(train=False).init({'params': key2}, x, x)\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))\n",
"\n",
"y = model(train=True).apply(init_variables, x, x, rngs={'dropout': key4})\n",
"print('output:\\n', y.shape)"
Expand Down Expand Up @@ -1149,7 +1149,7 @@
"model = SimpleScan(2)\n",
"init_variables = model.init(key2, xs)\n",
"\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))\n",
"\n",
"y = model.apply(init_variables, xs)\n",
"print('output:\\n', y)"
Expand Down
20 changes: 10 additions & 10 deletions docs/notebooks/linen_intro.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ import functools
from typing import Any, Callable, Sequence, Optional
import jax
from jax import lax, random, numpy as jnp
from flax.core import freeze, unfreeze
import flax
from flax import linen as nn
```

Expand Down Expand Up @@ -154,7 +154,7 @@ model = ExplicitMLP(features=[3,4,5])
init_variables = model.init(key2, x)
y = model.apply(init_variables, x)
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))
print('output:\n', y)
```

Expand Down Expand Up @@ -189,7 +189,7 @@ model = SimpleMLP(features=[3,4,5])
init_variables = model.init(key2, x)
y = model.apply(init_variables, x)
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))
print('output:\n', y)
```

Expand Down Expand Up @@ -357,7 +357,7 @@ x = random.uniform(key1, (3,4,4))
model = Block(features=3, training=True)
init_variables = model.init({'params': key2, 'dropout': key3}, x)
_, init_params = init_variables.pop('params')
_, init_params = flax.core.pop(init_variables, 'params')
# When calling `apply` with mutable kinds, returns a pair of output,
# mutated_variables.
Expand All @@ -366,8 +366,8 @@ y, mutated_variables = model.apply(
# Now we reassemble the full variables from the updates (in a real training
# loop, with the updated params from an optimizer).
updated_variables = freeze(dict(params=init_params,
**mutated_variables))
updated_variables = flax.core.freeze(dict(params=init_params,
**mutated_variables))
print('updated variables:\n', updated_variables)
print('initialized variable shapes:\n',
Expand Down Expand Up @@ -419,7 +419,7 @@ model = MLP(features=[3,4,5])
init_variables = model.init(key2, x)
y = model.apply(init_variables, x)
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))
print('output:\n', y)
```

Expand Down Expand Up @@ -459,7 +459,7 @@ model = RematMLP(features=[3,4,5])
init_variables = model.init(key2, x)
y = model.apply(init_variables, x)
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))
print('output:\n', y)
```

Expand Down Expand Up @@ -587,7 +587,7 @@ model = functools.partial(
batch_axes=(0,))
init_variables = model(train=False).init({'params': key2}, x, x)
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))
y = model(train=True).apply(init_variables, x, x, rngs={'dropout': key4})
print('output:\n', y.shape)
Expand Down Expand Up @@ -635,7 +635,7 @@ xs = random.uniform(key1, (1, 5, 2))
model = SimpleScan(2)
init_variables = model.init(key2, xs)
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))
y = model.apply(init_variables, xs)
print('output:\n', y)
Expand Down
5 changes: 3 additions & 2 deletions docs/notebooks/state_params.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
"from jax import random\n",
"import optax\n",
"\n",
"import flax\n",
"from flax import linen as nn\n",
"\n",
"\n",
Expand Down Expand Up @@ -152,7 +153,7 @@
"model = BiasAdderWithRunningMean()\n",
"variables = model.init(random.PRNGKey(0), dummy_input)\n",
"# Split state and params (which are updated by optimizer).\n",
"state, params = variables.pop('params')\n",
"state, params = flax.core.pop(variables, 'params')\n",
"del variables # Delete variables to avoid wasting resources\n",
"tx = optax.sgd(learning_rate=0.02)\n",
"opt_state = tx.init(params)\n",
Expand Down Expand Up @@ -271,7 +272,7 @@
"model = MLP(hidden_size=10, out_size=1)\n",
"variables = model.init(random.PRNGKey(0), dummy_input)\n",
"# Split state and params (which are updated by optimizer).\n",
"state, params = variables.pop('params')\n",
"state, params = flax.core.pop(variables, 'params')\n",
"del variables # Delete variables to avoid wasting resources\n",
"tx = optax.sgd(learning_rate=0.02)\n",
"opt_state = tx.init(params)\n",
Expand Down
5 changes: 3 additions & 2 deletions docs/notebooks/state_params.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ from jax import numpy as jnp
from jax import random
import optax
import flax
from flax import linen as nn
Expand Down Expand Up @@ -113,7 +114,7 @@ Then we can write the actual training code.
model = BiasAdderWithRunningMean()
variables = model.init(random.PRNGKey(0), dummy_input)
# Split state and params (which are updated by optimizer).
state, params = variables.pop('params')
state, params = flax.core.pop(variables, 'params')
del variables # Delete variables to avoid wasting resources
tx = optax.sgd(learning_rate=0.02)
opt_state = tx.init(params)
Expand Down Expand Up @@ -202,7 +203,7 @@ Note that we also need to specify that the model state does not have a batch dim
model = MLP(hidden_size=10, out_size=1)
variables = model.init(random.PRNGKey(0), dummy_input)
# Split state and params (which are updated by optimizer).
state, params = variables.pop('params')
state, params = flax.core.pop(variables, 'params')
del variables # Delete variables to avoid wasting resources
tx = optax.sgd(learning_rate=0.02)
opt_state = tx.init(params)
Expand Down

0 comments on commit 54a0296

Please sign in to comment.