Skip to content

Commit

Permalink
Move custom text OpResolver from examples folder to utils.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 452041108
  • Loading branch information
tensorflower-gardener authored and tflite-support-robot committed May 31, 2022
1 parent fba267e commit 9aa3b1a
Show file tree
Hide file tree
Showing 16 changed files with 51 additions and 53 deletions.
16 changes: 16 additions & 0 deletions tensorflow_lite_support/cc/task/text/utils/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,19 @@ cc_library_with_tflite(
"@com_google_absl//absl/strings:str_format",
],
)

cc_library_with_tflite(
name = "text_op_resolver",
srcs = ["text_op_resolver.cc"],
hdrs = ["text_op_resolver.h"],
tflite_deps = [
"@org_tensorflow//tensorflow/lite/core/shims:builtin_ops",
],
deps = [
"//tensorflow_lite_support/custom_ops/kernel/ragged:ragged_tensor_to_tensor_tflite", # fixdeps: keep
"//tensorflow_lite_support/custom_ops/kernel/sentencepiece:sentencepiece_tokenizer_tflite", # fixdeps: keep
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@org_tensorflow//tensorflow/lite/core/api:op_resolver",
],
)
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow_lite_support/examples/task/text/desktop/universal_sentence_encoder_qa_op_resolver.h"
#include "tensorflow_lite_support/cc/task/text/utils/text_op_resolver.h"

#include "absl/memory/memory.h" // from @com_google_absl
#include "tensorflow/lite/core/shims/cc/kernels/register.h"
Expand All @@ -30,8 +30,7 @@ namespace tflite {
namespace task {
namespace text {

// Creates custom op resolver for USE QA task.
std::unique_ptr<tflite::OpResolver> CreateQACustomOpResolver() {
std::unique_ptr<tflite::OpResolver> CreateTextOpResolver() {
auto resolver =
absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>();
resolver->AddCustom(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,16 @@ limitations under the License.

#include <memory>

#include "tensorflow/lite/op_resolver.h"
#include "tensorflow/lite/core/api/op_resolver.h"

namespace tflite {
namespace task {
namespace text {

std::unique_ptr<tflite::OpResolver> CreateQACustomOpResolver();
// Creates a custom OpResolver containing the additional SENTENCEPIECE_TOKENIZER
// and RAGGED_TENSOR_TO_TENSOR ops needed by some text embedders such as
// universal sentence encoder-based models.
std::unique_ptr<tflite::OpResolver> CreateTextOpResolver();

} // namespace text
} // namespace task
Expand Down
6 changes: 3 additions & 3 deletions tensorflow_lite_support/cc/test/task/text/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ cc_test(
"//tensorflow_lite_support/cc/task/core:task_utils",
"//tensorflow_lite_support/cc/task/text:universal_sentence_encoder_qa",
"//tensorflow_lite_support/cc/task/text/proto:retrieval_cc_proto",
"//tensorflow_lite_support/cc/task/text/utils:text_op_resolver",
"//tensorflow_lite_support/cc/test:test_utils",
"//tensorflow_lite_support/examples/task/text/desktop:universal_sentence_encoder_qa_op_resolver",
"@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
],
)
Expand All @@ -76,8 +76,8 @@ cc_test(
"//tensorflow_lite_support/cc/task/core/proto:base_options_proto_inc",
"//tensorflow_lite_support/cc/task/processor/proto:embedding_options_cc_proto",
"//tensorflow_lite_support/cc/task/text:text_embedder",
"//tensorflow_lite_support/cc/task/text/utils:text_op_resolver",
"//tensorflow_lite_support/cc/test:test_utils",
"//tensorflow_lite_support/examples/task/text/desktop:universal_sentence_encoder_qa_op_resolver",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@org_tensorflow//tensorflow/lite/core/shims:cc_shims_test_util",
Expand All @@ -104,8 +104,8 @@ cc_test(
"//tensorflow_lite_support/cc/task/processor/proto:search_result_cc_proto",
"//tensorflow_lite_support/cc/task/text:text_searcher",
"//tensorflow_lite_support/cc/task/text/proto:text_searcher_options_cc_proto",
"//tensorflow_lite_support/cc/task/text/utils:text_op_resolver",
"//tensorflow_lite_support/cc/test:test_utils",
"//tensorflow_lite_support/examples/task/text/desktop:universal_sentence_encoder_qa_op_resolver",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ limitations under the License.
#include "tensorflow_lite_support/cc/port/status_matchers.h"
#include "tensorflow_lite_support/cc/task/core/proto/base_options_proto_inc.h"
#include "tensorflow_lite_support/cc/task/processor/proto/embedding_options.pb.h"
#include "tensorflow_lite_support/cc/task/text/utils/text_op_resolver.h"
#include "tensorflow_lite_support/cc/test/test_utils.h"
#include "tensorflow_lite_support/examples/task/text/desktop/universal_sentence_encoder_qa_op_resolver.h"

namespace tflite {
namespace task {
Expand Down Expand Up @@ -195,7 +195,7 @@ TEST(EmbedTest, SucceedsWithUniversalSentenceEncoder) {
// No Embedding options means all head get a default option.
SUPPORT_ASSERT_OK_AND_ASSIGN(
std::unique_ptr<TextEmbedder> text_embedder,
TextEmbedder::CreateFromOptions(options, CreateQACustomOpResolver()));
TextEmbedder::CreateFromOptions(options, CreateTextOpResolver()));

SUPPORT_ASSERT_OK_AND_ASSIGN(
auto result0,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,8 @@ limitations under the License.
#include "tensorflow_lite_support/cc/task/processor/proto/search_options.pb.h"
#include "tensorflow_lite_support/cc/task/processor/proto/search_result.pb.h"
#include "tensorflow_lite_support/cc/task/text/proto/text_searcher_options.pb.h"
#include "tensorflow_lite_support/cc/task/text/utils/text_op_resolver.h"
#include "tensorflow_lite_support/cc/test/test_utils.h"
#include "tensorflow_lite_support/examples/task/text/desktop/universal_sentence_encoder_qa_op_resolver.h"

namespace tflite {
namespace task {
namespace text {
Expand Down Expand Up @@ -84,7 +83,7 @@ void ExpectApproximatelyEqual(const SearchResult& actual,
std::unique_ptr<tflite::OpResolver> GetOpResolver(
bool is_universal_sentence_encoder) {
if (is_universal_sentence_encoder) {
return CreateQACustomOpResolver();
return CreateTextOpResolver();
} else {
return absl::make_unique<tflite_shims::ops::builtin::BuiltinOpResolver>();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,9 @@ limitations under the License.
#include "tensorflow_lite_support/cc/port/status_matchers.h"
#include "tensorflow_lite_support/cc/task/core/task_utils.h"
#include "tensorflow_lite_support/cc/task/text/proto/retrieval.pb.h"
#include "tensorflow_lite_support/cc/task/text/utils/text_op_resolver.h"
#include "tensorflow_lite_support/cc/test/message_matchers.h"
#include "tensorflow_lite_support/cc/test/test_utils.h"
#include "tensorflow_lite_support/examples/task/text/desktop/universal_sentence_encoder_qa_op_resolver.h"

namespace tflite {
namespace task {
namespace text {
Expand Down Expand Up @@ -83,7 +82,7 @@ class UniversalSentenceEncoderQATest : public tflite_shims::testing::Test {
options.mutable_base_options()->mutable_model_file()->set_file_name(
filename);
auto status = UniversalSentenceEncoderQA::CreateFromOption(
options, CreateQACustomOpResolver());
options, CreateTextOpResolver());
if (status.ok()) {
qa_client_ = std::move(status.value());
}
Expand Down
24 changes: 3 additions & 21 deletions tensorflow_lite_support/examples/task/text/desktop/BUILD
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
load("@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl", "cc_library_with_tflite")

package(
default_visibility = [
"//tensorflow_lite_support:internal",
Expand Down Expand Up @@ -87,22 +85,6 @@ cc_binary(
}),
)

cc_library_with_tflite(
name = "universal_sentence_encoder_qa_op_resolver",
srcs = ["universal_sentence_encoder_qa_op_resolver.cc"],
hdrs = ["universal_sentence_encoder_qa_op_resolver.h"],
tflite_deps = [
"@org_tensorflow//tensorflow/lite/core/shims:builtin_ops",
],
deps = [
"//tensorflow_lite_support/custom_ops/kernel/ragged:ragged_tensor_to_tensor_tflite", # fixdeps: keep
"//tensorflow_lite_support/custom_ops/kernel/sentencepiece:sentencepiece_tokenizer_tflite", # fixdeps: keep
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@org_tensorflow//tensorflow/lite:op_resolver",
],
)

# Example usage:
# bazel run -c opt \
# tensorflow_lite_support/examples/task/text/desktop:universal_sentence_encoder_qa_main \
Expand All @@ -114,8 +96,8 @@ cc_binary(
"universal_sentence_encoder_qa_demo.cc",
],
deps = [
":universal_sentence_encoder_qa_op_resolver",
"//tensorflow_lite_support/cc/task/text:universal_sentence_encoder_qa",
"//tensorflow_lite_support/cc/task/text/utils:text_op_resolver",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/flags:parse",
"@com_google_absl//absl/status",
Expand All @@ -134,7 +116,6 @@ cc_binary(
name = "text_searcher_demo",
srcs = ["text_searcher_demo.cc"],
deps = [
":universal_sentence_encoder_qa_op_resolver",
"//tensorflow_lite_support/cc/port:configuration_proto_inc",
"//tensorflow_lite_support/cc/port:status_macros",
"//tensorflow_lite_support/cc/task/core/proto:base_options_cc_proto",
Expand All @@ -144,6 +125,7 @@ cc_binary(
"//tensorflow_lite_support/cc/task/processor/proto:search_result_cc_proto",
"//tensorflow_lite_support/cc/task/text:text_searcher",
"//tensorflow_lite_support/cc/task/text/proto:text_searcher_options_cc_proto",
"//tensorflow_lite_support/cc/task/text/utils:text_op_resolver",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/flags:parse",
"@com_google_absl//absl/status",
Expand All @@ -162,13 +144,13 @@ cc_binary(
name = "text_embedder_demo",
srcs = ["text_embedder_demo.cc"],
deps = [
":universal_sentence_encoder_qa_op_resolver",
"//tensorflow_lite_support/cc/port:configuration_proto_inc",
"//tensorflow_lite_support/cc/port:status_macros",
"//tensorflow_lite_support/cc/task/core/proto:base_options_cc_proto",
"//tensorflow_lite_support/cc/task/processor/proto:embedding_cc_proto",
"//tensorflow_lite_support/cc/task/text:text_embedder",
"//tensorflow_lite_support/cc/task/text/proto:text_embedder_options_cc_proto",
"//tensorflow_lite_support/cc/task/text/utils:text_op_resolver",
"@com_google_absl//absl/flags:flag",
"@com_google_absl//absl/flags:parse",
"@com_google_absl//absl/status",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ limitations under the License.
#include "tensorflow_lite_support/cc/task/processor/proto/embedding.pb.h"
#include "tensorflow_lite_support/cc/task/text/proto/text_embedder_options.pb.h"
#include "tensorflow_lite_support/cc/task/text/text_embedder.h"
#include "tensorflow_lite_support/examples/task/text/desktop/universal_sentence_encoder_qa_op_resolver.h"
#include "tensorflow_lite_support/cc/task/text/utils/text_op_resolver.h"

ABSL_FLAG(std::string, model_path, "",
"Absolute path to the '.tflite' text embedder model.");
Expand Down Expand Up @@ -82,7 +82,7 @@ absl::Status ComputeCosineSimilarity() {
const TextEmbedderOptions options = BuildOptions();
ASSIGN_OR_RETURN(
std::unique_ptr<TextEmbedder> text_embedder,
TextEmbedder::CreateFromOptions(options, CreateQACustomOpResolver()));
TextEmbedder::CreateFromOptions(options, CreateTextOpResolver()));

// Run search and display results.
auto start_embed = steady_clock::now();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ limitations under the License.
#include "tensorflow_lite_support/cc/task/processor/proto/search_result.pb.h"
#include "tensorflow_lite_support/cc/task/text/proto/text_searcher_options.pb.h"
#include "tensorflow_lite_support/cc/task/text/text_searcher.h"
#include "tensorflow_lite_support/examples/task/text/desktop/universal_sentence_encoder_qa_op_resolver.h"
#include "tensorflow_lite_support/cc/task/text/utils/text_op_resolver.h"

ABSL_FLAG(std::string, model_path, "",
"Absolute path to the '.tflite' text embedder model.");
Expand Down Expand Up @@ -103,7 +103,7 @@ absl::Status Search() {
const TextSearcherOptions options = BuildOptions();
ASSIGN_OR_RETURN(
std::unique_ptr<TextSearcher> text_searcher,
TextSearcher::CreateFromOptions(options, CreateQACustomOpResolver()));
TextSearcher::CreateFromOptions(options, CreateTextOpResolver()));

// Run search and display results.
auto start_search = steady_clock::now();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ limitations under the License.
#include "absl/status/status.h" // from @com_google_absl
#include "absl/strings/str_split.h" // from @com_google_absl
#include "tensorflow_lite_support/cc/task/text/universal_sentence_encoder_qa.h"
#include "tensorflow_lite_support/examples/task/text/desktop/universal_sentence_encoder_qa_op_resolver.h"
#include "tensorflow_lite_support/cc/task/text/utils/text_op_resolver.h"

namespace {
using tflite::task::text::CreateQACustomOpResolver;
using tflite::task::text::CreateTextOpResolver;
using tflite::task::text::RetrievalInput;
using tflite::task::text::RetrievalOptions;
using tflite::task::text::RetrievalOutput;
Expand Down Expand Up @@ -58,7 +58,7 @@ int main(int argc, char** argv) {
options.mutable_base_options()->mutable_model_file()->set_file_name(
absl::GetFlag(FLAGS_model_path));
auto status = UniversalSentenceEncoderQA::CreateFromOption(
options, CreateQACustomOpResolver());
options, CreateTextOpResolver());
if (!status.ok()) {
std::cerr << "Retrieve failed: " << status.status().message() << std::endl;
return 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ cc_library_with_tflite(
"universal_sentence_encoder_qa_op_register.cc",
],
tflite_deps = [
"//tensorflow_lite_support/examples/task/text/desktop:universal_sentence_encoder_qa_op_resolver",
"//tensorflow_lite_support/cc/task/text/utils:text_op_resolver",
],
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "tensorflow_lite_support/examples/task/text/desktop/universal_sentence_encoder_qa_op_resolver.h"
#include "tensorflow_lite_support/cc/task/text/utils/text_op_resolver.h"

namespace tflite {
namespace task {
// Provides a custom OpResolver for TextSearcher Java API.
std::unique_ptr<OpResolver> CreateOpResolver() {
return tflite::task::text::CreateQACustomOpResolver();
return tflite::task::text::CreateTextOpResolver();
}

} // namespace task
Expand Down
4 changes: 2 additions & 2 deletions tensorflow_lite_support/python/task/text/pybinds/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ pybind_extension_may_pack_coral(
module_name = "_pywrap_text_embedder",
deps = [
"//tensorflow_lite_support/cc/task/text:text_embedder",
"//tensorflow_lite_support/examples/task/text/desktop:universal_sentence_encoder_qa_op_resolver",
"//tensorflow_lite_support/cc/task/text/utils:text_op_resolver",
"//tensorflow_lite_support/python/task/core/pybinds:task_utils",
"@pybind11",
"@pybind11_abseil//pybind11_abseil:status_casters",
Expand All @@ -31,7 +31,7 @@ pybind_extension_may_pack_coral(
module_name = "_pywrap_text_searcher",
deps = [
"//tensorflow_lite_support/cc/task/text:text_searcher",
"//tensorflow_lite_support/examples/task/text/desktop:universal_sentence_encoder_qa_op_resolver",
"//tensorflow_lite_support/cc/task/text/utils:text_op_resolver",
"//tensorflow_lite_support/python/task/core/pybinds:task_utils",
"@pybind11",
"@pybind11_protobuf//pybind11_protobuf:native_proto_caster",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ limitations under the License.
#include "pybind11/pybind11.h"
#include "pybind11_protobuf/native_proto_caster.h" // from @pybind11_protobuf
#include "tensorflow_lite_support/cc/task/text/text_embedder.h"
#include "tensorflow_lite_support/examples/task/text/desktop/universal_sentence_encoder_qa_op_resolver.h"
#include "tensorflow_lite_support/cc/task/text/utils/text_op_resolver.h"
#include "tensorflow_lite_support/python/task/core/pybinds/task_utils.h"

namespace tflite {
Expand Down Expand Up @@ -45,7 +45,7 @@ PYBIND11_MODULE(_pywrap_text_embedder, m) {
options.set_allocated_base_options(cpp_base_options.release());
options.add_embedding_options()->CopyFrom(embedding_options);
auto embedder = TextEmbedder::CreateFromOptions(
options, CreateQACustomOpResolver());
options, CreateTextOpResolver());
return core::get_value(embedder);
})
.def("embed",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ limitations under the License.
#include "pybind11/pybind11.h"
#include "pybind11_protobuf/native_proto_caster.h" // from @pybind11_protobuf
#include "tensorflow_lite_support/cc/task/text/text_searcher.h"
#include "tensorflow_lite_support/examples/task/text/desktop/universal_sentence_encoder_qa_op_resolver.h"
#include "tensorflow_lite_support/cc/task/text/utils/text_op_resolver.h"
#include "tensorflow_lite_support/python/task/core/pybinds/task_utils.h"

namespace tflite {
Expand Down Expand Up @@ -59,7 +59,7 @@ PYBIND11_MODULE(_pywrap_text_searcher, m) {
options.set_allocated_search_options(cpp_search_options.release());

auto searcher = TextSearcher::CreateFromOptions(
options, CreateQACustomOpResolver());
options, CreateTextOpResolver());
return core::get_value(searcher);
})
.def("search",
Expand Down

0 comments on commit 9aa3b1a

Please sign in to comment.