| /* |
| * Copyright (C) 2018 The Android Open Source Project |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| #include "actions/actions-suggestions.h" |
| |
| #include <fstream> |
| #include <iterator> |
| #include <memory> |
| |
| #include "actions/actions_model_generated.h" |
| #include "actions/test_utils.h" |
| #include "actions/zlib-utils.h" |
| #include "annotator/collections.h" |
| #include "annotator/types.h" |
| #include "utils/flatbuffers.h" |
| #include "utils/flatbuffers_generated.h" |
| #include "utils/hash/farmhash.h" |
| #include "gmock/gmock.h" |
| #include "gtest/gtest.h" |
| #include "flatbuffers/flatbuffers.h" |
| #include "flatbuffers/reflection.h" |
| |
| namespace libtextclassifier3 { |
| namespace { |
| using testing::_; |
| |
| constexpr char kModelFileName[] = "actions_suggestions_test.model"; |
| constexpr char kHashGramModelFileName[] = |
| "actions_suggestions_test.hashgram.model"; |
| |
| std::string ReadFile(const std::string& file_name) { |
| std::ifstream file_stream(file_name); |
| return std::string(std::istreambuf_iterator<char>(file_stream), {}); |
| } |
| |
| std::string GetModelPath() { |
| return ""; |
| } |
| |
| class ActionsSuggestionsTest : public testing::Test { |
| protected: |
| ActionsSuggestionsTest() : INIT_UNILIB_FOR_TESTING(unilib_) {} |
| std::unique_ptr<ActionsSuggestions> LoadTestModel() { |
| return ActionsSuggestions::FromPath(GetModelPath() + kModelFileName, |
| &unilib_); |
| } |
| std::unique_ptr<ActionsSuggestions> LoadHashGramTestModel() { |
| return ActionsSuggestions::FromPath(GetModelPath() + kHashGramModelFileName, |
| &unilib_); |
| } |
| UniLib unilib_; |
| }; |
| |
| TEST_F(ActionsSuggestionsTest, InstantiateActionSuggestions) { |
| EXPECT_THAT(LoadTestModel(), testing::NotNull()); |
| } |
| |
| TEST_F(ActionsSuggestionsTest, SuggestActions) { |
| std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel(); |
| const ActionsSuggestionsResponse& response = |
| actions_suggestions->SuggestActions( |
| {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0, |
| /*reference_timezone=*/"Europe/Zurich", |
| /*annotations=*/{}, /*locales=*/"en"}}}); |
| EXPECT_EQ(response.actions.size(), 3 /* share_location + 2 smart replies*/); |
| } |
| |
| TEST_F(ActionsSuggestionsTest, SuggestNoActionsForUnknownLocale) { |
| std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel(); |
| const ActionsSuggestionsResponse& response = |
| actions_suggestions->SuggestActions( |
| {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0, |
| /*reference_timezone=*/"Europe/Zurich", |
| /*annotations=*/{}, /*locales=*/"zz"}}}); |
| EXPECT_THAT(response.actions, testing::IsEmpty()); |
| } |
| |
| TEST_F(ActionsSuggestionsTest, SuggestActionsFromAnnotations) { |
| std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel(); |
| AnnotatedSpan annotation; |
| annotation.span = {11, 15}; |
| annotation.classification = {ClassificationResult("address", 1.0)}; |
| const ActionsSuggestionsResponse& response = |
| actions_suggestions->SuggestActions( |
| {{{/*user_id=*/1, "are you at home?", |
| /*reference_time_ms_utc=*/0, |
| /*reference_timezone=*/"Europe/Zurich", |
| /*annotations=*/{annotation}, |
| /*locales=*/"en"}}}); |
| ASSERT_GE(response.actions.size(), 1); |
| EXPECT_EQ(response.actions.front().type, "view_map"); |
| EXPECT_EQ(response.actions.front().score, 1.0); |
| } |
| |
| TEST_F(ActionsSuggestionsTest, SuggestActionsFromAnnotationsWithEntityData) { |
| const std::string actions_model_string = |
| ReadFile(GetModelPath() + kModelFileName); |
| std::unique_ptr<ActionsModelT> actions_model = |
| UnPackActionsModel(actions_model_string.c_str()); |
| SetTestEntityDataSchema(actions_model.get()); |
| |
| // Set custom actions from annotations config. |
| actions_model->annotation_actions_spec->annotation_mapping.clear(); |
| actions_model->annotation_actions_spec->annotation_mapping.emplace_back( |
| new AnnotationActionsSpec_::AnnotationMappingT); |
| AnnotationActionsSpec_::AnnotationMappingT* mapping = |
| actions_model->annotation_actions_spec->annotation_mapping.back().get(); |
| mapping->annotation_collection = "address"; |
| mapping->action.reset(new ActionSuggestionSpecT); |
| mapping->action->type = "save_location"; |
| mapping->action->score = 1.0; |
| mapping->action->priority_score = 2.0; |
| mapping->entity_field.reset(new FlatbufferFieldPathT); |
| mapping->entity_field->field.emplace_back(new FlatbufferFieldT); |
| mapping->entity_field->field.back()->field_name = "location"; |
| |
| flatbuffers::FlatBufferBuilder builder; |
| FinishActionsModelBuffer(builder, |
| ActionsModel::Pack(builder, actions_model.get())); |
| std::unique_ptr<ActionsSuggestions> actions_suggestions = |
| ActionsSuggestions::FromUnownedBuffer( |
| reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()), |
| builder.GetSize(), &unilib_); |
| |
| AnnotatedSpan annotation; |
| annotation.span = {11, 15}; |
| annotation.classification = {ClassificationResult("address", 1.0)}; |
| const ActionsSuggestionsResponse& response = |
| actions_suggestions->SuggestActions( |
| {{{/*user_id=*/1, "are you at home?", |
| /*reference_time_ms_utc=*/0, |
| /*reference_timezone=*/"Europe/Zurich", |
| /*annotations=*/{annotation}, |
| /*locales=*/"en"}}}); |
| ASSERT_GE(response.actions.size(), 1); |
| EXPECT_EQ(response.actions.front().type, "save_location"); |
| EXPECT_EQ(response.actions.front().score, 1.0); |
| |
| // Check that the `location` entity field holds the text from the address |
| // annotation. |
| const flatbuffers::Table* entity = |
| flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>( |
| response.actions.front().serialized_entity_data.data())); |
| EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/6)->str(), |
| "home"); |
| } |
| |
| TEST_F(ActionsSuggestionsTest, SuggestActionsFromDuplicatedAnnotations) { |
| std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel(); |
| AnnotatedSpan flight_annotation; |
| flight_annotation.span = {11, 15}; |
| flight_annotation.classification = {ClassificationResult("flight", 2.5)}; |
| AnnotatedSpan flight_annotation2; |
| flight_annotation2.span = {35, 39}; |
| flight_annotation2.classification = {ClassificationResult("flight", 3.0)}; |
| AnnotatedSpan email_annotation; |
| email_annotation.span = {55, 68}; |
| email_annotation.classification = {ClassificationResult("email", 2.0)}; |
| |
| const ActionsSuggestionsResponse& response = |
| actions_suggestions->SuggestActions( |
| {{{/*user_id=*/1, |
| "call me at LX38 or send message to LX38 or test@test.com.", |
| /*reference_time_ms_utc=*/0, |
| /*reference_timezone=*/"Europe/Zurich", |
| /*annotations=*/ |
| {flight_annotation, flight_annotation2, email_annotation}, |
| /*locales=*/"en"}}}); |
| |
| ASSERT_GE(response.actions.size(), 2); |
| EXPECT_EQ(response.actions[0].type, "track_flight"); |
| EXPECT_EQ(response.actions[0].score, 3.0); |
| EXPECT_EQ(response.actions[1].type, "send_email"); |
| EXPECT_EQ(response.actions[1].score, 2.0); |
| } |
| |
| TEST_F(ActionsSuggestionsTest, SuggestActionsAnnotationsNoDeduplication) { |
| const std::string actions_model_string = |
| ReadFile(GetModelPath() + kModelFileName); |
| std::unique_ptr<ActionsModelT> actions_model = |
| UnPackActionsModel(actions_model_string.c_str()); |
| // Disable deduplication. |
| actions_model->annotation_actions_spec->deduplicate_annotations = false; |
| flatbuffers::FlatBufferBuilder builder; |
| FinishActionsModelBuffer(builder, |
| ActionsModel::Pack(builder, actions_model.get())); |
| std::unique_ptr<ActionsSuggestions> actions_suggestions = |
| ActionsSuggestions::FromUnownedBuffer( |
| reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()), |
| builder.GetSize(), &unilib_); |
| AnnotatedSpan flight_annotation; |
| flight_annotation.span = {11, 15}; |
| flight_annotation.classification = {ClassificationResult("flight", 2.5)}; |
| AnnotatedSpan flight_annotation2; |
| flight_annotation2.span = {35, 39}; |
| flight_annotation2.classification = {ClassificationResult("flight", 3.0)}; |
| AnnotatedSpan email_annotation; |
| email_annotation.span = {55, 68}; |
| email_annotation.classification = {ClassificationResult("email", 2.0)}; |
| |
| const ActionsSuggestionsResponse& response = |
| actions_suggestions->SuggestActions( |
| {{{/*user_id=*/1, |
| "call me at LX38 or send message to LX38 or test@test.com.", |
| /*reference_time_ms_utc=*/0, |
| /*reference_timezone=*/"Europe/Zurich", |
| /*annotations=*/ |
| {flight_annotation, flight_annotation2, email_annotation}, |
| /*locales=*/"en"}}}); |
| |
| ASSERT_GE(response.actions.size(), 3); |
| EXPECT_EQ(response.actions[0].type, "track_flight"); |
| EXPECT_EQ(response.actions[0].score, 3.0); |
| EXPECT_EQ(response.actions[1].type, "track_flight"); |
| EXPECT_EQ(response.actions[1].score, 2.5); |
| EXPECT_EQ(response.actions[2].type, "send_email"); |
| EXPECT_EQ(response.actions[2].score, 2.0); |
| } |
| |
| ActionsSuggestionsResponse TestSuggestActionsFromAnnotations( |
| const std::function<void(ActionsModelT*)>& set_config_fn, |
| const UniLib* unilib = nullptr) { |
| const std::string actions_model_string = |
| ReadFile(GetModelPath() + kModelFileName); |
| std::unique_ptr<ActionsModelT> actions_model = |
| UnPackActionsModel(actions_model_string.c_str()); |
| |
| // Set custom config. |
| set_config_fn(actions_model.get()); |
| |
| // Disable smart reply for easier testing. |
| actions_model->preconditions->min_smart_reply_triggering_score = 1.0; |
| |
| flatbuffers::FlatBufferBuilder builder; |
| FinishActionsModelBuffer(builder, |
| ActionsModel::Pack(builder, actions_model.get())); |
| std::unique_ptr<ActionsSuggestions> actions_suggestions = |
| ActionsSuggestions::FromUnownedBuffer( |
| reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()), |
| builder.GetSize(), unilib); |
| |
| AnnotatedSpan flight_annotation; |
| flight_annotation.span = {15, 19}; |
| flight_annotation.classification = {ClassificationResult("flight", 2.0)}; |
| AnnotatedSpan email_annotation; |
| email_annotation.span = {0, 16}; |
| email_annotation.classification = {ClassificationResult("email", 1.0)}; |
| |
| return actions_suggestions->SuggestActions( |
| {{{/*user_id=*/ActionsSuggestions::kLocalUserId, |
| "hehe@android.com", |
| /*reference_time_ms_utc=*/0, |
| /*reference_timezone=*/"Europe/Zurich", |
| /*annotations=*/ |
| {email_annotation}, |
| /*locales=*/"en"}, |
| {/*user_id=*/2, |
| "yoyo@android.com", |
| /*reference_time_ms_utc=*/0, |
| /*reference_timezone=*/"Europe/Zurich", |
| /*annotations=*/ |
| {email_annotation}, |
| /*locales=*/"en"}, |
| {/*user_id=*/1, |
| "test@android.com", |
| /*reference_time_ms_utc=*/0, |
| /*reference_timezone=*/"Europe/Zurich", |
| /*annotations=*/ |
| {email_annotation}, |
| /*locales=*/"en"}, |
| {/*user_id=*/1, |
| "I am on flight LX38.", |
| /*reference_time_ms_utc=*/0, |
| /*reference_timezone=*/"Europe/Zurich", |
| /*annotations=*/ |
| {flight_annotation}, |
| /*locales=*/"en"}}}); |
| } |
| |
| TEST_F(ActionsSuggestionsTest, SuggestActionsWithAnnotationsOnlyLastMessage) { |
| const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations( |
| [](ActionsModelT* actions_model) { |
| actions_model->annotation_actions_spec->include_local_user_messages = |
| false; |
| actions_model->annotation_actions_spec->only_until_last_sent = true; |
| actions_model->annotation_actions_spec->max_history_from_any_person = 1; |
| actions_model->annotation_actions_spec->max_history_from_last_person = |
| 1; |
| }, |
| &unilib_); |
| EXPECT_EQ(response.actions.size(), 1); |
| EXPECT_EQ(response.actions[0].type, "track_flight"); |
| } |
| |
| TEST_F(ActionsSuggestionsTest, SuggestActionsWithAnnotationsOnlyLastPerson) { |
| const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations( |
| [](ActionsModelT* actions_model) { |
| actions_model->annotation_actions_spec->include_local_user_messages = |
| false; |
| actions_model->annotation_actions_spec->only_until_last_sent = true; |
| actions_model->annotation_actions_spec->max_history_from_any_person = 1; |
| actions_model->annotation_actions_spec->max_history_from_last_person = |
| 3; |
| }, |
| &unilib_); |
| EXPECT_EQ(response.actions.size(), 2); |
| EXPECT_EQ(response.actions[0].type, "track_flight"); |
| EXPECT_EQ(response.actions[1].type, "send_email"); |
| } |
| |
| TEST_F(ActionsSuggestionsTest, SuggestActionsWithAnnotationsFromAny) { |
| const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations( |
| [](ActionsModelT* actions_model) { |
| actions_model->annotation_actions_spec->include_local_user_messages = |
| false; |
| actions_model->annotation_actions_spec->only_until_last_sent = true; |
| actions_model->annotation_actions_spec->max_history_from_any_person = 2; |
| actions_model->annotation_actions_spec->max_history_from_last_person = |
| 1; |
| }, |
| &unilib_); |
| EXPECT_EQ(response.actions.size(), 2); |
| EXPECT_EQ(response.actions[0].type, "track_flight"); |
| EXPECT_EQ(response.actions[1].type, "send_email"); |
| } |
| |
| TEST_F(ActionsSuggestionsTest, |
| SuggestActionsWithAnnotationsFromAnyManyMessages) { |
| const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations( |
| [](ActionsModelT* actions_model) { |
| actions_model->annotation_actions_spec->include_local_user_messages = |
| false; |
| actions_model->annotation_actions_spec->only_until_last_sent = true; |
| actions_model->annotation_actions_spec->max_history_from_any_person = 3; |
| actions_model->annotation_actions_spec->max_history_from_last_person = |
| 1; |
| }, |
| &unilib_); |
| EXPECT_EQ(response.actions.size(), 3); |
| EXPECT_EQ(response.actions[0].type, "track_flight"); |
| EXPECT_EQ(response.actions[1].type, "send_email"); |
| EXPECT_EQ(response.actions[2].type, "send_email"); |
| } |
| |
| TEST_F(ActionsSuggestionsTest, |
| SuggestActionsWithAnnotationsFromAnyManyMessagesButNotLocalUser) { |
| const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations( |
| [](ActionsModelT* actions_model) { |
| actions_model->annotation_actions_spec->include_local_user_messages = |
| false; |
| actions_model->annotation_actions_spec->only_until_last_sent = true; |
| actions_model->annotation_actions_spec->max_history_from_any_person = 5; |
| actions_model->annotation_actions_spec->max_history_from_last_person = |
| 1; |
| }, |
| &unilib_); |
| EXPECT_EQ(response.actions.size(), 3); |
| EXPECT_EQ(response.actions[0].type, "track_flight"); |
| EXPECT_EQ(response.actions[1].type, "send_email"); |
| EXPECT_EQ(response.actions[2].type, "send_email"); |
| } |
| |
| TEST_F(ActionsSuggestionsTest, |
| SuggestActionsWithAnnotationsFromAnyManyMessagesAlsoFromLocalUser) { |
| const ActionsSuggestionsResponse response = TestSuggestActionsFromAnnotations( |
| [](ActionsModelT* actions_model) { |
| actions_model->annotation_actions_spec->include_local_user_messages = |
| true; |
| actions_model->annotation_actions_spec->only_until_last_sent = false; |
| actions_model->annotation_actions_spec->max_history_from_any_person = 5; |
| actions_model->annotation_actions_spec->max_history_from_last_person = |
| 1; |
| }, |
| &unilib_); |
| EXPECT_EQ(response.actions.size(), 4); |
| EXPECT_EQ(response.actions[0].type, "track_flight"); |
| EXPECT_EQ(response.actions[1].type, "send_email"); |
| EXPECT_EQ(response.actions[2].type, "send_email"); |
| EXPECT_EQ(response.actions[3].type, "send_email"); |
| } |
| |
| void TestSuggestActionsWithThreshold( |
| const std::function<void(ActionsModelT*)>& set_value_fn, |
| const UniLib* unilib = nullptr, const int expected_size = 0, |
| const std::string& preconditions_overwrite = "") { |
| const std::string actions_model_string = |
| ReadFile(GetModelPath() + kModelFileName); |
| std::unique_ptr<ActionsModelT> actions_model = |
| UnPackActionsModel(actions_model_string.c_str()); |
| set_value_fn(actions_model.get()); |
| flatbuffers::FlatBufferBuilder builder; |
| FinishActionsModelBuffer(builder, |
| ActionsModel::Pack(builder, actions_model.get())); |
| std::unique_ptr<ActionsSuggestions> actions_suggestions = |
| ActionsSuggestions::FromUnownedBuffer( |
| reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()), |
| builder.GetSize(), unilib, preconditions_overwrite); |
| ASSERT_TRUE(actions_suggestions); |
| const ActionsSuggestionsResponse& response = |
| actions_suggestions->SuggestActions( |
| {{{/*user_id=*/1, "I have the low-ground. Where are you?", |
| /*reference_time_ms_utc=*/0, |
| /*reference_timezone=*/"Europe/Zurich", |
| /*annotations=*/{}, /*locales=*/"en"}}}); |
| EXPECT_LE(response.actions.size(), expected_size); |
| } |
| |
| TEST_F(ActionsSuggestionsTest, SuggestActionsWithTriggeringScore) { |
| TestSuggestActionsWithThreshold( |
| [](ActionsModelT* actions_model) { |
| actions_model->preconditions->min_smart_reply_triggering_score = 1.0; |
| }, |
| &unilib_, |
| /*expected_size=*/1 /*no smart reply, only actions*/ |
| ); |
| } |
| |
| TEST_F(ActionsSuggestionsTest, SuggestActionsWithMinReplyScore) { |
| TestSuggestActionsWithThreshold( |
| [](ActionsModelT* actions_model) { |
| actions_model->preconditions->min_reply_score_threshold = 1.0; |
| }, |
| &unilib_, |
| /*expected_size=*/1 /*no smart reply, only actions*/ |
| ); |
| } |
| |
| TEST_F(ActionsSuggestionsTest, SuggestActionsWithSensitiveTopicScore) { |
| TestSuggestActionsWithThreshold( |
| [](ActionsModelT* actions_model) { |
| actions_model->preconditions->max_sensitive_topic_score = 0.0; |
| }, |
| &unilib_, |
| /*expected_size=*/4 /* no sensitive prediction in test model*/); |
| } |
| |
| TEST_F(ActionsSuggestionsTest, SuggestActionsWithMaxInputLength) { |
| TestSuggestActionsWithThreshold( |
| [](ActionsModelT* actions_model) { |
| actions_model->preconditions->max_input_length = 0; |
| }, |
| &unilib_); |
| } |
| |
| TEST_F(ActionsSuggestionsTest, SuggestActionsWithMinInputLength) { |
| TestSuggestActionsWithThreshold( |
| [](ActionsModelT* actions_model) { |
| actions_model->preconditions->min_input_length = 100; |
| }, |
| &unilib_); |
| } |
| |
| TEST_F(ActionsSuggestionsTest, SuggestActionsWithPreconditionsOverwrite) { |
| TriggeringPreconditionsT preconditions_overwrite; |
| preconditions_overwrite.max_input_length = 0; |
| flatbuffers::FlatBufferBuilder builder; |
| builder.Finish( |
| TriggeringPreconditions::Pack(builder, &preconditions_overwrite)); |
| TestSuggestActionsWithThreshold( |
| // Keep model untouched. |
| [](ActionsModelT* actions_model) {}, &unilib_, |
| /*expected_size=*/0, |
| std::string(reinterpret_cast<const char*>(builder.GetBufferPointer()), |
| builder.GetSize())); |
| } |
| |
| #ifdef TC3_UNILIB_ICU |
| TEST_F(ActionsSuggestionsTest, SuggestActionsLowConfidence) { |
| TestSuggestActionsWithThreshold( |
| [](ActionsModelT* actions_model) { |
| actions_model->preconditions->suppress_on_low_confidence_input = true; |
| actions_model->low_confidence_rules.reset(new RulesModelT); |
| actions_model->low_confidence_rules->rule.emplace_back( |
| new RulesModel_::RuleT); |
| actions_model->low_confidence_rules->rule.back()->pattern = |
| "low-ground"; |
| }, |
| &unilib_); |
| } |
| |
| TEST_F(ActionsSuggestionsTest, SuggestActionsLowConfidenceInputOutput) { |
| const std::string actions_model_string = |
| ReadFile(GetModelPath() + kModelFileName); |
| std::unique_ptr<ActionsModelT> actions_model = |
| UnPackActionsModel(actions_model_string.c_str()); |
| // Add custom triggering rule. |
| actions_model->rules.reset(new RulesModelT()); |
| actions_model->rules->rule.emplace_back(new RulesModel_::RuleT); |
| RulesModel_::RuleT* rule = actions_model->rules->rule.back().get(); |
| rule->pattern = "^(?i:hello\\s(there))$"; |
| { |
| std::unique_ptr<RulesModel_::Rule_::RuleActionSpecT> rule_action( |
| new RulesModel_::Rule_::RuleActionSpecT); |
| rule_action->action.reset(new ActionSuggestionSpecT); |
| rule_action->action->type = "text_reply"; |
| rule_action->action->response_text = "General Desaster!"; |
| rule_action->action->score = 1.0f; |
| rule_action->action->priority_score = 1.0f; |
| rule->actions.push_back(std::move(rule_action)); |
| } |
| { |
| std::unique_ptr<RulesModel_::Rule_::RuleActionSpecT> rule_action( |
| new RulesModel_::Rule_::RuleActionSpecT); |
| rule_action->action.reset(new ActionSuggestionSpecT); |
| rule_action->action->type = "text_reply"; |
| rule_action->action->response_text = "General Kenobi!"; |
| rule_action->action->score = 1.0f; |
| rule_action->action->priority_score = 1.0f; |
| rule->actions.push_back(std::move(rule_action)); |
| } |
| |
| // Add input-output low confidence rule. |
| actions_model->preconditions->suppress_on_low_confidence_input = true; |
| actions_model->low_confidence_rules.reset(new RulesModelT); |
| actions_model->low_confidence_rules->rule.emplace_back( |
| new RulesModel_::RuleT); |
| actions_model->low_confidence_rules->rule.back()->pattern = "hello"; |
| actions_model->low_confidence_rules->rule.back()->output_pattern = |
| "(?i:desaster)"; |
| |
| flatbuffers::FlatBufferBuilder builder; |
| FinishActionsModelBuffer(builder, |
| ActionsModel::Pack(builder, actions_model.get())); |
| std::unique_ptr<ActionsSuggestions> actions_suggestions = |
| ActionsSuggestions::FromUnownedBuffer( |
| reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()), |
| builder.GetSize(), &unilib_); |
| ASSERT_TRUE(actions_suggestions); |
| const ActionsSuggestionsResponse& response = |
| actions_suggestions->SuggestActions( |
| {{{/*user_id=*/1, "hello there", |
| /*reference_time_ms_utc=*/0, |
| /*reference_timezone=*/"Europe/Zurich", |
| /*annotations=*/{}, /*locales=*/"en"}}}); |
| ASSERT_GE(response.actions.size(), 1); |
| EXPECT_EQ(response.actions[0].response_text, "General Kenobi!"); |
| } |
| |
| TEST_F(ActionsSuggestionsTest, |
| SuggestActionsLowConfidenceInputOutputOverwrite) { |
| const std::string actions_model_string = |
| ReadFile(GetModelPath() + kModelFileName); |
| std::unique_ptr<ActionsModelT> actions_model = |
| UnPackActionsModel(actions_model_string.c_str()); |
| actions_model->low_confidence_rules.reset(); |
| |
| // Add custom triggering rule. |
| actions_model->rules.reset(new RulesModelT()); |
| actions_model->rules->rule.emplace_back(new RulesModel_::RuleT); |
| RulesModel_::RuleT* rule = actions_model->rules->rule.back().get(); |
| rule->pattern = "^(?i:hello\\s(there))$"; |
| { |
| std::unique_ptr<RulesModel_::Rule_::RuleActionSpecT> rule_action( |
| new RulesModel_::Rule_::RuleActionSpecT); |
| rule_action->action.reset(new ActionSuggestionSpecT); |
| rule_action->action->type = "text_reply"; |
| rule_action->action->response_text = "General Desaster!"; |
| rule_action->action->score = 1.0f; |
| rule_action->action->priority_score = 1.0f; |
| rule->actions.push_back(std::move(rule_action)); |
| } |
| { |
| std::unique_ptr<RulesModel_::Rule_::RuleActionSpecT> rule_action( |
| new RulesModel_::Rule_::RuleActionSpecT); |
| rule_action->action.reset(new ActionSuggestionSpecT); |
| rule_action->action->type = "text_reply"; |
| rule_action->action->response_text = "General Kenobi!"; |
| rule_action->action->score = 1.0f; |
| rule_action->action->priority_score = 1.0f; |
| rule->actions.push_back(std::move(rule_action)); |
| } |
| |
| // Add custom triggering rule via overwrite. |
| actions_model->preconditions->low_confidence_rules.reset(); |
| TriggeringPreconditionsT preconditions; |
| preconditions.suppress_on_low_confidence_input = true; |
| preconditions.low_confidence_rules.reset(new RulesModelT); |
| preconditions.low_confidence_rules->rule.emplace_back(new RulesModel_::RuleT); |
| preconditions.low_confidence_rules->rule.back()->pattern = "hello"; |
| preconditions.low_confidence_rules->rule.back()->output_pattern = |
| "(?i:desaster)"; |
| flatbuffers::FlatBufferBuilder preconditions_builder; |
| preconditions_builder.Finish( |
| TriggeringPreconditions::Pack(preconditions_builder, &preconditions)); |
| std::string serialize_preconditions = std::string( |
| reinterpret_cast<const char*>(preconditions_builder.GetBufferPointer()), |
| preconditions_builder.GetSize()); |
| |
| flatbuffers::FlatBufferBuilder builder; |
| FinishActionsModelBuffer(builder, |
| ActionsModel::Pack(builder, actions_model.get())); |
| std::unique_ptr<ActionsSuggestions> actions_suggestions = |
| ActionsSuggestions::FromUnownedBuffer( |
| reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()), |
| builder.GetSize(), &unilib_, serialize_preconditions); |
| |
| ASSERT_TRUE(actions_suggestions); |
| const ActionsSuggestionsResponse& response = |
| actions_suggestions->SuggestActions( |
| {{{/*user_id=*/1, "hello there", |
| /*reference_time_ms_utc=*/0, |
| /*reference_timezone=*/"Europe/Zurich", |
| /*annotations=*/{}, /*locales=*/"en"}}}); |
| ASSERT_GE(response.actions.size(), 1); |
| EXPECT_EQ(response.actions[0].response_text, "General Kenobi!"); |
| } |
| #endif |
| |
| TEST_F(ActionsSuggestionsTest, SuppressActionsFromAnnotationsOnSensitiveTopic) { |
| const std::string actions_model_string = |
| ReadFile(GetModelPath() + kModelFileName); |
| std::unique_ptr<ActionsModelT> actions_model = |
| UnPackActionsModel(actions_model_string.c_str()); |
| |
| // Don't test if no sensitivity score is produced |
| if (actions_model->tflite_model_spec->output_sensitive_topic_score < 0) { |
| return; |
| } |
| |
| actions_model->preconditions->max_sensitive_topic_score = 0.0; |
| actions_model->preconditions->suppress_on_sensitive_topic = true; |
| flatbuffers::FlatBufferBuilder builder; |
| FinishActionsModelBuffer(builder, |
| ActionsModel::Pack(builder, actions_model.get())); |
| std::unique_ptr<ActionsSuggestions> actions_suggestions = |
| ActionsSuggestions::FromUnownedBuffer( |
| reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()), |
| builder.GetSize(), &unilib_); |
| AnnotatedSpan annotation; |
| annotation.span = {11, 15}; |
| annotation.classification = { |
| ClassificationResult(Collections::Address(), 1.0)}; |
| const ActionsSuggestionsResponse& response = |
| actions_suggestions->SuggestActions( |
| {{{/*user_id=*/1, "are you at home?", |
| /*reference_time_ms_utc=*/0, |
| /*reference_timezone=*/"Europe/Zurich", |
| /*annotations=*/{annotation}, |
| /*locales=*/"en"}}}); |
| EXPECT_THAT(response.actions, testing::IsEmpty()); |
| } |
| |
| TEST_F(ActionsSuggestionsTest, SuggestActionsWithLongerConversation) { |
| const std::string actions_model_string = |
| ReadFile(GetModelPath() + kModelFileName); |
| std::unique_ptr<ActionsModelT> actions_model = |
| UnPackActionsModel(actions_model_string.c_str()); |
| |
| // Allow a larger conversation context. |
| actions_model->max_conversation_history_length = 10; |
| |
| flatbuffers::FlatBufferBuilder builder; |
| FinishActionsModelBuffer(builder, |
| ActionsModel::Pack(builder, actions_model.get())); |
| std::unique_ptr<ActionsSuggestions> actions_suggestions = |
| ActionsSuggestions::FromUnownedBuffer( |
| reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()), |
| builder.GetSize(), &unilib_); |
| AnnotatedSpan annotation; |
| annotation.span = {11, 15}; |
| annotation.classification = { |
| ClassificationResult(Collections::Address(), 1.0)}; |
| const ActionsSuggestionsResponse& response = |
| actions_suggestions->SuggestActions( |
| {{{/*user_id=*/ActionsSuggestions::kLocalUserId, "hi, how are you?", |
| /*reference_time_ms_utc=*/10000, |
| /*reference_timezone=*/"Europe/Zurich", |
| /*annotations=*/{}, /*locales=*/"en"}, |
| {/*user_id=*/1, "good! are you at home?", |
| /*reference_time_ms_utc=*/15000, |
| /*reference_timezone=*/"Europe/Zurich", |
| /*annotations=*/{annotation}, |
| /*locales=*/"en"}}}); |
| ASSERT_GE(response.actions.size(), 1); |
| EXPECT_EQ(response.actions[0].type, "view_map"); |
| EXPECT_EQ(response.actions[0].score, 1.0); |
| } |
| |
| TEST_F(ActionsSuggestionsTest, CreateActionsFromClassificationResult) { |
| std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel(); |
| AnnotatedSpan annotation; |
| annotation.span = {8, 12}; |
| annotation.classification = { |
| ClassificationResult(Collections::Flight(), 1.0)}; |
| |
| const ActionsSuggestionsResponse& response = |
| actions_suggestions->SuggestActions( |
| {{{/*user_id=*/1, "I'm on LX38?", |
| /*reference_time_ms_utc=*/0, |
| /*reference_timezone=*/"Europe/Zurich", |
| /*annotations=*/{annotation}, |
| /*locales=*/"en"}}}); |
| |
| ASSERT_GE(response.actions.size(), 2); |
| EXPECT_EQ(response.actions[0].type, "track_flight"); |
| EXPECT_EQ(response.actions[0].score, 1.0); |
| EXPECT_EQ(response.actions[0].annotations.size(), 1); |
| EXPECT_EQ(response.actions[0].annotations[0].span.message_index, 0); |
| EXPECT_EQ(response.actions[0].annotations[0].span.span, annotation.span); |
| } |
| |
| #ifdef TC3_UNILIB_ICU |
| TEST_F(ActionsSuggestionsTest, CreateActionsFromRules) { |
| const std::string actions_model_string = |
| ReadFile(GetModelPath() + kModelFileName); |
| std::unique_ptr<ActionsModelT> actions_model = |
| UnPackActionsModel(actions_model_string.c_str()); |
| ASSERT_TRUE(DecompressActionsModel(actions_model.get())); |
| |
| actions_model->rules.reset(new RulesModelT()); |
| actions_model->rules->rule.emplace_back(new RulesModel_::RuleT); |
| RulesModel_::RuleT* rule = actions_model->rules->rule.back().get(); |
| rule->pattern = "^(?i:hello\\s(there))$"; |
| rule->actions.emplace_back(new RulesModel_::Rule_::RuleActionSpecT); |
| rule->actions.back()->action.reset(new ActionSuggestionSpecT); |
| ActionSuggestionSpecT* action = rule->actions.back()->action.get(); |
| action->type = "text_reply"; |
| action->response_text = "General Kenobi!"; |
| action->score = 1.0f; |
| action->priority_score = 1.0f; |
| |
| // Set capturing groups for entity data. |
| rule->actions.back()->capturing_group.emplace_back( |
| new RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT); |
| RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT* greeting_group = |
| rule->actions.back()->capturing_group.back().get(); |
| greeting_group->group_id = 0; |
| greeting_group->entity_field.reset(new FlatbufferFieldPathT); |
| greeting_group->entity_field->field.emplace_back(new FlatbufferFieldT); |
| greeting_group->entity_field->field.back()->field_name = "greeting"; |
| rule->actions.back()->capturing_group.emplace_back( |
| new RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT); |
| RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT* location_group = |
| rule->actions.back()->capturing_group.back().get(); |
| location_group->group_id = 1; |
| location_group->entity_field.reset(new FlatbufferFieldPathT); |
| location_group->entity_field->field.emplace_back(new FlatbufferFieldT); |
| location_group->entity_field->field.back()->field_name = "location"; |
| |
| // Set test entity data schema. |
| SetTestEntityDataSchema(actions_model.get()); |
| |
| // Use meta data to generate custom serialized entity data. |
| ReflectiveFlatbufferBuilder entity_data_builder( |
| flatbuffers::GetRoot<reflection::Schema>( |
| actions_model->actions_entity_data_schema.data())); |
| std::unique_ptr<ReflectiveFlatbuffer> entity_data = |
| entity_data_builder.NewRoot(); |
| entity_data->Set("person", "Kenobi"); |
| action->serialized_entity_data = entity_data->Serialize(); |
| |
| flatbuffers::FlatBufferBuilder builder; |
| FinishActionsModelBuffer(builder, |
| ActionsModel::Pack(builder, actions_model.get())); |
| std::unique_ptr<ActionsSuggestions> actions_suggestions = |
| ActionsSuggestions::FromUnownedBuffer( |
| reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()), |
| builder.GetSize(), &unilib_); |
| |
| const ActionsSuggestionsResponse& response = |
| actions_suggestions->SuggestActions( |
| {{{/*user_id=*/1, "hello there", /*reference_time_ms_utc=*/0, |
| /*reference_timezone=*/"Europe/Zurich", |
| /*annotations=*/{}, /*locales=*/"en"}}}); |
| EXPECT_GE(response.actions.size(), 1); |
| EXPECT_EQ(response.actions[0].response_text, "General Kenobi!"); |
| |
| // Check entity data. |
| const flatbuffers::Table* entity = |
| flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>( |
| response.actions[0].serialized_entity_data.data())); |
| EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(), |
| "hello there"); |
| EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/6)->str(), |
| "there"); |
| EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(), |
| "Kenobi"); |
| } |
| |
| TEST_F(ActionsSuggestionsTest, CreatesTextRepliesFromRules) { |
| const std::string actions_model_string = |
| ReadFile(GetModelPath() + kModelFileName); |
| std::unique_ptr<ActionsModelT> actions_model = |
| UnPackActionsModel(actions_model_string.c_str()); |
| ASSERT_TRUE(DecompressActionsModel(actions_model.get())); |
| |
| actions_model->rules.reset(new RulesModelT()); |
| actions_model->rules->rule.emplace_back(new RulesModel_::RuleT); |
| RulesModel_::RuleT* rule = actions_model->rules->rule.back().get(); |
| rule->pattern = "(?i:reply (stop|quit|end) (?:to|for) )"; |
| rule->actions.emplace_back(new RulesModel_::Rule_::RuleActionSpecT); |
| |
| // Set capturing groups for entity data. |
| rule->actions.back()->capturing_group.emplace_back( |
| new RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT); |
| RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT* code_group = |
| rule->actions.back()->capturing_group.back().get(); |
| code_group->group_id = 1; |
| code_group->text_reply.reset(new ActionSuggestionSpecT); |
| code_group->text_reply->score = 1.0f; |
| code_group->text_reply->priority_score = 1.0f; |
| |
| flatbuffers::FlatBufferBuilder builder; |
| FinishActionsModelBuffer(builder, |
| ActionsModel::Pack(builder, actions_model.get())); |
| std::unique_ptr<ActionsSuggestions> actions_suggestions = |
| ActionsSuggestions::FromUnownedBuffer( |
| reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()), |
| builder.GetSize(), &unilib_); |
| |
| const ActionsSuggestionsResponse& response = |
| actions_suggestions->SuggestActions( |
| {{{/*user_id=*/1, |
| "visit test.com or reply STOP to cancel your subscription", |
| /*reference_time_ms_utc=*/0, |
| /*reference_timezone=*/"Europe/Zurich", |
| /*annotations=*/{}, /*locales=*/"en"}}}); |
| EXPECT_GE(response.actions.size(), 1); |
| EXPECT_EQ(response.actions[0].response_text, "STOP"); |
| } |
| |
| TEST_F(ActionsSuggestionsTest, DeduplicateActions) { |
| std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel(); |
| ActionsSuggestionsResponse response = actions_suggestions->SuggestActions( |
| {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0, |
| /*reference_timezone=*/"Europe/Zurich", |
| /*annotations=*/{}, /*locales=*/"en"}}}); |
| |
| // Check that the location sharing model triggered. |
| bool has_location_sharing_action = false; |
| for (const ActionSuggestion action : response.actions) { |
| if (action.type == ActionsSuggestions::kShareLocation) { |
| has_location_sharing_action = true; |
| break; |
| } |
| } |
| EXPECT_TRUE(has_location_sharing_action); |
| const int num_actions = response.actions.size(); |
| |
| // Add custom rule for location sharing. |
| const std::string actions_model_string = |
| ReadFile(GetModelPath() + kModelFileName); |
| std::unique_ptr<ActionsModelT> actions_model = |
| UnPackActionsModel(actions_model_string.c_str()); |
| ASSERT_TRUE(DecompressActionsModel(actions_model.get())); |
| |
| actions_model->rules.reset(new RulesModelT()); |
| actions_model->rules->rule.emplace_back(new RulesModel_::RuleT); |
| actions_model->rules->rule.back()->pattern = "^(?i:where are you[.?]?)$"; |
| actions_model->rules->rule.back()->actions.emplace_back( |
| new RulesModel_::Rule_::RuleActionSpecT); |
| actions_model->rules->rule.back()->actions.back()->action.reset( |
| new ActionSuggestionSpecT); |
| ActionSuggestionSpecT* action = |
| actions_model->rules->rule.back()->actions.back()->action.get(); |
| action->score = 1.0f; |
| action->type = ActionsSuggestions::kShareLocation; |
| |
| flatbuffers::FlatBufferBuilder builder; |
| FinishActionsModelBuffer(builder, |
| ActionsModel::Pack(builder, actions_model.get())); |
| actions_suggestions = ActionsSuggestions::FromUnownedBuffer( |
| reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()), |
| builder.GetSize(), &unilib_); |
| |
| response = actions_suggestions->SuggestActions( |
| {{{/*user_id=*/1, "Where are you?", /*reference_time_ms_utc=*/0, |
| /*reference_timezone=*/"Europe/Zurich", |
| /*annotations=*/{}, /*locales=*/"en"}}}); |
| EXPECT_EQ(response.actions.size(), num_actions); |
| } |
| |
| TEST_F(ActionsSuggestionsTest, DeduplicateConflictingActions) { |
| std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel(); |
| AnnotatedSpan annotation; |
| annotation.span = {7, 11}; |
| annotation.classification = { |
| ClassificationResult(Collections::Flight(), 1.0)}; |
| ActionsSuggestionsResponse response = actions_suggestions->SuggestActions( |
| {{{/*user_id=*/1, "I'm on LX38", |
| /*reference_time_ms_utc=*/0, |
| /*reference_timezone=*/"Europe/Zurich", |
| /*annotations=*/{annotation}, |
| /*locales=*/"en"}}}); |
| |
| // Check that the phone actions are present. |
| EXPECT_GE(response.actions.size(), 1); |
| EXPECT_EQ(response.actions[0].type, "track_flight"); |
| |
| // Add custom rule. |
| const std::string actions_model_string = |
| ReadFile(GetModelPath() + kModelFileName); |
| std::unique_ptr<ActionsModelT> actions_model = |
| UnPackActionsModel(actions_model_string.c_str()); |
| ASSERT_TRUE(DecompressActionsModel(actions_model.get())); |
| |
| actions_model->rules.reset(new RulesModelT()); |
| actions_model->rules->rule.emplace_back(new RulesModel_::RuleT); |
| RulesModel_::RuleT* rule = actions_model->rules->rule.back().get(); |
| rule->pattern = "^(?i:I'm on ([a-z0-9]+))$"; |
| rule->actions.emplace_back(new RulesModel_::Rule_::RuleActionSpecT); |
| rule->actions.back()->action.reset(new ActionSuggestionSpecT); |
| ActionSuggestionSpecT* action = rule->actions.back()->action.get(); |
| action->score = 1.0f; |
| action->priority_score = 2.0f; |
| action->type = "test_code"; |
| rule->actions.back()->capturing_group.emplace_back( |
| new RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT); |
| RulesModel_::Rule_::RuleActionSpec_::RuleCapturingGroupT* code_group = |
| rule->actions.back()->capturing_group.back().get(); |
| code_group->group_id = 1; |
| code_group->annotation_name = "code"; |
| code_group->annotation_type = "code"; |
| |
| flatbuffers::FlatBufferBuilder builder; |
| FinishActionsModelBuffer(builder, |
| ActionsModel::Pack(builder, actions_model.get())); |
| actions_suggestions = ActionsSuggestions::FromUnownedBuffer( |
| reinterpret_cast<const uint8_t*>(builder.GetBufferPointer()), |
| builder.GetSize(), &unilib_); |
| |
| response = actions_suggestions->SuggestActions( |
| {{{/*user_id=*/1, "I'm on LX38", |
| /*reference_time_ms_utc=*/0, |
| /*reference_timezone=*/"Europe/Zurich", |
| /*annotations=*/{annotation}, |
| /*locales=*/"en"}}}); |
| EXPECT_GE(response.actions.size(), 1); |
| EXPECT_EQ(response.actions[0].type, "test_code"); |
| } |
| #endif |
| |
| TEST_F(ActionsSuggestionsTest, SuggestActionsRanking) { |
| std::unique_ptr<ActionsSuggestions> actions_suggestions = LoadTestModel(); |
| std::vector<AnnotatedSpan> annotations(2); |
| annotations[0].span = {11, 15}; |
| annotations[0].classification = {ClassificationResult("address", 1.0)}; |
| annotations[1].span = {19, 23}; |
| annotations[1].classification = {ClassificationResult("address", 2.0)}; |
| const ActionsSuggestionsResponse& response = |
| actions_suggestions->SuggestActions( |
| {{{/*user_id=*/1, "are you at home or work?", |
| /*reference_time_ms_utc=*/0, |
| /*reference_timezone=*/"Europe/Zurich", |
| /*annotations=*/annotations, |
| /*locales=*/"en"}}}); |
| EXPECT_GE(response.actions.size(), 2); |
| EXPECT_EQ(response.actions[0].type, "view_map"); |
| EXPECT_EQ(response.actions[0].score, 2.0); |
| EXPECT_EQ(response.actions[1].type, "view_map"); |
| EXPECT_EQ(response.actions[1].score, 1.0); |
| } |
| |
| TEST_F(ActionsSuggestionsTest, VisitActionsModel) { |
| EXPECT_TRUE(VisitActionsModel<bool>(GetModelPath() + kModelFileName, |
| [](const ActionsModel* model) { |
| if (model == nullptr) { |
| return false; |
| } |
| return true; |
| })); |
| EXPECT_FALSE(VisitActionsModel<bool>(GetModelPath() + "non_existing_model.fb", |
| [](const ActionsModel* model) { |
| if (model == nullptr) { |
| return false; |
| } |
| return true; |
| })); |
| } |
| |
| TEST_F(ActionsSuggestionsTest, SuggestActionsWithHashGramModel) { |
| std::unique_ptr<ActionsSuggestions> actions_suggestions = |
| LoadHashGramTestModel(); |
| ASSERT_TRUE(actions_suggestions != nullptr); |
| { |
| const ActionsSuggestionsResponse& response = |
| actions_suggestions->SuggestActions( |
| {{{/*user_id=*/1, "hello", |
| /*reference_time_ms_utc=*/0, |
| /*reference_timezone=*/"Europe/Zurich", |
| /*annotations=*/{}, |
| /*locales=*/"en"}}}); |
| EXPECT_THAT(response.actions, testing::IsEmpty()); |
| } |
| { |
| const ActionsSuggestionsResponse& response = |
| actions_suggestions->SuggestActions( |
| {{{/*user_id=*/1, "where are you", |
| /*reference_time_ms_utc=*/0, |
| /*reference_timezone=*/"Europe/Zurich", |
| /*annotations=*/{}, |
| /*locales=*/"en"}}}); |
| EXPECT_THAT( |
| response.actions, |
| ElementsAre(testing::Field(&ActionSuggestion::type, "share_location"))); |
| } |
| { |
| const ActionsSuggestionsResponse& response = |
| actions_suggestions->SuggestActions( |
| {{{/*user_id=*/1, "do you know johns number", |
| /*reference_time_ms_utc=*/0, |
| /*reference_timezone=*/"Europe/Zurich", |
| /*annotations=*/{}, |
| /*locales=*/"en"}}}); |
| EXPECT_THAT( |
| response.actions, |
| ElementsAre(testing::Field(&ActionSuggestion::type, "share_contact"))); |
| } |
| } |
| |
| // Test class to expose token embedding methods for testing. |
| class TestingMessageEmbedder : private ActionsSuggestions { |
| public: |
| explicit TestingMessageEmbedder(const ActionsModel* model); |
| |
| using ActionsSuggestions::EmbedAndFlattenTokens; |
| using ActionsSuggestions::EmbedTokensPerMessage; |
| |
| protected: |
| // EmbeddingExecutor that always returns features based on |
| // the id of the sparse features. |
| class FakeEmbeddingExecutor : public EmbeddingExecutor { |
| public: |
| bool AddEmbedding(const TensorView<int>& sparse_features, float* dest, |
| const int dest_size) const override { |
| TC3_CHECK_GE(dest_size, 1); |
| EXPECT_EQ(sparse_features.size(), 1); |
| dest[0] = sparse_features.data()[0]; |
| return true; |
| } |
| }; |
| }; |
| |
| TestingMessageEmbedder::TestingMessageEmbedder(const ActionsModel* model) { |
| model_ = model; |
| const ActionsTokenFeatureProcessorOptions* options = |
| model->feature_processor_options(); |
| feature_processor_.reset( |
| new ActionsFeatureProcessor(options, /*unilib=*/nullptr)); |
| embedding_executor_.reset(new FakeEmbeddingExecutor()); |
| EXPECT_TRUE( |
| EmbedTokenId(options->padding_token_id(), &embedded_padding_token_)); |
| EXPECT_TRUE(EmbedTokenId(options->start_token_id(), &embedded_start_token_)); |
| EXPECT_TRUE(EmbedTokenId(options->end_token_id(), &embedded_end_token_)); |
| token_embedding_size_ = feature_processor_->GetTokenEmbeddingSize(); |
| EXPECT_EQ(token_embedding_size_, 1); |
| } |
| |
| class EmbeddingTest : public testing::Test { |
| protected: |
| EmbeddingTest() { |
| model_.feature_processor_options.reset( |
| new ActionsTokenFeatureProcessorOptionsT); |
| options_ = model_.feature_processor_options.get(); |
| options_->chargram_orders = {1}; |
| options_->num_buckets = 1000; |
| options_->embedding_size = 1; |
| options_->start_token_id = 0; |
| options_->end_token_id = 1; |
| options_->padding_token_id = 2; |
| options_->tokenizer_options.reset(new ActionsTokenizerOptionsT); |
| } |
| |
| TestingMessageEmbedder CreateTestingMessageEmbedder() { |
| flatbuffers::FlatBufferBuilder builder; |
| FinishActionsModelBuffer(builder, ActionsModel::Pack(builder, &model_)); |
| buffer_ = builder.ReleaseBufferPointer(); |
| return TestingMessageEmbedder( |
| flatbuffers::GetRoot<ActionsModel>(buffer_.data())); |
| } |
| |
| flatbuffers::DetachedBuffer buffer_; |
| ActionsModelT model_; |
| ActionsTokenFeatureProcessorOptionsT* options_; |
| }; |
| |
| TEST_F(EmbeddingTest, EmbedsTokensPerMessageWithNoBounds) { |
| const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder(); |
| std::vector<std::vector<Token>> tokens = { |
| {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}}; |
| std::vector<float> embeddings; |
| int max_num_tokens_per_message = 0; |
| |
| EXPECT_TRUE(embedder.EmbedTokensPerMessage(tokens, &embeddings, |
| &max_num_tokens_per_message)); |
| |
| EXPECT_EQ(max_num_tokens_per_message, 3); |
| EXPECT_EQ(embeddings.size(), 3); |
| EXPECT_THAT(embeddings[0], |
| testing::FloatEq(tc3farmhash::Fingerprint64("a", 1) % |
| options_->num_buckets)); |
| EXPECT_THAT(embeddings[1], |
| testing::FloatEq(tc3farmhash::Fingerprint64("b", 1) % |
| options_->num_buckets)); |
| EXPECT_THAT(embeddings[2], |
| testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) % |
| options_->num_buckets)); |
| } |
| |
| TEST_F(EmbeddingTest, EmbedsTokensPerMessageWithPadding) { |
| options_->min_num_tokens_per_message = 5; |
| const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder(); |
| std::vector<std::vector<Token>> tokens = { |
| {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}}; |
| std::vector<float> embeddings; |
| int max_num_tokens_per_message = 0; |
| |
| EXPECT_TRUE(embedder.EmbedTokensPerMessage(tokens, &embeddings, |
| &max_num_tokens_per_message)); |
| |
| EXPECT_EQ(max_num_tokens_per_message, 5); |
| EXPECT_EQ(embeddings.size(), 5); |
| EXPECT_THAT(embeddings[0], |
| testing::FloatEq(tc3farmhash::Fingerprint64("a", 1) % |
| options_->num_buckets)); |
| EXPECT_THAT(embeddings[1], |
| testing::FloatEq(tc3farmhash::Fingerprint64("b", 1) % |
| options_->num_buckets)); |
| EXPECT_THAT(embeddings[2], |
| testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) % |
| options_->num_buckets)); |
| EXPECT_THAT(embeddings[3], testing::FloatEq(options_->padding_token_id)); |
| EXPECT_THAT(embeddings[4], testing::FloatEq(options_->padding_token_id)); |
| } |
| |
| TEST_F(EmbeddingTest, EmbedsTokensPerMessageDropsAtBeginning) { |
| options_->max_num_tokens_per_message = 2; |
| const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder(); |
| std::vector<std::vector<Token>> tokens = { |
| {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}}; |
| std::vector<float> embeddings; |
| int max_num_tokens_per_message = 0; |
| |
| EXPECT_TRUE(embedder.EmbedTokensPerMessage(tokens, &embeddings, |
| &max_num_tokens_per_message)); |
| |
| EXPECT_EQ(max_num_tokens_per_message, 2); |
| EXPECT_EQ(embeddings.size(), 2); |
| EXPECT_THAT(embeddings[0], |
| testing::FloatEq(tc3farmhash::Fingerprint64("b", 1) % |
| options_->num_buckets)); |
| EXPECT_THAT(embeddings[1], |
| testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) % |
| options_->num_buckets)); |
| } |
| |
| TEST_F(EmbeddingTest, EmbedsTokensPerMessageWithMultipleMessagesNoBounds) { |
| const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder(); |
| std::vector<std::vector<Token>> tokens = { |
| {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}, |
| {Token("d", 0, 1), Token("e", 2, 3)}}; |
| std::vector<float> embeddings; |
| int max_num_tokens_per_message = 0; |
| |
| EXPECT_TRUE(embedder.EmbedTokensPerMessage(tokens, &embeddings, |
| &max_num_tokens_per_message)); |
| |
| EXPECT_EQ(max_num_tokens_per_message, 3); |
| EXPECT_THAT(embeddings[0], |
| testing::FloatEq(tc3farmhash::Fingerprint64("a", 1) % |
| options_->num_buckets)); |
| EXPECT_THAT(embeddings[1], |
| testing::FloatEq(tc3farmhash::Fingerprint64("b", 1) % |
| options_->num_buckets)); |
| EXPECT_THAT(embeddings[2], |
| testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) % |
| options_->num_buckets)); |
| EXPECT_THAT(embeddings[3], |
| testing::FloatEq(tc3farmhash::Fingerprint64("d", 1) % |
| options_->num_buckets)); |
| EXPECT_THAT(embeddings[4], |
| testing::FloatEq(tc3farmhash::Fingerprint64("e", 1) % |
| options_->num_buckets)); |
| EXPECT_THAT(embeddings[5], testing::FloatEq(options_->padding_token_id)); |
| } |
| |
| TEST_F(EmbeddingTest, EmbedsFlattenedTokensWithNoBounds) { |
| const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder(); |
| std::vector<std::vector<Token>> tokens = { |
| {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}}; |
| std::vector<float> embeddings; |
| int total_token_count = 0; |
| |
| EXPECT_TRUE( |
| embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count)); |
| |
| EXPECT_EQ(total_token_count, 5); |
| EXPECT_EQ(embeddings.size(), 5); |
| EXPECT_THAT(embeddings[0], testing::FloatEq(options_->start_token_id)); |
| EXPECT_THAT(embeddings[1], |
| testing::FloatEq(tc3farmhash::Fingerprint64("a", 1) % |
| options_->num_buckets)); |
| EXPECT_THAT(embeddings[2], |
| testing::FloatEq(tc3farmhash::Fingerprint64("b", 1) % |
| options_->num_buckets)); |
| EXPECT_THAT(embeddings[3], |
| testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) % |
| options_->num_buckets)); |
| EXPECT_THAT(embeddings[4], testing::FloatEq(options_->end_token_id)); |
| } |
| |
| TEST_F(EmbeddingTest, EmbedsFlattenedTokensWithPadding) { |
| options_->min_num_total_tokens = 7; |
| const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder(); |
| std::vector<std::vector<Token>> tokens = { |
| {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}}; |
| std::vector<float> embeddings; |
| int total_token_count = 0; |
| |
| EXPECT_TRUE( |
| embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count)); |
| |
| EXPECT_EQ(total_token_count, 7); |
| EXPECT_EQ(embeddings.size(), 7); |
| EXPECT_THAT(embeddings[0], testing::FloatEq(options_->start_token_id)); |
| EXPECT_THAT(embeddings[1], |
| testing::FloatEq(tc3farmhash::Fingerprint64("a", 1) % |
| options_->num_buckets)); |
| EXPECT_THAT(embeddings[2], |
| testing::FloatEq(tc3farmhash::Fingerprint64("b", 1) % |
| options_->num_buckets)); |
| EXPECT_THAT(embeddings[3], |
| testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) % |
| options_->num_buckets)); |
| EXPECT_THAT(embeddings[4], testing::FloatEq(options_->end_token_id)); |
| EXPECT_THAT(embeddings[5], testing::FloatEq(options_->padding_token_id)); |
| EXPECT_THAT(embeddings[6], testing::FloatEq(options_->padding_token_id)); |
| } |
| |
| TEST_F(EmbeddingTest, EmbedsFlattenedTokensDropsAtBeginning) { |
| options_->max_num_total_tokens = 3; |
| const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder(); |
| std::vector<std::vector<Token>> tokens = { |
| {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}}; |
| std::vector<float> embeddings; |
| int total_token_count = 0; |
| |
| EXPECT_TRUE( |
| embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count)); |
| |
| EXPECT_EQ(total_token_count, 3); |
| EXPECT_EQ(embeddings.size(), 3); |
| EXPECT_THAT(embeddings[0], |
| testing::FloatEq(tc3farmhash::Fingerprint64("b", 1) % |
| options_->num_buckets)); |
| EXPECT_THAT(embeddings[1], |
| testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) % |
| options_->num_buckets)); |
| EXPECT_THAT(embeddings[2], testing::FloatEq(options_->end_token_id)); |
| } |
| |
| TEST_F(EmbeddingTest, EmbedsFlattenedTokensWithMultipleMessagesNoBounds) { |
| const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder(); |
| std::vector<std::vector<Token>> tokens = { |
| {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}, |
| {Token("d", 0, 1), Token("e", 2, 3)}}; |
| std::vector<float> embeddings; |
| int total_token_count = 0; |
| |
| EXPECT_TRUE( |
| embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count)); |
| |
| EXPECT_EQ(total_token_count, 9); |
| EXPECT_EQ(embeddings.size(), 9); |
| EXPECT_THAT(embeddings[0], testing::FloatEq(options_->start_token_id)); |
| EXPECT_THAT(embeddings[1], |
| testing::FloatEq(tc3farmhash::Fingerprint64("a", 1) % |
| options_->num_buckets)); |
| EXPECT_THAT(embeddings[2], |
| testing::FloatEq(tc3farmhash::Fingerprint64("b", 1) % |
| options_->num_buckets)); |
| EXPECT_THAT(embeddings[3], |
| testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) % |
| options_->num_buckets)); |
| EXPECT_THAT(embeddings[4], testing::FloatEq(options_->end_token_id)); |
| EXPECT_THAT(embeddings[5], testing::FloatEq(options_->start_token_id)); |
| EXPECT_THAT(embeddings[6], |
| testing::FloatEq(tc3farmhash::Fingerprint64("d", 1) % |
| options_->num_buckets)); |
| EXPECT_THAT(embeddings[7], |
| testing::FloatEq(tc3farmhash::Fingerprint64("e", 1) % |
| options_->num_buckets)); |
| EXPECT_THAT(embeddings[8], testing::FloatEq(options_->end_token_id)); |
| } |
| |
| TEST_F(EmbeddingTest, |
| EmbedsFlattenedTokensWithMultipleMessagesDropsAtBeginning) { |
| options_->max_num_total_tokens = 7; |
| const TestingMessageEmbedder embedder = CreateTestingMessageEmbedder(); |
| std::vector<std::vector<Token>> tokens = { |
| {Token("a", 0, 1), Token("b", 2, 3), Token("c", 4, 5)}, |
| {Token("d", 0, 1), Token("e", 2, 3), Token("f", 4, 5)}}; |
| std::vector<float> embeddings; |
| int total_token_count = 0; |
| |
| EXPECT_TRUE( |
| embedder.EmbedAndFlattenTokens(tokens, &embeddings, &total_token_count)); |
| |
| EXPECT_EQ(total_token_count, 7); |
| EXPECT_EQ(embeddings.size(), 7); |
| EXPECT_THAT(embeddings[0], |
| testing::FloatEq(tc3farmhash::Fingerprint64("c", 1) % |
| options_->num_buckets)); |
| EXPECT_THAT(embeddings[1], testing::FloatEq(options_->end_token_id)); |
| EXPECT_THAT(embeddings[2], testing::FloatEq(options_->start_token_id)); |
| EXPECT_THAT(embeddings[3], |
| testing::FloatEq(tc3farmhash::Fingerprint64("d", 1) % |
| options_->num_buckets)); |
| EXPECT_THAT(embeddings[4], |
| testing::FloatEq(tc3farmhash::Fingerprint64("e", 1) % |
| options_->num_buckets)); |
| EXPECT_THAT(embeddings[5], |
| testing::FloatEq(tc3farmhash::Fingerprint64("f", 1) % |
| options_->num_buckets)); |
| EXPECT_THAT(embeddings[6], testing::FloatEq(options_->end_token_id)); |
| } |
| |
| } // namespace |
| } // namespace libtextclassifier3 |