[tf.data service] thread_safe_buffer_test tests varying buffer and input sizes.

PiperOrigin-RevId: 376272959
Change-Id: I971666e9ac9185c0ae9db46f8bf76b7aae98d706
diff --git a/tensorflow/core/data/service/BUILD b/tensorflow/core/data/service/BUILD
index 1a8e203..f5d6217 100644
--- a/tensorflow/core/data/service/BUILD
+++ b/tensorflow/core/data/service/BUILD
@@ -574,7 +574,7 @@
     name = "thread_safe_buffer_test",
     size = "small",
     srcs = ["thread_safe_buffer_test.cc"],
-    shard_count = 2,
+    shard_count = 3,
     deps = [
         ":thread_safe_buffer",
         "//tensorflow/core:framework",
diff --git a/tensorflow/core/data/service/thread_safe_buffer_test.cc b/tensorflow/core/data/service/thread_safe_buffer_test.cc
index cf9f125..5a406f0 100644
--- a/tensorflow/core/data/service/thread_safe_buffer_test.cc
+++ b/tensorflow/core/data/service/thread_safe_buffer_test.cc
@@ -15,6 +15,7 @@
 #include "tensorflow/core/data/service/thread_safe_buffer.h"
 
 #include <memory>
+#include <tuple>
 #include <vector>
 
 #include "absl/strings/str_cat.h"
@@ -31,10 +32,31 @@
 namespace data {
 namespace {
 
-using ::testing::UnorderedElementsAre;
+using ::testing::UnorderedElementsAreArray;
 
-TEST(ThreadSafeBufferTest, OneReaderAndOneWriter) {
-  ThreadSafeBuffer<Tensor> buffer(/*buffer_size=*/1);
+class ThreadSafeBufferTest
+    : public ::testing::Test,
+      public ::testing::WithParamInterface<std::tuple<size_t, size_t>> {
+ protected:
+  size_t GetBufferSize() const { return std::get<0>(GetParam()); }
+  size_t GetNumOfElements() const { return std::get<1>(GetParam()); }
+};
+
+std::vector<int> GetRange(const size_t range) {
+  std::vector<int> result;
+  for (int i = 0; i < range; ++i) {
+    result.push_back(i);
+  }
+  return result;
+}
+
+INSTANTIATE_TEST_SUITE_P(VaryingBufferAndInputSizes, ThreadSafeBufferTest,
+                         ::testing::Values(std::make_tuple(1, 2),
+                                           std::make_tuple(2, 10),
+                                           std::make_tuple(10, 2)));
+
+TEST_P(ThreadSafeBufferTest, OneReaderAndOneWriter) {
+  ThreadSafeBuffer<Tensor> buffer(GetBufferSize());
   auto thread = absl::WrapUnique(Env::Default()->StartThread(
       /*thread_options=*/{}, /*name=*/"writer_thread",
       [&buffer]() { TF_EXPECT_OK(buffer.Push(Tensor("Test tensor"))); }));
@@ -43,32 +65,30 @@
   test::ExpectEqual(tensor, Tensor("Test tensor"));
 }
 
-TEST(ThreadSafeBufferTest, OneReaderAndMultipleWriters) {
-  constexpr size_t kNumOfElements = 10;
-  ThreadSafeBuffer<int> buffer(/*buffer_size=*/1);
+TEST_P(ThreadSafeBufferTest, OneReaderAndMultipleWriters) {
+  ThreadSafeBuffer<int> buffer(GetBufferSize());
   std::vector<std::unique_ptr<Thread>> threads;
-  for (int i = 0; i < kNumOfElements; ++i) {
+  for (int i = 0; i < GetNumOfElements(); ++i) {
     threads.push_back(absl::WrapUnique(Env::Default()->StartThread(
         /*thread_options=*/{}, /*name=*/absl::StrCat("writer_thread_", i),
         [&buffer, i] { TF_EXPECT_OK(buffer.Push(i)); })));
   }
 
   std::vector<int> results;
-  for (int i = 0; i < kNumOfElements; ++i) {
+  for (int i = 0; i < GetNumOfElements(); ++i) {
     TF_ASSERT_OK_AND_ASSIGN(int next, buffer.Pop());
     results.push_back(next);
   }
-  EXPECT_THAT(results, UnorderedElementsAre(0, 1, 2, 3, 4, 5, 6, 7, 8, 9));
+  EXPECT_THAT(results, UnorderedElementsAreArray(GetRange(GetNumOfElements())));
 }
 
-TEST(ThreadSafeBufferTest, MultipleReadersAndOneWriter) {
-  constexpr size_t kNumOfElements = 10;
-  ThreadSafeBuffer<int> buffer(/*buffer_size=*/1);
+TEST_P(ThreadSafeBufferTest, MultipleReadersAndOneWriter) {
+  ThreadSafeBuffer<int> buffer(GetBufferSize());
   mutex mu;
   std::vector<int> results;
 
   std::vector<std::unique_ptr<Thread>> threads;
-  for (int i = 0; i < kNumOfElements; ++i) {
+  for (int i = 0; i < GetNumOfElements(); ++i) {
     threads.push_back(absl::WrapUnique(Env::Default()->StartThread(
         /*thread_options=*/{}, /*name=*/absl::StrCat("reader_thread_", i),
         [&buffer, &mu, &results]() {
@@ -78,23 +98,22 @@
         })));
   }
 
-  for (int i = 0; i < kNumOfElements; ++i) {
+  for (int i = 0; i < GetNumOfElements(); ++i) {
     TF_EXPECT_OK(buffer.Push(i));
   }
 
   // Wait for all threads to complete.
   threads.clear();
-  EXPECT_THAT(results, UnorderedElementsAre(0, 1, 2, 3, 4, 5, 6, 7, 8, 9));
+  EXPECT_THAT(results, UnorderedElementsAreArray(GetRange(GetNumOfElements())));
 }
 
-TEST(ThreadSafeBufferTest, MultipleReadersAndWriters) {
-  constexpr size_t kNumOfElements = 10;
-  ThreadSafeBuffer<int> buffer(/*buffer_size=*/1);
+TEST_P(ThreadSafeBufferTest, MultipleReadersAndWriters) {
+  ThreadSafeBuffer<int> buffer(GetBufferSize());
   mutex mu;
   std::vector<int> results;
 
   std::vector<std::unique_ptr<Thread>> threads;
-  for (int i = 0; i < kNumOfElements; ++i) {
+  for (int i = 0; i < GetNumOfElements(); ++i) {
     threads.push_back(absl::WrapUnique(Env::Default()->StartThread(
         /*thread_options=*/{}, /*name=*/absl::StrCat("reader_thread_", i),
         [&buffer, &mu, &results]() {
@@ -104,7 +123,7 @@
         })));
   }
 
-  for (int i = 0; i < kNumOfElements; ++i) {
+  for (int i = 0; i < GetNumOfElements(); ++i) {
     threads.push_back(absl::WrapUnique(Env::Default()->StartThread(
         /*thread_options=*/{}, /*name=*/absl::StrCat("writer_thread_", i),
         [&buffer, i]() { TF_EXPECT_OK(buffer.Push(i)); })));
@@ -112,14 +131,14 @@
 
   // Wait for all threads to complete.
   threads.clear();
-  EXPECT_THAT(results, UnorderedElementsAre(0, 1, 2, 3, 4, 5, 6, 7, 8, 9));
+  EXPECT_THAT(results, UnorderedElementsAreArray(GetRange(GetNumOfElements())));
 }
 
-TEST(ThreadSafeBufferTest, CancelReaders) {
-  ThreadSafeBuffer<int> buffer(/*buffer_size=*/1);
+TEST_P(ThreadSafeBufferTest, CancelReaders) {
+  ThreadSafeBuffer<int> buffer(GetBufferSize());
   std::vector<std::unique_ptr<Thread>> threads;
 
-  for (int i = 0; i < 10; ++i) {
+  for (int i = 0; i < GetNumOfElements(); ++i) {
     threads.push_back(absl::WrapUnique(Env::Default()->StartThread(
         /*thread_options=*/{}, /*name=*/absl::StrCat("reader_thread_", i),
         [&buffer]() {
@@ -129,13 +148,15 @@
   buffer.Cancel(errors::Aborted("Aborted"));
 }
 
-TEST(ThreadSafeBufferTest, CancelWriters) {
-  constexpr size_t kNumOfElements = 10;
-  ThreadSafeBuffer<Tensor> buffer(/*buffer_size=*/1);
-  TF_EXPECT_OK(buffer.Push(Tensor("Test tensor")));
+TEST_P(ThreadSafeBufferTest, CancelWriters) {
+  ThreadSafeBuffer<Tensor> buffer(GetBufferSize());
+  // Fills the buffer so subsequent pushes are all cancelled.
+  for (int i = 0; i < GetBufferSize(); ++i) {
+    TF_EXPECT_OK(buffer.Push(Tensor("Test tensor")));
+  }
 
   std::vector<std::unique_ptr<Thread>> threads;
-  for (int i = 0; i < kNumOfElements; ++i) {
+  for (int i = 0; i < GetNumOfElements(); ++i) {
     threads.push_back(absl::WrapUnique(Env::Default()->StartThread(
         /*thread_options=*/{}, /*name=*/absl::StrCat("writer_thread_", i),
         [&buffer]() {