From 64600e4abe54be3c85d47c16af80ccfd10ee5298 Mon Sep 17 00:00:00 2001 From: Marina Zhang <40069936+MarinaZhang@users.noreply.github.com> Date: Mon, 9 Oct 2023 15:32:10 -0700 Subject: [PATCH 1/4] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index d1bf6c3..60c6ac6 100644 --- a/README.md +++ b/README.md @@ -62,7 +62,7 @@ Please cite this reference if you use RETVec in your research: ```bibtex @article{retvec2023, - title={RetVec: Resilient and Efficient Text Vectorizer}, + title={RETVec: Resilient and Efficient Text Vectorizer}, author={Elie Bursztein, Marina Zhang, Owen Vallis, Xinyu Jia, and Alexey Kurakin}, year={2023}, eprint={2302.09207} From e25d7c54ad447b47d7d5808ed3d34316be3d5412 Mon Sep 17 00:00:00 2001 From: Marina Zhang <40069936+MarinaZhang@users.noreply.github.com> Date: Mon, 9 Oct 2023 15:33:20 -0700 Subject: [PATCH 2/4] Update CONTRIBUTING.md --- CONTRIBUTING.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index be7bb48..2895328 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -1,5 +1,5 @@ # How to Contribute -Thanks for considering contributing to TF similarity! +Thanks for considering contributing to RETVec! Here is what you need to know to make a successful contribution. There are just a few small guidelines you need to follow. @@ -31,7 +31,7 @@ pull request: - Ideally one PR corespond to one feature or improvement to make it easier to review. So **try** to split your contribution in meaning logical units. - Your code **must** pass the unit-tests. We use `pytest` so simply run it at the root of the project. -- Your code **must** passs static analyis. We use `mypy` so simply run `mypy tensorflow_similarity/` from the root of the project. +- Your code **must** passs static analyis. We use `mypy` so simply run `mypy retvec/` from the root of the project. - Your code **must** comes with unit-tests to ensure long term quality - Your functions **must** be documented except obvious ones using the Google style. - Your functions **must** be typed. From 23d28c0ce49f9bfbdb75b8fca46c310f7fa1c55b Mon Sep 17 00:00:00 2001 From: Marina Zhang <40069936+MarinaZhang@users.noreply.github.com> Date: Mon, 9 Oct 2023 16:23:17 -0700 Subject: [PATCH 3/4] Delete notebooks/train_hello_world_tf.ipynb --- notebooks/train_hello_world_tf.ipynb | 547 --------------------------- 1 file changed, 547 deletions(-) delete mode 100644 notebooks/train_hello_world_tf.ipynb diff --git a/notebooks/train_hello_world_tf.ipynb b/notebooks/train_hello_world_tf.ipynb deleted file mode 100644 index 89b2021..0000000 --- a/notebooks/train_hello_world_tf.ipynb +++ /dev/null @@ -1,547 +0,0 @@ -{ - "cells": [ - { - "attachments": {}, - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Using RetVec to train an emotion classifier\n", - "\n", - "RetVec is a state of art text tokenizer that works directly out of raw strings to create resilient models. Model trained with RetVec acheive state of art classification performance\n", - "and exhibit strong resilience to adversarial attacks as reported in our [paper](https://arxiv.org/abs/2302.09207).\n", - "\n", - "\n", - "RetVec speed, low , and stateless nature makes it the perfect choice to train and deploy\n", - "small and efficient on-device models. It is natively supported in TFLite via custom ops implemented in tensorflow-text for ondevice models and we provide a Javascript implementation RetVecJS that allows to deploy web models via TFJS.\n", - "\n", - "\n", - "This notebook demonstrates how to quickly train and use a text emotion classifier.\n", - "This classifier can then be easily exported to run in a webpage as demonstrate in \n", - "this notebook\n", - "\n", - "Let's get started" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": { - "tags": [ - "hide-output" - ] - }, - "outputs": [], - "source": [ - "# installing needed dependencies\n", - "try:\n", - " import retvec\n", - "except:\n", - " !pip install retvec # is retvec installed?\n", - "try:\n", - " import datasets\n", - "except:\n", - " !pip install datasets # used to get the dataset\n", - "\n", - "try:\n", - " import matplotlib\n", - "except:\n", - " !pip install matplotlib\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/elieb/git/retvec/venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - } - ], - "source": [ - "import tensorflow as tf\n", - "import numpy as np\n", - "from tensorflow.keras import layers\n", - "from datasets import load_dataset\n", - "from matplotlib import pyplot as plt" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "In this notebook we are using the `RETVecTokenizer()` layer which perform the binarization and embedding in a single step. This is the best approach for GPU training. For TPU training,\n", - "it is more efficient to split the two steps -- see our TPU training notbook for this.\n" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [], - "source": [ - " # RetVec tokenizer layer.\n", - "from retvec.tf import RETVecTokenizer " - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Create dataset\n", - "\n", - "We are going to use the [Go Emotion dataset](https://huggingface.co/datasets/go_emotions) to create a mulit-class emotion classifier.\n", - "https://ai.googleblog.com/2021/10/goemotions-dataset-for-fine-grained.html" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [], - "source": [ - "# downloading\n", - "dataset = load_dataset('go_emotions')" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "num classes 28\n", - "['admiration', 'amusement', 'anger', 'annoyance', 'approval', 'caring', 'confusion', 'curiosity', 'desire', 'disappointment', 'disapproval', 'disgust', 'embarrassment', 'excitement', 'fear', 'gratitude', 'grief', 'joy', 'love', 'nervousness', 'optimism', 'pride', 'realization', 'relief', 'remorse', 'sadness', 'surprise', 'neutral']\n" - ] - } - ], - "source": [ - "# getting class name mapping and number of class\n", - "CLASSES = dataset['train'].features['labels'].feature.names\n", - "NUM_CLASSES = len(CLASSES)\n", - "print(f\"num classes {NUM_CLASSES}\")\n", - "print(CLASSES)" - ] - }, - { - "cell_type": "code", - "execution_count": 78, - "metadata": {}, - "outputs": [], - "source": [ - "# preparing data\n", - "x_train = tf.constant(dataset['train']['text'], dtype=tf.string)\n", - "\n", - "# the one-hot requires a little more due to the multi-class nature of the dataset.\n", - "y_train = np.zeros((len(x_train),NUM_CLASSES))\n", - "for idx, ex in enumerate(dataset['train']['labels']):\n", - " for val in ex:\n", - " y_train[idx][val] = 1\n", - "\n", - "# test data\n", - "x_test = tf.constant(dataset['test']['text'], dtype=tf.string)\n", - "y_test = np.zeros((len(x_test),NUM_CLASSES))\n", - "for idx, ex in enumerate(dataset['test']['labels']):\n", - " for val in ex:\n", - " y_test[idx][val] = 1\n" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Model\n", - "\n", - "A key strength of RetVec is that input to the model are \n", - "the raw datasets string with no pre-processing which greatly simplify the training and inference process. In particular for on-device models. \n", - "\n", - "Notes:\n", - "- Using strings directly as input requires to use a shape of `(1,)` and specify the type `tf.string`\n", - "\n", - "- We are using `RetVecTokenizer()` in its default configuration which is to truncate at `128` words and use a small pretained model for the embedding. You can experiment with shorter or longer length by changing the `sequence_length`\n", - "parameter." - ] - }, - { - "cell_type": "code", - "execution_count": 128, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.\n", - "Model: \"model_12\"\n", - "_________________________________________________________________\n", - " Layer (type) Output Shape Param # \n", - "=================================================================\n", - " token (InputLayer) [(None, 1)] 0 \n", - " \n", - " ret_vec_tokenizer_14 (RETV (None, 128, 256) 230144 \n", - " ecTokenizer) \n", - " \n", - " batch_normalization_14 (Ba (None, 128, 256) 1024 \n", - " tchNormalization) \n", - " \n", - " spatial_dropout1d_16 (Spat (None, 128, 256) 0 \n", - " ialDropout1D) \n", - " \n", - " bidirectional_26 (Bidirect (None, 128, 128) 164352 \n", - " ional) \n", - " \n", - " bidirectional_27 (Bidirect (None, 128, 64) 41216 \n", - " ional) \n", - " \n", - " bidirectional_28 (Bidirect (None, 64) 24832 \n", - " ional) \n", - " \n", - " dense_11 (Dense) (None, 28) 1820 \n", - " \n", - "=================================================================\n", - "Total params: 463388 (1.77 MB)\n", - "Trainable params: 232732 (909.11 KB)\n", - "Non-trainable params: 230656 (901.00 KB)\n", - "_________________________________________________________________\n" - ] - } - ], - "source": [ - "\n", - "\n", - "# Using strings directely requires to put a shape of (1, ) and a dtype: tf.string\n", - "inputs = layers.Input(shape=(1, ), name=\"token\", dtype=tf.string)\n", - "\n", - "# we are using RetVec with it's default settings\n", - "x = RETVecTokenizer(model='retvec-v1')(inputs)\n", - "\n", - "# Adding a batch norm after RetVec usually help with the convergence\n", - "x = layers.BatchNormalization()(x)\n", - "\n", - "# standard three LSTM layers\n", - "x = layers.SpatialDropout1D(0.1)(x)\n", - "x = layers.Bidirectional(layers.LSTM(64, return_sequences=True))(x)\n", - "x = layers.Bidirectional(layers.LSTM(32, return_sequences=True))(x)\n", - "x = layers.Bidirectional(layers.LSTM(32))(x)\n", - "outputs = layers.Dense(NUM_CLASSES, activation='sigmoid')(x)\n", - "model = tf.keras.Model(inputs, outputs)\n", - "model.summary()\n", - " " - ] - }, - { - "cell_type": "code", - "execution_count": 129, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 1/50\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2023-08-03 15:19:13.127544: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.\n", - "2023-08-03 15:19:14.165134: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.\n", - "2023-08-03 15:19:14.184875: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.\n", - "2023-08-03 15:19:14.439446: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.\n", - "2023-08-03 15:19:14.460276: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.\n", - "2023-08-03 15:19:14.691597: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.\n", - "2023-08-03 15:19:14.709305: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.\n", - "2023-08-03 15:19:15.189193: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.\n", - "2023-08-03 15:19:15.216765: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.\n", - "2023-08-03 15:19:15.649611: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.\n", - "2023-08-03 15:19:15.681321: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.\n", - "2023-08-03 15:19:16.115779: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.\n", - "2023-08-03 15:19:16.144417: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "340/340 [==============================] - ETA: 0s - loss: 0.1727 - acc: 0.2734" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "2023-08-03 15:20:08.448031: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.\n", - "2023-08-03 15:20:08.890213: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.\n", - "2023-08-03 15:20:08.903641: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.\n", - "2023-08-03 15:20:09.091102: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.\n", - "2023-08-03 15:20:09.103586: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.\n", - "2023-08-03 15:20:09.291403: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.\n", - "2023-08-03 15:20:09.304346: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "340/340 [==============================] - 67s 169ms/step - loss: 0.1727 - acc: 0.2734 - val_loss: 0.1479 - val_acc: 0.2959\n", - "Epoch 2/50\n", - "340/340 [==============================] - 49s 143ms/step - loss: 0.1490 - acc: 0.2954 - val_loss: 0.1464 - val_acc: 0.2959\n", - "Epoch 3/50\n", - "340/340 [==============================] - 49s 145ms/step - loss: 0.1455 - acc: 0.3104 - val_loss: 0.1394 - val_acc: 0.3455\n", - "Epoch 4/50\n", - "340/340 [==============================] - 48s 142ms/step - loss: 0.1381 - acc: 0.3532 - val_loss: 0.1330 - val_acc: 0.3650\n", - "Epoch 5/50\n", - "340/340 [==============================] - 49s 144ms/step - loss: 0.1339 - acc: 0.3682 - val_loss: 0.1295 - val_acc: 0.3932\n", - "Epoch 6/50\n", - "340/340 [==============================] - 49s 143ms/step - loss: 0.1309 - acc: 0.3921 - val_loss: 0.1261 - val_acc: 0.4166\n", - "Epoch 7/50\n", - "340/340 [==============================] - 49s 145ms/step - loss: 0.1279 - acc: 0.4106 - val_loss: 0.1238 - val_acc: 0.4229\n", - "Epoch 8/50\n", - "340/340 [==============================] - 48s 142ms/step - loss: 0.1251 - acc: 0.4246 - val_loss: 0.1224 - val_acc: 0.4295\n", - "Epoch 9/50\n", - "340/340 [==============================] - 48s 141ms/step - loss: 0.1225 - acc: 0.4344 - val_loss: 0.1199 - val_acc: 0.4420\n", - "Epoch 10/50\n", - "340/340 [==============================] - 48s 140ms/step - loss: 0.1206 - acc: 0.4425 - val_loss: 0.1165 - val_acc: 0.4570\n", - "Epoch 11/50\n", - "340/340 [==============================] - 48s 142ms/step - loss: 0.1186 - acc: 0.4501 - val_loss: 0.1152 - val_acc: 0.4603\n", - "Epoch 12/50\n", - "340/340 [==============================] - 49s 143ms/step - loss: 0.1165 - acc: 0.4589 - val_loss: 0.1134 - val_acc: 0.4638\n", - "Epoch 13/50\n", - "340/340 [==============================] - 49s 144ms/step - loss: 0.1148 - acc: 0.4635 - val_loss: 0.1121 - val_acc: 0.4702\n", - "Epoch 14/50\n", - "340/340 [==============================] - 48s 142ms/step - loss: 0.1133 - acc: 0.4661 - val_loss: 0.1111 - val_acc: 0.4721\n", - "Epoch 15/50\n", - "340/340 [==============================] - 49s 144ms/step - loss: 0.1119 - acc: 0.4718 - val_loss: 0.1099 - val_acc: 0.4732\n", - "Epoch 16/50\n", - "340/340 [==============================] - 48s 141ms/step - loss: 0.1102 - acc: 0.4749 - val_loss: 0.1086 - val_acc: 0.4736\n", - "Epoch 17/50\n", - "340/340 [==============================] - 50s 146ms/step - loss: 0.1090 - acc: 0.4800 - val_loss: 0.1081 - val_acc: 0.4818\n", - "Epoch 18/50\n", - "340/340 [==============================] - 48s 140ms/step - loss: 0.1078 - acc: 0.4879 - val_loss: 0.1071 - val_acc: 0.4885\n", - "Epoch 19/50\n", - "340/340 [==============================] - 49s 145ms/step - loss: 0.1065 - acc: 0.4920 - val_loss: 0.1058 - val_acc: 0.4909\n", - "Epoch 20/50\n", - "340/340 [==============================] - 48s 141ms/step - loss: 0.1053 - acc: 0.4952 - val_loss: 0.1046 - val_acc: 0.4947\n", - "Epoch 21/50\n", - "340/340 [==============================] - 48s 142ms/step - loss: 0.1044 - acc: 0.4985 - val_loss: 0.1044 - val_acc: 0.4971\n", - "Epoch 22/50\n", - "340/340 [==============================] - 48s 141ms/step - loss: 0.1031 - acc: 0.5039 - val_loss: 0.1031 - val_acc: 0.5041\n", - "Epoch 23/50\n", - "340/340 [==============================] - 48s 141ms/step - loss: 0.1022 - acc: 0.5052 - val_loss: 0.1021 - val_acc: 0.5030\n", - "Epoch 24/50\n", - "340/340 [==============================] - 48s 140ms/step - loss: 0.1015 - acc: 0.5103 - val_loss: 0.1009 - val_acc: 0.5111\n", - "Epoch 25/50\n", - "340/340 [==============================] - 48s 141ms/step - loss: 0.1007 - acc: 0.5118 - val_loss: 0.1014 - val_acc: 0.5100\n", - "Epoch 26/50\n", - "340/340 [==============================] - 48s 142ms/step - loss: 0.0997 - acc: 0.5181 - val_loss: 0.1003 - val_acc: 0.5143\n", - "Epoch 27/50\n", - "340/340 [==============================] - 47s 139ms/step - loss: 0.0992 - acc: 0.5204 - val_loss: 0.1000 - val_acc: 0.5172\n", - "Epoch 28/50\n", - "340/340 [==============================] - 48s 142ms/step - loss: 0.0987 - acc: 0.5217 - val_loss: 0.0999 - val_acc: 0.5159\n", - "Epoch 29/50\n", - "340/340 [==============================] - 48s 140ms/step - loss: 0.0980 - acc: 0.5221 - val_loss: 0.0990 - val_acc: 0.5139\n", - "Epoch 30/50\n", - "340/340 [==============================] - 48s 142ms/step - loss: 0.0972 - acc: 0.5269 - val_loss: 0.0990 - val_acc: 0.5189\n", - "Epoch 31/50\n", - "340/340 [==============================] - 48s 140ms/step - loss: 0.0968 - acc: 0.5276 - val_loss: 0.0984 - val_acc: 0.5193\n", - "Epoch 32/50\n", - "340/340 [==============================] - 48s 142ms/step - loss: 0.0965 - acc: 0.5293 - val_loss: 0.0985 - val_acc: 0.5215\n", - "Epoch 33/50\n", - "340/340 [==============================] - 48s 141ms/step - loss: 0.0958 - acc: 0.5310 - val_loss: 0.0980 - val_acc: 0.5235\n", - "Epoch 34/50\n", - "340/340 [==============================] - 48s 140ms/step - loss: 0.0951 - acc: 0.5359 - val_loss: 0.0986 - val_acc: 0.5145\n", - "Epoch 35/50\n", - "340/340 [==============================] - 48s 141ms/step - loss: 0.0949 - acc: 0.5360 - val_loss: 0.0980 - val_acc: 0.5264\n", - "Epoch 36/50\n", - "340/340 [==============================] - 47s 139ms/step - loss: 0.0944 - acc: 0.5359 - val_loss: 0.0977 - val_acc: 0.5220\n", - "Epoch 37/50\n", - "340/340 [==============================] - 48s 140ms/step - loss: 0.0939 - acc: 0.5412 - val_loss: 0.0971 - val_acc: 0.5259\n", - "Epoch 38/50\n", - "340/340 [==============================] - 48s 141ms/step - loss: 0.0936 - acc: 0.5390 - val_loss: 0.0984 - val_acc: 0.5202\n", - "Epoch 39/50\n", - "340/340 [==============================] - 48s 141ms/step - loss: 0.0930 - acc: 0.5431 - val_loss: 0.0982 - val_acc: 0.5266\n", - "Epoch 40/50\n", - "340/340 [==============================] - 47s 139ms/step - loss: 0.0928 - acc: 0.5421 - val_loss: 0.0975 - val_acc: 0.5250\n", - "Epoch 41/50\n", - "340/340 [==============================] - 48s 141ms/step - loss: 0.0921 - acc: 0.5433 - val_loss: 0.0977 - val_acc: 0.5205\n", - "Epoch 42/50\n", - "340/340 [==============================] - 48s 141ms/step - loss: 0.0919 - acc: 0.5469 - val_loss: 0.0972 - val_acc: 0.5233\n", - "Epoch 43/50\n", - "340/340 [==============================] - 47s 138ms/step - loss: 0.0915 - acc: 0.5479 - val_loss: 0.0967 - val_acc: 0.5275\n", - "Epoch 44/50\n", - "340/340 [==============================] - 47s 138ms/step - loss: 0.0913 - acc: 0.5472 - val_loss: 0.0969 - val_acc: 0.5264\n", - "Epoch 45/50\n", - "340/340 [==============================] - 47s 139ms/step - loss: 0.0911 - acc: 0.5499 - val_loss: 0.0975 - val_acc: 0.5279\n", - "Epoch 46/50\n", - "340/340 [==============================] - 48s 141ms/step - loss: 0.0906 - acc: 0.5497 - val_loss: 0.0966 - val_acc: 0.5270\n", - "Epoch 47/50\n", - "340/340 [==============================] - 993s 3s/step - loss: 0.0902 - acc: 0.5534 - val_loss: 0.0963 - val_acc: 0.5298\n", - "Epoch 48/50\n", - "340/340 [==============================] - 967s 3s/step - loss: 0.0897 - acc: 0.5556 - val_loss: 0.0962 - val_acc: 0.5310\n", - "Epoch 49/50\n", - "340/340 [==============================] - 995s 3s/step - loss: 0.0893 - acc: 0.5559 - val_loss: 0.0966 - val_acc: 0.5310\n", - "Epoch 50/50\n", - "340/340 [==============================] - 1021s 3s/step - loss: 0.0890 - acc: 0.5569 - val_loss: 0.0967 - val_acc: 0.5246\n" - ] - } - ], - "source": [ - "batch_size = 128\n", - "epochs = 50\n", - "model.compile('adam', 'binary_crossentropy', ['acc'])\n", - "history = model.fit(x_train, y_train, epochs=epochs, batch_size=batch_size, \n", - " validation_data=(x_test, y_test))" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "340/340 [==============================] - 69s 176ms/step - loss: 0.1677 - acc: 0.2777 - val_loss: 0.1479 - val_acc: 0.2959\n", - "Epoch 2/30\n", - "340/340 [==============================] - 51s 151ms/step - loss: 0.1492 - acc: 0.2954 - val_loss: 0.1471 - val_acc: 0.2959\n", - "Epoch 3/30\n", - "340/340 [==============================] - 52s 154ms/step - loss: 0.1453 - acc: 0.3122 - val_loss: 0.1368 - val_acc: 0.3586\n", - "Epoch 4/30\n", - "340/340 [==============================] - 56s 164ms/step - loss: 0.1360 - acc: 0.3648 - val_loss: 0.1306 - val_acc: 0.3833\n", - "Epoch 5/30\n", - "340/340 [==============================] - 51s 150ms/step - loss: 0.1301 - acc: 0.3988 - val_loss: 0.1246 - val_acc: 0.4150" - ] - }, - { - "cell_type": "code", - "execution_count": 130, - "metadata": {}, - "outputs": [], - "source": [ - "# saving the model so we can use it on the web\n", - "model.save('demo_models/emotions.keras') # new saving format requires .keras" - ] - }, - { - "cell_type": "code", - "execution_count": 131, - "metadata": {}, - "outputs": [ - { - "data": { - "image/png": "", - "text/plain": [ - "
" - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "plt.plot(history.history['acc'])\n", - "plt.plot(history.history['val_acc'])\n", - "plt.legend(['acc', 'val_acc'])\n", - "plt.title(f'Accuracy')\n", - "plt.show()" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## model testing\n", - "let's test how well the model perform" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.\n" - ] - } - ], - "source": [ - "# reload\n", - "model = tf.keras.models.load_model('demo_models/emotions.keras', compile=False)" - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "metadata": {}, - "outputs": [], - "source": [ - "def predict_emotions(txt, threshold=0.5):\n", - " # recall it is multi-class so we need to get all prediction above a threshold (0.5)\n", - " preds = model(tf.constant([txt]))[0]\n", - " out = 0\n", - " for i in range(NUM_CLASSES):\n", - " if preds[i] > threshold:\n", - " emotion_name = CLASSES[i]\n", - " emotion_prob = round(float(preds[i]) * 100, 1)\n", - " print(f\"{emotion_name} ({emotion_prob})%\")\n", - " out += 1\n", - " if not out:\n", - " print(\"neutral\") \n", - " " - ] - }, - { - "cell_type": "code", - "execution_count": 29, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "joy (91.6)%\n" - ] - } - ], - "source": [ - "txt = \"I enjoy having a good icecream\"\n", - "predict_emotions(txt)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": ".env", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.8" - }, - "orig_nbformat": 4, - "vscode": { - "interpreter": { - "hash": "bc86f0786348ca6b89e1e790af95528ab7136d16a08a53d7be83a40ce5119309" - } - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} From 3e419a8c01fd62d1c309f5d3da18a0a6a37299a9 Mon Sep 17 00:00:00 2001 From: Marina Zhang <40069936+MarinaZhang@users.noreply.github.com> Date: Mon, 9 Oct 2023 16:24:31 -0700 Subject: [PATCH 4/4] Add train_tpu.ipynb notebook --- notebooks/train_tpu.ipynb | 551 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 551 insertions(+) create mode 100644 notebooks/train_tpu.ipynb diff --git a/notebooks/train_tpu.ipynb b/notebooks/train_tpu.ipynb new file mode 100644 index 0000000..fda4caf --- /dev/null +++ b/notebooks/train_tpu.ipynb @@ -0,0 +1,551 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "TPU" + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "### Use RETVec on TPU\n", + "\n", + "You can run this notebook in Google Colab, where you can request a TPU.\n", + "\n", + "RETVec requires a slightly different setup on TPU, because TPUs do not support string tensors. Thus, we will split the default RETVecTokenizer layer into two parts -- one that converts strings into an integer representation which runs on CPU, and the remaining components of RETVec including the word embedding model which runs on TPU (as well as the rest of the model).\n", + "\n", + "We will use the same example as in the `train_retvec_model_tf.ipynb` notebook, where we train an emotion classifier." + ], + "metadata": { + "id": "KCdhdgLc-xCP" + } + }, + { + "cell_type": "code", + "source": [ + "# installing needed dependencies\n", + "try:\n", + " import retvec\n", + "except ImportError:\n", + " !pip install retvec # is retvec installed?\n", + "\n", + "try:\n", + " import datasets\n", + "except ImportError:\n", + " !pip install datasets # used to get the dataset\n", + "\n", + "try:\n", + " import matplotlib\n", + "except ImportError:\n", + " !pip install matplotlib\n", + "\n", + "try:\n", + " import tensorflow_text\n", + "except ImportError:\n", + " !pip install tensorflow_text" + ], + "metadata": { + "id": "niDxQ-1x-mx9" + }, + "execution_count": 1, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "import os\n", + "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '1' # silence TF INFO messages\n", + "import tensorflow as tf\n", + "import numpy as np\n", + "from tensorflow.keras import layers\n", + "import tensorflow_text as text\n", + "from datasets import load_dataset\n", + "from matplotlib import pyplot as plt" + ], + "metadata": { + "id": "2Ovv0fvB-psR" + }, + "execution_count": 2, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### Import RETVec layers we need for TPU\n", + "\n", + "Note that we do not import the RETVecTokenizer layer for TPU here, since TPUs do not support tf.string inputs." + ], + "metadata": { + "id": "6oLggdAk_a2B" + } + }, + { + "cell_type": "code", + "source": [ + "from retvec.tf.layers import RETVecEmbedding, RETVecIntToBinary, RETVecIntegerizer" + ], + "metadata": { + "id": "fWwzCVs7-rgT" + }, + "execution_count": 3, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### Initialize TPU and TPU Strategy" + ], + "metadata": { + "id": "pRdgWcpXrliX" + } + }, + { + "cell_type": "code", + "source": [ + "resolver = tf.distribute.cluster_resolver.TPUClusterResolver(tpu='')\n", + "tf.config.experimental_connect_to_cluster(resolver)\n", + "# This is the TPU initialization code that has to be at the beginning.\n", + "tf.tpu.experimental.initialize_tpu_system(resolver)\n", + "print(\"All devices: \", tf.config.list_logical_devices('TPU'))\n", + "\n", + "strategy = tf.distribute.TPUStrategy(resolver)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Jrf9ViW6rkmE", + "outputId": "42c04ce9-b4de-47ae-d4d2-26c66cda78c5" + }, + "execution_count": 4, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "All devices: [LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:0', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:1', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:2', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:3', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:4', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:5', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:6', device_type='TPU'), LogicalDevice(name='/job:worker/replica:0/task:0/device:TPU:7', device_type='TPU')]\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "### Create dataset\n", + "\n", + "We are going to use the [Go Emotion](https://huggingface.co/datasets/go_emotions) dataset to create a mulit-class emotion classifier. https://ai.googleblog.com/2021/10/goemotions-dataset-for-fine-grained.html" + ], + "metadata": { + "id": "F9Rvip4S_wCO" + } + }, + { + "cell_type": "code", + "source": [ + "# downloading dataset\n", + "dataset = load_dataset('go_emotions')" + ], + "metadata": { + "id": "jxYGjVIR_6S6" + }, + "execution_count": 5, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# get class name mapping and number of class\n", + "CLASSES = dataset['train'].features['labels'].feature.names\n", + "NUM_CLASSES = len(CLASSES)\n", + "print(f\"num classes {NUM_CLASSES}\")\n", + "print(CLASSES)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "CJDELoVj_6WE", + "outputId": "0d027bb4-398a-49be-b891-a43167650e52" + }, + "execution_count": 6, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "num classes 28\n", + "['admiration', 'amusement', 'anger', 'annoyance', 'approval', 'caring', 'confusion', 'curiosity', 'desire', 'disappointment', 'disapproval', 'disgust', 'embarrassment', 'excitement', 'fear', 'gratitude', 'grief', 'joy', 'love', 'nervousness', 'optimism', 'pride', 'realization', 'relief', 'remorse', 'sadness', 'surprise', 'neutral']\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "# preparing data\n", + "x_train = tf.constant(dataset['train']['text'], dtype=tf.string)\n", + "\n", + "# the one-hot requires a little more due to the multi-class nature of the dataset.\n", + "y_train = np.zeros((len(x_train),NUM_CLASSES))\n", + "for idx, ex in enumerate(dataset['train']['labels']):\n", + " for val in ex:\n", + " y_train[idx][val] = 1\n", + "\n", + "# test data\n", + "x_test = tf.constant(dataset['test']['text'], dtype=tf.string)\n", + "y_test = np.zeros((len(x_test),NUM_CLASSES))\n", + "for idx, ex in enumerate(dataset['test']['labels']):\n", + " for val in ex:\n", + " y_test[idx][val] = 1" + ], + "metadata": { + "id": "UDAQ-jsy_6YZ" + }, + "execution_count": 14, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "### Pre-process dataset on CPU\n", + "\n", + "We use the `RETVecIntegerizer` class to convert the tf.strings input into integer tensors (on CPU). By default, the layer converts each word into it's UTF-8 codepoints (with max 16 characters per word).\n", + "\n", + "The `RETVecIntegerizer` has a method `integerize` which will encode a string tensor into its integer representation." + ], + "metadata": { + "id": "a7CZv1PUAIPg" + } + }, + { + "cell_type": "code", + "source": [ + "sequence_length = 128 # number of words per text, inputs will be padded or truncated to this length\n", + "word_length = 16 # max characters per word\n", + "\n", + "# initialize whitespace tokenizer\n", + "whitespace_tokenizer = text.WhitespaceTokenizer()\n", + "\n", + "# initialize RETVec integerizer\n", + "integerizer = RETVecIntegerizer()" + ], + "metadata": { + "id": "0lUZoLh1AV0Y" + }, + "execution_count": 15, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# split text in dataset into words\n", + "x_train = whitespace_tokenizer.tokenize(x_train)\n", + "x_test = whitespace_tokenizer.tokenize(x_test)\n", + "\n", + "# convert from ragged tensor to tensor with pad/truncation to sequence_length\n", + "x_train = x_train.to_tensor(default_value=\"\", shape=(x_train.shape[0], sequence_length))\n", + "x_test = x_test.to_tensor(default_value=\"\", shape=(x_test.shape[0], sequence_length))\n", + "\n", + "# encode each word into their integer representation\n", + "x_train = integerizer.integerize(x_train)\n", + "x_test = integerizer.integerize(x_test)" + ], + "metadata": { + "id": "p4lhTH4vrhYC" + }, + "execution_count": 16, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "print('train input shape:', x_train.shape)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "gw7QEnPKA0ZL", + "outputId": "b87bf75b-61f6-4d7b-afd8-1697b8873453" + }, + "execution_count": 17, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "train input shape: (43410, 128, 16)\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "### Create Model\n", + "\n", + "Now, we can create the model with RETVec which will run on TPU. To do this, we use the layers `RETVecIntToBinary` which will binarize the UTF-8 codepoints, and `RETVecEmbedding` which will call the embedding model and produce 256-dim float embeddings for each word. Then, we can build the rest of the model like usual and train." + ], + "metadata": { + "id": "EnAfdnUb__pt" + } + }, + { + "cell_type": "code", + "source": [ + "retvec_model_dir = 'gs://tensorflow/keras-applications/retvec-v1/' # currently colab TPU only supports gcs paths\n", + "batch_size = 256\n", + "epochs = 25\n", + "\n", + "# use TPU\n", + "with strategy.scope():\n", + " # input is int32 tensor with shape (sequence_length, word_length)\n", + " inputs = tf.keras.layers.Input(shape=(sequence_length, word_length), dtype=tf.int32)\n", + " x = RETVecIntToBinary(sequence_length=sequence_length, word_length=word_length)(inputs)\n", + " x = RETVecEmbedding(model=retvec_model_dir)(x)\n", + "\n", + " # build the rest of the model\n", + " x = layers.Bidirectional(layers.LSTM(64, return_sequences=True))(x)\n", + " x = layers.Bidirectional(layers.LSTM(64))(x)\n", + " outputs = layers.Dense(NUM_CLASSES, activation='sigmoid')(x)\n", + " model = tf.keras.Model(inputs, outputs)\n", + " model.summary()\n", + "\n", + " model.compile('adam', 'binary_crossentropy', ['acc'])\n", + " history = model.fit(x_train, y_train, epochs=epochs, batch_size=batch_size,\n", + " validation_data=(x_test, y_test))" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "LUzjZXZ--e5H", + "outputId": "0e19b6e4-b88a-40d9-8444-3adf5cecaa6b" + }, + "execution_count": 19, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "WARNING:tensorflow:No training configuration found in save file, so the model was *not* compiled. Compile it manually.\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Model: \"model_1\"\n", + "_________________________________________________________________\n", + " Layer (type) Output Shape Param # \n", + "=================================================================\n", + " input_2 (InputLayer) [(None, 128, 16)] 0 \n", + " \n", + " ret_vec_int_to_binary_1 (R (None, 128, 16, 24) 0 \n", + " ETVecIntToBinary) \n", + " \n", + " ret_vec_embedding_1 (RETVe (None, 128, 256) 230144 \n", + " cEmbedding) \n", + " \n", + " bidirectional_2 (Bidirecti (None, 128, 128) 164352 \n", + " onal) \n", + " \n", + " bidirectional_3 (Bidirecti (None, 128) 98816 \n", + " onal) \n", + " \n", + " dense_1 (Dense) (None, 28) 3612 \n", + " \n", + "=================================================================\n", + "Total params: 496924 (1.90 MB)\n", + "Trainable params: 266780 (1.02 MB)\n", + "Non-trainable params: 230144 (899.00 KB)\n", + "_________________________________________________________________\n", + "Epoch 1/25\n", + "170/170 [==============================] - 42s 121ms/step - loss: 0.1708 - acc: 0.2827 - val_loss: 0.1474 - val_acc: 0.2959\n", + "Epoch 2/25\n", + "170/170 [==============================] - 11s 64ms/step - loss: 0.1478 - acc: 0.2955 - val_loss: 0.1443 - val_acc: 0.2959\n", + "Epoch 3/25\n", + "170/170 [==============================] - 11s 68ms/step - loss: 0.1420 - acc: 0.3297 - val_loss: 0.1360 - val_acc: 0.3575\n", + "Epoch 4/25\n", + "170/170 [==============================] - 11s 62ms/step - loss: 0.1354 - acc: 0.3607 - val_loss: 0.1314 - val_acc: 0.3717\n", + "Epoch 5/25\n", + "170/170 [==============================] - 11s 63ms/step - loss: 0.1315 - acc: 0.3757 - val_loss: 0.1273 - val_acc: 0.3893\n", + "Epoch 6/25\n", + "170/170 [==============================] - 11s 67ms/step - loss: 0.1277 - acc: 0.3958 - val_loss: 0.1245 - val_acc: 0.4024\n", + "Epoch 7/25\n", + "170/170 [==============================] - 11s 62ms/step - loss: 0.1243 - acc: 0.4169 - val_loss: 0.1205 - val_acc: 0.4315\n", + "Epoch 8/25\n", + "170/170 [==============================] - 11s 63ms/step - loss: 0.1206 - acc: 0.4334 - val_loss: 0.1178 - val_acc: 0.4468\n", + "Epoch 9/25\n", + "170/170 [==============================] - 11s 64ms/step - loss: 0.1179 - acc: 0.4447 - val_loss: 0.1156 - val_acc: 0.4516\n", + "Epoch 10/25\n", + "170/170 [==============================] - 11s 64ms/step - loss: 0.1157 - acc: 0.4528 - val_loss: 0.1136 - val_acc: 0.4655\n", + "Epoch 11/25\n", + "170/170 [==============================] - 11s 63ms/step - loss: 0.1138 - acc: 0.4591 - val_loss: 0.1126 - val_acc: 0.4608\n", + "Epoch 12/25\n", + "170/170 [==============================] - 12s 73ms/step - loss: 0.1115 - acc: 0.4691 - val_loss: 0.1102 - val_acc: 0.4691\n", + "Epoch 13/25\n", + "170/170 [==============================] - 11s 64ms/step - loss: 0.1097 - acc: 0.4768 - val_loss: 0.1098 - val_acc: 0.4728\n", + "Epoch 14/25\n", + "170/170 [==============================] - 11s 63ms/step - loss: 0.1081 - acc: 0.4831 - val_loss: 0.1085 - val_acc: 0.4806\n", + "Epoch 15/25\n", + "170/170 [==============================] - 11s 65ms/step - loss: 0.1064 - acc: 0.4896 - val_loss: 0.1068 - val_acc: 0.4854\n", + "Epoch 16/25\n", + "170/170 [==============================] - 11s 63ms/step - loss: 0.1052 - acc: 0.4951 - val_loss: 0.1067 - val_acc: 0.4824\n", + "Epoch 17/25\n", + "170/170 [==============================] - 11s 64ms/step - loss: 0.1037 - acc: 0.4998 - val_loss: 0.1065 - val_acc: 0.4822\n", + "Epoch 18/25\n", + "170/170 [==============================] - 11s 65ms/step - loss: 0.1022 - acc: 0.5053 - val_loss: 0.1046 - val_acc: 0.4918\n", + "Epoch 19/25\n", + "170/170 [==============================] - 11s 63ms/step - loss: 0.1010 - acc: 0.5102 - val_loss: 0.1047 - val_acc: 0.4911\n", + "Epoch 20/25\n", + "170/170 [==============================] - 11s 66ms/step - loss: 0.0997 - acc: 0.5140 - val_loss: 0.1033 - val_acc: 0.4973\n", + "Epoch 21/25\n", + "170/170 [==============================] - 11s 65ms/step - loss: 0.0987 - acc: 0.5198 - val_loss: 0.1051 - val_acc: 0.4924\n", + "Epoch 22/25\n", + "170/170 [==============================] - 11s 62ms/step - loss: 0.0978 - acc: 0.5210 - val_loss: 0.1030 - val_acc: 0.5027\n", + "Epoch 23/25\n", + "170/170 [==============================] - 12s 69ms/step - loss: 0.0968 - acc: 0.5266 - val_loss: 0.1023 - val_acc: 0.5001\n", + "Epoch 24/25\n", + "170/170 [==============================] - 11s 67ms/step - loss: 0.0954 - acc: 0.5309 - val_loss: 0.1021 - val_acc: 0.5045\n", + "Epoch 25/25\n", + "170/170 [==============================] - 11s 62ms/step - loss: 0.0948 - acc: 0.5334 - val_loss: 0.1025 - val_acc: 0.5025\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "# visualize the training curves\n", + "plt.plot(history.history['acc'])\n", + "plt.plot(history.history['val_acc'])\n", + "plt.legend(['acc', 'val_acc'])\n", + "plt.title(f'Accuracy')\n", + "plt.show()" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 452 + }, + "id": "fLYBe3Yvsdqq", + "outputId": "9199edc4-7c8e-464b-9a6c-7a8e0610930e" + }, + "execution_count": 20, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "
" + ], + "image/png": "\n" + }, + "metadata": {} + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Test Model\n", + "\n", + "Let's test our model on some examples, noting that we have to preprocess our inputs before feeding them to the model." + ], + "metadata": { + "id": "qRnwfkjatMOV" + } + }, + { + "cell_type": "code", + "source": [ + "def preprocess_text(text):\n", + " # we need to tokenize our inputs and convert them into integer codepoints\n", + " # before passing them to the model\n", + " text = tf.constant([txt])\n", + " text = whitespace_tokenizer.tokenize(text)\n", + " text = text.to_tensor(default_value=\"\", shape=(text.shape[0], sequence_length))\n", + " text = integerizer.integerize(text)\n", + " return text\n", + "\n", + "def predict_emotions(txt, threshold=0.5):\n", + " # recall it is multi-class so we need to get all prediction above a threshold (0.5)\n", + " preds = model(preprocess_text(text))[0]\n", + " out = 0\n", + " for i in range(NUM_CLASSES):\n", + " if preds[i] > threshold:\n", + " emotion_name = CLASSES[i]\n", + " emotion_prob = round(float(preds[i]) * 100, 1)\n", + " print(f\"{emotion_name} ({emotion_prob})%\")\n", + " out += 1\n", + " if not out:\n", + " print(\"neutral\")" + ], + "metadata": { + "id": "So39n-nKrd_K" + }, + "execution_count": 22, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "txt = \"I enjoy having a good icecream.\"\n", + "predict_emotions(txt)" + ], + "metadata": { + "id": "isWJPbWf-h4R", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "d7fe825c-0ee1-412f-9919-64f35ef9c718" + }, + "execution_count": 23, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "joy (84.5)%\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "# the model works even with typos, substitutions, and emojis!\n", + "txt = \"I enjoy hving a g00d ic3cream!!! 🍦\"\n", + "predict_emotions(txt)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "twyBz-_lvGyb", + "outputId": "e79754c3-e414-425d-be6b-7d68b2eaadb2" + }, + "execution_count": 24, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "joy (85.9)%\n" + ] + } + ] + } + ] +} \ No newline at end of file