blob: 2f2c429d8094da78ddaaab5600f4e3fa94c916a6 [file] [log] [blame]
/*
* Copyright (C) 2017 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LIBTEXTCLASSIFIER_COMMON_EMBEDDING_NETWORK_PARAMS_FROM_PROTO_H_
#define LIBTEXTCLASSIFIER_COMMON_EMBEDDING_NETWORK_PARAMS_FROM_PROTO_H_
#include <algorithm>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "common/embedding-network-package.pb.h"
#include "common/embedding-network-params.h"
#include "common/embedding-network.pb.h"
#include "common/float16.h"
#include "common/little-endian-data.h"
#include "common/task-context.h"
#include "common/task-spec.pb.h"
#include "util/base/integral_types.h"
#include "util/base/logging.h"
namespace libtextclassifier {
namespace nlp_core {
// A wrapper class that owns and exposes an EmbeddingNetworkProto message via
// the EmbeddingNetworkParams interface.
//
// The EmbeddingNetworkParams interface encapsulates the weight matrices of the
// embeddings, hidden and softmax layers as transposed versions of their
// counterparts in the original EmbeddingNetworkProto. The matrices in the proto
// passed to this class' constructor must likewise already have been transposed.
// See embedding-network-params.h for details.
class EmbeddingNetworkParamsFromProto : public EmbeddingNetworkParams {
public:
// Constructor that takes ownership of the provided proto. See class-comment
// for the requirements that certain weight matrices must satisfy.
explicit EmbeddingNetworkParamsFromProto(
std::unique_ptr<EmbeddingNetworkProto> proto)
: proto_(std::move(proto)) {
valid_ = true;
// Initialize these vectors to have the required number of elements
// regardless of quantization status. This is to support the unlikely case
// where only some embeddings are quantized, along with the fact that
// EmbeddingNetworkParams interface accesses them by index.
embeddings_quant_scales_.resize(proto_->embeddings_size());
embeddings_quant_weights_.resize(proto_->embeddings_size());
for (int i = 0; i < proto_->embeddings_size(); ++i) {
MatrixParams *embedding = proto_->mutable_embeddings()->Mutable(i);
if (!embedding->is_quantized()) {
continue;
}
bool success = FillVectorFromDataBytesInLittleEndian(
embedding->bytes_for_quantized_values(),
embedding->rows() * embedding->cols(),
&(embeddings_quant_weights_[i]));
if (!success) {
TC_LOG(ERROR) << "Problem decoding quant_weights for embeddings #" << i;
valid_ = false;
}
// The repeated field bytes_for_quantized_values uses a lot of memory.
// Since it's no longer necessary (and we own the proto), we clear it.
embedding->clear_bytes_for_quantized_values();
success = FillVectorFromDataBytesInLittleEndian(
embedding->bytes_for_col_scales(),
embedding->rows(),
&(embeddings_quant_scales_[i]));
if (!success) {
TC_LOG(ERROR) << "Problem decoding col_scales for embeddings #" << i;
valid_ = false;
}
// See comments for clear_bytes_for_quantized_values().
embedding->clear_bytes_for_col_scales();
}
}
const TaskSpec *GetTaskSpec() override {
if (!proto_) {
return nullptr;
}
auto extension_id = task_spec_in_embedding_network_proto;
if (proto_->HasExtension(extension_id)) {
return &(proto_->GetExtension(extension_id));
} else {
TC_LOG(ERROR) << "Unable to get TaskSpec from EmbeddingNetworkProto";
return nullptr;
}
}
// Returns true if these params are valid. False otherwise (e.g., if the
// original proto data was corrupted).
bool is_valid() { return valid_; }
protected:
int embeddings_size() const override { return proto_->embeddings_size(); }
int embeddings_num_rows(int i) const override {
TC_DCHECK(InRange(i, embeddings_size()));
return proto_->embeddings(i).rows();
}
int embeddings_num_cols(int i) const override {
TC_DCHECK(InRange(i, embeddings_size()));
return proto_->embeddings(i).cols();
}
const void *embeddings_weights(int i) const override {
TC_DCHECK(InRange(i, embeddings_size()));
if (proto_->embeddings(i).is_quantized()) {
return static_cast<const void *>(embeddings_quant_weights_.at(i).data());
} else {
return static_cast<const void *>(proto_->embeddings(i).value().data());
}
}
QuantizationType embeddings_quant_type(int i) const override {
TC_DCHECK(InRange(i, embeddings_size()));
return proto_->embeddings(i).is_quantized() ? QuantizationType::UINT8
: QuantizationType::NONE;
}
const float16 *embeddings_quant_scales(int i) const override {
TC_DCHECK(InRange(i, embeddings_size()));
return proto_->embeddings(i).is_quantized()
? embeddings_quant_scales_.at(i).data()
: nullptr;
}
int hidden_size() const override { return proto_->hidden_size(); }
int hidden_num_rows(int i) const override {
TC_DCHECK(InRange(i, hidden_size()));
return proto_->hidden(i).rows();
}
int hidden_num_cols(int i) const override {
TC_DCHECK(InRange(i, hidden_size()));
return proto_->hidden(i).cols();
}
const void *hidden_weights(int i) const override {
TC_DCHECK(InRange(i, hidden_size()));
return proto_->hidden(i).value().data();
}
int hidden_bias_size() const override { return proto_->hidden_bias_size(); }
int hidden_bias_num_rows(int i) const override {
TC_DCHECK(InRange(i, hidden_bias_size()));
return proto_->hidden_bias(i).rows();
}
int hidden_bias_num_cols(int i) const override {
TC_DCHECK(InRange(i, hidden_bias_size()));
return proto_->hidden_bias(i).cols();
}
const void *hidden_bias_weights(int i) const override {
TC_DCHECK(InRange(i, hidden_bias_size()));
return proto_->hidden_bias(i).value().data();
}
int softmax_size() const override { return proto_->has_softmax() ? 1 : 0; }
int softmax_num_rows(int i) const override {
TC_DCHECK(InRange(i, softmax_size()));
return proto_->has_softmax() ? proto_->softmax().rows() : 0;
}
int softmax_num_cols(int i) const override {
TC_DCHECK(InRange(i, softmax_size()));
return proto_->has_softmax() ? proto_->softmax().cols() : 0;
}
const void *softmax_weights(int i) const override {
TC_DCHECK(InRange(i, softmax_size()));
return proto_->has_softmax() ? proto_->softmax().value().data() : nullptr;
}
int softmax_bias_size() const override {
return proto_->has_softmax_bias() ? 1 : 0;
}
int softmax_bias_num_rows(int i) const override {
TC_DCHECK(InRange(i, softmax_bias_size()));
return proto_->has_softmax_bias() ? proto_->softmax_bias().rows() : 0;
}
int softmax_bias_num_cols(int i) const override {
TC_DCHECK(InRange(i, softmax_bias_size()));
return proto_->has_softmax_bias() ? proto_->softmax_bias().cols() : 0;
}
const void *softmax_bias_weights(int i) const override {
TC_DCHECK(InRange(i, softmax_bias_size()));
return proto_->has_softmax_bias() ? proto_->softmax_bias().value().data()
: nullptr;
}
int embedding_num_features_size() const override {
return proto_->embedding_num_features_size();
}
int embedding_num_features(int i) const override {
TC_DCHECK(InRange(i, embedding_num_features_size()));
return proto_->embedding_num_features(i);
}
private:
std::unique_ptr<EmbeddingNetworkProto> proto_;
// True if these params are valid. May be false if the original proto was
// corrupted. We prefer to set this to false to CHECK-failing.
bool valid_;
// When the embeddings are quantized, these members are used to store their
// numeric values using the types expected by the rest of the class. Due to
// technical reasons, the proto stores this info using larger types (i.e.,
// more bits).
std::vector<std::vector<float16>> embeddings_quant_scales_;
std::vector<std::vector<uint8>> embeddings_quant_weights_;
};
} // namespace nlp_core
} // namespace libtextclassifier
#endif // LIBTEXTCLASSIFIER_COMMON_EMBEDDING_NETWORK_PARAMS_FROM_PROTO_H_