blob: 522372ce243b65425c06bdd5f48e854279eb44fc [file] [log] [blame]
/*
* Copyright (C) 2017 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
// Inference code for the feed-forward text classification models.
#ifndef LIBTEXTCLASSIFIER_SMARTSELECT_TEXT_CLASSIFICATION_MODEL_H_
#define LIBTEXTCLASSIFIER_SMARTSELECT_TEXT_CLASSIFICATION_MODEL_H_
#include <memory>
#include <set>
#include <string>
#include "base.h"
#include "common/embedding-network.h"
#include "common/feature-extractor.h"
#include "common/memory_image/embedding-network-params-from-image.h"
#include "common/mmap.h"
#include "smartselect/feature-processor.h"
#include "smartselect/model-params.h"
#include "smartselect/text-classification-model.pb.h"
#include "smartselect/types.h"
namespace libtextclassifier {
// SmartSelection/Sharing feed-forward model.
class TextClassificationModel {
public:
// Loads TextClassificationModel from given file given by an int
// file descriptor.
explicit TextClassificationModel(int fd);
// Bit flags for the input selection.
enum SelectionInputFlags { SELECTION_IS_URL = 0x1, SELECTION_IS_EMAIL = 0x2 };
// Runs inference for given a context and current selection (i.e. index
// of the first and one past last selected characters (utf8 codepoint
// offsets)). Returns the indices (utf8 codepoint offsets) of the selection
// beginning character and one past selection end character.
// Returns the original click_indices if an error occurs.
// NOTE: The selection indices are passed in and returned in terms of
// UTF8 codepoints (not bytes).
// Requires that the model is a smart selection model.
CodepointSpan SuggestSelection(const std::string& context,
CodepointSpan click_indices) const;
// Classifies the selected text given the context string.
// Requires that the model is a smart sharing model.
// Returns an empty result if an error occurs.
std::vector<std::pair<std::string, float>> ClassifyText(
const std::string& context, CodepointSpan click_indices,
int input_flags = 0) const;
protected:
// Removes punctuation from the beginning and end of the selection and returns
// the new selection span.
CodepointSpan StripPunctuation(CodepointSpan selection,
const std::string& context) const;
// During evaluation we need access to the feature processor.
FeatureProcessor* SelectionFeatureProcessor() const {
return selection_feature_processor_.get();
}
// Collection name when url hint is accepted.
const std::string kUrlHintCollection = "url";
// Collection name when email hint is accepted.
const std::string kEmailHintCollection = "email";
// Collection name for other.
const std::string kOtherCollection = "other";
// Collection name for phone.
const std::string kPhoneCollection = "phone";
SelectionModelOptions selection_options_;
SharingModelOptions sharing_options_;
private:
bool LoadModels(const nlp_core::MmapHandle& mmap_handle);
nlp_core::EmbeddingNetwork::Vector InferInternal(
const std::string& context, CodepointSpan span,
const FeatureProcessor& feature_processor,
const nlp_core::EmbeddingNetwork& network,
const FeatureVectorFn& feature_vector_fn,
std::vector<CodepointSpan>* selection_label_spans) const;
// Returns a selection suggestion with a score.
std::pair<CodepointSpan, float> SuggestSelectionInternal(
const std::string& context, CodepointSpan click_indices) const;
// Returns a selection suggestion and makes sure it's symmetric. Internally
// runs several times SuggestSelectionInternal.
CodepointSpan SuggestSelectionSymmetrical(const std::string& context,
CodepointSpan click_indices) const;
bool initialized_;
nlp_core::ScopedMmap mmap_;
std::unique_ptr<ModelParams> selection_params_;
std::unique_ptr<FeatureProcessor> selection_feature_processor_;
std::unique_ptr<nlp_core::EmbeddingNetwork> selection_network_;
FeatureVectorFn selection_feature_fn_;
std::unique_ptr<FeatureProcessor> sharing_feature_processor_;
std::unique_ptr<ModelParams> sharing_params_;
std::unique_ptr<nlp_core::EmbeddingNetwork> sharing_network_;
FeatureVectorFn sharing_feature_fn_;
std::set<int> punctuation_to_strip_;
};
// Parses the merged image given as a file descriptor, and reads
// the ModelOptions proto from the selection model.
bool ReadSelectionModelOptions(int fd, ModelOptions* model_options);
} // namespace libtextclassifier
#endif // LIBTEXTCLASSIFIER_SMARTSELECT_TEXT_CLASSIFICATION_MODEL_H_