blob: b1dde17c4e337eec0c91a1feface728d6d2cfe1e [file] [log] [blame]
/*
* Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "annotator/types.h"
#include <vector>
#include "utils/optional.h"
namespace libtextclassifier3 {
const CodepointSpan CodepointSpan::kInvalid =
CodepointSpan(kInvalidIndex, kInvalidIndex);
const TokenSpan TokenSpan::kInvalid = TokenSpan(kInvalidIndex, kInvalidIndex);
logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
const CodepointSpan& span) {
return stream << "CodepointSpan(" << span.first << ", " << span.second << ")";
}
logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
const TokenSpan& span) {
return stream << "TokenSpan(" << span.first << ", " << span.second << ")";
}
logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
const Token& token) {
if (!token.is_padding) {
return stream << "Token(\"" << token.value << "\", " << token.start << ", "
<< token.end << ")";
} else {
return stream << "Token()";
}
}
bool DatetimeComponent::ShouldRoundToGranularity() const {
// Don't round to the granularity for relative expressions that specify the
// distance. So that, e.g. "in 2 hours" when it's 8:35:03 will result in
// 10:35:03.
if (relative_qualifier == RelativeQualifier::UNSPECIFIED) {
return false;
}
if (relative_qualifier == RelativeQualifier::NEXT ||
relative_qualifier == RelativeQualifier::TOMORROW ||
relative_qualifier == RelativeQualifier::YESTERDAY ||
relative_qualifier == RelativeQualifier::LAST ||
relative_qualifier == RelativeQualifier::THIS ||
relative_qualifier == RelativeQualifier::NOW) {
return true;
}
return false;
}
namespace {
std::string FormatMillis(int64 time_ms_utc) {
long time_seconds = time_ms_utc / 1000; // NOLINT
char buffer[512];
strftime(buffer, sizeof(buffer), "%a %Y-%m-%d %H:%M:%S %Z",
localtime(&time_seconds));
return std::string(buffer);
}
} // namespace
std::string ComponentTypeToString(
const DatetimeComponent::ComponentType& component_type) {
switch (component_type) {
case DatetimeComponent::ComponentType::UNSPECIFIED:
return "UNSPECIFIED";
case DatetimeComponent::ComponentType::YEAR:
return "YEAR";
case DatetimeComponent::ComponentType::MONTH:
return "MONTH";
case DatetimeComponent::ComponentType::WEEK:
return "WEEK";
case DatetimeComponent::ComponentType::DAY_OF_WEEK:
return "DAY_OF_WEEK";
case DatetimeComponent::ComponentType::DAY_OF_MONTH:
return "DAY_OF_MONTH";
case DatetimeComponent::ComponentType::HOUR:
return "HOUR";
case DatetimeComponent::ComponentType::MINUTE:
return "MINUTE";
case DatetimeComponent::ComponentType::SECOND:
return "SECOND";
case DatetimeComponent::ComponentType::MERIDIEM:
return "MERIDIEM";
case DatetimeComponent::ComponentType::ZONE_OFFSET:
return "ZONE_OFFSET";
case DatetimeComponent::ComponentType::DST_OFFSET:
return "DST_OFFSET";
default:
return "";
}
}
std::string RelativeQualifierToString(
const DatetimeComponent::RelativeQualifier& relative_qualifier) {
switch (relative_qualifier) {
case DatetimeComponent::RelativeQualifier::UNSPECIFIED:
return "UNSPECIFIED";
case DatetimeComponent::RelativeQualifier::NEXT:
return "NEXT";
case DatetimeComponent::RelativeQualifier::THIS:
return "THIS";
case DatetimeComponent::RelativeQualifier::LAST:
return "LAST";
case DatetimeComponent::RelativeQualifier::NOW:
return "NOW";
case DatetimeComponent::RelativeQualifier::TOMORROW:
return "TOMORROW";
case DatetimeComponent::RelativeQualifier::YESTERDAY:
return "YESTERDAY";
case DatetimeComponent::RelativeQualifier::PAST:
return "PAST";
case DatetimeComponent::RelativeQualifier::FUTURE:
return "FUTURE";
default:
return "";
}
}
logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
const DatetimeParseResultSpan& value) {
stream << "DatetimeParseResultSpan({" << value.span.first << ", "
<< value.span.second << "}, "
<< "/*target_classification_score=*/ "
<< value.target_classification_score << "/*priority_score=*/"
<< value.priority_score << " {";
for (const DatetimeParseResult& data : value.data) {
stream << "{/*time_ms_utc=*/ " << data.time_ms_utc << " /* "
<< FormatMillis(data.time_ms_utc) << " */, /*granularity=*/ "
<< data.granularity << ", /*datetime_components=*/ ";
for (const DatetimeComponent& datetime_comp : data.datetime_components) {
stream << "{/*component_type=*/ "
<< ComponentTypeToString(datetime_comp.component_type)
<< " /*relative_qualifier=*/ "
<< RelativeQualifierToString(datetime_comp.relative_qualifier)
<< " /*value=*/ " << datetime_comp.value << " /*relative_count=*/ "
<< datetime_comp.relative_count << "}, ";
}
stream << "}, ";
}
stream << "})";
return stream;
}
bool ClassificationResult::operator==(const ClassificationResult& other) const {
return ClassificationResultsEqualIgnoringScoresAndSerializedEntityData(
*this, other) &&
fabs(score - other.score) < 0.001 &&
fabs(priority_score - other.priority_score) < 0.001 &&
serialized_entity_data == other.serialized_entity_data;
}
bool ClassificationResultsEqualIgnoringScoresAndSerializedEntityData(
const ClassificationResult& a, const ClassificationResult& b) {
return a.collection == b.collection &&
a.datetime_parse_result == b.datetime_parse_result &&
a.serialized_knowledge_result == b.serialized_knowledge_result &&
a.contact_pointer == b.contact_pointer &&
a.contact_name == b.contact_name &&
a.contact_given_name == b.contact_given_name &&
a.contact_family_name == b.contact_family_name &&
a.contact_nickname == b.contact_nickname &&
a.contact_email_address == b.contact_email_address &&
a.contact_phone_number == b.contact_phone_number &&
a.contact_id == b.contact_id &&
a.app_package_name == b.app_package_name &&
a.numeric_value == b.numeric_value &&
fabs(a.numeric_double_value - b.numeric_double_value) < 0.001 &&
a.duration_ms == b.duration_ms;
}
logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
const ClassificationResult& result) {
return stream << "ClassificationResult(" << result.collection
<< ", /*score=*/ " << result.score << ", /*priority_score=*/ "
<< result.priority_score << ")";
}
logging::LoggingStringStream& operator<<(
logging::LoggingStringStream& stream,
const std::vector<ClassificationResult>& results) {
stream = stream << "{\n";
for (const ClassificationResult& result : results) {
stream = stream << " " << result << "\n";
}
stream = stream << "}";
return stream;
}
logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
const AnnotatedSpan& span) {
std::string best_class;
float best_score = -1;
if (!span.classification.empty()) {
best_class = span.classification[0].collection;
best_score = span.classification[0].score;
}
return stream << "Span(" << span.span.first << ", " << span.span.second
<< ", " << best_class << ", " << best_score << ")";
}
logging::LoggingStringStream& operator<<(logging::LoggingStringStream& stream,
const DatetimeParsedData& data) {
std::vector<DatetimeComponent> date_time_components;
data.GetDatetimeComponents(&date_time_components);
stream = stream << "DatetimeParsedData { \n";
for (const DatetimeComponent& c : date_time_components) {
stream = stream << " DatetimeComponent { \n";
stream = stream << " Component Type:" << static_cast<int>(c.component_type)
<< "\n";
stream = stream << " Value:" << c.value << "\n";
stream = stream << " Relative Qualifier:"
<< static_cast<int>(c.relative_qualifier) << "\n";
stream = stream << " Relative Count:" << c.relative_count << "\n";
stream = stream << " } \n";
}
stream = stream << "}";
return stream;
}
void DatetimeParsedData::SetAbsoluteValue(
const DatetimeComponent::ComponentType& field_type, int value) {
GetOrCreateDatetimeComponent(field_type).value = value;
}
void DatetimeParsedData::SetRelativeValue(
const DatetimeComponent::ComponentType& field_type,
const DatetimeComponent::RelativeQualifier& relative_value) {
GetOrCreateDatetimeComponent(field_type).relative_qualifier = relative_value;
}
void DatetimeParsedData::SetRelativeCount(
const DatetimeComponent::ComponentType& field_type, int relative_count) {
GetOrCreateDatetimeComponent(field_type).relative_count = relative_count;
}
void DatetimeParsedData::AddDatetimeComponents(
const std::vector<DatetimeComponent>& datetime_components) {
for (const DatetimeComponent& datetime_component : datetime_components) {
date_time_components_.insert(
{datetime_component.component_type, datetime_component});
}
}
bool DatetimeParsedData::HasFieldType(
const DatetimeComponent::ComponentType& field_type) const {
if (date_time_components_.find(field_type) == date_time_components_.end()) {
return false;
}
return true;
}
bool DatetimeParsedData::GetFieldValue(
const DatetimeComponent::ComponentType& field_type,
int* field_value) const {
if (HasFieldType(field_type)) {
*field_value = date_time_components_.at(field_type).value;
return true;
}
return false;
}
bool DatetimeParsedData::GetRelativeValue(
const DatetimeComponent::ComponentType& field_type,
DatetimeComponent::RelativeQualifier* relative_value) const {
if (HasFieldType(field_type)) {
*relative_value = date_time_components_.at(field_type).relative_qualifier;
return true;
}
return false;
}
bool DatetimeParsedData::HasRelativeValue(
const DatetimeComponent::ComponentType& field_type) const {
if (HasFieldType(field_type)) {
return date_time_components_.at(field_type).relative_qualifier !=
DatetimeComponent::RelativeQualifier::UNSPECIFIED;
}
return false;
}
bool DatetimeParsedData::HasAbsoluteValue(
const DatetimeComponent::ComponentType& field_type) const {
return HasFieldType(field_type) && !HasRelativeValue(field_type);
}
bool DatetimeParsedData::IsEmpty() const {
return date_time_components_.empty();
}
void DatetimeParsedData::GetRelativeDatetimeComponents(
std::vector<DatetimeComponent>* date_time_components) const {
for (auto it = date_time_components_.begin();
it != date_time_components_.end(); it++) {
if (it->second.relative_qualifier !=
DatetimeComponent::RelativeQualifier::UNSPECIFIED) {
date_time_components->push_back(it->second);
}
}
}
void DatetimeParsedData::GetDatetimeComponents(
std::vector<DatetimeComponent>* date_time_components) const {
for (auto it = date_time_components_.begin();
it != date_time_components_.end(); it++) {
date_time_components->push_back(it->second);
}
}
DatetimeComponent& DatetimeParsedData::GetOrCreateDatetimeComponent(
const DatetimeComponent::ComponentType& component_type) {
auto result =
date_time_components_
.insert(
{component_type,
DatetimeComponent(
component_type,
DatetimeComponent::RelativeQualifier::UNSPECIFIED, 0, 0)})
.first;
return result->second;
}
namespace {
DatetimeGranularity GetFinestGranularityFromComponentTypes(
const std::vector<DatetimeComponent::ComponentType>&
datetime_component_types) {
DatetimeGranularity granularity = DatetimeGranularity::GRANULARITY_UNKNOWN;
for (const auto& component_type : datetime_component_types) {
switch (component_type) {
case DatetimeComponent::ComponentType::YEAR:
if (granularity < DatetimeGranularity::GRANULARITY_YEAR) {
granularity = DatetimeGranularity::GRANULARITY_YEAR;
}
break;
case DatetimeComponent::ComponentType::MONTH:
if (granularity < DatetimeGranularity::GRANULARITY_MONTH) {
granularity = DatetimeGranularity::GRANULARITY_MONTH;
}
break;
case DatetimeComponent::ComponentType::WEEK:
if (granularity < DatetimeGranularity::GRANULARITY_WEEK) {
granularity = DatetimeGranularity::GRANULARITY_WEEK;
}
break;
case DatetimeComponent::ComponentType::DAY_OF_WEEK:
case DatetimeComponent::ComponentType::DAY_OF_MONTH:
if (granularity < DatetimeGranularity::GRANULARITY_DAY) {
granularity = DatetimeGranularity::GRANULARITY_DAY;
}
break;
case DatetimeComponent::ComponentType::HOUR:
if (granularity < DatetimeGranularity::GRANULARITY_HOUR) {
granularity = DatetimeGranularity::GRANULARITY_HOUR;
}
break;
case DatetimeComponent::ComponentType::MINUTE:
if (granularity < DatetimeGranularity::GRANULARITY_MINUTE) {
granularity = DatetimeGranularity::GRANULARITY_MINUTE;
}
break;
case DatetimeComponent::ComponentType::SECOND:
if (granularity < DatetimeGranularity::GRANULARITY_SECOND) {
granularity = DatetimeGranularity::GRANULARITY_SECOND;
}
break;
case DatetimeComponent::ComponentType::MERIDIEM:
case DatetimeComponent::ComponentType::ZONE_OFFSET:
case DatetimeComponent::ComponentType::DST_OFFSET:
default:
break;
}
}
return granularity;
}
} // namespace
DatetimeGranularity DatetimeParsedData::GetFinestGranularity() const {
std::vector<DatetimeComponent::ComponentType> component_types;
std::transform(date_time_components_.begin(), date_time_components_.end(),
std::back_inserter(component_types),
[](const std::map<DatetimeComponent::ComponentType,
DatetimeComponent>::value_type& pair) {
return pair.first;
});
return GetFinestGranularityFromComponentTypes(component_types);
}
Optional<DatetimeComponent> GetDatetimeComponent(
const std::vector<DatetimeComponent>& datetime_components,
const DatetimeComponent::ComponentType& component_type) {
for (auto datetime_component : datetime_components) {
if (datetime_component.component_type == component_type) {
return Optional<DatetimeComponent>(datetime_component);
}
}
return Optional<DatetimeComponent>();
}
// Returns the granularity of the DatetimeComponents.
DatetimeGranularity GetFinestGranularity(
const std::vector<DatetimeComponent>& datetime_component) {
std::vector<DatetimeComponent::ComponentType> component_types;
std::transform(datetime_component.begin(), datetime_component.end(),
std::back_inserter(component_types),
[](const DatetimeComponent& component) {
return component.component_type;
});
return GetFinestGranularityFromComponentTypes(component_types);
}
} // namespace libtextclassifier3