From 2857f30fd68f6aa8b718f7a669ce5e3aae03ab65 Mon Sep 17 00:00:00 2001 From: ivyzheng Date: Mon, 24 Jul 2023 19:06:21 -0700 Subject: [PATCH] Fix checkpointing guide --- docs/guides/use_checkpointing.ipynb | 120 +++++++++++----------------- docs/guides/use_checkpointing.md | 12 +-- 2 files changed, 49 insertions(+), 83 deletions(-) diff --git a/docs/guides/use_checkpointing.ipynb b/docs/guides/use_checkpointing.ipynb index d2f12ebdec..1811f016f5 100644 --- a/docs/guides/use_checkpointing.ipynb +++ b/docs/guides/use_checkpointing.ipynb @@ -55,21 +55,6 @@ "Install/upgrade Flax and [Orbax](https://github.com/google/orbax). For JAX installation with GPU/TPU support, visit [this section on GitHub](https://github.com/google/jax#installation)." ] }, - { - "cell_type": "code", - "execution_count": 1, - "id": "e80f8743", - "metadata": { - "tags": [ - "skip-execution" - ] - }, - "outputs": [], - "source": [ - "# replace with `pip install -U flax` after release 0.6.9.\n", - "! pip install -U -qq \"git+https://github.com/google/flax.git@main#egg=flax\"" - ] - }, { "attachments": {}, "cell_type": "markdown", @@ -101,15 +86,7 @@ "metadata": { "id": "SJT9DTxTytjn" }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "WARNING:absl:Tensorflow library not found, tensorflow.io.gfile operations will use native shim calls. GCS paths (i.e. 'gs://...') cannot be accessed.\n" - ] - } - ], + "outputs": [], "source": [ "from typing import Optional, Any\n", "import shutil\n", @@ -183,7 +160,7 @@ " [ 0.36260957, 0.18276349, -0.6856061 ],\n", " [-0.8519373 , -0.6416717 , -0.4818122 ],\n", " [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32),\n", - " }), tx=GradientTransformation(init=.init_fn at 0x12ab6d1f0>, update=.update_fn at 0x12abb7ca0>), opt_state=(EmptyState(), EmptyState())),\n", + " }), tx=GradientTransformation(init=.init_fn at 0x158e7cd30>, update=.update_fn at 0x158eaa670>), opt_state=(EmptyState(), EmptyState())),\n", " 'config': {'dimensions': array([5, 3]), 'name': 'dense'},\n", " 'data': [Array([0.59902626, 0.2172144 , 2.4202902 , 0.03266738, 1.2164948 ], dtype=float32)]}" ] @@ -457,17 +434,17 @@ { "data": { "text/plain": [ - "{'model': {'step': 1,\n", + "{'config': {'dimensions': array([5, 3]), 'name': 'dense'},\n", + " 'data': {'0': array([0.59902626, 0.2172144 , 2.4202902 , 0.03266738, 1.2164948 ],\n", + " dtype=float32)},\n", + " 'model': {'opt_state': {'0': None, '1': None},\n", " 'params': {'bias': array([-0.001, -0.001, -0.001], dtype=float32),\n", " 'kernel': array([[ 0.26048955, -0.61399287, -0.23458514],\n", " [ 0.11050402, -0.8765793 , 0.9800635 ],\n", " [ 0.36260957, 0.18276349, -0.6856061 ],\n", " [-0.8519373 , -0.6416717 , -0.4818122 ],\n", " [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32)},\n", - " 'opt_state': {'0': {}, '1': {}}},\n", - " 'config': {'dimensions': array([5, 3]), 'name': 'dense'},\n", - " 'data': {'0': array([0.59902626, 0.2172144 , 2.4202902 , 0.03266738, 1.2164948 ],\n", - " dtype=float32)}}" + " 'step': 1}}" ] }, "execution_count": 11, @@ -532,7 +509,7 @@ " [ 0.36260957, 0.18276349, -0.6856061 ],\n", " [-0.8519373 , -0.6416717 , -0.4818122 ],\n", " [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32),\n", - " }), tx=GradientTransformation(init=.init_fn at 0x12ab6d1f0>, update=.update_fn at 0x12abb7ca0>), opt_state=(EmptyState(), EmptyState()))}" + " }), tx=GradientTransformation(init=.init_fn at 0x158e7cd30>, update=.update_fn at 0x158eaa670>), opt_state=(EmptyState(), EmptyState()))}" ] }, "execution_count": 12, @@ -591,7 +568,7 @@ " [ 0.36260957, 0.18276349, -0.6856061 ],\n", " [-0.8519373 , -0.6416717 , -0.4818122 ],\n", " [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32),\n", - " }), tx=GradientTransformation(init=.init_fn at 0x12ab6d1f0>, update=.update_fn at 0x12abb7ca0>), opt_state=(EmptyState(), EmptyState()))}" + " }), tx=GradientTransformation(init=.init_fn at 0x158e7cd30>, update=.update_fn at 0x158eaa670>), opt_state=(EmptyState(), EmptyState()))}" ] }, "execution_count": 13, @@ -629,7 +606,7 @@ " [ 0.36260957, 0.18276349, -0.6856061 ],\n", " [-0.8519373 , -0.6416717 , -0.4818122 ],\n", " [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32),\n", - " }), tx=GradientTransformation(init=.init_fn at 0x12ab6d1f0>, update=.update_fn at 0x12abb7ca0>), opt_state=(EmptyState(), EmptyState())),\n", + " }), tx=GradientTransformation(init=.init_fn at 0x158e7cd30>, update=.update_fn at 0x158eaa670>), opt_state=(EmptyState(), EmptyState())),\n", " 'config': {'dimensions': array([5, 3]), 'name': 'dense'},\n", " 'data': [array([0.59902626, 0.2172144 , 2.4202902 , 0.03266738, 1.2164948 ],\n", " dtype=float32)]}" @@ -767,7 +744,7 @@ " [ 0.36260957, 0.18276349, -0.6856061 ],\n", " [-0.8519373 , -0.6416717 , -0.4818122 ],\n", " [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32),\n", - " }), tx=GradientTransformation(init=.init_fn at 0x12ab6d1f0>, update=.update_fn at 0x12abb7ca0>), opt_state=(EmptyState(), EmptyState()))}" + " }), tx=GradientTransformation(init=.init_fn at 0x158e7cd30>, update=.update_fn at 0x158eaa670>), opt_state=(EmptyState(), EmptyState()))}" ] }, "execution_count": 16, @@ -799,7 +776,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 20, "id": "a5d14c9f", "metadata": {}, "outputs": [ @@ -834,10 +811,10 @@ " [ 0.36260957, 0.18276349, -0.6856061 ],\n", " [-0.8519373 , -0.6416717 , -0.4818122 ],\n", " [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32),\n", - " }), tx=GradientTransformation(init=.init_fn at 0x12ab6d1f0>, update=.update_fn at 0x12abb7ca0>), opt_state=(EmptyState(), EmptyState()), batch_stats=array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]))}" + " }), tx=GradientTransformation(init=.init_fn at 0x158e7cd30>, update=.update_fn at 0x158eaa670>), opt_state=(EmptyState(), EmptyState()), batch_stats=array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]))}" ] }, - "execution_count": 17, + "execution_count": 20, "metadata": {}, "output_type": "execute_result" } @@ -850,10 +827,11 @@ " print('')\n", "\n", "# Step 4 is an original `TrainState`, without the `batch_stats`\n", - "restored = checkpoint_manager.restore(4, items=custom_target, \n", - " restore_kwargs={'transforms': {}})\n", + "custom_restore_args = orbax_utils.restore_args_from_target(custom_target)\n", + "restored = checkpoint_manager.restore(4, items=custom_target,\n", + " restore_kwargs={'transforms': {}, 'restore_args': custom_restore_args})\n", "assert type(restored['model']) == CustomTrainState\n", - "np.testing.assert_equal(restored['model'].batch_stats, \n", + "np.testing.assert_equal(restored['model'].batch_stats,\n", " custom_target['model'].batch_stats)\n", "restored" ] @@ -870,7 +848,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 21, "id": "29fd1e33", "metadata": { "id": "29fd1e33", @@ -897,13 +875,13 @@ " [ 0.36260957, 0.18276349, -0.6856061 ],\n", " [-0.8519373 , -0.6416717 , -0.4818122 ],\n", " [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32),\n", - " }), tx=GradientTransformation(init=.init_fn at 0x12ab6d1f0>, update=.update_fn at 0x12abb7ca0>), opt_state=(EmptyState(), EmptyState()), batch_stats=array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])),\n", + " }), tx=GradientTransformation(init=.init_fn at 0x158e7cd30>, update=.update_fn at 0x158eaa670>), opt_state=(EmptyState(), EmptyState()), batch_stats=array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])),\n", " 'config': None,\n", " 'data': [array([0.59902626, 0.2172144 , 2.4202902 , 0.03266738, 1.2164948 ],\n", " dtype=float32)]}" ] }, - "execution_count": 18, + "execution_count": 21, "metadata": {}, "output_type": "execute_result" } @@ -935,7 +913,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 22, "id": "051e7a16", "metadata": {}, "outputs": [ @@ -959,13 +937,13 @@ " [ 0.36260957, 0.18276349, -0.6856061 ],\n", " [-0.8519373 , -0.6416717 , -0.4818122 ],\n", " [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32),\n", - " }), tx=GradientTransformation(init=.init_fn at 0x12ab6d1f0>, update=.update_fn at 0x12abb7ca0>), opt_state=(EmptyState(), EmptyState()), batch_stats=array([9, 8, 7, 6, 5, 4, 3, 2, 1, 0])),\n", + " }), tx=GradientTransformation(init=.init_fn at 0x158e7cd30>, update=.update_fn at 0x158eaa670>), opt_state=(EmptyState(), EmptyState()), batch_stats=array([9, 8, 7, 6, 5, 4, 3, 2, 1, 0])),\n", " 'config': {'dimensions': array([5, 3]), 'name': 'dense'},\n", " 'data': [array([0.59902626, 0.2172144 , 2.4202902 , 0.03266738, 1.2164948 ],\n", " dtype=float32)]}" ] }, - "execution_count": 19, + "execution_count": 22, "metadata": {}, "output_type": "execute_result" } @@ -1007,7 +985,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 23, "id": "85be68a6", "metadata": { "id": "85be68a6", @@ -1037,10 +1015,10 @@ " [ 0.36260957, 0.18276349, -0.6856061 ],\n", " [-0.8519373 , -0.6416717 , -0.4818122 ],\n", " [-0.6886102 , -0.33987316, -0.05898903]], dtype=float32),\n", - " }), tx=GradientTransformation(init=.init_fn at 0x12ab6d1f0>, update=.update_fn at 0x12abb7ca0>), opt_state=(EmptyState(), EmptyState()))}" + " }), tx=GradientTransformation(init=.init_fn at 0x158e7cd30>, update=.update_fn at 0x158eaa670>), opt_state=(EmptyState(), EmptyState()))}" ] }, - "execution_count": 20, + "execution_count": 23, "metadata": {}, "output_type": "execute_result" } @@ -1077,7 +1055,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 24, "id": "af33b138", "metadata": {}, "outputs": [], @@ -1105,7 +1083,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 25, "id": "ubdUvyMrhD-1", "metadata": { "id": "ubdUvyMrhD-1" @@ -1128,7 +1106,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 26, "id": "a669bc05", "metadata": {}, "outputs": [], @@ -1153,24 +1131,20 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 27, "id": "b8e7daaa", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "{'model': array([[ 0, 1],\n", - " [ 2, 3],\n", - " [ 4, 5],\n", - " [ 6, 7],\n", - " [ 8, 9],\n", - " [10, 11],\n", - " [12, 13],\n", - " [14, 15]], dtype=int32)}" + "{'model': Array([[0, 1],\n", + " [2, 3],\n", + " [4, 5],\n", + " [6, 7]], dtype=int32)}" ] }, - "execution_count": 24, + "execution_count": 27, "metadata": {}, "output_type": "execute_result" } @@ -1203,7 +1177,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 28, "id": "5d10039b", "metadata": { "id": "5d10039b", @@ -1216,7 +1190,7 @@ "'tmp/checkpoint_3'" ] }, - "execution_count": 25, + "execution_count": 28, "metadata": {}, "output_type": "execute_result" } @@ -1233,7 +1207,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 29, "id": "a9f9724c", "metadata": { "id": "a9f9724c", @@ -1243,17 +1217,13 @@ { "data": { "text/plain": [ - "{'model': array([[ 0, 1],\n", - " [ 2, 3],\n", - " [ 4, 5],\n", - " [ 6, 7],\n", - " [ 8, 9],\n", - " [10, 11],\n", - " [12, 13],\n", - " [14, 15]], dtype=int32)}" + "{'model': Array([[0, 1],\n", + " [2, 3],\n", + " [4, 5],\n", + " [6, 7]], dtype=int32)}" ] }, - "execution_count": 26, + "execution_count": 29, "metadata": {}, "output_type": "execute_result" } @@ -1336,7 +1306,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.15" + "version": "3.9.16" } }, "nbformat": 4, diff --git a/docs/guides/use_checkpointing.md b/docs/guides/use_checkpointing.md index fa19b204ff..5d170b1a10 100644 --- a/docs/guides/use_checkpointing.md +++ b/docs/guides/use_checkpointing.md @@ -53,11 +53,6 @@ If you need to learn more about `orbax.checkpoint`, refer to the [Orbax docs](ht Install/upgrade Flax and [Orbax](https://github.com/google/orbax). For JAX installation with GPU/TPU support, visit [this section on GitHub](https://github.com/google/jax#installation). -```python tags=["skip-execution"] -# replace with `pip install -U flax` after release 0.6.9. -! pip install -U -qq "git+https://github.com/google/flax.git@main#egg=flax" -``` - Note: Before running `import jax`, create eight fake devices to mimic a [multi-host environment](https://jax.readthedocs.io/en/latest/jax-101/06-parallelism.html?#aside-hosts-and-devices-in-jax) in this notebook. Note that the order of imports is important here. The `os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=8'` command works only with the CPU backend, which means it won't work with GPU/TPU acceleration on if you're running this notebook in Google Colab. If you are already running the code on multiple devices (for example, in a 4x2 TPU environment), you can skip running the next cell. @@ -318,10 +313,11 @@ except KeyError as e: print('') # Step 4 is an original `TrainState`, without the `batch_stats` -restored = checkpoint_manager.restore(4, items=custom_target, - restore_kwargs={'transforms': {}}) +custom_restore_args = orbax_utils.restore_args_from_target(custom_target) +restored = checkpoint_manager.restore(4, items=custom_target, + restore_kwargs={'transforms': {}, 'restore_args': custom_restore_args}) assert type(restored['model']) == CustomTrainState -np.testing.assert_equal(restored['model'].batch_stats, +np.testing.assert_equal(restored['model'].batch_stats, custom_target['model'].batch_stats) restored ```