blob: 57d59c5413f3a7438fcd7e5b91f72832e2f3693e [file] [log] [blame]
/*
* Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_FLATBUFFERS_EMBEDDING_NETWORK_PARAMS_FROM_FLATBUFFER_H_
#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_FLATBUFFERS_EMBEDDING_NETWORK_PARAMS_FROM_FLATBUFFER_H_
#include <algorithm>
#include <memory>
#include <string>
#include <utility>
#include "lang_id/common/embedding-network-params.h"
#include "lang_id/common/flatbuffers/embedding-network_generated.h"
#include "lang_id/common/lite_base/float16.h"
#include "lang_id/common/lite_base/logging.h"
#include "lang_id/common/lite_strings/stringpiece.h"
namespace libtextclassifier3 {
namespace mobile {
// EmbeddingNetworkParams implementation backed by a flatbuffer.
//
// For info on our flatbuffer schema, see embedding-network.fbs.
class EmbeddingNetworkParamsFromFlatbuffer : public EmbeddingNetworkParams {
public:
// Constructs an EmbeddingNetworkParamsFromFlatbuffer instance, using the
// flatbuffer from |bytes|.
//
// IMPORTANT #1: caller should make sure |bytes| are alive during the lifetime
// of this EmbeddingNetworkParamsFromFlatbuffer instance. To avoid overhead,
// this constructor does not copy |bytes|.
//
// IMPORTANT #2: immediately after this constructor returns, we suggest you
// call is_valid() on the newly-constructed object and do not call any other
// method if the answer is negative (false).
explicit EmbeddingNetworkParamsFromFlatbuffer(StringPiece bytes);
bool UpdateTaskContextParameters(mobile::TaskContext *task_context) override {
// This class does not provide access to the overall TaskContext. It
// provides only parameters for the Neurosis neural network.
SAFTM_LOG(DFATAL) << "Not supported";
return false;
}
bool is_valid() const override { return valid_; }
int embeddings_size() const override { return SafeGetNumInputChunks(); }
int embeddings_num_rows(int i) const override {
const saft_fbs::Matrix *matrix = SafeGetEmbeddingMatrix(i);
return SafeGetNumRows(matrix);
}
int embeddings_num_cols(int i) const override {
const saft_fbs::Matrix *matrix = SafeGetEmbeddingMatrix(i);
return SafeGetNumCols(matrix);
}
const void *embeddings_weights(int i) const override {
const saft_fbs::Matrix *matrix = SafeGetEmbeddingMatrix(i);
return SafeGetValuesOfMatrix(matrix);
}
QuantizationType embeddings_quant_type(int i) const override {
const saft_fbs::Matrix *matrix = SafeGetEmbeddingMatrix(i);
return SafeGetQuantizationType(matrix);
}
const float16 *embeddings_quant_scales(int i) const override {
const saft_fbs::Matrix *matrix = SafeGetEmbeddingMatrix(i);
return SafeGetScales(matrix);
}
int hidden_size() const override {
// -1 because last layer is always the softmax layer.
return std::max(SafeGetNumLayers() - 1, 0);
}
int hidden_num_rows(int i) const override {
const saft_fbs::Matrix *weights = SafeGetLayerWeights(i);
return SafeGetNumRows(weights);
}
int hidden_num_cols(int i) const override {
const saft_fbs::Matrix *weights = SafeGetLayerWeights(i);
return SafeGetNumCols(weights);
}
QuantizationType hidden_weights_quant_type(int i) const override {
const saft_fbs::Matrix *weights = SafeGetLayerWeights(i);
return SafeGetQuantizationType(weights);
}
const void *hidden_weights(int i) const override {
const saft_fbs::Matrix *weights = SafeGetLayerWeights(i);
return SafeGetValuesOfMatrix(weights);
}
int hidden_bias_size() const override { return hidden_size(); }
int hidden_bias_num_rows(int i) const override {
const saft_fbs::Matrix *bias = SafeGetLayerBias(i);
return SafeGetNumRows(bias);
}
int hidden_bias_num_cols(int i) const override {
const saft_fbs::Matrix *bias = SafeGetLayerBias(i);
return SafeGetNumCols(bias);
}
const void *hidden_bias_weights(int i) const override {
const saft_fbs::Matrix *bias = SafeGetLayerBias(i);
return SafeGetValues(bias);
}
int softmax_size() const override { return (SafeGetNumLayers() > 0) ? 1 : 0; }
int softmax_num_rows(int i) const override {
const saft_fbs::Matrix *weights = SafeGetSoftmaxWeights();
return SafeGetNumRows(weights);
}
int softmax_num_cols(int i) const override {
const saft_fbs::Matrix *weights = SafeGetSoftmaxWeights();
return SafeGetNumCols(weights);
}
QuantizationType softmax_weights_quant_type(int i) const override {
const saft_fbs::Matrix *weights = SafeGetSoftmaxWeights();
return SafeGetQuantizationType(weights);
}
const void *softmax_weights(int i) const override {
const saft_fbs::Matrix *weights = SafeGetSoftmaxWeights();
return SafeGetValuesOfMatrix(weights);
}
int softmax_bias_size() const override { return softmax_size(); }
int softmax_bias_num_rows(int i) const override {
const saft_fbs::Matrix *bias = SafeGetSoftmaxBias();
return SafeGetNumRows(bias);
}
int softmax_bias_num_cols(int i) const override {
const saft_fbs::Matrix *bias = SafeGetSoftmaxBias();
return SafeGetNumCols(bias);
}
const void *softmax_bias_weights(int i) const override {
const saft_fbs::Matrix *bias = SafeGetSoftmaxBias();
return SafeGetValues(bias);
}
int embedding_num_features_size() const override {
return SafeGetNumInputChunks();
}
int embedding_num_features(int i) const override {
if (!InRangeIndex(i, embedding_num_features_size(),
"embedding num features")) {
return 0;
}
const saft_fbs::InputChunk *input_chunk = SafeGetInputChunk(i);
if (input_chunk == nullptr) {
return 0;
}
return input_chunk->num_features();
}
bool has_is_precomputed() const override { return false; }
bool is_precomputed() const override { return false; }
private:
// Returns true if and only if index is in [0, limit). info should be a
// pointer to a zero-terminated array of chars (ideally a literal string,
// e.g. "layer") indicating what the index refers to; info is used to make log
// messages more informative.
static bool InRangeIndex(int index, int limit, const char *info);
// Returns network_->input_chunks()->size(), if all dereferences are safe
// (i.e., no nullptr); otherwise, returns 0.
int SafeGetNumInputChunks() const;
// Returns network_->input_chunks()->Get(i), if all dereferences are safe
// (i.e., no nullptr) otherwise, returns nullptr.
const saft_fbs::InputChunk *SafeGetInputChunk(int i) const;
// Returns network_->input_chunks()->Get(i)->embedding(), if all dereferences
// are safe (i.e., no nullptr); otherwise, returns nullptr.
const saft_fbs::Matrix *SafeGetEmbeddingMatrix(int i) const;
// Returns network_->layers()->size(), if all dereferences are safe (i.e., no
// nullptr); otherwise, returns 0.
int SafeGetNumLayers() const;
// Returns network_->layers()->Get(i), if all dereferences are safe
// (i.e., no nullptr); otherwise, returns nullptr.
const saft_fbs::NeuralLayer *SafeGetLayer(int i) const;
// Returns network_->layers()->Get(i)->weights(), if all dereferences are safe
// (i.e., no nullptr); otherwise, returns nullptr.
const saft_fbs::Matrix *SafeGetLayerWeights(int i) const;
// Returns network_->layers()->Get(i)->bias(), if all dereferences are safe
// (i.e., no nullptr); otherwise, returns nullptr.
const saft_fbs::Matrix *SafeGetLayerBias(int i) const;
static int SafeGetNumRows(const saft_fbs::Matrix *matrix) {
return (matrix == nullptr) ? 0 : matrix->rows();
}
static int SafeGetNumCols(const saft_fbs::Matrix *matrix) {
return (matrix == nullptr) ? 0 : matrix->cols();
}
// Returns matrix->values()->data() if all dereferences are safe (i.e., no
// nullptr); otherwise, returns nullptr.
static const float *SafeGetValues(const saft_fbs::Matrix *matrix);
// Returns matrix->quantized_values()->data() if all dereferences are safe
// (i.e., no nullptr); otherwise, returns nullptr.
static const uint8_t *SafeGetQuantizedValues(const saft_fbs::Matrix *matrix);
// Returns matrix->scales()->data() if all dereferences are safe (i.e., no
// nullptr); otherwise, returns nullptr.
static const float16 *SafeGetScales(const saft_fbs::Matrix *matrix);
// Returns network_->layers()->Get(last_index) with last_index =
// SafeGetNumLayers() - 1, if all dereferences are safe (i.e., no nullptr) and
// there exists at least one layer; otherwise, returns nullptr.
const saft_fbs::NeuralLayer *SafeGetSoftmaxLayer() const;
const saft_fbs::Matrix *SafeGetSoftmaxWeights() const {
const saft_fbs::NeuralLayer *layer = SafeGetSoftmaxLayer();
return (layer == nullptr) ? nullptr : layer->weights();
}
const saft_fbs::Matrix *SafeGetSoftmaxBias() const {
const saft_fbs::NeuralLayer *layer = SafeGetSoftmaxLayer();
return (layer == nullptr) ? nullptr : layer->bias();
}
// Returns the quantization type for |matrix|. Returns NONE in case of
// problems (e.g., matrix is nullptr or unknown quantization type).
QuantizationType SafeGetQuantizationType(
const saft_fbs::Matrix *matrix) const;
// Returns a pointer to the values (float, uint8, or float16, depending on
// quantization) from |matrix|, in row-major order. Returns nullptr in case
// of a problem.
const void *SafeGetValuesOfMatrix(const saft_fbs::Matrix *matrix) const;
// Performs some validity checks. E.g., check that dimensions of the network
// layers match. Also checks that all pointers we return are inside the
// |bytes| passed to the constructor, such that client that reads from those
// pointers will not run into troubles.
bool ValidityChecking(StringPiece bytes) const;
// True if these params are valid. May be false if the original proto was
// corrupted. We prefer to set this to false to CHECK-failing.
bool valid_ = false;
// EmbeddingNetwork flatbuffer from the bytes passed as parameter to the
// constructor; see constructor doc.
const saft_fbs::EmbeddingNetwork *network_ = nullptr;
};
} // namespace mobile
} // namespace nlp_saft
#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_FLATBUFFERS_EMBEDDING_NETWORK_PARAMS_FROM_FLATBUFFER_H_