Skip to content

Commit

Permalink
Merge pull request #3215 from chiamp:utility_fn
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 549457987
  • Loading branch information
Flax Authors committed Jul 19, 2023
2 parents ac6cde0 + a0dbe39 commit dfb4fc3
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 11 deletions.
2 changes: 1 addition & 1 deletion docs/guides/regular_dict_upgrade_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ The following are the utility functions and example upgrade patterns:
import jax.numpy as jnp

x = jnp.empty((1,3))
variables = nn.Dense(5).init(jax.random.PRNGKey(0), x)
variables = flax.core.freeze(nn.Dense(5).init(jax.random.PRNGKey(0), x))

other_variables = jnp.array([1, 1, 1, 1, 1], dtype=jnp.float32)

Expand Down
10 changes: 5 additions & 5 deletions docs/guides/transfer_learning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -174,11 +174,11 @@
"metadata": {},
"outputs": [],
"source": [
"from flax.core.frozen_dict import freeze\n",
"import flax\n",
"\n",
"params = params.unfreeze()\n",
"params = flax.core.unfreeze(params)\n",
"params['backbone'] = vision_model_vars['params']\n",
"params = freeze(params)"
"params = flax.core.freeze(params)"
]
},
{
Expand Down Expand Up @@ -247,13 +247,13 @@
"import optax\n",
"\n",
"partition_optimizers = {'trainable': optax.adam(5e-3), 'frozen': optax.set_to_zero()}\n",
"param_partitions = freeze(traverse_util.path_aware_map(\n",
"param_partitions = flax.core.freeze(traverse_util.path_aware_map(\n",
" lambda path, v: 'frozen' if 'backbone' in path else 'trainable', params))\n",
"tx = optax.multi_transform(partition_optimizers, param_partitions)\n",
"\n",
"# visualize a subset of the param_partitions structure\n",
"flat = list(traverse_util.flatten_dict(param_partitions).items())\n",
"freeze(traverse_util.unflatten_dict(dict(flat[:2] + flat[-2:])))"
"flax.core.freeze(traverse_util.unflatten_dict(dict(flat[:2] + flat[-2:])))"
]
},
{
Expand Down
10 changes: 5 additions & 5 deletions docs/guides/transfer_learning.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,11 +114,11 @@ params = variables['params']
Since `params` are currently random, the pretrained parameters from `vision_model_vars` have to be transfered to the `params` structure at the appropriate location. This can be done by unfreezing `params`, updating the `backbone` parameters, and freezing the `params` again:

```{code-cell} ipython3
from flax.core.frozen_dict import freeze
import flax
params = params.unfreeze()
params = flax.core.unfreeze(params)
params['backbone'] = vision_model_vars['params']
params = freeze(params)
params = flax.core.freeze(params)
```

**Note:** if the model contains other variable collections such as `batch_stats`, these have to be transfered as well.
Expand Down Expand Up @@ -153,13 +153,13 @@ from flax import traverse_util
import optax
partition_optimizers = {'trainable': optax.adam(5e-3), 'frozen': optax.set_to_zero()}
param_partitions = freeze(traverse_util.path_aware_map(
param_partitions = flax.core.freeze(traverse_util.path_aware_map(
lambda path, v: 'frozen' if 'backbone' in path else 'trainable', params))
tx = optax.multi_transform(partition_optimizers, param_partitions)
# visualize a subset of the param_partitions structure
flat = list(traverse_util.flatten_dict(param_partitions).items())
freeze(traverse_util.unflatten_dict(dict(flat[:2] + flat[-2:])))
flax.core.freeze(traverse_util.unflatten_dict(dict(flat[:2] + flat[-2:])))
```

To implement [differential learning rates](https://blog.slavv.com/differential-learning-rates-59eff5209a4f), the `optax.set_to_zero` can be replaced with any other optimizer, different optimizers and partitioning schemes can be selected depending on the task. For more information on advanced optimizers, refer to Optax's [Combining Optimizers](https://optax.readthedocs.io/en/latest/api.html#combining-optimizers) documentation.
Expand Down

0 comments on commit dfb4fc3

Please sign in to comment.