blob: aa7ffef3b04d5bde1a94342b559b4387d2c948a9 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow_lite_support/cc/task/text/qa/bert_question_answerer.h"
#include "absl/strings/str_join.h"
#include "absl/strings/str_split.h"
#include "tensorflow_lite_support/cc/port/status_macros.h"
#include "tensorflow_lite_support/cc/task/core/task_utils.h"
#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer.h"
#include "tensorflow_lite_support/cc/text/tokenizers/tokenizer_utils.h"
#include "tensorflow_lite_support/metadata/metadata_schema_generated.h"
namespace tflite {
namespace task {
namespace text {
namespace qa {
constexpr char kIdsTensorName[] = "ids";
constexpr char kMaskTensorName[] = "mask";
constexpr char kSegmentIdsTensorName[] = "segment_ids";
constexpr char kEndLogitsTensorName[] = "end_logits";
constexpr char kStartLogitsTensorName[] = "start_logits";
using ::tflite::support::CreateStatusWithPayload;
using ::tflite::support::StatusOr;
using ::tflite::support::TfLiteSupportStatus;
using ::tflite::support::text::tokenizer::BertTokenizer;
using ::tflite::support::text::tokenizer::CreateTokenizerFromProcessUnit;
using ::tflite::support::text::tokenizer::SentencePieceTokenizer;
using ::tflite::support::text::tokenizer::TokenizerResult;
using ::tflite::task::core::FindTensorByName;
using ::tflite::task::core::PopulateTensor;
using ::tflite::task::core::PopulateVector;
using ::tflite::task::core::ReverseSortIndices;
namespace {
constexpr int kTokenizerProcessUnitIndex = 0;
}
StatusOr<std::unique_ptr<QuestionAnswerer>>
BertQuestionAnswerer::CreateFromFile(
const std::string& path_to_model_with_metadata) {
std::unique_ptr<BertQuestionAnswerer> api_to_init;
ASSIGN_OR_RETURN(
api_to_init,
core::TaskAPIFactory::CreateFromFile<BertQuestionAnswerer>(
path_to_model_with_metadata,
absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(),
kNumLiteThreads));
RETURN_IF_ERROR(api_to_init->InitializeFromMetadata());
return api_to_init;
}
StatusOr<std::unique_ptr<QuestionAnswerer>>
BertQuestionAnswerer::CreateFromBuffer(
const char* model_with_metadata_buffer_data,
size_t model_with_metadata_buffer_size) {
std::unique_ptr<BertQuestionAnswerer> api_to_init;
ASSIGN_OR_RETURN(
api_to_init,
core::TaskAPIFactory::CreateFromBuffer<BertQuestionAnswerer>(
model_with_metadata_buffer_data, model_with_metadata_buffer_size,
absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(),
kNumLiteThreads));
RETURN_IF_ERROR(api_to_init->InitializeFromMetadata());
return api_to_init;
}
StatusOr<std::unique_ptr<QuestionAnswerer>> BertQuestionAnswerer::CreateFromFd(
int fd) {
std::unique_ptr<BertQuestionAnswerer> api_to_init;
ASSIGN_OR_RETURN(
api_to_init,
core::TaskAPIFactory::CreateFromFileDescriptor<BertQuestionAnswerer>(
fd, absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(),
kNumLiteThreads));
RETURN_IF_ERROR(api_to_init->InitializeFromMetadata());
return api_to_init;
}
StatusOr<std::unique_ptr<QuestionAnswerer>>
BertQuestionAnswerer::CreateBertQuestionAnswererFromFile(
const std::string& path_to_model, const std::string& path_to_vocab) {
std::unique_ptr<BertQuestionAnswerer> api_to_init;
ASSIGN_OR_RETURN(
api_to_init,
core::TaskAPIFactory::CreateFromFile<BertQuestionAnswerer>(
path_to_model,
absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(),
kNumLiteThreads));
api_to_init->InitializeBertTokenizer(path_to_vocab);
return api_to_init;
}
StatusOr<std::unique_ptr<QuestionAnswerer>>
BertQuestionAnswerer::CreateBertQuestionAnswererFromBuffer(
const char* model_buffer_data, size_t model_buffer_size,
const char* vocab_buffer_data, size_t vocab_buffer_size) {
std::unique_ptr<BertQuestionAnswerer> api_to_init;
ASSIGN_OR_RETURN(
api_to_init,
core::TaskAPIFactory::CreateFromBuffer<BertQuestionAnswerer>(
model_buffer_data, model_buffer_size,
absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(),
kNumLiteThreads));
api_to_init->InitializeBertTokenizerFromBinary(vocab_buffer_data,
vocab_buffer_size);
return api_to_init;
}
StatusOr<std::unique_ptr<QuestionAnswerer>>
BertQuestionAnswerer::CreateAlbertQuestionAnswererFromFile(
const std::string& path_to_model, const std::string& path_to_spmodel) {
std::unique_ptr<BertQuestionAnswerer> api_to_init;
ASSIGN_OR_RETURN(
api_to_init,
core::TaskAPIFactory::CreateFromFile<BertQuestionAnswerer>(
path_to_model,
absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(),
kNumLiteThreads));
api_to_init->InitializeSentencepieceTokenizer(path_to_spmodel);
return api_to_init;
}
StatusOr<std::unique_ptr<QuestionAnswerer>>
BertQuestionAnswerer::CreateAlbertQuestionAnswererFromBuffer(
const char* model_buffer_data, size_t model_buffer_size,
const char* spmodel_buffer_data, size_t spmodel_buffer_size) {
std::unique_ptr<BertQuestionAnswerer> api_to_init;
ASSIGN_OR_RETURN(
api_to_init,
core::TaskAPIFactory::CreateFromBuffer<BertQuestionAnswerer>(
model_buffer_data, model_buffer_size,
absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>(),
kNumLiteThreads));
api_to_init->InitializeSentencepieceTokenizerFromBinary(spmodel_buffer_data,
spmodel_buffer_size);
return api_to_init;
}
std::vector<QaAnswer> BertQuestionAnswerer::Answer(
const std::string& context, const std::string& question) {
// The BertQuestionAnswererer implementation for Preprocess() and
// Postprocess() never returns errors: just call value().
return Infer(context, question).value();
}
absl::Status BertQuestionAnswerer::Preprocess(
const std::vector<TfLiteTensor*>& input_tensors, const std::string& context,
const std::string& query) {
auto* input_tensor_metadatas =
GetMetadataExtractor()->GetInputTensorMetadata();
TfLiteTensor* ids_tensor =
input_tensor_metadatas
? FindTensorByName(input_tensors, input_tensor_metadatas,
kIdsTensorName)
: input_tensors[0];
TfLiteTensor* mask_tensor =
input_tensor_metadatas
? FindTensorByName(input_tensors, input_tensor_metadatas,
kMaskTensorName)
: input_tensors[1];
TfLiteTensor* segment_ids_tensor =
input_tensor_metadatas
? FindTensorByName(input_tensors, input_tensor_metadatas,
kSegmentIdsTensorName)
: input_tensors[2];
token_to_orig_map_.clear();
// The orig_tokens is used for recovering the answer string from the index,
// while the processed_tokens is lower-cased and used to generate input of
// the model.
orig_tokens_ = absl::StrSplit(context, absl::ByChar(' '), absl::SkipEmpty());
std::vector<std::string> processed_tokens(orig_tokens_);
std::string processed_query = query;
if (kUseLowerCase) {
for (auto& token : processed_tokens) {
absl::AsciiStrToLower(&token);
}
absl::AsciiStrToLower(&processed_query);
}
TokenizerResult query_tokenize_results;
query_tokenize_results = tokenizer_->Tokenize(processed_query);
std::vector<std::string> query_tokens = query_tokenize_results.subwords;
if (query_tokens.size() > kMaxQueryLen) {
query_tokens.resize(kMaxQueryLen);
}
// Example:
// context: tokenize me please
// all_doc_tokens: token ##ize me plea ##se
// token_to_orig_index: [0, 0, 1, 2, 2]
std::vector<std::string> all_doc_tokens;
std::vector<int> token_to_orig_index;
for (size_t i = 0; i < processed_tokens.size(); i++) {
const std::string& token = processed_tokens[i];
std::vector<std::string> sub_tokens = tokenizer_->Tokenize(token).subwords;
for (const std::string& sub_token : sub_tokens) {
token_to_orig_index.emplace_back(i);
all_doc_tokens.emplace_back(sub_token);
}
}
// -3 accounts for [CLS], [SEP] and [SEP].
int max_context_len = kMaxSeqLen - query_tokens.size() - 3;
if (all_doc_tokens.size() > max_context_len) {
all_doc_tokens.resize(max_context_len);
}
std::vector<std::string> tokens;
tokens.reserve(3 + query_tokens.size() + all_doc_tokens.size());
std::vector<int> segment_ids;
segment_ids.reserve(kMaxSeqLen);
// Start of generating the features.
tokens.emplace_back("[CLS]");
segment_ids.emplace_back(0);
// For query input.
for (const auto& query_token : query_tokens) {
tokens.emplace_back(query_token);
segment_ids.emplace_back(0);
}
// For Separation.
tokens.emplace_back("[SEP]");
segment_ids.emplace_back(0);
// For Text Input.
for (int i = 0; i < all_doc_tokens.size(); i++) {
auto& doc_token = all_doc_tokens[i];
tokens.emplace_back(doc_token);
segment_ids.emplace_back(1);
token_to_orig_map_[tokens.size()] = token_to_orig_index[i];
}
// For ending mark.
tokens.emplace_back("[SEP]");
segment_ids.emplace_back(1);
std::vector<int> input_ids(tokens.size());
input_ids.reserve(kMaxSeqLen);
// Convert tokens back into ids
for (int i = 0; i < tokens.size(); i++) {
auto& token = tokens[i];
tokenizer_->LookupId(token, &input_ids[i]);
}
std::vector<int> input_mask;
input_mask.reserve(kMaxSeqLen);
input_mask.insert(input_mask.end(), tokens.size(), 1);
int zeros_to_pad = kMaxSeqLen - input_ids.size();
input_ids.insert(input_ids.end(), zeros_to_pad, 0);
input_mask.insert(input_mask.end(), zeros_to_pad, 0);
segment_ids.insert(segment_ids.end(), zeros_to_pad, 0);
// input_ids INT32[1, 384]
PopulateTensor(input_ids, ids_tensor);
// input_mask INT32[1, 384]
PopulateTensor(input_mask, mask_tensor);
// segment_ids INT32[1, 384]
PopulateTensor(segment_ids, segment_ids_tensor);
return absl::OkStatus();
}
StatusOr<std::vector<QaAnswer>> BertQuestionAnswerer::Postprocess(
const std::vector<const TfLiteTensor*>& output_tensors,
const std::string& /*lowercased_context*/,
const std::string& /*lowercased_query*/) {
auto* output_tensor_metadatas =
GetMetadataExtractor()->GetOutputTensorMetadata();
const TfLiteTensor* end_logits_tensor =
output_tensor_metadatas
? FindTensorByName(output_tensors, output_tensor_metadatas,
kEndLogitsTensorName)
: output_tensors[0];
const TfLiteTensor* start_logits_tensor =
output_tensor_metadatas
? FindTensorByName(output_tensors, output_tensor_metadatas,
kStartLogitsTensorName)
: output_tensors[1];
std::vector<float> end_logits;
std::vector<float> start_logits;
// end_logits FLOAT[1, 384]
PopulateVector(end_logits_tensor, &end_logits);
// start_logits FLOAT[1, 384]
PopulateVector(start_logits_tensor, &start_logits);
auto start_indices = ReverseSortIndices(start_logits);
auto end_indices = ReverseSortIndices(end_logits);
std::vector<QaAnswer::Pos> orig_results;
for (int start_index = 0; start_index < kPredictAnsNum; start_index++) {
for (int end_index = 0; end_index < kPredictAnsNum; end_index++) {
int start = start_indices[start_index];
int end = end_indices[end_index];
if (!token_to_orig_map_.contains(start + kOutputOffset) ||
!token_to_orig_map_.contains(end + kOutputOffset) || end < start ||
(end - start + 1) > kMaxAnsLen) {
continue;
}
orig_results.emplace_back(
QaAnswer::Pos(start, end, start_logits[start] + end_logits[end]));
}
}
std::sort(orig_results.begin(), orig_results.end());
std::vector<QaAnswer> answers;
for (int i = 0; i < orig_results.size() && i < kPredictAnsNum; i++) {
auto orig_pos = orig_results[i];
answers.emplace_back(
orig_pos.start > 0 ? ConvertIndexToString(orig_pos.start, orig_pos.end)
: "",
orig_pos);
}
return answers;
}
std::string BertQuestionAnswerer::ConvertIndexToString(int start, int end) {
int start_index = token_to_orig_map_[start + kOutputOffset];
int end_index = token_to_orig_map_[end + kOutputOffset];
return absl::StrJoin(orig_tokens_.begin() + start_index,
orig_tokens_.begin() + end_index + 1, " ");
}
absl::Status BertQuestionAnswerer::InitializeFromMetadata() {
const ProcessUnit* tokenizer_process_unit =
GetMetadataExtractor()->GetInputProcessUnit(kTokenizerProcessUnitIndex);
if (tokenizer_process_unit == nullptr) {
return CreateStatusWithPayload(
absl::StatusCode::kInvalidArgument,
"No input process unit found from metadata.",
TfLiteSupportStatus::kMetadataInvalidTokenizerError);
}
ASSIGN_OR_RETURN(tokenizer_,
CreateTokenizerFromProcessUnit(tokenizer_process_unit,
GetMetadataExtractor()));
return absl::OkStatus();
}
void BertQuestionAnswerer::InitializeBertTokenizer(
const std::string& path_to_vocab) {
tokenizer_ = absl::make_unique<BertTokenizer>(path_to_vocab);
}
void BertQuestionAnswerer::InitializeBertTokenizerFromBinary(
const char* vocab_buffer_data, size_t vocab_buffer_size) {
tokenizer_ =
absl::make_unique<BertTokenizer>(vocab_buffer_data, vocab_buffer_size);
}
void BertQuestionAnswerer::InitializeSentencepieceTokenizer(
const std::string& path_to_spmodel) {
tokenizer_ = absl::make_unique<SentencePieceTokenizer>(path_to_spmodel);
}
void BertQuestionAnswerer::InitializeSentencepieceTokenizerFromBinary(
const char* spmodel_buffer_data, size_t spmodel_buffer_size) {
tokenizer_ = absl::make_unique<SentencePieceTokenizer>(spmodel_buffer_data,
spmodel_buffer_size);
}
} // namespace qa
} // namespace text
} // namespace task
} // namespace tflite