internal tests cleanup
PiperOrigin-RevId: 340568129
Change-Id: I521739fa73d00d4ef75a61556e54dc0c0805c2a9
diff --git a/tensorflow/core/platform/vmodule_benchmark_test.cc b/tensorflow/core/platform/vmodule_benchmark_test.cc
index 0f9e75b..f164ece 100644
--- a/tensorflow/core/platform/vmodule_benchmark_test.cc
+++ b/tensorflow/core/platform/vmodule_benchmark_test.cc
@@ -18,8 +18,8 @@
namespace tensorflow {
-static void BM_DisabledVlog(int iters) {
- for (int i = 0; i < iters; ++i) {
+static void BM_DisabledVlog(::testing::benchmark::State& state) {
+ for (auto s : state) {
VLOG(1) << "Testing VLOG(1)!";
}
}
diff --git a/tensorflow/core/util/bcast_test.cc b/tensorflow/core/util/bcast_test.cc
index b6e8bcd..63133fc 100644
--- a/tensorflow/core/util/bcast_test.cc
+++ b/tensorflow/core/util/bcast_test.cc
@@ -673,15 +673,17 @@
BCastBatchIndices({3, 1}, {2, 1, 2}));
}
-static void BM_BCastSetup(int iters, int same_shape) {
+void BM_BCastSetup(::testing::benchmark::State& state) {
+ const int same_shape = state.range(0);
+
if (same_shape) {
- testing::SetLabel("same_shapes");
- while (--iters > 0) {
+ state.SetLabel("same_shapes");
+ for (auto s : state) {
class BCast b({1000, 100}, {1000, 100});
}
} else {
- testing::SetLabel("different_shapes");
- while (--iters > 0) {
+ state.SetLabel("different_shapes");
+ for (auto s : state) {
class BCast b({3, 1, 5}, {2, 0, 3, 0, 5});
}
}
diff --git a/tensorflow/core/util/device_name_utils_test.cc b/tensorflow/core/util/device_name_utils_test.cc
index 065fcfb..7824d9f 100644
--- a/tensorflow/core/util/device_name_utils_test.cc
+++ b/tensorflow/core/util/device_name_utils_test.cc
@@ -572,9 +572,9 @@
}
}
-static void BM_ParseFullName(int iters) {
+static void BM_ParseFullName(::testing::benchmark::State& state) {
DeviceNameUtils::ParsedName p;
- while (iters--) {
+ for (auto s : state) {
DeviceNameUtils::ParseFullName("/job:worker/replica:3/task:0/cpu:0", &p);
}
}
diff --git a/tensorflow/core/util/presized_cuckoo_map_test.cc b/tensorflow/core/util/presized_cuckoo_map_test.cc
index f2c7904..36a7642 100644
--- a/tensorflow/core/util/presized_cuckoo_map_test.cc
+++ b/tensorflow/core/util/presized_cuckoo_map_test.cc
@@ -164,13 +164,13 @@
}
}
-static void BM_CuckooFill(int iters, int arg) {
+void BM_CuckooFill(::testing::benchmark::State &state) {
+ const int arg = state.range(0);
+
uint64 table_size = arg;
- testing::StopTiming();
std::vector<uint64> calculated_keys;
CalculateKeys(table_size, &calculated_keys);
- testing::StartTiming();
- for (int iter = 0; iter < iters; iter++) {
+ for (auto s : state) {
PresizedCuckooMap<int> pscm(table_size);
for (uint64 i = 0; i < table_size; i++) {
pscm.InsertUnique(calculated_keys[i], i);
@@ -180,25 +180,27 @@
BENCHMARK(BM_CuckooFill)->Arg(1000)->Arg(10000000);
-static void BM_CuckooRead(int iters, int arg) {
+void BM_CuckooRead(::testing::benchmark::State &state) {
+ const int arg = state.range(0);
+
uint64 table_size = arg;
- testing::StopTiming();
std::vector<uint64> calculated_keys;
CalculateKeys(table_size, &calculated_keys);
PresizedCuckooMap<int> pscm(table_size);
for (uint64 i = 0; i < table_size; i++) {
pscm.InsertUnique(calculated_keys[i], i);
}
- testing::StartTiming();
- uint64_t defeat_optimization = 0;
- for (int i = 0; i < iters; i++) {
- uint64 key_index = i % table_size; // May slow down bench!
+
+ int i = 0;
+ for (auto s : state) {
+ // Avoid using '%', which is expensive.
+ uint64 key_index = i;
+ ++i;
+ if (i == table_size) i = 0;
+
int out = 0;
pscm.Find(calculated_keys[key_index], &out);
- defeat_optimization += out;
- }
- if (defeat_optimization == 0) {
- printf("Preventing the compiler from eliding the inner loop\n");
+ tensorflow::testing::DoNotOptimize(out);
}
}
diff --git a/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc b/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc
index a2ac7c3..dea55f3 100644
--- a/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc
+++ b/tensorflow/core/util/tensor_bundle/tensor_bundle_test.cc
@@ -1109,9 +1109,8 @@
}
}
-static void BM_BundleAlignmentByteOff(int iters, int alignment,
- int tensor_size) {
- testing::StopTiming();
+static void BM_BundleAlignmentByteOff(::testing::benchmark::State& state,
+ int alignment, int tensor_size) {
{
BundleWriter::Options opts;
opts.data_alignment = alignment;
@@ -1122,18 +1121,17 @@
}
BundleReader reader(Env::Default(), Prefix("foo"));
TF_CHECK_OK(reader.status());
- testing::StartTiming();
- for (int i = 0; i < iters; ++i) {
+ for (auto s : state) {
Tensor t;
TF_CHECK_OK(reader.Lookup("big", &t));
}
- testing::StopTiming();
}
-#define BM_BundleAlignment(ALIGN, SIZE) \
- static void BM_BundleAlignment_##ALIGN##_##SIZE(int iters) { \
- BM_BundleAlignmentByteOff(iters, ALIGN, SIZE); \
- } \
+#define BM_BundleAlignment(ALIGN, SIZE) \
+ static void BM_BundleAlignment_##ALIGN##_##SIZE( \
+ ::testing::benchmark::State& state) { \
+ BM_BundleAlignmentByteOff(state, ALIGN, SIZE); \
+ } \
BENCHMARK(BM_BundleAlignment_##ALIGN##_##SIZE)
BM_BundleAlignment(1, 512);
diff --git a/tensorflow/core/util/work_sharder_test.cc b/tensorflow/core/util/work_sharder_test.cc
index bc5a1d2..f69572d 100644
--- a/tensorflow/core/util/work_sharder_test.cc
+++ b/tensorflow/core/util/work_sharder_test.cc
@@ -89,12 +89,14 @@
}
}
-void BM_Sharding(int iters, int arg) {
+void BM_Sharding(::testing::benchmark::State& state) {
+ const int arg = state.range(0);
+
thread::ThreadPool threads(Env::Default(), "test", 16);
const int64 total = 1LL << 30;
auto lambda = [](int64 start, int64 limit) {};
auto work = std::cref(lambda);
- for (; iters > 0; iters -= arg) {
+ for (auto s : state) {
Shard(arg - 1, &threads, total, 1, work);
}
}
diff --git a/tensorflow/stream_executor/lib/statusor_test.cc b/tensorflow/stream_executor/lib/statusor_test.cc
index 46bdb9d..6b59eaa 100644
--- a/tensorflow/stream_executor/lib/statusor_test.cc
+++ b/tensorflow/stream_executor/lib/statusor_test.cc
@@ -535,12 +535,10 @@
// Calibrate the amount of time spent just calling DoWork, since each of our
// tests will do this, we can subtract this out of benchmark results.
-void BM_CalibrateWorkLoop(int iters) {
- tensorflow::testing::StopTiming();
+void BM_CalibrateWorkLoop(::testing::benchmark::State& state) {
BenchmarkFactory<BenchmarkType> factory;
BenchmarkType* result = factory.TrivialFactory();
- tensorflow::testing::StartTiming();
- for (int i = 0; i != iters; ++i) {
+ for (auto s : state) {
if (result != nullptr) {
result->DoWork();
}
@@ -550,11 +548,9 @@
// Measure the time taken to call into the factory, return the value,
// determine that it is OK, and invoke a trivial function.
-void BM_TrivialFactory(int iters) {
- tensorflow::testing::StopTiming();
+void BM_TrivialFactory(::testing::benchmark::State& state) {
BenchmarkFactory<BenchmarkType> factory;
- tensorflow::testing::StartTiming();
- for (int i = 0; i != iters; ++i) {
+ for (auto s : state) {
BenchmarkType* result = factory.TrivialFactory();
if (result != nullptr) {
result->DoWork();
@@ -566,11 +562,9 @@
// Measure the time taken to call into the factory, providing an
// out-param for the result, evaluating the status result and the
// result pointer, and invoking the trivial function.
-void BM_ArgumentFactory(int iters) {
- tensorflow::testing::StopTiming();
+void BM_ArgumentFactory(::testing::benchmark::State& state) {
BenchmarkFactory<BenchmarkType> factory;
- tensorflow::testing::StartTiming();
- for (int i = 0; i != iters; ++i) {
+ for (auto s : state) {
BenchmarkType* result = nullptr;
Status status = factory.ArgumentFactory(&result);
if (status.ok() && result != nullptr) {
@@ -582,11 +576,9 @@
// Measure the time to use the StatusOr<T*> factory, evaluate the result,
// and invoke the trivial function.
-void BM_StatusOrFactory(int iters) {
- tensorflow::testing::StopTiming();
+void BM_StatusOrFactory(::testing::benchmark::State& state) {
BenchmarkFactory<BenchmarkType> factory;
- tensorflow::testing::StartTiming();
- for (int i = 0; i != iters; ++i) {
+ for (auto s : state) {
StatusOr<BenchmarkType*> result = factory.StatusOrFactory();
if (result.ok()) {
result.ValueOrDie()->DoWork();
@@ -598,11 +590,9 @@
// Measure the time taken to call into the factory, providing an
// out-param for the result, evaluating the status result and the
// result pointer, and invoking the trivial function.
-void BM_ArgumentFactoryFail(int iters) {
- tensorflow::testing::StopTiming();
+void BM_ArgumentFactoryFail(::testing::benchmark::State& state) {
BenchmarkFactory<BenchmarkType> factory;
- tensorflow::testing::StartTiming();
- for (int i = 0; i != iters; ++i) {
+ for (auto s : state) {
BenchmarkType* result = nullptr;
Status status = factory.ArgumentFactoryFail(&result);
if (status.ok() && result != nullptr) {
@@ -614,11 +604,9 @@
// Measure the time to use the StatusOr<T*> factory, evaluate the result,
// and invoke the trivial function.
-void BM_StatusOrFactoryFail(int iters) {
- tensorflow::testing::StopTiming();
+void BM_StatusOrFactoryFail(::testing::benchmark::State& state) {
BenchmarkFactory<BenchmarkType> factory;
- tensorflow::testing::StartTiming();
- for (int i = 0; i != iters; ++i) {
+ for (auto s : state) {
StatusOr<BenchmarkType*> result = factory.StatusOrFactoryFail();
if (result.ok()) {
result.ValueOrDie()->DoWork();
@@ -630,11 +618,9 @@
// Measure the time taken to call into the factory, providing an
// out-param for the result, evaluating the status result and the
// result pointer, and invoking the trivial function.
-void BM_ArgumentFactoryFailShortMsg(int iters) {
- tensorflow::testing::StopTiming();
+void BM_ArgumentFactoryFailShortMsg(::testing::benchmark::State& state) {
BenchmarkFactory<BenchmarkType> factory;
- tensorflow::testing::StartTiming();
- for (int i = 0; i != iters; ++i) {
+ for (auto s : state) {
BenchmarkType* result = nullptr;
Status status = factory.ArgumentFactoryFailShortMsg(&result);
if (status.ok() && result != nullptr) {
@@ -646,11 +632,9 @@
// Measure the time to use the StatusOr<T*> factory, evaluate the result,
// and invoke the trivial function.
-void BM_StatusOrFactoryFailShortMsg(int iters) {
- tensorflow::testing::StopTiming();
+void BM_StatusOrFactoryFailShortMsg(::testing::benchmark::State& state) {
BenchmarkFactory<BenchmarkType> factory;
- tensorflow::testing::StartTiming();
- for (int i = 0; i != iters; ++i) {
+ for (auto s : state) {
StatusOr<BenchmarkType*> result = factory.StatusOrFactoryFailShortMsg();
if (result.ok()) {
result.ValueOrDie()->DoWork();
@@ -662,11 +646,9 @@
// Measure the time taken to call into the factory, providing an
// out-param for the result, evaluating the status result and the
// result pointer, and invoking the trivial function.
-void BM_ArgumentFactoryFailLongMsg(int iters) {
- tensorflow::testing::StopTiming();
+void BM_ArgumentFactoryFailLongMsg(::testing::benchmark::State& state) {
BenchmarkFactory<BenchmarkType> factory;
- tensorflow::testing::StartTiming();
- for (int i = 0; i != iters; ++i) {
+ for (auto s : state) {
BenchmarkType* result = nullptr;
Status status = factory.ArgumentFactoryFailLongMsg(&result);
if (status.ok() && result != nullptr) {
@@ -678,11 +660,9 @@
// Measure the time to use the StatusOr<T*> factory, evaluate the result,
// and invoke the trivial function.
-void BM_StatusOrFactoryFailLongMsg(int iters) {
- tensorflow::testing::StopTiming();
+void BM_StatusOrFactoryFailLongMsg(::testing::benchmark::State& state) {
BenchmarkFactory<BenchmarkType> factory;
- tensorflow::testing::StartTiming();
- for (int i = 0; i != iters; ++i) {
+ for (auto s : state) {
StatusOr<BenchmarkType*> result = factory.StatusOrFactoryFailLongMsg();
if (result.ok()) {
result.ValueOrDie()->DoWork();