Skip to content

Commit

Permalink
Add n-gram example
Browse files Browse the repository at this point in the history
  • Loading branch information
parsiad committed Aug 17, 2024
1 parent 162ad67 commit ab3738e
Showing 1 changed file with 317 additions and 0 deletions.
317 changes: 317 additions & 0 deletions examples/n-gram.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,317 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "c56226d9-be44-4db2-ae09-04f355dced98",
"metadata": {},
"source": [
"# n-gram language model\n",
"\n",
"An [n-gram language model](https://en.wikipedia.org/wiki/Word_n-gram_language_model) is a statistical model of language that models the distribution of the $k$-th token $X_k$ on the previous $n$ tokens $X_{k - 1}, \\ldots, X_{k - n}$.\n",
"An efficient way to compute an estimator for the probabilities of this model is by [counting n-grams](https://en.wikipedia.org/wiki/Word_n-gram_language_model#Approximation_method).\n",
"\n",
"This estimator can also be approximated by performing gradient descent on the cross entropy loss, as is done below.\n",
"\n",
"Although this is not an efficient way to compute the estimator, it is useful as it demonstrates how a language model without a closed form solution (e.g., [recurrent neural networks](https://en.wikipedia.org/wiki/Recurrent_neural_network) and [large language models](https://en.wikipedia.org/wiki/Large_language_model)) can be learned."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "31c06100-f075-403d-91e2-60bc4bd142a1",
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import micrograd_pp as mpp\n",
"import numpy as np\n",
"import numpy.typing as npt\n",
"import scipy.special"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "9e15b90d-c80a-4bc5-8cca-f2fd4655fb3d",
"metadata": {},
"outputs": [],
"source": [
"BATCH_SIZE = 32\n",
"CONTEXT_WIDTH = 2 # Trigram\n",
"NUM_ITERS = 1_000_000\n",
"TRAIN_FRAC = 0.9"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "4e23ae66-1627-438f-8de7-d09524bf6b36",
"metadata": {},
"outputs": [],
"source": [
"text = mpp.datasets.load_tiny_shakespeare()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "7e70b1a3-195f-4996-823c-00cd12a18699",
"metadata": {},
"outputs": [],
"source": [
"vocab = sorted(set(text))\n",
"\n",
"char2token = {char: token for token, char in enumerate(vocab)}\n",
"all_tokens = np.array([char2token[char] for char in text], dtype=np.int32)\n",
"\n",
"first_val_index = int(TRAIN_FRAC * all_tokens.size)\n",
"train_tokens = all_tokens[:first_val_index]\n",
"val_tokens = all_tokens[first_val_index:]\n",
"\n",
"def base_expand(context: npt.NDArray) -> npt.NDArray:\n",
" \"\"\"Convert a context window of tokens into a single token.\"\"\"\n",
" c = len(vocab)**np.arange(CONTEXT_WIDTH)\n",
" return (c * context).sum(axis=-1)\n",
"\n",
"def loss(embedding: mpp.Embedding, val: bool = True) -> mpp.Expr:\n",
" \"\"\"Compute loss on a random training batch or the validation set.\"\"\"\n",
" if val:\n",
" user_data = val_tokens\n",
" indices = np.arange(start=CONTEXT_WIDTH, stop=val_tokens.size)\n",
" else:\n",
" user_data = train_tokens\n",
" indices = np.random.randint(low=CONTEXT_WIDTH, high=train_tokens.size, size=(BATCH_SIZE,))\n",
" x = np.stack([user_data[index - CONTEXT_WIDTH:index] for index in indices]) # (B, C)\n",
" y = user_data[indices] # (B,)\n",
" logits = embedding(base_expand(x))\n",
" return mpp.cross_entropy_loss(logits, y)\n",
"\n",
"def generate_sentence(embedding: mpp.Embedding, init: npt.NDArray | None = None, length: int = 64) -> str:\n",
" \"\"\"Use a learned embedding to generate a sentence.\"\"\"\n",
" if init is None:\n",
" init = np.zeros((CONTEXT_WIDTH,), dtype=np.int32)\n",
" context = init\n",
" tokens = context.tolist()\n",
" for _ in range(length):\n",
" logits = embedding(base_expand(context[np.newaxis, ...]))\n",
" pvals = scipy.special.softmax(logits.value.squeeze())\n",
" token = np.random.multinomial(n=1, pvals=pvals).argmax().item()\n",
" context[:-1] = context[1:]\n",
" context[-1] = token\n",
" tokens.append(token)\n",
" return ''.join(vocab[token] for token in tokens)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "78207a31-d5a7-4fc3-b86f-ee93b6804bb0",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Uninitialized Embedding\n",
"-----------------------\n",
"Loss: 4.618406097403782\n",
"Random sentence: \n",
"\n",
"? lNtAp.'ZcS-Um\n",
"US:I!X.DC&VTej:XdX'QVMw3IK Fkv?rvkLnVqFZC\n",
"lB&TC$\n",
"\n"
]
}
],
"source": [
"np.random.seed(0)\n",
"embedding = mpp.Embedding(num_embeddings=len(vocab)**CONTEXT_WIDTH, embedding_dim=len(vocab))\n",
"\n",
"with mpp.eval(), mpp.no_grad():\n",
" print(f\"\"\"\n",
"Uninitialized Embedding\n",
"-----------------------\n",
"Loss: {loss(embedding).value.item()}\n",
"Random sentence: {generate_sentence(embedding)}\n",
"\"\"\")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "8bb757b6-e04e-48f5-b864-b14e09916e77",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"Iteration 0\n",
"------------------\n",
"Loss: 4.618406097403782\n",
"Random sentence: \n",
"\n",
"H,WIJa-Wk&'XXFCIq?lbJCBT?'XtDf-kW-Grq&FLIRpPz'tjY3Tpc.jLUXk loOn\n",
"\n",
"\n",
"Iteration 100000\n",
"------------------\n",
"Loss: 2.171994386892817\n",
"Random sentence: \n",
"\n",
"And myumbed VI:\n",
"WhourtR;Az'xMpU?gzXs\n",
"God, yoused\n",
"iNA:\n",
"HER:\n",
"O ban\n",
"\n",
"\n",
"Iteration 200000\n",
"------------------\n",
"Loss: 2.116169136143967\n",
"Random sentence: \n",
"\n",
"LEONTER:\n",
"POld come muchat Pome\n",
"ROKE Ennow\n",
"Wheithey ord's fie ast\n",
"\n",
"\n",
"Iteration 300000\n",
"------------------\n",
"Loss: 2.0921984053229843\n",
"Random sentence: \n",
"\n",
"so peaver, bod\n",
"What fore,\n",
"Nown,\n",
"That rall daun he berce wor the \n",
"\n",
"\n",
"Iteration 400000\n",
"------------------\n",
"Loss: 2.085089184457391\n",
"Random sentence: \n",
"\n",
"Ex&\n",
"$MKFRIANNENTES:\n",
"Shat shose,\n",
"That,\n",
"My se her ass of smadeectu\n",
"\n",
"\n",
"Iteration 500000\n",
"------------------\n",
"Loss: 2.0779833496975297\n",
"Random sentence: \n",
"\n",
"Bol, wou wour eato clovirseend.\n",
"\n",
"My oad to my bet your beam not \n",
"\n",
"\n",
"Iteration 600000\n",
"------------------\n",
"Loss: 2.0715481547062122\n",
"Random sentence: \n",
"\n",
"CLAUNTISsHlp were enap?yl's fordievall.\n",
"Younst be thave we beye\n",
"\n",
"\n",
"\n",
"Iteration 700000\n",
"------------------\n",
"Loss: 2.0705078753230732\n",
"Random sentence: \n",
"\n",
"Beforth sit I helovers, we him and shas cow nou whou.\n",
"\n",
"JULIETERS\n",
"\n",
"\n",
"Iteration 800000\n",
"------------------\n",
"Loss: 2.0671958447674292\n",
"Random sentence: \n",
"\n",
"KING EdLsquis cromendess oustesordessay les, mut to like, got in\n",
"\n",
"\n",
"Iteration 900000\n",
"------------------\n",
"Loss: 2.0646833826588025\n",
"Random sentence: \n",
"\n",
"QUEENVOLIO:\n",
"I alty of tway, Sect tre reads.\n",
"\n",
"My ch ther. Whing.\n",
"\n",
"\n",
"\n",
"Iteration 1000000\n",
"------------------\n",
"Loss: 2.064905884655797\n",
"Random sentence: \n",
"\n",
"HES:\n",
"Vold, and he sonens!\n",
"\n",
"Butely thy:\n",
"By my les;\n",
"All, theased f\n",
"\n"
]
}
],
"source": [
"opt = mpp.SGD(lr=1.0)\n",
"\n",
"n = 0\n",
"while True:\n",
" if n % (NUM_ITERS // 10) == 0:\n",
" with mpp.eval(), mpp.no_grad():\n",
" print(f\"\"\"\n",
"Iteration {n:8d}\n",
"------------------\n",
"Loss: {loss(embedding).value.item()}\n",
"Random sentence: {generate_sentence(embedding)}\n",
"\"\"\")\n",
"\n",
" if n >= NUM_ITERS:\n",
" break\n",
"\n",
" loss(embedding=embedding, val=False).backward(opt=opt)\n",
" opt.step()\n",
"\n",
" n += 1"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.4"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

0 comments on commit ab3738e

Please sign in to comment.