Skip to content

Commit

Permalink
Merge pull request #825 from khanhlvg:migrate-base-options-to-dataclass
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 451496850
  • Loading branch information
tflite-support-robot committed May 27, 2022
2 parents 47958bf + 04cd05e commit fba267e
Show file tree
Hide file tree
Showing 27 changed files with 159 additions and 67 deletions.
4 changes: 2 additions & 2 deletions tensorflow_lite_support/python/task/audio/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ py_library(
"//tensorflow_lite_support/python/task/audio/core:tensor_audio",
"//tensorflow_lite_support/python/task/audio/core/pybinds:_pywrap_audio_buffer",
"//tensorflow_lite_support/python/task/audio/pybinds:_pywrap_audio_embedder",
"//tensorflow_lite_support/python/task/core/proto:base_options_py_pb2",
"//tensorflow_lite_support/python/task/core:base_options",
"//tensorflow_lite_support/python/task/processor/proto:embedding_options_pb2",
"//tensorflow_lite_support/python/task/processor/proto:embedding_pb2",
],
Expand All @@ -31,7 +31,7 @@ py_library(
"//tensorflow_lite_support/python/task/audio/core:tensor_audio",
"//tensorflow_lite_support/python/task/audio/core/pybinds:_pywrap_audio_buffer",
"//tensorflow_lite_support/python/task/audio/pybinds:_pywrap_audio_classifier",
"//tensorflow_lite_support/python/task/core/proto:base_options_py_pb2",
"//tensorflow_lite_support/python/task/core:base_options",
"//tensorflow_lite_support/python/task/processor/proto:classification_options_pb2",
"//tensorflow_lite_support/python/task/processor/proto:classifications_pb2",
],
Expand Down
6 changes: 3 additions & 3 deletions tensorflow_lite_support/python/task/audio/audio_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@
from tensorflow_lite_support.python.task.audio.core import tensor_audio
from tensorflow_lite_support.python.task.audio.core.pybinds import _pywrap_audio_buffer
from tensorflow_lite_support.python.task.audio.pybinds import _pywrap_audio_classifier
from tensorflow_lite_support.python.task.core.proto import base_options_pb2
from tensorflow_lite_support.python.task.core import base_options as base_options_module
from tensorflow_lite_support.python.task.processor.proto import classification_options_pb2
from tensorflow_lite_support.python.task.processor.proto import classifications_pb2

_CppAudioFormat = _pywrap_audio_buffer.AudioFormat
_CppAudioBuffer = _pywrap_audio_buffer.AudioBuffer
_CppAudioClassifier = _pywrap_audio_classifier.AudioClassifier
_ClassificationOptions = classification_options_pb2.ClassificationOptions
_BaseOptions = base_options_pb2.BaseOptions
_BaseOptions = base_options_module.BaseOptions


@dataclasses.dataclass
Expand Down Expand Up @@ -83,7 +83,7 @@ def create_from_options(cls,
RuntimeError: If other types of error occurred.
"""
classifier = _CppAudioClassifier.create_from_options(
options.base_options, options.classification_options.to_pb2())
options.base_options.to_pb2(), options.classification_options.to_pb2())
return cls(options, classifier)

def create_input_tensor_audio(self) -> tensor_audio.TensorAudio:
Expand Down
6 changes: 3 additions & 3 deletions tensorflow_lite_support/python/task/audio/audio_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@
from tensorflow_lite_support.python.task.audio.core import tensor_audio
from tensorflow_lite_support.python.task.audio.core.pybinds import _pywrap_audio_buffer
from tensorflow_lite_support.python.task.audio.pybinds import _pywrap_audio_embedder
from tensorflow_lite_support.python.task.core.proto import base_options_pb2
from tensorflow_lite_support.python.task.core import base_options as base_options_module
from tensorflow_lite_support.python.task.processor.proto import embedding_options_pb2
from tensorflow_lite_support.python.task.processor.proto import embedding_pb2

_CppAudioFormat = _pywrap_audio_buffer.AudioFormat
_CppAudioBuffer = _pywrap_audio_buffer.AudioBuffer
_CppAudioEmbedder = _pywrap_audio_embedder.AudioEmbedder
_BaseOptions = base_options_pb2.BaseOptions
_BaseOptions = base_options_module.BaseOptions
_EmbeddingOptions = embedding_options_pb2.EmbeddingOptions


Expand Down Expand Up @@ -82,7 +82,7 @@ def create_from_options(cls,
RuntimeError: If other types of error occurred.
"""
embedder = _CppAudioEmbedder.create_from_options(
options.base_options, options.embedding_options.to_pb2())
options.base_options.to_pb2(), options.embedding_options.to_pb2())
return cls(options, embedder)

def create_input_tensor_audio(self) -> tensor_audio.TensorAudio:
Expand Down
9 changes: 9 additions & 0 deletions tensorflow_lite_support/python/task/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,12 @@ py_library(
"optional_dependencies.py",
],
)

py_library(
name = "base_options",
srcs = ["base_options.py"],
deps = [
":optional_dependencies",
"//tensorflow_lite_support/python/task/core/proto:base_options_py_pb2",
],
)
84 changes: 84 additions & 0 deletions tensorflow_lite_support/python/task/core/base_options.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base options for task APIs."""

import dataclasses
from typing import Any, Optional

from tensorflow_lite_support.python.task.core.optional_dependencies import doc_controls
from tensorflow_lite_support.python.task.core.proto import base_options_pb2

_BaseOptionsProto = base_options_pb2.BaseOptions


@dataclasses.dataclass
class BaseOptions:
"""Base options for TensorFlow Lite Task Library's Python APIs.
Represents external files used by the Task APIs (e.g. TF Lite FlatBuffer or
plain-text labels file). The files can be specified by one of the following
two ways:
(1) file contents loaded in `file_content`.
(2) file path in `file_name`.
If more than one field of these fields is provided, they are used in this
precedence order.
Attributes:
file_name: Path to the index.
file_content: The index file contents as bytes.
num_threads: Number of thread, the default value is -1 which means
Interpreter will decide what is the most appropriate `num_threads`.
use_coral: If true, inference will be delegated to a connected Coral Edge
TPU device.
"""

file_name: Optional[str] = None
file_content: Optional[bytes] = None
num_threads: Optional[int] = -1
use_coral: Optional[bool] = None

@doc_controls.do_not_generate_docs
def to_pb2(self) -> _BaseOptionsProto:
"""Generates a protobuf object to pass to the C++ layer."""
return _BaseOptionsProto(
file_name=self.file_name,
file_content=self.file_content,
num_threads=self.num_threads,
use_coral=self.use_coral)

@classmethod
@doc_controls.do_not_generate_docs
def create_from_pb2(cls, pb2_obj: _BaseOptionsProto) -> "BaseOptions":
"""Creates a `BaseOptions` object from the given protobuf object."""
return BaseOptions(
file_name=pb2_obj.file_name,
file_content=pb2_obj.file_content,
num_threads=pb2_obj.num_threads,
use_coral=pb2_obj.use_coral)

def __eq__(self, other: Any) -> bool:
"""Checks if this object is equal to the given object.
Args:
other: The object to be compared with.
Returns:
True if the objects are equal.
"""
if not isinstance(other, BaseOptions):
return False

return self.to_pb2().__eq__(other.to_pb2())
4 changes: 2 additions & 2 deletions tensorflow_lite_support/python/task/text/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ py_library(
],
visibility = ["//visibility:public"],
deps = [
"//tensorflow_lite_support/python/task/core/proto:base_options_py_pb2",
"//tensorflow_lite_support/python/task/core:base_options",
"//tensorflow_lite_support/python/task/processor/proto:embedding_options_pb2",
"//tensorflow_lite_support/python/task/processor/proto:embedding_pb2",
"//tensorflow_lite_support/python/task/text/pybinds:_pywrap_text_embedder",
Expand All @@ -26,7 +26,7 @@ py_library(
],
visibility = ["//visibility:public"],
deps = [
"//tensorflow_lite_support/python/task/core/proto:base_options_py_pb2",
"//tensorflow_lite_support/python/task/core:base_options",
"//tensorflow_lite_support/python/task/processor/proto:embedding_options_pb2",
"//tensorflow_lite_support/python/task/processor/proto:search_options_pb2",
"//tensorflow_lite_support/python/task/processor/proto:search_result_pb2",
Expand Down
6 changes: 3 additions & 3 deletions tensorflow_lite_support/python/task/text/text_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@

import dataclasses

from tensorflow_lite_support.python.task.core.proto import base_options_pb2
from tensorflow_lite_support.python.task.core import base_options as base_options_module
from tensorflow_lite_support.python.task.processor.proto import embedding_options_pb2
from tensorflow_lite_support.python.task.processor.proto import embedding_pb2
from tensorflow_lite_support.python.task.text.pybinds import _pywrap_text_embedder

_CppTextEmbedder = _pywrap_text_embedder.TextEmbedder
_BaseOptions = base_options_pb2.BaseOptions
_BaseOptions = base_options_module.BaseOptions
_EmbeddingOptions = embedding_options_pb2.EmbeddingOptions


Expand Down Expand Up @@ -75,7 +75,7 @@ def create_from_options(cls, options: TextEmbedderOptions) -> "TextEmbedder":
RuntimeError: If other types of error occurred.
"""
embedder = _CppTextEmbedder.create_from_options(
options.base_options, options.embedding_options.to_pb2())
options.base_options.to_pb2(), options.embedding_options.to_pb2())
return cls(options, embedder)

def embed(self, text: str) -> embedding_pb2.EmbeddingResult:
Expand Down
6 changes: 3 additions & 3 deletions tensorflow_lite_support/python/task/text/text_searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@
import dataclasses
from typing import Optional

from tensorflow_lite_support.python.task.core.proto import base_options_pb2
from tensorflow_lite_support.python.task.core import base_options as base_options_module
from tensorflow_lite_support.python.task.processor.proto import embedding_options_pb2
from tensorflow_lite_support.python.task.processor.proto import search_options_pb2
from tensorflow_lite_support.python.task.processor.proto import search_result_pb2
from tensorflow_lite_support.python.task.text.pybinds import _pywrap_text_searcher

_CppTextSearcher = _pywrap_text_searcher.TextSearcher
_BaseOptions = base_options_pb2.BaseOptions
_BaseOptions = base_options_module.BaseOptions
_EmbeddingOptions = embedding_options_pb2.EmbeddingOptions
_SearchOptions = search_options_pb2.SearchOptions

Expand Down Expand Up @@ -90,7 +90,7 @@ def create_from_options(cls, options: TextSearcherOptions) -> "TextSearcher":
RuntimeError: If other types of error occurred.
"""
searcher = _CppTextSearcher.create_from_options(
options.base_options, options.embedding_options.to_pb2(),
options.base_options.to_pb2(), options.embedding_options.to_pb2(),
options.search_options.to_pb2())
return cls(options, searcher)

Expand Down
11 changes: 5 additions & 6 deletions tensorflow_lite_support/python/task/vision/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ py_library(
"image_embedder.py",
],
deps = [
"//tensorflow_lite_support/python/task/core/proto:base_options_py_pb2",
"//tensorflow_lite_support/python/task/core:base_options",
"//tensorflow_lite_support/python/task/processor/proto:bounding_box_pb2",
"//tensorflow_lite_support/python/task/processor/proto:embedding_options_pb2",
"//tensorflow_lite_support/python/task/processor/proto:embedding_pb2",
Expand All @@ -27,7 +27,7 @@ py_library(
"image_classifier.py",
],
deps = [
"//tensorflow_lite_support/python/task/core/proto:base_options_py_pb2",
"//tensorflow_lite_support/python/task/core:base_options",
"//tensorflow_lite_support/python/task/processor/proto:bounding_box_pb2",
"//tensorflow_lite_support/python/task/processor/proto:classification_options_pb2",
"//tensorflow_lite_support/python/task/processor/proto:classifications_pb2",
Expand All @@ -43,7 +43,7 @@ py_library(
"image_segmenter.py",
],
deps = [
"//tensorflow_lite_support/python/task/core/proto:base_options_py_pb2",
"//tensorflow_lite_support/python/task/core:base_options",
"//tensorflow_lite_support/python/task/processor/proto:segmentation_options_pb2",
"//tensorflow_lite_support/python/task/processor/proto:segmentations_pb2",
"//tensorflow_lite_support/python/task/vision/core:tensor_image",
Expand All @@ -58,7 +58,7 @@ py_library(
"image_searcher.py",
],
deps = [
"//tensorflow_lite_support/python/task/core/proto:base_options_py_pb2",
"//tensorflow_lite_support/python/task/core:base_options",
"//tensorflow_lite_support/python/task/processor/proto:bounding_box_pb2",
"//tensorflow_lite_support/python/task/processor/proto:embedding_options_pb2",
"//tensorflow_lite_support/python/task/processor/proto:search_options_pb2",
Expand All @@ -75,8 +75,7 @@ py_library(
"object_detector.py",
],
deps = [
"//tensorflow_lite_support/python/task/core/proto:base_options_py_pb2",
"//tensorflow_lite_support/python/task/processor/proto:bounding_box_pb2",
"//tensorflow_lite_support/python/task/core:base_options",
"//tensorflow_lite_support/python/task/processor/proto:detection_options_pb2",
"//tensorflow_lite_support/python/task/processor/proto:detections_pb2",
"//tensorflow_lite_support/python/task/vision/core:tensor_image",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import dataclasses
from typing import Optional

from tensorflow_lite_support.python.task.core.proto import base_options_pb2
from tensorflow_lite_support.python.task.core import base_options as base_options_module
from tensorflow_lite_support.python.task.processor.proto import bounding_box_pb2
from tensorflow_lite_support.python.task.processor.proto import classification_options_pb2
from tensorflow_lite_support.python.task.processor.proto import classifications_pb2
Expand All @@ -26,7 +26,7 @@

_CppImageClassifier = _pywrap_image_classifier.ImageClassifier
_ClassificationOptions = classification_options_pb2.ClassificationOptions
_BaseOptions = base_options_pb2.BaseOptions
_BaseOptions = base_options_module.BaseOptions


@dataclasses.dataclass
Expand Down Expand Up @@ -80,7 +80,7 @@ def create_from_options(cls,
RuntimeError: If other types of error occurred.
"""
classifier = _CppImageClassifier.create_from_options(
options.base_options, options.classification_options.to_pb2())
options.base_options.to_pb2(), options.classification_options.to_pb2())
return cls(options, classifier)

def classify(
Expand Down
6 changes: 3 additions & 3 deletions tensorflow_lite_support/python/task/vision/image_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import dataclasses
from typing import Optional

from tensorflow_lite_support.python.task.core.proto import base_options_pb2
from tensorflow_lite_support.python.task.core import base_options as base_options_module
from tensorflow_lite_support.python.task.processor.proto import bounding_box_pb2
from tensorflow_lite_support.python.task.processor.proto import embedding_options_pb2
from tensorflow_lite_support.python.task.processor.proto import embedding_pb2
Expand All @@ -25,7 +25,7 @@
from tensorflow_lite_support.python.task.vision.pybinds import _pywrap_image_embedder

_CppImageEmbedder = _pywrap_image_embedder.ImageEmbedder
_BaseOptions = base_options_pb2.BaseOptions
_BaseOptions = base_options_module.BaseOptions
_EmbeddingOptions = embedding_options_pb2.EmbeddingOptions


Expand Down Expand Up @@ -82,7 +82,7 @@ def create_from_options(cls,
RuntimeError: If other types of error occurred.
"""
embedder = _CppImageEmbedder.create_from_options(
options.base_options, options.embedding_options.to_pb2())
options.base_options.to_pb2(), options.embedding_options.to_pb2())
return cls(options, embedder)

def embed(
Expand Down
6 changes: 3 additions & 3 deletions tensorflow_lite_support/python/task/vision/image_searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import dataclasses
from typing import Optional

from tensorflow_lite_support.python.task.core.proto import base_options_pb2
from tensorflow_lite_support.python.task.core import base_options as base_options_module
from tensorflow_lite_support.python.task.processor.proto import bounding_box_pb2
from tensorflow_lite_support.python.task.processor.proto import embedding_options_pb2
from tensorflow_lite_support.python.task.processor.proto import search_options_pb2
Expand All @@ -26,7 +26,7 @@
from tensorflow_lite_support.python.task.vision.pybinds import _pywrap_image_searcher

_CppImageSearcher = _pywrap_image_searcher.ImageSearcher
_BaseOptions = base_options_pb2.BaseOptions
_BaseOptions = base_options_module.BaseOptions
_EmbeddingOptions = embedding_options_pb2.EmbeddingOptions
_SearchOptions = search_options_pb2.SearchOptions

Expand Down Expand Up @@ -95,7 +95,7 @@ def create_from_options(cls,
RuntimeError: If other types of error occurred.
"""
searcher = _CppImageSearcher.create_from_options(
options.base_options, options.embedding_options.to_pb2(),
options.base_options.to_pb2(), options.embedding_options.to_pb2(),
options.search_options.to_pb2())
return cls(options, searcher)

Expand Down
6 changes: 3 additions & 3 deletions tensorflow_lite_support/python/task/vision/image_segmenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import dataclasses

from tensorflow_lite_support.python.task.core.proto import base_options_pb2
from tensorflow_lite_support.python.task.core import base_options as base_options_module
from tensorflow_lite_support.python.task.processor.proto import segmentation_options_pb2
from tensorflow_lite_support.python.task.processor.proto import segmentations_pb2
from tensorflow_lite_support.python.task.vision.core import tensor_image
Expand All @@ -24,7 +24,7 @@

_CppImageSegmenter = _pywrap_image_segmenter.ImageSegmenter
_SegmentationOptions = segmentation_options_pb2.SegmentationOptions
_BaseOptions = base_options_pb2.BaseOptions
_BaseOptions = base_options_module.BaseOptions


@dataclasses.dataclass
Expand Down Expand Up @@ -76,7 +76,7 @@ def create_from_options(cls,
RuntimeError: If other types of error occurred.
"""
segmenter = _CppImageSegmenter.create_from_options(
options.base_options, options.segmentation_options.to_pb2())
options.base_options.to_pb2(), options.segmentation_options.to_pb2())
return cls(options, segmenter)

def segment(
Expand Down
Loading

0 comments on commit fba267e

Please sign in to comment.