Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade Flax NNX Gemma Sampling Inference doc #4325

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

8bitmp3
Copy link
Collaborator

@8bitmp3 8bitmp3 commented Oct 23, 2024

Preview: https://flax--4325.org.readthedocs.build/en/4325/guides/gemma.html

Also fixes broken code after:

! git clone https://github.com/google/flax.git flax_examples
...
- sys.path.append("./flax_examples/flax/nnx/examples/gemma")
+ sys.path.append("./flax_examples/examples/gemma")
...

@8bitmp3 8bitmp3 self-assigned this Oct 23, 2024
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

"You will find in this colab a detailed tutorial explaining how to use NNX to load a Gemma checkpoint and sample from it."
"In this tutorial, you will learn step-by-step how to use Flax NNX to load the [Gemma](https://ai.google.dev/gemma) open model files and use them to perform sampling/inference for generating text. You will use the [Flax NNX `gemma` code](https://github.com/google/flax.git) that was written with Flax and JAX.\n",
"\n",
"> Gemma is a family of lightweight, state-of-the-art open models based on Google DeepMind’s [Gemini](https://deepmind.google/technologies/gemini/#introduction). Read more about [Gemma](https://blog.google/technology/developers/gemma-open-models/) and [Gemma 2](https://blog.google/technology/developers/google-gemma-2/).\n",
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to what we did in other Gemma docs - added some background.

"\n",
"You will find in this colab a detailed tutorial explaining how to use NNX to load a Gemma checkpoint and sample from it."
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"guide" in the title, "tutorial" in the first paragraph -> let's use "tutorial".

"\n",
"Now select and download the checkpoint you want to try. Note that you will need an A100 runtime for the 7b models."
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cgarciae Since there are checkpoints and tokenizer files, changed to "model" instead of "checkpoint".

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checking if free TPU v2-8 is sufficient.

@@ -19,16 +19,24 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# Getting Started with Gemma Sampling using NNX: A Step-by-Step Guide\n",
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cgarciae Adding "inference" next to "sampling" for search.

"\n",
"1. Visit https://www.kaggle.com/ and create an account.\n",
"2. Go to your account settings, then the 'API' section.\n",
"3. Click 'Create new token' to download your key.\n",
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cgarciae Adding Step 3 "OPTIONAL" and removing "OPTIONAL". since Colab asks for access here after running the code below, so users won't have to manually entering the details if they are stored in Colab:

import kagglehub
kagglehub.login()

"1. To create an account, visit Kaggle and click on 'Register'."
"2. If/once you have an account, you need to sign in, go to your 'Settings', and under 'API' click on 'Create New Token' to generate and download your Kaggle API key."
"3. OPTIONAL: In Google Colab, under 'Secrets' add your Kaggle username and API key, storing the username as KAGGLE_USERNAME and the key as KAGGLE_KEY. If you are using a Kaggle Notebook for free TPU or other hardware acceleration, it has a key storage feature under 'Add-ons' > 'Secrets', along with instructions for accessing stored keys."

@@ -82,13 +90,21 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"If everything went well, you should see:\n",
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding an extra optional step here, similar to what we have in the Gemma docs.

"Note: In Google Colab, you can instead authenticate into Kaggle using the code below after following the optional step 3 from above...."

@@ -124,7 +140,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Flax examples are not exposed as packages so you need to use the workaround in the next cells to import from NNX's Gemma example."
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cgarciae Edited:

"To interact with the Gemma model, you will use the Flax NNX Gemma code from google/flax examples on GitHub. Since it is not exposed as packages, you need to use the following workaround in the next cells to import from the Flax NNX Gemma example."

@@ -195,7 +218,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Use the `transformer_lib.TransformerConfig.from_params` function to automatically load the correct configuration from a checkpoint. Note that the vocabulary size is smaller than the number of input embeddings due to unused tokens in this release."
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added the source code that has more docstring(s) since transformer_lib is an alias for Flax NNX examples -> gemma.transformer:

"Then, use the Flax NNX transformer_lib.TransformerConfig.from_params function to automatically load the correct configuration from a checkpoint."

@@ -212,7 +237,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally, build a sampler on top of your model and your tokenizer."
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added the source code with the docstring.

"Build a Flax NNX Sampler on top of your model and tokenizer with the right parameter shapes."

@@ -235,7 +261,11 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"You're ready to start sampling ! This sampler uses just-in-time compilation, so changing the input shape triggers recompilation, which can slow things down. For the fastest and most efficient results, keep your batch size consistent."
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added some background on JAX JIT after studying the source code (it's not NNX JIT).

"Note: This Flax NNX gemma.Sampler uses JAX’s just-in-time (JIT) compilation, so changing the input shape triggers recompilation, which can slow things down. For the fastest and most efficient results, keep your batch size consistent."

@@ -136,6 +152,14 @@
"! git clone https://github.com/google/flax.git flax_examples"
Copy link
Collaborator Author

@8bitmp3 8bitmp3 Oct 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixing

! git clone https://github.com/google/flax.git flax_examples
...
- sys.path.append("./flax_examples/flax/nnx/examples/gemma")
+ sys.path.append("./flax_examples/examples/gemma")
...

@8bitmp3
Copy link
Collaborator Author

8bitmp3 commented Oct 23, 2024

Sampler configuration (https://github.com/google/flax/blob/main/examples/gemma/sampler.py):

    transformer=transformer,
    vocab=vocab,
    params=params['transformer'],
)

Throws an error

TypeError                                 Traceback (most recent call last)
...
in <cell line: 1>()
----> 1 sampler = sampler_lib.Sampler(
      2     transformer=transformer,
      3     vocab=vocab,
      4     params=params['transformer'],
      5 )

TypeError: Sampler.__init__() got an unexpected keyword argument 'params'

@cgarciae PTAL

"3. Click 'Create new token' to download your key.\n",
"1. To create an account, visit [Kaggle](https://www.kaggle.com/) and click on 'Register'.\n",
"2. If/once you have an account, you need to sign in, go to your ['Settings'](https://www.kaggle.com/settings), and under 'API' click on 'Create New Token' to generate and download your Kaggle API key.\n",
"3. OPTIONAL: In [Google Colab](https://colab.research.google.com/), under 'Secrets' add your Kaggle username and API key, storing the username as `KAGGLE_USERNAME` and the key as `KAGGLE_KEY`. If you are using a [Kaggle Notebook](https://www.kaggle.com/code) for free TPU or other hardware acceleration, it has a key storage feature under 'Add-ons' > 'Secrets', along with instructions for accessing stored keys.\n",
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: Should remove Optional for Colab users?

@8bitmp3 8bitmp3 marked this pull request as ready for review October 28, 2024 21:43
@cgarciae
Copy link
Collaborator

Hey @8bitmp3! I cleaned up this guide a little bit. Can you take a look at the new version?

@8bitmp3
Copy link
Collaborator Author

8bitmp3 commented Oct 31, 2024

thanks @cgarciae 👍
on it

@8bitmp3
Copy link
Collaborator Author

8bitmp3 commented Nov 4, 2024

Reopening after #4334 fixes

@8bitmp3 8bitmp3 reopened this Nov 4, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants