blob: 6d51c199da30017c1e7b6712133224c65b9984b5 [file] [log] [blame]
/*
* Copyright (C) 2018 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "annotator/datetime/grammar-parser.h"
#include <set>
#include <unordered_set>
#include "annotator/datetime/datetime-grounder.h"
#include "annotator/types.h"
#include "utils/grammar/analyzer.h"
#include "utils/grammar/evaluated-derivation.h"
#include "utils/grammar/parsing/derivation.h"
using ::libtextclassifier3::grammar::EvaluatedDerivation;
using ::libtextclassifier3::grammar::datetime::UngroundedDatetime;
namespace libtextclassifier3 {
GrammarDatetimeParser::GrammarDatetimeParser(
const grammar::Analyzer& analyzer,
const DatetimeGrounder& datetime_grounder,
const float target_classification_score, const float priority_score)
: analyzer_(analyzer),
datetime_grounder_(datetime_grounder),
target_classification_score_(target_classification_score),
priority_score_(priority_score) {}
StatusOr<std::vector<DatetimeParseResultSpan>> GrammarDatetimeParser::Parse(
const std::string& input, const int64 reference_time_ms_utc,
const std::string& reference_timezone, const LocaleList& locale_list,
ModeFlag mode, AnnotationUsecase annotation_usecase,
bool anchor_start_end) const {
return Parse(UTF8ToUnicodeText(input, /*do_copy=*/false),
reference_time_ms_utc, reference_timezone, locale_list, mode,
annotation_usecase, anchor_start_end);
}
StatusOr<std::vector<DatetimeParseResultSpan>> GrammarDatetimeParser::Parse(
const UnicodeText& input, const int64 reference_time_ms_utc,
const std::string& reference_timezone, const LocaleList& locale_list,
ModeFlag mode, AnnotationUsecase annotation_usecase,
bool anchor_start_end) const {
std::vector<DatetimeParseResultSpan> results;
UnsafeArena arena(/*block_size=*/16 << 10);
std::vector<Locale> locales = locale_list.GetLocales();
// If the locale list is empty then datetime regex expression will still
// execute but in grammar based parser the rules are associated with local
// and engine will not run if the locale list is empty. In an unlikely
// scenario when locale is not mentioned fallback to en-*.
if (locales.empty()) {
locales.emplace_back(Locale::FromBCP47("en"));
}
TC3_ASSIGN_OR_RETURN(
const std::vector<EvaluatedDerivation> evaluated_derivations,
analyzer_.Parse(input, locales, &arena,
/*deduplicate_derivations=*/false));
std::vector<EvaluatedDerivation> valid_evaluated_derivations;
for (const EvaluatedDerivation& evaluated_derivation :
evaluated_derivations) {
if (evaluated_derivation.value) {
if (evaluated_derivation.value->Has<flatbuffers::Table>()) {
const UngroundedDatetime* ungrounded_datetime =
evaluated_derivation.value->Table<UngroundedDatetime>();
if (datetime_grounder_.IsValidUngroundedDatetime(ungrounded_datetime)) {
valid_evaluated_derivations.emplace_back(evaluated_derivation);
}
}
}
}
valid_evaluated_derivations =
grammar::DeduplicateDerivations(valid_evaluated_derivations);
for (const EvaluatedDerivation& evaluated_derivation :
valid_evaluated_derivations) {
if (evaluated_derivation.value) {
if (evaluated_derivation.value->Has<flatbuffers::Table>()) {
const UngroundedDatetime* ungrounded_datetime =
evaluated_derivation.value->Table<UngroundedDatetime>();
if ((ungrounded_datetime->annotation_usecases() &
(1 << annotation_usecase)) == 0) {
continue;
}
const StatusOr<std::vector<DatetimeParseResult>>&
datetime_parse_results = datetime_grounder_.Ground(
reference_time_ms_utc, reference_timezone,
locale_list.GetReferenceLocale(), ungrounded_datetime);
TC3_ASSIGN_OR_RETURN(
const std::vector<DatetimeParseResult>& parse_datetime,
datetime_parse_results);
DatetimeParseResultSpan datetime_parse_result_span;
datetime_parse_result_span.target_classification_score =
target_classification_score_;
datetime_parse_result_span.priority_score = priority_score_;
datetime_parse_result_span.data.reserve(parse_datetime.size());
datetime_parse_result_span.data.insert(
datetime_parse_result_span.data.end(), parse_datetime.begin(),
parse_datetime.end());
datetime_parse_result_span.span =
evaluated_derivation.parse_tree->codepoint_span;
results.emplace_back(datetime_parse_result_span);
}
}
}
return results;
}
} // namespace libtextclassifier3