Sync with google3. am: 044fd1a035
am: e75e62f0a6
Change-Id: Icf85c0d2cc356c9ff15c9003677f167746ff51e7
diff --git a/common/embedding-network.cc b/common/embedding-network.cc
index 7522576..30919e8 100644
--- a/common/embedding-network.cc
+++ b/common/embedding-network.cc
@@ -249,31 +249,32 @@
template <typename ScaleAdderClass>
bool EmbeddingNetwork::FinishComputeFinalScoresInternal(
const VectorSpan<float> &input, Vector *scores) const {
- Vector h0(hidden_bias_[0].size());
- bool success = SparseReluProductPlusBias<ScaleAdderClass>(
- false, hidden_weights_[0], hidden_bias_[0], input, &h0);
- if (!success) return false;
+ // This vector serves as an alternating storage for activations of the
+ // different layers. We can't use just one vector here because all of the
+ // activations of the previous layer are needed for computation of
+ // activations of the next one.
+ std::vector<Vector> h_storage(2);
- if (hidden_weights_.size() == 1) { // 1 hidden layer
- success = SparseReluProductPlusBias<ScaleAdderClass>(
- true, softmax_weights_, softmax_bias_, h0, scores);
- if (!success) return false;
- } else if (hidden_weights_.size() == 2) { // 2 hidden layers
- Vector h1(hidden_bias_[1].size());
- success = SparseReluProductPlusBias<ScaleAdderClass>(
- true, hidden_weights_[1], hidden_bias_[1], h0, &h1);
- if (!success) return false;
- success = SparseReluProductPlusBias<ScaleAdderClass>(
- true, softmax_weights_, softmax_bias_, h1, scores);
- if (!success) return false;
- } else {
- // This should never happen: the EmbeddingNetwork() constructor marks the
- // object invalid if #hidden layers is not 1 or 2. Even if a client uses an
- // invalid EmbeddingNetwork, ComputeFinalScores() (the only caller to this
- // method) returns immediately if !is_valid(). Still, just in case, we log
- // an error, but don't crash.
- TC_LOG(ERROR) << hidden_weights_.size();
+ // Compute pre-logits activations.
+ VectorSpan<float> h_in(input);
+ Vector *h_out;
+ for (int i = 0; i < hidden_weights_.size(); ++i) {
+ const bool apply_relu = i > 0;
+ h_out = &(h_storage[i % 2]);
+ h_out->resize(hidden_bias_[i].size());
+ if (!SparseReluProductPlusBias<ScaleAdderClass>(
+ apply_relu, hidden_weights_[i], hidden_bias_[i], h_in, h_out)) {
+ return false;
+ }
+ h_in = VectorSpan<float>(*h_out);
}
+
+ // Compute logit scores.
+ if (!SparseReluProductPlusBias<ScaleAdderClass>(
+ true, softmax_weights_, softmax_bias_, h_in, scores)) {
+ return false;
+ }
+
return true;
}
@@ -336,8 +337,8 @@
// Invariant 2 (trivial by the code above).
TC_DCHECK_EQ(concat_offset_.size(), embedding_matrices_.size());
- int num_hidden_layers = model->GetNumHiddenLayers();
- if ((num_hidden_layers != 1) && (num_hidden_layers != 2)) {
+ const int num_hidden_layers = model->GetNumHiddenLayers();
+ if (num_hidden_layers < 1) {
TC_LOG(ERROR) << num_hidden_layers;
return;
}
diff --git a/common/embedding-network.h b/common/embedding-network.h
index 594f34c..a02c6ea 100644
--- a/common/embedding-network.h
+++ b/common/embedding-network.h
@@ -198,7 +198,7 @@
// Returns the size (the number of columns) of the embedding space es_index.
int EmbeddingSize(int es_index) const;
- private:
+ protected:
// Builds an embedding for given feature vector, and places it from
// concat_offset to the concat vector.
bool GetEmbeddingInternal(const FeatureVector &feature_vector,
diff --git a/lang_id/lang-id.cc b/lang_id/lang-id.cc
index aa9ef28..8383d33 100644
--- a/lang_id/lang-id.cc
+++ b/lang_id/lang-id.cc
@@ -87,17 +87,31 @@
// Using mmap as a fast way to read the model bytes.
ScopedMmap scoped_mmap(filename);
MmapHandle mmap_handle = scoped_mmap.handle();
- Initialize(mmap_handle);
+ if (!mmap_handle.ok()) {
+ TC_LOG(ERROR) << "Unable to read model bytes.";
+ return;
+ }
+
+ Initialize(mmap_handle.to_stringpiece());
}
explicit LangIdImpl(int fd) {
// Using mmap as a fast way to read the model bytes.
ScopedMmap scoped_mmap(fd);
MmapHandle mmap_handle = scoped_mmap.handle();
- Initialize(mmap_handle);
+ if (!mmap_handle.ok()) {
+ TC_LOG(ERROR) << "Unable to read model bytes.";
+ return;
+ }
+
+ Initialize(mmap_handle.to_stringpiece());
}
- void Initialize(const MmapHandle &mmap_handle) {
+ LangIdImpl(const char *ptr, size_t length) {
+ Initialize(StringPiece(ptr, length));
+ }
+
+ void Initialize(StringPiece model_bytes) {
// Will set valid_ to true only on successful initialization.
valid_ = false;
@@ -105,12 +119,6 @@
ContinuousBagOfNgramsFunction::RegisterClass();
RelevantScriptFeature::RegisterClass();
- if (!mmap_handle.ok()) {
- TC_LOG(ERROR) << "Unable to read model bytes.";
- return;
- }
- StringPiece model_bytes = mmap_handle.to_stringpiece();
-
// NOTE(salcianu): code below relies on the fact that the current features
// do not rely on data from a TaskInput. Otherwise, one would have to use
// the more complex model registration mechanism, which requires more code.
@@ -357,6 +365,15 @@
}
}
+LangId::LangId(const char *ptr, size_t length)
+ : pimpl_(new LangIdImpl(ptr, length)) {
+ if (!pimpl_->is_valid()) {
+ TC_LOG(ERROR) << "Unable to construct a valid LangId based "
+ << "on the memory region; nothing should crash, "
+ << "but accuracy will be bad.";
+ }
+}
+
LangId::~LangId() = default;
void LangId::SetProbabilityThreshold(float threshold) {
diff --git a/lang_id/lang-id.h b/lang_id/lang-id.h
index b6b6b1c..7653dde 100644
--- a/lang_id/lang-id.h
+++ b/lang_id/lang-id.h
@@ -53,6 +53,9 @@
// Same as above but uses a file descriptor.
explicit LangId(int fd);
+ // Same as above but uses already mapped memory region
+ explicit LangId(const char *ptr, size_t length);
+
virtual ~LangId();
// Sets probability threshold for predictions. If our likeliest prediction is
diff --git a/tests/embedding-network_test.cc b/tests/embedding-network_test.cc
new file mode 100644
index 0000000..026ec17
--- /dev/null
+++ b/tests/embedding-network_test.cc
@@ -0,0 +1,87 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "common/embedding-network.h"
+#include "common/embedding-network-params-from-proto.h"
+#include "common/embedding-network.pb.h"
+#include "common/simple-adder.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier {
+namespace nlp_core {
+namespace {
+
+using testing::ElementsAreArray;
+
+class TestingEmbeddingNetwork : public EmbeddingNetwork {
+ public:
+ using EmbeddingNetwork::EmbeddingNetwork;
+ using EmbeddingNetwork::FinishComputeFinalScoresInternal;
+};
+
+void DiagonalAndBias3x3(int diagonal_value, int bias_value,
+ MatrixParams* weights, MatrixParams* bias) {
+ weights->set_rows(3);
+ weights->set_cols(3);
+ weights->add_value(diagonal_value);
+ weights->add_value(0);
+ weights->add_value(0);
+ weights->add_value(0);
+ weights->add_value(diagonal_value);
+ weights->add_value(0);
+ weights->add_value(0);
+ weights->add_value(0);
+ weights->add_value(diagonal_value);
+
+ bias->set_rows(3);
+ bias->set_cols(1);
+ bias->add_value(bias_value);
+ bias->add_value(bias_value);
+ bias->add_value(bias_value);
+}
+
+TEST(EmbeddingNetworkTest, IdentityThroughMultipleLayers) {
+ std::unique_ptr<EmbeddingNetworkProto> proto;
+ proto.reset(new EmbeddingNetworkProto);
+
+ // These layers should be an identity with bias.
+ DiagonalAndBias3x3(/*diagonal_value=*/1, /*bias_value=*/1,
+ proto->add_hidden(), proto->add_hidden_bias());
+ DiagonalAndBias3x3(/*diagonal_value=*/1, /*bias_value=*/2,
+ proto->add_hidden(), proto->add_hidden_bias());
+ DiagonalAndBias3x3(/*diagonal_value=*/1, /*bias_value=*/3,
+ proto->add_hidden(), proto->add_hidden_bias());
+ DiagonalAndBias3x3(/*diagonal_value=*/1, /*bias_value=*/4,
+ proto->add_hidden(), proto->add_hidden_bias());
+ DiagonalAndBias3x3(/*diagonal_value=*/1, /*bias_value=*/5,
+ proto->mutable_softmax(), proto->mutable_softmax_bias());
+
+ EmbeddingNetworkParamsFromProto params(std::move(proto));
+ TestingEmbeddingNetwork network(¶ms);
+
+ std::vector<float> input({-2, -1, 0});
+ std::vector<float> output;
+ network.FinishComputeFinalScoresInternal<SimpleAdder>(
+ VectorSpan<float>(input), &output);
+
+ EXPECT_THAT(output, ElementsAreArray({14, 14, 15}));
+}
+
+} // namespace
+} // namespace nlp_core
+} // namespace libtextclassifier