diff --git a/tensorflow_lite_support/cc/task/processor/bert_preprocessor.cc b/tensorflow_lite_support/cc/task/processor/bert_preprocessor.cc index 755d22f7d..76a7a534a 100644 --- a/tensorflow_lite_support/cc/task/processor/bert_preprocessor.cc +++ b/tensorflow_lite_support/cc/task/processor/bert_preprocessor.cc @@ -52,32 +52,81 @@ StatusOr> BertPreprocessor::Create( return processor; } +// TODO(b/241507692) Add a unit test for a model with dynamic tensors. absl::Status BertPreprocessor::Init() { - // Try if RegexTokenzier can be found. - // BertTokenzier is packed in the processing unit SubgraphMetadata. + // Try if RegexTokenizer can be found. + // BertTokenizer is packed in the processing unit SubgraphMetadata. const tflite::ProcessUnit* tokenizer_metadata = GetMetadataExtractor()->GetInputProcessUnit(kTokenizerProcessUnitIndex); ASSIGN_OR_RETURN(tokenizer_, CreateTokenizerFromProcessUnit( tokenizer_metadata, GetMetadataExtractor())); - // Sanity check and assign max sequence length. - if (GetLastDimSize(tensor_indices_[kIdsTensorIndex]) != - GetLastDimSize(tensor_indices_[kMaskTensorIndex]) || - GetLastDimSize(tensor_indices_[kIdsTensorIndex]) != - GetLastDimSize(tensor_indices_[kSegmentIdsTensorIndex])) { + const auto& ids_tensor = *GetTensor(kIdsTensorIndex); + const auto& mask_tensor = *GetTensor(kMaskTensorIndex); + const auto& segment_ids_tensor = *GetTensor(kSegmentIdsTensorIndex); + if (ids_tensor.dims->size != 2 || mask_tensor.dims->size != 2 || + segment_ids_tensor.dims->size != 2) { return CreateStatusWithPayload( absl::StatusCode::kInternal, absl::StrFormat( - "The three input tensors in Bert models are " - "expected to have same length, but got ids_tensor " - "(%d), mask_tensor (%d), segment_ids_tensor (%d).", - GetLastDimSize(tensor_indices_[kIdsTensorIndex]), - GetLastDimSize(tensor_indices_[kMaskTensorIndex]), - GetLastDimSize(tensor_indices_[kSegmentIdsTensorIndex])), - TfLiteSupportStatus::kInvalidNumOutputTensorsError); + "The three input tensors in Bert models are expected to have dim " + "2, but got ids_tensor (%d), mask_tensor (%d), segment_ids_tensor " + "(%d).", + ids_tensor.dims->size, mask_tensor.dims->size, + segment_ids_tensor.dims->size), + TfLiteSupportStatus::kInvalidInputTensorDimensionsError); + } + if (ids_tensor.dims->data[0] != 1 || mask_tensor.dims->data[0] != 1 || + segment_ids_tensor.dims->data[0] != 1) { + return CreateStatusWithPayload( + absl::StatusCode::kInternal, + absl::StrFormat( + "The three input tensors in Bert models are expected to have same " + "batch size 1, but got ids_tensor (%d), mask_tensor (%d), " + "segment_ids_tensor (%d).", + ids_tensor.dims->data[0], mask_tensor.dims->data[0], + segment_ids_tensor.dims->data[0]), + TfLiteSupportStatus::kInvalidInputTensorSizeError); + } + if (ids_tensor.dims->data[1] != mask_tensor.dims->data[1] || + ids_tensor.dims->data[1] != segment_ids_tensor.dims->data[1]) { + return CreateStatusWithPayload( + absl::StatusCode::kInternal, + absl::StrFormat("The three input tensors in Bert models are " + "expected to have same length, but got ids_tensor " + "(%d), mask_tensor (%d), segment_ids_tensor (%d).", + ids_tensor.dims->data[1], mask_tensor.dims->data[1], + segment_ids_tensor.dims->data[1]), + TfLiteSupportStatus::kInvalidInputTensorSizeError); } - bert_max_seq_len_ = GetLastDimSize(tensor_indices_[kIdsTensorIndex]); + bool has_valid_dims_signature = ids_tensor.dims_signature->size == 2 && + mask_tensor.dims_signature->size == 2 && + segment_ids_tensor.dims_signature->size == 2; + if (has_valid_dims_signature && ids_tensor.dims_signature->data[1] == -1 && + mask_tensor.dims_signature->data[1] == -1 && + segment_ids_tensor.dims_signature->data[1] == -1) { + input_tensors_are_dynamic_ = true; + } else if (has_valid_dims_signature && + (ids_tensor.dims_signature->data[1] == -1 || + mask_tensor.dims_signature->data[1] == -1 || + segment_ids_tensor.dims_signature->data[1] == -1)) { + return CreateStatusWithPayload( + absl::StatusCode::kInternal, + "Input tensors contain a mix of static and dynamic tensors", + TfLiteSupportStatus::kInvalidInputTensorSizeError); + } + + if (input_tensors_are_dynamic_) return absl::OkStatus(); + + bert_max_seq_len_ = ids_tensor.dims->data[1]; + if (bert_max_seq_len_ < 2) { + return CreateStatusWithPayload( + absl::StatusCode::kInternal, + absl::StrFormat("bert_max_seq_len_ should be at least 2, got: (%d).", + bert_max_seq_len_), + TfLiteSupportStatus::kInvalidInputTensorSizeError); + } return absl::OkStatus(); } @@ -92,48 +141,50 @@ absl::Status BertPreprocessor::Preprocess(const std::string& input_text) { TokenizerResult input_tokenize_results; input_tokenize_results = tokenizer_->Tokenize(processed_input); - // 2 accounts for [CLS], [SEP] - absl::Span query_tokens = - absl::MakeSpan(input_tokenize_results.subwords.data(), - input_tokenize_results.subwords.data() + - std::min(static_cast(bert_max_seq_len_ - 2), - input_tokenize_results.subwords.size())); - - std::vector tokens; - tokens.reserve(2 + query_tokens.size()); - // Start of generating the features. - tokens.push_back(kClassificationToken); - // For query input. - for (const auto& query_token : query_tokens) { - tokens.push_back(query_token); + // Offset by 2 to account for [CLS] and [SEP] + int input_tokens_size = + static_cast(input_tokenize_results.subwords.size()) + 2; + int input_tensor_length = input_tokens_size; + if (!input_tensors_are_dynamic_) { + input_tokens_size = std::min(bert_max_seq_len_, input_tokens_size); + input_tensor_length = bert_max_seq_len_; + } else { + engine_->interpreter()->ResizeInputTensorStrict(kIdsTensorIndex, + {1, input_tensor_length}); + engine_->interpreter()->ResizeInputTensorStrict(kMaskTensorIndex, + {1, input_tensor_length}); + engine_->interpreter()->ResizeInputTensorStrict(kSegmentIdsTensorIndex, + {1, input_tensor_length}); + engine_->interpreter()->AllocateTensors(); } - // For Separation. - tokens.push_back(kSeparator); - std::vector input_ids(bert_max_seq_len_, 0); - std::vector input_mask(bert_max_seq_len_, 0); + std::vector input_tokens; + input_tokens.reserve(input_tokens_size); + input_tokens.push_back(std::string(kClassificationToken)); + for (int i = 0; i < input_tokens_size - 2; ++i) { + input_tokens.push_back(std::move(input_tokenize_results.subwords[i])); + } + input_tokens.push_back(std::string(kSeparator)); + + std::vector input_ids(input_tensor_length, 0); + std::vector input_mask(input_tensor_length, 0); // Convert tokens back into ids and set mask - for (int i = 0; i < tokens.size(); ++i) { - tokenizer_->LookupId(tokens[i], &input_ids[i]); + for (int i = 0; i < input_tokens.size(); ++i) { + tokenizer_->LookupId(input_tokens[i], &input_ids[i]); input_mask[i] = 1; } - // |<--------bert_max_seq_len_--------->| + // |<--------input_tensor_length------->| // input_ids [CLS] s1 s2... sn [SEP] 0 0... 0 // input_masks 1 1 1... 1 1 0 0... 0 // segment_ids 0 0 0... 0 0 0 0... 0 RETURN_IF_ERROR(PopulateTensor(input_ids, ids_tensor)); RETURN_IF_ERROR(PopulateTensor(input_mask, mask_tensor)); - RETURN_IF_ERROR(PopulateTensor(std::vector(bert_max_seq_len_, 0), + RETURN_IF_ERROR(PopulateTensor(std::vector(input_tensor_length, 0), segment_ids_tensor)); return absl::OkStatus(); } -int BertPreprocessor::GetLastDimSize(int tensor_index) { - auto tensor = engine_->GetInput(engine_->interpreter(), tensor_index); - return tensor->dims->data[tensor->dims->size - 1]; -} - } // namespace processor } // namespace task } // namespace tflite diff --git a/tensorflow_lite_support/cc/task/processor/bert_preprocessor.h b/tensorflow_lite_support/cc/task/processor/bert_preprocessor.h index 02ece9a8b..eb288ff7c 100644 --- a/tensorflow_lite_support/cc/task/processor/bert_preprocessor.h +++ b/tensorflow_lite_support/cc/task/processor/bert_preprocessor.h @@ -46,10 +46,12 @@ class BertPreprocessor : public TextPreprocessor { absl::Status Init(); - int GetLastDimSize(int tensor_index); - std::unique_ptr tokenizer_; - int bert_max_seq_len_; + // The maximum input sequence length the BERT model can accept. Used for + // static input tensors. + int bert_max_seq_len_ = 2; + // Whether the input tensors are dynamic instead of static. + bool input_tensors_are_dynamic_ = false; }; } // namespace processor