blob: dfa6959e529ec2a36de5df9ed6319ad7c5d0dbf4 [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/data/range_dataset_op.h"
#include "tensorflow/core/kernels/data/dataset_test_base.h"
namespace tensorflow {
namespace data {
namespace {
constexpr char kNodeName[] = "range_dataset";
constexpr char kIteratorPrefix[] = "Iterator";
class RangeDatasetOpTest : public DatasetOpsTestBase {
protected:
// Creates a new RangeDataset op kernel context.
Status CreateRangeDatasetContext(
OpKernel* const range_kernel,
gtl::InlinedVector<TensorValue, 4>* const inputs,
std::unique_ptr<OpKernelContext>* range_context) {
TF_RETURN_IF_ERROR(CheckOpKernelInput(*range_kernel, *inputs));
TF_RETURN_IF_ERROR(
CreateOpKernelContext(range_kernel, inputs, range_context));
return Status::OK();
}
};
class RangeDatasetParams : public DatasetParams {
public:
RangeDatasetParams(int64 start, int64 stop, int64 step,
DataTypeVector output_dtypes,
std::vector<PartialTensorShape> output_shapes,
string node_name)
: DatasetParams(std::move(output_dtypes), std::move(output_shapes),
std::move(node_name)),
start(CreateTensor<int64>(TensorShape({}), {start})),
stop(CreateTensor<int64>(TensorShape({}), {stop})),
step(CreateTensor<int64>(TensorShape({}), {step})) {}
Status MakeInputs(gtl::InlinedVector<TensorValue, 4>* inputs) override {
*inputs = {TensorValue(&start), TensorValue(&stop), TensorValue(&step)};
return Status::OK();
}
Tensor start;
Tensor stop;
Tensor step;
};
RangeDatasetParams PositiveStepRangeDataset() {
return {/*start=*/0,
/*stop=*/10,
/*step=*/3,
/*output_dtypes=*/{DT_INT64},
/*output_shapes=*/{PartialTensorShape({})},
/*node_name=*/kNodeName};
}
RangeDatasetParams NegativeStepRangeDataset() {
return {/*start=*/10,
/*stop=*/0,
/*step=*/-3,
/*output_dtypes=*/{DT_INT64},
/*output_shapes=*/{PartialTensorShape({})},
/*node_name=*/kNodeName};
}
RangeDatasetParams ZeroStepRangeDataset() {
return {/*start=*/10,
/*stop=*/0,
/*step=*/0,
/*output_dtypes=*/{DT_INT64},
/*output_shapes=*/{PartialTensorShape({})},
/*node_name=*/kNodeName};
}
class ParameterizedGetNextRangeDatasetOpTest
: public RangeDatasetOpTest,
public ::testing::WithParamInterface<
GetNextTestCase<RangeDatasetParams>> {};
GetNextTestCase<RangeDatasetParams> GetNextTestCase1() {
return {/*dataset_params=*/PositiveStepRangeDataset(),
/*expected_outputs=*/
CreateTensors<int64>(TensorShape({}), {{0}, {3}, {6}, {9}})};
}
GetNextTestCase<RangeDatasetParams> GetNextTestCase2() {
return {/*dataset_params=*/NegativeStepRangeDataset(),
/*expected_outputs=*/
CreateTensors<int64>(TensorShape({}), {{10}, {7}, {4}, {1}})};
}
TEST_P(ParameterizedGetNextRangeDatasetOpTest, GetNext) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
GetNextTestCase<RangeDatasetParams> test_case = GetParam();
gtl::InlinedVector<TensorValue, 4> inputs;
TF_ASSERT_OK(test_case.dataset_params.MakeInputs(&inputs));
std::unique_ptr<OpKernel> range_dataset_kernel;
TF_ASSERT_OK(CreateRangeDatasetOpKernel<int64>(
test_case.dataset_params.node_name, &range_dataset_kernel));
std::unique_ptr<OpKernelContext> range_dataset_context;
TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs,
&range_dataset_context));
DatasetBase* range_dataset;
TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(),
range_dataset_context.get(), &range_dataset));
core::ScopedUnref scoped_unref(range_dataset);
std::unique_ptr<IteratorContext> iterator_context;
TF_ASSERT_OK(
CreateIteratorContext(range_dataset_context.get(), &iterator_context));
std::unique_ptr<IteratorBase> iterator;
TF_ASSERT_OK(range_dataset->MakeIterator(iterator_context.get(),
kIteratorPrefix, &iterator));
TF_ASSERT_OK(CheckIteratorGetNext(iterator.get(), iterator_context.get(),
test_case.expected_outputs,
/*compare_order=*/true));
}
INSTANTIATE_TEST_SUITE_P(
RangeDatasetOpTest, ParameterizedGetNextRangeDatasetOpTest,
::testing::ValuesIn(std::vector<GetNextTestCase<RangeDatasetParams>>(
{GetNextTestCase1(), GetNextTestCase2()})));
DatasetNodeNameTestCase<RangeDatasetParams> DatasetNodeNameTestCase1() {
return {/*dataset_params=*/PositiveStepRangeDataset(),
/*expected_node_name=*/kNodeName};
}
TEST_F(RangeDatasetOpTest, DatasetNodeName) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
auto test_case = DatasetNodeNameTestCase1();
gtl::InlinedVector<TensorValue, 4> inputs;
TF_ASSERT_OK(test_case.dataset_params.MakeInputs(&inputs));
std::unique_ptr<OpKernel> range_dataset_kernel;
TF_ASSERT_OK(CreateRangeDatasetOpKernel<int64>(
test_case.dataset_params.node_name, &range_dataset_kernel));
std::unique_ptr<OpKernelContext> range_dataset_context;
TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs,
&range_dataset_context));
DatasetBase* range_dataset;
TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(),
range_dataset_context.get(), &range_dataset));
core::ScopedUnref scoped_unref(range_dataset);
TF_ASSERT_OK(
CheckDatasetNodeName(*range_dataset, test_case.expected_node_name));
}
DatasetTypeStringTestCase<RangeDatasetParams> DatasetTypeStringTestCase1() {
return {/*dataset_params=*/PositiveStepRangeDataset(),
/*expected_dataset_type_string=*/
name_utils::OpName(RangeDatasetOp::kDatasetType)};
}
TEST_F(RangeDatasetOpTest, DatasetTypeString) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
auto test_case = DatasetTypeStringTestCase1();
gtl::InlinedVector<TensorValue, 4> inputs;
TF_ASSERT_OK(test_case.dataset_params.MakeInputs(&inputs));
std::unique_ptr<OpKernel> range_dataset_kernel;
TF_ASSERT_OK(CreateRangeDatasetOpKernel<int64>(
test_case.dataset_params.node_name, &range_dataset_kernel));
std::unique_ptr<OpKernelContext> range_dataset_context;
TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs,
&range_dataset_context));
DatasetBase* range_dataset;
TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(),
range_dataset_context.get(), &range_dataset));
core::ScopedUnref scoped_unref(range_dataset);
TF_ASSERT_OK(CheckDatasetTypeString(*range_dataset,
test_case.expected_dataset_type_string));
}
DatasetOutputDtypesTestCase<RangeDatasetParams> DatasetOutputDtypesTestCase1() {
return {/*dataset_params=*/PositiveStepRangeDataset(),
/*expected_output_dtypes=*/{DT_INT64}};
}
TEST_F(RangeDatasetOpTest, DatasetOutputDtypes) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
auto test_case = DatasetOutputDtypesTestCase1();
gtl::InlinedVector<TensorValue, 4> inputs;
TF_ASSERT_OK(test_case.dataset_params.MakeInputs(&inputs));
std::unique_ptr<OpKernel> range_dataset_kernel;
TF_ASSERT_OK(CreateRangeDatasetOpKernel<int64>(
test_case.dataset_params.node_name, &range_dataset_kernel));
std::unique_ptr<OpKernelContext> range_dataset_context;
TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs,
&range_dataset_context));
DatasetBase* range_dataset;
TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(),
range_dataset_context.get(), &range_dataset));
core::ScopedUnref scoped_unref(range_dataset);
TF_ASSERT_OK(CheckDatasetOutputDtypes(*range_dataset,
test_case.expected_output_dtypes));
}
DatasetOutputShapesTestCase<RangeDatasetParams> DatasetOutputShapesTestCase1() {
return {/*dataset_params=*/PositiveStepRangeDataset(),
/*expected_output_shapes=*/{PartialTensorShape({})}};
}
TEST_F(RangeDatasetOpTest, DatasetOutputShapes) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
auto test_case = DatasetOutputShapesTestCase1();
gtl::InlinedVector<TensorValue, 4> inputs;
TF_ASSERT_OK(test_case.dataset_params.MakeInputs(&inputs));
std::unique_ptr<OpKernel> range_dataset_kernel;
TF_ASSERT_OK(CreateRangeDatasetOpKernel<int64>(
test_case.dataset_params.node_name, &range_dataset_kernel));
std::unique_ptr<OpKernelContext> range_dataset_context;
TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs,
&range_dataset_context));
DatasetBase* range_dataset;
TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(),
range_dataset_context.get(), &range_dataset));
core::ScopedUnref scoped_unref(range_dataset);
TF_ASSERT_OK(CheckDatasetOutputShapes(*range_dataset,
test_case.expected_output_shapes));
}
class ParameterizedCardinalityRangeDatasetOpTest
: public RangeDatasetOpTest,
public ::testing::WithParamInterface<
CardinalityTestCase<RangeDatasetParams>> {};
CardinalityTestCase<RangeDatasetParams> CardinalityTestCase1() {
return {/*dataset_params=*/PositiveStepRangeDataset(),
/*expected_cardinality=*/4};
}
CardinalityTestCase<RangeDatasetParams> CardinalityTestCase2() {
return {/*dataset_params=*/NegativeStepRangeDataset(),
/*expected_cardinality=*/4};
}
TEST_P(ParameterizedCardinalityRangeDatasetOpTest, Cardinality) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
auto test_case = GetParam();
gtl::InlinedVector<TensorValue, 4> inputs;
TF_ASSERT_OK(test_case.dataset_params.MakeInputs(&inputs));
std::unique_ptr<OpKernel> range_dataset_kernel;
TF_ASSERT_OK(CreateRangeDatasetOpKernel<int64>(
test_case.dataset_params.node_name, &range_dataset_kernel));
std::unique_ptr<OpKernelContext> range_dataset_context;
TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs,
&range_dataset_context));
DatasetBase* range_dataset;
TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(),
range_dataset_context.get(), &range_dataset));
core::ScopedUnref scoped_unref(range_dataset);
TF_ASSERT_OK(
CheckDatasetCardinality(*range_dataset, test_case.expected_cardinality));
}
INSTANTIATE_TEST_SUITE_P(
RangeDatasetOpTest, ParameterizedCardinalityRangeDatasetOpTest,
::testing::ValuesIn(std::vector<CardinalityTestCase<RangeDatasetParams>>(
{CardinalityTestCase1(), CardinalityTestCase2()})));
DatasetSaveTestCase<RangeDatasetParams> DatasetSaveTestCase1() {
return {/*dataset_params=*/PositiveStepRangeDataset()};
}
IsStatefulTestCase<RangeDatasetParams> IsStatefulTestCase1() {
return {/*dataset_params=*/PositiveStepRangeDataset(),
/*expected_stateful=*/false};
}
TEST_F(RangeDatasetOpTest, IsStateful) {
int64 thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
auto test_case = IsStatefulTestCase1();
gtl::InlinedVector<TensorValue, 4> inputs;
TF_ASSERT_OK(test_case.dataset_params.MakeInputs(&inputs));
std::unique_ptr<OpKernel> range_dataset_kernel;
TF_ASSERT_OK(CreateRangeDatasetOpKernel<int64>(
test_case.dataset_params.node_name, &range_dataset_kernel));
std::unique_ptr<OpKernelContext> range_dataset_context;
TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs,
&range_dataset_context));
DatasetBase* range_dataset;
TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(),
range_dataset_context.get(), &range_dataset));
core::ScopedUnref scoped_unref(range_dataset);
TF_ASSERT_OK(
CheckDatasetIsStateful(*range_dataset, test_case.expected_stateful));
}
IteratorOutputDtypesTestCase<RangeDatasetParams>
IteratorOutputDtypesTestCase1() {
return {/*dataset_params=*/PositiveStepRangeDataset(),
/*expected_output_dtypes=*/{DT_INT64}};
}
TEST_F(RangeDatasetOpTest, IteratorOutputDtypes) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
auto test_case = IteratorOutputDtypesTestCase1();
gtl::InlinedVector<TensorValue, 4> inputs;
TF_ASSERT_OK(test_case.dataset_params.MakeInputs(&inputs));
std::unique_ptr<OpKernel> range_dataset_kernel;
TF_ASSERT_OK(CreateRangeDatasetOpKernel<int64>(
test_case.dataset_params.node_name, &range_dataset_kernel));
std::unique_ptr<OpKernelContext> range_dataset_context;
TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs,
&range_dataset_context));
DatasetBase* range_dataset;
TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(),
range_dataset_context.get(), &range_dataset));
core::ScopedUnref scoped_unref(range_dataset);
std::unique_ptr<IteratorContext> iterator_context;
TF_ASSERT_OK(
CreateIteratorContext(range_dataset_context.get(), &iterator_context));
std::unique_ptr<IteratorBase> iterator;
TF_ASSERT_OK(range_dataset->MakeIterator(iterator_context.get(),
kIteratorPrefix, &iterator));
TF_ASSERT_OK(
CheckIteratorOutputDtypes(*iterator, test_case.expected_output_dtypes));
}
IteratorOutputShapesTestCase<RangeDatasetParams>
IteratorOutputShapesTestCase1() {
return {/*dataset_params=*/PositiveStepRangeDataset(),
/*expected_output_shapes=*/{PartialTensorShape({})}};
}
TEST_F(RangeDatasetOpTest, IteratorOutputShapes) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
auto test_case = IteratorOutputShapesTestCase1();
gtl::InlinedVector<TensorValue, 4> inputs;
TF_ASSERT_OK(test_case.dataset_params.MakeInputs(&inputs));
std::unique_ptr<OpKernel> range_dataset_kernel;
TF_ASSERT_OK(CreateRangeDatasetOpKernel<int64>(
test_case.dataset_params.node_name, &range_dataset_kernel));
std::unique_ptr<OpKernelContext> range_dataset_context;
TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs,
&range_dataset_context));
DatasetBase* range_dataset;
TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(),
range_dataset_context.get(), &range_dataset));
core::ScopedUnref scoped_unref(range_dataset);
std::unique_ptr<IteratorContext> iterator_context;
TF_ASSERT_OK(
CreateIteratorContext(range_dataset_context.get(), &iterator_context));
std::unique_ptr<IteratorBase> iterator;
TF_ASSERT_OK(range_dataset->MakeIterator(iterator_context.get(),
kIteratorPrefix, &iterator));
TF_ASSERT_OK(
CheckIteratorOutputShapes(*iterator, test_case.expected_output_shapes));
}
IteratorOutputPrefixTestCase<RangeDatasetParams>
IteratorOutputPrefixTestCase1() {
return {/*dataset_params=*/PositiveStepRangeDataset(),
/*expected_iterator_prefix=*/name_utils::IteratorPrefix(
RangeDatasetOp::kDatasetType, kIteratorPrefix)};
}
TEST_F(RangeDatasetOpTest, IteratorOutputPrefix) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
auto test_case = IteratorOutputPrefixTestCase1();
gtl::InlinedVector<TensorValue, 4> inputs;
TF_ASSERT_OK(test_case.dataset_params.MakeInputs(&inputs));
std::unique_ptr<OpKernel> range_dataset_kernel;
TF_ASSERT_OK(CreateRangeDatasetOpKernel<int64>(
test_case.dataset_params.node_name, &range_dataset_kernel));
std::unique_ptr<OpKernelContext> range_dataset_context;
TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs,
&range_dataset_context));
DatasetBase* range_dataset;
TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(),
range_dataset_context.get(), &range_dataset));
core::ScopedUnref scoped_unref(range_dataset);
std::unique_ptr<IteratorContext> iterator_context;
TF_ASSERT_OK(
CreateIteratorContext(range_dataset_context.get(), &iterator_context));
std::unique_ptr<IteratorBase> iterator;
TF_ASSERT_OK(range_dataset->MakeIterator(iterator_context.get(),
kIteratorPrefix, &iterator));
TF_ASSERT_OK(
CheckIteratorPrefix(*iterator, test_case.expected_iterator_prefix));
}
class ParameterizedIteratorSaveAndRestoreRangeDatasetOpTest
: public RangeDatasetOpTest,
public ::testing::WithParamInterface<
IteratorSaveAndRestoreTestCase<RangeDatasetParams>> {};
IteratorSaveAndRestoreTestCase<RangeDatasetParams>
IteratorSaveAndRestoreTestCase1() {
return {/*dataset_params=*/PositiveStepRangeDataset(),
/*breakpoints=*/{0, 1, 4},
/*expected_outputs=*/
CreateTensors<int64>(TensorShape({}), {{0}, {3}, {6}, {9}})};
}
IteratorSaveAndRestoreTestCase<RangeDatasetParams>
IteratorSaveAndRestoreTestCase2() {
return {/*dataset_params=*/NegativeStepRangeDataset(),
/*breakpoints=*/{0, 1, 4},
/*expected_outputs=*/
CreateTensors<int64>(TensorShape({}), {{10}, {7}, {4}, {1}})};
}
TEST_P(ParameterizedIteratorSaveAndRestoreRangeDatasetOpTest,
IteratorSaveAndRestore) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
auto test_case = GetParam();
gtl::InlinedVector<TensorValue, 4> inputs;
TF_ASSERT_OK(test_case.dataset_params.MakeInputs(&inputs));
std::unique_ptr<OpKernel> range_dataset_kernel;
TF_ASSERT_OK(CreateRangeDatasetOpKernel<int64>(
test_case.dataset_params.node_name, &range_dataset_kernel));
std::unique_ptr<OpKernelContext> range_dataset_context;
TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs,
&range_dataset_context));
DatasetBase* range_dataset;
TF_ASSERT_OK(CreateDataset(range_dataset_kernel.get(),
range_dataset_context.get(), &range_dataset));
core::ScopedUnref scoped_unref(range_dataset);
std::unique_ptr<IteratorContext> iterator_context;
TF_ASSERT_OK(
CreateIteratorContext(range_dataset_context.get(), &iterator_context));
std::unique_ptr<IteratorBase> iterator;
TF_ASSERT_OK(range_dataset->MakeIterator(iterator_context.get(),
kIteratorPrefix, &iterator));
TF_ASSERT_OK(CheckIteratorSaveAndRestore(
*range_dataset, iterator_context.get(), kIteratorPrefix,
test_case.expected_outputs, test_case.breakpoints));
}
INSTANTIATE_TEST_SUITE_P(
RangeDatasetOpTest, ParameterizedIteratorSaveAndRestoreRangeDatasetOpTest,
::testing::ValuesIn(
std::vector<IteratorSaveAndRestoreTestCase<RangeDatasetParams>>(
{IteratorSaveAndRestoreTestCase1(),
IteratorSaveAndRestoreTestCase2()})));
GetNextTestCase<RangeDatasetParams> ZeroStepTestCase1() {
return {/*dataset_params=*/ZeroStepRangeDataset(),
/*expected_outputs=*/{}};
}
TEST_F(RangeDatasetOpTest, ZeroStep) {
int thread_num = 2, cpu_num = 2;
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
auto test_case = ZeroStepTestCase1();
gtl::InlinedVector<TensorValue, 4> inputs;
TF_ASSERT_OK(test_case.dataset_params.MakeInputs(&inputs));
std::unique_ptr<OpKernel> range_dataset_kernel;
TF_ASSERT_OK(CreateRangeDatasetOpKernel<int64>(
test_case.dataset_params.node_name, &range_dataset_kernel));
std::unique_ptr<OpKernelContext> range_dataset_context;
TF_ASSERT_OK(CreateRangeDatasetContext(range_dataset_kernel.get(), &inputs,
&range_dataset_context));
DatasetBase* range_dataset;
EXPECT_EQ(CreateDataset(range_dataset_kernel.get(),
range_dataset_context.get(), &range_dataset)
.code(),
tensorflow::error::INVALID_ARGUMENT);
}
} // namespace
} // namespace data
} // namespace tensorflow