Skip to content

Commit

Permalink
Merge pull request #3223 from IvyZX:orbax
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 550742979
  • Loading branch information
Flax Authors committed Jul 25, 2023
2 parents 9ba7cc0 + 2857f30 commit 2497f82
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 83 deletions.
120 changes: 45 additions & 75 deletions docs/guides/use_checkpointing.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -55,21 +55,6 @@
"Install/upgrade Flax and [Orbax](https://github.com/google/orbax). For JAX installation with GPU/TPU support, visit [this section on GitHub](https://github.com/google/jax#installation)."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "e80f8743",
"metadata": {
"tags": [
"skip-execution"
]
},
"outputs": [],
"source": [
"# replace with `pip install -U flax` after release 0.6.9.\n",
"! pip install -U -qq \"git+https://github.com/google/flax.git@main#egg=flax\""
]
},
{
"attachments": {},
"cell_type": "markdown",
Expand Down Expand Up @@ -101,15 +86,7 @@
"metadata": {
"id": "SJT9DTxTytjn"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:absl:Tensorflow library not found, tensorflow.io.gfile operations will use native shim calls. GCS paths (i.e. 'gs://...') cannot be accessed.\n"
]
}
],
"outputs": [],
"source": [
"from typing import Optional, Any\n",
"import shutil\n",
Expand Down Expand Up @@ -183,7 +160,7 @@
" [ 0.36260957, 0.18276349, -0.6856061 ],\n",
" [-0.8519373 , -0.6416717 , -0.4818122 ],\n",
" [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32),\n",
" }), tx=GradientTransformation(init=<function chain.<locals>.init_fn at 0x12ab6d1f0>, update=<function chain.<locals>.update_fn at 0x12abb7ca0>), opt_state=(EmptyState(), EmptyState())),\n",
" }), tx=GradientTransformation(init=<function chain.<locals>.init_fn at 0x158e7cd30>, update=<function chain.<locals>.update_fn at 0x158eaa670>), opt_state=(EmptyState(), EmptyState())),\n",
" 'config': {'dimensions': array([5, 3]), 'name': 'dense'},\n",
" 'data': [Array([0.59902626, 0.2172144 , 2.4202902 , 0.03266738, 1.2164948 ], dtype=float32)]}"
]
Expand Down Expand Up @@ -457,17 +434,17 @@
{
"data": {
"text/plain": [
"{'model': {'step': 1,\n",
"{'config': {'dimensions': array([5, 3]), 'name': 'dense'},\n",
" 'data': {'0': array([0.59902626, 0.2172144 , 2.4202902 , 0.03266738, 1.2164948 ],\n",
" dtype=float32)},\n",
" 'model': {'opt_state': {'0': None, '1': None},\n",
" 'params': {'bias': array([-0.001, -0.001, -0.001], dtype=float32),\n",
" 'kernel': array([[ 0.26048955, -0.61399287, -0.23458514],\n",
" [ 0.11050402, -0.8765793 , 0.9800635 ],\n",
" [ 0.36260957, 0.18276349, -0.6856061 ],\n",
" [-0.8519373 , -0.6416717 , -0.4818122 ],\n",
" [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32)},\n",
" 'opt_state': {'0': {}, '1': {}}},\n",
" 'config': {'dimensions': array([5, 3]), 'name': 'dense'},\n",
" 'data': {'0': array([0.59902626, 0.2172144 , 2.4202902 , 0.03266738, 1.2164948 ],\n",
" dtype=float32)}}"
" 'step': 1}}"
]
},
"execution_count": 11,
Expand Down Expand Up @@ -532,7 +509,7 @@
" [ 0.36260957, 0.18276349, -0.6856061 ],\n",
" [-0.8519373 , -0.6416717 , -0.4818122 ],\n",
" [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32),\n",
" }), tx=GradientTransformation(init=<function chain.<locals>.init_fn at 0x12ab6d1f0>, update=<function chain.<locals>.update_fn at 0x12abb7ca0>), opt_state=(EmptyState(), EmptyState()))}"
" }), tx=GradientTransformation(init=<function chain.<locals>.init_fn at 0x158e7cd30>, update=<function chain.<locals>.update_fn at 0x158eaa670>), opt_state=(EmptyState(), EmptyState()))}"
]
},
"execution_count": 12,
Expand Down Expand Up @@ -591,7 +568,7 @@
" [ 0.36260957, 0.18276349, -0.6856061 ],\n",
" [-0.8519373 , -0.6416717 , -0.4818122 ],\n",
" [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32),\n",
" }), tx=GradientTransformation(init=<function chain.<locals>.init_fn at 0x12ab6d1f0>, update=<function chain.<locals>.update_fn at 0x12abb7ca0>), opt_state=(EmptyState(), EmptyState()))}"
" }), tx=GradientTransformation(init=<function chain.<locals>.init_fn at 0x158e7cd30>, update=<function chain.<locals>.update_fn at 0x158eaa670>), opt_state=(EmptyState(), EmptyState()))}"
]
},
"execution_count": 13,
Expand Down Expand Up @@ -629,7 +606,7 @@
" [ 0.36260957, 0.18276349, -0.6856061 ],\n",
" [-0.8519373 , -0.6416717 , -0.4818122 ],\n",
" [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32),\n",
" }), tx=GradientTransformation(init=<function chain.<locals>.init_fn at 0x12ab6d1f0>, update=<function chain.<locals>.update_fn at 0x12abb7ca0>), opt_state=(EmptyState(), EmptyState())),\n",
" }), tx=GradientTransformation(init=<function chain.<locals>.init_fn at 0x158e7cd30>, update=<function chain.<locals>.update_fn at 0x158eaa670>), opt_state=(EmptyState(), EmptyState())),\n",
" 'config': {'dimensions': array([5, 3]), 'name': 'dense'},\n",
" 'data': [array([0.59902626, 0.2172144 , 2.4202902 , 0.03266738, 1.2164948 ],\n",
" dtype=float32)]}"
Expand Down Expand Up @@ -767,7 +744,7 @@
" [ 0.36260957, 0.18276349, -0.6856061 ],\n",
" [-0.8519373 , -0.6416717 , -0.4818122 ],\n",
" [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32),\n",
" }), tx=GradientTransformation(init=<function chain.<locals>.init_fn at 0x12ab6d1f0>, update=<function chain.<locals>.update_fn at 0x12abb7ca0>), opt_state=(EmptyState(), EmptyState()))}"
" }), tx=GradientTransformation(init=<function chain.<locals>.init_fn at 0x158e7cd30>, update=<function chain.<locals>.update_fn at 0x158eaa670>), opt_state=(EmptyState(), EmptyState()))}"
]
},
"execution_count": 16,
Expand Down Expand Up @@ -799,7 +776,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 20,
"id": "a5d14c9f",
"metadata": {},
"outputs": [
Expand Down Expand Up @@ -834,10 +811,10 @@
" [ 0.36260957, 0.18276349, -0.6856061 ],\n",
" [-0.8519373 , -0.6416717 , -0.4818122 ],\n",
" [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32),\n",
" }), tx=GradientTransformation(init=<function chain.<locals>.init_fn at 0x12ab6d1f0>, update=<function chain.<locals>.update_fn at 0x12abb7ca0>), opt_state=(EmptyState(), EmptyState()), batch_stats=array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]))}"
" }), tx=GradientTransformation(init=<function chain.<locals>.init_fn at 0x158e7cd30>, update=<function chain.<locals>.update_fn at 0x158eaa670>), opt_state=(EmptyState(), EmptyState()), batch_stats=array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]))}"
]
},
"execution_count": 17,
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -850,10 +827,11 @@
" print('')\n",
"\n",
"# Step 4 is an original `TrainState`, without the `batch_stats`\n",
"restored = checkpoint_manager.restore(4, items=custom_target, \n",
" restore_kwargs={'transforms': {}})\n",
"custom_restore_args = orbax_utils.restore_args_from_target(custom_target)\n",
"restored = checkpoint_manager.restore(4, items=custom_target,\n",
" restore_kwargs={'transforms': {}, 'restore_args': custom_restore_args})\n",
"assert type(restored['model']) == CustomTrainState\n",
"np.testing.assert_equal(restored['model'].batch_stats, \n",
"np.testing.assert_equal(restored['model'].batch_stats,\n",
" custom_target['model'].batch_stats)\n",
"restored"
]
Expand All @@ -870,7 +848,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 21,
"id": "29fd1e33",
"metadata": {
"id": "29fd1e33",
Expand All @@ -897,13 +875,13 @@
" [ 0.36260957, 0.18276349, -0.6856061 ],\n",
" [-0.8519373 , -0.6416717 , -0.4818122 ],\n",
" [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32),\n",
" }), tx=GradientTransformation(init=<function chain.<locals>.init_fn at 0x12ab6d1f0>, update=<function chain.<locals>.update_fn at 0x12abb7ca0>), opt_state=(EmptyState(), EmptyState()), batch_stats=array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])),\n",
" }), tx=GradientTransformation(init=<function chain.<locals>.init_fn at 0x158e7cd30>, update=<function chain.<locals>.update_fn at 0x158eaa670>), opt_state=(EmptyState(), EmptyState()), batch_stats=array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])),\n",
" 'config': None,\n",
" 'data': [array([0.59902626, 0.2172144 , 2.4202902 , 0.03266738, 1.2164948 ],\n",
" dtype=float32)]}"
]
},
"execution_count": 18,
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -935,7 +913,7 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 22,
"id": "051e7a16",
"metadata": {},
"outputs": [
Expand All @@ -959,13 +937,13 @@
" [ 0.36260957, 0.18276349, -0.6856061 ],\n",
" [-0.8519373 , -0.6416717 , -0.4818122 ],\n",
" [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32),\n",
" }), tx=GradientTransformation(init=<function chain.<locals>.init_fn at 0x12ab6d1f0>, update=<function chain.<locals>.update_fn at 0x12abb7ca0>), opt_state=(EmptyState(), EmptyState()), batch_stats=array([9, 8, 7, 6, 5, 4, 3, 2, 1, 0])),\n",
" }), tx=GradientTransformation(init=<function chain.<locals>.init_fn at 0x158e7cd30>, update=<function chain.<locals>.update_fn at 0x158eaa670>), opt_state=(EmptyState(), EmptyState()), batch_stats=array([9, 8, 7, 6, 5, 4, 3, 2, 1, 0])),\n",
" 'config': {'dimensions': array([5, 3]), 'name': 'dense'},\n",
" 'data': [array([0.59902626, 0.2172144 , 2.4202902 , 0.03266738, 1.2164948 ],\n",
" dtype=float32)]}"
]
},
"execution_count": 19,
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -1007,7 +985,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 23,
"id": "85be68a6",
"metadata": {
"id": "85be68a6",
Expand Down Expand Up @@ -1037,10 +1015,10 @@
" [ 0.36260957, 0.18276349, -0.6856061 ],\n",
" [-0.8519373 , -0.6416717 , -0.4818122 ],\n",
" [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32),\n",
" }), tx=GradientTransformation(init=<function chain.<locals>.init_fn at 0x12ab6d1f0>, update=<function chain.<locals>.update_fn at 0x12abb7ca0>), opt_state=(EmptyState(), EmptyState()))}"
" }), tx=GradientTransformation(init=<function chain.<locals>.init_fn at 0x158e7cd30>, update=<function chain.<locals>.update_fn at 0x158eaa670>), opt_state=(EmptyState(), EmptyState()))}"
]
},
"execution_count": 20,
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -1077,7 +1055,7 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 24,
"id": "af33b138",
"metadata": {},
"outputs": [],
Expand Down Expand Up @@ -1105,7 +1083,7 @@
},
{
"cell_type": "code",
"execution_count": 22,
"execution_count": 25,
"id": "ubdUvyMrhD-1",
"metadata": {
"id": "ubdUvyMrhD-1"
Expand All @@ -1128,7 +1106,7 @@
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": 26,
"id": "a669bc05",
"metadata": {},
"outputs": [],
Expand All @@ -1153,24 +1131,20 @@
},
{
"cell_type": "code",
"execution_count": 24,
"execution_count": 27,
"id": "b8e7daaa",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'model': array([[ 0, 1],\n",
" [ 2, 3],\n",
" [ 4, 5],\n",
" [ 6, 7],\n",
" [ 8, 9],\n",
" [10, 11],\n",
" [12, 13],\n",
" [14, 15]], dtype=int32)}"
"{'model': Array([[0, 1],\n",
" [2, 3],\n",
" [4, 5],\n",
" [6, 7]], dtype=int32)}"
]
},
"execution_count": 24,
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -1203,7 +1177,7 @@
},
{
"cell_type": "code",
"execution_count": 25,
"execution_count": 28,
"id": "5d10039b",
"metadata": {
"id": "5d10039b",
Expand All @@ -1216,7 +1190,7 @@
"'tmp/checkpoint_3'"
]
},
"execution_count": 25,
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -1233,7 +1207,7 @@
},
{
"cell_type": "code",
"execution_count": 26,
"execution_count": 29,
"id": "a9f9724c",
"metadata": {
"id": "a9f9724c",
Expand All @@ -1243,17 +1217,13 @@
{
"data": {
"text/plain": [
"{'model': array([[ 0, 1],\n",
" [ 2, 3],\n",
" [ 4, 5],\n",
" [ 6, 7],\n",
" [ 8, 9],\n",
" [10, 11],\n",
" [12, 13],\n",
" [14, 15]], dtype=int32)}"
"{'model': Array([[0, 1],\n",
" [2, 3],\n",
" [4, 5],\n",
" [6, 7]], dtype=int32)}"
]
},
"execution_count": 26,
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -1336,7 +1306,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.15"
"version": "3.9.16"
}
},
"nbformat": 4,
Expand Down
12 changes: 4 additions & 8 deletions docs/guides/use_checkpointing.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,6 @@ If you need to learn more about `orbax.checkpoint`, refer to the [Orbax docs](ht
Install/upgrade Flax and [Orbax](https://github.com/google/orbax). For JAX installation with GPU/TPU support, visit [this section on GitHub](https://github.com/google/jax#installation).
<!-- #endregion -->

```python tags=["skip-execution"]
# replace with `pip install -U flax` after release 0.6.9.
! pip install -U -qq "git+https://github.com/google/flax.git@main#egg=flax"
```

<!-- #region id="-icO30rwmKYj" -->
Note: Before running `import jax`, create eight fake devices to mimic a [multi-host environment](https://jax.readthedocs.io/en/latest/jax-101/06-parallelism.html?#aside-hosts-and-devices-in-jax) in this notebook. Note that the order of imports is important here. The `os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'` command works only with the CPU backend, which means it won't work with GPU/TPU acceleration on if you're running this notebook in Google Colab. If you are already running the code on multiple devices (for example, in a 4x2 TPU environment), you can skip running the next cell.
<!-- #endregion -->
Expand Down Expand Up @@ -318,10 +313,11 @@ except KeyError as e:
print('')

# Step 4 is an original `TrainState`, without the `batch_stats`
restored = checkpoint_manager.restore(4, items=custom_target,
restore_kwargs={'transforms': {}})
custom_restore_args = orbax_utils.restore_args_from_target(custom_target)
restored = checkpoint_manager.restore(4, items=custom_target,
restore_kwargs={'transforms': {}, 'restore_args': custom_restore_args})
assert type(restored['model']) == CustomTrainState
np.testing.assert_equal(restored['model'].batch_stats,
np.testing.assert_equal(restored['model'].batch_stats,
custom_target['model'].batch_stats)
restored
```
Expand Down

0 comments on commit 2497f82

Please sign in to comment.