| # Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
| # |
| # Licensed under the Apache License, Version 2.0 (the "License"); |
| # you may not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================== |
| """Tests for the experimental input pipeline ops.""" |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| import itertools |
| |
| from absl.testing import parameterized |
| import numpy as np |
| |
| from tensorflow.python.data.ops import dataset_ops |
| from tensorflow.python.framework import errors |
| from tensorflow.python.framework import sparse_tensor |
| from tensorflow.python.ops import array_ops |
| from tensorflow.python.ops import sparse_ops |
| from tensorflow.python.platform import test |
| |
| |
| class InterleaveDatasetTest(test.TestCase, parameterized.TestCase): |
| |
| def _interleave(self, lists, cycle_length, block_length): |
| num_open = 0 |
| |
| # `all_iterators` acts as a queue of iterators over each element of `lists`. |
| all_iterators = [iter(l) for l in lists] |
| |
| # `open_iterators` are the iterators whose elements are currently being |
| # interleaved. |
| open_iterators = [] |
| for i in range(cycle_length): |
| if all_iterators: |
| open_iterators.append(all_iterators.pop(0)) |
| num_open += 1 |
| else: |
| open_iterators.append(None) |
| |
| while num_open or all_iterators: |
| for i in range(cycle_length): |
| if open_iterators[i] is None: |
| if all_iterators: |
| open_iterators[i] = all_iterators.pop(0) |
| num_open += 1 |
| else: |
| continue |
| for _ in range(block_length): |
| try: |
| yield next(open_iterators[i]) |
| except StopIteration: |
| open_iterators[i] = None |
| num_open -= 1 |
| break |
| |
| def testPythonImplementation(self): |
| input_lists = [[4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6], |
| [4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6]] |
| |
| # Cycle length 1 acts like `Dataset.flat_map()`. |
| expected_elements = itertools.chain(*input_lists) |
| for expected, produced in zip( |
| expected_elements, self._interleave(input_lists, 1, 1)): |
| self.assertEqual(expected, produced) |
| |
| # Cycle length > 1. |
| expected_elements = [4, 5, 4, 5, 4, 5, 4, |
| 5, 5, 6, 6, # NOTE(mrry): When we cycle back |
| # to a list and are already at |
| # the end of that list, we move |
| # on to the next element. |
| 4, 6, 4, 6, 4, 6, 4, 6, 5, 6, 5, 6, 5, 6, 5, 6, 5] |
| for expected, produced in zip( |
| expected_elements, self._interleave(input_lists, 2, 1)): |
| self.assertEqual(expected, produced) |
| |
| # Cycle length > 1 and block length > 1. |
| expected_elements = [4, 4, 4, 5, 5, 5, 4, 5, 5, 6, 6, 6, 4, 4, 4, 6, 6, 6, |
| 4, 5, 5, 5, 6, 6, 6, 5, 5, 6, 6, 6] |
| for expected, produced in zip( |
| expected_elements, self._interleave(input_lists, 2, 3)): |
| self.assertEqual(expected, produced) |
| |
| # Cycle length > len(input_values). |
| expected_elements = [4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, 4, 4, 5, 5, 6, 6, |
| 4, 4, 5, 5, 6, 6, 5, 6, 6, 5, 6, 6] |
| for expected, produced in zip( |
| expected_elements, self._interleave(input_lists, 7, 2)): |
| self.assertEqual(expected, produced) |
| |
| @parameterized.named_parameters( |
| ("1", np.int64([4, 5, 6]), 1, 3, None), |
| ("2", np.int64([4, 5, 6]), 1, 3, 1), |
| ("3", np.int64([4, 5, 6]), 2, 1, None), |
| ("4", np.int64([4, 5, 6]), 2, 1, 1), |
| ("5", np.int64([4, 5, 6]), 2, 1, 2), |
| ("6", np.int64([4, 5, 6]), 2, 3, None), |
| ("7", np.int64([4, 5, 6]), 2, 3, 1), |
| ("8", np.int64([4, 5, 6]), 2, 3, 2), |
| ("9", np.int64([4, 5, 6]), 7, 2, None), |
| ("10", np.int64([4, 5, 6]), 7, 2, 1), |
| ("11", np.int64([4, 5, 6]), 7, 2, 3), |
| ("12", np.int64([4, 5, 6]), 7, 2, 5), |
| ("13", np.int64([4, 5, 6]), 7, 2, 7), |
| ("14", np.int64([]), 2, 3, None), |
| ("15", np.int64([0, 0, 0]), 2, 3, None), |
| ("16", np.int64([4, 0, 6]), 2, 3, None), |
| ("17", np.int64([4, 0, 6]), 2, 3, 1), |
| ("18", np.int64([4, 0, 6]), 2, 3, 2), |
| ) |
| def testInterleaveDataset(self, input_values, cycle_length, block_length, |
| num_parallel_calls): |
| count = 2 |
| dataset = dataset_ops.Dataset.from_tensor_slices(input_values).repeat( |
| count).interleave( |
| lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x), |
| cycle_length, block_length, num_parallel_calls) |
| get_next = dataset.make_one_shot_iterator().get_next() |
| |
| def repeat(values, count): |
| result = [] |
| for value in values: |
| result.append([value] * value) |
| return result * count |
| |
| with self.cached_session() as sess: |
| for expected_element in self._interleave( |
| repeat(input_values, count), cycle_length, block_length): |
| self.assertEqual(expected_element, sess.run(get_next)) |
| |
| for _ in range(2): |
| with self.assertRaises(errors.OutOfRangeError): |
| sess.run(get_next) |
| |
| @parameterized.named_parameters( |
| ("1", np.float32([1., np.nan, 2., np.nan, 3.]), 1, 3, None), |
| ("2", np.float32([1., np.nan, 2., np.nan, 3.]), 1, 3, 1), |
| ("3", np.float32([1., np.nan, 2., np.nan, 3.]), 2, 1, None), |
| ("4", np.float32([1., np.nan, 2., np.nan, 3.]), 2, 1, 1), |
| ("5", np.float32([1., np.nan, 2., np.nan, 3.]), 2, 1, 2), |
| ("6", np.float32([1., np.nan, 2., np.nan, 3.]), 2, 3, None), |
| ("7", np.float32([1., np.nan, 2., np.nan, 3.]), 2, 3, 1), |
| ("8", np.float32([1., np.nan, 2., np.nan, 3.]), 2, 3, 2), |
| ("9", np.float32([1., np.nan, 2., np.nan, 3.]), 7, 2, None), |
| ("10", np.float32([1., np.nan, 2., np.nan, 3.]), 7, 2, 1), |
| ("11", np.float32([1., np.nan, 2., np.nan, 3.]), 7, 2, 3), |
| ("12", np.float32([1., np.nan, 2., np.nan, 3.]), 7, 2, 5), |
| ("13", np.float32([1., np.nan, 2., np.nan, 3.]), 7, 2, 7), |
| ) |
| def testInterleaveErrorDataset(self, |
| input_values, |
| cycle_length, |
| block_length, |
| num_parallel_calls): |
| dataset = dataset_ops.Dataset.from_tensor_slices(input_values).map( |
| lambda x: array_ops.check_numerics(x, "message")).interleave( |
| dataset_ops.Dataset.from_tensors, cycle_length, block_length, |
| num_parallel_calls) |
| get_next = dataset.make_one_shot_iterator().get_next() |
| |
| with self.cached_session() as sess: |
| for value in input_values: |
| if np.isnan(value): |
| with self.assertRaises(errors.InvalidArgumentError): |
| sess.run(get_next) |
| else: |
| self.assertEqual(value, sess.run(get_next)) |
| with self.assertRaises(errors.OutOfRangeError): |
| sess.run(get_next) |
| |
| def testSparse(self): |
| |
| def _map_fn(i): |
| return sparse_tensor.SparseTensorValue( |
| indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2]) |
| |
| def _interleave_fn(x): |
| return dataset_ops.Dataset.from_tensor_slices( |
| sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values)) |
| |
| iterator = ( |
| dataset_ops.Dataset.range(10).map(_map_fn).interleave( |
| _interleave_fn, cycle_length=1).make_initializable_iterator()) |
| init_op = iterator.initializer |
| get_next = iterator.get_next() |
| |
| with self.cached_session() as sess: |
| sess.run(init_op) |
| for i in range(10): |
| for j in range(2): |
| expected = [i, 0] if j % 2 == 0 else [0, -i] |
| self.assertAllEqual(expected, sess.run(get_next)) |
| with self.assertRaises(errors.OutOfRangeError): |
| sess.run(get_next) |
| |
| |
| if __name__ == "__main__": |
| test.main() |