blob: 936d7e1cf162cb37bd6706c21cfa4c5c7837c0d6 [file] [log] [blame]
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/data/tf_record_dataset_op.h"
#include "tensorflow/core/kernels/data/dataset_test_base.h"
namespace tensorflow {
namespace data {
namespace {
constexpr char kNodeName[] = "tf_record_dataset";
constexpr char kIteratorPrefix[] = "Iterator";
class TFRecordDatasetOpTest : public DatasetOpsTestBase {
protected:
// Create a new `TFRecordDataset` op kernel.
Status CreateTFRecordDatasetOpKernel(
std::unique_ptr<OpKernel>* tf_record_dataset_op_kernel) {
NodeDef node_def = test::function::NDef(
kNodeName, name_utils::OpName(TFRecordDatasetOp::kDatasetType),
{TFRecordDatasetOp::kFileNames, TFRecordDatasetOp::kCompressionType,
TFRecordDatasetOp::kBufferSize},
{});
TF_RETURN_IF_ERROR(CreateOpKernel(node_def, tf_record_dataset_op_kernel));
return Status::OK();
}
// Create a new `TFRecordDataset` op kernel context
Status CreateTFRecordDatasetContext(
OpKernel* const op_kernel,
gtl::InlinedVector<TensorValue, 4>* const inputs,
std::unique_ptr<OpKernelContext>* context) {
TF_RETURN_IF_ERROR(CheckOpKernelInput(*op_kernel, *inputs));
TF_RETURN_IF_ERROR(CreateOpKernelContext(op_kernel, inputs, context));
return Status::OK();
}
};
struct TestCase {
std::vector<string> filenames;
std::vector<std::vector<string>> contents;
CompressionType compression_type;
int64 buffer_size;
std::vector<Tensor> expected_outputs;
DataTypeVector expected_output_dtypes;
std::vector<PartialTensorShape> expected_output_shapes;
int64 expected_cardinality;
std::vector<int> breakpoints;
};
Status CreateTestFiles(const TestCase& test_case) {
if (test_case.filenames.size() != test_case.contents.size()) {
return tensorflow::errors::InvalidArgument(
"The number of files does not match with the contents");
}
CompressionParams params;
params.compression_type = test_case.compression_type;
params.input_buffer_size = test_case.buffer_size;
for (int i = 0; i < test_case.filenames.size(); ++i) {
std::vector<absl::string_view> records(test_case.contents[i].begin(),
test_case.contents[i].end());
TF_RETURN_IF_ERROR(
WriteDataToTFRecordFile(test_case.filenames[i], records, params));
}
return Status::OK();
}
// Test case 1: multiple text files with ZLIB compression.
TestCase TestCase1() {
return {/*filenames*/ {absl::StrCat(testing::TmpDir(), "/tf_record_ZLIB_1"),
absl::StrCat(testing::TmpDir(), "/tf_record_ZLIB_2")},
/*contents*/
{{"1", "22", "333"}, {"a", "bb", "ccc"}},
/*compression_type*/ CompressionType::ZLIB,
/*buffer_size*/ 10,
/*expected_outputs*/
{CreateTensor<string>(TensorShape({}), {"1"}),
CreateTensor<string>(TensorShape({}), {"22"}),
CreateTensor<string>(TensorShape({}), {"333"}),
CreateTensor<string>(TensorShape({}), {"a"}),
CreateTensor<string>(TensorShape({}), {"bb"}),
CreateTensor<string>(TensorShape({}), {"ccc"})},
/*expected_output_dtypes*/ {DT_STRING},
/*expected_output_shapes*/ {PartialTensorShape({})},
/*expected_cardinality*/ kUnknownCardinality,
/*breakpoints*/ {0, 2, 7}};
}
// Test case 2: multiple text files with GZIP compression.
TestCase TestCase2() {
return {/*filenames*/ {absl::StrCat(testing::TmpDir(), "/tf_record_GZIP_1"),
absl::StrCat(testing::TmpDir(), "/tf_record_GZIP_2")},
/*contents*/
{{"1", "22", "333"}, {"a", "bb", "ccc"}},
/*compression_type*/ CompressionType::GZIP,
/*buffer_size*/ 10,
/*expected_outputs*/
{CreateTensor<string>(TensorShape({}), {"1"}),
CreateTensor<string>(TensorShape({}), {"22"}),
CreateTensor<string>(TensorShape({}), {"333"}),
CreateTensor<string>(TensorShape({}), {"a"}),
CreateTensor<string>(TensorShape({}), {"bb"}),
CreateTensor<string>(TensorShape({}), {"ccc"})},
/*expected_output_dtypes*/ {DT_STRING},
/*expected_output_shapes*/ {PartialTensorShape({})},
/*expected_cardinality*/ kUnknownCardinality,
/*breakpoints*/ {0, 2, 7}};
}
// Test case 3: multiple text files without compression.
TestCase TestCase3() {
return {/*filenames*/ {
absl::StrCat(testing::TmpDir(), "/tf_record_UNCOMPRESSED_1"),
absl::StrCat(testing::TmpDir(), "/tf_record_UNCOMPRESSED_2")},
/*contents*/
{{"1", "22", "333"}, {"a", "bb", "ccc"}},
/*compression_type*/ CompressionType::UNCOMPRESSED,
/*buffer_size*/ 10,
/*expected_outputs*/
{CreateTensor<string>(TensorShape({}), {"1"}),
CreateTensor<string>(TensorShape({}), {"22"}),
CreateTensor<string>(TensorShape({}), {"333"}),
CreateTensor<string>(TensorShape({}), {"a"}),
CreateTensor<string>(TensorShape({}), {"bb"}),
CreateTensor<string>(TensorShape({}), {"ccc"})},
/*expected_output_dtypes*/ {DT_STRING},
/*expected_output_shapes*/ {PartialTensorShape({})},
/*expected_cardinality*/ kUnknownCardinality,
/*breakpoints*/ {0, 2, 7}};
}
class ParameterizedTFRecordDatasetOpTest
: public TFRecordDatasetOpTest,
public ::testing::WithParamInterface<TestCase> {};
TEST_P(ParameterizedTFRecordDatasetOpTest, GetNext) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = GetParam();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
TF_ASSERT_OK(CreateTestFiles(test_case));
std::unique_ptr<OpKernel> tf_record_dataset_kernel;
TF_ASSERT_OK(CreateTFRecordDatasetOpKernel(&tf_record_dataset_kernel));
int64 num_files = test_case.filenames.size();
Tensor filenames =
CreateTensor<string>(TensorShape({num_files}), test_case.filenames);
Tensor compression_type = CreateTensor<string>(
TensorShape({}), {ToString(test_case.compression_type)});
Tensor buffer_size =
CreateTensor<int64>(TensorShape({}), {test_case.buffer_size});
gtl::InlinedVector<TensorValue, 4> inputs{TensorValue(&filenames),
TensorValue(&compression_type),
TensorValue(&buffer_size)};
std::unique_ptr<OpKernelContext> tf_record_dataset_context;
TF_ASSERT_OK(CreateTFRecordDatasetContext(
tf_record_dataset_kernel.get(), &inputs, &tf_record_dataset_context));
DatasetBase* tf_record_dataset;
TF_ASSERT_OK(CreateDataset(tf_record_dataset_kernel.get(),
tf_record_dataset_context.get(),
&tf_record_dataset));
core::ScopedUnref scoped_unref(tf_record_dataset);
std::unique_ptr<IteratorContext> iterator_ctx;
TF_ASSERT_OK(
CreateIteratorContext(tf_record_dataset_context.get(), &iterator_ctx));
std::unique_ptr<IteratorBase> iterator;
TF_ASSERT_OK(tf_record_dataset->MakeIterator(iterator_ctx.get(),
kIteratorPrefix, &iterator));
bool end_of_sequence = false;
std::vector<Tensor> out_tensors;
while (!end_of_sequence) {
std::vector<Tensor> next;
TF_EXPECT_OK(
iterator->GetNext(iterator_ctx.get(), &next, &end_of_sequence));
out_tensors.insert(out_tensors.end(), next.begin(), next.end());
}
TF_EXPECT_OK(ExpectEqual(out_tensors, test_case.expected_outputs,
/*compare_order*/ true));
}
TEST_F(TFRecordDatasetOpTest, DatasetNodeName) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = TestCase1();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
TF_ASSERT_OK(CreateTestFiles(test_case));
std::unique_ptr<OpKernel> tf_record_dataset_kernel;
TF_ASSERT_OK(CreateTFRecordDatasetOpKernel(&tf_record_dataset_kernel));
int64 num_files = test_case.filenames.size();
Tensor filenames =
CreateTensor<string>(TensorShape({num_files}), test_case.filenames);
Tensor compression_type = CreateTensor<string>(
TensorShape({}), {ToString(test_case.compression_type)});
Tensor buffer_size =
CreateTensor<int64>(TensorShape({}), {test_case.buffer_size});
gtl::InlinedVector<TensorValue, 4> inputs{TensorValue(&filenames),
TensorValue(&compression_type),
TensorValue(&buffer_size)};
std::unique_ptr<OpKernelContext> tf_record_dataset_context;
TF_ASSERT_OK(CreateTFRecordDatasetContext(
tf_record_dataset_kernel.get(), &inputs, &tf_record_dataset_context));
DatasetBase* tf_record_dataset;
TF_ASSERT_OK(CreateDataset(tf_record_dataset_kernel.get(),
tf_record_dataset_context.get(),
&tf_record_dataset));
core::ScopedUnref scoped_unref(tf_record_dataset);
EXPECT_EQ(tf_record_dataset->node_name(), kNodeName);
}
TEST_F(TFRecordDatasetOpTest, DatasetTypeString) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = TestCase1();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
TF_ASSERT_OK(CreateTestFiles(test_case));
std::unique_ptr<OpKernel> tf_record_dataset_kernel;
TF_ASSERT_OK(CreateTFRecordDatasetOpKernel(&tf_record_dataset_kernel));
int64 num_files = test_case.filenames.size();
Tensor filenames =
CreateTensor<string>(TensorShape({num_files}), test_case.filenames);
Tensor compression_type = CreateTensor<string>(
TensorShape({}), {ToString(test_case.compression_type)});
Tensor buffer_size =
CreateTensor<int64>(TensorShape({}), {test_case.buffer_size});
gtl::InlinedVector<TensorValue, 4> inputs{TensorValue(&filenames),
TensorValue(&compression_type),
TensorValue(&buffer_size)};
std::unique_ptr<OpKernelContext> tf_record_dataset_context;
TF_ASSERT_OK(CreateTFRecordDatasetContext(
tf_record_dataset_kernel.get(), &inputs, &tf_record_dataset_context));
DatasetBase* tf_record_dataset;
TF_ASSERT_OK(CreateDataset(tf_record_dataset_kernel.get(),
tf_record_dataset_context.get(),
&tf_record_dataset));
core::ScopedUnref scoped_unref(tf_record_dataset);
EXPECT_EQ(tf_record_dataset->type_string(),
name_utils::OpName(TFRecordDatasetOp::kDatasetType));
}
TEST_P(ParameterizedTFRecordDatasetOpTest, DatasetOutputDtypes) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = GetParam();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
TF_ASSERT_OK(CreateTestFiles(test_case));
std::unique_ptr<OpKernel> tf_record_dataset_kernel;
TF_ASSERT_OK(CreateTFRecordDatasetOpKernel(&tf_record_dataset_kernel));
int64 num_files = test_case.filenames.size();
Tensor filenames =
CreateTensor<string>(TensorShape({num_files}), test_case.filenames);
Tensor compression_type = CreateTensor<string>(
TensorShape({}), {ToString(test_case.compression_type)});
Tensor buffer_size =
CreateTensor<int64>(TensorShape({}), {test_case.buffer_size});
gtl::InlinedVector<TensorValue, 4> inputs{TensorValue(&filenames),
TensorValue(&compression_type),
TensorValue(&buffer_size)};
std::unique_ptr<OpKernelContext> tf_record_dataset_context;
TF_ASSERT_OK(CreateTFRecordDatasetContext(
tf_record_dataset_kernel.get(), &inputs, &tf_record_dataset_context));
DatasetBase* tf_record_dataset;
TF_ASSERT_OK(CreateDataset(tf_record_dataset_kernel.get(),
tf_record_dataset_context.get(),
&tf_record_dataset));
core::ScopedUnref scoped_unref(tf_record_dataset);
TF_EXPECT_OK(VerifyTypesMatch(tf_record_dataset->output_dtypes(),
test_case.expected_output_dtypes));
}
TEST_P(ParameterizedTFRecordDatasetOpTest, DatasetOutputShapes) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = GetParam();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
TF_ASSERT_OK(CreateTestFiles(test_case));
std::unique_ptr<OpKernel> tf_record_dataset_kernel;
TF_ASSERT_OK(CreateTFRecordDatasetOpKernel(&tf_record_dataset_kernel));
int64 num_files = test_case.filenames.size();
Tensor filenames =
CreateTensor<string>(TensorShape({num_files}), test_case.filenames);
Tensor compression_type = CreateTensor<string>(
TensorShape({}), {ToString(test_case.compression_type)});
Tensor buffer_size =
CreateTensor<int64>(TensorShape({}), {test_case.buffer_size});
gtl::InlinedVector<TensorValue, 4> inputs{TensorValue(&filenames),
TensorValue(&compression_type),
TensorValue(&buffer_size)};
std::unique_ptr<OpKernelContext> tf_record_dataset_context;
TF_ASSERT_OK(CreateTFRecordDatasetContext(
tf_record_dataset_kernel.get(), &inputs, &tf_record_dataset_context));
DatasetBase* tf_record_dataset;
TF_ASSERT_OK(CreateDataset(tf_record_dataset_kernel.get(),
tf_record_dataset_context.get(),
&tf_record_dataset));
core::ScopedUnref scoped_unref(tf_record_dataset);
TF_EXPECT_OK(VerifyShapesCompatible(tf_record_dataset->output_shapes(),
test_case.expected_output_shapes));
}
TEST_P(ParameterizedTFRecordDatasetOpTest, Cardinality) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = GetParam();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
TF_ASSERT_OK(CreateTestFiles(test_case));
std::unique_ptr<OpKernel> tf_record_dataset_kernel;
TF_ASSERT_OK(CreateTFRecordDatasetOpKernel(&tf_record_dataset_kernel));
int64 num_files = test_case.filenames.size();
Tensor filenames =
CreateTensor<string>(TensorShape({num_files}), test_case.filenames);
Tensor compression_type = CreateTensor<string>(
TensorShape({}), {ToString(test_case.compression_type)});
Tensor buffer_size =
CreateTensor<int64>(TensorShape({}), {test_case.buffer_size});
gtl::InlinedVector<TensorValue, 4> inputs{TensorValue(&filenames),
TensorValue(&compression_type),
TensorValue(&buffer_size)};
std::unique_ptr<OpKernelContext> tf_record_dataset_context;
TF_ASSERT_OK(CreateTFRecordDatasetContext(
tf_record_dataset_kernel.get(), &inputs, &tf_record_dataset_context));
DatasetBase* tf_record_dataset;
TF_ASSERT_OK(CreateDataset(tf_record_dataset_kernel.get(),
tf_record_dataset_context.get(),
&tf_record_dataset));
core::ScopedUnref scoped_unref(tf_record_dataset);
EXPECT_EQ(tf_record_dataset->Cardinality(), test_case.expected_cardinality);
}
TEST_P(ParameterizedTFRecordDatasetOpTest, IteratorOutputDtypes) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = GetParam();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
TF_ASSERT_OK(CreateTestFiles(test_case));
std::unique_ptr<OpKernel> tf_record_dataset_kernel;
TF_ASSERT_OK(CreateTFRecordDatasetOpKernel(&tf_record_dataset_kernel));
int64 num_files = test_case.filenames.size();
Tensor filenames =
CreateTensor<string>(TensorShape({num_files}), test_case.filenames);
Tensor compression_type = CreateTensor<string>(
TensorShape({}), {ToString(test_case.compression_type)});
Tensor buffer_size =
CreateTensor<int64>(TensorShape({}), {test_case.buffer_size});
gtl::InlinedVector<TensorValue, 4> inputs{TensorValue(&filenames),
TensorValue(&compression_type),
TensorValue(&buffer_size)};
std::unique_ptr<OpKernelContext> tf_record_dataset_context;
TF_ASSERT_OK(CreateTFRecordDatasetContext(
tf_record_dataset_kernel.get(), &inputs, &tf_record_dataset_context));
DatasetBase* tf_record_dataset;
TF_ASSERT_OK(CreateDataset(tf_record_dataset_kernel.get(),
tf_record_dataset_context.get(),
&tf_record_dataset));
core::ScopedUnref scoped_unref(tf_record_dataset);
std::unique_ptr<IteratorContext> iterator_ctx;
TF_ASSERT_OK(
CreateIteratorContext(tf_record_dataset_context.get(), &iterator_ctx));
std::unique_ptr<IteratorBase> iterator;
TF_ASSERT_OK(tf_record_dataset->MakeIterator(iterator_ctx.get(),
kIteratorPrefix, &iterator));
TF_EXPECT_OK(VerifyTypesMatch(iterator->output_dtypes(),
test_case.expected_output_dtypes));
}
TEST_P(ParameterizedTFRecordDatasetOpTest, IteratorOutputShapes) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = GetParam();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
TF_ASSERT_OK(CreateTestFiles(test_case));
std::unique_ptr<OpKernel> tf_record_dataset_kernel;
TF_ASSERT_OK(CreateTFRecordDatasetOpKernel(&tf_record_dataset_kernel));
int64 num_files = test_case.filenames.size();
Tensor filenames =
CreateTensor<string>(TensorShape({num_files}), test_case.filenames);
Tensor compression_type = CreateTensor<string>(
TensorShape({}), {ToString(test_case.compression_type)});
Tensor buffer_size =
CreateTensor<int64>(TensorShape({}), {test_case.buffer_size});
gtl::InlinedVector<TensorValue, 4> inputs{TensorValue(&filenames),
TensorValue(&compression_type),
TensorValue(&buffer_size)};
std::unique_ptr<OpKernelContext> tf_record_dataset_context;
TF_ASSERT_OK(CreateTFRecordDatasetContext(
tf_record_dataset_kernel.get(), &inputs, &tf_record_dataset_context));
DatasetBase* tf_record_dataset;
TF_ASSERT_OK(CreateDataset(tf_record_dataset_kernel.get(),
tf_record_dataset_context.get(),
&tf_record_dataset));
core::ScopedUnref scoped_unref(tf_record_dataset);
std::unique_ptr<IteratorContext> iterator_ctx;
TF_ASSERT_OK(
CreateIteratorContext(tf_record_dataset_context.get(), &iterator_ctx));
std::unique_ptr<IteratorBase> iterator;
TF_ASSERT_OK(tf_record_dataset->MakeIterator(iterator_ctx.get(),
kIteratorPrefix, &iterator));
TF_EXPECT_OK(VerifyShapesCompatible(iterator->output_shapes(),
test_case.expected_output_shapes));
}
TEST_P(ParameterizedTFRecordDatasetOpTest, IteratorOutputPrefix) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = GetParam();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
TF_ASSERT_OK(CreateTestFiles(test_case));
std::unique_ptr<OpKernel> tf_record_dataset_kernel;
TF_ASSERT_OK(CreateTFRecordDatasetOpKernel(&tf_record_dataset_kernel));
int64 num_files = test_case.filenames.size();
Tensor filenames =
CreateTensor<string>(TensorShape({num_files}), test_case.filenames);
Tensor compression_type = CreateTensor<string>(
TensorShape({}), {ToString(test_case.compression_type)});
Tensor buffer_size =
CreateTensor<int64>(TensorShape({}), {test_case.buffer_size});
gtl::InlinedVector<TensorValue, 4> inputs{TensorValue(&filenames),
TensorValue(&compression_type),
TensorValue(&buffer_size)};
std::unique_ptr<OpKernelContext> tf_record_dataset_context;
TF_ASSERT_OK(CreateTFRecordDatasetContext(
tf_record_dataset_kernel.get(), &inputs, &tf_record_dataset_context));
DatasetBase* tf_record_dataset;
TF_ASSERT_OK(CreateDataset(tf_record_dataset_kernel.get(),
tf_record_dataset_context.get(),
&tf_record_dataset));
core::ScopedUnref scoped_unref(tf_record_dataset);
std::unique_ptr<IteratorContext> iterator_ctx;
TF_ASSERT_OK(
CreateIteratorContext(tf_record_dataset_context.get(), &iterator_ctx));
std::unique_ptr<IteratorBase> iterator;
TF_ASSERT_OK(tf_record_dataset->MakeIterator(iterator_ctx.get(),
kIteratorPrefix, &iterator));
EXPECT_EQ(iterator->prefix(),
name_utils::IteratorPrefix(TFRecordDatasetOp::kDatasetType,
kIteratorPrefix));
}
TEST_P(ParameterizedTFRecordDatasetOpTest, Roundtrip) {
int thread_num = 2, cpu_num = 2;
TestCase test_case = GetParam();
TF_ASSERT_OK(InitThreadPool(thread_num));
TF_ASSERT_OK(InitFunctionLibraryRuntime({}, cpu_num));
TF_ASSERT_OK(CreateTestFiles(test_case));
std::unique_ptr<OpKernel> tf_record_dataset_kernel;
TF_ASSERT_OK(CreateTFRecordDatasetOpKernel(&tf_record_dataset_kernel));
int64 num_files = test_case.filenames.size();
Tensor filenames =
CreateTensor<string>(TensorShape({num_files}), test_case.filenames);
Tensor compression_type = CreateTensor<string>(
TensorShape({}), {ToString(test_case.compression_type)});
Tensor buffer_size =
CreateTensor<int64>(TensorShape({}), {test_case.buffer_size});
gtl::InlinedVector<TensorValue, 4> inputs{TensorValue(&filenames),
TensorValue(&compression_type),
TensorValue(&buffer_size)};
std::unique_ptr<OpKernelContext> tf_record_dataset_context;
TF_ASSERT_OK(CreateTFRecordDatasetContext(
tf_record_dataset_kernel.get(), &inputs, &tf_record_dataset_context));
DatasetBase* tf_record_dataset;
TF_ASSERT_OK(CreateDataset(tf_record_dataset_kernel.get(),
tf_record_dataset_context.get(),
&tf_record_dataset));
core::ScopedUnref scoped_unref(tf_record_dataset);
std::unique_ptr<IteratorContext> iterator_ctx;
TF_ASSERT_OK(
CreateIteratorContext(tf_record_dataset_context.get(), &iterator_ctx));
std::unique_ptr<IteratorBase> iterator;
TF_ASSERT_OK(tf_record_dataset->MakeIterator(iterator_ctx.get(),
kIteratorPrefix, &iterator));
std::unique_ptr<SerializationContext> serialization_ctx;
TF_ASSERT_OK(CreateSerializationContext(&serialization_ctx));
bool end_of_sequence = false;
std::vector<Tensor> out_tensors;
int cur_iteration = 0;
const std::vector<int>& breakpoints = test_case.breakpoints;
for (int breakpoint : breakpoints) {
VariantTensorData data;
VariantTensorDataWriter writer(&data);
TF_EXPECT_OK(iterator->Save(serialization_ctx.get(), &writer));
TF_EXPECT_OK(writer.Flush());
VariantTensorDataReader reader(&data);
TF_EXPECT_OK(RestoreIterator(iterator_ctx.get(), &reader, kIteratorPrefix,
*tf_record_dataset, &iterator));
while (cur_iteration <= breakpoint) {
std::vector<Tensor> next;
TF_EXPECT_OK(
iterator->GetNext(iterator_ctx.get(), &next, &end_of_sequence));
out_tensors.insert(out_tensors.end(), next.begin(), next.end());
cur_iteration++;
}
}
TF_EXPECT_OK(ExpectEqual(out_tensors, test_case.expected_outputs,
/*compare_order*/ true));
}
INSTANTIATE_TEST_SUITE_P(TFRecordDatasetOpTest,
ParameterizedTFRecordDatasetOpTest,
::testing::ValuesIn(std::vector<TestCase>(
{TestCase1(), TestCase2(), TestCase3()})));
} // namespace
} // namespace data
} // namespace tensorflow