blob: c85ba507475bb598503d67c85eccf5a192dbdbb7 [file] [log] [blame]
/*
* Copyright (C) 2017 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "smartselect/token-feature-extractor.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
namespace libtextclassifier {
namespace {
class TestingTokenFeatureExtractor : public TokenFeatureExtractor {
public:
using TokenFeatureExtractor::TokenFeatureExtractor;
using TokenFeatureExtractor::HashToken;
};
TEST(TokenFeatureExtractorTest, ExtractAscii) {
TokenFeatureExtractorOptions options;
options.num_buckets = 1000;
options.chargram_orders = std::vector<int>{1, 2, 3};
options.extract_case_feature = true;
options.unicode_aware_features = false;
options.extract_selection_mask_feature = true;
TestingTokenFeatureExtractor extractor(options);
std::vector<int> sparse_features;
std::vector<float> dense_features;
extractor.Extract(Token{"Hello", 0, 5}, true, &sparse_features,
&dense_features);
EXPECT_THAT(sparse_features,
testing::ElementsAreArray({
// clang-format off
extractor.HashToken("H"),
extractor.HashToken("e"),
extractor.HashToken("l"),
extractor.HashToken("l"),
extractor.HashToken("o"),
extractor.HashToken("^H"),
extractor.HashToken("He"),
extractor.HashToken("el"),
extractor.HashToken("ll"),
extractor.HashToken("lo"),
extractor.HashToken("o$"),
extractor.HashToken("^He"),
extractor.HashToken("Hel"),
extractor.HashToken("ell"),
extractor.HashToken("llo"),
extractor.HashToken("lo$")
// clang-format on
}));
EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0}));
sparse_features.clear();
dense_features.clear();
extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
&dense_features);
EXPECT_THAT(sparse_features,
testing::ElementsAreArray({
// clang-format off
extractor.HashToken("w"),
extractor.HashToken("o"),
extractor.HashToken("r"),
extractor.HashToken("l"),
extractor.HashToken("d"),
extractor.HashToken("!"),
extractor.HashToken("^w"),
extractor.HashToken("wo"),
extractor.HashToken("or"),
extractor.HashToken("rl"),
extractor.HashToken("ld"),
extractor.HashToken("d!"),
extractor.HashToken("!$"),
extractor.HashToken("^wo"),
extractor.HashToken("wor"),
extractor.HashToken("orl"),
extractor.HashToken("rld"),
extractor.HashToken("ld!"),
extractor.HashToken("d!$"),
// clang-format on
}));
EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0}));
}
TEST(TokenFeatureExtractorTest, ExtractUnicode) {
TokenFeatureExtractorOptions options;
options.num_buckets = 1000;
options.chargram_orders = std::vector<int>{1, 2, 3};
options.extract_case_feature = true;
options.unicode_aware_features = true;
options.extract_selection_mask_feature = true;
TestingTokenFeatureExtractor extractor(options);
std::vector<int> sparse_features;
std::vector<float> dense_features;
extractor.Extract(Token{"Hělló", 0, 5}, true, &sparse_features,
&dense_features);
EXPECT_THAT(sparse_features,
testing::ElementsAreArray({
// clang-format off
extractor.HashToken("H"),
extractor.HashToken("ě"),
extractor.HashToken("l"),
extractor.HashToken("l"),
extractor.HashToken("ó"),
extractor.HashToken("^H"),
extractor.HashToken("Hě"),
extractor.HashToken("ěl"),
extractor.HashToken("ll"),
extractor.HashToken("ló"),
extractor.HashToken("ó$"),
extractor.HashToken("^Hě"),
extractor.HashToken("Hěl"),
extractor.HashToken("ěll"),
extractor.HashToken("lló"),
extractor.HashToken("ló$")
// clang-format on
}));
EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, 1.0}));
sparse_features.clear();
dense_features.clear();
extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
&dense_features);
EXPECT_THAT(sparse_features,
testing::ElementsAreArray({
// clang-format off
extractor.HashToken("w"),
extractor.HashToken("o"),
extractor.HashToken("r"),
extractor.HashToken("l"),
extractor.HashToken("d"),
extractor.HashToken("!"),
extractor.HashToken("^w"),
extractor.HashToken("wo"),
extractor.HashToken("or"),
extractor.HashToken("rl"),
extractor.HashToken("ld"),
extractor.HashToken("d!"),
extractor.HashToken("!$"),
extractor.HashToken("^wo"),
extractor.HashToken("wor"),
extractor.HashToken("orl"),
extractor.HashToken("rld"),
extractor.HashToken("ld!"),
extractor.HashToken("d!$"),
// clang-format on
}));
EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0}));
}
TEST(TokenFeatureExtractorTest, ICUCaseFeature) {
TokenFeatureExtractorOptions options;
options.num_buckets = 1000;
options.chargram_orders = std::vector<int>{1, 2};
options.extract_case_feature = true;
options.unicode_aware_features = true;
options.extract_selection_mask_feature = false;
TokenFeatureExtractor extractor(options);
std::vector<int> sparse_features;
std::vector<float> dense_features;
extractor.Extract(Token{"Hělló", 0, 5}, true, &sparse_features,
&dense_features);
EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0}));
sparse_features.clear();
dense_features.clear();
extractor.Extract(Token{"world!", 23, 29}, false, &sparse_features,
&dense_features);
EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0}));
sparse_features.clear();
dense_features.clear();
extractor.Extract(Token{"Ř", 23, 29}, false, &sparse_features,
&dense_features);
EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0}));
sparse_features.clear();
dense_features.clear();
extractor.Extract(Token{"ř", 23, 29}, false, &sparse_features,
&dense_features);
EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0}));
}
TEST(TokenFeatureExtractorTest, DigitRemapping) {
TokenFeatureExtractorOptions options;
options.num_buckets = 1000;
options.chargram_orders = std::vector<int>{1, 2};
options.remap_digits = true;
options.unicode_aware_features = false;
TokenFeatureExtractor extractor(options);
std::vector<int> sparse_features;
std::vector<float> dense_features;
extractor.Extract(Token{"9:30am", 0, 6}, true, &sparse_features,
&dense_features);
std::vector<int> sparse_features2;
extractor.Extract(Token{"5:32am", 0, 6}, true, &sparse_features2,
&dense_features);
EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
extractor.Extract(Token{"10:32am", 0, 6}, true, &sparse_features2,
&dense_features);
EXPECT_THAT(sparse_features,
testing::Not(testing::ElementsAreArray(sparse_features2)));
}
TEST(TokenFeatureExtractorTest, DigitRemappingUnicode) {
TokenFeatureExtractorOptions options;
options.num_buckets = 1000;
options.chargram_orders = std::vector<int>{1, 2};
options.remap_digits = true;
options.unicode_aware_features = true;
TokenFeatureExtractor extractor(options);
std::vector<int> sparse_features;
std::vector<float> dense_features;
extractor.Extract(Token{"9:30am", 0, 6}, true, &sparse_features,
&dense_features);
std::vector<int> sparse_features2;
extractor.Extract(Token{"5:32am", 0, 6}, true, &sparse_features2,
&dense_features);
EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
extractor.Extract(Token{"10:32am", 0, 6}, true, &sparse_features2,
&dense_features);
EXPECT_THAT(sparse_features,
testing::Not(testing::ElementsAreArray(sparse_features2)));
}
TEST(TokenFeatureExtractorTest, LowercaseAscii) {
TokenFeatureExtractorOptions options;
options.num_buckets = 1000;
options.chargram_orders = std::vector<int>{1, 2};
options.lowercase_tokens = true;
options.unicode_aware_features = false;
TokenFeatureExtractor extractor(options);
std::vector<int> sparse_features;
std::vector<float> dense_features;
extractor.Extract(Token{"AABB", 0, 6}, true, &sparse_features,
&dense_features);
std::vector<int> sparse_features2;
extractor.Extract(Token{"aaBB", 0, 6}, true, &sparse_features2,
&dense_features);
EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
extractor.Extract(Token{"aAbB", 0, 6}, true, &sparse_features2,
&dense_features);
EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
}
TEST(TokenFeatureExtractorTest, LowercaseUnicode) {
TokenFeatureExtractorOptions options;
options.num_buckets = 1000;
options.chargram_orders = std::vector<int>{1, 2};
options.lowercase_tokens = true;
options.unicode_aware_features = true;
TokenFeatureExtractor extractor(options);
std::vector<int> sparse_features;
std::vector<float> dense_features;
extractor.Extract(Token{"ŘŘ", 0, 6}, true, &sparse_features, &dense_features);
std::vector<int> sparse_features2;
extractor.Extract(Token{"řř", 0, 6}, true, &sparse_features2,
&dense_features);
EXPECT_THAT(sparse_features, testing::ElementsAreArray(sparse_features2));
}
TEST(TokenFeatureExtractorTest, RegexFeatures) {
TokenFeatureExtractorOptions options;
options.num_buckets = 1000;
options.chargram_orders = std::vector<int>{1, 2};
options.remap_digits = false;
options.unicode_aware_features = false;
options.regexp_features.push_back("^[a-z]+$"); // all lower case.
options.regexp_features.push_back("^[0-9]+$"); // all digits.
TokenFeatureExtractor extractor(options);
std::vector<int> sparse_features;
std::vector<float> dense_features;
extractor.Extract(Token{"abCde", 0, 6}, true, &sparse_features,
&dense_features);
EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0}));
dense_features.clear();
extractor.Extract(Token{"abcde", 0, 6}, true, &sparse_features,
&dense_features);
EXPECT_THAT(dense_features, testing::ElementsAreArray({1.0, -1.0}));
dense_features.clear();
extractor.Extract(Token{"12c45", 0, 6}, true, &sparse_features,
&dense_features);
EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, -1.0}));
dense_features.clear();
extractor.Extract(Token{"12345", 0, 6}, true, &sparse_features,
&dense_features);
EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 1.0}));
}
TEST(TokenFeatureExtractorTest, ExtractTooLongWord) {
TokenFeatureExtractorOptions options;
options.num_buckets = 1000;
options.chargram_orders = std::vector<int>{22};
options.extract_case_feature = true;
options.unicode_aware_features = true;
options.extract_selection_mask_feature = true;
TestingTokenFeatureExtractor extractor(options);
// Test that this runs. ASAN should catch problems.
std::vector<int> sparse_features;
std::vector<float> dense_features;
extractor.Extract(Token{"abcdefghijklmnopqřstuvwxyz", 0, 0}, true,
&sparse_features, &dense_features);
EXPECT_THAT(sparse_features,
testing::ElementsAreArray({
// clang-format off
extractor.HashToken("^abcdefghij\1qřstuvwxyz"),
extractor.HashToken("abcdefghij\1qřstuvwxyz$"),
// clang-format on
}));
}
TEST(TokenFeatureExtractorTest, ExtractAsciiUnicodeMatches) {
TokenFeatureExtractorOptions options;
options.num_buckets = 1000;
options.chargram_orders = std::vector<int>{1, 2, 3, 4, 5};
options.extract_case_feature = true;
options.unicode_aware_features = true;
options.extract_selection_mask_feature = true;
TestingTokenFeatureExtractor extractor_unicode(options);
options.unicode_aware_features = false;
TestingTokenFeatureExtractor extractor_ascii(options);
for (const std::string& input :
{"https://www.abcdefgh.com/in/xxxkkkvayio",
"https://www.fjsidofj.om/xx/abadfy/xxxx/?xfjiis=ffffiijiihil",
"asdfhasdofjiasdofj#%()*%#*(aisdojfaosdifjiaofjdsiofjdi_fdis3w", "abcd",
"x", "Hello", "Hey,", "Hi", ""}) {
std::vector<int> sparse_features_unicode;
std::vector<float> dense_features_unicode;
extractor_unicode.Extract(Token{input, 0, 0}, true,
&sparse_features_unicode,
&dense_features_unicode);
std::vector<int> sparse_features_ascii;
std::vector<float> dense_features_ascii;
extractor_ascii.Extract(Token{input, 0, 0}, true, &sparse_features_ascii,
&dense_features_ascii);
EXPECT_THAT(sparse_features_unicode, sparse_features_ascii) << input;
EXPECT_THAT(dense_features_unicode, dense_features_ascii) << input;
}
}
TEST(TokenFeatureExtractorTest, ExtractForPadToken) {
TokenFeatureExtractorOptions options;
options.num_buckets = 1000;
options.chargram_orders = std::vector<int>{1, 2};
options.extract_case_feature = true;
options.unicode_aware_features = false;
options.extract_selection_mask_feature = true;
TestingTokenFeatureExtractor extractor(options);
std::vector<int> sparse_features;
std::vector<float> dense_features;
extractor.Extract(Token(), false, &sparse_features, &dense_features);
EXPECT_THAT(sparse_features,
testing::ElementsAreArray({extractor.HashToken("<PAD>")}));
EXPECT_THAT(dense_features, testing::ElementsAreArray({-1.0, 0.0}));
}
} // namespace
} // namespace libtextclassifier