Refactor ParallelInterleaveTest to be parameterized
diff --git a/tensorflow/python/data/experimental/kernel_tests/parallel_interleave_test.py b/tensorflow/python/data/experimental/kernel_tests/parallel_interleave_test.py
index 14d3c9d..6638c0e 100644
--- a/tensorflow/python/data/experimental/kernel_tests/parallel_interleave_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/parallel_interleave_test.py
@@ -17,7 +17,6 @@
from __future__ import division
from __future__ import print_function
-import itertools
import math
import threading
import time
@@ -39,7 +38,6 @@
from tensorflow.python.platform import test
-# TODO(feihugis): refactor this test to be parameterized.
class ParallelInterleaveTest(test_base.DatasetTestBase, parameterized.TestCase):
def setUp(self):
@@ -117,49 +115,36 @@
num_open -= 1
break
- @combinations.generate(test_base.default_test_combinations())
- def testPythonImplementation(self):
- input_lists = [[4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6],
- [4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6]]
-
- # Cycle length 1 acts like `Dataset.flat_map()`.
- expected_elements = itertools.chain(*input_lists)
- for expected, produced in zip(expected_elements,
- self._interleave(input_lists, 1, 1)):
- self.assertEqual(expected, produced)
-
- # Cycle length > 1.
- expected_elements = [
- 4, 5, 4, 5, 4, 5, 4, 5, 5, 6, 6, 4, 6, 4, 6, 4, 6, 4, 6, 5, 6, 5, 6, 5,
- 6, 5, 6, 5, 6, 6
- ]
+ @combinations.generate(
+ combinations.times(
+ combinations.combine(
+ input_lists=[[[4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6],
+ [4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6]]],
+ expected_elements=[[4, 4, 4, 4, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6,
+ 4, 4, 4, 4, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6]],
+ cycle_length=1, block_length=1) +
+ combinations.combine(
+ input_lists=[[[4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6],
+ [4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6]]],
+ expected_elements=[[4, 5, 4, 5, 4, 5, 4, 5, 5, 6, 6, 4, 6, 4, 6,
+ 4, 6, 4, 6, 5, 6, 5, 6, 5, 6, 5, 6, 5, 6, 6]],
+ cycle_length=2, block_length=1) +
+ combinations.combine(
+ input_lists=[[[4] * 4, [5] * 5, [6] * 6] * 2],
+ expected_elements=[[4, 4, 5, 5, 4, 4, 5, 5, 5, 6, 6, 4, 4, 6, 6,
+ 4, 4, 6, 6, 5, 5, 6, 6, 5, 5, 6, 6, 5, 6, 6]],
+ cycle_length=2, block_length=2) +
+ combinations.combine(
+ input_lists=[[[4, 4, 4, 4], [], [6, 6, 6, 6, 6, 6], [4, 4, 4, 4],
+ [], [6, 6, 6, 6, 6, 6]]],
+ expected_elements=[[4, 4, 6, 4, 6, 4, 6, 6, 4, 6, 4, 6, 4, 4, 6,
+ 6, 6, 6, 6, 6]],
+ cycle_length=2, block_length=1)))
+ def testPythonImplementation(
+ self, input_lists, expected_elements, cycle_length, block_length):
for index, (expected, produced) in enumerate(
- zip_longest(expected_elements, self._interleave(input_lists, 2, 1))):
- self.assertEqual(expected, produced, "Values differ at %s. %s != %s" %
- (index, expected, produced))
-
- @combinations.generate(test_base.default_test_combinations())
- def testPythonImplementationBlockLength(self):
- input_lists = [[4] * 4, [5] * 5, [6] * 6] * 2
- expected_elements = [
- 4, 4, 5, 5, 4, 4, 5, 5, 5, 6, 6, 4, 4, 6, 6, 4, 4, 6, 6, 5, 5, 6, 6, 5,
- 5, 6, 6, 5, 6, 6
- ]
- for index, (expected, produced) in enumerate(
- zip_longest(expected_elements, self._interleave(input_lists, 2, 2))):
- self.assertEqual(expected, produced, "Values differ at %s. %s != %s" %
- (index, expected, produced))
-
- @combinations.generate(test_base.default_test_combinations())
- def testPythonImplementationEmptyLists(self):
- input_lists = [[4, 4, 4, 4], [], [6, 6, 6, 6, 6, 6], [4, 4, 4, 4], [],
- [6, 6, 6, 6, 6, 6]]
-
- expected_elements = [
- 4, 4, 6, 4, 6, 4, 6, 6, 4, 6, 4, 6, 4, 4, 6, 6, 6, 6, 6, 6
- ]
- for index, (expected, produced) in enumerate(
- zip_longest(expected_elements, self._interleave(input_lists, 2, 1))):
+ zip_longest(expected_elements,
+ self._interleave(input_lists, cycle_length, block_length))):
self.assertEqual(expected, produced, "Values differ at %s. %s != %s" %
(index, expected, produced))
@@ -172,7 +157,12 @@
for i in range(4, 7):
self.write_coordination_events[i].set()
- def _testSingleThreaded(self, sloppy=False, prefetch_input_elements=0):
+ @combinations.generate(
+ combinations.times(
+ test_base.default_test_combinations(),
+ combinations.combine(
+ sloppy=[False, True], prefetch_input_elements=[0, 1])))
+ def testSingleThreaded(self, sloppy, prefetch_input_elements):
# cycle_length=1,block_length=1 acts like `Dataset.interleave()` and
# `Dataset.flat_map()` and is single-threaded. No synchronization required.
self.skipTest("b/131722904")
@@ -194,22 +184,6 @@
self.evaluate(next_element())
@combinations.generate(test_base.default_test_combinations())
- def testSingleThreaded(self):
- self._testSingleThreaded()
-
- @combinations.generate(test_base.default_test_combinations())
- def testSingleThreadedSloppy(self):
- self._testSingleThreaded(sloppy=True)
-
- @combinations.generate(test_base.default_test_combinations())
- def testSingleThreadedPrefetch1Itr(self):
- self._testSingleThreaded(prefetch_input_elements=1)
-
- @combinations.generate(test_base.default_test_combinations())
- def testSingleThreadedPrefetch1ItrSloppy(self):
- self._testSingleThreaded(prefetch_input_elements=1, sloppy=True)
-
- @combinations.generate(test_base.default_test_combinations())
def testSingleThreadedRagged(self):
# Tests a sequence with wildly different elements per iterator.
self.skipTest("b/131722904")
@@ -237,7 +211,11 @@
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element())
- def _testTwoThreadsNoContention(self, sloppy=False):
+ @combinations.generate(
+ combinations.times(
+ test_base.default_test_combinations(),
+ combinations.combine(sloppy=[False, True])))
+ def testTwoThreadsNoContention(self, sloppy):
# num_threads > 1.
# Explicit coordination should result in `Dataset.interleave()` behavior
self.skipTest("b/131722904")
@@ -268,15 +246,11 @@
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element())
- @combinations.generate(test_base.default_test_combinations())
- def testTwoThreadsNoContention(self):
- self._testTwoThreadsNoContention()
-
- @combinations.generate(test_base.default_test_combinations())
- def testTwoThreadsNoContentionSloppy(self):
- self._testTwoThreadsNoContention(sloppy=True)
-
- def _testTwoThreadsNoContentionWithRaces(self, sloppy=False):
+ @combinations.generate(
+ combinations.times(
+ test_base.default_test_combinations(),
+ combinations.combine(sloppy=[False, True])))
+ def testTwoThreadsNoContentionWithRaces(self, sloppy):
"""Tests where all the workers race in producing elements.
Note: this is in contrast with the previous test which carefully sequences
@@ -317,15 +291,11 @@
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element())
- @combinations.generate(test_base.default_test_combinations())
- def testTwoThreadsNoContentionWithRaces(self):
- self._testTwoThreadsNoContentionWithRaces()
-
- @combinations.generate(test_base.default_test_combinations())
- def testTwoThreadsNoContentionWithRacesSloppy(self):
- self._testTwoThreadsNoContentionWithRaces(sloppy=True)
-
- def _testTwoThreadsNoContentionBlockLength(self, sloppy=False):
+ @combinations.generate(
+ combinations.times(
+ test_base.default_test_combinations(),
+ combinations.combine(sloppy=[False, True])))
+ def testTwoThreadsNoContentionBlockLength(self, sloppy):
# num_threads > 1.
# Explicit coordination should result in `Dataset.interleave()` behavior
self.skipTest("b/131722904")
@@ -356,15 +326,11 @@
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element())
- @combinations.generate(test_base.default_test_combinations())
- def testTwoThreadsNoContentionBlockLength(self):
- self._testTwoThreadsNoContentionBlockLength()
-
- @combinations.generate(test_base.default_test_combinations())
- def testTwoThreadsNoContentionBlockLengthSloppy(self):
- self._testTwoThreadsNoContentionBlockLength(sloppy=True)
-
- def _testTwoThreadsNoContentionWithRacesAndBlocking(self, sloppy=False):
+ @combinations.generate(
+ combinations.times(
+ test_base.default_test_combinations(),
+ combinations.combine(sloppy=[False, True])))
+ def testTwoThreadsNoContentionWithRacesAndBlocking(self, sloppy):
"""Tests where all the workers race in producing elements.
Note: this is in contrast with the previous test which carefully sequences
@@ -406,15 +372,11 @@
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element())
- @combinations.generate(test_base.default_test_combinations())
- def testTwoThreadsNoContentionWithRacesAndBlocking(self):
- self._testTwoThreadsNoContentionWithRacesAndBlocking()
-
- @combinations.generate(test_base.default_test_combinations())
- def testTwoThreadsNoContentionWithRacesAndBlockingSloppy(self):
- self._testTwoThreadsNoContentionWithRacesAndBlocking(sloppy=True)
-
- def _testEmptyInput(self, sloppy=False):
+ @combinations.generate(
+ combinations.times(
+ test_base.default_test_combinations(),
+ combinations.combine(sloppy=[False, True])))
+ def testEmptyInput(self, sloppy):
# Empty input.
self._clear_coordination_events()
next_element = self.getNext(
@@ -428,15 +390,11 @@
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element())
- @combinations.generate(test_base.default_test_combinations())
- def testEmptyInput(self):
- self._testEmptyInput()
-
- @combinations.generate(test_base.default_test_combinations())
- def testEmptyInputSloppy(self):
- self._testEmptyInput(sloppy=True)
-
- def _testNonEmptyInputIntoEmptyOutputs(self, sloppy=False):
+ @combinations.generate(
+ combinations.times(
+ test_base.default_test_combinations(),
+ combinations.combine(sloppy=[False, True])))
+ def _testNonEmptyInputIntoEmptyOutputs(self, sloppy):
# Non-empty input leading to empty output.
self._clear_coordination_events()
next_element = self.getNext(
@@ -450,15 +408,12 @@
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element())
- @combinations.generate(test_base.default_test_combinations())
- def testNonEmptyInputIntoEmptyOutputs(self):
- self._testNonEmptyInputIntoEmptyOutputs()
-
- @combinations.generate(test_base.default_test_combinations())
- def testNonEmptyInputIntoEmptyOutputsSloppy(self):
- self._testNonEmptyInputIntoEmptyOutputs(sloppy=True)
-
- def _testPartiallyEmptyOutputs(self, sloppy=False, prefetch_input_elements=1):
+ @combinations.generate(
+ combinations.times(
+ test_base.default_test_combinations(),
+ combinations.combine(
+ sloppy=[False, True], prefetch_input_elements=[1, 0])))
+ def testPartiallyEmptyOutputs(self, sloppy, prefetch_input_elements):
race_indices = {2, 8, 14} # Sequence points when sloppy mode has race conds
# Mixture of non-empty and empty interleaved datasets.
self.skipTest("b/131722904")
@@ -491,14 +446,6 @@
actual_element))
@combinations.generate(test_base.default_test_combinations())
- def testPartiallyEmptyOutputs(self):
- self._testPartiallyEmptyOutputs()
-
- @combinations.generate(test_base.default_test_combinations())
- def testPartiallyEmptyOutputsSloppy(self):
- self._testPartiallyEmptyOutputs(sloppy=True, prefetch_input_elements=0)
-
- @combinations.generate(test_base.default_test_combinations())
def testDelayedOutputSloppy(self):
# Explicitly control the sequence of events to ensure we correctly avoid
# head-of-line blocking.
@@ -558,7 +505,11 @@
with self.assertRaises(errors.OutOfRangeError):
self.evaluate(next_element())
- def _testEarlyExit(self, sloppy=False):
+ @combinations.generate(
+ combinations.times(
+ test_base.default_test_combinations(),
+ combinations.combine(sloppy=[False, True])))
+ def testEarlyExit(self, sloppy):
# Exiting without consuming all input should not block
self.skipTest("b/131722904")
self._clear_coordination_events()
@@ -582,15 +533,11 @@
self.read_coordination_events[i].acquire()
self.write_coordination_events[i].set()
- @combinations.generate(test_base.default_test_combinations())
- def testEarlyExit(self):
- self._testEarlyExit()
-
- @combinations.generate(test_base.default_test_combinations())
- def testEarlyExitSloppy(self):
- self._testEarlyExit(sloppy=True)
-
- def _testTooManyReaders(self, sloppy=False):
+ @combinations.generate(
+ combinations.times(
+ test_base.default_test_combinations(),
+ combinations.combine(sloppy=[False, True])))
+ def testTooManyReaders(self, sloppy=False):
def interleave_fn(x):
dataset = dataset_ops.Dataset.from_tensors(x)
@@ -612,14 +559,6 @@
self.assertItemsEqual(output_values, expected_values)
@combinations.generate(test_base.default_test_combinations())
- def testTooManyReaders(self):
- self._testTooManyReaders()
-
- @combinations.generate(test_base.default_test_combinations())
- def testTooManyReadersSloppy(self):
- self._testTooManyReaders(sloppy=True)
-
- @combinations.generate(test_base.default_test_combinations())
def testSparse(self):
def _map_fn(i):
return sparse_tensor.SparseTensor(