Refactor test case to new format

Refactor test case to new format

Refactor initialization of input dataset
diff --git a/tensorflow/core/kernels/data/experimental/sampling_dataset_op_test.cc b/tensorflow/core/kernels/data/experimental/sampling_dataset_op_test.cc
index 8baa68a..9abb222 100644
--- a/tensorflow/core/kernels/data/experimental/sampling_dataset_op_test.cc
+++ b/tensorflow/core/kernels/data/experimental/sampling_dataset_op_test.cc
@@ -19,18 +19,7 @@
 namespace {
 
 constexpr char kNodeName[] = "sampling_dataset";
-
-// Parameters for constructing a dataset that returns an ordered sequence
-// of numbers
-struct RangeDatasetParams {
-  int start;
-  int stop;
-  int step;
-};
-
-struct TakeDatasetParams {
-  int count;
-};
+constexpr char kIteratorPrefix[] = "Iterator";
 
 class SamplingDatasetOpTest : public DatasetOpsTestBase {
  protected:
@@ -63,120 +52,135 @@
     TF_RETURN_IF_ERROR(CreateOpKernelContext(op_kernel, inputs, context));
     return Status::OK();
   }
-
-  // Build a dataset that will return an ordered sequence of numbers in chunks
-  // of size `params.count`.
-  // Stuffs the returned dataset into a variant tensor.
-  Status MakeRangeAndTakeDatasetTensor(
-      const RangeDatasetParams& range_dataset_params,
-      const TakeDatasetParams& take_dataset_params,
-      Tensor* range_and_take_dataset_tensor) {
-    Tensor range_dataset_tensor;
-    Tensor start =
-        CreateTensor<int64>(TensorShape({}), {range_dataset_params.start});
-    Tensor stop =
-        CreateTensor<int64>(TensorShape({}), {range_dataset_params.stop});
-    Tensor step =
-        CreateTensor<int64>(TensorShape({}), {range_dataset_params.step});
-    TF_RETURN_IF_ERROR(MakeRangeDataset(start, stop, step, {DT_INT64},
-                                        {PartialTensorShape({})},
-                                        &range_dataset_tensor));
-
-    TF_RETURN_IF_ERROR(MakeTakeDataset(
-        range_dataset_tensor, take_dataset_params.count, {DT_INT64},
-        {PartialTensorShape({})}, range_and_take_dataset_tensor));
-    return Status::OK();
-  }
 };
 
-// Common parameters that every test case in this file shares
-struct TestCase {
+// TODO(frreiss): Remove this once #31344 goes in and RangeDatasetParams is
+// defined in dataset_test_base.h
+class LocalRangeDatasetParams : public DatasetParams {
+ public:
+  LocalRangeDatasetParams(int64 start, int64 stop, int64 step,
+                          DataTypeVector output_dtypes,
+                          std::vector<PartialTensorShape> output_shapes,
+                          string node_name)
+      : DatasetParams(std::move(output_dtypes), std::move(output_shapes),
+                      std::move(node_name)),
+        start(CreateTensor<int64>(TensorShape({}), {start})),
+        stop(CreateTensor<int64>(TensorShape({}), {stop})),
+        step(CreateTensor<int64>(TensorShape({}), {step})) {}
+
+  Status MakeInputs(gtl::InlinedVector<TensorValue, 4>* inputs) override {
+    *inputs = {TensorValue(&start), TensorValue(&stop), TensorValue(&step)};
+    return Status::OK();
+  }
+
+  Tensor start;
+  Tensor stop;
+  Tensor step;
+};
+
+class SamplingDatasetParams : public DatasetParams {
+ public:
+  SamplingDatasetParams(float rate, int64 seed, int64 seed2, int64 start,
+                        int64 stop, int64 step, DataTypeVector output_dtypes,
+                        std::vector<PartialTensorShape> output_shapes,
+                        string node_name)
+      : DatasetParams(std::move(output_dtypes), std::move(output_shapes),
+                      std::move(node_name)),
+        rate(CreateTensor<float>(TensorShape({}), {rate})),
+        seed(CreateTensor<int64>(TensorShape({}), {seed})),
+        seed2(CreateTensor<int64>(TensorShape({}), {seed2})),
+        range_dataset_params(start, stop, step, {DT_INT64},
+                             {PartialTensorShape({})}, "") {}
+
+  Status MakeInputs(gtl::InlinedVector<TensorValue, 4>* inputs) override {
+    if (input_dataset.NumElements() == 0 ||
+        input_dataset.dtype() != DT_VARIANT) {
+      return tensorflow::errors::Internal(
+          "The input dataset is not populated as the dataset tensor yet.");
+    }
+    *inputs = {TensorValue(&input_dataset), TensorValue(&rate),
+               TensorValue(&seed), TensorValue(&seed2)};
+    return Status::OK();
+  }
+
   // Static parameters of the kernel
-  float rate;
-  int64 seed;
-  int64 seed2;
+  Tensor rate;
+  Tensor seed;
+  Tensor seed2;
 
   // Parameters of the sequence of numbers that will serve as the dynamic input
   // of the kernel.
-  RangeDatasetParams range_dataset_params;
-  TakeDatasetParams take_dataset_params;
+  LocalRangeDatasetParams range_dataset_params;
 
-  // The tensors that the kernel is expected to return, in the order they
-  // should be returned
-  std::vector<Tensor> expected_outputs;
-
-  // Information about the returned outputs of the op that the test case
-  // creates.
-  DataTypeVector expected_output_dtypes;
-  std::vector<PartialTensorShape> expected_output_shapes;
-
-  // Value that the dataset's Cardinality() function returns. May be different
-  // from the size of the outputs, as Cardinality() is not supposed to perform
-  // expensive computations.
-  int64 expected_cardinality;
-
-  // When to insert save and restore steps while scanning the dataset in the
-  // "roundtrip" test case.
-  std::vector<int> breakpoints;
+  // RangeDataset kernel wrapped in a variant tensor. Initialized by the test
+  // case itself because the MakeRangeDataset() method requires an instance of
+  // DatasetOpsTestBase.
+  Tensor input_dataset;
 };
 
-// Test case 1: 100% sample should return all inputs
-TestCase TestCase1() {
+SamplingDatasetParams OneHundredPercentSampleDataset() {
   return {/*rate*/ 1.0,
           /*seed*/ 42,
           /*seed2*/ 7,
-          /*range_dataset_params*/ {/*start*/ 0, /*stop*/ 10, /*step*/ 1},
-          /*take_dataset_params*/ {/*count*/ 3},
-          /*expected_outputs*/
-          {DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {0}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2})},
-          /*expected_output_dtypes*/ {DT_INT64},
-          /*expected_output_shapes*/ {PartialTensorShape({})},
-          /*expected_cardinality*/ kUnknownCardinality,
-          /*breakpoints*/ {0, 2, 5}};
+          /*start*/ 0,
+          /*stop*/ 3,
+          /*step*/ 1,
+          /*output_dtypes*/ {DT_INT64},
+          /*output_shapes*/ {PartialTensorShape({})},
+          /*node_name=*/kNodeName};
+}
+
+SamplingDatasetParams TenPercentSampleDataset() {
+  return {/*rate*/ 0.1,
+          /*seed*/ 42,
+          /*seed2*/ 7,
+          /*start*/ 0,
+          /*stop*/ 20,
+          /*step*/ 1,
+          /*output_dtypes*/ {DT_INT64},
+          /*output_shapes*/ {PartialTensorShape({})},
+          /*node_name=*/kNodeName};
+}
+
+SamplingDatasetParams ZeroPercentSampleDataset() {
+  return {/*rate*/ 0.0,
+          /*seed*/ 42,
+          /*seed2*/ 7,
+          /*start*/ 0,
+          /*stop*/ 20,
+          /*step*/ 1,
+          /*output_dtypes*/ {DT_INT64},
+          /*output_shapes*/ {PartialTensorShape({})},
+          /*node_name=*/kNodeName};
+}
+
+class ParameterizedGetNextSamplingDatasetOpTest
+    : public SamplingDatasetOpTest,
+      public ::testing::WithParamInterface<
+          GetNextTestCase<SamplingDatasetParams>> {};
+
+// Test case 1: 100% sample should return all inputs
+GetNextTestCase<SamplingDatasetParams> GetNextTestCase1() {
+  return {/*dataset_params=*/OneHundredPercentSampleDataset(),
+          /*expected_outputs=*/
+          CreateTensors<int64>(TensorShape({}), {{0}, {1}, {2}})};
 }
 
 // Test case 2: 10% sample should return about 10% of inputs, and the specific
 // inputs returned shouldn't change across build environments.
-TestCase TestCase2() {
-  return {/*rate*/ 0.1,
-          /*seed*/ 42,
-          /*seed2*/ 7,
-          /*range_dataset_params*/ {/*start*/ 0, /*stop*/ 100, /*step*/ 1},
-          /*take_dataset_params*/ {/*count*/ 20},
-          /*expected_outputs*/
-          {DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {9}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {11}),
-           DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {19})},
-          /*expected_output_dtypes*/ {DT_INT64},
-          /*expected_output_shapes*/ {PartialTensorShape({})},
-          /*expected_cardinality*/ kUnknownCardinality,
-          /*breakpoints*/ {0, 2, 5}};
+GetNextTestCase<SamplingDatasetParams> GetNextTestCase2() {
+  return {/*dataset_params=*/TenPercentSampleDataset(),
+          /*expected_outputs=*/
+          CreateTensors<int64>(TensorShape({}), {{9}, {11}, {19}})};
 }
 
 // Test case 3: 0% sample should return nothing and should not crash.
-TestCase TestCase3() {
-  return {/*rate*/ 0.0,
-          /*seed*/ 42,
-          /*seed2*/ 7,
-          /*range_dataset_params*/ {/*start*/ 0, /*stop*/ 100, /*step*/ 1},
-          /*take_dataset_params*/ {/*count*/ 20},
-          /*expected_outputs*/
-          {},
-          /*expected_output_dtypes*/ {DT_INT64},
-          /*expected_output_shapes*/ {PartialTensorShape({})},
-          /*expected_cardinality*/ kUnknownCardinality,
-          /*breakpoints*/ {0, 2, 5}};
+GetNextTestCase<SamplingDatasetParams> GetNextTestCase3() {
+  return {/*dataset_params=*/ZeroPercentSampleDataset(),
+          /*expected_outputs=*/{}};
 }
 
-// Parameterized test class shared by the next 6 test cases
-class ParameterizedSamplingDatasetOpTest
-    : public SamplingDatasetOpTest,
-      public ::testing::WithParamInterface<TestCase> {};
-
-// Verify that the GetNext function works and returns the expected outputs
-TEST_P(ParameterizedSamplingDatasetOpTest, GetNext) {
+TEST_P(ParameterizedGetNextSamplingDatasetOpTest, GetNext) {
   // BEGIN INITIALIZATION CODE
   // This test case and all the other test cases in this file go through the
   // same sequence of initialization steps.
@@ -184,31 +188,30 @@
 
   // Step 1: Set up enough of a TF runtime to be able to invoke a kernel.
   const int thread_num = 2, cpu_num = 2;
-  TestCase test_case = GetParam();
+  auto test_case = GetParam();
   TF_ASSERT_OK(InitThreadPool(thread_num));
   TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
 
   // Step 2: Create the dataset that will provide input data for the kernel
-  Tensor range_and_take_dataset_tensor;
-  TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params,
-                                             test_case.take_dataset_params,
-                                             &range_and_take_dataset_tensor));
+  TF_ASSERT_OK(MakeRangeDataset(
+      test_case.dataset_params.range_dataset_params.start,
+      test_case.dataset_params.range_dataset_params.stop,
+      test_case.dataset_params.range_dataset_params.step,
+      test_case.dataset_params.range_dataset_params.output_dtypes,
+      test_case.dataset_params.range_dataset_params.output_shapes,
+      &test_case.dataset_params.input_dataset));
 
   // Step 3: Box up the four inputs to the kernel inside TensorValue objects
   // inside a vector.
-  Tensor rate = CreateTensor<float>(TensorShape({}), {test_case.rate});
-  Tensor seed = CreateTensor<int64>(TensorShape({}), {test_case.seed});
-  Tensor seed2 = CreateTensor<int64>(TensorShape({}), {test_case.seed2});
-  gtl::InlinedVector<TensorValue, 4> inputs(
-      {TensorValue(&range_and_take_dataset_tensor), TensorValue(&rate),
-       TensorValue(&seed), TensorValue(&seed2)});
+  gtl::InlinedVector<TensorValue, 4> inputs;
+  TF_ASSERT_OK(test_case.dataset_params.MakeInputs(&inputs));
 
   // Step 4: Create a SamplingDataset kernel to test, passing in attributes
   // of the kernel.
   std::unique_ptr<OpKernel> sampling_dataset_kernel;
-  TF_ASSERT_OK(CreateSamplingDatasetOpKernel(test_case.expected_output_dtypes,
-                                             test_case.expected_output_shapes,
-                                             &sampling_dataset_kernel));
+  TF_ASSERT_OK(CreateSamplingDatasetOpKernel(
+      test_case.dataset_params.output_dtypes,
+      test_case.dataset_params.output_shapes, &sampling_dataset_kernel));
 
   // Step 5: Create a context in which the kernel will operate. This is where
   // the kernel gets initialized with its inputs
@@ -229,11 +232,8 @@
   TF_ASSERT_OK(
       CreateIteratorContext(sampling_dataset_context.get(), &iterator_context));
   std::unique_ptr<IteratorBase> iterator;
-  string iterator_prefix = name_utils::IteratorPrefix(
-      TakeDatasetOp::kDatasetType,
-      name_utils::IteratorPrefix(RangeDatasetOp::kDatasetType, "Iterator"));
   TF_ASSERT_OK(sampling_dataset->MakeIterator(iterator_context.get(),
-                                              iterator_prefix, &iterator));
+                                              kIteratorPrefix, &iterator));
   // END INITIALIZATION CODE
 
   // Copy the iterator's output into a vector to make comparison easier.
@@ -250,32 +250,42 @@
                            /*compare_order*/ true));
 }
 
+INSTANTIATE_TEST_SUITE_P(
+    SamplingDatasetOpTest, ParameterizedGetNextSamplingDatasetOpTest,
+    ::testing::ValuesIn(std::vector<GetNextTestCase<SamplingDatasetParams>>(
+        {GetNextTestCase1(), GetNextTestCase2(), GetNextTestCase3()})));
+
+DatasetNodeNameTestCase<SamplingDatasetParams> DatasetNodeNameTestCase1() {
+  return {/*dataset_params=*/TenPercentSampleDataset(),
+          /*expected_node_name=*/kNodeName};
+}
+
 // Verify that the machinery for creating SamplingDataset kernels runs and
 // correctly creates kernels of with the node name "SamplingDataset".
 TEST_F(SamplingDatasetOpTest, DatasetNodeName) {
   // BEGIN INITIALIZATION CODE
-  // See ParameterizedSamplingDatasetOpTest::GetNext for explanatory comments.
+  // See ParameterizedGetNextSamplingDatasetOpTest::GetNext for explanatory
+  // comments.
   const int thread_num = 2, cpu_num = 2;
-  TestCase test_case = TestCase1();
+  auto test_case = DatasetNodeNameTestCase1();
   TF_ASSERT_OK(InitThreadPool(thread_num));
   TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
 
-  Tensor range_and_take_dataset_tensor;
-  TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params,
-                                             test_case.take_dataset_params,
-                                             &range_and_take_dataset_tensor));
+  TF_ASSERT_OK(MakeRangeDataset(
+      test_case.dataset_params.range_dataset_params.start,
+      test_case.dataset_params.range_dataset_params.stop,
+      test_case.dataset_params.range_dataset_params.step,
+      test_case.dataset_params.range_dataset_params.output_dtypes,
+      test_case.dataset_params.range_dataset_params.output_shapes,
+      &test_case.dataset_params.input_dataset));
 
-  Tensor rate = CreateTensor<float>(TensorShape({}), {test_case.rate});
-  Tensor seed = CreateTensor<int64>(TensorShape({}), {test_case.seed});
-  Tensor seed2 = CreateTensor<int64>(TensorShape({}), {test_case.seed2});
-  gtl::InlinedVector<TensorValue, 4> inputs(
-      {TensorValue(&range_and_take_dataset_tensor), TensorValue(&rate),
-       TensorValue(&seed), TensorValue(&seed2)});
+  gtl::InlinedVector<TensorValue, 4> inputs;
+  TF_ASSERT_OK(test_case.dataset_params.MakeInputs(&inputs));
 
   std::unique_ptr<OpKernel> sampling_dataset_kernel;
-  TF_ASSERT_OK(CreateSamplingDatasetOpKernel(test_case.expected_output_dtypes,
-                                             test_case.expected_output_shapes,
-                                             &sampling_dataset_kernel));
+  TF_ASSERT_OK(CreateSamplingDatasetOpKernel(
+      test_case.dataset_params.output_dtypes,
+      test_case.dataset_params.output_shapes, &sampling_dataset_kernel));
 
   std::unique_ptr<OpKernelContext> sampling_dataset_context;
   TF_ASSERT_OK(CreateSamplingDatasetContext(
@@ -288,33 +298,40 @@
   core::ScopedUnref scoped_unref(sampling_dataset);
   // END INITIALIZATION CODE
 
-  EXPECT_EQ(sampling_dataset->node_name(), kNodeName);
+  TF_ASSERT_OK(
+      CheckDatasetNodeName(*sampling_dataset, test_case.expected_node_name));
 }
 
-TEST_P(ParameterizedSamplingDatasetOpTest, DatasetTypeString) {
+DatasetTypeStringTestCase<SamplingDatasetParams> DatasetTypeStringTestCase1() {
+  return {/*dataset_params=*/TenPercentSampleDataset(),
+          /*expected_dataset_type_string=*/
+          name_utils::OpName(SamplingDatasetOp::kDatasetType)};
+}
+
+TEST_F(SamplingDatasetOpTest, DatasetTypeString) {
   // BEGIN INITIALIZATION CODE
-  // See ParameterizedSamplingDatasetOpTest::GetNext for explanatory comments.
+  // See ParameterizedGetNextSamplingDatasetOpTest::GetNext for explanatory
+  // comments.
   const int thread_num = 2, cpu_num = 2;
-  TestCase test_case = GetParam();
+  auto test_case = DatasetTypeStringTestCase1();
   TF_ASSERT_OK(InitThreadPool(thread_num));
   TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
 
-  Tensor range_and_take_dataset_tensor;
-  TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params,
-                                             test_case.take_dataset_params,
-                                             &range_and_take_dataset_tensor));
+  TF_ASSERT_OK(MakeRangeDataset(
+      test_case.dataset_params.range_dataset_params.start,
+      test_case.dataset_params.range_dataset_params.stop,
+      test_case.dataset_params.range_dataset_params.step,
+      test_case.dataset_params.range_dataset_params.output_dtypes,
+      test_case.dataset_params.range_dataset_params.output_shapes,
+      &test_case.dataset_params.input_dataset));
 
-  Tensor rate = CreateTensor<float>(TensorShape({}), {test_case.rate});
-  Tensor seed = CreateTensor<int64>(TensorShape({}), {test_case.seed});
-  Tensor seed2 = CreateTensor<int64>(TensorShape({}), {test_case.seed2});
-  gtl::InlinedVector<TensorValue, 4> inputs(
-      {TensorValue(&range_and_take_dataset_tensor), TensorValue(&rate),
-       TensorValue(&seed), TensorValue(&seed2)});
+  gtl::InlinedVector<TensorValue, 4> inputs;
+  TF_ASSERT_OK(test_case.dataset_params.MakeInputs(&inputs));
 
   std::unique_ptr<OpKernel> sampling_dataset_kernel;
-  TF_ASSERT_OK(CreateSamplingDatasetOpKernel(test_case.expected_output_dtypes,
-                                             test_case.expected_output_shapes,
-                                             &sampling_dataset_kernel));
+  TF_ASSERT_OK(CreateSamplingDatasetOpKernel(
+      test_case.dataset_params.output_dtypes,
+      test_case.dataset_params.output_shapes, &sampling_dataset_kernel));
 
   std::unique_ptr<OpKernelContext> sampling_dataset_context;
   TF_ASSERT_OK(CreateSamplingDatasetContext(
@@ -327,34 +344,40 @@
   core::ScopedUnref scoped_unref(sampling_dataset);
   // END INITIALIZATION CODE
 
-  EXPECT_EQ(sampling_dataset->type_string(),
-            name_utils::OpName(SamplingDatasetOp::kDatasetType));
+  TF_ASSERT_OK(CheckDatasetTypeString(*sampling_dataset,
+                                      test_case.expected_dataset_type_string));
 }
 
-TEST_P(ParameterizedSamplingDatasetOpTest, DatasetOutputDtypes) {
+DatasetOutputDtypesTestCase<SamplingDatasetParams>
+DatasetOutputDtypesTestCase1() {
+  return {/*dataset_params=*/TenPercentSampleDataset(),
+          /*expected_output_dtypes=*/{DT_INT64}};
+}
+
+TEST_F(SamplingDatasetOpTest, DatasetOutputDtypes) {
   // BEGIN INITIALIZATION CODE
-  // See ParameterizedSamplingDatasetOpTest::GetNext for explanatory comments.
+  // See ParameterizedGetNextSamplingDatasetOpTest::GetNext for explanatory
+  // comments.
   const int thread_num = 2, cpu_num = 2;
-  TestCase test_case = GetParam();
+  auto test_case = DatasetOutputDtypesTestCase1();
   TF_ASSERT_OK(InitThreadPool(thread_num));
   TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
 
-  Tensor range_and_take_dataset_tensor;
-  TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params,
-                                             test_case.take_dataset_params,
-                                             &range_and_take_dataset_tensor));
+  TF_ASSERT_OK(MakeRangeDataset(
+      test_case.dataset_params.range_dataset_params.start,
+      test_case.dataset_params.range_dataset_params.stop,
+      test_case.dataset_params.range_dataset_params.step,
+      test_case.dataset_params.range_dataset_params.output_dtypes,
+      test_case.dataset_params.range_dataset_params.output_shapes,
+      &test_case.dataset_params.input_dataset));
 
-  Tensor rate = CreateTensor<float>(TensorShape({}), {test_case.rate});
-  Tensor seed = CreateTensor<int64>(TensorShape({}), {test_case.seed});
-  Tensor seed2 = CreateTensor<int64>(TensorShape({}), {test_case.seed2});
-  gtl::InlinedVector<TensorValue, 4> inputs(
-      {TensorValue(&range_and_take_dataset_tensor), TensorValue(&rate),
-       TensorValue(&seed), TensorValue(&seed2)});
+  gtl::InlinedVector<TensorValue, 4> inputs;
+  TF_ASSERT_OK(test_case.dataset_params.MakeInputs(&inputs));
 
   std::unique_ptr<OpKernel> sampling_dataset_kernel;
-  TF_ASSERT_OK(CreateSamplingDatasetOpKernel(test_case.expected_output_dtypes,
-                                             test_case.expected_output_shapes,
-                                             &sampling_dataset_kernel));
+  TF_ASSERT_OK(CreateSamplingDatasetOpKernel(
+      test_case.dataset_params.output_dtypes,
+      test_case.dataset_params.output_shapes, &sampling_dataset_kernel));
 
   std::unique_ptr<OpKernelContext> sampling_dataset_context;
   TF_ASSERT_OK(CreateSamplingDatasetContext(
@@ -367,34 +390,40 @@
   core::ScopedUnref scoped_unref(sampling_dataset);
   // END INITIALIZATION CODE
 
-  TF_EXPECT_OK(VerifyTypesMatch(sampling_dataset->output_dtypes(),
-                                test_case.expected_output_dtypes));
+  TF_ASSERT_OK(CheckDatasetOutputDtypes(*sampling_dataset,
+                                        test_case.expected_output_dtypes));
 }
 
-TEST_P(ParameterizedSamplingDatasetOpTest, DatasetOutputShapes) {
+DatasetOutputShapesTestCase<SamplingDatasetParams>
+DatasetOutputShapesTestCase1() {
+  return {/*dataset_params=*/TenPercentSampleDataset(),
+          /*expected_output_shapes=*/{PartialTensorShape({})}};
+}
+
+TEST_F(SamplingDatasetOpTest, DatasetOutputShapes) {
   // BEGIN INITIALIZATION CODE
-  // See ParameterizedSamplingDatasetOpTest::GetNext for explanatory comments.
+  // See ParameterizedGetNextSamplingDatasetOpTest::GetNext for explanatory
+  // comments.
   const int thread_num = 2, cpu_num = 2;
-  TestCase test_case = GetParam();
+  auto test_case = DatasetOutputShapesTestCase1();
   TF_ASSERT_OK(InitThreadPool(thread_num));
   TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
 
-  Tensor range_and_take_dataset_tensor;
-  TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params,
-                                             test_case.take_dataset_params,
-                                             &range_and_take_dataset_tensor));
+  TF_ASSERT_OK(MakeRangeDataset(
+      test_case.dataset_params.range_dataset_params.start,
+      test_case.dataset_params.range_dataset_params.stop,
+      test_case.dataset_params.range_dataset_params.step,
+      test_case.dataset_params.range_dataset_params.output_dtypes,
+      test_case.dataset_params.range_dataset_params.output_shapes,
+      &test_case.dataset_params.input_dataset));
 
-  Tensor rate = CreateTensor<float>(TensorShape({}), {test_case.rate});
-  Tensor seed = CreateTensor<int64>(TensorShape({}), {test_case.seed});
-  Tensor seed2 = CreateTensor<int64>(TensorShape({}), {test_case.seed2});
-  gtl::InlinedVector<TensorValue, 4> inputs(
-      {TensorValue(&range_and_take_dataset_tensor), TensorValue(&rate),
-       TensorValue(&seed), TensorValue(&seed2)});
+  gtl::InlinedVector<TensorValue, 4> inputs;
+  TF_ASSERT_OK(test_case.dataset_params.MakeInputs(&inputs));
 
   std::unique_ptr<OpKernel> sampling_dataset_kernel;
-  TF_ASSERT_OK(CreateSamplingDatasetOpKernel(test_case.expected_output_dtypes,
-                                             test_case.expected_output_shapes,
-                                             &sampling_dataset_kernel));
+  TF_ASSERT_OK(CreateSamplingDatasetOpKernel(
+      test_case.dataset_params.output_dtypes,
+      test_case.dataset_params.output_shapes, &sampling_dataset_kernel));
 
   std::unique_ptr<OpKernelContext> sampling_dataset_context;
   TF_ASSERT_OK(CreateSamplingDatasetContext(
@@ -407,34 +436,54 @@
   core::ScopedUnref scoped_unref(sampling_dataset);
   // END INITIALIZATION CODE
 
-  TF_EXPECT_OK(VerifyShapesCompatible(sampling_dataset->output_shapes(),
-                                      test_case.expected_output_shapes));
+  TF_ASSERT_OK(CheckDatasetOutputShapes(*sampling_dataset,
+                                        test_case.expected_output_shapes));
 }
 
-TEST_P(ParameterizedSamplingDatasetOpTest, Cardinality) {
+class ParameterizedCardinalitySamplingDatasetOpTest
+    : public SamplingDatasetOpTest,
+      public ::testing::WithParamInterface<
+          CardinalityTestCase<SamplingDatasetParams>> {};
+
+CardinalityTestCase<SamplingDatasetParams> CardinalityTestCase1() {
+  return {/*dataset_params=*/OneHundredPercentSampleDataset(),
+          /*expected_cardinality=*/kUnknownCardinality};
+}
+
+CardinalityTestCase<SamplingDatasetParams> CardinalityTestCase2() {
+  return {/*dataset_params=*/TenPercentSampleDataset(),
+          /*expected_cardinality=*/kUnknownCardinality};
+}
+
+CardinalityTestCase<SamplingDatasetParams> CardinalityTestCase3() {
+  return {/*dataset_params=*/ZeroPercentSampleDataset(),
+          /*expected_cardinality=*/kUnknownCardinality};
+}
+
+TEST_P(ParameterizedCardinalitySamplingDatasetOpTest, Cardinality) {
   // BEGIN INITIALIZATION CODE
-  // See ParameterizedSamplingDatasetOpTest::GetNext for explanatory comments.
+  // See ParameterizedGetNextSamplingDatasetOpTest::GetNext for explanatory
+  // comments.
   const int thread_num = 2, cpu_num = 2;
-  TestCase test_case = GetParam();
+  auto test_case = GetParam();
   TF_ASSERT_OK(InitThreadPool(thread_num));
   TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
 
-  Tensor range_and_take_dataset_tensor;
-  TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params,
-                                             test_case.take_dataset_params,
-                                             &range_and_take_dataset_tensor));
+  TF_ASSERT_OK(MakeRangeDataset(
+      test_case.dataset_params.range_dataset_params.start,
+      test_case.dataset_params.range_dataset_params.stop,
+      test_case.dataset_params.range_dataset_params.step,
+      test_case.dataset_params.range_dataset_params.output_dtypes,
+      test_case.dataset_params.range_dataset_params.output_shapes,
+      &test_case.dataset_params.input_dataset));
 
-  Tensor rate = CreateTensor<float>(TensorShape({}), {test_case.rate});
-  Tensor seed = CreateTensor<int64>(TensorShape({}), {test_case.seed});
-  Tensor seed2 = CreateTensor<int64>(TensorShape({}), {test_case.seed2});
-  gtl::InlinedVector<TensorValue, 4> inputs(
-      {TensorValue(&range_and_take_dataset_tensor), TensorValue(&rate),
-       TensorValue(&seed), TensorValue(&seed2)});
+  gtl::InlinedVector<TensorValue, 4> inputs;
+  TF_ASSERT_OK(test_case.dataset_params.MakeInputs(&inputs));
 
   std::unique_ptr<OpKernel> sampling_dataset_kernel;
-  TF_ASSERT_OK(CreateSamplingDatasetOpKernel(test_case.expected_output_dtypes,
-                                             test_case.expected_output_shapes,
-                                             &sampling_dataset_kernel));
+  TF_ASSERT_OK(CreateSamplingDatasetOpKernel(
+      test_case.dataset_params.output_dtypes,
+      test_case.dataset_params.output_shapes, &sampling_dataset_kernel));
 
   std::unique_ptr<OpKernelContext> sampling_dataset_context;
   TF_ASSERT_OK(CreateSamplingDatasetContext(
@@ -447,34 +496,44 @@
   core::ScopedUnref scoped_unref(sampling_dataset);
   // END INITIALIZATION CODE
 
-  EXPECT_EQ(sampling_dataset->Cardinality(), test_case.expected_cardinality);
+  TF_ASSERT_OK(CheckDatasetCardinality(*sampling_dataset,
+                                       test_case.expected_cardinality));
 }
 
-// Verify that the Save() function executes without raising an error.
-TEST_P(ParameterizedSamplingDatasetOpTest, DatasetSave) {
+INSTANTIATE_TEST_SUITE_P(
+    SamplingDatasetOpTest, ParameterizedCardinalitySamplingDatasetOpTest,
+    ::testing::ValuesIn(std::vector<CardinalityTestCase<SamplingDatasetParams>>(
+        {CardinalityTestCase1(), CardinalityTestCase2(),
+         CardinalityTestCase3()})));
+
+DatasetSaveTestCase<SamplingDatasetParams> DatasetSaveTestCase1() {
+  return {/*dataset_params=*/TenPercentSampleDataset()};
+}
+
+TEST_F(SamplingDatasetOpTest, DatasetSave) {
   // BEGIN INITIALIZATION CODE
-  // See ParameterizedSamplingDatasetOpTest::GetNext for explanatory comments.
+  // See ParameterizedGetNextSamplingDatasetOpTest::GetNext for explanatory
+  // comments.
   const int thread_num = 2, cpu_num = 2;
-  TestCase test_case = GetParam();
+  auto test_case = DatasetSaveTestCase1();
   TF_ASSERT_OK(InitThreadPool(thread_num));
   TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
 
-  Tensor range_and_take_dataset_tensor;
-  TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params,
-                                             test_case.take_dataset_params,
-                                             &range_and_take_dataset_tensor));
+  TF_ASSERT_OK(MakeRangeDataset(
+      test_case.dataset_params.range_dataset_params.start,
+      test_case.dataset_params.range_dataset_params.stop,
+      test_case.dataset_params.range_dataset_params.step,
+      test_case.dataset_params.range_dataset_params.output_dtypes,
+      test_case.dataset_params.range_dataset_params.output_shapes,
+      &test_case.dataset_params.input_dataset));
 
-  Tensor rate = CreateTensor<float>(TensorShape({}), {test_case.rate});
-  Tensor seed = CreateTensor<int64>(TensorShape({}), {test_case.seed});
-  Tensor seed2 = CreateTensor<int64>(TensorShape({}), {test_case.seed2});
-  gtl::InlinedVector<TensorValue, 4> inputs(
-      {TensorValue(&range_and_take_dataset_tensor), TensorValue(&rate),
-       TensorValue(&seed), TensorValue(&seed2)});
+  gtl::InlinedVector<TensorValue, 4> inputs;
+  TF_ASSERT_OK(test_case.dataset_params.MakeInputs(&inputs));
 
   std::unique_ptr<OpKernel> sampling_dataset_kernel;
-  TF_ASSERT_OK(CreateSamplingDatasetOpKernel(test_case.expected_output_dtypes,
-                                             test_case.expected_output_shapes,
-                                             &sampling_dataset_kernel));
+  TF_ASSERT_OK(CreateSamplingDatasetOpKernel(
+      test_case.dataset_params.output_dtypes,
+      test_case.dataset_params.output_shapes, &sampling_dataset_kernel));
 
   std::unique_ptr<OpKernelContext> sampling_dataset_context;
   TF_ASSERT_OK(CreateSamplingDatasetContext(
@@ -487,38 +546,84 @@
   core::ScopedUnref scoped_unref(sampling_dataset);
   // END INITIALIZATION CODE
 
-  std::unique_ptr<SerializationContext> serialization_context;
-  TF_ASSERT_OK(CreateSerializationContext(&serialization_context));
-  VariantTensorData data;
-  VariantTensorDataWriter writer(&data);
-  TF_ASSERT_OK(sampling_dataset->Save(serialization_context.get(), &writer));
-  TF_ASSERT_OK(writer.Flush());
+  TF_ASSERT_OK(CheckDatasetSave(*sampling_dataset));
 }
 
-TEST_P(ParameterizedSamplingDatasetOpTest, IteratorOutputDtypes) {
+IsStatefulTestCase<SamplingDatasetParams> IsStatefulTestCase1() {
+  return {/*dataset_params=*/TenPercentSampleDataset(),
+          /*expected_stateful=*/false};
+}
+
+TEST_F(SamplingDatasetOpTest, IsStateful) {
   // BEGIN INITIALIZATION CODE
-  // See ParameterizedSamplingDatasetOpTest::GetNext for explanatory comments.
+  // See ParameterizedGetNextSamplingDatasetOpTest::GetNext for explanatory
+  // comments.
   const int thread_num = 2, cpu_num = 2;
-  TestCase test_case = GetParam();
+  auto test_case = IsStatefulTestCase1();
   TF_ASSERT_OK(InitThreadPool(thread_num));
   TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
 
-  Tensor range_and_take_dataset_tensor;
-  TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params,
-                                             test_case.take_dataset_params,
-                                             &range_and_take_dataset_tensor));
+  TF_ASSERT_OK(MakeRangeDataset(
+      test_case.dataset_params.range_dataset_params.start,
+      test_case.dataset_params.range_dataset_params.stop,
+      test_case.dataset_params.range_dataset_params.step,
+      test_case.dataset_params.range_dataset_params.output_dtypes,
+      test_case.dataset_params.range_dataset_params.output_shapes,
+      &test_case.dataset_params.input_dataset));
 
-  Tensor rate = CreateTensor<float>(TensorShape({}), {test_case.rate});
-  Tensor seed = CreateTensor<int64>(TensorShape({}), {test_case.seed});
-  Tensor seed2 = CreateTensor<int64>(TensorShape({}), {test_case.seed2});
-  gtl::InlinedVector<TensorValue, 4> inputs(
-      {TensorValue(&range_and_take_dataset_tensor), TensorValue(&rate),
-       TensorValue(&seed), TensorValue(&seed2)});
+  gtl::InlinedVector<TensorValue, 4> inputs;
+  TF_ASSERT_OK(test_case.dataset_params.MakeInputs(&inputs));
 
   std::unique_ptr<OpKernel> sampling_dataset_kernel;
-  TF_ASSERT_OK(CreateSamplingDatasetOpKernel(test_case.expected_output_dtypes,
-                                             test_case.expected_output_shapes,
-                                             &sampling_dataset_kernel));
+  TF_ASSERT_OK(CreateSamplingDatasetOpKernel(
+      test_case.dataset_params.output_dtypes,
+      test_case.dataset_params.output_shapes, &sampling_dataset_kernel));
+
+  std::unique_ptr<OpKernelContext> sampling_dataset_context;
+  TF_ASSERT_OK(CreateSamplingDatasetContext(
+      sampling_dataset_kernel.get(), &inputs, &sampling_dataset_context));
+
+  DatasetBase* sampling_dataset;
+  TF_ASSERT_OK(CreateDataset(sampling_dataset_kernel.get(),
+                             sampling_dataset_context.get(),
+                             &sampling_dataset));
+  core::ScopedUnref scoped_unref(sampling_dataset);
+  // END INITIALIZATION CODE
+
+  TF_ASSERT_OK(
+      CheckDatasetIsStateful(*sampling_dataset, test_case.expected_stateful));
+}
+
+IteratorOutputDtypesTestCase<SamplingDatasetParams>
+IteratorOutputDtypesTestCase1() {
+  return {/*dataset_params=*/TenPercentSampleDataset(),
+          /*expected_output_dtypes=*/{DT_INT64}};
+}
+
+TEST_F(SamplingDatasetOpTest, IteratorOutputDtypes) {
+  // BEGIN INITIALIZATION CODE
+  // See ParameterizedGetNextSamplingDatasetOpTest::GetNext for explanatory
+  // comments.
+  const int thread_num = 2, cpu_num = 2;
+  auto test_case = IteratorOutputDtypesTestCase1();
+  TF_ASSERT_OK(InitThreadPool(thread_num));
+  TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
+
+  TF_ASSERT_OK(MakeRangeDataset(
+      test_case.dataset_params.range_dataset_params.start,
+      test_case.dataset_params.range_dataset_params.stop,
+      test_case.dataset_params.range_dataset_params.step,
+      test_case.dataset_params.range_dataset_params.output_dtypes,
+      test_case.dataset_params.range_dataset_params.output_shapes,
+      &test_case.dataset_params.input_dataset));
+
+  gtl::InlinedVector<TensorValue, 4> inputs;
+  TF_ASSERT_OK(test_case.dataset_params.MakeInputs(&inputs));
+
+  std::unique_ptr<OpKernel> sampling_dataset_kernel;
+  TF_ASSERT_OK(CreateSamplingDatasetOpKernel(
+      test_case.dataset_params.output_dtypes,
+      test_case.dataset_params.output_shapes, &sampling_dataset_kernel));
 
   std::unique_ptr<OpKernelContext> sampling_dataset_context;
   TF_ASSERT_OK(CreateSamplingDatasetContext(
@@ -534,41 +639,44 @@
   TF_ASSERT_OK(
       CreateIteratorContext(sampling_dataset_context.get(), &iterator_context));
   std::unique_ptr<IteratorBase> iterator;
-  string iterator_prefix = name_utils::IteratorPrefix(
-      TakeDatasetOp::kDatasetType,
-      name_utils::IteratorPrefix(RangeDatasetOp::kDatasetType, "Iterator"));
   TF_ASSERT_OK(sampling_dataset->MakeIterator(iterator_context.get(),
-                                              iterator_prefix, &iterator));
+                                              kIteratorPrefix, &iterator));
   // END INITIALIZATION CODE
 
-  TF_EXPECT_OK(VerifyTypesMatch(iterator->output_dtypes(),
-                                test_case.expected_output_dtypes));
+  TF_ASSERT_OK(
+      CheckIteratorOutputDtypes(*iterator, test_case.expected_output_dtypes));
 }
 
-TEST_P(ParameterizedSamplingDatasetOpTest, IteratorOutputShapes) {
+IteratorOutputShapesTestCase<SamplingDatasetParams>
+IteratorOutputShapesTestCase1() {
+  return {/*dataset_params=*/TenPercentSampleDataset(),
+          /*expected_output_shapes=*/{PartialTensorShape({})}};
+}
+
+TEST_F(SamplingDatasetOpTest, IteratorOutputShapes) {
   // BEGIN INITIALIZATION CODE
-  // See ParameterizedSamplingDatasetOpTest::GetNext for explanatory comments.
+  // See ParameterizedGetNextSamplingDatasetOpTest::GetNext for explanatory
+  // comments.
   const int thread_num = 2, cpu_num = 2;
-  TestCase test_case = GetParam();
+  auto test_case = IteratorOutputShapesTestCase1();
   TF_ASSERT_OK(InitThreadPool(thread_num));
   TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
 
-  Tensor range_and_take_dataset_tensor;
-  TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params,
-                                             test_case.take_dataset_params,
-                                             &range_and_take_dataset_tensor));
+  TF_ASSERT_OK(MakeRangeDataset(
+      test_case.dataset_params.range_dataset_params.start,
+      test_case.dataset_params.range_dataset_params.stop,
+      test_case.dataset_params.range_dataset_params.step,
+      test_case.dataset_params.range_dataset_params.output_dtypes,
+      test_case.dataset_params.range_dataset_params.output_shapes,
+      &test_case.dataset_params.input_dataset));
 
-  Tensor rate = CreateTensor<float>(TensorShape({}), {test_case.rate});
-  Tensor seed = CreateTensor<int64>(TensorShape({}), {test_case.seed});
-  Tensor seed2 = CreateTensor<int64>(TensorShape({}), {test_case.seed2});
-  gtl::InlinedVector<TensorValue, 4> inputs(
-      {TensorValue(&range_and_take_dataset_tensor), TensorValue(&rate),
-       TensorValue(&seed), TensorValue(&seed2)});
+  gtl::InlinedVector<TensorValue, 4> inputs;
+  TF_ASSERT_OK(test_case.dataset_params.MakeInputs(&inputs));
 
   std::unique_ptr<OpKernel> sampling_dataset_kernel;
-  TF_ASSERT_OK(CreateSamplingDatasetOpKernel(test_case.expected_output_dtypes,
-                                             test_case.expected_output_shapes,
-                                             &sampling_dataset_kernel));
+  TF_ASSERT_OK(CreateSamplingDatasetOpKernel(
+      test_case.dataset_params.output_dtypes,
+      test_case.dataset_params.output_shapes, &sampling_dataset_kernel));
 
   std::unique_ptr<OpKernelContext> sampling_dataset_context;
   TF_ASSERT_OK(CreateSamplingDatasetContext(
@@ -584,41 +692,46 @@
   TF_ASSERT_OK(
       CreateIteratorContext(sampling_dataset_context.get(), &iterator_context));
   std::unique_ptr<IteratorBase> iterator;
-  string iterator_prefix = name_utils::IteratorPrefix(
-      TakeDatasetOp::kDatasetType,
-      name_utils::IteratorPrefix(RangeDatasetOp::kDatasetType, "Iterator"));
   TF_ASSERT_OK(sampling_dataset->MakeIterator(iterator_context.get(),
-                                              iterator_prefix, &iterator));
+                                              kIteratorPrefix, &iterator));
   // END INITIALIZATION CODE
 
-  TF_EXPECT_OK(VerifyShapesCompatible(iterator->output_shapes(),
-                                      test_case.expected_output_shapes));
+  TF_ASSERT_OK(
+      CheckIteratorOutputShapes(*iterator, test_case.expected_output_shapes));
 }
 
-TEST_P(ParameterizedSamplingDatasetOpTest, IteratorOutputPrefix) {
+IteratorOutputPrefixTestCase<SamplingDatasetParams>
+IteratorOutputPrefixTestCase1() {
+  return {/*dataset_params=*/TenPercentSampleDataset(),
+          /*expected_iterator_prefix=*/
+          name_utils::IteratorPrefix(SamplingDatasetOp::kDatasetType,
+                                     kIteratorPrefix)};
+}
+
+TEST_F(SamplingDatasetOpTest, IteratorOutputPrefix) {
   // BEGIN INITIALIZATION CODE
-  // See ParameterizedSamplingDatasetOpTest::GetNext for explanatory comments.
+  // See ParameterizedGetNextSamplingDatasetOpTest::GetNext for explanatory
+  // comments.
   const int thread_num = 2, cpu_num = 2;
-  TestCase test_case = GetParam();
+  auto test_case = IteratorOutputPrefixTestCase1();
   TF_ASSERT_OK(InitThreadPool(thread_num));
   TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
 
-  Tensor range_and_take_dataset_tensor;
-  TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params,
-                                             test_case.take_dataset_params,
-                                             &range_and_take_dataset_tensor));
+  TF_ASSERT_OK(MakeRangeDataset(
+      test_case.dataset_params.range_dataset_params.start,
+      test_case.dataset_params.range_dataset_params.stop,
+      test_case.dataset_params.range_dataset_params.step,
+      test_case.dataset_params.range_dataset_params.output_dtypes,
+      test_case.dataset_params.range_dataset_params.output_shapes,
+      &test_case.dataset_params.input_dataset));
 
-  Tensor rate = CreateTensor<float>(TensorShape({}), {test_case.rate});
-  Tensor seed = CreateTensor<int64>(TensorShape({}), {test_case.seed});
-  Tensor seed2 = CreateTensor<int64>(TensorShape({}), {test_case.seed2});
-  gtl::InlinedVector<TensorValue, 4> inputs(
-      {TensorValue(&range_and_take_dataset_tensor), TensorValue(&rate),
-       TensorValue(&seed), TensorValue(&seed2)});
+  gtl::InlinedVector<TensorValue, 4> inputs;
+  TF_ASSERT_OK(test_case.dataset_params.MakeInputs(&inputs));
 
   std::unique_ptr<OpKernel> sampling_dataset_kernel;
-  TF_ASSERT_OK(CreateSamplingDatasetOpKernel(test_case.expected_output_dtypes,
-                                             test_case.expected_output_shapes,
-                                             &sampling_dataset_kernel));
+  TF_ASSERT_OK(CreateSamplingDatasetOpKernel(
+      test_case.dataset_params.output_dtypes,
+      test_case.dataset_params.output_shapes, &sampling_dataset_kernel));
 
   std::unique_ptr<OpKernelContext> sampling_dataset_context;
   TF_ASSERT_OK(CreateSamplingDatasetContext(
@@ -634,43 +747,67 @@
   TF_ASSERT_OK(
       CreateIteratorContext(sampling_dataset_context.get(), &iterator_context));
   std::unique_ptr<IteratorBase> iterator;
-  string iterator_prefix = name_utils::IteratorPrefix(
-      TakeDatasetOp::kDatasetType,
-      name_utils::IteratorPrefix(RangeDatasetOp::kDatasetType, "Iterator"));
   TF_ASSERT_OK(sampling_dataset->MakeIterator(iterator_context.get(),
-                                              iterator_prefix, &iterator));
+                                              kIteratorPrefix, &iterator));
   // END INITIALIZATION CODE
 
-  EXPECT_EQ(iterator->prefix(),
-            name_utils::IteratorPrefix(SamplingDatasetOp::kDatasetType,
-                                       iterator_prefix));
+  TF_ASSERT_OK(
+      CheckIteratorPrefix(*iterator, test_case.expected_iterator_prefix));
+}
+
+class ParameterizedIteratorSaveAndRestoreSamplingDatasetOpTest
+    : public SamplingDatasetOpTest,
+      public ::testing::WithParamInterface<
+          IteratorSaveAndRestoreTestCase<SamplingDatasetParams>> {};
+
+IteratorSaveAndRestoreTestCase<SamplingDatasetParams>
+IteratorSaveAndRestoreTestCase1() {
+  return {/*dataset_params=*/OneHundredPercentSampleDataset(),
+          /*breakpoints=*/{0, 2, 5},
+          /*expected_outputs=*/
+          CreateTensors<int64>(TensorShape({}), {{0}, {1}, {2}})};
+}
+
+IteratorSaveAndRestoreTestCase<SamplingDatasetParams>
+IteratorSaveAndRestoreTestCase2() {
+  return {/*dataset_params=*/TenPercentSampleDataset(),
+          /*breakpoints=*/{0, 2, 5},
+          /*expected_outputs=*/
+          CreateTensors<int64>(TensorShape({}), {{9}, {11}, {19}})};
+}
+
+IteratorSaveAndRestoreTestCase<SamplingDatasetParams>
+IteratorSaveAndRestoreTestCase3() {
+  return {/*dataset_params=*/ZeroPercentSampleDataset(),
+          /*breakpoints=*/{0, 2, 5},
+          /*expected_outputs=*/{}};
 }
 
 // Save and restore the dataset while scanning it. Verify the returned tuples.
-TEST_P(ParameterizedSamplingDatasetOpTest, Roundtrip) {
+TEST_P(ParameterizedIteratorSaveAndRestoreSamplingDatasetOpTest, Roundtrip) {
   // BEGIN INITIALIZATION CODE
-  // See ParameterizedSamplingDatasetOpTest::GetNext for explanatory comments.
+  // See ParameterizedGetNextSamplingDatasetOpTest::GetNext for explanatory
+  // comments.
   const int thread_num = 2, cpu_num = 2;
-  TestCase test_case = GetParam();
+  auto test_case = GetParam();
   TF_ASSERT_OK(InitThreadPool(thread_num));
   TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
 
-  Tensor range_and_take_dataset_tensor;
-  TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params,
-                                             test_case.take_dataset_params,
-                                             &range_and_take_dataset_tensor));
+  TF_ASSERT_OK(MakeRangeDataset(
+      test_case.dataset_params.range_dataset_params.start,
+      test_case.dataset_params.range_dataset_params.stop,
+      test_case.dataset_params.range_dataset_params.step,
+      test_case.dataset_params.range_dataset_params.output_dtypes,
+      test_case.dataset_params.range_dataset_params.output_shapes,
+      &test_case.dataset_params.input_dataset));
 
-  Tensor rate = CreateTensor<float>(TensorShape({}), {test_case.rate});
-  Tensor seed = CreateTensor<int64>(TensorShape({}), {test_case.seed});
-  Tensor seed2 = CreateTensor<int64>(TensorShape({}), {test_case.seed2});
-  gtl::InlinedVector<TensorValue, 4> inputs(
-      {TensorValue(&range_and_take_dataset_tensor), TensorValue(&rate),
-       TensorValue(&seed), TensorValue(&seed2)});
+  gtl::InlinedVector<TensorValue, 4> inputs;
+  TF_ASSERT_OK(test_case.dataset_params.MakeInputs(&inputs));
 
   std::unique_ptr<OpKernel> sampling_dataset_kernel;
-  TF_ASSERT_OK(CreateSamplingDatasetOpKernel(test_case.expected_output_dtypes,
-                                             test_case.expected_output_shapes,
-                                             &sampling_dataset_kernel));
+  TF_ASSERT_OK(CreateSamplingDatasetOpKernel(
+      test_case.dataset_params.output_dtypes,
+      test_case.dataset_params.output_shapes, &sampling_dataset_kernel));
 
   std::unique_ptr<OpKernelContext> sampling_dataset_context;
   TF_ASSERT_OK(CreateSamplingDatasetContext(
@@ -686,46 +823,23 @@
   TF_ASSERT_OK(
       CreateIteratorContext(sampling_dataset_context.get(), &iterator_context));
   std::unique_ptr<IteratorBase> iterator;
-  string iterator_prefix = name_utils::IteratorPrefix(
-      TakeDatasetOp::kDatasetType,
-      name_utils::IteratorPrefix(RangeDatasetOp::kDatasetType, "Iterator"));
   TF_ASSERT_OK(sampling_dataset->MakeIterator(iterator_context.get(),
-                                              iterator_prefix, &iterator));
+                                              kIteratorPrefix, &iterator));
   // END INITIALIZATION CODE
 
-  std::unique_ptr<SerializationContext> serialization_ctx;
-  TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
-  bool end_of_sequence = false;
-  std::vector<Tensor> out_tensors;
-  int cur_iteration = 0;
-  const std::vector<int>& breakpoints = test_case.breakpoints;
-  for (int breakpoint : breakpoints) {
-    VariantTensorData data;
-    VariantTensorDataWriter writer(&data);
-    TF_EXPECT_OK(iterator->Save(serialization_ctx.get(), &writer));
-    TF_EXPECT_OK(writer.Flush());
-    VariantTensorDataReader reader(&data);
-    TF_EXPECT_OK(RestoreIterator(iterator_context.get(), &reader,
-                                 iterator_prefix, *sampling_dataset,
-                                 &iterator));
-
-    while (cur_iteration <= breakpoint) {
-      std::vector<Tensor> next;
-      TF_EXPECT_OK(
-          iterator->GetNext(iterator_context.get(), &next, &end_of_sequence));
-      out_tensors.insert(out_tensors.end(), next.begin(), next.end());
-      ++cur_iteration;
-    }
-  }
-
-  TF_EXPECT_OK(ExpectEqual(out_tensors, test_case.expected_outputs,
-                           /*compare_order*/ true));
+  TF_ASSERT_OK(CheckIteratorSaveAndRestore(
+      *sampling_dataset, iterator_context.get(), kIteratorPrefix,
+      test_case.expected_outputs, test_case.breakpoints));
 }
 
-INSTANTIATE_TEST_SUITE_P(SamplingDatasetOpTest,
-                         ParameterizedSamplingDatasetOpTest,
-                         ::testing::ValuesIn(std::vector<TestCase>(
-                             {TestCase1(), TestCase2(), TestCase3()})));
+INSTANTIATE_TEST_SUITE_P(
+    SamplingDatasetOpTest,
+    ParameterizedIteratorSaveAndRestoreSamplingDatasetOpTest,
+    ::testing::ValuesIn(
+        std::vector<IteratorSaveAndRestoreTestCase<SamplingDatasetParams>>(
+            {IteratorSaveAndRestoreTestCase1(),
+             IteratorSaveAndRestoreTestCase2(),
+             IteratorSaveAndRestoreTestCase3()})));
 
 }  // namespace
 }  // namespace experimental