diff --git a/README.md b/README.md index ccc4c6e..bc54599 100644 --- a/README.md +++ b/README.md @@ -25,4 +25,5 @@ sys.path.insert(0, os.path.expanduser("~/micrograd-pp/python")) ## Examples * [Train a simple feedforward neural network on MNIST to classify handwritten digits](https://nbviewer.org/github/parsiad/micrograd-pp/blob/main/examples/mnist.ipynb) -* [Learn an n-gram model](https://nbviewer.org/github/parsiad/micrograd-pp/blob/main/examples/n-gram.ipynb) +* [Learn an n-gram model to generate text](https://nbviewer.org/github/parsiad/micrograd-pp/blob/main/examples/n-gram.ipynb) +* [Train a decoder-only transformer to generate text](https://nbviewer.org/github/parsiad/micrograd-pp/blob/main/examples/transformer.ipynb) diff --git a/examples/transformer.ipynb b/examples/transformer.ipynb new file mode 100644 index 0000000..1d259e0 --- /dev/null +++ b/examples/transformer.ipynb @@ -0,0 +1,455 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "c56226d9-be44-4db2-ae09-04f355dced98", + "metadata": {}, + "source": [ + "# Transformer\n", + "\n", + "This notebook trains a decoder-only transformer to perform next token prediction on the Tiny Shakespeare dataset.\n", + "\n", + "This should at most be considered a purely educational exercise: since Micrograd++ does not yet support GPU training, the parameters are chosen to be fairly restrictive (e.g., short context, small validation set, etc.) to make CPU training tolerable." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "31c06100-f075-403d-91e2-60bc4bd142a1", + "metadata": {}, + "outputs": [], + "source": [ + "import micrograd_pp as mpp\n", + "import numpy as np\n", + "import numpy.typing as npt" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "9e15b90d-c80a-4bc5-8cca-f2fd4655fb3d", + "metadata": {}, + "outputs": [], + "source": [ + "CONTEXT_WIDTH = 32\n", + "DROPOUT = 0.0\n", + "EMBEDDING_DIM = 128\n", + "EVAL_FREQ = 500\n", + "HIDDEN_SIZE = EMBEDDING_DIM * 4\n", + "LEARNING_RATE = 0.1\n", + "NUM_BLOCKS = 3\n", + "NUM_HEADS = 4\n", + "NUM_ITERS = 10_000\n", + "TRAIN_BATCH_SIZE = 96\n", + "TRAIN_FRAC = 0.99\n", + "VAL_BATCH_SIZE = 4_096" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "4e23ae66-1627-438f-8de7-d09524bf6b36", + "metadata": {}, + "outputs": [], + "source": [ + "text = mpp.datasets.load_tiny_shakespeare()" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "288a0a7d-c6ef-4fd8-8716-c390bc0ca3cf", + "metadata": {}, + "outputs": [], + "source": [ + "vocab = sorted(set(text))\n", + "vocab_size = len(vocab)\n", + "\n", + "char2token = {char: token for token, char in enumerate(vocab)}\n", + "all_tokens = np.array([char2token[char] for char in text], dtype=np.int32)\n", + "\n", + "first_val_index = int(TRAIN_FRAC * all_tokens.size)\n", + "train_tokens = all_tokens[:first_val_index]\n", + "val_tokens = all_tokens[first_val_index:]" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "0b48de65-5e4c-4b2d-a5f5-55c3b0811fb2", + "metadata": {}, + "outputs": [], + "source": [ + "class Block:\n", + " def __init__(self) -> None:\n", + " self._ln1 = mpp.LayerNorm(EMBEDDING_DIM)\n", + " self._attn = mpp.MultiheadAttention(\n", + " embed_dim=EMBEDDING_DIM,\n", + " num_heads=NUM_HEADS,\n", + " batch_first=True\n", + " )\n", + " attn_mask_np = np.zeros((CONTEXT_WIDTH, CONTEXT_WIDTH))\n", + " attn_mask_np[np.triu_indices_from(attn_mask_np, k=1)] = -np.inf\n", + " self._attn_mask = mpp.Constant(attn_mask_np)\n", + " self._dropout = mpp.Dropout(DROPOUT)\n", + " self._ln2 = mpp.LayerNorm(EMBEDDING_DIM)\n", + " self._ff = mpp.Sequential(\n", + " mpp.Linear(in_features=EMBEDDING_DIM, out_features=HIDDEN_SIZE),\n", + " mpp.ReLU(),\n", + " mpp.Linear(in_features=HIDDEN_SIZE, out_features=EMBEDDING_DIM),\n", + " mpp.Dropout(DROPOUT),\n", + " )\n", + "\n", + " def __call__(\n", + " self,\n", + " x: mpp.Expr # (N, L, E)\n", + " ) -> mpp.Expr:\n", + " x = self._ln1(x)\n", + " x = x + self._dropout(self._attn(x, x, x, attn_mask=self._attn_mask)[0])\n", + " x = self._ln2(x)\n", + " x = x + self._ff(x)\n", + " return x # (N, L, E)\n", + "\n", + "class DecoderOnlyTransformer:\n", + " def __init__(self) -> None:\n", + " self._tok_embedding = mpp.Embedding(\n", + " num_embeddings=vocab_size,\n", + " embedding_dim=EMBEDDING_DIM,\n", + " label=\"token_embedding\",\n", + " )\n", + " self._pos_embedding = mpp.Embedding(\n", + " num_embeddings=CONTEXT_WIDTH,\n", + " embedding_dim=EMBEDDING_DIM,\n", + " label=\"positional_embedding\",\n", + " )\n", + " self._blocks = mpp.Sequential(*[Block() for _ in range(NUM_BLOCKS)])\n", + " self._ln = mpp.LayerNorm(EMBEDDING_DIM)\n", + " self._output_proj = mpp.Linear(\n", + " in_features=EMBEDDING_DIM,\n", + " out_features=vocab_size,\n", + " label=\"output_projection\",\n", + " )\n", + "\n", + " def __call__(\n", + " self,\n", + " tokens: npt.NDArray, # (N, L)\n", + " ) -> None:\n", + " t = self._tok_embedding(tokens) # (N, L, E)\n", + " p = self._pos_embedding(np.arange(CONTEXT_WIDTH)) # (L, E)\n", + " x = t + p # (N, L, E)\n", + " x = self._blocks(x) # (N, L, E)\n", + " x = self._ln(x) # (N, L, E)\n", + " return self._output_proj(x) # (N, L, V)\n", + "\n", + "def loss(model: mpp.Module, indices: npt.NDArray, user_data: npt.NDArray) -> mpp.Expr:\n", + " \"\"\"Compute loss on a random batch.\"\"\"\n", + " x = np.stack([user_data[index - CONTEXT_WIDTH :index ] for index in indices]) # (N, L)\n", + " y = np.stack([user_data[index - CONTEXT_WIDTH + 1:index + 1] for index in indices]) # (N, L)\n", + " yhat = model(x).reshape((-1, vocab_size)) # (N * L, V)\n", + " y = y.reshape(-1) # (N * L,)\n", + " return mpp.cross_entropy_loss(yhat, y)\n", + "\n", + "def train_loss(model: mpp.Module) -> mpp.Expr:\n", + " \"\"\"Compute loss on a random batch from the training set.\"\"\"\n", + " indices = np.random.randint(low=CONTEXT_WIDTH, high=train_tokens.size, size=(TRAIN_BATCH_SIZE,))\n", + " return loss(model=model, indices=indices, user_data=train_tokens)\n", + "\n", + "def val_loss(model: mpp.Module) -> npt.NDArray:\n", + " \"\"\"Approximate loss on the validation set.\"\"\"\n", + " losses = []\n", + " with mpp.eval(), mpp.no_grad():\n", + " n = 0\n", + " low = CONTEXT_WIDTH\n", + " while low < val_tokens.size:\n", + " high = min(low + VAL_BATCH_SIZE, val_tokens.size)\n", + " indices = np.arange(low, high)\n", + " item = loss(model=model, indices=indices, user_data=val_tokens).value\n", + " losses.append(item)\n", + " low = high\n", + " return np.array(losses).mean()\n", + "\n", + "def generate_sentence(model: mpp.Module, init: npt.NDArray | None = None, length: int = 64) -> str:\n", + " \"\"\"Use a learned decoder-only transformer to generate a sentence.\"\"\"\n", + " with mpp.eval(), mpp.no_grad():\n", + " if init is None:\n", + " init = np.zeros((CONTEXT_WIDTH,), dtype=np.int32)\n", + " context = init\n", + " tokens = []\n", + " for _ in range(length):\n", + " logits = model(context.reshape(1, -1))\n", + " pvals = mpp.softmax(logits, dim=-1)[0, -1, :]\n", + " token = np.random.multinomial(n=1, pvals=pvals.value).argmax().item()\n", + " context[:-1] = context[1:]\n", + " context[-1] = token\n", + " tokens.append(token)\n", + " return ''.join(vocab[token] for token in tokens)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "78207a31-d5a7-4fc3-b86f-ee93b6804bb0", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Uninitialized Embedding\n", + "-----------------------\n", + "Loss: 4.411328242953009\n", + "Random sentence: oxNr\n", + "sMy,C\n", + "AIseiPmrraFfaFMWHwlHsiFjMA&fxRH,pNhixpRM$EEYSPsxPfEoD\n", + "\n" + ] + } + ], + "source": [ + "np.random.seed(0)\n", + "model = DecoderOnlyTransformer()\n", + "\n", + "print(f\"\"\"\n", + "Uninitialized Embedding\n", + "-----------------------\n", + "Loss: {val_loss(model).item()}\n", + "Random sentence: {generate_sentence(model)}\n", + "\"\"\")" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "64662026-ea81-4c9e-aede-0dbf484424a2", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Iteration 0\n", + "------------------\n", + "Loss: 4.411328242953009\n", + "Random sentence: Lxs$DhMcWFQxPscUC$W!EosAUsANgL\n", + "vjY$NCRdDOqQqZuRICWC$EjNGboej&Pc!\n", + "\n", + "\n", + "Iteration 500\n", + "------------------\n", + "Loss: 2.533599337553858\n", + "Random sentence: GElaRKeeris, r t be ailll hee wotur:ye ed:\n", + "The,\n", + "$ y?Angrs;e ber \n", + "\n", + "\n", + "Iteration 1000\n", + "------------------\n", + "Loss: 2.3411745737318306\n", + "Random sentence: ANESTENYANCENGAET:\n", + "Whimy Whe:\n", + "Noorind Pind anir, neamig?\n", + "\n", + "\n", + "\n", + "PYIG\n", + "\n", + "\n", + "Iteration 1500\n", + "------------------\n", + "Loss: 2.2217165534240615\n", + "Random sentence: NULENUBELIUCER:\n", + "NRy theme?H my wheass as mpuges pulons.\n", + "\n", + "YCDASTR\n", + "\n", + "\n", + "Iteration 2000\n", + "------------------\n", + "Loss: 2.1437060877395724\n", + "Random sentence: YUETV:\n", + "Diz ViBA freters be sire neeed.\n", + "\n", + "VOREY Awhy lege\n", + "That shi\n", + "\n", + "\n", + "Iteration 2500\n", + "------------------\n", + "Loss: 2.089483798097002\n", + "Random sentence: OQENGUEELEES:\n", + "AR I, Wean not gerd shall now amby hushe im.\n", + "And s\n", + "\n", + "\n", + "Iteration 3000\n", + "------------------\n", + "Loss: 2.072011124940682\n", + "Random sentence: LETYUGHE'R:\n", + "O, go shal, my sil; Sill be to dess,\n", + "shed fee acher,\n", + "\n", + "\n", + "Iteration 3500\n", + "------------------\n", + "Loss: 2.020927447962387\n", + "Random sentence: MICHARY:\n", + "He bard! nont, ins; all time in of my herdonds, are his\n", + "\n", + "\n", + "Iteration 4000\n", + "------------------\n", + "Loss: 1.9881086504413765\n", + "Random sentence: NLARUDYIUB:\n", + "Pet so, in the the sehrouds retch,\n", + "My band, Brans an\n", + "\n", + "\n", + "Iteration 4500\n", + "------------------\n", + "Loss: 1.9580338393000023\n", + "Random sentence: TAUTELLENUS:\n", + "No goord thyssingies are.\n", + "\n", + "GRO:\n", + "Shall his irs ritot\n", + "\n", + "\n", + "Iteration 5000\n", + "------------------\n", + "Loss: 1.958173437010472\n", + "Random sentence: Secounds sight, eye theak: as the the the Lords.\n", + "\n", + "MNOUCESTER:\n", + "Ay\n", + "\n", + "\n", + "Iteration 5500\n", + "------------------\n", + "Loss: 1.9308816189733313\n", + "Random sentence: RISCAMIO:\n", + "\n", + "RUTOUMER:\n", + "Wood their bown cleading: betougner the is,\n", + "\n", + "\n", + "Iteration 6000\n", + "------------------\n", + "Loss: 1.911466234016684\n", + "Random sentence: MERKETANBE:\n", + "I well proy's onger beentlemance wan meather with ma\n", + "\n", + "\n", + "Iteration 6500\n", + "------------------\n", + "Loss: 1.9035973708617508\n", + "Random sentence: TyNLESTES:\n", + "It a seem'd enors where lend wherefore Cather death, \n", + "\n", + "\n", + "Iteration 7000\n", + "------------------\n", + "Loss: 1.8735182510486765\n", + "Random sentence: Turt:\n", + "Sign burid, then so golve butice't what, muy ever you\n", + "For \n", + "\n", + "\n", + "Iteration 7500\n", + "------------------\n", + "Loss: 1.869562540186564\n", + "Random sentence: MARIS.BARDAP:\n", + "Ist feat breast-raid campts yield\n", + "A clrowd ans; wh\n", + "\n", + "\n", + "Iteration 8000\n", + "------------------\n", + "Loss: 1.85178764820606\n", + "Random sentence: SAMINDALY:\n", + "Godtilive not dide to fromptrator:\n", + "Father worrown inM\n", + "\n", + "\n", + "Iteration 8500\n", + "------------------\n", + "Loss: 1.840864339445898\n", + "Random sentence: ANTELA:\n", + "Notio, or vosper of basband;\n", + "Wern thou grand my lord. Th\n", + "\n", + "\n", + "Iteration 9000\n", + "------------------\n", + "Loss: 1.8231779725821042\n", + "Random sentence: Set RIVINCEN York me his valousand\n", + "That iffock to upon him\n", + "Bode \n", + "\n", + "\n", + "Iteration 9500\n", + "------------------\n", + "Loss: 1.815877081143661\n", + "Random sentence: DUivETH's all you that means him; for joys:\n", + "Sick he love.\n", + "Can I \n", + "\n", + "\n", + "Iteration 10000\n", + "------------------\n", + "Loss: 1.7962866239903075\n", + "Random sentence: First MARINA:\n", + "Gentleman.\n", + "\n", + "HERRANIONE:\n", + "Well, me nile my men, whos\n", + "\n" + ] + } + ], + "source": [ + "opt = mpp.SGD(lr=LEARNING_RATE)\n", + "\n", + "n = 0\n", + "while True:\n", + " if n % EVAL_FREQ == 0:\n", + " print(f\"\"\"\n", + "Iteration {n:8d}\n", + "------------------\n", + "Loss: {val_loss(model).item()}\n", + "Random sentence: {generate_sentence(model)}\n", + "\"\"\")\n", + "\n", + " if n >= NUM_ITERS:\n", + " break\n", + "\n", + " train_loss(model).backward(opt=opt)\n", + " opt.step()\n", + "\n", + " n += 1" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.4" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}