blob: fa7807e1c2b4180f6bd77364cac10772cf1c7eaf [file] [log] [blame]
* Copyright (C) 2018 The Android Open Source Project
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* See the License for the specific language governing permissions and
* limitations under the License.
#include <memory>
#include <string>
#include <vector>
#include "actions/actions_model_generated.h"
#include "annotator/annotator.h"
#include "annotator/types.h"
#include "utils/memory/mmap.h"
#include "utils/tflite-model-executor.h"
namespace libtextclassifier3 {
// An entity associated with an action.
struct ActionSuggestionAnnotation {
// The referenced message.
// -1 if not referencing a particular message in the provided input.
int message_index;
// The span within the reference message.
// (-1, -1) if not referencing a particular location.
CodepointSpan span;
ClassificationResult entity;
// Optional annotation name.
std::string name;
explicit ActionSuggestionAnnotation()
: message_index(kInvalidIndex), span({kInvalidIndex, kInvalidIndex}) {}
// Action suggestion that contains a response text and the type of the response.
struct ActionSuggestion {
// Text of the action suggestion.
std::string response_text;
// Type of the action suggestion.
std::string type;
// Score.
float score;
// The associated annotations.
std::vector<ActionSuggestionAnnotation> annotations;
// Actions suggestions result containing meta-information and the suggested
// actions.
struct ActionsSuggestionsResponse {
: sensitivity_score(-1),
output_filtered_min_triggering_score(false) {}
// The sensitivity assessment.
float sensitivity_score;
float triggering_score;
// Whether the output was suppressed by the sensitivity threshold.
bool output_filtered_sensitivity;
// Whether the output was suppressed by the triggering score threshold.
bool output_filtered_min_triggering_score;
// The suggested actions.
std::vector<ActionSuggestion> actions;
// Represents a single message in the conversation.
struct ConversationMessage {
// User ID distinguishing the user from other users in the conversation.
int user_id;
// Text of the message.
std::string text;
// Relative time to previous message.
float time_diff_secs;
// Annotations on the text.
std::vector<AnnotatedSpan> annotations;
// Comma-separated list of locale specification for the text in the
// conversation (BCP 47 tags).
std::string locales;
// Conversation between multiple users.
struct Conversation {
// Sequence of messages that were exchanged in the conversation.
std::vector<ConversationMessage> messages;
// Options for suggesting actions.
struct ActionSuggestionOptions {
// Options for annotation of the messages.
AnnotationOptions annotation_options = AnnotationOptions::Default();
static ActionSuggestionOptions Default() { return ActionSuggestionOptions(); }
// Class for predicting actions following a conversation.
class ActionsSuggestions {
static std::unique_ptr<ActionsSuggestions> FromUnownedBuffer(
const uint8_t* buffer, const int size);
// Takes ownership of the mmap.
static std::unique_ptr<ActionsSuggestions> FromScopedMmap(
std::unique_ptr<libtextclassifier3::ScopedMmap> mmap);
static std::unique_ptr<ActionsSuggestions> FromFileDescriptor(
const int fd, const int offset, const int size);
static std::unique_ptr<ActionsSuggestions> FromFileDescriptor(const int fd);
static std::unique_ptr<ActionsSuggestions> FromPath(const std::string& path);
ActionsSuggestionsResponse SuggestActions(
const Conversation& conversation,
const ActionSuggestionOptions& options =
ActionSuggestionOptions::Default()) const;
// Provide an annotator.
void SetAnnotator(const Annotator* annotator);
// Should be in sync with those defined in Android.
// android/frameworks/base/core/java/android/view/textclassifier/
static const std::string& kViewCalendarType;
static const std::string& kViewMapType;
static const std::string& kTrackFlightType;
static const std::string& kOpenUrlType;
static const std::string& kSendSmsType;
static const std::string& kCallPhoneType;
static const std::string& kSendEmailType;
// Checks that model contains all required fields, and initializes internal
// datastructures.
bool ValidateAndInitialize();
void SetupModelInput(const std::vector<std::string>& context,
const std::vector<int>& user_ids,
const std::vector<float>& time_diffs,
const int num_suggestions,
tflite::Interpreter* interpreter) const;
void ReadModelOutput(tflite::Interpreter* interpreter,
ActionsSuggestionsResponse* response) const;
void SuggestActionsFromModel(const Conversation& conversation,
const int num_messages,
ActionsSuggestionsResponse* response) const;
void SuggestActionsFromAnnotations(
const Conversation& conversation, const ActionSuggestionOptions& options,
ActionsSuggestionsResponse* suggestions) const;
void CreateActionsFromAnnotationResult(
const int message_index, const AnnotatedSpan& annotation,
ActionsSuggestionsResponse* suggestions) const;
const ActionsModel* model_;
std::unique_ptr<libtextclassifier3::ScopedMmap> mmap_;
// Tensorflow Lite models.
std::unique_ptr<const TfLiteModelExecutor> model_executor_;
// Annotator.
const Annotator* annotator_ = nullptr;
// Interprets the buffer as a Model flatbuffer and returns it for reading.
const ActionsModel* ViewActionsModel(const void* buffer, int size);
} // namespace libtextclassifier3