Sync with google3. am: 044fd1a035
am: e75e62f0a6

Change-Id: Icf85c0d2cc356c9ff15c9003677f167746ff51e7
diff --git a/common/embedding-network.cc b/common/embedding-network.cc
index 7522576..30919e8 100644
--- a/common/embedding-network.cc
+++ b/common/embedding-network.cc
@@ -249,31 +249,32 @@
 template <typename ScaleAdderClass>
 bool EmbeddingNetwork::FinishComputeFinalScoresInternal(
     const VectorSpan<float> &input, Vector *scores) const {
-  Vector h0(hidden_bias_[0].size());
-  bool success = SparseReluProductPlusBias<ScaleAdderClass>(
-      false, hidden_weights_[0], hidden_bias_[0], input, &h0);
-  if (!success) return false;
+  // This vector serves as an alternating storage for activations of the
+  // different layers. We can't use just one vector here because all of the
+  // activations of  the previous layer are needed for computation of
+  // activations of the next one.
+  std::vector<Vector> h_storage(2);
 
-  if (hidden_weights_.size() == 1) {  // 1 hidden layer
-    success = SparseReluProductPlusBias<ScaleAdderClass>(
-        true, softmax_weights_, softmax_bias_, h0, scores);
-    if (!success) return false;
-  } else if (hidden_weights_.size() == 2) {  // 2 hidden layers
-    Vector h1(hidden_bias_[1].size());
-    success = SparseReluProductPlusBias<ScaleAdderClass>(
-        true, hidden_weights_[1], hidden_bias_[1], h0, &h1);
-    if (!success) return false;
-    success = SparseReluProductPlusBias<ScaleAdderClass>(
-        true, softmax_weights_, softmax_bias_, h1, scores);
-    if (!success) return false;
-  } else {
-    // This should never happen: the EmbeddingNetwork() constructor marks the
-    // object invalid if #hidden layers is not 1 or 2.  Even if a client uses an
-    // invalid EmbeddingNetwork, ComputeFinalScores() (the only caller to this
-    // method) returns immediately if !is_valid().  Still, just in case, we log
-    // an error, but don't crash.
-    TC_LOG(ERROR) << hidden_weights_.size();
+  // Compute pre-logits activations.
+  VectorSpan<float> h_in(input);
+  Vector *h_out;
+  for (int i = 0; i < hidden_weights_.size(); ++i) {
+    const bool apply_relu = i > 0;
+    h_out = &(h_storage[i % 2]);
+    h_out->resize(hidden_bias_[i].size());
+    if (!SparseReluProductPlusBias<ScaleAdderClass>(
+            apply_relu, hidden_weights_[i], hidden_bias_[i], h_in, h_out)) {
+      return false;
+    }
+    h_in = VectorSpan<float>(*h_out);
   }
+
+  // Compute logit scores.
+  if (!SparseReluProductPlusBias<ScaleAdderClass>(
+          true, softmax_weights_, softmax_bias_, h_in, scores)) {
+    return false;
+  }
+
   return true;
 }
 
@@ -336,8 +337,8 @@
   // Invariant 2 (trivial by the code above).
   TC_DCHECK_EQ(concat_offset_.size(), embedding_matrices_.size());
 
-  int num_hidden_layers = model->GetNumHiddenLayers();
-  if ((num_hidden_layers != 1) && (num_hidden_layers != 2)) {
+  const int num_hidden_layers = model->GetNumHiddenLayers();
+  if (num_hidden_layers < 1) {
     TC_LOG(ERROR) << num_hidden_layers;
     return;
   }
diff --git a/common/embedding-network.h b/common/embedding-network.h
index 594f34c..a02c6ea 100644
--- a/common/embedding-network.h
+++ b/common/embedding-network.h
@@ -198,7 +198,7 @@
   // Returns the size (the number of columns) of the embedding space es_index.
   int EmbeddingSize(int es_index) const;
 
- private:
+ protected:
   // Builds an embedding for given feature vector, and places it from
   // concat_offset to the concat vector.
   bool GetEmbeddingInternal(const FeatureVector &feature_vector,
diff --git a/lang_id/lang-id.cc b/lang_id/lang-id.cc
index aa9ef28..8383d33 100644
--- a/lang_id/lang-id.cc
+++ b/lang_id/lang-id.cc
@@ -87,17 +87,31 @@
     // Using mmap as a fast way to read the model bytes.
     ScopedMmap scoped_mmap(filename);
     MmapHandle mmap_handle = scoped_mmap.handle();
-    Initialize(mmap_handle);
+    if (!mmap_handle.ok()) {
+      TC_LOG(ERROR) << "Unable to read model bytes.";
+      return;
+    }
+
+    Initialize(mmap_handle.to_stringpiece());
   }
 
   explicit LangIdImpl(int fd) {
     // Using mmap as a fast way to read the model bytes.
     ScopedMmap scoped_mmap(fd);
     MmapHandle mmap_handle = scoped_mmap.handle();
-    Initialize(mmap_handle);
+    if (!mmap_handle.ok()) {
+      TC_LOG(ERROR) << "Unable to read model bytes.";
+      return;
+    }
+
+    Initialize(mmap_handle.to_stringpiece());
   }
 
-  void Initialize(const MmapHandle &mmap_handle) {
+  LangIdImpl(const char *ptr, size_t length) {
+    Initialize(StringPiece(ptr, length));
+  }
+
+  void Initialize(StringPiece model_bytes) {
     // Will set valid_ to true only on successful initialization.
     valid_ = false;
 
@@ -105,12 +119,6 @@
     ContinuousBagOfNgramsFunction::RegisterClass();
     RelevantScriptFeature::RegisterClass();
 
-    if (!mmap_handle.ok()) {
-      TC_LOG(ERROR) << "Unable to read model bytes.";
-      return;
-    }
-    StringPiece model_bytes = mmap_handle.to_stringpiece();
-
     // NOTE(salcianu): code below relies on the fact that the current features
     // do not rely on data from a TaskInput.  Otherwise, one would have to use
     // the more complex model registration mechanism, which requires more code.
@@ -357,6 +365,15 @@
   }
 }
 
+LangId::LangId(const char *ptr, size_t length)
+    : pimpl_(new LangIdImpl(ptr, length)) {
+  if (!pimpl_->is_valid()) {
+    TC_LOG(ERROR) << "Unable to construct a valid LangId based "
+                  << "on the memory region; nothing should crash, "
+                  << "but accuracy will be bad.";
+  }
+}
+
 LangId::~LangId() = default;
 
 void LangId::SetProbabilityThreshold(float threshold) {
diff --git a/lang_id/lang-id.h b/lang_id/lang-id.h
index b6b6b1c..7653dde 100644
--- a/lang_id/lang-id.h
+++ b/lang_id/lang-id.h
@@ -53,6 +53,9 @@
   // Same as above but uses a file descriptor.
   explicit LangId(int fd);
 
+  // Same as above but uses already mapped memory region
+  explicit LangId(const char *ptr, size_t length);
+
   virtual ~LangId();
 
   // Sets probability threshold for predictions.  If our likeliest prediction is
diff --git a/tests/embedding-network_test.cc b/tests/embedding-network_test.cc
new file mode 100644
index 0000000..026ec17
--- /dev/null
+++ b/tests/embedding-network_test.cc
@@ -0,0 +1,87 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "common/embedding-network.h"
+#include "common/embedding-network-params-from-proto.h"
+#include "common/embedding-network.pb.h"
+#include "common/simple-adder.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier {
+namespace nlp_core {
+namespace {
+
+using testing::ElementsAreArray;
+
+class TestingEmbeddingNetwork : public EmbeddingNetwork {
+ public:
+  using EmbeddingNetwork::EmbeddingNetwork;
+  using EmbeddingNetwork::FinishComputeFinalScoresInternal;
+};
+
+void DiagonalAndBias3x3(int diagonal_value, int bias_value,
+                        MatrixParams* weights, MatrixParams* bias) {
+  weights->set_rows(3);
+  weights->set_cols(3);
+  weights->add_value(diagonal_value);
+  weights->add_value(0);
+  weights->add_value(0);
+  weights->add_value(0);
+  weights->add_value(diagonal_value);
+  weights->add_value(0);
+  weights->add_value(0);
+  weights->add_value(0);
+  weights->add_value(diagonal_value);
+
+  bias->set_rows(3);
+  bias->set_cols(1);
+  bias->add_value(bias_value);
+  bias->add_value(bias_value);
+  bias->add_value(bias_value);
+}
+
+TEST(EmbeddingNetworkTest, IdentityThroughMultipleLayers) {
+  std::unique_ptr<EmbeddingNetworkProto> proto;
+  proto.reset(new EmbeddingNetworkProto);
+
+  // These layers should be an identity with bias.
+  DiagonalAndBias3x3(/*diagonal_value=*/1, /*bias_value=*/1,
+                     proto->add_hidden(), proto->add_hidden_bias());
+  DiagonalAndBias3x3(/*diagonal_value=*/1, /*bias_value=*/2,
+                     proto->add_hidden(), proto->add_hidden_bias());
+  DiagonalAndBias3x3(/*diagonal_value=*/1, /*bias_value=*/3,
+                     proto->add_hidden(), proto->add_hidden_bias());
+  DiagonalAndBias3x3(/*diagonal_value=*/1, /*bias_value=*/4,
+                     proto->add_hidden(), proto->add_hidden_bias());
+  DiagonalAndBias3x3(/*diagonal_value=*/1, /*bias_value=*/5,
+                     proto->mutable_softmax(), proto->mutable_softmax_bias());
+
+  EmbeddingNetworkParamsFromProto params(std::move(proto));
+  TestingEmbeddingNetwork network(&params);
+
+  std::vector<float> input({-2, -1, 0});
+  std::vector<float> output;
+  network.FinishComputeFinalScoresInternal<SimpleAdder>(
+      VectorSpan<float>(input), &output);
+
+  EXPECT_THAT(output, ElementsAreArray({14, 14, 15}));
+}
+
+}  // namespace
+}  // namespace nlp_core
+}  // namespace libtextclassifier