Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make flax_basics guide use utility fns #3214

Merged
merged 1 commit into from
Jul 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions docs/guides/flax_basics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
"import jax\n",
"from typing import Any, Callable, Sequence\n",
"from jax import lax, random, numpy as jnp\n",
"from flax.core import freeze, unfreeze\n",
"import flax\n",
"from flax import linen as nn"
]
},
Expand Down Expand Up @@ -246,7 +246,7 @@
"W = random.normal(k1, (x_dim, y_dim))\n",
"b = random.normal(k2, (y_dim,))\n",
"# Store the parameters in a FrozenDict pytree.\n",
"true_params = freeze({'params': {'bias': b, 'kernel': W}})\n",
"true_params = flax.core.freeze({'params': {'bias': b, 'kernel': W}})\n",
"\n",
"# Generate samples with additional noise.\n",
"key_sample, key_noise = random.split(k1)\n",
Expand Down Expand Up @@ -604,7 +604,7 @@
"params = model.init(key2, x)\n",
"y = model.apply(params, x)\n",
"\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, unfreeze(params)))\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params)))\n",
"print('output:\\n', y)"
]
},
Expand Down Expand Up @@ -706,7 +706,7 @@
"params = model.init(key2, x)\n",
"y = model.apply(params, x)\n",
"\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, unfreeze(params)))\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params)))\n",
"print('output:\\n', y)"
]
},
Expand Down Expand Up @@ -929,8 +929,8 @@
"for val in [1.0, 2.0, 3.0]:\n",
" x = val * jnp.ones((10,5))\n",
" y, updated_state = model.apply(variables, x, mutable=['batch_stats'])\n",
" old_state, params = variables.pop('params')\n",
" variables = freeze({'params': params, **updated_state})\n",
" old_state, params = flax.core.pop(variables, 'params')\n",
" variables = flax.core.freeze({'params': params, **updated_state})\n",
" print('updated state:\\n', updated_state) # Shows only the mutable part"
]
},
Expand Down Expand Up @@ -994,7 +994,7 @@
"\n",
"x = jnp.ones((10,5))\n",
"variables = model.init(random.PRNGKey(0), x)\n",
"state, params = variables.pop('params')\n",
"state, params = flax.core.pop(variables, 'params')\n",
"del variables\n",
"tx = optax.sgd(learning_rate=0.02)\n",
"opt_state = tx.init(params)\n",
Expand Down
14 changes: 7 additions & 7 deletions docs/guides/flax_basics.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ Here we provide the code needed to set up the environment for our notebook.
import jax
from typing import Any, Callable, Sequence
from jax import lax, random, numpy as jnp
from flax.core import freeze, unfreeze
import flax
from flax import linen as nn
```

Expand Down Expand Up @@ -132,7 +132,7 @@ k1, k2 = random.split(key)
W = random.normal(k1, (x_dim, y_dim))
b = random.normal(k2, (y_dim,))
# Store the parameters in a FrozenDict pytree.
true_params = freeze({'params': {'bias': b, 'kernel': W}})
true_params = flax.core.freeze({'params': {'bias': b, 'kernel': W}})

# Generate samples with additional noise.
key_sample, key_noise = random.split(k1)
Expand Down Expand Up @@ -306,7 +306,7 @@ model = ExplicitMLP(features=[3,4,5])
params = model.init(key2, x)
y = model.apply(params, x)

print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, unfreeze(params)))
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params)))
print('output:\n', y)
```

Expand Down Expand Up @@ -362,7 +362,7 @@ model = SimpleMLP(features=[3,4,5])
params = model.init(key2, x)
y = model.apply(params, x)

print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, unfreeze(params)))
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params)))
print('output:\n', y)
```

Expand Down Expand Up @@ -476,8 +476,8 @@ Here, `updated_state` returns only the state variables that are being mutated by
for val in [1.0, 2.0, 3.0]:
x = val * jnp.ones((10,5))
y, updated_state = model.apply(variables, x, mutable=['batch_stats'])
old_state, params = variables.pop('params')
variables = freeze({'params': params, **updated_state})
old_state, params = flax.core.pop(variables, 'params')
variables = flax.core.freeze({'params': params, **updated_state})
print('updated state:\n', updated_state) # Shows only the mutable part
```

Expand Down Expand Up @@ -509,7 +509,7 @@ def update_step(tx, apply_fn, x, opt_state, params, state):

x = jnp.ones((10,5))
variables = model.init(random.PRNGKey(0), x)
state, params = variables.pop('params')
state, params = flax.core.pop(variables, 'params')
del variables
tx = optax.sgd(learning_rate=0.02)
opt_state = tx.init(params)
Expand Down
20 changes: 10 additions & 10 deletions docs/notebooks/linen_intro.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
"from typing import Any, Callable, Sequence, Optional\n",
"import jax\n",
"from jax import lax, random, numpy as jnp\n",
"from flax.core import freeze, unfreeze\n",
"import flax\n",
"from flax import linen as nn"
]
},
Expand Down Expand Up @@ -306,7 +306,7 @@
"init_variables = model.init(key2, x)\n",
"y = model.apply(init_variables, x)\n",
"\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))\n",
"print('output:\\n', y)"
]
},
Expand Down Expand Up @@ -370,7 +370,7 @@
"init_variables = model.init(key2, x)\n",
"y = model.apply(init_variables, x)\n",
"\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))\n",
"print('output:\\n', y)"
]
},
Expand Down Expand Up @@ -743,7 +743,7 @@
"model = Block(features=3, training=True)\n",
"\n",
"init_variables = model.init({'params': key2, 'dropout': key3}, x)\n",
"_, init_params = init_variables.pop('params')\n",
"_, init_params = flax.core.pop(init_variables, 'params')\n",
"\n",
"# When calling `apply` with mutable kinds, returns a pair of output,\n",
"# mutated_variables.\n",
Expand All @@ -752,8 +752,8 @@
"\n",
"# Now we reassemble the full variables from the updates (in a real training\n",
"# loop, with the updated params from an optimizer).\n",
"updated_variables = freeze(dict(params=init_params,\n",
" **mutated_variables))\n",
"updated_variables = flax.core.freeze(dict(params=init_params,\n",
" **mutated_variables))\n",
"\n",
"print('updated variables:\\n', updated_variables)\n",
"print('initialized variable shapes:\\n',\n",
Expand Down Expand Up @@ -842,7 +842,7 @@
"init_variables = model.init(key2, x)\n",
"y = model.apply(init_variables, x)\n",
"\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))\n",
"print('output:\\n', y)"
]
},
Expand Down Expand Up @@ -913,7 +913,7 @@
"init_variables = model.init(key2, x)\n",
"y = model.apply(init_variables, x)\n",
"\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))\n",
"print('output:\\n', y)"
]
},
Expand Down Expand Up @@ -1069,7 +1069,7 @@
" batch_axes=(0,))\n",
"\n",
"init_variables = model(train=False).init({'params': key2}, x, x)\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))\n",
"\n",
"y = model(train=True).apply(init_variables, x, x, rngs={'dropout': key4})\n",
"print('output:\\n', y.shape)"
Expand Down Expand Up @@ -1149,7 +1149,7 @@
"model = SimpleScan(2)\n",
"init_variables = model.init(key2, xs)\n",
"\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))\n",
"\n",
"y = model.apply(init_variables, xs)\n",
"print('output:\\n', y)"
Expand Down
20 changes: 10 additions & 10 deletions docs/notebooks/linen_intro.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ import functools
from typing import Any, Callable, Sequence, Optional
import jax
from jax import lax, random, numpy as jnp
from flax.core import freeze, unfreeze
import flax
from flax import linen as nn
```

Expand Down Expand Up @@ -154,7 +154,7 @@ model = ExplicitMLP(features=[3,4,5])
init_variables = model.init(key2, x)
y = model.apply(init_variables, x)

print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))
print('output:\n', y)
```

Expand Down Expand Up @@ -189,7 +189,7 @@ model = SimpleMLP(features=[3,4,5])
init_variables = model.init(key2, x)
y = model.apply(init_variables, x)

print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))
print('output:\n', y)
```

Expand Down Expand Up @@ -357,7 +357,7 @@ x = random.uniform(key1, (3,4,4))
model = Block(features=3, training=True)

init_variables = model.init({'params': key2, 'dropout': key3}, x)
_, init_params = init_variables.pop('params')
_, init_params = flax.core.pop(init_variables, 'params')

# When calling `apply` with mutable kinds, returns a pair of output,
# mutated_variables.
Expand All @@ -366,8 +366,8 @@ y, mutated_variables = model.apply(

# Now we reassemble the full variables from the updates (in a real training
# loop, with the updated params from an optimizer).
updated_variables = freeze(dict(params=init_params,
**mutated_variables))
updated_variables = flax.core.freeze(dict(params=init_params,
**mutated_variables))

print('updated variables:\n', updated_variables)
print('initialized variable shapes:\n',
Expand Down Expand Up @@ -419,7 +419,7 @@ model = MLP(features=[3,4,5])
init_variables = model.init(key2, x)
y = model.apply(init_variables, x)

print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))
print('output:\n', y)
```

Expand Down Expand Up @@ -459,7 +459,7 @@ model = RematMLP(features=[3,4,5])
init_variables = model.init(key2, x)
y = model.apply(init_variables, x)

print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))
print('output:\n', y)
```

Expand Down Expand Up @@ -587,7 +587,7 @@ model = functools.partial(
batch_axes=(0,))

init_variables = model(train=False).init({'params': key2}, x, x)
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))

y = model(train=True).apply(init_variables, x, x, rngs={'dropout': key4})
print('output:\n', y.shape)
Expand Down Expand Up @@ -635,7 +635,7 @@ xs = random.uniform(key1, (1, 5, 2))
model = SimpleScan(2)
init_variables = model.init(key2, xs)

print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, unfreeze(init_variables)))
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(init_variables)))

y = model.apply(init_variables, xs)
print('output:\n', y)
Expand Down
5 changes: 3 additions & 2 deletions docs/notebooks/state_params.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
"from jax import random\n",
"import optax\n",
"\n",
"import flax\n",
"from flax import linen as nn\n",
"\n",
"\n",
Expand Down Expand Up @@ -152,7 +153,7 @@
"model = BiasAdderWithRunningMean()\n",
"variables = model.init(random.PRNGKey(0), dummy_input)\n",
"# Split state and params (which are updated by optimizer).\n",
"state, params = variables.pop('params')\n",
"state, params = flax.core.pop(variables, 'params')\n",
"del variables # Delete variables to avoid wasting resources\n",
"tx = optax.sgd(learning_rate=0.02)\n",
"opt_state = tx.init(params)\n",
Expand Down Expand Up @@ -271,7 +272,7 @@
"model = MLP(hidden_size=10, out_size=1)\n",
"variables = model.init(random.PRNGKey(0), dummy_input)\n",
"# Split state and params (which are updated by optimizer).\n",
"state, params = variables.pop('params')\n",
"state, params = flax.core.pop(variables, 'params')\n",
"del variables # Delete variables to avoid wasting resources\n",
"tx = optax.sgd(learning_rate=0.02)\n",
"opt_state = tx.init(params)\n",
Expand Down
5 changes: 3 additions & 2 deletions docs/notebooks/state_params.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ from jax import numpy as jnp
from jax import random
import optax

import flax
from flax import linen as nn


Expand Down Expand Up @@ -113,7 +114,7 @@ Then we can write the actual training code.
model = BiasAdderWithRunningMean()
variables = model.init(random.PRNGKey(0), dummy_input)
# Split state and params (which are updated by optimizer).
state, params = variables.pop('params')
state, params = flax.core.pop(variables, 'params')
del variables # Delete variables to avoid wasting resources
tx = optax.sgd(learning_rate=0.02)
opt_state = tx.init(params)
Expand Down Expand Up @@ -202,7 +203,7 @@ Note that we also need to specify that the model state does not have a batch dim
model = MLP(hidden_size=10, out_size=1)
variables = model.init(random.PRNGKey(0), dummy_input)
# Split state and params (which are updated by optimizer).
state, params = variables.pop('params')
state, params = flax.core.pop(variables, 'params')
del variables # Delete variables to avoid wasting resources
tx = optax.sgd(learning_rate=0.02)
opt_state = tx.init(params)
Expand Down