Skip to content

Commit

Permalink
[nnx] cleanup gemma notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Oct 25, 2024
1 parent 40f080e commit 2a091a5
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 104 deletions.
78 changes: 30 additions & 48 deletions docs_nnx/guides/gemma.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Copyright 2024 The Flax Authors.\n",
"\n",
"Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License. You may obtain a copy of the License at\n",
"\n",
"http://www.apache.org/licenses/LICENSE-2.0\n",
"\n",
"Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.\n",
"\n",
"---"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Getting Started with Gemma Sampling using NNX: A Step-by-Step Guide\n",
"# Getting Started with Gemma Sampling\n",
"\n",
"You will find in this colab a detailed tutorial explaining how to use NNX to load a Gemma checkpoint and sample from it."
]
Expand All @@ -33,12 +18,12 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"! pip install --no-deps -U flax\n",
"! pip install jaxtyping kagglehub penzai"
"# ! pip install --no-deps -U flax\n",
"# ! pip install jaxtyping kagglehub penzai"
]
},
{
Expand All @@ -58,19 +43,22 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 4,
"metadata": {},
"outputs": [
{
"ename": "ModuleNotFoundError",
"evalue": "No module named 'kagglehub'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[1], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mkagglehub\u001b[39;00m\n\u001b[1;32m 2\u001b[0m kagglehub\u001b[38;5;241m.\u001b[39mlogin()\n",
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'kagglehub'"
]
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0d6c3894765240b098f75e98d73acb41",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"VBox(children=(HTML(value='<center> <img\\nsrc=https://www.kaggle.com/static/images/site-logo.png\\nalt=\\'Kaggle…"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
Expand Down Expand Up @@ -127,28 +115,24 @@
"Flax examples are not exposed as packages so you need to use the workaround in the next cells to import from NNX's Gemma example."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"! git clone https://github.com/google/flax.git flax_examples"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"import tempfile\n",
"\n",
"sys.path.append(\"./flax_examples/flax/nnx/examples/gemma\")\n",
"import params as params_lib\n",
"import sampler as sampler_lib\n",
"import transformer as transformer_lib\n",
"sys.path.pop();"
"with tempfile.TemporaryDirectory() as tmp:\n",
" # Here we create a temporary directory and clone the flax repo\n",
" # Then we append the examples/gemma folder to the path to load the gemma modules\n",
" ! git clone https://github.com/google/flax.git {tmp}/flax\n",
" sys.path.append(f\"{tmp}/flax/examples/gemma\")\n",
" import params as params_lib\n",
" import sampler as sampler_lib\n",
" import transformer as transformer_lib\n",
" sys.path.pop();"
]
},
{
Expand Down Expand Up @@ -227,7 +211,6 @@
"sampler = sampler_lib.Sampler(\n",
" transformer=transformer,\n",
" vocab=vocab,\n",
" params=params['transformer'],\n",
")"
]
},
Expand All @@ -247,9 +230,8 @@
"outputs": [],
"source": [
"input_batch = [\n",
" \"\\n# Python program for implementation of Bubble Sort\\n\\ndef bubbleSort(arr):\",\n",
" \"What are the planets of the solar system?\",\n",
" ]\n",
" \"\\n# Python program for implementation of Bubble Sort\\n\\ndef bubbleSort(arr):\",\n",
"]\n",
"\n",
"out_data = sampler(\n",
" input_strings=input_batch,\n",
Expand All @@ -266,7 +248,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"You should get an implementation of bubble sort and a description of the solar system."
"You should get an implementation of bubble sort."
]
}
],
Expand Down
47 changes: 17 additions & 30 deletions docs_nnx/guides/gemma.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,7 @@ jupytext:
jupytext_version: 1.13.8
---

Copyright 2024 The Flax Authors.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.

---

+++

# Getting Started with Gemma Sampling using NNX: A Step-by-Step Guide
# Getting Started with Gemma Sampling

You will find in this colab a detailed tutorial explaining how to use NNX to load a Gemma checkpoint and sample from it.

Expand All @@ -29,8 +17,8 @@ You will find in this colab a detailed tutorial explaining how to use NNX to loa
## Installation

```{code-cell} ipython3
! pip install --no-deps -U flax
! pip install jaxtyping kagglehub penzai
# ! pip install --no-deps -U flax
# ! pip install jaxtyping kagglehub penzai
```

## Downloading the checkpoint
Expand Down Expand Up @@ -72,18 +60,19 @@ import sentencepiece as spm

Flax examples are not exposed as packages so you need to use the workaround in the next cells to import from NNX's Gemma example.

```{code-cell} ipython3
! git clone https://github.com/google/flax.git flax_examples
```

```{code-cell} ipython3
import sys
sys.path.append("./flax_examples/flax/nnx/examples/gemma")
import params as params_lib
import sampler as sampler_lib
import transformer as transformer_lib
sys.path.pop();
import tempfile
with tempfile.TemporaryDirectory() as tmp:
# Here we create a temporary directory and clone the flax repo
# Then we append the examples/gemma folder to the path to load the gemma modules
! git clone https://github.com/google/flax.git {tmp}/flax
sys.path.append(f"{tmp}/flax/examples/gemma")
import params as params_lib
import sampler as sampler_lib
import transformer as transformer_lib
sys.path.pop();
```

## Start Generating with Your Model
Expand Down Expand Up @@ -122,7 +111,6 @@ Finally, build a sampler on top of your model and your tokenizer.
sampler = sampler_lib.Sampler(
transformer=transformer,
vocab=vocab,
params=params['transformer'],
)
```

Expand All @@ -132,9 +120,8 @@ You're ready to start sampling ! This sampler uses just-in-time compilation, so
:cellView: form
input_batch = [
"\n# Python program for implementation of Bubble Sort\n\ndef bubbleSort(arr):",
"What are the planets of the solar system?",
]
"\n# Python program for implementation of Bubble Sort\n\ndef bubbleSort(arr):",
]
out_data = sampler(
input_strings=input_batch,
Expand All @@ -147,4 +134,4 @@ for input_string, out_string in zip(input_batch, out_data.text):
print(10*'#')
```

You should get an implementation of bubble sort and a description of the solar system.
You should get an implementation of bubble sort.
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -76,18 +76,18 @@ docs = [
"sphinx-design",
"jupytext==1.13.8",
"dm-haiku",

# Need to pin docutils to 0.16 to make bulleted lists appear correctly on
# ReadTheDocs: https://stackoverflow.com/a/68008428
"docutils==0.16",

# The next packages are for notebooks.
"matplotlib",
"scikit-learn",
# The next packages are used in testcode blocks.
"ml_collections",
# notebooks
"einops",
"kagglehub>=0.3.3",
"ipywidgets>=8.1.5",
]
dev = [
"pre-commit>=3.8.0",
Expand Down
Loading

0 comments on commit 2a091a5

Please sign in to comment.