Support to store the index of argc in ToolParams when parsing command-line flags.

PiperOrigin-RevId: 374997161
Change-Id: Iaa8e4b3401d4ff487c8f158a4f52154a8081a14d
diff --git a/tensorflow/lite/tools/BUILD b/tensorflow/lite/tools/BUILD
index c5e58a6..d9697f3 100644
--- a/tensorflow/lite/tools/BUILD
+++ b/tensorflow/lite/tools/BUILD
@@ -266,6 +266,7 @@
     visibility = ["//visibility:private"],
     deps = [
         ":command_line_flags",
+        ":tool_params",
         "@com_google_googletest//:gtest_main",
     ],
 )
diff --git a/tensorflow/lite/tools/benchmark/benchmark_model.h b/tensorflow/lite/tools/benchmark/benchmark_model.h
index 912e54f..c3ba4c6 100644
--- a/tensorflow/lite/tools/benchmark/benchmark_model.h
+++ b/tensorflow/lite/tools/benchmark/benchmark_model.h
@@ -160,7 +160,10 @@
 Flag CreateFlag(const char* name, BenchmarkParams* params,
                 const std::string& usage) {
   return Flag(
-      name, [params, name](const T& val) { params->Set<T>(name, val); },
+      name,
+      [params, name](const T& val, int argv_position) {
+        params->Set<T>(name, val, argv_position);
+      },
       params->Get<T>(name), usage, Flag::kOptional);
 }
 
diff --git a/tensorflow/lite/tools/command_line_flags.cc b/tensorflow/lite/tools/command_line_flags.cc
index c7affe4..1576cb2 100644
--- a/tensorflow/lite/tools/command_line_flags.cc
+++ b/tensorflow/lite/tools/command_line_flags.cc
@@ -34,11 +34,23 @@
   return stream.str();
 }
 
-bool ParseFlag(const std::string& arg, const std::string& flag, bool positional,
-               const std::function<bool(const std::string&)>& parse_func,
+template <>
+std::string ToString(bool val) {
+  return val ? "true" : "false";
+}
+
+template <>
+std::string ToString(const std::string& val) {
+  return val;
+}
+
+bool ParseFlag(const std::string& arg, int argv_position,
+               const std::string& flag, bool positional,
+               const std::function<bool(const std::string&, int argv_position)>&
+                   parse_func,
                bool* value_parsing_ok) {
   if (positional) {
-    *value_parsing_ok = parse_func(arg);
+    *value_parsing_ok = parse_func(arg, argv_position);
     return true;
   }
   *value_parsing_ok = true;
@@ -49,101 +61,76 @@
   bool has_value = arg.size() >= flag_prefix.size();
   *value_parsing_ok = has_value;
   if (has_value) {
-    *value_parsing_ok = parse_func(arg.substr(flag_prefix.size()));
+    *value_parsing_ok =
+        parse_func(arg.substr(flag_prefix.size()), argv_position);
   }
   return true;
 }
 
 template <typename T>
-bool ParseFlag(const std::string& flag_value,
-               const std::function<void(const T&)>& hook) {
+bool ParseFlag(const std::string& flag_value, int argv_position,
+               const std::function<void(const T&, int)>& hook) {
   std::istringstream stream(flag_value);
   T read_value;
   stream >> read_value;
   if (!stream.eof() && !stream.good()) {
     return false;
   }
-  hook(read_value);
+  hook(read_value, argv_position);
   return true;
 }
 
-bool ParseBoolFlag(const std::string& flag_value,
-                   const std::function<void(const bool&)>& hook) {
+template <>
+bool ParseFlag(const std::string& flag_value, int argv_position,
+               const std::function<void(const bool&, int)>& hook) {
   if (flag_value != "true" && flag_value != "false" && flag_value != "0" &&
       flag_value != "1") {
     return false;
   }
 
-  hook(flag_value == "true" || flag_value == "1");
+  hook(flag_value == "true" || flag_value == "1", argv_position);
+  return true;
+}
+
+template <typename T>
+bool ParseFlag(const std::string& flag_value, int argv_position,
+               const std::function<void(const std::string&, int)>& hook) {
+  hook(flag_value, argv_position);
   return true;
 }
 }  // namespace
 
-Flag::Flag(const char* name, const std::function<void(const int32_t&)>& hook,
-           int32_t default_value, const std::string& usage_text,
-           FlagType flag_type)
-    : name_(name),
-      type_(TYPE_INT32),
-      value_hook_([hook](const std::string& flag_value) {
-        return ParseFlag<int32_t>(flag_value, hook);
-      }),
-      default_for_display_(ToString(default_value)),
-      usage_text_(usage_text),
-      flag_type_(flag_type) {}
+#define CONSTRUCTOR_IMPLEMENTATION(flag_T, default_value_T, flag_enum_val)     \
+  Flag::Flag(const char* name,                                                 \
+             const std::function<void(const flag_T& /*flag_val*/,              \
+                                      int /*argv_position*/)>& hook,           \
+             default_value_T default_value, const std::string& usage_text,     \
+             FlagType flag_type)                                               \
+      : name_(name),                                                           \
+        type_(flag_enum_val),                                                  \
+        value_hook_([hook](const std::string& flag_value, int argv_position) { \
+          return ParseFlag<flag_T>(flag_value, argv_position, hook);           \
+        }),                                                                    \
+        default_for_display_(ToString<default_value_T>(default_value)),        \
+        usage_text_(usage_text),                                               \
+        flag_type_(flag_type) {}
 
-Flag::Flag(const char* name, const std::function<void(const int64_t&)>& hook,
-           int64_t default_value, const std::string& usage_text,
-           FlagType flag_type)
-    : name_(name),
-      type_(TYPE_INT64),
-      value_hook_([hook](const std::string& flag_value) {
-        return ParseFlag<int64_t>(flag_value, hook);
-      }),
-      default_for_display_(ToString(default_value)),
-      usage_text_(usage_text),
-      flag_type_(flag_type) {}
+CONSTRUCTOR_IMPLEMENTATION(int32_t, int32_t, TYPE_INT32)
+CONSTRUCTOR_IMPLEMENTATION(int64_t, int64_t, TYPE_INT64)
+CONSTRUCTOR_IMPLEMENTATION(float, float, TYPE_FLOAT)
+CONSTRUCTOR_IMPLEMENTATION(bool, bool, TYPE_BOOL)
+CONSTRUCTOR_IMPLEMENTATION(std::string, const std::string&, TYPE_STRING)
 
-Flag::Flag(const char* name, const std::function<void(const float&)>& hook,
-           float default_value, const std::string& usage_text,
-           FlagType flag_type)
-    : name_(name),
-      type_(TYPE_FLOAT),
-      value_hook_([hook](const std::string& flag_value) {
-        return ParseFlag<float>(flag_value, hook);
-      }),
-      default_for_display_(ToString(default_value)),
-      usage_text_(usage_text),
-      flag_type_(flag_type) {}
+#undef CONSTRUCTOR_IMPLEMENTATION
 
-Flag::Flag(const char* name, const std::function<void(const bool&)>& hook,
-           bool default_value, const std::string& usage_text,
-           FlagType flag_type)
-    : name_(name),
-      type_(TYPE_BOOL),
-      value_hook_([hook](const std::string& flag_value) {
-        return ParseBoolFlag(flag_value, hook);
-      }),
-      default_for_display_(default_value ? "true" : "false"),
-      usage_text_(usage_text),
-      flag_type_(flag_type) {}
-
-Flag::Flag(const char* name,
-           const std::function<void(const std::string&)>& hook,
-           const std::string& default_value, const std::string& usage_text,
-           FlagType flag_type)
-    : name_(name),
-      type_(TYPE_STRING),
-      value_hook_([hook](const std::string& flag_value) {
-        hook(flag_value);
-        return true;
-      }),
-      default_for_display_(default_value),
-      usage_text_(usage_text),
-      flag_type_(flag_type) {}
-
-bool Flag::Parse(const std::string& arg, bool* value_parsing_ok) const {
-  return ParseFlag(arg, name_, flag_type_ == kPositional, value_hook_,
-                   value_parsing_ok);
+bool Flag::Parse(const std::string& arg, int argv_position,
+                 bool* value_parsing_ok) const {
+  return ParseFlag(
+      arg, argv_position, name_, flag_type_ == kPositional,
+      [&](const std::string& read_value, int argv_position) {
+        return value_hook_(read_value, argv_position);
+      },
+      value_parsing_ok);
 }
 
 std::string Flag::GetTypeName() const {
@@ -191,7 +178,7 @@
 #endif
       if (it->second != -1) {
         bool value_parsing_ok;
-        flag.Parse(argv[it->second], &value_parsing_ok);
+        flag.Parse(argv[it->second], it->second, &value_parsing_ok);
         if (!value_parsing_ok) {
           TFLITE_LOG(ERROR) << "Failed to parse flag '" << flag.name_
                             << "' against argv '" << argv[it->second] << "'";
@@ -214,7 +201,7 @@
         return false;
       }
       bool value_parsing_ok;
-      flag.Parse(argv[positional_count], &value_parsing_ok);
+      flag.Parse(argv[positional_count], positional_count, &value_parsing_ok);
       if (!value_parsing_ok) {
         TFLITE_LOG(ERROR) << "Failed to parse positional flag: " << flag.name_;
         return false;
@@ -229,7 +216,7 @@
     for (int i = positional_count + 1; i < *argc; ++i) {
       if (!unknown_argvs[i]) continue;
       bool value_parsing_ok;
-      was_found = flag.Parse(argv[i], &value_parsing_ok);
+      was_found = flag.Parse(argv[i], i, &value_parsing_ok);
       if (!value_parsing_ok) {
         TFLITE_LOG(ERROR) << "Failed to parse flag '" << flag.name_
                           << "' against argv '" << argv[i] << "'";
diff --git a/tensorflow/lite/tools/command_line_flags.h b/tensorflow/lite/tools/command_line_flags.h
index 4cc09f4..6c90ed0 100644
--- a/tensorflow/lite/tools/command_line_flags.h
+++ b/tensorflow/lite/tools/command_line_flags.h
@@ -79,26 +79,49 @@
         name, [val](const T& v) { *val = v; }, *val, usage, flag_type);
   }
 
-  Flag(const char* name, const std::function<void(const int32_t&)>& hook,
-       int32_t default_value, const std::string& usage_text,
+// "flag_T" is same as "default_value_T" for trivial types, like int32, bool
+// etc. But when it's a complex type, "default_value_T" is generally a const
+// reference "flag_T".
+#define CONSTRUCTOR_WITH_ARGV_INDEX(flag_T, default_value_T)         \
+  Flag(const char* name,                                             \
+       const std::function<void(const flag_T& /*flag_val*/,          \
+                                int /*argv_position*/)>& hook,       \
+       default_value_T default_value, const std::string& usage_text, \
        FlagType flag_type);
-  Flag(const char* name, const std::function<void(const int64_t&)>& hook,
-       int64_t default_value, const std::string& usage_text,
-       FlagType flag_type);
-  Flag(const char* name, const std::function<void(const float&)>& hook,
-       float default_value, const std::string& usage_text, FlagType flag_type);
-  Flag(const char* name, const std::function<void(const bool&)>& hook,
-       bool default_value, const std::string& usage_text, FlagType flag_type);
-  Flag(const char* name, const std::function<void(const std::string&)>& hook,
-       const std::string& default_value, const std::string& usage_text,
-       FlagType flag_type);
+
+#define CONSTRUCTOR_WITHOUT_ARGV_INDEX(flag_T, default_value_T)            \
+  Flag(const char* name, const std::function<void(const flag_T&)>& hook,   \
+       default_value_T default_value, const std::string& usage_text,       \
+       FlagType flag_type)                                                 \
+      : Flag(                                                              \
+            name, [hook](const flag_T& flag_val, int) { hook(flag_val); }, \
+            default_value, usage_text, flag_type) {}
+
+  CONSTRUCTOR_WITH_ARGV_INDEX(int32_t, int32_t)
+  CONSTRUCTOR_WITHOUT_ARGV_INDEX(int32_t, int32_t)
+
+  CONSTRUCTOR_WITH_ARGV_INDEX(int64_t, int64_t)
+  CONSTRUCTOR_WITHOUT_ARGV_INDEX(int64_t, int64_t)
+
+  CONSTRUCTOR_WITH_ARGV_INDEX(float, float)
+  CONSTRUCTOR_WITHOUT_ARGV_INDEX(float, float)
+
+  CONSTRUCTOR_WITH_ARGV_INDEX(bool, bool)
+  CONSTRUCTOR_WITHOUT_ARGV_INDEX(bool, bool)
+
+  CONSTRUCTOR_WITH_ARGV_INDEX(std::string, const std::string&)
+  CONSTRUCTOR_WITHOUT_ARGV_INDEX(std::string, const std::string&)
+
+#undef CONSTRUCTOR_WITH_ARGV_INDEX
+#undef CONSTRUCTOR_WITHOUT_ARGV_INDEX
 
   FlagType GetFlagType() const { return flag_type_; }
 
  private:
   friend class Flags;
 
-  bool Parse(const std::string& arg, bool* value_parsing_ok) const;
+  bool Parse(const std::string& arg, int argv_position,
+             bool* value_parsing_ok) const;
 
   std::string name_;
   enum {
@@ -111,7 +134,8 @@
 
   std::string GetTypeName() const;
 
-  std::function<bool(const std::string&)> value_hook_;
+  std::function<bool(const std::string& /*read_value*/, int /*argv_position*/)>
+      value_hook_;
   std::string default_for_display_;
 
   std::string usage_text_;
diff --git a/tensorflow/lite/tools/command_line_flags_test.cc b/tensorflow/lite/tools/command_line_flags_test.cc
index afd1264..5248245 100644
--- a/tensorflow/lite/tools/command_line_flags_test.cc
+++ b/tensorflow/lite/tools/command_line_flags_test.cc
@@ -17,6 +17,7 @@
 
 #include <gmock/gmock.h>
 #include <gtest/gtest.h>
+#include "tensorflow/lite/tools/tool_params.h"
 
 namespace tflite {
 namespace {
@@ -364,5 +365,55 @@
   EXPECT_EQ("--some_int=1 --some_int=2", args);
 }
 
+TEST(CommandLineFlagsTest, ArgvPositions) {
+  tools::ToolParams params;
+  params.AddParam("some_int", tools::ToolParam::Create<int>(13));
+  params.AddParam("some_float", tools::ToolParam::Create<float>(17.0f));
+  params.AddParam("some_bool", tools::ToolParam::Create<bool>(true));
+
+  const char* argv_strings[] = {"program_name", "--some_float=42.0",
+                                "--some_bool=false", "--some_int=5"};
+  int argc = 4;
+  tools::ToolParams* const params_ptr = &params;
+  bool parsed_ok = Flags::Parse(
+      &argc, reinterpret_cast<const char**>(argv_strings),
+      {
+          Flag(
+              "some_int",
+              // NOLINT because of needing templating both trivial and complex
+              // types for a Flag.
+              [params_ptr](const int& val, int argv_position) {  // NOLINT
+                params_ptr->Set<int>("some_int", val, argv_position);
+              },
+              13, "some int", Flag::kOptional),
+          Flag(
+              "some_float",
+              [params_ptr](const float& val, int argv_position) {  // NOLINT
+                params_ptr->Set<float>("some_float", val, argv_position);
+              },
+              17.0f, "some float", Flag::kOptional),
+          Flag(
+              "some_bool",
+              [params_ptr](const bool& val, int argv_position) {  // NOLINT
+                params_ptr->Set<bool>("some_bool", val, argv_position);
+              },
+              true, "some bool", Flag::kOptional),
+      });
+
+  EXPECT_TRUE(parsed_ok);
+  EXPECT_EQ(5, params.Get<int>("some_int"));
+  EXPECT_NEAR(42.0f, params.Get<float>("some_float"), 1e-5f);
+  EXPECT_FALSE(params.Get<bool>("some_bool"));
+
+  // The position of a parameter depends on the ordering of the associated flag
+  // specfied in the argv (i.e. 'argv_strings' above), not as the ordering of
+  // the flag in the flag list that's passed to Flags::Parse above.
+  EXPECT_EQ(3, params.GetPosition<int>("some_int"));
+  EXPECT_EQ(1, params.GetPosition<float>("some_float"));
+  EXPECT_EQ(2, params.GetPosition<bool>("some_bool"));
+
+  EXPECT_EQ(argc, 1);
+}
+
 }  // namespace
 }  // namespace tflite
diff --git a/tensorflow/lite/tools/delegates/delegate_provider.h b/tensorflow/lite/tools/delegates/delegate_provider.h
index ef721f4..e891e42 100644
--- a/tensorflow/lite/tools/delegates/delegate_provider.h
+++ b/tensorflow/lite/tools/delegates/delegate_provider.h
@@ -58,7 +58,10 @@
   Flag CreateFlag(const char* name, ToolParams* params,
                   const std::string& usage) const {
     return Flag(
-        name, [params, name](const T& val) { params->Set<T>(name, val); },
+        name,
+        [params, name](const T& val, int argv_position) {
+          params->Set<T>(name, val, argv_position);
+        },
         default_params_.Get<T>(name), usage, Flag::kOptional);
   }
   ToolParams default_params_;
diff --git a/tensorflow/lite/tools/tool_params.h b/tensorflow/lite/tools/tool_params.h
index b3e6975..01202ac 100644
--- a/tensorflow/lite/tools/tool_params.h
+++ b/tensorflow/lite/tools/tool_params.h
@@ -35,8 +35,11 @@
 
  public:
   template <typename T>
-  static std::unique_ptr<ToolParam> Create(const T& default_value) {
-    return std::unique_ptr<ToolParam>(new TypedToolParam<T>(default_value));
+  static std::unique_ptr<ToolParam> Create(const T& default_value,
+                                           int position = 0) {
+    auto* param = new TypedToolParam<T>(default_value);
+    param->SetPosition(position);
+    return std::unique_ptr<ToolParam>(param);
   }
 
   template <typename T>
@@ -52,10 +55,14 @@
   }
 
   virtual ~ToolParam() {}
-  explicit ToolParam(ParamType type) : has_value_set_(false), type_(type) {}
+  explicit ToolParam(ParamType type)
+      : has_value_set_(false), position_(0), type_(type) {}
 
   bool HasValueSet() const { return has_value_set_; }
 
+  int GetPosition() const { return position_; }
+  void SetPosition(int position) { position_ = position; }
+
   virtual void Set(const ToolParam&) {}
 
   virtual std::unique_ptr<ToolParam> Clone() const = 0;
@@ -63,6 +70,13 @@
  protected:
   bool has_value_set_;
 
+  // Represents the relative ordering among a set of params.
+  // Note: in our code, a ToolParam is generally used together with a
+  // tflite::Flag so that its value could be set when parsing commandline flags.
+  // In this case, the `position_` is simply the index of the particular flag
+  // into the list of commandline flags (i.e. named 'argv' in general).
+  int position_;
+
  private:
   static void AssertHasSameType(ParamType a, ParamType b);
 
@@ -84,10 +98,11 @@
 
   void Set(const ToolParam& other) override {
     Set(other.AsConstTyped<T>()->Get());
+    SetPosition(other.AsConstTyped<T>()->GetPosition());
   }
 
   std::unique_ptr<ToolParam> Clone() const override {
-    return std::unique_ptr<ToolParam>(new TypedToolParam<T>(value_));
+    return ToolParam::Create<T>(value_, position_);
   }
 
  private:
@@ -97,6 +112,7 @@
 // A map-like container for holding values of different types.
 class ToolParams {
  public:
+  // Add a ToolParam instance `value` w/ `name` to this container.
   void AddParam(const std::string& name, std::unique_ptr<ToolParam> value) {
     params_[name] = std::move(value);
   }
@@ -114,9 +130,10 @@
   }
 
   template <typename T>
-  void Set(const std::string& name, const T& value) {
+  void Set(const std::string& name, const T& value, int position = 0) {
     AssertParamExists(name);
     params_.at(name)->AsTyped<T>()->Set(value);
+    params_.at(name)->AsTyped<T>()->SetPosition(position);
   }
 
   template <typename T>
@@ -126,6 +143,12 @@
   }
 
   template <typename T>
+  int GetPosition(const std::string& name) const {
+    AssertParamExists(name);
+    return params_.at(name)->AsConstTyped<T>()->GetPosition();
+  }
+
+  template <typename T>
   T Get(const std::string& name) const {
     AssertParamExists(name);
     return params_.at(name)->AsConstTyped<T>()->Get();
diff --git a/tensorflow/lite/tools/tool_params_test.cc b/tensorflow/lite/tools/tool_params_test.cc
index 8c12ea6..248db53 100644
--- a/tensorflow/lite/tools/tool_params_test.cc
+++ b/tensorflow/lite/tools/tool_params_test.cc
@@ -24,18 +24,20 @@
 
 TEST(ToolParams, SetTest) {
   ToolParams params;
-  params.AddParam("some-int1", ToolParam::Create<int>(13));
-  params.AddParam("some-int2", ToolParam::Create<int>(17));
+  params.AddParam("some-int1", ToolParam::Create<int>(13 /*, position=0*/));
+  params.AddParam("some-int2", ToolParam::Create<int>(17 /*, position=0*/));
 
   ToolParams others;
-  others.AddParam("some-int1", ToolParam::Create<int>(19));
-  others.AddParam("some-bool", ToolParam::Create<bool>(true));
+  others.AddParam("some-int1", ToolParam::Create<int>(19, 5));
+  others.AddParam("some-bool", ToolParam::Create<bool>(true, 1));
 
   params.Set(others);
   EXPECT_EQ(19, params.Get<int>("some-int1"));
+  EXPECT_EQ(5, params.GetPosition<int>("some-int1"));
   EXPECT_TRUE(params.HasValueSet<int>("some-int1"));
 
   EXPECT_EQ(17, params.Get<int>("some-int2"));
+  EXPECT_EQ(0, params.GetPosition<int>("some-int2"));
   EXPECT_FALSE(params.HasValueSet<int>("some-int2"));
 
   EXPECT_FALSE(params.HasParam("some-bool"));
@@ -43,30 +45,32 @@
 
 TEST(ToolParams, MergeTestOverwriteTrue) {
   ToolParams params;
-  params.AddParam("some-int1", ToolParam::Create<int>(13));
-  params.AddParam("some-int2", ToolParam::Create<int>(17));
+  params.AddParam("some-int1", ToolParam::Create<int>(13 /*, position=0*/));
+  params.AddParam("some-int2", ToolParam::Create<int>(17 /*, position=0*/));
 
   ToolParams others;
-  others.AddParam("some-int1", ToolParam::Create<int>(19));
-  others.AddParam("some-bool", ToolParam::Create<bool>(true));
+  others.AddParam("some-int1", ToolParam::Create<int>(19, 5));
+  others.AddParam("some-bool", ToolParam::Create<bool>(true /*, position=0*/));
 
   params.Merge(others, true /* overwrite */);
   EXPECT_EQ(19, params.Get<int>("some-int1"));
+  EXPECT_EQ(5, params.GetPosition<int>("some-int1"));
   EXPECT_EQ(17, params.Get<int>("some-int2"));
   EXPECT_TRUE(params.Get<bool>("some-bool"));
 }
 
 TEST(ToolParams, MergeTestOverwriteFalse) {
   ToolParams params;
-  params.AddParam("some-int1", ToolParam::Create<int>(13));
-  params.AddParam("some-int2", ToolParam::Create<int>(17));
+  params.AddParam("some-int1", ToolParam::Create<int>(13 /*, position=0*/));
+  params.AddParam("some-int2", ToolParam::Create<int>(17 /*, position=0*/));
 
   ToolParams others;
-  others.AddParam("some-int1", ToolParam::Create<int>(19));
-  others.AddParam("some-bool", ToolParam::Create<bool>(true));
+  others.AddParam("some-int1", ToolParam::Create<int>(19, 5));
+  others.AddParam("some-bool", ToolParam::Create<bool>(true /*, position=0*/));
 
   params.Merge(others);  // default overwrite is false
   EXPECT_EQ(13, params.Get<int>("some-int1"));
+  EXPECT_EQ(0, params.GetPosition<int>("some-int1"));
   EXPECT_EQ(17, params.Get<int>("some-int2"));
   EXPECT_TRUE(params.Get<bool>("some-bool"));
 }