blob: bcc318b4fa4654689248e060a05511834655e3b0 [file] [log] [blame]
/*
* Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Contains classes that can execute different models/parts of a model.
#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_MODEL_EXECUTOR_H_
#define LIBTEXTCLASSIFIER_ANNOTATOR_MODEL_EXECUTOR_H_
#include <memory>
#include "annotator/types.h"
#include "utils/base/logging.h"
#include "utils/tensor-view.h"
#include "utils/tflite-model-executor.h"
namespace libtextclassifier3 {
// Executor for the text selection prediction and classification models.
class ModelExecutor : public TfLiteModelExecutor {
public:
static std::unique_ptr<ModelExecutor> FromModelSpec(
const tflite::Model* model_spec) {
auto model = TfLiteModelFromModelSpec(model_spec);
if (!model) {
return nullptr;
}
return std::unique_ptr<ModelExecutor>(new ModelExecutor(std::move(model)));
}
static std::unique_ptr<ModelExecutor> FromBuffer(
const flatbuffers::Vector<uint8_t>* model_spec_buffer) {
auto model = TfLiteModelFromBuffer(model_spec_buffer);
if (!model) {
return nullptr;
}
return std::unique_ptr<ModelExecutor>(new ModelExecutor(std::move(model)));
}
TensorView<float> ComputeLogits(const TensorView<float>& features,
tflite::Interpreter* interpreter) const;
protected:
explicit ModelExecutor(std::unique_ptr<const tflite::FlatBufferModel> model)
: TfLiteModelExecutor(std::move(model)) {}
static const int kInputIndexFeatures = 0;
static const int kOutputIndexLogits = 0;
};
// Executor for embedding sparse features into a dense vector.
class EmbeddingExecutor {
public:
virtual ~EmbeddingExecutor() {}
// Embeds the sparse_features into a dense embedding and adds (+) it
// element-wise to the dest vector.
virtual bool AddEmbedding(const TensorView<int>& sparse_features, float* dest,
int dest_size) const = 0;
// Returns true when the model is ready to be used, false otherwise.
virtual bool IsReady() const { return true; }
};
class TFLiteEmbeddingExecutor : public EmbeddingExecutor {
public:
static std::unique_ptr<TFLiteEmbeddingExecutor> FromBuffer(
const flatbuffers::Vector<uint8_t>* model_spec_buffer, int embedding_size,
int quantization_bits,
const Model_::EmbeddingPruningMask* embedding_pruning_mask = nullptr);
// Embeds the sparse_features into a dense embedding and adds (+) it
// element-wise to the dest vector.
bool AddEmbedding(const TensorView<int>& sparse_features, float* dest,
int dest_size) const;
// Auxiliary function for computing prefixes used in implementation of
// efficient mask indexing data structure.
void ComputePrefixCounts();
// Function implementing mask indexing based on efficient data structure
int PruneBucketId(int bucket_id) const;
protected:
explicit TFLiteEmbeddingExecutor(
std::unique_ptr<TfLiteModelExecutor> executor, int quantization_bits,
int num_buckets, int bytes_per_embedding, int output_embedding_size,
const TfLiteTensor* scales, const TfLiteTensor* embeddings,
std::unique_ptr<tflite::Interpreter> interpreter,
const Model_::EmbeddingPruningMask* embedding_pruning_mask = nullptr);
std::unique_ptr<TfLiteModelExecutor> executor_;
int quantization_bits_;
int num_buckets_ = -1;
int bytes_per_embedding_ = -1;
int output_embedding_size_ = -1;
const TfLiteTensor* scales_ = nullptr;
const TfLiteTensor* embeddings_ = nullptr;
// NOTE: This interpreter is used in a read-only way (as a storage for the
// model params), thus is still thread-safe.
std::unique_ptr<tflite::Interpreter> interpreter_;
std::vector<uint64> pruning_mask_;
std::vector<uint16> prefix_counts_;
int full_num_buckets_ = -1;
// Index of row of embedding table corresponding to all pruned buckets.
int pruned_row_bucket_id_ = -1;
};
} // namespace libtextclassifier3
#endif // LIBTEXTCLASSIFIER_ANNOTATOR_MODEL_EXECUTOR_H_