| /* |
| * Copyright (C) 2018 The Android Open Source Project |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| #include "annotator/types.h" |
| |
| #include <vector> |
| |
| #include "utils/optional.h" |
| |
| namespace libtextclassifier3 { |
| |
| const CodepointSpan CodepointSpan::kInvalid = |
| CodepointSpan(kInvalidIndex, kInvalidIndex); |
| |
| const TokenSpan TokenSpan::kInvalid = TokenSpan(kInvalidIndex, kInvalidIndex); |
| |
| logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream, |
| const CodepointSpan& span) { |
| return stream << "CodepointSpan(" << span.first << ", " << span.second << ")"; |
| } |
| |
| logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream, |
| const TokenSpan& span) { |
| return stream << "TokenSpan(" << span.first << ", " << span.second << ")"; |
| } |
| |
| logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream, |
| const Token& token) { |
| if (!token.is_padding) { |
| return stream << "Token(\"" << token.value << "\", " << token.start << ", " |
| << token.end << ")"; |
| } else { |
| return stream << "Token()"; |
| } |
| } |
| |
| bool DatetimeComponent::ShouldRoundToGranularity() const { |
| // Don't round to the granularity for relative expressions that specify the |
| // distance. So that, e.g. "in 2 hours" when it's 8:35:03 will result in |
| // 10:35:03. |
| if (relative_qualifier == RelativeQualifier::UNSPECIFIED) { |
| return false; |
| } |
| if (relative_qualifier == RelativeQualifier::NEXT || |
| relative_qualifier == RelativeQualifier::TOMORROW || |
| relative_qualifier == RelativeQualifier::YESTERDAY || |
| relative_qualifier == RelativeQualifier::LAST || |
| relative_qualifier == RelativeQualifier::THIS || |
| relative_qualifier == RelativeQualifier::NOW) { |
| return true; |
| } |
| return false; |
| } |
| |
| namespace { |
| std::string FormatMillis(int64 time_ms_utc) { |
| long time_seconds = time_ms_utc / 1000; // NOLINT |
| char buffer[512]; |
| strftime(buffer, sizeof(buffer), "%a %Y-%m-%d %H:%M:%S %Z", |
| localtime(&time_seconds)); |
| return std::string(buffer); |
| } |
| } // namespace |
| |
| std::string ComponentTypeToString( |
| const DatetimeComponent::ComponentType& component_type) { |
| switch (component_type) { |
| case DatetimeComponent::ComponentType::UNSPECIFIED: |
| return "UNSPECIFIED"; |
| case DatetimeComponent::ComponentType::YEAR: |
| return "YEAR"; |
| case DatetimeComponent::ComponentType::MONTH: |
| return "MONTH"; |
| case DatetimeComponent::ComponentType::WEEK: |
| return "WEEK"; |
| case DatetimeComponent::ComponentType::DAY_OF_WEEK: |
| return "DAY_OF_WEEK"; |
| case DatetimeComponent::ComponentType::DAY_OF_MONTH: |
| return "DAY_OF_MONTH"; |
| case DatetimeComponent::ComponentType::HOUR: |
| return "HOUR"; |
| case DatetimeComponent::ComponentType::MINUTE: |
| return "MINUTE"; |
| case DatetimeComponent::ComponentType::SECOND: |
| return "SECOND"; |
| case DatetimeComponent::ComponentType::MERIDIEM: |
| return "MERIDIEM"; |
| case DatetimeComponent::ComponentType::ZONE_OFFSET: |
| return "ZONE_OFFSET"; |
| case DatetimeComponent::ComponentType::DST_OFFSET: |
| return "DST_OFFSET"; |
| default: |
| return ""; |
| } |
| } |
| |
| std::string RelativeQualifierToString( |
| const DatetimeComponent::RelativeQualifier& relative_qualifier) { |
| switch (relative_qualifier) { |
| case DatetimeComponent::RelativeQualifier::UNSPECIFIED: |
| return "UNSPECIFIED"; |
| case DatetimeComponent::RelativeQualifier::NEXT: |
| return "NEXT"; |
| case DatetimeComponent::RelativeQualifier::THIS: |
| return "THIS"; |
| case DatetimeComponent::RelativeQualifier::LAST: |
| return "LAST"; |
| case DatetimeComponent::RelativeQualifier::NOW: |
| return "NOW"; |
| case DatetimeComponent::RelativeQualifier::TOMORROW: |
| return "TOMORROW"; |
| case DatetimeComponent::RelativeQualifier::YESTERDAY: |
| return "YESTERDAY"; |
| case DatetimeComponent::RelativeQualifier::PAST: |
| return "PAST"; |
| case DatetimeComponent::RelativeQualifier::FUTURE: |
| return "FUTURE"; |
| default: |
| return ""; |
| } |
| } |
| |
| logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream, |
| const DatetimeParseResultSpan& value) { |
| stream << "DatetimeParseResultSpan({" << value.span.first << ", " |
| << value.span.second << "}, " |
| << "/*target_classification_score=*/ " |
| << value.target_classification_score << "/*priority_score=*/" |
| << value.priority_score << " {"; |
| for (const DatetimeParseResult& data : value.data) { |
| stream << "{/*time_ms_utc=*/ " << data.time_ms_utc << " /* " |
| << FormatMillis(data.time_ms_utc) << " */, /*granularity=*/ " |
| << data.granularity << ", /*datetime_components=*/ "; |
| for (const DatetimeComponent& datetime_comp : data.datetime_components) { |
| stream << "{/*component_type=*/ " |
| << ComponentTypeToString(datetime_comp.component_type) |
| << " /*relative_qualifier=*/ " |
| << RelativeQualifierToString(datetime_comp.relative_qualifier) |
| << " /*value=*/ " << datetime_comp.value << " /*relative_count=*/ " |
| << datetime_comp.relative_count << "}, "; |
| } |
| stream << "}, "; |
| } |
| stream << "})"; |
| return stream; |
| } |
| |
| bool ClassificationResult::operator==(const ClassificationResult& other) const { |
| return ClassificationResultsEqualIgnoringScoresAndSerializedEntityData( |
| *this, other) && |
| fabs(score - other.score) < 0.001 && |
| fabs(priority_score - other.priority_score) < 0.001 && |
| serialized_entity_data == other.serialized_entity_data; |
| } |
| |
| bool ClassificationResultsEqualIgnoringScoresAndSerializedEntityData( |
| const ClassificationResult& a, const ClassificationResult& b) { |
| return a.collection == b.collection && |
| a.datetime_parse_result == b.datetime_parse_result && |
| a.serialized_knowledge_result == b.serialized_knowledge_result && |
| a.contact_pointer == b.contact_pointer && |
| a.contact_name == b.contact_name && |
| a.contact_given_name == b.contact_given_name && |
| a.contact_family_name == b.contact_family_name && |
| a.contact_nickname == b.contact_nickname && |
| a.contact_email_address == b.contact_email_address && |
| a.contact_phone_number == b.contact_phone_number && |
| a.contact_id == b.contact_id && |
| a.app_package_name == b.app_package_name && |
| a.numeric_value == b.numeric_value && |
| fabs(a.numeric_double_value - b.numeric_double_value) < 0.001 && |
| a.duration_ms == b.duration_ms; |
| } |
| |
| logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream, |
| const ClassificationResult& result) { |
| return stream << "ClassificationResult(" << result.collection |
| << ", /*score=*/ " << result.score << ", /*priority_score=*/ " |
| << result.priority_score << ")"; |
| } |
| |
| logging::LoggingStringStream& operator<<( |
| logging::LoggingStringStream& stream, |
| const std::vector<ClassificationResult>& results) { |
| stream = stream << "{\n"; |
| for (const ClassificationResult& result : results) { |
| stream = stream << " " << result << "\n"; |
| } |
| stream = stream << "}"; |
| return stream; |
| } |
| |
| logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream, |
| const AnnotatedSpan& span) { |
| std::string best_class; |
| float best_score = -1; |
| if (!span.classification.empty()) { |
| best_class = span.classification[0].collection; |
| best_score = span.classification[0].score; |
| } |
| return stream << "Span(" << span.span.first << ", " << span.span.second |
| << ", " << best_class << ", " << best_score << ")"; |
| } |
| |
| logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream, |
| const DatetimeParsedData& data) { |
| std::vector<DatetimeComponent> date_time_components; |
| data.GetDatetimeComponents(&date_time_components); |
| stream = stream << "DatetimeParsedData { \n"; |
| for (const DatetimeComponent& c : date_time_components) { |
| stream = stream << " DatetimeComponent { \n"; |
| stream = stream << " Component Type:" << static_cast<int>(c.component_type) |
| << "\n"; |
| stream = stream << " Value:" << c.value << "\n"; |
| stream = stream << " Relative Qualifier:" |
| << static_cast<int>(c.relative_qualifier) << "\n"; |
| stream = stream << " Relative Count:" << c.relative_count << "\n"; |
| stream = stream << " } \n"; |
| } |
| stream = stream << "}"; |
| return stream; |
| } |
| |
| void DatetimeParsedData::SetAbsoluteValue( |
| const DatetimeComponent::ComponentType& field_type, int value) { |
| GetOrCreateDatetimeComponent(field_type).value = value; |
| } |
| |
| void DatetimeParsedData::SetRelativeValue( |
| const DatetimeComponent::ComponentType& field_type, |
| const DatetimeComponent::RelativeQualifier& relative_value) { |
| GetOrCreateDatetimeComponent(field_type).relative_qualifier = relative_value; |
| } |
| |
| void DatetimeParsedData::SetRelativeCount( |
| const DatetimeComponent::ComponentType& field_type, int relative_count) { |
| GetOrCreateDatetimeComponent(field_type).relative_count = relative_count; |
| } |
| |
| void DatetimeParsedData::AddDatetimeComponents( |
| const std::vector<DatetimeComponent>& datetime_components) { |
| for (const DatetimeComponent& datetime_component : datetime_components) { |
| date_time_components_.insert( |
| {datetime_component.component_type, datetime_component}); |
| } |
| } |
| |
| bool DatetimeParsedData::HasFieldType( |
| const DatetimeComponent::ComponentType& field_type) const { |
| if (date_time_components_.find(field_type) == date_time_components_.end()) { |
| return false; |
| } |
| return true; |
| } |
| |
| bool DatetimeParsedData::GetFieldValue( |
| const DatetimeComponent::ComponentType& field_type, |
| int* field_value) const { |
| if (HasFieldType(field_type)) { |
| *field_value = date_time_components_.at(field_type).value; |
| return true; |
| } |
| return false; |
| } |
| |
| bool DatetimeParsedData::GetRelativeValue( |
| const DatetimeComponent::ComponentType& field_type, |
| DatetimeComponent::RelativeQualifier* relative_value) const { |
| if (HasFieldType(field_type)) { |
| *relative_value = date_time_components_.at(field_type).relative_qualifier; |
| return true; |
| } |
| return false; |
| } |
| |
| bool DatetimeParsedData::HasRelativeValue( |
| const DatetimeComponent::ComponentType& field_type) const { |
| if (HasFieldType(field_type)) { |
| return date_time_components_.at(field_type).relative_qualifier != |
| DatetimeComponent::RelativeQualifier::UNSPECIFIED; |
| } |
| return false; |
| } |
| |
| bool DatetimeParsedData::HasAbsoluteValue( |
| const DatetimeComponent::ComponentType& field_type) const { |
| return HasFieldType(field_type) && !HasRelativeValue(field_type); |
| } |
| |
| bool DatetimeParsedData::IsEmpty() const { |
| return date_time_components_.empty(); |
| } |
| |
| void DatetimeParsedData::GetRelativeDatetimeComponents( |
| std::vector<DatetimeComponent>* date_time_components) const { |
| for (auto it = date_time_components_.begin(); |
| it != date_time_components_.end(); it++) { |
| if (it->second.relative_qualifier != |
| DatetimeComponent::RelativeQualifier::UNSPECIFIED) { |
| date_time_components->push_back(it->second); |
| } |
| } |
| } |
| |
| void DatetimeParsedData::GetDatetimeComponents( |
| std::vector<DatetimeComponent>* date_time_components) const { |
| for (auto it = date_time_components_.begin(); |
| it != date_time_components_.end(); it++) { |
| date_time_components->push_back(it->second); |
| } |
| } |
| |
| DatetimeComponent& DatetimeParsedData::GetOrCreateDatetimeComponent( |
| const DatetimeComponent::ComponentType& component_type) { |
| auto result = |
| date_time_components_ |
| .insert( |
| {component_type, |
| DatetimeComponent( |
| component_type, |
| DatetimeComponent::RelativeQualifier::UNSPECIFIED, 0, 0)}) |
| .first; |
| return result->second; |
| } |
| |
| namespace { |
| DatetimeGranularity GetFinestGranularityFromComponentTypes( |
| const std::vector<DatetimeComponent::ComponentType>& |
| datetime_component_types) { |
| DatetimeGranularity granularity = DatetimeGranularity::GRANULARITY_UNKNOWN; |
| for (const auto& component_type : datetime_component_types) { |
| switch (component_type) { |
| case DatetimeComponent::ComponentType::YEAR: |
| if (granularity < DatetimeGranularity::GRANULARITY_YEAR) { |
| granularity = DatetimeGranularity::GRANULARITY_YEAR; |
| } |
| break; |
| |
| case DatetimeComponent::ComponentType::MONTH: |
| if (granularity < DatetimeGranularity::GRANULARITY_MONTH) { |
| granularity = DatetimeGranularity::GRANULARITY_MONTH; |
| } |
| break; |
| |
| case DatetimeComponent::ComponentType::WEEK: |
| if (granularity < DatetimeGranularity::GRANULARITY_WEEK) { |
| granularity = DatetimeGranularity::GRANULARITY_WEEK; |
| } |
| break; |
| |
| case DatetimeComponent::ComponentType::DAY_OF_WEEK: |
| case DatetimeComponent::ComponentType::DAY_OF_MONTH: |
| if (granularity < DatetimeGranularity::GRANULARITY_DAY) { |
| granularity = DatetimeGranularity::GRANULARITY_DAY; |
| } |
| break; |
| |
| case DatetimeComponent::ComponentType::HOUR: |
| if (granularity < DatetimeGranularity::GRANULARITY_HOUR) { |
| granularity = DatetimeGranularity::GRANULARITY_HOUR; |
| } |
| break; |
| |
| case DatetimeComponent::ComponentType::MINUTE: |
| if (granularity < DatetimeGranularity::GRANULARITY_MINUTE) { |
| granularity = DatetimeGranularity::GRANULARITY_MINUTE; |
| } |
| break; |
| |
| case DatetimeComponent::ComponentType::SECOND: |
| if (granularity < DatetimeGranularity::GRANULARITY_SECOND) { |
| granularity = DatetimeGranularity::GRANULARITY_SECOND; |
| } |
| break; |
| |
| case DatetimeComponent::ComponentType::MERIDIEM: |
| case DatetimeComponent::ComponentType::ZONE_OFFSET: |
| case DatetimeComponent::ComponentType::DST_OFFSET: |
| default: |
| break; |
| } |
| } |
| return granularity; |
| } |
| } // namespace |
| |
| DatetimeGranularity DatetimeParsedData::GetFinestGranularity() const { |
| std::vector<DatetimeComponent::ComponentType> component_types; |
| std::transform(date_time_components_.begin(), date_time_components_.end(), |
| std::back_inserter(component_types), |
| [](const std::map<DatetimeComponent::ComponentType, |
| DatetimeComponent>::value_type& pair) { |
| return pair.first; |
| }); |
| return GetFinestGranularityFromComponentTypes(component_types); |
| } |
| |
| Optional<DatetimeComponent> GetDatetimeComponent( |
| const std::vector<DatetimeComponent>& datetime_components, |
| const DatetimeComponent::ComponentType& component_type) { |
| for (auto datetime_component : datetime_components) { |
| if (datetime_component.component_type == component_type) { |
| return Optional<DatetimeComponent>(datetime_component); |
| } |
| } |
| return Optional<DatetimeComponent>(); |
| } |
| |
| // Returns the granularity of the DatetimeComponents. |
| DatetimeGranularity GetFinestGranularity( |
| const std::vector<DatetimeComponent>& datetime_component) { |
| std::vector<DatetimeComponent::ComponentType> component_types; |
| std::transform(datetime_component.begin(), datetime_component.end(), |
| std::back_inserter(component_types), |
| [](const DatetimeComponent& component) { |
| return component.component_type; |
| }); |
| return GetFinestGranularityFromComponentTypes(component_types); |
| } |
| |
| } // namespace libtextclassifier3 |