blob: 4c65dc0017f45eb58b3ce557e5494639d72c7e65 [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_QA_BERT_QUESTION_ANSWERER_H_
#define TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_QA_BERT_QUESTION_ANSWERER_H_
#include "absl/container/flat_hash_map.h"
#include "absl/status/status.h"
#include "tensorflow_lite_support/cc/port/statusor.h"
#include "tensorflow_lite_support/cc/task/core/base_task_api.h"
#include "tensorflow_lite_support/cc/task/core/task_api_factory.h"
#include "tensorflow_lite_support/cc/task/core/tflite_engine.h"
#include "tensorflow_lite_support/cc/task/text/qa/question_answerer.h"
#include "tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.h"
#include "tensorflow_lite_support/cc/text/tokenizers/sentencepiece_tokenizer.h"
namespace tflite {
namespace task {
namespace text {
namespace qa {
// BertQA task API, performs tokenization for models (BERT, Albert, etc.) in
// preprocess and returns most possible answers.
//
// In particular, the branch of BERT models use WordPiece tokenizer, and the
// branch of Albert models use SentencePiece tokenizer, respectively.
//
// Factory methods:
// CreateFromFile(path_to_model_with_metadata)
// CreateFromBuffer(model_with_metadata_buffer_data,
// model_with_metadata_buffer_size)
// CreateFromFd(file_descriptor_to_model_with_metadata)
// Generic API to create the QuestionAnswerer for bert models with metadata
// populated. The API expects a Bert based TFLite model with metadata
// containing the following information:
// - input_process_units for Wordpiece/Sentencepiece Tokenizer. Wordpiece
// Tokenizer can be used for a MobileBert[0] model, Sentencepiece
// Tokenizer Tokenizer can be used for an Albert[1] model
// - 3 input tensors with names "ids", "mask" and "segment_ids"
// - 2 output tensors with names "end_logits" and "start_logits"
// [0]: https://tfhub.dev/tensorflow/lite-model/mobilebert/1/default/1
// [1]: https://tfhub.dev/tensorflow/lite-model/albert_lite_base/squadv1/1
//
// CreateBertQuestionAnswererFromFile(path_to_model, path_to_vocab)
// Creates a BertQuestionAnswerer from TFLite model file and vocab file for
// WordPiece tokenizer. Used in C++ environment.
// One suitable model is:
// https://tfhub.dev/tensorflow/lite-model/mobilebert/1/default/1
//
// CreateBertQuestionAnswererFromBuffer(model_buffer_data, model_buffer_size,
// vocab_buffer_data, vocab_buffer_size)
// Creates a BertQuestionAnswerer from TFLite model buffer and vocab file
// buffer for WordPiece tokenizer. Used in Jave (JNI) environment.
//
// CreateAlbertQuestionAnswererFromFile(path_to_model, path_to_spmodel)
// Creates an AlbertQuestionAnswerer from TFLite model file and
// SentencePiece model file. Used in C++ environment.
// One suitable model is:
// https://tfhub.dev/tensorflow/lite-model/albert_lite_base/squadv1/1
//
// CreateAlbertQuestionAnswererFromBuffer(model_buffer_data,
// model_buffer_size,
// spmodel_buffer_data,
// spmodel_buffer_size)
// Creates an AlbertQuestionAnswerer from TFLite model file buffer and
// SentencePiece model file buffer. Used in Jave (JNI) environment.
//
class BertQuestionAnswerer : public QuestionAnswerer {
public:
// TODO(b/150904655): add support to parameterize.
static constexpr int kMaxQueryLen = 64;
static constexpr int kMaxSeqLen = 384;
static constexpr int kPredictAnsNum = 5;
static constexpr int kMaxAnsLen = 32;
// TODO(b/151954803): clarify the offset usage
static constexpr int kOutputOffset = 1;
static constexpr int kNumLiteThreads = 4;
static constexpr bool kUseLowerCase = true;
static tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>>
CreateFromFile(const std::string& path_to_model_with_metadata);
static tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>>
CreateFromBuffer(const char* model_with_metadata_buffer_data,
size_t model_with_metadata_buffer_size);
static tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>>
CreateFromFd(int fd);
static tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>>
CreateBertQuestionAnswererFromFile(const std::string& path_to_model,
const std::string& path_to_vocab);
static tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>>
CreateBertQuestionAnswererFromBuffer(const char* model_buffer_data,
size_t model_buffer_size,
const char* vocab_buffer_data,
size_t vocab_buffer_size);
static tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>>
CreateAlbertQuestionAnswererFromFile(const std::string& path_to_model,
const std::string& path_to_spmodel);
static tflite::support::StatusOr<std::unique_ptr<QuestionAnswerer>>
CreateAlbertQuestionAnswererFromBuffer(const char* model_buffer_data,
size_t model_buffer_size,
const char* spmodel_buffer_data,
size_t spmodel_buffer_size);
explicit BertQuestionAnswerer(std::unique_ptr<core::TfLiteEngine> engine)
: QuestionAnswerer(std::move(engine)) {}
// Answers question based on the context. Could be empty if no answer was
// found from the given context.
std::vector<QaAnswer> Answer(const std::string& context,
const std::string& question) override;
private:
absl::Status Preprocess(const std::vector<TfLiteTensor*>& input_tensors,
const std::string& lowercased_context,
const std::string& lowercased_query) override;
tflite::support::StatusOr<std::vector<QaAnswer>> Postprocess(
const std::vector<const TfLiteTensor*>& output_tensors,
const std::string& lowercased_context,
const std::string& lowercased_query) override;
// Initialize API with a BertTokenizer from the vocabulary file.
void InitializeBertTokenizer(const std::string& path_to_vocab);
// Initialize API with a BertTokenizer from the vocabulary buffer.
void InitializeBertTokenizerFromBinary(const char* vocab_buffer_data,
size_t vocab_buffer_size);
// Initialize API with a SentencepieceTokenizer from the model file.
void InitializeSentencepieceTokenizer(const std::string& path_to_spmodel);
// Initialize API with a SentencepieceTokenizer from the model buffer.
void InitializeSentencepieceTokenizerFromBinary(
const char* spmodel_buffer_data, size_t spmodel_buffer_size);
// Initialize the API with the tokenizer set in the metadata.
absl::Status InitializeFromMetadata();
std::string ConvertIndexToString(int start, int end);
std::unique_ptr<tflite::support::text::tokenizer::Tokenizer> tokenizer_;
// Maps index of input token to index of untokenized word from original input.
absl::flat_hash_map<size_t, size_t> token_to_orig_map_;
// Original tokens of context.
std::vector<std::string> orig_tokens_;
};
} // namespace qa
} // namespace text
} // namespace task
} // namespace tflite
#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_QA_BERT_QUESTION_ANSWERER_H_