Skip to content

Commit

Permalink
Enforce using Keras 2
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 582676829
  • Loading branch information
TensorFlow Hub Authors authored and copybara-github committed Nov 15, 2023
1 parent ed4b52f commit 543ba05
Show file tree
Hide file tree
Showing 6 changed files with 159 additions and 103 deletions.
16 changes: 12 additions & 4 deletions tensorflow_hub/feature_column_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@
import tensorflow as tf
import tensorflow_hub as hub

# pylint: disable=g-import-not-at-top
# Use Keras 2.
version_fn = getattr(tf.keras, "version", None)
if version_fn and version_fn().startswith("3."):
import tf_keras as keras
else:
keras = tf.keras

# pylint: disable=g-direct-tensorflow-import
from tensorflow.python.feature_column import feature_column_v2
from tensorflow.python.ops.lookup_ops import HashTable
Expand Down Expand Up @@ -130,7 +138,7 @@ def testDenseFeatures(self):
with tf.Graph().as_default():
# We want to test with dense_features_v2.DenseFeatures. This symbol was
# added in https://github.com/tensorflow/tensorflow/commit/64586f18724f737393071125a91b19adf013cf8a.
feature_layer = tf.compat.v2.keras.layers.DenseFeatures(feature_columns)
feature_layer = keras.layers.DenseFeatures(feature_columns)
feature_layer_out = feature_layer(features)
with tf.compat.v1.train.MonitoredSession() as sess:
output = sess.run(feature_layer_out)
Expand All @@ -150,7 +158,7 @@ def testDenseFeatures_shareAcrossApplication(self):
with tf.Graph().as_default():
# We want to test with dense_features_v2.DenseFeatures. This symbol was
# added in https://github.com/tensorflow/tensorflow/commit/64586f18724f737393071125a91b19adf013cf8a.
feature_layer = tf.compat.v2.keras.layers.DenseFeatures(feature_columns)
feature_layer = keras.layers.DenseFeatures(feature_columns)
feature_layer_out_1 = feature_layer(features)
feature_layer_out_2 = feature_layer(features)

Expand Down Expand Up @@ -311,7 +319,7 @@ def testDenseFeatures(self):
with tf.Graph().as_default():
# We want to test with dense_features_v2.DenseFeatures. This symbol was
# added in https://github.com/tensorflow/tensorflow/commit/64586f18724f737393071125a91b19adf013cf8a.
feature_layer = tf.compat.v2.keras.layers.DenseFeatures(feature_columns)
feature_layer = keras.layers.DenseFeatures(feature_columns)
feature_layer_out = feature_layer(features)
with tf.compat.v1.train.MonitoredSession() as sess:
output = sess.run(feature_layer_out)
Expand All @@ -333,7 +341,7 @@ def testDenseFeatures_shareAcrossApplication(self):
with tf.Graph().as_default():
# We want to test with dense_features_v2.DenseFeatures. This symbol was
# added in https://github.com/tensorflow/tensorflow/commit/64586f18724f737393071125a91b19adf013cf8a.
feature_layer = tf.compat.v2.keras.layers.DenseFeatures(feature_columns)
feature_layer = keras.layers.DenseFeatures(feature_columns)
feature_layer_out_1 = feature_layer(features)
feature_layer_out_2 = feature_layer(features)

Expand Down
44 changes: 26 additions & 18 deletions tensorflow_hub/feature_column_v2_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,13 @@
import tensorflow.compat.v2 as tf
import tensorflow_hub as hub

# pylint: disable=g-import-not-at-top
# Use Keras 2.
version_fn = getattr(tf.keras, "version", None)
if version_fn and version_fn().startswith("3."):
import tf_keras as keras
else:
keras = tf.keras

# pylint: disable=g-direct-tensorflow-import
from tensorflow.python.feature_column import feature_column_v2
Expand Down Expand Up @@ -100,7 +107,7 @@ def testDenseFeaturesDirectly(self):
hub.text_embedding_column_v2("text_a", self.model, trainable=False),
hub.text_embedding_column_v2("text_b", self.model, trainable=False),
]
feature_layer = tf.keras.layers.DenseFeatures(feature_columns)
feature_layer = keras.layers.DenseFeatures(feature_columns)
feature_layer_out = feature_layer(features)
self.assertAllEqual(feature_layer_out,
[[1, 2, 3, 4, 1, 2, 3, 4], [5, 5, 5, 5, 0, 0, 0, 0]])
Expand All @@ -114,12 +121,13 @@ def testDenseFeaturesInKeras(self):
hub.text_embedding_column_v2("text", self.model, trainable=True),
]
input_features = dict(
text=tf.keras.layers.Input(name="text", shape=[None], dtype=tf.string))
dense_features = tf.keras.layers.DenseFeatures(feature_columns)
text=keras.layers.Input(name="text", shape=[None], dtype=tf.string)
)
dense_features = keras.layers.DenseFeatures(feature_columns)
x = dense_features(input_features)
x = tf.keras.layers.Dense(16, activation="relu")(x)
logits = tf.keras.layers.Dense(1, activation="linear")(x)
model = tf.keras.Model(inputs=input_features, outputs=logits)
x = keras.layers.Dense(16, activation="relu")(x)
logits = keras.layers.Dense(1, activation="linear")(x)
model = keras.Model(inputs=input_features, outputs=logits)
model.compile(
optimizer="rmsprop", loss="binary_crossentropy", metrics=["accuracy"])
model.fit(x=features, y=label, epochs=10)
Expand All @@ -135,13 +143,13 @@ def testLoadingDifferentFeatureColumnsFails(self):
]
# Build the first model.
input_features = dict(
text_1=tf.keras.layers.Input(
name="text_1", shape=[None], dtype=tf.string))
dense_features = tf.keras.layers.DenseFeatures(feature_columns)
text_1=keras.layers.Input(name="text_1", shape=[None], dtype=tf.string)
)
dense_features = keras.layers.DenseFeatures(feature_columns)
x = dense_features(input_features)
x = tf.keras.layers.Dense(16, activation="relu")(x)
logits = tf.keras.layers.Dense(1, activation="linear")(x)
model_1 = tf.keras.Model(inputs=input_features, outputs=logits)
x = keras.layers.Dense(16, activation="relu")(x)
logits = keras.layers.Dense(1, activation="linear")(x)
model_1 = keras.Model(inputs=input_features, outputs=logits)
model_1.compile(
optimizer="rmsprop", loss="binary_crossentropy", metrics=["accuracy"])
model_1.fit(x=features, y=label, epochs=10)
Expand All @@ -155,13 +163,13 @@ def testLoadingDifferentFeatureColumnsFails(self):
hub.text_embedding_column_v2("text_2", self.model, trainable=True),
]
input_features = dict(
text_2=tf.keras.layers.Input(
name="text_2", shape=[None], dtype=tf.string))
dense_features = tf.keras.layers.DenseFeatures(feature_columns)
text_2=keras.layers.Input(name="text_2", shape=[None], dtype=tf.string)
)
dense_features = keras.layers.DenseFeatures(feature_columns)
x = dense_features(input_features)
x = tf.keras.layers.Dense(16, activation="relu")(x)
logits = tf.keras.layers.Dense(1, activation="linear")(x)
model_2 = tf.keras.Model(inputs=input_features, outputs=logits)
x = keras.layers.Dense(16, activation="relu")(x)
logits = keras.layers.Dense(1, activation="linear")(x)
model_2 = keras.Model(inputs=input_features, outputs=logits)
model_2.compile(
optimizer="rmsprop", loss="binary_crossentropy", metrics=["accuracy"])

Expand Down
22 changes: 15 additions & 7 deletions tensorflow_hub/keras_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,15 @@

from tensorflow_hub import module_v2

# pylint: disable=g-direct-tensorflow-import,g-import-not-at-top
# pylint: disable=g-import-not-at-top
# Use Keras 2.
version_fn = getattr(tf.keras, "version", None)
if version_fn and version_fn().startswith("3."):
import tf_keras as keras
else:
keras = tf.keras

# pylint: disable=g-direct-tensorflow-import
from tensorflow.python.framework import smart_cond
from tensorflow.python.util import tf_inspect

Expand All @@ -33,7 +41,7 @@
# pylint: enable=g-direct-tensorflow-import,g-import-not-at-top


class KerasLayer(tf.keras.layers.Layer):
class KerasLayer(keras.layers.Layer):
"""Wraps a SavedModel (or a legacy TF1 Hub format) as a Keras Layer.
This layer wraps a callable object for use as a Keras layer. The callable
Expand All @@ -51,7 +59,7 @@ class KerasLayer(tf.keras.layers.Layer):
or a nest of tensors containing the inputs to the layer. If the callable
accepts a `training` argument, a Python boolean is passed for it. It is True
if this layer is marked trainable *and* called for training, analogous to
tf.keras.layers.BatchNormalization. (By contrast, tf.keras.layers.Dropout
keras.layers.BatchNormalization. (By contrast, keras.layers.Dropout
ignores the trainable state and applies the training argument verbatim.)
If present, the following attributes of callable are understood to have
Expand Down Expand Up @@ -86,7 +94,7 @@ class KerasLayer(tf.keras.layers.Layer):
`tf.estimator.RunConfig`. (This option was experimental from TF1.14 to TF2.1.)
Note: The data types used by a saved model have been fixed at saving time.
Using tf.keras.mixed_precision etc. has no effect on the saved model
Using keras.mixed_precision etc. has no effect on the saved model
that gets loaded by a hub.KerasLayer.
Attributes:
Expand Down Expand Up @@ -227,15 +235,15 @@ def call(self, inputs, training=None):
f = functools.partial(self._callable, *args, **kwargs)
# ...but we may also have to pass a Python boolean for `training`, which
# is the logical "and" of this layer's trainability and what the surrounding
# model is doing (analogous to tf.keras.layers.BatchNormalization in TF2).
# model is doing (analogous to keras.layers.BatchNormalization in TF2).
# For the latter, we have to look in two places: the `training` argument,
# or else Keras' global `learning_phase`, which might actually be a tensor.
if not self._has_training_argument:
result = f()
else:
if self.trainable:
if training is None:
training = tf.keras.backend.learning_phase()
training = keras.backend.learning_phase()
else:
# Behave like BatchNormalization. (Dropout is different, b/181839368.)
training = False
Expand Down Expand Up @@ -383,7 +391,7 @@ def compute_output_shape(self, input_shape):
"""Computes the output shape of the layer.
This relies on the `output_shape` provided during initialization, if any,
else falls back to the default behavior from `tf.keras.layers.Layer`.
else falls back to the default behavior from `keras.layers.Layer`.
Args:
input_shape: Shape tuple (tuple of integers) or list of shape tuples (one
Expand Down
Loading

0 comments on commit 543ba05

Please sign in to comment.