Skip to content

Commit

Permalink
Minor tweaks to JAX Gemma docs (google#290)
Browse files Browse the repository at this point in the history
* Minor tweaks to JAX Gemma docs

* Run nbfmt

---------

Co-authored-by: Mark McDonald <macd@google.com>
  • Loading branch information
8bitmp3 and markmcd authored Mar 15, 2024
1 parent c908a50 commit 4b2cbf8
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 13 deletions.
14 changes: 7 additions & 7 deletions site/en/gemma/docs/jax_finetune.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@
"\n",
"After you've completed the Gemma setup, move on to the next section, where you'll set environment variables for your Colab environment.\n",
"\n",
"### Set environment variables\n",
"### 2. Set environment variables\n",
"\n",
"Set environment variables for `KAGGLE_USERNAME` and `KAGGLE_KEY`. When prompted with the \"Grant access?\" messages, agree to provide secret access."
]
Expand All @@ -128,7 +128,7 @@
"id": "m1UE1CEnE9ql"
},
"source": [
"### 2. Install the `gemma` library\n",
"### 3. Install the `gemma` library\n",
"\n",
"Free Colab hardware acceleration is currently *insufficient* to run this notebook. If you are using [Colab Pay As You Go or Colab Pro](https://colab.research.google.com/signup), click on **Edit** > **Notebook settings** > Select **A100 GPU** > **Save** to enable hardware acceleration.\n",
"\n",
Expand Down Expand Up @@ -170,7 +170,7 @@
"id": "-mRkkT-iPYoq"
},
"source": [
"### 3. Import libraries\n",
"### 4. Import libraries\n",
"\n",
"This notebook uses [Flax](https://flax.readthedocs.io) (for neural networks), core [JAX](https://jax.readthedocs.io), [SentencePiece](https://github.com/google/sentencepiece) (for tokenization), [Chex](https://chex.readthedocs.io/en/latest/) (a library of utilities for writing reliable JAX code), and TensorFlow Datasets."
]
Expand Down Expand Up @@ -912,9 +912,9 @@
"source": [
"## Configure the model\n",
"\n",
"Before you begin fine-tuning the Gemma model, configure it as follows:\n",
"Before you begin fine-tuning the Gemma model, you need to configure it.\n",
"\n",
"Load and format the Gemma model checkpoint with the [`gemma.params`](https://github.com/google-deepmind/gemma/blob/main/gemma/params.py) method:"
"First, load and format the Gemma model checkpoint with the [`gemma.params.load_and_format_params`](https://github.com/google-deepmind/gemma/blob/c6bd156c246530e1620a7c62de98542a377e3934/gemma/params.py#L27) method:"
]
},
{
Expand All @@ -934,7 +934,7 @@
"id": "BtJhJkkZzsy1"
},
"source": [
"To automatically load the correct configuration from the Gemma model checkpoint, use [`gemma.transformer.TransformerConfig`](https://github.com/google-deepmind/gemma/blob/56e501ce147af4ea5c23cc0ddf5a9c4a6b7bd0d0/gemma/transformer.py#L65). The `cache_size` argument is the number of time steps in the Gemma `transformer` cache. Afterwards, instantiate the Gemma model as `transformer` with [`gemma.transformer.Transformer`](https://github.com/google-deepmind/gemma/blob/56e501ce147af4ea5c23cc0ddf5a9c4a6b7bd0d0/gemma/transformer.py#L136) (which inherits from [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html).\n",
"To automatically load the correct configuration from the Gemma model checkpoint, use [`gemma.transformer.TransformerConfig`](https://github.com/google-deepmind/gemma/blob/56e501ce147af4ea5c23cc0ddf5a9c4a6b7bd0d0/gemma/transformer.py#L65). The `cache_size` argument is the number of time steps in the Gemma `Transformer` cache. Afterwards, instantiate the Gemma model as `model_2b` with [`gemma.transformer.Transformer`](https://github.com/google-deepmind/gemma/blob/56e501ce147af4ea5c23cc0ddf5a9c4a6b7bd0d0/gemma/transformer.py#L136) (which inherits from [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html)).\n",
"\n",
"**Note:** The vocabulary size is smaller than the number of input embeddings because of unused tokens in the current Gemma release."
]
Expand Down Expand Up @@ -1375,7 +1375,7 @@
"source": [
"## Learn more\n",
"\n",
"- You can learn more about the Google DeepMind [`gemma` library on GitHub](https://github.com/google-deepmind/gemma), which contains docstrings of methods you used in this tutorial, such as [`gemma.params`](https://github.com/google-deepmind/gemma/blob/main/gemma/params.py),\n",
"- You can learn more about the Google DeepMind [`gemma` library on GitHub](https://github.com/google-deepmind/gemma), which contains docstrings of modules you used in this tutorial, such as [`gemma.params`](https://github.com/google-deepmind/gemma/blob/main/gemma/params.py),\n",
"[`gemma.transformer`](https://github.com/google-deepmind/gemma/blob/main/gemma/transformer.py), and\n",
"[`gemma.sampler`](https://github.com/google-deepmind/gemma/blob/main/gemma/sampler.py).\n",
"- The following libraries have their own documentation sites: [core JAX](https://jax.readthedocs.io), [Flax](https://flax.readthedocs.io), [Chex](https://chex.readthedocs.io/en/latest/), [Optax](https://optax.readthedocs.io/en/latest/), and [Orbax](https://orbax.readthedocs.io/).\n",
Expand Down
12 changes: 6 additions & 6 deletions site/en/gemma/docs/jax_inference.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@
"\n",
"After you've completed the Gemma setup, move on to the next section, where you'll set environment variables for your Colab environment.\n",
"\n",
"### Set environment variables\n",
"### 2. Set environment variables\n",
"\n",
"Set environment variables for `KAGGLE_USERNAME` and `KAGGLE_KEY`. When prompted with the \"Grant access?\" messages, agree to provide secret access."
]
Expand All @@ -126,7 +126,7 @@
"id": "AO7a1Q4Yyc9Z"
},
"source": [
"### 2. Install the `gemma` library\n",
"### 3. Install the `gemma` library\n",
"\n",
"This notebook focuses on using a free Colab GPU. To enable hardware acceleration, click on **Edit** > **Notebook settings** > Select **T4 GPU** > **Save**.\n",
"\n",
Expand Down Expand Up @@ -291,7 +291,7 @@
"id": "aEe3p8geqekV"
},
"source": [
"1. Load and format the Gemma model checkpoint with the [`gemma.params`](https://github.com/google-deepmind/gemma/blob/main/gemma/params.py) method:"
"1. Load and format the Gemma model checkpoint with the [`gemma.params.load_and_format_params`](https://github.com/google-deepmind/gemma/blob/c6bd156c246530e1620a7c62de98542a377e3934/gemma/params.py#L27) method:"
]
},
{
Expand Down Expand Up @@ -347,7 +347,7 @@
"id": "IkAf4fkNrY-3"
},
"source": [
"3. To automatically load the correct configuration from the Gemma model checkpoint, use [`gemma.transformer.TransformerConfig`](https://github.com/google-deepmind/gemma/blob/56e501ce147af4ea5c23cc0ddf5a9c4a6b7bd0d0/gemma/transformer.py#L65). The `cache_size` argument is the number of time steps in the Gemma `transformer` cache. Afterwards, instantiate the Gemma model as `transformer` with [`gemma.transformer.Transformer`](https://github.com/google-deepmind/gemma/blob/56e501ce147af4ea5c23cc0ddf5a9c4a6b7bd0d0/gemma/transformer.py#L136) (which inherits from [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html)).\n",
"3. To automatically load the correct configuration from the Gemma model checkpoint, use [`gemma.transformer.TransformerConfig`](https://github.com/google-deepmind/gemma/blob/56e501ce147af4ea5c23cc0ddf5a9c4a6b7bd0d0/gemma/transformer.py#L65). The `cache_size` argument is the number of time steps in the Gemma `Transformer` cache. Afterwards, instantiate the Gemma model as `transformer` with [`gemma.transformer.Transformer`](https://github.com/google-deepmind/gemma/blob/56e501ce147af4ea5c23cc0ddf5a9c4a6b7bd0d0/gemma/transformer.py#L136) (which inherits from [`flax.linen.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.linen/module.html)).\n",
"\n",
"**Note:** The vocabulary size is smaller than the number of input embeddings because of unused tokens in the current Gemma release."
]
Expand Down Expand Up @@ -452,7 +452,7 @@
"id": "njxRJy3qsBWw"
},
"source": [
"5. (Option) Run this cell to free up memory if you have completed the notebook and want to try another prompt. Afterwards, you can instantiate the `sampler` again in step 3 and customize and run the prompt in step 4."
"5. (Optional) Run this cell to free up memory if you have completed the notebook and want to try another prompt. Afterwards, you can instantiate the `sampler` again in step 3 and customize and run the prompt in step 4."
]
},
{
Expand All @@ -474,7 +474,7 @@
"source": [
"## Learn more\n",
"\n",
"- You can learn more about the Google DeepMind [`gemma` library on GitHub](https://github.com/google-deepmind/gemma), which contains docstrings of methods you used in this tutorial, such as [`gemma.params`](https://github.com/google-deepmind/gemma/blob/main/gemma/params.py),\n",
"- You can learn more about the Google DeepMind [`gemma` library on GitHub](https://github.com/google-deepmind/gemma), which contains docstrings of modules you used in this tutorial, such as [`gemma.params`](https://github.com/google-deepmind/gemma/blob/main/gemma/params.py),\n",
"[`gemma.transformer`](https://github.com/google-deepmind/gemma/blob/main/gemma/transformer.py), and\n",
"[`gemma.sampler`](https://github.com/google-deepmind/gemma/blob/main/gemma/sampler.py).\n",
"- The following libraries have their own documentation sites: [core JAX](https://jax.readthedocs.io), [Flax](https://flax.readthedocs.io), and [Orbax](https://orbax.readthedocs.io/).\n",
Expand Down

0 comments on commit 4b2cbf8

Please sign in to comment.