blob: 65c4f9318be09a65edb4e4319ccd4932472fe051 [file] [log] [blame]
/*
* Copyright (C) 2017 The Android Open Source Project
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "smartselect/model-params.h"
#include "common/memory_image/memory-image-reader.h"
namespace libtextclassifier {
using nlp_core::EmbeddingNetworkProto;
using nlp_core::MemoryImageReader;
ModelParams* ModelParamsBuilder(
const void* start, uint64 num_bytes,
std::shared_ptr<EmbeddingParams> external_embedding_params) {
MemoryImageReader<EmbeddingNetworkProto> reader(start, num_bytes);
ModelOptions model_options;
auto model_options_extension_id = model_options_in_embedding_network_proto;
if (reader.trimmed_proto().HasExtension(model_options_extension_id)) {
model_options =
reader.trimmed_proto().GetExtension(model_options_extension_id);
}
FeatureProcessorOptions feature_processor_options;
auto feature_processor_extension_id =
feature_processor_options_in_embedding_network_proto;
if (reader.trimmed_proto().HasExtension(feature_processor_extension_id)) {
feature_processor_options =
reader.trimmed_proto().GetExtension(feature_processor_extension_id);
// If no tokenization codepoint config is present, tokenize on space.
// TODO(zilka): Remove the default config.
if (feature_processor_options.tokenization_codepoint_config_size() == 0) {
TokenizationCodepointRange* config;
// New line character.
config = feature_processor_options.add_tokenization_codepoint_config();
config->set_start(10);
config->set_end(11);
config->set_role(TokenizationCodepointRange::WHITESPACE_SEPARATOR);
// Space character.
config = feature_processor_options.add_tokenization_codepoint_config();
config->set_start(32);
config->set_end(33);
config->set_role(TokenizationCodepointRange::WHITESPACE_SEPARATOR);
}
} else {
return nullptr;
}
SelectionModelOptions selection_options;
auto selection_options_extension_id =
selection_model_options_in_embedding_network_proto;
if (reader.trimmed_proto().HasExtension(selection_options_extension_id)) {
selection_options =
reader.trimmed_proto().GetExtension(selection_options_extension_id);
// For backward compatibility with the current models.
if (!feature_processor_options.ignored_span_boundary_codepoints_size()) {
*feature_processor_options.mutable_ignored_span_boundary_codepoints() =
selection_options.deprecated_punctuation_to_strip();
}
} else {
selection_options.set_enforce_symmetry(true);
selection_options.set_symmetry_context_size(
feature_processor_options.context_size() * 2);
}
SharingModelOptions sharing_options;
auto sharing_options_extension_id =
sharing_model_options_in_embedding_network_proto;
if (reader.trimmed_proto().HasExtension(sharing_options_extension_id)) {
sharing_options =
reader.trimmed_proto().GetExtension(sharing_options_extension_id);
} else {
// Default values when SharingModelOptions is not present.
sharing_options.set_always_accept_url_hint(true);
sharing_options.set_always_accept_email_hint(true);
}
if (!model_options.use_shared_embeddings()) {
std::shared_ptr<EmbeddingParams> embedding_params(new EmbeddingParams(
start, num_bytes, feature_processor_options.context_size()));
return new ModelParams(start, num_bytes, embedding_params,
selection_options, sharing_options,
feature_processor_options);
} else {
return new ModelParams(
start, num_bytes, std::move(external_embedding_params),
selection_options, sharing_options, feature_processor_options);
}
}
} // namespace libtextclassifier