| /* |
| * Copyright (C) 2018 The Android Open Source Project |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| #include "lang_id/lang-id.h" |
| |
| #include <stdio.h> |
| |
| #include <memory> |
| #include <string> |
| #include <unordered_map> |
| #include <utility> |
| #include <vector> |
| |
| #include "lang_id/common/embedding-feature-interface.h" |
| #include "lang_id/common/embedding-network-params.h" |
| #include "lang_id/common/embedding-network.h" |
| #include "lang_id/common/fel/feature-extractor.h" |
| #include "lang_id/common/lite_base/logging.h" |
| #include "lang_id/common/lite_strings/numbers.h" |
| #include "lang_id/common/lite_strings/str-split.h" |
| #include "lang_id/common/lite_strings/stringpiece.h" |
| #include "lang_id/common/math/algorithm.h" |
| #include "lang_id/common/math/softmax.h" |
| #include "lang_id/custom-tokenizer.h" |
| #include "lang_id/features/light-sentence-features.h" |
| // The two features/ headers below are needed only for RegisterClass(). |
| #include "lang_id/features/char-ngram-feature.h" |
| #include "lang_id/features/relevant-script-feature.h" |
| #include "lang_id/light-sentence.h" |
| // The two script/ headers below are needed only for RegisterClass(). |
| #include "lang_id/script/approx-script.h" |
| #include "lang_id/script/tiny-script-detector.h" |
| |
| namespace libtextclassifier3 { |
| namespace mobile { |
| namespace lang_id { |
| |
| namespace { |
| // Default value for the confidence threshold. If the confidence of the top |
| // prediction is below this threshold, then FindLanguage() returns |
| // LangId::kUnknownLanguageCode. Note: this is just a default value; if the |
| // TaskSpec from the model specifies a "reliability_thresh" parameter, then we |
| // use that value instead. Note: for legacy reasons, our code and comments use |
| // the terms "confidence", "probability" and "reliability" equivalently. |
| static const float kDefaultConfidenceThreshold = 0.50f; |
| } // namespace |
| |
| // Class that performs all work behind LangId. |
| class LangIdImpl { |
| public: |
| explicit LangIdImpl(std::unique_ptr<ModelProvider> model_provider) |
| : model_provider_(std::move(model_provider)), |
| lang_id_brain_interface_("language_identifier") { |
| // Note: in the code below, we set valid_ to true only if all initialization |
| // steps completed successfully. Otherwise, we return early, leaving valid_ |
| // to its default value false. |
| if (!model_provider_ || !model_provider_->is_valid()) { |
| SAFTM_LOG(ERROR) << "Invalid model provider"; |
| return; |
| } |
| |
| auto *nn_params = model_provider_->GetNnParams(); |
| if (!nn_params) { |
| SAFTM_LOG(ERROR) << "No NN params"; |
| return; |
| } |
| network_.reset(new EmbeddingNetwork(nn_params)); |
| |
| languages_ = model_provider_->GetLanguages(); |
| if (languages_.empty()) { |
| SAFTM_LOG(ERROR) << "No known languages"; |
| return; |
| } |
| |
| TaskContext context = *model_provider_->GetTaskContext(); |
| if (!Setup(&context)) { |
| SAFTM_LOG(ERROR) << "Unable to Setup() LangId"; |
| return; |
| } |
| if (!Init(&context)) { |
| SAFTM_LOG(ERROR) << "Unable to Init() LangId"; |
| return; |
| } |
| valid_ = true; |
| } |
| |
| std::string FindLanguage(StringPiece text) const { |
| LangIdResult lang_id_result; |
| FindLanguages(text, &lang_id_result, /* max_results = */ 1); |
| if (lang_id_result.predictions.empty()) { |
| return LangId::kUnknownLanguageCode; |
| } |
| |
| const std::string &language = lang_id_result.predictions[0].first; |
| const float probability = lang_id_result.predictions[0].second; |
| SAFTM_DLOG(INFO) << "Predicted " << language |
| << " with prob: " << probability << " for \"" << text |
| << "\""; |
| |
| // Find confidence threshold for language. |
| float threshold = default_threshold_; |
| auto it = per_lang_thresholds_.find(language); |
| if (it != per_lang_thresholds_.end()) { |
| threshold = it->second; |
| } |
| if (probability < threshold) { |
| SAFTM_DLOG(INFO) << " below threshold => " |
| << LangId::kUnknownLanguageCode; |
| return LangId::kUnknownLanguageCode; |
| } |
| return language; |
| } |
| |
| void FindLanguages(StringPiece text, LangIdResult *result, |
| int max_results) const { |
| if (result == nullptr) return; |
| |
| if (max_results <= 0) { |
| max_results = languages_.size(); |
| } |
| result->predictions.clear(); |
| if (!is_valid() || (max_results == 0)) { |
| result->predictions.emplace_back(LangId::kUnknownLanguageCode, 1); |
| return; |
| } |
| |
| // Tokenize the input text (this also does some pre-processing, like |
| // removing ASCII digits, punctuation, etc). |
| LightSentence sentence; |
| tokenizer_.Tokenize(text, &sentence); |
| |
| // Test input size here, after pre-processing removed irrelevant chars. |
| if (IsTooShort(sentence)) { |
| result->predictions.emplace_back(LangId::kUnknownLanguageCode, 1); |
| return; |
| } |
| |
| // Extract features from the tokenized text. |
| std::vector<FeatureVector> features = |
| lang_id_brain_interface_.GetFeaturesNoCaching(&sentence); |
| |
| // Run feed-forward neural network to compute scores (softmax logits). |
| std::vector<float> scores; |
| network_->ComputeFinalScores(features, &scores); |
| |
| if (max_results == 1) { |
| // Optimization for the case when the user wants only the top result. |
| // Computing argmax is faster than the general top-k code. |
| int prediction_id = GetArgMax(scores); |
| const std::string language = GetLanguageForSoftmaxLabel(prediction_id); |
| float probability = ComputeSoftmaxProbability(scores, prediction_id); |
| result->predictions.emplace_back(language, probability); |
| } else { |
| // Compute and sort softmax in descending order by probability and convert |
| // IDs to language code strings. When probabilities are equal, we sort by |
| // language code string in ascending order. |
| const std::vector<float> softmax = ComputeSoftmax(scores); |
| const std::vector<int> indices = GetTopKIndices(max_results, softmax); |
| for (const int index : indices) { |
| result->predictions.emplace_back(GetLanguageForSoftmaxLabel(index), |
| softmax[index]); |
| } |
| } |
| } |
| |
| bool is_valid() const { return valid_; } |
| |
| int GetModelVersion() const { return model_version_; } |
| |
| // Returns a property stored in the model file. |
| template <typename T, typename R> |
| R GetProperty(const std::string &property, T default_value) const { |
| return model_provider_->GetTaskContext()->Get(property, default_value); |
| } |
| |
| // Perform any necessary static initialization. |
| // This function is thread-safe. |
| // It's also safe to call this function multiple times. |
| // |
| // We explicitly call RegisterClass() rather than relying on alwayslink=1 in |
| // the BUILD file, because the build process for some users of this code |
| // doesn't support any equivalent to alwayslink=1 (in particular the |
| // Firebase C++ SDK build uses a Kokoro-based CMake build). While it might |
| // be possible to add such support, avoiding the need for an equivalent to |
| // alwayslink=1 is preferable because it avoids unnecessarily bloating code |
| // size in apps that link against this code but don't use it. |
| static void RegisterClasses() { |
| static bool initialized = []() -> bool { |
| libtextclassifier3::mobile::ApproxScriptDetector::RegisterClass(); |
| libtextclassifier3::mobile::lang_id::ContinuousBagOfNgramsFunction::RegisterClass(); |
| libtextclassifier3::mobile::lang_id::TinyScriptDetector::RegisterClass(); |
| libtextclassifier3::mobile::lang_id::RelevantScriptFeature::RegisterClass(); |
| return true; |
| }(); |
| (void)initialized; // Variable used only for initializer's side effects. |
| } |
| |
| private: |
| bool Setup(TaskContext *context) { |
| tokenizer_.Setup(context); |
| if (!lang_id_brain_interface_.SetupForProcessing(context)) return false; |
| |
| min_text_size_in_bytes_ = context->Get("min_text_size_in_bytes", 0); |
| default_threshold_ = |
| context->Get("reliability_thresh", kDefaultConfidenceThreshold); |
| |
| // Parse task parameter "per_lang_reliability_thresholds", fill |
| // per_lang_thresholds_. |
| const std::string thresholds_str = |
| context->Get("per_lang_reliability_thresholds", ""); |
| std::vector<StringPiece> tokens = LiteStrSplit(thresholds_str, ','); |
| for (const auto &token : tokens) { |
| if (token.empty()) continue; |
| std::vector<StringPiece> parts = LiteStrSplit(token, '='); |
| float threshold = 0.0f; |
| if ((parts.size() == 2) && LiteAtof(parts[1], &threshold)) { |
| per_lang_thresholds_[std::string(parts[0])] = threshold; |
| } else { |
| SAFTM_LOG(ERROR) << "Broken token: \"" << token << "\""; |
| } |
| } |
| model_version_ = context->Get("model_version", model_version_); |
| return true; |
| } |
| |
| bool Init(TaskContext *context) { |
| return lang_id_brain_interface_.InitForProcessing(context); |
| } |
| |
| // Returns language code for a softmax label. See comments for languages_ |
| // field. If label is out of range, returns LangId::kUnknownLanguageCode. |
| std::string GetLanguageForSoftmaxLabel(int label) const { |
| if ((label >= 0) && (static_cast<size_t>(label) < languages_.size())) { |
| return languages_[label]; |
| } else { |
| SAFTM_LOG(ERROR) << "Softmax label " << label << " outside range [0, " |
| << languages_.size() << ")"; |
| return LangId::kUnknownLanguageCode; |
| } |
| } |
| |
| bool IsTooShort(const LightSentence &sentence) const { |
| int text_size = 0; |
| for (const std::string &token : sentence) { |
| // Each token has the form ^...$: we subtract 2 because we want to count |
| // only the real text, not the chars added by us. |
| text_size += token.size() - 2; |
| } |
| return text_size < min_text_size_in_bytes_; |
| } |
| |
| std::unique_ptr<ModelProvider> model_provider_; |
| |
| TokenizerForLangId tokenizer_; |
| |
| EmbeddingFeatureInterface<LightSentenceExtractor, LightSentence> |
| lang_id_brain_interface_; |
| |
| // Neural network to use for scoring. |
| std::unique_ptr<EmbeddingNetwork> network_; |
| |
| // True if this object is ready to perform language predictions. |
| bool valid_ = false; |
| |
| // The model returns LangId::kUnknownLanguageCode for input text that has |
| // fewer than min_text_size_in_bytes_ bytes (excluding ASCII whitespaces, |
| // digits, and punctuation). |
| int min_text_size_in_bytes_ = 0; |
| |
| // Only predictions with a probability (confidence) above this threshold are |
| // reported. Otherwise, we report LangId::kUnknownLanguageCode. |
| float default_threshold_ = kDefaultConfidenceThreshold; |
| |
| std::unordered_map<std::string, float> per_lang_thresholds_; |
| |
| // Recognized languages: softmax label i means languages_[i] (something like |
| // "en", "fr", "ru", etc). |
| std::vector<std::string> languages_; |
| |
| // Version of the model used by this LangIdImpl object. Zero means that the |
| // model version could not be determined. |
| int model_version_ = 0; |
| }; |
| |
| const char LangId::kUnknownLanguageCode[] = "und"; |
| |
| LangId::LangId(std::unique_ptr<ModelProvider> model_provider) |
| : pimpl_(new LangIdImpl(std::move(model_provider))) { |
| LangIdImpl::RegisterClasses(); |
| } |
| |
| LangId::~LangId() = default; |
| |
| std::string LangId::FindLanguage(const char *data, size_t num_bytes) const { |
| StringPiece text(data, num_bytes); |
| return pimpl_->FindLanguage(text); |
| } |
| |
| void LangId::FindLanguages(const char *data, size_t num_bytes, |
| LangIdResult *result, int max_results) const { |
| SAFTM_DCHECK(result) << "LangIdResult must not be null."; |
| StringPiece text(data, num_bytes); |
| pimpl_->FindLanguages(text, result, max_results); |
| } |
| |
| bool LangId::is_valid() const { return pimpl_->is_valid(); } |
| |
| int LangId::GetModelVersion() const { return pimpl_->GetModelVersion(); } |
| |
| float LangId::GetFloatProperty(const std::string &property, |
| float default_value) const { |
| return pimpl_->GetProperty<float, float>(property, default_value); |
| } |
| |
| } // namespace lang_id |
| } // namespace mobile |
| } // namespace nlp_saft |