blob: 348228b4f922e83adb9c92e3e25fa78ae6ffa1a2 [file] [log] [blame]
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for `tf.data.Dataset`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import warnings
from absl.testing import parameterized
import numpy as np
from tensorflow.core.framework import graph_pb2
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import optional_ops
from tensorflow.python.data.ops import readers
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import structure
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
@test_util.run_all_in_graph_and_eager_modes
class DatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
def testAsSerializedGraph(self):
dataset = dataset_ops.Dataset.range(10)
graph = graph_pb2.GraphDef().FromString(
self.evaluate(dataset._as_serialized_graph()))
self.assertTrue(any([node.op != "RangeDataset" for node in graph.node]))
def testAsFunctionWithMap(self):
if not context.executing_eagerly():
self.skipTest("Only works executing eagerly")
with ops.device("CPU"):
original_dataset = dataset_ops.Dataset.range(5).map(lambda x: x * 2)
fn = original_dataset._trace_variant_creation()
variant = fn()
revived_dataset = dataset_ops._VariantDataset(
variant, original_dataset.element_spec)
self.assertDatasetProduces(revived_dataset, range(0, 10, 2))
def testAsFunctionWithMapInFlatMap(self):
if not context.executing_eagerly():
self.skipTest("Only works executing eagerly")
with ops.device("CPU"):
original_dataset = dataset_ops.Dataset.range(5).flat_map(
lambda x: dataset_ops.Dataset.range(5).map(lambda x: x * 2))
fn = original_dataset._trace_variant_creation()
variant = fn()
revived_dataset = dataset_ops._VariantDataset(
variant, original_dataset.element_spec)
self.assertDatasetProduces(revived_dataset, list(original_dataset))
@staticmethod
def make_apply_fn(dataset):
def apply_fn(dataset):
def _apply_fn(dataset):
return dataset.cache()
return dataset.apply(_apply_fn)
return apply_fn
@staticmethod
def make_gen():
def gen():
yield 42
return gen
@staticmethod
def make_interleave_fn(dataset, num_parallel_calls=None):
def interleave_fn(dataset):
return dataset.interleave(
lambda x: dataset_ops.Dataset.range(0),
cycle_length=2,
num_parallel_calls=num_parallel_calls)
return interleave_fn
@parameterized.named_parameters(
("FixedLengthRecord",
lambda: readers.FixedLengthRecordDataset("", 42)),
("FromGenerator",
lambda: dataset_ops.Dataset.from_generator(
DatasetTest.make_gen(), dtypes.int32),
1),
("FromTensors", lambda: dataset_ops.Dataset.from_tensors([42])),
("FromTensorSlices", lambda: dataset_ops.Dataset.from_tensors([42])),
("Range", lambda: dataset_ops.Dataset.range(10)),
("TextLine", lambda: readers.TextLineDataset("")),
("TFRecord", lambda: readers.TFRecordDataset(""), 1),
)
def testDatasetSimpleSourceInputs(self, dataset_fn, num_inputs=0):
self.assertLen(dataset_fn()._inputs(), num_inputs)
@test_util.run_v1_only("deprecated API, no eager or V2 test coverage")
def testDatasetComplexSourceInputs(self):
dataset_fn = dataset_ops.Dataset.from_sparse_tensor_slices(
sparse_tensor.SparseTensor(
indices=np.array([[0, 0], [1, 0], [2, 0]]),
values=np.array([0, 0, 0]),
dense_shape=np.array([3, 1])))
self.assertEmpty(dataset_fn._inputs())
@parameterized.named_parameters(
("Batch",
lambda x: x.batch(10),
lambda: dataset_ops.Dataset.range(0)),
("Cache",
lambda x: x.cache(),
lambda: dataset_ops.Dataset.range(0)),
("Filter",
lambda x: x.filter(lambda x: True),
lambda: dataset_ops.Dataset.range(0)),
("FlatMap",
lambda x: x.flat_map(lambda x: dataset_ops.Dataset.range(0)),
lambda: dataset_ops.Dataset.range(0)),
("Map",
lambda x: x.map(lambda x: x),
lambda: dataset_ops.Dataset.range(0)),
("PaddedBatch",
lambda x: x.padded_batch(10, []),
lambda: dataset_ops.Dataset.range(0)),
("ParallelMap",
lambda x: x.map(lambda x: x, num_parallel_calls=2),
lambda: dataset_ops.Dataset.range(0)),
("Repeat",
lambda x: x.repeat(),
lambda: dataset_ops.Dataset.range(0)),
("Shuffle",
lambda x: x.shuffle(10),
lambda: dataset_ops.Dataset.range(0)),
("Skip",
lambda x: x.skip(1),
lambda: dataset_ops.Dataset.range(0)),
("Take",
lambda x: x.take(1),
lambda: dataset_ops.Dataset.range(0)),
("Window",
lambda x: x.window(10),
lambda: dataset_ops.Dataset.range(0)),
)
def testUnaryTransformationInputs(self, dataset_fn, input_dataset_fn):
input_dataset = input_dataset_fn()
self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())
def testUnaryTransformationInputsApply(self):
input_dataset = dataset_ops.Dataset.range(0)
dataset_fn = self.make_apply_fn(dataset_ops.Dataset.range(0))
self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())
@parameterized.named_parameters(
("ParallelInterleave",
[lambda: dataset_ops.Dataset.range(0), 2],
lambda: dataset_ops.Dataset.range(0)),
("Interleave",
[lambda: dataset_ops.Dataset.range(0), None],
lambda: dataset_ops.Dataset.range(0)),
)
def testUnaryTransformationInputsWithInterleaveFn(
self, interleave_fn_args, input_dataset_fn):
input_dataset = input_dataset_fn()
dataset_fn = self.make_interleave_fn(*interleave_fn_args)
self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())
def testNoWarnings(self):
with test.mock.patch.object(warnings, "warn") as mock_log:
dataset_fn = self.make_interleave_fn(dataset_ops.Dataset.range(10))
dataset_fn(dataset_ops.Dataset.range(10))
self.assertEmpty(mock_log.call_args_list)
@parameterized.named_parameters(
("Concatenate", lambda x, y: x.concatenate(y),
lambda: dataset_ops.Dataset.range(0),
lambda: dataset_ops.Dataset.range(1)))
def testBinaryTransformationInputs(self, dataset_fn, input1_fn, input2_fn):
input1 = input1_fn()
input2 = input2_fn()
self.assertEqual([input1, input2], dataset_fn(input1, input2)._inputs())
@parameterized.named_parameters(
("ZipOne",
dataset_ops.Dataset.zip,
lambda: (dataset_ops.Dataset.range(0))),
("ZipNest",
dataset_ops.Dataset.zip,
lambda: (dataset_ops.Dataset.range(0),
(dataset_ops.Dataset.range(1),
dataset_ops.Dataset.range(2)))),
("ZipTuple",
dataset_ops.Dataset.zip,
lambda: (dataset_ops.Dataset.range(0),
dataset_ops.Dataset.range(1))),
)
def testVariadicTransformationInputs(self, dataset_fn, input_datasets_fn):
input_datasets = input_datasets_fn()
self.assertEqual(
nest.flatten(input_datasets),
dataset_fn(input_datasets)._inputs())
def testFunctions(self):
dataset = dataset_ops.Dataset.range(5).map(lambda x: x * 2)
self.assertLen(dataset._functions(), 1)
def testCollectInputs(self):
ds1 = dataset_ops.Dataset.range(0)
ds2 = ds1.concatenate(ds1)
ds3 = dataset_ops.Dataset.zip((ds2, ds1, ds2))
inputs = []
queue = [ds3]
while queue:
ds = queue[0]
queue = queue[1:]
queue.extend(ds._inputs())
inputs.append(ds)
self.assertEqual(5, inputs.count(ds1))
self.assertEqual(2, inputs.count(ds2))
self.assertEqual(1, inputs.count(ds3))
# pylint: disable=g-long-lambda
@parameterized.named_parameters(
("Tensor", lambda: constant_op.constant(37.0),
tensor_spec.TensorSpec([], dtypes.float32)),
("SparseTensor", lambda: sparse_tensor.SparseTensor(
indices=[[0]],
values=constant_op.constant([0], dtype=dtypes.int32),
dense_shape=[1]), sparse_tensor.SparseTensorSpec([1], dtypes.int32)),
("Nest", lambda: {
"a": constant_op.constant(37.0),
"b": (constant_op.constant(["Foo"]), constant_op.constant("Bar"))
}, {
"a":
tensor_spec.TensorSpec([], dtypes.float32),
"b": (
tensor_spec.TensorSpec([1], dtypes.string),
tensor_spec.TensorSpec([], dtypes.string),
)
}),
("Dataset", lambda: dataset_ops.Dataset.from_tensor_slices(
constant_op.constant([1, 2, 3])),
dataset_ops.DatasetSpec(tensor_spec.TensorSpec([], dtypes.int32))),
("Optional", lambda: optional_ops.Optional.from_value(37.0),
optional_ops.OptionalSpec(
tensor_spec.TensorSpec([], dtypes.float32))),
)
def testDatasetSpec(self, tf_value_fn, expected_element_structure):
dataset = dataset_ops.Dataset.from_tensors(0).map(lambda _: tf_value_fn())
dataset_structure = structure.type_spec_from_value(dataset)
self.assertIsInstance(dataset_structure, dataset_ops.DatasetSpec)
self.assertTrue(
structure.are_compatible(
dataset_ops.get_structure(dataset), expected_element_structure))
self.assertEqual([dtypes.variant],
structure.get_flat_tensor_types(dataset_structure))
self.assertEqual([tensor_shape.TensorShape([])],
structure.get_flat_tensor_shapes(dataset_structure))
# Assert that the `Dataset` survives a round-trip via _from_tensor_list()
# and _to_tensor_list().
round_trip_dataset = dataset_structure._from_tensor_list(
dataset_structure._to_tensor_list(dataset))
value = tf_value_fn()
if isinstance(value, dataset_ops.Dataset):
self.assertDatasetsEqual(value, dataset.flat_map(lambda x: x))
elif isinstance(value, optional_ops.Optional):
self.assertDatasetProduces(
round_trip_dataset.map(lambda opt: opt.get_value()),
[self.evaluate(value.get_value())],
requires_initialization=True)
else:
self.assertDatasetProduces(
round_trip_dataset, [self.evaluate(tf_value_fn())],
requires_initialization=True)
@test_util.run_v1_only("graph mode specific, no eager or V2 test coverage")
def testSkipEagerSameGraphErrorOneShot(self):
dataset = dataset_ops.Dataset.range(10)
with ops.Graph().as_default():
with self.assertRaisesRegexp(ValueError, "must be from the same graph"):
dataset = dataset.batch(2)
@test_util.run_v1_only("graph mode specific, no eager or V2 test coverage")
def testSkipEagerSameGraphErrorOneShotSimple(self):
dataset = dataset_ops.Dataset.range(10)
with ops.Graph().as_default():
with test.mock.patch.object(logging, "warning") as mock_log:
_ = dataset_ops.make_one_shot_iterator(dataset)
self.assertRegexpMatches(
str(mock_log.call_args), "Please ensure that all datasets in the "
"pipeline are created in the same graph as the iterator.")
@test_util.run_v1_only("graph mode specific, no eager or V2 test coverage")
def testSkipEagerSameGraphErrorInitializable(self):
dataset = dataset_ops.Dataset.range(10)
with ops.Graph().as_default():
with self.assertRaisesRegexp(ValueError, "must be from the same graph"):
dataset = dataset.batch(2)
@parameterized.named_parameters(
("Async", context.ASYNC),
("Sync", context.SYNC),
)
def testDatasetEagerIteration(self, execution_mode):
with context.eager_mode(), context.execution_mode(execution_mode):
val = 0
dataset = dataset_ops.Dataset.range(10)
for foo in dataset:
self.assertEqual(val, foo.numpy())
val += 1
def testDatasetAsFunctionArgument(self):
@def_function.function
def _uses_dataset(d):
accumulator = array_ops.zeros([], dtype=dtypes.int64)
for value in d:
accumulator += value
return accumulator
with ops.device("CPU"):
first_dataset = dataset_ops.Dataset.range(10)
self.assertEqual(45, self.evaluate(_uses_dataset(first_dataset)))
second_dataset = dataset_ops.Dataset.range(11)
self.assertEqual(55, self.evaluate(_uses_dataset(second_dataset)))
first_concrete = _uses_dataset.get_concrete_function(first_dataset)
# The dataset should not be a captured input
self.assertEmpty(first_concrete.graph.captures)
# The two datasets have the same structure and so should re-use a trace.
self.assertIs(first_concrete,
_uses_dataset.get_concrete_function(second_dataset))
# With a different structure we should use a different trace.
self.assertIsNot(
first_concrete,
_uses_dataset.get_concrete_function(
dataset_ops.Dataset.zip((first_dataset, second_dataset))))
def testLimitedRetracing(self):
trace_count = [0]
@def_function.function
def f(ds):
trace_count[0] += 1
counter = np.int64(0)
for elem in ds:
counter += elem
return counter
dataset = dataset_ops.Dataset.range(5)
dataset2 = dataset_ops.Dataset.range(10)
for _ in range(10):
self.assertEqual(self.evaluate(f(dataset)), 10)
self.assertEqual(self.evaluate(f(dataset2)), 45)
self.assertEqual(trace_count[0], 1)
if __name__ == "__main__":
test.main()