| /* |
| * Copyright (C) 2017 The Android Open Source Project |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| #include "smartselect/model-params.h" |
| |
| #include "common/memory_image/memory-image-reader.h" |
| |
| namespace libtextclassifier { |
| |
| using nlp_core::EmbeddingNetworkProto; |
| using nlp_core::MemoryImageReader; |
| |
| ModelParams* ModelParamsBuilder( |
| const void* start, uint64 num_bytes, |
| std::shared_ptr<EmbeddingParams> external_embedding_params) { |
| MemoryImageReader<EmbeddingNetworkProto> reader(start, num_bytes); |
| |
| ModelOptions model_options; |
| auto model_options_extension_id = model_options_in_embedding_network_proto; |
| if (reader.trimmed_proto().HasExtension(model_options_extension_id)) { |
| model_options = |
| reader.trimmed_proto().GetExtension(model_options_extension_id); |
| } |
| |
| FeatureProcessorOptions feature_processor_options; |
| auto feature_processor_extension_id = |
| feature_processor_options_in_embedding_network_proto; |
| if (reader.trimmed_proto().HasExtension(feature_processor_extension_id)) { |
| feature_processor_options = |
| reader.trimmed_proto().GetExtension(feature_processor_extension_id); |
| |
| // If no tokenization codepoint config is present, tokenize on space. |
| // TODO(zilka): Remove the default config. |
| if (feature_processor_options.tokenization_codepoint_config_size() == 0) { |
| TokenizationCodepointRange* config; |
| // New line character. |
| config = feature_processor_options.add_tokenization_codepoint_config(); |
| config->set_start(10); |
| config->set_end(11); |
| config->set_role(TokenizationCodepointRange::WHITESPACE_SEPARATOR); |
| |
| // Space character. |
| config = feature_processor_options.add_tokenization_codepoint_config(); |
| config->set_start(32); |
| config->set_end(33); |
| config->set_role(TokenizationCodepointRange::WHITESPACE_SEPARATOR); |
| } |
| } else { |
| return nullptr; |
| } |
| |
| SelectionModelOptions selection_options; |
| auto selection_options_extension_id = |
| selection_model_options_in_embedding_network_proto; |
| if (reader.trimmed_proto().HasExtension(selection_options_extension_id)) { |
| selection_options = |
| reader.trimmed_proto().GetExtension(selection_options_extension_id); |
| |
| // For backward compatibility with the current models. |
| if (!feature_processor_options.ignored_span_boundary_codepoints_size()) { |
| *feature_processor_options.mutable_ignored_span_boundary_codepoints() = |
| selection_options.deprecated_punctuation_to_strip(); |
| } |
| } else { |
| selection_options.set_enforce_symmetry(true); |
| selection_options.set_symmetry_context_size( |
| feature_processor_options.context_size() * 2); |
| } |
| |
| SharingModelOptions sharing_options; |
| auto sharing_options_extension_id = |
| sharing_model_options_in_embedding_network_proto; |
| if (reader.trimmed_proto().HasExtension(sharing_options_extension_id)) { |
| sharing_options = |
| reader.trimmed_proto().GetExtension(sharing_options_extension_id); |
| } else { |
| // Default values when SharingModelOptions is not present. |
| sharing_options.set_always_accept_url_hint(true); |
| sharing_options.set_always_accept_email_hint(true); |
| } |
| |
| if (!model_options.use_shared_embeddings()) { |
| std::shared_ptr<EmbeddingParams> embedding_params(new EmbeddingParams( |
| start, num_bytes, feature_processor_options.context_size())); |
| return new ModelParams(start, num_bytes, embedding_params, |
| selection_options, sharing_options, |
| feature_processor_options); |
| } else { |
| return new ModelParams( |
| start, num_bytes, std::move(external_embedding_params), |
| selection_options, sharing_options, feature_processor_options); |
| } |
| } |
| |
| } // namespace libtextclassifier |