Support to store the index of argc in ToolParams when parsing command-line flags.
PiperOrigin-RevId: 374997161
Change-Id: Iaa8e4b3401d4ff487c8f158a4f52154a8081a14d
diff --git a/tensorflow/lite/tools/BUILD b/tensorflow/lite/tools/BUILD
index c5e58a6..d9697f3 100644
--- a/tensorflow/lite/tools/BUILD
+++ b/tensorflow/lite/tools/BUILD
@@ -266,6 +266,7 @@
visibility = ["//visibility:private"],
deps = [
":command_line_flags",
+ ":tool_params",
"@com_google_googletest//:gtest_main",
],
)
diff --git a/tensorflow/lite/tools/benchmark/benchmark_model.h b/tensorflow/lite/tools/benchmark/benchmark_model.h
index 912e54f..c3ba4c6 100644
--- a/tensorflow/lite/tools/benchmark/benchmark_model.h
+++ b/tensorflow/lite/tools/benchmark/benchmark_model.h
@@ -160,7 +160,10 @@
Flag CreateFlag(const char* name, BenchmarkParams* params,
const std::string& usage) {
return Flag(
- name, [params, name](const T& val) { params->Set<T>(name, val); },
+ name,
+ [params, name](const T& val, int argv_position) {
+ params->Set<T>(name, val, argv_position);
+ },
params->Get<T>(name), usage, Flag::kOptional);
}
diff --git a/tensorflow/lite/tools/command_line_flags.cc b/tensorflow/lite/tools/command_line_flags.cc
index c7affe4..1576cb2 100644
--- a/tensorflow/lite/tools/command_line_flags.cc
+++ b/tensorflow/lite/tools/command_line_flags.cc
@@ -34,11 +34,23 @@
return stream.str();
}
-bool ParseFlag(const std::string& arg, const std::string& flag, bool positional,
- const std::function<bool(const std::string&)>& parse_func,
+template <>
+std::string ToString(bool val) {
+ return val ? "true" : "false";
+}
+
+template <>
+std::string ToString(const std::string& val) {
+ return val;
+}
+
+bool ParseFlag(const std::string& arg, int argv_position,
+ const std::string& flag, bool positional,
+ const std::function<bool(const std::string&, int argv_position)>&
+ parse_func,
bool* value_parsing_ok) {
if (positional) {
- *value_parsing_ok = parse_func(arg);
+ *value_parsing_ok = parse_func(arg, argv_position);
return true;
}
*value_parsing_ok = true;
@@ -49,101 +61,76 @@
bool has_value = arg.size() >= flag_prefix.size();
*value_parsing_ok = has_value;
if (has_value) {
- *value_parsing_ok = parse_func(arg.substr(flag_prefix.size()));
+ *value_parsing_ok =
+ parse_func(arg.substr(flag_prefix.size()), argv_position);
}
return true;
}
template <typename T>
-bool ParseFlag(const std::string& flag_value,
- const std::function<void(const T&)>& hook) {
+bool ParseFlag(const std::string& flag_value, int argv_position,
+ const std::function<void(const T&, int)>& hook) {
std::istringstream stream(flag_value);
T read_value;
stream >> read_value;
if (!stream.eof() && !stream.good()) {
return false;
}
- hook(read_value);
+ hook(read_value, argv_position);
return true;
}
-bool ParseBoolFlag(const std::string& flag_value,
- const std::function<void(const bool&)>& hook) {
+template <>
+bool ParseFlag(const std::string& flag_value, int argv_position,
+ const std::function<void(const bool&, int)>& hook) {
if (flag_value != "true" && flag_value != "false" && flag_value != "0" &&
flag_value != "1") {
return false;
}
- hook(flag_value == "true" || flag_value == "1");
+ hook(flag_value == "true" || flag_value == "1", argv_position);
+ return true;
+}
+
+template <typename T>
+bool ParseFlag(const std::string& flag_value, int argv_position,
+ const std::function<void(const std::string&, int)>& hook) {
+ hook(flag_value, argv_position);
return true;
}
} // namespace
-Flag::Flag(const char* name, const std::function<void(const int32_t&)>& hook,
- int32_t default_value, const std::string& usage_text,
- FlagType flag_type)
- : name_(name),
- type_(TYPE_INT32),
- value_hook_([hook](const std::string& flag_value) {
- return ParseFlag<int32_t>(flag_value, hook);
- }),
- default_for_display_(ToString(default_value)),
- usage_text_(usage_text),
- flag_type_(flag_type) {}
+#define CONSTRUCTOR_IMPLEMENTATION(flag_T, default_value_T, flag_enum_val) \
+ Flag::Flag(const char* name, \
+ const std::function<void(const flag_T& /*flag_val*/, \
+ int /*argv_position*/)>& hook, \
+ default_value_T default_value, const std::string& usage_text, \
+ FlagType flag_type) \
+ : name_(name), \
+ type_(flag_enum_val), \
+ value_hook_([hook](const std::string& flag_value, int argv_position) { \
+ return ParseFlag<flag_T>(flag_value, argv_position, hook); \
+ }), \
+ default_for_display_(ToString<default_value_T>(default_value)), \
+ usage_text_(usage_text), \
+ flag_type_(flag_type) {}
-Flag::Flag(const char* name, const std::function<void(const int64_t&)>& hook,
- int64_t default_value, const std::string& usage_text,
- FlagType flag_type)
- : name_(name),
- type_(TYPE_INT64),
- value_hook_([hook](const std::string& flag_value) {
- return ParseFlag<int64_t>(flag_value, hook);
- }),
- default_for_display_(ToString(default_value)),
- usage_text_(usage_text),
- flag_type_(flag_type) {}
+CONSTRUCTOR_IMPLEMENTATION(int32_t, int32_t, TYPE_INT32)
+CONSTRUCTOR_IMPLEMENTATION(int64_t, int64_t, TYPE_INT64)
+CONSTRUCTOR_IMPLEMENTATION(float, float, TYPE_FLOAT)
+CONSTRUCTOR_IMPLEMENTATION(bool, bool, TYPE_BOOL)
+CONSTRUCTOR_IMPLEMENTATION(std::string, const std::string&, TYPE_STRING)
-Flag::Flag(const char* name, const std::function<void(const float&)>& hook,
- float default_value, const std::string& usage_text,
- FlagType flag_type)
- : name_(name),
- type_(TYPE_FLOAT),
- value_hook_([hook](const std::string& flag_value) {
- return ParseFlag<float>(flag_value, hook);
- }),
- default_for_display_(ToString(default_value)),
- usage_text_(usage_text),
- flag_type_(flag_type) {}
+#undef CONSTRUCTOR_IMPLEMENTATION
-Flag::Flag(const char* name, const std::function<void(const bool&)>& hook,
- bool default_value, const std::string& usage_text,
- FlagType flag_type)
- : name_(name),
- type_(TYPE_BOOL),
- value_hook_([hook](const std::string& flag_value) {
- return ParseBoolFlag(flag_value, hook);
- }),
- default_for_display_(default_value ? "true" : "false"),
- usage_text_(usage_text),
- flag_type_(flag_type) {}
-
-Flag::Flag(const char* name,
- const std::function<void(const std::string&)>& hook,
- const std::string& default_value, const std::string& usage_text,
- FlagType flag_type)
- : name_(name),
- type_(TYPE_STRING),
- value_hook_([hook](const std::string& flag_value) {
- hook(flag_value);
- return true;
- }),
- default_for_display_(default_value),
- usage_text_(usage_text),
- flag_type_(flag_type) {}
-
-bool Flag::Parse(const std::string& arg, bool* value_parsing_ok) const {
- return ParseFlag(arg, name_, flag_type_ == kPositional, value_hook_,
- value_parsing_ok);
+bool Flag::Parse(const std::string& arg, int argv_position,
+ bool* value_parsing_ok) const {
+ return ParseFlag(
+ arg, argv_position, name_, flag_type_ == kPositional,
+ [&](const std::string& read_value, int argv_position) {
+ return value_hook_(read_value, argv_position);
+ },
+ value_parsing_ok);
}
std::string Flag::GetTypeName() const {
@@ -191,7 +178,7 @@
#endif
if (it->second != -1) {
bool value_parsing_ok;
- flag.Parse(argv[it->second], &value_parsing_ok);
+ flag.Parse(argv[it->second], it->second, &value_parsing_ok);
if (!value_parsing_ok) {
TFLITE_LOG(ERROR) << "Failed to parse flag '" << flag.name_
<< "' against argv '" << argv[it->second] << "'";
@@ -214,7 +201,7 @@
return false;
}
bool value_parsing_ok;
- flag.Parse(argv[positional_count], &value_parsing_ok);
+ flag.Parse(argv[positional_count], positional_count, &value_parsing_ok);
if (!value_parsing_ok) {
TFLITE_LOG(ERROR) << "Failed to parse positional flag: " << flag.name_;
return false;
@@ -229,7 +216,7 @@
for (int i = positional_count + 1; i < *argc; ++i) {
if (!unknown_argvs[i]) continue;
bool value_parsing_ok;
- was_found = flag.Parse(argv[i], &value_parsing_ok);
+ was_found = flag.Parse(argv[i], i, &value_parsing_ok);
if (!value_parsing_ok) {
TFLITE_LOG(ERROR) << "Failed to parse flag '" << flag.name_
<< "' against argv '" << argv[i] << "'";
diff --git a/tensorflow/lite/tools/command_line_flags.h b/tensorflow/lite/tools/command_line_flags.h
index 4cc09f4..6c90ed0 100644
--- a/tensorflow/lite/tools/command_line_flags.h
+++ b/tensorflow/lite/tools/command_line_flags.h
@@ -79,26 +79,49 @@
name, [val](const T& v) { *val = v; }, *val, usage, flag_type);
}
- Flag(const char* name, const std::function<void(const int32_t&)>& hook,
- int32_t default_value, const std::string& usage_text,
+// "flag_T" is same as "default_value_T" for trivial types, like int32, bool
+// etc. But when it's a complex type, "default_value_T" is generally a const
+// reference "flag_T".
+#define CONSTRUCTOR_WITH_ARGV_INDEX(flag_T, default_value_T) \
+ Flag(const char* name, \
+ const std::function<void(const flag_T& /*flag_val*/, \
+ int /*argv_position*/)>& hook, \
+ default_value_T default_value, const std::string& usage_text, \
FlagType flag_type);
- Flag(const char* name, const std::function<void(const int64_t&)>& hook,
- int64_t default_value, const std::string& usage_text,
- FlagType flag_type);
- Flag(const char* name, const std::function<void(const float&)>& hook,
- float default_value, const std::string& usage_text, FlagType flag_type);
- Flag(const char* name, const std::function<void(const bool&)>& hook,
- bool default_value, const std::string& usage_text, FlagType flag_type);
- Flag(const char* name, const std::function<void(const std::string&)>& hook,
- const std::string& default_value, const std::string& usage_text,
- FlagType flag_type);
+
+#define CONSTRUCTOR_WITHOUT_ARGV_INDEX(flag_T, default_value_T) \
+ Flag(const char* name, const std::function<void(const flag_T&)>& hook, \
+ default_value_T default_value, const std::string& usage_text, \
+ FlagType flag_type) \
+ : Flag( \
+ name, [hook](const flag_T& flag_val, int) { hook(flag_val); }, \
+ default_value, usage_text, flag_type) {}
+
+ CONSTRUCTOR_WITH_ARGV_INDEX(int32_t, int32_t)
+ CONSTRUCTOR_WITHOUT_ARGV_INDEX(int32_t, int32_t)
+
+ CONSTRUCTOR_WITH_ARGV_INDEX(int64_t, int64_t)
+ CONSTRUCTOR_WITHOUT_ARGV_INDEX(int64_t, int64_t)
+
+ CONSTRUCTOR_WITH_ARGV_INDEX(float, float)
+ CONSTRUCTOR_WITHOUT_ARGV_INDEX(float, float)
+
+ CONSTRUCTOR_WITH_ARGV_INDEX(bool, bool)
+ CONSTRUCTOR_WITHOUT_ARGV_INDEX(bool, bool)
+
+ CONSTRUCTOR_WITH_ARGV_INDEX(std::string, const std::string&)
+ CONSTRUCTOR_WITHOUT_ARGV_INDEX(std::string, const std::string&)
+
+#undef CONSTRUCTOR_WITH_ARGV_INDEX
+#undef CONSTRUCTOR_WITHOUT_ARGV_INDEX
FlagType GetFlagType() const { return flag_type_; }
private:
friend class Flags;
- bool Parse(const std::string& arg, bool* value_parsing_ok) const;
+ bool Parse(const std::string& arg, int argv_position,
+ bool* value_parsing_ok) const;
std::string name_;
enum {
@@ -111,7 +134,8 @@
std::string GetTypeName() const;
- std::function<bool(const std::string&)> value_hook_;
+ std::function<bool(const std::string& /*read_value*/, int /*argv_position*/)>
+ value_hook_;
std::string default_for_display_;
std::string usage_text_;
diff --git a/tensorflow/lite/tools/command_line_flags_test.cc b/tensorflow/lite/tools/command_line_flags_test.cc
index afd1264..5248245 100644
--- a/tensorflow/lite/tools/command_line_flags_test.cc
+++ b/tensorflow/lite/tools/command_line_flags_test.cc
@@ -17,6 +17,7 @@
#include <gmock/gmock.h>
#include <gtest/gtest.h>
+#include "tensorflow/lite/tools/tool_params.h"
namespace tflite {
namespace {
@@ -364,5 +365,55 @@
EXPECT_EQ("--some_int=1 --some_int=2", args);
}
+TEST(CommandLineFlagsTest, ArgvPositions) {
+ tools::ToolParams params;
+ params.AddParam("some_int", tools::ToolParam::Create<int>(13));
+ params.AddParam("some_float", tools::ToolParam::Create<float>(17.0f));
+ params.AddParam("some_bool", tools::ToolParam::Create<bool>(true));
+
+ const char* argv_strings[] = {"program_name", "--some_float=42.0",
+ "--some_bool=false", "--some_int=5"};
+ int argc = 4;
+ tools::ToolParams* const params_ptr = ¶ms;
+ bool parsed_ok = Flags::Parse(
+ &argc, reinterpret_cast<const char**>(argv_strings),
+ {
+ Flag(
+ "some_int",
+ // NOLINT because of needing templating both trivial and complex
+ // types for a Flag.
+ [params_ptr](const int& val, int argv_position) { // NOLINT
+ params_ptr->Set<int>("some_int", val, argv_position);
+ },
+ 13, "some int", Flag::kOptional),
+ Flag(
+ "some_float",
+ [params_ptr](const float& val, int argv_position) { // NOLINT
+ params_ptr->Set<float>("some_float", val, argv_position);
+ },
+ 17.0f, "some float", Flag::kOptional),
+ Flag(
+ "some_bool",
+ [params_ptr](const bool& val, int argv_position) { // NOLINT
+ params_ptr->Set<bool>("some_bool", val, argv_position);
+ },
+ true, "some bool", Flag::kOptional),
+ });
+
+ EXPECT_TRUE(parsed_ok);
+ EXPECT_EQ(5, params.Get<int>("some_int"));
+ EXPECT_NEAR(42.0f, params.Get<float>("some_float"), 1e-5f);
+ EXPECT_FALSE(params.Get<bool>("some_bool"));
+
+ // The position of a parameter depends on the ordering of the associated flag
+ // specfied in the argv (i.e. 'argv_strings' above), not as the ordering of
+ // the flag in the flag list that's passed to Flags::Parse above.
+ EXPECT_EQ(3, params.GetPosition<int>("some_int"));
+ EXPECT_EQ(1, params.GetPosition<float>("some_float"));
+ EXPECT_EQ(2, params.GetPosition<bool>("some_bool"));
+
+ EXPECT_EQ(argc, 1);
+}
+
} // namespace
} // namespace tflite
diff --git a/tensorflow/lite/tools/delegates/delegate_provider.h b/tensorflow/lite/tools/delegates/delegate_provider.h
index ef721f4..e891e42 100644
--- a/tensorflow/lite/tools/delegates/delegate_provider.h
+++ b/tensorflow/lite/tools/delegates/delegate_provider.h
@@ -58,7 +58,10 @@
Flag CreateFlag(const char* name, ToolParams* params,
const std::string& usage) const {
return Flag(
- name, [params, name](const T& val) { params->Set<T>(name, val); },
+ name,
+ [params, name](const T& val, int argv_position) {
+ params->Set<T>(name, val, argv_position);
+ },
default_params_.Get<T>(name), usage, Flag::kOptional);
}
ToolParams default_params_;
diff --git a/tensorflow/lite/tools/tool_params.h b/tensorflow/lite/tools/tool_params.h
index b3e6975..01202ac 100644
--- a/tensorflow/lite/tools/tool_params.h
+++ b/tensorflow/lite/tools/tool_params.h
@@ -35,8 +35,11 @@
public:
template <typename T>
- static std::unique_ptr<ToolParam> Create(const T& default_value) {
- return std::unique_ptr<ToolParam>(new TypedToolParam<T>(default_value));
+ static std::unique_ptr<ToolParam> Create(const T& default_value,
+ int position = 0) {
+ auto* param = new TypedToolParam<T>(default_value);
+ param->SetPosition(position);
+ return std::unique_ptr<ToolParam>(param);
}
template <typename T>
@@ -52,10 +55,14 @@
}
virtual ~ToolParam() {}
- explicit ToolParam(ParamType type) : has_value_set_(false), type_(type) {}
+ explicit ToolParam(ParamType type)
+ : has_value_set_(false), position_(0), type_(type) {}
bool HasValueSet() const { return has_value_set_; }
+ int GetPosition() const { return position_; }
+ void SetPosition(int position) { position_ = position; }
+
virtual void Set(const ToolParam&) {}
virtual std::unique_ptr<ToolParam> Clone() const = 0;
@@ -63,6 +70,13 @@
protected:
bool has_value_set_;
+ // Represents the relative ordering among a set of params.
+ // Note: in our code, a ToolParam is generally used together with a
+ // tflite::Flag so that its value could be set when parsing commandline flags.
+ // In this case, the `position_` is simply the index of the particular flag
+ // into the list of commandline flags (i.e. named 'argv' in general).
+ int position_;
+
private:
static void AssertHasSameType(ParamType a, ParamType b);
@@ -84,10 +98,11 @@
void Set(const ToolParam& other) override {
Set(other.AsConstTyped<T>()->Get());
+ SetPosition(other.AsConstTyped<T>()->GetPosition());
}
std::unique_ptr<ToolParam> Clone() const override {
- return std::unique_ptr<ToolParam>(new TypedToolParam<T>(value_));
+ return ToolParam::Create<T>(value_, position_);
}
private:
@@ -97,6 +112,7 @@
// A map-like container for holding values of different types.
class ToolParams {
public:
+ // Add a ToolParam instance `value` w/ `name` to this container.
void AddParam(const std::string& name, std::unique_ptr<ToolParam> value) {
params_[name] = std::move(value);
}
@@ -114,9 +130,10 @@
}
template <typename T>
- void Set(const std::string& name, const T& value) {
+ void Set(const std::string& name, const T& value, int position = 0) {
AssertParamExists(name);
params_.at(name)->AsTyped<T>()->Set(value);
+ params_.at(name)->AsTyped<T>()->SetPosition(position);
}
template <typename T>
@@ -126,6 +143,12 @@
}
template <typename T>
+ int GetPosition(const std::string& name) const {
+ AssertParamExists(name);
+ return params_.at(name)->AsConstTyped<T>()->GetPosition();
+ }
+
+ template <typename T>
T Get(const std::string& name) const {
AssertParamExists(name);
return params_.at(name)->AsConstTyped<T>()->Get();
diff --git a/tensorflow/lite/tools/tool_params_test.cc b/tensorflow/lite/tools/tool_params_test.cc
index 8c12ea6..248db53 100644
--- a/tensorflow/lite/tools/tool_params_test.cc
+++ b/tensorflow/lite/tools/tool_params_test.cc
@@ -24,18 +24,20 @@
TEST(ToolParams, SetTest) {
ToolParams params;
- params.AddParam("some-int1", ToolParam::Create<int>(13));
- params.AddParam("some-int2", ToolParam::Create<int>(17));
+ params.AddParam("some-int1", ToolParam::Create<int>(13 /*, position=0*/));
+ params.AddParam("some-int2", ToolParam::Create<int>(17 /*, position=0*/));
ToolParams others;
- others.AddParam("some-int1", ToolParam::Create<int>(19));
- others.AddParam("some-bool", ToolParam::Create<bool>(true));
+ others.AddParam("some-int1", ToolParam::Create<int>(19, 5));
+ others.AddParam("some-bool", ToolParam::Create<bool>(true, 1));
params.Set(others);
EXPECT_EQ(19, params.Get<int>("some-int1"));
+ EXPECT_EQ(5, params.GetPosition<int>("some-int1"));
EXPECT_TRUE(params.HasValueSet<int>("some-int1"));
EXPECT_EQ(17, params.Get<int>("some-int2"));
+ EXPECT_EQ(0, params.GetPosition<int>("some-int2"));
EXPECT_FALSE(params.HasValueSet<int>("some-int2"));
EXPECT_FALSE(params.HasParam("some-bool"));
@@ -43,30 +45,32 @@
TEST(ToolParams, MergeTestOverwriteTrue) {
ToolParams params;
- params.AddParam("some-int1", ToolParam::Create<int>(13));
- params.AddParam("some-int2", ToolParam::Create<int>(17));
+ params.AddParam("some-int1", ToolParam::Create<int>(13 /*, position=0*/));
+ params.AddParam("some-int2", ToolParam::Create<int>(17 /*, position=0*/));
ToolParams others;
- others.AddParam("some-int1", ToolParam::Create<int>(19));
- others.AddParam("some-bool", ToolParam::Create<bool>(true));
+ others.AddParam("some-int1", ToolParam::Create<int>(19, 5));
+ others.AddParam("some-bool", ToolParam::Create<bool>(true /*, position=0*/));
params.Merge(others, true /* overwrite */);
EXPECT_EQ(19, params.Get<int>("some-int1"));
+ EXPECT_EQ(5, params.GetPosition<int>("some-int1"));
EXPECT_EQ(17, params.Get<int>("some-int2"));
EXPECT_TRUE(params.Get<bool>("some-bool"));
}
TEST(ToolParams, MergeTestOverwriteFalse) {
ToolParams params;
- params.AddParam("some-int1", ToolParam::Create<int>(13));
- params.AddParam("some-int2", ToolParam::Create<int>(17));
+ params.AddParam("some-int1", ToolParam::Create<int>(13 /*, position=0*/));
+ params.AddParam("some-int2", ToolParam::Create<int>(17 /*, position=0*/));
ToolParams others;
- others.AddParam("some-int1", ToolParam::Create<int>(19));
- others.AddParam("some-bool", ToolParam::Create<bool>(true));
+ others.AddParam("some-int1", ToolParam::Create<int>(19, 5));
+ others.AddParam("some-bool", ToolParam::Create<bool>(true /*, position=0*/));
params.Merge(others); // default overwrite is false
EXPECT_EQ(13, params.Get<int>("some-int1"));
+ EXPECT_EQ(0, params.GetPosition<int>("some-int1"));
EXPECT_EQ(17, params.Get<int>("some-int2"));
EXPECT_TRUE(params.Get<bool>("some-bool"));
}