From f9aebfad32a1d8eba62348f9cdf4ffdf91bbf19b Mon Sep 17 00:00:00 2001 From: Matt Watson Date: Sat, 4 Nov 2023 00:57:53 +0000 Subject: [PATCH] Add multi-backend examples/generative/text_generation_with_miniature_gpt --- .../text_generation_with_miniature_gpt.ipynb | 82 +++++--- .../md/text_generation_with_miniature_gpt.md | 198 ++++++++++-------- .../text_generation_with_miniature_gpt.py | 70 ++++--- 3 files changed, 204 insertions(+), 146 deletions(-) diff --git a/examples/generative/ipynb/text_generation_with_miniature_gpt.ipynb b/examples/generative/ipynb/text_generation_with_miniature_gpt.ipynb index 2e4e4d2d89..a93db1a026 100644 --- a/examples/generative/ipynb/text_generation_with_miniature_gpt.ipynb +++ b/examples/generative/ipynb/text_generation_with_miniature_gpt.ipynb @@ -58,14 +58,30 @@ }, "outputs": [], "source": [ - "import tensorflow as tf\n", - "from tensorflow import keras\n", - "from tensorflow.keras import layers\n", - "from tensorflow.keras.layers import TextVectorization\n", + "# We set the backend to TensorFlow. The code works with\n", + "# both `tensorflow` and `torch`. It does not work with JAX\n", + "# due to the behavior of `jax.numpy.tile` in a jit scope\n", + "# (used in `causal_attention_mask()`: `tile` in JAX does\n", + "# not support a dynamic `reps` argument.\n", + "# You can make the code work in JAX by wrapping the\n", + "# inside of the `causal_attention_mask` function in\n", + "# a decorator to prevent jit compilation:\n", + "# `with jax.ensure_compile_time_eval():`.\n", + "import os\n", + "\n", + "os.environ[\"KERAS_BACKEND\"] = \"tensorflow\"\n", + "\n", + "import keras\n", + "from keras import layers\n", + "from keras import ops\n", + "from keras.layers import TextVectorization\n", "import numpy as np\n", "import os\n", "import string\n", "import random\n", + "import tensorflow\n", + "import tensorflow.data as tf_data\n", + "import tensorflow.strings as tf_strings\n", "" ] }, @@ -93,15 +109,15 @@ " This prevents flow of information from future tokens to current token.\n", " 1's in the lower triangle, counting from the lower right corner.\n", " \"\"\"\n", - " i = tf.range(n_dest)[:, None]\n", - " j = tf.range(n_src)\n", + " i = ops.arange(n_dest)[:, None]\n", + " j = ops.arange(n_src)\n", " m = i >= j - n_src + n_dest\n", - " mask = tf.cast(m, dtype)\n", - " mask = tf.reshape(mask, [1, n_dest, n_src])\n", - " mult = tf.concat(\n", - " [tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)], 0\n", + " mask = ops.cast(m, dtype)\n", + " mask = ops.reshape(mask, [1, n_dest, n_src])\n", + " mult = ops.concatenate(\n", + " [ops.expand_dims(batch_size, -1), ops.convert_to_tensor([1, 1])], 0\n", " )\n", - " return tf.tile(mask, mult)\n", + " return ops.tile(mask, mult)\n", "\n", "\n", "class TransformerBlock(layers.Layer):\n", @@ -109,7 +125,10 @@ " super().__init__()\n", " self.att = layers.MultiHeadAttention(num_heads, embed_dim)\n", " self.ffn = keras.Sequential(\n", - " [layers.Dense(ff_dim, activation=\"relu\"), layers.Dense(embed_dim),]\n", + " [\n", + " layers.Dense(ff_dim, activation=\"relu\"),\n", + " layers.Dense(embed_dim),\n", + " ]\n", " )\n", " self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)\n", " self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)\n", @@ -117,10 +136,10 @@ " self.dropout2 = layers.Dropout(rate)\n", "\n", " def call(self, inputs):\n", - " input_shape = tf.shape(inputs)\n", + " input_shape = ops.shape(inputs)\n", " batch_size = input_shape[0]\n", " seq_len = input_shape[1]\n", - " causal_mask = causal_attention_mask(batch_size, seq_len, seq_len, tf.bool)\n", + " causal_mask = causal_attention_mask(batch_size, seq_len, seq_len, \"bool\")\n", " attention_output = self.att(inputs, inputs, attention_mask=causal_mask)\n", " attention_output = self.dropout1(attention_output)\n", " out1 = self.layernorm1(inputs + attention_output)\n", @@ -158,8 +177,8 @@ " self.pos_emb = layers.Embedding(input_dim=maxlen, output_dim=embed_dim)\n", "\n", " def call(self, x):\n", - " maxlen = tf.shape(x)[-1]\n", - " positions = tf.range(start=0, limit=maxlen, delta=1)\n", + " maxlen = ops.shape(x)[-1]\n", + " positions = ops.arange(0, maxlen, 1)\n", " positions = self.pos_emb(positions)\n", " x = self.token_emb(x)\n", " return x + positions\n", @@ -191,16 +210,17 @@ "\n", "\n", "def create_model():\n", - " inputs = layers.Input(shape=(maxlen,), dtype=tf.int32)\n", + " inputs = layers.Input(shape=(maxlen,), dtype=\"int32\")\n", " embedding_layer = TokenAndPositionEmbedding(maxlen, vocab_size, embed_dim)\n", " x = embedding_layer(inputs)\n", " transformer_block = TransformerBlock(embed_dim, num_heads, feed_forward_dim)\n", " x = transformer_block(x)\n", " outputs = layers.Dense(vocab_size)(x)\n", " model = keras.Model(inputs=inputs, outputs=[outputs, x])\n", - " loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n", + " loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n", " model.compile(\n", - " \"adam\", loss=[loss_fn, None],\n", + " \"adam\",\n", + " loss=[loss_fn, None],\n", " ) # No loss and optimization based on word embeddings from transformer block\n", " return model\n", "" @@ -259,16 +279,16 @@ "\n", "# Create a dataset from text files\n", "random.shuffle(filenames)\n", - "text_ds = tf.data.TextLineDataset(filenames)\n", + "text_ds = tf_data.TextLineDataset(filenames)\n", "text_ds = text_ds.shuffle(buffer_size=256)\n", "text_ds = text_ds.batch(batch_size)\n", "\n", "\n", "def custom_standardization(input_string):\n", - " \"\"\" Remove html line-break tags and handle punctuation \"\"\"\n", - " lowercased = tf.strings.lower(input_string)\n", - " stripped_html = tf.strings.regex_replace(lowercased, \"
\", \" \")\n", - " return tf.strings.regex_replace(stripped_html, f\"([{string.punctuation}])\", r\" \\1\")\n", + " \"\"\"Remove html line-break tags and handle punctuation\"\"\"\n", + " lowercased = tf_strings.lower(input_string)\n", + " stripped_html = tf_strings.regex_replace(lowercased, \"
\", \" \")\n", + " return tf_strings.regex_replace(stripped_html, f\"([{string.punctuation}])\", r\" \\1\")\n", "\n", "\n", "# Create a vectorization layer and adapt it to the text\n", @@ -288,15 +308,15 @@ " word at position (i+1). The model will use all words up till position (i)\n", " to predict the next word.\n", " \"\"\"\n", - " text = tf.expand_dims(text, -1)\n", + " text = tensorflow.expand_dims(text, -1)\n", " tokenized_sentences = vectorize_layer(text)\n", " x = tokenized_sentences[:, :-1]\n", " y = tokenized_sentences[:, 1:]\n", " return x, y\n", "\n", "\n", - "text_ds = text_ds.map(prepare_lm_inputs_labels)\n", - "text_ds = text_ds.prefetch(tf.data.AUTOTUNE)\n", + "text_ds = text_ds.map(prepare_lm_inputs_labels, num_parallel_calls=tf_data.AUTOTUNE)\n", + "text_ds = text_ds.prefetch(tf_data.AUTOTUNE)\n", "" ] }, @@ -342,9 +362,9 @@ " self.k = top_k\n", "\n", " def sample_from(self, logits):\n", - " logits, indices = tf.math.top_k(logits, k=self.k, sorted=True)\n", + " logits, indices = ops.top_k(logits, k=self.k, sorted=True)\n", " indices = np.asarray(indices).astype(\"int32\")\n", - " preds = keras.activations.softmax(tf.expand_dims(logits, 0))[0]\n", + " preds = keras.activations.softmax(ops.expand_dims(logits, 0))[0]\n", " preds = np.asarray(preds).astype(\"float32\")\n", " return np.random.choice(indices, p=preds)\n", "\n", @@ -368,7 +388,7 @@ " else:\n", " x = start_tokens\n", " x = np.array([x])\n", - " y, _ = self.model.predict(x)\n", + " y, _ = self.model.predict(x, verbose=0)\n", " sample_token = self.sample_from(y[0][sample_index])\n", " tokens_generated.append(sample_token)\n", " start_tokens.append(sample_token)\n", @@ -445,4 +465,4 @@ }, "nbformat": 4, "nbformat_minor": 0 -} +} \ No newline at end of file diff --git a/examples/generative/md/text_generation_with_miniature_gpt.md b/examples/generative/md/text_generation_with_miniature_gpt.md index 9425664cfc..78688a9c0b 100644 --- a/examples/generative/md/text_generation_with_miniature_gpt.md +++ b/examples/generative/md/text_generation_with_miniature_gpt.md @@ -36,14 +36,30 @@ with TensorFlow 2.3 or higher. ```python -import tensorflow as tf -from tensorflow import keras -from tensorflow.keras import layers -from tensorflow.keras.layers import TextVectorization +# We set the backend to TensorFlow. The code works with +# both `tensorflow` and `torch`. It does not work with JAX +# due to the behavior of `jax.numpy.tile` in a jit scope +# (used in `causal_attention_mask()`: `tile` in JAX does +# not support a dynamic `reps` argument. +# You can make the code work in JAX by wrapping the +# inside of the `causal_attention_mask` function in +# a decorator to prevent jit compilation: +# `with jax.ensure_compile_time_eval():`. +import os + +os.environ["KERAS_BACKEND"] = "tensorflow" + +import keras +from keras import layers +from keras import ops +from keras.layers import TextVectorization import numpy as np import os import string import random +import tensorflow +import tensorflow.data as tf_data +import tensorflow.strings as tf_strings ``` @@ -59,15 +75,15 @@ def causal_attention_mask(batch_size, n_dest, n_src, dtype): This prevents flow of information from future tokens to current token. 1's in the lower triangle, counting from the lower right corner. """ - i = tf.range(n_dest)[:, None] - j = tf.range(n_src) + i = ops.arange(n_dest)[:, None] + j = ops.arange(n_src) m = i >= j - n_src + n_dest - mask = tf.cast(m, dtype) - mask = tf.reshape(mask, [1, n_dest, n_src]) - mult = tf.concat( - [tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)], 0 + mask = ops.cast(m, dtype) + mask = ops.reshape(mask, [1, n_dest, n_src]) + mult = ops.concatenate( + [ops.expand_dims(batch_size, -1), ops.convert_to_tensor([1, 1])], 0 ) - return tf.tile(mask, mult) + return ops.tile(mask, mult) class TransformerBlock(layers.Layer): @@ -75,7 +91,10 @@ class TransformerBlock(layers.Layer): super().__init__() self.att = layers.MultiHeadAttention(num_heads, embed_dim) self.ffn = keras.Sequential( - [layers.Dense(ff_dim, activation="relu"), layers.Dense(embed_dim),] + [ + layers.Dense(ff_dim, activation="relu"), + layers.Dense(embed_dim), + ] ) self.layernorm1 = layers.LayerNormalization(epsilon=1e-6) self.layernorm2 = layers.LayerNormalization(epsilon=1e-6) @@ -83,10 +102,10 @@ class TransformerBlock(layers.Layer): self.dropout2 = layers.Dropout(rate) def call(self, inputs): - input_shape = tf.shape(inputs) + input_shape = ops.shape(inputs) batch_size = input_shape[0] seq_len = input_shape[1] - causal_mask = causal_attention_mask(batch_size, seq_len, seq_len, tf.bool) + causal_mask = causal_attention_mask(batch_size, seq_len, seq_len, "bool") attention_output = self.att(inputs, inputs, attention_mask=causal_mask) attention_output = self.dropout1(attention_output) out1 = self.layernorm1(inputs + attention_output) @@ -112,8 +131,8 @@ class TokenAndPositionEmbedding(layers.Layer): self.pos_emb = layers.Embedding(input_dim=maxlen, output_dim=embed_dim) def call(self, x): - maxlen = tf.shape(x)[-1] - positions = tf.range(start=0, limit=maxlen, delta=1) + maxlen = ops.shape(x)[-1] + positions = ops.arange(0, maxlen, 1) positions = self.pos_emb(positions) x = self.token_emb(x) return x + positions @@ -133,16 +152,17 @@ feed_forward_dim = 256 # Hidden layer size in feed forward network inside trans def create_model(): - inputs = layers.Input(shape=(maxlen,), dtype=tf.int32) + inputs = layers.Input(shape=(maxlen,), dtype="int32") embedding_layer = TokenAndPositionEmbedding(maxlen, vocab_size, embed_dim) x = embedding_layer(inputs) transformer_block = TransformerBlock(embed_dim, num_heads, feed_forward_dim) x = transformer_block(x) outputs = layers.Dense(vocab_size)(x) model = keras.Model(inputs=inputs, outputs=[outputs, x]) - loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) + loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True) model.compile( - "adam", loss=[loss_fn, None], + "adam", + loss=[loss_fn, None], ) # No loss and optimization based on word embeddings from transformer block return model @@ -182,16 +202,16 @@ print(f"{len(filenames)} files") # Create a dataset from text files random.shuffle(filenames) -text_ds = tf.data.TextLineDataset(filenames) +text_ds = tf_data.TextLineDataset(filenames) text_ds = text_ds.shuffle(buffer_size=256) text_ds = text_ds.batch(batch_size) def custom_standardization(input_string): - """ Remove html line-break tags and handle punctuation """ - lowercased = tf.strings.lower(input_string) - stripped_html = tf.strings.regex_replace(lowercased, "
", " ") - return tf.strings.regex_replace(stripped_html, f"([{string.punctuation}])", r" \1") + """Remove html line-break tags and handle punctuation""" + lowercased = tf_strings.lower(input_string) + stripped_html = tf_strings.regex_replace(lowercased, "
", " ") + return tf_strings.regex_replace(stripped_html, f"([{string.punctuation}])", r" \1") # Create a vectorization layer and adapt it to the text @@ -211,22 +231,22 @@ def prepare_lm_inputs_labels(text): word at position (i+1). The model will use all words up till position (i) to predict the next word. """ - text = tf.expand_dims(text, -1) + text = tensorflow.expand_dims(text, -1) tokenized_sentences = vectorize_layer(text) x = tokenized_sentences[:, :-1] y = tokenized_sentences[:, 1:] return x, y -text_ds = text_ds.map(prepare_lm_inputs_labels) -text_ds = text_ds.prefetch(tf.data.AUTOTUNE) +text_ds = text_ds.map(prepare_lm_inputs_labels, num_parallel_calls=tf_data.AUTOTUNE) +text_ds = text_ds.prefetch(tf_data.AUTOTUNE) ```
``` % Total % Received % Xferd Average Speed Time Time Time Current Dload Upload Total Spent Left Speed -100 80.2M 100 80.2M 0 0 24.2M 0 0:00:03 0:00:03 --:--:-- 24.2M +100 80.2M 100 80.2M 0 0 7926k 0 0:00:10 0:00:10 --:--:-- 7661k 50000 files @@ -262,9 +282,9 @@ class TextGenerator(keras.callbacks.Callback): self.k = top_k def sample_from(self, logits): - logits, indices = tf.math.top_k(logits, k=self.k, sorted=True) + logits, indices = ops.top_k(logits, k=self.k, sorted=True) indices = np.asarray(indices).astype("int32") - preds = keras.activations.softmax(tf.expand_dims(logits, 0))[0] + preds = keras.activations.softmax(ops.expand_dims(logits, 0))[0] preds = np.asarray(preds).astype("float32") return np.random.choice(indices, p=preds) @@ -288,7 +308,7 @@ class TextGenerator(keras.callbacks.Callback): else: x = start_tokens x = np.array([x]) - y, _ = self.model.predict(x) + y, _ = self.model.predict(x, verbose=0) sample_token = self.sample_from(y[0][sample_index]) tokens_generated.append(sample_token) start_tokens.append(sample_token) @@ -326,236 +346,238 @@ model.fit(text_ds, verbose=2, epochs=25, callbacks=[text_gen_callback])
``` Epoch 1/25 -391/391 - 135s - loss: 5.5949 - dense_2_loss: 5.5949 + +WARNING: All log messages before absl::InitializeLog() is called are written to STDERR +I0000 00:00:1699499022.078758 633491 device_compiler.h:187] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process. +/home/mattdangerw/miniconda3/envs/keras-tensorflow/lib/python3.10/contextlib.py:153: UserWarning: Your input ran out of data; interrupting training. Make sure that your dataset or generator can generate at least `steps_per_epoch * epochs` batches. You may need to use the `.repeat()` function when building your dataset. + self.gen.throw(typ, value, traceback) + generated text: -this movie is a great movie . the film is so many other comments . the plot and some people were [UNK] to do . i think the story is about that it is not a good movie . there are very good actors +this movie is a good example of the [UNK] " movies , and the movie was pretty well written , i had to say that the movie made me of the [UNK] " and was very well done . i 've seen a few ```
``` +391/391 - 33s - 84ms/step - loss: 5.4696 Epoch 2/25 -391/391 - 135s - loss: 4.7108 - dense_2_loss: 4.7108 generated text: -this movie is one of the worst movies i have ever seen . i have no doubt the better movies of this one 's worst movies i have ever seen . i don 't know what the hell , and i 'm not going +this movie is so far the worst movies i have ever seen . it is that it just a bit of a movie but i really don 't think it is a very bad movie . it is a lot and the characters in ```
``` +391/391 - 16s - 42ms/step - loss: 4.7016 Epoch 3/25 -391/391 - 135s - loss: 4.4620 - dense_2_loss: 4.4620 generated text: -this movie is a very good movie , i think i am not a kid . the story is a great movie . the director who is a great director who likes the director 's film . this was not funny and the director +this movie is a classic and has a good cast in a good story . the movie itself is good at best . the acting is superb , but the story is a little bit slow , the music hall , and music is ```
``` +391/391 - 16s - 42ms/step - loss: 4.4533 Epoch 4/25 -391/391 - 136s - loss: 4.3047 - dense_2_loss: 4.3047 generated text: -this movie is a very good story and very well . this movie is one of the worst movies i have ever seen , and there are some good actors and actresses in the movie , it is not the worst . the script +this movie is a good , and is not the greatest movie ever since , the director has a lot of [UNK] , but it 's just a bit of the original and the plot has some decent acting , it has a bit ```
``` +391/391 - 16s - 42ms/step - loss: 4.2985 Epoch 5/25 -391/391 - 135s - loss: 4.1840 - dense_2_loss: 4.1840 generated text: -this movie is a very good movie . it is the best thing about it 's a very good movie . it 's not funny , very , it 's so bad that it 's so funny , it 's like most romantic movie +this movie is really bad , the acting in this movie is bad and bad . it 's not bad it . it 's a bad story about a bad film but i can 't imagine if it 's a bad ending . the ```
``` +391/391 - 17s - 42ms/step - loss: 4.1787 Epoch 6/25 -391/391 - 135s - loss: 4.0834 - dense_2_loss: 4.0834 generated text: -this movie is the worst . the acting is awful . i have to admit that you 're just watching this film as i have to say that it is a [UNK] with [UNK] [UNK] " in the last ten years . i think +this movie is so bad , the bad acting , everything is awful , the script is bad , and the only one that i just saw in the original [UNK] . i was hoping it could make up the sequel . it wasn ```
``` +391/391 - 17s - 42ms/step - loss: 4.0807 Epoch 7/25 -391/391 - 135s - loss: 3.9987 - dense_2_loss: 3.9987 generated text: -this movie is really about the acting is good and the script . i don 't think this is just a waste of movie . it was so terrible that it wasn 't funny , but that 's what it was made in movies +this movie is one of the best kung fu movies i 've ever seen , i have seen in my life that 's not for anyone who has to think of it , or maybe even though , i can 't find it funny ```
``` +391/391 - 16s - 42ms/step - loss: 3.9978 Epoch 8/25 -391/391 - 134s - loss: 3.9242 - dense_2_loss: 3.9242 generated text: -this movie is so bad . the story itself is about a family guy named jack , who is told by a father , who is trying to get to help him to commit . he has the same problem and the [UNK] . +this movie is just plain boring . . . . . . . . . . . . . . . . . [UNK] , the movie [UNK] . . . [UNK] . . . . . . [UNK] is a bad , it ```
``` +391/391 - 17s - 42ms/step - loss: 3.9236 Epoch 9/25 -391/391 - 135s - loss: 3.8579 - dense_2_loss: 3.8579 generated text: -this movie is not bad , it does not deserve one . i can say that i was able to sit at , relax [UNK] . i was wrong , and i think i was able to buy the dvd , i would say +this movie is the only good movie i think i 've never seen it again . but it 's the only thing i feel about it . the story was about the fact that it was a very good movie . the movie has ```
``` +391/391 - 17s - 42ms/step - loss: 3.8586 Epoch 10/25 -391/391 - 134s - loss: 3.7989 - dense_2_loss: 3.7989 generated text: -this movie is very funny ! its very funny . a touching movie about three women who don 't know who is not to go on with a movie that has a lot of fun to watch . it is funny . the main +this movie is very well written and directed . it contains some of the best [UNK] in the genre . the story is about a group of actors , especially jamie harris and danny glover who are the only good guys that is really ```
``` +391/391 - 17s - 42ms/step - loss: 3.8002 Epoch 11/25 -391/391 - 134s - loss: 3.7459 - dense_2_loss: 3.7459 generated text: -this movie is not the best movie i 've seen in a long time . this movie was just about a guy who gets killed for one . . i saw this movie at a time when i first saw it in the movie +this movie is so terrible . i think that the movie isn 't as bad as you should go and watch it again . there were so many clichés that it 's a very bad movie in itself . there is no story line ```
``` +391/391 - 17s - 42ms/step - loss: 3.7478 Epoch 12/25 -391/391 - 134s - loss: 3.6974 - dense_2_loss: 3.6974 generated text: -this movie is a good example of how many films have seen and many films , that are often overlooked , in the seventies , in fact it is more enjoyable than the average viewer has some interesting parallels . this movie is based +this movie is a total waste of money and money . i am surprised to find it very funny , very enjoyable . the plot is totally unbelievable , the acting is horrible . the story is awful , it 's not scary at ```
``` +391/391 - 17s - 42ms/step - loss: 3.6993 Epoch 13/25 -391/391 - 134s - loss: 3.6534 - dense_2_loss: 3.6534 generated text: -this movie is so bad ! i think this is one . i really didn 't think anybody who gets the impression that the people who is trying to find themselves to be funny . . there 's the humor is no punchline ? +this movie is so bad and not very good as it goes . it 's a nice movie and it 's so bad that it takes you back on your tv . i don 't really know how bad this movie is . you ```
``` +391/391 - 17s - 42ms/step - loss: 3.6546 Epoch 14/25 -391/391 - 134s - loss: 3.6123 - dense_2_loss: 3.6123 generated text: -this movie is really bad . the actors are good ,the acting is great . a must see [UNK] the worst in history of all time . the plot is so bad that you can 't even make a bad movie about the bad +this movie is a great fun story , with lots of action , and romance . if you like the action and the story is really bad . it doesn 't get the idea , but you have your heart of darkness . the ```
``` +391/391 - 17s - 42ms/step - loss: 3.6147 Epoch 15/25 -391/391 - 134s - loss: 3.5745 - dense_2_loss: 3.5745 generated text: -this movie is one of the worst movies i 've ever had . the acting and direction are terrible . what i 've seen , i 've watched it several times , and i can 't really believe how to make a movie about +this movie is a little more than a horror film . it 's not really a great deal , i can honestly say , a story about a group of teens that are all over the place . but this is still a fun ```
``` +391/391 - 17s - 42ms/step - loss: 3.5769 Epoch 16/25 -391/391 - 134s - loss: 3.5404 - dense_2_loss: 3.5404 generated text: -this movie is so bad it is . that it is supposed to be a comedy . the script , which is just as bad as some movies are bad . if you 're looking for it , if you 're in the mood +this movie is just about a guy who is supposed to be a girl in the [UNK] of a movie that doesn 't make sense . the humor is not to watch it all the way the movie is . you can 't know ```
``` +391/391 - 17s - 42ms/step - loss: 3.5425 Epoch 17/25 -391/391 - 134s - loss: 3.5083 - dense_2_loss: 3.5083 generated text: -this movie is one of all bad movies i have a fan ever seen . i have seen a good movies , this isn 't the worst . i 've seen in a long time . the story involves twins , a priest and +this movie is one of the best movies i 've ever seen . i was really surprised when renting it and it wasn 't even better in it , it was not even funny and i really don 't really know what i was ```
``` +391/391 - 17s - 42ms/step - loss: 3.5099 Epoch 18/25 -391/391 - 134s - loss: 3.4789 - dense_2_loss: 3.4789 generated text: -this movie is a great movie . it 's a shame that it was hard to see that it was . this movie is a good movie . the movie itself is a complete waste of time and time you have a bad rant +this movie is so bad . i think it 's a bit overrated . i have a lot of bad movies . i have to say that this movie was just bad . i was hoping the [UNK] . the [UNK] is good " ```
``` +391/391 - 17s - 43ms/step - loss: 3.4800 Epoch 19/25 -391/391 - 134s - loss: 3.4513 - dense_2_loss: 3.4513 generated text: -this movie is not one of the most moving movies i have ever seen . the story is about the plot is just so ridiculous that i could have done it with the actors . the actors are great and the acting is great +this movie is one of the best kung fu movies i 've ever seen . it was a great movie , and for the music . the graphics are really cool . it 's like a lot more than the action scenes and action ```
``` +391/391 - 17s - 42ms/step - loss: 3.4520 Epoch 20/25 -391/391 - 134s - loss: 3.4251 - dense_2_loss: 3.4251 generated text: -this movie is about a man named todd . it is a funny movie that has a lot of nerve on screen . it is not just the right ingredients and a movie . it is a great film , and it is a +this movie is just plain awful and stupid .i cant get the movie . i cant believe people have ever spent money and money on the [UNK] . i swear i was so embarrassed to say that i had a few lines that are ```
``` +391/391 - 17s - 42ms/step - loss: 3.4260 Epoch 21/25 -391/391 - 134s - loss: 3.4011 - dense_2_loss: 3.4011 generated text: -this movie is not only funny , but i have never seen it before . the other comments i am not kidding or have been [UNK] and the worst movie i have to be . . there is something that is no where else +this movie is one of those movies that i 've ever seen , and you must know that i can say that i was not impressed with this one . i found it to be an interesting one . the story of the first ```
``` +391/391 - 17s - 42ms/step - loss: 3.4014 Epoch 22/25 -391/391 - 134s - loss: 3.3787 - dense_2_loss: 3.3787 generated text: -this movie is a very entertaining , very funny , and very funny , very well written and very nicely directed movie . this was done , very well done , with very good acting and a wonderful script , a very good movie +this movie is about a man 's life and it is a very good film and it takes a look at some sort of movie . this movie is one of the most amazing movie you have ever had to go in , so ```
``` +391/391 - 17s - 42ms/step - loss: 3.3783 Epoch 23/25 -391/391 - 133s - loss: 3.3575 - dense_2_loss: 3.3575 generated text: -this movie is the kind of movie you will not be disappointed . it 's like an [UNK] [UNK] , who is a movie . it 's a great story and the characters are great , the actors are good , their [UNK] , +this movie is a great , good thing about this movie , even the worst i 've ever seen ! it doesn 't mean anything terribly , the acting and the directing is terrible . the script is bad , the plot and the ```
``` +391/391 - 17s - 42ms/step - loss: 3.3564 Epoch 24/25 -391/391 - 134s - loss: 3.3372 - dense_2_loss: 3.3372 generated text: -this movie is a classic 80s horror movie . this has a great premise and the characters is a bit too typical [UNK] and [UNK] " with the [UNK] " . it 's all that makes sense . the characters were shallow and unrealistic +this movie is one of the best movies ever . [UNK] [UNK] ' is about the main character and a nobleman named fallon ; is stranded on an eccentric , falls in love when her island escaped . when , meanwhile , the escaped ```
``` +391/391 - 17s - 42ms/step - loss: 3.3362 Epoch 25/25 -391/391 - 134s - loss: 3.3182 - dense_2_loss: 3.3182 generated text: -this movie is not the worst movie i have ever seen . it 's a movie where i 've never seen it before and i 've seen it again and again , again , i can 't believe it was made in a theatre +this movie is very good . the acting , especially the whole movie itself - a total of the worst . this movie had a lot to recommend it to anyone . it is not funny . the story is so lame ! the ```
- - - - -
``` - +391/391 - 17s - 42ms/step - loss: 3.3170 + + ``` -
+
\ No newline at end of file diff --git a/examples/generative/text_generation_with_miniature_gpt.py b/examples/generative/text_generation_with_miniature_gpt.py index b36b527725..7287731dd8 100644 --- a/examples/generative/text_generation_with_miniature_gpt.py +++ b/examples/generative/text_generation_with_miniature_gpt.py @@ -30,14 +30,30 @@ """ ## Setup """ -import tensorflow as tf -from tensorflow import keras -from tensorflow.keras import layers -from tensorflow.keras.layers import TextVectorization +# We set the backend to TensorFlow. The code works with +# both `tensorflow` and `torch`. It does not work with JAX +# due to the behavior of `jax.numpy.tile` in a jit scope +# (used in `causal_attention_mask()`: `tile` in JAX does +# not support a dynamic `reps` argument. +# You can make the code work in JAX by wrapping the +# inside of the `causal_attention_mask` function in +# a decorator to prevent jit compilation: +# `with jax.ensure_compile_time_eval():`. +import os + +os.environ["KERAS_BACKEND"] = "tensorflow" + +import keras +from keras import layers +from keras import ops +from keras.layers import TextVectorization import numpy as np import os import string import random +import tensorflow +import tensorflow.data as tf_data +import tensorflow.strings as tf_strings """ @@ -51,15 +67,15 @@ def causal_attention_mask(batch_size, n_dest, n_src, dtype): This prevents flow of information from future tokens to current token. 1's in the lower triangle, counting from the lower right corner. """ - i = tf.range(n_dest)[:, None] - j = tf.range(n_src) + i = ops.arange(n_dest)[:, None] + j = ops.arange(n_src) m = i >= j - n_src + n_dest - mask = tf.cast(m, dtype) - mask = tf.reshape(mask, [1, n_dest, n_src]) - mult = tf.concat( - [tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)], 0 + mask = ops.cast(m, dtype) + mask = ops.reshape(mask, [1, n_dest, n_src]) + mult = ops.concatenate( + [ops.expand_dims(batch_size, -1), ops.convert_to_tensor([1, 1])], 0 ) - return tf.tile(mask, mult) + return ops.tile(mask, mult) class TransformerBlock(layers.Layer): @@ -78,10 +94,10 @@ def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1): self.dropout2 = layers.Dropout(rate) def call(self, inputs): - input_shape = tf.shape(inputs) + input_shape = ops.shape(inputs) batch_size = input_shape[0] seq_len = input_shape[1] - causal_mask = causal_attention_mask(batch_size, seq_len, seq_len, tf.bool) + causal_mask = causal_attention_mask(batch_size, seq_len, seq_len, "bool") attention_output = self.att(inputs, inputs, attention_mask=causal_mask) attention_output = self.dropout1(attention_output) out1 = self.layernorm1(inputs + attention_output) @@ -105,8 +121,8 @@ def __init__(self, maxlen, vocab_size, embed_dim): self.pos_emb = layers.Embedding(input_dim=maxlen, output_dim=embed_dim) def call(self, x): - maxlen = tf.shape(x)[-1] - positions = tf.range(start=0, limit=maxlen, delta=1) + maxlen = ops.shape(x)[-1] + positions = ops.arange(0, maxlen, 1) positions = self.pos_emb(positions) x = self.token_emb(x) return x + positions @@ -123,14 +139,14 @@ def call(self, x): def create_model(): - inputs = layers.Input(shape=(maxlen,), dtype=tf.int32) + inputs = layers.Input(shape=(maxlen,), dtype="int32") embedding_layer = TokenAndPositionEmbedding(maxlen, vocab_size, embed_dim) x = embedding_layer(inputs) transformer_block = TransformerBlock(embed_dim, num_heads, feed_forward_dim) x = transformer_block(x) outputs = layers.Dense(vocab_size)(x) model = keras.Model(inputs=inputs, outputs=[outputs, x]) - loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) + loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True) model.compile( "adam", loss=[loss_fn, None], @@ -171,16 +187,16 @@ def create_model(): # Create a dataset from text files random.shuffle(filenames) -text_ds = tf.data.TextLineDataset(filenames) +text_ds = tf_data.TextLineDataset(filenames) text_ds = text_ds.shuffle(buffer_size=256) text_ds = text_ds.batch(batch_size) def custom_standardization(input_string): """Remove html line-break tags and handle punctuation""" - lowercased = tf.strings.lower(input_string) - stripped_html = tf.strings.regex_replace(lowercased, "
", " ") - return tf.strings.regex_replace(stripped_html, f"([{string.punctuation}])", r" \1") + lowercased = tf_strings.lower(input_string) + stripped_html = tf_strings.regex_replace(lowercased, "
", " ") + return tf_strings.regex_replace(stripped_html, f"([{string.punctuation}])", r" \1") # Create a vectorization layer and adapt it to the text @@ -200,15 +216,15 @@ def prepare_lm_inputs_labels(text): word at position (i+1). The model will use all words up till position (i) to predict the next word. """ - text = tf.expand_dims(text, -1) + text = tensorflow.expand_dims(text, -1) tokenized_sentences = vectorize_layer(text) x = tokenized_sentences[:, :-1] y = tokenized_sentences[:, 1:] return x, y -text_ds = text_ds.map(prepare_lm_inputs_labels, num_parallel_calls=tf.data.AUTOTUNE) -text_ds = text_ds.prefetch(tf.data.AUTOTUNE) +text_ds = text_ds.map(prepare_lm_inputs_labels, num_parallel_calls=tf_data.AUTOTUNE) +text_ds = text_ds.prefetch(tf_data.AUTOTUNE) """ @@ -240,9 +256,9 @@ def __init__( self.k = top_k def sample_from(self, logits): - logits, indices = tf.math.top_k(logits, k=self.k, sorted=True) + logits, indices = ops.top_k(logits, k=self.k, sorted=True) indices = np.asarray(indices).astype("int32") - preds = keras.activations.softmax(tf.expand_dims(logits, 0))[0] + preds = keras.activations.softmax(ops.expand_dims(logits, 0))[0] preds = np.asarray(preds).astype("float32") return np.random.choice(indices, p=preds) @@ -266,7 +282,7 @@ def on_epoch_end(self, epoch, logs=None): else: x = start_tokens x = np.array([x]) - y, _ = self.model.predict(x) + y, _ = self.model.predict(x, verbose=0) sample_token = self.sample_from(y[0][sample_index]) tokens_generated.append(sample_token) start_tokens.append(sample_token)