| /* |
| * Copyright (C) 2017 The Android Open Source Project |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| #include "lang_id/lang-id.h" |
| |
| #include <stdio.h> |
| |
| #include <algorithm> |
| #include <limits> |
| #include <memory> |
| #include <string> |
| #include <vector> |
| |
| #include "common/algorithm.h" |
| #include "common/embedding-network-params-from-proto.h" |
| #include "common/embedding-network.pb.h" |
| #include "common/embedding-network.h" |
| #include "common/feature-extractor.h" |
| #include "common/file-utils.h" |
| #include "common/list-of-strings.pb.h" |
| #include "common/memory_image/in-memory-model-data.h" |
| #include "common/mmap.h" |
| #include "common/softmax.h" |
| #include "common/task-context.h" |
| #include "lang_id/custom-tokenizer.h" |
| #include "lang_id/lang-id-brain-interface.h" |
| #include "lang_id/language-identifier-features.h" |
| #include "lang_id/light-sentence-features.h" |
| #include "lang_id/light-sentence.h" |
| #include "lang_id/relevant-script-feature.h" |
| #include "util/base/logging.h" |
| #include "util/base/macros.h" |
| |
| using ::libtextclassifier::nlp_core::file_utils::ParseProtoFromMemory; |
| |
| namespace libtextclassifier { |
| namespace nlp_core { |
| namespace lang_id { |
| |
| namespace { |
| // Default value for the probability threshold; see comments for |
| // LangId::SetProbabilityThreshold(). |
| static const float kDefaultProbabilityThreshold = 0.50; |
| |
| // Default value for min text size below which our model can't provide a |
| // meaningful prediction. |
| static const int kDefaultMinTextSizeInBytes = 20; |
| |
| // Initial value for the default language for LangId::FindLanguage(). The |
| // default language can be changed (for an individual LangId object) using |
| // LangId::SetDefaultLanguage(). |
| static const char kInitialDefaultLanguage[] = ""; |
| |
| // Returns total number of bytes of the words from sentence, without the ^ |
| // (start-of-word) and $ (end-of-word) markers. Note: "real text" means that |
| // this ignores whitespace and punctuation characters from the original text. |
| int GetRealTextSize(const LightSentence &sentence) { |
| int total = 0; |
| for (int i = 0; i < sentence.num_words(); ++i) { |
| TC_DCHECK(!sentence.word(i).empty()); |
| TC_DCHECK_EQ('^', sentence.word(i).front()); |
| TC_DCHECK_EQ('$', sentence.word(i).back()); |
| total += sentence.word(i).size() - 2; |
| } |
| return total; |
| } |
| |
| } // namespace |
| |
| // Class that performs all work behind LangId. |
| class LangIdImpl { |
| public: |
| explicit LangIdImpl(const std::string &filename) { |
| // Using mmap as a fast way to read the model bytes. |
| ScopedMmap scoped_mmap(filename); |
| MmapHandle mmap_handle = scoped_mmap.handle(); |
| if (!mmap_handle.ok()) { |
| TC_LOG(ERROR) << "Unable to read model bytes."; |
| return; |
| } |
| |
| Initialize(mmap_handle.to_stringpiece()); |
| } |
| |
| explicit LangIdImpl(int fd) { |
| // Using mmap as a fast way to read the model bytes. |
| ScopedMmap scoped_mmap(fd); |
| MmapHandle mmap_handle = scoped_mmap.handle(); |
| if (!mmap_handle.ok()) { |
| TC_LOG(ERROR) << "Unable to read model bytes."; |
| return; |
| } |
| |
| Initialize(mmap_handle.to_stringpiece()); |
| } |
| |
| LangIdImpl(const char *ptr, size_t length) { |
| Initialize(StringPiece(ptr, length)); |
| } |
| |
| void Initialize(StringPiece model_bytes) { |
| // Will set valid_ to true only on successful initialization. |
| valid_ = false; |
| |
| // Make sure all relevant features are registered: |
| ContinuousBagOfNgramsFunction::RegisterClass(); |
| RelevantScriptFeature::RegisterClass(); |
| |
| // NOTE(salcianu): code below relies on the fact that the current features |
| // do not rely on data from a TaskInput. Otherwise, one would have to use |
| // the more complex model registration mechanism, which requires more code. |
| InMemoryModelData model_data(model_bytes); |
| TaskContext context; |
| if (!model_data.GetTaskSpec(context.mutable_spec())) { |
| TC_LOG(ERROR) << "Unable to get model TaskSpec"; |
| return; |
| } |
| |
| if (!ParseNetworkParams(model_data, &context)) { |
| return; |
| } |
| if (!ParseListOfKnownLanguages(model_data, &context)) { |
| return; |
| } |
| |
| network_.reset(new EmbeddingNetwork(network_params_.get())); |
| if (!network_->is_valid()) { |
| return; |
| } |
| |
| probability_threshold_ = |
| context.Get("reliability_thresh", kDefaultProbabilityThreshold); |
| min_text_size_in_bytes_ = |
| context.Get("min_text_size_in_bytes", kDefaultMinTextSizeInBytes); |
| version_ = context.Get("version", 0); |
| |
| if (!lang_id_brain_interface_.Init(&context)) { |
| return; |
| } |
| valid_ = true; |
| } |
| |
| void SetProbabilityThreshold(float threshold) { |
| probability_threshold_ = threshold; |
| } |
| |
| void SetDefaultLanguage(const std::string &lang) { default_language_ = lang; } |
| |
| std::string FindLanguage(const std::string &text) const { |
| std::vector<float> scores = ScoreLanguages(text); |
| if (scores.empty()) { |
| return default_language_; |
| } |
| |
| // Softmax label with max score. |
| int label = GetArgMax(scores); |
| float probability = scores[label]; |
| if (probability < probability_threshold_) { |
| return default_language_; |
| } |
| return GetLanguageForSoftmaxLabel(label); |
| } |
| |
| std::vector<std::pair<std::string, float>> FindLanguages( |
| const std::string &text) const { |
| std::vector<float> scores = ScoreLanguages(text); |
| |
| std::vector<std::pair<std::string, float>> result; |
| for (int i = 0; i < scores.size(); i++) { |
| result.push_back({GetLanguageForSoftmaxLabel(i), scores[i]}); |
| } |
| |
| // To avoid crashing clients that always expect at least one predicted |
| // language, we promised (see doc for this method) that the result always |
| // contains at least one element. |
| if (result.empty()) { |
| // We use a tiny probability, such that any client that uses a meaningful |
| // probability threshold ignores this prediction. We don't use 0.0f, to |
| // avoid crashing clients that normalize the probabilities we return here. |
| result.push_back({default_language_, 0.001f}); |
| } |
| return result; |
| } |
| |
| std::vector<float> ScoreLanguages(const std::string &text) const { |
| if (!is_valid()) { |
| return {}; |
| } |
| |
| // Create a Sentence storing the input text. |
| LightSentence sentence; |
| TokenizeTextForLangId(text, &sentence); |
| |
| if (GetRealTextSize(sentence) < min_text_size_in_bytes_) { |
| return {}; |
| } |
| |
| // TODO(salcianu): reuse vector<FeatureVector>. |
| std::vector<FeatureVector> features( |
| lang_id_brain_interface_.NumEmbeddings()); |
| lang_id_brain_interface_.GetFeatures(&sentence, &features); |
| |
| // Predict language. |
| EmbeddingNetwork::Vector scores; |
| network_->ComputeFinalScores(features, &scores); |
| |
| return ComputeSoftmax(scores); |
| } |
| |
| bool is_valid() const { return valid_; } |
| |
| int version() const { return version_; } |
| |
| private: |
| // Returns name of the (in-memory) file for the indicated TaskInput from |
| // context. |
| static std::string GetInMemoryFileNameForTaskInput( |
| const std::string &input_name, TaskContext *context) { |
| TaskInput *task_input = context->GetInput(input_name); |
| if (task_input->part_size() != 1) { |
| TC_LOG(ERROR) << "TaskInput " << input_name << " has " |
| << task_input->part_size() << " parts"; |
| return ""; |
| } |
| return task_input->part(0).file_pattern(); |
| } |
| |
| bool ParseNetworkParams(const InMemoryModelData &model_data, |
| TaskContext *context) { |
| const std::string input_name = "language-identifier-network"; |
| const std::string input_file_name = |
| GetInMemoryFileNameForTaskInput(input_name, context); |
| if (input_file_name.empty()) { |
| TC_LOG(ERROR) << "No input file name for TaskInput " << input_name; |
| return false; |
| } |
| StringPiece bytes = model_data.GetBytesForInputFile(input_file_name); |
| if (bytes.data() == nullptr) { |
| TC_LOG(ERROR) << "Unable to get bytes for TaskInput " << input_name; |
| return false; |
| } |
| std::unique_ptr<EmbeddingNetworkProto> proto(new EmbeddingNetworkProto()); |
| if (!ParseProtoFromMemory(bytes, proto.get())) { |
| TC_LOG(ERROR) << "Unable to parse EmbeddingNetworkProto"; |
| return false; |
| } |
| network_params_.reset( |
| new EmbeddingNetworkParamsFromProto(std::move(proto))); |
| if (!network_params_->is_valid()) { |
| TC_LOG(ERROR) << "EmbeddingNetworkParamsFromProto not valid"; |
| return false; |
| } |
| return true; |
| } |
| |
| // Parses dictionary with known languages (i.e., field languages_) from a |
| // TaskInput of context. Note: that TaskInput should be a ListOfStrings proto |
| // with a single element, the serialized form of a ListOfStrings. |
| // |
| bool ParseListOfKnownLanguages(const InMemoryModelData &model_data, |
| TaskContext *context) { |
| const std::string input_name = "language-name-id-map"; |
| const std::string input_file_name = |
| GetInMemoryFileNameForTaskInput(input_name, context); |
| if (input_file_name.empty()) { |
| TC_LOG(ERROR) << "No input file name for TaskInput " << input_name; |
| return false; |
| } |
| StringPiece bytes = model_data.GetBytesForInputFile(input_file_name); |
| if (bytes.data() == nullptr) { |
| TC_LOG(ERROR) << "Unable to get bytes for TaskInput " << input_name; |
| return false; |
| } |
| ListOfStrings records; |
| if (!ParseProtoFromMemory(bytes, &records)) { |
| TC_LOG(ERROR) << "Unable to parse ListOfStrings from TaskInput " |
| << input_name; |
| return false; |
| } |
| if (records.element_size() != 1) { |
| TC_LOG(ERROR) << "Wrong number of records in TaskInput " << input_name |
| << " : " << records.element_size(); |
| return false; |
| } |
| if (!ParseProtoFromMemory(std::string(records.element(0)), &languages_)) { |
| TC_LOG(ERROR) << "Unable to parse dictionary with known languages"; |
| return false; |
| } |
| return true; |
| } |
| |
| // Returns language code for a softmax label. See comments for languages_ |
| // field. If label is out of range, returns default_language_. |
| std::string GetLanguageForSoftmaxLabel(int label) const { |
| if ((label >= 0) && (label < languages_.element_size())) { |
| return languages_.element(label); |
| } else { |
| TC_LOG(ERROR) << "Softmax label " << label << " outside range [0, " |
| << languages_.element_size() << ")"; |
| return default_language_; |
| } |
| } |
| |
| LangIdBrainInterface lang_id_brain_interface_; |
| |
| // Parameters for the neural network network_ (see below). |
| std::unique_ptr<EmbeddingNetworkParamsFromProto> network_params_; |
| |
| // Neural network to use for scoring. |
| std::unique_ptr<EmbeddingNetwork> network_; |
| |
| // True if this object is ready to perform language predictions. |
| bool valid_; |
| |
| // Only predictions with a probability (confidence) above this threshold are |
| // reported. Otherwise, we report default_language_. |
| float probability_threshold_ = kDefaultProbabilityThreshold; |
| |
| // Min size of the input text for our predictions to be meaningful. Below |
| // this threshold, the underlying model may report a wrong language and a high |
| // confidence score. |
| int min_text_size_in_bytes_ = kDefaultMinTextSizeInBytes; |
| |
| // Version of the model. |
| int version_ = -1; |
| |
| // Known languages: softmax label i (an integer) means languages_.element(i) |
| // (something like "en", "fr", "ru", etc). |
| ListOfStrings languages_; |
| |
| // Language code to return in case of errors. |
| std::string default_language_ = kInitialDefaultLanguage; |
| |
| TC_DISALLOW_COPY_AND_ASSIGN(LangIdImpl); |
| }; |
| |
| LangId::LangId(const std::string &filename) : pimpl_(new LangIdImpl(filename)) { |
| if (!pimpl_->is_valid()) { |
| TC_LOG(ERROR) << "Unable to construct a valid LangId based " |
| << "on the data from " << filename |
| << "; nothing should crash, but " |
| << "accuracy will be bad."; |
| } |
| } |
| |
| LangId::LangId(int fd) : pimpl_(new LangIdImpl(fd)) { |
| if (!pimpl_->is_valid()) { |
| TC_LOG(ERROR) << "Unable to construct a valid LangId based " |
| << "on the data from descriptor " << fd |
| << "; nothing should crash, " |
| << "but accuracy will be bad."; |
| } |
| } |
| |
| LangId::LangId(const char *ptr, size_t length) |
| : pimpl_(new LangIdImpl(ptr, length)) { |
| if (!pimpl_->is_valid()) { |
| TC_LOG(ERROR) << "Unable to construct a valid LangId based " |
| << "on the memory region; nothing should crash, " |
| << "but accuracy will be bad."; |
| } |
| } |
| |
| LangId::~LangId() = default; |
| |
| void LangId::SetProbabilityThreshold(float threshold) { |
| pimpl_->SetProbabilityThreshold(threshold); |
| } |
| |
| void LangId::SetDefaultLanguage(const std::string &lang) { |
| pimpl_->SetDefaultLanguage(lang); |
| } |
| |
| std::string LangId::FindLanguage(const std::string &text) const { |
| return pimpl_->FindLanguage(text); |
| } |
| |
| std::vector<std::pair<std::string, float>> LangId::FindLanguages( |
| const std::string &text) const { |
| return pimpl_->FindLanguages(text); |
| } |
| |
| bool LangId::is_valid() const { return pimpl_->is_valid(); } |
| |
| int LangId::version() const { return pimpl_->version(); } |
| |
| } // namespace lang_id |
| } // namespace nlp_core |
| } // namespace libtextclassifier |