| /* |
| * Copyright (C) 2018 The Android Open Source Project |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| #ifndef LIBTEXTCLASSIFIER_ANNOTATOR_TYPES_H_ |
| #define LIBTEXTCLASSIFIER_ANNOTATOR_TYPES_H_ |
| |
| #include <algorithm> |
| #include <cmath> |
| #include <functional> |
| #include <map> |
| #include <set> |
| #include <string> |
| #include <utility> |
| #include <vector> |
| |
| #include "utils/base/integral_types.h" |
| #include "utils/base/logging.h" |
| #include "utils/variant.h" |
| |
| namespace libtextclassifier3 { |
| |
| constexpr int kInvalidIndex = -1; |
| |
| // Index for a 0-based array of tokens. |
| using TokenIndex = int; |
| |
| // Index for a 0-based array of codepoints. |
| using CodepointIndex = int; |
| |
| // Marks a span in a sequence of codepoints. The first element is the index of |
| // the first codepoint of the span, and the second element is the index of the |
| // codepoint one past the end of the span. |
| // TODO(b/71982294): Make it a struct. |
| using CodepointSpan = std::pair<CodepointIndex, CodepointIndex>; |
| |
| inline bool SpansOverlap(const CodepointSpan& a, const CodepointSpan& b) { |
| return a.first < b.second && b.first < a.second; |
| } |
| |
| inline bool ValidNonEmptySpan(const CodepointSpan& span) { |
| return span.first < span.second && span.first >= 0 && span.second >= 0; |
| } |
| |
| template <typename T> |
| bool DoesCandidateConflict( |
| const int considered_candidate, const std::vector<T>& candidates, |
| const std::set<int, std::function<bool(int, int)>>& chosen_indices_set) { |
| if (chosen_indices_set.empty()) { |
| return false; |
| } |
| |
| auto conflicting_it = chosen_indices_set.lower_bound(considered_candidate); |
| // Check conflict on the right. |
| if (conflicting_it != chosen_indices_set.end() && |
| SpansOverlap(candidates[considered_candidate].span, |
| candidates[*conflicting_it].span)) { |
| return true; |
| } |
| |
| // Check conflict on the left. |
| // If we can't go more left, there can't be a conflict: |
| if (conflicting_it == chosen_indices_set.begin()) { |
| return false; |
| } |
| // Otherwise move one span left and insert if it doesn't overlap with the |
| // candidate. |
| --conflicting_it; |
| if (!SpansOverlap(candidates[considered_candidate].span, |
| candidates[*conflicting_it].span)) { |
| return false; |
| } |
| |
| return true; |
| } |
| |
| // Marks a span in a sequence of tokens. The first element is the index of the |
| // first token in the span, and the second element is the index of the token one |
| // past the end of the span. |
| // TODO(b/71982294): Make it a struct. |
| using TokenSpan = std::pair<TokenIndex, TokenIndex>; |
| |
| // Returns the size of the token span. Assumes that the span is valid. |
| inline int TokenSpanSize(const TokenSpan& token_span) { |
| return token_span.second - token_span.first; |
| } |
| |
| // Returns a token span consisting of one token. |
| inline TokenSpan SingleTokenSpan(int token_index) { |
| return {token_index, token_index + 1}; |
| } |
| |
| // Returns an intersection of two token spans. Assumes that both spans are valid |
| // and overlapping. |
| inline TokenSpan IntersectTokenSpans(const TokenSpan& token_span1, |
| const TokenSpan& token_span2) { |
| return {std::max(token_span1.first, token_span2.first), |
| std::min(token_span1.second, token_span2.second)}; |
| } |
| |
| // Returns and expanded token span by adding a certain number of tokens on its |
| // left and on its right. |
| inline TokenSpan ExpandTokenSpan(const TokenSpan& token_span, |
| int num_tokens_left, int num_tokens_right) { |
| return {token_span.first - num_tokens_left, |
| token_span.second + num_tokens_right}; |
| } |
| |
| // Token holds a token, its position in the original string and whether it was |
| // part of the input span. |
| struct Token { |
| std::string value; |
| CodepointIndex start; |
| CodepointIndex end; |
| |
| // Whether the token is a padding token. |
| bool is_padding; |
| |
| // Default constructor constructs the padding-token. |
| Token() |
| : value(""), start(kInvalidIndex), end(kInvalidIndex), is_padding(true) {} |
| |
| Token(const std::string& arg_value, CodepointIndex arg_start, |
| CodepointIndex arg_end) |
| : value(arg_value), start(arg_start), end(arg_end), is_padding(false) {} |
| |
| bool operator==(const Token& other) const { |
| return value == other.value && start == other.start && end == other.end && |
| is_padding == other.is_padding; |
| } |
| |
| bool IsContainedInSpan(CodepointSpan span) const { |
| return start >= span.first && end <= span.second; |
| } |
| }; |
| |
| // Pretty-printing function for Token. |
| inline logging::LoggingStringStream& operator<<( |
| logging::LoggingStringStream& stream, const Token& token) { |
| if (!token.is_padding) { |
| return stream << "Token(\"" << token.value << "\", " << token.start << ", " |
| << token.end << ")"; |
| } else { |
| return stream << "Token()"; |
| } |
| } |
| |
| enum DatetimeGranularity { |
| GRANULARITY_UNKNOWN = -1, // GRANULARITY_UNKNOWN is used as a proxy for this |
| // structure being uninitialized. |
| GRANULARITY_YEAR = 0, |
| GRANULARITY_MONTH = 1, |
| GRANULARITY_WEEK = 2, |
| GRANULARITY_DAY = 3, |
| GRANULARITY_HOUR = 4, |
| GRANULARITY_MINUTE = 5, |
| GRANULARITY_SECOND = 6 |
| }; |
| |
| struct DatetimeParseResult { |
| // The absolute time in milliseconds since the epoch in UTC. This is derived |
| // from the reference time and the fields specified in the text - so it may |
| // be imperfect where the time was ambiguous. (e.g. "at 7:30" may be am or pm) |
| int64 time_ms_utc; |
| |
| // The precision of the estimate then in to calculating the milliseconds |
| DatetimeGranularity granularity; |
| |
| DatetimeParseResult() : time_ms_utc(0), granularity(GRANULARITY_UNKNOWN) {} |
| |
| DatetimeParseResult(int64 arg_time_ms_utc, |
| DatetimeGranularity arg_granularity) |
| : time_ms_utc(arg_time_ms_utc), granularity(arg_granularity) {} |
| |
| bool IsSet() const { return granularity != GRANULARITY_UNKNOWN; } |
| |
| bool operator==(const DatetimeParseResult& other) const { |
| return granularity == other.granularity && time_ms_utc == other.time_ms_utc; |
| } |
| }; |
| |
| const float kFloatCompareEpsilon = 1e-5; |
| |
| struct DatetimeParseResultSpan { |
| CodepointSpan span; |
| DatetimeParseResult data; |
| float target_classification_score; |
| float priority_score; |
| |
| bool operator==(const DatetimeParseResultSpan& other) const { |
| return span == other.span && data.granularity == other.data.granularity && |
| data.time_ms_utc == other.data.time_ms_utc && |
| std::abs(target_classification_score - |
| other.target_classification_score) < kFloatCompareEpsilon && |
| std::abs(priority_score - other.priority_score) < |
| kFloatCompareEpsilon; |
| } |
| }; |
| |
| // Pretty-printing function for DatetimeParseResultSpan. |
| inline logging::LoggingStringStream& operator<<( |
| logging::LoggingStringStream& stream, |
| const DatetimeParseResultSpan& value) { |
| return stream << "DatetimeParseResultSpan({" << value.span.first << ", " |
| << value.span.second << "}, {/*time_ms_utc=*/ " |
| << value.data.time_ms_utc << ", /*granularity=*/ " |
| << value.data.granularity << "})"; |
| } |
| |
| struct ClassificationResult { |
| std::string collection; |
| float score; |
| DatetimeParseResult datetime_parse_result; |
| std::string serialized_knowledge_result; |
| |
| // Internal score used for conflict resolution. |
| float priority_score; |
| |
| // Extra information. |
| std::map<std::string, Variant> extra; |
| |
| explicit ClassificationResult() : score(-1.0f), priority_score(-1.0) {} |
| |
| ClassificationResult(const std::string& arg_collection, float arg_score) |
| : collection(arg_collection), |
| score(arg_score), |
| priority_score(arg_score) {} |
| |
| ClassificationResult(const std::string& arg_collection, float arg_score, |
| float arg_priority_score) |
| : collection(arg_collection), |
| score(arg_score), |
| priority_score(arg_priority_score) {} |
| }; |
| |
| // Pretty-printing function for ClassificationResult. |
| inline logging::LoggingStringStream& operator<<( |
| logging::LoggingStringStream& stream, const ClassificationResult& result) { |
| return stream << "ClassificationResult(" << result.collection << ", " |
| << result.score << ")"; |
| } |
| |
| // Pretty-printing function for std::vector<ClassificationResult>. |
| inline logging::LoggingStringStream& operator<<( |
| logging::LoggingStringStream& stream, |
| const std::vector<ClassificationResult>& results) { |
| stream = stream << "{\n"; |
| for (const ClassificationResult& result : results) { |
| stream = stream << " " << result << "\n"; |
| } |
| stream = stream << "}"; |
| return stream; |
| } |
| |
| // Represents a result of Annotate call. |
| struct AnnotatedSpan { |
| // Unicode codepoint indices in the input string. |
| CodepointSpan span = {kInvalidIndex, kInvalidIndex}; |
| |
| // Classification result for the span. |
| std::vector<ClassificationResult> classification; |
| }; |
| |
| // Pretty-printing function for AnnotatedSpan. |
| inline logging::LoggingStringStream& operator<<( |
| logging::LoggingStringStream& stream, const AnnotatedSpan& span) { |
| std::string best_class; |
| float best_score = -1; |
| if (!span.classification.empty()) { |
| best_class = span.classification[0].collection; |
| best_score = span.classification[0].score; |
| } |
| return stream << "Span(" << span.span.first << ", " << span.span.second |
| << ", " << best_class << ", " << best_score << ")"; |
| } |
| |
| // StringPiece analogue for std::vector<T>. |
| template <class T> |
| class VectorSpan { |
| public: |
| VectorSpan() : begin_(), end_() {} |
| VectorSpan(const std::vector<T>& v) // NOLINT(runtime/explicit) |
| : begin_(v.begin()), end_(v.end()) {} |
| VectorSpan(typename std::vector<T>::const_iterator begin, |
| typename std::vector<T>::const_iterator end) |
| : begin_(begin), end_(end) {} |
| |
| const T& operator[](typename std::vector<T>::size_type i) const { |
| return *(begin_ + i); |
| } |
| |
| int size() const { return end_ - begin_; } |
| typename std::vector<T>::const_iterator begin() const { return begin_; } |
| typename std::vector<T>::const_iterator end() const { return end_; } |
| const float* data() const { return &(*begin_); } |
| |
| private: |
| typename std::vector<T>::const_iterator begin_; |
| typename std::vector<T>::const_iterator end_; |
| }; |
| |
| struct DateParseData { |
| enum Relation { |
| NEXT = 1, |
| NEXT_OR_SAME = 2, |
| LAST = 3, |
| NOW = 4, |
| TOMORROW = 5, |
| YESTERDAY = 6, |
| PAST = 7, |
| FUTURE = 8 |
| }; |
| |
| enum RelationType { |
| SUNDAY = 1, |
| MONDAY = 2, |
| TUESDAY = 3, |
| WEDNESDAY = 4, |
| THURSDAY = 5, |
| FRIDAY = 6, |
| SATURDAY = 7, |
| DAY = 8, |
| WEEK = 9, |
| MONTH = 10, |
| YEAR = 11 |
| }; |
| |
| enum Fields { |
| YEAR_FIELD = 1 << 0, |
| MONTH_FIELD = 1 << 1, |
| DAY_FIELD = 1 << 2, |
| HOUR_FIELD = 1 << 3, |
| MINUTE_FIELD = 1 << 4, |
| SECOND_FIELD = 1 << 5, |
| AMPM_FIELD = 1 << 6, |
| ZONE_OFFSET_FIELD = 1 << 7, |
| DST_OFFSET_FIELD = 1 << 8, |
| RELATION_FIELD = 1 << 9, |
| RELATION_TYPE_FIELD = 1 << 10, |
| RELATION_DISTANCE_FIELD = 1 << 11 |
| }; |
| |
| enum AMPM { AM = 0, PM = 1 }; |
| |
| enum TimeUnit { |
| DAYS = 1, |
| WEEKS = 2, |
| MONTHS = 3, |
| HOURS = 4, |
| MINUTES = 5, |
| SECONDS = 6, |
| YEARS = 7 |
| }; |
| |
| // Bit mask of fields which have been set on the struct |
| int field_set_mask; |
| |
| // Fields describing absolute date fields. |
| // Year of the date seen in the text match. |
| int year; |
| // Month of the year starting with January = 1. |
| int month; |
| // Day of the month starting with 1. |
| int day_of_month; |
| // Hour of the day with a range of 0-23, |
| // values less than 12 need the AMPM field below or heuristics |
| // to definitively determine the time. |
| int hour; |
| // Hour of the day with a range of 0-59. |
| int minute; |
| // Hour of the day with a range of 0-59. |
| int second; |
| // 0 == AM, 1 == PM |
| int ampm; |
| // Number of hours offset from UTC this date time is in. |
| int zone_offset; |
| // Number of hours offest for DST |
| int dst_offset; |
| |
| // The permutation from now that was made to find the date time. |
| Relation relation; |
| // The unit of measure of the change to the date time. |
| RelationType relation_type; |
| // The number of units of change that were made. |
| int relation_distance; |
| }; |
| |
| } // namespace libtextclassifier3 |
| |
| #endif // LIBTEXTCLASSIFIER_ANNOTATOR_TYPES_H_ |