blob: 0ba8683abb4a391e25b46f1d12b6fd68a202e254 [file] [log] [blame]
#include <c10/util/Flags.h>
#include <c10/macros/Macros.h>
#include <cstdlib>
#include <iostream>
#include <sstream>
#include <string>
#ifndef C10_USE_GFLAGS
namespace c10 {
using std::string;
C10_DEFINE_REGISTRY(C10FlagsRegistry, C10FlagParser, const string&);
namespace {
static bool gCommandLineFlagsParsed = false;
// Since flags is going to be loaded before logging, we would
// need to have a stringstream to hold the messages instead of directly
// using caffe logging.
std::stringstream& GlobalInitStream() {
static std::stringstream ss;
return ss;
}
static const char* gUsageMessage = "(Usage message not set.)";
} // namespace
C10_EXPORT void SetUsageMessage(const string& str) {
static string usage_message_safe_copy = str;
gUsageMessage = usage_message_safe_copy.c_str();
}
C10_EXPORT const char* UsageMessage() {
return gUsageMessage;
}
C10_EXPORT bool ParseCommandLineFlags(int* pargc, char*** pargv) {
if (*pargc == 0)
return true;
char** argv = *pargv;
bool success = true;
GlobalInitStream() << "Parsing commandline arguments for c10." << std::endl;
// write_head is the location we write the unused arguments to.
int write_head = 1;
for (int i = 1; i < *pargc; ++i) {
string arg(argv[i]);
if (arg.find("--help") != string::npos) {
// Print the help message, and quit.
std::cout << UsageMessage() << std::endl;
std::cout << "Arguments: " << std::endl;
for (const auto& help_msg : C10FlagsRegistry()->HelpMessage()) {
std::cout << " " << help_msg.first << ": " << help_msg.second
<< std::endl;
}
exit(0);
}
// If the arg does not start with "--", we will ignore it.
if (arg[0] != '-' || arg[1] != '-') {
GlobalInitStream()
<< "C10 flag: commandline argument does not match --name=var "
"or --name format: "
<< arg << ". Ignoring this argument." << std::endl;
argv[write_head++] = argv[i];
continue;
}
string key;
string value;
size_t prefix_idx = arg.find('=');
if (prefix_idx == string::npos) {
// If there is no equality char in the arg, it means that the
// arg is specified in the next argument.
key = arg.substr(2, arg.size() - 2);
++i;
if (i == *pargc) {
GlobalInitStream()
<< "C10 flag: reached the last commandline argument, but "
"I am expecting a value for "
<< arg;
success = false;
break;
}
value = string(argv[i]);
} else {
// If there is an equality character, we will basically use the value
// after the "=".
key = arg.substr(2, prefix_idx - 2);
value = arg.substr(prefix_idx + 1, string::npos);
}
// If the flag is not registered, we will ignore it.
if (!C10FlagsRegistry()->Has(key)) {
GlobalInitStream() << "C10 flag: unrecognized commandline argument: "
<< arg << std::endl;
success = false;
break;
}
std::unique_ptr<C10FlagParser> parser(
C10FlagsRegistry()->Create(key, value));
if (!parser->success()) {
GlobalInitStream() << "C10 flag: illegal argument: " << arg << std::endl;
success = false;
break;
}
}
*pargc = write_head;
gCommandLineFlagsParsed = true;
// TODO: when we fail commandline flag parsing, shall we continue, or
// shall we just quit loudly? Right now we carry on the computation, but
// since there are failures in parsing, it is very likely that some
// downstream things will break, in which case it makes sense to quit loud
// and early.
if (!success) {
std::cerr << GlobalInitStream().str();
}
// Clear the global init stream.
GlobalInitStream().str(std::string());
return success;
}
C10_EXPORT bool CommandLineFlagsHasBeenParsed() {
return gCommandLineFlagsParsed;
}
template <>
C10_EXPORT bool C10FlagParser::Parse<string>(
const string& content,
string* value) {
*value = content;
return true;
}
template <>
C10_EXPORT bool C10FlagParser::Parse<int>(const string& content, int* value) {
try {
*value = std::atoi(content.c_str());
return true;
} catch (...) {
GlobalInitStream() << "C10 flag error: Cannot convert argument to int: "
<< content << std::endl;
return false;
}
}
template <>
C10_EXPORT bool C10FlagParser::Parse<int64_t>(
const string& content,
int64_t* value) {
try {
static_assert(sizeof(long long) == sizeof(int64_t), "");
#ifdef __ANDROID__
// Android does not have std::atoll.
*value = atoll(content.c_str());
#else
*value = std::atoll(content.c_str());
#endif
return true;
} catch (...) {
GlobalInitStream() << "C10 flag error: Cannot convert argument to int: "
<< content << std::endl;
return false;
}
}
template <>
C10_EXPORT bool C10FlagParser::Parse<double>(
const string& content,
double* value) {
try {
*value = std::atof(content.c_str());
return true;
} catch (...) {
GlobalInitStream() << "C10 flag error: Cannot convert argument to double: "
<< content << std::endl;
return false;
}
}
template <>
C10_EXPORT bool C10FlagParser::Parse<bool>(const string& content, bool* value) {
if (content == "false" || content == "False" || content == "FALSE" ||
content == "0") {
*value = false;
return true;
} else if (
content == "true" || content == "True" || content == "TRUE" ||
content == "1") {
*value = true;
return true;
} else {
GlobalInitStream()
<< "C10 flag error: Cannot convert argument to bool: " << content
<< std::endl
<< "Note that if you are passing in a bool flag, you need to "
"explicitly specify it, like --arg=True or --arg True. Otherwise, "
"the next argument may be inadvertently used as the argument, "
"causing the above error."
<< std::endl;
return false;
}
}
} // namespace c10
#endif // C10_USE_GFLAGS