Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade Flax NNX Gemma Sampling Inference doc #4325

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 50 additions & 33 deletions docs_nnx/guides/gemma.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,22 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# Example: Using Pretrained Gemma\n",
"# Example: Using pretrained Gemma for inference with Flax NNX\n",
"\n",
"You will find in this colab a detailed tutorial explaining how to use NNX to load a Gemma checkpoint and sample from it."
"This example shows how to use Flax NNX to load the [Gemma](https://ai.google.dev/gemma) open model files and use them to perform sampling/inference for generating text. You will use [Flax NNX `gemma` modules](https://github.com/google/flax/tree/main/examples/gemma) written with Flax and JAX for model parameter configuration and inference.\n",
"\n",
"> Gemma is a family of lightweight, state-of-the-art open models based on Google DeepMind’s [Gemini](https://deepmind.google/technologies/gemini/#introduction). Read more about [Gemma](https://blog.google/technology/developers/gemma-open-models/) and [Gemma 2](https://blog.google/technology/developers/google-gemma-2/).\n",
"\n",
"You are recommended to use [Google Colab](https://colab.research.google.com/) with access to A100 GPU acceleration to run the code."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Installation"
"## Installation\n",
"\n",
"Install the necessary dependencies, including `kagglehub`."
]
},
{
Expand All @@ -23,20 +29,20 @@
"outputs": [],
"source": [
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"guide" in the title, "tutorial" in the first paragraph -> let's use "tutorial".

"! pip install --no-deps -U flax\n",
"! pip install jaxtyping kagglehub treescope"
"! pip install jaxtyping kagglehub penzai"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Downloading the checkpoint\n",
"## Download the model\n",
"\n",
"\"To use Gemma's checkpoints, you'll need a Kaggle account and API key. Here's how to get them:\n",
"To use Gemma model, you'll need a [Kaggle](https://www.kaggle.com/models/google/gemma/) account and API key:\n",
"\n",
"1. Visit https://www.kaggle.com/ and create an account.\n",
"2. Go to your account settings, then the 'API' section.\n",
"3. Click 'Create new token' to download your key.\n",
"1. To create an account, visit [Kaggle](https://www.kaggle.com/) and click on 'Register'.\n",
"2. If/once you have an account, you need to sign in, go to your ['Settings'](https://www.kaggle.com/settings), and under 'API' click on 'Create New Token' to generate and download your Kaggle API key.\n",
"3. In [Google Colab](https://colab.research.google.com/), under 'Secrets' add your Kaggle username and API key, storing the username as `KAGGLE_USERNAME` and the key as `KAGGLE_KEY`. If you are using a [Kaggle Notebook](https://www.kaggle.com/code) for free TPU or other hardware acceleration, it has a key storage feature under 'Add-ons' > 'Secrets', along with instructions for accessing stored keys.\n",
"\n",
"Then run the cell below."
]
Expand Down Expand Up @@ -70,13 +76,21 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"If everything went well, you should see:\n",
"```\n",
"Kaggle credentials set.\n",
"Kaggle credentials successfully validated.\n",
"If everything went well, it should say `Kaggle credentials set. Kaggle credentials successfully validated.`.\n",
"\n",
"**Note:** In Google Colab, you can instead authenticate into Kaggle using the code below after following the optional step 3 from above.\n",
"\n",
"```\n",
"import os\n",
"from google.colab import userdata # `userdata` is a Colab API.\n",
"\n",
"os.environ[\"KAGGLE_USERNAME\"] = userdata.get('KAGGLE_USERNAME')\n",
"os.environ[\"KAGGLE_KEY\"] = userdata.get('KAGGLE_KEY')\n",
"``` \n",
"\n",
"Now select and download the checkpoint you want to try. Note that you will need an A100 runtime for the 7b models."
"Now, load the Gemma model you want to try. The code in the next cell utilizes [`kagglehub.model_download`](https://github.com/Kaggle/kagglehub/blob/8efe3e99477aa4f41885840de6903e61a49df4aa/src/kagglehub/models.py#L16) to download model files.\n",
"\n",
"**Note:** For larger models, such as `gemma 7b` and `gemma 7b-it` (instruct), you may require a hardware accelerator with plenty of memory, such as the NVIDIA A100."
]
},
{
Expand All @@ -90,9 +104,7 @@
"VARIANT = '2b-it' # @param ['2b', '2b-it', '7b', '7b-it'] {type:\"string\"}\n",
"weights_dir = kagglehub.model_download(f'google/gemma/Flax/{VARIANT}')\n",
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cgarciae Since there are checkpoints and tokenizer files, changed to "model" instead of "checkpoint".

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checking if free TPU v2-8 is sufficient.

"ckpt_path = f'{weights_dir}/{VARIANT}'\n",
"vocab_path = f'{weights_dir}/tokenizer.model'\n",
"\n",
"clear_output()"
"vocab_path = f'{weights_dir}/tokenizer.model'"
]
},
{
Expand All @@ -116,7 +128,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Flax examples are not exposed as packages so you need to use the workaround in the next cells to import from NNX's Gemma example."
"To interact with the Gemma model, you will use the Flax NNX `gemma` code from [`google/flax` examples on GitHub](https://github.com/google/flax/tree/main/examples/gemma). Since it is not exposed as a package, you need to use the following workaround to import from the Flax NNX `examples/gemma` on GitHub."
]
},
{
Expand All @@ -141,10 +153,9 @@
"source": [
"import sys\n",
"import tempfile\n",
"\n",
"with tempfile.TemporaryDirectory() as tmp:\n",
" # Here we create a temporary directory and clone the flax repo\n",
" # Then we append the examples/gemma folder to the path to load the gemma modules\n",
" # Create a temporary directory and clone the `flax` repo.\n",
" # Then, append the `examples/gemma` folder to the path for loading the `gemma` modules.\n",
" ! git clone https://github.com/google/flax.git {tmp}/flax\n",
" sys.path.append(f\"{tmp}/flax/examples/gemma\")\n",
" import params as params_lib\n",
Expand All @@ -157,9 +168,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Start Generating with Your Model\n",
"## Load and prepare the Gemma model\n",
"\n",
"Load and prepare your LLM's checkpoint for use with Flax."
"First, load the Gemma model parameters for use with Flax."
]
},
{
Expand All @@ -170,15 +181,14 @@
},
"outputs": [],
"source": [
"# Load parameters\n",
"params = params_lib.load_and_format_params(ckpt_path)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Load your tokenizer, which we'll construct using the [SentencePiece](https://github.com/google/sentencepiece) library."
"Next, load the tokenizer file constructed using the [SentencePiece](https://github.com/google/sentencepiece) library."
]
},
{
Expand Down Expand Up @@ -208,7 +218,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Use the `transformer_lib.TransformerConfig.from_params` function to automatically load the correct configuration from a checkpoint. Note that the vocabulary size is smaller than the number of input embeddings due to unused tokens in this release."
"Then, use the Flax NNX [`gemma.transformer.TransformerConfig.from_params`](https://github.com/google/flax/blob/3f3c03b23d4fd3d85d1c5d4d97381a8a2c48b475/examples/gemma/transformer.py#L193) function to automatically load the correct configuration from a checkpoint.\n",
"\n",
"**Note:** The vocabulary size is smaller than the number of input embeddings due to unused tokens in this release."
]
},
{
Expand Down Expand Up @@ -250,7 +262,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally, build a sampler on top of your model and your tokenizer."
"## Perform sampling/inference\n",
"\n",
"Build a Flax NNX [`gemma.Sampler`](https://github.com/google/flax/blob/main/examples/gemma/sampler.py) on top of your model and tokenizer with the right parameter shapes."
]
},
{
Expand All @@ -261,7 +275,6 @@
},
"outputs": [],
"source": [
"# Create a sampler with the right param shapes.\n",
"sampler = sampler_lib.Sampler(\n",
" transformer=transformer,\n",
" vocab=vocab,\n",
Expand All @@ -272,7 +285,11 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"You're ready to start sampling ! This sampler uses just-in-time compilation, so changing the input shape triggers recompilation, which can slow things down. For the fastest and most efficient results, keep your batch size consistent."
"You're ready to start sampling!\n",
"\n",
"**Note:** This Flax NNX [`gemma.Sampler`](https://github.com/google/flax/blob/main/examples/gemma/sampler.py) uses JAX’s [just-in-time (JIT) compilation](https://jax.readthedocs.io/en/latest/jit-compilation.html), so changing the input shape triggers recompilation, which can slow things down. For the fastest and most efficient results, keep your batch size consistent.\n",
"\n",
"Write a prompt in `input_batch` and perform inference. Feel free to tweak `total_generation_steps` (the number of steps performed when generating a response)."
]
},
{
Expand Down Expand Up @@ -342,12 +359,12 @@
],
"source": [
"input_batch = [\n",
" \"\\n# Python program for implementation of Bubble Sort\\n\\ndef bubbleSort(arr):\",\n",
"]\n",
" \"\\n# Python program for implementation of Bubble Sort\\n\\ndef bubbleSort(arr):\",\n",
" ]\n",
"\n",
"out_data = sampler(\n",
" input_strings=input_batch,\n",
" total_generation_steps=300, # number of steps performed when generating\n",
" total_generation_steps=300, # The number of steps performed when generating a response.\n",
" )\n",
"\n",
"for input_string, out_string in zip(input_batch, out_data.text):\n",
Expand All @@ -360,7 +377,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"You should get an implementation of bubble sort."
"You should get a Python implementation of the bubble sort algorithm."
]
}
],
Expand Down
79 changes: 48 additions & 31 deletions docs_nnx/guides/gemma.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,32 @@ jupytext:
jupytext_version: 1.13.8
---

# Example: Using Pretrained Gemma
# Example: Using pretrained Gemma for inference with Flax NNX

You will find in this colab a detailed tutorial explaining how to use NNX to load a Gemma checkpoint and sample from it.
This example shows how to use Flax NNX to load the [Gemma](https://ai.google.dev/gemma) open model files and use them to perform sampling/inference for generating text. You will use [Flax NNX `gemma` modules](https://github.com/google/flax/tree/main/examples/gemma) written with Flax and JAX for model parameter configuration and inference.

> Gemma is a family of lightweight, state-of-the-art open models based on Google DeepMind’s [Gemini](https://deepmind.google/technologies/gemini/#introduction). Read more about [Gemma](https://blog.google/technology/developers/gemma-open-models/) and [Gemma 2](https://blog.google/technology/developers/google-gemma-2/).

You are recommended to use [Google Colab](https://colab.research.google.com/) with access to A100 GPU acceleration to run the code.

+++

## Installation

Install the necessary dependencies, including `kagglehub`.

```{code-cell} ipython3
! pip install --no-deps -U flax
! pip install jaxtyping kagglehub treescope
! pip install jaxtyping kagglehub penzai
```

## Downloading the checkpoint
## Download the model

"To use Gemma's checkpoints, you'll need a Kaggle account and API key. Here's how to get them:
To use Gemma model, you'll need a [Kaggle](https://www.kaggle.com/models/google/gemma/) account and API key:

1. Visit https://www.kaggle.com/ and create an account.
2. Go to your account settings, then the 'API' section.
3. Click 'Create new token' to download your key.
1. To create an account, visit [Kaggle](https://www.kaggle.com/) and click on 'Register'.
2. If/once you have an account, you need to sign in, go to your ['Settings'](https://www.kaggle.com/settings), and under 'API' click on 'Create New Token' to generate and download your Kaggle API key.
3. In [Google Colab](https://colab.research.google.com/), under 'Secrets' add your Kaggle username and API key, storing the username as `KAGGLE_USERNAME` and the key as `KAGGLE_KEY`. If you are using a [Kaggle Notebook](https://www.kaggle.com/code) for free TPU or other hardware acceleration, it has a key storage feature under 'Add-ons' > 'Secrets', along with instructions for accessing stored keys.

Then run the cell below.

Expand All @@ -36,13 +42,21 @@ import kagglehub
kagglehub.login()
```

If everything went well, you should see:
```
Kaggle credentials set.
Kaggle credentials successfully validated.
If everything went well, it should say `Kaggle credentials set. Kaggle credentials successfully validated.`.

**Note:** In Google Colab, you can instead authenticate into Kaggle using the code below after following the optional step 3 from above.

```
import os
from google.colab import userdata # `userdata` is a Colab API.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')
```

Now select and download the checkpoint you want to try. Note that you will need an A100 runtime for the 7b models.
Now, load the Gemma model you want to try. The code in the next cell utilizes [`kagglehub.model_download`](https://github.com/Kaggle/kagglehub/blob/8efe3e99477aa4f41885840de6903e61a49df4aa/src/kagglehub/models.py#L16) to download model files.

**Note:** For larger models, such as `gemma 7b` and `gemma 7b-it` (instruct), you may require a hardware accelerator with plenty of memory, such as the NVIDIA A100.

```{code-cell} ipython3
from IPython.display import clear_output
Expand All @@ -51,8 +65,6 @@ VARIANT = '2b-it' # @param ['2b', '2b-it', '7b', '7b-it'] {type:"string"}
weights_dir = kagglehub.model_download(f'google/gemma/Flax/{VARIANT}')
ckpt_path = f'{weights_dir}/{VARIANT}'
vocab_path = f'{weights_dir}/tokenizer.model'

clear_output()
```

## Python imports
Expand All @@ -62,15 +74,14 @@ from flax import nnx
import sentencepiece as spm
```

Flax examples are not exposed as packages so you need to use the workaround in the next cells to import from NNX's Gemma example.
To interact with the Gemma model, you will use the Flax NNX `gemma` code from [`google/flax` examples on GitHub](https://github.com/google/flax/tree/main/examples/gemma). Since it is not exposed as a package, you need to use the following workaround to import from the Flax NNX `examples/gemma` on GitHub.

```{code-cell} ipython3
import sys
import tempfile

with tempfile.TemporaryDirectory() as tmp:
# Here we create a temporary directory and clone the flax repo
# Then we append the examples/gemma folder to the path to load the gemma modules
# Create a temporary directory and clone the `flax` repo.
# Then, append the `examples/gemma` folder to the path for loading the `gemma` modules.
! git clone https://github.com/google/flax.git {tmp}/flax
sys.path.append(f"{tmp}/flax/examples/gemma")
import params as params_lib
Expand All @@ -79,18 +90,17 @@ with tempfile.TemporaryDirectory() as tmp:
sys.path.pop();
```

## Start Generating with Your Model
## Load and prepare the Gemma model

Load and prepare your LLM's checkpoint for use with Flax.
First, load the Gemma model parameters for use with Flax.

```{code-cell} ipython3
:cellView: form

# Load parameters
params = params_lib.load_and_format_params(ckpt_path)
```

Load your tokenizer, which we'll construct using the [SentencePiece](https://github.com/google/sentencepiece) library.
Next, load the tokenizer file constructed using the [SentencePiece](https://github.com/google/sentencepiece) library.

```{code-cell} ipython3
:cellView: form
Expand All @@ -99,37 +109,44 @@ vocab = spm.SentencePieceProcessor()
vocab.Load(vocab_path)
```

Use the `transformer_lib.TransformerConfig.from_params` function to automatically load the correct configuration from a checkpoint. Note that the vocabulary size is smaller than the number of input embeddings due to unused tokens in this release.
Then, use the Flax NNX [`gemma.transformer.TransformerConfig.from_params`](https://github.com/google/flax/blob/3f3c03b23d4fd3d85d1c5d4d97381a8a2c48b475/examples/gemma/transformer.py#L193) function to automatically load the correct configuration from a checkpoint.

**Note:** The vocabulary size is smaller than the number of input embeddings due to unused tokens in this release.

```{code-cell} ipython3
transformer = transformer_lib.Transformer.from_params(params)
nnx.display(transformer)
```

Finally, build a sampler on top of your model and your tokenizer.
## Perform sampling/inference

Build a Flax NNX [`gemma.Sampler`](https://github.com/google/flax/blob/main/examples/gemma/sampler.py) on top of your model and tokenizer with the right parameter shapes.

```{code-cell} ipython3
:cellView: form

# Create a sampler with the right param shapes.
sampler = sampler_lib.Sampler(
transformer=transformer,
vocab=vocab,
)
```

You're ready to start sampling ! This sampler uses just-in-time compilation, so changing the input shape triggers recompilation, which can slow things down. For the fastest and most efficient results, keep your batch size consistent.
You're ready to start sampling!

**Note:** This Flax NNX [`gemma.Sampler`](https://github.com/google/flax/blob/main/examples/gemma/sampler.py) uses JAX’s [just-in-time (JIT) compilation](https://jax.readthedocs.io/en/latest/jit-compilation.html), so changing the input shape triggers recompilation, which can slow things down. For the fastest and most efficient results, keep your batch size consistent.

Write a prompt in `input_batch` and perform inference. Feel free to tweak `total_generation_steps` (the number of steps performed when generating a response).

```{code-cell} ipython3
:cellView: form

input_batch = [
"\n# Python program for implementation of Bubble Sort\n\ndef bubbleSort(arr):",
]
"\n# Python program for implementation of Bubble Sort\n\ndef bubbleSort(arr):",
]

out_data = sampler(
input_strings=input_batch,
total_generation_steps=300, # number of steps performed when generating
total_generation_steps=300, # The number of steps performed when generating a response.
)

for input_string, out_string in zip(input_batch, out_data.text):
Expand All @@ -138,4 +155,4 @@ for input_string, out_string in zip(input_batch, out_data.text):
print(10*'#')
```

You should get an implementation of bubble sort.
You should get a Python implementation of the bubble sort algorithm.
Loading