blob: 2a9573a13cde8e5dcc19e3f7be8d613150a5c7dd [file] [log] [blame]
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_NLCLASSIFIER_NL_CLASSIFIER_H_
#define TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_NLCLASSIFIER_NL_CLASSIFIER_H_
#include <stddef.h>
#include <string.h>
#include <memory>
#include <string>
#include <vector>
#include "absl/status/status.h"
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/core/api/op_resolver.h"
#include "tensorflow/lite/kernels/register.h"
#include "tensorflow/lite/string_type.h"
#include "tensorflow_lite_support/cc/common.h"
#include "tensorflow_lite_support/cc/port/statusor.h"
#include "tensorflow_lite_support/cc/task/core/base_task_api.h"
#include "tensorflow_lite_support/cc/task/core/category.h"
#include "tensorflow_lite_support/cc/text/tokenizers/regex_tokenizer.h"
namespace tflite {
namespace task {
namespace text {
namespace nlclassifier {
// Options to identify input and output tensors of the model
struct NLClassifierOptions {
int input_tensor_index = 0;
int output_score_tensor_index = 0;
// By default there is no output label tensor. The label file can be attached
// to the output score tensor metadata.
int output_label_tensor_index = -1;
std::string input_tensor_name = "INPUT";
std::string output_score_tensor_name = "OUTPUT_SCORE";
std::string output_label_tensor_name = "OUTPUT_LABEL";
};
// Classifier API for NLClassification tasks, categorizes string into different
// classes.
//
// The API expects a TFLite model with the following input/output tensor:
// Input tensor:
// (kTfLiteString) - input of the model, accepts a string.
// or
// (kTfLiteInt32) - input of the model, accepts a tokenized
// indices of a string input. A RegexTokenizer needs to be set up in the input
// tensor's metadata.
// Output score tensor:
// (kTfLiteUInt8/kTfLiteInt8/kTfLiteInt16/kTfLiteFloat32/
// kTfLiteFloat64/kTfLiteBool)
// - output scores for each class, if type is one of the Int types,
// dequantize it to double, if type is kTfLiteBool, convert the values to
// 0.0 and 1.0 respectively
// - can have an optional associated file in metadata for labels, the file
// should be a plain text file with one label per line, the number of
// labels should match the number of categories the model outputs.
// Output label tensor: optional
// (kTfLiteString/kTfLiteInt32)
// - output classname for each class, should be of the same length with
// scores. If this tensor is not present, the API uses score indices as
// classnames.
// - will be ignored if output score tensor already has an associated label
// file.
//
// By default the API tries to find the input/output tensors with default
// configurations in NLClassifierOptions, with tensor name prioritized over
// tensor index. The option is configurable for different TFLite models.
class NLClassifier : public core::BaseTaskApi<std::vector<core::Category>,
const std::string&> {
public:
using BaseTaskApi::BaseTaskApi;
// Creates a NLClassifier from TFLite model buffer.
static tflite::support::StatusOr<std::unique_ptr<NLClassifier>>
CreateFromBufferAndOptions(
const char* model_buffer_data, size_t model_buffer_size,
const NLClassifierOptions& options = {},
std::unique_ptr<tflite::OpResolver> resolver =
absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>());
// Creates a NLClassifier from TFLite model file.
static tflite::support::StatusOr<std::unique_ptr<NLClassifier>>
CreateFromFileAndOptions(
const std::string& path_to_model, const NLClassifierOptions& options = {},
std::unique_ptr<tflite::OpResolver> resolver =
absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>());
// Creates a NLClassifier from TFLite model file descriptor.
static tflite::support::StatusOr<std::unique_ptr<NLClassifier>>
CreateFromFdAndOptions(
int fd, const NLClassifierOptions& options = {},
std::unique_ptr<tflite::OpResolver> resolver =
absl::make_unique<tflite::ops::builtin::BuiltinOpResolver>());
// Performs classification on a string input, returns classified results.
std::vector<core::Category> Classify(const std::string& text);
protected:
static constexpr int kOutputTensorIndex = 0;
static constexpr int kOutputTensorLabelFileIndex = 0;
absl::Status Initialize(const NLClassifierOptions& options);
const NLClassifierOptions& GetOptions() const;
// Try to extract attached label file from metadata and initialize
// labels_vector_, return error if metadata type is incorrect or no label file
// is attached in metadata.
absl::Status TrySetLabelFromMetadata(const TensorMetadata* metadata);
// Pass through the input text into model's input tensor.
absl::Status Preprocess(const std::vector<TfLiteTensor*>& input_tensors,
const std::string& input) override;
// Extract model output and create results with output label tensor or label
// file attached in metadata. If no output label tensor or label file is
// found, use output score index as labels.
tflite::support::StatusOr<std::vector<core::Category>> Postprocess(
const std::vector<const TfLiteTensor*>& output_tensors,
const std::string& input) override;
std::vector<core::Category> BuildResults(const TfLiteTensor* scores,
const TfLiteTensor* labels);
// Gets the tensor from a vector of tensors by checking tensor name first and
// tensor index second, return nullptr if no tensor is found.
template <typename TensorType>
static TensorType* FindTensorWithNameOrIndex(
const std::vector<TensorType*>& tensors,
const flatbuffers::Vector<flatbuffers::Offset<TensorMetadata>>*
metadata_array,
const std::string& name, int index) {
if (metadata_array != nullptr && metadata_array->size() == tensors.size()) {
for (int i = 0; i < metadata_array->size(); i++) {
if (strcmp(name.data(), metadata_array->Get(i)->name()->c_str()) == 0) {
return tensors[i];
}
}
}
for (TensorType* tensor : tensors) {
if (tensor->name == name) {
return tensor;
}
}
return index >= 0 && index < tensors.size() ? tensors[index] : nullptr;
}
private:
bool HasRegexTokenizerMetadata();
absl::Status SetupRegexTokenizer();
NLClassifierOptions options_;
// labels vector initialized from output tensor's associated file, if one
// exists.
std::unique_ptr<std::vector<std::string>> labels_vector_;
std::unique_ptr<tflite::support::text::tokenizer::RegexTokenizer> tokenizer_;
};
} // namespace nlclassifier
} // namespace text
} // namespace task
} // namespace tflite
#endif // TENSORFLOW_LITE_SUPPORT_CC_TASK_TEXT_NLCLASSIFIER_NL_CLASSIFIER_H_