| /* |
| * Copyright (C) 2017 The Android Open Source Project |
| * |
| * Licensed under the Apache License, Version 2.0 (the "License"); |
| * you may not use this file except in compliance with the License. |
| * You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| #include "common/task-context.h" |
| |
| #include <stdlib.h> |
| |
| #include <string> |
| |
| #include "util/base/integral_types.h" |
| #include "util/base/logging.h" |
| #include "util/strings/numbers.h" |
| |
| namespace libtextclassifier { |
| namespace nlp_core { |
| |
| namespace { |
| int32 ParseInt32WithDefault(const std::string &s, int32 defval) { |
| int32 value = defval; |
| return ParseInt32(s.c_str(), &value) ? value : defval; |
| } |
| |
| int64 ParseInt64WithDefault(const std::string &s, int64 defval) { |
| int64 value = defval; |
| return ParseInt64(s.c_str(), &value) ? value : defval; |
| } |
| |
| double ParseDoubleWithDefault(const std::string &s, double defval) { |
| double value = defval; |
| return ParseDouble(s.c_str(), &value) ? value : defval; |
| } |
| } // namespace |
| |
| TaskInput *TaskContext::GetInput(const std::string &name) { |
| // Return existing input if it exists. |
| for (int i = 0; i < spec_.input_size(); ++i) { |
| if (spec_.input(i).name() == name) return spec_.mutable_input(i); |
| } |
| |
| // Create new input. |
| TaskInput *input = spec_.add_input(); |
| input->set_name(name); |
| return input; |
| } |
| |
| TaskInput *TaskContext::GetInput(const std::string &name, |
| const std::string &file_format, |
| const std::string &record_format) { |
| TaskInput *input = GetInput(name); |
| if (!file_format.empty()) { |
| bool found = false; |
| for (int i = 0; i < input->file_format_size(); ++i) { |
| if (input->file_format(i) == file_format) found = true; |
| } |
| if (!found) input->add_file_format(file_format); |
| } |
| if (!record_format.empty()) { |
| bool found = false; |
| for (int i = 0; i < input->record_format_size(); ++i) { |
| if (input->record_format(i) == record_format) found = true; |
| } |
| if (!found) input->add_record_format(record_format); |
| } |
| return input; |
| } |
| |
| void TaskContext::SetParameter(const std::string &name, |
| const std::string &value) { |
| TC_LOG(INFO) << "SetParameter(" << name << ", " << value << ")"; |
| |
| // If the parameter already exists update the value. |
| for (int i = 0; i < spec_.parameter_size(); ++i) { |
| if (spec_.parameter(i).name() == name) { |
| spec_.mutable_parameter(i)->set_value(value); |
| return; |
| } |
| } |
| |
| // Add new parameter. |
| TaskSpec::Parameter *param = spec_.add_parameter(); |
| param->set_name(name); |
| param->set_value(value); |
| } |
| |
| std::string TaskContext::GetParameter(const std::string &name) const { |
| // First try to find parameter in task specification. |
| for (int i = 0; i < spec_.parameter_size(); ++i) { |
| if (spec_.parameter(i).name() == name) return spec_.parameter(i).value(); |
| } |
| |
| // Parameter not found, return empty std::string. |
| return ""; |
| } |
| |
| int TaskContext::GetIntParameter(const std::string &name) const { |
| std::string value = GetParameter(name); |
| return ParseInt32WithDefault(value, 0); |
| } |
| |
| int64 TaskContext::GetInt64Parameter(const std::string &name) const { |
| std::string value = GetParameter(name); |
| return ParseInt64WithDefault(value, 0); |
| } |
| |
| bool TaskContext::GetBoolParameter(const std::string &name) const { |
| std::string value = GetParameter(name); |
| return value == "true"; |
| } |
| |
| double TaskContext::GetFloatParameter(const std::string &name) const { |
| std::string value = GetParameter(name); |
| return ParseDoubleWithDefault(value, 0.0); |
| } |
| |
| std::string TaskContext::Get(const std::string &name, |
| const char *defval) const { |
| // First try to find parameter in task specification. |
| for (int i = 0; i < spec_.parameter_size(); ++i) { |
| if (spec_.parameter(i).name() == name) return spec_.parameter(i).value(); |
| } |
| |
| // Parameter not found, return default value. |
| return defval; |
| } |
| |
| std::string TaskContext::Get(const std::string &name, |
| const std::string &defval) const { |
| return Get(name, defval.c_str()); |
| } |
| |
| int TaskContext::Get(const std::string &name, int defval) const { |
| std::string value = Get(name, ""); |
| return ParseInt32WithDefault(value, defval); |
| } |
| |
| int64 TaskContext::Get(const std::string &name, int64 defval) const { |
| std::string value = Get(name, ""); |
| return ParseInt64WithDefault(value, defval); |
| } |
| |
| double TaskContext::Get(const std::string &name, double defval) const { |
| std::string value = Get(name, ""); |
| return ParseDoubleWithDefault(value, defval); |
| } |
| |
| bool TaskContext::Get(const std::string &name, bool defval) const { |
| std::string value = Get(name, ""); |
| return value.empty() ? defval : value == "true"; |
| } |
| |
| std::string TaskContext::InputFile(const TaskInput &input) { |
| if (input.part_size() == 0) { |
| TC_LOG(ERROR) << "No file for TaskInput " << input.name(); |
| return ""; |
| } |
| if (input.part_size() > 1) { |
| TC_LOG(ERROR) << "Ambiguous: multiple files for TaskInput " << input.name(); |
| } |
| return input.part(0).file_pattern(); |
| } |
| |
| bool TaskContext::Supports(const TaskInput &input, |
| const std::string &file_format, |
| const std::string &record_format) { |
| // Check file format. |
| if (input.file_format_size() > 0) { |
| bool found = false; |
| for (int i = 0; i < input.file_format_size(); ++i) { |
| if (input.file_format(i) == file_format) { |
| found = true; |
| break; |
| } |
| } |
| if (!found) return false; |
| } |
| |
| // Check record format. |
| if (input.record_format_size() > 0) { |
| bool found = false; |
| for (int i = 0; i < input.record_format_size(); ++i) { |
| if (input.record_format(i) == record_format) { |
| found = true; |
| break; |
| } |
| } |
| if (!found) return false; |
| } |
| |
| return true; |
| } |
| |
| } // namespace nlp_core |
| } // namespace libtextclassifier |