-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
317 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,317 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "c56226d9-be44-4db2-ae09-04f355dced98", | ||
"metadata": {}, | ||
"source": [ | ||
"# n-gram language model\n", | ||
"\n", | ||
"An [n-gram language model](https://en.wikipedia.org/wiki/Word_n-gram_language_model) is a statistical model of language that models the distribution of the $k$-th token $X_k$ on the previous $n$ tokens $X_{k - 1}, \\ldots, X_{k - n}$.\n", | ||
"An efficient way to compute an estimator for the probabilities of this model is by [counting n-grams](https://en.wikipedia.org/wiki/Word_n-gram_language_model#Approximation_method).\n", | ||
"\n", | ||
"This estimator can also be approximated by performing gradient descent on the cross entropy loss, as is done below.\n", | ||
"\n", | ||
"Although this is not an efficient way to compute the estimator, it is useful as it demonstrates how a language model without a closed form solution (e.g., [recurrent neural networks](https://en.wikipedia.org/wiki/Recurrent_neural_network) and [large language models](https://en.wikipedia.org/wiki/Large_language_model)) can be learned." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"id": "31c06100-f075-403d-91e2-60bc4bd142a1", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import matplotlib.pyplot as plt\n", | ||
"import micrograd_pp as mpp\n", | ||
"import numpy as np\n", | ||
"import numpy.typing as npt\n", | ||
"import scipy.special" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"id": "9e15b90d-c80a-4bc5-8cca-f2fd4655fb3d", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"BATCH_SIZE = 32\n", | ||
"CONTEXT_WIDTH = 2 # Trigram\n", | ||
"NUM_ITERS = 1_000_000\n", | ||
"TRAIN_FRAC = 0.9" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"id": "4e23ae66-1627-438f-8de7-d09524bf6b36", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"text = mpp.datasets.load_tiny_shakespeare()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"id": "7e70b1a3-195f-4996-823c-00cd12a18699", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"vocab = sorted(set(text))\n", | ||
"\n", | ||
"char2token = {char: token for token, char in enumerate(vocab)}\n", | ||
"all_tokens = np.array([char2token[char] for char in text], dtype=np.int32)\n", | ||
"\n", | ||
"first_val_index = int(TRAIN_FRAC * all_tokens.size)\n", | ||
"train_tokens = all_tokens[:first_val_index]\n", | ||
"val_tokens = all_tokens[first_val_index:]\n", | ||
"\n", | ||
"def base_expand(context: npt.NDArray) -> npt.NDArray:\n", | ||
" \"\"\"Convert a context window of tokens into a single token.\"\"\"\n", | ||
" c = len(vocab)**np.arange(CONTEXT_WIDTH)\n", | ||
" return (c * context).sum(axis=-1)\n", | ||
"\n", | ||
"def loss(embedding: mpp.Embedding, val: bool = True) -> mpp.Expr:\n", | ||
" \"\"\"Compute loss on a random training batch or the validation set.\"\"\"\n", | ||
" if val:\n", | ||
" user_data = val_tokens\n", | ||
" indices = np.arange(start=CONTEXT_WIDTH, stop=val_tokens.size)\n", | ||
" else:\n", | ||
" user_data = train_tokens\n", | ||
" indices = np.random.randint(low=CONTEXT_WIDTH, high=train_tokens.size, size=(BATCH_SIZE,))\n", | ||
" x = np.stack([user_data[index - CONTEXT_WIDTH:index] for index in indices]) # (B, C)\n", | ||
" y = user_data[indices] # (B,)\n", | ||
" logits = embedding(base_expand(x))\n", | ||
" return mpp.cross_entropy_loss(logits, y)\n", | ||
"\n", | ||
"def generate_sentence(embedding: mpp.Embedding, init: npt.NDArray | None = None, length: int = 64) -> str:\n", | ||
" \"\"\"Use a learned embedding to generate a sentence.\"\"\"\n", | ||
" if init is None:\n", | ||
" init = np.zeros((CONTEXT_WIDTH,), dtype=np.int32)\n", | ||
" context = init\n", | ||
" tokens = context.tolist()\n", | ||
" for _ in range(length):\n", | ||
" logits = embedding(base_expand(context[np.newaxis, ...]))\n", | ||
" pvals = scipy.special.softmax(logits.value.squeeze())\n", | ||
" token = np.random.multinomial(n=1, pvals=pvals).argmax().item()\n", | ||
" context[:-1] = context[1:]\n", | ||
" context[-1] = token\n", | ||
" tokens.append(token)\n", | ||
" return ''.join(vocab[token] for token in tokens)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"id": "78207a31-d5a7-4fc3-b86f-ee93b6804bb0", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"\n", | ||
"Uninitialized Embedding\n", | ||
"-----------------------\n", | ||
"Loss: 4.618406097403782\n", | ||
"Random sentence: \n", | ||
"\n", | ||
"? lNtAp.'ZcS-Um\n", | ||
"US:I!X.DC&VTej:XdX'QVMw3IK Fkv?rvkLnVqFZC\n", | ||
"lB&TC$\n", | ||
"\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"np.random.seed(0)\n", | ||
"embedding = mpp.Embedding(num_embeddings=len(vocab)**CONTEXT_WIDTH, embedding_dim=len(vocab))\n", | ||
"\n", | ||
"with mpp.eval(), mpp.no_grad():\n", | ||
" print(f\"\"\"\n", | ||
"Uninitialized Embedding\n", | ||
"-----------------------\n", | ||
"Loss: {loss(embedding).value.item()}\n", | ||
"Random sentence: {generate_sentence(embedding)}\n", | ||
"\"\"\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"id": "8bb757b6-e04e-48f5-b864-b14e09916e77", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"\n", | ||
"Iteration 0\n", | ||
"------------------\n", | ||
"Loss: 4.618406097403782\n", | ||
"Random sentence: \n", | ||
"\n", | ||
"H,WIJa-Wk&'XXFCIq?lbJCBT?'XtDf-kW-Grq&FLIRpPz'tjY3Tpc.jLUXk loOn\n", | ||
"\n", | ||
"\n", | ||
"Iteration 100000\n", | ||
"------------------\n", | ||
"Loss: 2.171994386892817\n", | ||
"Random sentence: \n", | ||
"\n", | ||
"And myumbed VI:\n", | ||
"WhourtR;Az'xMpU?gzXs\n", | ||
"God, yoused\n", | ||
"iNA:\n", | ||
"HER:\n", | ||
"O ban\n", | ||
"\n", | ||
"\n", | ||
"Iteration 200000\n", | ||
"------------------\n", | ||
"Loss: 2.116169136143967\n", | ||
"Random sentence: \n", | ||
"\n", | ||
"LEONTER:\n", | ||
"POld come muchat Pome\n", | ||
"ROKE Ennow\n", | ||
"Wheithey ord's fie ast\n", | ||
"\n", | ||
"\n", | ||
"Iteration 300000\n", | ||
"------------------\n", | ||
"Loss: 2.0921984053229843\n", | ||
"Random sentence: \n", | ||
"\n", | ||
"so peaver, bod\n", | ||
"What fore,\n", | ||
"Nown,\n", | ||
"That rall daun he berce wor the \n", | ||
"\n", | ||
"\n", | ||
"Iteration 400000\n", | ||
"------------------\n", | ||
"Loss: 2.085089184457391\n", | ||
"Random sentence: \n", | ||
"\n", | ||
"Ex&\n", | ||
"$MKFRIANNENTES:\n", | ||
"Shat shose,\n", | ||
"That,\n", | ||
"My se her ass of smadeectu\n", | ||
"\n", | ||
"\n", | ||
"Iteration 500000\n", | ||
"------------------\n", | ||
"Loss: 2.0779833496975297\n", | ||
"Random sentence: \n", | ||
"\n", | ||
"Bol, wou wour eato clovirseend.\n", | ||
"\n", | ||
"My oad to my bet your beam not \n", | ||
"\n", | ||
"\n", | ||
"Iteration 600000\n", | ||
"------------------\n", | ||
"Loss: 2.0715481547062122\n", | ||
"Random sentence: \n", | ||
"\n", | ||
"CLAUNTISsHlp were enap?yl's fordievall.\n", | ||
"Younst be thave we beye\n", | ||
"\n", | ||
"\n", | ||
"\n", | ||
"Iteration 700000\n", | ||
"------------------\n", | ||
"Loss: 2.0705078753230732\n", | ||
"Random sentence: \n", | ||
"\n", | ||
"Beforth sit I helovers, we him and shas cow nou whou.\n", | ||
"\n", | ||
"JULIETERS\n", | ||
"\n", | ||
"\n", | ||
"Iteration 800000\n", | ||
"------------------\n", | ||
"Loss: 2.0671958447674292\n", | ||
"Random sentence: \n", | ||
"\n", | ||
"KING EdLsquis cromendess oustesordessay les, mut to like, got in\n", | ||
"\n", | ||
"\n", | ||
"Iteration 900000\n", | ||
"------------------\n", | ||
"Loss: 2.0646833826588025\n", | ||
"Random sentence: \n", | ||
"\n", | ||
"QUEENVOLIO:\n", | ||
"I alty of tway, Sect tre reads.\n", | ||
"\n", | ||
"My ch ther. Whing.\n", | ||
"\n", | ||
"\n", | ||
"\n", | ||
"Iteration 1000000\n", | ||
"------------------\n", | ||
"Loss: 2.064905884655797\n", | ||
"Random sentence: \n", | ||
"\n", | ||
"HES:\n", | ||
"Vold, and he sonens!\n", | ||
"\n", | ||
"Butely thy:\n", | ||
"By my les;\n", | ||
"All, theased f\n", | ||
"\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"opt = mpp.SGD(lr=1.0)\n", | ||
"\n", | ||
"n = 0\n", | ||
"while True:\n", | ||
" if n % (NUM_ITERS // 10) == 0:\n", | ||
" with mpp.eval(), mpp.no_grad():\n", | ||
" print(f\"\"\"\n", | ||
"Iteration {n:8d}\n", | ||
"------------------\n", | ||
"Loss: {loss(embedding).value.item()}\n", | ||
"Random sentence: {generate_sentence(embedding)}\n", | ||
"\"\"\")\n", | ||
"\n", | ||
" if n >= NUM_ITERS:\n", | ||
" break\n", | ||
"\n", | ||
" loss(embedding=embedding, val=False).backward(opt=opt)\n", | ||
" opt.step()\n", | ||
"\n", | ||
" n += 1" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.12.4" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |