diff --git a/README.md b/README.md index 761df44..d6b7c15 100644 --- a/README.md +++ b/README.md @@ -68,9 +68,9 @@ As presented in the code above, the model state is represented as a JAX PyTree o A full collection of examples is available: * [Scalify quickstart notebook](./examples/scalify-quickstart.ipynb): basics of `ScaledArray` and `scalify` transform; -* [MNIST FP16 training example](./experiments/mnist/mnist_classifier_from_scratch.py): adapting JAX MNIST example to `scalify`; -* [MNIST FP8 training example](./experiments/mnist/mnist_classifier_from_scratch.py): easy FP8 support in `scalify`; -* [CIFAR10 training](./experiments/mnist/cifar_training.py): `scalify` CIFAR10 training, with Optax optimizer integration; +* [MNIST FP16 training example](./examples/mnist/mnist_classifier_from_scratch.py): adapting JAX MNIST example to `scalify`; +* [MNIST FP8 training example](./examples/mnist/mnist_classifier_from_scratch_fp8.py): easy FP8 support in `scalify`; +* [MNIST Flax example](./examples/mnist/flax): `scalify` Flax training, with Optax optimizer integration; ## Installation diff --git a/experiments/mnist/cifar_training.py b/examples/cifar10/cifar10_training.py similarity index 92% rename from experiments/mnist/cifar_training.py rename to examples/cifar10/cifar10_training.py index d715d1d..7d4e89e 100644 --- a/experiments/mnist/cifar_training.py +++ b/examples/cifar10/cifar10_training.py @@ -11,16 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# Modified by Graphcore Ltd 2024. -"""A basic MNIST example using Numpy and JAX. +"""A basic CIFAR10 example using Numpy and JAX. -The primary aim here is simplicity and minimal dependencies. +CIFAR10 training using MLP network + raw SGD optimizer. """ - - import time -import datasets +import dataset_cifar10 import jax import jax.numpy as jnp import numpy as np @@ -100,7 +99,7 @@ def accuracy(params, batch): training_dtype = np.float16 scale_dtype = np.float32 - train_images, train_labels, test_images, test_labels = datasets.cifar() + train_images, train_labels, test_images, test_labels = dataset_cifar10.cifar() num_train = train_images.shape[0] num_complete_batches, leftover = divmod(num_train, batch_size) num_batches = num_complete_batches + bool(leftover) @@ -118,7 +117,7 @@ def data_stream(): # Transform parameters to `ScaledArray` and proper dtype. if use_scalify: params = jsa.as_scaled_array(params, scale=scale_dtype(param_scale)) - params = jax.tree_map(lambda v: v.astype(training_dtype), params, is_leaf=jsa.core.is_scaled_leaf) + params = jax.tree_util.tree_map(lambda v: v.astype(training_dtype), params, is_leaf=jsa.core.is_scaled_leaf) @jit @scalify @@ -133,7 +132,7 @@ def update(params, batch): # Scaled micro-batch + training dtype cast. if use_scalify: batch = jsa.as_scaled_array(batch, scale=scale_dtype(param_scale)) - batch = jax.tree_map(lambda v: v.astype(training_dtype), batch, is_leaf=jsa.core.is_scaled_leaf) + batch = jax.tree_util.tree_map(lambda v: v.astype(training_dtype), batch, is_leaf=jsa.core.is_scaled_leaf) with jsa.ScalifyConfig(rounding_mode=jsa.Pow2RoundMode.DOWN, scale_dtype=scale_dtype): params = update(params, batch) diff --git a/experiments/mnist/optax_cifar_training.py b/examples/cifar10/cifar10_training_with_optax.py similarity index 94% rename from experiments/mnist/optax_cifar_training.py rename to examples/cifar10/cifar10_training_with_optax.py index 3a078b6..13e8e72 100644 --- a/experiments/mnist/optax_cifar_training.py +++ b/examples/cifar10/cifar10_training_with_optax.py @@ -11,16 +11,15 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# Modified by Graphcore Ltd 2024. -"""A basic MNIST example using Numpy and JAX. - -The primary aim here is simplicity and minimal dependencies. +"""A basic CIFAR10 example using Numpy and JAX. """ import time -import datasets +import dataset_cifar10 import jax import jax.numpy as jnp import numpy as np @@ -65,10 +64,6 @@ def predict(params, inputs): final_w, final_b = params[-1] logits = jnp.dot(activations, final_w) + final_b - - # jsa.ops.debug_callback(partial(print_mean_std, "Logits"), logits) - # (logits,) = jsa.ops.debug_callback_grad(partial(print_mean_std, "LogitsGrad"), logits) - # Dynamic rescaling of the gradient, as logits gradient not properly scaled. logits = jsa.ops.dynamic_rescale_l2_grad(logits) output = logits - logsumexp(logits, axis=1, keepdims=True) @@ -102,7 +97,7 @@ def accuracy(params, batch): batch_size = 128 scale_dtype = np.float32 - train_images, train_labels, test_images, test_labels = datasets.cifar() + train_images, train_labels, test_images, test_labels = dataset_cifar10.cifar() num_train = train_images.shape[0] num_complete_batches, leftover = divmod(num_train, batch_size) num_batches = num_complete_batches + bool(leftover) diff --git a/examples/cifar10/dataset_cifar10.py b/examples/cifar10/dataset_cifar10.py new file mode 100644 index 0000000..38fe6be --- /dev/null +++ b/examples/cifar10/dataset_cifar10.py @@ -0,0 +1,154 @@ +# Copyright 2018 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Modified by Graphcore Ltd 2024. + +"""Datasets used in examples.""" + + +import array +import gzip +import os +import pickle +import struct +import tarfile +import urllib.request +from os import path + +import numpy as np + +_DATA = "/tmp/jax_example_data/" + + +def _download(url, filename): + """Download a url to a file in the JAX data temp directory.""" + if not path.exists(_DATA): + os.makedirs(_DATA) + out_file = path.join(_DATA, filename) + if not path.isfile(out_file): + urllib.request.urlretrieve(url, out_file) + print(f"downloaded {url} to {_DATA}") + + +def _partial_flatten(x): + """Flatten all but the first dimension of an ndarray.""" + return np.reshape(x, (x.shape[0], -1)) + + +def _one_hot(x, k, dtype=np.float32): + """Create a one-hot encoding of x of size k.""" + return np.array(x[:, None] == np.arange(k), dtype) + + +def _unzip(file): + file = tarfile.open(file) + file.extractall(_DATA) + file.close() + return + + +def _unpickle(file): + with open(file, "rb") as fo: + dict = pickle.load(fo, encoding="bytes") + return dict + + +def mnist_raw(): + """Download and parse the raw MNIST dataset.""" + # CVDF mirror of http://yann.lecun.com/exdb/mnist/ + base_url = "https://storage.googleapis.com/cvdf-datasets/mnist/" + + def parse_labels(filename): + with gzip.open(filename, "rb") as fh: + _ = struct.unpack(">II", fh.read(8)) + return np.array(array.array("B", fh.read()), dtype=np.uint8) + + def parse_images(filename): + with gzip.open(filename, "rb") as fh: + _, num_data, rows, cols = struct.unpack(">IIII", fh.read(16)) + return np.array(array.array("B", fh.read()), dtype=np.uint8).reshape(num_data, rows, cols) + + for filename in [ + "train-images-idx3-ubyte.gz", + "train-labels-idx1-ubyte.gz", + "t10k-images-idx3-ubyte.gz", + "t10k-labels-idx1-ubyte.gz", + ]: + _download(base_url + filename, filename) + + train_images = parse_images(path.join(_DATA, "train-images-idx3-ubyte.gz")) + train_labels = parse_labels(path.join(_DATA, "train-labels-idx1-ubyte.gz")) + test_images = parse_images(path.join(_DATA, "t10k-images-idx3-ubyte.gz")) + test_labels = parse_labels(path.join(_DATA, "t10k-labels-idx1-ubyte.gz")) + + return train_images, train_labels, test_images, test_labels + + +def mnist(permute_train=False): + """Download, parse and process MNIST data to unit scale and one-hot labels.""" + train_images, train_labels, test_images, test_labels = mnist_raw() + + train_images = _partial_flatten(train_images) / np.float32(255.0) + test_images = _partial_flatten(test_images) / np.float32(255.0) + train_labels = _one_hot(train_labels, 10) + test_labels = _one_hot(test_labels, 10) + + if permute_train: + perm = np.random.RandomState(0).permutation(train_images.shape[0]) + train_images = train_images[perm] + train_labels = train_labels[perm] + + return train_images, train_labels, test_images, test_labels + + +def cifar_raw(): + """Download, unzip and parse the raw cifar dataset.""" + + filename = "cifar-10-python.tar.gz" + url = "https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz" + _download(url, filename) + _unzip(path.join(_DATA, filename)) + + data_batches = ["data_batch_1", "data_batch_2", "data_batch_3", "data_batch_4", "data_batch_5"] + data = [] + labels = [] + for batch in data_batches: + tmp_dict = _unpickle(path.join(_DATA, "cifar-10-batches-py", batch)) + data.append(tmp_dict[b"data"]) + labels.append(tmp_dict[b"labels"]) + train_images = np.concatenate(data) + train_labels = np.concatenate(labels) + + test_dict = _unpickle(path.join(_DATA, "cifar-10-batches-py", "test_batch")) + test_images = test_dict[b"data"] + test_labels = np.array(test_dict[b"labels"]) + + return train_images, train_labels, test_images, test_labels + + +def cifar(permute_train=False): + """Download, parse and process cifar data to unit scale and one-hot labels.""" + + train_images, train_labels, test_images, test_labels = cifar_raw() + + train_images = train_images / np.float32(255.0) + test_images = test_images / np.float32(255.0) + train_labels = _one_hot(train_labels, 10) + test_labels = _one_hot(test_labels, 10) + + if permute_train: + perm = np.random.RandomState(0).permutation(train_images.shape[0]) + train_images = train_images[perm] + train_labels = train_labels[perm] + + return train_images, train_labels, test_images, test_labels diff --git a/experiments/mnist/datasets.py b/examples/mnist/datasets.py similarity index 99% rename from experiments/mnist/datasets.py rename to examples/mnist/datasets.py index 5f8eff3..e359955 100644 --- a/experiments/mnist/datasets.py +++ b/examples/mnist/datasets.py @@ -11,6 +11,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# Modified by Graphcore Ltd 2024. + """Datasets used in examples.""" diff --git a/experiments/mnist/flax_example/README.md b/examples/mnist/flax/README.md similarity index 100% rename from experiments/mnist/flax_example/README.md rename to examples/mnist/flax/README.md diff --git a/experiments/mnist/flax_example/configs/__init__.py b/examples/mnist/flax/configs/__init__.py similarity index 100% rename from experiments/mnist/flax_example/configs/__init__.py rename to examples/mnist/flax/configs/__init__.py diff --git a/experiments/mnist/flax_example/configs/default.py b/examples/mnist/flax/configs/default.py similarity index 100% rename from experiments/mnist/flax_example/configs/default.py rename to examples/mnist/flax/configs/default.py diff --git a/experiments/mnist/flax_example/main.py b/examples/mnist/flax/main.py similarity index 95% rename from experiments/mnist/flax_example/main.py rename to examples/mnist/flax/main.py index 29cd53e..061cc89 100644 --- a/experiments/mnist/flax_example/main.py +++ b/examples/mnist/flax/main.py @@ -19,7 +19,8 @@ """ import jax -import tensorflow as tf + +# import tensorflow as tf import train from absl import app, flags, logging from clu import platform @@ -42,7 +43,7 @@ def main(argv): # Hide any GPUs from TensorFlow. Otherwise TF might reserve memory and make # it unavailable to JAX. - tf.config.experimental.set_visible_devices([], "GPU") + # tf.config.experimental.set_visible_devices([], "GPU") logging.info("JAX process: %d / %d", jax.process_index(), jax.process_count()) logging.info("JAX local devices: %r", jax.local_devices()) diff --git a/experiments/mnist/flax_example/requirements.txt b/examples/mnist/flax/requirements.txt similarity index 100% rename from experiments/mnist/flax_example/requirements.txt rename to examples/mnist/flax/requirements.txt diff --git a/experiments/mnist/flax_example/train.py b/examples/mnist/flax/train.py similarity index 92% rename from experiments/mnist/flax_example/train.py rename to examples/mnist/flax/train.py index 3c09e2c..8e6e737 100644 --- a/experiments/mnist/flax_example/train.py +++ b/examples/mnist/flax/train.py @@ -28,8 +28,9 @@ import optax import tensorflow_datasets as tfds from absl import logging -from flax import linen as nn -from flax.metrics import tensorboard +from flax import linen as nn # type:ignore + +# from flax.metrics import tensorboard from flax.training import train_state import jax_scalify as jsa @@ -143,8 +144,8 @@ def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str) -> train train_ds, test_ds = get_datasets() rng = jax.random.key(0) - summary_writer = tensorboard.SummaryWriter(workdir) - summary_writer.hparams(dict(config)) + # summary_writer = tensorboard.SummaryWriter(workdir) + # summary_writer.hparams(dict(config)) rng, init_rng = jax.random.split(rng) init_rng = jax.random.PRNGKey(1) @@ -173,10 +174,10 @@ def train_and_evaluate(config: ml_collections.ConfigDict, workdir: str) -> train ) ) - summary_writer.scalar("train_loss", train_loss, epoch) - summary_writer.scalar("train_accuracy", train_accuracy, epoch) - summary_writer.scalar("test_loss", test_loss, epoch) - summary_writer.scalar("test_accuracy", test_accuracy, epoch) + # summary_writer.scalar("train_loss", train_loss, epoch) + # summary_writer.scalar("train_accuracy", train_accuracy, epoch) + # summary_writer.scalar("test_loss", test_loss, epoch) + # summary_writer.scalar("test_accuracy", test_accuracy, epoch) - summary_writer.flush() + # summary_writer.flush() return state diff --git a/experiments/mnist/mnist_classifier_from_scratch.py b/examples/mnist/mnist_classifier_from_scratch.py similarity index 96% rename from experiments/mnist/mnist_classifier_from_scratch.py rename to examples/mnist/mnist_classifier_from_scratch.py index 69cab06..fb54341 100644 --- a/experiments/mnist/mnist_classifier_from_scratch.py +++ b/examples/mnist/mnist_classifier_from_scratch.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# Modified by Graphcore Ltd 2024. """A basic MNIST example using Numpy and JAX. @@ -78,9 +79,9 @@ def accuracy(params, batch): if __name__ == "__main__": - layer_sizes = [784, 1024, 1024, 10] - param_scale = 1.0 - step_size = 0.001 + layer_sizes = [784, 512, 512, 10] + param_scale = 0.1 + step_size = 0.1 num_epochs = 10 batch_size = 128 @@ -125,7 +126,7 @@ def update(params, batch): epoch_time = time.time() - start_time - # Evaluation in float32, for consistency. + # Evaluation in normal/unscaled float32, for consistency. raw_params = jsa.asarray(params, dtype=np.float32) train_acc = accuracy(raw_params, (train_images, train_labels)) test_acc = accuracy(raw_params, (test_images, test_labels)) diff --git a/experiments/mnist/mnist_classifier_from_scratch_fp8.py b/examples/mnist/mnist_classifier_from_scratch_fp8.py similarity index 97% rename from experiments/mnist/mnist_classifier_from_scratch_fp8.py rename to examples/mnist/mnist_classifier_from_scratch_fp8.py index 01912f3..5f142d1 100644 --- a/experiments/mnist/mnist_classifier_from_scratch_fp8.py +++ b/examples/mnist/mnist_classifier_from_scratch_fp8.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +# Modified by Graphcore Ltd 2024. """A basic MNIST example using Numpy and JAX. @@ -34,6 +35,7 @@ def print_mean_std(name, v): + """Debugging method/tool for JAX Scalify.""" data, scale = jsa.lax.get_data_scale(v) # Always use np.float32, to avoid floating errors in descaling + stats. data = jsa.asarray(data, dtype=np.float32) @@ -105,9 +107,9 @@ def accuracy(params, batch): if __name__ == "__main__": - layer_sizes = [784, 1024, 1024, 10] - param_scale = 1.0 - step_size = 0.001 + layer_sizes = [784, 512, 512, 10] + param_scale = 0.1 + step_size = 0.1 num_epochs = 10 batch_size = 128 diff --git a/experiments/mnist/mnist_classifier.py b/experiments/mnist/mnist_classifier.py deleted file mode 100644 index efb058a..0000000 --- a/experiments/mnist/mnist_classifier.py +++ /dev/null @@ -1,107 +0,0 @@ -# Copyright 2018 The JAX Authors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""A basic MNIST example using JAX with the mini-libraries stax and optimizers. - -The mini-library jax.example_libraries.stax is for neural network building, and -the mini-library jax.example_libraries.optimizers is for first-order stochastic -optimization. -""" - - -import itertools -import time - -import datasets -import jax.numpy as jnp -import numpy as np -import numpy.random as npr -from jax import grad, jit, random -from jax.example_libraries import optimizers, stax -from jax.example_libraries.stax import Dense, LogSoftmax, Relu - -import jax_scalify as jsa - - -def loss(params, batch): - inputs, targets = batch - preds = predict(params, inputs) - return -jnp.mean(jnp.sum(preds * targets, axis=1)) - - -def accuracy(params, batch): - inputs, targets = batch - target_class = jnp.argmax(targets, axis=1) - predicted_class = jnp.argmax(predict(params, inputs), axis=1) - return jnp.mean(predicted_class == target_class) - - -init_random_params, predict = stax.serial(Dense(1024), Relu, Dense(1024), Relu, Dense(10), LogSoftmax) - -if __name__ == "__main__": - rng = random.PRNGKey(0) - - step_size = 0.001 - num_epochs = 10 - batch_size = 128 - momentum_mass = 0.9 - - train_images, train_labels, test_images, test_labels = datasets.mnist() - num_train = train_images.shape[0] - num_complete_batches, leftover = divmod(num_train, batch_size) - num_batches = num_complete_batches + bool(leftover) - - def data_stream(): - rng = npr.RandomState(0) - while True: - perm = rng.permutation(num_train) - for i in range(num_batches): - batch_idx = perm[i * batch_size : (i + 1) * batch_size] - yield train_images[batch_idx], train_labels[batch_idx] - - batches = data_stream() - - opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=momentum_mass) - - @jit - @jsa.scalify - def update(i, opt_state, batch): - params = get_params(opt_state) - return opt_update(i, grad(loss)(params, batch), opt_state) - - _, init_params = init_random_params(rng, (-1, 28 * 28)) - opt_state = opt_init(init_params) - itercount = itertools.count() - # Convert weights + optimizer state to scaled arrays (assuming unit scaling initialization). - opt_state = jsa.as_scaled_array(opt_state, scale=np.float32(1.0)) - - print("\nStarting training...") - for epoch in range(num_epochs): - start_time = time.time() - - for _ in range(num_batches): - batch = next(batches) - # Convert batch to ScaledArray (assuming proper normalized data). - batch = jsa.as_scaled_array(batch, scale=np.float32(1.0)) - opt_state = update(next(itercount), opt_state, batch) - epoch_time = time.time() - start_time - - params = get_params(opt_state) - # Evaluate model without scaling. - params = jsa.asarray(params) - train_acc = accuracy(params, (train_images, train_labels)) - test_acc = accuracy(params, (test_images, test_labels)) - print(f"Epoch {epoch} in {epoch_time:0.2f} sec") - print(f"Training set accuracy {train_acc}") - print(f"Test set accuracy {test_acc}") diff --git a/experiments/nanogpt/model.py b/experiments/nanogpt/model.py deleted file mode 100644 index d535346..0000000 --- a/experiments/nanogpt/model.py +++ /dev/null @@ -1,337 +0,0 @@ -""" -Full definition of a GPT Language Model, all of it in this single file. -References: -1) the official GPT-2 TensorFlow implementation released by OpenAI: -https://github.com/openai/gpt-2/blob/master/src/model.py -2) huggingface/transformers PyTorch implementation: -https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py -""" - -from dataclasses import dataclass -from typing import Any, Optional, Tuple - -import flax.linen as nn -import jax -import jax.numpy as jnp -import optax -from flax import traverse_util -from flax.core import freeze -from flax.training import train_state -from flax.traverse_util import path_aware_map - - -@dataclass -class GPTConfig: - block_size: int = 1024 - vocab_size: int = 50257 - n_layer: int = 12 - n_head: int = 12 - n_embd: int = 768 - dropout: float = 0.1 - - -class CausalSelfAttention(nn.Module): - config: GPTConfig - - def setup(self): - config = self.config - assert config.n_embd % config.n_head == 0 - # head_size = config.n_embd // config.n_head - # key, query, value projections for all heads, but in a batch - self.c_attn = nn.Dense(config.n_embd * 3) - # output projection - self.c_proj = nn.Dense(config.n_embd) - # regularization - self.attn_dropout = nn.Dropout(config.dropout) - self.resid_dropout = nn.Dropout(config.dropout) - # causal mask to ensure that attention is only applied to the left in the input sequence - self.n_head = config.n_head - self.n_embd = config.n_embd - - def __call__(self, x: jax.Array, *, train: bool) -> jax.Array: - B, T, C = x.shape # batch size, sequence length, embedding dimensionality (n_embd) - - # calculate query, key, values for all heads in batch and move head forward to be the batch dim - qkv = self.c_attn(x) - q, k, v = jnp.split(qkv, 3, axis=-1) - q = q.reshape(B, T, self.n_head, C // self.n_head).swapaxes(1, 2) # (B, nh, T, hs) - k = k.reshape(B, T, self.n_head, C // self.n_head).swapaxes(1, 2) # (B, nh, T, hs) - v = v.reshape(B, T, self.n_head, C // self.n_head).swapaxes(1, 2) # (B, nh, T, hs) - - mask = jnp.tril(jnp.ones((T, T))).reshape((1, 1, T, T)) - - # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) - att = (q @ k.swapaxes(-2, -1)) * (1.0 / jnp.sqrt(k.shape[-1])) - att = jnp.where(mask == 0, float("-inf"), att) - att = nn.softmax(att, axis=-1) - att = self.attn_dropout(att, deterministic=not train) - y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) - y = y.swapaxes(1, 2).reshape(B, T, C) # re-assemble all head outputs side by side - # output projection - y = self.resid_dropout(self.c_proj(y), deterministic=not train) - return y - - -class MLP(nn.Module): - config: GPTConfig - - def setup(self): - config = self.config - self.c_fc = nn.Dense(4 * config.n_embd) - self.c_proj = nn.Dense(config.n_embd) - self.dropout = nn.Dropout(config.dropout) - - def __call__(self, x: jax.Array, *, train: bool) -> jax.Array: - x = self.c_fc(x) - x = nn.gelu(x, approximate=True) - x = self.c_proj(x) - x = self.dropout(x, deterministic=not train) - return x - - -class Block(nn.Module): - config: GPTConfig - - def setup(self): - config = self.config - self.ln_1 = nn.LayerNorm(epsilon=1e-5) - self.attn = CausalSelfAttention(config) - self.ln_2 = nn.LayerNorm(epsilon=1e-5) - self.mlp = MLP(config) - - def __call__(self, x: jax.Array, *, train: bool) -> jax.Array: - x = x + self.attn(self.ln_1(x), train=train) - x = x + self.mlp(self.ln_2(x), train=train) - return x - - -class GPT(nn.Module): - config: GPTConfig - - def setup(self): - config = self.config - assert config.vocab_size is not None - assert config.block_size is not None - - self.wte = nn.Embed(config.vocab_size, config.n_embd) - self.wpe = nn.Embed(config.block_size, config.n_embd) - self.drop = nn.Dropout(config.dropout) - self.h = [Block(config) for _ in range(config.n_layer)] - self.ln_f = nn.LayerNorm() - - def __call__(self, idx: jax.Array, *, train: bool, targets: Optional[jax.Array] = None) -> jax.Array: - b, t = idx.shape - assert ( - t <= self.config.block_size - ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" - pos = jnp.arange(0, t, dtype=jnp.int32)[None] # shape (1, t) - - # forward the GPT model itself - tok_emb = self.wte(idx) # token embeddings of shape (b, t, n_embd) - pos_emb = self.wpe(pos) # position embeddings of shape (1, t, n_embd) - x = self.drop(tok_emb + pos_emb, deterministic=not train) - for block in self.h: - x = block(x, train=train) - x = self.ln_f(x) - - logits = self.wte.attend(x) - - if targets is not None: - # if we are given some desired targets also calculate the loss - loss = optax.softmax_cross_entropy_with_integer_labels(logits, targets).mean() - else: - loss = None - - return logits, loss - - def crop_block_size(self, params: Any, block_size: int) -> Any: - # model surgery to decrease the block size if necessary - # e.g. we may load the GPT2 pretrained model checkpoint (block size 1024) - # but want to use a smaller block size for some smaller, simpler model - - assert block_size <= self.config.block_size - self.config.block_size = block_size - - # self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size]) - def crop_weights(path: Tuple[str, ...], x: Any) -> Any: - if path[-2:] == ("wpe", "embedding"): - return x[:block_size] - return x - - return freeze(path_aware_map(crop_weights, params)) - - @classmethod - def from_pretrained(cls, model_type, override_args=None): - assert model_type in {"gpt2", "gpt2-medium", "gpt2-large", "gpt2-xl"} - override_args = override_args or {} # default to empty dict - # only dropout can be overridden see more notes below - assert all(k == "dropout" for k in override_args) - from transformers import GPT2LMHeadModel - - print("loading weights from pretrained gpt: %s" % model_type) - - # n_layer, n_head and n_embd are determined from model_type - config_args = { - "gpt2": dict(n_layer=12, n_head=12, n_embd=768), # 124M params - "gpt2-medium": dict(n_layer=24, n_head=16, n_embd=1024), # 350M params - "gpt2-large": dict(n_layer=36, n_head=20, n_embd=1280), # 774M params - "gpt2-xl": dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params - }[model_type] - # we can override the dropout rate - if "dropout" in override_args: - config_args["dropout"] = override_args["dropout"] - # block_size is always 1024 for GPT model checkpoints - # if one wants a lower block_size it has to be done through model surgery - # later, by calling crop_block_shape - - # create a from-scratch initialized minGPT model - config = GPTConfig(block_size=1024, **config_args) - model = GPT(config) - variables = jax.eval_shape( - lambda: model.init(jax.random.PRNGKey(0), jnp.ones((1, 1), dtype=jnp.int32), train=False) - ) - params = variables["params"] - flat_params = traverse_util.flatten_dict(params, sep=".") - - # init a huggingface/transformers model - model_hf = GPT2LMHeadModel.from_pretrained(model_type) - sd_hf = model_hf.state_dict() - - def copy_from(flax_name, pt_name, transpose=False, add_head_dim=False): - pt_tensor = sd_hf[pt_name] - jax_array = flat_params[flax_name] - if transpose: - pt_tensor = pt_tensor.t() - pt_array = pt_tensor.detach().cpu().numpy() - - if add_head_dim: - # pt_array = pt_array.reshape(*pt_array.shape[:-1], config.n_head, -1, 3) - pass - - assert pt_array.shape == jax_array.shape - - flat_params[flax_name] = pt_array - - # transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight'] - copy_from("wte.embedding", "transformer.wte.weight") - copy_from("wpe.embedding", "transformer.wpe.weight") - - for i in range(config.n_layer): - copy_from(f"h_{i}.ln_1.scale", f"transformer.h.{i}.ln_1.weight") - copy_from(f"h_{i}.ln_1.bias", f"transformer.h.{i}.ln_1.bias") - copy_from(f"h_{i}.attn.c_attn.kernel", f"transformer.h.{i}.attn.c_attn.weight", add_head_dim=True) - copy_from(f"h_{i}.attn.c_attn.bias", f"transformer.h.{i}.attn.c_attn.bias", add_head_dim=True) - copy_from(f"h_{i}.attn.c_proj.kernel", f"transformer.h.{i}.attn.c_proj.weight") - copy_from(f"h_{i}.attn.c_proj.bias", f"transformer.h.{i}.attn.c_proj.bias") - copy_from(f"h_{i}.ln_2.scale", f"transformer.h.{i}.ln_2.weight") - copy_from(f"h_{i}.ln_2.bias", f"transformer.h.{i}.ln_2.bias") - copy_from(f"h_{i}.mlp.c_fc.kernel", f"transformer.h.{i}.mlp.c_fc.weight") - copy_from(f"h_{i}.mlp.c_fc.bias", f"transformer.h.{i}.mlp.c_fc.bias") - copy_from(f"h_{i}.mlp.c_proj.kernel", f"transformer.h.{i}.mlp.c_proj.weight") - copy_from(f"h_{i}.mlp.c_proj.bias", f"transformer.h.{i}.mlp.c_proj.bias") - - copy_from("ln_f.scale", "transformer.ln_f.weight") - copy_from("ln_f.bias", "transformer.ln_f.bias") - - params = freeze(traverse_util.unflatten_dict(flat_params, sep=".")) - - return model, params - - def configure_optimizers(self, params, weight_decay, learning_rate, betas): - """ - This long function is unfortunately doing something very simple and is being very defensive: - We are separating out all parameters of the model into two buckets: those that will experience - weight decay for regularization and those that won't (biases, and layernorm/embedding weights). - We are then returning the PyTorch optimizer object. - """ - - def get_optimizer(decay): - return optax.adamw(learning_rate=learning_rate, b1=betas[0], b2=betas[1], weight_decay=decay) - - def partition_fn(path: Tuple[str, ...], x: Any) -> str: - if path[-1] in ("bias", "scale", "embedding"): - return "no_decay" - elif path[-1] in ("kernel",): - return "decay" - else: - raise ValueError(f"Unrecognized parameter: {path}") - - partition_optimizers = {"decay": get_optimizer(weight_decay), "no_decay": get_optimizer(0.0)} - param_partitions = freeze(path_aware_map(partition_fn, params)) - tx = optax.multi_transform(partition_optimizers, param_partitions) - - return tx - - # @torch.no_grad() - def generate(self, key, params, input_tokens, max_new_tokens, temperature=1.0, top_k=None): - """ - Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete - the sequence max_new_tokens times, feeding the predictions back into the model each time. - Most likely you'll want to make sure to be in model.eval() mode of operation for this. - """ - B, T = input_tokens.shape - padding = jnp.zeros((B, max_new_tokens), dtype=jnp.int32) - tokens = jnp.concatenate([input_tokens, padding], axis=-1) - indexes = jnp.arange(T, T + max_new_tokens) - - # tokens index -> tokens None - def scan_f(tokens, i): - # l: x y - # t: a b - - - # i: 0 1 2 3 - step_key = jax.random.fold_in(key, i) - # if the sequence context is growing too long we must crop it at block_size - # idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:] - # forward the model to get the logits for the index in the sequence - logits, _ = self.apply({"params": params}, tokens, train=False) - # pluck the logits at the final step and scale by desired temperature - logits = logits[:, i - 1, :] / temperature - # optionally crop the logits to only the top k options - # sample from the distribution - if top_k is not None: - top_logits, top_tokens = jax.lax.top_k(logits, min(top_k, logits.shape[-1])) - token_idx = jax.random.categorical(step_key, top_logits, axis=-1) - next_token = jnp.take_along_axis(top_tokens, token_idx[:, None], axis=-1).squeeze(-1) - else: - next_token = jax.random.categorical(step_key, logits, axis=-1) - # logits = jnp.where(logits < v[:, -1:], float('-inf'), logits) - # append sampled index to the running sequence and continue - tokens = tokens.at[:, i].set(next_token) - - return tokens, None - - tokens, _ = jax.lax.scan(scan_f, tokens, indexes) - - return tokens - - def create_state( - self, - learning_rate, - weight_decay, - beta1, - beta2, - decay_lr=None, - warmup_iters=None, - lr_decay_iters=None, - min_lr=None, - params=None, - **kwargs, - ): - if params is None: - variables = self.init(jax.random.PRNGKey(0), jnp.ones((1, 1), dtype=jnp.int32), train=False) - params = variables["params"] - if decay_lr: - assert warmup_iters is not None and lr_decay_iters is not None and min_lr is not None - lr_schedule = optax.warmup_cosine_decay_schedule( - init_value=0.0, - peak_value=learning_rate, - warmup_steps=warmup_iters, - decay_steps=lr_decay_iters, - end_value=min_lr, - ) - else: - lr_schedule = learning_rate - tx = self.configure_optimizers( - params, weight_decay=weight_decay, learning_rate=lr_schedule, betas=(beta1, beta2) - ) - return train_state.TrainState.create(apply_fn=self.apply, params=params, tx=tx) diff --git a/experiments/nanogpt/model_info.py b/experiments/nanogpt/model_info.py deleted file mode 100644 index 6400330..0000000 --- a/experiments/nanogpt/model_info.py +++ /dev/null @@ -1,28 +0,0 @@ -"""Testing NanoGPT JAX model definition. - -Inspired by: https://github.com/cgarciae/nanoGPT-jax/blob/master/train.py -""" - -import jax -import jax.numpy as jnp -from model import GPT, GPTConfig - -gpt2_tiny = GPTConfig(block_size=128, vocab_size=32000, n_layer=2, n_head=8, n_embd=512) -train_config = dict( - learning_rate=0.001, - weight_decay=0.1, - beta1=1, - beta2=1, -) - -rng_key = jax.random.PRNGKey(0) -init_value = jnp.ones((1, 1), dtype=jnp.int32) - -model = GPT(gpt2_tiny) -# initialize weights -# state = model.create_state(**train_config) -params = model.init(rng_key, init_value, train=False) -print("Model initialized...") - -# Model description -print(model.tabulate(rng_key, init_value, train=False)) diff --git a/pyproject.toml b/pyproject.toml index a88e1a1..083e0e1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,3 +69,4 @@ disallow_incomplete_defs = true # # disallow_subclassing_any = true # # for strict mypy: (this is the tricky one :-)) # disallow_untyped_defs = true +exclude = ['examples']