| /* |
| * Copyright (C) 2017 The Android Open Source Project |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| #ifndef LIBTEXTCLASSIFIER_SMARTSELECT_CACHED_FEATURES_H_ |
| #define LIBTEXTCLASSIFIER_SMARTSELECT_CACHED_FEATURES_H_ |
| |
| #include <memory> |
| #include <vector> |
| |
| #include "base.h" |
| #include "common/vector-span.h" |
| #include "smartselect/types.h" |
| |
| namespace libtextclassifier { |
| |
| // Holds state for extracting features across multiple calls and reusing them. |
| // Assumes that features for each Token are independent. |
| class CachedFeatures { |
| public: |
| // Extracts the features for the given sequence of tokens. |
| // - context_size: Specifies how many tokens to the left, and how many |
| // tokens to the right spans the context. |
| // - sparse_features, dense_features: Extracted features for each token. |
| // - feature_vector_fn: Writes features for given Token to the specified |
| // storage. |
| // NOTE: The function can assume that the underlying |
| // storage is initialized to all zeros. |
| // - feature_vector_size: Size of a feature vector for one Token. |
| CachedFeatures(VectorSpan<Token> tokens, int context_size, |
| const std::vector<std::vector<int>>& sparse_features, |
| const std::vector<std::vector<float>>& dense_features, |
| const std::function<bool(const std::vector<int>&, |
| const std::vector<float>&, float*)>& |
| feature_vector_fn, |
| int feature_vector_size) |
| : tokens_(tokens), |
| context_size_(context_size), |
| feature_vector_size_(feature_vector_size), |
| remap_v0_feature_vector_(false), |
| remap_v0_chargram_embedding_size_(-1) { |
| Extract(sparse_features, dense_features, feature_vector_fn); |
| } |
| |
| // Gets a VectorSpan with the features for given click position. |
| bool Get(int click_pos, VectorSpan<float>* features, |
| VectorSpan<Token>* output_tokens); |
| |
| // Turns on a compatibility mode, which re-maps the extracted features to the |
| // v0 feature format (where the dense features were at the end). |
| // WARNING: Internally v0_feature_storage_ is used as a backing buffer for |
| // VectorSpan<float>, so the output of Extract is valid only until the next |
| // call or destruction of the current CachedFeatures object. |
| // TODO(zilka): Remove when we'll have retrained models. |
| void SetV0FeatureMode(int chargram_embedding_size) { |
| remap_v0_feature_vector_ = true; |
| remap_v0_chargram_embedding_size_ = chargram_embedding_size; |
| v0_feature_storage_.resize(feature_vector_size_ * (context_size_ * 2 + 1)); |
| } |
| |
| protected: |
| // Extracts features for all tokens and stores them for later retrieval. |
| void Extract(const std::vector<std::vector<int>>& sparse_features, |
| const std::vector<std::vector<float>>& dense_features, |
| const std::function<bool(const std::vector<int>&, |
| const std::vector<float>&, float*)>& |
| feature_vector_fn); |
| |
| // Remaps extracted features to V0 feature format. The mapping is using |
| // the v0_feature_storage_ as the backing storage for the mapped features. |
| // For each token the features consist of: |
| // - chargram embeddings |
| // - dense features |
| // They are concatenated together as [chargram embeddings; dense features] |
| // for each token independently. |
| // The V0 features require that the chargram embeddings for tokens are |
| // concatenated first together, and at the end, the dense features for the |
| // tokens are concatenated to it. |
| void RemapV0FeatureVector(VectorSpan<float>* features); |
| |
| private: |
| const VectorSpan<Token> tokens_; |
| const int context_size_; |
| const int feature_vector_size_; |
| bool remap_v0_feature_vector_; |
| int remap_v0_chargram_embedding_size_; |
| |
| std::vector<float> features_; |
| std::vector<float> v0_feature_storage_; |
| }; |
| |
| } // namespace libtextclassifier |
| |
| #endif // LIBTEXTCLASSIFIER_SMARTSELECT_CACHED_FEATURES_H_ |