Makes inference efficient. am: 6bb39a8ec5
am: 49574bd039
Change-Id: I4085bf36469571115a0566024cfe8e9a98b603ba
diff --git a/common/embedding-network.cc b/common/embedding-network.cc
index 61f7323..a17e082 100644
--- a/common/embedding-network.cc
+++ b/common/embedding-network.cc
@@ -48,8 +48,8 @@
// Before we access the weights as floats, we need to check that they are
// really floats, i.e., no quantization is used.
if (!CheckNoQuantization(source_matrix)) return false;
- const float *weights = reinterpret_cast<const float *>(
- source_matrix.elements);
+ const float *weights =
+ reinterpret_cast<const float *>(source_matrix.elements);
for (int r = 0; r < source_matrix.rows; ++r) {
(*mat)[r] = EmbeddingNetwork::VectorWrapper(weights, source_matrix.cols);
weights += source_matrix.cols;
@@ -86,7 +86,7 @@
bool SparseReluProductPlusBias(bool apply_relu,
const EmbeddingNetwork::Matrix &weights,
const EmbeddingNetwork::VectorWrapper &b,
- const EmbeddingNetwork::Vector &x,
+ const VectorSpan<float> &x,
EmbeddingNetwork::Vector *y) {
// Check that dimensions match.
if ((x.size() != weights.size()) || weights.empty()) {
@@ -131,87 +131,127 @@
// "es_index" stands for "embedding space index".
for (int es_index = 0; es_index < feature_vectors.size(); ++es_index) {
- // Access is safe by es_index loop bounds, Invariant 2, and Invariant 3.
- const int concat_offset = concat_offset_[es_index];
-
// Access is safe by es_index loop bounds and Invariant 3.
- const EmbeddingMatrix *embedding_matrix =
+ EmbeddingMatrix *const embedding_matrix =
embedding_matrices_[es_index].get();
if (embedding_matrix == nullptr) {
// Should not happen, hence our terse log error message.
TC_LOG(ERROR) << es_index;
return false;
}
- const int embedding_dim = embedding_matrix->dim();
- const bool is_quantized =
- embedding_matrix->quant_type() != QuantizationType::NONE;
// Access is safe due to es_index loop bounds.
const FeatureVector &feature_vector = feature_vectors[es_index];
- const int num_features = feature_vector.size();
- for (int fi = 0; fi < num_features; ++fi) {
- // Both accesses below are safe due to loop bounds for fi.
- const FeatureType *feature_type = feature_vector.type(fi);
- const FeatureValue feature_value = feature_vector.value(fi);
- const int feature_offset =
- concat_offset + feature_type->base() * embedding_dim;
- // Code below updates max(0, embedding_dim) elements from concat, starting
- // with index feature_offset. Check below ensures these updates are safe.
- if ((feature_offset < 0) ||
- (feature_offset + embedding_dim > concat->size())) {
- TC_LOG(ERROR) << es_index << "," << fi << ": " << feature_offset << " "
- << embedding_dim << " " << concat->size();
- return false;
+ // Access is safe by es_index loop bounds, Invariant 2, and Invariant 3.
+ const int concat_offset = concat_offset_[es_index];
+
+ if (!GetEmbeddingInternal(feature_vector, embedding_matrix, concat_offset,
+ concat->data(), concat->size())) {
+ TC_LOG(ERROR) << es_index;
+ return false;
+ }
+ }
+ return true;
+}
+
+bool EmbeddingNetwork::GetEmbedding(const FeatureVector &feature_vector,
+ int es_index, float *embedding) const {
+ EmbeddingMatrix *const embedding_matrix = embedding_matrices_[es_index].get();
+ if (embedding_matrix == nullptr) {
+ // Should not happen, hence our terse log error message.
+ TC_LOG(ERROR) << es_index;
+ return false;
+ }
+ return GetEmbeddingInternal(feature_vector, embedding_matrix, 0, embedding,
+ embedding_matrices_[es_index]->dim());
+}
+
+bool EmbeddingNetwork::GetEmbeddingInternal(
+ const FeatureVector &feature_vector,
+ EmbeddingMatrix *const embedding_matrix, const int concat_offset,
+ float *concat, int concat_size) const {
+ const int embedding_dim = embedding_matrix->dim();
+ const bool is_quantized =
+ embedding_matrix->quant_type() != QuantizationType::NONE;
+ const int num_features = feature_vector.size();
+ for (int fi = 0; fi < num_features; ++fi) {
+ // Both accesses below are safe due to loop bounds for fi.
+ const FeatureType *feature_type = feature_vector.type(fi);
+ const FeatureValue feature_value = feature_vector.value(fi);
+ const int feature_offset =
+ concat_offset + feature_type->base() * embedding_dim;
+
+ // Code below updates max(0, embedding_dim) elements from concat, starting
+ // with index feature_offset. Check below ensures these updates are safe.
+ if ((feature_offset < 0) ||
+ (feature_offset + embedding_dim > concat_size)) {
+ TC_LOG(ERROR) << fi << ": " << feature_offset << " " << embedding_dim
+ << " " << concat_size;
+ return false;
+ }
+
+ // Pointer to float / uint8 weights for relevant embedding.
+ const void *embedding_data;
+
+ // Multiplier for each embedding weight.
+ float multiplier;
+
+ if (feature_type->is_continuous()) {
+ // Continuous features (encoded as FloatFeatureValue).
+ FloatFeatureValue float_feature_value(feature_value);
+ const int id = float_feature_value.id;
+ embedding_matrix->get_embedding(id, &embedding_data, &multiplier);
+ multiplier *= float_feature_value.weight;
+ } else {
+ // Discrete features: every present feature has implicit value 1.0.
+ // Hence, after we grab the multiplier below, we don't multiply it by
+ // any weight.
+ embedding_matrix->get_embedding(feature_value, &embedding_data,
+ &multiplier);
+ }
+
+ // Weighted embeddings will be added starting from this address.
+ float *concat_ptr = concat + feature_offset;
+
+ if (is_quantized) {
+ const uint8 *quant_weights =
+ reinterpret_cast<const uint8 *>(embedding_data);
+ for (int i = 0; i < embedding_dim; ++i, ++quant_weights, ++concat_ptr) {
+ // 128 is bias for UINT8 quantization, only one we currently support.
+ *concat_ptr += (static_cast<int>(*quant_weights) - 128) * multiplier;
}
-
- // Pointer to float / uint8 weights for relevant embedding.
- const void *embedding_data;
-
- // Multiplier for each embedding weight.
- float multiplier;
-
- if (feature_type->is_continuous()) {
- // Continuous features (encoded as FloatFeatureValue).
- FloatFeatureValue float_feature_value(feature_value);
- const int id = float_feature_value.id;
- embedding_matrix->get_embedding(id, &embedding_data, &multiplier);
- multiplier *= float_feature_value.weight;
- } else {
- // Discrete features: every present feature has implicit value 1.0.
- // Hence, after we grab the multiplier below, we don't multiply it by
- // any weight.
- embedding_matrix->get_embedding(feature_value, &embedding_data,
- &multiplier);
- }
-
- // Weighted embeddings will be added starting from this address.
- float *concat_ptr = concat->data() + feature_offset;
-
- if (is_quantized) {
- const uint8 *quant_weights =
- reinterpret_cast<const uint8 *>(embedding_data);
- for (int i = 0; i < embedding_dim; ++i, ++quant_weights, ++concat_ptr) {
- // 128 is bias for UINT8 quantization, only one we currently support.
- *concat_ptr += (static_cast<int>(*quant_weights) - 128) * multiplier;
- }
- } else {
- const float *weights = reinterpret_cast<const float *>(embedding_data);
- for (int i = 0; i < embedding_dim; ++i, ++weights, ++concat_ptr) {
- *concat_ptr += *weights * multiplier;
- }
+ } else {
+ const float *weights = reinterpret_cast<const float *>(embedding_data);
+ for (int i = 0; i < embedding_dim; ++i, ++weights, ++concat_ptr) {
+ *concat_ptr += *weights * multiplier;
}
}
}
return true;
}
+bool EmbeddingNetwork::ComputeLogits(const VectorSpan<float> &input,
+ Vector *scores) const {
+ return EmbeddingNetwork::ComputeLogitsInternal(input, scores);
+}
+
+bool EmbeddingNetwork::ComputeLogits(const Vector &input,
+ Vector *scores) const {
+ return EmbeddingNetwork::ComputeLogitsInternal(input, scores);
+}
+
+bool EmbeddingNetwork::ComputeLogitsInternal(const VectorSpan<float> &input,
+ Vector *scores) const {
+ return FinishComputeFinalScoresInternal<SimpleAdder>(input, scores);
+}
+
template <typename ScaleAdderClass>
-bool EmbeddingNetwork::FinishComputeFinalScores(const Vector &concat,
- Vector *scores) const {
+bool EmbeddingNetwork::FinishComputeFinalScoresInternal(
+ const VectorSpan<float> &input, Vector *scores) const {
Vector h0(hidden_bias_[0].size());
bool success = SparseReluProductPlusBias<ScaleAdderClass>(
- false, hidden_weights_[0], hidden_bias_[0], concat, &h0);
+ false, hidden_weights_[0], hidden_bias_[0], input, &h0);
if (!success) return false;
if (hidden_weights_.size() == 1) { // 1 hidden layer
@@ -259,7 +299,7 @@
}
scores->resize(softmax_bias_.size());
- return FinishComputeFinalScores<SimpleAdder>(concat, scores);
+ return ComputeLogits(concat, scores);
}
EmbeddingNetwork::EmbeddingNetwork(const EmbeddingNetworkParams *model) {
@@ -331,5 +371,9 @@
valid_ = true;
}
+int EmbeddingNetwork::EmbeddingSize(int es_index) const {
+ return embedding_matrices_[es_index]->dim();
+}
+
} // namespace nlp_core
} // namespace libtextclassifier
diff --git a/common/embedding-network.h b/common/embedding-network.h
index 95b4d58..594f34c 100644
--- a/common/embedding-network.h
+++ b/common/embedding-network.h
@@ -22,6 +22,7 @@
#include "common/embedding-network-params.h"
#include "common/feature-extractor.h"
+#include "common/vector-span.h"
#include "util/base/integral_types.h"
#include "util/base/logging.h"
#include "util/base/macros.h"
@@ -54,8 +55,7 @@
quant_type_(source_matrix.quant_type),
data_(source_matrix.elements),
row_size_in_bytes_(GetRowSizeInBytes(cols_, quant_type_)),
- quant_scales_(source_matrix.quant_scales) {
- }
+ quant_scales_(source_matrix.quant_scales) {}
// Returns vocabulary size; one embedding for each vocabulary element.
int size() const { return rows_; }
@@ -133,8 +133,7 @@
// at address data. Note: the underlying data should be alive for at least
// the lifetime of this VectorWrapper object. That's trivially true if data
// points to statically allocated data :)
- VectorWrapper(const float *data, int size)
- : data_(data), size_(size) {}
+ VectorWrapper(const float *data, int size) : data_(data), size_(size) {}
int size() const { return size_; }
@@ -176,17 +175,48 @@
const std::vector<float> extra_inputs,
Vector *scores) const;
- private:
- // Computes the softmax scores (prior to normalization) from the concatenated
- // representation. Returns true on success, false on error.
- template <typename ScaleAdderClass>
- bool FinishComputeFinalScores(const Vector &concat, Vector *scores) const;
-
// Constructs the concatenated input embedding vector in place in output
// vector concat. Returns true on success, false on error.
bool ConcatEmbeddings(const std::vector<FeatureVector> &features,
Vector *concat) const;
+ // Sums embeddings for all features from |feature_vector| and adds result
+ // to values from the array pointed-to by |output|. Embeddings for continuous
+ // features are weighted by the feature weight.
+ //
+ // NOTE: output should point to an array of EmbeddingSize(es_index) floats.
+ bool GetEmbedding(const FeatureVector &feature_vector, int es_index,
+ float *embedding) const;
+
+ // Runs the feed-forward neural network for |input| and computes logits for
+ // softmax layer.
+ bool ComputeLogits(const Vector &input, Vector *scores) const;
+
+ // Same as above but uses a view of the feature vector.
+ bool ComputeLogits(const VectorSpan<float> &input, Vector *scores) const;
+
+ // Returns the size (the number of columns) of the embedding space es_index.
+ int EmbeddingSize(int es_index) const;
+
+ private:
+ // Builds an embedding for given feature vector, and places it from
+ // concat_offset to the concat vector.
+ bool GetEmbeddingInternal(const FeatureVector &feature_vector,
+ EmbeddingMatrix *embedding_matrix,
+ int concat_offset, float *concat,
+ int embedding_size) const;
+
+ // Templated function that computes the logit scores given the concatenated
+ // input embeddings.
+ bool ComputeLogitsInternal(const VectorSpan<float> &concat,
+ Vector *scores) const;
+
+ // Computes the softmax scores (prior to normalization) from the concatenated
+ // representation. Returns true on success, false on error.
+ template <typename ScaleAdderClass>
+ bool FinishComputeFinalScoresInternal(const VectorSpan<float> &concat,
+ Vector *scores) const;
+
// Set to true on successful construction, false otherwise.
bool valid_ = false;
diff --git a/common/vector-span.h b/common/vector-span.h
new file mode 100644
index 0000000..d7fbfe9
--- /dev/null
+++ b/common/vector-span.h
@@ -0,0 +1,50 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_COMMON_VECTOR_SPAN_H_
+#define LIBTEXTCLASSIFIER_COMMON_VECTOR_SPAN_H_
+
+#include <vector>
+
+namespace libtextclassifier {
+
+// StringPiece analogue for std::vector<T>.
+template <class T>
+class VectorSpan {
+ public:
+ VectorSpan() : begin_(), end_() {}
+ VectorSpan(const std::vector<T>& v) // NOLINT(runtime/explicit)
+ : begin_(v.begin()), end_(v.end()) {}
+ VectorSpan(typename std::vector<T>::const_iterator begin,
+ typename std::vector<T>::const_iterator end)
+ : begin_(begin), end_(end) {}
+
+ const T& operator[](typename std::vector<T>::size_type i) const {
+ return *(begin_ + i);
+ }
+
+ int size() const { return end_ - begin_; }
+ typename std::vector<T>::const_iterator begin() const { return begin_; }
+ typename std::vector<T>::const_iterator end() const { return end_; }
+
+ private:
+ typename std::vector<T>::const_iterator begin_;
+ typename std::vector<T>::const_iterator end_;
+};
+
+} // namespace libtextclassifier
+
+#endif // LIBTEXTCLASSIFIER_COMMON_VECTOR_SPAN_H_
diff --git a/smartselect/cached-features.cc b/smartselect/cached-features.cc
new file mode 100644
index 0000000..c249db9
--- /dev/null
+++ b/smartselect/cached-features.cc
@@ -0,0 +1,82 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "smartselect/cached-features.h"
+#include "util/base/logging.h"
+
+namespace libtextclassifier {
+
+void CachedFeatures::Extract(
+ const std::vector<std::vector<int>>& sparse_features,
+ const std::vector<std::vector<float>>& dense_features,
+ const std::function<bool(const std::vector<int>&, const std::vector<float>&,
+ float*)>& feature_vector_fn) {
+ features_.resize(feature_vector_size_ * tokens_.size());
+ for (int i = 0; i < tokens_.size(); ++i) {
+ feature_vector_fn(sparse_features[i], dense_features[i],
+ features_.data() + i * feature_vector_size_);
+ }
+}
+
+bool CachedFeatures::Get(int click_pos, VectorSpan<float>* features,
+ VectorSpan<Token>* output_tokens) {
+ const int token_start = click_pos - context_size_;
+ const int token_end = click_pos + context_size_ + 1;
+ if (token_start < 0 || token_end > tokens_.size()) {
+ TC_LOG(ERROR) << "Tokens out of range: " << token_start << " " << token_end;
+ return false;
+ }
+
+ *features =
+ VectorSpan<float>(features_.begin() + token_start * feature_vector_size_,
+ features_.begin() + token_end * feature_vector_size_);
+ *output_tokens = VectorSpan<Token>(tokens_.begin() + token_start,
+ tokens_.begin() + token_end);
+ if (remap_v0_feature_vector_) {
+ RemapV0FeatureVector(features);
+ }
+
+ return true;
+}
+
+void CachedFeatures::RemapV0FeatureVector(VectorSpan<float>* features) {
+ if (!remap_v0_feature_vector_) {
+ return;
+ }
+
+ auto it = features->begin();
+ int num_suffix_features =
+ feature_vector_size_ - remap_v0_chargram_embedding_size_;
+ int num_tokens = context_size_ * 2 + 1;
+ for (int t = 0; t < num_tokens; ++t) {
+ for (int i = 0; i < remap_v0_chargram_embedding_size_; ++i) {
+ v0_feature_storage_[t * remap_v0_chargram_embedding_size_ + i] = *it;
+ ++it;
+ }
+ // Rest of the features are the dense features that come to the end.
+ for (int i = 0; i < num_suffix_features; ++i) {
+ // clang-format off
+ v0_feature_storage_[num_tokens * remap_v0_chargram_embedding_size_
+ + t * num_suffix_features
+ + i] = *it;
+ // clang-format on
+ ++it;
+ }
+ }
+ *features = VectorSpan<float>(v0_feature_storage_);
+}
+
+} // namespace libtextclassifier
diff --git a/smartselect/cached-features.h b/smartselect/cached-features.h
new file mode 100644
index 0000000..6490748
--- /dev/null
+++ b/smartselect/cached-features.h
@@ -0,0 +1,106 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#ifndef LIBTEXTCLASSIFIER_SMARTSELECT_CACHED_FEATURES_H_
+#define LIBTEXTCLASSIFIER_SMARTSELECT_CACHED_FEATURES_H_
+
+#include <memory>
+#include <vector>
+
+#include "base.h"
+#include "common/vector-span.h"
+#include "smartselect/types.h"
+
+namespace libtextclassifier {
+
+// Holds state for extracting features across multiple calls and reusing them.
+// Assumes that features for each Token are independent.
+class CachedFeatures {
+ public:
+ // Extracts the features for the given sequence of tokens.
+ // - context_size: Specifies how many tokens to the left, and how many
+ // tokens to the right spans the context.
+ // - sparse_features, dense_features: Extracted features for each token.
+ // - feature_vector_fn: Writes features for given Token to the specified
+ // storage.
+ // NOTE: The function can assume that the underlying
+ // storage is initialized to all zeros.
+ // - feature_vector_size: Size of a feature vector for one Token.
+ CachedFeatures(VectorSpan<Token> tokens, int context_size,
+ const std::vector<std::vector<int>>& sparse_features,
+ const std::vector<std::vector<float>>& dense_features,
+ const std::function<bool(const std::vector<int>&,
+ const std::vector<float>&, float*)>&
+ feature_vector_fn,
+ int feature_vector_size)
+ : tokens_(tokens),
+ context_size_(context_size),
+ feature_vector_size_(feature_vector_size),
+ remap_v0_feature_vector_(false),
+ remap_v0_chargram_embedding_size_(-1) {
+ Extract(sparse_features, dense_features, feature_vector_fn);
+ }
+
+ // Gets a VectorSpan with the features for given click position.
+ bool Get(int click_pos, VectorSpan<float>* features,
+ VectorSpan<Token>* output_tokens);
+
+ // Turns on a compatibility mode, which re-maps the extracted features to the
+ // v0 feature format (where the dense features were at the end).
+ // WARNING: Internally v0_feature_storage_ is used as a backing buffer for
+ // VectorSpan<float>, so the output of Extract is valid only until the next
+ // call or destruction of the current CachedFeatures object.
+ // TODO(zilka): Remove when we'll have retrained models.
+ void SetV0FeatureMode(int chargram_embedding_size) {
+ remap_v0_feature_vector_ = true;
+ remap_v0_chargram_embedding_size_ = chargram_embedding_size;
+ v0_feature_storage_.resize(feature_vector_size_ * (context_size_ * 2 + 1));
+ }
+
+ protected:
+ // Extracts features for all tokens and stores them for later retrieval.
+ void Extract(const std::vector<std::vector<int>>& sparse_features,
+ const std::vector<std::vector<float>>& dense_features,
+ const std::function<bool(const std::vector<int>&,
+ const std::vector<float>&, float*)>&
+ feature_vector_fn);
+
+ // Remaps extracted features to V0 feature format. The mapping is using
+ // the v0_feature_storage_ as the backing storage for the mapped features.
+ // For each token the features consist of:
+ // - chargram embeddings
+ // - dense features
+ // They are concatenated together as [chargram embeddings; dense features]
+ // for each token independently.
+ // The V0 features require that the chargram embeddings for tokens are
+ // concatenated first together, and at the end, the dense features for the
+ // tokens are concatenated to it.
+ void RemapV0FeatureVector(VectorSpan<float>* features);
+
+ private:
+ const VectorSpan<Token> tokens_;
+ const int context_size_;
+ const int feature_vector_size_;
+ bool remap_v0_feature_vector_;
+ int remap_v0_chargram_embedding_size_;
+
+ std::vector<float> features_;
+ std::vector<float> v0_feature_storage_;
+};
+
+} // namespace libtextclassifier
+
+#endif // LIBTEXTCLASSIFIER_SMARTSELECT_CACHED_FEATURES_H_
diff --git a/smartselect/feature-processor.cc b/smartselect/feature-processor.cc
index 0ac7bd3..0ba25ca 100644
--- a/smartselect/feature-processor.cc
+++ b/smartselect/feature-processor.cc
@@ -93,8 +93,7 @@
for (const auto& split_point : split_points) {
Token new_token(token_word.UTF8Substring(last_start, split_point),
current_pos,
- current_pos + std::distance(last_start, split_point),
- /*is_in_span=*/false);
+ current_pos + std::distance(last_start, split_point));
last_start = split_point;
current_pos = new_token.end;
@@ -169,7 +168,13 @@
} // namespace internal
-const char* const FeatureProcessor::kFeatureTypeName = "chargram_continuous";
+std::string FeatureProcessor::GetDefaultCollection() const {
+ if (options_.default_collection() >= options_.collections_size()) {
+ TC_LOG(ERROR) << "No collections specified. Returning empty string.";
+ return "";
+ }
+ return options_.collections(options_.default_collection());
+}
std::vector<Token> FeatureProcessor::Tokenize(
const std::string& utf8_text) const {
@@ -177,7 +182,7 @@
}
bool FeatureProcessor::LabelToSpan(
- const int label, const std::vector<Token>& tokens,
+ const int label, const VectorSpan<Token>& tokens,
std::pair<CodepointIndex, CodepointIndex>* span) const {
if (tokens.size() != GetNumContextTokens()) {
return false;
@@ -283,7 +288,8 @@
TokenIndex end_token = kInvalidIndex;
for (int i = 0; i < selectable_tokens.size(); ++i) {
if (codepoint_start <= selectable_tokens[i].start &&
- codepoint_end >= selectable_tokens[i].end) {
+ codepoint_end >= selectable_tokens[i].end &&
+ !selectable_tokens[i].is_padding) {
if (start_token == kInvalidIndex) {
start_token = i;
}
@@ -386,194 +392,16 @@
}
}
-bool FeatureProcessor::GetFeatures(
- const std::string& context, CodepointSpan input_span,
- std::vector<nlp_core::FeatureVector>* features,
- std::vector<float>* extra_features,
+bool FeatureProcessor::SelectionLabelSpans(
+ const VectorSpan<Token> tokens,
std::vector<CodepointSpan>* selection_label_spans) const {
- return FeatureProcessor::GetFeaturesAndLabels(
- context, input_span, {kInvalidIndex, kInvalidIndex}, "", features,
- extra_features, selection_label_spans, /*selection_label=*/nullptr,
- /*selection_codepoint_label=*/nullptr, /*classification_label=*/nullptr);
-}
-
-bool FeatureProcessor::GetFeaturesAndLabels(
- const std::string& context, CodepointSpan input_span,
- CodepointSpan label_span, const std::string& label_collection,
- std::vector<nlp_core::FeatureVector>* features,
- std::vector<float>* extra_features,
- std::vector<CodepointSpan>* selection_label_spans, int* selection_label,
- CodepointSpan* selection_codepoint_label, int* classification_label) const {
- if (features == nullptr) {
- return false;
- }
- *features =
- std::vector<nlp_core::FeatureVector>(options_.context_size() * 2 + 1);
-
- std::vector<Token> input_tokens = Tokenize(context);
-
- if (options_.split_tokens_on_selection_boundaries()) {
- internal::SplitTokensOnSelectionBoundaries(input_span, &input_tokens);
- }
-
- if (options_.only_use_line_with_click()) {
- internal::StripTokensFromOtherLines(context, input_span, &input_tokens);
- }
-
- const int click_pos = FindCenterToken(input_span, input_tokens);
- if (click_pos == kInvalidIndex) {
- TC_LOG(ERROR) << "Could not extract click position.";
- return false;
- }
-
- if (options_.min_supported_codepoint_ratio() > 0) {
- const float supported_codepoint_ratio =
- SupportedCodepointsRatio(click_pos, input_tokens);
- if (supported_codepoint_ratio < options_.min_supported_codepoint_ratio()) {
- TC_LOG(INFO) << "Not enough supported codepoints in the context: "
- << supported_codepoint_ratio;
+ for (int i = 0; i < label_to_selection_.size(); ++i) {
+ CodepointSpan span;
+ if (!LabelToSpan(i, tokens, &span)) {
+ TC_LOG(ERROR) << "Could not convert label to span: " << i;
return false;
}
- }
-
- std::vector<Token> output_tokens;
- bool status = ComputeFeatures(click_pos, input_tokens, input_span, features,
- extra_features, &output_tokens);
- if (!status) {
- TC_LOG(ERROR) << "Feature computation failed.";
- return false;
- }
-
- if (selection_label != nullptr) {
- status = SpanToLabel(label_span, output_tokens, selection_label);
- if (!status) {
- TC_LOG(ERROR) << "Could not convert selection span to label.";
- return false;
- }
- }
-
- if (selection_codepoint_label != nullptr) {
- *selection_codepoint_label = label_span;
- }
-
- if (selection_label_spans != nullptr) {
- for (int i = 0; i < label_to_selection_.size(); ++i) {
- CodepointSpan span;
- status = LabelToSpan(i, output_tokens, &span);
- if (!status) {
- TC_LOG(ERROR) << "Could not convert label to span: " << i;
- return false;
- }
- selection_label_spans->push_back(span);
- }
- }
-
- if (classification_label != nullptr) {
- *classification_label = CollectionToLabel(label_collection);
- }
-
- return true;
-}
-
-bool FeatureProcessor::GetFeaturesAndLabels(
- const std::string& context, CodepointSpan input_span,
- CodepointSpan label_span, const std::string& label_collection,
- std::vector<std::vector<std::pair<int, float>>>* features,
- std::vector<float>* extra_features,
- std::vector<CodepointSpan>* selection_label_spans, int* selection_label,
- CodepointSpan* selection_codepoint_label, int* classification_label) const {
- if (features == nullptr) {
- return false;
- }
- if (extra_features == nullptr) {
- return false;
- }
-
- std::vector<nlp_core::FeatureVector> feature_vectors;
- bool result = GetFeaturesAndLabels(
- context, input_span, label_span, label_collection, &feature_vectors,
- extra_features, selection_label_spans, selection_label,
- selection_codepoint_label, classification_label);
-
- if (!result) {
- return false;
- }
-
- features->clear();
- for (int i = 0; i < feature_vectors.size(); ++i) {
- features->emplace_back();
- for (int j = 0; j < feature_vectors[i].size(); ++j) {
- nlp_core::FloatFeatureValue feature_value(feature_vectors[i].value(j));
- (*features)[i].push_back({feature_value.id, feature_value.weight});
- }
- }
-
- return true;
-}
-
-bool FeatureProcessor::ComputeFeatures(
- int click_pos, const std::vector<Token>& tokens,
- CodepointSpan selected_span, std::vector<nlp_core::FeatureVector>* features,
- std::vector<float>* extra_features,
- std::vector<Token>* output_tokens) const {
- int dropout_left = 0;
- int dropout_right = 0;
- if (options_.context_dropout_probability() > 0.0) {
- // Determine how much context to drop.
- bool status = GetContextDropoutRange(&dropout_left, &dropout_right);
- if (!status) {
- return false;
- }
- }
-
- int feature_index = 0;
- output_tokens->reserve(options_.context_size() * 2 + 1);
- const int num_extra_features =
- static_cast<int>(options_.extract_case_feature()) +
- static_cast<int>(options_.extract_selection_mask_feature());
- extra_features->reserve((options_.context_size() * 2 + 1) *
- num_extra_features);
- for (int i = click_pos - options_.context_size();
- i <= click_pos + options_.context_size(); ++i, ++feature_index) {
- std::vector<int> sparse_features;
- std::vector<float> dense_features;
-
- const bool is_valid_token = i >= 0 && i < tokens.size();
-
- bool is_dropped = false;
- if (options_.context_dropout_probability() > 0.0) {
- if (i < click_pos - options_.context_size() + dropout_left) {
- is_dropped = true;
- } else if (i > click_pos + options_.context_size() - dropout_right) {
- is_dropped = true;
- }
- }
-
- if (is_valid_token && !is_dropped) {
- Token token(tokens[i]);
- token.is_in_span = token.start >= selected_span.first &&
- token.end <= selected_span.second;
- feature_extractor_.Extract(token, &sparse_features, &dense_features);
- output_tokens->push_back(tokens[i]);
- } else {
- feature_extractor_.Extract(Token(), &sparse_features, &dense_features);
- // This adds an empty string for each missing context token to exactly
- // match the input tokens to the network.
- output_tokens->emplace_back();
- }
-
- for (int feature_id : sparse_features) {
- const int64 feature_value =
- nlp_core::FloatFeatureValue(feature_id, 1.0 / sparse_features.size())
- .discrete_value;
- (*features)[feature_index].add(
- const_cast<nlp_core::NumericFeatureType*>(&feature_type_),
- feature_value);
- }
-
- for (float value : dense_features) {
- extra_features->push_back(value);
- }
+ selection_label_spans->push_back(span);
}
return true;
}
@@ -680,25 +508,129 @@
}
}
-bool FeatureProcessor::GetContextDropoutRange(int* dropout_left,
- int* dropout_right) const {
- std::uniform_real_distribution<> uniform01_draw(0, 1);
- if (uniform01_draw(*random_) < options_.context_dropout_probability()) {
- if (options_.use_variable_context_dropout()) {
- std::uniform_int_distribution<> uniform_context_draw(
- 0, options_.context_size());
- // Select how much to drop in the range: [zero; context size]
- *dropout_left = uniform_context_draw(*random_);
- *dropout_right = uniform_context_draw(*random_);
- } else {
- // Drop all context.
+void FeatureProcessor::TokenizeAndFindClick(const std::string& context,
+ CodepointSpan input_span,
+ std::vector<Token>* tokens,
+ int* click_pos) const {
+ TC_CHECK(tokens != nullptr);
+ *tokens = Tokenize(context);
+
+ if (options_.split_tokens_on_selection_boundaries()) {
+ internal::SplitTokensOnSelectionBoundaries(input_span, tokens);
+ }
+
+ if (options_.only_use_line_with_click()) {
+ internal::StripTokensFromOtherLines(context, input_span, tokens);
+ }
+
+ int local_click_pos;
+ if (click_pos == nullptr) {
+ click_pos = &local_click_pos;
+ }
+ *click_pos = FindCenterToken(input_span, *tokens);
+}
+
+namespace internal {
+
+void StripOrPadTokens(TokenSpan relative_click_span, int context_size,
+ std::vector<Token>* tokens, int* click_pos) {
+ int right_context_needed = relative_click_span.second + context_size;
+ if (*click_pos + right_context_needed + 1 >= tokens->size()) {
+ // Pad max the context size.
+ const int num_pad_tokens = std::min(
+ context_size, static_cast<int>(*click_pos + right_context_needed + 1 -
+ tokens->size()));
+ std::vector<Token> pad_tokens(num_pad_tokens);
+ tokens->insert(tokens->end(), pad_tokens.begin(), pad_tokens.end());
+ } else if (*click_pos + right_context_needed + 1 < tokens->size() - 1) {
+ // Strip unused tokens.
+ auto it = tokens->begin();
+ std::advance(it, *click_pos + right_context_needed + 1);
+ tokens->erase(it, tokens->end());
+ }
+
+ int left_context_needed = relative_click_span.first + context_size;
+ if (*click_pos < left_context_needed) {
+ // Pad max the context size.
+ const int num_pad_tokens =
+ std::min(context_size, left_context_needed - *click_pos);
+ std::vector<Token> pad_tokens(num_pad_tokens);
+ tokens->insert(tokens->begin(), pad_tokens.begin(), pad_tokens.end());
+ *click_pos += num_pad_tokens;
+ } else if (*click_pos > left_context_needed) {
+ // Strip unused tokens.
+ auto it = tokens->begin();
+ std::advance(it, *click_pos - left_context_needed);
+ *click_pos -= it - tokens->begin();
+ tokens->erase(tokens->begin(), it);
+ }
+}
+
+} // namespace internal
+
+bool FeatureProcessor::ExtractFeatures(
+ const std::string& context, CodepointSpan input_span,
+ TokenSpan relative_click_span, const FeatureVectorFn& feature_vector_fn,
+ int feature_vector_size, std::vector<Token>* tokens, int* click_pos,
+ std::unique_ptr<CachedFeatures>* cached_features) const {
+ TokenizeAndFindClick(context, input_span, tokens, click_pos);
+
+ // If the default click method failed fails, let's try to do sub-token
+ // matching before we fail.
+ if (*click_pos == kInvalidIndex) {
+ *click_pos = internal::CenterTokenFromClick(input_span, *tokens);
+ if (*click_pos == kInvalidIndex) {
return false;
}
- } else {
- *dropout_left = 0;
- *dropout_right = 0;
}
+
+ internal::StripOrPadTokens(relative_click_span, options_.context_size(),
+ tokens, click_pos);
+
+ if (options_.min_supported_codepoint_ratio() > 0) {
+ const float supported_codepoint_ratio =
+ SupportedCodepointsRatio(*click_pos, *tokens);
+ if (supported_codepoint_ratio < options_.min_supported_codepoint_ratio()) {
+ TC_LOG(INFO) << "Not enough supported codepoints in the context: "
+ << supported_codepoint_ratio;
+ return false;
+ }
+ }
+
+ std::vector<std::vector<int>> sparse_features(tokens->size());
+ std::vector<std::vector<float>> dense_features(tokens->size());
+ for (int i = 0; i < tokens->size(); ++i) {
+ const Token& token = (*tokens)[i];
+ if (!feature_extractor_.Extract(token, token.IsContainedInSpan(input_span),
+ &(sparse_features[i]),
+ &(dense_features[i]))) {
+ TC_LOG(ERROR) << "Could not extract token's features: " << token;
+ return false;
+ }
+ }
+
+ cached_features->reset(new CachedFeatures(
+ *tokens, options_.context_size(), sparse_features, dense_features,
+ feature_vector_fn, feature_vector_size));
+
+ if (*cached_features == nullptr) {
+ return false;
+ }
+
+ if (options_.feature_version() == 0) {
+ (*cached_features)
+ ->SetV0FeatureMode(feature_vector_size -
+ feature_extractor_.DenseFeaturesCount());
+ }
+
return true;
}
+int FeatureProcessor::PadContext(std::vector<Token>* tokens) const {
+ std::vector<Token> pad_tokens(options_.context_size());
+ tokens->insert(tokens->begin(), pad_tokens.begin(), pad_tokens.end());
+ tokens->insert(tokens->end(), pad_tokens.begin(), pad_tokens.end());
+ return options_.context_size();
+}
+
} // namespace libtextclassifier
diff --git a/smartselect/feature-processor.h b/smartselect/feature-processor.h
index 85043e3..2f1e530 100644
--- a/smartselect/feature-processor.h
+++ b/smartselect/feature-processor.h
@@ -24,16 +24,25 @@
#include <string>
#include <vector>
-#include "common/feature-extractor.h"
+#include "smartselect/cached-features.h"
#include "smartselect/text-classification-model.pb.h"
#include "smartselect/token-feature-extractor.h"
#include "smartselect/tokenizer.h"
#include "smartselect/types.h"
+#include "util/base/logging.h"
namespace libtextclassifier {
constexpr int kInvalidLabel = -1;
+// Maps a vector of sparse features and a vector of dense features to a vector
+// of features that combines both.
+// The output is written to the memory location pointed to by the last float*
+// argument.
+// Returns true on success false on failure.
+using FeatureVectorFn = std::function<bool(const std::vector<int>&,
+ const std::vector<float>&, float*)>;
+
namespace internal {
// Parses the serialized protocol buffer.
@@ -61,23 +70,27 @@
int CenterTokenFromMiddleOfSelection(
CodepointSpan span, const std::vector<Token>& selectable_tokens);
+// Strips the tokens from the tokens vector that are not used for feature
+// extraction because they are out of scope, or pads them so that there is
+// enough tokens in the required context_size for all inferences with a click
+// in relative_click_span.
+void StripOrPadTokens(TokenSpan relative_click_span, int context_size,
+ std::vector<Token>* tokens, int* click_pos);
+
} // namespace internal
TokenSpan CodepointSpanToTokenSpan(const std::vector<Token>& selectable_tokens,
CodepointSpan codepoint_span);
-// Takes care of preparing features for the FFModel.
+// Takes care of preparing features for the span prediction model.
class FeatureProcessor {
public:
explicit FeatureProcessor(const FeatureProcessorOptions& options)
- : options_(options),
- feature_extractor_(
+ : feature_extractor_(
internal::BuildTokenFeatureExtractorOptions(options)),
- feature_type_(FeatureProcessor::kFeatureTypeName,
- options.num_buckets()),
+ options_(options),
tokenizer_({options.tokenization_codepoint_config().begin(),
- options.tokenization_codepoint_config().end()}),
- random_(new std::mt19937(std::random_device()())) {
+ options.tokenization_codepoint_config().end()}) {
MakeLabelMaps();
PrepareSupportedCodepointRanges(
{options.supported_codepoint_ranges().begin(),
@@ -91,38 +104,12 @@
// Tokenizes the input string using the selected tokenization method.
std::vector<Token> Tokenize(const std::string& utf8_text) const;
- bool GetFeatures(const std::string& context, CodepointSpan input_span,
- std::vector<nlp_core::FeatureVector>* features,
- std::vector<float>* extra_features,
- std::vector<CodepointSpan>* selection_label_spans) const;
-
- // NOTE: If dropout is on, subsequent calls of this function with the same
- // arguments might return different results.
- bool GetFeaturesAndLabels(const std::string& context,
- CodepointSpan input_span, CodepointSpan label_span,
- const std::string& label_collection,
- std::vector<nlp_core::FeatureVector>* features,
- std::vector<float>* extra_features,
- std::vector<CodepointSpan>* selection_label_spans,
- int* selection_label,
- CodepointSpan* selection_codepoint_label,
- int* classification_label) const;
-
- // Same as above but uses std::vector instead of FeatureVector.
- // NOTE: If dropout is on, subsequent calls of this function with the same
- // arguments might return different results.
- bool GetFeaturesAndLabels(
- const std::string& context, CodepointSpan input_span,
- CodepointSpan label_span, const std::string& label_collection,
- std::vector<std::vector<std::pair<int, float>>>* features,
- std::vector<float>* extra_features,
- std::vector<CodepointSpan>* selection_label_spans, int* selection_label,
- CodepointSpan* selection_codepoint_label,
- int* classification_label) const;
-
// Converts a label into a token span.
bool LabelToTokenSpan(int label, TokenSpan* token_span) const;
+ // Gets the total number of selection labels.
+ int GetSelectionLabelCount() const { return label_to_selection_.size(); }
+
// Gets the string value for given collection label.
std::string LabelToCollection(int label) const;
@@ -130,17 +117,35 @@
int NumCollections() const { return collection_to_label_.size(); }
// Gets the name of the default collection.
- std::string GetDefaultCollection() const {
- return options_.collections(options_.default_collection());
+ std::string GetDefaultCollection() const;
+
+ const FeatureProcessorOptions& GetOptions() const { return options_; }
+
+ // Tokenizes the context and input span, and finds the click position.
+ void TokenizeAndFindClick(const std::string& context,
+ CodepointSpan input_span,
+ std::vector<Token>* tokens, int* click_pos) const;
+
+ // Extracts features as a CachedFeatures object that can be used for repeated
+ // inference over token spans in the given context.
+ bool ExtractFeatures(const std::string& context, CodepointSpan input_span,
+ TokenSpan relative_click_span,
+ const FeatureVectorFn& feature_vector_fn,
+ int feature_vector_size, std::vector<Token>* tokens,
+ int* click_pos,
+ std::unique_ptr<CachedFeatures>* cached_features) const;
+
+ // Fills selection_label_spans with CodepointSpans that correspond to the
+ // selection labels. The CodepointSpans are based on the codepoint ranges of
+ // given tokens.
+ bool SelectionLabelSpans(
+ VectorSpan<Token> tokens,
+ std::vector<CodepointSpan>* selection_label_spans) const;
+
+ int DenseFeaturesCount() const {
+ return feature_extractor_.DenseFeaturesCount();
}
- FeatureProcessorOptions GetOptions() const { return options_; }
-
- int GetSelectionLabelCount() const { return label_to_selection_.size(); }
-
- // Sets the source of randomness.
- void SetRandom(std::mt19937* new_random) { random_.reset(new_random); }
-
protected:
// Represents a codepoint range [start, end).
struct CodepointRange {
@@ -151,23 +156,6 @@
: start(arg_start), end(arg_end) {}
};
- // Extracts features for given word.
- std::vector<int> GetWordFeatures(const std::string& word) const;
-
- // NOTE: If dropout is on, subsequent calls of this function with the same
- // arguments might return different results.
- bool ComputeFeatures(int click_pos,
- const std::vector<Token>& selectable_tokens,
- CodepointSpan selected_span,
- std::vector<nlp_core::FeatureVector>* features,
- std::vector<float>* extra_features,
- std::vector<Token>* output_tokens) const;
-
- // Helper function that computes how much left context and how much right
- // context should be dropped. Uses a mutable random_ member as a source of
- // randomness.
- bool GetContextDropoutRange(int* dropout_left, int* dropout_right) const;
-
// Returns the class id corresponding to the given string collection
// identifier. There is a catch-all class id that the function returns for
// unknown collections.
@@ -185,7 +173,7 @@
// Converts a label into a span of codepoint indices corresponding to it
// given output_tokens.
- bool LabelToSpan(int label, const std::vector<Token>& output_tokens,
+ bool LabelToSpan(int label, const VectorSpan<Token>& output_tokens,
CodepointSpan* span) const;
// Converts a span to the corresponding label given output_tokens.
@@ -195,11 +183,6 @@
// Converts a token span to the corresponding label.
int TokenSpanToLabel(const std::pair<TokenIndex, TokenIndex>& span) const;
- // Finds the center token index in tokens vector, using the method defined
- // in options_.
- int FindCenterToken(CodepointSpan span,
- const std::vector<Token>& tokens) const;
-
void PrepareSupportedCodepointRanges(
const std::vector<FeatureProcessorOptions::CodepointRange>&
codepoint_range_configs);
@@ -212,14 +195,18 @@
// Returns true if given codepoint is supported.
bool IsCodepointSupported(int codepoint) const;
+ // Finds the center token index in tokens vector, using the method defined
+ // in options_.
+ int FindCenterToken(CodepointSpan span,
+ const std::vector<Token>& tokens) const;
+
+ // Pads tokens with options.context_size() padding tokens on both sides.
+ int PadContext(std::vector<Token>* tokens) const;
+
+ const TokenFeatureExtractor feature_extractor_;
+
private:
- FeatureProcessorOptions options_;
-
- TokenFeatureExtractor feature_extractor_;
-
- static const char* const kFeatureTypeName;
-
- nlp_core::NumericFeatureType feature_type_;
+ const FeatureProcessorOptions options_;
// Mapping between token selection spans and labels ids.
std::map<TokenSpan, int> selection_to_label_;
@@ -233,9 +220,6 @@
// Codepoint ranges that define what codepoints are supported by the model.
// NOTE: Must be sorted.
std::vector<CodepointRange> supported_codepoint_ranges_;
-
- // Source of randomness.
- mutable std::unique_ptr<std::mt19937> random_;
};
} // namespace libtextclassifier
diff --git a/smartselect/text-classification-model.cc b/smartselect/text-classification-model.cc
index b21614d..ca0484f 100644
--- a/smartselect/text-classification-model.cc
+++ b/smartselect/text-classification-model.cc
@@ -91,6 +91,49 @@
sharing_options_ = selection_params_->GetSharingModelOptions();
}
+namespace {
+
+// Converts sparse features vector to nlp_core::FeatureVector.
+void SparseFeaturesToFeatureVector(
+ const std::vector<int> sparse_features,
+ const nlp_core::NumericFeatureType& feature_type,
+ nlp_core::FeatureVector* result) {
+ for (int feature_id : sparse_features) {
+ const int64 feature_value =
+ nlp_core::FloatFeatureValue(feature_id, 1.0 / sparse_features.size())
+ .discrete_value;
+ result->add(const_cast<nlp_core::NumericFeatureType*>(&feature_type),
+ feature_value);
+ }
+}
+
+// Returns a function that can be used for mapping sparse and dense features
+// to a float feature vector.
+// NOTE: The network object needs to be available at the time when the returned
+// function object is used.
+FeatureVectorFn CreateFeatureVectorFn(const EmbeddingNetwork& network,
+ int sparse_embedding_size) {
+ const nlp_core::NumericFeatureType feature_type("chargram_continuous", 0);
+ return [&network, sparse_embedding_size, feature_type](
+ const std::vector<int>& sparse_features,
+ const std::vector<float>& dense_features, float* embedding) {
+ nlp_core::FeatureVector feature_vector;
+ SparseFeaturesToFeatureVector(sparse_features, feature_type,
+ &feature_vector);
+
+ if (network.GetEmbedding(feature_vector, 0, embedding)) {
+ for (int i = 0; i < dense_features.size(); i++) {
+ embedding[sparse_embedding_size + i] = dense_features[i];
+ }
+ return true;
+ } else {
+ return false;
+ }
+ };
+}
+
+} // namespace
+
bool TextClassificationModel::LoadModels(int fd) {
MmapHandle mmap_handle = MmapFile(fd);
if (!mmap_handle.ok()) {
@@ -111,6 +154,8 @@
selection_network_.reset(new EmbeddingNetwork(selection_params_.get()));
selection_feature_processor_.reset(
new FeatureProcessor(selection_params_->GetFeatureProcessorOptions()));
+ selection_feature_fn_ = CreateFeatureVectorFn(
+ *selection_network_, selection_network_->EmbeddingSize(0));
model_data += selection_model_length;
uint32 sharing_model_length =
@@ -125,25 +170,39 @@
sharing_network_.reset(new EmbeddingNetwork(sharing_params_.get()));
sharing_feature_processor_.reset(
new FeatureProcessor(sharing_params_->GetFeatureProcessorOptions()));
+ sharing_feature_fn_ = CreateFeatureVectorFn(
+ *sharing_network_, sharing_network_->EmbeddingSize(0));
return true;
}
EmbeddingNetwork::Vector TextClassificationModel::InferInternal(
const std::string& context, CodepointSpan span,
- const FeatureProcessor& feature_processor, const EmbeddingNetwork* network,
+ const FeatureProcessor& feature_processor, const EmbeddingNetwork& network,
+ const FeatureVectorFn& feature_vector_fn,
std::vector<CodepointSpan>* selection_label_spans) const {
- std::vector<FeatureVector> features;
- std::vector<float> extra_features;
- const bool features_computed = feature_processor.GetFeatures(
- context, span, &features, &extra_features, selection_label_spans);
-
- EmbeddingNetwork::Vector scores;
- if (!features_computed) {
- TC_LOG(ERROR) << "Features not computed";
- return scores;
+ std::vector<Token> tokens;
+ int click_pos;
+ std::unique_ptr<CachedFeatures> cached_features;
+ int embedding_size = network.EmbeddingSize(0);
+ if (!feature_processor.ExtractFeatures(
+ context, span, /*relative_click_span=*/{0, 0},
+ CreateFeatureVectorFn(network, embedding_size),
+ embedding_size + feature_processor.DenseFeaturesCount(), &tokens,
+ &click_pos, &cached_features)) {
+ TC_LOG(ERROR) << "Could not extract features.";
+ return {};
}
- network->ComputeFinalScores(features, extra_features, &scores);
+
+ VectorSpan<float> features;
+ VectorSpan<Token> output_tokens;
+ if (!cached_features->Get(click_pos, &features, &output_tokens)) {
+ TC_LOG(ERROR) << "Could not extract features.";
+ return {};
+ }
+
+ std::vector<float> scores;
+ network.ComputeLogits(features, &scores);
return scores;
}
@@ -161,6 +220,15 @@
return click_indices;
}
+ const UnicodeText context_unicode =
+ UTF8ToUnicodeText(context, /*do_copy=*/false);
+ const int context_length =
+ std::distance(context_unicode.begin(), context_unicode.end());
+ if (std::get<0>(click_indices) >= context_length ||
+ std::get<1>(click_indices) > context_length) {
+ return click_indices;
+ }
+
CodepointSpan result;
if (selection_options_.enforce_symmetry()) {
result = SuggestSelectionSymmetrical(context, click_indices);
@@ -176,21 +244,12 @@
return result;
}
-std::pair<CodepointSpan, float>
-TextClassificationModel::SuggestSelectionInternal(
- const std::string& context, CodepointSpan click_indices) const {
- if (!initialized_) {
- TC_LOG(ERROR) << "Not initialized";
- return {click_indices, -1.0};
- }
+namespace {
- std::vector<CodepointSpan> selection_label_spans;
- EmbeddingNetwork::Vector scores =
- InferInternal(context, click_indices, *selection_feature_processor_,
- selection_network_.get(), &selection_label_spans);
-
+std::pair<CodepointSpan, float> BestSelectionSpan(
+ CodepointSpan original_click_indices, const std::vector<float>& scores,
+ const std::vector<CodepointSpan>& selection_label_spans) {
if (!scores.empty()) {
- scores = nlp_core::ComputeSoftmax(scores);
const int prediction =
std::max_element(scores.begin(), scores.end()) - scores.begin();
std::pair<CodepointIndex, CodepointIndex> selection =
@@ -200,15 +259,34 @@
TC_LOG(ERROR) << "Invalid indices predicted, returning input: "
<< prediction << " " << selection.first << " "
<< selection.second;
- return {click_indices, -1.0};
+ return {original_click_indices, -1.0};
}
return {{selection.first, selection.second}, scores[prediction]};
} else {
TC_LOG(ERROR) << "Returning default selection: scores.size() = "
<< scores.size();
+ return {original_click_indices, -1.0};
+ }
+}
+
+} // namespace
+
+std::pair<CodepointSpan, float>
+TextClassificationModel::SuggestSelectionInternal(
+ const std::string& context, CodepointSpan click_indices) const {
+ if (!initialized_) {
+ TC_LOG(ERROR) << "Not initialized";
return {click_indices, -1.0};
}
+
+ std::vector<CodepointSpan> selection_label_spans;
+ EmbeddingNetwork::Vector scores = InferInternal(
+ context, click_indices, *selection_feature_processor_,
+ *selection_network_, selection_feature_fn_, &selection_label_spans);
+ scores = nlp_core::ComputeSoftmax(scores);
+
+ return BestSelectionSpan(click_indices, scores, selection_label_spans);
}
namespace {
@@ -245,27 +323,49 @@
// selection.
CodepointSpan TextClassificationModel::SuggestSelectionSymmetrical(
const std::string& context, CodepointSpan click_indices) const {
- std::vector<Token> tokens = selection_feature_processor_->Tokenize(context);
- internal::StripTokensFromOtherLines(context, click_indices, &tokens);
-
- // const int click_index = GetClickTokenIndex(tokens, click_indices);
- const int click_index = internal::CenterTokenFromClick(click_indices, tokens);
- if (click_index == kInvalidIndex) {
+ const int symmetry_context_size = selection_options_.symmetry_context_size();
+ std::vector<Token> tokens;
+ std::unique_ptr<CachedFeatures> cached_features;
+ int click_index;
+ int embedding_size = selection_network_->EmbeddingSize(0);
+ if (!selection_feature_processor_->ExtractFeatures(
+ context, click_indices, /*relative_click_span=*/
+ {symmetry_context_size, symmetry_context_size + 1},
+ selection_feature_fn_,
+ embedding_size + selection_feature_processor_->DenseFeaturesCount(),
+ &tokens, &click_index, &cached_features)) {
+ TC_LOG(ERROR) << "Couldn't ExtractFeatures.";
return click_indices;
}
- const int symmetry_context_size = selection_options_.symmetry_context_size();
-
// Scan in the symmetry context for selection span proposals.
std::vector<std::pair<CodepointSpan, float>> proposals;
+
for (int i = -symmetry_context_size; i < symmetry_context_size + 1; ++i) {
const int token_index = click_index + i;
- if (token_index >= 0 && token_index < tokens.size()) {
+ if (token_index >= 0 && token_index < tokens.size() &&
+ !tokens[token_index].is_padding) {
float score;
+ VectorSpan<float> features;
+ VectorSpan<Token> output_tokens;
+
CodepointSpan span;
- std::tie(span, score) = SuggestSelectionInternal(
- context, {tokens[token_index].start, tokens[token_index].end});
- proposals.push_back({span, score});
+ if (cached_features->Get(token_index, &features, &output_tokens)) {
+ std::vector<float> scores;
+ selection_network_->ComputeLogits(features, &scores);
+
+ std::vector<CodepointSpan> selection_label_spans;
+ if (selection_feature_processor_->SelectionLabelSpans(
+ output_tokens, &selection_label_spans)) {
+ scores = nlp_core::ComputeSoftmax(scores);
+ std::tie(span, score) =
+ BestSelectionSpan(click_indices, scores, selection_label_spans);
+ if (span.first != kInvalidIndex && span.second != kInvalidIndex &&
+ score >= 0) {
+ proposals.push_back({span, score});
+ }
+ }
+ }
}
}
@@ -315,6 +415,13 @@
return {};
}
+ if (std::get<0>(selection_indices) >= std::get<1>(selection_indices)) {
+ TC_LOG(ERROR) << "Trying to run ClassifyText with invalid indices: "
+ << std::get<0>(selection_indices) << " "
+ << std::get<1>(selection_indices);
+ return {};
+ }
+
if (hint_flags & SELECTION_IS_URL &&
sharing_options_.always_accept_url_hint()) {
return {{kUrlHintCollection, 1.0}};
@@ -327,7 +434,7 @@
EmbeddingNetwork::Vector scores =
InferInternal(context, selection_indices, *sharing_feature_processor_,
- sharing_network_.get(), nullptr);
+ *sharing_network_, sharing_feature_fn_, nullptr);
if (scores.empty()) {
TC_LOG(ERROR) << "Using default class";
return {};
diff --git a/smartselect/text-classification-model.h b/smartselect/text-classification-model.h
index 9f17ef5..ae2049b 100644
--- a/smartselect/text-classification-model.h
+++ b/smartselect/text-classification-model.h
@@ -88,7 +88,8 @@
nlp_core::EmbeddingNetwork::Vector InferInternal(
const std::string& context, CodepointSpan span,
const FeatureProcessor& feature_processor,
- const nlp_core::EmbeddingNetwork* network,
+ const nlp_core::EmbeddingNetwork& network,
+ const FeatureVectorFn& feature_vector_fn,
std::vector<CodepointSpan>* selection_label_spans) const;
// Returns a selection suggestion with a score.
@@ -104,9 +105,11 @@
std::unique_ptr<ModelParams> selection_params_;
std::unique_ptr<FeatureProcessor> selection_feature_processor_;
std::unique_ptr<nlp_core::EmbeddingNetwork> selection_network_;
+ FeatureVectorFn selection_feature_fn_;
std::unique_ptr<FeatureProcessor> sharing_feature_processor_;
std::unique_ptr<ModelParams> sharing_params_;
std::unique_ptr<nlp_core::EmbeddingNetwork> sharing_network_;
+ FeatureVectorFn sharing_feature_fn_;
std::set<int> punctuation_to_strip_;
};
diff --git a/smartselect/text-classification-model.proto b/smartselect/text-classification-model.proto
index ecf0fc6..7a4c9f1 100644
--- a/smartselect/text-classification-model.proto
+++ b/smartselect/text-classification-model.proto
@@ -159,6 +159,16 @@
// Minimum ratio of supported codepoints in the input context. If the ratio
// is lower than this, the feature computation will fail.
optional float min_supported_codepoint_ratio = 24 [default = 0.0];
+
+ // Used for versioning the format of features the model expects.
+ // - feature_version == 0:
+ // For each token the features consist of:
+ // - chargram embeddings
+ // - dense features
+ // Chargram embeddings for tokens are concatenated first together,
+ // and at the end, the dense features for the tokens are concatenated
+ // to it. So the resulting feature vector has two regions.
+ optional int32 feature_version = 25 [default = 0];
};
extend nlp_core.EmbeddingNetworkProto {
diff --git a/smartselect/token-feature-extractor.cc b/smartselect/token-feature-extractor.cc
index 1413385..6013ef3 100644
--- a/smartselect/token-feature-extractor.cc
+++ b/smartselect/token-feature-extractor.cc
@@ -207,7 +207,7 @@
return result;
}
-bool TokenFeatureExtractor::Extract(const Token& token,
+bool TokenFeatureExtractor::Extract(const Token& token, bool is_in_span,
std::vector<int>* sparse_features,
std::vector<float>* dense_features) const {
if (sparse_features == nullptr || dense_features == nullptr) {
@@ -235,7 +235,7 @@
}
if (options_.extract_selection_mask_feature) {
- if (token.is_in_span) {
+ if (is_in_span) {
dense_features->push_back(1.0);
} else {
if (options_.unicode_aware_features) {
@@ -270,23 +270,4 @@
return true;
}
-bool TokenFeatureExtractor::Extract(
- const std::vector<Token>& tokens,
- std::vector<std::vector<int>>* sparse_features,
- std::vector<std::vector<float>>* dense_features) const {
- if (sparse_features == nullptr || dense_features == nullptr) {
- return false;
- }
-
- sparse_features->resize(tokens.size());
- dense_features->resize(tokens.size());
- for (size_t i = 0; i < tokens.size(); i++) {
- if (!Extract(tokens[i], &((*sparse_features)[i]),
- &((*dense_features)[i]))) {
- return false;
- }
- }
- return true;
-}
-
} // namespace libtextclassifier
diff --git a/smartselect/token-feature-extractor.h b/smartselect/token-feature-extractor.h
index e606e91..8502199 100644
--- a/smartselect/token-feature-extractor.h
+++ b/smartselect/token-feature-extractor.h
@@ -59,17 +59,20 @@
explicit TokenFeatureExtractor(const TokenFeatureExtractorOptions& options);
// Extracts features from a token.
+ // - is_in_span is a bool indicator whether the token is a part of the
+ // selection span (true) or not (false).
// - sparse_features are indices into a sparse feature vector of size
// options.num_buckets which are set to 1.0 (others are implicitly 0.0).
// - dense_features are values of a dense feature vector of size 0-2
// (depending on the options) for the token
- bool Extract(const Token& token, std::vector<int>* sparse_features,
+ bool Extract(const Token& token, bool is_in_span,
+ std::vector<int>* sparse_features,
std::vector<float>* dense_features) const;
- // Convenience method that sequentially applies Extract to each Token.
- bool Extract(const std::vector<Token>& tokens,
- std::vector<std::vector<int>>* sparse_features,
- std::vector<std::vector<float>>* dense_features) const;
+ int DenseFeaturesCount() const {
+ return options_.extract_case_feature +
+ options_.extract_selection_mask_feature + regex_patterns_.size();
+ }
protected:
// Hashes given token to given number of buckets.
diff --git a/smartselect/types.h b/smartselect/types.h
index 9f07f91..443e3ac 100644
--- a/smartselect/types.h
+++ b/smartselect/types.h
@@ -48,43 +48,31 @@
CodepointIndex start;
CodepointIndex end;
- // Whether the token was in the input span.
- bool is_in_span;
-
// Whether the token is a padding token.
bool is_padding;
// Default constructor constructs the padding-token.
Token()
- : value(""),
- start(kInvalidIndex),
- end(kInvalidIndex),
- is_in_span(false),
- is_padding(true) {}
+ : value(""), start(kInvalidIndex), end(kInvalidIndex), is_padding(true) {}
Token(const std::string& arg_value, CodepointIndex arg_start,
CodepointIndex arg_end)
- : Token(arg_value, arg_start, arg_end, false) {}
-
- Token(const std::string& arg_value, CodepointIndex arg_start,
- CodepointIndex arg_end, bool is_in_span)
- : value(arg_value),
- start(arg_start),
- end(arg_end),
- is_in_span(is_in_span),
- is_padding(false) {}
+ : value(arg_value), start(arg_start), end(arg_end), is_padding(false) {}
bool operator==(const Token& other) const {
return value == other.value && start == other.start && end == other.end &&
- is_in_span == other.is_in_span && is_padding == other.is_padding;
+ is_padding == other.is_padding;
+ }
+
+ bool IsContainedInSpan(CodepointSpan span) const {
+ return start >= span.first && end <= span.second;
}
};
// Pretty-printing function for Token.
inline std::ostream& operator<<(std::ostream& os, const Token& token) {
return os << "Token(\"" << token.value << "\", " << token.start << ", "
- << token.end << ", is_in_span=" << token.is_in_span
- << ", is_padding=" << token.is_padding << ")";
+ << token.end << ", is_padding=" << token.is_padding << ")";
}
} // namespace libtextclassifier
diff --git a/tests/cached-features_test.cc b/tests/cached-features_test.cc
new file mode 100644
index 0000000..b456816
--- /dev/null
+++ b/tests/cached-features_test.cc
@@ -0,0 +1,149 @@
+/*
+ * Copyright (C) 2017 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "smartselect/cached-features.h"
+
+#include "gmock/gmock.h"
+#include "gtest/gtest.h"
+
+namespace libtextclassifier {
+namespace {
+
+class TestingCachedFeatures : public CachedFeatures {
+ public:
+ using CachedFeatures::CachedFeatures;
+ using CachedFeatures::RemapV0FeatureVector;
+};
+
+TEST(CachedFeaturesTest, Simple) {
+ std::vector<Token> tokens;
+ tokens.push_back(Token());
+ tokens.push_back(Token());
+ tokens.push_back(Token("Hello", 0, 1));
+ tokens.push_back(Token("World", 1, 2));
+ tokens.push_back(Token("today!", 2, 3));
+ tokens.push_back(Token());
+ tokens.push_back(Token());
+
+ std::vector<std::vector<int>> sparse_features(tokens.size());
+ for (int i = 0; i < sparse_features.size(); ++i) {
+ sparse_features[i].push_back(i);
+ }
+ std::vector<std::vector<float>> dense_features(tokens.size());
+ for (int i = 0; i < dense_features.size(); ++i) {
+ dense_features[i].push_back(-i);
+ }
+
+ TestingCachedFeatures feature_extractor(
+ tokens, /*context_size=*/2, sparse_features, dense_features,
+ [](const std::vector<int>& sparse_features,
+ const std::vector<float>& dense_features, float* features) {
+ features[0] = sparse_features[0];
+ features[1] = sparse_features[0];
+ features[2] = dense_features[0];
+ features[3] = dense_features[0];
+ features[4] = 123;
+ return true;
+ },
+ 5);
+
+ VectorSpan<float> features;
+ VectorSpan<Token> output_tokens;
+ EXPECT_TRUE(feature_extractor.Get(2, &features, &output_tokens));
+ for (int i = 0; i < 5; i++) {
+ EXPECT_EQ(features[i * 5 + 0], i) << "Feature " << i;
+ EXPECT_EQ(features[i * 5 + 1], i) << "Feature " << i;
+ EXPECT_EQ(features[i * 5 + 2], -i) << "Feature " << i;
+ EXPECT_EQ(features[i * 5 + 3], -i) << "Feature " << i;
+ EXPECT_EQ(features[i * 5 + 4], 123) << "Feature " << i;
+ }
+}
+
+TEST(CachedFeaturesTest, InvalidInput) {
+ std::vector<Token> tokens;
+ tokens.push_back(Token());
+ tokens.push_back(Token());
+ tokens.push_back(Token("Hello", 0, 1));
+ tokens.push_back(Token("World", 1, 2));
+ tokens.push_back(Token("today!", 2, 3));
+ tokens.push_back(Token());
+ tokens.push_back(Token());
+
+ std::vector<std::vector<int>> sparse_features(tokens.size());
+ std::vector<std::vector<float>> dense_features(tokens.size());
+
+ TestingCachedFeatures feature_extractor(
+ tokens, /*context_size=*/2, sparse_features, dense_features,
+ [](const std::vector<int>& sparse_features,
+ const std::vector<float>& dense_features,
+ float* features) { return true; },
+ /*feature_vector_size=*/5);
+
+ VectorSpan<float> features;
+ VectorSpan<Token> output_tokens;
+ EXPECT_FALSE(feature_extractor.Get(-1000, &features, &output_tokens));
+ EXPECT_FALSE(feature_extractor.Get(-1, &features, &output_tokens));
+ EXPECT_FALSE(feature_extractor.Get(0, &features, &output_tokens));
+ EXPECT_TRUE(feature_extractor.Get(2, &features, &output_tokens));
+ EXPECT_TRUE(feature_extractor.Get(4, &features, &output_tokens));
+ EXPECT_FALSE(feature_extractor.Get(5, &features, &output_tokens));
+ EXPECT_FALSE(feature_extractor.Get(500, &features, &output_tokens));
+}
+
+TEST(CachedFeaturesTest, RemapV0FeatureVector) {
+ std::vector<Token> tokens;
+ tokens.push_back(Token());
+ tokens.push_back(Token());
+ tokens.push_back(Token("Hello", 0, 1));
+ tokens.push_back(Token("World", 1, 2));
+ tokens.push_back(Token("today!", 2, 3));
+ tokens.push_back(Token());
+ tokens.push_back(Token());
+
+ std::vector<std::vector<int>> sparse_features(tokens.size());
+ std::vector<std::vector<float>> dense_features(tokens.size());
+
+ TestingCachedFeatures feature_extractor(
+ tokens, /*context_size=*/2, sparse_features, dense_features,
+ [](const std::vector<int>& sparse_features,
+ const std::vector<float>& dense_features,
+ float* features) { return true; },
+ /*feature_vector_size=*/5);
+
+ std::vector<float> features_orig(5 * 5);
+ for (int i = 0; i < features_orig.size(); i++) {
+ features_orig[i] = i;
+ }
+ VectorSpan<float> features;
+
+ feature_extractor.SetV0FeatureMode(0);
+ features = VectorSpan<float>(features_orig);
+ feature_extractor.RemapV0FeatureVector(&features);
+ EXPECT_EQ(
+ std::vector<float>({0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12,
+ 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24}),
+ std::vector<float>(features.begin(), features.end()));
+
+ feature_extractor.SetV0FeatureMode(2);
+ features = VectorSpan<float>(features_orig);
+ feature_extractor.RemapV0FeatureVector(&features);
+ EXPECT_EQ(std::vector<float>({0, 1, 5, 6, 10, 11, 15, 16, 20, 21, 2, 3, 4,
+ 7, 8, 9, 12, 13, 14, 17, 18, 19, 22, 23, 24}),
+ std::vector<float>(features.begin(), features.end()));
+}
+
+} // namespace
+} // namespace libtextclassifier
diff --git a/tests/feature-processor_test.cc b/tests/feature-processor_test.cc
index 88a93f3..e3a39e3 100644
--- a/tests/feature-processor_test.cc
+++ b/tests/feature-processor_test.cc
@@ -26,83 +26,83 @@
using testing::FloatEq;
TEST(FeatureProcessorTest, SplitTokensOnSelectionBoundariesMiddle) {
- std::vector<Token> tokens{Token("Hělló", 0, 5, false),
- Token("fěěbař@google.com", 6, 23, false),
- Token("heře!", 24, 29, false)};
+ std::vector<Token> tokens{Token("Hělló", 0, 5),
+ Token("fěěbař@google.com", 6, 23),
+ Token("heře!", 24, 29)};
internal::SplitTokensOnSelectionBoundaries({9, 12}, &tokens);
// clang-format off
EXPECT_THAT(tokens, ElementsAreArray(
- {Token("Hělló", 0, 5, false),
- Token("fěě", 6, 9, false),
- Token("bař", 9, 12, false),
- Token("@google.com", 12, 23, false),
- Token("heře!", 24, 29, false)}));
+ {Token("Hělló", 0, 5),
+ Token("fěě", 6, 9),
+ Token("bař", 9, 12),
+ Token("@google.com", 12, 23),
+ Token("heře!", 24, 29)}));
// clang-format on
}
TEST(FeatureProcessorTest, SplitTokensOnSelectionBoundariesBegin) {
- std::vector<Token> tokens{Token("Hělló", 0, 5, false),
- Token("fěěbař@google.com", 6, 23, false),
- Token("heře!", 24, 29, false)};
+ std::vector<Token> tokens{Token("Hělló", 0, 5),
+ Token("fěěbař@google.com", 6, 23),
+ Token("heře!", 24, 29)};
internal::SplitTokensOnSelectionBoundaries({6, 12}, &tokens);
// clang-format off
EXPECT_THAT(tokens, ElementsAreArray(
- {Token("Hělló", 0, 5, false),
- Token("fěěbař", 6, 12, false),
- Token("@google.com", 12, 23, false),
- Token("heře!", 24, 29, false)}));
+ {Token("Hělló", 0, 5),
+ Token("fěěbař", 6, 12),
+ Token("@google.com", 12, 23),
+ Token("heře!", 24, 29)}));
// clang-format on
}
TEST(FeatureProcessorTest, SplitTokensOnSelectionBoundariesEnd) {
- std::vector<Token> tokens{Token("Hělló", 0, 5, false),
- Token("fěěbař@google.com", 6, 23, false),
- Token("heře!", 24, 29, false)};
+ std::vector<Token> tokens{Token("Hělló", 0, 5),
+ Token("fěěbař@google.com", 6, 23),
+ Token("heře!", 24, 29)};
internal::SplitTokensOnSelectionBoundaries({9, 23}, &tokens);
// clang-format off
EXPECT_THAT(tokens, ElementsAreArray(
- {Token("Hělló", 0, 5, false),
- Token("fěě", 6, 9, false),
- Token("bař@google.com", 9, 23, false),
- Token("heře!", 24, 29, false)}));
+ {Token("Hělló", 0, 5),
+ Token("fěě", 6, 9),
+ Token("bař@google.com", 9, 23),
+ Token("heře!", 24, 29)}));
// clang-format on
}
TEST(FeatureProcessorTest, SplitTokensOnSelectionBoundariesWhole) {
- std::vector<Token> tokens{Token("Hělló", 0, 5, false),
- Token("fěěbař@google.com", 6, 23, false),
- Token("heře!", 24, 29, false)};
+ std::vector<Token> tokens{Token("Hělló", 0, 5),
+ Token("fěěbař@google.com", 6, 23),
+ Token("heře!", 24, 29)};
internal::SplitTokensOnSelectionBoundaries({6, 23}, &tokens);
// clang-format off
EXPECT_THAT(tokens, ElementsAreArray(
- {Token("Hělló", 0, 5, false),
- Token("fěěbař@google.com", 6, 23, false),
- Token("heře!", 24, 29, false)}));
+ {Token("Hělló", 0, 5),
+ Token("fěěbař@google.com", 6, 23),
+ Token("heře!", 24, 29)}));
// clang-format on
}
TEST(FeatureProcessorTest, SplitTokensOnSelectionBoundariesCrossToken) {
- std::vector<Token> tokens{Token("Hělló", 0, 5, false),
- Token("fěěbař@google.com", 6, 23, false),
- Token("heře!", 24, 29, false)};
+ std::vector<Token> tokens{Token("Hělló", 0, 5),
+ Token("fěěbař@google.com", 6, 23),
+ Token("heře!", 24, 29)};
internal::SplitTokensOnSelectionBoundaries({2, 9}, &tokens);
// clang-format off
EXPECT_THAT(tokens, ElementsAreArray(
- {Token("Hě", 0, 2, false),
- Token("lló", 2, 5, false),
- Token("fěě", 6, 9, false),
- Token("bař@google.com", 9, 23, false),
- Token("heře!", 24, 29, false)}));
+ {Token("Hě", 0, 2),
+ Token("lló", 2, 5),
+ Token("fěě", 6, 9),
+ Token("bař@google.com", 9, 23),
+ Token("heře!", 24, 29)}));
// clang-format on
}
@@ -269,98 +269,25 @@
EXPECT_EQ(label2, label3);
}
-TEST(FeatureProcessorTest, GetFeaturesWithContextDropout) {
- FeatureProcessorOptions options;
- options.set_num_buckets(10);
- options.set_context_size(7);
- options.set_max_selection_span(7);
- options.add_chargram_orders(1);
- options.set_tokenize_on_space(true);
- options.set_context_dropout_probability(0.5);
- options.set_use_variable_context_dropout(true);
- TokenizationCodepointRange* config =
- options.add_tokenization_codepoint_config();
- config->set_start(32);
- config->set_end(33);
- config->set_role(TokenizationCodepointRange::WHITESPACE_SEPARATOR);
- FeatureProcessor feature_processor(options);
-
- // Test that two subsequent runs with random context dropout produce
- // different features.
- feature_processor.SetRandom(new std::mt19937);
-
- std::vector<std::vector<std::pair<int, float>>> features;
- std::vector<std::vector<std::pair<int, float>>> features2;
- std::vector<float> extra_features;
- std::vector<CodepointSpan> selection_label_spans;
- int selection_label;
- CodepointSpan selection_codepoint_label;
- int classification_label;
- EXPECT_TRUE(feature_processor.GetFeaturesAndLabels(
- "1 2 3 c o n t e x t X c o n t e x t 1 2 3", {20, 21}, {20, 21}, "",
- &features, &extra_features, &selection_label_spans, &selection_label,
- &selection_codepoint_label, &classification_label));
- EXPECT_TRUE(feature_processor.GetFeaturesAndLabels(
- "1 2 3 c o n t e x t X c o n t e x t 1 2 3", {20, 21}, {20, 21}, "",
- &features2, &extra_features, &selection_label_spans, &selection_label,
- &selection_codepoint_label, &classification_label));
-
- EXPECT_NE(features, features2);
-}
-
-TEST(FeatureProcessorTest, GetFeaturesWithLongerContext) {
- FeatureProcessorOptions options;
- options.set_num_buckets(10);
- options.set_context_size(9);
- options.set_max_selection_span(7);
- options.add_chargram_orders(1);
- options.set_tokenize_on_space(true);
- TokenizationCodepointRange* config =
- options.add_tokenization_codepoint_config();
- config->set_start(32);
- config->set_end(33);
- config->set_role(TokenizationCodepointRange::WHITESPACE_SEPARATOR);
- FeatureProcessor feature_processor(options);
-
- std::vector<std::vector<std::pair<int, float>>> features;
- std::vector<float> extra_features;
- std::vector<CodepointSpan> selection_label_spans;
- int selection_label;
- CodepointSpan selection_codepoint_label;
- int classification_label;
- EXPECT_TRUE(feature_processor.GetFeaturesAndLabels(
- "1 2 3 c o n t e x t X c o n t e x t 1 2 3", {20, 21}, {20, 21}, "",
- &features, &extra_features, &selection_label_spans, &selection_label,
- &selection_codepoint_label, &classification_label));
- EXPECT_EQ(19, features.size());
-
- // Should pad the string.
- EXPECT_TRUE(feature_processor.GetFeaturesAndLabels(
- "X", {0, 1}, {0, 1}, "", &features, &extra_features,
- &selection_label_spans, &selection_label, &selection_codepoint_label,
- &classification_label));
- EXPECT_EQ(19, features.size());
-}
-
TEST(FeatureProcessorTest, CenterTokenFromClick) {
int token_index;
// Exactly aligned indices.
token_index = internal::CenterTokenFromClick(
- {6, 11}, {Token("Hělló", 0, 5, false), Token("world", 6, 11, false),
- Token("heře!", 12, 17, false)});
+ {6, 11},
+ {Token("Hělló", 0, 5), Token("world", 6, 11), Token("heře!", 12, 17)});
EXPECT_EQ(token_index, 1);
// Click is contained in a token.
token_index = internal::CenterTokenFromClick(
- {13, 17}, {Token("Hělló", 0, 5, false), Token("world", 6, 11, false),
- Token("heře!", 12, 17, false)});
+ {13, 17},
+ {Token("Hělló", 0, 5), Token("world", 6, 11), Token("heře!", 12, 17)});
EXPECT_EQ(token_index, 2);
// Click spans two tokens.
token_index = internal::CenterTokenFromClick(
- {6, 17}, {Token("Hělló", 0, 5, false), Token("world", 6, 11, false),
- Token("heře!", 12, 17, false)});
+ {6, 17},
+ {Token("Hělló", 0, 5), Token("world", 6, 11), Token("heře!", 12, 17)});
EXPECT_EQ(token_index, kInvalidIndex);
}
@@ -369,37 +296,37 @@
// Selection of length 3. Exactly aligned indices.
token_index = internal::CenterTokenFromMiddleOfSelection(
- {7, 27}, {Token("Token1", 0, 6, false), Token("Token2", 7, 13, false),
- Token("Token3", 14, 20, false), Token("Token4", 21, 27, false),
- Token("Token5", 28, 34, false)});
+ {7, 27},
+ {Token("Token1", 0, 6), Token("Token2", 7, 13), Token("Token3", 14, 20),
+ Token("Token4", 21, 27), Token("Token5", 28, 34)});
EXPECT_EQ(token_index, 2);
// Selection of length 1 token. Exactly aligned indices.
token_index = internal::CenterTokenFromMiddleOfSelection(
- {21, 27}, {Token("Token1", 0, 6, false), Token("Token2", 7, 13, false),
- Token("Token3", 14, 20, false), Token("Token4", 21, 27, false),
- Token("Token5", 28, 34, false)});
+ {21, 27},
+ {Token("Token1", 0, 6), Token("Token2", 7, 13), Token("Token3", 14, 20),
+ Token("Token4", 21, 27), Token("Token5", 28, 34)});
EXPECT_EQ(token_index, 3);
// Selection marks sub-token range, with no tokens in it.
token_index = internal::CenterTokenFromMiddleOfSelection(
- {29, 33}, {Token("Token1", 0, 6, false), Token("Token2", 7, 13, false),
- Token("Token3", 14, 20, false), Token("Token4", 21, 27, false),
- Token("Token5", 28, 34, false)});
+ {29, 33},
+ {Token("Token1", 0, 6), Token("Token2", 7, 13), Token("Token3", 14, 20),
+ Token("Token4", 21, 27), Token("Token5", 28, 34)});
EXPECT_EQ(token_index, kInvalidIndex);
// Selection of length 2. Sub-token indices.
token_index = internal::CenterTokenFromMiddleOfSelection(
- {3, 25}, {Token("Token1", 0, 6, false), Token("Token2", 7, 13, false),
- Token("Token3", 14, 20, false), Token("Token4", 21, 27, false),
- Token("Token5", 28, 34, false)});
+ {3, 25},
+ {Token("Token1", 0, 6), Token("Token2", 7, 13), Token("Token3", 14, 20),
+ Token("Token4", 21, 27), Token("Token5", 28, 34)});
EXPECT_EQ(token_index, 1);
// Selection of length 1. Sub-token indices.
token_index = internal::CenterTokenFromMiddleOfSelection(
- {22, 34}, {Token("Token1", 0, 6, false), Token("Token2", 7, 13, false),
- Token("Token3", 14, 20, false), Token("Token4", 21, 27, false),
- Token("Token5", 28, 34, false)});
+ {22, 34},
+ {Token("Token1", 0, 6), Token("Token2", 7, 13), Token("Token3", 14, 20),
+ Token("Token4", 21, 27), Token("Token5", 28, 34)});
EXPECT_EQ(token_index, 4);
// Some invalid ones.
@@ -407,42 +334,6 @@
EXPECT_EQ(token_index, -1);
}
-TEST(FeatureProcessorTest, GetFeaturesForSharing) {
- FeatureProcessorOptions options;
- options.set_num_buckets(10);
- options.set_context_size(9);
- options.set_max_selection_span(7);
- options.add_chargram_orders(1);
- options.set_tokenize_on_space(true);
- options.set_center_token_selection_method(
- FeatureProcessorOptions::CENTER_TOKEN_MIDDLE_OF_SELECTION);
- options.set_only_use_line_with_click(true);
- options.set_split_tokens_on_selection_boundaries(true);
- options.set_extract_selection_mask_feature(true);
- TokenizationCodepointRange* config =
- options.add_tokenization_codepoint_config();
- config->set_start(32);
- config->set_end(33);
- config->set_role(TokenizationCodepointRange::WHITESPACE_SEPARATOR);
- config = options.add_tokenization_codepoint_config();
- config->set_start(10);
- config->set_end(11);
- config->set_role(TokenizationCodepointRange::WHITESPACE_SEPARATOR);
- FeatureProcessor feature_processor(options);
-
- std::vector<std::vector<std::pair<int, float>>> features;
- std::vector<float> extra_features;
- std::vector<CodepointSpan> selection_label_spans;
- int selection_label;
- CodepointSpan selection_codepoint_label;
- int classification_label;
- EXPECT_TRUE(feature_processor.GetFeaturesAndLabels(
- "line 1\nline2\nsome entity\n line 4", {13, 24}, {13, 24}, "", &features,
- &extra_features, &selection_label_spans, &selection_label,
- &selection_codepoint_label, &classification_label));
- EXPECT_EQ(19, features.size());
-}
-
TEST(FeatureProcessorTest, SupportedCodepointsRatio) {
FeatureProcessorOptions options;
options.set_context_size(2);
@@ -488,26 +379,144 @@
EXPECT_FALSE(feature_processor.IsCodepointSupported(10001));
EXPECT_TRUE(feature_processor.IsCodepointSupported(25000));
- std::vector<nlp_core::FeatureVector> features;
+ std::vector<Token> tokens;
+ int click_pos;
std::vector<float> extra_features;
+ std::unique_ptr<CachedFeatures> cached_features;
+
+ auto feature_fn = [](const std::vector<int>& sparse_features,
+ const std::vector<float>& dense_features,
+ float* embedding) { return true; };
options.set_min_supported_codepoint_ratio(0.0);
- feature_processor = TestingFeatureProcessor(options);
- EXPECT_TRUE(feature_processor.GetFeatures("ěěě řřř eee", {4, 7}, &features,
- &extra_features,
- /*selection_label_spans=*/nullptr));
+ TestingFeatureProcessor feature_processor2(options);
+ EXPECT_TRUE(feature_processor2.ExtractFeatures("ěěě řřř eee", {4, 7}, {0, 0},
+ feature_fn, 2, &tokens,
+ &click_pos, &cached_features));
options.set_min_supported_codepoint_ratio(0.2);
- feature_processor = TestingFeatureProcessor(options);
- EXPECT_TRUE(feature_processor.GetFeatures("ěěě řřř eee", {4, 7}, &features,
- &extra_features,
- /*selection_label_spans=*/nullptr));
+ TestingFeatureProcessor feature_processor3(options);
+ EXPECT_TRUE(feature_processor3.ExtractFeatures("ěěě řřř eee", {4, 7}, {0, 0},
+ feature_fn, 2, &tokens,
+ &click_pos, &cached_features));
options.set_min_supported_codepoint_ratio(0.5);
- feature_processor = TestingFeatureProcessor(options);
- EXPECT_FALSE(feature_processor.GetFeatures(
- "ěěě řřř eee", {4, 7}, &features, &extra_features,
- /*selection_label_spans=*/nullptr));
+ TestingFeatureProcessor feature_processor4(options);
+ EXPECT_FALSE(feature_processor4.ExtractFeatures(
+ "ěěě řřř eee", {4, 7}, {0, 0}, feature_fn, 2, &tokens, &click_pos,
+ &cached_features));
+}
+
+TEST(FeatureProcessorTest, StripUnusedTokensWithNoRelativeClick) {
+ std::vector<Token> tokens_orig{
+ Token("0", 0, 0), Token("1", 0, 0), Token("2", 0, 0), Token("3", 0, 0),
+ Token("4", 0, 0), Token("5", 0, 0), Token("6", 0, 0), Token("7", 0, 0),
+ Token("8", 0, 0), Token("9", 0, 0), Token("10", 0, 0), Token("11", 0, 0),
+ Token("12", 0, 0)};
+
+ std::vector<Token> tokens;
+ int click_index;
+
+ // Try to click first token and see if it gets padded from left.
+ tokens = tokens_orig;
+ click_index = 0;
+ internal::StripOrPadTokens({0, 0}, 2, &tokens, &click_index);
+ // clang-format off
+ EXPECT_EQ(tokens, std::vector<Token>({Token(),
+ Token(),
+ Token("0", 0, 0),
+ Token("1", 0, 0),
+ Token("2", 0, 0)}));
+ // clang-format on
+ EXPECT_EQ(click_index, 2);
+
+ // When we click the second token nothing should get padded.
+ tokens = tokens_orig;
+ click_index = 2;
+ internal::StripOrPadTokens({0, 0}, 2, &tokens, &click_index);
+ // clang-format off
+ EXPECT_EQ(tokens, std::vector<Token>({Token("0", 0, 0),
+ Token("1", 0, 0),
+ Token("2", 0, 0),
+ Token("3", 0, 0),
+ Token("4", 0, 0)}));
+ // clang-format on
+ EXPECT_EQ(click_index, 2);
+
+ // When we click the last token tokens should get padded from the right.
+ tokens = tokens_orig;
+ click_index = 12;
+ internal::StripOrPadTokens({0, 0}, 2, &tokens, &click_index);
+ // clang-format off
+ EXPECT_EQ(tokens, std::vector<Token>({Token("10", 0, 0),
+ Token("11", 0, 0),
+ Token("12", 0, 0),
+ Token(),
+ Token()}));
+ // clang-format on
+ EXPECT_EQ(click_index, 2);
+}
+
+TEST(FeatureProcessorTest, StripUnusedTokensWithRelativeClick) {
+ std::vector<Token> tokens_orig{
+ Token("0", 0, 0), Token("1", 0, 0), Token("2", 0, 0), Token("3", 0, 0),
+ Token("4", 0, 0), Token("5", 0, 0), Token("6", 0, 0), Token("7", 0, 0),
+ Token("8", 0, 0), Token("9", 0, 0), Token("10", 0, 0), Token("11", 0, 0),
+ Token("12", 0, 0)};
+
+ std::vector<Token> tokens;
+ int click_index;
+
+ // Try to click first token and see if it gets padded from left to maximum
+ // context_size.
+ tokens = tokens_orig;
+ click_index = 0;
+ internal::StripOrPadTokens({2, 3}, 2, &tokens, &click_index);
+ // clang-format off
+ EXPECT_EQ(tokens, std::vector<Token>({Token(),
+ Token(),
+ Token("0", 0, 0),
+ Token("1", 0, 0),
+ Token("2", 0, 0),
+ Token("3", 0, 0),
+ Token("4", 0, 0),
+ Token("5", 0, 0)}));
+ // clang-format on
+ EXPECT_EQ(click_index, 2);
+
+ // Clicking to the middle with enough context should not produce any padding.
+ tokens = tokens_orig;
+ click_index = 6;
+ internal::StripOrPadTokens({3, 1}, 2, &tokens, &click_index);
+ // clang-format off
+ EXPECT_EQ(tokens, std::vector<Token>({Token("1", 0, 0),
+ Token("2", 0, 0),
+ Token("3", 0, 0),
+ Token("4", 0, 0),
+ Token("5", 0, 0),
+ Token("6", 0, 0),
+ Token("7", 0, 0),
+ Token("8", 0, 0),
+ Token("9", 0, 0)}));
+ // clang-format on
+ EXPECT_EQ(click_index, 5);
+
+ // Clicking at the end should pad right to maximum context_size.
+ tokens = tokens_orig;
+ click_index = 11;
+ internal::StripOrPadTokens({3, 1}, 2, &tokens, &click_index);
+ // clang-format off
+ EXPECT_EQ(tokens, std::vector<Token>({Token("6", 0, 0),
+ Token("7", 0, 0),
+ Token("8", 0, 0),
+ Token("9", 0, 0),
+ Token("10", 0, 0),
+ Token("11", 0, 0),
+ Token("12", 0, 0),
+ Token(),
+ Token()}));
+ // clang-format on
+ EXPECT_EQ(click_index, 5);
}
} // namespace
diff --git a/tests/token-feature-extractor_test.cc b/tests/token-feature-extractor_test.cc
index 55a5228..277549e 100644
--- a/tests/token-feature-extractor_test.cc
+++ b/tests/token-feature-extractor_test.cc
@@ -40,7 +40,7 @@
std::vector<int> sparse_features;
std::vector<float> dense_features;
- extractor.Extract(Token{"Hello", 0, 5, true}, &sparse_features,
+ extractor.Extract(Token{"Hello", 0, 5}, true, &sparse_features,
&dense_features);
EXPECT_THAT(sparse_features,
@@ -68,7 +68,7 @@
sparse_features.clear();
dense_features.clear();
- extractor.Extract(Token{"world!", 23, 29, false}, &sparse_features,
+ extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
&dense_features);
EXPECT_THAT(sparse_features,
@@ -110,7 +110,7 @@
std::vector<int> sparse_features;
std::vector<float> dense_features;
- extractor.Extract(Token{"Hělló", 0, 5, true}, &sparse_features,
+ extractor.Extract(Token{"Hělló", 0, 5}, true, &sparse_features,
&dense_features);
EXPECT_THAT(sparse_features,
@@ -138,7 +138,7 @@
sparse_features.clear();
dense_features.clear();
- extractor.Extract(Token{"world!", 23, 29, false}, &sparse_features,
+ extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
&dense_features);
EXPECT_THAT(sparse_features,
@@ -179,25 +179,25 @@
std::vector<int> sparse_features;
std::vector<float> dense_features;
- extractor.Extract(Token{"Hělló", 0, 5, true}, &sparse_features,
+ extractor.Extract(Token{"Hělló", 0, 5}, true, &sparse_features,
&dense_features);
EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0}));
sparse_features.clear();
dense_features.clear();
- extractor.Extract(Token{"world!", 23, 29, false}, &sparse_features,
+ extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
&dense_features);
EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0}));
sparse_features.clear();
dense_features.clear();
- extractor.Extract(Token{"Ř", 23, 29, false}, &sparse_features,
+ extractor.Extract(Token{"Ř", 23, 29}, false, &sparse_features,
&dense_features);
EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0}));
sparse_features.clear();
dense_features.clear();
- extractor.Extract(Token{"ř", 23, 29, false}, &sparse_features,
+ extractor.Extract(Token{"ř", 23, 29}, false, &sparse_features,
&dense_features);
EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0}));
}
@@ -212,15 +212,15 @@
std::vector<int> sparse_features;
std::vector<float> dense_features;
- extractor.Extract(Token{"9:30am", 0, 6, true}, &sparse_features,
+ extractor.Extract(Token{"9:30am", 0, 6}, true, &sparse_features,
&dense_features);
std::vector<int> sparse_features2;
- extractor.Extract(Token{"5:32am", 0, 6, true}, &sparse_features2,
+ extractor.Extract(Token{"5:32am", 0, 6}, true, &sparse_features2,
&dense_features);
EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
- extractor.Extract(Token{"10:32am", 0, 6, true}, &sparse_features2,
+ extractor.Extract(Token{"10:32am", 0, 6}, true, &sparse_features2,
&dense_features);
EXPECT_THAT(sparse_features,
testing::Not(testing::ElementsAreArray(sparse_features2)));
@@ -236,15 +236,15 @@
std::vector<int> sparse_features;
std::vector<float> dense_features;
- extractor.Extract(Token{"9:30am", 0, 6, true}, &sparse_features,
+ extractor.Extract(Token{"9:30am", 0, 6}, true, &sparse_features,
&dense_features);
std::vector<int> sparse_features2;
- extractor.Extract(Token{"5:32am", 0, 6, true}, &sparse_features2,
+ extractor.Extract(Token{"5:32am", 0, 6}, true, &sparse_features2,
&dense_features);
EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
- extractor.Extract(Token{"10:32am", 0, 6, true}, &sparse_features2,
+ extractor.Extract(Token{"10:32am", 0, 6}, true, &sparse_features2,
&dense_features);
EXPECT_THAT(sparse_features,
testing::Not(testing::ElementsAreArray(sparse_features2)));
@@ -262,22 +262,22 @@
std::vector<int> sparse_features;
std::vector<float> dense_features;
- extractor.Extract(Token{"abCde", 0, 6, true}, &sparse_features,
+ extractor.Extract(Token{"abCde", 0, 6}, true, &sparse_features,
&dense_features);
EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0}));
dense_features.clear();
- extractor.Extract(Token{"abcde", 0, 6, true}, &sparse_features,
+ extractor.Extract(Token{"abcde", 0, 6}, true, &sparse_features,
&dense_features);
EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, -1.0}));
dense_features.clear();
- extractor.Extract(Token{"12c45", 0, 6, true}, &sparse_features,
+ extractor.Extract(Token{"12c45", 0, 6}, true, &sparse_features,
&dense_features);
EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0}));
dense_features.clear();
- extractor.Extract(Token{"12345", 0, 6, true}, &sparse_features,
+ extractor.Extract(Token{"12345", 0, 6}, true, &sparse_features,
&dense_features);
EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 1.0}));
}
@@ -294,7 +294,7 @@
// Test that this runs. ASAN should catch problems.
std::vector<int> sparse_features;
std::vector<float> dense_features;
- extractor.Extract(Token{"abcdefghijklmnopqřstuvwxyz", 0, 0, true},
+ extractor.Extract(Token{"abcdefghijklmnopqřstuvwxyz", 0, 0}, true,
&sparse_features, &dense_features);
EXPECT_THAT(sparse_features,
@@ -325,13 +325,13 @@
"x", "Hello", "Hey,", "Hi", ""}) {
std::vector<int> sparse_features_unicode;
std::vector<float> dense_features_unicode;
- extractor_unicode.Extract(Token{input, 0, 0, true},
+ extractor_unicode.Extract(Token{input, 0, 0}, true,
&sparse_features_unicode,
&dense_features_unicode);
std::vector<int> sparse_features_ascii;
std::vector<float> dense_features_ascii;
- extractor_ascii.Extract(Token{input, 0, 0, true}, &sparse_features_ascii,
+ extractor_ascii.Extract(Token{input, 0, 0}, true, &sparse_features_ascii,
&dense_features_ascii);
EXPECT_THAT(sparse_features_unicode, sparse_features_ascii) << input;
@@ -352,7 +352,7 @@
std::vector<int> sparse_features;
std::vector<float> dense_features;
- extractor.Extract(Token(), &sparse_features, &dense_features);
+ extractor.Extract(Token(), false, &sparse_features, &dense_features);
EXPECT_THAT(sparse_features,
testing::ElementsAreArray({extractor.HashToken("<PAD>")}));
diff --git a/tests/tokenizer_test.cc b/tests/tokenizer_test.cc
index f0034bb..cdb90a9 100644
--- a/tests/tokenizer_test.cc
+++ b/tests/tokenizer_test.cc
@@ -104,8 +104,8 @@
TestingTokenizer tokenizer(configs);
std::vector<Token> tokens = tokenizer.Tokenize("Hello world!");
- EXPECT_THAT(tokens, ElementsAreArray({Token("Hello", 0, 5, false),
- Token("world!", 6, 12, false)}));
+ EXPECT_THAT(tokens,
+ ElementsAreArray({Token("Hello", 0, 5), Token("world!", 6, 12)}));
}
TEST(TokenizerTest, TokenizeComplex) {
@@ -243,17 +243,17 @@
// clang-format off
EXPECT_THAT(
tokens,
- ElementsAreArray({Token("問", 0, 1, false),
- Token("少", 1, 2, false),
- Token("目", 2, 3, false),
- Token("hello", 4, 9, false),
- Token("木", 10, 11, false),
- Token("輸", 11, 12, false),
- Token("ย", 12, 13, false),
- Token("า", 13, 14, false),
- Token("ม", 14, 15, false),
- Token("き", 15, 16, false),
- Token("ゃ", 16, 17, false)}));
+ ElementsAreArray({Token("問", 0, 1),
+ Token("少", 1, 2),
+ Token("目", 2, 3),
+ Token("hello", 4, 9),
+ Token("木", 10, 11),
+ Token("輸", 11, 12),
+ Token("ย", 12, 13),
+ Token("า", 13, 14),
+ Token("ม", 14, 15),
+ Token("き", 15, 16),
+ Token("ゃ", 16, 17)}));
// clang-format on
}