blob: 6c4a902843b0331cbc25505d4522fdf109fdff39 [file] [log] [blame]
/*
* Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LIBTEXTCLASSIFIER_ANNOTATOR_POD_NER_UTILS_H_
#define LIBTEXTCLASSIFIER_ANNOTATOR_POD_NER_UTILS_H_
#include <string>
#include <vector>
#include "annotator/model_generated.h"
#include "annotator/types.h"
#include "absl/strings/string_view.h"
namespace libtextclassifier3 {
// Converts saft labels like /saft/person to collection name 'person'.
std::string SaftLabelToCollection(absl::string_view saft_label);
struct WordpieceSpan {
// Beginning index is inclusive, end index is exclusive.
WordpieceSpan() : begin(0), end(0) {}
WordpieceSpan(int begin, int end) : begin(begin), end(end) {}
int begin;
int end;
bool operator==(const WordpieceSpan &other) const {
return this->begin == other.begin && this->end == other.end;
}
int length() { return end - begin; }
};
namespace internal {
// Finds the wordpiece window arond the given span_of_interest. If the number
// of wordpieces in this window is smaller than max_num_wordpieces_in_window
// it is expanded around the span of interest.
WordpieceSpan FindWordpiecesWindowAroundSpan(
const CodepointSpan &span_of_interest, const std::vector<Token> &tokens,
const std::vector<int32_t> &word_starts, int num_wordpieces,
int max_num_wordpieces_in_window);
// Expands the given wordpiece window around the given window to the be
// maximal possible while making sure it includes only full tokens.
WordpieceSpan ExpandWindowAndAlign(int max_num_wordpieces_in_window,
int num_wordpieces,
WordpieceSpan wordpiece_span_to_expand);
// Returns the index of the last token which ends before wordpiece_end.
int FindLastFullTokenIndex(const std::vector<int32_t> &word_starts,
int num_wordpieces, int wordpiece_end);
// Returns the index of the token which includes first_wordpiece_index.
int FindFirstFullTokenIndex(const std::vector<int32_t> &word_starts,
int first_wordpiece_index);
// Given wordpiece_span, and max_num_wordpieces, finds:
// 1. The first token which includes wordpiece_span.begin.
// 2. The length of tokens sequence which starts from this token and:
// a. Its last token's last wordpiece index ends before wordpiece_span.end.
// b. Its overall number of wordpieces is at most max_num_wordpieces.
// Returns the updated wordpiece_span: begin and end wordpieces of this token
// sequence.
WordpieceSpan FindFullTokensSpanInWindow(
const std::vector<int32_t> &word_starts,
const WordpieceSpan &wordpiece_span, int max_num_wordpieces,
int num_wordpieces, int *first_token_index, int *num_tokens);
} // namespace internal
// Converts sequence of IOB tags to AnnotatedSpans. Ignores illegal sequences.
// Setting label_filter can also help ignore certain label tags like "NAM" or
// "NOM".
// The inside tag can be ignored when setting relaxed_inside_label_matching,
// e.g. B-NAM-location, I-NAM-other, E-NAM-location would be considered a valid
// sequence.
// The label category matching can be ignored when setting
// relaxed_label_category_matching. The matching will only operate at the entity
// level, e.g. B-NAM-location, E-NOM-location would be considered a valid
// sequence.
bool ConvertTagsToAnnotatedSpans(const VectorSpan<Token> &tokens,
const std::vector<std::string> &tags,
const std::vector<std::string> &label_filter,
bool relaxed_inside_label_matching,
bool relaxed_label_category_matching,
float priority_score,
std::vector<AnnotatedSpan> *results);
// Like the previous function but instead of getting the tags as strings
// the input is PodNerModel_::LabelT along with the collections vector which
// hold the collection name and priorities. e.g. a tag was "B-NAM-location" and
// the priority_score was 1.0 it would be Label(BoiseType_BEGIN,
// MentionType_NAM, 1) and collections={{"xxx", 1., 1.},
// {"location", 1., 1.}, {"yyy", 1., 1.}, ...}.
bool ConvertTagsToAnnotatedSpans(
const VectorSpan<Token> &tokens,
const std::vector<PodNerModel_::LabelT> &labels,
const std::vector<PodNerModel_::CollectionT> &collections,
const std::vector<PodNerModel_::Label_::MentionType> &mention_filter,
bool relaxed_inside_label_matching, bool relaxed_mention_type_matching,
std::vector<AnnotatedSpan> *results);
// Merge two overlaping sequences of labels, the result is placed into the left
// sequence. In the overlapping part takes the labels from the left sequence on
// the first half and from the right on the second half.
bool MergeLabelsIntoLeftSequence(
const std::vector<PodNerModel_::LabelT> &labels_right,
int index_first_right_tag_in_left,
std::vector<PodNerModel_::LabelT> *labels_left);
// This class is used to slide over {wordpiece_indices, token_starts, tokens} in
// windows of at most max_num_wordpieces while assuring that each window
// contains only full tokens.
class WindowGenerator {
public:
WindowGenerator(const std::vector<int32_t> &wordpiece_indices,
const std::vector<int32_t> &token_starts,
const std::vector<Token> &tokens, int max_num_wordpieces,
int sliding_window_overlap,
const CodepointSpan &span_of_interest);
bool Next(VectorSpan<int32_t> *cur_wordpiece_indices,
VectorSpan<int32_t> *cur_token_starts,
VectorSpan<Token> *cur_tokens);
bool Done() const {
return previous_wordpiece_span_.end >= entire_wordpiece_span_.end;
}
private:
const std::vector<int32_t> *wordpiece_indices_;
const std::vector<int32_t> *token_starts_;
const std::vector<Token> *tokens_;
int max_num_effective_wordpieces_;
int sliding_window_num_wordpieces_overlap_;
WordpieceSpan entire_wordpiece_span_;
WordpieceSpan next_wordpiece_span_;
WordpieceSpan previous_wordpiece_span_;
};
} // namespace libtextclassifier3
#endif // LIBTEXTCLASSIFIER_ANNOTATOR_POD_NER_UTILS_H_