| /* |
| * Copyright (C) 2018 The Android Open Source Project |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| #include "annotator/number/number.h" |
| |
| #include <climits> |
| #include <cstdlib> |
| |
| #include "annotator/collections.h" |
| #include "utils/base/logging.h" |
| |
| namespace libtextclassifier3 { |
| |
| bool NumberAnnotator::ClassifyText( |
| const UnicodeText& context, CodepointSpan selection_indices, |
| AnnotationUsecase annotation_usecase, |
| ClassificationResult* classification_result) const { |
| int64 parsed_value; |
| int num_prefix_codepoints; |
| int num_suffix_codepoints; |
| if (ParseNumber(UnicodeText::Substring(context, selection_indices.first, |
| selection_indices.second), |
| &parsed_value, &num_prefix_codepoints, |
| &num_suffix_codepoints)) { |
| ClassificationResult classification{Collections::Number(), 1.0}; |
| TC3_CHECK(classification_result != nullptr); |
| classification_result->collection = Collections::Number(); |
| classification_result->score = options_->score(); |
| classification_result->priority_score = options_->priority_score(); |
| classification_result->numeric_value = parsed_value; |
| return true; |
| } |
| return false; |
| } |
| |
| bool NumberAnnotator::FindAll(const UnicodeText& context, |
| AnnotationUsecase annotation_usecase, |
| std::vector<AnnotatedSpan>* result) const { |
| if (!options_->enabled() || ((1 << annotation_usecase) & |
| options_->enabled_annotation_usecases()) == 0) { |
| return true; |
| } |
| |
| const std::vector<Token> tokens = feature_processor_->Tokenize(context); |
| for (const Token& token : tokens) { |
| const UnicodeText token_text = |
| UTF8ToUnicodeText(token.value, /*do_copy=*/false); |
| int64 parsed_value; |
| int num_prefix_codepoints; |
| int num_suffix_codepoints; |
| if (ParseNumber(token_text, &parsed_value, &num_prefix_codepoints, |
| &num_suffix_codepoints)) { |
| ClassificationResult classification{Collections::Number(), |
| options_->score()}; |
| classification.numeric_value = parsed_value; |
| classification.priority_score = options_->priority_score(); |
| |
| AnnotatedSpan annotated_span; |
| annotated_span.span = {token.start + num_prefix_codepoints, |
| token.end - num_suffix_codepoints}; |
| annotated_span.classification.push_back(classification); |
| |
| result->push_back(annotated_span); |
| } |
| } |
| |
| return true; |
| } |
| |
| std::unordered_set<int> NumberAnnotator::FlatbuffersVectorToSet( |
| const flatbuffers::Vector<int32_t>* codepoints) { |
| if (codepoints == nullptr) { |
| return std::unordered_set<int>{}; |
| } |
| |
| std::unordered_set<int> result; |
| for (const int codepoint : *codepoints) { |
| result.insert(codepoint); |
| } |
| return result; |
| } |
| |
| namespace { |
| UnicodeText::const_iterator ConsumeAndParseNumber( |
| const UnicodeText::const_iterator& it_begin, |
| const UnicodeText::const_iterator& it_end, int64* result) { |
| *result = 0; |
| |
| // See if there's a sign in the beginning of the number. |
| int sign = 1; |
| auto it = it_begin; |
| if (it != it_end) { |
| if (*it == '-') { |
| ++it; |
| sign = -1; |
| } else if (*it == '+') { |
| ++it; |
| sign = 1; |
| } |
| } |
| |
| while (it != it_end) { |
| if (*it >= '0' && *it <= '9') { |
| // When overflow is imminent we'll fail to parse the number. |
| if (*result > INT64_MAX / 10) { |
| return it_begin; |
| } |
| *result *= 10; |
| *result += *it - '0'; |
| } else { |
| *result *= sign; |
| return it; |
| } |
| |
| ++it; |
| } |
| |
| *result *= sign; |
| return it_end; |
| } |
| } // namespace |
| |
| bool NumberAnnotator::ParseNumber(const UnicodeText& text, int64* result, |
| int* num_prefix_codepoints, |
| int* num_suffix_codepoints) const { |
| TC3_CHECK(result != nullptr && num_prefix_codepoints != nullptr && |
| num_suffix_codepoints != nullptr); |
| auto it = text.begin(); |
| auto it_end = text.end(); |
| |
| // Strip boundary codepoints from both ends. |
| const CodepointSpan original_span{0, text.size_codepoints()}; |
| const CodepointSpan stripped_span = |
| feature_processor_->StripBoundaryCodepoints(text, original_span); |
| const int num_stripped_end = (original_span.second - stripped_span.second); |
| std::advance(it, stripped_span.first); |
| std::advance(it_end, -num_stripped_end); |
| |
| // Consume prefix codepoints. |
| *num_prefix_codepoints = stripped_span.first; |
| while (it != text.end()) { |
| if (allowed_prefix_codepoints_.find(*it) == |
| allowed_prefix_codepoints_.end()) { |
| break; |
| } |
| |
| ++it; |
| ++(*num_prefix_codepoints); |
| } |
| |
| auto it_start = it; |
| it = ConsumeAndParseNumber(it, text.end(), result); |
| if (it == it_start) { |
| return false; |
| } |
| |
| // Consume suffix codepoints. |
| bool valid_suffix = true; |
| *num_suffix_codepoints = 0; |
| while (it != it_end) { |
| if (allowed_suffix_codepoints_.find(*it) == |
| allowed_suffix_codepoints_.end()) { |
| valid_suffix = false; |
| break; |
| } |
| |
| ++it; |
| ++(*num_suffix_codepoints); |
| } |
| *num_suffix_codepoints += num_stripped_end; |
| return valid_suffix; |
| } |
| |
| } // namespace libtextclassifier3 |