-
Notifications
You must be signed in to change notification settings - Fork 645
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Upgrade Flax NNX Gemma Sampling Inference doc #4325
Open
8bitmp3
wants to merge
1
commit into
google:main
Choose a base branch
from
8bitmp3:update-nnx-gemma
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,16 +4,22 @@ | |
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Example: Using Pretrained Gemma\n", | ||
"# Example: Using pretrained Gemma for inference with Flax NNX\n", | ||
"\n", | ||
"You will find in this colab a detailed tutorial explaining how to use NNX to load a Gemma checkpoint and sample from it." | ||
"This example shows how to use Flax NNX to load the [Gemma](https://ai.google.dev/gemma) open model files and use them to perform sampling/inference for generating text. You will use [Flax NNX `gemma` modules](https://github.com/google/flax/tree/main/examples/gemma) written with Flax and JAX for model parameter configuration and inference.\n", | ||
"\n", | ||
"> Gemma is a family of lightweight, state-of-the-art open models based on Google DeepMind’s [Gemini](https://deepmind.google/technologies/gemini/#introduction). Read more about [Gemma](https://blog.google/technology/developers/gemma-open-models/) and [Gemma 2](https://blog.google/technology/developers/google-gemma-2/).\n", | ||
"\n", | ||
"You are recommended to use [Google Colab](https://colab.research.google.com/) with access to A100 GPU acceleration to run the code." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Installation" | ||
"## Installation\n", | ||
"\n", | ||
"Install the necessary dependencies, including `kagglehub`." | ||
] | ||
}, | ||
{ | ||
|
@@ -23,20 +29,20 @@ | |
"outputs": [], | ||
"source": [ | ||
"! pip install --no-deps -U flax\n", | ||
"! pip install jaxtyping kagglehub treescope" | ||
"! pip install jaxtyping kagglehub penzai" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Downloading the checkpoint\n", | ||
"## Download the model\n", | ||
"\n", | ||
"\"To use Gemma's checkpoints, you'll need a Kaggle account and API key. Here's how to get them:\n", | ||
"To use Gemma model, you'll need a [Kaggle](https://www.kaggle.com/models/google/gemma/) account and API key:\n", | ||
"\n", | ||
"1. Visit https://www.kaggle.com/ and create an account.\n", | ||
"2. Go to your account settings, then the 'API' section.\n", | ||
"3. Click 'Create new token' to download your key.\n", | ||
"1. To create an account, visit [Kaggle](https://www.kaggle.com/) and click on 'Register'.\n", | ||
"2. If/once you have an account, you need to sign in, go to your ['Settings'](https://www.kaggle.com/settings), and under 'API' click on 'Create New Token' to generate and download your Kaggle API key.\n", | ||
"3. In [Google Colab](https://colab.research.google.com/), under 'Secrets' add your Kaggle username and API key, storing the username as `KAGGLE_USERNAME` and the key as `KAGGLE_KEY`. If you are using a [Kaggle Notebook](https://www.kaggle.com/code) for free TPU or other hardware acceleration, it has a key storage feature under 'Add-ons' > 'Secrets', along with instructions for accessing stored keys.\n", | ||
"\n", | ||
"Then run the cell below." | ||
] | ||
|
@@ -70,13 +76,21 @@ | |
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"If everything went well, you should see:\n", | ||
"```\n", | ||
"Kaggle credentials set.\n", | ||
"Kaggle credentials successfully validated.\n", | ||
"If everything went well, it should say `Kaggle credentials set. Kaggle credentials successfully validated.`.\n", | ||
"\n", | ||
"**Note:** In Google Colab, you can instead authenticate into Kaggle using the code below after following the optional step 3 from above.\n", | ||
"\n", | ||
"```\n", | ||
"import os\n", | ||
"from google.colab import userdata # `userdata` is a Colab API.\n", | ||
"\n", | ||
"os.environ[\"KAGGLE_USERNAME\"] = userdata.get('KAGGLE_USERNAME')\n", | ||
"os.environ[\"KAGGLE_KEY\"] = userdata.get('KAGGLE_KEY')\n", | ||
"``` \n", | ||
"\n", | ||
"Now select and download the checkpoint you want to try. Note that you will need an A100 runtime for the 7b models." | ||
"Now, load the Gemma model you want to try. The code in the next cell utilizes [`kagglehub.model_download`](https://github.com/Kaggle/kagglehub/blob/8efe3e99477aa4f41885840de6903e61a49df4aa/src/kagglehub/models.py#L16) to download model files.\n", | ||
"\n", | ||
"**Note:** For larger models, such as `gemma 7b` and `gemma 7b-it` (instruct), you may require a hardware accelerator with plenty of memory, such as the NVIDIA A100." | ||
] | ||
}, | ||
{ | ||
|
@@ -90,9 +104,7 @@ | |
"VARIANT = '2b-it' # @param ['2b', '2b-it', '7b', '7b-it'] {type:\"string\"}\n", | ||
"weights_dir = kagglehub.model_download(f'google/gemma/Flax/{VARIANT}')\n", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @cgarciae Since there are checkpoints and tokenizer files, changed to "model" instead of "checkpoint". There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Checking if free TPU v2-8 is sufficient. |
||
"ckpt_path = f'{weights_dir}/{VARIANT}'\n", | ||
"vocab_path = f'{weights_dir}/tokenizer.model'\n", | ||
"\n", | ||
"clear_output()" | ||
"vocab_path = f'{weights_dir}/tokenizer.model'" | ||
] | ||
}, | ||
{ | ||
|
@@ -116,7 +128,7 @@ | |
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Flax examples are not exposed as packages so you need to use the workaround in the next cells to import from NNX's Gemma example." | ||
"To interact with the Gemma model, you will use the Flax NNX `gemma` code from [`google/flax` examples on GitHub](https://github.com/google/flax/tree/main/examples/gemma). Since it is not exposed as a package, you need to use the following workaround to import from the Flax NNX `examples/gemma` on GitHub." | ||
] | ||
}, | ||
{ | ||
|
@@ -141,10 +153,9 @@ | |
"source": [ | ||
"import sys\n", | ||
"import tempfile\n", | ||
"\n", | ||
"with tempfile.TemporaryDirectory() as tmp:\n", | ||
" # Here we create a temporary directory and clone the flax repo\n", | ||
" # Then we append the examples/gemma folder to the path to load the gemma modules\n", | ||
" # Create a temporary directory and clone the `flax` repo.\n", | ||
" # Then, append the `examples/gemma` folder to the path for loading the `gemma` modules.\n", | ||
" ! git clone https://github.com/google/flax.git {tmp}/flax\n", | ||
" sys.path.append(f\"{tmp}/flax/examples/gemma\")\n", | ||
" import params as params_lib\n", | ||
|
@@ -157,9 +168,9 @@ | |
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Start Generating with Your Model\n", | ||
"## Load and prepare the Gemma model\n", | ||
"\n", | ||
"Load and prepare your LLM's checkpoint for use with Flax." | ||
"First, load the Gemma model parameters for use with Flax." | ||
] | ||
}, | ||
{ | ||
|
@@ -170,15 +181,14 @@ | |
}, | ||
"outputs": [], | ||
"source": [ | ||
"# Load parameters\n", | ||
"params = params_lib.load_and_format_params(ckpt_path)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Load your tokenizer, which we'll construct using the [SentencePiece](https://github.com/google/sentencepiece) library." | ||
"Next, load the tokenizer file constructed using the [SentencePiece](https://github.com/google/sentencepiece) library." | ||
] | ||
}, | ||
{ | ||
|
@@ -208,7 +218,9 @@ | |
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Use the `transformer_lib.TransformerConfig.from_params` function to automatically load the correct configuration from a checkpoint. Note that the vocabulary size is smaller than the number of input embeddings due to unused tokens in this release." | ||
"Then, use the Flax NNX [`gemma.transformer.TransformerConfig.from_params`](https://github.com/google/flax/blob/3f3c03b23d4fd3d85d1c5d4d97381a8a2c48b475/examples/gemma/transformer.py#L193) function to automatically load the correct configuration from a checkpoint.\n", | ||
"\n", | ||
"**Note:** The vocabulary size is smaller than the number of input embeddings due to unused tokens in this release." | ||
] | ||
}, | ||
{ | ||
|
@@ -250,7 +262,9 @@ | |
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Finally, build a sampler on top of your model and your tokenizer." | ||
"## Perform sampling/inference\n", | ||
"\n", | ||
"Build a Flax NNX [`gemma.Sampler`](https://github.com/google/flax/blob/main/examples/gemma/sampler.py) on top of your model and tokenizer with the right parameter shapes." | ||
] | ||
}, | ||
{ | ||
|
@@ -261,7 +275,6 @@ | |
}, | ||
"outputs": [], | ||
"source": [ | ||
"# Create a sampler with the right param shapes.\n", | ||
"sampler = sampler_lib.Sampler(\n", | ||
" transformer=transformer,\n", | ||
" vocab=vocab,\n", | ||
|
@@ -272,7 +285,11 @@ | |
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"You're ready to start sampling ! This sampler uses just-in-time compilation, so changing the input shape triggers recompilation, which can slow things down. For the fastest and most efficient results, keep your batch size consistent." | ||
"You're ready to start sampling!\n", | ||
"\n", | ||
"**Note:** This Flax NNX [`gemma.Sampler`](https://github.com/google/flax/blob/main/examples/gemma/sampler.py) uses JAX’s [just-in-time (JIT) compilation](https://jax.readthedocs.io/en/latest/jit-compilation.html), so changing the input shape triggers recompilation, which can slow things down. For the fastest and most efficient results, keep your batch size consistent.\n", | ||
"\n", | ||
"Write a prompt in `input_batch` and perform inference. Feel free to tweak `total_generation_steps` (the number of steps performed when generating a response)." | ||
] | ||
}, | ||
{ | ||
|
@@ -342,12 +359,12 @@ | |
], | ||
"source": [ | ||
"input_batch = [\n", | ||
" \"\\n# Python program for implementation of Bubble Sort\\n\\ndef bubbleSort(arr):\",\n", | ||
"]\n", | ||
" \"\\n# Python program for implementation of Bubble Sort\\n\\ndef bubbleSort(arr):\",\n", | ||
" ]\n", | ||
"\n", | ||
"out_data = sampler(\n", | ||
" input_strings=input_batch,\n", | ||
" total_generation_steps=300, # number of steps performed when generating\n", | ||
" total_generation_steps=300, # The number of steps performed when generating a response.\n", | ||
" )\n", | ||
"\n", | ||
"for input_string, out_string in zip(input_batch, out_data.text):\n", | ||
|
@@ -360,7 +377,7 @@ | |
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"You should get an implementation of bubble sort." | ||
"You should get a Python implementation of the bubble sort algorithm." | ||
] | ||
} | ||
], | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"guide" in the title, "tutorial" in the first paragraph -> let's use "tutorial".