| /* |
| * Copyright (C) 2018 The Android Open Source Project |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| #ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_FLATBUFFERS_EMBEDDING_NETWORK_PARAMS_FROM_FLATBUFFER_H_ |
| #define NLP_SAFT_COMPONENTS_COMMON_MOBILE_FLATBUFFERS_EMBEDDING_NETWORK_PARAMS_FROM_FLATBUFFER_H_ |
| |
| #include <algorithm> |
| #include <memory> |
| #include <string> |
| #include <utility> |
| |
| #include "lang_id/common/embedding-network-params.h" |
| #include "lang_id/common/flatbuffers/embedding-network_generated.h" |
| #include "lang_id/common/lite_base/float16.h" |
| #include "lang_id/common/lite_base/logging.h" |
| #include "lang_id/common/lite_strings/stringpiece.h" |
| |
| namespace libtextclassifier3 { |
| namespace mobile { |
| |
| // EmbeddingNetworkParams implementation backed by a flatbuffer. |
| // |
| // For info on our flatbuffer schema, see embedding-network.fbs. |
| class EmbeddingNetworkParamsFromFlatbuffer : public EmbeddingNetworkParams { |
| public: |
| // Constructs an EmbeddingNetworkParamsFromFlatbuffer instance, using the |
| // flatbuffer from |bytes|. |
| // |
| // IMPORTANT #1: caller should make sure |bytes| are alive during the lifetime |
| // of this EmbeddingNetworkParamsFromFlatbuffer instance. To avoid overhead, |
| // this constructor does not copy |bytes|. |
| // |
| // IMPORTANT #2: immediately after this constructor returns, we suggest you |
| // call is_valid() on the newly-constructed object and do not call any other |
| // method if the answer is negative (false). |
| explicit EmbeddingNetworkParamsFromFlatbuffer(StringPiece bytes); |
| |
| bool UpdateTaskContextParameters(mobile::TaskContext *task_context) override { |
| // This class does not provide access to the overall TaskContext. It |
| // provides only parameters for the Neurosis neural network. |
| SAFTM_LOG(DFATAL) << "Not supported"; |
| return false; |
| } |
| |
| bool is_valid() const override { return valid_; } |
| |
| int embeddings_size() const override { return SafeGetNumInputChunks(); } |
| |
| int embeddings_num_rows(int i) const override { |
| const saft_fbs::Matrix *matrix = SafeGetEmbeddingMatrix(i); |
| return SafeGetNumRows(matrix); |
| } |
| |
| int embeddings_num_cols(int i) const override { |
| const saft_fbs::Matrix *matrix = SafeGetEmbeddingMatrix(i); |
| return SafeGetNumCols(matrix); |
| } |
| |
| const void *embeddings_weights(int i) const override { |
| const saft_fbs::Matrix *matrix = SafeGetEmbeddingMatrix(i); |
| return SafeGetValuesOfMatrix(matrix); |
| } |
| |
| QuantizationType embeddings_quant_type(int i) const override { |
| const saft_fbs::Matrix *matrix = SafeGetEmbeddingMatrix(i); |
| return SafeGetQuantizationType(matrix); |
| } |
| |
| const float16 *embeddings_quant_scales(int i) const override { |
| const saft_fbs::Matrix *matrix = SafeGetEmbeddingMatrix(i); |
| return SafeGetScales(matrix); |
| } |
| |
| int hidden_size() const override { |
| // -1 because last layer is always the softmax layer. |
| return std::max(SafeGetNumLayers() - 1, 0); |
| } |
| |
| int hidden_num_rows(int i) const override { |
| const saft_fbs::Matrix *weights = SafeGetLayerWeights(i); |
| return SafeGetNumRows(weights); |
| } |
| |
| int hidden_num_cols(int i) const override { |
| const saft_fbs::Matrix *weights = SafeGetLayerWeights(i); |
| return SafeGetNumCols(weights); |
| } |
| |
| QuantizationType hidden_weights_quant_type(int i) const override { |
| const saft_fbs::Matrix *weights = SafeGetLayerWeights(i); |
| return SafeGetQuantizationType(weights); |
| } |
| |
| const void *hidden_weights(int i) const override { |
| const saft_fbs::Matrix *weights = SafeGetLayerWeights(i); |
| return SafeGetValuesOfMatrix(weights); |
| } |
| |
| int hidden_bias_size() const override { return hidden_size(); } |
| |
| int hidden_bias_num_rows(int i) const override { |
| const saft_fbs::Matrix *bias = SafeGetLayerBias(i); |
| return SafeGetNumRows(bias); |
| } |
| |
| int hidden_bias_num_cols(int i) const override { |
| const saft_fbs::Matrix *bias = SafeGetLayerBias(i); |
| return SafeGetNumCols(bias); |
| } |
| |
| const void *hidden_bias_weights(int i) const override { |
| const saft_fbs::Matrix *bias = SafeGetLayerBias(i); |
| return SafeGetValues(bias); |
| } |
| |
| int softmax_size() const override { return (SafeGetNumLayers() > 0) ? 1 : 0; } |
| |
| int softmax_num_rows(int i) const override { |
| const saft_fbs::Matrix *weights = SafeGetSoftmaxWeights(); |
| return SafeGetNumRows(weights); |
| } |
| |
| int softmax_num_cols(int i) const override { |
| const saft_fbs::Matrix *weights = SafeGetSoftmaxWeights(); |
| return SafeGetNumCols(weights); |
| } |
| |
| QuantizationType softmax_weights_quant_type(int i) const override { |
| const saft_fbs::Matrix *weights = SafeGetSoftmaxWeights(); |
| return SafeGetQuantizationType(weights); |
| } |
| |
| const void *softmax_weights(int i) const override { |
| const saft_fbs::Matrix *weights = SafeGetSoftmaxWeights(); |
| return SafeGetValuesOfMatrix(weights); |
| } |
| |
| int softmax_bias_size() const override { return softmax_size(); } |
| |
| int softmax_bias_num_rows(int i) const override { |
| const saft_fbs::Matrix *bias = SafeGetSoftmaxBias(); |
| return SafeGetNumRows(bias); |
| } |
| |
| int softmax_bias_num_cols(int i) const override { |
| const saft_fbs::Matrix *bias = SafeGetSoftmaxBias(); |
| return SafeGetNumCols(bias); |
| } |
| |
| const void *softmax_bias_weights(int i) const override { |
| const saft_fbs::Matrix *bias = SafeGetSoftmaxBias(); |
| return SafeGetValues(bias); |
| } |
| |
| int embedding_num_features_size() const override { |
| return SafeGetNumInputChunks(); |
| } |
| |
| int embedding_num_features(int i) const override { |
| if (!InRangeIndex(i, embedding_num_features_size(), |
| "embedding num features")) { |
| return 0; |
| } |
| const saft_fbs::InputChunk *input_chunk = SafeGetInputChunk(i); |
| if (input_chunk == nullptr) { |
| return 0; |
| } |
| return input_chunk->num_features(); |
| } |
| |
| bool has_is_precomputed() const override { return false; } |
| bool is_precomputed() const override { return false; } |
| |
| private: |
| // Returns true if and only if index is in [0, limit). info should be a |
| // pointer to a zero-terminated array of chars (ideally a literal string, |
| // e.g. "layer") indicating what the index refers to; info is used to make log |
| // messages more informative. |
| static bool InRangeIndex(int index, int limit, const char *info); |
| |
| // Returns network_->input_chunks()->size(), if all dereferences are safe |
| // (i.e., no nullptr); otherwise, returns 0. |
| int SafeGetNumInputChunks() const; |
| |
| // Returns network_->input_chunks()->Get(i), if all dereferences are safe |
| // (i.e., no nullptr) otherwise, returns nullptr. |
| const saft_fbs::InputChunk *SafeGetInputChunk(int i) const; |
| |
| // Returns network_->input_chunks()->Get(i)->embedding(), if all dereferences |
| // are safe (i.e., no nullptr); otherwise, returns nullptr. |
| const saft_fbs::Matrix *SafeGetEmbeddingMatrix(int i) const; |
| |
| // Returns network_->layers()->size(), if all dereferences are safe (i.e., no |
| // nullptr); otherwise, returns 0. |
| int SafeGetNumLayers() const; |
| |
| // Returns network_->layers()->Get(i), if all dereferences are safe |
| // (i.e., no nullptr); otherwise, returns nullptr. |
| const saft_fbs::NeuralLayer *SafeGetLayer(int i) const; |
| |
| // Returns network_->layers()->Get(i)->weights(), if all dereferences are safe |
| // (i.e., no nullptr); otherwise, returns nullptr. |
| const saft_fbs::Matrix *SafeGetLayerWeights(int i) const; |
| |
| // Returns network_->layers()->Get(i)->bias(), if all dereferences are safe |
| // (i.e., no nullptr); otherwise, returns nullptr. |
| const saft_fbs::Matrix *SafeGetLayerBias(int i) const; |
| |
| static int SafeGetNumRows(const saft_fbs::Matrix *matrix) { |
| return (matrix == nullptr) ? 0 : matrix->rows(); |
| } |
| |
| static int SafeGetNumCols(const saft_fbs::Matrix *matrix) { |
| return (matrix == nullptr) ? 0 : matrix->cols(); |
| } |
| |
| // Returns matrix->values()->data() if all dereferences are safe (i.e., no |
| // nullptr); otherwise, returns nullptr. |
| static const float *SafeGetValues(const saft_fbs::Matrix *matrix); |
| |
| // Returns matrix->quantized_values()->data() if all dereferences are safe |
| // (i.e., no nullptr); otherwise, returns nullptr. |
| static const uint8_t *SafeGetQuantizedValues(const saft_fbs::Matrix *matrix); |
| |
| // Returns matrix->scales()->data() if all dereferences are safe (i.e., no |
| // nullptr); otherwise, returns nullptr. |
| static const float16 *SafeGetScales(const saft_fbs::Matrix *matrix); |
| |
| // Returns network_->layers()->Get(last_index) with last_index = |
| // SafeGetNumLayers() - 1, if all dereferences are safe (i.e., no nullptr) and |
| // there exists at least one layer; otherwise, returns nullptr. |
| const saft_fbs::NeuralLayer *SafeGetSoftmaxLayer() const; |
| |
| const saft_fbs::Matrix *SafeGetSoftmaxWeights() const { |
| const saft_fbs::NeuralLayer *layer = SafeGetSoftmaxLayer(); |
| return (layer == nullptr) ? nullptr : layer->weights(); |
| } |
| |
| const saft_fbs::Matrix *SafeGetSoftmaxBias() const { |
| const saft_fbs::NeuralLayer *layer = SafeGetSoftmaxLayer(); |
| return (layer == nullptr) ? nullptr : layer->bias(); |
| } |
| |
| // Returns the quantization type for |matrix|. Returns NONE in case of |
| // problems (e.g., matrix is nullptr or unknown quantization type). |
| QuantizationType SafeGetQuantizationType( |
| const saft_fbs::Matrix *matrix) const; |
| |
| // Returns a pointer to the values (float, uint8, or float16, depending on |
| // quantization) from |matrix|, in row-major order. Returns nullptr in case |
| // of a problem. |
| const void *SafeGetValuesOfMatrix(const saft_fbs::Matrix *matrix) const; |
| |
| // Performs some validity checks. E.g., check that dimensions of the network |
| // layers match. Also checks that all pointers we return are inside the |
| // |bytes| passed to the constructor, such that client that reads from those |
| // pointers will not run into troubles. |
| bool ValidityChecking(StringPiece bytes) const; |
| |
| // True if these params are valid. May be false if the original proto was |
| // corrupted. We prefer to set this to false to CHECK-failing. |
| bool valid_ = false; |
| |
| // EmbeddingNetwork flatbuffer from the bytes passed as parameter to the |
| // constructor; see constructor doc. |
| const saft_fbs::EmbeddingNetwork *network_ = nullptr; |
| }; |
| |
| } // namespace mobile |
| } // namespace nlp_saft |
| |
| #endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_FLATBUFFERS_EMBEDDING_NETWORK_PARAMS_FROM_FLATBUFFER_H_ |