| /* |
| * Copyright (C) 2018 The Android Open Source Project |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| #include "actions/ranker.h" |
| |
| #include <functional> |
| #include <set> |
| #include <vector> |
| |
| #include "actions/actions_model_generated.h" |
| |
| #if !defined(TC3_DISABLE_LUA) |
| #include "actions/lua-ranker.h" |
| #endif |
| #include "actions/zlib-utils.h" |
| #include "annotator/types.h" |
| #include "utils/base/logging.h" |
| #if !defined(TC3_DISABLE_LUA) |
| #include "utils/lua-utils.h" |
| #endif |
| |
| namespace libtextclassifier3 { |
| namespace { |
| |
| void SortByScoreAndType(std::vector<ActionSuggestion>* actions) { |
| std::stable_sort(actions->begin(), actions->end(), |
| [](const ActionSuggestion& a, const ActionSuggestion& b) { |
| return a.score > b.score || |
| (a.score >= b.score && a.type < b.type); |
| }); |
| } |
| |
| void SortByPriorityAndScoreAndType(std::vector<ActionSuggestion>* actions) { |
| std::stable_sort( |
| actions->begin(), actions->end(), |
| [](const ActionSuggestion& a, const ActionSuggestion& b) { |
| return a.priority_score > b.priority_score || |
| (a.priority_score >= b.priority_score && a.score > b.score) || |
| (a.priority_score >= b.priority_score && a.score >= b.score && |
| a.type < b.type); |
| }); |
| } |
| |
| template <typename T> |
| int Compare(const T& left, const T& right) { |
| if (left < right) { |
| return -1; |
| } |
| if (left > right) { |
| return 1; |
| } |
| return 0; |
| } |
| |
| template <> |
| int Compare(const std::string& left, const std::string& right) { |
| return left.compare(right); |
| } |
| |
| template <> |
| int Compare(const MessageTextSpan& span, const MessageTextSpan& other) { |
| if (const int value = Compare(span.message_index, other.message_index)) { |
| return value; |
| } |
| if (const int value = Compare(span.span.first, other.span.first)) { |
| return value; |
| } |
| if (const int value = Compare(span.span.second, other.span.second)) { |
| return value; |
| } |
| return 0; |
| } |
| |
| bool IsSameSpan(const MessageTextSpan& span, const MessageTextSpan& other) { |
| return Compare(span, other) == 0; |
| } |
| |
| bool TextSpansIntersect(const MessageTextSpan& span, |
| const MessageTextSpan& other) { |
| return span.message_index == other.message_index && |
| SpansOverlap(span.span, other.span); |
| } |
| |
| template <> |
| int Compare(const ActionSuggestionAnnotation& annotation, |
| const ActionSuggestionAnnotation& other) { |
| if (const int value = Compare(annotation.span, other.span)) { |
| return value; |
| } |
| if (const int value = Compare(annotation.name, other.name)) { |
| return value; |
| } |
| if (const int value = |
| Compare(annotation.entity.collection, other.entity.collection)) { |
| return value; |
| } |
| return 0; |
| } |
| |
| // Checks whether two annotations can be considered equivalent. |
| bool IsEquivalentActionAnnotation(const ActionSuggestionAnnotation& annotation, |
| const ActionSuggestionAnnotation& other) { |
| return Compare(annotation, other) == 0; |
| } |
| |
| // Compares actions based on annotations. |
| int CompareAnnotationsOnly(const ActionSuggestion& action, |
| const ActionSuggestion& other) { |
| if (const int value = |
| Compare(action.annotations.size(), other.annotations.size())) { |
| return value; |
| } |
| for (int i = 0; i < action.annotations.size(); i++) { |
| if (const int value = |
| Compare(action.annotations[i], other.annotations[i])) { |
| return value; |
| } |
| } |
| return 0; |
| } |
| |
| // Checks whether two actions have the same annotations. |
| bool HaveEquivalentAnnotations(const ActionSuggestion& action, |
| const ActionSuggestion& other) { |
| return CompareAnnotationsOnly(action, other) == 0; |
| } |
| |
| template <> |
| int Compare(const ActionSuggestion& action, const ActionSuggestion& other) { |
| if (const int value = Compare(action.type, other.type)) { |
| return value; |
| } |
| if (const int value = Compare(action.response_text, other.response_text)) { |
| return value; |
| } |
| if (const int value = Compare(action.serialized_entity_data, |
| other.serialized_entity_data)) { |
| return value; |
| } |
| return CompareAnnotationsOnly(action, other); |
| } |
| |
| // Checks whether two action suggestions can be considered equivalent. |
| bool IsEquivalentActionSuggestion(const ActionSuggestion& action, |
| const ActionSuggestion& other) { |
| return Compare(action, other) == 0; |
| } |
| |
| // Checks whether any action is equivalent to the given one. |
| bool IsAnyActionEquivalent(const ActionSuggestion& action, |
| const std::vector<ActionSuggestion>& actions) { |
| for (const ActionSuggestion& other : actions) { |
| if (IsEquivalentActionSuggestion(action, other)) { |
| return true; |
| } |
| } |
| return false; |
| } |
| |
| bool IsConflicting(const ActionSuggestionAnnotation& annotation, |
| const ActionSuggestionAnnotation& other) { |
| // Two annotations are conflicting if they are different but refer to |
| // overlapping spans in the conversation. |
| return (!IsEquivalentActionAnnotation(annotation, other) && |
| TextSpansIntersect(annotation.span, other.span)); |
| } |
| |
| // Checks whether two action suggestions can be considered conflicting. |
| bool IsConflictingActionSuggestion(const ActionSuggestion& action, |
| const ActionSuggestion& other) { |
| // Actions are considered conflicting, iff they refer to the same text span, |
| // but were not generated from the same annotation. |
| if (action.annotations.empty() || other.annotations.empty()) { |
| return false; |
| } |
| for (const ActionSuggestionAnnotation& annotation : action.annotations) { |
| for (const ActionSuggestionAnnotation& other_annotation : |
| other.annotations) { |
| if (IsConflicting(annotation, other_annotation)) { |
| return true; |
| } |
| } |
| } |
| return false; |
| } |
| |
| // Checks whether any action is considered conflicting with the given one. |
| bool IsAnyActionConflicting(const ActionSuggestion& action, |
| const std::vector<ActionSuggestion>& actions) { |
| for (const ActionSuggestion& other : actions) { |
| if (IsConflictingActionSuggestion(action, other)) { |
| return true; |
| } |
| } |
| return false; |
| } |
| |
| } // namespace |
| |
| std::unique_ptr<ActionsSuggestionsRanker> |
| ActionsSuggestionsRanker::CreateActionsSuggestionsRanker( |
| const RankingOptions* options, ZlibDecompressor* decompressor, |
| const std::string& smart_reply_action_type) { |
| auto ranker = std::unique_ptr<ActionsSuggestionsRanker>( |
| new ActionsSuggestionsRanker(options, smart_reply_action_type)); |
| |
| if (!ranker->InitializeAndValidate(decompressor)) { |
| TC3_LOG(ERROR) << "Could not initialize action ranker."; |
| return nullptr; |
| } |
| |
| return ranker; |
| } |
| |
| bool ActionsSuggestionsRanker::InitializeAndValidate( |
| ZlibDecompressor* decompressor) { |
| if (options_ == nullptr) { |
| TC3_LOG(ERROR) << "No ranking options specified."; |
| return false; |
| } |
| |
| #if !defined(TC3_DISABLE_LUA) |
| std::string lua_ranking_script; |
| if (GetUncompressedString(options_->lua_ranking_script(), |
| options_->compressed_lua_ranking_script(), |
| decompressor, &lua_ranking_script) && |
| !lua_ranking_script.empty()) { |
| if (!Compile(lua_ranking_script, &lua_bytecode_)) { |
| TC3_LOG(ERROR) << "Could not precompile lua ranking snippet."; |
| return false; |
| } |
| } |
| #endif |
| |
| return true; |
| } |
| |
| bool ActionsSuggestionsRanker::RankActions( |
| const Conversation& conversation, ActionsSuggestionsResponse* response, |
| const reflection::Schema* entity_data_schema, |
| const reflection::Schema* annotations_entity_data_schema) const { |
| if (options_->deduplicate_suggestions() || |
| options_->deduplicate_suggestions_by_span()) { |
| // Order suggestions by [priority score -> score] for deduplication |
| SortByPriorityAndScoreAndType(&response->actions); |
| |
| // Deduplicate, keeping the higher score actions. |
| if (options_->deduplicate_suggestions()) { |
| std::vector<ActionSuggestion> deduplicated_actions; |
| for (const ActionSuggestion& candidate : response->actions) { |
| // Check whether we already have an equivalent action. |
| if (!IsAnyActionEquivalent(candidate, deduplicated_actions)) { |
| deduplicated_actions.push_back(std::move(candidate)); |
| } |
| } |
| response->actions = std::move(deduplicated_actions); |
| } |
| |
| // Resolve conflicts between conflicting actions referring to the same |
| // text span. |
| if (options_->deduplicate_suggestions_by_span()) { |
| std::vector<ActionSuggestion> deduplicated_actions; |
| for (const ActionSuggestion& candidate : response->actions) { |
| // Check whether we already have a conflicting action. |
| if (!IsAnyActionConflicting(candidate, deduplicated_actions)) { |
| deduplicated_actions.push_back(std::move(candidate)); |
| } |
| } |
| response->actions = std::move(deduplicated_actions); |
| } |
| } |
| |
| bool sort_by_priority = |
| options_->sort_type() == RankingOptionsSortType_SORT_TYPE_PRIORITY_SCORE; |
| // Suppress smart replies if actions are present. |
| if (options_->suppress_smart_replies_with_actions()) { |
| std::vector<ActionSuggestion> non_smart_reply_actions; |
| for (const ActionSuggestion& action : response->actions) { |
| if (action.type != smart_reply_action_type_) { |
| non_smart_reply_actions.push_back(std::move(action)); |
| } |
| } |
| response->actions = std::move(non_smart_reply_actions); |
| } |
| |
| // Group by annotation if specified. |
| if (options_->group_by_annotations()) { |
| auto group_id = std::map< |
| ActionSuggestion, int, |
| std::function<bool(const ActionSuggestion&, const ActionSuggestion&)>>{ |
| [](const ActionSuggestion& action, const ActionSuggestion& other) { |
| return (CompareAnnotationsOnly(action, other) < 0); |
| }}; |
| typedef std::vector<ActionSuggestion> ActionSuggestionGroup; |
| std::vector<ActionSuggestionGroup> groups; |
| |
| // Group actions by the annotation set they are based of. |
| for (const ActionSuggestion& action : response->actions) { |
| // Treat actions with no annotations idependently. |
| if (action.annotations.empty()) { |
| groups.emplace_back(1, action); |
| continue; |
| } |
| |
| auto it = group_id.find(action); |
| if (it != group_id.end()) { |
| groups[it->second].push_back(action); |
| } else { |
| group_id[action] = groups.size(); |
| groups.emplace_back(1, action); |
| } |
| } |
| |
| // Sort within each group by score. |
| for (std::vector<ActionSuggestion>& group : groups) { |
| if (sort_by_priority) { |
| SortByPriorityAndScoreAndType(&group); |
| } else { |
| SortByScoreAndType(&group); |
| } |
| } |
| |
| // Sort groups by maximum score or priority score. |
| if (sort_by_priority) { |
| std::stable_sort( |
| groups.begin(), groups.end(), |
| [](const std::vector<ActionSuggestion>& a, |
| const std::vector<ActionSuggestion>& b) { |
| return (a.begin()->priority_score > b.begin()->priority_score) || |
| (a.begin()->priority_score >= b.begin()->priority_score && |
| a.begin()->score > b.begin()->score) || |
| (a.begin()->priority_score >= b.begin()->priority_score && |
| a.begin()->score >= b.begin()->score && |
| a.begin()->type < b.begin()->type); |
| }); |
| } else { |
| std::stable_sort(groups.begin(), groups.end(), |
| [](const std::vector<ActionSuggestion>& a, |
| const std::vector<ActionSuggestion>& b) { |
| return a.begin()->score > b.begin()->score || |
| (a.begin()->score >= b.begin()->score && |
| a.begin()->type < b.begin()->type); |
| }); |
| } |
| |
| // Flatten result. |
| const size_t num_actions = response->actions.size(); |
| response->actions.clear(); |
| response->actions.reserve(num_actions); |
| for (const std::vector<ActionSuggestion>& actions : groups) { |
| response->actions.insert(response->actions.end(), actions.begin(), |
| actions.end()); |
| } |
| } else if (sort_by_priority) { |
| SortByPriorityAndScoreAndType(&response->actions); |
| } else { |
| SortByScoreAndType(&response->actions); |
| } |
| |
| #if !defined(TC3_DISABLE_LUA) |
| // Run lua ranking snippet, if provided. |
| if (!lua_bytecode_.empty()) { |
| auto lua_ranker = ActionsSuggestionsLuaRanker::Create( |
| conversation, lua_bytecode_, entity_data_schema, |
| annotations_entity_data_schema, response); |
| if (lua_ranker == nullptr || !lua_ranker->RankActions()) { |
| TC3_LOG(ERROR) << "Could not run lua ranking snippet."; |
| return false; |
| } |
| } |
| #endif |
| |
| return true; |
| } |
| |
| } // namespace libtextclassifier3 |