blob: 72cae2c2b25091c60baeab63d73496dfac3f2fab [file] [log] [blame] [edit]
/*
* Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "actions/lua-actions.h"
#include <map>
#include <string>
#include "actions/test-utils.h"
#include "actions/types.h"
#include "utils/tflite-model-executor.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
namespace libtextclassifier3 {
namespace {
using testing::ElementsAre;
TEST(LuaActions, SimpleAction) {
Conversation conversation;
const std::string test_snippet = R"(
return {{ type = "test_action" }}
)";
std::vector<ActionSuggestion> actions;
EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
test_snippet, conversation,
/*model_executor=*/nullptr,
/*model_spec=*/nullptr,
/*interpreter=*/nullptr,
/*actions_entity_data_schema=*/nullptr,
/*annotations_entity_data_schema=*/nullptr)
->SuggestActions(&actions));
EXPECT_THAT(actions, ElementsAre(IsActionOfType("test_action")));
}
TEST(LuaActions, ConversationActions) {
Conversation conversation;
conversation.messages.push_back({/*user_id=*/0, "hello there!"});
conversation.messages.push_back({/*user_id=*/1, "general kenobi!"});
const std::string test_snippet = R"(
local actions = {}
for i, message in pairs(messages) do
if i < #messages then
if message.text == "hello there!" and
messages[i+1].text == "general kenobi!" then
table.insert(actions, {
type = "text_reply",
response_text = "you are a bold one!"
})
end
if message.text == "i am the senate!" and
messages[i+1].text == "not yet!" then
table.insert(actions, {
type = "text_reply",
response_text = "it's treason then"
})
end
end
end
return actions;
)";
std::vector<ActionSuggestion> actions;
EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
test_snippet, conversation,
/*model_executor=*/nullptr,
/*model_spec=*/nullptr,
/*interpreter=*/nullptr,
/*actions_entity_data_schema=*/nullptr,
/*annotations_entity_data_schema=*/nullptr)
->SuggestActions(&actions));
EXPECT_THAT(actions, ElementsAre(IsSmartReply("you are a bold one!")));
}
TEST(LuaActions, SimpleModelAction) {
Conversation conversation;
const std::string test_snippet = R"(
if #model.actions_scores == 0 then
return {{ type = "test_action" }}
end
return {}
)";
std::vector<ActionSuggestion> actions;
EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
test_snippet, conversation,
/*model_executor=*/nullptr,
/*model_spec=*/nullptr,
/*interpreter=*/nullptr,
/*actions_entity_data_schema=*/nullptr,
/*annotations_entity_data_schema=*/nullptr)
->SuggestActions(&actions));
EXPECT_THAT(actions, ElementsAre(IsActionOfType("test_action")));
}
TEST(LuaActions, SimpleModelRepliesAction) {
Conversation conversation;
const std::string test_snippet = R"(
if #model.reply == 0 then
return {{ type = "test_action" }}
end
return {}
)";
std::vector<ActionSuggestion> actions;
EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
test_snippet, conversation,
/*model_executor=*/nullptr,
/*model_spec=*/nullptr,
/*interpreter=*/nullptr,
/*actions_entity_data_schema=*/nullptr,
/*annotations_entity_data_schema=*/nullptr)
->SuggestActions(&actions));
EXPECT_THAT(actions, ElementsAre(IsActionOfType("test_action")));
}
TEST(LuaActions, AnnotationActions) {
AnnotatedSpan annotation;
annotation.span = {11, 15};
annotation.classification = {ClassificationResult("address", 1.0)};
Conversation conversation = {{{/*user_id=*/1, "are you at home?",
/*reference_time_ms_utc=*/0,
/*reference_timezone=*/"Europe/Zurich",
/*annotations=*/{annotation},
/*locales=*/"en"}}};
const std::string test_snippet = R"(
local actions = {}
local last_message = messages[#messages]
for i, annotation in pairs(last_message.annotation) do
if #annotation.classification > 0 then
if annotation.classification[1].collection == "address" then
local text = string.sub(last_message.text,
annotation.span["begin"] + 1,
annotation.span["end"])
table.insert(actions, {
type = "text_reply",
response_text = "i am at " .. text,
annotation = {{
name = "location",
span = {
text = text
},
entity = annotation.classification[1]
}},
})
end
end
end
return actions;
)";
std::vector<ActionSuggestion> actions;
EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
test_snippet, conversation,
/*model_executor=*/nullptr,
/*model_spec=*/nullptr,
/*interpreter=*/nullptr,
/*actions_entity_data_schema=*/nullptr,
/*annotations_entity_data_schema=*/nullptr)
->SuggestActions(&actions));
EXPECT_THAT(actions, ElementsAre(IsSmartReply("i am at home")));
EXPECT_EQ("address", actions[0].annotations[0].entity.collection);
}
TEST(LuaActions, EntityData) {
std::string test_schema = TestEntityDataSchema();
Conversation conversation = {{{/*user_id=*/1, "hello there"}}};
const std::string test_snippet = R"(
return {{
type = "test",
entity = {
greeting = "hello",
location = "there",
person = "Kenobi",
},
}};
)";
std::vector<ActionSuggestion> actions;
EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
test_snippet, conversation,
/*model_executor=*/nullptr,
/*model_spec=*/nullptr,
/*interpreter=*/nullptr,
/*actions_entity_data_schema=*/
flatbuffers::GetRoot<reflection::Schema>(test_schema.data()),
/*annotations_entity_data_schema=*/nullptr)
->SuggestActions(&actions));
EXPECT_THAT(actions, testing::SizeIs(1));
EXPECT_EQ("test", actions.front().type);
const flatbuffers::Table* entity =
flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
actions.front().serialized_entity_data.data()));
EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
"hello");
EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/6)->str(),
"there");
EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
"Kenobi");
}
} // namespace
} // namespace libtextclassifier3