blob: e8c7d1ea15c22d5e232c325fe02843958e6441f9 [file] [log] [blame]
/*
* Copyright (C) 2017 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LIBTEXTCLASSIFIER_COMMON_MEMORY_IMAGE_EMBEDDING_NETWORK_PARAMS_FROM_IMAGE_H_
#define LIBTEXTCLASSIFIER_COMMON_MEMORY_IMAGE_EMBEDDING_NETWORK_PARAMS_FROM_IMAGE_H_
#include "common/embedding-network-package.pb.h"
#include "common/embedding-network-params.h"
#include "common/embedding-network.pb.h"
#include "common/memory_image/memory-image-reader.h"
#include "util/base/integral_types.h"
#include "util/strings/stringpiece.h"
namespace libtextclassifier {
namespace nlp_core {
// EmbeddingNetworkParams backed by a memory image.
//
// In this context, a memory image is like an EmbeddingNetworkProto, but with
// all repeated weights (>99% of the size) directly usable (with no parsing
// required).
class EmbeddingNetworkParamsFromImage : public EmbeddingNetworkParams {
public:
// Constructs an EmbeddingNetworkParamsFromImage, using the memory image that
// starts at address start and contains num_bytes bytes.
EmbeddingNetworkParamsFromImage(const void *start, uint64 num_bytes)
: memory_reader_(start, num_bytes),
trimmed_proto_(memory_reader_.trimmed_proto()) {
embeddings_blob_offset_ = 0;
hidden_blob_offset_ = embeddings_blob_offset_ + embeddings_size();
if (trimmed_proto_.embeddings_size() &&
trimmed_proto_.embeddings(0).is_quantized()) {
// Adjust for quantization: each quantized matrix takes two blobs (instead
// of one): one for the quantized values and one for the scales.
hidden_blob_offset_ += embeddings_size();
}
hidden_bias_blob_offset_ = hidden_blob_offset_ + hidden_size();
softmax_blob_offset_ = hidden_bias_blob_offset_ + hidden_bias_size();
softmax_bias_blob_offset_ = softmax_blob_offset_ + softmax_size();
}
~EmbeddingNetworkParamsFromImage() override {}
const TaskSpec *GetTaskSpec() override {
auto extension_id = task_spec_in_embedding_network_proto;
if (trimmed_proto_.HasExtension(extension_id)) {
return &(trimmed_proto_.GetExtension(extension_id));
} else {
return nullptr;
}
}
protected:
int embeddings_size() const override {
return trimmed_proto_.embeddings_size();
}
int embeddings_num_rows(int i) const override {
TC_DCHECK(InRange(i, embeddings_size()));
return trimmed_proto_.embeddings(i).rows();
}
int embeddings_num_cols(int i) const override {
TC_DCHECK(InRange(i, embeddings_size()));
return trimmed_proto_.embeddings(i).cols();
}
const void *embeddings_weights(int i) const override {
TC_DCHECK(InRange(i, embeddings_size()));
const int blob_index = trimmed_proto_.embeddings(i).is_quantized()
? (embeddings_blob_offset_ + 2 * i)
: (embeddings_blob_offset_ + i);
StringPiece data_blob_view = memory_reader_.data_blob_view(blob_index);
return data_blob_view.data();
}
QuantizationType embeddings_quant_type(int i) const override {
TC_DCHECK(InRange(i, embeddings_size()));
if (trimmed_proto_.embeddings(i).is_quantized()) {
return QuantizationType::UINT8;
} else {
return QuantizationType::NONE;
}
}
const float16 *embeddings_quant_scales(int i) const override {
TC_DCHECK(InRange(i, embeddings_size()));
if (trimmed_proto_.embeddings(i).is_quantized()) {
// Each embedding matrix has two atttached data blobs (hence the "2 * i"):
// one blob with the quantized values and (immediately after it, hence the
// "+ 1") one blob with the scales.
int blob_index = embeddings_blob_offset_ + 2 * i + 1;
StringPiece data_blob_view = memory_reader_.data_blob_view(blob_index);
return reinterpret_cast<const float16 *>(data_blob_view.data());
} else {
return nullptr;
}
}
int hidden_size() const override { return trimmed_proto_.hidden_size(); }
int hidden_num_rows(int i) const override {
TC_DCHECK(InRange(i, hidden_size()));
return trimmed_proto_.hidden(i).rows();
}
int hidden_num_cols(int i) const override {
TC_DCHECK(InRange(i, hidden_size()));
return trimmed_proto_.hidden(i).cols();
}
const void *hidden_weights(int i) const override {
TC_DCHECK(InRange(i, hidden_size()));
StringPiece data_blob_view =
memory_reader_.data_blob_view(hidden_blob_offset_ + i);
return data_blob_view.data();
}
int hidden_bias_size() const override {
return trimmed_proto_.hidden_bias_size();
}
int hidden_bias_num_rows(int i) const override {
TC_DCHECK(InRange(i, hidden_bias_size()));
return trimmed_proto_.hidden_bias(i).rows();
}
int hidden_bias_num_cols(int i) const override {
TC_DCHECK(InRange(i, hidden_bias_size()));
return trimmed_proto_.hidden_bias(i).cols();
}
const void *hidden_bias_weights(int i) const override {
TC_DCHECK(InRange(i, hidden_bias_size()));
StringPiece data_blob_view =
memory_reader_.data_blob_view(hidden_bias_blob_offset_ + i);
return data_blob_view.data();
}
int softmax_size() const override {
return trimmed_proto_.has_softmax() ? 1 : 0;
}
int softmax_num_rows(int i) const override {
TC_DCHECK(InRange(i, softmax_size()));
return trimmed_proto_.softmax().rows();
}
int softmax_num_cols(int i) const override {
TC_DCHECK(InRange(i, softmax_size()));
return trimmed_proto_.softmax().cols();
}
const void *softmax_weights(int i) const override {
TC_DCHECK(InRange(i, softmax_size()));
StringPiece data_blob_view =
memory_reader_.data_blob_view(softmax_blob_offset_ + i);
return data_blob_view.data();
}
int softmax_bias_size() const override {
return trimmed_proto_.has_softmax_bias() ? 1 : 0;
}
int softmax_bias_num_rows(int i) const override {
TC_DCHECK(InRange(i, softmax_bias_size()));
return trimmed_proto_.softmax_bias().rows();
}
int softmax_bias_num_cols(int i) const override {
TC_DCHECK(InRange(i, softmax_bias_size()));
return trimmed_proto_.softmax_bias().cols();
}
const void *softmax_bias_weights(int i) const override {
TC_DCHECK(InRange(i, softmax_bias_size()));
StringPiece data_blob_view =
memory_reader_.data_blob_view(softmax_bias_blob_offset_ + i);
return data_blob_view.data();
}
int embedding_num_features_size() const override {
return trimmed_proto_.embedding_num_features_size();
}
int embedding_num_features(int i) const override {
TC_DCHECK(InRange(i, embedding_num_features_size()));
return trimmed_proto_.embedding_num_features(i);
}
private:
MemoryImageReader<EmbeddingNetworkProto> memory_reader_;
const EmbeddingNetworkProto &trimmed_proto_;
// 0-based offsets in the list of data blobs for the different MatrixParams
// fields. E.g., the 1st hidden MatrixParams has its weights stored in the
// data blob number hidden_blob_offset_, the 2nd one in hidden_blob_offset_ +
// 1, and so on.
int embeddings_blob_offset_;
int hidden_blob_offset_;
int hidden_bias_blob_offset_;
int softmax_blob_offset_;
int softmax_bias_blob_offset_;
};
} // namespace nlp_core
} // namespace libtextclassifier
#endif // LIBTEXTCLASSIFIER_COMMON_MEMORY_IMAGE_EMBEDDING_NETWORK_PARAMS_FROM_IMAGE_H_