[tf.data service] thread_safe_buffer_test tests varying buffer and input sizes.
PiperOrigin-RevId: 376272959
Change-Id: I971666e9ac9185c0ae9db46f8bf76b7aae98d706
diff --git a/tensorflow/core/data/service/BUILD b/tensorflow/core/data/service/BUILD
index 1a8e203..f5d6217 100644
--- a/tensorflow/core/data/service/BUILD
+++ b/tensorflow/core/data/service/BUILD
@@ -574,7 +574,7 @@
name = "thread_safe_buffer_test",
size = "small",
srcs = ["thread_safe_buffer_test.cc"],
- shard_count = 2,
+ shard_count = 3,
deps = [
":thread_safe_buffer",
"//tensorflow/core:framework",
diff --git a/tensorflow/core/data/service/thread_safe_buffer_test.cc b/tensorflow/core/data/service/thread_safe_buffer_test.cc
index cf9f125..5a406f0 100644
--- a/tensorflow/core/data/service/thread_safe_buffer_test.cc
+++ b/tensorflow/core/data/service/thread_safe_buffer_test.cc
@@ -15,6 +15,7 @@
#include "tensorflow/core/data/service/thread_safe_buffer.h"
#include <memory>
+#include <tuple>
#include <vector>
#include "absl/strings/str_cat.h"
@@ -31,10 +32,31 @@
namespace data {
namespace {
-using ::testing::UnorderedElementsAre;
+using ::testing::UnorderedElementsAreArray;
-TEST(ThreadSafeBufferTest, OneReaderAndOneWriter) {
- ThreadSafeBuffer<Tensor> buffer(/*buffer_size=*/1);
+class ThreadSafeBufferTest
+ : public ::testing::Test,
+ public ::testing::WithParamInterface<std::tuple<size_t, size_t>> {
+ protected:
+ size_t GetBufferSize() const { return std::get<0>(GetParam()); }
+ size_t GetNumOfElements() const { return std::get<1>(GetParam()); }
+};
+
+std::vector<int> GetRange(const size_t range) {
+ std::vector<int> result;
+ for (int i = 0; i < range; ++i) {
+ result.push_back(i);
+ }
+ return result;
+}
+
+INSTANTIATE_TEST_SUITE_P(VaryingBufferAndInputSizes, ThreadSafeBufferTest,
+ ::testing::Values(std::make_tuple(1, 2),
+ std::make_tuple(2, 10),
+ std::make_tuple(10, 2)));
+
+TEST_P(ThreadSafeBufferTest, OneReaderAndOneWriter) {
+ ThreadSafeBuffer<Tensor> buffer(GetBufferSize());
auto thread = absl::WrapUnique(Env::Default()->StartThread(
/*thread_options=*/{}, /*name=*/"writer_thread",
[&buffer]() { TF_EXPECT_OK(buffer.Push(Tensor("Test tensor"))); }));
@@ -43,32 +65,30 @@
test::ExpectEqual(tensor, Tensor("Test tensor"));
}
-TEST(ThreadSafeBufferTest, OneReaderAndMultipleWriters) {
- constexpr size_t kNumOfElements = 10;
- ThreadSafeBuffer<int> buffer(/*buffer_size=*/1);
+TEST_P(ThreadSafeBufferTest, OneReaderAndMultipleWriters) {
+ ThreadSafeBuffer<int> buffer(GetBufferSize());
std::vector<std::unique_ptr<Thread>> threads;
- for (int i = 0; i < kNumOfElements; ++i) {
+ for (int i = 0; i < GetNumOfElements(); ++i) {
threads.push_back(absl::WrapUnique(Env::Default()->StartThread(
/*thread_options=*/{}, /*name=*/absl::StrCat("writer_thread_", i),
[&buffer, i] { TF_EXPECT_OK(buffer.Push(i)); })));
}
std::vector<int> results;
- for (int i = 0; i < kNumOfElements; ++i) {
+ for (int i = 0; i < GetNumOfElements(); ++i) {
TF_ASSERT_OK_AND_ASSIGN(int next, buffer.Pop());
results.push_back(next);
}
- EXPECT_THAT(results, UnorderedElementsAre(0, 1, 2, 3, 4, 5, 6, 7, 8, 9));
+ EXPECT_THAT(results, UnorderedElementsAreArray(GetRange(GetNumOfElements())));
}
-TEST(ThreadSafeBufferTest, MultipleReadersAndOneWriter) {
- constexpr size_t kNumOfElements = 10;
- ThreadSafeBuffer<int> buffer(/*buffer_size=*/1);
+TEST_P(ThreadSafeBufferTest, MultipleReadersAndOneWriter) {
+ ThreadSafeBuffer<int> buffer(GetBufferSize());
mutex mu;
std::vector<int> results;
std::vector<std::unique_ptr<Thread>> threads;
- for (int i = 0; i < kNumOfElements; ++i) {
+ for (int i = 0; i < GetNumOfElements(); ++i) {
threads.push_back(absl::WrapUnique(Env::Default()->StartThread(
/*thread_options=*/{}, /*name=*/absl::StrCat("reader_thread_", i),
[&buffer, &mu, &results]() {
@@ -78,23 +98,22 @@
})));
}
- for (int i = 0; i < kNumOfElements; ++i) {
+ for (int i = 0; i < GetNumOfElements(); ++i) {
TF_EXPECT_OK(buffer.Push(i));
}
// Wait for all threads to complete.
threads.clear();
- EXPECT_THAT(results, UnorderedElementsAre(0, 1, 2, 3, 4, 5, 6, 7, 8, 9));
+ EXPECT_THAT(results, UnorderedElementsAreArray(GetRange(GetNumOfElements())));
}
-TEST(ThreadSafeBufferTest, MultipleReadersAndWriters) {
- constexpr size_t kNumOfElements = 10;
- ThreadSafeBuffer<int> buffer(/*buffer_size=*/1);
+TEST_P(ThreadSafeBufferTest, MultipleReadersAndWriters) {
+ ThreadSafeBuffer<int> buffer(GetBufferSize());
mutex mu;
std::vector<int> results;
std::vector<std::unique_ptr<Thread>> threads;
- for (int i = 0; i < kNumOfElements; ++i) {
+ for (int i = 0; i < GetNumOfElements(); ++i) {
threads.push_back(absl::WrapUnique(Env::Default()->StartThread(
/*thread_options=*/{}, /*name=*/absl::StrCat("reader_thread_", i),
[&buffer, &mu, &results]() {
@@ -104,7 +123,7 @@
})));
}
- for (int i = 0; i < kNumOfElements; ++i) {
+ for (int i = 0; i < GetNumOfElements(); ++i) {
threads.push_back(absl::WrapUnique(Env::Default()->StartThread(
/*thread_options=*/{}, /*name=*/absl::StrCat("writer_thread_", i),
[&buffer, i]() { TF_EXPECT_OK(buffer.Push(i)); })));
@@ -112,14 +131,14 @@
// Wait for all threads to complete.
threads.clear();
- EXPECT_THAT(results, UnorderedElementsAre(0, 1, 2, 3, 4, 5, 6, 7, 8, 9));
+ EXPECT_THAT(results, UnorderedElementsAreArray(GetRange(GetNumOfElements())));
}
-TEST(ThreadSafeBufferTest, CancelReaders) {
- ThreadSafeBuffer<int> buffer(/*buffer_size=*/1);
+TEST_P(ThreadSafeBufferTest, CancelReaders) {
+ ThreadSafeBuffer<int> buffer(GetBufferSize());
std::vector<std::unique_ptr<Thread>> threads;
- for (int i = 0; i < 10; ++i) {
+ for (int i = 0; i < GetNumOfElements(); ++i) {
threads.push_back(absl::WrapUnique(Env::Default()->StartThread(
/*thread_options=*/{}, /*name=*/absl::StrCat("reader_thread_", i),
[&buffer]() {
@@ -129,13 +148,15 @@
buffer.Cancel(errors::Aborted("Aborted"));
}
-TEST(ThreadSafeBufferTest, CancelWriters) {
- constexpr size_t kNumOfElements = 10;
- ThreadSafeBuffer<Tensor> buffer(/*buffer_size=*/1);
- TF_EXPECT_OK(buffer.Push(Tensor("Test tensor")));
+TEST_P(ThreadSafeBufferTest, CancelWriters) {
+ ThreadSafeBuffer<Tensor> buffer(GetBufferSize());
+ // Fills the buffer so subsequent pushes are all cancelled.
+ for (int i = 0; i < GetBufferSize(); ++i) {
+ TF_EXPECT_OK(buffer.Push(Tensor("Test tensor")));
+ }
std::vector<std::unique_ptr<Thread>> threads;
- for (int i = 0; i < kNumOfElements; ++i) {
+ for (int i = 0; i < GetNumOfElements(); ++i) {
threads.push_back(absl::WrapUnique(Env::Default()->StartThread(
/*thread_options=*/{}, /*name=*/absl::StrCat("writer_thread_", i),
[&buffer]() {