Skip to content

Commit

Permalink
Merge pull request #16 from google-research/marinazh/keras_applicatio…
Browse files Browse the repository at this point in the history
…ns_download_refactor

[Version 1.0] Support for automatically downloading RETVec embedding model from Keras Applications + small fixes in preperation for v1 release
  • Loading branch information
ebursztein committed Jun 30, 2023
2 parents cd9e4a3 + 7a8fb9b commit ce35112
Show file tree
Hide file tree
Showing 28 changed files with 152 additions and 97 deletions.
4 changes: 2 additions & 2 deletions retvec/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Copyright 2021 Google LLC
Copyright 2023 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand All @@ -14,4 +14,4 @@
limitations under the License.
"""

__version__ = "0.1.0"
__version__ = "1.0.0"
2 changes: 1 addition & 1 deletion retvec/tf/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Copyright 2021 Google LLC
Copyright 2023 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion retvec/tf/dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Copyright 2021 Google LLC
Copyright 2023 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion retvec/tf/dataset/io.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Copyright 2021 Google LLC
Copyright 2023 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion retvec/tf/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Copyright 2021 Google LLC
Copyright 2023 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion retvec/tf/layers/binarizer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Copyright 2021 Google LLC
Copyright 2023 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down
14 changes: 11 additions & 3 deletions retvec/tf/layers/embedding.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Copyright 2021 Google LLC
Copyright 2023 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand All @@ -20,6 +20,8 @@
import tensorflow as tf
from tensorflow import Tensor, TensorShape

from ..utils import RETVEC_MODEL_URLS, download_retvec_saved_model


@tf.keras.utils.register_keras_serializable(package="retvec")
class RETVecEmbedding(tf.keras.layers.Layer):
Expand All @@ -36,7 +38,8 @@ def __init__(
Args:
model: Path to saved pretrained RETVec model, str or pathlib.Path
object.
object. 'retvec-v1' to use V1 of the pre-trained RETVec word
embedding model.
trainable: Whether to make the pretrained RETVec model trainable
or to freeze all weights.
Expand Down Expand Up @@ -93,11 +96,16 @@ def _load_model(
"""Load pretrained RETVec model.
Args:
path: Path to the saved REW* model.
model: Path to saved pretrained RETVec model. Either a pre-defined
RETVec model name, str or pathlib.Path.
Returns:
The pretrained RETVec model, trainable set to `self.trainable`.
"""
path_str = str(path)
if path_str in RETVEC_MODEL_URLS:
path = download_retvec_saved_model(path_str)

model = tf.keras.models.load_model(path)
model.trainable = self.trainable
model.compile("adam", "mse")
Expand Down
2 changes: 1 addition & 1 deletion retvec/tf/layers/integerizer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Copyright 2021 Google LLC
Copyright 2023 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down
6 changes: 4 additions & 2 deletions retvec/tf/layers/tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Copyright 2021 Google LLC
Copyright 2023 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -79,7 +79,9 @@ def __init__(
`sequence_length` words.
model: Path to saved pretrained RETVec model, str or pathlib.Path
object.
object. 'retvec-v1' to use V1 of the pre-trained RETVec word
embedding model, None to use the default RETVec character
encoding.
trainable: Whether to make the pretrained RETVec model trainable
or to freeze all weights.
Expand Down
2 changes: 1 addition & 1 deletion retvec/tf/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Copyright 2021 Google LLC
Copyright 2023 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion retvec/tf/models/gau.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Copyright 2021 Google LLC
Copyright 2023 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion retvec/tf/models/layers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Copyright 2021 Google LLC
Copyright 2023 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion retvec/tf/models/outputs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Copyright 2021 Google LLC
Copyright 2023 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion retvec/tf/models/positional_embeddings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Copyright 2021 Google LLC
Copyright 2023 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion retvec/tf/models/retvec_base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Copyright 2021 Google LLC
Copyright 2023 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion retvec/tf/models/retvec_large.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Copyright 2021 Google LLC
Copyright 2023 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion retvec/tf/optimizers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Copyright 2021 Google LLC
Copyright 2023 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion retvec/tf/optimizers/warmup_cosine_decay.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Copyright 2021 Google LLC
Copyright 2023 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down
52 changes: 51 additions & 1 deletion retvec/tf/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Copyright 2021 Google LLC
Copyright 2023 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand All @@ -14,9 +14,27 @@
limitations under the License.
"""

import os
from pathlib import Path
from typing import Optional

import tensorflow as tf

RETVEC_MODEL_URLS = {
"retvec-v1": "https://storage.googleapis.com/tensorflow/keras-applications/retvec-v1"
}

# TODO (marinazh): we should download RETVec model weights instead of SavedModel files
RETVEC_COMPONENTS_HASHES = {
"retvec-v1": {
"fingerprint.pb": "5c3991599c293ba653c55e8cceae8e10815eeedea6aff75a64905cd71587d4c1",
"keras_metadata.pb": "e87e8b660ef66f8a058c4c0aa8bfaa8b683bcd4669c21e4bf71055148f8c6afc",
"saved_model.pb": "337c8e91c92946513d127b256f2872a497545186c4d2c2c09afc7d76b55454b7",
"variables.data-00000-of-00001": "22d4760b452fe8110ef2fa96b3d84186372f5259b8f6c4041a05c3ab58d93d37",
"variables.index": "431d19b7426b939c9834bb7d55d515a4ee7d7a6cda78ef0bf7b8ba03e67e480b",
}
}


def tf_cap_memory():
"""Avoid TF to hog memory before needing it"""
Expand All @@ -38,3 +56,35 @@ def clone_initializer(initializer: tf.keras.initializers.Initializer):
):
return initializer.__class__.from_config(initializer.get_config())
return initializer


def download_retvec_saved_model(
model_name: str = "retvec-v1",
cache_dir: str = "~/.keras/",
model_cache_subdir: str = "retvec-v1",
):
if model_name not in RETVEC_MODEL_URLS:
raise ValueError(f"{model_name} is not a valid RETVec model name.")

model_url = RETVEC_MODEL_URLS[model_name]
model_cache_subdir_variables = f"{model_cache_subdir}/variables"

# download model components
retvec_components = RETVEC_COMPONENTS_HASHES[model_name]
for component_name in retvec_components.keys():
if "variables" in component_name:
origin = f"{model_url}/variables/{component_name}"
cache_subdir = model_cache_subdir_variables
else:
origin = f"{model_url}/{component_name}"
cache_subdir = model_cache_subdir

tf.keras.utils.get_file(
origin=origin,
extract=True,
cache_subdir=cache_subdir,
file_hash=retvec_components[component_name],
)

retvec_model_dir = cache_dir + model_cache_subdir
return Path(retvec_model_dir).expanduser()
2 changes: 0 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,6 @@ def get_version(rel_path):
classifiers=[
"Development Status :: 3 - Alpha",
"Environment :: Console",
"Framework :: TensorFlow",
"Framework :: Torch",
"License :: OSI Approved :: Apache Software License",
"Intended Audience :: Science/Research",
"Programming Language :: Python :: 3",
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Copyright 2021 Google LLC
Copyright 2023 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down
2 changes: 1 addition & 1 deletion tests/tf/layers/test_binarizer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
Copyright 2021 Google LLC
Copyright 2023 Google LLC
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Expand Down
Loading

0 comments on commit ce35112

Please sign in to comment.