blob: a7b0cac7a42383e34846e06f4ab0a54d1778246c [file] [log] [blame] [edit]
/*
* Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "utils/wordpiece_tokenizer.h"
#include "utils/utf8/unicodetext.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/strings/string_view.h"
namespace libtextclassifier3 {
namespace {
LookupStatus Lookup(int byte_start, int byte_end, const absl::string_view token,
const std::string& suffix_indicator,
const WordpieceVocab* vocab_map, bool* in_vocab) {
int byte_len = byte_end - byte_start;
absl::string_view substr(token.data() + byte_start, byte_len);
std::string lookup_value;
if (byte_start > 0) {
lookup_value = absl::StrCat(suffix_indicator, substr);
} else {
// absl::CopyToString
lookup_value.assign(substr.begin(), substr.end());
}
return vocab_map->Contains(lookup_value, in_vocab);
}
// Sets byte_end to the longest byte sequence which:
// 1) is a proper UTF8 sequence
// 2) is in the vocab OR if split_unknown_characters is true, is a single
// UTF8 character.
// If no match is found, found_match is set to false.
LookupStatus LongestMatchStartingAt(
int byte_start, const absl::string_view token,
const std::string& suffix_indicator, const int max_chars_per_subtoken,
bool split_unknown_characters, const WordpieceVocab* vocab_map,
int* byte_end, bool* found_match, bool* match_is_unknown_character) {
*match_is_unknown_character = false;
*found_match = false;
const UnicodeText unicode_token =
UTF8ToUnicodeText(token.substr(byte_start), /*do_copy=*/false);
std::vector<int32_t> byte_ends;
int32_t codepoint_offset = byte_start;
for (auto it = unicode_token.begin(); it != unicode_token.end(); ++it) {
codepoint_offset += it.utf8_length();
byte_ends.push_back(codepoint_offset);
if (max_chars_per_subtoken > 0 &&
byte_ends.size() == max_chars_per_subtoken) {
// If the max bytes of a subtoken is known, do not search beyond that
// length.
break;
}
}
int n = byte_ends.size();
for (int i = n - 1; i >= 0; i--) {
bool in_vocab;
auto status = Lookup(byte_start, byte_ends[i], token, suffix_indicator,
vocab_map, &in_vocab);
if (!status.success) return status;
if (in_vocab) {
*byte_end = byte_ends[i];
*found_match = true;
return LookupStatus::OK();
}
if (i == 0 && split_unknown_characters) {
*byte_end = byte_ends[0];
*found_match = true;
*match_is_unknown_character = true;
return LookupStatus::OK();
}
}
return LookupStatus::OK();
}
// Sets the outputs 'begin_offset', 'end_offset' and 'num_word_pieces' when no
// token is found.
LookupStatus NoTokenFound(const absl::string_view token, bool use_unknown_token,
const std::string& unknown_token,
std::vector<std::string>* subwords,
std::vector<int>* begin_offset,
std::vector<int>* end_offset, int* num_word_pieces) {
begin_offset->push_back(0);
if (use_unknown_token) {
subwords->push_back(unknown_token);
end_offset->push_back(token.length());
} else {
subwords->emplace_back(token.data(), token.length());
end_offset->push_back(token.length());
}
++(*num_word_pieces);
return LookupStatus::OK();
}
// When a subword is found, this helper function will add the outputs to
// 'subwords', 'begin_offset' and 'end_offset'.
void AddWord(const absl::string_view token, int byte_start, int byte_end,
const std::string& suffix_indicator,
std::vector<std::string>* subwords, std::vector<int>* begin_offset,
std::vector<int>* end_offset) {
begin_offset->push_back(byte_start);
int len = byte_end - byte_start;
if (byte_start > 0) {
// Prepend suffix_indicator if the token is within a word.
subwords->push_back(::absl::StrCat(
suffix_indicator, absl::string_view(token.data() + byte_start, len)));
} else {
subwords->emplace_back(token.data(), len);
}
end_offset->push_back(byte_end);
}
// Adds a single unknown character subword, found when split_unknown_characters
// is true.
void AddUnknownCharacter(const absl::string_view token, int byte_start,
int byte_end, const std::string& suffix_indicator,
bool use_unknown_token,
const std::string& unknown_token,
std::vector<std::string>* subwords,
std::vector<int>* begin_offset,
std::vector<int>* end_offset) {
begin_offset->push_back(byte_start);
end_offset->push_back(byte_end);
int len = byte_end - byte_start;
if (use_unknown_token) {
if (byte_start > 0) {
// Prepend suffix_indicator if the character is within a word.
subwords->push_back(::absl::StrCat(suffix_indicator, unknown_token));
} else {
subwords->push_back(unknown_token);
}
} else {
if (byte_start > 0) {
// Prepend suffix_indicator if the character is within a word.
subwords->push_back(::absl::StrCat(
suffix_indicator, absl::string_view(token.data() + byte_start, len)));
} else {
subwords->emplace_back(token.data(), len);
}
}
}
LookupStatus TokenizeL2RGreedy(
const absl::string_view token, const int max_bytes_per_token,
const int max_chars_per_subtoken, const std::string& suffix_indicator,
bool use_unknown_token, const std::string& unknown_token,
bool split_unknown_characters, const WordpieceVocab* vocab_map,
std::vector<std::string>* subwords, std::vector<int>* begin_offset,
std::vector<int>* end_offset, int* num_word_pieces) {
std::vector<std::string> candidate_subwords;
std::vector<int> candidate_begin_offsets;
std::vector<int> candidate_end_offsets;
const int token_len = token.length();
for (int byte_start = 0; byte_start < token_len;) {
int byte_end;
bool found_subword;
bool match_is_unknown_character;
auto status = LongestMatchStartingAt(
byte_start, token, suffix_indicator, max_chars_per_subtoken,
split_unknown_characters, vocab_map, &byte_end, &found_subword,
&match_is_unknown_character);
if (!status.success) return status;
if (found_subword) {
if (match_is_unknown_character) {
AddUnknownCharacter(token, byte_start, byte_end, suffix_indicator,
use_unknown_token, unknown_token,
&candidate_subwords, &candidate_begin_offsets,
&candidate_end_offsets);
} else {
AddWord(token, byte_start, byte_end, suffix_indicator,
&candidate_subwords, &candidate_begin_offsets,
&candidate_end_offsets);
}
byte_start = byte_end;
} else {
return NoTokenFound(token, use_unknown_token, unknown_token, subwords,
begin_offset, end_offset, num_word_pieces);
}
}
subwords->insert(subwords->end(), candidate_subwords.begin(),
candidate_subwords.end());
begin_offset->insert(begin_offset->end(), candidate_begin_offsets.begin(),
candidate_begin_offsets.end());
end_offset->insert(end_offset->end(), candidate_end_offsets.begin(),
candidate_end_offsets.end());
*num_word_pieces += candidate_subwords.size();
return LookupStatus::OK();
}
} // namespace
LookupStatus WordpieceTokenize(
const absl::string_view token, const int max_bytes_per_token,
const int max_chars_per_subtoken, const std::string& suffix_indicator,
bool use_unknown_token, const std::string& unknown_token,
bool split_unknown_characters, const WordpieceVocab* vocab_map,
std::vector<std::string>* subwords, std::vector<int>* begin_offset,
std::vector<int>* end_offset, int* num_word_pieces) {
int token_len = token.size();
if (token_len > max_bytes_per_token) {
begin_offset->push_back(0);
*num_word_pieces = 1;
if (use_unknown_token) {
subwords->emplace_back(unknown_token);
} else {
subwords->emplace_back(token);
}
end_offset->push_back(token.size());
return LookupStatus::OK();
}
return TokenizeL2RGreedy(token, max_bytes_per_token, max_chars_per_subtoken,
suffix_indicator, use_unknown_token, unknown_token,
split_unknown_characters, vocab_map, subwords,
begin_offset, end_offset, num_word_pieces);
}
LookupStatus WordpieceTokenize(
const absl::string_view token, const int max_bytes_per_token,
const std::string& suffix_indicator, bool use_unknown_token,
const std::string& unknown_token, const WordpieceVocab* vocab_map,
std::vector<std::string>* subwords, std::vector<int>* begin_offset,
std::vector<int>* end_offset, int* num_word_pieces) {
return WordpieceTokenize(token, max_bytes_per_token,
/* max_chars_per_subtoken= */ 0, suffix_indicator,
use_unknown_token, unknown_token,
/* split_unknown_characters= */ false, vocab_map,
subwords, begin_offset, end_offset, num_word_pieces);
}
} // namespace libtextclassifier3