blob: 756fcd5b1c0b37b7c4ab1a3f9e8cc4f5905f2380 [file] [log] [blame]
/*
* Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "annotator/grammar/grammar-annotator.h"
#include "annotator/feature-processor.h"
#include "annotator/grammar/utils.h"
#include "annotator/types.h"
#include "utils/base/logging.h"
#include "utils/grammar/callback-delegate.h"
#include "utils/grammar/match.h"
#include "utils/grammar/matcher.h"
#include "utils/grammar/rules-utils.h"
#include "utils/grammar/types.h"
#include "utils/normalization.h"
#include "utils/optional.h"
#include "utils/utf8/unicodetext.h"
namespace libtextclassifier3 {
namespace {
// Returns the unicode codepoint offsets in a utf8 encoded text.
std::vector<UnicodeText::const_iterator> UnicodeCodepointOffsets(
const UnicodeText& text) {
std::vector<UnicodeText::const_iterator> offsets;
for (auto it = text.begin(); it != text.end(); it++) {
offsets.push_back(it);
}
offsets.push_back(text.end());
return offsets;
}
} // namespace
class GrammarAnnotatorCallbackDelegate : public grammar::CallbackDelegate {
public:
explicit GrammarAnnotatorCallbackDelegate(
const UniLib* unilib, const GrammarModel* model,
const MutableFlatbufferBuilder* entity_data_builder, const ModeFlag mode)
: unilib_(*unilib),
model_(model),
entity_data_builder_(entity_data_builder),
mode_(mode) {}
// Handles a grammar rule match in the annotator grammar.
void MatchFound(const grammar::Match* match, grammar::CallbackId type,
int64 value, grammar::Matcher* matcher) override {
switch (static_cast<GrammarAnnotator::Callback>(type)) {
case GrammarAnnotator::Callback::kRuleMatch: {
HandleRuleMatch(match, /*rule_id=*/value);
return;
}
default:
grammar::CallbackDelegate::MatchFound(match, type, value, matcher);
}
}
// Deduplicate and populate annotations from grammar matches.
bool GetAnnotations(const std::vector<UnicodeText::const_iterator>& text,
std::vector<AnnotatedSpan>* annotations) const {
for (const grammar::Derivation& candidate :
grammar::DeduplicateDerivations(candidates_)) {
// Check that assertions are fulfilled.
if (!grammar::VerifyAssertions(candidate.match)) {
continue;
}
if (!AddAnnotatedSpanFromMatch(text, candidate, annotations)) {
return false;
}
}
return true;
}
bool GetTextSelection(const std::vector<UnicodeText::const_iterator>& text,
const CodepointSpan& selection, AnnotatedSpan* result) {
std::vector<grammar::Derivation> selection_candidates;
// Deduplicate and verify matches.
auto maybe_interpretation = GetBestValidInterpretation(
grammar::DeduplicateDerivations(GetOverlappingRuleMatches(
selection, candidates_, /*only_exact_overlap=*/false)));
if (!maybe_interpretation.has_value()) {
return false;
}
const GrammarModel_::RuleClassificationResult* interpretation;
const grammar::Match* match;
std::tie(interpretation, match) = maybe_interpretation.value();
return InstantiateAnnotatedSpanFromInterpretation(text, interpretation,
match, result);
}
// Provides a classification results from the grammar matches.
bool GetClassification(const std::vector<UnicodeText::const_iterator>& text,
const CodepointSpan& selection,
ClassificationResult* classification) const {
// Deduplicate and verify matches.
auto maybe_interpretation = GetBestValidInterpretation(
grammar::DeduplicateDerivations(GetOverlappingRuleMatches(
selection, candidates_, /*only_exact_overlap=*/true)));
if (!maybe_interpretation.has_value()) {
return false;
}
// Instantiate result.
const GrammarModel_::RuleClassificationResult* interpretation;
const grammar::Match* match;
std::tie(interpretation, match) = maybe_interpretation.value();
return InstantiateClassificationInterpretation(text, interpretation, match,
classification);
}
private:
// Handles annotation/selection/classification rule matches.
void HandleRuleMatch(const grammar::Match* match, const int64 rule_id) {
if ((model_->rule_classification_result()->Get(rule_id)->enabled_modes() &
mode_) != 0) {
candidates_.push_back(grammar::Derivation{match, rule_id});
}
}
// Computes the selection boundaries from a grammar match.
CodepointSpan MatchSelectionBoundaries(
const grammar::Match* match,
const GrammarModel_::RuleClassificationResult* classification) const {
if (classification->capturing_group() == nullptr) {
// Use full match as selection span.
return match->codepoint_span;
}
// Set information from capturing matches.
CodepointSpan span{kInvalidIndex, kInvalidIndex};
// Gather active capturing matches.
std::unordered_map<uint16, const grammar::Match*> capturing_matches;
for (const grammar::MappingMatch* match :
grammar::SelectAllOfType<grammar::MappingMatch>(
match, grammar::Match::kMappingMatch)) {
capturing_matches[match->id] = match;
}
// Compute span boundaries.
for (int i = 0; i < classification->capturing_group()->size(); i++) {
auto it = capturing_matches.find(i);
if (it == capturing_matches.end()) {
// Capturing group is not active, skip.
continue;
}
const CapturingGroup* group = classification->capturing_group()->Get(i);
if (group->extend_selection()) {
if (span.first == kInvalidIndex) {
span = it->second->codepoint_span;
} else {
span.first = std::min(span.first, it->second->codepoint_span.first);
span.second =
std::max(span.second, it->second->codepoint_span.second);
}
}
}
return span;
}
// Filters out results that do not overlap with a reference span.
std::vector<grammar::Derivation> GetOverlappingRuleMatches(
const CodepointSpan& selection,
const std::vector<grammar::Derivation>& candidates,
const bool only_exact_overlap) const {
std::vector<grammar::Derivation> result;
for (const grammar::Derivation& candidate : candidates) {
// Discard matches that do not match the selection.
// Simple check.
if (!SpansOverlap(selection, candidate.match->codepoint_span)) {
continue;
}
// Compute exact selection boundaries (without assertions and
// non-capturing parts).
const CodepointSpan span = MatchSelectionBoundaries(
candidate.match,
model_->rule_classification_result()->Get(candidate.rule_id));
if (!SpansOverlap(selection, span) ||
(only_exact_overlap && span != selection)) {
continue;
}
result.push_back(candidate);
}
return result;
}
// Returns the best valid interpretation of a set of candidate matches.
Optional<std::pair<const GrammarModel_::RuleClassificationResult*,
const grammar::Match*>>
GetBestValidInterpretation(
const std::vector<grammar::Derivation>& candidates) const {
const GrammarModel_::RuleClassificationResult* best_interpretation =
nullptr;
const grammar::Match* best_match = nullptr;
for (const grammar::Derivation& candidate : candidates) {
if (!grammar::VerifyAssertions(candidate.match)) {
continue;
}
const GrammarModel_::RuleClassificationResult*
rule_classification_result =
model_->rule_classification_result()->Get(candidate.rule_id);
if (best_interpretation == nullptr ||
best_interpretation->priority_score() <
rule_classification_result->priority_score()) {
best_interpretation = rule_classification_result;
best_match = candidate.match;
}
}
// No valid interpretation found.
Optional<std::pair<const GrammarModel_::RuleClassificationResult*,
const grammar::Match*>>
result;
if (best_interpretation != nullptr) {
result = {best_interpretation, best_match};
}
return result;
}
// Instantiates an annotated span from a rule match and appends it to the
// result.
bool AddAnnotatedSpanFromMatch(
const std::vector<UnicodeText::const_iterator>& text,
const grammar::Derivation& candidate,
std::vector<AnnotatedSpan>* result) const {
if (candidate.rule_id < 0 ||
candidate.rule_id >= model_->rule_classification_result()->size()) {
TC3_LOG(INFO) << "Invalid rule id.";
return false;
}
const GrammarModel_::RuleClassificationResult* interpretation =
model_->rule_classification_result()->Get(candidate.rule_id);
result->emplace_back();
return InstantiateAnnotatedSpanFromInterpretation(
text, interpretation, candidate.match, &result->back());
}
bool InstantiateAnnotatedSpanFromInterpretation(
const std::vector<UnicodeText::const_iterator>& text,
const GrammarModel_::RuleClassificationResult* interpretation,
const grammar::Match* match, AnnotatedSpan* result) const {
result->span = MatchSelectionBoundaries(match, interpretation);
ClassificationResult classification;
if (!InstantiateClassificationInterpretation(text, interpretation, match,
&classification)) {
return false;
}
result->classification.push_back(classification);
return true;
}
// Instantiates a classification result from a rule match.
bool InstantiateClassificationInterpretation(
const std::vector<UnicodeText::const_iterator>& text,
const GrammarModel_::RuleClassificationResult* interpretation,
const grammar::Match* match, ClassificationResult* classification) const {
classification->collection = interpretation->collection_name()->str();
classification->score = interpretation->target_classification_score();
classification->priority_score = interpretation->priority_score();
// Assemble entity data.
if (entity_data_builder_ == nullptr) {
return true;
}
std::unique_ptr<MutableFlatbuffer> entity_data =
entity_data_builder_->NewRoot();
if (interpretation->serialized_entity_data() != nullptr) {
entity_data->MergeFromSerializedFlatbuffer(
StringPiece(interpretation->serialized_entity_data()->data(),
interpretation->serialized_entity_data()->size()));
}
if (interpretation->entity_data() != nullptr) {
entity_data->MergeFrom(reinterpret_cast<const flatbuffers::Table*>(
interpretation->entity_data()));
}
// Populate entity data from the capturing matches.
if (interpretation->capturing_group() != nullptr) {
// Gather active capturing matches.
std::unordered_map<uint16, const grammar::Match*> capturing_matches;
for (const grammar::MappingMatch* match :
grammar::SelectAllOfType<grammar::MappingMatch>(
match, grammar::Match::kMappingMatch)) {
capturing_matches[match->id] = match;
}
for (int i = 0; i < interpretation->capturing_group()->size(); i++) {
auto it = capturing_matches.find(i);
if (it == capturing_matches.end()) {
// Capturing group is not active, skip.
continue;
}
const CapturingGroup* group = interpretation->capturing_group()->Get(i);
// Add static entity data.
if (group->serialized_entity_data() != nullptr) {
entity_data->MergeFromSerializedFlatbuffer(
StringPiece(interpretation->serialized_entity_data()->data(),
interpretation->serialized_entity_data()->size()));
}
// Set entity field from captured text.
if (group->entity_field_path() != nullptr) {
const grammar::Match* capturing_match = it->second;
StringPiece group_text = StringPiece(
text[capturing_match->codepoint_span.first].utf8_data(),
text[capturing_match->codepoint_span.second].utf8_data() -
text[capturing_match->codepoint_span.first].utf8_data());
UnicodeText normalized_group_text =
UTF8ToUnicodeText(group_text, /*do_copy=*/false);
if (group->normalization_options() != nullptr) {
normalized_group_text = NormalizeText(
unilib_, group->normalization_options(), normalized_group_text);
}
if (!entity_data->ParseAndSet(group->entity_field_path(),
normalized_group_text.ToUTF8String())) {
TC3_LOG(ERROR) << "Could not set entity data from capturing match.";
return false;
}
}
}
}
if (entity_data && entity_data->HasExplicitlySetFields()) {
classification->serialized_entity_data = entity_data->Serialize();
}
return true;
}
const UniLib& unilib_;
const GrammarModel* model_;
const MutableFlatbufferBuilder* entity_data_builder_;
const ModeFlag mode_;
// All annotation/selection/classification rule match candidates.
// Grammar rule matches are recorded, deduplicated and then instantiated.
std::vector<grammar::Derivation> candidates_;
};
GrammarAnnotator::GrammarAnnotator(
const UniLib* unilib, const GrammarModel* model,
const MutableFlatbufferBuilder* entity_data_builder)
: unilib_(*unilib),
model_(model),
lexer_(unilib, model->rules()),
tokenizer_(BuildTokenizer(unilib, model->tokenizer_options())),
entity_data_builder_(entity_data_builder),
rules_locales_(grammar::ParseRulesLocales(model->rules())) {}
bool GrammarAnnotator::Annotate(const std::vector<Locale>& locales,
const UnicodeText& text,
std::vector<AnnotatedSpan>* result) const {
if (model_ == nullptr || model_->rules() == nullptr) {
// Nothing to do.
return true;
}
// Select locale matching rules.
std::vector<const grammar::RulesSet_::Rules*> locale_rules =
SelectLocaleMatchingShards(model_->rules(), rules_locales_, locales);
if (locale_rules.empty()) {
// Nothing to do.
return true;
}
// Run the grammar.
GrammarAnnotatorCallbackDelegate callback_handler(
&unilib_, model_, entity_data_builder_,
/*mode=*/ModeFlag_ANNOTATION);
grammar::Matcher matcher(&unilib_, model_->rules(), locale_rules,
&callback_handler);
lexer_.Process(text, tokenizer_.Tokenize(text), /*annotations=*/nullptr,
&matcher);
// Populate results.
return callback_handler.GetAnnotations(UnicodeCodepointOffsets(text), result);
}
bool GrammarAnnotator::SuggestSelection(const std::vector<Locale>& locales,
const UnicodeText& text,
const CodepointSpan& selection,
AnnotatedSpan* result) const {
if (model_ == nullptr || model_->rules() == nullptr ||
selection == CodepointSpan{kInvalidIndex, kInvalidIndex}) {
// Nothing to do.
return false;
}
// Select locale matching rules.
std::vector<const grammar::RulesSet_::Rules*> locale_rules =
SelectLocaleMatchingShards(model_->rules(), rules_locales_, locales);
if (locale_rules.empty()) {
// Nothing to do.
return true;
}
// Run the grammar.
GrammarAnnotatorCallbackDelegate callback_handler(
&unilib_, model_, entity_data_builder_,
/*mode=*/ModeFlag_SELECTION);
grammar::Matcher matcher(&unilib_, model_->rules(), locale_rules,
&callback_handler);
lexer_.Process(text, tokenizer_.Tokenize(text), /*annotations=*/nullptr,
&matcher);
// Populate the result.
return callback_handler.GetTextSelection(UnicodeCodepointOffsets(text),
selection, result);
}
bool GrammarAnnotator::ClassifyText(
const std::vector<Locale>& locales, const UnicodeText& text,
const CodepointSpan& selection,
ClassificationResult* classification_result) const {
if (model_ == nullptr || model_->rules() == nullptr ||
selection == CodepointSpan{kInvalidIndex, kInvalidIndex}) {
// Nothing to do.
return false;
}
// Select locale matching rules.
std::vector<const grammar::RulesSet_::Rules*> locale_rules =
SelectLocaleMatchingShards(model_->rules(), rules_locales_, locales);
if (locale_rules.empty()) {
// Nothing to do.
return false;
}
// Run the grammar.
GrammarAnnotatorCallbackDelegate callback_handler(
&unilib_, model_, entity_data_builder_,
/*mode=*/ModeFlag_CLASSIFICATION);
grammar::Matcher matcher(&unilib_, model_->rules(), locale_rules,
&callback_handler);
const std::vector<Token> tokens = tokenizer_.Tokenize(text);
if (model_->context_left_num_tokens() == -1 &&
model_->context_right_num_tokens() == -1) {
// Use all tokens.
lexer_.Process(text, tokens, /*annotations=*/{}, &matcher);
} else {
TokenSpan context_span = CodepointSpanToTokenSpan(
tokens, selection, /*snap_boundaries_to_containing_tokens=*/true);
std::vector<Token>::const_iterator begin = tokens.begin();
std::vector<Token>::const_iterator end = tokens.begin();
if (model_->context_left_num_tokens() != -1) {
std::advance(begin, std::max(0, context_span.first -
model_->context_left_num_tokens()));
}
if (model_->context_right_num_tokens() == -1) {
end = tokens.end();
} else {
std::advance(end, std::min(static_cast<int>(tokens.size()),
context_span.second +
model_->context_right_num_tokens()));
}
lexer_.Process(text, begin, end,
/*annotations=*/nullptr, &matcher);
}
// Populate result.
return callback_handler.GetClassification(UnicodeCodepointOffsets(text),
selection, classification_result);
}
} // namespace libtextclassifier3