Skip to content

Commit

Permalink
make flax_basics guide use utility fns
Browse files Browse the repository at this point in the history
  • Loading branch information
chiamp committed Jul 19, 2023
1 parent 290e50f commit 08cb7ef
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 14 deletions.
14 changes: 7 additions & 7 deletions docs/guides/flax_basics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
"import jax\n",
"from typing import Any, Callable, Sequence\n",
"from jax import lax, random, numpy as jnp\n",
"from flax.core import freeze, unfreeze\n",
"import flax\n",
"from flax import linen as nn"
]
},
Expand Down Expand Up @@ -246,7 +246,7 @@
"W = random.normal(k1, (x_dim, y_dim))\n",
"b = random.normal(k2, (y_dim,))\n",
"# Store the parameters in a FrozenDict pytree.\n",
"true_params = freeze({'params': {'bias': b, 'kernel': W}})\n",
"true_params = flax.core.freeze({'params': {'bias': b, 'kernel': W}})\n",
"\n",
"# Generate samples with additional noise.\n",
"key_sample, key_noise = random.split(k1)\n",
Expand Down Expand Up @@ -604,7 +604,7 @@
"params = model.init(key2, x)\n",
"y = model.apply(params, x)\n",
"\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, unfreeze(params)))\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params)))\n",
"print('output:\\n', y)"
]
},
Expand Down Expand Up @@ -706,7 +706,7 @@
"params = model.init(key2, x)\n",
"y = model.apply(params, x)\n",
"\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, unfreeze(params)))\n",
"print('initialized parameter shapes:\\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params)))\n",
"print('output:\\n', y)"
]
},
Expand Down Expand Up @@ -929,8 +929,8 @@
"for val in [1.0, 2.0, 3.0]:\n",
" x = val * jnp.ones((10,5))\n",
" y, updated_state = model.apply(variables, x, mutable=['batch_stats'])\n",
" old_state, params = variables.pop('params')\n",
" variables = freeze({'params': params, **updated_state})\n",
" old_state, params = flax.core.pop(variables, 'params')\n",
" variables = flax.core.freeze({'params': params, **updated_state})\n",
" print('updated state:\\n', updated_state) # Shows only the mutable part"
]
},
Expand Down Expand Up @@ -994,7 +994,7 @@
"\n",
"x = jnp.ones((10,5))\n",
"variables = model.init(random.PRNGKey(0), x)\n",
"state, params = variables.pop('params')\n",
"state, params = flax.core.pop(variables, 'params')\n",
"del variables\n",
"tx = optax.sgd(learning_rate=0.02)\n",
"opt_state = tx.init(params)\n",
Expand Down
14 changes: 7 additions & 7 deletions docs/guides/flax_basics.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ Here we provide the code needed to set up the environment for our notebook.
import jax
from typing import Any, Callable, Sequence
from jax import lax, random, numpy as jnp
from flax.core import freeze, unfreeze
import flax
from flax import linen as nn
```

Expand Down Expand Up @@ -132,7 +132,7 @@ k1, k2 = random.split(key)
W = random.normal(k1, (x_dim, y_dim))
b = random.normal(k2, (y_dim,))
# Store the parameters in a FrozenDict pytree.
true_params = freeze({'params': {'bias': b, 'kernel': W}})
true_params = flax.core.freeze({'params': {'bias': b, 'kernel': W}})
# Generate samples with additional noise.
key_sample, key_noise = random.split(k1)
Expand Down Expand Up @@ -306,7 +306,7 @@ model = ExplicitMLP(features=[3,4,5])
params = model.init(key2, x)
y = model.apply(params, x)
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, unfreeze(params)))
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params)))
print('output:\n', y)
```

Expand Down Expand Up @@ -362,7 +362,7 @@ model = SimpleMLP(features=[3,4,5])
params = model.init(key2, x)
y = model.apply(params, x)
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, unfreeze(params)))
print('initialized parameter shapes:\n', jax.tree_util.tree_map(jnp.shape, flax.core.unfreeze(params)))
print('output:\n', y)
```

Expand Down Expand Up @@ -476,8 +476,8 @@ Here, `updated_state` returns only the state variables that are being mutated by
for val in [1.0, 2.0, 3.0]:
x = val * jnp.ones((10,5))
y, updated_state = model.apply(variables, x, mutable=['batch_stats'])
old_state, params = variables.pop('params')
variables = freeze({'params': params, **updated_state})
old_state, params = flax.core.pop(variables, 'params')
variables = flax.core.freeze({'params': params, **updated_state})
print('updated state:\n', updated_state) # Shows only the mutable part
```

Expand Down Expand Up @@ -509,7 +509,7 @@ def update_step(tx, apply_fn, x, opt_state, params, state):
x = jnp.ones((10,5))
variables = model.init(random.PRNGKey(0), x)
state, params = variables.pop('params')
state, params = flax.core.pop(variables, 'params')
del variables
tx = optax.sgd(learning_rate=0.02)
opt_state = tx.init(params)
Expand Down

0 comments on commit 08cb7ef

Please sign in to comment.