Skip to content

Commit

Permalink
Move tensorflow lite python calls to ai-edge-litert.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 694612834
  • Loading branch information
pak-laura authored and copybara-github committed Nov 8, 2024
1 parent d349af4 commit b68584a
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 78 deletions.
4 changes: 1 addition & 3 deletions chirp/projects/zoo/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@
import tensorflow.compat.v1 as tf1
import tensorflow_hub as hub

from ai_edge_litert import interpreter as tfl_interpreter # pylint: disable=g-direct-tensorflow-import


@dataclasses.dataclass
class SeparateEmbedModel(zoo_interface.EmbeddingModel):
Expand Down Expand Up @@ -328,7 +326,7 @@ def from_config(cls, config: config_dict.ConfigDict) -> 'BirdNet':
with tempfile.NamedTemporaryFile() as tmpf:
model_file = epath.Path(config.model_path)
model_file.copy(tmpf.name, overwrite=True)
model = tfl_interpreter.Interpreter(
model = tf.lite.Interpreter(
tmpf.name, num_threads=config.num_tflite_threads
)
model.allocate_tensors()
Expand Down
5 changes: 2 additions & 3 deletions chirp/train_tests/frontend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

from absl.testing import absltest
from absl.testing import parameterized
from ai_edge_litert import interpreter as tfl_interpreter # pylint: disable=g-direct-tensorflow-import


class FrontendTest(parameterized.TestCase):
Expand Down Expand Up @@ -201,7 +200,7 @@ def test_tflite_stft_export(
tflite_float_model = converter.convert()

# Use the converted TFLite model.
interpreter = tfl_interpreter.Interpreter(model_content=tflite_float_model)
interpreter = tf.lite.Interpreter(model_content=tflite_float_model)
interpreter.allocate_tensors()
input_tensor = interpreter.get_input_details()[0]
output_tensor = interpreter.get_output_details()[0]
Expand Down Expand Up @@ -248,7 +247,7 @@ def test_simple_melspec(self):
tflite_float_model = converter.convert()

# Use the converted TFLite model.
interpreter = tfl_interpreter.Interpreter(model_content=tflite_float_model)
interpreter = tf.lite.Interpreter(model_content=tflite_float_model)
interpreter.allocate_tensors()
input_tensor = interpreter.get_input_details()[0]
output_tensor = interpreter.get_output_details()[0]
Expand Down
121 changes: 50 additions & 71 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ license = "Apache 2.0"
# See: https://python-poetry.org/docs/managing-dependencies/#dependency-groups
python = ">=3.10,<3.12"
absl-py = "^1.4.0"
ai-edge-litert = "^1.0.0"
apache-beam = {version = "^2.50.0", extras = ["gcp"]}
clu = "^0.0.9"
flax = "^0.8.1"
Expand Down

0 comments on commit b68584a

Please sign in to comment.