From ab3738e456f979b852797602c51f05335469e65d Mon Sep 17 00:00:00 2001 From: Parsiad Azimzadeh Date: Fri, 16 Aug 2024 22:28:38 -0400 Subject: [PATCH] Add n-gram example --- examples/n-gram.ipynb | 317 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 317 insertions(+) create mode 100644 examples/n-gram.ipynb diff --git a/examples/n-gram.ipynb b/examples/n-gram.ipynb new file mode 100644 index 0000000..023e931 --- /dev/null +++ b/examples/n-gram.ipynb @@ -0,0 +1,317 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "c56226d9-be44-4db2-ae09-04f355dced98", + "metadata": {}, + "source": [ + "# n-gram language model\n", + "\n", + "An [n-gram language model](https://en.wikipedia.org/wiki/Word_n-gram_language_model) is a statistical model of language that models the distribution of the $k$-th token $X_k$ on the previous $n$ tokens $X_{k - 1}, \\ldots, X_{k - n}$.\n", + "An efficient way to compute an estimator for the probabilities of this model is by [counting n-grams](https://en.wikipedia.org/wiki/Word_n-gram_language_model#Approximation_method).\n", + "\n", + "This estimator can also be approximated by performing gradient descent on the cross entropy loss, as is done below.\n", + "\n", + "Although this is not an efficient way to compute the estimator, it is useful as it demonstrates how a language model without a closed form solution (e.g., [recurrent neural networks](https://en.wikipedia.org/wiki/Recurrent_neural_network) and [large language models](https://en.wikipedia.org/wiki/Large_language_model)) can be learned." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "31c06100-f075-403d-91e2-60bc4bd142a1", + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "import micrograd_pp as mpp\n", + "import numpy as np\n", + "import numpy.typing as npt\n", + "import scipy.special" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "9e15b90d-c80a-4bc5-8cca-f2fd4655fb3d", + "metadata": {}, + "outputs": [], + "source": [ + "BATCH_SIZE = 32\n", + "CONTEXT_WIDTH = 2 # Trigram\n", + "NUM_ITERS = 1_000_000\n", + "TRAIN_FRAC = 0.9" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "4e23ae66-1627-438f-8de7-d09524bf6b36", + "metadata": {}, + "outputs": [], + "source": [ + "text = mpp.datasets.load_tiny_shakespeare()" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "7e70b1a3-195f-4996-823c-00cd12a18699", + "metadata": {}, + "outputs": [], + "source": [ + "vocab = sorted(set(text))\n", + "\n", + "char2token = {char: token for token, char in enumerate(vocab)}\n", + "all_tokens = np.array([char2token[char] for char in text], dtype=np.int32)\n", + "\n", + "first_val_index = int(TRAIN_FRAC * all_tokens.size)\n", + "train_tokens = all_tokens[:first_val_index]\n", + "val_tokens = all_tokens[first_val_index:]\n", + "\n", + "def base_expand(context: npt.NDArray) -> npt.NDArray:\n", + " \"\"\"Convert a context window of tokens into a single token.\"\"\"\n", + " c = len(vocab)**np.arange(CONTEXT_WIDTH)\n", + " return (c * context).sum(axis=-1)\n", + "\n", + "def loss(embedding: mpp.Embedding, val: bool = True) -> mpp.Expr:\n", + " \"\"\"Compute loss on a random training batch or the validation set.\"\"\"\n", + " if val:\n", + " user_data = val_tokens\n", + " indices = np.arange(start=CONTEXT_WIDTH, stop=val_tokens.size)\n", + " else:\n", + " user_data = train_tokens\n", + " indices = np.random.randint(low=CONTEXT_WIDTH, high=train_tokens.size, size=(BATCH_SIZE,))\n", + " x = np.stack([user_data[index - CONTEXT_WIDTH:index] for index in indices]) # (B, C)\n", + " y = user_data[indices] # (B,)\n", + " logits = embedding(base_expand(x))\n", + " return mpp.cross_entropy_loss(logits, y)\n", + "\n", + "def generate_sentence(embedding: mpp.Embedding, init: npt.NDArray | None = None, length: int = 64) -> str:\n", + " \"\"\"Use a learned embedding to generate a sentence.\"\"\"\n", + " if init is None:\n", + " init = np.zeros((CONTEXT_WIDTH,), dtype=np.int32)\n", + " context = init\n", + " tokens = context.tolist()\n", + " for _ in range(length):\n", + " logits = embedding(base_expand(context[np.newaxis, ...]))\n", + " pvals = scipy.special.softmax(logits.value.squeeze())\n", + " token = np.random.multinomial(n=1, pvals=pvals).argmax().item()\n", + " context[:-1] = context[1:]\n", + " context[-1] = token\n", + " tokens.append(token)\n", + " return ''.join(vocab[token] for token in tokens)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "78207a31-d5a7-4fc3-b86f-ee93b6804bb0", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Uninitialized Embedding\n", + "-----------------------\n", + "Loss: 4.618406097403782\n", + "Random sentence: \n", + "\n", + "? lNtAp.'ZcS-Um\n", + "US:I!X.DC&VTej:XdX'QVMw3IK Fkv?rvkLnVqFZC\n", + "lB&TC$\n", + "\n" + ] + } + ], + "source": [ + "np.random.seed(0)\n", + "embedding = mpp.Embedding(num_embeddings=len(vocab)**CONTEXT_WIDTH, embedding_dim=len(vocab))\n", + "\n", + "with mpp.eval(), mpp.no_grad():\n", + " print(f\"\"\"\n", + "Uninitialized Embedding\n", + "-----------------------\n", + "Loss: {loss(embedding).value.item()}\n", + "Random sentence: {generate_sentence(embedding)}\n", + "\"\"\")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "8bb757b6-e04e-48f5-b864-b14e09916e77", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Iteration 0\n", + "------------------\n", + "Loss: 4.618406097403782\n", + "Random sentence: \n", + "\n", + "H,WIJa-Wk&'XXFCIq?lbJCBT?'XtDf-kW-Grq&FLIRpPz'tjY3Tpc.jLUXk loOn\n", + "\n", + "\n", + "Iteration 100000\n", + "------------------\n", + "Loss: 2.171994386892817\n", + "Random sentence: \n", + "\n", + "And myumbed VI:\n", + "WhourtR;Az'xMpU?gzXs\n", + "God, yoused\n", + "iNA:\n", + "HER:\n", + "O ban\n", + "\n", + "\n", + "Iteration 200000\n", + "------------------\n", + "Loss: 2.116169136143967\n", + "Random sentence: \n", + "\n", + "LEONTER:\n", + "POld come muchat Pome\n", + "ROKE Ennow\n", + "Wheithey ord's fie ast\n", + "\n", + "\n", + "Iteration 300000\n", + "------------------\n", + "Loss: 2.0921984053229843\n", + "Random sentence: \n", + "\n", + "so peaver, bod\n", + "What fore,\n", + "Nown,\n", + "That rall daun he berce wor the \n", + "\n", + "\n", + "Iteration 400000\n", + "------------------\n", + "Loss: 2.085089184457391\n", + "Random sentence: \n", + "\n", + "Ex&\n", + "$MKFRIANNENTES:\n", + "Shat shose,\n", + "That,\n", + "My se her ass of smadeectu\n", + "\n", + "\n", + "Iteration 500000\n", + "------------------\n", + "Loss: 2.0779833496975297\n", + "Random sentence: \n", + "\n", + "Bol, wou wour eato clovirseend.\n", + "\n", + "My oad to my bet your beam not \n", + "\n", + "\n", + "Iteration 600000\n", + "------------------\n", + "Loss: 2.0715481547062122\n", + "Random sentence: \n", + "\n", + "CLAUNTISsHlp were enap?yl's fordievall.\n", + "Younst be thave we beye\n", + "\n", + "\n", + "\n", + "Iteration 700000\n", + "------------------\n", + "Loss: 2.0705078753230732\n", + "Random sentence: \n", + "\n", + "Beforth sit I helovers, we him and shas cow nou whou.\n", + "\n", + "JULIETERS\n", + "\n", + "\n", + "Iteration 800000\n", + "------------------\n", + "Loss: 2.0671958447674292\n", + "Random sentence: \n", + "\n", + "KING EdLsquis cromendess oustesordessay les, mut to like, got in\n", + "\n", + "\n", + "Iteration 900000\n", + "------------------\n", + "Loss: 2.0646833826588025\n", + "Random sentence: \n", + "\n", + "QUEENVOLIO:\n", + "I alty of tway, Sect tre reads.\n", + "\n", + "My ch ther. Whing.\n", + "\n", + "\n", + "\n", + "Iteration 1000000\n", + "------------------\n", + "Loss: 2.064905884655797\n", + "Random sentence: \n", + "\n", + "HES:\n", + "Vold, and he sonens!\n", + "\n", + "Butely thy:\n", + "By my les;\n", + "All, theased f\n", + "\n" + ] + } + ], + "source": [ + "opt = mpp.SGD(lr=1.0)\n", + "\n", + "n = 0\n", + "while True:\n", + " if n % (NUM_ITERS // 10) == 0:\n", + " with mpp.eval(), mpp.no_grad():\n", + " print(f\"\"\"\n", + "Iteration {n:8d}\n", + "------------------\n", + "Loss: {loss(embedding).value.item()}\n", + "Random sentence: {generate_sentence(embedding)}\n", + "\"\"\")\n", + "\n", + " if n >= NUM_ITERS:\n", + " break\n", + "\n", + " loss(embedding=embedding, val=False).backward(opt=opt)\n", + " opt.step()\n", + "\n", + " n += 1" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.4" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}