Refactor test case to new format
Refactor test case to new format
Refactor initialization of input dataset
diff --git a/tensorflow/core/kernels/data/experimental/sampling_dataset_op_test.cc b/tensorflow/core/kernels/data/experimental/sampling_dataset_op_test.cc
index 8baa68a..9abb222 100644
--- a/tensorflow/core/kernels/data/experimental/sampling_dataset_op_test.cc
+++ b/tensorflow/core/kernels/data/experimental/sampling_dataset_op_test.cc
@@ -19,18 +19,7 @@
namespace {
constexpr char kNodeName[] = "sampling_dataset";
-
-// Parameters for constructing a dataset that returns an ordered sequence
-// of numbers
-struct RangeDatasetParams {
- int start;
- int stop;
- int step;
-};
-
-struct TakeDatasetParams {
- int count;
-};
+constexpr char kIteratorPrefix[] = "Iterator";
class SamplingDatasetOpTest : public DatasetOpsTestBase {
protected:
@@ -63,120 +52,135 @@
TF_RETURN_IF_ERROR(CreateOpKernelContext(op_kernel, inputs, context));
return Status::OK();
}
-
- // Build a dataset that will return an ordered sequence of numbers in chunks
- // of size `params.count`.
- // Stuffs the returned dataset into a variant tensor.
- Status MakeRangeAndTakeDatasetTensor(
- const RangeDatasetParams& range_dataset_params,
- const TakeDatasetParams& take_dataset_params,
- Tensor* range_and_take_dataset_tensor) {
- Tensor range_dataset_tensor;
- Tensor start =
- CreateTensor<int64>(TensorShape({}), {range_dataset_params.start});
- Tensor stop =
- CreateTensor<int64>(TensorShape({}), {range_dataset_params.stop});
- Tensor step =
- CreateTensor<int64>(TensorShape({}), {range_dataset_params.step});
- TF_RETURN_IF_ERROR(MakeRangeDataset(start, stop, step, {DT_INT64},
- {PartialTensorShape({})},
- &range_dataset_tensor));
-
- TF_RETURN_IF_ERROR(MakeTakeDataset(
- range_dataset_tensor, take_dataset_params.count, {DT_INT64},
- {PartialTensorShape({})}, range_and_take_dataset_tensor));
- return Status::OK();
- }
};
-// Common parameters that every test case in this file shares
-struct TestCase {
+// TODO(frreiss): Remove this once #31344 goes in and RangeDatasetParams is
+// defined in dataset_test_base.h
+class LocalRangeDatasetParams : public DatasetParams {
+ public:
+ LocalRangeDatasetParams(int64 start, int64 stop, int64 step,
+ DataTypeVector output_dtypes,
+ std::vector<PartialTensorShape> output_shapes,
+ string node_name)
+ : DatasetParams(std::move(output_dtypes), std::move(output_shapes),
+ std::move(node_name)),
+ start(CreateTensor<int64>(TensorShape({}), {start})),
+ stop(CreateTensor<int64>(TensorShape({}), {stop})),
+ step(CreateTensor<int64>(TensorShape({}), {step})) {}
+
+ Status MakeInputs(gtl::InlinedVector<TensorValue, 4>* inputs) override {
+ *inputs = {TensorValue(&start), TensorValue(&stop), TensorValue(&step)};
+ return Status::OK();
+ }
+
+ Tensor start;
+ Tensor stop;
+ Tensor step;
+};
+
+class SamplingDatasetParams : public DatasetParams {
+ public:
+ SamplingDatasetParams(float rate, int64 seed, int64 seed2, int64 start,
+ int64 stop, int64 step, DataTypeVector output_dtypes,
+ std::vector<PartialTensorShape> output_shapes,
+ string node_name)
+ : DatasetParams(std::move(output_dtypes), std::move(output_shapes),
+ std::move(node_name)),
+ rate(CreateTensor<float>(TensorShape({}), {rate})),
+ seed(CreateTensor<int64>(TensorShape({}), {seed})),
+ seed2(CreateTensor<int64>(TensorShape({}), {seed2})),
+ range_dataset_params(start, stop, step, {DT_INT64},
+ {PartialTensorShape({})}, "") {}
+
+ Status MakeInputs(gtl::InlinedVector<TensorValue, 4>* inputs) override {
+ if (input_dataset.NumElements() == 0 ||
+ input_dataset.dtype() != DT_VARIANT) {
+ return tensorflow::errors::Internal(
+ "The input dataset is not populated as the dataset tensor yet.");
+ }
+ *inputs = {TensorValue(&input_dataset), TensorValue(&rate),
+ TensorValue(&seed), TensorValue(&seed2)};
+ return Status::OK();
+ }
+
// Static parameters of the kernel
- float rate;
- int64 seed;
- int64 seed2;
+ Tensor rate;
+ Tensor seed;
+ Tensor seed2;
// Parameters of the sequence of numbers that will serve as the dynamic input
// of the kernel.
- RangeDatasetParams range_dataset_params;
- TakeDatasetParams take_dataset_params;
+ LocalRangeDatasetParams range_dataset_params;
- // The tensors that the kernel is expected to return, in the order they
- // should be returned
- std::vector<Tensor> expected_outputs;
-
- // Information about the returned outputs of the op that the test case
- // creates.
- DataTypeVector expected_output_dtypes;
- std::vector<PartialTensorShape> expected_output_shapes;
-
- // Value that the dataset's Cardinality() function returns. May be different
- // from the size of the outputs, as Cardinality() is not supposed to perform
- // expensive computations.
- int64 expected_cardinality;
-
- // When to insert save and restore steps while scanning the dataset in the
- // "roundtrip" test case.
- std::vector<int> breakpoints;
+ // RangeDataset kernel wrapped in a variant tensor. Initialized by the test
+ // case itself because the MakeRangeDataset() method requires an instance of
+ // DatasetOpsTestBase.
+ Tensor input_dataset;
};
-// Test case 1: 100% sample should return all inputs
-TestCase TestCase1() {
+SamplingDatasetParams OneHundredPercentSampleDataset() {
return {/*rate*/ 1.0,
/*seed*/ 42,
/*seed2*/ 7,
- /*range_dataset_params*/ {/*start*/ 0, /*stop*/ 10, /*step*/ 1},
- /*take_dataset_params*/ {/*count*/ 3},
- /*expected_outputs*/
- {DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {0}),
- DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {1}),
- DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {2})},
- /*expected_output_dtypes*/ {DT_INT64},
- /*expected_output_shapes*/ {PartialTensorShape({})},
- /*expected_cardinality*/ kUnknownCardinality,
- /*breakpoints*/ {0, 2, 5}};
+ /*start*/ 0,
+ /*stop*/ 3,
+ /*step*/ 1,
+ /*output_dtypes*/ {DT_INT64},
+ /*output_shapes*/ {PartialTensorShape({})},
+ /*node_name=*/kNodeName};
+}
+
+SamplingDatasetParams TenPercentSampleDataset() {
+ return {/*rate*/ 0.1,
+ /*seed*/ 42,
+ /*seed2*/ 7,
+ /*start*/ 0,
+ /*stop*/ 20,
+ /*step*/ 1,
+ /*output_dtypes*/ {DT_INT64},
+ /*output_shapes*/ {PartialTensorShape({})},
+ /*node_name=*/kNodeName};
+}
+
+SamplingDatasetParams ZeroPercentSampleDataset() {
+ return {/*rate*/ 0.0,
+ /*seed*/ 42,
+ /*seed2*/ 7,
+ /*start*/ 0,
+ /*stop*/ 20,
+ /*step*/ 1,
+ /*output_dtypes*/ {DT_INT64},
+ /*output_shapes*/ {PartialTensorShape({})},
+ /*node_name=*/kNodeName};
+}
+
+class ParameterizedGetNextSamplingDatasetOpTest
+ : public SamplingDatasetOpTest,
+ public ::testing::WithParamInterface<
+ GetNextTestCase<SamplingDatasetParams>> {};
+
+// Test case 1: 100% sample should return all inputs
+GetNextTestCase<SamplingDatasetParams> GetNextTestCase1() {
+ return {/*dataset_params=*/OneHundredPercentSampleDataset(),
+ /*expected_outputs=*/
+ CreateTensors<int64>(TensorShape({}), {{0}, {1}, {2}})};
}
// Test case 2: 10% sample should return about 10% of inputs, and the specific
// inputs returned shouldn't change across build environments.
-TestCase TestCase2() {
- return {/*rate*/ 0.1,
- /*seed*/ 42,
- /*seed2*/ 7,
- /*range_dataset_params*/ {/*start*/ 0, /*stop*/ 100, /*step*/ 1},
- /*take_dataset_params*/ {/*count*/ 20},
- /*expected_outputs*/
- {DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {9}),
- DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {11}),
- DatasetOpsTestBase::CreateTensor<int64>(TensorShape({}), {19})},
- /*expected_output_dtypes*/ {DT_INT64},
- /*expected_output_shapes*/ {PartialTensorShape({})},
- /*expected_cardinality*/ kUnknownCardinality,
- /*breakpoints*/ {0, 2, 5}};
+GetNextTestCase<SamplingDatasetParams> GetNextTestCase2() {
+ return {/*dataset_params=*/TenPercentSampleDataset(),
+ /*expected_outputs=*/
+ CreateTensors<int64>(TensorShape({}), {{9}, {11}, {19}})};
}
// Test case 3: 0% sample should return nothing and should not crash.
-TestCase TestCase3() {
- return {/*rate*/ 0.0,
- /*seed*/ 42,
- /*seed2*/ 7,
- /*range_dataset_params*/ {/*start*/ 0, /*stop*/ 100, /*step*/ 1},
- /*take_dataset_params*/ {/*count*/ 20},
- /*expected_outputs*/
- {},
- /*expected_output_dtypes*/ {DT_INT64},
- /*expected_output_shapes*/ {PartialTensorShape({})},
- /*expected_cardinality*/ kUnknownCardinality,
- /*breakpoints*/ {0, 2, 5}};
+GetNextTestCase<SamplingDatasetParams> GetNextTestCase3() {
+ return {/*dataset_params=*/ZeroPercentSampleDataset(),
+ /*expected_outputs=*/{}};
}
-// Parameterized test class shared by the next 6 test cases
-class ParameterizedSamplingDatasetOpTest
- : public SamplingDatasetOpTest,
- public ::testing::WithParamInterface<TestCase> {};
-
-// Verify that the GetNext function works and returns the expected outputs
-TEST_P(ParameterizedSamplingDatasetOpTest, GetNext) {
+TEST_P(ParameterizedGetNextSamplingDatasetOpTest, GetNext) {
// BEGIN INITIALIZATION CODE
// This test case and all the other test cases in this file go through the
// same sequence of initialization steps.
@@ -184,31 +188,30 @@
// Step 1: Set up enough of a TF runtime to be able to invoke a kernel.
const int thread_num = 2, cpu_num = 2;
- TestCase test_case = GetParam();
+ auto test_case = GetParam();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
// Step 2: Create the dataset that will provide input data for the kernel
- Tensor range_and_take_dataset_tensor;
- TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params,
- test_case.take_dataset_params,
- &range_and_take_dataset_tensor));
+ TF_ASSERT_OK(MakeRangeDataset(
+ test_case.dataset_params.range_dataset_params.start,
+ test_case.dataset_params.range_dataset_params.stop,
+ test_case.dataset_params.range_dataset_params.step,
+ test_case.dataset_params.range_dataset_params.output_dtypes,
+ test_case.dataset_params.range_dataset_params.output_shapes,
+ &test_case.dataset_params.input_dataset));
// Step 3: Box up the four inputs to the kernel inside TensorValue objects
// inside a vector.
- Tensor rate = CreateTensor<float>(TensorShape({}), {test_case.rate});
- Tensor seed = CreateTensor<int64>(TensorShape({}), {test_case.seed});
- Tensor seed2 = CreateTensor<int64>(TensorShape({}), {test_case.seed2});
- gtl::InlinedVector<TensorValue, 4> inputs(
- {TensorValue(&range_and_take_dataset_tensor), TensorValue(&rate),
- TensorValue(&seed), TensorValue(&seed2)});
+ gtl::InlinedVector<TensorValue, 4> inputs;
+ TF_ASSERT_OK(test_case.dataset_params.MakeInputs(&inputs));
// Step 4: Create a SamplingDataset kernel to test, passing in attributes
// of the kernel.
std::unique_ptr<OpKernel> sampling_dataset_kernel;
- TF_ASSERT_OK(CreateSamplingDatasetOpKernel(test_case.expected_output_dtypes,
- test_case.expected_output_shapes,
- &sampling_dataset_kernel));
+ TF_ASSERT_OK(CreateSamplingDatasetOpKernel(
+ test_case.dataset_params.output_dtypes,
+ test_case.dataset_params.output_shapes, &sampling_dataset_kernel));
// Step 5: Create a context in which the kernel will operate. This is where
// the kernel gets initialized with its inputs
@@ -229,11 +232,8 @@
TF_ASSERT_OK(
CreateIteratorContext(sampling_dataset_context.get(), &iterator_context));
std::unique_ptr<IteratorBase> iterator;
- string iterator_prefix = name_utils::IteratorPrefix(
- TakeDatasetOp::kDatasetType,
- name_utils::IteratorPrefix(RangeDatasetOp::kDatasetType, "Iterator"));
TF_ASSERT_OK(sampling_dataset->MakeIterator(iterator_context.get(),
- iterator_prefix, &iterator));
+ kIteratorPrefix, &iterator));
// END INITIALIZATION CODE
// Copy the iterator's output into a vector to make comparison easier.
@@ -250,32 +250,42 @@
/*compare_order*/ true));
}
+INSTANTIATE_TEST_SUITE_P(
+ SamplingDatasetOpTest, ParameterizedGetNextSamplingDatasetOpTest,
+ ::testing::ValuesIn(std::vector<GetNextTestCase<SamplingDatasetParams>>(
+ {GetNextTestCase1(), GetNextTestCase2(), GetNextTestCase3()})));
+
+DatasetNodeNameTestCase<SamplingDatasetParams> DatasetNodeNameTestCase1() {
+ return {/*dataset_params=*/TenPercentSampleDataset(),
+ /*expected_node_name=*/kNodeName};
+}
+
// Verify that the machinery for creating SamplingDataset kernels runs and
// correctly creates kernels of with the node name "SamplingDataset".
TEST_F(SamplingDatasetOpTest, DatasetNodeName) {
// BEGIN INITIALIZATION CODE
- // See ParameterizedSamplingDatasetOpTest::GetNext for explanatory comments.
+ // See ParameterizedGetNextSamplingDatasetOpTest::GetNext for explanatory
+ // comments.
const int thread_num = 2, cpu_num = 2;
- TestCase test_case = TestCase1();
+ auto test_case = DatasetNodeNameTestCase1();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
- Tensor range_and_take_dataset_tensor;
- TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params,
- test_case.take_dataset_params,
- &range_and_take_dataset_tensor));
+ TF_ASSERT_OK(MakeRangeDataset(
+ test_case.dataset_params.range_dataset_params.start,
+ test_case.dataset_params.range_dataset_params.stop,
+ test_case.dataset_params.range_dataset_params.step,
+ test_case.dataset_params.range_dataset_params.output_dtypes,
+ test_case.dataset_params.range_dataset_params.output_shapes,
+ &test_case.dataset_params.input_dataset));
- Tensor rate = CreateTensor<float>(TensorShape({}), {test_case.rate});
- Tensor seed = CreateTensor<int64>(TensorShape({}), {test_case.seed});
- Tensor seed2 = CreateTensor<int64>(TensorShape({}), {test_case.seed2});
- gtl::InlinedVector<TensorValue, 4> inputs(
- {TensorValue(&range_and_take_dataset_tensor), TensorValue(&rate),
- TensorValue(&seed), TensorValue(&seed2)});
+ gtl::InlinedVector<TensorValue, 4> inputs;
+ TF_ASSERT_OK(test_case.dataset_params.MakeInputs(&inputs));
std::unique_ptr<OpKernel> sampling_dataset_kernel;
- TF_ASSERT_OK(CreateSamplingDatasetOpKernel(test_case.expected_output_dtypes,
- test_case.expected_output_shapes,
- &sampling_dataset_kernel));
+ TF_ASSERT_OK(CreateSamplingDatasetOpKernel(
+ test_case.dataset_params.output_dtypes,
+ test_case.dataset_params.output_shapes, &sampling_dataset_kernel));
std::unique_ptr<OpKernelContext> sampling_dataset_context;
TF_ASSERT_OK(CreateSamplingDatasetContext(
@@ -288,33 +298,40 @@
core::ScopedUnref scoped_unref(sampling_dataset);
// END INITIALIZATION CODE
- EXPECT_EQ(sampling_dataset->node_name(), kNodeName);
+ TF_ASSERT_OK(
+ CheckDatasetNodeName(*sampling_dataset, test_case.expected_node_name));
}
-TEST_P(ParameterizedSamplingDatasetOpTest, DatasetTypeString) {
+DatasetTypeStringTestCase<SamplingDatasetParams> DatasetTypeStringTestCase1() {
+ return {/*dataset_params=*/TenPercentSampleDataset(),
+ /*expected_dataset_type_string=*/
+ name_utils::OpName(SamplingDatasetOp::kDatasetType)};
+}
+
+TEST_F(SamplingDatasetOpTest, DatasetTypeString) {
// BEGIN INITIALIZATION CODE
- // See ParameterizedSamplingDatasetOpTest::GetNext for explanatory comments.
+ // See ParameterizedGetNextSamplingDatasetOpTest::GetNext for explanatory
+ // comments.
const int thread_num = 2, cpu_num = 2;
- TestCase test_case = GetParam();
+ auto test_case = DatasetTypeStringTestCase1();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
- Tensor range_and_take_dataset_tensor;
- TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params,
- test_case.take_dataset_params,
- &range_and_take_dataset_tensor));
+ TF_ASSERT_OK(MakeRangeDataset(
+ test_case.dataset_params.range_dataset_params.start,
+ test_case.dataset_params.range_dataset_params.stop,
+ test_case.dataset_params.range_dataset_params.step,
+ test_case.dataset_params.range_dataset_params.output_dtypes,
+ test_case.dataset_params.range_dataset_params.output_shapes,
+ &test_case.dataset_params.input_dataset));
- Tensor rate = CreateTensor<float>(TensorShape({}), {test_case.rate});
- Tensor seed = CreateTensor<int64>(TensorShape({}), {test_case.seed});
- Tensor seed2 = CreateTensor<int64>(TensorShape({}), {test_case.seed2});
- gtl::InlinedVector<TensorValue, 4> inputs(
- {TensorValue(&range_and_take_dataset_tensor), TensorValue(&rate),
- TensorValue(&seed), TensorValue(&seed2)});
+ gtl::InlinedVector<TensorValue, 4> inputs;
+ TF_ASSERT_OK(test_case.dataset_params.MakeInputs(&inputs));
std::unique_ptr<OpKernel> sampling_dataset_kernel;
- TF_ASSERT_OK(CreateSamplingDatasetOpKernel(test_case.expected_output_dtypes,
- test_case.expected_output_shapes,
- &sampling_dataset_kernel));
+ TF_ASSERT_OK(CreateSamplingDatasetOpKernel(
+ test_case.dataset_params.output_dtypes,
+ test_case.dataset_params.output_shapes, &sampling_dataset_kernel));
std::unique_ptr<OpKernelContext> sampling_dataset_context;
TF_ASSERT_OK(CreateSamplingDatasetContext(
@@ -327,34 +344,40 @@
core::ScopedUnref scoped_unref(sampling_dataset);
// END INITIALIZATION CODE
- EXPECT_EQ(sampling_dataset->type_string(),
- name_utils::OpName(SamplingDatasetOp::kDatasetType));
+ TF_ASSERT_OK(CheckDatasetTypeString(*sampling_dataset,
+ test_case.expected_dataset_type_string));
}
-TEST_P(ParameterizedSamplingDatasetOpTest, DatasetOutputDtypes) {
+DatasetOutputDtypesTestCase<SamplingDatasetParams>
+DatasetOutputDtypesTestCase1() {
+ return {/*dataset_params=*/TenPercentSampleDataset(),
+ /*expected_output_dtypes=*/{DT_INT64}};
+}
+
+TEST_F(SamplingDatasetOpTest, DatasetOutputDtypes) {
// BEGIN INITIALIZATION CODE
- // See ParameterizedSamplingDatasetOpTest::GetNext for explanatory comments.
+ // See ParameterizedGetNextSamplingDatasetOpTest::GetNext for explanatory
+ // comments.
const int thread_num = 2, cpu_num = 2;
- TestCase test_case = GetParam();
+ auto test_case = DatasetOutputDtypesTestCase1();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
- Tensor range_and_take_dataset_tensor;
- TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params,
- test_case.take_dataset_params,
- &range_and_take_dataset_tensor));
+ TF_ASSERT_OK(MakeRangeDataset(
+ test_case.dataset_params.range_dataset_params.start,
+ test_case.dataset_params.range_dataset_params.stop,
+ test_case.dataset_params.range_dataset_params.step,
+ test_case.dataset_params.range_dataset_params.output_dtypes,
+ test_case.dataset_params.range_dataset_params.output_shapes,
+ &test_case.dataset_params.input_dataset));
- Tensor rate = CreateTensor<float>(TensorShape({}), {test_case.rate});
- Tensor seed = CreateTensor<int64>(TensorShape({}), {test_case.seed});
- Tensor seed2 = CreateTensor<int64>(TensorShape({}), {test_case.seed2});
- gtl::InlinedVector<TensorValue, 4> inputs(
- {TensorValue(&range_and_take_dataset_tensor), TensorValue(&rate),
- TensorValue(&seed), TensorValue(&seed2)});
+ gtl::InlinedVector<TensorValue, 4> inputs;
+ TF_ASSERT_OK(test_case.dataset_params.MakeInputs(&inputs));
std::unique_ptr<OpKernel> sampling_dataset_kernel;
- TF_ASSERT_OK(CreateSamplingDatasetOpKernel(test_case.expected_output_dtypes,
- test_case.expected_output_shapes,
- &sampling_dataset_kernel));
+ TF_ASSERT_OK(CreateSamplingDatasetOpKernel(
+ test_case.dataset_params.output_dtypes,
+ test_case.dataset_params.output_shapes, &sampling_dataset_kernel));
std::unique_ptr<OpKernelContext> sampling_dataset_context;
TF_ASSERT_OK(CreateSamplingDatasetContext(
@@ -367,34 +390,40 @@
core::ScopedUnref scoped_unref(sampling_dataset);
// END INITIALIZATION CODE
- TF_EXPECT_OK(VerifyTypesMatch(sampling_dataset->output_dtypes(),
- test_case.expected_output_dtypes));
+ TF_ASSERT_OK(CheckDatasetOutputDtypes(*sampling_dataset,
+ test_case.expected_output_dtypes));
}
-TEST_P(ParameterizedSamplingDatasetOpTest, DatasetOutputShapes) {
+DatasetOutputShapesTestCase<SamplingDatasetParams>
+DatasetOutputShapesTestCase1() {
+ return {/*dataset_params=*/TenPercentSampleDataset(),
+ /*expected_output_shapes=*/{PartialTensorShape({})}};
+}
+
+TEST_F(SamplingDatasetOpTest, DatasetOutputShapes) {
// BEGIN INITIALIZATION CODE
- // See ParameterizedSamplingDatasetOpTest::GetNext for explanatory comments.
+ // See ParameterizedGetNextSamplingDatasetOpTest::GetNext for explanatory
+ // comments.
const int thread_num = 2, cpu_num = 2;
- TestCase test_case = GetParam();
+ auto test_case = DatasetOutputShapesTestCase1();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
- Tensor range_and_take_dataset_tensor;
- TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params,
- test_case.take_dataset_params,
- &range_and_take_dataset_tensor));
+ TF_ASSERT_OK(MakeRangeDataset(
+ test_case.dataset_params.range_dataset_params.start,
+ test_case.dataset_params.range_dataset_params.stop,
+ test_case.dataset_params.range_dataset_params.step,
+ test_case.dataset_params.range_dataset_params.output_dtypes,
+ test_case.dataset_params.range_dataset_params.output_shapes,
+ &test_case.dataset_params.input_dataset));
- Tensor rate = CreateTensor<float>(TensorShape({}), {test_case.rate});
- Tensor seed = CreateTensor<int64>(TensorShape({}), {test_case.seed});
- Tensor seed2 = CreateTensor<int64>(TensorShape({}), {test_case.seed2});
- gtl::InlinedVector<TensorValue, 4> inputs(
- {TensorValue(&range_and_take_dataset_tensor), TensorValue(&rate),
- TensorValue(&seed), TensorValue(&seed2)});
+ gtl::InlinedVector<TensorValue, 4> inputs;
+ TF_ASSERT_OK(test_case.dataset_params.MakeInputs(&inputs));
std::unique_ptr<OpKernel> sampling_dataset_kernel;
- TF_ASSERT_OK(CreateSamplingDatasetOpKernel(test_case.expected_output_dtypes,
- test_case.expected_output_shapes,
- &sampling_dataset_kernel));
+ TF_ASSERT_OK(CreateSamplingDatasetOpKernel(
+ test_case.dataset_params.output_dtypes,
+ test_case.dataset_params.output_shapes, &sampling_dataset_kernel));
std::unique_ptr<OpKernelContext> sampling_dataset_context;
TF_ASSERT_OK(CreateSamplingDatasetContext(
@@ -407,34 +436,54 @@
core::ScopedUnref scoped_unref(sampling_dataset);
// END INITIALIZATION CODE
- TF_EXPECT_OK(VerifyShapesCompatible(sampling_dataset->output_shapes(),
- test_case.expected_output_shapes));
+ TF_ASSERT_OK(CheckDatasetOutputShapes(*sampling_dataset,
+ test_case.expected_output_shapes));
}
-TEST_P(ParameterizedSamplingDatasetOpTest, Cardinality) {
+class ParameterizedCardinalitySamplingDatasetOpTest
+ : public SamplingDatasetOpTest,
+ public ::testing::WithParamInterface<
+ CardinalityTestCase<SamplingDatasetParams>> {};
+
+CardinalityTestCase<SamplingDatasetParams> CardinalityTestCase1() {
+ return {/*dataset_params=*/OneHundredPercentSampleDataset(),
+ /*expected_cardinality=*/kUnknownCardinality};
+}
+
+CardinalityTestCase<SamplingDatasetParams> CardinalityTestCase2() {
+ return {/*dataset_params=*/TenPercentSampleDataset(),
+ /*expected_cardinality=*/kUnknownCardinality};
+}
+
+CardinalityTestCase<SamplingDatasetParams> CardinalityTestCase3() {
+ return {/*dataset_params=*/ZeroPercentSampleDataset(),
+ /*expected_cardinality=*/kUnknownCardinality};
+}
+
+TEST_P(ParameterizedCardinalitySamplingDatasetOpTest, Cardinality) {
// BEGIN INITIALIZATION CODE
- // See ParameterizedSamplingDatasetOpTest::GetNext for explanatory comments.
+ // See ParameterizedGetNextSamplingDatasetOpTest::GetNext for explanatory
+ // comments.
const int thread_num = 2, cpu_num = 2;
- TestCase test_case = GetParam();
+ auto test_case = GetParam();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
- Tensor range_and_take_dataset_tensor;
- TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params,
- test_case.take_dataset_params,
- &range_and_take_dataset_tensor));
+ TF_ASSERT_OK(MakeRangeDataset(
+ test_case.dataset_params.range_dataset_params.start,
+ test_case.dataset_params.range_dataset_params.stop,
+ test_case.dataset_params.range_dataset_params.step,
+ test_case.dataset_params.range_dataset_params.output_dtypes,
+ test_case.dataset_params.range_dataset_params.output_shapes,
+ &test_case.dataset_params.input_dataset));
- Tensor rate = CreateTensor<float>(TensorShape({}), {test_case.rate});
- Tensor seed = CreateTensor<int64>(TensorShape({}), {test_case.seed});
- Tensor seed2 = CreateTensor<int64>(TensorShape({}), {test_case.seed2});
- gtl::InlinedVector<TensorValue, 4> inputs(
- {TensorValue(&range_and_take_dataset_tensor), TensorValue(&rate),
- TensorValue(&seed), TensorValue(&seed2)});
+ gtl::InlinedVector<TensorValue, 4> inputs;
+ TF_ASSERT_OK(test_case.dataset_params.MakeInputs(&inputs));
std::unique_ptr<OpKernel> sampling_dataset_kernel;
- TF_ASSERT_OK(CreateSamplingDatasetOpKernel(test_case.expected_output_dtypes,
- test_case.expected_output_shapes,
- &sampling_dataset_kernel));
+ TF_ASSERT_OK(CreateSamplingDatasetOpKernel(
+ test_case.dataset_params.output_dtypes,
+ test_case.dataset_params.output_shapes, &sampling_dataset_kernel));
std::unique_ptr<OpKernelContext> sampling_dataset_context;
TF_ASSERT_OK(CreateSamplingDatasetContext(
@@ -447,34 +496,44 @@
core::ScopedUnref scoped_unref(sampling_dataset);
// END INITIALIZATION CODE
- EXPECT_EQ(sampling_dataset->Cardinality(), test_case.expected_cardinality);
+ TF_ASSERT_OK(CheckDatasetCardinality(*sampling_dataset,
+ test_case.expected_cardinality));
}
-// Verify that the Save() function executes without raising an error.
-TEST_P(ParameterizedSamplingDatasetOpTest, DatasetSave) {
+INSTANTIATE_TEST_SUITE_P(
+ SamplingDatasetOpTest, ParameterizedCardinalitySamplingDatasetOpTest,
+ ::testing::ValuesIn(std::vector<CardinalityTestCase<SamplingDatasetParams>>(
+ {CardinalityTestCase1(), CardinalityTestCase2(),
+ CardinalityTestCase3()})));
+
+DatasetSaveTestCase<SamplingDatasetParams> DatasetSaveTestCase1() {
+ return {/*dataset_params=*/TenPercentSampleDataset()};
+}
+
+TEST_F(SamplingDatasetOpTest, DatasetSave) {
// BEGIN INITIALIZATION CODE
- // See ParameterizedSamplingDatasetOpTest::GetNext for explanatory comments.
+ // See ParameterizedGetNextSamplingDatasetOpTest::GetNext for explanatory
+ // comments.
const int thread_num = 2, cpu_num = 2;
- TestCase test_case = GetParam();
+ auto test_case = DatasetSaveTestCase1();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
- Tensor range_and_take_dataset_tensor;
- TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params,
- test_case.take_dataset_params,
- &range_and_take_dataset_tensor));
+ TF_ASSERT_OK(MakeRangeDataset(
+ test_case.dataset_params.range_dataset_params.start,
+ test_case.dataset_params.range_dataset_params.stop,
+ test_case.dataset_params.range_dataset_params.step,
+ test_case.dataset_params.range_dataset_params.output_dtypes,
+ test_case.dataset_params.range_dataset_params.output_shapes,
+ &test_case.dataset_params.input_dataset));
- Tensor rate = CreateTensor<float>(TensorShape({}), {test_case.rate});
- Tensor seed = CreateTensor<int64>(TensorShape({}), {test_case.seed});
- Tensor seed2 = CreateTensor<int64>(TensorShape({}), {test_case.seed2});
- gtl::InlinedVector<TensorValue, 4> inputs(
- {TensorValue(&range_and_take_dataset_tensor), TensorValue(&rate),
- TensorValue(&seed), TensorValue(&seed2)});
+ gtl::InlinedVector<TensorValue, 4> inputs;
+ TF_ASSERT_OK(test_case.dataset_params.MakeInputs(&inputs));
std::unique_ptr<OpKernel> sampling_dataset_kernel;
- TF_ASSERT_OK(CreateSamplingDatasetOpKernel(test_case.expected_output_dtypes,
- test_case.expected_output_shapes,
- &sampling_dataset_kernel));
+ TF_ASSERT_OK(CreateSamplingDatasetOpKernel(
+ test_case.dataset_params.output_dtypes,
+ test_case.dataset_params.output_shapes, &sampling_dataset_kernel));
std::unique_ptr<OpKernelContext> sampling_dataset_context;
TF_ASSERT_OK(CreateSamplingDatasetContext(
@@ -487,38 +546,84 @@
core::ScopedUnref scoped_unref(sampling_dataset);
// END INITIALIZATION CODE
- std::unique_ptr<SerializationContext> serialization_context;
- TF_ASSERT_OK(CreateSerializationContext(&serialization_context));
- VariantTensorData data;
- VariantTensorDataWriter writer(&data);
- TF_ASSERT_OK(sampling_dataset->Save(serialization_context.get(), &writer));
- TF_ASSERT_OK(writer.Flush());
+ TF_ASSERT_OK(CheckDatasetSave(*sampling_dataset));
}
-TEST_P(ParameterizedSamplingDatasetOpTest, IteratorOutputDtypes) {
+IsStatefulTestCase<SamplingDatasetParams> IsStatefulTestCase1() {
+ return {/*dataset_params=*/TenPercentSampleDataset(),
+ /*expected_stateful=*/false};
+}
+
+TEST_F(SamplingDatasetOpTest, IsStateful) {
// BEGIN INITIALIZATION CODE
- // See ParameterizedSamplingDatasetOpTest::GetNext for explanatory comments.
+ // See ParameterizedGetNextSamplingDatasetOpTest::GetNext for explanatory
+ // comments.
const int thread_num = 2, cpu_num = 2;
- TestCase test_case = GetParam();
+ auto test_case = IsStatefulTestCase1();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
- Tensor range_and_take_dataset_tensor;
- TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params,
- test_case.take_dataset_params,
- &range_and_take_dataset_tensor));
+ TF_ASSERT_OK(MakeRangeDataset(
+ test_case.dataset_params.range_dataset_params.start,
+ test_case.dataset_params.range_dataset_params.stop,
+ test_case.dataset_params.range_dataset_params.step,
+ test_case.dataset_params.range_dataset_params.output_dtypes,
+ test_case.dataset_params.range_dataset_params.output_shapes,
+ &test_case.dataset_params.input_dataset));
- Tensor rate = CreateTensor<float>(TensorShape({}), {test_case.rate});
- Tensor seed = CreateTensor<int64>(TensorShape({}), {test_case.seed});
- Tensor seed2 = CreateTensor<int64>(TensorShape({}), {test_case.seed2});
- gtl::InlinedVector<TensorValue, 4> inputs(
- {TensorValue(&range_and_take_dataset_tensor), TensorValue(&rate),
- TensorValue(&seed), TensorValue(&seed2)});
+ gtl::InlinedVector<TensorValue, 4> inputs;
+ TF_ASSERT_OK(test_case.dataset_params.MakeInputs(&inputs));
std::unique_ptr<OpKernel> sampling_dataset_kernel;
- TF_ASSERT_OK(CreateSamplingDatasetOpKernel(test_case.expected_output_dtypes,
- test_case.expected_output_shapes,
- &sampling_dataset_kernel));
+ TF_ASSERT_OK(CreateSamplingDatasetOpKernel(
+ test_case.dataset_params.output_dtypes,
+ test_case.dataset_params.output_shapes, &sampling_dataset_kernel));
+
+ std::unique_ptr<OpKernelContext> sampling_dataset_context;
+ TF_ASSERT_OK(CreateSamplingDatasetContext(
+ sampling_dataset_kernel.get(), &inputs, &sampling_dataset_context));
+
+ DatasetBase* sampling_dataset;
+ TF_ASSERT_OK(CreateDataset(sampling_dataset_kernel.get(),
+ sampling_dataset_context.get(),
+ &sampling_dataset));
+ core::ScopedUnref scoped_unref(sampling_dataset);
+ // END INITIALIZATION CODE
+
+ TF_ASSERT_OK(
+ CheckDatasetIsStateful(*sampling_dataset, test_case.expected_stateful));
+}
+
+IteratorOutputDtypesTestCase<SamplingDatasetParams>
+IteratorOutputDtypesTestCase1() {
+ return {/*dataset_params=*/TenPercentSampleDataset(),
+ /*expected_output_dtypes=*/{DT_INT64}};
+}
+
+TEST_F(SamplingDatasetOpTest, IteratorOutputDtypes) {
+ // BEGIN INITIALIZATION CODE
+ // See ParameterizedGetNextSamplingDatasetOpTest::GetNext for explanatory
+ // comments.
+ const int thread_num = 2, cpu_num = 2;
+ auto test_case = IteratorOutputDtypesTestCase1();
+ TF_ASSERT_OK(InitThreadPool(thread_num));
+ TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
+
+ TF_ASSERT_OK(MakeRangeDataset(
+ test_case.dataset_params.range_dataset_params.start,
+ test_case.dataset_params.range_dataset_params.stop,
+ test_case.dataset_params.range_dataset_params.step,
+ test_case.dataset_params.range_dataset_params.output_dtypes,
+ test_case.dataset_params.range_dataset_params.output_shapes,
+ &test_case.dataset_params.input_dataset));
+
+ gtl::InlinedVector<TensorValue, 4> inputs;
+ TF_ASSERT_OK(test_case.dataset_params.MakeInputs(&inputs));
+
+ std::unique_ptr<OpKernel> sampling_dataset_kernel;
+ TF_ASSERT_OK(CreateSamplingDatasetOpKernel(
+ test_case.dataset_params.output_dtypes,
+ test_case.dataset_params.output_shapes, &sampling_dataset_kernel));
std::unique_ptr<OpKernelContext> sampling_dataset_context;
TF_ASSERT_OK(CreateSamplingDatasetContext(
@@ -534,41 +639,44 @@
TF_ASSERT_OK(
CreateIteratorContext(sampling_dataset_context.get(), &iterator_context));
std::unique_ptr<IteratorBase> iterator;
- string iterator_prefix = name_utils::IteratorPrefix(
- TakeDatasetOp::kDatasetType,
- name_utils::IteratorPrefix(RangeDatasetOp::kDatasetType, "Iterator"));
TF_ASSERT_OK(sampling_dataset->MakeIterator(iterator_context.get(),
- iterator_prefix, &iterator));
+ kIteratorPrefix, &iterator));
// END INITIALIZATION CODE
- TF_EXPECT_OK(VerifyTypesMatch(iterator->output_dtypes(),
- test_case.expected_output_dtypes));
+ TF_ASSERT_OK(
+ CheckIteratorOutputDtypes(*iterator, test_case.expected_output_dtypes));
}
-TEST_P(ParameterizedSamplingDatasetOpTest, IteratorOutputShapes) {
+IteratorOutputShapesTestCase<SamplingDatasetParams>
+IteratorOutputShapesTestCase1() {
+ return {/*dataset_params=*/TenPercentSampleDataset(),
+ /*expected_output_shapes=*/{PartialTensorShape({})}};
+}
+
+TEST_F(SamplingDatasetOpTest, IteratorOutputShapes) {
// BEGIN INITIALIZATION CODE
- // See ParameterizedSamplingDatasetOpTest::GetNext for explanatory comments.
+ // See ParameterizedGetNextSamplingDatasetOpTest::GetNext for explanatory
+ // comments.
const int thread_num = 2, cpu_num = 2;
- TestCase test_case = GetParam();
+ auto test_case = IteratorOutputShapesTestCase1();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
- Tensor range_and_take_dataset_tensor;
- TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params,
- test_case.take_dataset_params,
- &range_and_take_dataset_tensor));
+ TF_ASSERT_OK(MakeRangeDataset(
+ test_case.dataset_params.range_dataset_params.start,
+ test_case.dataset_params.range_dataset_params.stop,
+ test_case.dataset_params.range_dataset_params.step,
+ test_case.dataset_params.range_dataset_params.output_dtypes,
+ test_case.dataset_params.range_dataset_params.output_shapes,
+ &test_case.dataset_params.input_dataset));
- Tensor rate = CreateTensor<float>(TensorShape({}), {test_case.rate});
- Tensor seed = CreateTensor<int64>(TensorShape({}), {test_case.seed});
- Tensor seed2 = CreateTensor<int64>(TensorShape({}), {test_case.seed2});
- gtl::InlinedVector<TensorValue, 4> inputs(
- {TensorValue(&range_and_take_dataset_tensor), TensorValue(&rate),
- TensorValue(&seed), TensorValue(&seed2)});
+ gtl::InlinedVector<TensorValue, 4> inputs;
+ TF_ASSERT_OK(test_case.dataset_params.MakeInputs(&inputs));
std::unique_ptr<OpKernel> sampling_dataset_kernel;
- TF_ASSERT_OK(CreateSamplingDatasetOpKernel(test_case.expected_output_dtypes,
- test_case.expected_output_shapes,
- &sampling_dataset_kernel));
+ TF_ASSERT_OK(CreateSamplingDatasetOpKernel(
+ test_case.dataset_params.output_dtypes,
+ test_case.dataset_params.output_shapes, &sampling_dataset_kernel));
std::unique_ptr<OpKernelContext> sampling_dataset_context;
TF_ASSERT_OK(CreateSamplingDatasetContext(
@@ -584,41 +692,46 @@
TF_ASSERT_OK(
CreateIteratorContext(sampling_dataset_context.get(), &iterator_context));
std::unique_ptr<IteratorBase> iterator;
- string iterator_prefix = name_utils::IteratorPrefix(
- TakeDatasetOp::kDatasetType,
- name_utils::IteratorPrefix(RangeDatasetOp::kDatasetType, "Iterator"));
TF_ASSERT_OK(sampling_dataset->MakeIterator(iterator_context.get(),
- iterator_prefix, &iterator));
+ kIteratorPrefix, &iterator));
// END INITIALIZATION CODE
- TF_EXPECT_OK(VerifyShapesCompatible(iterator->output_shapes(),
- test_case.expected_output_shapes));
+ TF_ASSERT_OK(
+ CheckIteratorOutputShapes(*iterator, test_case.expected_output_shapes));
}
-TEST_P(ParameterizedSamplingDatasetOpTest, IteratorOutputPrefix) {
+IteratorOutputPrefixTestCase<SamplingDatasetParams>
+IteratorOutputPrefixTestCase1() {
+ return {/*dataset_params=*/TenPercentSampleDataset(),
+ /*expected_iterator_prefix=*/
+ name_utils::IteratorPrefix(SamplingDatasetOp::kDatasetType,
+ kIteratorPrefix)};
+}
+
+TEST_F(SamplingDatasetOpTest, IteratorOutputPrefix) {
// BEGIN INITIALIZATION CODE
- // See ParameterizedSamplingDatasetOpTest::GetNext for explanatory comments.
+ // See ParameterizedGetNextSamplingDatasetOpTest::GetNext for explanatory
+ // comments.
const int thread_num = 2, cpu_num = 2;
- TestCase test_case = GetParam();
+ auto test_case = IteratorOutputPrefixTestCase1();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
- Tensor range_and_take_dataset_tensor;
- TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params,
- test_case.take_dataset_params,
- &range_and_take_dataset_tensor));
+ TF_ASSERT_OK(MakeRangeDataset(
+ test_case.dataset_params.range_dataset_params.start,
+ test_case.dataset_params.range_dataset_params.stop,
+ test_case.dataset_params.range_dataset_params.step,
+ test_case.dataset_params.range_dataset_params.output_dtypes,
+ test_case.dataset_params.range_dataset_params.output_shapes,
+ &test_case.dataset_params.input_dataset));
- Tensor rate = CreateTensor<float>(TensorShape({}), {test_case.rate});
- Tensor seed = CreateTensor<int64>(TensorShape({}), {test_case.seed});
- Tensor seed2 = CreateTensor<int64>(TensorShape({}), {test_case.seed2});
- gtl::InlinedVector<TensorValue, 4> inputs(
- {TensorValue(&range_and_take_dataset_tensor), TensorValue(&rate),
- TensorValue(&seed), TensorValue(&seed2)});
+ gtl::InlinedVector<TensorValue, 4> inputs;
+ TF_ASSERT_OK(test_case.dataset_params.MakeInputs(&inputs));
std::unique_ptr<OpKernel> sampling_dataset_kernel;
- TF_ASSERT_OK(CreateSamplingDatasetOpKernel(test_case.expected_output_dtypes,
- test_case.expected_output_shapes,
- &sampling_dataset_kernel));
+ TF_ASSERT_OK(CreateSamplingDatasetOpKernel(
+ test_case.dataset_params.output_dtypes,
+ test_case.dataset_params.output_shapes, &sampling_dataset_kernel));
std::unique_ptr<OpKernelContext> sampling_dataset_context;
TF_ASSERT_OK(CreateSamplingDatasetContext(
@@ -634,43 +747,67 @@
TF_ASSERT_OK(
CreateIteratorContext(sampling_dataset_context.get(), &iterator_context));
std::unique_ptr<IteratorBase> iterator;
- string iterator_prefix = name_utils::IteratorPrefix(
- TakeDatasetOp::kDatasetType,
- name_utils::IteratorPrefix(RangeDatasetOp::kDatasetType, "Iterator"));
TF_ASSERT_OK(sampling_dataset->MakeIterator(iterator_context.get(),
- iterator_prefix, &iterator));
+ kIteratorPrefix, &iterator));
// END INITIALIZATION CODE
- EXPECT_EQ(iterator->prefix(),
- name_utils::IteratorPrefix(SamplingDatasetOp::kDatasetType,
- iterator_prefix));
+ TF_ASSERT_OK(
+ CheckIteratorPrefix(*iterator, test_case.expected_iterator_prefix));
+}
+
+class ParameterizedIteratorSaveAndRestoreSamplingDatasetOpTest
+ : public SamplingDatasetOpTest,
+ public ::testing::WithParamInterface<
+ IteratorSaveAndRestoreTestCase<SamplingDatasetParams>> {};
+
+IteratorSaveAndRestoreTestCase<SamplingDatasetParams>
+IteratorSaveAndRestoreTestCase1() {
+ return {/*dataset_params=*/OneHundredPercentSampleDataset(),
+ /*breakpoints=*/{0, 2, 5},
+ /*expected_outputs=*/
+ CreateTensors<int64>(TensorShape({}), {{0}, {1}, {2}})};
+}
+
+IteratorSaveAndRestoreTestCase<SamplingDatasetParams>
+IteratorSaveAndRestoreTestCase2() {
+ return {/*dataset_params=*/TenPercentSampleDataset(),
+ /*breakpoints=*/{0, 2, 5},
+ /*expected_outputs=*/
+ CreateTensors<int64>(TensorShape({}), {{9}, {11}, {19}})};
+}
+
+IteratorSaveAndRestoreTestCase<SamplingDatasetParams>
+IteratorSaveAndRestoreTestCase3() {
+ return {/*dataset_params=*/ZeroPercentSampleDataset(),
+ /*breakpoints=*/{0, 2, 5},
+ /*expected_outputs=*/{}};
}
// Save and restore the dataset while scanning it. Verify the returned tuples.
-TEST_P(ParameterizedSamplingDatasetOpTest, Roundtrip) {
+TEST_P(ParameterizedIteratorSaveAndRestoreSamplingDatasetOpTest, Roundtrip) {
// BEGIN INITIALIZATION CODE
- // See ParameterizedSamplingDatasetOpTest::GetNext for explanatory comments.
+ // See ParameterizedGetNextSamplingDatasetOpTest::GetNext for explanatory
+ // comments.
const int thread_num = 2, cpu_num = 2;
- TestCase test_case = GetParam();
+ auto test_case = GetParam();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
- Tensor range_and_take_dataset_tensor;
- TF_ASSERT_OK(MakeRangeAndTakeDatasetTensor(test_case.range_dataset_params,
- test_case.take_dataset_params,
- &range_and_take_dataset_tensor));
+ TF_ASSERT_OK(MakeRangeDataset(
+ test_case.dataset_params.range_dataset_params.start,
+ test_case.dataset_params.range_dataset_params.stop,
+ test_case.dataset_params.range_dataset_params.step,
+ test_case.dataset_params.range_dataset_params.output_dtypes,
+ test_case.dataset_params.range_dataset_params.output_shapes,
+ &test_case.dataset_params.input_dataset));
- Tensor rate = CreateTensor<float>(TensorShape({}), {test_case.rate});
- Tensor seed = CreateTensor<int64>(TensorShape({}), {test_case.seed});
- Tensor seed2 = CreateTensor<int64>(TensorShape({}), {test_case.seed2});
- gtl::InlinedVector<TensorValue, 4> inputs(
- {TensorValue(&range_and_take_dataset_tensor), TensorValue(&rate),
- TensorValue(&seed), TensorValue(&seed2)});
+ gtl::InlinedVector<TensorValue, 4> inputs;
+ TF_ASSERT_OK(test_case.dataset_params.MakeInputs(&inputs));
std::unique_ptr<OpKernel> sampling_dataset_kernel;
- TF_ASSERT_OK(CreateSamplingDatasetOpKernel(test_case.expected_output_dtypes,
- test_case.expected_output_shapes,
- &sampling_dataset_kernel));
+ TF_ASSERT_OK(CreateSamplingDatasetOpKernel(
+ test_case.dataset_params.output_dtypes,
+ test_case.dataset_params.output_shapes, &sampling_dataset_kernel));
std::unique_ptr<OpKernelContext> sampling_dataset_context;
TF_ASSERT_OK(CreateSamplingDatasetContext(
@@ -686,46 +823,23 @@
TF_ASSERT_OK(
CreateIteratorContext(sampling_dataset_context.get(), &iterator_context));
std::unique_ptr<IteratorBase> iterator;
- string iterator_prefix = name_utils::IteratorPrefix(
- TakeDatasetOp::kDatasetType,
- name_utils::IteratorPrefix(RangeDatasetOp::kDatasetType, "Iterator"));
TF_ASSERT_OK(sampling_dataset->MakeIterator(iterator_context.get(),
- iterator_prefix, &iterator));
+ kIteratorPrefix, &iterator));
// END INITIALIZATION CODE
- std::unique_ptr<SerializationContext> serialization_ctx;
- TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
- bool end_of_sequence = false;
- std::vector<Tensor> out_tensors;
- int cur_iteration = 0;
- const std::vector<int>& breakpoints = test_case.breakpoints;
- for (int breakpoint : breakpoints) {
- VariantTensorData data;
- VariantTensorDataWriter writer(&data);
- TF_EXPECT_OK(iterator->Save(serialization_ctx.get(), &writer));
- TF_EXPECT_OK(writer.Flush());
- VariantTensorDataReader reader(&data);
- TF_EXPECT_OK(RestoreIterator(iterator_context.get(), &reader,
- iterator_prefix, *sampling_dataset,
- &iterator));
-
- while (cur_iteration <= breakpoint) {
- std::vector<Tensor> next;
- TF_EXPECT_OK(
- iterator->GetNext(iterator_context.get(), &next, &end_of_sequence));
- out_tensors.insert(out_tensors.end(), next.begin(), next.end());
- ++cur_iteration;
- }
- }
-
- TF_EXPECT_OK(ExpectEqual(out_tensors, test_case.expected_outputs,
- /*compare_order*/ true));
+ TF_ASSERT_OK(CheckIteratorSaveAndRestore(
+ *sampling_dataset, iterator_context.get(), kIteratorPrefix,
+ test_case.expected_outputs, test_case.breakpoints));
}
-INSTANTIATE_TEST_SUITE_P(SamplingDatasetOpTest,
- ParameterizedSamplingDatasetOpTest,
- ::testing::ValuesIn(std::vector<TestCase>(
- {TestCase1(), TestCase2(), TestCase3()})));
+INSTANTIATE_TEST_SUITE_P(
+ SamplingDatasetOpTest,
+ ParameterizedIteratorSaveAndRestoreSamplingDatasetOpTest,
+ ::testing::ValuesIn(
+ std::vector<IteratorSaveAndRestoreTestCase<SamplingDatasetParams>>(
+ {IteratorSaveAndRestoreTestCase1(),
+ IteratorSaveAndRestoreTestCase2(),
+ IteratorSaveAndRestoreTestCase3()})));
} // namespace
} // namespace experimental