blob: 990233cd694325763cbe876e37893369f36270ab [file] [log] [blame]
/*
* Copyright (C) 2017 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LIBTEXTCLASSIFIER_SMARTSELECT_CACHED_FEATURES_H_
#define LIBTEXTCLASSIFIER_SMARTSELECT_CACHED_FEATURES_H_
#include <memory>
#include <vector>
#include "common/vector-span.h"
#include "smartselect/types.h"
namespace libtextclassifier {
// Holds state for extracting features across multiple calls and reusing them.
// Assumes that features for each Token are independent.
class CachedFeatures {
public:
// Extracts the features for the given sequence of tokens.
// - context_size: Specifies how many tokens to the left, and how many
// tokens to the right spans the context.
// - sparse_features, dense_features: Extracted features for each token.
// - feature_vector_fn: Writes features for given Token to the specified
// storage.
// NOTE: The function can assume that the underlying
// storage is initialized to all zeros.
// - feature_vector_size: Size of a feature vector for one Token.
CachedFeatures(VectorSpan<Token> tokens, int context_size,
const std::vector<std::vector<int>>& sparse_features,
const std::vector<std::vector<float>>& dense_features,
const std::function<bool(const std::vector<int>&,
const std::vector<float>&, float*)>&
feature_vector_fn,
int feature_vector_size)
: tokens_(tokens),
context_size_(context_size),
feature_vector_size_(feature_vector_size),
remap_v0_feature_vector_(false),
remap_v0_chargram_embedding_size_(-1) {
Extract(sparse_features, dense_features, feature_vector_fn);
}
// Gets a VectorSpan with the features for given click position.
bool Get(int click_pos, VectorSpan<float>* features,
VectorSpan<Token>* output_tokens);
// Turns on a compatibility mode, which re-maps the extracted features to the
// v0 feature format (where the dense features were at the end).
// WARNING: Internally v0_feature_storage_ is used as a backing buffer for
// VectorSpan<float>, so the output of Extract is valid only until the next
// call or destruction of the current CachedFeatures object.
// TODO(zilka): Remove when we'll have retrained models.
void SetV0FeatureMode(int chargram_embedding_size) {
remap_v0_feature_vector_ = true;
remap_v0_chargram_embedding_size_ = chargram_embedding_size;
v0_feature_storage_.resize(feature_vector_size_ * (context_size_ * 2 + 1));
}
protected:
// Extracts features for all tokens and stores them for later retrieval.
void Extract(const std::vector<std::vector<int>>& sparse_features,
const std::vector<std::vector<float>>& dense_features,
const std::function<bool(const std::vector<int>&,
const std::vector<float>&, float*)>&
feature_vector_fn);
// Remaps extracted features to V0 feature format. The mapping is using
// the v0_feature_storage_ as the backing storage for the mapped features.
// For each token the features consist of:
// - chargram embeddings
// - dense features
// They are concatenated together as [chargram embeddings; dense features]
// for each token independently.
// The V0 features require that the chargram embeddings for tokens are
// concatenated first together, and at the end, the dense features for the
// tokens are concatenated to it.
void RemapV0FeatureVector(VectorSpan<float>* features);
private:
const VectorSpan<Token> tokens_;
const int context_size_;
const int feature_vector_size_;
bool remap_v0_feature_vector_;
int remap_v0_chargram_embedding_size_;
std::vector<float> features_;
std::vector<float> v0_feature_storage_;
};
} // namespace libtextclassifier
#endif // LIBTEXTCLASSIFIER_SMARTSELECT_CACHED_FEATURES_H_