blob: c8ced76166ead0f4cbc077f09a49adca6ca42ff5 [file] [log] [blame]
/*
* Copyright (C) 2017 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "text-classifier.h"
#include <fstream>
#include <iostream>
#include <memory>
#include <string>
#include "model_generated.h"
#include "types-test-util.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
namespace libtextclassifier2 {
namespace {
using testing::ElementsAreArray;
using testing::IsEmpty;
using testing::Pair;
using testing::Values;
std::string FirstResult(const std::vector<ClassificationResult>& results) {
if (results.empty()) {
return "<INVALID RESULTS>";
}
return results[0].collection;
}
MATCHER_P3(IsAnnotatedSpan, start, end, best_class, "") {
return testing::Value(arg.span, Pair(start, end)) &&
testing::Value(FirstResult(arg.classification), best_class);
}
std::string ReadFile(const std::string& file_name) {
std::ifstream file_stream(file_name);
return std::string(std::istreambuf_iterator<char>(file_stream), {});
}
std::string GetModelPath() {
return LIBTEXTCLASSIFIER_TEST_DATA_DIR;
}
TEST(TextClassifierTest, EmbeddingExecutorLoadingFails) {
CREATE_UNILIB_FOR_TESTING;
std::unique_ptr<TextClassifier> classifier =
TextClassifier::FromPath(GetModelPath() + "wrong_embeddings.fb", &unilib);
EXPECT_FALSE(classifier);
}
class TextClassifierTest : public ::testing::TestWithParam<const char*> {};
INSTANTIATE_TEST_CASE_P(ClickContext, TextClassifierTest,
Values("test_model_cc.fb"));
INSTANTIATE_TEST_CASE_P(BoundsSensitive, TextClassifierTest,
Values("test_model.fb"));
TEST_P(TextClassifierTest, ClassifyText) {
CREATE_UNILIB_FOR_TESTING;
std::unique_ptr<TextClassifier> classifier =
TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib);
ASSERT_TRUE(classifier);
EXPECT_EQ("other",
FirstResult(classifier->ClassifyText(
"this afternoon Barack Obama gave a speech at", {15, 27})));
EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
"Call me at (800) 123-456 today", {11, 24})));
// More lines.
EXPECT_EQ("other",
FirstResult(classifier->ClassifyText(
"this afternoon Barack Obama gave a speech at|Visit "
"www.google.com every today!|Call me at (800) 123-456 today.",
{15, 27})));
EXPECT_EQ("phone",
FirstResult(classifier->ClassifyText(
"this afternoon Barack Obama gave a speech at|Visit "
"www.google.com every today!|Call me at (800) 123-456 today.",
{90, 103})));
// Single word.
EXPECT_EQ("other", FirstResult(classifier->ClassifyText("obama", {0, 5})));
EXPECT_EQ("other", FirstResult(classifier->ClassifyText("asdf", {0, 4})));
EXPECT_EQ("<INVALID RESULTS>",
FirstResult(classifier->ClassifyText("asdf", {0, 0})));
// Junk.
EXPECT_EQ("<INVALID RESULTS>",
FirstResult(classifier->ClassifyText("", {0, 0})));
EXPECT_EQ("<INVALID RESULTS>", FirstResult(classifier->ClassifyText(
"a\n\n\n\nx x x\n\n\n\n\n\n", {1, 5})));
// Test invalid utf8 input.
EXPECT_EQ("<INVALID RESULTS>", FirstResult(classifier->ClassifyText(
"\xf0\x9f\x98\x8b\x8b", {0, 0})));
}
TEST_P(TextClassifierTest, ClassifyTextDisabledFail) {
CREATE_UNILIB_FOR_TESTING;
const std::string test_model = ReadFile(GetModelPath() + GetParam());
std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
unpacked_model->classification_model.clear();
unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
unpacked_model->triggering_options->enabled_modes = ModeFlag_SELECTION;
flatbuffers::FlatBufferBuilder builder;
builder.Finish(Model::Pack(builder, unpacked_model.get()));
std::unique_ptr<TextClassifier> classifier =
TextClassifier::FromUnownedBuffer(
reinterpret_cast<const char*>(builder.GetBufferPointer()),
builder.GetSize(), &unilib);
// The classification model is still needed for selection scores.
ASSERT_FALSE(classifier);
}
TEST_P(TextClassifierTest, ClassifyTextDisabled) {
CREATE_UNILIB_FOR_TESTING;
const std::string test_model = ReadFile(GetModelPath() + GetParam());
std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
unpacked_model->triggering_options->enabled_modes =
ModeFlag_ANNOTATION_AND_SELECTION;
flatbuffers::FlatBufferBuilder builder;
builder.Finish(Model::Pack(builder, unpacked_model.get()));
std::unique_ptr<TextClassifier> classifier =
TextClassifier::FromUnownedBuffer(
reinterpret_cast<const char*>(builder.GetBufferPointer()),
builder.GetSize(), &unilib);
ASSERT_TRUE(classifier);
EXPECT_THAT(
classifier->ClassifyText("Call me at (800) 123-456 today", {11, 24}),
IsEmpty());
}
TEST_P(TextClassifierTest, ClassifyTextFilteredCollections) {
CREATE_UNILIB_FOR_TESTING;
const std::string test_model = ReadFile(GetModelPath() + GetParam());
std::unique_ptr<TextClassifier> classifier =
TextClassifier::FromUnownedBuffer(test_model.c_str(), test_model.size(),
&unilib);
ASSERT_TRUE(classifier);
EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
"Call me at (800) 123-456 today", {11, 24})));
std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
unpacked_model->output_options.reset(new OutputOptionsT);
// Disable phone classification
unpacked_model->output_options->filtered_collections_classification.push_back(
"phone");
flatbuffers::FlatBufferBuilder builder;
builder.Finish(Model::Pack(builder, unpacked_model.get()));
classifier = TextClassifier::FromUnownedBuffer(
reinterpret_cast<const char*>(builder.GetBufferPointer()),
builder.GetSize(), &unilib);
ASSERT_TRUE(classifier);
EXPECT_EQ("other", FirstResult(classifier->ClassifyText(
"Call me at (800) 123-456 today", {11, 24})));
// Check that the address classification still passes.
EXPECT_EQ("address", FirstResult(classifier->ClassifyText(
"350 Third Street, Cambridge", {0, 27})));
}
std::unique_ptr<RegexModel_::PatternT> MakePattern(
const std::string& collection_name, const std::string& pattern,
const bool enabled_for_classification, const bool enabled_for_selection,
const bool enabled_for_annotation, const float score) {
std::unique_ptr<RegexModel_::PatternT> result(new RegexModel_::PatternT);
result->collection_name = collection_name;
result->pattern = pattern;
// We cannot directly operate with |= on the flag, so use an int here.
int enabled_modes = ModeFlag_NONE;
if (enabled_for_annotation) enabled_modes |= ModeFlag_ANNOTATION;
if (enabled_for_classification) enabled_modes |= ModeFlag_CLASSIFICATION;
if (enabled_for_selection) enabled_modes |= ModeFlag_SELECTION;
result->enabled_modes = static_cast<ModeFlag>(enabled_modes);
result->target_classification_score = score;
result->priority_score = score;
return result;
}
#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
TEST_P(TextClassifierTest, ClassifyTextRegularExpression) {
CREATE_UNILIB_FOR_TESTING;
const std::string test_model = ReadFile(GetModelPath() + GetParam());
std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
// Add test regex models.
unpacked_model->regex_model->patterns.push_back(MakePattern(
"person", "Barack Obama", /*enabled_for_classification=*/true,
/*enabled_for_selection=*/false, /*enabled_for_annotation=*/false, 1.0));
unpacked_model->regex_model->patterns.push_back(MakePattern(
"flight", "[a-zA-Z]{2}\\d{2,4}", /*enabled_for_classification=*/true,
/*enabled_for_selection=*/false, /*enabled_for_annotation=*/false, 0.5));
flatbuffers::FlatBufferBuilder builder;
builder.Finish(Model::Pack(builder, unpacked_model.get()));
std::unique_ptr<TextClassifier> classifier =
TextClassifier::FromUnownedBuffer(
reinterpret_cast<const char*>(builder.GetBufferPointer()),
builder.GetSize(), &unilib);
ASSERT_TRUE(classifier);
EXPECT_EQ("flight",
FirstResult(classifier->ClassifyText(
"Your flight LX373 is delayed by 3 hours.", {12, 17})));
EXPECT_EQ("person",
FirstResult(classifier->ClassifyText(
"this afternoon Barack Obama gave a speech at", {15, 27})));
EXPECT_EQ("email",
FirstResult(classifier->ClassifyText("you@android.com", {0, 15})));
EXPECT_EQ("email", FirstResult(classifier->ClassifyText(
"Contact me at you@android.com", {14, 29})));
EXPECT_EQ("url", FirstResult(classifier->ClassifyText(
"Visit www.google.com every today!", {6, 20})));
EXPECT_EQ("flight", FirstResult(classifier->ClassifyText("LX 37", {0, 5})));
EXPECT_EQ("flight", FirstResult(classifier->ClassifyText("flight LX 37 abcd",
{7, 12})));
// More lines.
EXPECT_EQ("url",
FirstResult(classifier->ClassifyText(
"this afternoon Barack Obama gave a speech at|Visit "
"www.google.com every today!|Call me at (800) 123-456 today.",
{51, 65})));
}
#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
TEST_P(TextClassifierTest, SuggestSelectionRegularExpression) {
CREATE_UNILIB_FOR_TESTING;
const std::string test_model = ReadFile(GetModelPath() + GetParam());
std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
// Add test regex models.
unpacked_model->regex_model->patterns.push_back(MakePattern(
"person", " (Barack Obama) ", /*enabled_for_classification=*/false,
/*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
unpacked_model->regex_model->patterns.push_back(MakePattern(
"flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false,
/*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
unpacked_model->regex_model->patterns.back()->priority_score = 1.1;
flatbuffers::FlatBufferBuilder builder;
builder.Finish(Model::Pack(builder, unpacked_model.get()));
std::unique_ptr<TextClassifier> classifier =
TextClassifier::FromUnownedBuffer(
reinterpret_cast<const char*>(builder.GetBufferPointer()),
builder.GetSize(), &unilib);
ASSERT_TRUE(classifier);
// Check regular expression selection.
EXPECT_EQ(classifier->SuggestSelection(
"Your flight MA 0123 is delayed by 3 hours.", {12, 14}),
std::make_pair(12, 19));
EXPECT_EQ(classifier->SuggestSelection(
"this afternoon Barack Obama gave a speech at", {15, 21}),
std::make_pair(15, 27));
}
#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
TEST_P(TextClassifierTest,
SuggestSelectionRegularExpressionConflictsModelWins) {
const std::string test_model = ReadFile(GetModelPath() + GetParam());
std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
// Add test regex models.
unpacked_model->regex_model->patterns.push_back(MakePattern(
"person", " (Barack Obama) ", /*enabled_for_classification=*/false,
/*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
unpacked_model->regex_model->patterns.push_back(MakePattern(
"flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false,
/*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
unpacked_model->regex_model->patterns.back()->priority_score = 0.5;
flatbuffers::FlatBufferBuilder builder;
builder.Finish(Model::Pack(builder, unpacked_model.get()));
std::unique_ptr<TextClassifier> classifier =
TextClassifier::FromUnownedBuffer(
reinterpret_cast<const char*>(builder.GetBufferPointer()),
builder.GetSize());
ASSERT_TRUE(classifier);
// Check conflict resolution.
EXPECT_EQ(
classifier->SuggestSelection(
"saw Barack Obama today .. 350 Third Street, Cambridge, MA 0123",
{55, 57}),
std::make_pair(26, 62));
}
#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
TEST_P(TextClassifierTest,
SuggestSelectionRegularExpressionConflictsRegexWins) {
const std::string test_model = ReadFile(GetModelPath() + GetParam());
std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
// Add test regex models.
unpacked_model->regex_model->patterns.push_back(MakePattern(
"person", " (Barack Obama) ", /*enabled_for_classification=*/false,
/*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
unpacked_model->regex_model->patterns.push_back(MakePattern(
"flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false,
/*enabled_for_selection=*/true, /*enabled_for_annotation=*/false, 1.0));
unpacked_model->regex_model->patterns.back()->priority_score = 1.1;
flatbuffers::FlatBufferBuilder builder;
builder.Finish(Model::Pack(builder, unpacked_model.get()));
std::unique_ptr<TextClassifier> classifier =
TextClassifier::FromUnownedBuffer(
reinterpret_cast<const char*>(builder.GetBufferPointer()),
builder.GetSize());
ASSERT_TRUE(classifier);
// Check conflict resolution.
EXPECT_EQ(
classifier->SuggestSelection(
"saw Barack Obama today .. 350 Third Street, Cambridge, MA 0123",
{55, 57}),
std::make_pair(55, 62));
}
#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
TEST_P(TextClassifierTest, AnnotateRegex) {
CREATE_UNILIB_FOR_TESTING;
const std::string test_model = ReadFile(GetModelPath() + GetParam());
std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
// Add test regex models.
unpacked_model->regex_model->patterns.push_back(MakePattern(
"person", " (Barack Obama) ", /*enabled_for_classification=*/false,
/*enabled_for_selection=*/false, /*enabled_for_annotation=*/true, 1.0));
unpacked_model->regex_model->patterns.push_back(MakePattern(
"flight", "([a-zA-Z]{2} ?\\d{2,4})", /*enabled_for_classification=*/false,
/*enabled_for_selection=*/false, /*enabled_for_annotation=*/true, 0.5));
flatbuffers::FlatBufferBuilder builder;
builder.Finish(Model::Pack(builder, unpacked_model.get()));
std::unique_ptr<TextClassifier> classifier =
TextClassifier::FromUnownedBuffer(
reinterpret_cast<const char*>(builder.GetBufferPointer()),
builder.GetSize(), &unilib);
ASSERT_TRUE(classifier);
const std::string test_string =
"& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
"number is 853 225 3556";
EXPECT_THAT(classifier->Annotate(test_string),
ElementsAreArray({
IsAnnotatedSpan(6, 18, "person"),
IsAnnotatedSpan(19, 24, "date"),
IsAnnotatedSpan(28, 55, "address"),
IsAnnotatedSpan(79, 91, "phone"),
}));
}
#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
TEST_P(TextClassifierTest, PhoneFiltering) {
CREATE_UNILIB_FOR_TESTING;
std::unique_ptr<TextClassifier> classifier =
TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib);
ASSERT_TRUE(classifier);
EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
"phone: (123) 456 789", {7, 20})));
EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
"phone: (123) 456 789,0001112", {7, 25})));
EXPECT_EQ("other", FirstResult(classifier->ClassifyText(
"phone: (123) 456 789,0001112", {7, 28})));
}
TEST_P(TextClassifierTest, SuggestSelection) {
CREATE_UNILIB_FOR_TESTING;
std::unique_ptr<TextClassifier> classifier =
TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib);
ASSERT_TRUE(classifier);
EXPECT_EQ(classifier->SuggestSelection(
"this afternoon Barack Obama gave a speech at", {15, 21}),
std::make_pair(15, 21));
// Try passing whole string.
// If more than 1 token is specified, we should return back what entered.
EXPECT_EQ(
classifier->SuggestSelection("350 Third Street, Cambridge", {0, 27}),
std::make_pair(0, 27));
// Single letter.
EXPECT_EQ(classifier->SuggestSelection("a", {0, 1}), std::make_pair(0, 1));
// Single word.
EXPECT_EQ(classifier->SuggestSelection("asdf", {0, 4}), std::make_pair(0, 4));
EXPECT_EQ(
classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}),
std::make_pair(11, 23));
// Unpaired bracket stripping.
EXPECT_EQ(
classifier->SuggestSelection("call me at (857) 225 3556 today", {11, 16}),
std::make_pair(11, 25));
EXPECT_EQ(classifier->SuggestSelection("call me at (857 today", {11, 15}),
std::make_pair(12, 15));
EXPECT_EQ(classifier->SuggestSelection("call me at 3556) today", {11, 16}),
std::make_pair(11, 15));
EXPECT_EQ(classifier->SuggestSelection("call me at )857( today", {11, 16}),
std::make_pair(12, 15));
// If the resulting selection would be empty, the original span is returned.
EXPECT_EQ(classifier->SuggestSelection("call me at )( today", {11, 13}),
std::make_pair(11, 13));
EXPECT_EQ(classifier->SuggestSelection("call me at ( today", {11, 12}),
std::make_pair(11, 12));
EXPECT_EQ(classifier->SuggestSelection("call me at ) today", {11, 12}),
std::make_pair(11, 12));
}
TEST_P(TextClassifierTest, SuggestSelectionDisabledFail) {
CREATE_UNILIB_FOR_TESTING;
const std::string test_model = ReadFile(GetModelPath() + GetParam());
std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
// Disable the selection model.
unpacked_model->selection_model.clear();
unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
unpacked_model->triggering_options->enabled_modes = ModeFlag_ANNOTATION;
flatbuffers::FlatBufferBuilder builder;
builder.Finish(Model::Pack(builder, unpacked_model.get()));
std::unique_ptr<TextClassifier> classifier =
TextClassifier::FromUnownedBuffer(
reinterpret_cast<const char*>(builder.GetBufferPointer()),
builder.GetSize(), &unilib);
// Selection model needs to be present for annotation.
ASSERT_FALSE(classifier);
}
TEST_P(TextClassifierTest, SuggestSelectionDisabled) {
CREATE_UNILIB_FOR_TESTING;
const std::string test_model = ReadFile(GetModelPath() + GetParam());
std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
// Disable the selection model.
unpacked_model->selection_model.clear();
unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
unpacked_model->triggering_options->enabled_modes = ModeFlag_CLASSIFICATION;
unpacked_model->enabled_modes = ModeFlag_CLASSIFICATION;
flatbuffers::FlatBufferBuilder builder;
builder.Finish(Model::Pack(builder, unpacked_model.get()));
std::unique_ptr<TextClassifier> classifier =
TextClassifier::FromUnownedBuffer(
reinterpret_cast<const char*>(builder.GetBufferPointer()),
builder.GetSize(), &unilib);
ASSERT_TRUE(classifier);
EXPECT_EQ(
classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}),
std::make_pair(11, 14));
EXPECT_EQ("phone", FirstResult(classifier->ClassifyText(
"call me at (800) 123-456 today", {11, 24})));
EXPECT_THAT(classifier->Annotate("call me at (800) 123-456 today"),
IsEmpty());
}
TEST_P(TextClassifierTest, SuggestSelectionFilteredCollections) {
CREATE_UNILIB_FOR_TESTING;
const std::string test_model = ReadFile(GetModelPath() + GetParam());
std::unique_ptr<TextClassifier> classifier =
TextClassifier::FromUnownedBuffer(test_model.c_str(), test_model.size(),
&unilib);
ASSERT_TRUE(classifier);
EXPECT_EQ(
classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}),
std::make_pair(11, 23));
std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
unpacked_model->output_options.reset(new OutputOptionsT);
// Disable phone selection
unpacked_model->output_options->filtered_collections_selection.push_back(
"phone");
// We need to force this for filtering.
unpacked_model->selection_options->always_classify_suggested_selection = true;
flatbuffers::FlatBufferBuilder builder;
builder.Finish(Model::Pack(builder, unpacked_model.get()));
classifier = TextClassifier::FromUnownedBuffer(
reinterpret_cast<const char*>(builder.GetBufferPointer()),
builder.GetSize(), &unilib);
ASSERT_TRUE(classifier);
EXPECT_EQ(
classifier->SuggestSelection("call me at 857 225 3556 today", {11, 14}),
std::make_pair(11, 14));
// Address selection should still work.
EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {4, 9}),
std::make_pair(0, 27));
}
TEST_P(TextClassifierTest, SuggestSelectionsAreSymmetric) {
CREATE_UNILIB_FOR_TESTING;
std::unique_ptr<TextClassifier> classifier =
TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib);
ASSERT_TRUE(classifier);
EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {0, 3}),
std::make_pair(0, 27));
EXPECT_EQ(classifier->SuggestSelection("350 Third Street, Cambridge", {4, 9}),
std::make_pair(0, 27));
EXPECT_EQ(
classifier->SuggestSelection("350 Third Street, Cambridge", {10, 16}),
std::make_pair(0, 27));
EXPECT_EQ(classifier->SuggestSelection("a\nb\nc\n350 Third Street, Cambridge",
{16, 22}),
std::make_pair(6, 33));
}
TEST_P(TextClassifierTest, SuggestSelectionWithNewLine) {
CREATE_UNILIB_FOR_TESTING;
std::unique_ptr<TextClassifier> classifier =
TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib);
ASSERT_TRUE(classifier);
EXPECT_EQ(classifier->SuggestSelection("abc\n857 225 3556", {4, 7}),
std::make_pair(4, 16));
EXPECT_EQ(classifier->SuggestSelection("857 225 3556\nabc", {0, 3}),
std::make_pair(0, 12));
SelectionOptions options;
EXPECT_EQ(classifier->SuggestSelection("857 225\n3556\nabc", {0, 3}, options),
std::make_pair(0, 7));
}
TEST_P(TextClassifierTest, SuggestSelectionWithPunctuation) {
CREATE_UNILIB_FOR_TESTING;
std::unique_ptr<TextClassifier> classifier =
TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib);
ASSERT_TRUE(classifier);
// From the right.
EXPECT_EQ(classifier->SuggestSelection(
"this afternoon BarackObama, gave a speech at", {15, 26}),
std::make_pair(15, 26));
// From the right multiple.
EXPECT_EQ(classifier->SuggestSelection(
"this afternoon BarackObama,.,.,, gave a speech at", {15, 26}),
std::make_pair(15, 26));
// From the left multiple.
EXPECT_EQ(classifier->SuggestSelection(
"this afternoon ,.,.,,BarackObama gave a speech at", {21, 32}),
std::make_pair(21, 32));
// From both sides.
EXPECT_EQ(classifier->SuggestSelection(
"this afternoon !BarackObama,- gave a speech at", {16, 27}),
std::make_pair(16, 27));
}
TEST_P(TextClassifierTest, SuggestSelectionNoCrashWithJunk) {
CREATE_UNILIB_FOR_TESTING;
std::unique_ptr<TextClassifier> classifier =
TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib);
ASSERT_TRUE(classifier);
// Try passing in bunch of invalid selections.
EXPECT_EQ(classifier->SuggestSelection("", {0, 27}), std::make_pair(0, 27));
EXPECT_EQ(classifier->SuggestSelection("", {-10, 27}),
std::make_pair(-10, 27));
EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {0, 27}),
std::make_pair(0, 27));
EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {-30, 300}),
std::make_pair(-30, 300));
EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {-10, -1}),
std::make_pair(-10, -1));
EXPECT_EQ(classifier->SuggestSelection("Word 1 2 3 hello!", {100, 17}),
std::make_pair(100, 17));
// Try passing invalid utf8.
EXPECT_EQ(classifier->SuggestSelection("\xf0\x9f\x98\x8b\x8b", {-1, -1}),
std::make_pair(-1, -1));
}
TEST_P(TextClassifierTest, SuggestSelectionSelectSpace) {
CREATE_UNILIB_FOR_TESTING;
std::unique_ptr<TextClassifier> classifier =
TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib);
ASSERT_TRUE(classifier);
EXPECT_EQ(
classifier->SuggestSelection("call me at 857 225 3556 today", {14, 15}),
std::make_pair(11, 23));
EXPECT_EQ(
classifier->SuggestSelection("call me at 857 225 3556 today", {10, 11}),
std::make_pair(10, 11));
EXPECT_EQ(
classifier->SuggestSelection("call me at 857 225 3556 today", {23, 24}),
std::make_pair(23, 24));
EXPECT_EQ(
classifier->SuggestSelection("call me at 857 225 3556, today", {23, 24}),
std::make_pair(23, 24));
EXPECT_EQ(classifier->SuggestSelection("call me at 857 225 3556, today",
{14, 17}),
std::make_pair(11, 25));
EXPECT_EQ(
classifier->SuggestSelection("call me at 857-225 3556, today", {14, 17}),
std::make_pair(11, 23));
EXPECT_EQ(
classifier->SuggestSelection(
"let's meet at 350 Third Street Cambridge and go there", {30, 31}),
std::make_pair(14, 40));
EXPECT_EQ(classifier->SuggestSelection("call me today", {4, 5}),
std::make_pair(4, 5));
EXPECT_EQ(classifier->SuggestSelection("call me today", {7, 8}),
std::make_pair(7, 8));
// With a punctuation around the selected whitespace.
EXPECT_EQ(
classifier->SuggestSelection(
"let's meet at 350 Third Street, Cambridge and go there", {31, 32}),
std::make_pair(14, 41));
// When all's whitespace, should return the original indices.
EXPECT_EQ(classifier->SuggestSelection(" ", {0, 1}),
std::make_pair(0, 1));
EXPECT_EQ(classifier->SuggestSelection(" ", {0, 3}),
std::make_pair(0, 3));
EXPECT_EQ(classifier->SuggestSelection(" ", {2, 3}),
std::make_pair(2, 3));
EXPECT_EQ(classifier->SuggestSelection(" ", {5, 6}),
std::make_pair(5, 6));
}
TEST(TextClassifierTest, SnapLeftIfWhitespaceSelection) {
CREATE_UNILIB_FOR_TESTING;
UnicodeText text;
text = UTF8ToUnicodeText("abcd efgh", /*do_copy=*/false);
EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, unilib),
std::make_pair(3, 4));
text = UTF8ToUnicodeText("abcd ", /*do_copy=*/false);
EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, unilib),
std::make_pair(3, 4));
// Nothing on the left.
text = UTF8ToUnicodeText(" efgh", /*do_copy=*/false);
EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, unilib),
std::make_pair(4, 5));
text = UTF8ToUnicodeText(" efgh", /*do_copy=*/false);
EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({0, 1}, text, unilib),
std::make_pair(0, 1));
// Whitespace only.
text = UTF8ToUnicodeText(" ", /*do_copy=*/false);
EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({2, 3}, text, unilib),
std::make_pair(2, 3));
text = UTF8ToUnicodeText(" ", /*do_copy=*/false);
EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({4, 5}, text, unilib),
std::make_pair(4, 5));
text = UTF8ToUnicodeText(" ", /*do_copy=*/false);
EXPECT_EQ(internal::SnapLeftIfWhitespaceSelection({0, 1}, text, unilib),
std::make_pair(0, 1));
}
TEST_P(TextClassifierTest, Annotate) {
CREATE_UNILIB_FOR_TESTING;
std::unique_ptr<TextClassifier> classifier =
TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib);
ASSERT_TRUE(classifier);
const std::string test_string =
"& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
"number is 853 225 3556";
EXPECT_THAT(classifier->Annotate(test_string),
ElementsAreArray({
#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
IsAnnotatedSpan(19, 24, "date"),
#endif
IsAnnotatedSpan(28, 55, "address"),
IsAnnotatedSpan(79, 91, "phone"),
}));
AnnotationOptions options;
EXPECT_THAT(classifier->Annotate("853 225 3556", options),
ElementsAreArray({IsAnnotatedSpan(0, 12, "phone")}));
EXPECT_TRUE(classifier->Annotate("853 225\n3556", options).empty());
// Try passing invalid utf8.
EXPECT_TRUE(
classifier->Annotate("853 225 3556\n\xf0\x9f\x98\x8b\x8b", options)
.empty());
}
TEST_P(TextClassifierTest, AnnotateSmallBatches) {
CREATE_UNILIB_FOR_TESTING;
const std::string test_model = ReadFile(GetModelPath() + GetParam());
std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
// Set the batch size.
unpacked_model->selection_options->batch_size = 4;
flatbuffers::FlatBufferBuilder builder;
builder.Finish(Model::Pack(builder, unpacked_model.get()));
std::unique_ptr<TextClassifier> classifier =
TextClassifier::FromUnownedBuffer(
reinterpret_cast<const char*>(builder.GetBufferPointer()),
builder.GetSize(), &unilib);
ASSERT_TRUE(classifier);
const std::string test_string =
"& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
"number is 853 225 3556";
EXPECT_THAT(classifier->Annotate(test_string),
ElementsAreArray({
#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
IsAnnotatedSpan(19, 24, "date"),
#endif
IsAnnotatedSpan(28, 55, "address"),
IsAnnotatedSpan(79, 91, "phone"),
}));
AnnotationOptions options;
EXPECT_THAT(classifier->Annotate("853 225 3556", options),
ElementsAreArray({IsAnnotatedSpan(0, 12, "phone")}));
EXPECT_TRUE(classifier->Annotate("853 225\n3556", options).empty());
}
#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
TEST_P(TextClassifierTest, AnnotateFilteringDiscardAll) {
CREATE_UNILIB_FOR_TESTING;
const std::string test_model = ReadFile(GetModelPath() + GetParam());
std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
// Add test threshold.
unpacked_model->triggering_options->min_annotate_confidence =
2.f; // Discards all results.
flatbuffers::FlatBufferBuilder builder;
builder.Finish(Model::Pack(builder, unpacked_model.get()));
std::unique_ptr<TextClassifier> classifier =
TextClassifier::FromUnownedBuffer(
reinterpret_cast<const char*>(builder.GetBufferPointer()),
builder.GetSize(), &unilib);
ASSERT_TRUE(classifier);
const std::string test_string =
"& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
"number is 853 225 3556";
EXPECT_EQ(classifier->Annotate(test_string).size(), 1);
}
#endif
TEST_P(TextClassifierTest, AnnotateFilteringKeepAll) {
CREATE_UNILIB_FOR_TESTING;
const std::string test_model = ReadFile(GetModelPath() + GetParam());
std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
// Add test thresholds.
unpacked_model->triggering_options.reset(new ModelTriggeringOptionsT);
unpacked_model->triggering_options->min_annotate_confidence =
0.f; // Keeps all results.
unpacked_model->triggering_options->enabled_modes = ModeFlag_ALL;
flatbuffers::FlatBufferBuilder builder;
builder.Finish(Model::Pack(builder, unpacked_model.get()));
std::unique_ptr<TextClassifier> classifier =
TextClassifier::FromUnownedBuffer(
reinterpret_cast<const char*>(builder.GetBufferPointer()),
builder.GetSize(), &unilib);
ASSERT_TRUE(classifier);
const std::string test_string =
"& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
"number is 853 225 3556";
#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
EXPECT_EQ(classifier->Annotate(test_string).size(), 3);
#else
// In non-ICU mode there is no "date" result.
EXPECT_EQ(classifier->Annotate(test_string).size(), 2);
#endif
}
TEST_P(TextClassifierTest, AnnotateDisabled) {
CREATE_UNILIB_FOR_TESTING;
const std::string test_model = ReadFile(GetModelPath() + GetParam());
std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
// Disable the model for annotation.
unpacked_model->enabled_modes = ModeFlag_CLASSIFICATION_AND_SELECTION;
flatbuffers::FlatBufferBuilder builder;
builder.Finish(Model::Pack(builder, unpacked_model.get()));
std::unique_ptr<TextClassifier> classifier =
TextClassifier::FromUnownedBuffer(
reinterpret_cast<const char*>(builder.GetBufferPointer()),
builder.GetSize(), &unilib);
ASSERT_TRUE(classifier);
const std::string test_string =
"& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
"number is 853 225 3556";
EXPECT_THAT(classifier->Annotate(test_string), IsEmpty());
}
TEST_P(TextClassifierTest, AnnotateFilteredCollections) {
CREATE_UNILIB_FOR_TESTING;
const std::string test_model = ReadFile(GetModelPath() + GetParam());
std::unique_ptr<TextClassifier> classifier =
TextClassifier::FromUnownedBuffer(test_model.c_str(), test_model.size(),
&unilib);
ASSERT_TRUE(classifier);
const std::string test_string =
"& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
"number is 853 225 3556";
EXPECT_THAT(classifier->Annotate(test_string),
ElementsAreArray({
#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
IsAnnotatedSpan(19, 24, "date"),
#endif
IsAnnotatedSpan(28, 55, "address"),
IsAnnotatedSpan(79, 91, "phone"),
}));
std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
unpacked_model->output_options.reset(new OutputOptionsT);
// Disable phone annotation
unpacked_model->output_options->filtered_collections_annotation.push_back(
"phone");
flatbuffers::FlatBufferBuilder builder;
builder.Finish(Model::Pack(builder, unpacked_model.get()));
classifier = TextClassifier::FromUnownedBuffer(
reinterpret_cast<const char*>(builder.GetBufferPointer()),
builder.GetSize(), &unilib);
ASSERT_TRUE(classifier);
EXPECT_THAT(classifier->Annotate(test_string),
ElementsAreArray({
#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
IsAnnotatedSpan(19, 24, "date"),
#endif
IsAnnotatedSpan(28, 55, "address"),
}));
}
#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
TEST_P(TextClassifierTest, AnnotateFilteredCollectionsSuppress) {
CREATE_UNILIB_FOR_TESTING;
const std::string test_model = ReadFile(GetModelPath() + GetParam());
std::unique_ptr<TextClassifier> classifier =
TextClassifier::FromUnownedBuffer(test_model.c_str(), test_model.size(),
&unilib);
ASSERT_TRUE(classifier);
const std::string test_string =
"& saw Barack Obama today .. 350 Third Street, Cambridge\nand my phone "
"number is 853 225 3556";
EXPECT_THAT(classifier->Annotate(test_string),
ElementsAreArray({
#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
IsAnnotatedSpan(19, 24, "date"),
#endif
IsAnnotatedSpan(28, 55, "address"),
IsAnnotatedSpan(79, 91, "phone"),
}));
std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
unpacked_model->output_options.reset(new OutputOptionsT);
// We add a custom annotator that wins against the phone classification
// below and that we subsequently suppress.
unpacked_model->output_options->filtered_collections_annotation.push_back(
"suppress");
unpacked_model->regex_model->patterns.push_back(MakePattern(
"suppress", "(\\d{3} ?\\d{4})",
/*enabled_for_classification=*/false,
/*enabled_for_selection=*/false, /*enabled_for_annotation=*/true, 2.0));
flatbuffers::FlatBufferBuilder builder;
builder.Finish(Model::Pack(builder, unpacked_model.get()));
classifier = TextClassifier::FromUnownedBuffer(
reinterpret_cast<const char*>(builder.GetBufferPointer()),
builder.GetSize(), &unilib);
ASSERT_TRUE(classifier);
EXPECT_THAT(classifier->Annotate(test_string),
ElementsAreArray({
IsAnnotatedSpan(19, 24, "date"),
IsAnnotatedSpan(28, 55, "address"),
}));
}
#endif
#ifdef LIBTEXTCLASSIFIER_CALENDAR_ICU
TEST_P(TextClassifierTest, ClassifyTextDate) {
std::unique_ptr<TextClassifier> classifier =
TextClassifier::FromPath(GetModelPath() + GetParam());
EXPECT_TRUE(classifier);
std::vector<ClassificationResult> result;
ClassificationOptions options;
options.reference_timezone = "Europe/Zurich";
result = classifier->ClassifyText("january 1, 2017", {0, 15}, options);
ASSERT_EQ(result.size(), 1);
EXPECT_THAT(result[0].collection, "date");
EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 1483225200000);
EXPECT_EQ(result[0].datetime_parse_result.granularity,
DatetimeGranularity::GRANULARITY_DAY);
result.clear();
options.reference_timezone = "America/Los_Angeles";
result = classifier->ClassifyText("march 1, 2017", {0, 13}, options);
ASSERT_EQ(result.size(), 1);
EXPECT_THAT(result[0].collection, "date");
EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 1488355200000);
EXPECT_EQ(result[0].datetime_parse_result.granularity,
DatetimeGranularity::GRANULARITY_DAY);
result.clear();
options.reference_timezone = "America/Los_Angeles";
result = classifier->ClassifyText("2018/01/01 10:30:20", {0, 19}, options);
ASSERT_EQ(result.size(), 1);
EXPECT_THAT(result[0].collection, "date");
EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 1514831420000);
EXPECT_EQ(result[0].datetime_parse_result.granularity,
DatetimeGranularity::GRANULARITY_SECOND);
result.clear();
// Date on another line.
options.reference_timezone = "Europe/Zurich";
result = classifier->ClassifyText(
"hello world this is the first line\n"
"january 1, 2017",
{35, 50}, options);
ASSERT_EQ(result.size(), 1);
EXPECT_THAT(result[0].collection, "date");
EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 1483225200000);
EXPECT_EQ(result[0].datetime_parse_result.granularity,
DatetimeGranularity::GRANULARITY_DAY);
}
#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
#ifdef LIBTEXTCLASSIFIER_CALENDAR_ICU
TEST_P(TextClassifierTest, ClassifyTextDatePriorities) {
std::unique_ptr<TextClassifier> classifier =
TextClassifier::FromPath(GetModelPath() + GetParam());
EXPECT_TRUE(classifier);
std::vector<ClassificationResult> result;
ClassificationOptions options;
result.clear();
options.reference_timezone = "Europe/Zurich";
options.locales = "en-US";
result = classifier->ClassifyText("03.05.1970", {0, 10}, options);
ASSERT_EQ(result.size(), 1);
EXPECT_THAT(result[0].collection, "date");
EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 5439600000);
EXPECT_EQ(result[0].datetime_parse_result.granularity,
DatetimeGranularity::GRANULARITY_DAY);
result.clear();
options.reference_timezone = "Europe/Zurich";
options.locales = "de";
result = classifier->ClassifyText("03.05.1970", {0, 10}, options);
ASSERT_EQ(result.size(), 1);
EXPECT_THAT(result[0].collection, "date");
EXPECT_EQ(result[0].datetime_parse_result.time_ms_utc, 10537200000);
EXPECT_EQ(result[0].datetime_parse_result.granularity,
DatetimeGranularity::GRANULARITY_DAY);
}
#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
#ifdef LIBTEXTCLASSIFIER_CALENDAR_ICU
TEST_P(TextClassifierTest, SuggestTextDateDisabled) {
CREATE_UNILIB_FOR_TESTING;
const std::string test_model = ReadFile(GetModelPath() + GetParam());
std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
// Disable the patterns for selection.
for (int i = 0; i < unpacked_model->datetime_model->patterns.size(); i++) {
unpacked_model->datetime_model->patterns[i]->enabled_modes =
ModeFlag_ANNOTATION_AND_CLASSIFICATION;
}
flatbuffers::FlatBufferBuilder builder;
builder.Finish(Model::Pack(builder, unpacked_model.get()));
std::unique_ptr<TextClassifier> classifier =
TextClassifier::FromUnownedBuffer(
reinterpret_cast<const char*>(builder.GetBufferPointer()),
builder.GetSize(), &unilib);
ASSERT_TRUE(classifier);
EXPECT_EQ("date",
FirstResult(classifier->ClassifyText("january 1, 2017", {0, 15})));
EXPECT_EQ(classifier->SuggestSelection("january 1, 2017", {0, 7}),
std::make_pair(0, 7));
EXPECT_THAT(classifier->Annotate("january 1, 2017"),
ElementsAreArray({IsAnnotatedSpan(0, 15, "date")}));
}
#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
class TestingTextClassifier : public TextClassifier {
public:
TestingTextClassifier(const std::string& model, const UniLib* unilib)
: TextClassifier(ViewModel(model.data(), model.size()), unilib) {}
using TextClassifier::ResolveConflicts;
};
AnnotatedSpan MakeAnnotatedSpan(CodepointSpan span,
const std::string& collection,
const float score) {
AnnotatedSpan result;
result.span = span;
result.classification.push_back({collection, score});
return result;
}
TEST(TextClassifierTest, ResolveConflictsTrivial) {
CREATE_UNILIB_FOR_TESTING;
TestingTextClassifier classifier("", &unilib);
std::vector<AnnotatedSpan> candidates{
{MakeAnnotatedSpan({0, 1}, "phone", 1.0)}};
std::vector<int> chosen;
classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
/*interpreter_manager=*/nullptr, &chosen);
EXPECT_THAT(chosen, ElementsAreArray({0}));
}
TEST(TextClassifierTest, ResolveConflictsSequence) {
CREATE_UNILIB_FOR_TESTING;
TestingTextClassifier classifier("", &unilib);
std::vector<AnnotatedSpan> candidates{{
MakeAnnotatedSpan({0, 1}, "phone", 1.0),
MakeAnnotatedSpan({1, 2}, "phone", 1.0),
MakeAnnotatedSpan({2, 3}, "phone", 1.0),
MakeAnnotatedSpan({3, 4}, "phone", 1.0),
MakeAnnotatedSpan({4, 5}, "phone", 1.0),
}};
std::vector<int> chosen;
classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
/*interpreter_manager=*/nullptr, &chosen);
EXPECT_THAT(chosen, ElementsAreArray({0, 1, 2, 3, 4}));
}
TEST(TextClassifierTest, ResolveConflictsThreeSpans) {
CREATE_UNILIB_FOR_TESTING;
TestingTextClassifier classifier("", &unilib);
std::vector<AnnotatedSpan> candidates{{
MakeAnnotatedSpan({0, 3}, "phone", 1.0),
MakeAnnotatedSpan({1, 5}, "phone", 0.5), // Looser!
MakeAnnotatedSpan({3, 7}, "phone", 1.0),
}};
std::vector<int> chosen;
classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
/*interpreter_manager=*/nullptr, &chosen);
EXPECT_THAT(chosen, ElementsAreArray({0, 2}));
}
TEST(TextClassifierTest, ResolveConflictsThreeSpansReversed) {
CREATE_UNILIB_FOR_TESTING;
TestingTextClassifier classifier("", &unilib);
std::vector<AnnotatedSpan> candidates{{
MakeAnnotatedSpan({0, 3}, "phone", 0.5), // Looser!
MakeAnnotatedSpan({1, 5}, "phone", 1.0),
MakeAnnotatedSpan({3, 7}, "phone", 0.6), // Looser!
}};
std::vector<int> chosen;
classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
/*interpreter_manager=*/nullptr, &chosen);
EXPECT_THAT(chosen, ElementsAreArray({1}));
}
TEST(TextClassifierTest, ResolveConflictsFiveSpans) {
CREATE_UNILIB_FOR_TESTING;
TestingTextClassifier classifier("", &unilib);
std::vector<AnnotatedSpan> candidates{{
MakeAnnotatedSpan({0, 3}, "phone", 0.5),
MakeAnnotatedSpan({1, 5}, "other", 1.0), // Looser!
MakeAnnotatedSpan({3, 7}, "phone", 0.6),
MakeAnnotatedSpan({8, 12}, "phone", 0.6), // Looser!
MakeAnnotatedSpan({11, 15}, "phone", 0.9),
}};
std::vector<int> chosen;
classifier.ResolveConflicts(candidates, /*context=*/"", /*cached_tokens=*/{},
/*interpreter_manager=*/nullptr, &chosen);
EXPECT_THAT(chosen, ElementsAreArray({0, 2, 4}));
}
#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
TEST_P(TextClassifierTest, LongInput) {
CREATE_UNILIB_FOR_TESTING;
std::unique_ptr<TextClassifier> classifier =
TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib);
ASSERT_TRUE(classifier);
for (const auto& type_value_pair :
std::vector<std::pair<std::string, std::string>>{
{"address", "350 Third Street, Cambridge"},
{"phone", "123 456-7890"},
{"url", "www.google.com"},
{"email", "someone@gmail.com"},
{"flight", "LX 38"},
{"date", "September 1, 2018"}}) {
const std::string input_100k = std::string(50000, ' ') +
type_value_pair.second +
std::string(50000, ' ');
const int value_length = type_value_pair.second.size();
EXPECT_THAT(classifier->Annotate(input_100k),
ElementsAreArray({IsAnnotatedSpan(50000, 50000 + value_length,
type_value_pair.first)}));
EXPECT_EQ(classifier->SuggestSelection(input_100k, {50000, 50001}),
std::make_pair(50000, 50000 + value_length));
EXPECT_EQ(type_value_pair.first,
FirstResult(classifier->ClassifyText(
input_100k, {50000, 50000 + value_length})));
}
}
#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
// These coarse tests are there only to make sure the execution happens in
// reasonable amount of time.
TEST_P(TextClassifierTest, LongInputNoResultCheck) {
CREATE_UNILIB_FOR_TESTING;
std::unique_ptr<TextClassifier> classifier =
TextClassifier::FromPath(GetModelPath() + GetParam(), &unilib);
ASSERT_TRUE(classifier);
for (const std::string& value :
std::vector<std::string>{"http://www.aaaaaaaaaaaaaaaaaaaa.com "}) {
const std::string input_100k =
std::string(50000, ' ') + value + std::string(50000, ' ');
const int value_length = value.size();
classifier->Annotate(input_100k);
classifier->SuggestSelection(input_100k, {50000, 50001});
classifier->ClassifyText(input_100k, {50000, 50000 + value_length});
}
}
#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
TEST_P(TextClassifierTest, MaxTokenLength) {
CREATE_UNILIB_FOR_TESTING;
const std::string test_model = ReadFile(GetModelPath() + GetParam());
std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
std::unique_ptr<TextClassifier> classifier;
// With unrestricted number of tokens should behave normally.
unpacked_model->classification_options->max_num_tokens = -1;
flatbuffers::FlatBufferBuilder builder;
builder.Finish(Model::Pack(builder, unpacked_model.get()));
classifier = TextClassifier::FromUnownedBuffer(
reinterpret_cast<const char*>(builder.GetBufferPointer()),
builder.GetSize(), &unilib);
ASSERT_TRUE(classifier);
EXPECT_EQ(FirstResult(classifier->ClassifyText(
"I live at 350 Third Street, Cambridge.", {10, 37})),
"address");
// Raise the maximum number of tokens to suppress the classification.
unpacked_model->classification_options->max_num_tokens = 3;
flatbuffers::FlatBufferBuilder builder2;
builder2.Finish(Model::Pack(builder2, unpacked_model.get()));
classifier = TextClassifier::FromUnownedBuffer(
reinterpret_cast<const char*>(builder2.GetBufferPointer()),
builder2.GetSize(), &unilib);
ASSERT_TRUE(classifier);
EXPECT_EQ(FirstResult(classifier->ClassifyText(
"I live at 350 Third Street, Cambridge.", {10, 37})),
"other");
}
#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
#ifdef LIBTEXTCLASSIFIER_UNILIB_ICU
TEST_P(TextClassifierTest, MinAddressTokenLength) {
CREATE_UNILIB_FOR_TESTING;
const std::string test_model = ReadFile(GetModelPath() + GetParam());
std::unique_ptr<ModelT> unpacked_model = UnPackModel(test_model.c_str());
std::unique_ptr<TextClassifier> classifier;
// With unrestricted number of address tokens should behave normally.
unpacked_model->classification_options->address_min_num_tokens = 0;
flatbuffers::FlatBufferBuilder builder;
builder.Finish(Model::Pack(builder, unpacked_model.get()));
classifier = TextClassifier::FromUnownedBuffer(
reinterpret_cast<const char*>(builder.GetBufferPointer()),
builder.GetSize(), &unilib);
ASSERT_TRUE(classifier);
EXPECT_EQ(FirstResult(classifier->ClassifyText(
"I live at 350 Third Street, Cambridge.", {10, 37})),
"address");
// Raise number of address tokens to suppress the address classification.
unpacked_model->classification_options->address_min_num_tokens = 5;
flatbuffers::FlatBufferBuilder builder2;
builder2.Finish(Model::Pack(builder2, unpacked_model.get()));
classifier = TextClassifier::FromUnownedBuffer(
reinterpret_cast<const char*>(builder2.GetBufferPointer()),
builder2.GetSize(), &unilib);
ASSERT_TRUE(classifier);
EXPECT_EQ(FirstResult(classifier->ClassifyText(
"I live at 350 Third Street, Cambridge.", {10, 37})),
"other");
}
#endif // LIBTEXTCLASSIFIER_UNILIB_ICU
} // namespace
} // namespace libtextclassifier2