A JAX/Equinox rewrite of nanoGPT that prioritizes pedagogy and modern ML frameworks. This implementation reimagines the original repository using JAX's powerful array computation capabilities while maintaining Pytorch-like simplicity through Equinox. The code is clean and pedagogical: model.py
is a ~300-line GPT model definition with Equinox modules, and train.py
implements training with JAX's powerful transformations.
📖 Want to understand every line? Check out our complete tutorial: nanoJAXGPT: A pedagogical introduction to JAX/Equinox
- JAX Native: Leverages JAX's powerful features like
jit
,grad
, andvmap
for efficient training - Equinox Integration: Maintains PyTorch-like simplicity while getting all of JAX's benefits
- SwiGLU Activation: Enhanced activation function replacing GELU in the original
- Educational First: Clear, documented code structure for learning JAX and Transformers
First, clone the repository:
git clone https://github.com/surgeglobal/nanoJAXGPT
cd nanoJAXGPT
Then install the dependencies using the provided requirements.txt:
pip install -r requirements.txt
This will install all necessary dependencies:
- JAX and Equinox for core model implementation
- Transformers for loading GPT-2 checkpoints
- Datasets for data processing
- Tiktoken for BPE tokenization
- Wandb for experiment tracking
- Additional utilities (numpy, tqdm)
For GPU support, make sure you have CUDA installed and install the CUDA-enabled version of JAX:
# For CUDA 12
pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
# For CUDA 11
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
To verify your installation:
import jax
print("Available devices:", jax.devices()) # Should show your GPU(s) if available
Now you're ready to start training! Check out the Quick start section below.
If you're new to JAX or just want to get your feet wet, let's start a simple training example.
📖 Want to train on your custom datasets? You can create custom scripts for your preferred datasets as we have done for the tinystories dataset.
I have a GPU/TPU. Excellent! JAX shines on accelerators. Let's train an example story telling model.
This project is set up to run locally by default, but you can easily switch to running on a remote GPU provided by modal.com.
⚠️ Before running the code (either locally or on the cloud), you need to be logged in to **modal.com**. This is because our script will try to attach a _Modal_ volume (similar to an S3 bucket, but free) regardless of the environment it is executed in.
NOTE: This is something that comes along with _Modal_ and we are trying to remove this functionality when run on local (**PRs are welcome**)!
-
Modal Setup:
Before executing the code, you need to be logged in to Modal. To do this:modal setup
This will prompt you to log in using your GitHub account. Follow the instructions on the terminal to authorize access. For more information, visit the Modal.com Getting Started Guide for details.
-
Run Locally or Remotely: By default, running the following command will execute the training script on your local machine:
modal run app.py
If you prefer to run on a remote GPU using Modal's serverless GPUs (e.g., an A100 GPU), add the
--run-on-remote
option:modal run app.py --run-on-remote true
-
Configuration File:
Currently, the code doesn't directly accept a configuration file. However, you can modify parameters directly in the training script (executables/train.py
) to customize the training process.
Make sure you are logged into modal.com before executing the code by running modal setup
.
I only have a CPU. No problem! JAX works great on CPU too. However, you may be limited to smaller parameter models due to memory/computation constraints. You may try a scaled-down version of the model by setting the following values in th train.py
file.
block_size=64
batch_size=12
n_layer=4
n_head=4
n_embd=128
max_iters=2000
lr_decay_iters=2000
⚠️ The sampling logic is available in thesample.py
file in the root directory. However, it only supports remote sampling at the moment. While we intend to modify this to support both environments (again, PRs are welcome!), you may use them use that logic to setup your own sampling file.
The model implements a modern GPT architecture with a clean hierarchical structure:
GPT (eqx.Module)
├── TransformerLayer
│ ├── Embedding (wte)
│ ├── Embedding (wpe)
│ └── Block
│ ├── CausalSelfAttention
│ │ └── Linear
│ └── MLP
│ └── SwiGLU
Key components:
GPT
: The main model class containing the transformer and language model headTransformerLayer
: Core component managing token/positional embeddings and transformer blocksBlock
: Structural component combining attention and MLP modulesCausalSelfAttention
: Implementation of masked self-attention with projection layersMLP
: Feedforward network using projection layers and SwiGLU activationSwiGLU
: Activation function with learnable parameters replacing traditional GELU
All components are implemented as Equinox modules, providing a PyTorch-like interface while leveraging JAX's functional programming model and computational benefits. This architecture maintains the core ideas of GPT while introducing modern improvements like SwiGLU activation and a cleaner organizational structure.
The code leverages several JAX optimizations:
- JIT compilation of training and inference loops
- Vectorization through
vmap
- Automatic gradient computation with
grad
- Efficient device parallelism
- PyTree-based parameter management through Equinox
- Add model parallel training (DDP).
- When running locally, avoid trying to create a volume on Modal.
- Allow passing in a configuration file to overwrite default training values. Currently, you have to set them up in the
train.py
file. - Modifying
sample.py
to allow sampling from either remote or local environment. Currently on remote is supported.
The most common issues relate to JAX device placement and JIT compilation. If you're running into errors:
- Ensure JAX can see your accelerators with
jax.devices()
- Try running without JIT using
--compile=False
- Check your JAX/CUDA version compatibility
For more context on transformers and language modeling, check out Andrej Karpathy's Zero To Hero series.
For questions and discussions about this JAX implementation, feel free to open an issue!
- Andrej Karpathy's nanoGPT for the original implementation
- We are also grateful for Anh Tong whose Equinox version of nanoGPT was a source of inspiration for our unique rewrite. We recommend referring to his version of nanoGPT as well here: https://github.com/anh-tong/nanoGPT-equinox
- The JAX team for an amazing framework
- The Equinox team for making JAX feel like PyTorch
- The Modal team for their effort in making serverless GPU usage accessible and affordable. Most importantly, for providing a free $30 credit for each workspace in your account.