blob: b51ebc7bcc090f5916da8ea4aa8965be31be6c10 [file] [log] [blame]
/*
* Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "actions/actions-suggestions.h"
#include <fstream>
#include <iterator>
#include <memory>
#include <string>
#include "actions/actions_model_generated.h"
#include "actions/test-utils.h"
#include "actions/zlib-utils.h"
#include "annotator/collections.h"
#include "annotator/types.h"
#include "utils/flatbuffers/flatbuffers.h"
#include "utils/flatbuffers/flatbuffers_generated.h"
#include "utils/flatbuffers/mutable.h"
#include "utils/grammar/utils/locale-shard-map.h"
#include "utils/grammar/utils/rules.h"
#include "utils/hash/farmhash.h"
#include "utils/jvm-test-utils.h"
#include "utils/test-data-test-utils.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "flatbuffers/flatbuffers.h"
#include "flatbuffers/reflection.h"
namespace libtextclassifier3 {
namespace {
using ::testing::ElementsAre;
using ::testing::FloatEq;
using ::testing::IsEmpty;
using ::testing::NotNull;
using ::testing::SizeIs;
constexpr char kModelFileName[] = "actions_suggestions_test.model";
constexpr char kModelGrammarFileName[] =
"actions_suggestions_grammar_test.model";
constexpr char kMultiTaskTF2TestModelFileName[] =
"actions_suggestions_test.multi_task_tf2_test.model";
constexpr char kMultiTaskModelFileName[] =
"actions_suggestions_test.multi_task_9heads.model";
constexpr char kHashGramModelFileName[] =
"actions_suggestions_test.hashgram.model";
constexpr char kMultiTaskSrP13nModelFileName[] =
"actions_suggestions_test.multi_task_sr_p13n.model";
constexpr char kMultiTaskSrEmojiModelFileName[] =
"actions_suggestions_test.multi_task_sr_emoji.model";
constexpr char kSensitiveTFliteModelFileName[] =
"actions_suggestions_test.sensitive_tflite.model";
constexpr char kLiveRelayTFLiteModelFileName[] =
"actions_suggestions_test.live_relay.model";
std::string ReadFile(const std::string& file_name) {
std::ifstream file_stream(file_name);
return std::string(std::istreambuf_iterator<char>(file_stream), {});
}
std::string GetModelPath() { return GetTestDataPath("actions/test_data/"); }
class ActionsSuggestionsTest : public testing::Test {
protected:
explicit ActionsSuggestionsTest() : unilib_(CreateUniLibForTesting()) {}
std::unique_ptr<ActionsSuggestions> LoadTestModel(
const std::string model_file_name) {
return ActionsSuggestions::FromPath(GetModelPath() + model_file_name,
unilib_.get());
}
std::unique_ptr<ActionsSuggestions> LoadHashGramTestModel() {
return ActionsSuggestions::FromPath(GetModelPath() + kHashGramModelFileName,
unilib_.get());
}
std::unique_ptr<ActionsSuggestions> LoadMultiTaskTestModel() {
return ActionsSuggestions::FromPath(
GetModelPath() + kMultiTaskModelFileName, unilib_.get());
}
std::unique_ptr<ActionsSuggestions> LoadMultiTaskSrP13nTestModel() {
return ActionsSuggestions::FromPath(
GetModelPath() + kMultiTaskSrP13nModelFileName, unilib_.get());
}
std::unique_ptr<UniLib> unilib_;
};
TEST_F(ActionsSuggestionsTest, InstantiateActionSuggestions) {
EXPECT_THAT(LoadTestModel(kModelFileName), NotNull());
}
TEST_F(ActionsSuggestionsTest, ProducesEmptyResponseOnInvalidInput) {
std::unique_ptr<ActionsSuggestions> actions_suggestions =
LoadTestModel(kModelFileName);
const ActionsSuggestionsResponse response =
actions_suggestions->SuggestActions(
{{{/*user_id=*/1, "Where are you?\xf0\x9f",
/*reference_time_ms_utc=*/0,
/*reference_timezone=*/"Europe/Zurich",
/*annotations=*/{}, /*locales=*/"en"}}});
EXPECT_THAT(response.actions, IsEmpty());
}
TEST_F(ActionsSuggestionsTest, ProducesEmptyResponseOnInvalidUtf8) {
std::unique_ptr<ActionsSuggestions> actions_suggestions =
LoadTestModel(kModelFileName);
const ActionsSuggestionsResponse response =
actions_suggestions->SuggestActions(
{{{/*user_id=*/1,
"(857) 225-3556 \xed\xa0\x80\xed\xa0\x80\xed\xa0\x80\xed\xa0\x80",
/*reference_time_ms_utc=*/0,
/*reference_timezone=*/"Europe/Zurich",
/*annotations=*/{}, /*locales=*/"en"}}});
EXPECT_THAT(response.actions, IsEmpty());
}
TEST_F(ActionsSuggestionsTest, SuggestsActions) {
std::unique_ptr<ActionsSuggestions> actions_suggestions =
LoadTestModel(kModelFileName);
const ActionsSuggestionsResponse response =
actions_suggestions->SuggestActions(
{{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
/*reference_timezone=*/"Europe/Zurich",
/*annotations=*/{}, /*locales=*/"en"}}});
EXPECT_EQ(response.actions.size(), 3 /* share_location + 2 smart replies*/);
}
TEST_F(ActionsSuggestionsTest, SuggestsNoActionsForUnknownLocale) {
std::unique_ptr<ActionsSuggestions> actions_suggestions =
LoadTestModel(kModelFileName);
const ActionsSuggestionsResponse response =
actions_suggestions->SuggestActions(
{{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
/*reference_timezone=*/"Europe/Zurich",
/*annotations=*/{}, /*locales=*/"zz"}}});
EXPECT_THAT(response.actions, testing::IsEmpty());
}
TEST_F(ActionsSuggestionsTest, SuggestsActionsFromAnnotations) {
std::unique_ptr<ActionsSuggestions> actions_suggestions =
LoadTestModel(kModelFileName);
AnnotatedSpan annotation;
annotation.span = {11, 15};
annotation.classification = {ClassificationResult("address", 1.0)};
const ActionsSuggestionsResponse response =
actions_suggestions->SuggestActions(
{{{/*user_id=*/1, "are you at home?",
/*reference_time_ms_utc=*/0,
/*reference_timezone=*/"Europe/Zurich",
/*annotations=*/{annotation},
/*locales=*/"en"}}});
ASSERT_GE(response.actions.size(), 1);
EXPECT_EQ(response.actions.front().type, "view_map");
EXPECT_EQ(response.actions.front().score, 1.0);
}
TEST_F(ActionsSuggestionsTest, SuggestsActionsFromAnnotationsWithEntityData) {
const std::string actions_model_string =
ReadFile(GetModelPath() + kModelFileName);
std::unique_ptr<ActionsModelT> actions_model =
UnPackActionsModel(actions_model_string.c_str());
SetTestEntityDataSchema(actions_model.get());
// Set custom actions from annotations config.
actions_model->annotation_actions_spec->annotation_mapping.clear();
actions_model->annotation_actions_spec->annotation_mapping.emplace_back(
new AnnotationActionsSpec_::AnnotationMappingT);
AnnotationActionsSpec_::AnnotationMappingT* mapping =
actions_model->annotation_actions_spec->annotation_mapping.back().get();
mapping->annotation_collection = "address";
mapping->action.reset(new ActionSuggestionSpecT);
mapping->action->type = "save_location";
mapping->action->score = 1.0;
mapping->action->priority_score = 2.0;
mapping->entity_field.reset(new FlatbufferFieldPathT);
mapping->entity_field->field.emplace_back(new FlatbufferFieldT);
mapping->entity_field->field.back()->field_name = "location";
flatbuffers::FlatBufferBuilder builder;
FinishActionsModelBuffer(builder,
ActionsModel::Pack(builder, actions_model.get()));
std::unique_ptr<ActionsSuggestions> actions_suggestions =
ActionsSuggestions::FromUnownedBuffer(
reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
builder.GetSize(), unilib_.get());
AnnotatedSpan annotation;
annotation.span = {11, 15};
annotation.classification = {ClassificationResult("address", 1.0)};
const ActionsSuggestionsResponse response =
actions_suggestions->SuggestActions(
{{{/*user_id=*/1, "are you at home?",
/*reference_time_ms_utc=*/0,
/*reference_timezone=*/"Europe/Zurich",
/*annotations=*/{annotation},
/*locales=*/"en"}}});
ASSERT_GE(response.actions.size(), 1);
EXPECT_EQ(response.actions.front().type, "save_location");
EXPECT_EQ(response.actions.front().score, 1.0);
// Check that the `location` entity field holds the text from the address
// annotation.
const flatbuffers::Table* entity =
flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
response.actions.front().serialized_entity_data.data()));
EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/6)->str(),
"home");
}
TEST_F(ActionsSuggestionsTest,
SuggestsActionsFromAnnotationsWithNormalization) {
const std::string actions_model_string =
ReadFile(GetModelPath() + kModelFileName);
std::unique_ptr<ActionsModelT> actions_model =
UnPackActionsModel(actions_model_string.c_str());
SetTestEntityDataSchema(actions_model.get());
// Set custom actions from annotations config.
actions_model->annotation_actions_spec->annotation_mapping.clear();
actions_model->annotation_actions_spec->annotation_mapping.emplace_back(
new AnnotationActionsSpec_::AnnotationMappingT);
AnnotationActionsSpec_::AnnotationMappingT* mapping =
actions_model->annotation_actions_spec->annotation_mapping.back().get();
mapping->annotation_collection = "address";
mapping->action.reset(new ActionSuggestionSpecT);
mapping->action->type = "save_location";
mapping->action->score = 1.0;
mapping->action->priority_score = 2.0;
mapping->entity_field.reset(new FlatbufferFieldPathT);
mapping->entity_field->field.emplace_back(new FlatbufferFieldT);
mapping->entity_field->field.back()->field_name = "location";
mapping->normalization_options.reset(new NormalizationOptionsT);
mapping->normalization_options->codepointwise_normalization =
NormalizationOptions_::CodepointwiseNormalizationOp_UPPERCASE;
flatbuffers::FlatBufferBuilder builder;
FinishActionsModelBuffer(builder,
ActionsModel::Pack(builder, actions_model.get()));
std::unique_ptr<ActionsSuggestions> actions_suggestions =
ActionsSuggestions::FromUnownedBuffer(
reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
builder.GetSize(), unilib_.get());
AnnotatedSpan annotation;
annotation.span = {11, 15};
annotation.classification = {ClassificationResult("address", 1.0)};
const ActionsSuggestionsResponse response =
actions_suggestions->SuggestActions(
{{{/*user_id=*/1, "are you at home?",
/*reference_time_ms_utc=*/0,
/*reference_timezone=*/"Europe/Zurich",
/*annotations=*/{annotation},
/*locales=*/"en"}}});
ASSERT_GE(response.actions.size(), 1);
EXPECT_EQ(response.actions.front().type, "save_location");
EXPECT_EQ(response.actions.front().score, 1.0);
// Check that the `location` entity field holds the normalized text of the
// annotation.
const flatbuffers::Table* entity =
flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
response.actions.front().serialized_entity_data.data()));
EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/6)->str(),
"HOME");
}
TEST_F(ActionsSuggestionsTest, SuggestsActionsFromDuplicatedAnnotations) {
std::unique_ptr<ActionsSuggestions> actions_suggestions =
LoadTestModel(kModelFileName);
AnnotatedSpan flight_annotation;
flight_annotation.span = {11, 15};
flight_annotation.classification = {ClassificationResult("flight", 2.5)};
AnnotatedSpan flight_annotation2;
flight_annotation2.span = {35, 39};
flight_annotation2.classification = {ClassificationResult("flight", 3.0)};
AnnotatedSpan email_annotation;
email_annotation.span = {43, 56};
email_annotation.classification = {ClassificationResult("email", 2.0)};
const ActionsSuggestionsResponse response =
actions_suggestions->SuggestActions(
{{{/*user_id=*/1,
"call me at LX38 or send message to LX38 or test@test.com.",
/*reference_time_ms_utc=*/0,
/*reference_timezone=*/"Europe/Zurich",
/*annotations=*/
{flight_annotation, flight_annotation2, email_annotation},
/*locales=*/"en"}}});
ASSERT_GE(response.actions.size(), 2);
EXPECT_EQ(response.actions[0].type, "track_flight");
EXPECT_EQ(response.actions[0].score, 3.0);
EXPECT_EQ(response.actions[1].type, "send_email");
EXPECT_EQ(response.actions[1].score, 2.0);
}
TEST_F(ActionsSuggestionsTest, SuggestsActionsAnnotationsWithNoDeduplication) {
const std::string actions_model_string =
ReadFile(GetModelPath() + kModelFileName);
std::unique_ptr<ActionsModelT> actions_model =
UnPackActionsModel(actions_model_string.c_str());
// Disable deduplication.
actions_model->annotation_actions_spec->deduplicate_annotations = false;
flatbuffers::FlatBufferBuilder builder;
FinishActionsModelBuffer(builder,
ActionsModel::Pack(builder, actions_model.get()));
std::unique_ptr<ActionsSuggestions> actions_suggestions =
ActionsSuggestions::FromUnownedBuffer(
reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
builder.GetSize(), unilib_.get());
AnnotatedSpan flight_annotation;
flight_annotation.span = {11, 15};
flight_annotation.classification = {ClassificationResult("flight", 2.5)};
AnnotatedSpan flight_annotation2;
flight_annotation2.span = {35, 39};
flight_annotation2.classification = {ClassificationResult("flight", 3.0)};
AnnotatedSpan email_annotation;
email_annotation.span = {43, 56};
email_annotation.classification = {ClassificationResult("email", 2.0)};
const ActionsSuggestionsResponse response =
actions_suggestions->SuggestActions(
{{{/*user_id=*/1,
"call me at LX38 or send message to LX38 or test@test.com.",
/*reference_time_ms_utc=*/0,
/*reference_timezone=*/"Europe/Zurich",
/*annotations=*/
{flight_annotation, flight_annotation2, email_annotation},
/*locales=*/"en"}}});
ASSERT_GE(response.actions.size(), 3);
EXPECT_EQ(response.actions[0].type, "track_flight");
EXPECT_EQ(response.actions[0].score, 3.0);
EXPECT_EQ(response.actions[1].type, "track_flight");
EXPECT_EQ(response.actions[1].score, 2.5);
EXPECT_EQ(response.actions[2].type, "send_email");
EXPECT_EQ(response.actions[2].score, 2.0);
}
ActionsSuggestionsResponse TestSuggestActionsFromAnnotations(
const std::function<void(ActionsModelT*)>& set_config_fn,
const UniLib* unilib = nullptr) {
const std::string actions_model_string =
ReadFile(GetModelPath() + kModelFileName);
std::unique_ptr<ActionsModelT> actions_model =
UnPackActionsModel(actions_model_string.c_str());
// Set custom config.
set_config_fn(actions_model.get());
// Disable smart reply for easier testing.
actions_model->preconditions->min_smart_reply_triggering_score = 1.0;
flatbuffers::FlatBufferBuilder builder;
FinishActionsModelBuffer(builder,
ActionsModel::Pack(builder, actions_model.get()));
std::unique_ptr<ActionsSuggestions> actions_suggestions =
ActionsSuggestions::FromUnownedBuffer(
reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
builder.GetSize(), unilib);
AnnotatedSpan flight_annotation;
flight_annotation.span = {15, 19};
flight_annotation.classification = {ClassificationResult("flight", 2.0)};
AnnotatedSpan email_annotation;
email_annotation.span = {0, 16};
email_annotation.classification = {ClassificationResult("email", 1.0)};
return actions_suggestions->SuggestActions(
{{{/*user_id=*/ActionsSuggestions::kLocalUserId,
"hehe@android.com",
/*reference_time_ms_utc=*/0,
/*reference_timezone=*/"Europe/Zurich",
/*annotations=*/
{email_annotation},
/*locales=*/"en"},
{/*user_id=*/2,
"yoyo@android.com",
/*reference_time_ms_utc=*/0,
/*reference_timezone=*/"Europe/Zurich",
/*annotations=*/
{email_annotation},
/*locales=*/"en"},
{/*user_id=*/1,
"test@android.com",
/*reference_time_ms_utc=*/0,
/*reference_timezone=*/"Europe/Zurich",
/*annotations=*/
{email_annotation},
/*locales=*/"en"},
{/*user_id=*/1,
"I am on flight LX38.",
/*reference_time_ms_utc=*/0,
/*reference_timezone=*/"Europe/Zurich",
/*annotations=*/
{flight_annotation},
/*locales=*/"en"}}});
}
TEST_F(ActionsSuggestionsTest, SuggestsActionsWithAnnotationsOnlyLastMessage) {
const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
[](ActionsModelT* actions_model) {
actions_model->annotation_actions_spec->include_local_user_messages =
false;
actions_model->annotation_actions_spec->only_until_last_sent = true;
actions_model->annotation_actions_spec->max_history_from_any_person = 1;
actions_model->annotation_actions_spec->max_history_from_last_person =
1;
},
unilib_.get());
EXPECT_THAT(response.actions, SizeIs(1));
EXPECT_EQ(response.actions[0].type, "track_flight");
}
TEST_F(ActionsSuggestionsTest, SuggestsActionsWithAnnotationsOnlyLastPerson) {
const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
[](ActionsModelT* actions_model) {
actions_model->annotation_actions_spec->include_local_user_messages =
false;
actions_model->annotation_actions_spec->only_until_last_sent = true;
actions_model->annotation_actions_spec->max_history_from_any_person = 1;
actions_model->annotation_actions_spec->max_history_from_last_person =
3;
},
unilib_.get());
EXPECT_THAT(response.actions, SizeIs(2));
EXPECT_EQ(response.actions[0].type, "track_flight");
EXPECT_EQ(response.actions[1].type, "send_email");
}
TEST_F(ActionsSuggestionsTest, SuggestsActionsWithAnnotationsFromAny) {
const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
[](ActionsModelT* actions_model) {
actions_model->annotation_actions_spec->include_local_user_messages =
false;
actions_model->annotation_actions_spec->only_until_last_sent = true;
actions_model->annotation_actions_spec->max_history_from_any_person = 2;
actions_model->annotation_actions_spec->max_history_from_last_person =
1;
},
unilib_.get());
EXPECT_THAT(response.actions, SizeIs(2));
EXPECT_EQ(response.actions[0].type, "track_flight");
EXPECT_EQ(response.actions[1].type, "send_email");
}
TEST_F(ActionsSuggestionsTest,
SuggestsActionsWithAnnotationsFromAnyManyMessages) {
const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
[](ActionsModelT* actions_model) {
actions_model->annotation_actions_spec->include_local_user_messages =
false;
actions_model->annotation_actions_spec->only_until_last_sent = true;
actions_model->annotation_actions_spec->max_history_from_any_person = 3;
actions_model->annotation_actions_spec->max_history_from_last_person =
1;
},
unilib_.get());
EXPECT_THAT(response.actions, SizeIs(3));
EXPECT_EQ(response.actions[0].type, "track_flight");
EXPECT_EQ(response.actions[1].type, "send_email");
EXPECT_EQ(response.actions[2].type, "send_email");
}
TEST_F(ActionsSuggestionsTest,
SuggestsActionsWithAnnotationsFromAnyManyMessagesButNotLocalUser) {
const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
[](ActionsModelT* actions_model) {
actions_model->annotation_actions_spec->include_local_user_messages =
false;
actions_model->annotation_actions_spec->only_until_last_sent = true;
actions_model->annotation_actions_spec->max_history_from_any_person = 5;
actions_model->annotation_actions_spec->max_history_from_last_person =
1;
},
unilib_.get());
EXPECT_THAT(response.actions, SizeIs(3));
EXPECT_EQ(response.actions[0].type, "track_flight");
EXPECT_EQ(response.actions[1].type, "send_email");
EXPECT_EQ(response.actions[2].type, "send_email");
}
TEST_F(ActionsSuggestionsTest,
SuggestsActionsWithAnnotationsFromAnyManyMessagesAlsoFromLocalUser) {
const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations(
[](ActionsModelT* actions_model) {
actions_model->annotation_actions_spec->include_local_user_messages =
true;
actions_model->annotation_actions_spec->only_until_last_sent = false;
actions_model->annotation_actions_spec->max_history_from_any_person = 5;
actions_model->annotation_actions_spec->max_history_from_last_person =
1;
},
unilib_.get());
EXPECT_THAT(response.actions, SizeIs(4));
EXPECT_EQ(response.actions[0].type, "track_flight");
EXPECT_EQ(response.actions[1].type, "send_email");
EXPECT_EQ(response.actions[2].type, "send_email");
EXPECT_EQ(response.actions[3].type, "send_email");
}
void TestSuggestActionsWithThreshold(
const std::function<void(ActionsModelT*)>& set_value_fn,
const UniLib* unilib = nullptr, const int expected_size = 0,
const std::string& preconditions_overwrite = "") {
const std::string actions_model_string =
ReadFile(GetModelPath() + kModelFileName);
std::unique_ptr<ActionsModelT> actions_model =
UnPackActionsModel(actions_model_string.c_str());
set_value_fn(actions_model.get());
flatbuffers::FlatBufferBuilder builder;
FinishActionsModelBuffer(builder,
ActionsModel::Pack(builder, actions_model.get()));
std::unique_ptr<ActionsSuggestions> actions_suggestions =
ActionsSuggestions::FromUnownedBuffer(
reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
builder.GetSize(), unilib, preconditions_overwrite);
ASSERT_TRUE(actions_suggestions);
const ActionsSuggestionsResponse response =
actions_suggestions->SuggestActions(
{{{/*user_id=*/1, "I have the low-ground. Where are you?",
/*reference_time_ms_utc=*/0,
/*reference_timezone=*/"Europe/Zurich",
/*annotations=*/{}, /*locales=*/"en"}}});
EXPECT_LE(response.actions.size(), expected_size);
}
TEST_F(ActionsSuggestionsTest, SuggestsActionsWithTriggeringScore) {
TestSuggestActionsWithThreshold(
[](ActionsModelT* actions_model) {
actions_model->preconditions->min_smart_reply_triggering_score = 1.0;
},
unilib_.get(),
/*expected_size=*/1 /*no smart reply, only actions*/
);
}
TEST_F(ActionsSuggestionsTest, SuggestsActionsWithMinReplyScore) {
TestSuggestActionsWithThreshold(
[](ActionsModelT* actions_model) {
actions_model->preconditions->min_reply_score_threshold = 1.0;
},
unilib_.get(),
/*expected_size=*/1 /*no smart reply, only actions*/
);
}
TEST_F(ActionsSuggestionsTest, SuggestsActionsWithSensitiveTopicScore) {
TestSuggestActionsWithThreshold(
[](ActionsModelT* actions_model) {
actions_model->preconditions->max_sensitive_topic_score = 0.0;
},
unilib_.get(),
/*expected_size=*/4 /* no sensitive prediction in test model*/);
}
TEST_F(ActionsSuggestionsTest, SuggestsActionsWithMaxInputLength) {
TestSuggestActionsWithThreshold(
[](ActionsModelT* actions_model) {
actions_model->preconditions->max_input_length = 0;
},
unilib_.get());
}
TEST_F(ActionsSuggestionsTest, SuggestsActionsWithMinInputLength) {
TestSuggestActionsWithThreshold(
[](ActionsModelT* actions_model) {
actions_model->preconditions->min_input_length = 100;
},
unilib_.get());
}
TEST_F(ActionsSuggestionsTest, SuggestsActionsWithPreconditionsOverwrite) {
TriggeringPreconditionsT preconditions_overwrite;
preconditions_overwrite.max_input_length = 0;
flatbuffers::FlatBufferBuilder builder;
builder.Finish(
TriggeringPreconditions::Pack(builder, &preconditions_overwrite));
TestSuggestActionsWithThreshold(
// Keep model untouched.
[](ActionsModelT* actions_model) {}, unilib_.get(),
/*expected_size=*/0,
std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()),
builder.GetSize()));
}
#ifdef TC3_UNILIB_ICU
TEST_F(ActionsSuggestionsTest, SuggestsActionsLowConfidence) {
TestSuggestActionsWithThreshold(
[](ActionsModelT* actions_model) {
actions_model->preconditions->suppress_on_low_confidence_input = true;
actions_model->low_confidence_rules.reset(new RulesModelT);
actions_model->low_confidence_rules->regex_rule.emplace_back(
new RulesModel_::RegexRuleT);
actions_model->low_confidence_rules->regex_rule.back()->pattern =
"low-ground";
},
unilib_.get());
}
TEST_F(ActionsSuggestionsTest, SuggestsActionsLowConfidenceInputOutput) {
const std::string actions_model_string =
ReadFile(GetModelPath() + kModelFileName);
std::unique_ptr<ActionsModelT> actions_model =
UnPackActionsModel(actions_model_string.c_str());
// Add custom triggering rule.
actions_model->rules.reset(new RulesModelT());
actions_model->rules->regex_rule.emplace_back(new RulesModel_::RegexRuleT);
RulesModel_::RegexRuleT* rule = actions_model->rules->regex_rule.back().get();
rule->pattern = "^(?i:hello\\s(there))$";
{
std::unique_ptr<RulesModel_::RuleActionSpecT> rule_action(
new RulesModel_::RuleActionSpecT);
rule_action->action.reset(new ActionSuggestionSpecT);
rule_action->action->type = "text_reply";
rule_action->action->response_text = "General Desaster!";
rule_action->action->score = 1.0f;
rule_action->action->priority_score = 1.0f;
rule->actions.push_back(std::move(rule_action));
}
{
std::unique_ptr<RulesModel_::RuleActionSpecT> rule_action(
new RulesModel_::RuleActionSpecT);
rule_action->action.reset(new ActionSuggestionSpecT);
rule_action->action->type = "text_reply";
rule_action->action->response_text = "General Kenobi!";
rule_action->action->score = 1.0f;
rule_action->action->priority_score = 1.0f;
rule->actions.push_back(std::move(rule_action));
}
// Add input-output low confidence rule.
actions_model->preconditions->suppress_on_low_confidence_input = true;
actions_model->low_confidence_rules.reset(new RulesModelT);
actions_model->low_confidence_rules->regex_rule.emplace_back(
new RulesModel_::RegexRuleT);
actions_model->low_confidence_rules->regex_rule.back()->pattern = "hello";
actions_model->low_confidence_rules->regex_rule.back()->output_pattern =
"(?i:desaster)";
flatbuffers::FlatBufferBuilder builder;
FinishActionsModelBuffer(builder,
ActionsModel::Pack(builder, actions_model.get()));
std::unique_ptr<ActionsSuggestions> actions_suggestions =
ActionsSuggestions::FromUnownedBuffer(
reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
builder.GetSize(), unilib_.get());
ASSERT_TRUE(actions_suggestions);
const ActionsSuggestionsResponse response =
actions_suggestions->SuggestActions(
{{{/*user_id=*/1, "hello there",
/*reference_time_ms_utc=*/0,
/*reference_timezone=*/"Europe/Zurich",
/*annotations=*/{}, /*locales=*/"en"}}});
ASSERT_GE(response.actions.size(), 1);
EXPECT_EQ(response.actions[0].response_text, "General Kenobi!");
}
TEST_F(ActionsSuggestionsTest,
SuggestsActionsLowConfidenceInputOutputOverwrite) {
const std::string actions_model_string =
ReadFile(GetModelPath() + kModelFileName);
std::unique_ptr<ActionsModelT> actions_model =
UnPackActionsModel(actions_model_string.c_str());
actions_model->low_confidence_rules.reset();
// Add custom triggering rule.
actions_model->rules.reset(new RulesModelT());
actions_model->rules->regex_rule.emplace_back(new RulesModel_::RegexRuleT);
RulesModel_::RegexRuleT* rule = actions_model->rules->regex_rule.back().get();
rule->pattern = "^(?i:hello\\s(there))$";
{
std::unique_ptr<RulesModel_::RuleActionSpecT> rule_action(
new RulesModel_::RuleActionSpecT);
rule_action->action.reset(new ActionSuggestionSpecT);
rule_action->action->type = "text_reply";
rule_action->action->response_text = "General Desaster!";
rule_action->action->score = 1.0f;
rule_action->action->priority_score = 1.0f;
rule->actions.push_back(std::move(rule_action));
}
{
std::unique_ptr<RulesModel_::RuleActionSpecT> rule_action(
new RulesModel_::RuleActionSpecT);
rule_action->action.reset(new ActionSuggestionSpecT);
rule_action->action->type = "text_reply";
rule_action->action->response_text = "General Kenobi!";
rule_action->action->score = 1.0f;
rule_action->action->priority_score = 1.0f;
rule->actions.push_back(std::move(rule_action));
}
// Add custom triggering rule via overwrite.
actions_model->preconditions->low_confidence_rules.reset();
TriggeringPreconditionsT preconditions;
preconditions.suppress_on_low_confidence_input = true;
preconditions.low_confidence_rules.reset(new RulesModelT);
preconditions.low_confidence_rules->regex_rule.emplace_back(
new RulesModel_::RegexRuleT);
preconditions.low_confidence_rules->regex_rule.back()->pattern = "hello";
preconditions.low_confidence_rules->regex_rule.back()->output_pattern =
"(?i:desaster)";
flatbuffers::FlatBufferBuilder preconditions_builder;
preconditions_builder.Finish(
TriggeringPreconditions::Pack(preconditions_builder, &preconditions));
std::string serialize_preconditions = std::string(
reinterpret_cast<const char*>(preconditions_builder.GetBufferPointer()),
preconditions_builder.GetSize());
flatbuffers::FlatBufferBuilder builder;
FinishActionsModelBuffer(builder,
ActionsModel::Pack(builder, actions_model.get()));
std::unique_ptr<ActionsSuggestions> actions_suggestions =
ActionsSuggestions::FromUnownedBuffer(
reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
builder.GetSize(), unilib_.get(), serialize_preconditions);
ASSERT_TRUE(actions_suggestions);
const ActionsSuggestionsResponse response =
actions_suggestions->SuggestActions(
{{{/*user_id=*/1, "hello there",
/*reference_time_ms_utc=*/0,
/*reference_timezone=*/"Europe/Zurich",
/*annotations=*/{}, /*locales=*/"en"}}});
ASSERT_GE(response.actions.size(), 1);
EXPECT_EQ(response.actions[0].response_text, "General Kenobi!");
}
#endif
TEST_F(ActionsSuggestionsTest, SuppressActionsFromAnnotationsOnSensitiveTopic) {
const std::string actions_model_string =
ReadFile(GetModelPath() + kModelFileName);
std::unique_ptr<ActionsModelT> actions_model =
UnPackActionsModel(actions_model_string.c_str());
// Don't test if no sensitivity score is produced
if (actions_model->tflite_model_spec->output_sensitive_topic_score < 0) {
return;
}
actions_model->preconditions->max_sensitive_topic_score = 0.0;
actions_model->preconditions->suppress_on_sensitive_topic = true;
flatbuffers::FlatBufferBuilder builder;
FinishActionsModelBuffer(builder,
ActionsModel::Pack(builder, actions_model.get()));
std::unique_ptr<ActionsSuggestions> actions_suggestions =
ActionsSuggestions::FromUnownedBuffer(
reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
builder.GetSize(), unilib_.get());
AnnotatedSpan annotation;
annotation.span = {11, 15};
annotation.classification = {
ClassificationResult(Collections::Address(), 1.0)};
const ActionsSuggestionsResponse response =
actions_suggestions->SuggestActions(
{{{/*user_id=*/1, "are you at home?",
/*reference_time_ms_utc=*/0,
/*reference_timezone=*/"Europe/Zurich",
/*annotations=*/{annotation},
/*locales=*/"en"}}});
EXPECT_THAT(response.actions, testing::IsEmpty());
}
TEST_F(ActionsSuggestionsTest, SuggestsActionsWithLongerConversation) {
const std::string actions_model_string =
ReadFile(GetModelPath() + kModelFileName);
std::unique_ptr<ActionsModelT> actions_model =
UnPackActionsModel(actions_model_string.c_str());
// Allow a larger conversation context.
actions_model->max_conversation_history_length = 10;
flatbuffers::FlatBufferBuilder builder;
FinishActionsModelBuffer(builder,
ActionsModel::Pack(builder, actions_model.get()));
std::unique_ptr<ActionsSuggestions> actions_suggestions =
ActionsSuggestions::FromUnownedBuffer(
reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
builder.GetSize(), unilib_.get());
AnnotatedSpan annotation;
annotation.span = {11, 15};
annotation.classification = {
ClassificationResult(Collections::Address(), 1.0)};
const ActionsSuggestionsResponse response =
actions_suggestions->SuggestActions(
{{{/*user_id=*/ActionsSuggestions::kLocalUserId, "hi, how are you?",
/*reference_time_ms_utc=*/10000,
/*reference_timezone=*/"Europe/Zurich",
/*annotations=*/{}, /*locales=*/"en"},
{/*user_id=*/1, "good! are you at home?",
/*reference_time_ms_utc=*/15000,
/*reference_timezone=*/"Europe/Zurich",
/*annotations=*/{annotation},
/*locales=*/"en"}}});
ASSERT_GE(response.actions.size(), 1);
EXPECT_EQ(response.actions[0].type, "view_map");
EXPECT_EQ(response.actions[0].score, 1.0);
}
TEST_F(ActionsSuggestionsTest, SuggestsActionsFromTF2MultiTaskModel) {
std::unique_ptr<ActionsSuggestions> actions_suggestions =
LoadTestModel(kMultiTaskTF2TestModelFileName);
const ActionsSuggestionsResponse response =
actions_suggestions->SuggestActions(
{{{/*user_id=*/1, "Hello how are you",
/*reference_time_ms_utc=*/0,
/*reference_timezone=*/"Europe/Zurich",
/*annotations=*/{},
/*locales=*/"en"}}});
EXPECT_EQ(response.actions.size(), 4);
EXPECT_EQ(response.actions[0].response_text, "Okay");
EXPECT_EQ(response.actions[0].type, "REPLY_SUGGESTION");
EXPECT_EQ(response.actions[3].type, "TEST_CLASSIFIER_INTENT");
}
TEST_F(ActionsSuggestionsTest, SuggestsActionsFromPhoneGrammarAnnotations) {
std::unique_ptr<ActionsSuggestions> actions_suggestions =
LoadTestModel(kModelGrammarFileName);
AnnotatedSpan annotation;
annotation.span = {11, 15};
annotation.classification = {ClassificationResult("phone", 0.0)};
const ActionsSuggestionsResponse response =
actions_suggestions->SuggestActions(
{{{/*user_id=*/1, "Contact us at: *1234",
/*reference_time_ms_utc=*/0,
/*reference_timezone=*/"Europe/Zurich",
/*annotations=*/{annotation},
/*locales=*/"en"}}});
ASSERT_GE(response.actions.size(), 1);
EXPECT_EQ(response.actions.front().type, "call_phone");
EXPECT_EQ(response.actions.front().score, 0.0);
EXPECT_EQ(response.actions.front().priority_score, 0.0);
EXPECT_EQ(response.actions.front().annotations.size(), 1);
EXPECT_EQ(response.actions.front().annotations.front().span.span.first, 15);
EXPECT_EQ(response.actions.front().annotations.front().span.span.second, 20);
}
TEST_F(ActionsSuggestionsTest, CreateActionsFromClassificationResult) {
std::unique_ptr<ActionsSuggestions> actions_suggestions =
LoadTestModel(kModelFileName);
AnnotatedSpan annotation;
annotation.span = {8, 12};
annotation.classification = {
ClassificationResult(Collections::Flight(), 1.0)};
const ActionsSuggestionsResponse response =
actions_suggestions->SuggestActions(
{{{/*user_id=*/1, "I'm on LX38?",
/*reference_time_ms_utc=*/0,
/*reference_timezone=*/"Europe/Zurich",
/*annotations=*/{annotation},
/*locales=*/"en"}}});
ASSERT_GE(response.actions.size(), 2);
EXPECT_EQ(response.actions[0].type, "track_flight");
EXPECT_EQ(response.actions[0].score, 1.0);
EXPECT_THAT(response.actions[0].annotations, SizeIs(1));
EXPECT_EQ(response.actions[0].annotations[0].span.message_index, 0);
EXPECT_EQ(response.actions[0].annotations[0].span.span, annotation.span);
}
#ifdef TC3_UNILIB_ICU
TEST_F(ActionsSuggestionsTest, CreateActionsFromRules) {
const std::string actions_model_string =
ReadFile(GetModelPath() + kModelFileName);
std::unique_ptr<ActionsModelT> actions_model =
UnPackActionsModel(actions_model_string.c_str());
ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
actions_model->rules.reset(new RulesModelT());
actions_model->rules->regex_rule.emplace_back(new RulesModel_::RegexRuleT);
RulesModel_::RegexRuleT* rule = actions_model->rules->regex_rule.back().get();
rule->pattern = "^(?i:hello\\s(there))$";
rule->actions.emplace_back(new RulesModel_::RuleActionSpecT);
rule->actions.back()->action.reset(new ActionSuggestionSpecT);
ActionSuggestionSpecT* action = rule->actions.back()->action.get();
action->type = "text_reply";
action->response_text = "General Kenobi!";
action->score = 1.0f;
action->priority_score = 1.0f;
// Set capturing groups for entity data.
rule->actions.back()->capturing_group.emplace_back(
new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
RulesModel_::RuleActionSpec_::RuleCapturingGroupT* greeting_group =
rule->actions.back()->capturing_group.back().get();
greeting_group->group_id = 0;
greeting_group->entity_field.reset(new FlatbufferFieldPathT);
greeting_group->entity_field->field.emplace_back(new FlatbufferFieldT);
greeting_group->entity_field->field.back()->field_name = "greeting";
rule->actions.back()->capturing_group.emplace_back(
new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
RulesModel_::RuleActionSpec_::RuleCapturingGroupT* location_group =
rule->actions.back()->capturing_group.back().get();
location_group->group_id = 1;
location_group->entity_field.reset(new FlatbufferFieldPathT);
location_group->entity_field->field.emplace_back(new FlatbufferFieldT);
location_group->entity_field->field.back()->field_name = "location";
// Set test entity data schema.
SetTestEntityDataSchema(actions_model.get());
// Use meta data to generate custom serialized entity data.
MutableFlatbufferBuilder entity_data_builder(
flatbuffers::GetRoot<reflection::Schema>(
actions_model->actions_entity_data_schema.data()));
std::unique_ptr<MutableFlatbuffer> entity_data =
entity_data_builder.NewRoot();
entity_data->Set("person", "Kenobi");
action->serialized_entity_data = entity_data->Serialize();
flatbuffers::FlatBufferBuilder builder;
FinishActionsModelBuffer(builder,
ActionsModel::Pack(builder, actions_model.get()));
std::unique_ptr<ActionsSuggestions> actions_suggestions =
ActionsSuggestions::FromUnownedBuffer(
reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
builder.GetSize(), unilib_.get());
const ActionsSuggestionsResponse response =
actions_suggestions->SuggestActions(
{{{/*user_id=*/1, "hello there", /*reference_time_ms_utc=*/0,
/*reference_timezone=*/"Europe/Zurich",
/*annotations=*/{}, /*locales=*/"en"}}});
EXPECT_GE(response.actions.size(), 1);
EXPECT_EQ(response.actions[0].response_text, "General Kenobi!");
// Check entity data.
const flatbuffers::Table* entity =
flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
response.actions[0].serialized_entity_data.data()));
EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
"hello there");
EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/6)->str(),
"there");
EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
"Kenobi");
}
TEST_F(ActionsSuggestionsTest, CreateActionsFromRulesWithNormalization) {
const std::string actions_model_string =
ReadFile(GetModelPath() + kModelFileName);
std::unique_ptr<ActionsModelT> actions_model =
UnPackActionsModel(actions_model_string.c_str());
ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
actions_model->rules.reset(new RulesModelT());
actions_model->rules->regex_rule.emplace_back(new RulesModel_::RegexRuleT);
RulesModel_::RegexRuleT* rule = actions_model->rules->regex_rule.back().get();
rule->pattern = "^(?i:hello\\sthere)$";
rule->actions.emplace_back(new RulesModel_::RuleActionSpecT);
rule->actions.back()->action.reset(new ActionSuggestionSpecT);
ActionSuggestionSpecT* action = rule->actions.back()->action.get();
action->type = "text_reply";
action->response_text = "General Kenobi!";
action->score = 1.0f;
action->priority_score = 1.0f;
// Set capturing groups for entity data.
rule->actions.back()->capturing_group.emplace_back(
new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
RulesModel_::RuleActionSpec_::RuleCapturingGroupT* greeting_group =
rule->actions.back()->capturing_group.back().get();
greeting_group->group_id = 0;
greeting_group->entity_field.reset(new FlatbufferFieldPathT);
greeting_group->entity_field->field.emplace_back(new FlatbufferFieldT);
greeting_group->entity_field->field.back()->field_name = "greeting";
greeting_group->normalization_options.reset(new NormalizationOptionsT);
greeting_group->normalization_options->codepointwise_normalization =
NormalizationOptions_::CodepointwiseNormalizationOp_DROP_WHITESPACE |
NormalizationOptions_::CodepointwiseNormalizationOp_UPPERCASE;
// Set test entity data schema.
SetTestEntityDataSchema(actions_model.get());
flatbuffers::FlatBufferBuilder builder;
FinishActionsModelBuffer(builder,
ActionsModel::Pack(builder, actions_model.get()));
std::unique_ptr<ActionsSuggestions> actions_suggestions =
ActionsSuggestions::FromUnownedBuffer(
reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
builder.GetSize(), unilib_.get());
const ActionsSuggestionsResponse response =
actions_suggestions->SuggestActions(
{{{/*user_id=*/1, "hello there", /*reference_time_ms_utc=*/0,
/*reference_timezone=*/"Europe/Zurich",
/*annotations=*/{}, /*locales=*/"en"}}});
EXPECT_GE(response.actions.size(), 1);
EXPECT_EQ(response.actions[0].response_text, "General Kenobi!");
// Check entity data.
const flatbuffers::Table* entity =
flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
response.actions[0].serialized_entity_data.data()));
EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
"HELLOTHERE");
}
TEST_F(ActionsSuggestionsTest, CreatesTextRepliesFromRules) {
const std::string actions_model_string =
ReadFile(GetModelPath() + kModelFileName);
std::unique_ptr<ActionsModelT> actions_model =
UnPackActionsModel(actions_model_string.c_str());
ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
actions_model->rules.reset(new RulesModelT());
actions_model->rules->regex_rule.emplace_back(new RulesModel_::RegexRuleT);
RulesModel_::RegexRuleT* rule = actions_model->rules->regex_rule.back().get();
rule->pattern = "(?i:reply (stop|quit|end) (?:to|for) )";
rule->actions.emplace_back(new RulesModel_::RuleActionSpecT);
// Set capturing groups for entity data.
rule->actions.back()->capturing_group.emplace_back(
new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
RulesModel_::RuleActionSpec_::RuleCapturingGroupT* code_group =
rule->actions.back()->capturing_group.back().get();
code_group->group_id = 1;
code_group->text_reply.reset(new ActionSuggestionSpecT);
code_group->text_reply->score = 1.0f;
code_group->text_reply->priority_score = 1.0f;
code_group->normalization_options.reset(new NormalizationOptionsT);
code_group->normalization_options->codepointwise_normalization =
NormalizationOptions_::CodepointwiseNormalizationOp_LOWERCASE;
flatbuffers::FlatBufferBuilder builder;
FinishActionsModelBuffer(builder,
ActionsModel::Pack(builder, actions_model.get()));
std::unique_ptr<ActionsSuggestions> actions_suggestions =
ActionsSuggestions::FromUnownedBuffer(
reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
builder.GetSize(), unilib_.get());
const ActionsSuggestionsResponse response =
actions_suggestions->SuggestActions(
{{{/*user_id=*/1,
"visit test.com or reply STOP to cancel your subscription",
/*reference_time_ms_utc=*/0,
/*reference_timezone=*/"Europe/Zurich",
/*annotations=*/{}, /*locales=*/"en"}}});
EXPECT_GE(response.actions.size(), 1);
EXPECT_EQ(response.actions[0].response_text, "stop");
}
TEST_F(ActionsSuggestionsTest, CreatesActionsFromGrammarRules) {
const std::string actions_model_string =
ReadFile(GetModelPath() + kModelFileName);
std::unique_ptr<ActionsModelT> actions_model =
UnPackActionsModel(actions_model_string.c_str());
ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
actions_model->rules->grammar_rules.reset(new RulesModel_::GrammarRulesT);
// Set tokenizer options.
RulesModel_::GrammarRulesT* action_grammar_rules =
actions_model->rules->grammar_rules.get();
action_grammar_rules->tokenizer_options.reset(new ActionsTokenizerOptionsT);
action_grammar_rules->tokenizer_options->type = TokenizationType_ICU;
action_grammar_rules->tokenizer_options->icu_preserve_whitespace_tokens =
false;
// Setup test rules.
action_grammar_rules->rules.reset(new grammar::RulesSetT);
grammar::LocaleShardMap locale_shard_map =
grammar::LocaleShardMap::CreateLocaleShardMap({""});
grammar::Rules rules(locale_shard_map);
rules.Add(
"<knock>", {"<^>", "ventura", "!?", "<$>"},
/*callback=*/
static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
/*callback_param=*/0);
rules.Finalize().Serialize(/*include_debug_information=*/false,
action_grammar_rules->rules.get());
action_grammar_rules->actions.emplace_back(new RulesModel_::RuleActionSpecT);
RulesModel_::RuleActionSpecT* actions_spec =
action_grammar_rules->actions.back().get();
actions_spec->action.reset(new ActionSuggestionSpecT);
actions_spec->action->response_text = "Yes, Satan?";
actions_spec->action->priority_score = 1.0;
actions_spec->action->score = 1.0;
actions_spec->action->type = "text_reply";
action_grammar_rules->rule_match.emplace_back(
new RulesModel_::GrammarRules_::RuleMatchT);
action_grammar_rules->rule_match.back()->action_id.push_back(0);
flatbuffers::FlatBufferBuilder builder;
FinishActionsModelBuffer(builder,
ActionsModel::Pack(builder, actions_model.get()));
std::unique_ptr<ActionsSuggestions> actions_suggestions =
ActionsSuggestions::FromUnownedBuffer(
reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
builder.GetSize(), unilib_.get());
const ActionsSuggestionsResponse response =
actions_suggestions->SuggestActions(
{{{/*user_id=*/1, "Ventura!",
/*reference_time_ms_utc=*/0,
/*reference_timezone=*/"Europe/Zurich",
/*annotations=*/{}, /*locales=*/"en"}}});
EXPECT_THAT(response.actions, ElementsAre(IsSmartReply("Yes, Satan?")));
}
#if defined(TC3_UNILIB_ICU) && !defined(TEST_NO_DATETIME)
TEST_F(ActionsSuggestionsTest, CreatesActionsWithAnnotationsFromGrammarRules) {
std::unique_ptr<Annotator> annotator =
Annotator::FromPath(GetModelPath() + "en.fb", unilib_.get());
const std::string actions_model_string =
ReadFile(GetModelPath() + kModelFileName);
std::unique_ptr<ActionsModelT> actions_model =
UnPackActionsModel(actions_model_string.c_str());
ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
actions_model->rules->grammar_rules.reset(new RulesModel_::GrammarRulesT);
// Set tokenizer options.
RulesModel_::GrammarRulesT* action_grammar_rules =
actions_model->rules->grammar_rules.get();
action_grammar_rules->tokenizer_options.reset(new ActionsTokenizerOptionsT);
action_grammar_rules->tokenizer_options->type = TokenizationType_ICU;
action_grammar_rules->tokenizer_options->icu_preserve_whitespace_tokens =
false;
// Setup test rules.
action_grammar_rules->rules.reset(new grammar::RulesSetT);
grammar::LocaleShardMap locale_shard_map =
grammar::LocaleShardMap::CreateLocaleShardMap({""});
grammar::Rules rules(locale_shard_map);
rules.Add(
"<event>", {"it", "is", "at", "<time>"},
/*callback=*/
static_cast<grammar::CallbackId>(grammar::DefaultCallback::kRootRule),
/*callback_param=*/0);
rules.BindAnnotation("<time>", "time");
rules.AddAnnotation("datetime");
rules.Finalize().Serialize(/*include_debug_information=*/false,
action_grammar_rules->rules.get());
action_grammar_rules->actions.emplace_back(new RulesModel_::RuleActionSpecT);
RulesModel_::RuleActionSpecT* actions_spec =
action_grammar_rules->actions.back().get();
actions_spec->action.reset(new ActionSuggestionSpecT);
actions_spec->action->priority_score = 1.0;
actions_spec->action->score = 1.0;
actions_spec->action->type = "create_event";
action_grammar_rules->rule_match.emplace_back(
new RulesModel_::GrammarRules_::RuleMatchT);
action_grammar_rules->rule_match.back()->action_id.push_back(0);
flatbuffers::FlatBufferBuilder builder;
FinishActionsModelBuffer(builder,
ActionsModel::Pack(builder, actions_model.get()));
std::unique_ptr<ActionsSuggestions> actions_suggestions =
ActionsSuggestions::FromUnownedBuffer(
reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
builder.GetSize(), unilib_.get());
const ActionsSuggestionsResponse response =
actions_suggestions->SuggestActions(
{{{/*user_id=*/1, "it is at 10:30",
/*reference_time_ms_utc=*/0,
/*reference_timezone=*/"Europe/Zurich",
/*annotations=*/{}, /*locales=*/"en"}}},
annotator.get());
EXPECT_THAT(response.actions, ElementsAre(IsActionOfType("create_event")));
}
#endif
TEST_F(ActionsSuggestionsTest, DeduplicateActions) {
std::unique_ptr<ActionsSuggestions> actions_suggestions =
LoadTestModel(kModelFileName);
ActionsSuggestionsResponse response = actions_suggestions->SuggestActions(
{{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
/*reference_timezone=*/"Europe/Zurich",
/*annotations=*/{}, /*locales=*/"en"}}});
// Check that the location sharing model triggered.
bool has_location_sharing_action = false;
for (const ActionSuggestion& action : response.actions) {
if (action.type == ActionsSuggestionsTypes::ShareLocation()) {
has_location_sharing_action = true;
break;
}
}
EXPECT_TRUE(has_location_sharing_action);
const int num_actions = response.actions.size();
// Add custom rule for location sharing.
const std::string actions_model_string =
ReadFile(GetModelPath() + kModelFileName);
std::unique_ptr<ActionsModelT> actions_model =
UnPackActionsModel(actions_model_string.c_str());
ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
actions_model->rules.reset(new RulesModelT());
actions_model->rules->regex_rule.emplace_back(new RulesModel_::RegexRuleT);
actions_model->rules->regex_rule.back()->pattern =
"^(?i:where are you[.?]?)$";
actions_model->rules->regex_rule.back()->actions.emplace_back(
new RulesModel_::RuleActionSpecT);
actions_model->rules->regex_rule.back()->actions.back()->action.reset(
new ActionSuggestionSpecT);
ActionSuggestionSpecT* action =
actions_model->rules->regex_rule.back()->actions.back()->action.get();
action->score = 1.0f;
action->type = ActionsSuggestionsTypes::ShareLocation();
flatbuffers::FlatBufferBuilder builder;
FinishActionsModelBuffer(builder,
ActionsModel::Pack(builder, actions_model.get()));
actions_suggestions = ActionsSuggestions::FromUnownedBuffer(
reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
builder.GetSize(), unilib_.get());
response = actions_suggestions->SuggestActions(
{{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
/*reference_timezone=*/"Europe/Zurich",
/*annotations=*/{}, /*locales=*/"en"}}});
EXPECT_THAT(response.actions, SizeIs(num_actions));
}
TEST_F(ActionsSuggestionsTest, DeduplicateConflictingActions) {
std::unique_ptr<ActionsSuggestions> actions_suggestions =
LoadTestModel(kModelFileName);
AnnotatedSpan annotation;
annotation.span = {7, 11};
annotation.classification = {
ClassificationResult(Collections::Flight(), 1.0)};
ActionsSuggestionsResponse response = actions_suggestions->SuggestActions(
{{{/*user_id=*/1, "I'm on LX38",
/*reference_time_ms_utc=*/0,
/*reference_timezone=*/"Europe/Zurich",
/*annotations=*/{annotation},
/*locales=*/"en"}}});
// Check that the phone actions are present.
EXPECT_GE(response.actions.size(), 1);
EXPECT_EQ(response.actions[0].type, "track_flight");
// Add custom rule.
const std::string actions_model_string =
ReadFile(GetModelPath() + kModelFileName);
std::unique_ptr<ActionsModelT> actions_model =
UnPackActionsModel(actions_model_string.c_str());
ASSERT_TRUE(DecompressActionsModel(actions_model.get()));
actions_model->rules.reset(new RulesModelT());
actions_model->rules->regex_rule.emplace_back(new RulesModel_::RegexRuleT);
RulesModel_::RegexRuleT* rule = actions_model->rules->regex_rule.back().get();
rule->pattern = "^(?i:I'm on ([a-z0-9]+))$";
rule->actions.emplace_back(new RulesModel_::RuleActionSpecT);
rule->actions.back()->action.reset(new ActionSuggestionSpecT);
ActionSuggestionSpecT* action = rule->actions.back()->action.get();
action->score = 1.0f;
action->priority_score = 2.0f;
action->type = "test_code";
rule->actions.back()->capturing_group.emplace_back(
new RulesModel_::RuleActionSpec_::RuleCapturingGroupT);
RulesModel_::RuleActionSpec_::RuleCapturingGroupT* code_group =
rule->actions.back()->capturing_group.back().get();
code_group->group_id = 1;
code_group->annotation_name = "code";
code_group->annotation_type = "code";
flatbuffers::FlatBufferBuilder builder;
FinishActionsModelBuffer(builder,
ActionsModel::Pack(builder, actions_model.get()));
actions_suggestions = ActionsSuggestions::FromUnownedBuffer(
reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()),
builder.GetSize(), unilib_.get());
response = actions_suggestions->SuggestActions(
{{{/*user_id=*/1, "I'm on LX38",
/*reference_time_ms_utc=*/0,
/*reference_timezone=*/"Europe/Zurich",
/*annotations=*/{annotation},
/*locales=*/"en"}}});
EXPECT_GE(response.actions.size(), 1);
EXPECT_EQ(response.actions[0].type, "test_code");
}
#endif
TEST_F(ActionsSuggestionsTest, RanksActions) {
std::unique_ptr<ActionsSuggestions> actions_suggestions =
LoadTestModel(kModelFileName);
std::vector<AnnotatedSpan> annotations(2);
annotations[0].span = {11, 15};
annotations[0].classification = {ClassificationResult("address", 1.0)};
annotations[1].span = {19, 23};
annotations[1].classification = {ClassificationResult("address", 2.0)};
const ActionsSuggestionsResponse response =
actions_suggestions->SuggestActions(
{{{/*user_id=*/1, "are you at home or work?",
/*reference_time_ms_utc=*/0,
/*reference_timezone=*/"Europe/Zurich",
/*annotations=*/annotations,
/*locales=*/"en"}}});
EXPECT_GE(response.actions.size(), 2);
EXPECT_EQ(response.actions[0].type, "view_map");
EXPECT_EQ(response.actions[0].score, 2.0);
EXPECT_EQ(response.actions[1].type, "view_map");
EXPECT_EQ(response.actions[1].score, 1.0);
}
TEST_F(ActionsSuggestionsTest, VisitActionsModel) {
EXPECT_TRUE(VisitActionsModel<bool>(GetModelPath() + kModelFileName,
[](const ActionsModel* model) {
if (model == nullptr) {
return false;
}
return true;
}));
EXPECT_FALSE(VisitActionsModel<bool>(GetModelPath() + "non_existing_model.fb",
[](const ActionsModel* model) {
if (model == nullptr) {
return false;
}
return true;
}));
}
TEST_F(ActionsSuggestionsTest, SuggestsActionsWithHashGramModel) {
std::unique_ptr<ActionsSuggestions> actions_suggestions =
LoadHashGramTestModel();
ASSERT_TRUE(actions_suggestions != nullptr);
{
const ActionsSuggestionsResponse response =
actions_suggestions->SuggestActions(
{{{/*user_id=*/1, "hello",
/*reference_time_ms_utc=*/0,
/*reference_timezone=*/"Europe/Zurich",
/*annotations=*/{},
/*locales=*/"en"}}});
EXPECT_THAT(response.actions, testing::IsEmpty());
}
{
const ActionsSuggestionsResponse response =
actions_suggestions->SuggestActions(
{{{/*user_id=*/1, "where are you",
/*reference_time_ms_utc=*/0,
/*reference_timezone=*/"Europe/Zurich",
/*annotations=*/{},
/*locales=*/"en"}}});
EXPECT_THAT(
response.actions,
ElementsAre(testing::Field(&ActionSuggestion::type, "share_location")));
}
{
const ActionsSuggestionsResponse response =
actions_suggestions->SuggestActions(
{{{/*user_id=*/1, "do you know johns number",
/*reference_time_ms_utc=*/0,
/*reference_timezone=*/"Europe/Zurich",
/*annotations=*/{},
/*locales=*/"en"}}});
EXPECT_THAT(
response.actions,
ElementsAre(testing::Field(&ActionSuggestion::type, "share_contact")));
}
}
// Test class to expose token embedding methods for testing.
class TestingMessageEmbedder : private ActionsSuggestions {
public:
explicit TestingMessageEmbedder(const ActionsModel* model);
using ActionsSuggestions::EmbedAndFlattenTokens;
using ActionsSuggestions::EmbedTokensPerMessage;
protected:
// EmbeddingExecutor that always returns features based on
// the id of the sparse features.
class FakeEmbeddingExecutor : public EmbeddingExecutor {
public:
bool AddEmbedding(const TensorView<int>& sparse_features, float* dest,
const int dest_size) const override {
TC3_CHECK_GE(dest_size, 1);
EXPECT_EQ(sparse_features.size(), 1);
dest[0] = sparse_features.data()[0];
return true;
}
};
std::unique_ptr<UniLib> unilib_;
};
TestingMessageEmbedder::TestingMessageEmbedder(const ActionsModel* model)
: unilib_(CreateUniLibForTesting()) {
model_ = model;
const ActionsTokenFeatureProcessorOptions* options =
model->feature_processor_options();
feature_processor_.reset(new ActionsFeatureProcessor(options, unilib_.get()));
embedding_executor_.reset(new FakeEmbeddingExecutor());
EXPECT_TRUE(
EmbedTokenId(options->padding_token_id(), &embedded_padding_token_));
EXPECT_TRUE(EmbedTokenId(options->start_token_id(), &embedded_start_token_));
EXPECT_TRUE(EmbedTokenId(options->end_token_id(), &embedded_end_token_));
token_embedding_size_ = feature_processor_->GetTokenEmbeddingSize();
EXPECT_EQ(token_embedding_size_, 1);
}
class EmbeddingTest : public testing::Test {
protected:
explicit EmbeddingTest() {
model_.feature_processor_options.reset(
new ActionsTokenFeatureProcessorOptionsT);
options_ = model_.feature_processor_options.get();
options_->chargram_orders = {1};
options_->num_buckets = 1000;
options_->embedding_size = 1;
options_->start_token_id = 0;
options_->end_token_id = 1;
options_->padding_token_id = 2;
options_->tokenizer_options.reset(new ActionsTokenizerOptionsT);
}
TestingMessageEmbedder CreateTestingMessageEmbedder() {
flatbuffers::FlatBufferBuilder builder;
FinishActionsModelBuffer(builder, ActionsModel::Pack(builder, &model_));
buffer_ = builder.Release();
return TestingMessageEmbedder(
flatbuffers::GetRoot<ActionsModel>(buffer_.data()));
}
flatbuffers::DetachedBuffer buffer_;
ActionsModelT model_;
ActionsTokenFeatureProcessorOptionsT* options_;
};
TEST_F(EmbeddingTest, EmbedsTokensPerMessageWithNoBounds) {
const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
std::vector<std::vector<Token>> tokens = {
{Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
std::vector<float> embeddings;
int max_num_tokens_per_message = 0;
EXPECT_TRUE(embedder.EmbedTokensPerMessage(tokens, &embeddings,
&max_num_tokens_per_message));
EXPECT_EQ(max_num_tokens_per_message, 3);
EXPECT_EQ(embeddings.size(), 3);
EXPECT_THAT(embeddings[0], FloatEq(tc3farmhash::Fingerprint64("a", 1) %
options_->num_buckets));
EXPECT_THAT(embeddings[1], FloatEq(tc3farmhash::Fingerprint64("b", 1) %
options_->num_buckets));
EXPECT_THAT(embeddings[2], FloatEq(tc3farmhash::Fingerprint64("c", 1) %
options_->num_buckets));
}
TEST_F(EmbeddingTest, EmbedsTokensPerMessageWithPadding) {
options_->min_num_tokens_per_message = 5;
const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
std::vector<std::vector<Token>> tokens = {
{Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
std::vector<float> embeddings;
int max_num_tokens_per_message = 0;
EXPECT_TRUE(embedder.EmbedTokensPerMessage(tokens, &embeddings,
&max_num_tokens_per_message));
EXPECT_EQ(max_num_tokens_per_message, 5);
EXPECT_EQ(embeddings.size(), 5);
EXPECT_THAT(embeddings[0], FloatEq(tc3farmhash::Fingerprint64("a", 1) %
options_->num_buckets));
EXPECT_THAT(embeddings[1], FloatEq(tc3farmhash::Fingerprint64("b", 1) %
options_->num_buckets));
EXPECT_THAT(embeddings[2], FloatEq(tc3farmhash::Fingerprint64("c", 1) %
options_->num_buckets));
EXPECT_THAT(embeddings[3], FloatEq(options_->padding_token_id));
EXPECT_THAT(embeddings[4], FloatEq(options_->padding_token_id));
}
TEST_F(EmbeddingTest, EmbedsTokensPerMessageDropsAtBeginning) {
options_->max_num_tokens_per_message = 2;
const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
std::vector<std::vector<Token>> tokens = {
{Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
std::vector<float> embeddings;
int max_num_tokens_per_message = 0;
EXPECT_TRUE(embedder.EmbedTokensPerMessage(tokens, &embeddings,
&max_num_tokens_per_message));
EXPECT_EQ(max_num_tokens_per_message, 2);
EXPECT_EQ(embeddings.size(), 2);
EXPECT_THAT(embeddings[0], FloatEq(tc3farmhash::Fingerprint64("b", 1) %
options_->num_buckets));
EXPECT_THAT(embeddings[1], FloatEq(tc3farmhash::Fingerprint64("c", 1) %
options_->num_buckets));
}
TEST_F(EmbeddingTest, EmbedsTokensPerMessageWithMultipleMessagesNoBounds) {
const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
std::vector<std::vector<Token>> tokens = {
{Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)},
{Token("d", 0, 1), Token("e", 2, 3)}};
std::vector<float> embeddings;
int max_num_tokens_per_message = 0;
EXPECT_TRUE(embedder.EmbedTokensPerMessage(tokens, &embeddings,
&max_num_tokens_per_message));
EXPECT_EQ(max_num_tokens_per_message, 3);
EXPECT_THAT(embeddings[0], FloatEq(tc3farmhash::Fingerprint64("a", 1) %
options_->num_buckets));
EXPECT_THAT(embeddings[1], FloatEq(tc3farmhash::Fingerprint64("b", 1) %
options_->num_buckets));
EXPECT_THAT(embeddings[2], FloatEq(tc3farmhash::Fingerprint64("c", 1) %
options_->num_buckets));
EXPECT_THAT(embeddings[3], FloatEq(tc3farmhash::Fingerprint64("d", 1) %
options_->num_buckets));
EXPECT_THAT(embeddings[4], FloatEq(tc3farmhash::Fingerprint64("e", 1) %
options_->num_buckets));
EXPECT_THAT(embeddings[5], FloatEq(options_->padding_token_id));
}
TEST_F(EmbeddingTest, EmbedsFlattenedTokensWithNoBounds) {
const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
std::vector<std::vector<Token>> tokens = {
{Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
std::vector<float> embeddings;
int total_token_count = 0;
EXPECT_TRUE(
embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count));
EXPECT_EQ(total_token_count, 5);
EXPECT_EQ(embeddings.size(), 5);
EXPECT_THAT(embeddings[0], FloatEq(options_->start_token_id));
EXPECT_THAT(embeddings[1], FloatEq(tc3farmhash::Fingerprint64("a", 1) %
options_->num_buckets));
EXPECT_THAT(embeddings[2], FloatEq(tc3farmhash::Fingerprint64("b", 1) %
options_->num_buckets));
EXPECT_THAT(embeddings[3], FloatEq(tc3farmhash::Fingerprint64("c", 1) %
options_->num_buckets));
EXPECT_THAT(embeddings[4], FloatEq(options_->end_token_id));
}
TEST_F(EmbeddingTest, EmbedsFlattenedTokensWithPadding) {
options_->min_num_total_tokens = 7;
const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
std::vector<std::vector<Token>> tokens = {
{Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
std::vector<float> embeddings;
int total_token_count = 0;
EXPECT_TRUE(
embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count));
EXPECT_EQ(total_token_count, 7);
EXPECT_EQ(embeddings.size(), 7);
EXPECT_THAT(embeddings[0], FloatEq(options_->start_token_id));
EXPECT_THAT(embeddings[1], FloatEq(tc3farmhash::Fingerprint64("a", 1) %
options_->num_buckets));
EXPECT_THAT(embeddings[2], FloatEq(tc3farmhash::Fingerprint64("b", 1) %
options_->num_buckets));
EXPECT_THAT(embeddings[3], FloatEq(tc3farmhash::Fingerprint64("c", 1) %
options_->num_buckets));
EXPECT_THAT(embeddings[4], FloatEq(options_->end_token_id));
EXPECT_THAT(embeddings[5], FloatEq(options_->padding_token_id));
EXPECT_THAT(embeddings[6], FloatEq(options_->padding_token_id));
}
TEST_F(EmbeddingTest, EmbedsFlattenedTokensDropsAtBeginning) {
options_->max_num_total_tokens = 3;
const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
std::vector<std::vector<Token>> tokens = {
{Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}};
std::vector<float> embeddings;
int total_token_count = 0;
EXPECT_TRUE(
embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count));
EXPECT_EQ(total_token_count, 3);
EXPECT_EQ(embeddings.size(), 3);
EXPECT_THAT(embeddings[0], FloatEq(tc3farmhash::Fingerprint64("b", 1) %
options_->num_buckets));
EXPECT_THAT(embeddings[1], FloatEq(tc3farmhash::Fingerprint64("c", 1) %
options_->num_buckets));
EXPECT_THAT(embeddings[2], FloatEq(options_->end_token_id));
}
TEST_F(EmbeddingTest, EmbedsFlattenedTokensWithMultipleMessagesNoBounds) {
const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
std::vector<std::vector<Token>> tokens = {
{Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)},
{Token("d", 0, 1), Token("e", 2, 3)}};
std::vector<float> embeddings;
int total_token_count = 0;
EXPECT_TRUE(
embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count));
EXPECT_EQ(total_token_count, 9);
EXPECT_EQ(embeddings.size(), 9);
EXPECT_THAT(embeddings[0], FloatEq(options_->start_token_id));
EXPECT_THAT(embeddings[1], FloatEq(tc3farmhash::Fingerprint64("a", 1) %
options_->num_buckets));
EXPECT_THAT(embeddings[2], FloatEq(tc3farmhash::Fingerprint64("b", 1) %
options_->num_buckets));
EXPECT_THAT(embeddings[3], FloatEq(tc3farmhash::Fingerprint64("c", 1) %
options_->num_buckets));
EXPECT_THAT(embeddings[4], FloatEq(options_->end_token_id));
EXPECT_THAT(embeddings[5], FloatEq(options_->start_token_id));
EXPECT_THAT(embeddings[6], FloatEq(tc3farmhash::Fingerprint64("d", 1) %
options_->num_buckets));
EXPECT_THAT(embeddings[7], FloatEq(tc3farmhash::Fingerprint64("e", 1) %
options_->num_buckets));
EXPECT_THAT(embeddings[8], FloatEq(options_->end_token_id));
}
TEST_F(EmbeddingTest,
EmbedsFlattenedTokensWithMultipleMessagesDropsAtBeginning) {
options_->max_num_total_tokens = 7;
const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder();
std::vector<std::vector<Token>> tokens = {
{Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)},
{Token("d", 0, 1), Token("e", 2, 3), Token("f", 4, 5)}};
std::vector<float> embeddings;
int total_token_count = 0;
EXPECT_TRUE(
embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count));
EXPECT_EQ(total_token_count, 7);
EXPECT_EQ(embeddings.size(), 7);
EXPECT_THAT(embeddings[0], FloatEq(tc3farmhash::Fingerprint64("c", 1) %
options_->num_buckets));
EXPECT_THAT(embeddings[1], FloatEq(options_->end_token_id));
EXPECT_THAT(embeddings[2], FloatEq(options_->start_token_id));
EXPECT_THAT(embeddings[3], FloatEq(tc3farmhash::Fingerprint64("d", 1) %
options_->num_buckets));
EXPECT_THAT(embeddings[4], FloatEq(tc3farmhash::Fingerprint64("e", 1) %
options_->num_buckets));
EXPECT_THAT(embeddings[5], FloatEq(tc3farmhash::Fingerprint64("f", 1) %
options_->num_buckets));
EXPECT_THAT(embeddings[6], FloatEq(options_->end_token_id));
}
TEST_F(ActionsSuggestionsTest, MultiTaskSuggestActionsDefault) {
std::unique_ptr<ActionsSuggestions> actions_suggestions =
LoadMultiTaskTestModel();
const ActionsSuggestionsResponse response =
actions_suggestions->SuggestActions(
{{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
/*reference_timezone=*/"Europe/Zurich",
/*annotations=*/{}, /*locales=*/"en"}}});
EXPECT_EQ(response.actions.size(),
11 /* 8 binary classification + 3 smart replies*/);
}
const float kDisableThresholdVal = 2.0;
constexpr char kSpamThreshold[] = "spam_confidence_threshold";
constexpr char kLocationThreshold[] = "location_confidence_threshold";
constexpr char kPhoneThreshold[] = "phone_confidence_threshold";
constexpr char kWeatherThreshold[] = "weather_confidence_threshold";
constexpr char kRestaurantsThreshold[] = "restaurants_confidence_threshold";
constexpr char kMoviesThreshold[] = "movies_confidence_threshold";
constexpr char kTtrThreshold[] = "time_to_reply_binary_threshold";
constexpr char kReminderThreshold[] = "reminder_intent_confidence_threshold";
constexpr char kDiversificationParm[] = "diversification_distance_threshold";
constexpr char kEmpiricalProbFactor[] = "empirical_probability_factor";
ActionSuggestionOptions GetOptionsToDisableAllClassification() {
ActionSuggestionOptions options;
// Disable all classification heads.
options.model_parameters.insert(
{kSpamThreshold, libtextclassifier3::Variant(kDisableThresholdVal)});
options.model_parameters.insert(
{kLocationThreshold, libtextclassifier3::Variant(kDisableThresholdVal)});
options.model_parameters.insert(
{kPhoneThreshold, libtextclassifier3::Variant(kDisableThresholdVal)});
options.model_parameters.insert(
{kWeatherThreshold, libtextclassifier3::Variant(kDisableThresholdVal)});
options.model_parameters.insert(
{kRestaurantsThreshold,
libtextclassifier3::Variant(kDisableThresholdVal)});
options.model_parameters.insert(
{kMoviesThreshold, libtextclassifier3::Variant(kDisableThresholdVal)});
options.model_parameters.insert(
{kTtrThreshold, libtextclassifier3::Variant(kDisableThresholdVal)});
options.model_parameters.insert(
{kReminderThreshold, libtextclassifier3::Variant(kDisableThresholdVal)});
return options;
}
TEST_F(ActionsSuggestionsTest, MultiTaskSuggestActionsSmartReplyOnly) {
std::unique_ptr<ActionsSuggestions> actions_suggestions =
LoadMultiTaskTestModel();
const ActionSuggestionOptions options =
GetOptionsToDisableAllClassification();
const ActionsSuggestionsResponse response =
actions_suggestions->SuggestActions(
{{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
/*reference_timezone=*/"Europe/Zurich",
/*annotations=*/{}, /*locales=*/"en"}}},
/*annotator=*/nullptr, options);
EXPECT_THAT(response.actions,
ElementsAre(IsSmartReply("Here"), IsSmartReply("I'm here"),
IsSmartReply("I'm home")));
EXPECT_EQ(response.actions.size(), 3 /*3 smart replies*/);
}
const int kUserProfileSize = 1000;
constexpr char kUserProfileTokenIndex[] = "user_profile_token_index";
constexpr char kUserProfileTokenWeight[] = "user_profile_token_weight";
ActionSuggestionOptions GetOptionsForSmartReplyP13nModel() {
ActionSuggestionOptions options;
const std::vector<int> user_profile_token_indexes(kUserProfileSize, 1);
const std::vector<float> user_profile_token_weights(kUserProfileSize, 0.1f);
options.model_parameters.insert(
{kUserProfileTokenIndex,
libtextclassifier3::Variant(user_profile_token_indexes)});
options.model_parameters.insert(
{kUserProfileTokenWeight,
libtextclassifier3::Variant(user_profile_token_weights)});
return options;
}
TEST_F(ActionsSuggestionsTest, MultiTaskSuggestActionsSmartReplyP13n) {
std::unique_ptr<ActionsSuggestions> actions_suggestions =
LoadMultiTaskSrP13nTestModel();
const ActionSuggestionOptions options = GetOptionsForSmartReplyP13nModel();
const ActionsSuggestionsResponse response =
actions_suggestions->SuggestActions(
{{{/*user_id=*/1, "How are you?", /*reference_time_ms_utc=*/0,
/*reference_timezone=*/"Europe/Zurich",
/*annotations=*/{}, /*locales=*/"en"}}},
/*annotator=*/nullptr, options);
EXPECT_EQ(response.actions.size(), 3 /*3 smart replies*/);
}
TEST_F(ActionsSuggestionsTest,
MultiTaskSuggestActionsDiversifiedSmartReplyAndLocation) {
std::unique_ptr<ActionsSuggestions> actions_suggestions =
LoadMultiTaskTestModel();
ActionSuggestionOptions options = GetOptionsToDisableAllClassification();
options.model_parameters[kLocationThreshold] =
libtextclassifier3::Variant(0.35f);
options.model_parameters.insert(
{kDiversificationParm, libtextclassifier3::Variant(0.5f)});
const ActionsSuggestionsResponse response =
actions_suggestions->SuggestActions(
{{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
/*reference_timezone=*/"Europe/Zurich",
/*annotations=*/{}, /*locales=*/"en"}}},
/*annotator=*/nullptr, options);
EXPECT_THAT(
response.actions,
ElementsAre(IsActionOfType("LOCATION_SHARE"), IsSmartReply("Here"),
IsSmartReply("Yes"), IsSmartReply("😟")));
EXPECT_EQ(response.actions.size(), 4 /*1 location share + 3 smart replies*/);
}
TEST_F(ActionsSuggestionsTest,
MultiTaskSuggestActionsEmProBoostedSmartReplyAndLocationAndReminder) {
std::unique_ptr<ActionsSuggestions> actions_suggestions =
LoadMultiTaskTestModel();
ActionSuggestionOptions options = GetOptionsToDisableAllClassification();
options.model_parameters[kLocationThreshold] =
libtextclassifier3::Variant(0.35f);
// reminder head always trigger since the threshold is zero.
options.model_parameters[kReminderThreshold] =
libtextclassifier3::Variant(0.0f);
options.model_parameters.insert(
{kEmpiricalProbFactor, libtextclassifier3::Variant(2.0f)});
const ActionsSuggestionsResponse response =
actions_suggestions->SuggestActions(
{{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0,
/*reference_timezone=*/"Europe/Zurich",
/*annotations=*/{}, /*locales=*/"en"}}},
/*annotator=*/nullptr, options);
EXPECT_THAT(
response.actions,
ElementsAre(IsSmartReply("Okay"), IsActionOfType("LOCATION_SHARE"),
IsSmartReply("Yes"),
/*Different emoji than previous test*/ IsSmartReply("😊"),
IsActionOfType("REMINDER_INTENT")));
EXPECT_EQ(response.actions.size(), 5 /*1 location share + 3 smart replies*/);
}
TEST_F(ActionsSuggestionsTest, SuggestsActionsFromMultiTaskSrEmojiModel) {
std::unique_ptr<ActionsSuggestions> actions_suggestions =
LoadTestModel(kMultiTaskSrEmojiModelFileName);
const ActionsSuggestionsResponse response =
actions_suggestions->SuggestActions(
{{{/*user_id=*/1, "hello?",
/*reference_time_ms_utc=*/0,
/*reference_timezone=*/"Europe/Zurich",
/*annotations=*/{},
/*locales=*/"en"}}});
EXPECT_EQ(response.actions.size(), 5);
EXPECT_EQ(response.actions[0].response_text, "😁");
EXPECT_EQ(response.actions[0].type, "text_reply");
EXPECT_EQ(response.actions[1].response_text, "👋");
EXPECT_EQ(response.actions[1].type, "text_reply");
EXPECT_EQ(response.actions[2].response_text, "Yes");
EXPECT_EQ(response.actions[2].type, "text_reply");
}
TEST_F(ActionsSuggestionsTest, MultiTaskSrEmojiModelRemovesTextHeadEmoji) {
std::unique_ptr<ActionsSuggestions> actions_suggestions =
LoadTestModel(kMultiTaskSrEmojiModelFileName);
const ActionsSuggestionsResponse response =
actions_suggestions->SuggestActions(
{{{/*user_id=*/1, "a pleasure chatting",
/*reference_time_ms_utc=*/0,
/*reference_timezone=*/"Europe/Zurich",
/*annotations=*/{},
/*locales=*/"en"}}});
EXPECT_EQ(response.actions.size(), 3);
EXPECT_EQ(response.actions[0].response_text, "😁");
EXPECT_EQ(response.actions[0].type, "text_reply");
EXPECT_EQ(response.actions[1].response_text, "😘");
EXPECT_EQ(response.actions[1].type, "text_reply");
EXPECT_EQ(response.actions[2].response_text, "Okay");
EXPECT_EQ(response.actions[2].type, "text_reply");
}
TEST_F(ActionsSuggestionsTest, LiveRelayModel) {
std::unique_ptr<ActionsSuggestions> actions_suggestions =
LoadTestModel(kLiveRelayTFLiteModelFileName);
const ActionsSuggestionsResponse response =
actions_suggestions->SuggestActions(
{{{/*user_id=*/1, "Hi",
/*reference_time_ms_utc=*/0,
/*reference_timezone=*/"Europe/Zurich",
/*annotations=*/{},
/*locales=*/"en"}}});
EXPECT_EQ(response.actions.size(), 3);
EXPECT_EQ(response.actions[0].response_text, "Hi how are you doing");
EXPECT_EQ(response.actions[0].type, "text_reply");
EXPECT_EQ(response.actions[1].response_text, "Hi whats up");
EXPECT_EQ(response.actions[1].type, "text_reply");
}
TEST_F(ActionsSuggestionsTest, SuggestsActionsFromSensitiveTfLiteModel) {
std::unique_ptr<ActionsSuggestions> actions_suggestions =
LoadTestModel(kSensitiveTFliteModelFileName);
const ActionsSuggestionsResponse response =
actions_suggestions->SuggestActions(
{{{/*user_id=*/1, "I want to kill myself",
/*reference_time_ms_utc=*/0,
/*reference_timezone=*/"Europe/Zurich",
/*annotations=*/{},
/*locales=*/"en"}}});
EXPECT_EQ(response.actions.size(), 0);
EXPECT_TRUE(response.is_sensitive);
EXPECT_FALSE(response.output_filtered_low_confidence);
}
} // namespace
} // namespace libtextclassifier3