Makes inference efficient. am: 6bb39a8ec5
am: 49574bd039

Change-Id: I4085bf36469571115a0566024cfe8e9a98b603ba
diff --git a/common/embedding-network.cc b/common/embedding-network.cc
index 61f7323..a17e082 100644
--- a/common/embedding-network.cc
+++ b/common/embedding-network.cc
@@ -48,8 +48,8 @@
   // Before we access the weights as floats, we need to check that they are
   // really floats, i.e., no quantization is used.
   if (!CheckNoQuantization(source_matrix)) return false;
-  const float *weights = reinterpret_cast<const float *>(
-      source_matrix.elements);
+  const float *weights =
+      reinterpret_cast<const float *>(source_matrix.elements);
   for (int r = 0; r < source_matrix.rows; ++r) {
     (*mat)[r] = EmbeddingNetwork::VectorWrapper(weights, source_matrix.cols);
     weights += source_matrix.cols;
@@ -86,7 +86,7 @@
 bool SparseReluProductPlusBias(bool apply_relu,
                                const EmbeddingNetwork::Matrix &weights,
                                const EmbeddingNetwork::VectorWrapper &b,
-                               const EmbeddingNetwork::Vector &x,
+                               const VectorSpan<float> &x,
                                EmbeddingNetwork::Vector *y) {
   // Check that dimensions match.
   if ((x.size() != weights.size()) || weights.empty()) {
@@ -131,87 +131,127 @@
 
   // "es_index" stands for "embedding space index".
   for (int es_index = 0; es_index < feature_vectors.size(); ++es_index) {
-    // Access is safe by es_index loop bounds, Invariant 2, and Invariant 3.
-    const int concat_offset = concat_offset_[es_index];
-
     // Access is safe by es_index loop bounds and Invariant 3.
-    const EmbeddingMatrix *embedding_matrix =
+    EmbeddingMatrix *const embedding_matrix =
         embedding_matrices_[es_index].get();
     if (embedding_matrix == nullptr) {
       // Should not happen, hence our terse log error message.
       TC_LOG(ERROR) << es_index;
       return false;
     }
-    const int embedding_dim = embedding_matrix->dim();
-    const bool is_quantized =
-        embedding_matrix->quant_type() != QuantizationType::NONE;
 
     // Access is safe due to es_index loop bounds.
     const FeatureVector &feature_vector = feature_vectors[es_index];
-    const int num_features = feature_vector.size();
-    for (int fi = 0; fi < num_features; ++fi) {
-      // Both accesses below are safe due to loop bounds for fi.
-      const FeatureType *feature_type = feature_vector.type(fi);
-      const FeatureValue feature_value = feature_vector.value(fi);
-      const int feature_offset =
-          concat_offset + feature_type->base() * embedding_dim;
 
-      // Code below updates max(0, embedding_dim) elements from concat, starting
-      // with index feature_offset.  Check below ensures these updates are safe.
-      if ((feature_offset < 0) ||
-          (feature_offset + embedding_dim > concat->size())) {
-        TC_LOG(ERROR) << es_index << "," << fi << ": " << feature_offset << " "
-                      << embedding_dim << " " << concat->size();
-        return false;
+    // Access is safe by es_index loop bounds, Invariant 2, and Invariant 3.
+    const int concat_offset = concat_offset_[es_index];
+
+    if (!GetEmbeddingInternal(feature_vector, embedding_matrix, concat_offset,
+                              concat->data(), concat->size())) {
+      TC_LOG(ERROR) << es_index;
+      return false;
+    }
+  }
+  return true;
+}
+
+bool EmbeddingNetwork::GetEmbedding(const FeatureVector &feature_vector,
+                                    int es_index, float *embedding) const {
+  EmbeddingMatrix *const embedding_matrix = embedding_matrices_[es_index].get();
+  if (embedding_matrix == nullptr) {
+    // Should not happen, hence our terse log error message.
+    TC_LOG(ERROR) << es_index;
+    return false;
+  }
+  return GetEmbeddingInternal(feature_vector, embedding_matrix, 0, embedding,
+                              embedding_matrices_[es_index]->dim());
+}
+
+bool EmbeddingNetwork::GetEmbeddingInternal(
+    const FeatureVector &feature_vector,
+    EmbeddingMatrix *const embedding_matrix, const int concat_offset,
+    float *concat, int concat_size) const {
+  const int embedding_dim = embedding_matrix->dim();
+  const bool is_quantized =
+      embedding_matrix->quant_type() != QuantizationType::NONE;
+  const int num_features = feature_vector.size();
+  for (int fi = 0; fi < num_features; ++fi) {
+    // Both accesses below are safe due to loop bounds for fi.
+    const FeatureType *feature_type = feature_vector.type(fi);
+    const FeatureValue feature_value = feature_vector.value(fi);
+    const int feature_offset =
+        concat_offset + feature_type->base() * embedding_dim;
+
+    // Code below updates max(0, embedding_dim) elements from concat, starting
+    // with index feature_offset.  Check below ensures these updates are safe.
+    if ((feature_offset < 0) ||
+        (feature_offset + embedding_dim > concat_size)) {
+      TC_LOG(ERROR) << fi << ": " << feature_offset << " " << embedding_dim
+                    << " " << concat_size;
+      return false;
+    }
+
+    // Pointer to float / uint8 weights for relevant embedding.
+    const void *embedding_data;
+
+    // Multiplier for each embedding weight.
+    float multiplier;
+
+    if (feature_type->is_continuous()) {
+      // Continuous features (encoded as FloatFeatureValue).
+      FloatFeatureValue float_feature_value(feature_value);
+      const int id = float_feature_value.id;
+      embedding_matrix->get_embedding(id, &embedding_data, &multiplier);
+      multiplier *= float_feature_value.weight;
+    } else {
+      // Discrete features: every present feature has implicit value 1.0.
+      // Hence, after we grab the multiplier below, we don't multiply it by
+      // any weight.
+      embedding_matrix->get_embedding(feature_value, &embedding_data,
+                                      &multiplier);
+    }
+
+    // Weighted embeddings will be added starting from this address.
+    float *concat_ptr = concat + feature_offset;
+
+    if (is_quantized) {
+      const uint8 *quant_weights =
+          reinterpret_cast<const uint8 *>(embedding_data);
+      for (int i = 0; i < embedding_dim; ++i, ++quant_weights, ++concat_ptr) {
+        // 128 is bias for UINT8 quantization, only one we currently support.
+        *concat_ptr += (static_cast<int>(*quant_weights) - 128) * multiplier;
       }
-
-      // Pointer to float / uint8 weights for relevant embedding.
-      const void *embedding_data;
-
-      // Multiplier for each embedding weight.
-      float multiplier;
-
-      if (feature_type->is_continuous()) {
-        // Continuous features (encoded as FloatFeatureValue).
-        FloatFeatureValue float_feature_value(feature_value);
-        const int id = float_feature_value.id;
-        embedding_matrix->get_embedding(id, &embedding_data, &multiplier);
-        multiplier *= float_feature_value.weight;
-      } else {
-        // Discrete features: every present feature has implicit value 1.0.
-        // Hence, after we grab the multiplier below, we don't multiply it by
-        // any weight.
-        embedding_matrix->get_embedding(feature_value, &embedding_data,
-                                        &multiplier);
-      }
-
-      // Weighted embeddings will be added starting from this address.
-      float *concat_ptr = concat->data() + feature_offset;
-
-      if (is_quantized) {
-        const uint8 *quant_weights =
-            reinterpret_cast<const uint8 *>(embedding_data);
-        for (int i = 0; i < embedding_dim; ++i, ++quant_weights, ++concat_ptr) {
-          // 128 is bias for UINT8 quantization, only one we currently support.
-          *concat_ptr += (static_cast<int>(*quant_weights) - 128) * multiplier;
-        }
-      } else {
-        const float *weights = reinterpret_cast<const float *>(embedding_data);
-        for (int i = 0; i < embedding_dim; ++i, ++weights, ++concat_ptr) {
-          *concat_ptr += *weights * multiplier;
-        }
+    } else {
+      const float *weights = reinterpret_cast<const float *>(embedding_data);
+      for (int i = 0; i < embedding_dim; ++i, ++weights, ++concat_ptr) {
+        *concat_ptr += *weights * multiplier;
       }
     }
   }
   return true;
 }
 
+bool EmbeddingNetwork::ComputeLogits(const VectorSpan<float> &input,
+                                     Vector *scores) const {
+  return EmbeddingNetwork::ComputeLogitsInternal(input, scores);
+}
+
+bool EmbeddingNetwork::ComputeLogits(const Vector &input,
+                                     Vector *scores) const {
+  return EmbeddingNetwork::ComputeLogitsInternal(input, scores);
+}
+
+bool EmbeddingNetwork::ComputeLogitsInternal(const VectorSpan<float> &input,
+                                             Vector *scores) const {
+  return FinishComputeFinalScoresInternal<SimpleAdder>(input, scores);
+}
+
 template <typename ScaleAdderClass>
-bool EmbeddingNetwork::FinishComputeFinalScores(const Vector &concat,
-                                                Vector *scores) const {
+bool EmbeddingNetwork::FinishComputeFinalScoresInternal(
+    const VectorSpan<float> &input, Vector *scores) const {
   Vector h0(hidden_bias_[0].size());
   bool success = SparseReluProductPlusBias<ScaleAdderClass>(
-      false, hidden_weights_[0], hidden_bias_[0], concat, &h0);
+      false, hidden_weights_[0], hidden_bias_[0], input, &h0);
   if (!success) return false;
 
   if (hidden_weights_.size() == 1) {  // 1 hidden layer
@@ -259,7 +299,7 @@
   }
 
   scores->resize(softmax_bias_.size());
-  return FinishComputeFinalScores<SimpleAdder>(concat, scores);
+  return ComputeLogits(concat, scores);
 }
 
 EmbeddingNetwork::EmbeddingNetwork(const EmbeddingNetworkParams *model) {
@@ -331,5 +371,9 @@
   valid_ = true;
 }
 
+int EmbeddingNetwork::EmbeddingSize(int es_index) const {
+  return embedding_matrices_[es_index]->dim();
+}
+
 }  // namespace nlp_core
 }  // namespace libtextclassifier
diff --git a/common/embedding-network.h b/common/embedding-network.h
index 95b4d58..594f34c 100644
--- a/common/embedding-network.h
+++ b/common/embedding-network.h
@@ -22,6 +22,7 @@
 
 #include "common/embedding-network-params.h"
 #include "common/feature-extractor.h"
+#include "common/vector-span.h"
 #include "util/base/integral_types.h"
 #include "util/base/logging.h"
 #include "util/base/macros.h"
@@ -54,8 +55,7 @@
           quant_type_(source_matrix.quant_type),
           data_(source_matrix.elements),
           row_size_in_bytes_(GetRowSizeInBytes(cols_, quant_type_)),
-          quant_scales_(source_matrix.quant_scales) {
-    }
+          quant_scales_(source_matrix.quant_scales) {}
 
     // Returns vocabulary size; one embedding for each vocabulary element.
     int size() const { return rows_; }
@@ -133,8 +133,7 @@
     // at address data.  Note: the underlying data should be alive for at least
     // the lifetime of this VectorWrapper object.  That's trivially true if data
     // points to statically allocated data :)
-    VectorWrapper(const float *data, int size)
-        : data_(data), size_(size) {}
+    VectorWrapper(const float *data, int size) : data_(data), size_(size) {}
 
     int size() const { return size_; }
 
@@ -176,17 +175,48 @@
                           const std::vector<float> extra_inputs,
                           Vector *scores) const;
 
- private:
-  // Computes the softmax scores (prior to normalization) from the concatenated
-  // representation.  Returns true on success, false on error.
-  template <typename ScaleAdderClass>
-  bool FinishComputeFinalScores(const Vector &concat, Vector *scores) const;
-
   // Constructs the concatenated input embedding vector in place in output
   // vector concat.  Returns true on success, false on error.
   bool ConcatEmbeddings(const std::vector<FeatureVector> &features,
                         Vector *concat) const;
 
+  // Sums embeddings for all features from |feature_vector| and adds result
+  // to values from the array pointed-to by |output|.  Embeddings for continuous
+  // features are weighted by the feature weight.
+  //
+  // NOTE: output should point to an array of EmbeddingSize(es_index) floats.
+  bool GetEmbedding(const FeatureVector &feature_vector, int es_index,
+                    float *embedding) const;
+
+  // Runs the feed-forward neural network for |input| and computes logits for
+  // softmax layer.
+  bool ComputeLogits(const Vector &input, Vector *scores) const;
+
+  // Same as above but uses a view of the feature vector.
+  bool ComputeLogits(const VectorSpan<float> &input, Vector *scores) const;
+
+  // Returns the size (the number of columns) of the embedding space es_index.
+  int EmbeddingSize(int es_index) const;
+
+ private:
+  // Builds an embedding for given feature vector, and places it from
+  // concat_offset to the concat vector.
+  bool GetEmbeddingInternal(const FeatureVector &feature_vector,
+                            EmbeddingMatrix *embedding_matrix,
+                            int concat_offset, float *concat,
+                            int embedding_size) const;
+
+  // Templated function that computes the logit scores given the concatenated
+  // input embeddings.
+  bool ComputeLogitsInternal(const VectorSpan<float> &concat,
+                             Vector *scores) const;
+
+  // Computes the softmax scores (prior to normalization) from the concatenated
+  // representation.  Returns true on success, false on error.
+  template <typename ScaleAdderClass>
+  bool FinishComputeFinalScoresInternal(const VectorSpan<float> &concat,
+                                        Vector *scores) const;
+
   // Set to true on successful construction, false otherwise.
   bool valid_ = false;
 
diff --git a/common/vector-span.h b/common/vector-span.h
new file mode 100644
index 0000000..d7fbfe9
--- /dev/null
+++ b/common/vector-span.h
@@ -0,0 +1,50 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_COMMON_VECTOR_SPAN_H_
+#define LIBTEXTCLASSIFIER_COMMON_VECTOR_SPAN_H_
+
+#include <vector>
+
+namespace libtextclassifier {
+
+// StringPiece analogue for std::vector<T>.
+template <class T>
+class VectorSpan {
+ public:
+  VectorSpan() : begin_(), end_() {}
+  VectorSpan(const std::vector<T>& v)  // NOLINT(runtime/explicit)
+      : begin_(v.begin()), end_(v.end()) {}
+  VectorSpan(typename std::vector<T>::const_iterator begin,
+             typename std::vector<T>::const_iterator end)
+      : begin_(begin), end_(end) {}
+
+  const T& operator[](typename std::vector<T>::size_type i) const {
+    return *(begin_ + i);
+  }
+
+  int size() const { return end_ - begin_; }
+  typename std::vector<T>::const_iterator begin() const { return begin_; }
+  typename std::vector<T>::const_iterator end() const { return end_; }
+
+ private:
+  typename std::vector<T>::const_iterator begin_;
+  typename std::vector<T>::const_iterator end_;
+};
+
+}  // namespace libtextclassifier
+
+#endif  // LIBTEXTCLASSIFIER_COMMON_VECTOR_SPAN_H_
diff --git a/smartselect/cached-features.cc b/smartselect/cached-features.cc
new file mode 100644
index 0000000..c249db9
--- /dev/null
+++ b/smartselect/cached-features.cc
@@ -0,0 +1,82 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "smartselect/cached-features.h"
+#include "util/base/logging.h"
+
+namespace libtextclassifier {
+
+void CachedFeatures::Extract(
+    const std::vector<std::vector<int>>& sparse_features,
+    const std::vector<std::vector<float>>& dense_features,
+    const std::function<bool(const std::vector<int>&, const std::vector<float>&,
+                             float*)>& feature_vector_fn) {
+  features_.resize(feature_vector_size_ * tokens_.size());
+  for (int i = 0; i < tokens_.size(); ++i) {
+    feature_vector_fn(sparse_features[i], dense_features[i],
+                      features_.data() + i * feature_vector_size_);
+  }
+}
+
+bool CachedFeatures::Get(int click_pos, VectorSpan<float>* features,
+                         VectorSpan<Token>* output_tokens) {
+  const int token_start = click_pos - context_size_;
+  const int token_end = click_pos + context_size_ + 1;
+  if (token_start < 0 || token_end > tokens_.size()) {
+    TC_LOG(ERROR) << "Tokens out of range: " << token_start << " " << token_end;
+    return false;
+  }
+
+  *features =
+      VectorSpan<float>(features_.begin() + token_start * feature_vector_size_,
+                        features_.begin() + token_end * feature_vector_size_);
+  *output_tokens = VectorSpan<Token>(tokens_.begin() + token_start,
+                                     tokens_.begin() + token_end);
+  if (remap_v0_feature_vector_) {
+    RemapV0FeatureVector(features);
+  }
+
+  return true;
+}
+
+void CachedFeatures::RemapV0FeatureVector(VectorSpan<float>* features) {
+  if (!remap_v0_feature_vector_) {
+    return;
+  }
+
+  auto it = features->begin();
+  int num_suffix_features =
+      feature_vector_size_ - remap_v0_chargram_embedding_size_;
+  int num_tokens = context_size_ * 2 + 1;
+  for (int t = 0; t < num_tokens; ++t) {
+    for (int i = 0; i < remap_v0_chargram_embedding_size_; ++i) {
+      v0_feature_storage_[t * remap_v0_chargram_embedding_size_ + i] = *it;
+      ++it;
+    }
+    // Rest of the features are the dense features that come to the end.
+    for (int i = 0; i < num_suffix_features; ++i) {
+      // clang-format off
+      v0_feature_storage_[num_tokens * remap_v0_chargram_embedding_size_
+                      + t * num_suffix_features
+                      + i] = *it;
+      // clang-format on
+      ++it;
+    }
+  }
+  *features = VectorSpan<float>(v0_feature_storage_);
+}
+
+}  // namespace libtextclassifier
diff --git a/smartselect/cached-features.h b/smartselect/cached-features.h
new file mode 100644
index 0000000..6490748
--- /dev/null
+++ b/smartselect/cached-features.h
@@ -0,0 +1,106 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_SMARTSELECT_CACHED_FEATURES_H_
+#define LIBTEXTCLASSIFIER_SMARTSELECT_CACHED_FEATURES_H_
+
+#include <memory>
+#include <vector>
+
+#include "base.h"
+#include "common/vector-span.h"
+#include "smartselect/types.h"
+
+namespace libtextclassifier {
+
+// Holds state for extracting features across multiple calls and reusing them.
+// Assumes that features for each Token are independent.
+class CachedFeatures {
+ public:
+  // Extracts the features for the given sequence of tokens.
+  //  - context_size: Specifies how many tokens to the left, and how many
+  //                   tokens to the right spans the context.
+  //  - sparse_features, dense_features: Extracted features for each token.
+  //  - feature_vector_fn: Writes features for given Token to the specified
+  //                       storage.
+  //                       NOTE: The function can assume that the underlying
+  //                       storage is initialized to all zeros.
+  //  - feature_vector_size: Size of a feature vector for one Token.
+  CachedFeatures(VectorSpan<Token> tokens, int context_size,
+                 const std::vector<std::vector<int>>& sparse_features,
+                 const std::vector<std::vector<float>>& dense_features,
+                 const std::function<bool(const std::vector<int>&,
+                                          const std::vector<float>&, float*)>&
+                     feature_vector_fn,
+                 int feature_vector_size)
+      : tokens_(tokens),
+        context_size_(context_size),
+        feature_vector_size_(feature_vector_size),
+        remap_v0_feature_vector_(false),
+        remap_v0_chargram_embedding_size_(-1) {
+    Extract(sparse_features, dense_features, feature_vector_fn);
+  }
+
+  // Gets a VectorSpan with the features for given click position.
+  bool Get(int click_pos, VectorSpan<float>* features,
+           VectorSpan<Token>* output_tokens);
+
+  // Turns on a compatibility mode, which re-maps the extracted features to the
+  // v0 feature format (where the dense features were at the end).
+  // WARNING: Internally v0_feature_storage_ is used as a backing buffer for
+  // VectorSpan<float>, so the output of Extract is valid only until the next
+  // call or destruction of the current CachedFeatures object.
+  // TODO(zilka): Remove when we'll have retrained models.
+  void SetV0FeatureMode(int chargram_embedding_size) {
+    remap_v0_feature_vector_ = true;
+    remap_v0_chargram_embedding_size_ = chargram_embedding_size;
+    v0_feature_storage_.resize(feature_vector_size_ * (context_size_ * 2 + 1));
+  }
+
+ protected:
+  // Extracts features for all tokens and stores them for later retrieval.
+  void Extract(const std::vector<std::vector<int>>& sparse_features,
+               const std::vector<std::vector<float>>& dense_features,
+               const std::function<bool(const std::vector<int>&,
+                                        const std::vector<float>&, float*)>&
+                   feature_vector_fn);
+
+  // Remaps extracted features to V0 feature format. The mapping is using
+  // the v0_feature_storage_ as the backing storage for the mapped features.
+  // For each token the features consist of:
+  //  - chargram embeddings
+  //  - dense features
+  // They are concatenated together as [chargram embeddings; dense features]
+  // for each token independently.
+  // The V0 features require that the chargram embeddings for tokens are
+  // concatenated first together, and at the end, the dense features for the
+  // tokens are concatenated to it.
+  void RemapV0FeatureVector(VectorSpan<float>* features);
+
+ private:
+  const VectorSpan<Token> tokens_;
+  const int context_size_;
+  const int feature_vector_size_;
+  bool remap_v0_feature_vector_;
+  int remap_v0_chargram_embedding_size_;
+
+  std::vector<float> features_;
+  std::vector<float> v0_feature_storage_;
+};
+
+}  // namespace libtextclassifier
+
+#endif  // LIBTEXTCLASSIFIER_SMARTSELECT_CACHED_FEATURES_H_
diff --git a/smartselect/feature-processor.cc b/smartselect/feature-processor.cc
index 0ac7bd3..0ba25ca 100644
--- a/smartselect/feature-processor.cc
+++ b/smartselect/feature-processor.cc
@@ -93,8 +93,7 @@
       for (const auto& split_point : split_points) {
         Token new_token(token_word.UTF8Substring(last_start, split_point),
                         current_pos,
-                        current_pos + std::distance(last_start, split_point),
-                        /*is_in_span=*/false);
+                        current_pos + std::distance(last_start, split_point));
 
         last_start = split_point;
         current_pos = new_token.end;
@@ -169,7 +168,13 @@
 
 }  // namespace internal
 
-const char* const FeatureProcessor::kFeatureTypeName = "chargram_continuous";
+std::string FeatureProcessor::GetDefaultCollection() const {
+  if (options_.default_collection() >= options_.collections_size()) {
+    TC_LOG(ERROR) << "No collections specified. Returning empty string.";
+    return "";
+  }
+  return options_.collections(options_.default_collection());
+}
 
 std::vector<Token> FeatureProcessor::Tokenize(
     const std::string& utf8_text) const {
@@ -177,7 +182,7 @@
 }
 
 bool FeatureProcessor::LabelToSpan(
-    const int label, const std::vector<Token>& tokens,
+    const int label, const VectorSpan<Token>& tokens,
     std::pair<CodepointIndex, CodepointIndex>* span) const {
   if (tokens.size() != GetNumContextTokens()) {
     return false;
@@ -283,7 +288,8 @@
   TokenIndex end_token = kInvalidIndex;
   for (int i = 0; i < selectable_tokens.size(); ++i) {
     if (codepoint_start <= selectable_tokens[i].start &&
-        codepoint_end >= selectable_tokens[i].end) {
+        codepoint_end >= selectable_tokens[i].end &&
+        !selectable_tokens[i].is_padding) {
       if (start_token == kInvalidIndex) {
         start_token = i;
       }
@@ -386,194 +392,16 @@
   }
 }
 
-bool FeatureProcessor::GetFeatures(
-    const std::string& context, CodepointSpan input_span,
-    std::vector<nlp_core::FeatureVector>* features,
-    std::vector<float>* extra_features,
+bool FeatureProcessor::SelectionLabelSpans(
+    const VectorSpan<Token> tokens,
     std::vector<CodepointSpan>* selection_label_spans) const {
-  return FeatureProcessor::GetFeaturesAndLabels(
-      context, input_span, {kInvalidIndex, kInvalidIndex}, "", features,
-      extra_features, selection_label_spans, /*selection_label=*/nullptr,
-      /*selection_codepoint_label=*/nullptr, /*classification_label=*/nullptr);
-}
-
-bool FeatureProcessor::GetFeaturesAndLabels(
-    const std::string& context, CodepointSpan input_span,
-    CodepointSpan label_span, const std::string& label_collection,
-    std::vector<nlp_core::FeatureVector>* features,
-    std::vector<float>* extra_features,
-    std::vector<CodepointSpan>* selection_label_spans, int* selection_label,
-    CodepointSpan* selection_codepoint_label, int* classification_label) const {
-  if (features == nullptr) {
-    return false;
-  }
-  *features =
-      std::vector<nlp_core::FeatureVector>(options_.context_size() * 2 + 1);
-
-  std::vector<Token> input_tokens = Tokenize(context);
-
-  if (options_.split_tokens_on_selection_boundaries()) {
-    internal::SplitTokensOnSelectionBoundaries(input_span, &input_tokens);
-  }
-
-  if (options_.only_use_line_with_click()) {
-    internal::StripTokensFromOtherLines(context, input_span, &input_tokens);
-  }
-
-  const int click_pos = FindCenterToken(input_span, input_tokens);
-  if (click_pos == kInvalidIndex) {
-    TC_LOG(ERROR) << "Could not extract click position.";
-    return false;
-  }
-
-  if (options_.min_supported_codepoint_ratio() > 0) {
-    const float supported_codepoint_ratio =
-        SupportedCodepointsRatio(click_pos, input_tokens);
-    if (supported_codepoint_ratio < options_.min_supported_codepoint_ratio()) {
-      TC_LOG(INFO) << "Not enough supported codepoints in the context: "
-                   << supported_codepoint_ratio;
+  for (int i = 0; i < label_to_selection_.size(); ++i) {
+    CodepointSpan span;
+    if (!LabelToSpan(i, tokens, &span)) {
+      TC_LOG(ERROR) << "Could not convert label to span: " << i;
       return false;
     }
-  }
-
-  std::vector<Token> output_tokens;
-  bool status = ComputeFeatures(click_pos, input_tokens, input_span, features,
-                                extra_features, &output_tokens);
-  if (!status) {
-    TC_LOG(ERROR) << "Feature computation failed.";
-    return false;
-  }
-
-  if (selection_label != nullptr) {
-    status = SpanToLabel(label_span, output_tokens, selection_label);
-    if (!status) {
-      TC_LOG(ERROR) << "Could not convert selection span to label.";
-      return false;
-    }
-  }
-
-  if (selection_codepoint_label != nullptr) {
-    *selection_codepoint_label = label_span;
-  }
-
-  if (selection_label_spans != nullptr) {
-    for (int i = 0; i < label_to_selection_.size(); ++i) {
-      CodepointSpan span;
-      status = LabelToSpan(i, output_tokens, &span);
-      if (!status) {
-        TC_LOG(ERROR) << "Could not convert label to span: " << i;
-        return false;
-      }
-      selection_label_spans->push_back(span);
-    }
-  }
-
-  if (classification_label != nullptr) {
-    *classification_label = CollectionToLabel(label_collection);
-  }
-
-  return true;
-}
-
-bool FeatureProcessor::GetFeaturesAndLabels(
-    const std::string& context, CodepointSpan input_span,
-    CodepointSpan label_span, const std::string& label_collection,
-    std::vector<std::vector<std::pair<int, float>>>* features,
-    std::vector<float>* extra_features,
-    std::vector<CodepointSpan>* selection_label_spans, int* selection_label,
-    CodepointSpan* selection_codepoint_label, int* classification_label) const {
-  if (features == nullptr) {
-    return false;
-  }
-  if (extra_features == nullptr) {
-    return false;
-  }
-
-  std::vector<nlp_core::FeatureVector> feature_vectors;
-  bool result = GetFeaturesAndLabels(
-      context, input_span, label_span, label_collection, &feature_vectors,
-      extra_features, selection_label_spans, selection_label,
-      selection_codepoint_label, classification_label);
-
-  if (!result) {
-    return false;
-  }
-
-  features->clear();
-  for (int i = 0; i < feature_vectors.size(); ++i) {
-    features->emplace_back();
-    for (int j = 0; j < feature_vectors[i].size(); ++j) {
-      nlp_core::FloatFeatureValue feature_value(feature_vectors[i].value(j));
-      (*features)[i].push_back({feature_value.id, feature_value.weight});
-    }
-  }
-
-  return true;
-}
-
-bool FeatureProcessor::ComputeFeatures(
-    int click_pos, const std::vector<Token>& tokens,
-    CodepointSpan selected_span, std::vector<nlp_core::FeatureVector>* features,
-    std::vector<float>* extra_features,
-    std::vector<Token>* output_tokens) const {
-  int dropout_left = 0;
-  int dropout_right = 0;
-  if (options_.context_dropout_probability() > 0.0) {
-    // Determine how much context to drop.
-    bool status = GetContextDropoutRange(&dropout_left, &dropout_right);
-    if (!status) {
-      return false;
-    }
-  }
-
-  int feature_index = 0;
-  output_tokens->reserve(options_.context_size() * 2 + 1);
-  const int num_extra_features =
-      static_cast<int>(options_.extract_case_feature()) +
-      static_cast<int>(options_.extract_selection_mask_feature());
-  extra_features->reserve((options_.context_size() * 2 + 1) *
-                          num_extra_features);
-  for (int i = click_pos - options_.context_size();
-       i <= click_pos + options_.context_size(); ++i, ++feature_index) {
-    std::vector<int> sparse_features;
-    std::vector<float> dense_features;
-
-    const bool is_valid_token = i >= 0 && i < tokens.size();
-
-    bool is_dropped = false;
-    if (options_.context_dropout_probability() > 0.0) {
-      if (i < click_pos - options_.context_size() + dropout_left) {
-        is_dropped = true;
-      } else if (i > click_pos + options_.context_size() - dropout_right) {
-        is_dropped = true;
-      }
-    }
-
-    if (is_valid_token && !is_dropped) {
-      Token token(tokens[i]);
-      token.is_in_span = token.start >= selected_span.first &&
-                         token.end <= selected_span.second;
-      feature_extractor_.Extract(token, &sparse_features, &dense_features);
-      output_tokens->push_back(tokens[i]);
-    } else {
-      feature_extractor_.Extract(Token(), &sparse_features, &dense_features);
-      // This adds an empty string for each missing context token to exactly
-      // match the input tokens to the network.
-      output_tokens->emplace_back();
-    }
-
-    for (int feature_id : sparse_features) {
-      const int64 feature_value =
-          nlp_core::FloatFeatureValue(feature_id, 1.0 / sparse_features.size())
-              .discrete_value;
-      (*features)[feature_index].add(
-          const_cast<nlp_core::NumericFeatureType*>(&feature_type_),
-          feature_value);
-    }
-
-    for (float value : dense_features) {
-      extra_features->push_back(value);
-    }
+    selection_label_spans->push_back(span);
   }
   return true;
 }
@@ -680,25 +508,129 @@
   }
 }
 
-bool FeatureProcessor::GetContextDropoutRange(int* dropout_left,
-                                              int* dropout_right) const {
-  std::uniform_real_distribution<> uniform01_draw(0, 1);
-  if (uniform01_draw(*random_) < options_.context_dropout_probability()) {
-    if (options_.use_variable_context_dropout()) {
-      std::uniform_int_distribution<> uniform_context_draw(
-          0, options_.context_size());
-      // Select how much to drop in the range: [zero; context size]
-      *dropout_left = uniform_context_draw(*random_);
-      *dropout_right = uniform_context_draw(*random_);
-    } else {
-      // Drop all context.
+void FeatureProcessor::TokenizeAndFindClick(const std::string& context,
+                                            CodepointSpan input_span,
+                                            std::vector<Token>* tokens,
+                                            int* click_pos) const {
+  TC_CHECK(tokens != nullptr);
+  *tokens = Tokenize(context);
+
+  if (options_.split_tokens_on_selection_boundaries()) {
+    internal::SplitTokensOnSelectionBoundaries(input_span, tokens);
+  }
+
+  if (options_.only_use_line_with_click()) {
+    internal::StripTokensFromOtherLines(context, input_span, tokens);
+  }
+
+  int local_click_pos;
+  if (click_pos == nullptr) {
+    click_pos = &local_click_pos;
+  }
+  *click_pos = FindCenterToken(input_span, *tokens);
+}
+
+namespace internal {
+
+void StripOrPadTokens(TokenSpan relative_click_span, int context_size,
+                      std::vector<Token>* tokens, int* click_pos) {
+  int right_context_needed = relative_click_span.second + context_size;
+  if (*click_pos + right_context_needed + 1 >= tokens->size()) {
+    // Pad max the context size.
+    const int num_pad_tokens = std::min(
+        context_size, static_cast<int>(*click_pos + right_context_needed + 1 -
+                                       tokens->size()));
+    std::vector<Token> pad_tokens(num_pad_tokens);
+    tokens->insert(tokens->end(), pad_tokens.begin(), pad_tokens.end());
+  } else if (*click_pos + right_context_needed + 1 < tokens->size() - 1) {
+    // Strip unused tokens.
+    auto it = tokens->begin();
+    std::advance(it, *click_pos + right_context_needed + 1);
+    tokens->erase(it, tokens->end());
+  }
+
+  int left_context_needed = relative_click_span.first + context_size;
+  if (*click_pos < left_context_needed) {
+    // Pad max the context size.
+    const int num_pad_tokens =
+        std::min(context_size, left_context_needed - *click_pos);
+    std::vector<Token> pad_tokens(num_pad_tokens);
+    tokens->insert(tokens->begin(), pad_tokens.begin(), pad_tokens.end());
+    *click_pos += num_pad_tokens;
+  } else if (*click_pos > left_context_needed) {
+    // Strip unused tokens.
+    auto it = tokens->begin();
+    std::advance(it, *click_pos - left_context_needed);
+    *click_pos -= it - tokens->begin();
+    tokens->erase(tokens->begin(), it);
+  }
+}
+
+}  // namespace internal
+
+bool FeatureProcessor::ExtractFeatures(
+    const std::string& context, CodepointSpan input_span,
+    TokenSpan relative_click_span, const FeatureVectorFn& feature_vector_fn,
+    int feature_vector_size, std::vector<Token>* tokens, int* click_pos,
+    std::unique_ptr<CachedFeatures>* cached_features) const {
+  TokenizeAndFindClick(context, input_span, tokens, click_pos);
+
+  // If the default click method failed fails, let's try to do sub-token
+  // matching before we fail.
+  if (*click_pos == kInvalidIndex) {
+    *click_pos = internal::CenterTokenFromClick(input_span, *tokens);
+    if (*click_pos == kInvalidIndex) {
       return false;
     }
-  } else {
-    *dropout_left = 0;
-    *dropout_right = 0;
   }
+
+  internal::StripOrPadTokens(relative_click_span, options_.context_size(),
+                             tokens, click_pos);
+
+  if (options_.min_supported_codepoint_ratio() > 0) {
+    const float supported_codepoint_ratio =
+        SupportedCodepointsRatio(*click_pos, *tokens);
+    if (supported_codepoint_ratio < options_.min_supported_codepoint_ratio()) {
+      TC_LOG(INFO) << "Not enough supported codepoints in the context: "
+                   << supported_codepoint_ratio;
+      return false;
+    }
+  }
+
+  std::vector<std::vector<int>> sparse_features(tokens->size());
+  std::vector<std::vector<float>> dense_features(tokens->size());
+  for (int i = 0; i < tokens->size(); ++i) {
+    const Token& token = (*tokens)[i];
+    if (!feature_extractor_.Extract(token, token.IsContainedInSpan(input_span),
+                                    &(sparse_features[i]),
+                                    &(dense_features[i]))) {
+      TC_LOG(ERROR) << "Could not extract token's features: " << token;
+      return false;
+    }
+  }
+
+  cached_features->reset(new CachedFeatures(
+      *tokens, options_.context_size(), sparse_features, dense_features,
+      feature_vector_fn, feature_vector_size));
+
+  if (*cached_features == nullptr) {
+    return false;
+  }
+
+  if (options_.feature_version() == 0) {
+    (*cached_features)
+        ->SetV0FeatureMode(feature_vector_size -
+                           feature_extractor_.DenseFeaturesCount());
+  }
+
   return true;
 }
 
+int FeatureProcessor::PadContext(std::vector<Token>* tokens) const {
+  std::vector<Token> pad_tokens(options_.context_size());
+  tokens->insert(tokens->begin(), pad_tokens.begin(), pad_tokens.end());
+  tokens->insert(tokens->end(), pad_tokens.begin(), pad_tokens.end());
+  return options_.context_size();
+}
+
 }  // namespace libtextclassifier
diff --git a/smartselect/feature-processor.h b/smartselect/feature-processor.h
index 85043e3..2f1e530 100644
--- a/smartselect/feature-processor.h
+++ b/smartselect/feature-processor.h
@@ -24,16 +24,25 @@
 #include <string>
 #include <vector>
 
-#include "common/feature-extractor.h"
+#include "smartselect/cached-features.h"
 #include "smartselect/text-classification-model.pb.h"
 #include "smartselect/token-feature-extractor.h"
 #include "smartselect/tokenizer.h"
 #include "smartselect/types.h"
+#include "util/base/logging.h"
 
 namespace libtextclassifier {
 
 constexpr int kInvalidLabel = -1;
 
+// Maps a vector of sparse features and a vector of dense features to a vector
+// of features that combines both.
+// The output is written to the memory location pointed to  by the last float*
+// argument.
+// Returns true on success false on failure.
+using FeatureVectorFn = std::function<bool(const std::vector<int>&,
+                                           const std::vector<float>&, float*)>;
+
 namespace internal {
 
 // Parses the serialized protocol buffer.
@@ -61,23 +70,27 @@
 int CenterTokenFromMiddleOfSelection(
     CodepointSpan span, const std::vector<Token>& selectable_tokens);
 
+// Strips the tokens from the tokens vector that are not used for feature
+// extraction because they are out of scope, or pads them so that there is
+// enough tokens in the required context_size for all inferences with a click
+// in relative_click_span.
+void StripOrPadTokens(TokenSpan relative_click_span, int context_size,
+                      std::vector<Token>* tokens, int* click_pos);
+
 }  // namespace internal
 
 TokenSpan CodepointSpanToTokenSpan(const std::vector<Token>& selectable_tokens,
                                    CodepointSpan codepoint_span);
 
-// Takes care of preparing features for the FFModel.
+// Takes care of preparing features for the span prediction model.
 class FeatureProcessor {
  public:
   explicit FeatureProcessor(const FeatureProcessorOptions& options)
-      : options_(options),
-        feature_extractor_(
+      : feature_extractor_(
             internal::BuildTokenFeatureExtractorOptions(options)),
-        feature_type_(FeatureProcessor::kFeatureTypeName,
-                      options.num_buckets()),
+        options_(options),
         tokenizer_({options.tokenization_codepoint_config().begin(),
-                    options.tokenization_codepoint_config().end()}),
-        random_(new std::mt19937(std::random_device()())) {
+                    options.tokenization_codepoint_config().end()}) {
     MakeLabelMaps();
     PrepareSupportedCodepointRanges(
         {options.supported_codepoint_ranges().begin(),
@@ -91,38 +104,12 @@
   // Tokenizes the input string using the selected tokenization method.
   std::vector<Token> Tokenize(const std::string& utf8_text) const;
 
-  bool GetFeatures(const std::string& context, CodepointSpan input_span,
-                   std::vector<nlp_core::FeatureVector>* features,
-                   std::vector<float>* extra_features,
-                   std::vector<CodepointSpan>* selection_label_spans) const;
-
-  // NOTE: If dropout is on, subsequent calls of this function with the same
-  // arguments might return different results.
-  bool GetFeaturesAndLabels(const std::string& context,
-                            CodepointSpan input_span, CodepointSpan label_span,
-                            const std::string& label_collection,
-                            std::vector<nlp_core::FeatureVector>* features,
-                            std::vector<float>* extra_features,
-                            std::vector<CodepointSpan>* selection_label_spans,
-                            int* selection_label,
-                            CodepointSpan* selection_codepoint_label,
-                            int* classification_label) const;
-
-  // Same as above but uses std::vector instead of FeatureVector.
-  // NOTE: If dropout is on, subsequent calls of this function with the same
-  // arguments might return different results.
-  bool GetFeaturesAndLabels(
-      const std::string& context, CodepointSpan input_span,
-      CodepointSpan label_span, const std::string& label_collection,
-      std::vector<std::vector<std::pair<int, float>>>* features,
-      std::vector<float>* extra_features,
-      std::vector<CodepointSpan>* selection_label_spans, int* selection_label,
-      CodepointSpan* selection_codepoint_label,
-      int* classification_label) const;
-
   // Converts a label into a token span.
   bool LabelToTokenSpan(int label, TokenSpan* token_span) const;
 
+  // Gets the total number of selection labels.
+  int GetSelectionLabelCount() const { return label_to_selection_.size(); }
+
   // Gets the string value for given collection label.
   std::string LabelToCollection(int label) const;
 
@@ -130,17 +117,35 @@
   int NumCollections() const { return collection_to_label_.size(); }
 
   // Gets the name of the default collection.
-  std::string GetDefaultCollection() const {
-    return options_.collections(options_.default_collection());
+  std::string GetDefaultCollection() const;
+
+  const FeatureProcessorOptions& GetOptions() const { return options_; }
+
+  // Tokenizes the context and input span, and finds the click position.
+  void TokenizeAndFindClick(const std::string& context,
+                            CodepointSpan input_span,
+                            std::vector<Token>* tokens, int* click_pos) const;
+
+  // Extracts features as a CachedFeatures object that can be used for repeated
+  // inference over token spans in the given context.
+  bool ExtractFeatures(const std::string& context, CodepointSpan input_span,
+                       TokenSpan relative_click_span,
+                       const FeatureVectorFn& feature_vector_fn,
+                       int feature_vector_size, std::vector<Token>* tokens,
+                       int* click_pos,
+                       std::unique_ptr<CachedFeatures>* cached_features) const;
+
+  // Fills selection_label_spans with CodepointSpans that correspond to the
+  // selection labels. The CodepointSpans are based on the codepoint ranges of
+  // given tokens.
+  bool SelectionLabelSpans(
+      VectorSpan<Token> tokens,
+      std::vector<CodepointSpan>* selection_label_spans) const;
+
+  int DenseFeaturesCount() const {
+    return feature_extractor_.DenseFeaturesCount();
   }
 
-  FeatureProcessorOptions GetOptions() const { return options_; }
-
-  int GetSelectionLabelCount() const { return label_to_selection_.size(); }
-
-  // Sets the source of randomness.
-  void SetRandom(std::mt19937* new_random) { random_.reset(new_random); }
-
  protected:
   // Represents a codepoint range [start, end).
   struct CodepointRange {
@@ -151,23 +156,6 @@
         : start(arg_start), end(arg_end) {}
   };
 
-  // Extracts features for given word.
-  std::vector<int> GetWordFeatures(const std::string& word) const;
-
-  // NOTE: If dropout is on, subsequent calls of this function with the same
-  // arguments might return different results.
-  bool ComputeFeatures(int click_pos,
-                       const std::vector<Token>& selectable_tokens,
-                       CodepointSpan selected_span,
-                       std::vector<nlp_core::FeatureVector>* features,
-                       std::vector<float>* extra_features,
-                       std::vector<Token>* output_tokens) const;
-
-  // Helper function that computes how much left context and how much right
-  // context should be dropped. Uses a mutable random_ member as a source of
-  // randomness.
-  bool GetContextDropoutRange(int* dropout_left, int* dropout_right) const;
-
   // Returns the class id corresponding to the given string collection
   // identifier. There is a catch-all class id that the function returns for
   // unknown collections.
@@ -185,7 +173,7 @@
 
   // Converts a label into a span of codepoint indices corresponding to it
   // given output_tokens.
-  bool LabelToSpan(int label, const std::vector<Token>& output_tokens,
+  bool LabelToSpan(int label, const VectorSpan<Token>& output_tokens,
                    CodepointSpan* span) const;
 
   // Converts a span to the corresponding label given output_tokens.
@@ -195,11 +183,6 @@
   // Converts a token span to the corresponding label.
   int TokenSpanToLabel(const std::pair<TokenIndex, TokenIndex>& span) const;
 
-  // Finds the center token index in tokens vector, using the method defined
-  // in options_.
-  int FindCenterToken(CodepointSpan span,
-                      const std::vector<Token>& tokens) const;
-
   void PrepareSupportedCodepointRanges(
       const std::vector<FeatureProcessorOptions::CodepointRange>&
           codepoint_range_configs);
@@ -212,14 +195,18 @@
   // Returns true if given codepoint is supported.
   bool IsCodepointSupported(int codepoint) const;
 
+  // Finds the center token index in tokens vector, using the method defined
+  // in options_.
+  int FindCenterToken(CodepointSpan span,
+                      const std::vector<Token>& tokens) const;
+
+  // Pads tokens with options.context_size() padding tokens on both sides.
+  int PadContext(std::vector<Token>* tokens) const;
+
+  const TokenFeatureExtractor feature_extractor_;
+
  private:
-  FeatureProcessorOptions options_;
-
-  TokenFeatureExtractor feature_extractor_;
-
-  static const char* const kFeatureTypeName;
-
-  nlp_core::NumericFeatureType feature_type_;
+  const FeatureProcessorOptions options_;
 
   // Mapping between token selection spans and labels ids.
   std::map<TokenSpan, int> selection_to_label_;
@@ -233,9 +220,6 @@
   // Codepoint ranges that define what codepoints are supported by the model.
   // NOTE: Must be sorted.
   std::vector<CodepointRange> supported_codepoint_ranges_;
-
-  // Source of randomness.
-  mutable std::unique_ptr<std::mt19937> random_;
 };
 
 }  // namespace libtextclassifier
diff --git a/smartselect/text-classification-model.cc b/smartselect/text-classification-model.cc
index b21614d..ca0484f 100644
--- a/smartselect/text-classification-model.cc
+++ b/smartselect/text-classification-model.cc
@@ -91,6 +91,49 @@
   sharing_options_ = selection_params_->GetSharingModelOptions();
 }
 
+namespace {
+
+// Converts sparse features vector to nlp_core::FeatureVector.
+void SparseFeaturesToFeatureVector(
+    const std::vector<int> sparse_features,
+    const nlp_core::NumericFeatureType& feature_type,
+    nlp_core::FeatureVector* result) {
+  for (int feature_id : sparse_features) {
+    const int64 feature_value =
+        nlp_core::FloatFeatureValue(feature_id, 1.0 / sparse_features.size())
+            .discrete_value;
+    result->add(const_cast<nlp_core::NumericFeatureType*>(&feature_type),
+                feature_value);
+  }
+}
+
+// Returns a function that can be used for mapping sparse and dense features
+// to a float feature vector.
+// NOTE: The network object needs to be available at the time when the returned
+// function object is used.
+FeatureVectorFn CreateFeatureVectorFn(const EmbeddingNetwork& network,
+                                      int sparse_embedding_size) {
+  const nlp_core::NumericFeatureType feature_type("chargram_continuous", 0);
+  return [&network, sparse_embedding_size, feature_type](
+             const std::vector<int>& sparse_features,
+             const std::vector<float>& dense_features, float* embedding) {
+    nlp_core::FeatureVector feature_vector;
+    SparseFeaturesToFeatureVector(sparse_features, feature_type,
+                                  &feature_vector);
+
+    if (network.GetEmbedding(feature_vector, 0, embedding)) {
+      for (int i = 0; i < dense_features.size(); i++) {
+        embedding[sparse_embedding_size + i] = dense_features[i];
+      }
+      return true;
+    } else {
+      return false;
+    }
+  };
+}
+
+}  // namespace
+
 bool TextClassificationModel::LoadModels(int fd) {
   MmapHandle mmap_handle = MmapFile(fd);
   if (!mmap_handle.ok()) {
@@ -111,6 +154,8 @@
   selection_network_.reset(new EmbeddingNetwork(selection_params_.get()));
   selection_feature_processor_.reset(
       new FeatureProcessor(selection_params_->GetFeatureProcessorOptions()));
+  selection_feature_fn_ = CreateFeatureVectorFn(
+      *selection_network_, selection_network_->EmbeddingSize(0));
 
   model_data += selection_model_length;
   uint32 sharing_model_length =
@@ -125,25 +170,39 @@
   sharing_network_.reset(new EmbeddingNetwork(sharing_params_.get()));
   sharing_feature_processor_.reset(
       new FeatureProcessor(sharing_params_->GetFeatureProcessorOptions()));
+  sharing_feature_fn_ = CreateFeatureVectorFn(
+      *sharing_network_, sharing_network_->EmbeddingSize(0));
 
   return true;
 }
 
 EmbeddingNetwork::Vector TextClassificationModel::InferInternal(
     const std::string& context, CodepointSpan span,
-    const FeatureProcessor& feature_processor, const EmbeddingNetwork* network,
+    const FeatureProcessor& feature_processor, const EmbeddingNetwork& network,
+    const FeatureVectorFn& feature_vector_fn,
     std::vector<CodepointSpan>* selection_label_spans) const {
-  std::vector<FeatureVector> features;
-  std::vector<float> extra_features;
-  const bool features_computed = feature_processor.GetFeatures(
-      context, span, &features, &extra_features, selection_label_spans);
-
-  EmbeddingNetwork::Vector scores;
-  if (!features_computed) {
-    TC_LOG(ERROR) << "Features not computed";
-    return scores;
+  std::vector<Token> tokens;
+  int click_pos;
+  std::unique_ptr<CachedFeatures> cached_features;
+  int embedding_size = network.EmbeddingSize(0);
+  if (!feature_processor.ExtractFeatures(
+          context, span, /*relative_click_span=*/{0, 0},
+          CreateFeatureVectorFn(network, embedding_size),
+          embedding_size + feature_processor.DenseFeaturesCount(), &tokens,
+          &click_pos, &cached_features)) {
+    TC_LOG(ERROR) << "Could not extract features.";
+    return {};
   }
-  network->ComputeFinalScores(features, extra_features, &scores);
+
+  VectorSpan<float> features;
+  VectorSpan<Token> output_tokens;
+  if (!cached_features->Get(click_pos, &features, &output_tokens)) {
+    TC_LOG(ERROR) << "Could not extract features.";
+    return {};
+  }
+
+  std::vector<float> scores;
+  network.ComputeLogits(features, &scores);
   return scores;
 }
 
@@ -161,6 +220,15 @@
     return click_indices;
   }
 
+  const UnicodeText context_unicode =
+      UTF8ToUnicodeText(context, /*do_copy=*/false);
+  const int context_length =
+      std::distance(context_unicode.begin(), context_unicode.end());
+  if (std::get<0>(click_indices) >= context_length ||
+      std::get<1>(click_indices) > context_length) {
+    return click_indices;
+  }
+
   CodepointSpan result;
   if (selection_options_.enforce_symmetry()) {
     result = SuggestSelectionSymmetrical(context, click_indices);
@@ -176,21 +244,12 @@
   return result;
 }
 
-std::pair<CodepointSpan, float>
-TextClassificationModel::SuggestSelectionInternal(
-    const std::string& context, CodepointSpan click_indices) const {
-  if (!initialized_) {
-    TC_LOG(ERROR) << "Not initialized";
-    return {click_indices, -1.0};
-  }
+namespace {
 
-  std::vector<CodepointSpan> selection_label_spans;
-  EmbeddingNetwork::Vector scores =
-      InferInternal(context, click_indices, *selection_feature_processor_,
-                    selection_network_.get(), &selection_label_spans);
-
+std::pair<CodepointSpan, float> BestSelectionSpan(
+    CodepointSpan original_click_indices, const std::vector<float>& scores,
+    const std::vector<CodepointSpan>& selection_label_spans) {
   if (!scores.empty()) {
-    scores = nlp_core::ComputeSoftmax(scores);
     const int prediction =
         std::max_element(scores.begin(), scores.end()) - scores.begin();
     std::pair<CodepointIndex, CodepointIndex> selection =
@@ -200,15 +259,34 @@
       TC_LOG(ERROR) << "Invalid indices predicted, returning input: "
                     << prediction << " " << selection.first << " "
                     << selection.second;
-      return {click_indices, -1.0};
+      return {original_click_indices, -1.0};
     }
 
     return {{selection.first, selection.second}, scores[prediction]};
   } else {
     TC_LOG(ERROR) << "Returning default selection: scores.size() = "
                   << scores.size();
+    return {original_click_indices, -1.0};
+  }
+}
+
+}  // namespace
+
+std::pair<CodepointSpan, float>
+TextClassificationModel::SuggestSelectionInternal(
+    const std::string& context, CodepointSpan click_indices) const {
+  if (!initialized_) {
+    TC_LOG(ERROR) << "Not initialized";
     return {click_indices, -1.0};
   }
+
+  std::vector<CodepointSpan> selection_label_spans;
+  EmbeddingNetwork::Vector scores = InferInternal(
+      context, click_indices, *selection_feature_processor_,
+      *selection_network_, selection_feature_fn_, &selection_label_spans);
+  scores = nlp_core::ComputeSoftmax(scores);
+
+  return BestSelectionSpan(click_indices, scores, selection_label_spans);
 }
 
 namespace {
@@ -245,27 +323,49 @@
 // selection.
 CodepointSpan TextClassificationModel::SuggestSelectionSymmetrical(
     const std::string& context, CodepointSpan click_indices) const {
-  std::vector<Token> tokens = selection_feature_processor_->Tokenize(context);
-  internal::StripTokensFromOtherLines(context, click_indices, &tokens);
-
-  // const int click_index = GetClickTokenIndex(tokens, click_indices);
-  const int click_index = internal::CenterTokenFromClick(click_indices, tokens);
-  if (click_index == kInvalidIndex) {
+  const int symmetry_context_size = selection_options_.symmetry_context_size();
+  std::vector<Token> tokens;
+  std::unique_ptr<CachedFeatures> cached_features;
+  int click_index;
+  int embedding_size = selection_network_->EmbeddingSize(0);
+  if (!selection_feature_processor_->ExtractFeatures(
+          context, click_indices, /*relative_click_span=*/
+          {symmetry_context_size, symmetry_context_size + 1},
+          selection_feature_fn_,
+          embedding_size + selection_feature_processor_->DenseFeaturesCount(),
+          &tokens, &click_index, &cached_features)) {
+    TC_LOG(ERROR) << "Couldn't ExtractFeatures.";
     return click_indices;
   }
 
-  const int symmetry_context_size = selection_options_.symmetry_context_size();
-
   // Scan in the symmetry context for selection span proposals.
   std::vector<std::pair<CodepointSpan, float>> proposals;
+
   for (int i = -symmetry_context_size; i < symmetry_context_size + 1; ++i) {
     const int token_index = click_index + i;
-    if (token_index >= 0 && token_index < tokens.size()) {
+    if (token_index >= 0 && token_index < tokens.size() &&
+        !tokens[token_index].is_padding) {
       float score;
+      VectorSpan<float> features;
+      VectorSpan<Token> output_tokens;
+
       CodepointSpan span;
-      std::tie(span, score) = SuggestSelectionInternal(
-          context, {tokens[token_index].start, tokens[token_index].end});
-      proposals.push_back({span, score});
+      if (cached_features->Get(token_index, &features, &output_tokens)) {
+        std::vector<float> scores;
+        selection_network_->ComputeLogits(features, &scores);
+
+        std::vector<CodepointSpan> selection_label_spans;
+        if (selection_feature_processor_->SelectionLabelSpans(
+                output_tokens, &selection_label_spans)) {
+          scores = nlp_core::ComputeSoftmax(scores);
+          std::tie(span, score) =
+              BestSelectionSpan(click_indices, scores, selection_label_spans);
+          if (span.first != kInvalidIndex && span.second != kInvalidIndex &&
+              score >= 0) {
+            proposals.push_back({span, score});
+          }
+        }
+      }
     }
   }
 
@@ -315,6 +415,13 @@
     return {};
   }
 
+  if (std::get<0>(selection_indices) >= std::get<1>(selection_indices)) {
+    TC_LOG(ERROR) << "Trying to run ClassifyText with invalid indices: "
+                  << std::get<0>(selection_indices) << " "
+                  << std::get<1>(selection_indices);
+    return {};
+  }
+
   if (hint_flags & SELECTION_IS_URL &&
       sharing_options_.always_accept_url_hint()) {
     return {{kUrlHintCollection, 1.0}};
@@ -327,7 +434,7 @@
 
   EmbeddingNetwork::Vector scores =
       InferInternal(context, selection_indices, *sharing_feature_processor_,
-                    sharing_network_.get(), nullptr);
+                    *sharing_network_, sharing_feature_fn_, nullptr);
   if (scores.empty()) {
     TC_LOG(ERROR) << "Using default class";
     return {};
diff --git a/smartselect/text-classification-model.h b/smartselect/text-classification-model.h
index 9f17ef5..ae2049b 100644
--- a/smartselect/text-classification-model.h
+++ b/smartselect/text-classification-model.h
@@ -88,7 +88,8 @@
   nlp_core::EmbeddingNetwork::Vector InferInternal(
       const std::string& context, CodepointSpan span,
       const FeatureProcessor& feature_processor,
-      const nlp_core::EmbeddingNetwork* network,
+      const nlp_core::EmbeddingNetwork& network,
+      const FeatureVectorFn& feature_vector_fn,
       std::vector<CodepointSpan>* selection_label_spans) const;
 
   // Returns a selection suggestion with a score.
@@ -104,9 +105,11 @@
   std::unique_ptr<ModelParams> selection_params_;
   std::unique_ptr<FeatureProcessor> selection_feature_processor_;
   std::unique_ptr<nlp_core::EmbeddingNetwork> selection_network_;
+  FeatureVectorFn selection_feature_fn_;
   std::unique_ptr<FeatureProcessor> sharing_feature_processor_;
   std::unique_ptr<ModelParams> sharing_params_;
   std::unique_ptr<nlp_core::EmbeddingNetwork> sharing_network_;
+  FeatureVectorFn sharing_feature_fn_;
 
   std::set<int> punctuation_to_strip_;
 };
diff --git a/smartselect/text-classification-model.proto b/smartselect/text-classification-model.proto
index ecf0fc6..7a4c9f1 100644
--- a/smartselect/text-classification-model.proto
+++ b/smartselect/text-classification-model.proto
@@ -159,6 +159,16 @@
   // Minimum ratio of supported codepoints in the input context. If the ratio
   // is lower than this, the feature computation will fail.
   optional float min_supported_codepoint_ratio = 24 [default = 0.0];
+
+  // Used for versioning the format of features the model expects.
+  //  - feature_version == 0:
+  //      For each token the features consist of:
+  //       - chargram embeddings
+  //       - dense features
+  //      Chargram embeddings for tokens are concatenated first together,
+  //      and at the end, the dense features for the tokens are concatenated
+  //      to it. So the resulting feature vector has two regions.
+  optional int32 feature_version = 25 [default = 0];
 };
 
 extend nlp_core.EmbeddingNetworkProto {
diff --git a/smartselect/token-feature-extractor.cc b/smartselect/token-feature-extractor.cc
index 1413385..6013ef3 100644
--- a/smartselect/token-feature-extractor.cc
+++ b/smartselect/token-feature-extractor.cc
@@ -207,7 +207,7 @@
   return result;
 }
 
-bool TokenFeatureExtractor::Extract(const Token& token,
+bool TokenFeatureExtractor::Extract(const Token& token, bool is_in_span,
                                     std::vector<int>* sparse_features,
                                     std::vector<float>* dense_features) const {
   if (sparse_features == nullptr || dense_features == nullptr) {
@@ -235,7 +235,7 @@
   }
 
   if (options_.extract_selection_mask_feature) {
-    if (token.is_in_span) {
+    if (is_in_span) {
       dense_features->push_back(1.0);
     } else {
       if (options_.unicode_aware_features) {
@@ -270,23 +270,4 @@
   return true;
 }
 
-bool TokenFeatureExtractor::Extract(
-    const std::vector<Token>& tokens,
-    std::vector<std::vector<int>>* sparse_features,
-    std::vector<std::vector<float>>* dense_features) const {
-  if (sparse_features == nullptr || dense_features == nullptr) {
-    return false;
-  }
-
-  sparse_features->resize(tokens.size());
-  dense_features->resize(tokens.size());
-  for (size_t i = 0; i < tokens.size(); i++) {
-    if (!Extract(tokens[i], &((*sparse_features)[i]),
-                 &((*dense_features)[i]))) {
-      return false;
-    }
-  }
-  return true;
-}
-
 }  // namespace libtextclassifier
diff --git a/smartselect/token-feature-extractor.h b/smartselect/token-feature-extractor.h
index e606e91..8502199 100644
--- a/smartselect/token-feature-extractor.h
+++ b/smartselect/token-feature-extractor.h
@@ -59,17 +59,20 @@
   explicit TokenFeatureExtractor(const TokenFeatureExtractorOptions& options);
 
   // Extracts features from a token.
+  //  - is_in_span is a bool indicator whether the token is a part of the
+  //      selection span (true) or not (false).
   //  - sparse_features are indices into a sparse feature vector of size
   //    options.num_buckets which are set to 1.0 (others are implicitly 0.0).
   //  - dense_features are values of a dense feature vector of size 0-2
   //    (depending on the options) for the token
-  bool Extract(const Token& token, std::vector<int>* sparse_features,
+  bool Extract(const Token& token, bool is_in_span,
+               std::vector<int>* sparse_features,
                std::vector<float>* dense_features) const;
 
-  // Convenience method that sequentially applies Extract to each Token.
-  bool Extract(const std::vector<Token>& tokens,
-               std::vector<std::vector<int>>* sparse_features,
-               std::vector<std::vector<float>>* dense_features) const;
+  int DenseFeaturesCount() const {
+    return options_.extract_case_feature +
+           options_.extract_selection_mask_feature + regex_patterns_.size();
+  }
 
  protected:
   // Hashes given token to given number of buckets.
diff --git a/smartselect/types.h b/smartselect/types.h
index 9f07f91..443e3ac 100644
--- a/smartselect/types.h
+++ b/smartselect/types.h
@@ -48,43 +48,31 @@
   CodepointIndex start;
   CodepointIndex end;
 
-  // Whether the token was in the input span.
-  bool is_in_span;
-
   // Whether the token is a padding token.
   bool is_padding;
 
   // Default constructor constructs the padding-token.
   Token()
-      : value(""),
-        start(kInvalidIndex),
-        end(kInvalidIndex),
-        is_in_span(false),
-        is_padding(true) {}
+      : value(""), start(kInvalidIndex), end(kInvalidIndex), is_padding(true) {}
 
   Token(const std::string& arg_value, CodepointIndex arg_start,
         CodepointIndex arg_end)
-      : Token(arg_value, arg_start, arg_end, false) {}
-
-  Token(const std::string& arg_value, CodepointIndex arg_start,
-        CodepointIndex arg_end, bool is_in_span)
-      : value(arg_value),
-        start(arg_start),
-        end(arg_end),
-        is_in_span(is_in_span),
-        is_padding(false) {}
+      : value(arg_value), start(arg_start), end(arg_end), is_padding(false) {}
 
   bool operator==(const Token& other) const {
     return value == other.value && start == other.start && end == other.end &&
-           is_in_span == other.is_in_span && is_padding == other.is_padding;
+           is_padding == other.is_padding;
+  }
+
+  bool IsContainedInSpan(CodepointSpan span) const {
+    return start >= span.first && end <= span.second;
   }
 };
 
 // Pretty-printing function for Token.
 inline std::ostream& operator<<(std::ostream& os, const Token& token) {
   return os << "Token(\"" << token.value << "\", " << token.start << ", "
-            << token.end << ", is_in_span=" << token.is_in_span
-            << ", is_padding=" << token.is_padding << ")";
+            << token.end << ", is_padding=" << token.is_padding << ")";
 }
 
 }  // namespace libtextclassifier
diff --git a/tests/cached-features_test.cc b/tests/cached-features_test.cc
new file mode 100644
index 0000000..b456816
--- /dev/null
+++ b/tests/cached-features_test.cc
@@ -0,0 +1,149 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "smartselect/cached-features.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier {
+namespace {
+
+class TestingCachedFeatures : public CachedFeatures {
+ public:
+  using CachedFeatures::CachedFeatures;
+  using CachedFeatures::RemapV0FeatureVector;
+};
+
+TEST(CachedFeaturesTest, Simple) {
+  std::vector<Token> tokens;
+  tokens.push_back(Token());
+  tokens.push_back(Token());
+  tokens.push_back(Token("Hello", 0, 1));
+  tokens.push_back(Token("World", 1, 2));
+  tokens.push_back(Token("today!", 2, 3));
+  tokens.push_back(Token());
+  tokens.push_back(Token());
+
+  std::vector<std::vector<int>> sparse_features(tokens.size());
+  for (int i = 0; i < sparse_features.size(); ++i) {
+    sparse_features[i].push_back(i);
+  }
+  std::vector<std::vector<float>> dense_features(tokens.size());
+  for (int i = 0; i < dense_features.size(); ++i) {
+    dense_features[i].push_back(-i);
+  }
+
+  TestingCachedFeatures feature_extractor(
+      tokens, /*context_size=*/2, sparse_features, dense_features,
+      [](const std::vector<int>& sparse_features,
+         const std::vector<float>& dense_features, float* features) {
+        features[0] = sparse_features[0];
+        features[1] = sparse_features[0];
+        features[2] = dense_features[0];
+        features[3] = dense_features[0];
+        features[4] = 123;
+        return true;
+      },
+      5);
+
+  VectorSpan<float> features;
+  VectorSpan<Token> output_tokens;
+  EXPECT_TRUE(feature_extractor.Get(2, &features, &output_tokens));
+  for (int i = 0; i < 5; i++) {
+    EXPECT_EQ(features[i * 5 + 0], i) << "Feature " << i;
+    EXPECT_EQ(features[i * 5 + 1], i) << "Feature " << i;
+    EXPECT_EQ(features[i * 5 + 2], -i) << "Feature " << i;
+    EXPECT_EQ(features[i * 5 + 3], -i) << "Feature " << i;
+    EXPECT_EQ(features[i * 5 + 4], 123) << "Feature " << i;
+  }
+}
+
+TEST(CachedFeaturesTest, InvalidInput) {
+  std::vector<Token> tokens;
+  tokens.push_back(Token());
+  tokens.push_back(Token());
+  tokens.push_back(Token("Hello", 0, 1));
+  tokens.push_back(Token("World", 1, 2));
+  tokens.push_back(Token("today!", 2, 3));
+  tokens.push_back(Token());
+  tokens.push_back(Token());
+
+  std::vector<std::vector<int>> sparse_features(tokens.size());
+  std::vector<std::vector<float>> dense_features(tokens.size());
+
+  TestingCachedFeatures feature_extractor(
+      tokens, /*context_size=*/2, sparse_features, dense_features,
+      [](const std::vector<int>& sparse_features,
+         const std::vector<float>& dense_features,
+         float* features) { return true; },
+      /*feature_vector_size=*/5);
+
+  VectorSpan<float> features;
+  VectorSpan<Token> output_tokens;
+  EXPECT_FALSE(feature_extractor.Get(-1000, &features, &output_tokens));
+  EXPECT_FALSE(feature_extractor.Get(-1, &features, &output_tokens));
+  EXPECT_FALSE(feature_extractor.Get(0, &features, &output_tokens));
+  EXPECT_TRUE(feature_extractor.Get(2, &features, &output_tokens));
+  EXPECT_TRUE(feature_extractor.Get(4, &features, &output_tokens));
+  EXPECT_FALSE(feature_extractor.Get(5, &features, &output_tokens));
+  EXPECT_FALSE(feature_extractor.Get(500, &features, &output_tokens));
+}
+
+TEST(CachedFeaturesTest, RemapV0FeatureVector) {
+  std::vector<Token> tokens;
+  tokens.push_back(Token());
+  tokens.push_back(Token());
+  tokens.push_back(Token("Hello", 0, 1));
+  tokens.push_back(Token("World", 1, 2));
+  tokens.push_back(Token("today!", 2, 3));
+  tokens.push_back(Token());
+  tokens.push_back(Token());
+
+  std::vector<std::vector<int>> sparse_features(tokens.size());
+  std::vector<std::vector<float>> dense_features(tokens.size());
+
+  TestingCachedFeatures feature_extractor(
+      tokens, /*context_size=*/2, sparse_features, dense_features,
+      [](const std::vector<int>& sparse_features,
+         const std::vector<float>& dense_features,
+         float* features) { return true; },
+      /*feature_vector_size=*/5);
+
+  std::vector<float> features_orig(5 * 5);
+  for (int i = 0; i < features_orig.size(); i++) {
+    features_orig[i] = i;
+  }
+  VectorSpan<float> features;
+
+  feature_extractor.SetV0FeatureMode(0);
+  features = VectorSpan<float>(features_orig);
+  feature_extractor.RemapV0FeatureVector(&features);
+  EXPECT_EQ(
+      std::vector<float>({0,  1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12,
+                          13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}),
+      std::vector<float>(features.begin(), features.end()));
+
+  feature_extractor.SetV0FeatureMode(2);
+  features = VectorSpan<float>(features_orig);
+  feature_extractor.RemapV0FeatureVector(&features);
+  EXPECT_EQ(std::vector<float>({0, 1, 5, 6,  10, 11, 15, 16, 20, 21, 2,  3, 4,
+                                7, 8, 9, 12, 13, 14, 17, 18, 19, 22, 23, 24}),
+            std::vector<float>(features.begin(), features.end()));
+}
+
+}  // namespace
+}  // namespace libtextclassifier
diff --git a/tests/feature-processor_test.cc b/tests/feature-processor_test.cc
index 88a93f3..e3a39e3 100644
--- a/tests/feature-processor_test.cc
+++ b/tests/feature-processor_test.cc
@@ -26,83 +26,83 @@
 using testing::FloatEq;
 
 TEST(FeatureProcessorTest, SplitTokensOnSelectionBoundariesMiddle) {
-  std::vector<Token> tokens{Token("Hělló", 0, 5, false),
-                            Token("fěěbař@google.com", 6, 23, false),
-                            Token("heře!", 24, 29, false)};
+  std::vector<Token> tokens{Token("Hělló", 0, 5),
+                            Token("fěěbař@google.com", 6, 23),
+                            Token("heře!", 24, 29)};
 
   internal::SplitTokensOnSelectionBoundaries({9, 12}, &tokens);
 
   // clang-format off
   EXPECT_THAT(tokens, ElementsAreArray(
-                          {Token("Hělló", 0, 5, false),
-                           Token("fěě", 6, 9, false),
-                           Token("bař", 9, 12, false),
-                           Token("@google.com", 12, 23, false),
-                           Token("heře!", 24, 29, false)}));
+                          {Token("Hělló", 0, 5),
+                           Token("fěě", 6, 9),
+                           Token("bař", 9, 12),
+                           Token("@google.com", 12, 23),
+                           Token("heře!", 24, 29)}));
   // clang-format on
 }
 
 TEST(FeatureProcessorTest, SplitTokensOnSelectionBoundariesBegin) {
-  std::vector<Token> tokens{Token("Hělló", 0, 5, false),
-                            Token("fěěbař@google.com", 6, 23, false),
-                            Token("heře!", 24, 29, false)};
+  std::vector<Token> tokens{Token("Hělló", 0, 5),
+                            Token("fěěbař@google.com", 6, 23),
+                            Token("heře!", 24, 29)};
 
   internal::SplitTokensOnSelectionBoundaries({6, 12}, &tokens);
 
   // clang-format off
   EXPECT_THAT(tokens, ElementsAreArray(
-                          {Token("Hělló", 0, 5, false),
-                           Token("fěěbař", 6, 12, false),
-                           Token("@google.com", 12, 23, false),
-                           Token("heře!", 24, 29, false)}));
+                          {Token("Hělló", 0, 5),
+                           Token("fěěbař", 6, 12),
+                           Token("@google.com", 12, 23),
+                           Token("heře!", 24, 29)}));
   // clang-format on
 }
 
 TEST(FeatureProcessorTest, SplitTokensOnSelectionBoundariesEnd) {
-  std::vector<Token> tokens{Token("Hělló", 0, 5, false),
-                            Token("fěěbař@google.com", 6, 23, false),
-                            Token("heře!", 24, 29, false)};
+  std::vector<Token> tokens{Token("Hělló", 0, 5),
+                            Token("fěěbař@google.com", 6, 23),
+                            Token("heře!", 24, 29)};
 
   internal::SplitTokensOnSelectionBoundaries({9, 23}, &tokens);
 
   // clang-format off
   EXPECT_THAT(tokens, ElementsAreArray(
-                          {Token("Hělló", 0, 5, false),
-                           Token("fěě", 6, 9, false),
-                           Token("bař@google.com", 9, 23, false),
-                           Token("heře!", 24, 29, false)}));
+                          {Token("Hělló", 0, 5),
+                           Token("fěě", 6, 9),
+                           Token("bař@google.com", 9, 23),
+                           Token("heře!", 24, 29)}));
   // clang-format on
 }
 
 TEST(FeatureProcessorTest, SplitTokensOnSelectionBoundariesWhole) {
-  std::vector<Token> tokens{Token("Hělló", 0, 5, false),
-                            Token("fěěbař@google.com", 6, 23, false),
-                            Token("heře!", 24, 29, false)};
+  std::vector<Token> tokens{Token("Hělló", 0, 5),
+                            Token("fěěbař@google.com", 6, 23),
+                            Token("heře!", 24, 29)};
 
   internal::SplitTokensOnSelectionBoundaries({6, 23}, &tokens);
 
   // clang-format off
   EXPECT_THAT(tokens, ElementsAreArray(
-                          {Token("Hělló", 0, 5, false),
-                           Token("fěěbař@google.com", 6, 23, false),
-                           Token("heře!", 24, 29, false)}));
+                          {Token("Hělló", 0, 5),
+                           Token("fěěbař@google.com", 6, 23),
+                           Token("heře!", 24, 29)}));
   // clang-format on
 }
 
 TEST(FeatureProcessorTest, SplitTokensOnSelectionBoundariesCrossToken) {
-  std::vector<Token> tokens{Token("Hělló", 0, 5, false),
-                            Token("fěěbař@google.com", 6, 23, false),
-                            Token("heře!", 24, 29, false)};
+  std::vector<Token> tokens{Token("Hělló", 0, 5),
+                            Token("fěěbař@google.com", 6, 23),
+                            Token("heře!", 24, 29)};
 
   internal::SplitTokensOnSelectionBoundaries({2, 9}, &tokens);
 
   // clang-format off
   EXPECT_THAT(tokens, ElementsAreArray(
-                          {Token("Hě", 0, 2, false),
-                           Token("lló", 2, 5, false),
-                           Token("fěě", 6, 9, false),
-                           Token("bař@google.com", 9, 23, false),
-                           Token("heře!", 24, 29, false)}));
+                          {Token("Hě", 0, 2),
+                           Token("lló", 2, 5),
+                           Token("fěě", 6, 9),
+                           Token("bař@google.com", 9, 23),
+                           Token("heře!", 24, 29)}));
   // clang-format on
 }
 
@@ -269,98 +269,25 @@
   EXPECT_EQ(label2, label3);
 }
 
-TEST(FeatureProcessorTest, GetFeaturesWithContextDropout) {
-  FeatureProcessorOptions options;
-  options.set_num_buckets(10);
-  options.set_context_size(7);
-  options.set_max_selection_span(7);
-  options.add_chargram_orders(1);
-  options.set_tokenize_on_space(true);
-  options.set_context_dropout_probability(0.5);
-  options.set_use_variable_context_dropout(true);
-  TokenizationCodepointRange* config =
-      options.add_tokenization_codepoint_config();
-  config->set_start(32);
-  config->set_end(33);
-  config->set_role(TokenizationCodepointRange::WHITESPACE_SEPARATOR);
-  FeatureProcessor feature_processor(options);
-
-  // Test that two subsequent runs with random context dropout produce
-  // different features.
-  feature_processor.SetRandom(new std::mt19937);
-
-  std::vector<std::vector<std::pair<int, float>>> features;
-  std::vector<std::vector<std::pair<int, float>>> features2;
-  std::vector<float> extra_features;
-  std::vector<CodepointSpan> selection_label_spans;
-  int selection_label;
-  CodepointSpan selection_codepoint_label;
-  int classification_label;
-  EXPECT_TRUE(feature_processor.GetFeaturesAndLabels(
-      "1 2 3 c o n t e x t X c o n t e x t 1 2 3", {20, 21}, {20, 21}, "",
-      &features, &extra_features, &selection_label_spans, &selection_label,
-      &selection_codepoint_label, &classification_label));
-  EXPECT_TRUE(feature_processor.GetFeaturesAndLabels(
-      "1 2 3 c o n t e x t X c o n t e x t 1 2 3", {20, 21}, {20, 21}, "",
-      &features2, &extra_features, &selection_label_spans, &selection_label,
-      &selection_codepoint_label, &classification_label));
-
-  EXPECT_NE(features, features2);
-}
-
-TEST(FeatureProcessorTest, GetFeaturesWithLongerContext) {
-  FeatureProcessorOptions options;
-  options.set_num_buckets(10);
-  options.set_context_size(9);
-  options.set_max_selection_span(7);
-  options.add_chargram_orders(1);
-  options.set_tokenize_on_space(true);
-  TokenizationCodepointRange* config =
-      options.add_tokenization_codepoint_config();
-  config->set_start(32);
-  config->set_end(33);
-  config->set_role(TokenizationCodepointRange::WHITESPACE_SEPARATOR);
-  FeatureProcessor feature_processor(options);
-
-  std::vector<std::vector<std::pair<int, float>>> features;
-  std::vector<float> extra_features;
-  std::vector<CodepointSpan> selection_label_spans;
-  int selection_label;
-  CodepointSpan selection_codepoint_label;
-  int classification_label;
-  EXPECT_TRUE(feature_processor.GetFeaturesAndLabels(
-      "1 2 3 c o n t e x t X c o n t e x t 1 2 3", {20, 21}, {20, 21}, "",
-      &features, &extra_features, &selection_label_spans, &selection_label,
-      &selection_codepoint_label, &classification_label));
-  EXPECT_EQ(19, features.size());
-
-  // Should pad the string.
-  EXPECT_TRUE(feature_processor.GetFeaturesAndLabels(
-      "X", {0, 1}, {0, 1}, "", &features, &extra_features,
-      &selection_label_spans, &selection_label, &selection_codepoint_label,
-      &classification_label));
-  EXPECT_EQ(19, features.size());
-}
-
 TEST(FeatureProcessorTest, CenterTokenFromClick) {
   int token_index;
 
   // Exactly aligned indices.
   token_index = internal::CenterTokenFromClick(
-      {6, 11}, {Token("Hělló", 0, 5, false), Token("world", 6, 11, false),
-                Token("heře!", 12, 17, false)});
+      {6, 11},
+      {Token("Hělló", 0, 5), Token("world", 6, 11), Token("heře!", 12, 17)});
   EXPECT_EQ(token_index, 1);
 
   // Click is contained in a token.
   token_index = internal::CenterTokenFromClick(
-      {13, 17}, {Token("Hělló", 0, 5, false), Token("world", 6, 11, false),
-                 Token("heře!", 12, 17, false)});
+      {13, 17},
+      {Token("Hělló", 0, 5), Token("world", 6, 11), Token("heře!", 12, 17)});
   EXPECT_EQ(token_index, 2);
 
   // Click spans two tokens.
   token_index = internal::CenterTokenFromClick(
-      {6, 17}, {Token("Hělló", 0, 5, false), Token("world", 6, 11, false),
-                Token("heře!", 12, 17, false)});
+      {6, 17},
+      {Token("Hělló", 0, 5), Token("world", 6, 11), Token("heře!", 12, 17)});
   EXPECT_EQ(token_index, kInvalidIndex);
 }
 
@@ -369,37 +296,37 @@
 
   // Selection of length 3. Exactly aligned indices.
   token_index = internal::CenterTokenFromMiddleOfSelection(
-      {7, 27}, {Token("Token1", 0, 6, false), Token("Token2", 7, 13, false),
-                Token("Token3", 14, 20, false), Token("Token4", 21, 27, false),
-                Token("Token5", 28, 34, false)});
+      {7, 27},
+      {Token("Token1", 0, 6), Token("Token2", 7, 13), Token("Token3", 14, 20),
+       Token("Token4", 21, 27), Token("Token5", 28, 34)});
   EXPECT_EQ(token_index, 2);
 
   // Selection of length 1 token. Exactly aligned indices.
   token_index = internal::CenterTokenFromMiddleOfSelection(
-      {21, 27}, {Token("Token1", 0, 6, false), Token("Token2", 7, 13, false),
-                 Token("Token3", 14, 20, false), Token("Token4", 21, 27, false),
-                 Token("Token5", 28, 34, false)});
+      {21, 27},
+      {Token("Token1", 0, 6), Token("Token2", 7, 13), Token("Token3", 14, 20),
+       Token("Token4", 21, 27), Token("Token5", 28, 34)});
   EXPECT_EQ(token_index, 3);
 
   // Selection marks sub-token range, with no tokens in it.
   token_index = internal::CenterTokenFromMiddleOfSelection(
-      {29, 33}, {Token("Token1", 0, 6, false), Token("Token2", 7, 13, false),
-                 Token("Token3", 14, 20, false), Token("Token4", 21, 27, false),
-                 Token("Token5", 28, 34, false)});
+      {29, 33},
+      {Token("Token1", 0, 6), Token("Token2", 7, 13), Token("Token3", 14, 20),
+       Token("Token4", 21, 27), Token("Token5", 28, 34)});
   EXPECT_EQ(token_index, kInvalidIndex);
 
   // Selection of length 2. Sub-token indices.
   token_index = internal::CenterTokenFromMiddleOfSelection(
-      {3, 25}, {Token("Token1", 0, 6, false), Token("Token2", 7, 13, false),
-                Token("Token3", 14, 20, false), Token("Token4", 21, 27, false),
-                Token("Token5", 28, 34, false)});
+      {3, 25},
+      {Token("Token1", 0, 6), Token("Token2", 7, 13), Token("Token3", 14, 20),
+       Token("Token4", 21, 27), Token("Token5", 28, 34)});
   EXPECT_EQ(token_index, 1);
 
   // Selection of length 1. Sub-token indices.
   token_index = internal::CenterTokenFromMiddleOfSelection(
-      {22, 34}, {Token("Token1", 0, 6, false), Token("Token2", 7, 13, false),
-                 Token("Token3", 14, 20, false), Token("Token4", 21, 27, false),
-                 Token("Token5", 28, 34, false)});
+      {22, 34},
+      {Token("Token1", 0, 6), Token("Token2", 7, 13), Token("Token3", 14, 20),
+       Token("Token4", 21, 27), Token("Token5", 28, 34)});
   EXPECT_EQ(token_index, 4);
 
   // Some invalid ones.
@@ -407,42 +334,6 @@
   EXPECT_EQ(token_index, -1);
 }
 
-TEST(FeatureProcessorTest, GetFeaturesForSharing) {
-  FeatureProcessorOptions options;
-  options.set_num_buckets(10);
-  options.set_context_size(9);
-  options.set_max_selection_span(7);
-  options.add_chargram_orders(1);
-  options.set_tokenize_on_space(true);
-  options.set_center_token_selection_method(
-      FeatureProcessorOptions::CENTER_TOKEN_MIDDLE_OF_SELECTION);
-  options.set_only_use_line_with_click(true);
-  options.set_split_tokens_on_selection_boundaries(true);
-  options.set_extract_selection_mask_feature(true);
-  TokenizationCodepointRange* config =
-      options.add_tokenization_codepoint_config();
-  config->set_start(32);
-  config->set_end(33);
-  config->set_role(TokenizationCodepointRange::WHITESPACE_SEPARATOR);
-  config = options.add_tokenization_codepoint_config();
-  config->set_start(10);
-  config->set_end(11);
-  config->set_role(TokenizationCodepointRange::WHITESPACE_SEPARATOR);
-  FeatureProcessor feature_processor(options);
-
-  std::vector<std::vector<std::pair<int, float>>> features;
-  std::vector<float> extra_features;
-  std::vector<CodepointSpan> selection_label_spans;
-  int selection_label;
-  CodepointSpan selection_codepoint_label;
-  int classification_label;
-  EXPECT_TRUE(feature_processor.GetFeaturesAndLabels(
-      "line 1\nline2\nsome entity\n line 4", {13, 24}, {13, 24}, "", &features,
-      &extra_features, &selection_label_spans, &selection_label,
-      &selection_codepoint_label, &classification_label));
-  EXPECT_EQ(19, features.size());
-}
-
 TEST(FeatureProcessorTest, SupportedCodepointsRatio) {
   FeatureProcessorOptions options;
   options.set_context_size(2);
@@ -488,26 +379,144 @@
   EXPECT_FALSE(feature_processor.IsCodepointSupported(10001));
   EXPECT_TRUE(feature_processor.IsCodepointSupported(25000));
 
-  std::vector<nlp_core::FeatureVector> features;
+  std::vector<Token> tokens;
+  int click_pos;
   std::vector<float> extra_features;
+  std::unique_ptr<CachedFeatures> cached_features;
+
+  auto feature_fn = [](const std::vector<int>& sparse_features,
+                       const std::vector<float>& dense_features,
+                       float* embedding) { return true; };
 
   options.set_min_supported_codepoint_ratio(0.0);
-  feature_processor = TestingFeatureProcessor(options);
-  EXPECT_TRUE(feature_processor.GetFeatures("ěěě řřř eee", {4, 7}, &features,
-                                            &extra_features,
-                                            /*selection_label_spans=*/nullptr));
+  TestingFeatureProcessor feature_processor2(options);
+  EXPECT_TRUE(feature_processor2.ExtractFeatures("ěěě řřř eee", {4, 7}, {0, 0},
+                                                 feature_fn, 2, &tokens,
+                                                 &click_pos, &cached_features));
 
   options.set_min_supported_codepoint_ratio(0.2);
-  feature_processor = TestingFeatureProcessor(options);
-  EXPECT_TRUE(feature_processor.GetFeatures("ěěě řřř eee", {4, 7}, &features,
-                                            &extra_features,
-                                            /*selection_label_spans=*/nullptr));
+  TestingFeatureProcessor feature_processor3(options);
+  EXPECT_TRUE(feature_processor3.ExtractFeatures("ěěě řřř eee", {4, 7}, {0, 0},
+                                                 feature_fn, 2, &tokens,
+                                                 &click_pos, &cached_features));
 
   options.set_min_supported_codepoint_ratio(0.5);
-  feature_processor = TestingFeatureProcessor(options);
-  EXPECT_FALSE(feature_processor.GetFeatures(
-      "ěěě řřř eee", {4, 7}, &features, &extra_features,
-      /*selection_label_spans=*/nullptr));
+  TestingFeatureProcessor feature_processor4(options);
+  EXPECT_FALSE(feature_processor4.ExtractFeatures(
+      "ěěě řřř eee", {4, 7}, {0, 0}, feature_fn, 2, &tokens, &click_pos,
+      &cached_features));
+}
+
+TEST(FeatureProcessorTest, StripUnusedTokensWithNoRelativeClick) {
+  std::vector<Token> tokens_orig{
+      Token("0", 0, 0), Token("1", 0, 0), Token("2", 0, 0),  Token("3", 0, 0),
+      Token("4", 0, 0), Token("5", 0, 0), Token("6", 0, 0),  Token("7", 0, 0),
+      Token("8", 0, 0), Token("9", 0, 0), Token("10", 0, 0), Token("11", 0, 0),
+      Token("12", 0, 0)};
+
+  std::vector<Token> tokens;
+  int click_index;
+
+  // Try to click first token and see if it gets padded from left.
+  tokens = tokens_orig;
+  click_index = 0;
+  internal::StripOrPadTokens({0, 0}, 2, &tokens, &click_index);
+  // clang-format off
+  EXPECT_EQ(tokens, std::vector<Token>({Token(),
+                                        Token(),
+                                        Token("0", 0, 0),
+                                        Token("1", 0, 0),
+                                        Token("2", 0, 0)}));
+  // clang-format on
+  EXPECT_EQ(click_index, 2);
+
+  // When we click the second token nothing should get padded.
+  tokens = tokens_orig;
+  click_index = 2;
+  internal::StripOrPadTokens({0, 0}, 2, &tokens, &click_index);
+  // clang-format off
+  EXPECT_EQ(tokens, std::vector<Token>({Token("0", 0, 0),
+                                        Token("1", 0, 0),
+                                        Token("2", 0, 0),
+                                        Token("3", 0, 0),
+                                        Token("4", 0, 0)}));
+  // clang-format on
+  EXPECT_EQ(click_index, 2);
+
+  // When we click the last token tokens should get padded from the right.
+  tokens = tokens_orig;
+  click_index = 12;
+  internal::StripOrPadTokens({0, 0}, 2, &tokens, &click_index);
+  // clang-format off
+  EXPECT_EQ(tokens, std::vector<Token>({Token("10", 0, 0),
+                                        Token("11", 0, 0),
+                                        Token("12", 0, 0),
+                                        Token(),
+                                        Token()}));
+  // clang-format on
+  EXPECT_EQ(click_index, 2);
+}
+
+TEST(FeatureProcessorTest, StripUnusedTokensWithRelativeClick) {
+  std::vector<Token> tokens_orig{
+      Token("0", 0, 0), Token("1", 0, 0), Token("2", 0, 0),  Token("3", 0, 0),
+      Token("4", 0, 0), Token("5", 0, 0), Token("6", 0, 0),  Token("7", 0, 0),
+      Token("8", 0, 0), Token("9", 0, 0), Token("10", 0, 0), Token("11", 0, 0),
+      Token("12", 0, 0)};
+
+  std::vector<Token> tokens;
+  int click_index;
+
+  // Try to click first token and see if it gets padded from left to maximum
+  // context_size.
+  tokens = tokens_orig;
+  click_index = 0;
+  internal::StripOrPadTokens({2, 3}, 2, &tokens, &click_index);
+  // clang-format off
+  EXPECT_EQ(tokens, std::vector<Token>({Token(),
+                                        Token(),
+                                        Token("0", 0, 0),
+                                        Token("1", 0, 0),
+                                        Token("2", 0, 0),
+                                        Token("3", 0, 0),
+                                        Token("4", 0, 0),
+                                        Token("5", 0, 0)}));
+  // clang-format on
+  EXPECT_EQ(click_index, 2);
+
+  // Clicking to the middle with enough context should not produce any padding.
+  tokens = tokens_orig;
+  click_index = 6;
+  internal::StripOrPadTokens({3, 1}, 2, &tokens, &click_index);
+  // clang-format off
+  EXPECT_EQ(tokens, std::vector<Token>({Token("1", 0, 0),
+                                        Token("2", 0, 0),
+                                        Token("3", 0, 0),
+                                        Token("4", 0, 0),
+                                        Token("5", 0, 0),
+                                        Token("6", 0, 0),
+                                        Token("7", 0, 0),
+                                        Token("8", 0, 0),
+                                        Token("9", 0, 0)}));
+  // clang-format on
+  EXPECT_EQ(click_index, 5);
+
+  // Clicking at the end should pad right to maximum context_size.
+  tokens = tokens_orig;
+  click_index = 11;
+  internal::StripOrPadTokens({3, 1}, 2, &tokens, &click_index);
+  // clang-format off
+  EXPECT_EQ(tokens, std::vector<Token>({Token("6", 0, 0),
+                                        Token("7", 0, 0),
+                                        Token("8", 0, 0),
+                                        Token("9", 0, 0),
+                                        Token("10", 0, 0),
+                                        Token("11", 0, 0),
+                                        Token("12", 0, 0),
+                                        Token(),
+                                        Token()}));
+  // clang-format on
+  EXPECT_EQ(click_index, 5);
 }
 
 }  // namespace
diff --git a/tests/token-feature-extractor_test.cc b/tests/token-feature-extractor_test.cc
index 55a5228..277549e 100644
--- a/tests/token-feature-extractor_test.cc
+++ b/tests/token-feature-extractor_test.cc
@@ -40,7 +40,7 @@
   std::vector<int> sparse_features;
   std::vector<float> dense_features;
 
-  extractor.Extract(Token{"Hello", 0, 5, true}, &sparse_features,
+  extractor.Extract(Token{"Hello", 0, 5}, true, &sparse_features,
                     &dense_features);
 
   EXPECT_THAT(sparse_features,
@@ -68,7 +68,7 @@
 
   sparse_features.clear();
   dense_features.clear();
-  extractor.Extract(Token{"world!", 23, 29, false}, &sparse_features,
+  extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
                     &dense_features);
 
   EXPECT_THAT(sparse_features,
@@ -110,7 +110,7 @@
   std::vector<int> sparse_features;
   std::vector<float> dense_features;
 
-  extractor.Extract(Token{"Hělló", 0, 5, true}, &sparse_features,
+  extractor.Extract(Token{"Hělló", 0, 5}, true, &sparse_features,
                     &dense_features);
 
   EXPECT_THAT(sparse_features,
@@ -138,7 +138,7 @@
 
   sparse_features.clear();
   dense_features.clear();
-  extractor.Extract(Token{"world!", 23, 29, false}, &sparse_features,
+  extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
                     &dense_features);
 
   EXPECT_THAT(sparse_features,
@@ -179,25 +179,25 @@
 
   std::vector<int> sparse_features;
   std::vector<float> dense_features;
-  extractor.Extract(Token{"Hělló", 0, 5, true}, &sparse_features,
+  extractor.Extract(Token{"Hělló", 0, 5}, true, &sparse_features,
                     &dense_features);
   EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0}));
 
   sparse_features.clear();
   dense_features.clear();
-  extractor.Extract(Token{"world!", 23, 29, false}, &sparse_features,
+  extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
                     &dense_features);
   EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0}));
 
   sparse_features.clear();
   dense_features.clear();
-  extractor.Extract(Token{"Ř", 23, 29, false}, &sparse_features,
+  extractor.Extract(Token{"Ř", 23, 29}, false, &sparse_features,
                     &dense_features);
   EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0}));
 
   sparse_features.clear();
   dense_features.clear();
-  extractor.Extract(Token{"ř", 23, 29, false}, &sparse_features,
+  extractor.Extract(Token{"ř", 23, 29}, false, &sparse_features,
                     &dense_features);
   EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0}));
 }
@@ -212,15 +212,15 @@
 
   std::vector<int> sparse_features;
   std::vector<float> dense_features;
-  extractor.Extract(Token{"9:30am", 0, 6, true}, &sparse_features,
+  extractor.Extract(Token{"9:30am", 0, 6}, true, &sparse_features,
                     &dense_features);
 
   std::vector<int> sparse_features2;
-  extractor.Extract(Token{"5:32am", 0, 6, true}, &sparse_features2,
+  extractor.Extract(Token{"5:32am", 0, 6}, true, &sparse_features2,
                     &dense_features);
   EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
 
-  extractor.Extract(Token{"10:32am", 0, 6, true}, &sparse_features2,
+  extractor.Extract(Token{"10:32am", 0, 6}, true, &sparse_features2,
                     &dense_features);
   EXPECT_THAT(sparse_features,
               testing::Not(testing::ElementsAreArray(sparse_features2)));
@@ -236,15 +236,15 @@
 
   std::vector<int> sparse_features;
   std::vector<float> dense_features;
-  extractor.Extract(Token{"9:30am", 0, 6, true}, &sparse_features,
+  extractor.Extract(Token{"9:30am", 0, 6}, true, &sparse_features,
                     &dense_features);
 
   std::vector<int> sparse_features2;
-  extractor.Extract(Token{"5:32am", 0, 6, true}, &sparse_features2,
+  extractor.Extract(Token{"5:32am", 0, 6}, true, &sparse_features2,
                     &dense_features);
   EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
 
-  extractor.Extract(Token{"10:32am", 0, 6, true}, &sparse_features2,
+  extractor.Extract(Token{"10:32am", 0, 6}, true, &sparse_features2,
                     &dense_features);
   EXPECT_THAT(sparse_features,
               testing::Not(testing::ElementsAreArray(sparse_features2)));
@@ -262,22 +262,22 @@
 
   std::vector<int> sparse_features;
   std::vector<float> dense_features;
-  extractor.Extract(Token{"abCde", 0, 6, true}, &sparse_features,
+  extractor.Extract(Token{"abCde", 0, 6}, true, &sparse_features,
                     &dense_features);
   EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0}));
 
   dense_features.clear();
-  extractor.Extract(Token{"abcde", 0, 6, true}, &sparse_features,
+  extractor.Extract(Token{"abcde", 0, 6}, true, &sparse_features,
                     &dense_features);
   EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, -1.0}));
 
   dense_features.clear();
-  extractor.Extract(Token{"12c45", 0, 6, true}, &sparse_features,
+  extractor.Extract(Token{"12c45", 0, 6}, true, &sparse_features,
                     &dense_features);
   EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0}));
 
   dense_features.clear();
-  extractor.Extract(Token{"12345", 0, 6, true}, &sparse_features,
+  extractor.Extract(Token{"12345", 0, 6}, true, &sparse_features,
                     &dense_features);
   EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 1.0}));
 }
@@ -294,7 +294,7 @@
   // Test that this runs. ASAN should catch problems.
   std::vector<int> sparse_features;
   std::vector<float> dense_features;
-  extractor.Extract(Token{"abcdefghijklmnopqřstuvwxyz", 0, 0, true},
+  extractor.Extract(Token{"abcdefghijklmnopqřstuvwxyz", 0, 0}, true,
                     &sparse_features, &dense_features);
 
   EXPECT_THAT(sparse_features,
@@ -325,13 +325,13 @@
         "x", "Hello", "Hey,", "Hi", ""}) {
     std::vector<int> sparse_features_unicode;
     std::vector<float> dense_features_unicode;
-    extractor_unicode.Extract(Token{input, 0, 0, true},
+    extractor_unicode.Extract(Token{input, 0, 0}, true,
                               &sparse_features_unicode,
                               &dense_features_unicode);
 
     std::vector<int> sparse_features_ascii;
     std::vector<float> dense_features_ascii;
-    extractor_ascii.Extract(Token{input, 0, 0, true}, &sparse_features_ascii,
+    extractor_ascii.Extract(Token{input, 0, 0}, true, &sparse_features_ascii,
                             &dense_features_ascii);
 
     EXPECT_THAT(sparse_features_unicode, sparse_features_ascii) << input;
@@ -352,7 +352,7 @@
   std::vector<int> sparse_features;
   std::vector<float> dense_features;
 
-  extractor.Extract(Token(), &sparse_features, &dense_features);
+  extractor.Extract(Token(), false, &sparse_features, &dense_features);
 
   EXPECT_THAT(sparse_features,
               testing::ElementsAreArray({extractor.HashToken("<PAD>")}));
diff --git a/tests/tokenizer_test.cc b/tests/tokenizer_test.cc
index f0034bb..cdb90a9 100644
--- a/tests/tokenizer_test.cc
+++ b/tests/tokenizer_test.cc
@@ -104,8 +104,8 @@
   TestingTokenizer tokenizer(configs);
   std::vector<Token> tokens = tokenizer.Tokenize("Hello world!");
 
-  EXPECT_THAT(tokens, ElementsAreArray({Token("Hello", 0, 5, false),
-                                        Token("world!", 6, 12, false)}));
+  EXPECT_THAT(tokens,
+              ElementsAreArray({Token("Hello", 0, 5), Token("world!", 6, 12)}));
 }
 
 TEST(TokenizerTest, TokenizeComplex) {
@@ -243,17 +243,17 @@
   // clang-format off
   EXPECT_THAT(
       tokens,
-      ElementsAreArray({Token("問", 0, 1, false),
-                        Token("少", 1, 2, false),
-                        Token("目", 2, 3, false),
-                        Token("hello", 4, 9, false),
-                        Token("木", 10, 11, false),
-                        Token("輸", 11, 12, false),
-                        Token("ย", 12, 13, false),
-                        Token("า", 13, 14, false),
-                        Token("ม", 14, 15, false),
-                        Token("き", 15, 16, false),
-                        Token("ゃ", 16, 17, false)}));
+      ElementsAreArray({Token("問", 0, 1),
+                        Token("少", 1, 2),
+                        Token("目", 2, 3),
+                        Token("hello", 4, 9),
+                        Token("木", 10, 11),
+                        Token("輸", 11, 12),
+                        Token("ย", 12, 13),
+                        Token("า", 13, 14),
+                        Token("ม", 14, 15),
+                        Token("き", 15, 16),
+                        Token("ゃ", 16, 17)}));
   // clang-format on
 }