Skip to content

Commit

Permalink
Merge pull request #24 from google-research/marinazh/notebooks
Browse files Browse the repository at this point in the history
TF Lite Native Support + Colab
  • Loading branch information
ebursztein authored Oct 12, 2023
2 parents b381c8c + ad901d1 commit 2ce44b7
Show file tree
Hide file tree
Showing 17 changed files with 868 additions and 651 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests-tensorflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: [3.7, 3.8, 3.9]
python-version: ['3.8', '3.9', '3.10']

steps:
- uses: actions/checkout@v2
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ Detailed example colabs for RETVec can be found at under [notebooks](notebooks/)
We have the following example colabs:

- Training RETVec-based models using TensorFlow: [train_hello_world_tf.ipynb](notebooks/train_hello_world_tf.ipynb) for GPU/CPU training, and [train_tpu.ipynb](notebooks/train_tpu.ipynb) for a TPU-compatible training example.
- (Coming soon!) Converting RETVec models into TF Lite models to run on-device.
- Converting RETVec models into TF Lite models to run on-device: [tf_lite_retvec.ipynb](notebooks/tf_lite_retvec.ipynb)
- (Coming soon!) Using RETVec JS to deploy RETVec models in the web using TensorFlow.js

## Citing
Expand Down
2 changes: 1 addition & 1 deletion notebooks/demo_models/emotion_model/fingerprint.pb
Original file line number Diff line number Diff line change
@@ -1 +1 @@
�ɫ߀ٰ�*���°�����ە���� �à����N(֜�׸���}2
������͘�ũ��������ە���� ���՗����(����ؚ��b2
48 changes: 24 additions & 24 deletions notebooks/demo_models/emotion_model/keras_metadata.pb

Large diffs are not rendered by default.

Binary file modified notebooks/demo_models/emotion_model/saved_model.pb
Binary file not shown.
Binary file not shown.
Binary file modified notebooks/demo_models/emotion_model/variables/variables.index
Binary file not shown.
409 changes: 409 additions & 0 deletions notebooks/tf_lite_retvec.ipynb

Large diffs are not rendered by default.

245 changes: 108 additions & 137 deletions notebooks/train_retvec_model_tf.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion retvec/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@
limitations under the License.
"""

__version__ = "1.0.1"
__version__ = "1.0.2"
130 changes: 101 additions & 29 deletions retvec/tf/layers/binarizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,42 @@
limitations under the License.
"""

import logging
import re
from typing import Any, Dict, List, Union

import tensorflow as tf
from tensorflow import Tensor, TensorShape

try:
from tensorflow_text import utf8_binarize
except ImportError:
utf8_binarize = None

from .integerizer import RETVecIntegerizer


def _reshape_embeddings(
embeddings: tf.Tensor,
batch_size: int,
sequence_length: int,
word_length: int,
encoding_size: int,
) -> tf.Tensor:
if sequence_length > 1:
return tf.reshape(
embeddings,
(
batch_size,
sequence_length,
word_length,
encoding_size,
),
)
else:
return tf.reshape(embeddings, (batch_size, word_length, encoding_size))


@tf.keras.utils.register_keras_serializable(package="retvec")
class RETVecIntToBinary(tf.keras.layers.Layer):
"""Convert Unicode integer code points to their float binary encoding."""
Expand Down Expand Up @@ -80,22 +108,13 @@ def call(self, inputs: Tensor) -> Tensor:
embeddings = tf.cast(embeddings, dtype="float32")

# reshape back to correct shape
if self.sequence_length > 1:
embeddings = tf.reshape(
embeddings,
(
batch_size,
self.sequence_length,
self.word_length,
self.encoding_size,
),
)
else:
embeddings = tf.reshape(
embeddings, (batch_size, self.word_length, self.encoding_size)
)

return embeddings
return _reshape_embeddings(
embeddings,
batch_size=batch_size,
sequence_length=self.sequence_length,
word_length=self.word_length,
encoding_size=self.encoding_size,
)

def _project(self, chars: Tensor, masks: Tensor) -> Tensor:
"""Project chars in subspace"""
Expand Down Expand Up @@ -133,6 +152,7 @@ def __init__(
encoding_size: int = 24,
encoding_type: str = "UTF-8",
replacement_char: int = 65533,
use_native_tf_ops: bool = False,
**kwargs
) -> None:
"""Initialize a RETVec binarizer.
Expand Down Expand Up @@ -163,20 +183,43 @@ def __init__(
replacement_char: The replacement Unicode integer codepoint to be
used in place of invalid substrings in the input.
use_native_tf_ops: A boolean indicating whether to use
`tensorflow_text.utf8_binarize` whenever possible
(limited by its availability and constraints).
**kwargs: Additional keyword args passed to the base Layer class.
"""
super().__init__(**kwargs)
self.word_length = word_length
self.encoding_size = encoding_size
self.encoding_type = encoding_type
self.replacement_char = replacement_char
self.use_native_tf_ops = use_native_tf_ops

# Check if the native `utf8_binarize` op is available for use.
is_utf8_encoding = re.match("^utf-?8$", encoding_type, re.IGNORECASE)
self._native_mode = (
use_native_tf_ops
and is_utf8_encoding
and utf8_binarize is not None
)
if use_native_tf_ops and not self._native_mode:
logging.warning(
"Native support for `RETVecBinarizer` unavailable. "
"Check `tensorflow_text.utf8_binarize` availability"
" and its parameter contraints."
)

# Set to True when 'binarize()' is called in eager mode
self.eager = False
self._integerizer = RETVecIntegerizer(
word_length=self.word_length,
encoding_type=self.encoding_type,
replacement_char=self.replacement_char,
self._integerizer = (
None
if self._native_mode
else RETVecIntegerizer(
word_length=self.word_length,
encoding_type=self.encoding_type,
replacement_char=self.replacement_char,
)
)

def build(
Expand All @@ -186,22 +229,49 @@ def build(

# Initialize int binarizer layer here since we know sequence_length
# only once we known the input_shape
self._int_to_binary = RETVecIntToBinary(
word_length=self.word_length,
sequence_length=self.sequence_length,
encoding_size=self.encoding_size,
self._int_to_binary = (
None
if self._native_mode
else RETVecIntToBinary(
word_length=self.word_length,
sequence_length=self.sequence_length,
encoding_size=self.encoding_size,
)
)

def call(self, inputs: Tensor) -> Tensor:
char_encodings = self._integerizer(inputs)
embeddings = self._int_to_binary(char_encodings)
return embeddings
if self._native_mode:
embeddings = utf8_binarize(
inputs,
word_length=self.word_length,
bits_per_char=self.encoding_size,
replacement_char=self.replacement_char,
)
batch_size = tf.shape(inputs)[0]
embeddings = _reshape_embeddings(
embeddings,
batch_size=batch_size,
sequence_length=self.sequence_length,
word_length=self.word_length,
encoding_size=self.encoding_size,
)
# TODO (marinazh): little vs big-endian order mismatch
return tf.reverse(embeddings, axis=[-1])

else:
assert self._integerizer is not None
char_encodings = self._integerizer(inputs)

assert self._int_to_binary is not None
embeddings = self._int_to_binary(char_encodings)

return embeddings

def binarize(self, inputs: Tensor) -> Tensor:
"""Return binary encodings for a word or a list of words.
Args:
inputs: A single word or list of words to encode.
inputs: Tensor of a single word or list of words to encode.
Returns:
RETVec binary encodings for the input words(s).
Expand All @@ -212,7 +282,8 @@ def binarize(self, inputs: Tensor) -> Tensor:

# set layers to eager mode
self.eager = True
self._integerizer.eager = True
if self._integerizer is not None:
self._integerizer.eager = True

# apply binarization
embeddings = self(inputs)
Expand All @@ -233,6 +304,7 @@ def get_config(self) -> Dict[str, Any]:
"encoding_size": self.encoding_size,
"encoding_type": self.encoding_type,
"replacement_char": self.replacement_char,
"use_native_tf_ops": self.use_native_tf_ops,
}
)
return config
Loading

0 comments on commit 2ce44b7

Please sign in to comment.