blob: f51b6e5838732212b4010e7844457649322e4f92 [file] [log] [blame]
/*
* Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_FEATURE_EXTRACTOR_H_
#define NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_FEATURE_EXTRACTOR_H_
#include <memory>
#include <string>
#include <vector>
#include "lang_id/common/fel/feature-extractor.h"
#include "lang_id/common/fel/task-context.h"
#include "lang_id/common/fel/workspace.h"
#include "lang_id/common/lite_base/attributes.h"
namespace libtextclassifier3 {
namespace mobile {
// An EmbeddingFeatureExtractor manages the extraction of features for
// embedding-based models. It wraps a sequence of underlying classes of feature
// extractors, along with associated predicate maps. Each class of feature
// extractors is associated with a name, e.g., "words", "labels", "tags".
//
// The class is split between a generic abstract version,
// GenericEmbeddingFeatureExtractor (that can be initialized without knowing the
// signature of the ExtractFeatures method) and a typed version.
//
// The predicate maps must be initialized before use: they can be loaded using
// Read() or updated via UpdateMapsForExample.
class GenericEmbeddingFeatureExtractor {
public:
// Constructs this GenericEmbeddingFeatureExtractor.
//
// |arg_prefix| is a string prefix for the relevant TaskContext parameters, to
// avoid name clashes. See GetParamName().
explicit GenericEmbeddingFeatureExtractor(const string &arg_prefix)
: arg_prefix_(arg_prefix) {}
virtual ~GenericEmbeddingFeatureExtractor() {}
// Sets/inits up predicate maps and embedding space names that are common for
// all embedding based feature extractors.
//
// Returns true on success, false otherwise.
SAFTM_MUST_USE_RESULT virtual bool Setup(TaskContext *context);
SAFTM_MUST_USE_RESULT virtual bool Init(TaskContext *context);
// Requests workspace for the underlying feature extractors. This is
// implemented in the typed class.
virtual void RequestWorkspaces(WorkspaceRegistry *registry) = 0;
// Returns number of embedding spaces.
int NumEmbeddings() const { return embedding_dims_.size(); }
const std::vector<string> &embedding_fml() const { return embedding_fml_; }
// Get parameter name by concatenating the prefix and the original name.
string GetParamName(const string &param_name) const {
string full_name = arg_prefix_;
full_name.push_back('_');
full_name.append(param_name);
return full_name;
}
private:
// Prefix for TaskContext parameters.
const string arg_prefix_;
// Embedding space names for parameter sharing.
std::vector<string> embedding_names_;
// FML strings for each feature extractor.
std::vector<string> embedding_fml_;
// Size of each of the embedding spaces (maximum predicate id).
std::vector<int> embedding_sizes_;
// Embedding dimensions of the embedding spaces (i.e. 32, 64 etc.)
std::vector<int> embedding_dims_;
};
// Templated, object-specific implementation of the
// EmbeddingFeatureExtractor. EXTRACTOR should be a FeatureExtractor<OBJ,
// ARGS...> class that has the appropriate FeatureTraits() to ensure that
// locator type features work.
//
// Note: for backwards compatibility purposes, this always reads the FML spec
// from "<prefix>_features".
template <class EXTRACTOR, class OBJ, class... ARGS>
class EmbeddingFeatureExtractor : public GenericEmbeddingFeatureExtractor {
public:
// Constructs this EmbeddingFeatureExtractor.
//
// |arg_prefix| is a string prefix for the relevant TaskContext parameters, to
// avoid name clashes. See GetParamName().
explicit EmbeddingFeatureExtractor(const string &arg_prefix)
: GenericEmbeddingFeatureExtractor(arg_prefix) {}
// Sets up all predicate maps, feature extractors, and flags.
SAFTM_MUST_USE_RESULT bool Setup(TaskContext *context) override {
if (!GenericEmbeddingFeatureExtractor::Setup(context)) {
return false;
}
feature_extractors_.resize(embedding_fml().size());
for (int i = 0; i < embedding_fml().size(); ++i) {
feature_extractors_[i].reset(new EXTRACTOR());
if (!feature_extractors_[i]->Parse(embedding_fml()[i])) return false;
if (!feature_extractors_[i]->Setup(context)) return false;
}
return true;
}
// Initializes resources needed by the feature extractors.
SAFTM_MUST_USE_RESULT bool Init(TaskContext *context) override {
if (!GenericEmbeddingFeatureExtractor::Init(context)) return false;
for (auto &feature_extractor : feature_extractors_) {
if (!feature_extractor->Init(context)) return false;
}
return true;
}
// Requests workspaces from the registry. Must be called after Init(), and
// before Preprocess().
void RequestWorkspaces(WorkspaceRegistry *registry) override {
for (auto &feature_extractor : feature_extractors_) {
feature_extractor->RequestWorkspaces(registry);
}
}
// Must be called on the object one state for each sentence, before any
// feature extraction (e.g., UpdateMapsForExample, ExtractFeatures).
void Preprocess(WorkspaceSet *workspaces, OBJ *obj) const {
for (auto &feature_extractor : feature_extractors_) {
feature_extractor->Preprocess(workspaces, obj);
}
}
// Extracts features using the extractors. Note that features must already
// be initialized to the correct number of feature extractors. No predicate
// mapping is applied.
void ExtractFeatures(const WorkspaceSet &workspaces, const OBJ &obj,
ARGS... args,
std::vector<FeatureVector> *features) const {
// DCHECK(features != nullptr);
// DCHECK_EQ(features->size(), feature_extractors_.size());
for (int i = 0; i < feature_extractors_.size(); ++i) {
(*features)[i].clear();
feature_extractors_[i]->ExtractFeatures(workspaces, obj, args...,
&(*features)[i]);
}
}
private:
// Templated feature extractor class.
std::vector<std::unique_ptr<EXTRACTOR>> feature_extractors_;
};
} // namespace mobile
} // namespace nlp_saft
#endif // NLP_SAFT_COMPONENTS_COMMON_MOBILE_EMBEDDING_FEATURE_EXTRACTOR_H_