diff --git a/examples/generative/ipynb/text_generation_with_miniature_gpt.ipynb b/examples/generative/ipynb/text_generation_with_miniature_gpt.ipynb
index 2e4e4d2d89..a93db1a026 100644
--- a/examples/generative/ipynb/text_generation_with_miniature_gpt.ipynb
+++ b/examples/generative/ipynb/text_generation_with_miniature_gpt.ipynb
@@ -58,14 +58,30 @@
},
"outputs": [],
"source": [
- "import tensorflow as tf\n",
- "from tensorflow import keras\n",
- "from tensorflow.keras import layers\n",
- "from tensorflow.keras.layers import TextVectorization\n",
+ "# We set the backend to TensorFlow. The code works with\n",
+ "# both `tensorflow` and `torch`. It does not work with JAX\n",
+ "# due to the behavior of `jax.numpy.tile` in a jit scope\n",
+ "# (used in `causal_attention_mask()`: `tile` in JAX does\n",
+ "# not support a dynamic `reps` argument.\n",
+ "# You can make the code work in JAX by wrapping the\n",
+ "# inside of the `causal_attention_mask` function in\n",
+ "# a decorator to prevent jit compilation:\n",
+ "# `with jax.ensure_compile_time_eval():`.\n",
+ "import os\n",
+ "\n",
+ "os.environ[\"KERAS_BACKEND\"] = \"tensorflow\"\n",
+ "\n",
+ "import keras\n",
+ "from keras import layers\n",
+ "from keras import ops\n",
+ "from keras.layers import TextVectorization\n",
"import numpy as np\n",
"import os\n",
"import string\n",
"import random\n",
+ "import tensorflow\n",
+ "import tensorflow.data as tf_data\n",
+ "import tensorflow.strings as tf_strings\n",
""
]
},
@@ -93,15 +109,15 @@
" This prevents flow of information from future tokens to current token.\n",
" 1's in the lower triangle, counting from the lower right corner.\n",
" \"\"\"\n",
- " i = tf.range(n_dest)[:, None]\n",
- " j = tf.range(n_src)\n",
+ " i = ops.arange(n_dest)[:, None]\n",
+ " j = ops.arange(n_src)\n",
" m = i >= j - n_src + n_dest\n",
- " mask = tf.cast(m, dtype)\n",
- " mask = tf.reshape(mask, [1, n_dest, n_src])\n",
- " mult = tf.concat(\n",
- " [tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)], 0\n",
+ " mask = ops.cast(m, dtype)\n",
+ " mask = ops.reshape(mask, [1, n_dest, n_src])\n",
+ " mult = ops.concatenate(\n",
+ " [ops.expand_dims(batch_size, -1), ops.convert_to_tensor([1, 1])], 0\n",
" )\n",
- " return tf.tile(mask, mult)\n",
+ " return ops.tile(mask, mult)\n",
"\n",
"\n",
"class TransformerBlock(layers.Layer):\n",
@@ -109,7 +125,10 @@
" super().__init__()\n",
" self.att = layers.MultiHeadAttention(num_heads, embed_dim)\n",
" self.ffn = keras.Sequential(\n",
- " [layers.Dense(ff_dim, activation=\"relu\"), layers.Dense(embed_dim),]\n",
+ " [\n",
+ " layers.Dense(ff_dim, activation=\"relu\"),\n",
+ " layers.Dense(embed_dim),\n",
+ " ]\n",
" )\n",
" self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)\n",
" self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)\n",
@@ -117,10 +136,10 @@
" self.dropout2 = layers.Dropout(rate)\n",
"\n",
" def call(self, inputs):\n",
- " input_shape = tf.shape(inputs)\n",
+ " input_shape = ops.shape(inputs)\n",
" batch_size = input_shape[0]\n",
" seq_len = input_shape[1]\n",
- " causal_mask = causal_attention_mask(batch_size, seq_len, seq_len, tf.bool)\n",
+ " causal_mask = causal_attention_mask(batch_size, seq_len, seq_len, \"bool\")\n",
" attention_output = self.att(inputs, inputs, attention_mask=causal_mask)\n",
" attention_output = self.dropout1(attention_output)\n",
" out1 = self.layernorm1(inputs + attention_output)\n",
@@ -158,8 +177,8 @@
" self.pos_emb = layers.Embedding(input_dim=maxlen, output_dim=embed_dim)\n",
"\n",
" def call(self, x):\n",
- " maxlen = tf.shape(x)[-1]\n",
- " positions = tf.range(start=0, limit=maxlen, delta=1)\n",
+ " maxlen = ops.shape(x)[-1]\n",
+ " positions = ops.arange(0, maxlen, 1)\n",
" positions = self.pos_emb(positions)\n",
" x = self.token_emb(x)\n",
" return x + positions\n",
@@ -191,16 +210,17 @@
"\n",
"\n",
"def create_model():\n",
- " inputs = layers.Input(shape=(maxlen,), dtype=tf.int32)\n",
+ " inputs = layers.Input(shape=(maxlen,), dtype=\"int32\")\n",
" embedding_layer = TokenAndPositionEmbedding(maxlen, vocab_size, embed_dim)\n",
" x = embedding_layer(inputs)\n",
" transformer_block = TransformerBlock(embed_dim, num_heads, feed_forward_dim)\n",
" x = transformer_block(x)\n",
" outputs = layers.Dense(vocab_size)(x)\n",
" model = keras.Model(inputs=inputs, outputs=[outputs, x])\n",
- " loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n",
+ " loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)\n",
" model.compile(\n",
- " \"adam\", loss=[loss_fn, None],\n",
+ " \"adam\",\n",
+ " loss=[loss_fn, None],\n",
" ) # No loss and optimization based on word embeddings from transformer block\n",
" return model\n",
""
@@ -259,16 +279,16 @@
"\n",
"# Create a dataset from text files\n",
"random.shuffle(filenames)\n",
- "text_ds = tf.data.TextLineDataset(filenames)\n",
+ "text_ds = tf_data.TextLineDataset(filenames)\n",
"text_ds = text_ds.shuffle(buffer_size=256)\n",
"text_ds = text_ds.batch(batch_size)\n",
"\n",
"\n",
"def custom_standardization(input_string):\n",
- " \"\"\" Remove html line-break tags and handle punctuation \"\"\"\n",
- " lowercased = tf.strings.lower(input_string)\n",
- " stripped_html = tf.strings.regex_replace(lowercased, \"
\", \" \")\n",
- " return tf.strings.regex_replace(stripped_html, f\"([{string.punctuation}])\", r\" \\1\")\n",
+ " \"\"\"Remove html line-break tags and handle punctuation\"\"\"\n",
+ " lowercased = tf_strings.lower(input_string)\n",
+ " stripped_html = tf_strings.regex_replace(lowercased, \"
\", \" \")\n",
+ " return tf_strings.regex_replace(stripped_html, f\"([{string.punctuation}])\", r\" \\1\")\n",
"\n",
"\n",
"# Create a vectorization layer and adapt it to the text\n",
@@ -288,15 +308,15 @@
" word at position (i+1). The model will use all words up till position (i)\n",
" to predict the next word.\n",
" \"\"\"\n",
- " text = tf.expand_dims(text, -1)\n",
+ " text = tensorflow.expand_dims(text, -1)\n",
" tokenized_sentences = vectorize_layer(text)\n",
" x = tokenized_sentences[:, :-1]\n",
" y = tokenized_sentences[:, 1:]\n",
" return x, y\n",
"\n",
"\n",
- "text_ds = text_ds.map(prepare_lm_inputs_labels)\n",
- "text_ds = text_ds.prefetch(tf.data.AUTOTUNE)\n",
+ "text_ds = text_ds.map(prepare_lm_inputs_labels, num_parallel_calls=tf_data.AUTOTUNE)\n",
+ "text_ds = text_ds.prefetch(tf_data.AUTOTUNE)\n",
""
]
},
@@ -342,9 +362,9 @@
" self.k = top_k\n",
"\n",
" def sample_from(self, logits):\n",
- " logits, indices = tf.math.top_k(logits, k=self.k, sorted=True)\n",
+ " logits, indices = ops.top_k(logits, k=self.k, sorted=True)\n",
" indices = np.asarray(indices).astype(\"int32\")\n",
- " preds = keras.activations.softmax(tf.expand_dims(logits, 0))[0]\n",
+ " preds = keras.activations.softmax(ops.expand_dims(logits, 0))[0]\n",
" preds = np.asarray(preds).astype(\"float32\")\n",
" return np.random.choice(indices, p=preds)\n",
"\n",
@@ -368,7 +388,7 @@
" else:\n",
" x = start_tokens\n",
" x = np.array([x])\n",
- " y, _ = self.model.predict(x)\n",
+ " y, _ = self.model.predict(x, verbose=0)\n",
" sample_token = self.sample_from(y[0][sample_index])\n",
" tokens_generated.append(sample_token)\n",
" start_tokens.append(sample_token)\n",
@@ -445,4 +465,4 @@
},
"nbformat": 4,
"nbformat_minor": 0
-}
+}
\ No newline at end of file
diff --git a/examples/generative/md/text_generation_with_miniature_gpt.md b/examples/generative/md/text_generation_with_miniature_gpt.md
index 9425664cfc..78688a9c0b 100644
--- a/examples/generative/md/text_generation_with_miniature_gpt.md
+++ b/examples/generative/md/text_generation_with_miniature_gpt.md
@@ -36,14 +36,30 @@ with TensorFlow 2.3 or higher.
```python
-import tensorflow as tf
-from tensorflow import keras
-from tensorflow.keras import layers
-from tensorflow.keras.layers import TextVectorization
+# We set the backend to TensorFlow. The code works with
+# both `tensorflow` and `torch`. It does not work with JAX
+# due to the behavior of `jax.numpy.tile` in a jit scope
+# (used in `causal_attention_mask()`: `tile` in JAX does
+# not support a dynamic `reps` argument.
+# You can make the code work in JAX by wrapping the
+# inside of the `causal_attention_mask` function in
+# a decorator to prevent jit compilation:
+# `with jax.ensure_compile_time_eval():`.
+import os
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
+
+import keras
+from keras import layers
+from keras import ops
+from keras.layers import TextVectorization
import numpy as np
import os
import string
import random
+import tensorflow
+import tensorflow.data as tf_data
+import tensorflow.strings as tf_strings
```
@@ -59,15 +75,15 @@ def causal_attention_mask(batch_size, n_dest, n_src, dtype):
This prevents flow of information from future tokens to current token.
1's in the lower triangle, counting from the lower right corner.
"""
- i = tf.range(n_dest)[:, None]
- j = tf.range(n_src)
+ i = ops.arange(n_dest)[:, None]
+ j = ops.arange(n_src)
m = i >= j - n_src + n_dest
- mask = tf.cast(m, dtype)
- mask = tf.reshape(mask, [1, n_dest, n_src])
- mult = tf.concat(
- [tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)], 0
+ mask = ops.cast(m, dtype)
+ mask = ops.reshape(mask, [1, n_dest, n_src])
+ mult = ops.concatenate(
+ [ops.expand_dims(batch_size, -1), ops.convert_to_tensor([1, 1])], 0
)
- return tf.tile(mask, mult)
+ return ops.tile(mask, mult)
class TransformerBlock(layers.Layer):
@@ -75,7 +91,10 @@ class TransformerBlock(layers.Layer):
super().__init__()
self.att = layers.MultiHeadAttention(num_heads, embed_dim)
self.ffn = keras.Sequential(
- [layers.Dense(ff_dim, activation="relu"), layers.Dense(embed_dim),]
+ [
+ layers.Dense(ff_dim, activation="relu"),
+ layers.Dense(embed_dim),
+ ]
)
self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
@@ -83,10 +102,10 @@ class TransformerBlock(layers.Layer):
self.dropout2 = layers.Dropout(rate)
def call(self, inputs):
- input_shape = tf.shape(inputs)
+ input_shape = ops.shape(inputs)
batch_size = input_shape[0]
seq_len = input_shape[1]
- causal_mask = causal_attention_mask(batch_size, seq_len, seq_len, tf.bool)
+ causal_mask = causal_attention_mask(batch_size, seq_len, seq_len, "bool")
attention_output = self.att(inputs, inputs, attention_mask=causal_mask)
attention_output = self.dropout1(attention_output)
out1 = self.layernorm1(inputs + attention_output)
@@ -112,8 +131,8 @@ class TokenAndPositionEmbedding(layers.Layer):
self.pos_emb = layers.Embedding(input_dim=maxlen, output_dim=embed_dim)
def call(self, x):
- maxlen = tf.shape(x)[-1]
- positions = tf.range(start=0, limit=maxlen, delta=1)
+ maxlen = ops.shape(x)[-1]
+ positions = ops.arange(0, maxlen, 1)
positions = self.pos_emb(positions)
x = self.token_emb(x)
return x + positions
@@ -133,16 +152,17 @@ feed_forward_dim = 256 # Hidden layer size in feed forward network inside trans
def create_model():
- inputs = layers.Input(shape=(maxlen,), dtype=tf.int32)
+ inputs = layers.Input(shape=(maxlen,), dtype="int32")
embedding_layer = TokenAndPositionEmbedding(maxlen, vocab_size, embed_dim)
x = embedding_layer(inputs)
transformer_block = TransformerBlock(embed_dim, num_heads, feed_forward_dim)
x = transformer_block(x)
outputs = layers.Dense(vocab_size)(x)
model = keras.Model(inputs=inputs, outputs=[outputs, x])
- loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
+ loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
model.compile(
- "adam", loss=[loss_fn, None],
+ "adam",
+ loss=[loss_fn, None],
) # No loss and optimization based on word embeddings from transformer block
return model
@@ -182,16 +202,16 @@ print(f"{len(filenames)} files")
# Create a dataset from text files
random.shuffle(filenames)
-text_ds = tf.data.TextLineDataset(filenames)
+text_ds = tf_data.TextLineDataset(filenames)
text_ds = text_ds.shuffle(buffer_size=256)
text_ds = text_ds.batch(batch_size)
def custom_standardization(input_string):
- """ Remove html line-break tags and handle punctuation """
- lowercased = tf.strings.lower(input_string)
- stripped_html = tf.strings.regex_replace(lowercased, "
", " ")
- return tf.strings.regex_replace(stripped_html, f"([{string.punctuation}])", r" \1")
+ """Remove html line-break tags and handle punctuation"""
+ lowercased = tf_strings.lower(input_string)
+ stripped_html = tf_strings.regex_replace(lowercased, "
", " ")
+ return tf_strings.regex_replace(stripped_html, f"([{string.punctuation}])", r" \1")
# Create a vectorization layer and adapt it to the text
@@ -211,22 +231,22 @@ def prepare_lm_inputs_labels(text):
word at position (i+1). The model will use all words up till position (i)
to predict the next word.
"""
- text = tf.expand_dims(text, -1)
+ text = tensorflow.expand_dims(text, -1)
tokenized_sentences = vectorize_layer(text)
x = tokenized_sentences[:, :-1]
y = tokenized_sentences[:, 1:]
return x, y
-text_ds = text_ds.map(prepare_lm_inputs_labels)
-text_ds = text_ds.prefetch(tf.data.AUTOTUNE)
+text_ds = text_ds.map(prepare_lm_inputs_labels, num_parallel_calls=tf_data.AUTOTUNE)
+text_ds = text_ds.prefetch(tf_data.AUTOTUNE)
```
```
% Total % Received % Xferd Average Speed Time Time Time Current
Dload Upload Total Spent Left Speed
-100 80.2M 100 80.2M 0 0 24.2M 0 0:00:03 0:00:03 --:--:-- 24.2M
+100 80.2M 100 80.2M 0 0 7926k 0 0:00:10 0:00:10 --:--:-- 7661k
50000 files
@@ -262,9 +282,9 @@ class TextGenerator(keras.callbacks.Callback):
self.k = top_k
def sample_from(self, logits):
- logits, indices = tf.math.top_k(logits, k=self.k, sorted=True)
+ logits, indices = ops.top_k(logits, k=self.k, sorted=True)
indices = np.asarray(indices).astype("int32")
- preds = keras.activations.softmax(tf.expand_dims(logits, 0))[0]
+ preds = keras.activations.softmax(ops.expand_dims(logits, 0))[0]
preds = np.asarray(preds).astype("float32")
return np.random.choice(indices, p=preds)
@@ -288,7 +308,7 @@ class TextGenerator(keras.callbacks.Callback):
else:
x = start_tokens
x = np.array([x])
- y, _ = self.model.predict(x)
+ y, _ = self.model.predict(x, verbose=0)
sample_token = self.sample_from(y[0][sample_index])
tokens_generated.append(sample_token)
start_tokens.append(sample_token)
@@ -326,236 +346,238 @@ model.fit(text_ds, verbose=2, epochs=25, callbacks=[text_gen_callback])
```
Epoch 1/25
-391/391 - 135s - loss: 5.5949 - dense_2_loss: 5.5949
+
+WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
+I0000 00:00:1699499022.078758 633491 device_compiler.h:187] Compiled cluster using XLA! This line is logged at most once for the lifetime of the process.
+/home/mattdangerw/miniconda3/envs/keras-tensorflow/lib/python3.10/contextlib.py:153: UserWarning: Your input ran out of data; interrupting training. Make sure that your dataset or generator can generate at least `steps_per_epoch * epochs` batches. You may need to use the `.repeat()` function when building your dataset.
+ self.gen.throw(typ, value, traceback)
+
generated text:
-this movie is a great movie . the film is so many other comments . the plot and some people were [UNK] to do . i think the story is about that it is not a good movie . there are very good actors
+this movie is a good example of the [UNK] " movies , and the movie was pretty well written , i had to say that the movie made me of the [UNK] " and was very well done . i 've seen a few
```
```
+391/391 - 33s - 84ms/step - loss: 5.4696
Epoch 2/25
-391/391 - 135s - loss: 4.7108 - dense_2_loss: 4.7108
generated text:
-this movie is one of the worst movies i have ever seen . i have no doubt the better movies of this one 's worst movies i have ever seen . i don 't know what the hell , and i 'm not going
+this movie is so far the worst movies i have ever seen . it is that it just a bit of a movie but i really don 't think it is a very bad movie . it is a lot and the characters in
```
```
+391/391 - 16s - 42ms/step - loss: 4.7016
Epoch 3/25
-391/391 - 135s - loss: 4.4620 - dense_2_loss: 4.4620
generated text:
-this movie is a very good movie , i think i am not a kid . the story is a great movie . the director who is a great director who likes the director 's film . this was not funny and the director
+this movie is a classic and has a good cast in a good story . the movie itself is good at best . the acting is superb , but the story is a little bit slow , the music hall , and music is
```
```
+391/391 - 16s - 42ms/step - loss: 4.4533
Epoch 4/25
-391/391 - 136s - loss: 4.3047 - dense_2_loss: 4.3047
generated text:
-this movie is a very good story and very well . this movie is one of the worst movies i have ever seen , and there are some good actors and actresses in the movie , it is not the worst . the script
+this movie is a good , and is not the greatest movie ever since , the director has a lot of [UNK] , but it 's just a bit of the original and the plot has some decent acting , it has a bit
```
```
+391/391 - 16s - 42ms/step - loss: 4.2985
Epoch 5/25
-391/391 - 135s - loss: 4.1840 - dense_2_loss: 4.1840
generated text:
-this movie is a very good movie . it is the best thing about it 's a very good movie . it 's not funny , very , it 's so bad that it 's so funny , it 's like most romantic movie
+this movie is really bad , the acting in this movie is bad and bad . it 's not bad it . it 's a bad story about a bad film but i can 't imagine if it 's a bad ending . the
```
```
+391/391 - 17s - 42ms/step - loss: 4.1787
Epoch 6/25
-391/391 - 135s - loss: 4.0834 - dense_2_loss: 4.0834
generated text:
-this movie is the worst . the acting is awful . i have to admit that you 're just watching this film as i have to say that it is a [UNK] with [UNK] [UNK] " in the last ten years . i think
+this movie is so bad , the bad acting , everything is awful , the script is bad , and the only one that i just saw in the original [UNK] . i was hoping it could make up the sequel . it wasn
```
```
+391/391 - 17s - 42ms/step - loss: 4.0807
Epoch 7/25
-391/391 - 135s - loss: 3.9987 - dense_2_loss: 3.9987
generated text:
-this movie is really about the acting is good and the script . i don 't think this is just a waste of movie . it was so terrible that it wasn 't funny , but that 's what it was made in movies
+this movie is one of the best kung fu movies i 've ever seen , i have seen in my life that 's not for anyone who has to think of it , or maybe even though , i can 't find it funny
```
```
+391/391 - 16s - 42ms/step - loss: 3.9978
Epoch 8/25
-391/391 - 134s - loss: 3.9242 - dense_2_loss: 3.9242
generated text:
-this movie is so bad . the story itself is about a family guy named jack , who is told by a father , who is trying to get to help him to commit . he has the same problem and the [UNK] .
+this movie is just plain boring . . . . . . . . . . . . . . . . . [UNK] , the movie [UNK] . . . [UNK] . . . . . . [UNK] is a bad , it
```
```
+391/391 - 17s - 42ms/step - loss: 3.9236
Epoch 9/25
-391/391 - 135s - loss: 3.8579 - dense_2_loss: 3.8579
generated text:
-this movie is not bad , it does not deserve one . i can say that i was able to sit at , relax [UNK] . i was wrong , and i think i was able to buy the dvd , i would say
+this movie is the only good movie i think i 've never seen it again . but it 's the only thing i feel about it . the story was about the fact that it was a very good movie . the movie has
```
```
+391/391 - 17s - 42ms/step - loss: 3.8586
Epoch 10/25
-391/391 - 134s - loss: 3.7989 - dense_2_loss: 3.7989
generated text:
-this movie is very funny ! its very funny . a touching movie about three women who don 't know who is not to go on with a movie that has a lot of fun to watch . it is funny . the main
+this movie is very well written and directed . it contains some of the best [UNK] in the genre . the story is about a group of actors , especially jamie harris and danny glover who are the only good guys that is really
```
```
+391/391 - 17s - 42ms/step - loss: 3.8002
Epoch 11/25
-391/391 - 134s - loss: 3.7459 - dense_2_loss: 3.7459
generated text:
-this movie is not the best movie i 've seen in a long time . this movie was just about a guy who gets killed for one . . i saw this movie at a time when i first saw it in the movie
+this movie is so terrible . i think that the movie isn 't as bad as you should go and watch it again . there were so many clichés that it 's a very bad movie in itself . there is no story line
```
```
+391/391 - 17s - 42ms/step - loss: 3.7478
Epoch 12/25
-391/391 - 134s - loss: 3.6974 - dense_2_loss: 3.6974
generated text:
-this movie is a good example of how many films have seen and many films , that are often overlooked , in the seventies , in fact it is more enjoyable than the average viewer has some interesting parallels . this movie is based
+this movie is a total waste of money and money . i am surprised to find it very funny , very enjoyable . the plot is totally unbelievable , the acting is horrible . the story is awful , it 's not scary at
```
```
+391/391 - 17s - 42ms/step - loss: 3.6993
Epoch 13/25
-391/391 - 134s - loss: 3.6534 - dense_2_loss: 3.6534
generated text:
-this movie is so bad ! i think this is one . i really didn 't think anybody who gets the impression that the people who is trying to find themselves to be funny . . there 's the humor is no punchline ?
+this movie is so bad and not very good as it goes . it 's a nice movie and it 's so bad that it takes you back on your tv . i don 't really know how bad this movie is . you
```
```
+391/391 - 17s - 42ms/step - loss: 3.6546
Epoch 14/25
-391/391 - 134s - loss: 3.6123 - dense_2_loss: 3.6123
generated text:
-this movie is really bad . the actors are good ,the acting is great . a must see [UNK] the worst in history of all time . the plot is so bad that you can 't even make a bad movie about the bad
+this movie is a great fun story , with lots of action , and romance . if you like the action and the story is really bad . it doesn 't get the idea , but you have your heart of darkness . the
```
```
+391/391 - 17s - 42ms/step - loss: 3.6147
Epoch 15/25
-391/391 - 134s - loss: 3.5745 - dense_2_loss: 3.5745
generated text:
-this movie is one of the worst movies i 've ever had . the acting and direction are terrible . what i 've seen , i 've watched it several times , and i can 't really believe how to make a movie about
+this movie is a little more than a horror film . it 's not really a great deal , i can honestly say , a story about a group of teens that are all over the place . but this is still a fun
```
```
+391/391 - 17s - 42ms/step - loss: 3.5769
Epoch 16/25
-391/391 - 134s - loss: 3.5404 - dense_2_loss: 3.5404
generated text:
-this movie is so bad it is . that it is supposed to be a comedy . the script , which is just as bad as some movies are bad . if you 're looking for it , if you 're in the mood
+this movie is just about a guy who is supposed to be a girl in the [UNK] of a movie that doesn 't make sense . the humor is not to watch it all the way the movie is . you can 't know
```
```
+391/391 - 17s - 42ms/step - loss: 3.5425
Epoch 17/25
-391/391 - 134s - loss: 3.5083 - dense_2_loss: 3.5083
generated text:
-this movie is one of all bad movies i have a fan ever seen . i have seen a good movies , this isn 't the worst . i 've seen in a long time . the story involves twins , a priest and
+this movie is one of the best movies i 've ever seen . i was really surprised when renting it and it wasn 't even better in it , it was not even funny and i really don 't really know what i was
```
```
+391/391 - 17s - 42ms/step - loss: 3.5099
Epoch 18/25
-391/391 - 134s - loss: 3.4789 - dense_2_loss: 3.4789
generated text:
-this movie is a great movie . it 's a shame that it was hard to see that it was . this movie is a good movie . the movie itself is a complete waste of time and time you have a bad rant
+this movie is so bad . i think it 's a bit overrated . i have a lot of bad movies . i have to say that this movie was just bad . i was hoping the [UNK] . the [UNK] is good "
```
```
+391/391 - 17s - 43ms/step - loss: 3.4800
Epoch 19/25
-391/391 - 134s - loss: 3.4513 - dense_2_loss: 3.4513
generated text:
-this movie is not one of the most moving movies i have ever seen . the story is about the plot is just so ridiculous that i could have done it with the actors . the actors are great and the acting is great
+this movie is one of the best kung fu movies i 've ever seen . it was a great movie , and for the music . the graphics are really cool . it 's like a lot more than the action scenes and action
```
```
+391/391 - 17s - 42ms/step - loss: 3.4520
Epoch 20/25
-391/391 - 134s - loss: 3.4251 - dense_2_loss: 3.4251
generated text:
-this movie is about a man named todd . it is a funny movie that has a lot of nerve on screen . it is not just the right ingredients and a movie . it is a great film , and it is a
+this movie is just plain awful and stupid .i cant get the movie . i cant believe people have ever spent money and money on the [UNK] . i swear i was so embarrassed to say that i had a few lines that are
```
```
+391/391 - 17s - 42ms/step - loss: 3.4260
Epoch 21/25
-391/391 - 134s - loss: 3.4011 - dense_2_loss: 3.4011
generated text:
-this movie is not only funny , but i have never seen it before . the other comments i am not kidding or have been [UNK] and the worst movie i have to be . . there is something that is no where else
+this movie is one of those movies that i 've ever seen , and you must know that i can say that i was not impressed with this one . i found it to be an interesting one . the story of the first
```
```
+391/391 - 17s - 42ms/step - loss: 3.4014
Epoch 22/25
-391/391 - 134s - loss: 3.3787 - dense_2_loss: 3.3787
generated text:
-this movie is a very entertaining , very funny , and very funny , very well written and very nicely directed movie . this was done , very well done , with very good acting and a wonderful script , a very good movie
+this movie is about a man 's life and it is a very good film and it takes a look at some sort of movie . this movie is one of the most amazing movie you have ever had to go in , so
```
```
+391/391 - 17s - 42ms/step - loss: 3.3783
Epoch 23/25
-391/391 - 133s - loss: 3.3575 - dense_2_loss: 3.3575
generated text:
-this movie is the kind of movie you will not be disappointed . it 's like an [UNK] [UNK] , who is a movie . it 's a great story and the characters are great , the actors are good , their [UNK] ,
+this movie is a great , good thing about this movie , even the worst i 've ever seen ! it doesn 't mean anything terribly , the acting and the directing is terrible . the script is bad , the plot and the
```
```
+391/391 - 17s - 42ms/step - loss: 3.3564
Epoch 24/25
-391/391 - 134s - loss: 3.3372 - dense_2_loss: 3.3372
generated text:
-this movie is a classic 80s horror movie . this has a great premise and the characters is a bit too typical [UNK] and [UNK] " with the [UNK] " . it 's all that makes sense . the characters were shallow and unrealistic
+this movie is one of the best movies ever . [UNK] [UNK] ' is about the main character and a nobleman named fallon ; is stranded on an eccentric , falls in love when her island escaped . when , meanwhile , the escaped
```
```
+391/391 - 17s - 42ms/step - loss: 3.3362
Epoch 25/25
-391/391 - 134s - loss: 3.3182 - dense_2_loss: 3.3182
generated text:
-this movie is not the worst movie i have ever seen . it 's a movie where i 've never seen it before and i 've seen it again and again , again , i can 't believe it was made in a theatre
+this movie is very good . the acting , especially the whole movie itself - a total of the worst . this movie had a lot to recommend it to anyone . it is not funny . the story is so lame ! the
```
-
-
-
-
-
```
-
+391/391 - 17s - 42ms/step - loss: 3.3170
+
+
```
-
+
\ No newline at end of file
diff --git a/examples/generative/text_generation_with_miniature_gpt.py b/examples/generative/text_generation_with_miniature_gpt.py
index b36b527725..7287731dd8 100644
--- a/examples/generative/text_generation_with_miniature_gpt.py
+++ b/examples/generative/text_generation_with_miniature_gpt.py
@@ -30,14 +30,30 @@
"""
## Setup
"""
-import tensorflow as tf
-from tensorflow import keras
-from tensorflow.keras import layers
-from tensorflow.keras.layers import TextVectorization
+# We set the backend to TensorFlow. The code works with
+# both `tensorflow` and `torch`. It does not work with JAX
+# due to the behavior of `jax.numpy.tile` in a jit scope
+# (used in `causal_attention_mask()`: `tile` in JAX does
+# not support a dynamic `reps` argument.
+# You can make the code work in JAX by wrapping the
+# inside of the `causal_attention_mask` function in
+# a decorator to prevent jit compilation:
+# `with jax.ensure_compile_time_eval():`.
+import os
+
+os.environ["KERAS_BACKEND"] = "tensorflow"
+
+import keras
+from keras import layers
+from keras import ops
+from keras.layers import TextVectorization
import numpy as np
import os
import string
import random
+import tensorflow
+import tensorflow.data as tf_data
+import tensorflow.strings as tf_strings
"""
@@ -51,15 +67,15 @@ def causal_attention_mask(batch_size, n_dest, n_src, dtype):
This prevents flow of information from future tokens to current token.
1's in the lower triangle, counting from the lower right corner.
"""
- i = tf.range(n_dest)[:, None]
- j = tf.range(n_src)
+ i = ops.arange(n_dest)[:, None]
+ j = ops.arange(n_src)
m = i >= j - n_src + n_dest
- mask = tf.cast(m, dtype)
- mask = tf.reshape(mask, [1, n_dest, n_src])
- mult = tf.concat(
- [tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)], 0
+ mask = ops.cast(m, dtype)
+ mask = ops.reshape(mask, [1, n_dest, n_src])
+ mult = ops.concatenate(
+ [ops.expand_dims(batch_size, -1), ops.convert_to_tensor([1, 1])], 0
)
- return tf.tile(mask, mult)
+ return ops.tile(mask, mult)
class TransformerBlock(layers.Layer):
@@ -78,10 +94,10 @@ def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):
self.dropout2 = layers.Dropout(rate)
def call(self, inputs):
- input_shape = tf.shape(inputs)
+ input_shape = ops.shape(inputs)
batch_size = input_shape[0]
seq_len = input_shape[1]
- causal_mask = causal_attention_mask(batch_size, seq_len, seq_len, tf.bool)
+ causal_mask = causal_attention_mask(batch_size, seq_len, seq_len, "bool")
attention_output = self.att(inputs, inputs, attention_mask=causal_mask)
attention_output = self.dropout1(attention_output)
out1 = self.layernorm1(inputs + attention_output)
@@ -105,8 +121,8 @@ def __init__(self, maxlen, vocab_size, embed_dim):
self.pos_emb = layers.Embedding(input_dim=maxlen, output_dim=embed_dim)
def call(self, x):
- maxlen = tf.shape(x)[-1]
- positions = tf.range(start=0, limit=maxlen, delta=1)
+ maxlen = ops.shape(x)[-1]
+ positions = ops.arange(0, maxlen, 1)
positions = self.pos_emb(positions)
x = self.token_emb(x)
return x + positions
@@ -123,14 +139,14 @@ def call(self, x):
def create_model():
- inputs = layers.Input(shape=(maxlen,), dtype=tf.int32)
+ inputs = layers.Input(shape=(maxlen,), dtype="int32")
embedding_layer = TokenAndPositionEmbedding(maxlen, vocab_size, embed_dim)
x = embedding_layer(inputs)
transformer_block = TransformerBlock(embed_dim, num_heads, feed_forward_dim)
x = transformer_block(x)
outputs = layers.Dense(vocab_size)(x)
model = keras.Model(inputs=inputs, outputs=[outputs, x])
- loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
+ loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
model.compile(
"adam",
loss=[loss_fn, None],
@@ -171,16 +187,16 @@ def create_model():
# Create a dataset from text files
random.shuffle(filenames)
-text_ds = tf.data.TextLineDataset(filenames)
+text_ds = tf_data.TextLineDataset(filenames)
text_ds = text_ds.shuffle(buffer_size=256)
text_ds = text_ds.batch(batch_size)
def custom_standardization(input_string):
"""Remove html line-break tags and handle punctuation"""
- lowercased = tf.strings.lower(input_string)
- stripped_html = tf.strings.regex_replace(lowercased, "
", " ")
- return tf.strings.regex_replace(stripped_html, f"([{string.punctuation}])", r" \1")
+ lowercased = tf_strings.lower(input_string)
+ stripped_html = tf_strings.regex_replace(lowercased, "
", " ")
+ return tf_strings.regex_replace(stripped_html, f"([{string.punctuation}])", r" \1")
# Create a vectorization layer and adapt it to the text
@@ -200,15 +216,15 @@ def prepare_lm_inputs_labels(text):
word at position (i+1). The model will use all words up till position (i)
to predict the next word.
"""
- text = tf.expand_dims(text, -1)
+ text = tensorflow.expand_dims(text, -1)
tokenized_sentences = vectorize_layer(text)
x = tokenized_sentences[:, :-1]
y = tokenized_sentences[:, 1:]
return x, y
-text_ds = text_ds.map(prepare_lm_inputs_labels, num_parallel_calls=tf.data.AUTOTUNE)
-text_ds = text_ds.prefetch(tf.data.AUTOTUNE)
+text_ds = text_ds.map(prepare_lm_inputs_labels, num_parallel_calls=tf_data.AUTOTUNE)
+text_ds = text_ds.prefetch(tf_data.AUTOTUNE)
"""
@@ -240,9 +256,9 @@ def __init__(
self.k = top_k
def sample_from(self, logits):
- logits, indices = tf.math.top_k(logits, k=self.k, sorted=True)
+ logits, indices = ops.top_k(logits, k=self.k, sorted=True)
indices = np.asarray(indices).astype("int32")
- preds = keras.activations.softmax(tf.expand_dims(logits, 0))[0]
+ preds = keras.activations.softmax(ops.expand_dims(logits, 0))[0]
preds = np.asarray(preds).astype("float32")
return np.random.choice(indices, p=preds)
@@ -266,7 +282,7 @@ def on_epoch_end(self, epoch, logs=None):
else:
x = start_tokens
x = np.array([x])
- y, _ = self.model.predict(x)
+ y, _ = self.model.predict(x, verbose=0)
sample_token = self.sample_from(y[0][sample_index])
tokens_generated.append(sample_token)
start_tokens.append(sample_token)