| /* |
| * Copyright (C) 2018 The Android Open Source Project |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| #include "utils/sentencepiece/normalizer.h" |
| |
| #include "utils/base/logging.h" |
| #include "utils/strings/utf8.h" |
| |
| namespace libtextclassifier3 { |
| |
| bool SentencePieceNormalizer::Normalize(StringPiece input, |
| std::string* normalized_input) const { |
| // Ignores heading space. |
| if (remove_extra_whitespaces_) { |
| while (!input.empty()) { |
| std::pair<StringPiece, int> suffix_and_length; |
| if (!NormalizePrefix(input, &suffix_and_length)) { |
| TC3_LOG(ERROR) << "Couldn't find match in normalization table."; |
| return false; |
| } |
| if (suffix_and_length.second <= 0) { |
| TC3_LOG(ERROR) << "Consumed string is empty."; |
| return false; |
| } |
| if (suffix_and_length.first.size() != 1 || |
| suffix_and_length.first[0] != ' ') { |
| break; |
| } |
| input.RemovePrefix(suffix_and_length.second); |
| } |
| } |
| |
| if (input.empty()) { |
| *normalized_input = ""; |
| return true; |
| } |
| |
| // Reserves the output buffer to avoid re-allocations. |
| const int kReservedSize = input.size() * 3; |
| normalized_input->reserve(kReservedSize); |
| |
| // Replaces white space with U+2581 (LOWER ONE EIGHT BLOCK) |
| // if escape_whitespaces() is set (default = true). |
| const StringPiece kSpaceSymbol = "\xe2\x96\x81"; |
| |
| // Adds a space symbol as a prefix (default is true) |
| // With this prefix, "world" and "hello world" are converted into |
| // "_world" and "_hello_world", which help the trainer to extract |
| // "_world" as one symbol. |
| if (add_dummy_prefix_) { |
| if (escape_whitespaces_) { |
| normalized_input->append(kSpaceSymbol.data(), kSpaceSymbol.size()); |
| } else { |
| normalized_input->append(" "); |
| } |
| } |
| |
| bool is_prev_space = remove_extra_whitespaces_; |
| while (!input.empty()) { |
| std::pair<StringPiece, int> p; |
| if (!NormalizePrefix(input, &p)) { |
| TC3_LOG(ERROR) << "Couldn't normalize string."; |
| return false; |
| } |
| if (p.second <= 0) { |
| TC3_LOG(ERROR) << "Consumed string is empty."; |
| return false; |
| } |
| |
| StringPiece sp = p.first; |
| |
| // Removes heading spaces in sentence piece, |
| // if the previous sentence piece ends with whitespace. |
| while (is_prev_space && ConsumePrefix(&sp, " ")) { |
| } |
| |
| if (!sp.empty()) { |
| const char* data = sp.data(); |
| for (int n = 0; n < sp.size(); ++n) { |
| if (escape_whitespaces_ && data[n] == ' ') { |
| normalized_input->append(kSpaceSymbol.data(), kSpaceSymbol.size()); |
| } else { |
| *normalized_input += data[n]; |
| } |
| } |
| // Checks whether the last character of sp is whitespace. |
| is_prev_space = EndsWith(sp, " "); |
| } |
| input.RemovePrefix(p.second); |
| is_prev_space = is_prev_space && remove_extra_whitespaces_; |
| } |
| |
| // Ignores tailing space. |
| if (remove_extra_whitespaces_) { |
| const StringPiece space = escape_whitespaces_ ? kSpaceSymbol : " "; |
| while (EndsWith(*normalized_input, space)) { |
| const int length = normalized_input->size() - space.size(); |
| normalized_input->resize(length); |
| } |
| } |
| return true; |
| } |
| |
| bool SentencePieceNormalizer::NormalizePrefix( |
| StringPiece input, std::pair<StringPiece, int>* prefix) const { |
| if (input.empty()) return true; |
| TrieMatch match; |
| if (!charsmap_trie_.LongestPrefixMatch(input, &match)) { |
| TC3_LOG(ERROR) << "Couldn't find match in normalization table."; |
| return false; |
| } |
| const bool no_match = match.match_length <= 0; |
| if (no_match) { |
| const int char_length = ValidUTF8CharLength(input.data(), input.size()); |
| if (char_length <= 0) { |
| // Found a malformed utf8. |
| // The rune is set to be 0xFFFD (REPLACEMENT CHARACTER), |
| // which is a valid Unicode of three bytes in utf8, |
| // but here we only consume one byte. |
| static const char kReplacementChar[] = "\xEF\xBF\xBD"; |
| prefix->first = StringPiece(kReplacementChar, 3); |
| prefix->second = 1; // Consumes 1 byte, buts emit 0xFFFD. |
| } else { |
| prefix->first = StringPiece(input.data(), char_length); |
| prefix->second = char_length; |
| } |
| } else { |
| if (match.id < 0 || match.id >= charsmap_normalized_.size()) { |
| TC3_LOG(ERROR) << "Invalid entry in normalization table."; |
| return false; |
| } |
| prefix->first = StringPiece(&charsmap_normalized_.data()[match.id]); |
| prefix->second = match.match_length; |
| } |
| return true; |
| } |
| |
| } // namespace libtextclassifier3 |