blob: 5c4a63f72019aa0a2b5b3fcdfba89f1204123e22 [file] [log] [blame]
/*
* Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "annotator/translate/translate.h"
#include <memory>
#include "annotator/model_generated.h"
#include "utils/test-data-test-utils.h"
#include "lang_id/fb_model/lang-id-from-fb.h"
#include "lang_id/lang-id.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
namespace libtextclassifier3 {
namespace {
using testing::AllOf;
using testing::Field;
const TranslateAnnotatorOptions* TestingTranslateAnnotatorOptions() {
static const flatbuffers::DetachedBuffer* options_data = []() {
TranslateAnnotatorOptionsT options;
options.enabled = true;
options.algorithm = TranslateAnnotatorOptions_::Algorithm_BACKOFF;
options.backoff_options.reset(
new TranslateAnnotatorOptions_::BackoffOptionsT());
flatbuffers::FlatBufferBuilder builder;
builder.Finish(TranslateAnnotatorOptions::Pack(builder, &options));
return new flatbuffers::DetachedBuffer(builder.Release());
}();
return flatbuffers::GetRoot<TranslateAnnotatorOptions>(options_data->data());
}
class TestingTranslateAnnotator : public TranslateAnnotator {
public:
// Make these protected members public for tests.
using TranslateAnnotator::BackoffDetectLanguages;
using TranslateAnnotator::FindIndexOfNextWhitespaceOrPunctuation;
using TranslateAnnotator::TokenAlignedSubstringAroundSpan;
using TranslateAnnotator::TranslateAnnotator;
};
std::string GetModelPath() { return GetTestDataPath("annotator/test_data/"); }
class TranslateAnnotatorTest : public ::testing::Test {
protected:
TranslateAnnotatorTest()
: INIT_UNILIB_FOR_TESTING(unilib_),
langid_model_(libtextclassifier3::mobile::lang_id::GetLangIdFromFlatbufferFile(
GetModelPath() + "lang_id.smfb")),
translate_annotator_(TestingTranslateAnnotatorOptions(),
langid_model_.get(), &unilib_) {}
UniLib unilib_;
std::unique_ptr<libtextclassifier3::mobile::lang_id::LangId> langid_model_;
TestingTranslateAnnotator translate_annotator_;
};
TEST_F(TranslateAnnotatorTest, WhenSpeaksEnglishGetsTranslateActionForCzech) {
ClassificationResult classification;
EXPECT_TRUE(translate_annotator_.ClassifyText(
UTF8ToUnicodeText("Třista třicet tři stříbrných stříkaček."), {18, 28},
"en", &classification));
EXPECT_THAT(classification,
AllOf(Field(&ClassificationResult::collection, "translate")));
const EntityData* entity_data =
GetEntityData(classification.serialized_entity_data.data());
const auto predictions =
entity_data->translate()->language_prediction_results();
EXPECT_EQ(predictions->size(), 1);
EXPECT_EQ(predictions->Get(0)->language_tag()->str(), "cs");
EXPECT_GT(predictions->Get(0)->confidence_score(), 0);
EXPECT_LE(predictions->Get(0)->confidence_score(), 1);
}
TEST_F(TranslateAnnotatorTest, EntityDataIsSet) {
ClassificationResult classification;
EXPECT_TRUE(translate_annotator_.ClassifyText(UTF8ToUnicodeText("学校"),
{0, 2}, "en", &classification));
EXPECT_THAT(classification,
AllOf(Field(&ClassificationResult::collection, "translate")));
const EntityData* entity_data =
GetEntityData(classification.serialized_entity_data.data());
const auto predictions =
entity_data->translate()->language_prediction_results();
EXPECT_EQ(predictions->size(), 2);
EXPECT_EQ(predictions->Get(0)->language_tag()->str(), "zh");
EXPECT_GT(predictions->Get(0)->confidence_score(), 0);
EXPECT_LE(predictions->Get(0)->confidence_score(), 1);
EXPECT_EQ(predictions->Get(1)->language_tag()->str(), "ja");
EXPECT_TRUE(predictions->Get(0)->confidence_score() >=
predictions->Get(1)->confidence_score());
}
TEST_F(TranslateAnnotatorTest,
WhenSpeaksEnglishDoesntGetTranslateActionForEnglish) {
ClassificationResult classification;
EXPECT_FALSE(translate_annotator_.ClassifyText(
UTF8ToUnicodeText("This is utterly unutterable."), {8, 15}, "en",
&classification));
}
TEST_F(TranslateAnnotatorTest,
WhenSpeaksMultipleAndNotCzechGetsTranslateActionForCzech) {
ClassificationResult classification;
EXPECT_TRUE(translate_annotator_.ClassifyText(
UTF8ToUnicodeText("Třista třicet tři stříbrných stříkaček."), {8, 15},
"de,en,ja", &classification));
EXPECT_THAT(classification,
AllOf(Field(&ClassificationResult::collection, "translate")));
}
TEST_F(TranslateAnnotatorTest,
WhenSpeaksMultipleAndEnglishDoesntGetTranslateActionForEnglish) {
ClassificationResult classification;
EXPECT_FALSE(translate_annotator_.ClassifyText(
UTF8ToUnicodeText("This is utterly unutterable."), {8, 15}, "cs,en,de,ja",
&classification));
}
TEST_F(TranslateAnnotatorTest, FindIndexOfNextWhitespaceOrPunctuation) {
const UnicodeText text =
UTF8ToUnicodeText("Třista třicet, tři stříbrných stříkaček");
EXPECT_EQ(
translate_annotator_.FindIndexOfNextWhitespaceOrPunctuation(text, 0, -1),
text.begin());
EXPECT_EQ(
translate_annotator_.FindIndexOfNextWhitespaceOrPunctuation(text, 35, 1),
text.end());
EXPECT_EQ(
translate_annotator_.FindIndexOfNextWhitespaceOrPunctuation(text, 10, -1),
std::next(text.begin(), 6));
EXPECT_EQ(
translate_annotator_.FindIndexOfNextWhitespaceOrPunctuation(text, 10, 1),
std::next(text.begin(), 13));
}
TEST_F(TranslateAnnotatorTest, TokenAlignedSubstringAroundSpan) {
const UnicodeText text =
UTF8ToUnicodeText("Třista třicet, tři stříbrných stříkaček");
EXPECT_EQ(translate_annotator_.TokenAlignedSubstringAroundSpan(
text, {35, 37}, /*minimum_length=*/100),
text);
EXPECT_EQ(translate_annotator_.TokenAlignedSubstringAroundSpan(
text, {35, 37}, /*minimum_length=*/0),
UTF8ToUnicodeText("ač"));
EXPECT_EQ(translate_annotator_.TokenAlignedSubstringAroundSpan(
text, {35, 37}, /*minimum_length=*/3),
UTF8ToUnicodeText("stříkaček"));
EXPECT_EQ(translate_annotator_.TokenAlignedSubstringAroundSpan(
text, {35, 37}, /*minimum_length=*/10),
UTF8ToUnicodeText("stříkaček"));
EXPECT_EQ(translate_annotator_.TokenAlignedSubstringAroundSpan(
text, {35, 37}, /*minimum_length=*/11),
UTF8ToUnicodeText("stříbrných stříkaček"));
const UnicodeText text_no_whitespace =
UTF8ToUnicodeText("reallyreallylongstring");
EXPECT_EQ(translate_annotator_.TokenAlignedSubstringAroundSpan(
text_no_whitespace, {10, 11}, /*minimum_length=*/2),
UTF8ToUnicodeText("reallyreallylongstring"));
}
TEST_F(TranslateAnnotatorTest, TokenAlignedSubstringWhitespaceText) {
const UnicodeText text = UTF8ToUnicodeText(" ");
// Shouldn't modify the selection in case it's all whitespace.
EXPECT_EQ(translate_annotator_.TokenAlignedSubstringAroundSpan(
text, {5, 7}, /*minimum_length=*/3),
UTF8ToUnicodeText(" "));
EXPECT_EQ(translate_annotator_.TokenAlignedSubstringAroundSpan(
text, {5, 5}, /*minimum_length=*/1),
UTF8ToUnicodeText(""));
}
TEST_F(TranslateAnnotatorTest, TokenAlignedSubstringMostlyWhitespaceText) {
const UnicodeText text = UTF8ToUnicodeText("a a");
// Should still select the whole text even if pointing to whitespace
// initially.
EXPECT_EQ(translate_annotator_.TokenAlignedSubstringAroundSpan(
text, {5, 7}, /*minimum_length=*/11),
UTF8ToUnicodeText("a a"));
}
} // namespace
} // namespace libtextclassifier3