blob: f7b9cd5f24bb6e56301a88012d9f594b740cf62e [file] [log] [blame]
/*
* Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "actions/lua-actions.h"
#include <map>
#include <string>
#include "actions/test_utils.h"
#include "actions/types.h"
#include "utils/tflite-model-executor.h"
#include "gmock/gmock.h"
#include "gtest/gtest.h"
namespace libtextclassifier3 {
namespace {
MATCHER_P2(IsAction, type, response_text, "") {
return testing::Value(arg.type, type) &&
testing::Value(arg.response_text, response_text);
}
MATCHER_P(IsActionType, type, "") { return testing::Value(arg.type, type); }
TEST(LuaActions, SimpleAction) {
Conversation conversation;
const std::string test_snippet = R"(
return {{ type = "test_action" }}
)";
std::vector<ActionSuggestion> actions;
EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
test_snippet, conversation,
/*model_executor=*/nullptr,
/*model_spec=*/nullptr,
/*interpreter=*/nullptr,
/*actions_entity_data_schema=*/nullptr,
/*annotations_entity_data_schema=*/nullptr)
->SuggestActions(&actions));
EXPECT_THAT(actions,
testing::ElementsAreArray({IsActionType("test_action")}));
}
TEST(LuaActions, ConversationActions) {
Conversation conversation;
conversation.messages.push_back({/*user_id=*/0, "hello there!"});
conversation.messages.push_back({/*user_id=*/1, "general kenobi!"});
const std::string test_snippet = R"(
local actions = {}
for i, message in pairs(messages) do
if i < #messages then
if message.text == "hello there!" and
messages[i+1].text == "general kenobi!" then
table.insert(actions, {
type = "text_reply",
response_text = "you are a bold one!"
})
end
if message.text == "i am the senate!" and
messages[i+1].text == "not yet!" then
table.insert(actions, {
type = "text_reply",
response_text = "it's treason then"
})
end
end
end
return actions;
)";
std::vector<ActionSuggestion> actions;
EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
test_snippet, conversation,
/*model_executor=*/nullptr,
/*model_spec=*/nullptr,
/*interpreter=*/nullptr,
/*actions_entity_data_schema=*/nullptr,
/*annotations_entity_data_schema=*/nullptr)
->SuggestActions(&actions));
EXPECT_THAT(actions, testing::ElementsAreArray(
{IsAction("text_reply", "you are a bold one!")}));
}
TEST(LuaActions, SimpleModelAction) {
Conversation conversation;
const std::string test_snippet = R"(
if #model.actions_scores == 0 then
return {{ type = "test_action" }}
end
return {}
)";
std::vector<ActionSuggestion> actions;
EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
test_snippet, conversation,
/*model_executor=*/nullptr,
/*model_spec=*/nullptr,
/*interpreter=*/nullptr,
/*actions_entity_data_schema=*/nullptr,
/*annotations_entity_data_schema=*/nullptr)
->SuggestActions(&actions));
EXPECT_THAT(actions,
testing::ElementsAreArray({IsActionType("test_action")}));
}
TEST(LuaActions, AnnotationActions) {
AnnotatedSpan annotation;
annotation.span = {11, 15};
annotation.classification = {ClassificationResult("address", 1.0)};
Conversation conversation = {{{/*user_id=*/1, "are you at home?",
/*reference_time_ms_utc=*/0,
/*reference_timezone=*/"Europe/Zurich",
/*annotations=*/{annotation},
/*locales=*/"en"}}};
const std::string test_snippet = R"(
local actions = {}
local last_message = messages[#messages]
for i, annotation in pairs(last_message.annotation) do
if #annotation.classification > 0 then
if annotation.classification[1].collection == "address" then
local text = string.sub(last_message.text,
annotation.span["begin"] + 1,
annotation.span["end"])
table.insert(actions, {
type = "text_reply",
response_text = "i am at " .. text,
annotation = {{
name = "location",
span = {
text = text
},
entity = annotation.classification[1]
}},
})
end
end
end
return actions;
)";
std::vector<ActionSuggestion> actions;
EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
test_snippet, conversation,
/*model_executor=*/nullptr,
/*model_spec=*/nullptr,
/*interpreter=*/nullptr,
/*actions_entity_data_schema=*/nullptr,
/*annotations_entity_data_schema=*/nullptr)
->SuggestActions(&actions));
EXPECT_THAT(actions, testing::ElementsAreArray(
{IsAction("text_reply", "i am at home")}));
EXPECT_EQ("address", actions[0].annotations[0].entity.collection);
}
TEST(LuaActions, EntityData) {
std::string test_schema = TestEntityDataSchema();
Conversation conversation = {{{/*user_id=*/1, "hello there"}}};
const std::string test_snippet = R"(
return {{
type = "test",
entity = {
greeting = "hello",
location = "there",
person = "Kenobi",
},
}};
)";
std::vector<ActionSuggestion> actions;
EXPECT_TRUE(LuaActionsSuggestions::CreateLuaActionsSuggestions(
test_snippet, conversation,
/*model_executor=*/nullptr,
/*model_spec=*/nullptr,
/*interpreter=*/nullptr,
/*actions_entity_data_schema=*/
flatbuffers::GetRoot<reflection::Schema>(test_schema.data()),
/*annotations_entity_data_schema=*/nullptr)
->SuggestActions(&actions));
EXPECT_THAT(actions, testing::SizeIs(1));
EXPECT_EQ("test", actions.front().type);
const flatbuffers::Table* entity =
flatbuffers::GetAnyRoot(reinterpret_cast<const unsigned char*>(
actions.front().serialized_entity_data.data()));
EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/4)->str(),
"hello");
EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/6)->str(),
"there");
EXPECT_EQ(entity->GetPointer<const flatbuffers::String*>(/*field=*/8)->str(),
"Kenobi");
}
} // namespace
} // namespace libtextclassifier3