blob: ce79497a05d59f41148f1bea4ae0c9efce6a67ea [file] [log] [blame]
/*
* Copyright (C) 2017 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LIBTEXTCLASSIFIER_LANG_ID_LANG_ID_BRAIN_INTERFACE_H_
#define LIBTEXTCLASSIFIER_LANG_ID_LANG_ID_BRAIN_INTERFACE_H_
#include <string>
#include <vector>
#include "common/embedding-feature-extractor.h"
#include "common/feature-extractor.h"
#include "common/task-context.h"
#include "common/workspace.h"
#include "lang_id/light-sentence-features.h"
#include "lang_id/light-sentence.h"
#include "util/base/macros.h"
namespace libtextclassifier {
namespace nlp_core {
namespace lang_id {
// Specialization of EmbeddingFeatureExtractor that extracts from LightSentence.
class LangIdEmbeddingFeatureExtractor
: public EmbeddingFeatureExtractor<LightSentenceExtractor, LightSentence> {
public:
LangIdEmbeddingFeatureExtractor() {}
const std::string ArgPrefix() const override { return "language_identifier"; }
TC_DISALLOW_COPY_AND_ASSIGN(LangIdEmbeddingFeatureExtractor);
};
// Handles sentence -> numeric_features and numeric_prediction -> language
// conversions.
class LangIdBrainInterface {
public:
LangIdBrainInterface() {}
// Initializes resources and parameters.
bool Init(TaskContext *context) {
if (!feature_extractor_.Init(context)) {
return false;
}
feature_extractor_.RequestWorkspaces(&workspace_registry_);
return true;
}
// Extract features from sentence. On return, FeatureVector features[i]
// contains the features for the embedding space #i.
void GetFeatures(LightSentence *sentence,
std::vector<FeatureVector> *features) const {
WorkspaceSet workspace;
workspace.Reset(workspace_registry_);
feature_extractor_.Preprocess(&workspace, sentence);
return feature_extractor_.ExtractFeatures(workspace, *sentence, features);
}
int NumEmbeddings() const {
return feature_extractor_.NumEmbeddings();
}
private:
// Typed feature extractor for embeddings.
LangIdEmbeddingFeatureExtractor feature_extractor_;
// The registry of shared workspaces in the feature extractor.
WorkspaceRegistry workspace_registry_;
TC_DISALLOW_COPY_AND_ASSIGN(LangIdBrainInterface);
};
} // namespace lang_id
} // namespace nlp_core
} // namespace libtextclassifier
#endif // LIBTEXTCLASSIFIER_LANG_ID_LANG_ID_BRAIN_INTERFACE_H_