Use C++11 to generate better quality random numbers.
PiperOrigin-RevId: 282586370
Change-Id: I7502bd2fbda2592adebe7abaefdf3c4367ba6e35
diff --git a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc
index 3448dde..82fb62c 100644
--- a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc
+++ b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.cc
@@ -20,6 +20,7 @@
#include <cstdlib>
#include <iostream>
#include <memory>
+#include <random>
#include <string>
#include <unordered_set>
#include <vector>
@@ -290,11 +291,9 @@
return default_params;
}
-BenchmarkTfLiteModel::BenchmarkTfLiteModel()
- : BenchmarkTfLiteModel(DefaultParams()) {}
-
BenchmarkTfLiteModel::BenchmarkTfLiteModel(BenchmarkParams params)
- : BenchmarkModel(std::move(params)) {}
+ : BenchmarkModel(std::move(params)),
+ random_engine_(std::random_device()()) {}
void BenchmarkTfLiteModel::CleanUp() {
// Free up any pre-allocated tensor data during PrepareInputData.
@@ -453,22 +452,16 @@
}
InputTensorData t_data;
if (t->type == kTfLiteFloat32) {
- t_data = InputTensorData::Create<float>(num_elements, []() {
- return static_cast<float>(rand()) / RAND_MAX - 0.5f;
- });
+ t_data = CreateInputTensorData<float>(
+ num_elements, std::uniform_real_distribution<float>(-0.5f, 0.5f));
} else if (t->type == kTfLiteFloat16) {
// TODO(b/138843274): Remove this preprocessor guard when bug is fixed.
#if TFLITE_ENABLE_FP16_CPU_BENCHMARKS
#if __GNUC__ && \
(__clang__ || __ARM_FP16_FORMAT_IEEE || __ARM_FP16_FORMAT_ALTERNATIVE)
// __fp16 is available on Clang or when __ARM_FP16_FORMAT_* is defined.
- t_data = InputTensorData::Create<TfLiteFloat16>(
- num_elements, []() -> TfLiteFloat16 {
- __fp16 f16_value = static_cast<float>(rand()) / RAND_MAX - 0.5f;
- TfLiteFloat16 f16_placeholder_value;
- memcpy(&f16_placeholder_value, &f16_value, sizeof(TfLiteFloat16));
- return f16_placeholder_value;
- });
+ t_data = CreateInputTensorData<__fp16>(
+ num_elements, std::uniform_real_distribution<float>(-0.5f, 0.5f));
#else
TFLITE_LOG(FATAL) << "Don't know how to populate tensor " << t->name
<< " of type FLOAT16 on this platform.";
@@ -484,33 +477,28 @@
} else if (t->type == kTfLiteInt64) {
int low = has_value_range ? low_range : 0;
int high = has_value_range ? high_range : 99;
- t_data = InputTensorData::Create<int64_t>(num_elements, [=]() {
- return static_cast<int64_t>(rand() % (high - low + 1) + low);
- });
+ t_data = CreateInputTensorData<int64_t>(
+ num_elements, std::uniform_int_distribution<int64_t>(low, high));
} else if (t->type == kTfLiteInt32) {
int low = has_value_range ? low_range : 0;
int high = has_value_range ? high_range : 99;
- t_data = InputTensorData::Create<int32_t>(num_elements, [=]() {
- return static_cast<int32_t>(rand() % (high - low + 1) + low);
- });
+ t_data = CreateInputTensorData<int32_t>(
+ num_elements, std::uniform_int_distribution<int32_t>(low, high));
} else if (t->type == kTfLiteInt16) {
int low = has_value_range ? low_range : 0;
int high = has_value_range ? high_range : 99;
- t_data = InputTensorData::Create<int16_t>(num_elements, [=]() {
- return static_cast<int16_t>(rand() % (high - low + 1) + low);
- });
+ t_data = CreateInputTensorData<int16_t>(
+ num_elements, std::uniform_int_distribution<int16_t>(low, high));
} else if (t->type == kTfLiteUInt8) {
int low = has_value_range ? low_range : 0;
int high = has_value_range ? high_range : 254;
- t_data = InputTensorData::Create<uint8_t>(num_elements, [=]() {
- return static_cast<uint8_t>(rand() % (high - low + 1) + low);
- });
+ t_data = CreateInputTensorData<uint8_t>(
+ num_elements, std::uniform_int_distribution<uint8_t>(low, high));
} else if (t->type == kTfLiteInt8) {
int low = has_value_range ? low_range : -127;
int high = has_value_range ? high_range : 127;
- t_data = InputTensorData::Create<int8_t>(num_elements, [=]() {
- return static_cast<int8_t>(rand() % (high - low + 1) + low);
- });
+ t_data = CreateInputTensorData<int8_t>(
+ num_elements, std::uniform_int_distribution<int8_t>(low, high));
} else if (t->type == kTfLiteString) {
// TODO(haoliang): No need to cache string tensors right now.
} else {
diff --git a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.h b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.h
index 491007f..ca7731e 100644
--- a/tensorflow/lite/tools/benchmark/benchmark_tflite_model.h
+++ b/tensorflow/lite/tools/benchmark/benchmark_tflite_model.h
@@ -19,6 +19,7 @@
#include <algorithm>
#include <map>
#include <memory>
+#include <random>
#include <string>
#include <vector>
@@ -47,8 +48,7 @@
int high;
};
- BenchmarkTfLiteModel();
- explicit BenchmarkTfLiteModel(BenchmarkParams params);
+ explicit BenchmarkTfLiteModel(BenchmarkParams params = DefaultParams());
~BenchmarkTfLiteModel() override;
std::vector<Flag> GetFlags() override;
@@ -80,30 +80,33 @@
struct InputTensorData {
InputTensorData() : data(nullptr, nullptr) {}
- template <typename T>
- static InputTensorData Create(int num_elements,
- const std::function<T()>& val_generator) {
- InputTensorData tmp;
- tmp.bytes = sizeof(T) * num_elements;
- T* raw = new T[num_elements];
- std::generate_n(raw, num_elements, val_generator);
- // Now initialize the type-erased unique_ptr (with custom deleter) from
- // 'raw'.
- tmp.data = std::unique_ptr<void, void (*)(void*)>(
- static_cast<void*>(raw),
- [](void* ptr) { delete[] static_cast<T*>(ptr); });
- return tmp;
- }
-
std::unique_ptr<void, void (*)(void*)> data;
size_t bytes;
};
+ template <typename T, typename Distribution>
+ inline InputTensorData CreateInputTensorData(int num_elements,
+ Distribution distribution) {
+ InputTensorData tmp;
+ tmp.bytes = sizeof(T) * num_elements;
+ T* raw = new T[num_elements];
+ std::generate_n(raw, num_elements,
+ [&]() { return distribution(random_engine_); });
+ // Now initialize the type-erased unique_ptr (with custom deleter) from
+ // 'raw'.
+ tmp.data = std::unique_ptr<void, void (*)(void*)>(
+ static_cast<void*>(raw),
+ [](void* ptr) { delete[] static_cast<T*>(ptr); });
+ return tmp;
+ }
+
std::vector<InputLayerInfo> inputs_;
std::vector<InputTensorData> inputs_data_;
std::unique_ptr<BenchmarkListener> profiling_listener_;
std::unique_ptr<BenchmarkListener> gemmlowp_profiling_listener_;
TfLiteDelegatePtrMap delegates_;
+
+ std::mt19937 random_engine_;
};
} // namespace benchmark