blob: 78548feab8a76041166a9ab8a6c845e5b97e6723 [file] [log] [blame]
/*
* Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "annotator/duration/duration.h"
#include <string>
#include <vector>
#include "annotator/collections.h"
#include "annotator/model_generated.h"
#include "annotator/types-test-util.h"
#include "annotator/types.h"
#include "utils/test-utils.h"
#include "utils/utf8/unicodetext.h"
#include "utils/utf8/unilib.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
namespace libtextclassifier3 {
namespace {
using testing::AllOf;
using testing::ElementsAre;
using testing::Field;
const DurationAnnotatorOptions* TestingDurationAnnotatorOptions() {
static const flatbuffers::DetachedBuffer* options_data = []() {
DurationAnnotatorOptionsT options;
options.enabled = true;
options.week_expressions.push_back("week");
options.week_expressions.push_back("weeks");
options.day_expressions.push_back("day");
options.day_expressions.push_back("days");
options.hour_expressions.push_back("hour");
options.hour_expressions.push_back("hours");
options.minute_expressions.push_back("minute");
options.minute_expressions.push_back("minutes");
options.second_expressions.push_back("second");
options.second_expressions.push_back("seconds");
options.filler_expressions.push_back("and");
options.filler_expressions.push_back("a");
options.filler_expressions.push_back("an");
options.filler_expressions.push_back("one");
options.half_expressions.push_back("half");
flatbuffers::FlatBufferBuilder builder;
builder.Finish(DurationAnnotatorOptions::Pack(builder, &options));
return new flatbuffers::DetachedBuffer(builder.Release());
}();
return flatbuffers::GetRoot<DurationAnnotatorOptions>(options_data->data());
}
FeatureProcessor BuildFeatureProcessor(const UniLib* unilib) {
static const flatbuffers::DetachedBuffer* options_data = []() {
FeatureProcessorOptionsT options;
options.context_size = 1;
options.max_selection_span = 1;
options.snap_label_span_boundaries_to_containing_tokens = false;
options.ignored_span_boundary_codepoints.push_back(',');
options.tokenization_codepoint_config.emplace_back(
new TokenizationCodepointRangeT());
auto& config = options.tokenization_codepoint_config.back();
config->start = 32;
config->end = 33;
config->role = TokenizationCodepointRange_::Role_WHITESPACE_SEPARATOR;
flatbuffers::FlatBufferBuilder builder;
builder.Finish(FeatureProcessorOptions::Pack(builder, &options));
return new flatbuffers::DetachedBuffer(builder.Release());
}();
const FeatureProcessorOptions* feature_processor_options =
flatbuffers::GetRoot<FeatureProcessorOptions>(options_data->data());
return FeatureProcessor(feature_processor_options, unilib);
}
class DurationAnnotatorTest : public ::testing::Test {
protected:
DurationAnnotatorTest()
: INIT_UNILIB_FOR_TESTING(unilib_),
feature_processor_(BuildFeatureProcessor(&unilib_)),
duration_annotator_(TestingDurationAnnotatorOptions(),
&feature_processor_) {}
std::vector<Token> Tokenize(const UnicodeText& text) {
return feature_processor_.Tokenize(text);
}
UniLib unilib_;
FeatureProcessor feature_processor_;
DurationAnnotator duration_annotator_;
};
TEST_F(DurationAnnotatorTest, ClassifiesSimpleDuration) {
ClassificationResult classification;
EXPECT_TRUE(duration_annotator_.ClassifyText(
UTF8ToUnicodeText("Wake me up in 15 minutes ok?"), {14, 24},
AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification));
EXPECT_THAT(classification,
AllOf(Field(&ClassificationResult::collection, "duration"),
Field(&ClassificationResult::duration_ms, 15 * 60 * 1000)));
}
TEST_F(DurationAnnotatorTest, ClassifiesWhenTokensDontAlignWithSelection) {
ClassificationResult classification;
EXPECT_TRUE(duration_annotator_.ClassifyText(
UTF8ToUnicodeText("Wake me up in15 minutesok?"), {13, 23},
AnnotationUsecase_ANNOTATION_USECASE_RAW, &classification));
EXPECT_THAT(classification,
AllOf(Field(&ClassificationResult::collection, "duration"),
Field(&ClassificationResult::duration_ms, 15 * 60 * 1000)));
}
TEST_F(DurationAnnotatorTest, FindsSimpleDuration) {
const UnicodeText text = UTF8ToUnicodeText("Wake me up in 15 minutes ok?");
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
EXPECT_THAT(
result,
ElementsAre(
AllOf(Field(&AnnotatedSpan::span, CodepointSpan(14, 24)),
Field(&AnnotatedSpan::classification,
ElementsAre(AllOf(
Field(&ClassificationResult::collection, "duration"),
Field(&ClassificationResult::duration_ms,
15 * 60 * 1000)))))));
}
TEST_F(DurationAnnotatorTest, FindsDurationWithHalfExpression) {
const UnicodeText text =
UTF8ToUnicodeText("Set a timer for 3 and half minutes ok?");
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
EXPECT_THAT(
result,
ElementsAre(
AllOf(Field(&AnnotatedSpan::span, CodepointSpan(16, 34)),
Field(&AnnotatedSpan::classification,
ElementsAre(AllOf(
Field(&ClassificationResult::collection, "duration"),
Field(&ClassificationResult::duration_ms,
3.5 * 60 * 1000)))))));
}
TEST_F(DurationAnnotatorTest, FindsComposedDuration) {
const UnicodeText text =
UTF8ToUnicodeText("Wake me up in 3 hours and 5 seconds ok?");
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
EXPECT_THAT(
result,
ElementsAre(
AllOf(Field(&AnnotatedSpan::span, CodepointSpan(14, 35)),
Field(&AnnotatedSpan::classification,
ElementsAre(AllOf(
Field(&ClassificationResult::collection, "duration"),
Field(&ClassificationResult::duration_ms,
3 * 60 * 60 * 1000 + 5 * 1000)))))));
}
TEST_F(DurationAnnotatorTest, FindsHalfAnHour) {
const UnicodeText text = UTF8ToUnicodeText("Set a timer for half an hour");
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
EXPECT_THAT(
result,
ElementsAre(
AllOf(Field(&AnnotatedSpan::span, CodepointSpan(16, 28)),
Field(&AnnotatedSpan::classification,
ElementsAre(AllOf(
Field(&ClassificationResult::collection, "duration"),
Field(&ClassificationResult::duration_ms,
0.5 * 60 * 60 * 1000)))))));
}
TEST_F(DurationAnnotatorTest, FindsWhenHalfIsAfterGranularitySpecification) {
const UnicodeText text =
UTF8ToUnicodeText("Set a timer for 1 hour and a half");
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
EXPECT_THAT(
result,
ElementsAre(
AllOf(Field(&AnnotatedSpan::span, CodepointSpan(16, 33)),
Field(&AnnotatedSpan::classification,
ElementsAre(AllOf(
Field(&ClassificationResult::collection, "duration"),
Field(&ClassificationResult::duration_ms,
1.5 * 60 * 60 * 1000)))))));
}
TEST_F(DurationAnnotatorTest, FindsAnHourAndAHalf) {
const UnicodeText text =
UTF8ToUnicodeText("Set a timer for an hour and a half");
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
EXPECT_THAT(
result,
ElementsAre(
AllOf(Field(&AnnotatedSpan::span, CodepointSpan(19, 34)),
Field(&AnnotatedSpan::classification,
ElementsAre(AllOf(
Field(&ClassificationResult::collection, "duration"),
Field(&ClassificationResult::duration_ms,
1.5 * 60 * 60 * 1000)))))));
}
TEST_F(DurationAnnotatorTest,
FindsCorrectlyWhenSecondsComeSecondAndDontHaveNumber) {
const UnicodeText text =
UTF8ToUnicodeText("Set a timer for 10 minutes and a second ok?");
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
EXPECT_THAT(
result,
ElementsAre(
AllOf(Field(&AnnotatedSpan::span, CodepointSpan(16, 39)),
Field(&AnnotatedSpan::classification,
ElementsAre(AllOf(
Field(&ClassificationResult::collection, "duration"),
Field(&ClassificationResult::duration_ms,
10 * 60 * 1000 + 1 * 1000)))))));
}
TEST_F(DurationAnnotatorTest, DoesNotGreedilyTakeFillerWords) {
const UnicodeText text = UTF8ToUnicodeText(
"Set a timer for a a a 10 minutes and 2 seconds an and an ok?");
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
EXPECT_THAT(
result,
ElementsAre(
AllOf(Field(&AnnotatedSpan::span, CodepointSpan(22, 46)),
Field(&AnnotatedSpan::classification,
ElementsAre(AllOf(
Field(&ClassificationResult::collection, "duration"),
Field(&ClassificationResult::duration_ms,
10 * 60 * 1000 + 2 * 1000)))))));
}
TEST_F(DurationAnnotatorTest, DoesNotCrashWhenJustHalfIsSaid) {
const UnicodeText text = UTF8ToUnicodeText("Set a timer for half ok?");
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
ASSERT_EQ(result.size(), 0);
}
TEST_F(DurationAnnotatorTest, StripsPunctuationFromTokens) {
const UnicodeText text =
UTF8ToUnicodeText("Set a timer for 10 ,minutes, ,and, ,2, seconds, ok?");
std::vector<Token> tokens = Tokenize(text);
std::vector<AnnotatedSpan> result;
EXPECT_TRUE(duration_annotator_.FindAll(
text, tokens, AnnotationUsecase_ANNOTATION_USECASE_RAW, &result));
EXPECT_THAT(
result,
ElementsAre(
AllOf(Field(&AnnotatedSpan::span, CodepointSpan(16, 46)),
Field(&AnnotatedSpan::classification,
ElementsAre(AllOf(
Field(&ClassificationResult::collection, "duration"),
Field(&ClassificationResult::duration_ms,
10 * 60 * 1000 + 2 * 1000)))))));
}
} // namespace
} // namespace libtextclassifier3