blob: 5eba45f194864723fa2ee650663fdaa0e7ca4523 [file] [log] [blame] [edit]
/*
* Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "actions/ranker.h"
#include <string>
#include "actions/actions_model_generated.h"
#include "actions/types.h"
#include "utils/zlib/zlib.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
namespace libtextclassifier3 {
namespace {
MATCHER_P3(IsAction, type, response_text, score, "") {
return testing::Value(arg.type, type) &&
testing::Value(arg.response_text, response_text) &&
testing::Value(arg.score, score);
}
MATCHER_P(IsActionType, type, "") { return testing::Value(arg.type, type); }
TEST(RankingTest, DeduplicationSmartReply) {
const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
ActionsSuggestionsResponse response;
response.actions = {
{/*response_text=*/"hello there", /*type=*/"text_reply",
/*score=*/1.0},
{/*response_text=*/"hello there", /*type=*/"text_reply", /*score=*/0.5}};
RankingOptionsT options;
options.deduplicate_suggestions = true;
flatbuffers::FlatBufferBuilder builder;
builder.Finish(RankingOptions::Pack(builder, &options));
auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
/*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
ranker->RankActions(conversation, &response);
EXPECT_THAT(
response.actions,
testing::ElementsAreArray({IsAction("text_reply", "hello there", 1.0)}));
}
TEST(RankingTest, DeduplicationExtraData) {
const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
ActionsSuggestionsResponse response;
response.actions = {
{/*response_text=*/"hello there", /*type=*/"text_reply",
/*score=*/1.0, /*priority_score=*/0.0},
{/*response_text=*/"hello there", /*type=*/"text_reply", /*score=*/0.5,
/*priority_score=*/0.0},
{/*response_text=*/"hello there", /*type=*/"text_reply", /*score=*/0.6,
/*priority_score=*/0.0,
/*annotations=*/{}, /*serialized_entity_data=*/"test"},
};
RankingOptionsT options;
options.deduplicate_suggestions = true;
flatbuffers::FlatBufferBuilder builder;
builder.Finish(RankingOptions::Pack(builder, &options));
auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
/*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
ranker->RankActions(conversation, &response);
EXPECT_THAT(
response.actions,
testing::ElementsAreArray({IsAction("text_reply", "hello there", 1.0),
// Is kept as it has different entity data.
IsAction("text_reply", "hello there", 0.6)}));
}
TEST(RankingTest, DeduplicationAnnotations) {
const Conversation conversation = {
{{/*user_id=*/1, "742 Evergreen Terrace, the number is 1-800-TESTING"}}};
ActionsSuggestionsResponse response;
{
ActionSuggestionAnnotation annotation;
annotation.span = {/*message_index=*/0, /*span=*/{0, 21},
/*text=*/"742 Evergreen Terrace"};
annotation.entity = ClassificationResult("address", 0.5);
response.actions.push_back({/*response_text=*/"",
/*type=*/"view_map",
/*score=*/0.5,
/*priority_score=*/1.0,
/*annotations=*/{annotation}});
}
{
ActionSuggestionAnnotation annotation;
annotation.span = {/*message_index=*/0, /*span=*/{0, 21},
/*text=*/"742 Evergreen Terrace"};
annotation.entity = ClassificationResult("address", 1.0);
response.actions.push_back({/*response_text=*/"",
/*type=*/"view_map",
/*score=*/1.0,
/*priority_score=*/2.0,
/*annotations=*/{annotation}});
}
{
ActionSuggestionAnnotation annotation;
annotation.span = {/*message_index=*/0, /*span=*/{37, 50},
/*text=*/"1-800-TESTING"};
annotation.entity = ClassificationResult("phone", 0.5);
response.actions.push_back({/*response_text=*/"",
/*type=*/"call_phone",
/*score=*/0.5,
/*priority_score=*/1.0,
/*annotations=*/{annotation}});
}
RankingOptionsT options;
options.deduplicate_suggestions = true;
flatbuffers::FlatBufferBuilder builder;
builder.Finish(RankingOptions::Pack(builder, &options));
auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
/*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
ranker->RankActions(conversation, &response);
EXPECT_THAT(response.actions,
testing::ElementsAreArray({IsAction("view_map", "", 1.0),
IsAction("call_phone", "", 0.5)}));
}
TEST(RankingTest, DeduplicationAnnotationsByPriorityScore) {
const Conversation conversation = {
{{/*user_id=*/1, "742 Evergreen Terrace, the number is 1-800-TESTING"}}};
ActionsSuggestionsResponse response;
{
ActionSuggestionAnnotation annotation;
annotation.span = {/*message_index=*/0, /*span=*/{0, 21},
/*text=*/"742 Evergreen Terrace"};
annotation.entity = ClassificationResult("address", 0.5);
response.actions.push_back({/*response_text=*/"",
/*type=*/"view_map",
/*score=*/0.6,
/*priority_score=*/2.0,
/*annotations=*/{annotation}});
}
{
ActionSuggestionAnnotation annotation;
annotation.span = {/*message_index=*/0, /*span=*/{0, 21},
/*text=*/"742 Evergreen Terrace"};
annotation.entity = ClassificationResult("address", 1.0);
response.actions.push_back({/*response_text=*/"",
/*type=*/"view_map",
/*score=*/1.0,
/*priority_score=*/1.0,
/*annotations=*/{annotation}});
}
{
ActionSuggestionAnnotation annotation;
annotation.span = {/*message_index=*/0, /*span=*/{37, 50},
/*text=*/"1-800-TESTING"};
annotation.entity = ClassificationResult("phone", 0.5);
response.actions.push_back({/*response_text=*/"",
/*type=*/"call_phone",
/*score=*/0.5,
/*priority_score=*/1.0,
/*annotations=*/{annotation}});
}
RankingOptionsT options;
options.deduplicate_suggestions = true;
flatbuffers::FlatBufferBuilder builder;
builder.Finish(RankingOptions::Pack(builder, &options));
auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
/*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
ranker->RankActions(conversation, &response);
EXPECT_THAT(
response.actions,
testing::ElementsAreArray(
{IsAction("view_map", "",
0.6), // lower score wins, as priority score is higher
IsAction("call_phone", "", 0.5)}));
}
TEST(RankingTest, DeduplicatesConflictingActions) {
const Conversation conversation = {{{/*user_id=*/1, "code A-911"}}};
ActionsSuggestionsResponse response;
{
ActionSuggestionAnnotation annotation;
annotation.span = {/*message_index=*/0, /*span=*/{7, 10},
/*text=*/"911"};
annotation.entity = ClassificationResult("phone", 1.0);
response.actions.push_back({/*response_text=*/"",
/*type=*/"call_phone",
/*score=*/1.0,
/*priority_score=*/1.0,
/*annotations=*/{annotation}});
}
{
ActionSuggestionAnnotation annotation;
annotation.span = {/*message_index=*/0, /*span=*/{5, 10},
/*text=*/"A-911"};
annotation.entity = ClassificationResult("code", 1.0);
response.actions.push_back({/*response_text=*/"",
/*type=*/"copy_code",
/*score=*/1.0,
/*priority_score=*/2.0,
/*annotations=*/{annotation}});
}
RankingOptionsT options;
options.deduplicate_suggestions = true;
flatbuffers::FlatBufferBuilder builder;
builder.Finish(RankingOptions::Pack(builder, &options));
auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
/*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
ranker->RankActions(conversation, &response);
EXPECT_THAT(response.actions,
testing::ElementsAreArray({IsAction("copy_code", "", 1.0)}));
}
TEST(RankingTest, HandlesCompressedLuaScript) {
const Conversation conversation = {{{/*user_id=*/1, "hello hello"}}};
ActionsSuggestionsResponse response;
response.actions = {
{/*response_text=*/"hello there", /*type=*/"text_reply",
/*score=*/1.0},
{/*response_text=*/"", /*type=*/"share_location", /*score=*/0.5},
{/*response_text=*/"", /*type=*/"add_to_collection", /*score=*/0.1}};
const std::string test_snippet = R"(
local result = {}
for id, action in pairs(actions) do
if action.type ~= "text_reply" then
table.insert(result, id)
end
end
return result
)";
RankingOptionsT options;
options.compressed_lua_ranking_script.reset(new CompressedBufferT);
std::unique_ptr<ZlibCompressor> compressor = ZlibCompressor::Instance();
compressor->Compress(test_snippet,
options.compressed_lua_ranking_script.get());
options.deduplicate_suggestions = true;
flatbuffers::FlatBufferBuilder builder;
builder.Finish(RankingOptions::Pack(builder, &options));
std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance();
auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
decompressor.get(), /*smart_reply_action_type=*/"text_reply");
ranker->RankActions(conversation, &response);
EXPECT_THAT(response.actions,
testing::ElementsAreArray({IsActionType("share_location"),
IsActionType("add_to_collection")}));
}
TEST(RankingTest, SuppressSmartRepliesWithAction) {
const Conversation conversation = {{{/*user_id=*/1, "should i call 911"}}};
ActionsSuggestionsResponse response;
{
ActionSuggestionAnnotation annotation;
annotation.span = {/*message_index=*/0, /*span=*/{5, 8},
/*text=*/"911"};
annotation.entity = ClassificationResult("phone", 1.0);
response.actions.push_back({/*response_text=*/"",
/*type=*/"call_phone",
/*score=*/1.0,
/*priority_score=*/1.0,
/*annotations=*/{annotation}});
}
response.actions.push_back({/*response_text=*/"How are you?",
/*type=*/"text_reply"});
RankingOptionsT options;
options.suppress_smart_replies_with_actions = true;
flatbuffers::FlatBufferBuilder builder;
builder.Finish(RankingOptions::Pack(builder, &options));
auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
/*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
ranker->RankActions(conversation, &response);
EXPECT_THAT(response.actions,
testing::ElementsAreArray({IsAction("call_phone", "", 1.0)}));
}
TEST(RankingTest, GroupsActionsByAnnotations) {
const Conversation conversation = {{{/*user_id=*/1, "should i call 911"}}};
ActionsSuggestionsResponse response;
{
ActionSuggestionAnnotation annotation;
annotation.span = {/*message_index=*/0, /*span=*/{5, 8},
/*text=*/"911"};
annotation.entity = ClassificationResult("phone", 1.0);
response.actions.push_back({/*response_text=*/"",
/*type=*/"call_phone",
/*score=*/1.0,
/*priority_score=*/0.0,
/*annotations=*/{annotation}});
response.actions.push_back({/*response_text=*/"",
/*type=*/"add_contact",
/*score=*/0.0,
/*priority_score=*/1.0,
/*annotations=*/{annotation}});
}
response.actions.push_back({/*response_text=*/"How are you?",
/*type=*/"text_reply",
/*score=*/0.5});
RankingOptionsT options;
options.group_by_annotations = true;
flatbuffers::FlatBufferBuilder builder;
builder.Finish(RankingOptions::Pack(builder, &options));
auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
/*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
ranker->RankActions(conversation, &response);
// The text reply should be last, even though it has a higher score than the
// `add_contact` action.
EXPECT_THAT(
response.actions,
testing::ElementsAreArray({IsAction("call_phone", "", 1.0),
IsAction("add_contact", "", 0.0),
IsAction("text_reply", "How are you?", 0.5)}));
}
TEST(RankingTest, GroupsByAnnotationsSortedByPriority) {
const Conversation conversation = {{{/*user_id=*/1, "should i call 911"}}};
ActionsSuggestionsResponse response;
response.actions.push_back({/*response_text=*/"How are you?",
/*type=*/"text_reply",
/*score=*/2.0,
/*priority_score=*/0.0});
{
ActionSuggestionAnnotation annotation;
annotation.span = {/*message_index=*/0, /*span=*/{5, 8},
/*text=*/"911"};
annotation.entity = ClassificationResult("phone", 1.0);
response.actions.push_back({/*response_text=*/"",
/*type=*/"add_contact",
/*score=*/0.0,
/*priority_score=*/1.0,
/*annotations=*/{annotation}});
response.actions.push_back({/*response_text=*/"",
/*type=*/"call_phone",
/*score=*/1.0,
/*priority_score=*/0.0,
/*annotations=*/{annotation}});
response.actions.push_back({/*response_text=*/"",
/*type=*/"add_contact2",
/*score=*/0.5,
/*priority_score=*/1.0,
/*annotations=*/{annotation}});
}
RankingOptionsT options;
options.group_by_annotations = true;
options.sort_type = RankingOptionsSortType_SORT_TYPE_PRIORITY_SCORE;
flatbuffers::FlatBufferBuilder builder;
builder.Finish(RankingOptions::Pack(builder, &options));
auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
/*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
ranker->RankActions(conversation, &response);
// The text reply should be last, even though it's score is higher than
// any other scores -- because it's priority_score is lower than the max
// of those with the 'phone' annotation
EXPECT_THAT(response.actions,
testing::ElementsAreArray({
// Group 1 (Phone annotation)
IsAction("add_contact2", "", 0.5), // priority_score=1.0
IsAction("add_contact", "", 0.0), // priority_score=1.0
IsAction("call_phone", "", 1.0), // priority_score=0.0
IsAction("text_reply", "How are you?", 2.0), // Group 2
}));
}
TEST(RankingTest, SortsActionsByScore) {
const Conversation conversation = {{{/*user_id=*/1, "should i call 911"}}};
ActionsSuggestionsResponse response;
{
ActionSuggestionAnnotation annotation;
annotation.span = {/*message_index=*/0, /*span=*/{5, 8},
/*text=*/"911"};
annotation.entity = ClassificationResult("phone", 1.0);
response.actions.push_back({/*response_text=*/"",
/*type=*/"call_phone",
/*score=*/1.0,
/*priority_score=*/0.0,
/*annotations=*/{annotation}});
response.actions.push_back({/*response_text=*/"",
/*type=*/"add_contact",
/*score=*/0.0,
/*priority_score=*/1.0,
/*annotations=*/{annotation}});
}
response.actions.push_back({/*response_text=*/"How are you?",
/*type=*/"text_reply",
/*score=*/0.5});
RankingOptionsT options;
// Don't group by annotation.
options.group_by_annotations = false;
flatbuffers::FlatBufferBuilder builder;
builder.Finish(RankingOptions::Pack(builder, &options));
auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
/*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
ranker->RankActions(conversation, &response);
EXPECT_THAT(
response.actions,
testing::ElementsAreArray({IsAction("call_phone", "", 1.0),
IsAction("text_reply", "How are you?", 0.5),
IsAction("add_contact", "", 0.0)}));
}
TEST(RankingTest, SortsActionsByPriority) {
const Conversation conversation = {{{/*user_id=*/1, "hello?"}}};
ActionsSuggestionsResponse response;
// emoji replies given higher priority_score
response.actions.push_back({/*response_text=*/"😁",
/*type=*/"text_reply",
/*score=*/0.5,
/*priority_score=*/1.0});
response.actions.push_back({/*response_text=*/"👋",
/*type=*/"text_reply",
/*score=*/0.4,
/*priority_score=*/1.0});
response.actions.push_back({/*response_text=*/"Yes",
/*type=*/"text_reply",
/*score=*/1.0,
/*priority_score=*/0.0});
RankingOptionsT options;
// Don't group by annotation.
options.group_by_annotations = false;
options.sort_type = RankingOptionsSortType_SORT_TYPE_PRIORITY_SCORE;
flatbuffers::FlatBufferBuilder builder;
builder.Finish(RankingOptions::Pack(builder, &options));
auto ranker = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
flatbuffers::GetRoot<RankingOptions>(builder.GetBufferPointer()),
/*decompressor=*/nullptr, /*smart_reply_action_type=*/"text_reply");
ranker->RankActions(conversation, &response);
EXPECT_THAT(response.actions, testing::ElementsAreArray(
{IsAction("text_reply", "😁", 0.5),
IsAction("text_reply", "👋", 0.4),
// Ranked last because of priority score
IsAction("text_reply", "Yes", 1.0)}));
}
} // namespace
} // namespace libtextclassifier3