Skip to content

Commit

Permalink
Switches bert_clu_annotator build rules to cc_library_with_tflite.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 452161028
  • Loading branch information
tensorflower-gardener authored and tflite-support-robot committed May 31, 2022
1 parent 9aa3b1a commit ecb6ea4
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 23 deletions.
16 changes: 10 additions & 6 deletions tensorflow_lite_support/cc/task/text/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -138,32 +138,36 @@ cc_library_with_tflite(
],
)

cc_library(
cc_library_with_tflite(
name = "clu_annotator",
hdrs = [
"clu_annotator.h",
],
deps = [
tflite_deps = [
"//tensorflow_lite_support/cc/task/core:base_task_api",
"//tensorflow_lite_support/cc/task/core:tflite_engine",
],
deps = [
"//tensorflow_lite_support/cc/task/text/proto:clu_proto_inc",
],
)

cc_library(
cc_library_with_tflite(
name = "bert_clu_annotator",
srcs = [
"bert_clu_annotator.cc",
],
hdrs = [
"bert_clu_annotator.h",
],
deps = [
tflite_deps = [
":clu_annotator",
"//tensorflow_lite_support/cc/port:status_macros",
"//tensorflow_lite_support/cc/task/core:task_api_factory",
"//tensorflow_lite_support/cc/task/core:task_utils",
"//tensorflow_lite_support/cc/task/text/clu_lib:tflite_modules",
],
deps = [
"//tensorflow_lite_support/cc/port:status_macros",
"//tensorflow_lite_support/cc/task/core:task_utils",
"//tensorflow_lite_support/cc/task/text/proto:bert_clu_annotator_options_proto_inc",
"//tensorflow_lite_support/cc/text/tokenizers:bert_tokenizer",
"//tensorflow_lite_support/cc/text/tokenizers:tokenizer_utils",
Expand Down
10 changes: 9 additions & 1 deletion tensorflow_lite_support/cc/task/text/clu_lib/BUILD
Original file line number Diff line number Diff line change
@@ -1,12 +1,20 @@
load(
"@org_tensorflow//tensorflow/lite/core/shims:cc_library_with_tflite.bzl",
"cc_library_with_tflite",
)

package(
default_visibility = ["//visibility:public"],
licenses = ["notice"], # Apache 2.0
)

cc_library(
cc_library_with_tflite(
name = "tflite_modules",
srcs = ["tflite_modules.cc"],
hdrs = ["tflite_modules.h"],
tflite_deps = [
"//tensorflow_lite_support/cc/task/core:tflite_engine",
],
deps = [
":bert_utils",
":constants",
Expand Down
23 changes: 14 additions & 9 deletions tensorflow_lite_support/cc/task/text/clu_lib/tflite_modules.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ absl::Status PopulateInputTextTensorForBERT(
const CluRequest& request, int token_id_tensor_idx,
int token_mask_tensor_idx, int token_type_id_tensor_idx,
const tflite::support::text::tokenizer::BertTokenizer* tokenizer,
size_t max_seq_len, int max_history_turns, tflite::Interpreter* interpreter,
Artifacts* artifacts) {
size_t max_seq_len, int max_history_turns,
core::TfLiteEngine::Interpreter* interpreter, Artifacts* artifacts) {
size_t seq_len;
int64_t* tokens_tensor =
interpreter->typed_input_tensor<int64_t>(token_id_tensor_idx);
Expand Down Expand Up @@ -116,8 +116,9 @@ absl::Status PopulateInputTextTensorForBERT(
return absl::OkStatus();
}

absl::StatusOr<int> GetInputSeqDimSize(const size_t input_idx,
const tflite::Interpreter* interpreter) {
absl::StatusOr<int> GetInputSeqDimSize(
const size_t input_idx,
const core::TfLiteEngine::Interpreter* interpreter) {
if (input_idx >= interpreter->inputs().size()) {
return absl::InternalError(absl::StrCat(
"input_idx should be less than interpreter input numbers. ", input_idx,
Expand All @@ -132,14 +133,15 @@ absl::StatusOr<int> GetInputSeqDimSize(const size_t input_idx,
return tflite::SizeOfDimension(tensor, 1);
}

absl::Status AbstractModule::Init(tflite::Interpreter* interpreter,
absl::Status AbstractModule::Init(core::TfLiteEngine::Interpreter* interpreter,
const BertCluAnnotatorOptions* options) {
interpreter_ = interpreter;
return absl::OkStatus();
}

absl::StatusOr<std::unique_ptr<AbstractModule>> UtteranceSeqModule::Create(
tflite::Interpreter* interpreter, const TensorIndexMap* tensor_index_map,
core::TfLiteEngine::Interpreter* interpreter,
const TensorIndexMap* tensor_index_map,
const BertCluAnnotatorOptions* options,
const tflite::support::text::tokenizer::BertTokenizer* tokenizer) {
auto out = std::make_unique<UtteranceSeqModule>();
Expand Down Expand Up @@ -187,7 +189,8 @@ AbstractModule::NamesAndConfidencesFromOutput(int names_tensor_idx,
}

absl::StatusOr<std::unique_ptr<AbstractModule>> DomainModule::Create(
tflite::Interpreter* interpreter, const TensorIndexMap* tensor_index_map,
core::TfLiteEngine::Interpreter* interpreter,
const TensorIndexMap* tensor_index_map,
const BertCluAnnotatorOptions* options) {
auto out = std::make_unique<DomainModule>();
out->tensor_index_map_ = tensor_index_map;
Expand Down Expand Up @@ -215,7 +218,8 @@ absl::Status DomainModule::Postprocess(Artifacts* artifacts,
}

absl::StatusOr<std::unique_ptr<AbstractModule>> IntentModule::Create(
tflite::Interpreter* interpreter, const TensorIndexMap* tensor_index_map,
core::TfLiteEngine::Interpreter* interpreter,
const TensorIndexMap* tensor_index_map,
const BertCluAnnotatorOptions* options) {
auto out = std::make_unique<IntentModule>();
out->tensor_index_map_ = tensor_index_map;
Expand Down Expand Up @@ -261,7 +265,8 @@ absl::Status IntentModule::Postprocess(Artifacts* artifacts,
}

absl::StatusOr<std::unique_ptr<AbstractModule>> SlotModule::Create(
tflite::Interpreter* interpreter, const TensorIndexMap* tensor_index_map,
core::TfLiteEngine::Interpreter* interpreter,
const TensorIndexMap* tensor_index_map,
const BertCluAnnotatorOptions* options) {
auto out = std::make_unique<SlotModule>();
out->tensor_index_map_ = tensor_index_map;
Expand Down
18 changes: 11 additions & 7 deletions tensorflow_lite_support/cc/task/text/clu_lib/tflite_modules.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ limitations under the License.

#include "absl/status/statusor.h" // from @com_google_absl
#include "absl/strings/string_view.h" // from @com_google_absl
#include "tensorflow/lite/interpreter.h"
#include "tensorflow_lite_support/cc/task/core/tflite_engine.h"
#include "tensorflow_lite_support/cc/task/text/proto/bert_clu_annotator_options_proto_inc.h"
#include "tensorflow_lite_support/cc/task/text/proto/clu_proto_inc.h"
#include "tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.h"
Expand Down Expand Up @@ -76,7 +76,7 @@ class AbstractModule {
protected:
AbstractModule() = default;

absl::Status Init(Interpreter* interpreter,
absl::Status Init(core::TfLiteEngine::Interpreter* interpreter,
const BertCluAnnotatorOptions* options);

using NamesAndConfidences =
Expand All @@ -88,7 +88,7 @@ class AbstractModule {
int names_tensor_idx, int scores_tensor_idx) const;

// TFLite interpreter
Interpreter* interpreter_ = nullptr;
core::TfLiteEngine::Interpreter* interpreter_ = nullptr;

const TensorIndexMap* tensor_index_map_ = nullptr;
};
Expand All @@ -98,7 +98,8 @@ class AbstractModule {
class UtteranceSeqModule : public AbstractModule {
public:
static absl::StatusOr<std::unique_ptr<AbstractModule>> Create(
Interpreter* interpreter, const TensorIndexMap* tensor_index_map,
core::TfLiteEngine::Interpreter* interpreter,
const TensorIndexMap* tensor_index_map,
const BertCluAnnotatorOptions* options,
const tflite::support::text::tokenizer::BertTokenizer* tokenizer);

Expand All @@ -116,7 +117,8 @@ class UtteranceSeqModule : public AbstractModule {
class DomainModule : public AbstractModule {
public:
static absl::StatusOr<std::unique_ptr<AbstractModule>> Create(
Interpreter* interpreter, const TensorIndexMap* tensor_index_map,
core::TfLiteEngine::Interpreter* interpreter,
const TensorIndexMap* tensor_index_map,
const BertCluAnnotatorOptions* options);

absl::Status Postprocess(Artifacts* artifacts,
Expand All @@ -130,7 +132,8 @@ class DomainModule : public AbstractModule {
class IntentModule : public AbstractModule {
public:
static absl::StatusOr<std::unique_ptr<AbstractModule>> Create(
Interpreter* interpreter, const TensorIndexMap* tensor_index_map,
core::TfLiteEngine::Interpreter* interpreter,
const TensorIndexMap* tensor_index_map,
const BertCluAnnotatorOptions* options);

absl::Status Postprocess(Artifacts* artifacts,
Expand All @@ -145,7 +148,8 @@ class IntentModule : public AbstractModule {
class SlotModule : public AbstractModule {
public:
static absl::StatusOr<std::unique_ptr<AbstractModule>> Create(
Interpreter* interpreter, const TensorIndexMap* tensor_index_map,
core::TfLiteEngine::Interpreter* interpreter,
const TensorIndexMap* tensor_index_map,
const BertCluAnnotatorOptions* options);

absl::Status Postprocess(Artifacts* artifacts,
Expand Down

0 comments on commit ecb6ea4

Please sign in to comment.