| /* |
| * Copyright (C) 2018 The Android Open Source Project |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| #include "lang_id/common/flatbuffers/embedding-network-params-from-flatbuffer.h" |
| |
| #include "lang_id/common/lite_base/endian.h" |
| #include "lang_id/common/lite_base/logging.h" |
| #include "lang_id/common/lite_base/macros.h" |
| |
| namespace libtextclassifier3 { |
| namespace mobile { |
| |
| namespace { |
| // Returns true if and only if ptr points to a location inside allowed_range. |
| bool IsPointerInRange(const char *ptr, StringPiece allowed_range) { |
| return (ptr >= allowed_range.data()) && |
| (ptr < (allowed_range.data() + allowed_range.size())); |
| } |
| |
| // Returns true if and only if the memory range [start, start + |
| // range_size_in_bytes) is included inside allowed_range. |
| // |
| // Special case: if range_size_in_bytes == 0 (empty range) then we require that |
| // start is nullptr or in the allowed_range. |
| bool IsMemoryRangeValid(const void *start, int range_size_in_bytes, |
| StringPiece allowed_range) { |
| const char *begin = reinterpret_cast<const char *>(start); |
| if (range_size_in_bytes < 0) { |
| return false; |
| } |
| if (range_size_in_bytes == 0) { |
| return (start == nullptr) || IsPointerInRange(begin, allowed_range); |
| } |
| const char *inclusive_end = begin + (range_size_in_bytes - 1); |
| return (begin <= inclusive_end) && IsPointerInRange(begin, allowed_range) && |
| IsPointerInRange(inclusive_end, allowed_range); |
| } |
| |
| bool VerifyQuantizationScales(EmbeddingNetworkParams::Matrix matrix, |
| StringPiece bytes) { |
| if (matrix.quant_scales == nullptr) { |
| SAFTM_LOG(ERROR) << "Quantization type " |
| << static_cast<int>(matrix.quant_type) |
| << "; but no quantization scales"; |
| return false; |
| } |
| bool valid_scales = IsMemoryRangeValid(matrix.quant_scales, |
| matrix.rows * sizeof(float16), bytes); |
| if (!valid_scales) { |
| SAFTM_LOG(ERROR) << "quantization scales not fully inside bytes"; |
| return false; |
| } |
| return true; |
| } |
| |
| // Returns false if we detect a problem with |matrix|, true otherwise. E.g., we |
| // check that the array that starts at pointer matrix.elements is fully inside |
| // |bytes| (the range of bytes passed to the |
| // EmbeddingNetworkParamsFromFlatbuffer constructor). |
| bool VerifyMatrix(EmbeddingNetworkParams::Matrix matrix, StringPiece bytes) { |
| if ((matrix.rows < 0) || (matrix.cols < 0)) { |
| SAFTM_LOG(ERROR) << "Wrong matrix geometry: " << matrix.rows << " x " |
| << matrix.cols; |
| return false; |
| } |
| |
| const int num_elements = matrix.rows * matrix.cols; |
| |
| // Number of bytes occupied by the num_elements elements that start at address |
| // matrix.elements. |
| int element_range_size_in_bytes = 0; |
| switch (matrix.quant_type) { |
| case QuantizationType::NONE: |
| element_range_size_in_bytes = num_elements * sizeof(float); |
| break; |
| case QuantizationType::UINT8: { |
| element_range_size_in_bytes = num_elements; |
| if (!VerifyQuantizationScales(matrix, bytes)) { |
| return false; |
| } |
| break; |
| } |
| case QuantizationType::UINT4: { |
| if (matrix.cols % 2 != 0) { |
| SAFTM_LOG(ERROR) << "UINT4 doesn't work with odd #cols" << matrix.cols; |
| return false; |
| } |
| element_range_size_in_bytes = num_elements / 2; |
| if (!VerifyQuantizationScales(matrix, bytes)) { |
| return false; |
| } |
| break; |
| } |
| case QuantizationType::FLOAT16: { |
| element_range_size_in_bytes = num_elements * sizeof(float16); |
| |
| // No need to verify the scales: FLOAT16 quantization does not use scales. |
| break; |
| } |
| default: |
| SAFTM_LOG(ERROR) << "Unsupported quantization type " |
| << static_cast<int>(matrix.quant_type); |
| return false; |
| } |
| if (matrix.elements == nullptr) { |
| SAFTM_LOG(ERROR) << "matrix.elements == nullptr"; |
| return false; |
| } |
| bool valid = |
| IsMemoryRangeValid(matrix.elements, element_range_size_in_bytes, bytes); |
| if (!valid) { |
| SAFTM_LOG(ERROR) << "elements not fully inside bytes"; |
| return false; |
| } |
| return true; |
| } |
| |
| // Checks the geometry of the network layer represented by |weights| and |bias|, |
| // assuming the input to this layer has size |input_size|. Returns false if we |
| // detect any problem, true otherwise. |
| bool GoodLayerGeometry(int input_size, |
| const EmbeddingNetworkParams::Matrix &weights, |
| const EmbeddingNetworkParams::Matrix &bias) { |
| if (weights.rows != input_size) { |
| SAFTM_LOG(ERROR) << "#rows " << weights.rows << " != " << input_size; |
| return false; |
| } |
| if ((bias.rows != 1) && (bias.cols != 1)) { |
| SAFTM_LOG(ERROR) << "bad bias vector geometry: " << bias.rows << " x " |
| << bias.cols; |
| return false; |
| } |
| int bias_dimension = bias.rows * bias.cols; |
| if (weights.cols != bias_dimension) { |
| SAFTM_LOG(ERROR) << "#cols " << weights.cols << " != " << bias_dimension; |
| return false; |
| } |
| return true; |
| } |
| } // namespace |
| |
| EmbeddingNetworkParamsFromFlatbuffer::EmbeddingNetworkParamsFromFlatbuffer( |
| StringPiece bytes) { |
| // We expect valid_ to be initialized to false at this point. We set it to |
| // true only if we successfully complete all initialization. On error, we |
| // return early, leaving valid_ set to false. |
| SAFTM_DCHECK(!valid_); |
| |
| // NOTE: current EmbeddingNetworkParams API works only on little-endian |
| // machines. Fortunately, all modern devices are little-endian so, instead of |
| // a costly API change, we support only the little-endian case. |
| // |
| // Technical explanation: for each Matrix, our API provides a pointer to the |
| // matrix elements (see Matrix field |elements|). For unquantized matrices, |
| // that's a const float *pointer; the client code (e.g., Neurosis) uses those |
| // floats directly. That is correct if the EmbeddingNetworkParams come from a |
| // proto, where the proto parsing already handled the endianness differences. |
| // But in the flatbuffer case, that's a pointer to floats in little-endian |
| // format (flatbuffers always use little-endian). If our API provided access |
| // to only one element at a time, the accessor method could swap the bytes "on |
| // the fly", using temporary variables. Instead, our API provides a pointer |
| // to all elements: as their number is variable (and underlying data is |
| // immutable), we can't ensure the bytes of all those elements are swapped |
| // without extra memory allocation to store the swapped bytes (which is what |
| // using flatbuffers is supposed to prevent). |
| if (!LittleEndian::IsLittleEndian()) { |
| SAFTM_LOG(INFO) << "Not a little-endian machine"; |
| return; |
| } |
| |
| const uint8_t *start = reinterpret_cast<const uint8_t *>(bytes.data()); |
| if (start == nullptr) { |
| // Note: as |bytes| is expected to be a valid EmbeddingNetwork flatbuffer, |
| // it should contain the 4-char identifier "NS00" (or a later version). It |
| // can't be empty; hence StringPiece(nullptr, 0) is not legal here. |
| SAFTM_LOG(ERROR) << "nullptr bytes"; |
| return; |
| } |
| flatbuffers::Verifier verifier(start, bytes.size()); |
| if (!saft_fbs::VerifyEmbeddingNetworkBuffer(verifier)) { |
| SAFTM_LOG(ERROR) << "Not a valid EmbeddingNetwork flatbuffer"; |
| return; |
| } |
| network_ = saft_fbs::GetEmbeddingNetwork(start); |
| if (network_ == nullptr) { |
| SAFTM_LOG(ERROR) << "Unable to interpret bytes as a flatbuffer"; |
| return; |
| } |
| |
| // Perform a few extra checks before declaring this object valid. |
| valid_ = ValidityChecking(bytes); |
| } |
| |
| bool EmbeddingNetworkParamsFromFlatbuffer::ValidityChecking( |
| StringPiece bytes) const { |
| int input_size = 0; |
| for (int i = 0; i < embeddings_size(); ++i) { |
| Matrix embeddings = GetEmbeddingMatrix(i); |
| if (!VerifyMatrix(embeddings, bytes)) { |
| SAFTM_LOG(ERROR) << "Bad embedding matrix #" << i; |
| return false; |
| } |
| input_size += embedding_num_features(i) * embeddings.cols; |
| } |
| int current_size = input_size; |
| for (int i = 0; i < hidden_size(); ++i) { |
| Matrix weights = GetHiddenLayerMatrix(i); |
| if (!VerifyMatrix(weights, bytes)) { |
| SAFTM_LOG(ERROR) << "Bad weights matrix for hidden layer #" << i; |
| return false; |
| } |
| Matrix bias = GetHiddenLayerBias(i); |
| if (!VerifyMatrix(bias, bytes)) { |
| SAFTM_LOG(ERROR) << "Bad bias vector for hidden layer #" << i; |
| return false; |
| } |
| if (!GoodLayerGeometry(current_size, weights, bias)) { |
| SAFTM_LOG(ERROR) << "Bad geometry for hidden layer #" << i; |
| return false; |
| } |
| current_size = weights.cols; |
| } |
| |
| if (HasSoftmax()) { |
| Matrix weights = GetSoftmaxMatrix(); |
| if (!VerifyMatrix(weights, bytes)) { |
| SAFTM_LOG(ERROR) << "Bad weights matrix for softmax"; |
| return false; |
| } |
| Matrix bias = GetSoftmaxBias(); |
| if (!VerifyMatrix(bias, bytes)) { |
| SAFTM_LOG(ERROR) << "Bad bias vector for softmax"; |
| return false; |
| } |
| if (!GoodLayerGeometry(current_size, weights, bias)) { |
| SAFTM_LOG(ERROR) << "Bad geometry for softmax layer"; |
| return false; |
| } |
| } |
| return true; |
| } |
| |
| // static |
| bool EmbeddingNetworkParamsFromFlatbuffer::InRangeIndex(int index, int limit, |
| const char *info) { |
| if ((index >= 0) && (index < limit)) { |
| return true; |
| } else { |
| SAFTM_LOG(ERROR) << info << " index " << index << " outside range [0, " |
| << limit << ")"; |
| return false; |
| } |
| } |
| |
| int EmbeddingNetworkParamsFromFlatbuffer::SafeGetNumInputChunks() const { |
| const auto *input_chunks = network_->input_chunks(); |
| if (input_chunks == nullptr) { |
| SAFTM_LOG(ERROR) << "nullptr input_chunks"; |
| return 0; |
| } |
| return input_chunks->size(); |
| } |
| |
| const saft_fbs::InputChunk * |
| EmbeddingNetworkParamsFromFlatbuffer::SafeGetInputChunk(int i) const { |
| if (!InRangeIndex(i, SafeGetNumInputChunks(), "input chunks")) { |
| return nullptr; |
| } |
| const auto *input_chunks = network_->input_chunks(); |
| if (input_chunks == nullptr) { |
| // Execution should not reach this point, due to how SafeGetNumInputChunks() |
| // is implemented. Still, just to be sure: |
| SAFTM_LOG(ERROR) << "nullptr input_chunks"; |
| return nullptr; |
| } |
| const saft_fbs::InputChunk *input_chunk = input_chunks->Get(i); |
| if (input_chunk == nullptr) { |
| SAFTM_LOG(ERROR) << "nullptr input chunk #" << i; |
| } |
| return input_chunk; |
| } |
| |
| const saft_fbs::Matrix * |
| EmbeddingNetworkParamsFromFlatbuffer::SafeGetEmbeddingMatrix(int i) const { |
| const saft_fbs::InputChunk *input_chunk = SafeGetInputChunk(i); |
| if (input_chunk == nullptr) return nullptr; |
| const saft_fbs::Matrix *matrix = input_chunk->embedding(); |
| if (matrix == nullptr) { |
| SAFTM_LOG(ERROR) << "nullptr embeding matrix #" << i; |
| } |
| return matrix; |
| } |
| |
| int EmbeddingNetworkParamsFromFlatbuffer::SafeGetNumLayers() const { |
| const auto *layers = network_->layers(); |
| if (layers == nullptr) { |
| SAFTM_LOG(ERROR) << "nullptr layers"; |
| return 0; |
| } |
| return layers->size(); |
| } |
| |
| const saft_fbs::NeuralLayer *EmbeddingNetworkParamsFromFlatbuffer::SafeGetLayer( |
| int i) const { |
| if (!InRangeIndex(i, SafeGetNumLayers(), "layer")) { |
| return nullptr; |
| } |
| const auto *layers = network_->layers(); |
| if (layers == nullptr) { |
| // Execution should not reach this point, due to how SafeGetNumLayers() |
| // is implemented. Still, just to be sure: |
| SAFTM_LOG(ERROR) << "nullptr layers"; |
| return nullptr; |
| } |
| const saft_fbs::NeuralLayer *layer = layers->Get(i); |
| if (layer == nullptr) { |
| SAFTM_LOG(ERROR) << "nullptr layer #" << i; |
| } |
| return layer; |
| } |
| |
| const saft_fbs::Matrix * |
| EmbeddingNetworkParamsFromFlatbuffer::SafeGetLayerWeights(int i) const { |
| const saft_fbs::NeuralLayer *layer = SafeGetLayer(i); |
| if (layer == nullptr) return nullptr; |
| const saft_fbs::Matrix *weights = layer->weights(); |
| if (weights == nullptr) { |
| SAFTM_LOG(ERROR) << "nullptr weights for layer #" << i; |
| } |
| return weights; |
| } |
| |
| const saft_fbs::Matrix *EmbeddingNetworkParamsFromFlatbuffer::SafeGetLayerBias( |
| int i) const { |
| const saft_fbs::NeuralLayer *layer = SafeGetLayer(i); |
| if (layer == nullptr) return nullptr; |
| const saft_fbs::Matrix *bias = layer->bias(); |
| if (bias == nullptr) { |
| SAFTM_LOG(ERROR) << "nullptr bias for layer #" << i; |
| } |
| return bias; |
| } |
| |
| // static |
| const float *EmbeddingNetworkParamsFromFlatbuffer::SafeGetValues( |
| const saft_fbs::Matrix *matrix) { |
| if (matrix == nullptr) return nullptr; |
| const flatbuffers::Vector<float> *values = matrix->values(); |
| if (values == nullptr) { |
| SAFTM_LOG(ERROR) << "nullptr values"; |
| } |
| return values->data(); |
| } |
| |
| // static |
| const uint8_t *EmbeddingNetworkParamsFromFlatbuffer::SafeGetQuantizedValues( |
| const saft_fbs::Matrix *matrix) { |
| if (matrix == nullptr) return nullptr; |
| const flatbuffers::Vector<uint8_t> *quantized_values = |
| matrix->quantized_values(); |
| if (quantized_values == nullptr) { |
| SAFTM_LOG(ERROR) << "nullptr quantized_values"; |
| } |
| return quantized_values->data(); |
| } |
| |
| // static |
| const float16 *EmbeddingNetworkParamsFromFlatbuffer::SafeGetScales( |
| const saft_fbs::Matrix *matrix) { |
| if (matrix == nullptr) return nullptr; |
| const flatbuffers::Vector<uint16_t> *scales = matrix->scales(); |
| if (scales == nullptr) { |
| SAFTM_LOG(ERROR) << "nullptr scales"; |
| } |
| return scales->data(); |
| } |
| |
| const saft_fbs::NeuralLayer * |
| EmbeddingNetworkParamsFromFlatbuffer::SafeGetSoftmaxLayer() const { |
| int num_layers = SafeGetNumLayers(); |
| if (num_layers <= 0) { |
| SAFTM_LOG(ERROR) << "No softmax layer"; |
| return nullptr; |
| } |
| return SafeGetLayer(num_layers - 1); |
| } |
| |
| QuantizationType EmbeddingNetworkParamsFromFlatbuffer::SafeGetQuantizationType( |
| const saft_fbs::Matrix *matrix) const { |
| if (matrix == nullptr) { |
| return QuantizationType::NONE; |
| } |
| saft_fbs::QuantizationType quantization_type = matrix->quantization_type(); |
| |
| // Conversion from nlp_saft::saft_fbs::QuantizationType to |
| // nlp_saft::QuantizationType (due to legacy reasons, we have both). |
| switch (quantization_type) { |
| case saft_fbs::QuantizationType_NONE: |
| return QuantizationType::NONE; |
| case saft_fbs::QuantizationType_UINT8: |
| return QuantizationType::UINT8; |
| case saft_fbs::QuantizationType_UINT4: |
| return QuantizationType::UINT4; |
| case saft_fbs::QuantizationType_FLOAT16: |
| return QuantizationType::FLOAT16; |
| default: |
| SAFTM_LOG(ERROR) << "Unsupported quantization type " |
| << static_cast<int>(quantization_type); |
| return QuantizationType::NONE; |
| } |
| } |
| |
| const void *EmbeddingNetworkParamsFromFlatbuffer::SafeGetValuesOfMatrix( |
| const saft_fbs::Matrix *matrix) const { |
| if (matrix == nullptr) { |
| return nullptr; |
| } |
| saft_fbs::QuantizationType quantization_type = matrix->quantization_type(); |
| switch (quantization_type) { |
| case saft_fbs::QuantizationType_NONE: |
| return SafeGetValues(matrix); |
| case saft_fbs::QuantizationType_UINT8: |
| SAFTM_FALLTHROUGH_INTENDED; |
| case saft_fbs::QuantizationType_UINT4: |
| SAFTM_FALLTHROUGH_INTENDED; |
| case saft_fbs::QuantizationType_FLOAT16: |
| return SafeGetQuantizedValues(matrix); |
| default: |
| SAFTM_LOG(ERROR) << "Unsupported quantization type " |
| << static_cast<int>(quantization_type); |
| return nullptr; |
| } |
| } |
| |
| } // namespace mobile |
| } // namespace nlp_saft |