internal tests cleanup

PiperOrigin-RevId: 340568129
Change-Id: I521739fa73d00d4ef75a61556e54dc0c0805c2a9
diff --git a/tensorflow/core/platform/vmodule_benchmark_test.cc b/tensorflow/core/platform/vmodule_benchmark_test.cc
index 0f9e75b..f164ece 100644
--- a/tensorflow/core/platform/vmodule_benchmark_test.cc
+++ b/tensorflow/core/platform/vmodule_benchmark_test.cc
@@ -18,8 +18,8 @@
 
 namespace tensorflow {
 
-static void BM_DisabledVlog(int iters) {
-  for (int i = 0; i < iters; ++i) {
+static void BM_DisabledVlog(::testing::benchmark::State& state) {
+  for (auto s : state) {
     VLOG(1) << "Testing VLOG(1)!";
   }
 }
diff --git a/tensorflow/core/util/bcast_test.cc b/tensorflow/core/util/bcast_test.cc
index b6e8bcd..63133fc 100644
--- a/tensorflow/core/util/bcast_test.cc
+++ b/tensorflow/core/util/bcast_test.cc
@@ -673,15 +673,17 @@
             BCastBatchIndices({3, 1}, {2, 1, 2}));
 }
 
-static void BM_BCastSetup(int iters, int same_shape) {
+void BM_BCastSetup(::testing::benchmark::State& state) {
+  const int same_shape = state.range(0);
+
   if (same_shape) {
-    testing::SetLabel("same_shapes");
-    while (--iters > 0) {
+    state.SetLabel("same_shapes");
+    for (auto s : state) {
       class BCast b({1000, 100}, {1000, 100});
     }
   } else {
-    testing::SetLabel("different_shapes");
-    while (--iters > 0) {
+    state.SetLabel("different_shapes");
+    for (auto s : state) {
       class BCast b({3, 1, 5}, {2, 0, 3, 0, 5});
     }
   }
diff --git a/tensorflow/core/util/device_name_utils_test.cc b/tensorflow/core/util/device_name_utils_test.cc
index 065fcfb..7824d9f 100644
--- a/tensorflow/core/util/device_name_utils_test.cc
+++ b/tensorflow/core/util/device_name_utils_test.cc
@@ -572,9 +572,9 @@
   }
 }
 
-static void BM_ParseFullName(int iters) {
+static void BM_ParseFullName(::testing::benchmark::State& state) {
   DeviceNameUtils::ParsedName p;
-  while (iters--) {
+  for (auto s : state) {
     DeviceNameUtils::ParseFullName("/job:worker/replica:3/task:0/cpu:0", &p);
   }
 }
diff --git a/tensorflow/core/util/presized_cuckoo_map_test.cc b/tensorflow/core/util/presized_cuckoo_map_test.cc
index f2c7904..36a7642 100644
--- a/tensorflow/core/util/presized_cuckoo_map_test.cc
+++ b/tensorflow/core/util/presized_cuckoo_map_test.cc
@@ -164,13 +164,13 @@
   }
 }
 
-static void BM_CuckooFill(int iters, int arg) {
+void BM_CuckooFill(::testing::benchmark::State &state) {
+  const int arg = state.range(0);
+
   uint64 table_size = arg;
-  testing::StopTiming();
   std::vector<uint64> calculated_keys;
   CalculateKeys(table_size, &calculated_keys);
-  testing::StartTiming();
-  for (int iter = 0; iter < iters; iter++) {
+  for (auto s : state) {
     PresizedCuckooMap<int> pscm(table_size);
     for (uint64 i = 0; i < table_size; i++) {
       pscm.InsertUnique(calculated_keys[i], i);
@@ -180,25 +180,27 @@
 
 BENCHMARK(BM_CuckooFill)->Arg(1000)->Arg(10000000);
 
-static void BM_CuckooRead(int iters, int arg) {
+void BM_CuckooRead(::testing::benchmark::State &state) {
+  const int arg = state.range(0);
+
   uint64 table_size = arg;
-  testing::StopTiming();
   std::vector<uint64> calculated_keys;
   CalculateKeys(table_size, &calculated_keys);
   PresizedCuckooMap<int> pscm(table_size);
   for (uint64 i = 0; i < table_size; i++) {
     pscm.InsertUnique(calculated_keys[i], i);
   }
-  testing::StartTiming();
-  uint64_t defeat_optimization = 0;
-  for (int i = 0; i < iters; i++) {
-    uint64 key_index = i % table_size;  // May slow down bench!
+
+  int i = 0;
+  for (auto s : state) {
+    // Avoid using '%', which is expensive.
+    uint64 key_index = i;
+    ++i;
+    if (i == table_size) i = 0;
+
     int out = 0;
     pscm.Find(calculated_keys[key_index], &out);
-    defeat_optimization += out;
-  }
-  if (defeat_optimization == 0) {
-    printf("Preventing the compiler from eliding the inner loop\n");
+    tensorflow::testing::DoNotOptimize(out);
   }
 }
 
diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc b/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc
index a2ac7c3..dea55f3 100644
--- a/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc
+++ b/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc
@@ -1109,9 +1109,8 @@
   }
 }
 
-static void BM_BundleAlignmentByteOff(int iters, int alignment,
-                                      int tensor_size) {
-  testing::StopTiming();
+static void BM_BundleAlignmentByteOff(::testing::benchmark::State& state,
+                                      int alignment, int tensor_size) {
   {
     BundleWriter::Options opts;
     opts.data_alignment = alignment;
@@ -1122,18 +1121,17 @@
   }
   BundleReader reader(Env::Default(), Prefix("foo"));
   TF_CHECK_OK(reader.status());
-  testing::StartTiming();
-  for (int i = 0; i < iters; ++i) {
+  for (auto s : state) {
     Tensor t;
     TF_CHECK_OK(reader.Lookup("big", &t));
   }
-  testing::StopTiming();
 }
 
-#define BM_BundleAlignment(ALIGN, SIZE)                        \
-  static void BM_BundleAlignment_##ALIGN##_##SIZE(int iters) { \
-    BM_BundleAlignmentByteOff(iters, ALIGN, SIZE);             \
-  }                                                            \
+#define BM_BundleAlignment(ALIGN, SIZE)            \
+  static void BM_BundleAlignment_##ALIGN##_##SIZE( \
+      ::testing::benchmark::State& state) {        \
+    BM_BundleAlignmentByteOff(state, ALIGN, SIZE); \
+  }                                                \
   BENCHMARK(BM_BundleAlignment_##ALIGN##_##SIZE)
 
 BM_BundleAlignment(1, 512);
diff --git a/tensorflow/core/util/work_sharder_test.cc b/tensorflow/core/util/work_sharder_test.cc
index bc5a1d2..f69572d 100644
--- a/tensorflow/core/util/work_sharder_test.cc
+++ b/tensorflow/core/util/work_sharder_test.cc
@@ -89,12 +89,14 @@
   }
 }
 
-void BM_Sharding(int iters, int arg) {
+void BM_Sharding(::testing::benchmark::State& state) {
+  const int arg = state.range(0);
+
   thread::ThreadPool threads(Env::Default(), "test", 16);
   const int64 total = 1LL << 30;
   auto lambda = [](int64 start, int64 limit) {};
   auto work = std::cref(lambda);
-  for (; iters > 0; iters -= arg) {
+  for (auto s : state) {
     Shard(arg - 1, &threads, total, 1, work);
   }
 }
diff --git a/tensorflow/stream_executor/lib/statusor_test.cc b/tensorflow/stream_executor/lib/statusor_test.cc
index 46bdb9d..6b59eaa 100644
--- a/tensorflow/stream_executor/lib/statusor_test.cc
+++ b/tensorflow/stream_executor/lib/statusor_test.cc
@@ -535,12 +535,10 @@
 
 // Calibrate the amount of time spent just calling DoWork, since each of our
 // tests will do this, we can subtract this out of benchmark results.
-void BM_CalibrateWorkLoop(int iters) {
-  tensorflow::testing::StopTiming();
+void BM_CalibrateWorkLoop(::testing::benchmark::State& state) {
   BenchmarkFactory<BenchmarkType> factory;
   BenchmarkType* result = factory.TrivialFactory();
-  tensorflow::testing::StartTiming();
-  for (int i = 0; i != iters; ++i) {
+  for (auto s : state) {
     if (result != nullptr) {
       result->DoWork();
     }
@@ -550,11 +548,9 @@
 
 // Measure the time taken to call into the factory, return the value,
 // determine that it is OK, and invoke a trivial function.
-void BM_TrivialFactory(int iters) {
-  tensorflow::testing::StopTiming();
+void BM_TrivialFactory(::testing::benchmark::State& state) {
   BenchmarkFactory<BenchmarkType> factory;
-  tensorflow::testing::StartTiming();
-  for (int i = 0; i != iters; ++i) {
+  for (auto s : state) {
     BenchmarkType* result = factory.TrivialFactory();
     if (result != nullptr) {
       result->DoWork();
@@ -566,11 +562,9 @@
 // Measure the time taken to call into the factory, providing an
 // out-param for the result, evaluating the status result and the
 // result pointer, and invoking the trivial function.
-void BM_ArgumentFactory(int iters) {
-  tensorflow::testing::StopTiming();
+void BM_ArgumentFactory(::testing::benchmark::State& state) {
   BenchmarkFactory<BenchmarkType> factory;
-  tensorflow::testing::StartTiming();
-  for (int i = 0; i != iters; ++i) {
+  for (auto s : state) {
     BenchmarkType* result = nullptr;
     Status status = factory.ArgumentFactory(&result);
     if (status.ok() && result != nullptr) {
@@ -582,11 +576,9 @@
 
 // Measure the time to use the StatusOr<T*> factory, evaluate the result,
 // and invoke the trivial function.
-void BM_StatusOrFactory(int iters) {
-  tensorflow::testing::StopTiming();
+void BM_StatusOrFactory(::testing::benchmark::State& state) {
   BenchmarkFactory<BenchmarkType> factory;
-  tensorflow::testing::StartTiming();
-  for (int i = 0; i != iters; ++i) {
+  for (auto s : state) {
     StatusOr<BenchmarkType*> result = factory.StatusOrFactory();
     if (result.ok()) {
       result.ValueOrDie()->DoWork();
@@ -598,11 +590,9 @@
 // Measure the time taken to call into the factory, providing an
 // out-param for the result, evaluating the status result and the
 // result pointer, and invoking the trivial function.
-void BM_ArgumentFactoryFail(int iters) {
-  tensorflow::testing::StopTiming();
+void BM_ArgumentFactoryFail(::testing::benchmark::State& state) {
   BenchmarkFactory<BenchmarkType> factory;
-  tensorflow::testing::StartTiming();
-  for (int i = 0; i != iters; ++i) {
+  for (auto s : state) {
     BenchmarkType* result = nullptr;
     Status status = factory.ArgumentFactoryFail(&result);
     if (status.ok() && result != nullptr) {
@@ -614,11 +604,9 @@
 
 // Measure the time to use the StatusOr<T*> factory, evaluate the result,
 // and invoke the trivial function.
-void BM_StatusOrFactoryFail(int iters) {
-  tensorflow::testing::StopTiming();
+void BM_StatusOrFactoryFail(::testing::benchmark::State& state) {
   BenchmarkFactory<BenchmarkType> factory;
-  tensorflow::testing::StartTiming();
-  for (int i = 0; i != iters; ++i) {
+  for (auto s : state) {
     StatusOr<BenchmarkType*> result = factory.StatusOrFactoryFail();
     if (result.ok()) {
       result.ValueOrDie()->DoWork();
@@ -630,11 +618,9 @@
 // Measure the time taken to call into the factory, providing an
 // out-param for the result, evaluating the status result and the
 // result pointer, and invoking the trivial function.
-void BM_ArgumentFactoryFailShortMsg(int iters) {
-  tensorflow::testing::StopTiming();
+void BM_ArgumentFactoryFailShortMsg(::testing::benchmark::State& state) {
   BenchmarkFactory<BenchmarkType> factory;
-  tensorflow::testing::StartTiming();
-  for (int i = 0; i != iters; ++i) {
+  for (auto s : state) {
     BenchmarkType* result = nullptr;
     Status status = factory.ArgumentFactoryFailShortMsg(&result);
     if (status.ok() && result != nullptr) {
@@ -646,11 +632,9 @@
 
 // Measure the time to use the StatusOr<T*> factory, evaluate the result,
 // and invoke the trivial function.
-void BM_StatusOrFactoryFailShortMsg(int iters) {
-  tensorflow::testing::StopTiming();
+void BM_StatusOrFactoryFailShortMsg(::testing::benchmark::State& state) {
   BenchmarkFactory<BenchmarkType> factory;
-  tensorflow::testing::StartTiming();
-  for (int i = 0; i != iters; ++i) {
+  for (auto s : state) {
     StatusOr<BenchmarkType*> result = factory.StatusOrFactoryFailShortMsg();
     if (result.ok()) {
       result.ValueOrDie()->DoWork();
@@ -662,11 +646,9 @@
 // Measure the time taken to call into the factory, providing an
 // out-param for the result, evaluating the status result and the
 // result pointer, and invoking the trivial function.
-void BM_ArgumentFactoryFailLongMsg(int iters) {
-  tensorflow::testing::StopTiming();
+void BM_ArgumentFactoryFailLongMsg(::testing::benchmark::State& state) {
   BenchmarkFactory<BenchmarkType> factory;
-  tensorflow::testing::StartTiming();
-  for (int i = 0; i != iters; ++i) {
+  for (auto s : state) {
     BenchmarkType* result = nullptr;
     Status status = factory.ArgumentFactoryFailLongMsg(&result);
     if (status.ok() && result != nullptr) {
@@ -678,11 +660,9 @@
 
 // Measure the time to use the StatusOr<T*> factory, evaluate the result,
 // and invoke the trivial function.
-void BM_StatusOrFactoryFailLongMsg(int iters) {
-  tensorflow::testing::StopTiming();
+void BM_StatusOrFactoryFailLongMsg(::testing::benchmark::State& state) {
   BenchmarkFactory<BenchmarkType> factory;
-  tensorflow::testing::StartTiming();
-  for (int i = 0; i != iters; ++i) {
+  for (auto s : state) {
     StatusOr<BenchmarkType*> result = factory.StatusOrFactoryFailLongMsg();
     if (result.ok()) {
       result.ValueOrDie()->DoWork();