blob: e7e51df65ed37b1a549b6bf647f096c08394dc41 [file] [log] [blame]
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the experimental input pipeline ops."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import itertools
from absl.testing import parameterized
import numpy as np
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.platform import test
class InterleaveDatasetTest(test.TestCase, parameterized.TestCase):
def _interleave(self, lists, cycle_length, block_length):
num_open = 0
# `all_iterators` acts as a queue of iterators over each element of `lists`.
all_iterators = [iter(l) for l in lists]
# `open_iterators` are the iterators whose elements are currently being
# interleaved.
open_iterators = []
for i in range(cycle_length):
if all_iterators:
open_iterators.append(all_iterators.pop(0))
num_open += 1
else:
open_iterators.append(None)
while num_open or all_iterators:
for i in range(cycle_length):
if open_iterators[i] is None:
if all_iterators:
open_iterators[i] = all_iterators.pop(0)
num_open += 1
else:
continue
for _ in range(block_length):
try:
yield next(open_iterators[i])
except StopIteration:
open_iterators[i] = None
num_open -= 1
break
def testPythonImplementation(self):
input_lists = [[4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6],
[4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6]]
# Cycle length 1 acts like `Dataset.flat_map()`.
expected_elements = itertools.chain(*input_lists)
for expected, produced in zip(
expected_elements, self._interleave(input_lists, 1, 1)):
self.assertEqual(expected, produced)
# Cycle length > 1.
expected_elements = [4, 5, 4, 5, 4, 5, 4,
5, 5, 6, 6, # NOTE(mrry): When we cycle back
# to a list and are already at
# the end of that list, we move
# on to the next element.
4, 6, 4, 6, 4, 6, 4, 6, 5, 6, 5, 6, 5, 6, 5, 6, 5]
for expected, produced in zip(
expected_elements, self._interleave(input_lists, 2, 1)):
self.assertEqual(expected, produced)
# Cycle length > 1 and block length > 1.
expected_elements = [4, 4, 4, 5, 5, 5, 4, 5, 5, 6, 6, 6, 4, 4, 4, 6, 6, 6,
4, 5, 5, 5, 6, 6, 6, 5, 5, 6, 6, 6]
for expected, produced in zip(
expected_elements, self._interleave(input_lists, 2, 3)):
self.assertEqual(expected, produced)
# Cycle length > len(input_values).
expected_elements = [4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6,
4, 4, 5, 5, 6, 6, 5, 6, 6, 5, 6, 6]
for expected, produced in zip(
expected_elements, self._interleave(input_lists, 7, 2)):
self.assertEqual(expected, produced)
@parameterized.named_parameters(
("1", np.int64([4, 5, 6]), 1, 3, None),
("2", np.int64([4, 5, 6]), 1, 3, 1),
("3", np.int64([4, 5, 6]), 2, 1, None),
("4", np.int64([4, 5, 6]), 2, 1, 1),
("5", np.int64([4, 5, 6]), 2, 1, 2),
("6", np.int64([4, 5, 6]), 2, 3, None),
("7", np.int64([4, 5, 6]), 2, 3, 1),
("8", np.int64([4, 5, 6]), 2, 3, 2),
("9", np.int64([4, 5, 6]), 7, 2, None),
("10", np.int64([4, 5, 6]), 7, 2, 1),
("11", np.int64([4, 5, 6]), 7, 2, 3),
("12", np.int64([4, 5, 6]), 7, 2, 5),
("13", np.int64([4, 5, 6]), 7, 2, 7),
("14", np.int64([]), 2, 3, None),
("15", np.int64([0, 0, 0]), 2, 3, None),
("16", np.int64([4, 0, 6]), 2, 3, None),
("17", np.int64([4, 0, 6]), 2, 3, 1),
("18", np.int64([4, 0, 6]), 2, 3, 2),
)
def testInterleaveDataset(self, input_values, cycle_length, block_length,
num_parallel_calls):
count = 2
dataset = dataset_ops.Dataset.from_tensor_slices(input_values).repeat(
count).interleave(
lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x),
cycle_length, block_length, num_parallel_calls)
get_next = dataset.make_one_shot_iterator().get_next()
def repeat(values, count):
result = []
for value in values:
result.append([value] * value)
return result * count
with self.cached_session() as sess:
for expected_element in self._interleave(
repeat(input_values, count), cycle_length, block_length):
self.assertEqual(expected_element, sess.run(get_next))
for _ in range(2):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@parameterized.named_parameters(
("1", np.float32([1., np.nan, 2., np.nan, 3.]), 1, 3, None),
("2", np.float32([1., np.nan, 2., np.nan, 3.]), 1, 3, 1),
("3", np.float32([1., np.nan, 2., np.nan, 3.]), 2, 1, None),
("4", np.float32([1., np.nan, 2., np.nan, 3.]), 2, 1, 1),
("5", np.float32([1., np.nan, 2., np.nan, 3.]), 2, 1, 2),
("6", np.float32([1., np.nan, 2., np.nan, 3.]), 2, 3, None),
("7", np.float32([1., np.nan, 2., np.nan, 3.]), 2, 3, 1),
("8", np.float32([1., np.nan, 2., np.nan, 3.]), 2, 3, 2),
("9", np.float32([1., np.nan, 2., np.nan, 3.]), 7, 2, None),
("10", np.float32([1., np.nan, 2., np.nan, 3.]), 7, 2, 1),
("11", np.float32([1., np.nan, 2., np.nan, 3.]), 7, 2, 3),
("12", np.float32([1., np.nan, 2., np.nan, 3.]), 7, 2, 5),
("13", np.float32([1., np.nan, 2., np.nan, 3.]), 7, 2, 7),
)
def testInterleaveErrorDataset(self,
input_values,
cycle_length,
block_length,
num_parallel_calls):
dataset = dataset_ops.Dataset.from_tensor_slices(input_values).map(
lambda x: array_ops.check_numerics(x, "message")).interleave(
dataset_ops.Dataset.from_tensors, cycle_length, block_length,
num_parallel_calls)
get_next = dataset.make_one_shot_iterator().get_next()
with self.cached_session() as sess:
for value in input_values:
if np.isnan(value):
with self.assertRaises(errors.InvalidArgumentError):
sess.run(get_next)
else:
self.assertEqual(value, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
def testSparse(self):
def _map_fn(i):
return sparse_tensor.SparseTensorValue(
indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2])
def _interleave_fn(x):
return dataset_ops.Dataset.from_tensor_slices(
sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values))
iterator = (
dataset_ops.Dataset.range(10).map(_map_fn).interleave(
_interleave_fn, cycle_length=1).make_initializable_iterator())
init_op = iterator.initializer
get_next = iterator.get_next()
with self.cached_session() as sess:
sess.run(init_op)
for i in range(10):
for j in range(2):
expected = [i, 0] if j % 2 == 0 else [0, -i]
self.assertAllEqual(expected, sess.run(get_next))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
if __name__ == "__main__":
test.main()