| /* |
| * Copyright (C) 2018 The Android Open Source Project |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| #include "actions/actions-suggestions.h" |
| |
| #include <memory> |
| #include <string> |
| #include <vector> |
| |
| #include "utils/base/statusor.h" |
| #include "absl/container/flat_hash_map.h" |
| #include "absl/random/random.h" |
| |
| #if !defined(TC3_DISABLE_LUA) |
| #include "actions/lua-actions.h" |
| #endif |
| #include "actions/ngram-model.h" |
| #include "actions/tflite-sensitive-model.h" |
| #include "actions/types.h" |
| #include "actions/utils.h" |
| #include "actions/zlib-utils.h" |
| #include "annotator/collections.h" |
| #include "utils/base/logging.h" |
| #if !defined(TC3_DISABLE_LUA) |
| #include "utils/lua-utils.h" |
| #endif |
| #include "utils/normalization.h" |
| #include "utils/optional.h" |
| #include "utils/strings/split.h" |
| #include "utils/strings/stringpiece.h" |
| #include "utils/strings/utf8.h" |
| #include "utils/utf8/unicodetext.h" |
| #include "absl/container/flat_hash_set.h" |
| #include "absl/random/distributions.h" |
| #include "tensorflow/lite/string_util.h" |
| |
| namespace libtextclassifier3 { |
| |
| constexpr float kDefaultFloat = 0.0; |
| constexpr bool kDefaultBool = false; |
| constexpr int kDefaultInt = 1; |
| |
| namespace { |
| |
| const ActionsModel* LoadAndVerifyModel(const uint8_t* addr, int size) { |
| flatbuffers::Verifier verifier(addr, size); |
| if (VerifyActionsModelBuffer(verifier)) { |
| return GetActionsModel(addr); |
| } else { |
| return nullptr; |
| } |
| } |
| |
| template <typename T> |
| T ValueOrDefault(const flatbuffers::Table* values, const int32 field_offset, |
| const T default_value) { |
| if (values == nullptr) { |
| return default_value; |
| } |
| return values->GetField<T>(field_offset, default_value); |
| } |
| |
| // Returns number of (tail) messages of a conversation to consider. |
| int NumMessagesToConsider(const Conversation& conversation, |
| const int max_conversation_history_length) { |
| return ((max_conversation_history_length < 0 || |
| conversation.messages.size() < max_conversation_history_length) |
| ? conversation.messages.size() |
| : max_conversation_history_length); |
| } |
| |
| template <typename T> |
| std::vector<T> PadOrTruncateToTargetLength(const std::vector<T>& inputs, |
| const int max_length, |
| const T pad_value) { |
| if (inputs.size() >= max_length) { |
| return std::vector<T>(inputs.begin(), inputs.begin() + max_length); |
| } else { |
| std::vector<T> result; |
| result.reserve(max_length); |
| result.insert(result.begin(), inputs.begin(), inputs.end()); |
| result.insert(result.end(), max_length - inputs.size(), pad_value); |
| return result; |
| } |
| } |
| |
| template <typename T> |
| void SetVectorOrScalarAsModelInput( |
| const int param_index, const Variant& param_value, |
| tflite::Interpreter* interpreter, |
| const std::unique_ptr<const TfLiteModelExecutor>& model_executor) { |
| if (param_value.Has<std::vector<T>>()) { |
| model_executor->SetInput<T>( |
| param_index, param_value.ConstRefValue<std::vector<T>>(), interpreter); |
| } else if (param_value.Has<T>()) { |
| model_executor->SetInput<float>(param_index, param_value.Value<T>(), |
| interpreter); |
| } else { |
| TC3_LOG(ERROR) << "Variant type error!"; |
| } |
| } |
| } // namespace |
| |
| std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromUnownedBuffer( |
| const uint8_t* buffer, const int size, const UniLib* unilib, |
| const std::string& triggering_preconditions_overlay) { |
| auto actions = std::unique_ptr<ActionsSuggestions>(new ActionsSuggestions()); |
| const ActionsModel* model = LoadAndVerifyModel(buffer, size); |
| if (model == nullptr) { |
| return nullptr; |
| } |
| actions->model_ = model; |
| actions->SetOrCreateUnilib(unilib); |
| actions->triggering_preconditions_overlay_buffer_ = |
| triggering_preconditions_overlay; |
| if (!actions->ValidateAndInitialize()) { |
| return nullptr; |
| } |
| return actions; |
| } |
| |
| std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromScopedMmap( |
| std::unique_ptr<libtextclassifier3::ScopedMmap> mmap, const UniLib* unilib, |
| const std::string& triggering_preconditions_overlay) { |
| if (!mmap->handle().ok()) { |
| TC3_VLOG(1) << "Mmap failed."; |
| return nullptr; |
| } |
| const ActionsModel* model = LoadAndVerifyModel( |
| reinterpret_cast<const uint8_t*>(mmap->handle().start()), |
| mmap->handle().num_bytes()); |
| if (!model) { |
| TC3_LOG(ERROR) << "Model verification failed."; |
| return nullptr; |
| } |
| auto actions = std::unique_ptr<ActionsSuggestions>(new ActionsSuggestions()); |
| actions->model_ = model; |
| actions->mmap_ = std::move(mmap); |
| actions->SetOrCreateUnilib(unilib); |
| actions->triggering_preconditions_overlay_buffer_ = |
| triggering_preconditions_overlay; |
| if (!actions->ValidateAndInitialize()) { |
| return nullptr; |
| } |
| return actions; |
| } |
| |
| std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromScopedMmap( |
| std::unique_ptr<libtextclassifier3::ScopedMmap> mmap, |
| std::unique_ptr<UniLib> unilib, |
| const std::string& triggering_preconditions_overlay) { |
| if (!mmap->handle().ok()) { |
| TC3_VLOG(1) << "Mmap failed."; |
| return nullptr; |
| } |
| const ActionsModel* model = LoadAndVerifyModel( |
| reinterpret_cast<const uint8_t*>(mmap->handle().start()), |
| mmap->handle().num_bytes()); |
| if (!model) { |
| TC3_LOG(ERROR) << "Model verification failed."; |
| return nullptr; |
| } |
| auto actions = std::unique_ptr<ActionsSuggestions>(new ActionsSuggestions()); |
| actions->model_ = model; |
| actions->mmap_ = std::move(mmap); |
| actions->owned_unilib_ = std::move(unilib); |
| actions->unilib_ = actions->owned_unilib_.get(); |
| actions->triggering_preconditions_overlay_buffer_ = |
| triggering_preconditions_overlay; |
| if (!actions->ValidateAndInitialize()) { |
| return nullptr; |
| } |
| return actions; |
| } |
| |
| std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromFileDescriptor( |
| const int fd, const int offset, const int size, const UniLib* unilib, |
| const std::string& triggering_preconditions_overlay) { |
| std::unique_ptr<libtextclassifier3::ScopedMmap> mmap; |
| if (offset >= 0 && size >= 0) { |
| mmap.reset(new libtextclassifier3::ScopedMmap(fd, offset, size)); |
| } else { |
| mmap.reset(new libtextclassifier3::ScopedMmap(fd)); |
| } |
| return FromScopedMmap(std::move(mmap), unilib, |
| triggering_preconditions_overlay); |
| } |
| |
| std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromFileDescriptor( |
| const int fd, const int offset, const int size, |
| std::unique_ptr<UniLib> unilib, |
| const std::string& triggering_preconditions_overlay) { |
| std::unique_ptr<libtextclassifier3::ScopedMmap> mmap; |
| if (offset >= 0 && size >= 0) { |
| mmap.reset(new libtextclassifier3::ScopedMmap(fd, offset, size)); |
| } else { |
| mmap.reset(new libtextclassifier3::ScopedMmap(fd)); |
| } |
| return FromScopedMmap(std::move(mmap), std::move(unilib), |
| triggering_preconditions_overlay); |
| } |
| |
| std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromFileDescriptor( |
| const int fd, const UniLib* unilib, |
| const std::string& triggering_preconditions_overlay) { |
| std::unique_ptr<libtextclassifier3::ScopedMmap> mmap( |
| new libtextclassifier3::ScopedMmap(fd)); |
| return FromScopedMmap(std::move(mmap), unilib, |
| triggering_preconditions_overlay); |
| } |
| |
| std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromFileDescriptor( |
| const int fd, std::unique_ptr<UniLib> unilib, |
| const std::string& triggering_preconditions_overlay) { |
| std::unique_ptr<libtextclassifier3::ScopedMmap> mmap( |
| new libtextclassifier3::ScopedMmap(fd)); |
| return FromScopedMmap(std::move(mmap), std::move(unilib), |
| triggering_preconditions_overlay); |
| } |
| |
| std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromPath( |
| const std::string& path, const UniLib* unilib, |
| const std::string& triggering_preconditions_overlay) { |
| std::unique_ptr<libtextclassifier3::ScopedMmap> mmap( |
| new libtextclassifier3::ScopedMmap(path)); |
| return FromScopedMmap(std::move(mmap), unilib, |
| triggering_preconditions_overlay); |
| } |
| |
| std::unique_ptr<ActionsSuggestions> ActionsSuggestions::FromPath( |
| const std::string& path, std::unique_ptr<UniLib> unilib, |
| const std::string& triggering_preconditions_overlay) { |
| std::unique_ptr<libtextclassifier3::ScopedMmap> mmap( |
| new libtextclassifier3::ScopedMmap(path)); |
| return FromScopedMmap(std::move(mmap), std::move(unilib), |
| triggering_preconditions_overlay); |
| } |
| |
| void ActionsSuggestions::SetOrCreateUnilib(const UniLib* unilib) { |
| if (unilib != nullptr) { |
| unilib_ = unilib; |
| } else { |
| owned_unilib_.reset(new UniLib); |
| unilib_ = owned_unilib_.get(); |
| } |
| } |
| |
| bool ActionsSuggestions::ValidateAndInitialize() { |
| if (model_ == nullptr) { |
| TC3_LOG(ERROR) << "No model specified."; |
| return false; |
| } |
| |
| if (model_->smart_reply_action_type() == nullptr) { |
| TC3_LOG(ERROR) << "No smart reply action type specified."; |
| return false; |
| } |
| |
| if (!InitializeTriggeringPreconditions()) { |
| TC3_LOG(ERROR) << "Could not initialize preconditions."; |
| return false; |
| } |
| |
| if (model_->locales() && |
| !ParseLocales(model_->locales()->c_str(), &locales_)) { |
| TC3_LOG(ERROR) << "Could not parse model supported locales."; |
| return false; |
| } |
| |
| if (model_->tflite_model_spec() != nullptr) { |
| model_executor_ = TfLiteModelExecutor::FromBuffer( |
| model_->tflite_model_spec()->tflite_model()); |
| if (!model_executor_) { |
| TC3_LOG(ERROR) << "Could not initialize model executor."; |
| return false; |
| } |
| } |
| |
| // Gather annotation entities for the rules. |
| if (model_->annotation_actions_spec() != nullptr && |
| model_->annotation_actions_spec()->annotation_mapping() != nullptr) { |
| for (const AnnotationActionsSpec_::AnnotationMapping* mapping : |
| *model_->annotation_actions_spec()->annotation_mapping()) { |
| annotation_entity_types_.insert(mapping->annotation_collection()->str()); |
| } |
| } |
| |
| if (model_->actions_entity_data_schema() != nullptr) { |
| entity_data_schema_ = LoadAndVerifyFlatbuffer<reflection::Schema>( |
| model_->actions_entity_data_schema()->Data(), |
| model_->actions_entity_data_schema()->size()); |
| if (entity_data_schema_ == nullptr) { |
| TC3_LOG(ERROR) << "Could not load entity data schema data."; |
| return false; |
| } |
| |
| entity_data_builder_.reset( |
| new MutableFlatbufferBuilder(entity_data_schema_)); |
| } else { |
| entity_data_schema_ = nullptr; |
| } |
| |
| // Initialize regular expressions model. |
| std::unique_ptr<ZlibDecompressor> decompressor = ZlibDecompressor::Instance(); |
| regex_actions_.reset( |
| new RegexActions(unilib_, model_->smart_reply_action_type()->str())); |
| if (!regex_actions_->InitializeRules( |
| model_->rules(), model_->low_confidence_rules(), |
| triggering_preconditions_overlay_, decompressor.get())) { |
| TC3_LOG(ERROR) << "Could not initialize regex rules."; |
| return false; |
| } |
| |
| // Setup grammar model. |
| if (model_->rules() != nullptr && |
| model_->rules()->grammar_rules() != nullptr) { |
| grammar_actions_.reset(new GrammarActions( |
| unilib_, model_->rules()->grammar_rules(), entity_data_builder_.get(), |
| model_->smart_reply_action_type()->str())); |
| |
| // Gather annotation entities for the grammars. |
| if (auto annotation_nt = model_->rules() |
| ->grammar_rules() |
| ->rules() |
| ->nonterminals() |
| ->annotation_nt()) { |
| for (const grammar::RulesSet_::Nonterminals_::AnnotationNtEntry* entry : |
| *annotation_nt) { |
| annotation_entity_types_.insert(entry->key()->str()); |
| } |
| } |
| } |
| |
| #if !defined(TC3_DISABLE_LUA) |
| std::string actions_script; |
| if (GetUncompressedString(model_->lua_actions_script(), |
| model_->compressed_lua_actions_script(), |
| decompressor.get(), &actions_script) && |
| !actions_script.empty()) { |
| if (!Compile(actions_script, &lua_bytecode_)) { |
| TC3_LOG(ERROR) << "Could not precompile lua actions snippet."; |
| return false; |
| } |
| } |
| #endif // TC3_DISABLE_LUA |
| |
| if (!(ranker_ = ActionsSuggestionsRanker::CreateActionsSuggestionsRanker( |
| model_->ranking_options(), decompressor.get(), |
| model_->smart_reply_action_type()->str()))) { |
| TC3_LOG(ERROR) << "Could not create an action suggestions ranker."; |
| return false; |
| } |
| |
| // Create feature processor if specified. |
| const ActionsTokenFeatureProcessorOptions* options = |
| model_->feature_processor_options(); |
| if (options != nullptr) { |
| if (options->tokenizer_options() == nullptr) { |
| TC3_LOG(ERROR) << "No tokenizer options specified."; |
| return false; |
| } |
| |
| feature_processor_.reset(new ActionsFeatureProcessor(options, unilib_)); |
| embedding_executor_ = TFLiteEmbeddingExecutor::FromBuffer( |
| options->embedding_model(), options->embedding_size(), |
| options->embedding_quantization_bits()); |
| |
| if (embedding_executor_ == nullptr) { |
| TC3_LOG(ERROR) << "Could not initialize embedding executor."; |
| return false; |
| } |
| |
| // Cache embedding of padding, start and end token. |
| if (!EmbedTokenId(options->padding_token_id(), &embedded_padding_token_) || |
| !EmbedTokenId(options->start_token_id(), &embedded_start_token_) || |
| !EmbedTokenId(options->end_token_id(), &embedded_end_token_)) { |
| TC3_LOG(ERROR) << "Could not precompute token embeddings."; |
| return false; |
| } |
| token_embedding_size_ = feature_processor_->GetTokenEmbeddingSize(); |
| } |
| |
| // Create low confidence model if specified. |
| if (model_->low_confidence_ngram_model() != nullptr) { |
| sensitive_model_ = NGramSensitiveModel::Create( |
| unilib_, model_->low_confidence_ngram_model(), |
| feature_processor_ == nullptr ? nullptr |
| : feature_processor_->tokenizer()); |
| if (sensitive_model_ == nullptr) { |
| TC3_LOG(ERROR) << "Could not create ngram linear regression model."; |
| return false; |
| } |
| } |
| if (model_->low_confidence_tflite_model() != nullptr) { |
| sensitive_model_ = |
| TFLiteSensitiveModel::Create(model_->low_confidence_tflite_model()); |
| if (sensitive_model_ == nullptr) { |
| TC3_LOG(ERROR) << "Could not create TFLite sensitive model."; |
| return false; |
| } |
| } |
| |
| return true; |
| } |
| |
| bool ActionsSuggestions::InitializeTriggeringPreconditions() { |
| triggering_preconditions_overlay_ = |
| LoadAndVerifyFlatbuffer<TriggeringPreconditions>( |
| triggering_preconditions_overlay_buffer_); |
| |
| if (triggering_preconditions_overlay_ == nullptr && |
| !triggering_preconditions_overlay_buffer_.empty()) { |
| TC3_LOG(ERROR) << "Could not load triggering preconditions overwrites."; |
| return false; |
| } |
| const flatbuffers::Table* overlay = |
| reinterpret_cast<const flatbuffers::Table*>( |
| triggering_preconditions_overlay_); |
| const TriggeringPreconditions* defaults = model_->preconditions(); |
| if (defaults == nullptr) { |
| TC3_LOG(ERROR) << "No triggering conditions specified."; |
| return false; |
| } |
| |
| preconditions_.min_smart_reply_triggering_score = ValueOrDefault( |
| overlay, TriggeringPreconditions::VT_MIN_SMART_REPLY_TRIGGERING_SCORE, |
| defaults->min_smart_reply_triggering_score()); |
| preconditions_.max_sensitive_topic_score = ValueOrDefault( |
| overlay, TriggeringPreconditions::VT_MAX_SENSITIVE_TOPIC_SCORE, |
| defaults->max_sensitive_topic_score()); |
| preconditions_.suppress_on_sensitive_topic = ValueOrDefault( |
| overlay, TriggeringPreconditions::VT_SUPPRESS_ON_SENSITIVE_TOPIC, |
| defaults->suppress_on_sensitive_topic()); |
| preconditions_.min_input_length = |
| ValueOrDefault(overlay, TriggeringPreconditions::VT_MIN_INPUT_LENGTH, |
| defaults->min_input_length()); |
| preconditions_.max_input_length = |
| ValueOrDefault(overlay, TriggeringPreconditions::VT_MAX_INPUT_LENGTH, |
| defaults->max_input_length()); |
| preconditions_.min_locale_match_fraction = ValueOrDefault( |
| overlay, TriggeringPreconditions::VT_MIN_LOCALE_MATCH_FRACTION, |
| defaults->min_locale_match_fraction()); |
| preconditions_.handle_missing_locale_as_supported = ValueOrDefault( |
| overlay, TriggeringPreconditions::VT_HANDLE_MISSING_LOCALE_AS_SUPPORTED, |
| defaults->handle_missing_locale_as_supported()); |
| preconditions_.handle_unknown_locale_as_supported = ValueOrDefault( |
| overlay, TriggeringPreconditions::VT_HANDLE_UNKNOWN_LOCALE_AS_SUPPORTED, |
| defaults->handle_unknown_locale_as_supported()); |
| preconditions_.suppress_on_low_confidence_input = ValueOrDefault( |
| overlay, TriggeringPreconditions::VT_SUPPRESS_ON_LOW_CONFIDENCE_INPUT, |
| defaults->suppress_on_low_confidence_input()); |
| preconditions_.min_reply_score_threshold = ValueOrDefault( |
| overlay, TriggeringPreconditions::VT_MIN_REPLY_SCORE_THRESHOLD, |
| defaults->min_reply_score_threshold()); |
| |
| return true; |
| } |
| |
| bool ActionsSuggestions::EmbedTokenId(const int32 token_id, |
| std::vector<float>* embedding) const { |
| return feature_processor_->AppendFeatures( |
| {token_id}, |
| /*dense_features=*/{}, embedding_executor_.get(), embedding); |
| } |
| |
| std::vector<std::vector<Token>> ActionsSuggestions::Tokenize( |
| const std::vector<std::string>& context) const { |
| std::vector<std::vector<Token>> tokens; |
| tokens.reserve(context.size()); |
| for (const std::string& message : context) { |
| tokens.push_back(feature_processor_->tokenizer()->Tokenize(message)); |
| } |
| return tokens; |
| } |
| |
| bool ActionsSuggestions::EmbedTokensPerMessage( |
| const std::vector<std::vector<Token>>& tokens, |
| std::vector<float>* embeddings, int* max_num_tokens_per_message) const { |
| const int num_messages = tokens.size(); |
| *max_num_tokens_per_message = 0; |
| for (int i = 0; i < num_messages; i++) { |
| const int num_message_tokens = tokens[i].size(); |
| if (num_message_tokens > *max_num_tokens_per_message) { |
| *max_num_tokens_per_message = num_message_tokens; |
| } |
| } |
| |
| if (model_->feature_processor_options()->min_num_tokens_per_message() > |
| *max_num_tokens_per_message) { |
| *max_num_tokens_per_message = |
| model_->feature_processor_options()->min_num_tokens_per_message(); |
| } |
| if (model_->feature_processor_options()->max_num_tokens_per_message() > 0 && |
| *max_num_tokens_per_message > |
| model_->feature_processor_options()->max_num_tokens_per_message()) { |
| *max_num_tokens_per_message = |
| model_->feature_processor_options()->max_num_tokens_per_message(); |
| } |
| |
| // Embed all tokens and add paddings to pad tokens of each message to the |
| // maximum number of tokens in a message of the conversation. |
| // If a number of tokens is specified in the model config, tokens at the |
| // beginning of a message are dropped if they don't fit in the limit. |
| for (int i = 0; i < num_messages; i++) { |
| const int start = |
| std::max<int>(tokens[i].size() - *max_num_tokens_per_message, 0); |
| for (int pos = start; pos < tokens[i].size(); pos++) { |
| if (!feature_processor_->AppendTokenFeatures( |
| tokens[i][pos], embedding_executor_.get(), embeddings)) { |
| TC3_LOG(ERROR) << "Could not run token feature extractor."; |
| return false; |
| } |
| } |
| // Add padding. |
| for (int k = tokens[i].size(); k < *max_num_tokens_per_message; k++) { |
| embeddings->insert(embeddings->end(), embedded_padding_token_.begin(), |
| embedded_padding_token_.end()); |
| } |
| } |
| |
| return true; |
| } |
| |
| bool ActionsSuggestions::EmbedAndFlattenTokens( |
| const std::vector<std::vector<Token>>& tokens, |
| std::vector<float>* embeddings, int* total_token_count) const { |
| const int num_messages = tokens.size(); |
| int start_message = 0; |
| int message_token_offset = 0; |
| |
| // If a maximum model input length is specified, we need to check how |
| // much we need to trim at the start. |
| const int max_num_total_tokens = |
| model_->feature_processor_options()->max_num_total_tokens(); |
| if (max_num_total_tokens > 0) { |
| int total_tokens = 0; |
| start_message = num_messages - 1; |
| for (; start_message >= 0; start_message--) { |
| // Tokens of the message + start and end token. |
| const int num_message_tokens = tokens[start_message].size() + 2; |
| total_tokens += num_message_tokens; |
| |
| // Check whether we exhausted the budget. |
| if (total_tokens >= max_num_total_tokens) { |
| message_token_offset = total_tokens - max_num_total_tokens; |
| break; |
| } |
| } |
| } |
| |
| // Add embeddings. |
| *total_token_count = 0; |
| for (int i = start_message; i < num_messages; i++) { |
| if (message_token_offset == 0) { |
| ++(*total_token_count); |
| // Add `start message` token. |
| embeddings->insert(embeddings->end(), embedded_start_token_.begin(), |
| embedded_start_token_.end()); |
| } |
| |
| for (int pos = std::max(0, message_token_offset - 1); |
| pos < tokens[i].size(); pos++) { |
| ++(*total_token_count); |
| if (!feature_processor_->AppendTokenFeatures( |
| tokens[i][pos], embedding_executor_.get(), embeddings)) { |
| TC3_LOG(ERROR) << "Could not run token feature extractor."; |
| return false; |
| } |
| } |
| |
| // Add `end message` token. |
| ++(*total_token_count); |
| embeddings->insert(embeddings->end(), embedded_end_token_.begin(), |
| embedded_end_token_.end()); |
| |
| // Reset for the subsequent messages. |
| message_token_offset = 0; |
| } |
| |
| // Add optional padding. |
| const int min_num_total_tokens = |
| model_->feature_processor_options()->min_num_total_tokens(); |
| for (; *total_token_count < min_num_total_tokens; ++(*total_token_count)) { |
| embeddings->insert(embeddings->end(), embedded_padding_token_.begin(), |
| embedded_padding_token_.end()); |
| } |
| |
| return true; |
| } |
| |
| bool ActionsSuggestions::AllocateInput(const int conversation_length, |
| const int max_tokens, |
| const int total_token_count, |
| tflite::Interpreter* interpreter) const { |
| if (model_->tflite_model_spec()->resize_inputs()) { |
| if (model_->tflite_model_spec()->input_context() >= 0) { |
| interpreter->ResizeInputTensor( |
| interpreter->inputs()[model_->tflite_model_spec()->input_context()], |
| {1, conversation_length}); |
| } |
| if (model_->tflite_model_spec()->input_user_id() >= 0) { |
| interpreter->ResizeInputTensor( |
| interpreter->inputs()[model_->tflite_model_spec()->input_user_id()], |
| {1, conversation_length}); |
| } |
| if (model_->tflite_model_spec()->input_time_diffs() >= 0) { |
| interpreter->ResizeInputTensor( |
| interpreter |
| ->inputs()[model_->tflite_model_spec()->input_time_diffs()], |
| {1, conversation_length}); |
| } |
| if (model_->tflite_model_spec()->input_num_tokens() >= 0) { |
| interpreter->ResizeInputTensor( |
| interpreter |
| ->inputs()[model_->tflite_model_spec()->input_num_tokens()], |
| {conversation_length, 1}); |
| } |
| if (model_->tflite_model_spec()->input_token_embeddings() >= 0) { |
| interpreter->ResizeInputTensor( |
| interpreter |
| ->inputs()[model_->tflite_model_spec()->input_token_embeddings()], |
| {conversation_length, max_tokens, token_embedding_size_}); |
| } |
| if (model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) { |
| interpreter->ResizeInputTensor( |
| interpreter->inputs()[model_->tflite_model_spec() |
| ->input_flattened_token_embeddings()], |
| {1, total_token_count}); |
| } |
| } |
| |
| return interpreter->AllocateTensors() == kTfLiteOk; |
| } |
| |
| bool ActionsSuggestions::SetupModelInput( |
| const std::vector<std::string>& context, const std::vector<int>& user_ids, |
| const std::vector<float>& time_diffs, const int num_suggestions, |
| const ActionSuggestionOptions& options, |
| tflite::Interpreter* interpreter) const { |
| // Compute token embeddings. |
| std::vector<std::vector<Token>> tokens; |
| std::vector<float> token_embeddings; |
| std::vector<float> flattened_token_embeddings; |
| int max_tokens = 0; |
| int total_token_count = 0; |
| if (model_->tflite_model_spec()->input_num_tokens() >= 0 || |
| model_->tflite_model_spec()->input_token_embeddings() >= 0 || |
| model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) { |
| if (feature_processor_ == nullptr) { |
| TC3_LOG(ERROR) << "No feature processor specified."; |
| return false; |
| } |
| |
| // Tokenize the messages in the conversation. |
| tokens = Tokenize(context); |
| if (model_->tflite_model_spec()->input_token_embeddings() >= 0) { |
| if (!EmbedTokensPerMessage(tokens, &token_embeddings, &max_tokens)) { |
| TC3_LOG(ERROR) << "Could not extract token features."; |
| return false; |
| } |
| } |
| if (model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) { |
| if (!EmbedAndFlattenTokens(tokens, &flattened_token_embeddings, |
| &total_token_count)) { |
| TC3_LOG(ERROR) << "Could not extract token features."; |
| return false; |
| } |
| } |
| } |
| |
| if (!AllocateInput(context.size(), max_tokens, total_token_count, |
| interpreter)) { |
| TC3_LOG(ERROR) << "TensorFlow Lite model allocation failed."; |
| return false; |
| } |
| if (model_->tflite_model_spec()->input_context() >= 0) { |
| if (model_->tflite_model_spec()->input_length_to_pad() > 0) { |
| model_executor_->SetInput<std::string>( |
| model_->tflite_model_spec()->input_context(), |
| PadOrTruncateToTargetLength( |
| context, model_->tflite_model_spec()->input_length_to_pad(), |
| std::string("")), |
| interpreter); |
| } else { |
| model_executor_->SetInput<std::string>( |
| model_->tflite_model_spec()->input_context(), context, interpreter); |
| } |
| } |
| if (model_->tflite_model_spec()->input_context_length() >= 0) { |
| model_executor_->SetInput<int>( |
| model_->tflite_model_spec()->input_context_length(), context.size(), |
| interpreter); |
| } |
| if (model_->tflite_model_spec()->input_user_id() >= 0) { |
| if (model_->tflite_model_spec()->input_length_to_pad() > 0) { |
| model_executor_->SetInput<int>( |
| model_->tflite_model_spec()->input_user_id(), |
| PadOrTruncateToTargetLength( |
| user_ids, model_->tflite_model_spec()->input_length_to_pad(), 0), |
| interpreter); |
| } else { |
| model_executor_->SetInput<int>( |
| model_->tflite_model_spec()->input_user_id(), user_ids, interpreter); |
| } |
| } |
| if (model_->tflite_model_spec()->input_num_suggestions() >= 0) { |
| model_executor_->SetInput<int>( |
| model_->tflite_model_spec()->input_num_suggestions(), num_suggestions, |
| interpreter); |
| } |
| if (model_->tflite_model_spec()->input_time_diffs() >= 0) { |
| model_executor_->SetInput<float>( |
| model_->tflite_model_spec()->input_time_diffs(), time_diffs, |
| interpreter); |
| } |
| if (model_->tflite_model_spec()->input_num_tokens() >= 0) { |
| std::vector<int> num_tokens_per_message(tokens.size()); |
| for (int i = 0; i < tokens.size(); i++) { |
| num_tokens_per_message[i] = tokens[i].size(); |
| } |
| model_executor_->SetInput<int>( |
| model_->tflite_model_spec()->input_num_tokens(), num_tokens_per_message, |
| interpreter); |
| } |
| if (model_->tflite_model_spec()->input_token_embeddings() >= 0) { |
| model_executor_->SetInput<float>( |
| model_->tflite_model_spec()->input_token_embeddings(), token_embeddings, |
| interpreter); |
| } |
| if (model_->tflite_model_spec()->input_flattened_token_embeddings() >= 0) { |
| model_executor_->SetInput<float>( |
| model_->tflite_model_spec()->input_flattened_token_embeddings(), |
| flattened_token_embeddings, interpreter); |
| } |
| // Set up additional input parameters. |
| if (const auto* input_name_index = |
| model_->tflite_model_spec()->input_name_index()) { |
| const std::unordered_map<std::string, Variant>& model_parameters = |
| options.model_parameters; |
| for (const TensorflowLiteModelSpec_::InputNameIndexEntry* entry : |
| *input_name_index) { |
| const std::string param_name = entry->key()->str(); |
| const int param_index = entry->value(); |
| const TfLiteType param_type = |
| interpreter->tensor(interpreter->inputs()[param_index])->type; |
| const auto param_value_it = model_parameters.find(param_name); |
| const bool has_value = param_value_it != model_parameters.end(); |
| switch (param_type) { |
| case kTfLiteFloat32: |
| if (has_value) { |
| SetVectorOrScalarAsModelInput<float>(param_index, |
| param_value_it->second, |
| interpreter, model_executor_); |
| } else { |
| model_executor_->SetInput<float>(param_index, kDefaultFloat, |
| interpreter); |
| } |
| break; |
| case kTfLiteInt32: |
| if (has_value) { |
| SetVectorOrScalarAsModelInput<int32_t>( |
| param_index, param_value_it->second, interpreter, |
| model_executor_); |
| } else { |
| model_executor_->SetInput<int32_t>(param_index, kDefaultInt, |
| interpreter); |
| } |
| break; |
| case kTfLiteInt64: |
| model_executor_->SetInput<int64_t>( |
| param_index, |
| has_value ? param_value_it->second.Value<int64>() : kDefaultInt, |
| interpreter); |
| break; |
| case kTfLiteUInt8: |
| model_executor_->SetInput<uint8_t>( |
| param_index, |
| has_value ? param_value_it->second.Value<uint8>() : kDefaultInt, |
| interpreter); |
| break; |
| case kTfLiteInt8: |
| model_executor_->SetInput<int8_t>( |
| param_index, |
| has_value ? param_value_it->second.Value<int8>() : kDefaultInt, |
| interpreter); |
| break; |
| case kTfLiteBool: |
| model_executor_->SetInput<bool>( |
| param_index, |
| has_value ? param_value_it->second.Value<bool>() : kDefaultBool, |
| interpreter); |
| break; |
| default: |
| TC3_LOG(ERROR) << "Unsupported type of additional input parameter: " |
| << param_name; |
| } |
| } |
| } |
| return true; |
| } |
| |
| void ActionsSuggestions::PopulateTextReplies( |
| const tflite::Interpreter* interpreter, int suggestion_index, |
| int score_index, const std::string& type, float priority_score, |
| const absl::flat_hash_set<std::string>& blocklist, |
| const absl::flat_hash_map<std::string, std::vector<std::string>>& |
| concept_mappings, |
| ActionsSuggestionsResponse* response) const { |
| const std::vector<tflite::StringRef> replies = |
| model_executor_->Output<tflite::StringRef>(suggestion_index, interpreter); |
| const TensorView<float> scores = |
| model_executor_->OutputView<float>(score_index, interpreter); |
| |
| for (int i = 0; i < replies.size(); i++) { |
| if (replies[i].len == 0) { |
| continue; |
| } |
| const float score = scores.data()[i]; |
| if (score < preconditions_.min_reply_score_threshold) { |
| continue; |
| } |
| std::string response_text(replies[i].str, replies[i].len); |
| if (blocklist.contains(response_text)) { |
| continue; |
| } |
| if (concept_mappings.contains(response_text)) { |
| const int candidates_size = concept_mappings.at(response_text).size(); |
| const int candidate_index = absl::Uniform<int>( |
| absl::IntervalOpenOpen, bit_gen_, 0, candidates_size); |
| response_text = concept_mappings.at(response_text)[candidate_index]; |
| } |
| |
| response->actions.push_back({response_text, type, score, priority_score}); |
| } |
| } |
| |
| void ActionsSuggestions::FillSuggestionFromSpecWithEntityData( |
| const ActionSuggestionSpec* spec, ActionSuggestion* suggestion) const { |
| std::unique_ptr<MutableFlatbuffer> entity_data = |
| entity_data_builder_ != nullptr ? entity_data_builder_->NewRoot() |
| : nullptr; |
| FillSuggestionFromSpec(spec, entity_data.get(), suggestion); |
| } |
| |
| void ActionsSuggestions::PopulateIntentTriggering( |
| const tflite::Interpreter* interpreter, int suggestion_index, |
| int score_index, const ActionSuggestionSpec* task_spec, |
| ActionsSuggestionsResponse* response) const { |
| if (!task_spec || task_spec->type()->size() == 0) { |
| TC3_LOG(ERROR) |
| << "Task type for intent (action) triggering cannot be empty!"; |
| return; |
| } |
| const TensorView<bool> intent_prediction = |
| model_executor_->OutputView<bool>(suggestion_index, interpreter); |
| const TensorView<float> intent_scores = |
| model_executor_->OutputView<float>(score_index, interpreter); |
| // Two result corresponding to binary triggering case. |
| TC3_CHECK_EQ(intent_prediction.size(), 2); |
| TC3_CHECK_EQ(intent_scores.size(), 2); |
| // We rely on in-graph thresholding logic so at this point the results |
| // have been ranked properly according to threshold. |
| const bool triggering = intent_prediction.data()[0]; |
| const float trigger_score = intent_scores.data()[0]; |
| |
| if (triggering) { |
| ActionSuggestion suggestion; |
| std::unique_ptr<MutableFlatbuffer> entity_data = |
| entity_data_builder_ != nullptr ? entity_data_builder_->NewRoot() |
| : nullptr; |
| FillSuggestionFromSpecWithEntityData(task_spec, &suggestion); |
| suggestion.score = trigger_score; |
| response->actions.push_back(std::move(suggestion)); |
| } |
| } |
| |
| bool ActionsSuggestions::ReadModelOutput( |
| tflite::Interpreter* interpreter, const ActionSuggestionOptions& options, |
| ActionsSuggestionsResponse* response) const { |
| // Read sensitivity and triggering score predictions. |
| if (model_->tflite_model_spec()->output_triggering_score() >= 0) { |
| const TensorView<float> triggering_score = |
| model_executor_->OutputView<float>( |
| model_->tflite_model_spec()->output_triggering_score(), |
| interpreter); |
| if (!triggering_score.is_valid() || triggering_score.size() == 0) { |
| TC3_LOG(ERROR) << "Could not compute triggering score."; |
| return false; |
| } |
| response->triggering_score = triggering_score.data()[0]; |
| response->output_filtered_min_triggering_score = |
| (response->triggering_score < |
| preconditions_.min_smart_reply_triggering_score); |
| } |
| if (model_->tflite_model_spec()->output_sensitive_topic_score() >= 0) { |
| const TensorView<float> sensitive_topic_score = |
| model_executor_->OutputView<float>( |
| model_->tflite_model_spec()->output_sensitive_topic_score(), |
| interpreter); |
| if (!sensitive_topic_score.is_valid() || |
| sensitive_topic_score.dim(0) != 1) { |
| TC3_LOG(ERROR) << "Could not compute sensitive topic score."; |
| return false; |
| } |
| response->sensitivity_score = sensitive_topic_score.data()[0]; |
| response->is_sensitive = (response->sensitivity_score > |
| preconditions_.max_sensitive_topic_score); |
| } |
| |
| // Suppress model outputs. |
| if (response->is_sensitive) { |
| return true; |
| } |
| |
| // Read smart reply predictions. |
| if (!response->output_filtered_min_triggering_score && |
| model_->tflite_model_spec()->output_replies() >= 0) { |
| absl::flat_hash_set<std::string> empty_blocklist; |
| PopulateTextReplies( |
| interpreter, model_->tflite_model_spec()->output_replies(), |
| model_->tflite_model_spec()->output_replies_scores(), |
| model_->smart_reply_action_type()->str(), |
| /* priority_score */ 0.0, empty_blocklist, {}, response); |
| } |
| |
| // Read actions suggestions. |
| if (model_->tflite_model_spec()->output_actions_scores() >= 0) { |
| const TensorView<float> actions_scores = model_executor_->OutputView<float>( |
| model_->tflite_model_spec()->output_actions_scores(), interpreter); |
| for (int i = 0; i < model_->action_type()->size(); i++) { |
| const ActionTypeOptions* action_type = model_->action_type()->Get(i); |
| // Skip disabled action classes, such as the default other category. |
| if (!action_type->enabled()) { |
| continue; |
| } |
| const float score = actions_scores.data()[i]; |
| if (score < action_type->min_triggering_score()) { |
| continue; |
| } |
| |
| // Create action from model output. |
| ActionSuggestion suggestion; |
| suggestion.type = action_type->name()->str(); |
| std::unique_ptr<MutableFlatbuffer> entity_data = |
| entity_data_builder_ != nullptr ? entity_data_builder_->NewRoot() |
| : nullptr; |
| FillSuggestionFromSpecWithEntityData(action_type->action(), &suggestion); |
| suggestion.score = score; |
| response->actions.push_back(std::move(suggestion)); |
| } |
| } |
| |
| // Read multi-task predictions and construct the result properly. |
| if (const auto* prediction_metadata = |
| model_->tflite_model_spec()->prediction_metadata()) { |
| for (const PredictionMetadata* metadata : *prediction_metadata) { |
| const ActionSuggestionSpec* task_spec = metadata->task_spec(); |
| const int suggestions_index = metadata->output_suggestions(); |
| const int suggestions_scores_index = |
| metadata->output_suggestions_scores(); |
| absl::flat_hash_set<std::string> response_text_blocklist; |
| absl::flat_hash_map<std::string, std::vector<std::string>> |
| concept_mappings; |
| switch (metadata->prediction_type()) { |
| case PredictionType_NEXT_MESSAGE_PREDICTION: |
| if (!task_spec || task_spec->type()->size() == 0) { |
| TC3_LOG(WARNING) << "Task type not provided, use default " |
| "smart_reply_action_type!"; |
| } |
| if (task_spec) { |
| if (task_spec->response_text_blocklist()) { |
| for (const auto& val : *task_spec->response_text_blocklist()) { |
| response_text_blocklist.insert(val->str()); |
| } |
| } |
| if (task_spec->concept_mappings()) { |
| for (const auto& concept : *task_spec->concept_mappings()) { |
| std::vector<std::string> candidates; |
| for (const auto& candidate : *concept->candidates()) { |
| candidates.push_back(candidate->str()); |
| } |
| concept_mappings[concept->concept_name()->str()] = candidates; |
| } |
| } |
| } |
| PopulateTextReplies( |
| interpreter, suggestions_index, suggestions_scores_index, |
| task_spec ? task_spec->type()->str() |
| : model_->smart_reply_action_type()->str(), |
| task_spec ? task_spec->priority_score() : 0.0, |
| response_text_blocklist, concept_mappings, response); |
| break; |
| case PredictionType_INTENT_TRIGGERING: |
| PopulateIntentTriggering(interpreter, suggestions_index, |
| suggestions_scores_index, task_spec, |
| response); |
| break; |
| default: |
| TC3_LOG(ERROR) << "Unsupported prediction type!"; |
| return false; |
| } |
| } |
| } |
| |
| return true; |
| } |
| |
| bool ActionsSuggestions::SuggestActionsFromModel( |
| const Conversation& conversation, const int num_messages, |
| const ActionSuggestionOptions& options, |
| ActionsSuggestionsResponse* response, |
| std::unique_ptr<tflite::Interpreter>* interpreter) const { |
| TC3_CHECK_LE(num_messages, conversation.messages.size()); |
| |
| if (sensitive_model_ != nullptr && |
| sensitive_model_->EvalConversation(conversation, num_messages).first) { |
| response->is_sensitive = true; |
| return true; |
| } |
| |
| if (!model_executor_) { |
| return true; |
| } |
| *interpreter = model_executor_->CreateInterpreter(); |
| |
| if (!*interpreter) { |
| TC3_LOG(ERROR) << "Could not build TensorFlow Lite interpreter for the " |
| "actions suggestions model."; |
| return false; |
| } |
| |
| std::vector<std::string> context; |
| std::vector<int> user_ids; |
| std::vector<float> time_diffs; |
| context.reserve(num_messages); |
| user_ids.reserve(num_messages); |
| time_diffs.reserve(num_messages); |
| |
| // Gather last `num_messages` messages from the conversation. |
| int64 last_message_reference_time_ms_utc = 0; |
| const float second_in_ms = 1000; |
| for (int i = conversation.messages.size() - num_messages; |
| i < conversation.messages.size(); i++) { |
| const ConversationMessage& message = conversation.messages[i]; |
| context.push_back(message.text); |
| user_ids.push_back(message.user_id); |
| |
| float time_diff_secs = 0; |
| if (message.reference_time_ms_utc != 0 && |
| last_message_reference_time_ms_utc != 0) { |
| time_diff_secs = std::max(0.0f, (message.reference_time_ms_utc - |
| last_message_reference_time_ms_utc) / |
| second_in_ms); |
| } |
| if (message.reference_time_ms_utc != 0) { |
| last_message_reference_time_ms_utc = message.reference_time_ms_utc; |
| } |
| time_diffs.push_back(time_diff_secs); |
| } |
| |
| if (!SetupModelInput(context, user_ids, time_diffs, |
| /*num_suggestions=*/model_->num_smart_replies(), options, |
| interpreter->get())) { |
| TC3_LOG(ERROR) << "Failed to setup input for TensorFlow Lite model."; |
| return false; |
| } |
| |
| if ((*interpreter)->Invoke() != kTfLiteOk) { |
| TC3_LOG(ERROR) << "Failed to invoke TensorFlow Lite interpreter."; |
| return false; |
| } |
| |
| return ReadModelOutput(interpreter->get(), options, response); |
| } |
| |
| Status ActionsSuggestions::SuggestActionsFromConversationIntentDetection( |
| const Conversation& conversation, const ActionSuggestionOptions& options, |
| std::vector<ActionSuggestion>* actions) const { |
| TC3_ASSIGN_OR_RETURN( |
| std::vector<ActionSuggestion> new_actions, |
| conversation_intent_detection_->SuggestActions(conversation, options)); |
| for (auto& action : new_actions) { |
| actions->push_back(std::move(action)); |
| } |
| return Status::OK; |
| } |
| |
| AnnotationOptions ActionsSuggestions::AnnotationOptionsForMessage( |
| const ConversationMessage& message) const { |
| AnnotationOptions options; |
| options.detected_text_language_tags = message.detected_text_language_tags; |
| options.reference_time_ms_utc = message.reference_time_ms_utc; |
| options.reference_timezone = message.reference_timezone; |
| options.annotation_usecase = |
| model_->annotation_actions_spec()->annotation_usecase(); |
| options.is_serialized_entity_data_enabled = |
| model_->annotation_actions_spec()->is_serialized_entity_data_enabled(); |
| options.entity_types = annotation_entity_types_; |
| return options; |
| } |
| |
| // Run annotator on the messages of a conversation. |
| Conversation ActionsSuggestions::AnnotateConversation( |
| const Conversation& conversation, const Annotator* annotator) const { |
| if (annotator == nullptr) { |
| return conversation; |
| } |
| const int num_messages_grammar = |
| ((model_->rules() && model_->rules()->grammar_rules() && |
| model_->rules() |
| ->grammar_rules() |
| ->rules() |
| ->nonterminals() |
| ->annotation_nt()) |
| ? 1 |
| : 0); |
| const int num_messages_mapping = |
| (model_->annotation_actions_spec() |
| ? std::max(model_->annotation_actions_spec() |
| ->max_history_from_any_person(), |
| model_->annotation_actions_spec() |
| ->max_history_from_last_person()) |
| : 0); |
| const int num_messages = std::max(num_messages_grammar, num_messages_mapping); |
| if (num_messages == 0) { |
| // No annotations are used. |
| return conversation; |
| } |
| Conversation annotated_conversation = conversation; |
| for (int i = 0, message_index = annotated_conversation.messages.size() - 1; |
| i < num_messages && message_index >= 0; i++, message_index--) { |
| ConversationMessage* message = |
| &annotated_conversation.messages[message_index]; |
| if (message->annotations.empty()) { |
| message->annotations = annotator->Annotate( |
| message->text, AnnotationOptionsForMessage(*message)); |
| ConvertDatetimeToTime(&message->annotations); |
| } |
| } |
| return annotated_conversation; |
| } |
| |
| void ActionsSuggestions::SuggestActionsFromAnnotations( |
| const Conversation& conversation, |
| std::vector<ActionSuggestion>* actions) const { |
| if (model_->annotation_actions_spec() == nullptr || |
| model_->annotation_actions_spec()->annotation_mapping() == nullptr || |
| model_->annotation_actions_spec()->annotation_mapping()->size() == 0) { |
| return; |
| } |
| |
| // Create actions based on the annotations. |
| const int max_from_any_person = |
| model_->annotation_actions_spec()->max_history_from_any_person(); |
| const int max_from_last_person = |
| model_->annotation_actions_spec()->max_history_from_last_person(); |
| const int last_person = conversation.messages.back().user_id; |
| |
| int num_messages_last_person = 0; |
| int num_messages_any_person = 0; |
| bool all_from_last_person = true; |
| for (int message_index = conversation.messages.size() - 1; message_index >= 0; |
| message_index--) { |
| const ConversationMessage& message = conversation.messages[message_index]; |
| std::vector<AnnotatedSpan> annotations = message.annotations; |
| |
| // Update how many messages we have processed from the last person in the |
| // conversation and from any person in the conversation. |
| num_messages_any_person++; |
| if (all_from_last_person && message.user_id == last_person) { |
| num_messages_last_person++; |
| } else { |
| all_from_last_person = false; |
| } |
| |
| if (num_messages_any_person > max_from_any_person && |
| (!all_from_last_person || |
| num_messages_last_person > max_from_last_person)) { |
| break; |
| } |
| |
| if (message.user_id == kLocalUserId) { |
| if (model_->annotation_actions_spec()->only_until_last_sent()) { |
| break; |
| } |
| if (!model_->annotation_actions_spec()->include_local_user_messages()) { |
| continue; |
| } |
| } |
| |
| std::vector<ActionSuggestionAnnotation> action_annotations; |
| action_annotations.reserve(annotations.size()); |
| for (const AnnotatedSpan& annotation : annotations) { |
| if (annotation.classification.empty()) { |
| continue; |
| } |
| |
| const ClassificationResult& classification_result = |
| annotation.classification[0]; |
| |
| ActionSuggestionAnnotation action_annotation; |
| action_annotation.span = { |
| message_index, annotation.span, |
| UTF8ToUnicodeText(message.text, /*do_copy=*/false) |
| .UTF8Substring(annotation.span.first, annotation.span.second)}; |
| action_annotation.entity = classification_result; |
| action_annotation.name = classification_result.collection; |
| action_annotations.push_back(std::move(action_annotation)); |
| } |
| |
| if (model_->annotation_actions_spec()->deduplicate_annotations()) { |
| // Create actions only for deduplicated annotations. |
| for (const int annotation_id : |
| DeduplicateAnnotations(action_annotations)) { |
| SuggestActionsFromAnnotation( |
| message_index, action_annotations[annotation_id], actions); |
| } |
| } else { |
| // Create actions for all annotations. |
| for (const ActionSuggestionAnnotation& annotation : action_annotations) { |
| SuggestActionsFromAnnotation(message_index, annotation, actions); |
| } |
| } |
| } |
| } |
| |
| void ActionsSuggestions::SuggestActionsFromAnnotation( |
| const int message_index, const ActionSuggestionAnnotation& annotation, |
| std::vector<ActionSuggestion>* actions) const { |
| for (const AnnotationActionsSpec_::AnnotationMapping* mapping : |
| *model_->annotation_actions_spec()->annotation_mapping()) { |
| if (annotation.entity.collection == |
| mapping->annotation_collection()->str()) { |
| if (annotation.entity.score < mapping->min_annotation_score()) { |
| continue; |
| } |
| |
| std::unique_ptr<MutableFlatbuffer> entity_data = |
| entity_data_builder_ != nullptr ? entity_data_builder_->NewRoot() |
| : nullptr; |
| |
| // Set annotation text as (additional) entity data field. |
| if (mapping->entity_field() != nullptr) { |
| TC3_CHECK_NE(entity_data, nullptr); |
| |
| UnicodeText normalized_annotation_text = |
| UTF8ToUnicodeText(annotation.span.text, /*do_copy=*/false); |
| |
| // Apply normalization if specified. |
| if (mapping->normalization_options() != nullptr) { |
| normalized_annotation_text = |
| NormalizeText(*unilib_, mapping->normalization_options(), |
| normalized_annotation_text); |
| } |
| |
| entity_data->ParseAndSet(mapping->entity_field(), |
| normalized_annotation_text.ToUTF8String()); |
| } |
| |
| ActionSuggestion suggestion; |
| FillSuggestionFromSpec(mapping->action(), entity_data.get(), &suggestion); |
| if (mapping->use_annotation_score()) { |
| suggestion.score = annotation.entity.score; |
| } |
| suggestion.annotations = {annotation}; |
| actions->push_back(std::move(suggestion)); |
| } |
| } |
| } |
| |
| std::vector<int> ActionsSuggestions::DeduplicateAnnotations( |
| const std::vector<ActionSuggestionAnnotation>& annotations) const { |
| std::map<std::pair<std::string, std::string>, int> deduplicated_annotations; |
| |
| for (int i = 0; i < annotations.size(); i++) { |
| const std::pair<std::string, std::string> key = {annotations[i].name, |
| annotations[i].span.text}; |
| auto entry = deduplicated_annotations.find(key); |
| if (entry != deduplicated_annotations.end()) { |
| // Kepp the annotation with the higher score. |
| if (annotations[entry->second].entity.score < |
| annotations[i].entity.score) { |
| entry->second = i; |
| } |
| continue; |
| } |
| deduplicated_annotations.insert(entry, {key, i}); |
| } |
| |
| std::vector<int> result; |
| result.reserve(deduplicated_annotations.size()); |
| for (const auto& key_and_annotation : deduplicated_annotations) { |
| result.push_back(key_and_annotation.second); |
| } |
| return result; |
| } |
| |
| #if !defined(TC3_DISABLE_LUA) |
| bool ActionsSuggestions::SuggestActionsFromLua( |
| const Conversation& conversation, const TfLiteModelExecutor* model_executor, |
| const tflite::Interpreter* interpreter, |
| const reflection::Schema* annotation_entity_data_schema, |
| std::vector<ActionSuggestion>* actions) const { |
| if (lua_bytecode_.empty()) { |
| return true; |
| } |
| |
| auto lua_actions = LuaActionsSuggestions::CreateLuaActionsSuggestions( |
| lua_bytecode_, conversation, model_executor, model_->tflite_model_spec(), |
| interpreter, entity_data_schema_, annotation_entity_data_schema); |
| if (lua_actions == nullptr) { |
| TC3_LOG(ERROR) << "Could not create lua actions."; |
| return false; |
| } |
| return lua_actions->SuggestActions(actions); |
| } |
| #else |
| bool ActionsSuggestions::SuggestActionsFromLua( |
| const Conversation& conversation, const TfLiteModelExecutor* model_executor, |
| const tflite::Interpreter* interpreter, |
| const reflection::Schema* annotation_entity_data_schema, |
| std::vector<ActionSuggestion>* actions) const { |
| return true; |
| } |
| #endif |
| |
| bool ActionsSuggestions::GatherActionsSuggestions( |
| const Conversation& conversation, const Annotator* annotator, |
| const ActionSuggestionOptions& options, |
| ActionsSuggestionsResponse* response) const { |
| if (conversation.messages.empty()) { |
| return true; |
| } |
| |
| // Run annotator against messages. |
| const Conversation annotated_conversation = |
| AnnotateConversation(conversation, annotator); |
| |
| const int num_messages = NumMessagesToConsider( |
| annotated_conversation, model_->max_conversation_history_length()); |
| |
| if (num_messages <= 0) { |
| TC3_LOG(INFO) << "No messages provided for actions suggestions."; |
| return false; |
| } |
| |
| SuggestActionsFromAnnotations(annotated_conversation, &response->actions); |
| |
| if (grammar_actions_ != nullptr && |
| !grammar_actions_->SuggestActions(annotated_conversation, |
| &response->actions)) { |
| TC3_LOG(ERROR) << "Could not suggest actions from grammar rules."; |
| return false; |
| } |
| |
| int input_text_length = 0; |
| int num_matching_locales = 0; |
| for (int i = annotated_conversation.messages.size() - num_messages; |
| i < annotated_conversation.messages.size(); i++) { |
| input_text_length += annotated_conversation.messages[i].text.length(); |
| std::vector<Locale> message_languages; |
| if (!ParseLocales( |
| annotated_conversation.messages[i].detected_text_language_tags, |
| &message_languages)) { |
| continue; |
| } |
| if (Locale::IsAnyLocaleSupported( |
| message_languages, locales_, |
| preconditions_.handle_unknown_locale_as_supported)) { |
| ++num_matching_locales; |
| } |
| } |
| |
| // Bail out if we are provided with too few or too much input. |
| if (input_text_length < preconditions_.min_input_length || |
| (preconditions_.max_input_length >= 0 && |
| input_text_length > preconditions_.max_input_length)) { |
| TC3_LOG(INFO) << "Too much or not enough input for inference."; |
| return response; |
| } |
| |
| // Bail out if the text does not look like it can be handled by the model. |
| const float matching_fraction = |
| static_cast<float>(num_matching_locales) / num_messages; |
| if (matching_fraction < preconditions_.min_locale_match_fraction) { |
| TC3_LOG(INFO) << "Not enough locale matches."; |
| response->output_filtered_locale_mismatch = true; |
| return true; |
| } |
| |
| std::vector<const UniLib::RegexPattern*> post_check_rules; |
| if (preconditions_.suppress_on_low_confidence_input) { |
| if (regex_actions_->IsLowConfidenceInput(annotated_conversation, |
| num_messages, &post_check_rules)) { |
| response->output_filtered_low_confidence = true; |
| return true; |
| } |
| } |
| |
| std::unique_ptr<tflite::Interpreter> interpreter; |
| if (!SuggestActionsFromModel(annotated_conversation, num_messages, options, |
| response, &interpreter)) { |
| TC3_LOG(ERROR) << "Could not run model."; |
| return false; |
| } |
| |
| // SuggestActionsFromModel also detects if the conversation is sensitive, |
| // either by using the old ngram model or the new model. |
| // Suppress all predictions if the conversation was deemed sensitive. |
| if (preconditions_.suppress_on_sensitive_topic && response->is_sensitive) { |
| return true; |
| } |
| |
| if (conversation_intent_detection_) { |
| // TODO(zbin): Ensure the deduplication/ranking logic in ranker.cc works. |
| auto actions = SuggestActionsFromConversationIntentDetection( |
| annotated_conversation, options, &response->actions); |
| if (!actions.ok()) { |
| TC3_LOG(ERROR) << "Could not run conversation intent detection: " |
| << actions.error_message(); |
| return false; |
| } |
| } |
| |
| if (!SuggestActionsFromLua( |
| annotated_conversation, model_executor_.get(), interpreter.get(), |
| annotator != nullptr ? annotator->entity_data_schema() : nullptr, |
| &response->actions)) { |
| TC3_LOG(ERROR) << "Could not suggest actions from script."; |
| return false; |
| } |
| |
| if (!regex_actions_->SuggestActions(annotated_conversation, |
| entity_data_builder_.get(), |
| &response->actions)) { |
| TC3_LOG(ERROR) << "Could not suggest actions from regex rules."; |
| return false; |
| } |
| |
| if (preconditions_.suppress_on_low_confidence_input && |
| !regex_actions_->FilterConfidenceOutput(post_check_rules, |
| &response->actions)) { |
| TC3_LOG(ERROR) << "Could not post-check actions."; |
| return false; |
| } |
| |
| return true; |
| } |
| |
| ActionsSuggestionsResponse ActionsSuggestions::SuggestActions( |
| const Conversation& conversation, const Annotator* annotator, |
| const ActionSuggestionOptions& options) const { |
| ActionsSuggestionsResponse response; |
| |
| // Assert that messages are sorted correctly. |
| for (int i = 1; i < conversation.messages.size(); i++) { |
| if (conversation.messages[i].reference_time_ms_utc < |
| conversation.messages[i - 1].reference_time_ms_utc) { |
| TC3_LOG(ERROR) << "Messages are not sorted most recent last."; |
| return response; |
| } |
| } |
| |
| // Check that messages are valid utf8. |
| for (const ConversationMessage& message : conversation.messages) { |
| if (message.text.size() > std::numeric_limits<int>::max()) { |
| TC3_LOG(ERROR) << "Rejecting too long input: " << message.text.size(); |
| return {}; |
| } |
| |
| if (!unilib_->IsValidUtf8(UTF8ToUnicodeText( |
| message.text.data(), message.text.size(), /*do_copy=*/false))) { |
| TC3_LOG(ERROR) << "Not valid utf8 provided."; |
| return response; |
| } |
| } |
| |
| if (!GatherActionsSuggestions(conversation, annotator, options, &response)) { |
| TC3_LOG(ERROR) << "Could not gather actions suggestions."; |
| response.actions.clear(); |
| } else if (!ranker_->RankActions(conversation, &response, entity_data_schema_, |
| annotator != nullptr |
| ? annotator->entity_data_schema() |
| : nullptr)) { |
| TC3_LOG(ERROR) << "Could not rank actions."; |
| response.actions.clear(); |
| } |
| return response; |
| } |
| |
| ActionsSuggestionsResponse ActionsSuggestions::SuggestActions( |
| const Conversation& conversation, |
| const ActionSuggestionOptions& options) const { |
| return SuggestActions(conversation, /*annotator=*/nullptr, options); |
| } |
| |
| const ActionsModel* ActionsSuggestions::model() const { return model_; } |
| const reflection::Schema* ActionsSuggestions::entity_data_schema() const { |
| return entity_data_schema_; |
| } |
| |
| const ActionsModel* ViewActionsModel(const void* buffer, int size) { |
| if (buffer == nullptr) { |
| return nullptr; |
| } |
| return LoadAndVerifyModel(reinterpret_cast<const uint8_t*>(buffer), size); |
| } |
| |
| bool ActionsSuggestions::InitializeConversationIntentDetection( |
| const std::string& serialized_config) { |
| auto conversation_intent_detection = |
| std::make_unique<ConversationIntentDetection>(); |
| if (!conversation_intent_detection->Initialize(serialized_config).ok()) { |
| TC3_LOG(ERROR) << "Failed to initialize conversation intent detection."; |
| return false; |
| } |
| conversation_intent_detection_ = std::move(conversation_intent_detection); |
| return true; |
| } |
| |
| } // namespace libtextclassifier3 |