blob: d52ecaa3b32362d94a4ab7b4cb02537679c4d1d5 [file] [log] [blame]
/*
* Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "actions/ranker.h"
#include <functional>
#include <set>
#include <vector>
#if !defined(TC3_DISABLE_LUA)
#include "actions/lua-ranker.h"
#endif
#include "actions/zlib-utils.h"
#include "annotator/types.h"
#include "utils/base/logging.h"
#if !defined(TC3_DISABLE_LUA)
#include "utils/lua-utils.h"
#endif
namespace libtextclassifier3 {
namespace {
void SortByScoreAndType(std::vector<ActionSuggestion>* actions) {
std::sort(actions->begin(), actions->end(),
[](const ActionSuggestion& a, const ActionSuggestion& b) {
return a.score > b.score ||
(a.score >= b.score && a.type < b.type);
});
}
template <typename T>
int Compare(const T& left, const T& right) {
if (left < right) {
return -1;
}
if (left > right) {
return 1;
}
return 0;
}
template <>
int Compare(const std::string& left, const std::string& right) {
return left.compare(right);
}
template <>
int Compare(const MessageTextSpan& span, const MessageTextSpan& other) {
if (const int value = Compare(span.message_index, other.message_index)) {
return value;
}
if (const int value = Compare(span.span.first, other.span.first)) {
return value;
}
if (const int value = Compare(span.span.second, other.span.second)) {
return value;
}
return 0;
}
bool IsSameSpan(const MessageTextSpan& span, const MessageTextSpan& other) {
return Compare(span, other) == 0;
}
bool TextSpansIntersect(const MessageTextSpan& span,
const MessageTextSpan& other) {
return span.message_index == other.message_index &&
SpansOverlap(span.span, other.span);
}
template <>
int Compare(const ActionSuggestionAnnotation& annotation,
const ActionSuggestionAnnotation& other) {
if (const int value = Compare(annotation.span, other.span)) {
return value;
}
if (const int value = Compare(annotation.name, other.name)) {
return value;
}
if (const int value =
Compare(annotation.entity.collection, other.entity.collection)) {
return value;
}
return 0;
}
// Checks whether two annotations can be considered equivalent.
bool IsEquivalentActionAnnotation(const ActionSuggestionAnnotation& annotation,
const ActionSuggestionAnnotation& other) {
return Compare(annotation, other) == 0;
}
// Compares actions based on annotations.
int CompareAnnotationsOnly(const ActionSuggestion& action,
const ActionSuggestion& other) {
if (const int value =
Compare(action.annotations.size(), other.annotations.size())) {
return value;
}
for (int i = 0; i < action.annotations.size(); i++) {
if (const int value =
Compare(action.annotations[i], other.annotations[i])) {
return value;
}
}
return 0;
}
// Checks whether two actions have the same annotations.
bool HaveEquivalentAnnotations(const ActionSuggestion& action,
const ActionSuggestion& other) {
return CompareAnnotationsOnly(action, other) == 0;
}
template <>
int Compare(const ActionSuggestion& action, const ActionSuggestion& other) {
if (const int value = Compare(action.type, other.type)) {
return value;
}
if (const int value = Compare(action.response_text, other.response_text)) {
return value;
}
if (const int value = Compare(action.serialized_entity_data,
other.serialized_entity_data)) {
return value;
}
return CompareAnnotationsOnly(action, other);
}
// Checks whether two action suggestions can be considered equivalent.
bool IsEquivalentActionSuggestion(const ActionSuggestion& action,
const ActionSuggestion& other) {
return Compare(action, other) == 0;
}
// Checks whether any action is equivalent to the given one.
bool IsAnyActionEquivalent(const ActionSuggestion& action,
const std::vector<ActionSuggestion>& actions) {
for (const ActionSuggestion& other : actions) {
if (IsEquivalentActionSuggestion(action, other)) {
return true;
}
}
return false;
}
bool IsConflicting(const ActionSuggestionAnnotation& annotation,
const ActionSuggestionAnnotation& other) {
// Two annotations are conflicting if they are different but refer to
// overlapping spans in the conversation.
return (!IsEquivalentActionAnnotation(annotation, other) &&
TextSpansIntersect(annotation.span, other.span));
}
// Checks whether two action suggestions can be considered conflicting.
bool IsConflictingActionSuggestion(const ActionSuggestion& action,
const ActionSuggestion& other) {
// Actions are considered conflicting, iff they refer to the same text span,
// but were not generated from the same annotation.
if (action.annotations.empty() || other.annotations.empty()) {
return false;
}
for (const ActionSuggestionAnnotation& annotation : action.annotations) {
for (const ActionSuggestionAnnotation& other_annotation :
other.annotations) {
if (IsConflicting(annotation, other_annotation)) {
return true;
}
}
}
return false;
}
// Checks whether any action is considered conflicting with the given one.
bool IsAnyActionConflicting(const ActionSuggestion& action,
const std::vector<ActionSuggestion>& actions) {
for (const ActionSuggestion& other : actions) {
if (IsConflictingActionSuggestion(action, other)) {
return true;
}
}
return false;
}
} // namespace
std::unique_ptr<ActionsSuggestionsRanker>
ActionsSuggestionsRanker::CreateActionsSuggestionsRanker(
const RankingOptions* options, ZlibDecompressor* decompressor,
const std::string& smart_reply_action_type) {
auto ranker = std::unique_ptr<ActionsSuggestionsRanker>(
new ActionsSuggestionsRanker(options, smart_reply_action_type));
if (!ranker->InitializeAndValidate(decompressor)) {
TC3_LOG(ERROR) << "Could not initialize action ranker.";
return nullptr;
}
return ranker;
}
bool ActionsSuggestionsRanker::InitializeAndValidate(
ZlibDecompressor* decompressor) {
if (options_ == nullptr) {
TC3_LOG(ERROR) << "No ranking options specified.";
return false;
}
#if !defined(TC3_DISABLE_LUA)
std::string lua_ranking_script;
if (GetUncompressedString(options_->lua_ranking_script(),
options_->compressed_lua_ranking_script(),
decompressor, &lua_ranking_script) &&
!lua_ranking_script.empty()) {
if (!Compile(lua_ranking_script, &lua_bytecode_)) {
TC3_LOG(ERROR) << "Could not precompile lua ranking snippet.";
return false;
}
}
#endif
return true;
}
bool ActionsSuggestionsRanker::RankActions(
const Conversation& conversation, ActionsSuggestionsResponse* response,
const reflection::Schema* entity_data_schema,
const reflection::Schema* annotations_entity_data_schema) const {
if (options_->deduplicate_suggestions() ||
options_->deduplicate_suggestions_by_span()) {
// First order suggestions by priority score for deduplication.
std::sort(
response->actions.begin(), response->actions.end(),
[](const ActionSuggestion& a, const ActionSuggestion& b) {
return a.priority_score > b.priority_score ||
(a.priority_score >= b.priority_score && a.score > b.score);
});
// Deduplicate, keeping the higher score actions.
if (options_->deduplicate_suggestions()) {
std::vector<ActionSuggestion> deduplicated_actions;
for (const ActionSuggestion& candidate : response->actions) {
// Check whether we already have an equivalent action.
if (!IsAnyActionEquivalent(candidate, deduplicated_actions)) {
deduplicated_actions.push_back(std::move(candidate));
}
}
response->actions = std::move(deduplicated_actions);
}
// Resolve conflicts between conflicting actions referring to the same
// text span.
if (options_->deduplicate_suggestions_by_span()) {
std::vector<ActionSuggestion> deduplicated_actions;
for (const ActionSuggestion& candidate : response->actions) {
// Check whether we already have a conflicting action.
if (!IsAnyActionConflicting(candidate, deduplicated_actions)) {
deduplicated_actions.push_back(std::move(candidate));
}
}
response->actions = std::move(deduplicated_actions);
}
}
// Suppress smart replies if actions are present.
if (options_->suppress_smart_replies_with_actions()) {
std::vector<ActionSuggestion> non_smart_reply_actions;
for (const ActionSuggestion& action : response->actions) {
if (action.type != smart_reply_action_type_) {
non_smart_reply_actions.push_back(std::move(action));
}
}
response->actions = std::move(non_smart_reply_actions);
}
// Group by annotation if specified.
if (options_->group_by_annotations()) {
auto group_id = std::map<
ActionSuggestion, int,
std::function<bool(const ActionSuggestion&, const ActionSuggestion&)>>{
[](const ActionSuggestion& action, const ActionSuggestion& other) {
return (CompareAnnotationsOnly(action, other) < 0);
}};
typedef std::vector<ActionSuggestion> ActionSuggestionGroup;
std::vector<ActionSuggestionGroup> groups;
// Group actions by the annotation set they are based of.
for (const ActionSuggestion& action : response->actions) {
// Treat actions with no annotations idependently.
if (action.annotations.empty()) {
groups.emplace_back(1, action);
continue;
}
auto it = group_id.find(action);
if (it != group_id.end()) {
groups[it->second].push_back(action);
} else {
group_id[action] = groups.size();
groups.emplace_back(1, action);
}
}
// Sort within each group by score.
for (std::vector<ActionSuggestion>& group : groups) {
SortByScoreAndType(&group);
}
// Sort groups by maximum score.
std::sort(groups.begin(), groups.end(),
[](const std::vector<ActionSuggestion>& a,
const std::vector<ActionSuggestion>& b) {
return a.begin()->score > b.begin()->score ||
(a.begin()->score >= b.begin()->score &&
a.begin()->type < b.begin()->type);
});
// Flatten result.
const size_t num_actions = response->actions.size();
response->actions.clear();
response->actions.reserve(num_actions);
for (const std::vector<ActionSuggestion>& actions : groups) {
response->actions.insert(response->actions.end(), actions.begin(),
actions.end());
}
} else {
// Order suggestions independently by score.
SortByScoreAndType(&response->actions);
}
#if !defined(TC3_DISABLE_LUA)
// Run lua ranking snippet, if provided.
if (!lua_bytecode_.empty()) {
auto lua_ranker = ActionsSuggestionsLuaRanker::Create(
conversation, lua_bytecode_, entity_data_schema,
annotations_entity_data_schema, response);
if (lua_ranker == nullptr || !lua_ranker->RankActions()) {
TC3_LOG(ERROR) << "Could not run lua ranking snippet.";
return false;
}
}
#endif
return true;
}
} // namespace libtextclassifier3