Add tests for DirectedInterleaveDatasetOp
diff --git a/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc b/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc
index 575b2e4..eea5ae6 100644
--- a/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc
@@ -34,7 +34,7 @@
/* static */ constexpr const char* const
DirectedInterleaveDatasetOp::kOutputShapes;
/* static */ constexpr const char* const
- DirectedInterleaveDatasetOp::kNumDatasets;
+ DirectedInterleaveDatasetOp::kNumInputDatasets;
class DirectedInterleaveDatasetOp::Dataset : public DatasetBase {
public:
@@ -192,8 +192,8 @@
if (selector_input_impl_) {
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, selector_input_impl_));
} else {
- TF_RETURN_IF_ERROR(writer->WriteScalar(
- full_name(strings::StrCat("data_input_impl_empty[", i, "]")), ""));
+ TF_RETURN_IF_ERROR(
+ writer->WriteScalar(full_name("selector_input_impl_empty"), ""));
}
for (size_t i = 0; i < data_input_impls_.size(); ++i) {
const auto& data_input_impl = data_input_impls_[i];
@@ -207,55 +207,53 @@
}
return Status::OK();
}
- return Status::OK();
- }
- Status
- RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override {
- mutex_lock l(mu_);
- if (!reader->Contains(full_name("selector_input_impl_empty"))) {
- TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, selector_input_impl_));
- } else {
- selector_input_impl_.reset();
- }
- for (size_t i = 0; i < data_input_impls_.size(); ++i) {
- if (!reader->Contains(
- full_name(strings::StrCat("data_input_impl_empty[", i, "]")))) {
- TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, data_input_impls_[i]));
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override {
+ mutex_lock l(mu_);
+ if (!reader->Contains(full_name("selector_input_impl_empty"))) {
+ TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, selector_input_impl_));
} else {
- data_input_impls_[i].reset();
+ selector_input_impl_.reset();
}
+ for (size_t i = 0; i < data_input_impls_.size(); ++i) {
+ if (!reader->Contains(
+ full_name(strings::StrCat("data_input_impl_empty[", i, "]")))) {
+ TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, data_input_impls_[i]));
+ } else {
+ data_input_impls_[i].reset();
+ }
+ }
+ return Status::OK();
}
- return Status::OK();
- }
- private:
- mutex mu_;
- std::unique_ptr<IteratorBase> selector_input_impl_ TF_GUARDED_BY(mu_);
- std::vector<std::unique_ptr<IteratorBase>> data_input_impls_
- TF_GUARDED_BY(mu_);
- int64 num_active_inputs_ TF_GUARDED_BY(mu_);
-};
+ private:
+ mutex mu_;
+ std::unique_ptr<IteratorBase> selector_input_impl_ TF_GUARDED_BY(mu_);
+ std::vector<std::unique_ptr<IteratorBase>> data_input_impls_
+ TF_GUARDED_BY(mu_);
+ int64 num_active_inputs_ TF_GUARDED_BY(mu_);
+ };
-static PartialTensorShape MostSpecificCompatibleShape(
- const PartialTensorShape& ts1, const PartialTensorShape& ts2) {
- PartialTensorShape output_tensorshape;
- if (ts1.dims() != ts2.dims() || ts1.unknown_rank() || ts2.unknown_rank())
+ static PartialTensorShape MostSpecificCompatibleShape(
+ const PartialTensorShape& ts1, const PartialTensorShape& ts2) {
+ PartialTensorShape output_tensorshape;
+ if (ts1.dims() != ts2.dims() || ts1.unknown_rank() || ts2.unknown_rank())
+ return output_tensorshape;
+ auto dims1 = ts1.dim_sizes();
+ auto dims2 = ts2.dim_sizes();
+ for (int d = 0; d < ts1.dims(); ++d) {
+ if (dims1[d] == dims2[d])
+ output_tensorshape.Concatenate(dims1[d]);
+ else
+ output_tensorshape.Concatenate(-1);
+ }
return output_tensorshape;
- auto dims1 = ts1.dim_sizes();
- auto dims2 = ts2.dim_sizes();
- for (int d = 0; d < ts1.dims(); ++d) {
- if (dims1[d] == dims2[d])
- output_tensorshape.Concatenate(dims1[d]);
- else
- output_tensorshape.Concatenate(-1);
}
- return output_tensorshape;
-}
-const DatasetBase* const selector_input_;
-const std::vector<DatasetBase*> data_inputs_;
-std::vector<PartialTensorShape> output_shapes_;
+ const DatasetBase* const selector_input_;
+ const std::vector<DatasetBase*> data_inputs_;
+ std::vector<PartialTensorShape> output_shapes_;
}; // namespace experimental
DirectedInterleaveDatasetOp::DirectedInterleaveDatasetOp(
@@ -302,6 +300,6 @@
Name("ExperimentalDirectedInterleaveDataset").Device(DEVICE_CPU),
DirectedInterleaveDatasetOp);
} // namespace
+} // namespace experimental
} // namespace data
} // namespace tensorflow
-} // namespace tensorflow
diff --git a/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.h b/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.h
index 03ee8ed..3dc689e 100644
--- a/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.h
+++ b/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.h
@@ -29,7 +29,7 @@
static constexpr const char* const kDataInputDatasets = "data_input_datasets";
static constexpr const char* const kOutputTypes = "output_types";
static constexpr const char* const kOutputShapes = "output_shapes";
- static constexpr const char* const kNumDatasets = "N";
+ static constexpr const char* const kNumInputDatasets = "N";
explicit DirectedInterleaveDatasetOp(OpKernelConstruction* ctx);
diff --git a/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op_test.cc b/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op_test.cc
new file mode 100644
index 0000000..7aed1d7
--- /dev/null
+++ b/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op_test.cc
@@ -0,0 +1,364 @@
+/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.h"
+
+#include "tensorflow/core/kernels/data/dataset_test_base.h"
+
+namespace tensorflow {
+namespace data {
+namespace experimental {
+namespace {
+
+constexpr char kNodeName[] = "directed_interleave_dataset";
+
+class DirectedInterleaveDatasetParams : public DatasetParams {
+ public:
+ template <typename S, typename T>
+ DirectedInterleaveDatasetParams(S selector_input_dataset_params,
+ std::vector<T> input_dataset_params_vec,
+ DataTypeVector output_dtypes,
+ std::vector<PartialTensorShape> output_shapes,
+ int num_input_datasets, string node_name)
+ : DatasetParams(std::move(output_dtypes), std::move(output_shapes),
+ std::move(node_name)),
+ num_input_datasets_(num_input_datasets) {
+ input_dataset_params_.push_back(
+ absl::make_unique<S>(selector_input_dataset_params));
+ for (auto input_dataset_params : input_dataset_params_vec) {
+ input_dataset_params_.push_back(
+ absl::make_unique<T>(input_dataset_params));
+ }
+
+ if (!input_dataset_params_vec.empty()) {
+ iterator_prefix_ = name_utils::IteratorPrefix(
+ input_dataset_params_vec[0].dataset_type(),
+ input_dataset_params_vec[0].iterator_prefix());
+ }
+ }
+
+ std::vector<Tensor> GetInputTensors() const override { return {}; }
+
+ Status GetInputNames(std::vector<string>* input_names) const override {
+ input_names->clear();
+ input_names->emplace_back(
+ DirectedInterleaveDatasetOp::kSelectorInputDataset);
+ for (int i = 0; i < num_input_datasets_; ++i) {
+ input_names->emplace_back(absl::StrCat(
+ DirectedInterleaveDatasetOp::kDataInputDatasets, "_", i));
+ }
+ return Status::OK();
+ }
+
+ Status GetAttributes(AttributeVector* attr_vector) const override {
+ attr_vector->clear();
+ attr_vector->emplace_back(DirectedInterleaveDatasetOp::kOutputTypes,
+ output_dtypes_);
+ attr_vector->emplace_back(DirectedInterleaveDatasetOp::kOutputShapes,
+ output_shapes_);
+ attr_vector->emplace_back(DirectedInterleaveDatasetOp::kNumInputDatasets,
+ num_input_datasets_);
+ return Status::OK();
+ }
+
+ string dataset_type() const override {
+ return DirectedInterleaveDatasetOp::kDatasetType;
+ }
+
+ private:
+ int32 num_input_datasets_;
+};
+
+class DirectedInterleaveDatasetOpTest : public DatasetOpsTestBase {};
+
+DirectedInterleaveDatasetParams AlternateInputsParams() {
+ auto selector_input_dataset_params = TensorSliceDatasetParams(
+ /*components=*/{CreateTensor<int64>(TensorShape{6}, {0, 1, 0, 1, 0, 1})},
+ /*node_name=*/"tensor_slice");
+ return DirectedInterleaveDatasetParams(
+ selector_input_dataset_params,
+ /*input_dataset_params_vec=*/
+ std::vector<RangeDatasetParams>{RangeDatasetParams(0, 3, 1),
+ RangeDatasetParams(10, 13, 1)},
+ /*output_dtypes=*/{DT_INT64, DT_INT64},
+ /*output_shapes=*/{PartialTensorShape({}), PartialTensorShape({})},
+ /*num_input_datasets=*/2,
+ /*node_name=*/kNodeName);
+}
+
+DirectedInterleaveDatasetParams SelectExhaustedInputParams() {
+ auto selector_input_dataset_params = TensorSliceDatasetParams(
+ /*components=*/{CreateTensor<int64>(TensorShape{6}, {0, 1, 0, 1, 0, 1})},
+ /*node_name=*/"tensor_slice");
+ return DirectedInterleaveDatasetParams(
+ selector_input_dataset_params,
+ /*input_dataset_params_vec=*/
+ std::vector<RangeDatasetParams>{RangeDatasetParams(0, 2, 1),
+ RangeDatasetParams(10, 13, 1)},
+ /*output_dtypes=*/{DT_INT64, DT_INT64},
+ /*output_shapes=*/{PartialTensorShape({}), PartialTensorShape({})},
+ /*num_input_datasets=*/2,
+ /*node_name=*/kNodeName);
+}
+
+DirectedInterleaveDatasetParams OneInputDatasetParams() {
+ auto selector_input_dataset_params = TensorSliceDatasetParams(
+ /*components=*/{CreateTensor<int64>(TensorShape{6}, {0, 0, 0, 0, 0, 0})},
+ /*node_name=*/"tensor_slice");
+ return DirectedInterleaveDatasetParams(
+ selector_input_dataset_params,
+ /*input_dataset_params_vec=*/
+ std::vector<RangeDatasetParams>{RangeDatasetParams(0, 6, 1)},
+ /*output_dtypes=*/{DT_INT64},
+ /*output_shapes=*/{PartialTensorShape({})},
+ /*num_input_datasets=*/1,
+ /*node_name=*/kNodeName);
+}
+
+DirectedInterleaveDatasetParams ZeroInputDatasetParams() {
+ auto selector_input_dataset_params = TensorSliceDatasetParams(
+ /*components=*/{CreateTensor<int64>(TensorShape{6}, {0, 0, 0, 0, 0, 0})},
+ /*node_name=*/"tensor_slice");
+ return DirectedInterleaveDatasetParams(
+ selector_input_dataset_params,
+ /*input_dataset_params_vec=*/std::vector<RangeDatasetParams>{},
+ /*output_dtypes=*/{DT_INT64},
+ /*output_shapes=*/{PartialTensorShape({})},
+ /*num_input_datasets=*/0,
+ /*node_name=*/kNodeName);
+}
+
+// Test case: `num_input_datasets` is larger than the size of
+// `input_dataset_params_vec`.
+DirectedInterleaveDatasetParams LargeNumInputDatasetsParams() {
+ auto selector_input_dataset_params = TensorSliceDatasetParams(
+ /*components=*/{CreateTensor<int64>(TensorShape{6}, {0, 1, 0, 1, 0, 1})},
+ /*node_name=*/"tensor_slice");
+ return DirectedInterleaveDatasetParams(
+ selector_input_dataset_params,
+ /*input_dataset_params_vec=*/
+ std::vector<RangeDatasetParams>{RangeDatasetParams(0, 3, 1),
+ RangeDatasetParams(10, 13, 1)},
+ /*output_dtypes=*/{DT_INT64, DT_INT64},
+ /*output_shapes=*/{PartialTensorShape({}), PartialTensorShape({})},
+ /*num_input_datasets=*/5,
+ /*node_name=*/kNodeName);
+}
+
+// Test case: `num_input_datasets` is smaller than the size of
+// `input_dataset_params_vec`.
+DirectedInterleaveDatasetParams SmallNumInputDatasetsParams() {
+ auto selector_input_dataset_params = TensorSliceDatasetParams(
+ /*components=*/{CreateTensor<int64>(TensorShape{6}, {0, 1, 0, 1, 0, 1})},
+ /*node_name=*/"tensor_slice");
+ return DirectedInterleaveDatasetParams(
+ selector_input_dataset_params,
+ /*input_dataset_params_vec=*/
+ std::vector<RangeDatasetParams>{RangeDatasetParams(0, 3, 1),
+ RangeDatasetParams(10, 13, 1)},
+ /*output_dtypes=*/{DT_INT64, DT_INT64},
+ /*output_shapes=*/{PartialTensorShape({}), PartialTensorShape({})},
+ /*num_input_datasets=*/1,
+ /*node_name=*/kNodeName);
+}
+
+DirectedInterleaveDatasetParams InvalidSelectorOuputDataType() {
+ auto selector_input_dataset_params = TensorSliceDatasetParams(
+ /*components=*/{CreateTensor<int32>(TensorShape{6}, {0, 1, 0, 1, 0, 1})},
+ /*node_name=*/"tensor_slice");
+ return DirectedInterleaveDatasetParams(
+ selector_input_dataset_params,
+ /*input_dataset_params_vec=*/
+ std::vector<RangeDatasetParams>{RangeDatasetParams(0, 3, 1),
+ RangeDatasetParams(10, 13, 1)},
+ /*output_dtypes=*/{DT_INT64, DT_INT64},
+ /*output_shapes=*/{PartialTensorShape({}), PartialTensorShape({})},
+ /*num_input_datasets=*/2,
+ /*node_name=*/kNodeName);
+}
+
+DirectedInterleaveDatasetParams InvalidSelectorOuputShape() {
+ auto selector_input_dataset_params = TensorSliceDatasetParams(
+ /*components=*/{CreateTensor<int64>(TensorShape{6, 1},
+ {0, 1, 0, 1, 0, 1})},
+ /*node_name=*/"tensor_slice");
+ return DirectedInterleaveDatasetParams(
+ selector_input_dataset_params,
+ /*input_dataset_params_vec=*/
+ std::vector<RangeDatasetParams>{RangeDatasetParams(0, 3, 1),
+ RangeDatasetParams(10, 13, 1)},
+ /*output_dtypes=*/{DT_INT64, DT_INT64},
+ /*output_shapes=*/{PartialTensorShape({}), PartialTensorShape({})},
+ /*num_input_datasets=*/2,
+ /*node_name=*/kNodeName);
+}
+
+DirectedInterleaveDatasetParams InvalidSelectorValues() {
+ auto selector_input_dataset_params = TensorSliceDatasetParams(
+ /*components=*/{CreateTensor<int64>(TensorShape{6}, {2, 1, 0, 1, 0, 1})},
+ /*node_name=*/"tensor_slice");
+ return DirectedInterleaveDatasetParams(
+ selector_input_dataset_params,
+ /*input_dataset_params_vec=*/
+ std::vector<RangeDatasetParams>{RangeDatasetParams(0, 3, 1),
+ RangeDatasetParams(10, 13, 1)},
+ /*output_dtypes=*/{DT_INT64, DT_INT64},
+ /*output_shapes=*/{PartialTensorShape({}), PartialTensorShape({})},
+ /*num_input_datasets=*/2,
+ /*node_name=*/kNodeName);
+}
+
+DirectedInterleaveDatasetParams InvalidInputDatasetsDataType() {
+ auto selector_input_dataset_params = TensorSliceDatasetParams(
+ /*components=*/{CreateTensor<int64>(TensorShape{6}, {0, 1, 0, 1, 0, 1})},
+ /*node_name=*/"tensor_slice");
+ return DirectedInterleaveDatasetParams(
+ selector_input_dataset_params,
+ /*input_dataset_params_vec=*/
+ std::vector<RangeDatasetParams>{
+ RangeDatasetParams(0, 3, 1, {DT_INT32}),
+ RangeDatasetParams(10, 13, 1, {DT_INT64})},
+ /*output_dtypes=*/{DT_INT64, DT_INT64},
+ /*output_shapes=*/{PartialTensorShape({}), PartialTensorShape({})},
+ /*num_input_datasets=*/2,
+ /*node_name=*/kNodeName);
+}
+
+std::vector<GetNextTestCase<DirectedInterleaveDatasetParams>>
+GetNextTestCases() {
+ return {{/*dataset_params=*/AlternateInputsParams(),
+ /*expected_outputs=*/{CreateTensors<int64>(
+ TensorShape({}), {{0}, {10}, {1}, {11}, {2}, {12}})}},
+ {/*dataset_params=*/SelectExhaustedInputParams(),
+ /*expected_outputs=*/{CreateTensors<int64>(
+ TensorShape({}), {{0}, {10}, {1}, {11}, {12}})}},
+ {/*dataset_params=*/OneInputDatasetParams(),
+ /*expected_outputs=*/{CreateTensors<int64>(
+ TensorShape({}), {{0}, {1}, {2}, {3}, {4}, {5}})}},
+ {/*dataset_params=*/LargeNumInputDatasetsParams(),
+ /*expected_outputs=*/{CreateTensors<int64>(
+ TensorShape({}), {{0}, {10}, {1}, {11}, {2}, {12}})}},
+ {/*dataset_params=*/SmallNumInputDatasetsParams(),
+ /*expected_outputs=*/{CreateTensors<int64>(
+ TensorShape({}), {{0}, {10}, {1}, {11}, {2}, {12}})}}};
+}
+
+ITERATOR_GET_NEXT_TEST_P(DirectedInterleaveDatasetOpTest,
+ DirectedInterleaveDatasetParams, GetNextTestCases())
+
+TEST_F(DirectedInterleaveDatasetOpTest, DatasetNodeName) {
+ auto dataset_params = AlternateInputsParams();
+ TF_ASSERT_OK(Initialize(dataset_params));
+ TF_ASSERT_OK(CheckDatasetNodeName(dataset_params.node_name()));
+}
+
+TEST_F(DirectedInterleaveDatasetOpTest, DatasetTypeString) {
+ auto dataset_params = AlternateInputsParams();
+ TF_ASSERT_OK(Initialize(dataset_params));
+ TF_ASSERT_OK(CheckDatasetTypeString(
+ name_utils::OpName(DirectedInterleaveDatasetOp::kDatasetType)));
+}
+
+TEST_F(DirectedInterleaveDatasetOpTest, DatasetOutputDtypes) {
+ auto dataset_params = AlternateInputsParams();
+ TF_ASSERT_OK(Initialize(dataset_params));
+ TF_ASSERT_OK(CheckDatasetOutputDtypes({DT_INT64}));
+}
+
+TEST_F(DirectedInterleaveDatasetOpTest, DatasetOutputShapes) {
+ auto dataset_params = AlternateInputsParams();
+ TF_ASSERT_OK(Initialize(dataset_params));
+ TF_ASSERT_OK(CheckDatasetOutputShapes({PartialTensorShape({})}));
+}
+
+TEST_F(DirectedInterleaveDatasetOpTest, Cardinality) {
+ auto dataset_params = AlternateInputsParams();
+ TF_ASSERT_OK(Initialize(dataset_params));
+ TF_ASSERT_OK(CheckDatasetCardinality(kUnknownCardinality));
+}
+
+TEST_F(DirectedInterleaveDatasetOpTest, IteratorOutputDtypes) {
+ auto dataset_params = AlternateInputsParams();
+ TF_ASSERT_OK(Initialize(dataset_params));
+ TF_ASSERT_OK(CheckIteratorOutputDtypes({DT_INT64}));
+}
+
+TEST_F(DirectedInterleaveDatasetOpTest, IteratorOutputShapes) {
+ auto dataset_params = AlternateInputsParams();
+ TF_ASSERT_OK(Initialize(dataset_params));
+ TF_ASSERT_OK(CheckIteratorOutputShapes({PartialTensorShape({})}));
+}
+
+TEST_F(DirectedInterleaveDatasetOpTest, IteratorPrefix) {
+ auto dataset_params = AlternateInputsParams();
+ TF_ASSERT_OK(Initialize(dataset_params));
+ TF_ASSERT_OK(CheckIteratorPrefix(
+ name_utils::IteratorPrefix(DirectedInterleaveDatasetOp::kDatasetType,
+ dataset_params.iterator_prefix())));
+}
+
+std::vector<IteratorSaveAndRestoreTestCase<DirectedInterleaveDatasetParams>>
+IteratorSaveAndRestoreTestCases() {
+ return {
+ {/*dataset_params=*/AlternateInputsParams(),
+ /*breakpoints=*/{0, 5, 8},
+ /*expected_outputs=*/
+ CreateTensors<int64>(TensorShape{}, {{0}, {10}, {1}, {11}, {2}, {12}}),
+ /*compare_order=*/true},
+ {/*dataset_params=*/SelectExhaustedInputParams(),
+ /*breakpoints=*/{0, 4, 8},
+ /*expected_outputs=*/
+ CreateTensors<int64>(TensorShape{}, {{0}, {10}, {1}, {11}, {12}}),
+ /*compare_order=*/true},
+ {/*dataset_params=*/OneInputDatasetParams(),
+ /*breakpoints=*/{0, 5, 8},
+ /*expected_outputs=*/
+ {CreateTensors<int64>(TensorShape({}), {{0}, {1}, {2}, {3}, {4}, {5}})}},
+ {/*dataset_params=*/LargeNumInputDatasetsParams(),
+ /*breakpoints=*/{0, 5, 8},
+ /*expected_outputs=*/
+ {CreateTensors<int64>(TensorShape({}),
+ {{0}, {10}, {1}, {11}, {2}, {12}})}},
+ {/*dataset_params=*/SmallNumInputDatasetsParams(),
+ /*breakpoints=*/{0, 5, 8},
+ /*expected_outputs=*/
+ {CreateTensors<int64>(TensorShape({}),
+ {{0}, {10}, {1}, {11}, {2}, {12}})}}};
+}
+
+ITERATOR_SAVE_AND_RESTORE_TEST_P(DirectedInterleaveDatasetOpTest,
+ DirectedInterleaveDatasetParams,
+ IteratorSaveAndRestoreTestCases())
+
+TEST_F(DirectedInterleaveDatasetOpTest, InvalidArguments) {
+ std::vector<DirectedInterleaveDatasetParams> invalid_params_vec = {
+ InvalidSelectorOuputDataType(), InvalidSelectorOuputShape(),
+ InvalidInputDatasetsDataType(), ZeroInputDatasetParams()};
+ for (auto& dataset_params : invalid_params_vec) {
+ EXPECT_EQ(Initialize(dataset_params).code(),
+ tensorflow::error::INVALID_ARGUMENT);
+ }
+}
+
+TEST_F(DirectedInterleaveDatasetOpTest, InvalidSelectorValues) {
+ auto dataset_params = InvalidSelectorValues();
+ TF_ASSERT_OK(Initialize(dataset_params));
+ bool end_of_sequence = false;
+ std::vector<Tensor> next;
+ EXPECT_EQ(
+ iterator_->GetNext(iterator_ctx_.get(), &next, &end_of_sequence).code(),
+ tensorflow::error::INVALID_ARGUMENT);
+}
+
+} // namespace
+} // namespace experimental
+} // namespace data
+} // namespace tensorflow