Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 467457993
  • Loading branch information
tensorflower-gardener authored and tflite-support-robot committed Aug 14, 2022
1 parent dd36717 commit 527be55
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 45 deletions.
135 changes: 93 additions & 42 deletions tensorflow_lite_support/cc/task/processor/bert_preprocessor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -52,32 +52,81 @@ StatusOr<std::unique_ptr<BertPreprocessor>> BertPreprocessor::Create(
return processor;
}

// TODO(b/241507692) Add a unit test for a model with dynamic tensors.
absl::Status BertPreprocessor::Init() {
// Try if RegexTokenzier can be found.
// BertTokenzier is packed in the processing unit SubgraphMetadata.
// Try if RegexTokenizer can be found.
// BertTokenizer is packed in the processing unit SubgraphMetadata.
const tflite::ProcessUnit* tokenizer_metadata =
GetMetadataExtractor()->GetInputProcessUnit(kTokenizerProcessUnitIndex);
ASSIGN_OR_RETURN(tokenizer_, CreateTokenizerFromProcessUnit(
tokenizer_metadata, GetMetadataExtractor()));

// Sanity check and assign max sequence length.
if (GetLastDimSize(tensor_indices_[kIdsTensorIndex]) !=
GetLastDimSize(tensor_indices_[kMaskTensorIndex]) ||
GetLastDimSize(tensor_indices_[kIdsTensorIndex]) !=
GetLastDimSize(tensor_indices_[kSegmentIdsTensorIndex])) {
const auto& ids_tensor = *GetTensor(kIdsTensorIndex);
const auto& mask_tensor = *GetTensor(kMaskTensorIndex);
const auto& segment_ids_tensor = *GetTensor(kSegmentIdsTensorIndex);
if (ids_tensor.dims->size != 2 || mask_tensor.dims->size != 2 ||
segment_ids_tensor.dims->size != 2) {
return CreateStatusWithPayload(
absl::StatusCode::kInternal,
absl::StrFormat(
"The three input tensors in Bert models are "
"expected to have same length, but got ids_tensor "
"(%d), mask_tensor (%d), segment_ids_tensor (%d).",
GetLastDimSize(tensor_indices_[kIdsTensorIndex]),
GetLastDimSize(tensor_indices_[kMaskTensorIndex]),
GetLastDimSize(tensor_indices_[kSegmentIdsTensorIndex])),
TfLiteSupportStatus::kInvalidNumOutputTensorsError);
"The three input tensors in Bert models are expected to have dim "
"2, but got ids_tensor (%d), mask_tensor (%d), segment_ids_tensor "
"(%d).",
ids_tensor.dims->size, mask_tensor.dims->size,
segment_ids_tensor.dims->size),
TfLiteSupportStatus::kInvalidInputTensorDimensionsError);
}
if (ids_tensor.dims->data[0] != 1 || mask_tensor.dims->data[0] != 1 ||
segment_ids_tensor.dims->data[0] != 1) {
return CreateStatusWithPayload(
absl::StatusCode::kInternal,
absl::StrFormat(
"The three input tensors in Bert models are expected to have same "
"batch size 1, but got ids_tensor (%d), mask_tensor (%d), "
"segment_ids_tensor (%d).",
ids_tensor.dims->data[0], mask_tensor.dims->data[0],
segment_ids_tensor.dims->data[0]),
TfLiteSupportStatus::kInvalidInputTensorSizeError);
}
if (ids_tensor.dims->data[1] != mask_tensor.dims->data[1] ||
ids_tensor.dims->data[1] != segment_ids_tensor.dims->data[1]) {
return CreateStatusWithPayload(
absl::StatusCode::kInternal,
absl::StrFormat("The three input tensors in Bert models are "
"expected to have same length, but got ids_tensor "
"(%d), mask_tensor (%d), segment_ids_tensor (%d).",
ids_tensor.dims->data[1], mask_tensor.dims->data[1],
segment_ids_tensor.dims->data[1]),
TfLiteSupportStatus::kInvalidInputTensorSizeError);
}
bert_max_seq_len_ = GetLastDimSize(tensor_indices_[kIdsTensorIndex]);

bool has_valid_dims_signature = ids_tensor.dims_signature->size == 2 &&
mask_tensor.dims_signature->size == 2 &&
segment_ids_tensor.dims_signature->size == 2;
if (has_valid_dims_signature && ids_tensor.dims_signature->data[1] == -1 &&
mask_tensor.dims_signature->data[1] == -1 &&
segment_ids_tensor.dims_signature->data[1] == -1) {
input_tensors_are_dynamic_ = true;
} else if (has_valid_dims_signature &&
(ids_tensor.dims_signature->data[1] == -1 ||
mask_tensor.dims_signature->data[1] == -1 ||
segment_ids_tensor.dims_signature->data[1] == -1)) {
return CreateStatusWithPayload(
absl::StatusCode::kInternal,
"Input tensors contain a mix of static and dynamic tensors",
TfLiteSupportStatus::kInvalidInputTensorSizeError);
}

if (input_tensors_are_dynamic_) return absl::OkStatus();

bert_max_seq_len_ = ids_tensor.dims->data[1];
if (bert_max_seq_len_ < 2) {
return CreateStatusWithPayload(
absl::StatusCode::kInternal,
absl::StrFormat("bert_max_seq_len_ should be at least 2, got: (%d).",
bert_max_seq_len_),
TfLiteSupportStatus::kInvalidInputTensorSizeError);
}
return absl::OkStatus();
}

Expand All @@ -92,48 +141,50 @@ absl::Status BertPreprocessor::Preprocess(const std::string& input_text) {
TokenizerResult input_tokenize_results;
input_tokenize_results = tokenizer_->Tokenize(processed_input);

// 2 accounts for [CLS], [SEP]
absl::Span<const std::string> query_tokens =
absl::MakeSpan(input_tokenize_results.subwords.data(),
input_tokenize_results.subwords.data() +
std::min(static_cast<size_t>(bert_max_seq_len_ - 2),
input_tokenize_results.subwords.size()));

std::vector<std::string> tokens;
tokens.reserve(2 + query_tokens.size());
// Start of generating the features.
tokens.push_back(kClassificationToken);
// For query input.
for (const auto& query_token : query_tokens) {
tokens.push_back(query_token);
// Offset by 2 to account for [CLS] and [SEP]
int input_tokens_size =
static_cast<int>(input_tokenize_results.subwords.size()) + 2;
int input_tensor_length = input_tokens_size;
if (!input_tensors_are_dynamic_) {
input_tokens_size = std::min(bert_max_seq_len_, input_tokens_size);
input_tensor_length = bert_max_seq_len_;
} else {
engine_->interpreter()->ResizeInputTensorStrict(kIdsTensorIndex,
{1, input_tensor_length});
engine_->interpreter()->ResizeInputTensorStrict(kMaskTensorIndex,
{1, input_tensor_length});
engine_->interpreter()->ResizeInputTensorStrict(kSegmentIdsTensorIndex,
{1, input_tensor_length});
engine_->interpreter()->AllocateTensors();
}
// For Separation.
tokens.push_back(kSeparator);

std::vector<int> input_ids(bert_max_seq_len_, 0);
std::vector<int> input_mask(bert_max_seq_len_, 0);
std::vector<std::string> input_tokens;
input_tokens.reserve(input_tokens_size);
input_tokens.push_back(std::string(kClassificationToken));
for (int i = 0; i < input_tokens_size - 2; ++i) {
input_tokens.push_back(std::move(input_tokenize_results.subwords[i]));
}
input_tokens.push_back(std::string(kSeparator));

std::vector<int> input_ids(input_tensor_length, 0);
std::vector<int> input_mask(input_tensor_length, 0);
// Convert tokens back into ids and set mask
for (int i = 0; i < tokens.size(); ++i) {
tokenizer_->LookupId(tokens[i], &input_ids[i]);
for (int i = 0; i < input_tokens.size(); ++i) {
tokenizer_->LookupId(input_tokens[i], &input_ids[i]);
input_mask[i] = 1;
}
// |<--------bert_max_seq_len_--------->|
// |<--------input_tensor_length------->|
// input_ids [CLS] s1 s2... sn [SEP] 0 0... 0
// input_masks 1 1 1... 1 1 0 0... 0
// segment_ids 0 0 0... 0 0 0 0... 0

RETURN_IF_ERROR(PopulateTensor(input_ids, ids_tensor));
RETURN_IF_ERROR(PopulateTensor(input_mask, mask_tensor));
RETURN_IF_ERROR(PopulateTensor(std::vector<int>(bert_max_seq_len_, 0),
RETURN_IF_ERROR(PopulateTensor(std::vector<int>(input_tensor_length, 0),
segment_ids_tensor));
return absl::OkStatus();
}

int BertPreprocessor::GetLastDimSize(int tensor_index) {
auto tensor = engine_->GetInput(engine_->interpreter(), tensor_index);
return tensor->dims->data[tensor->dims->size - 1];
}

} // namespace processor
} // namespace task
} // namespace tflite
8 changes: 5 additions & 3 deletions tensorflow_lite_support/cc/task/processor/bert_preprocessor.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,12 @@ class BertPreprocessor : public TextPreprocessor {

absl::Status Init();

int GetLastDimSize(int tensor_index);

std::unique_ptr<tflite::support::text::tokenizer::Tokenizer> tokenizer_;
int bert_max_seq_len_;
// The maximum input sequence length the BERT model can accept. Used for
// static input tensors.
int bert_max_seq_len_ = 2;
// Whether the input tensors are dynamic instead of static.
bool input_tensors_are_dynamic_ = false;
};

} // namespace processor
Expand Down

0 comments on commit 527be55

Please sign in to comment.