Refactor ParallelInterleaveTest to be parameterized
diff --git a/tensorflow/python/data/experimental/kernel_tests/parallel_interleave_test.py b/tensorflow/python/data/experimental/kernel_tests/parallel_interleave_test.py
index 14d3c9d..6638c0e 100644
--- a/tensorflow/python/data/experimental/kernel_tests/parallel_interleave_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/parallel_interleave_test.py
@@ -17,7 +17,6 @@
 from __future__ import division
 from __future__ import print_function
 
-import itertools
 import math
 import threading
 import time
@@ -39,7 +38,6 @@
 from tensorflow.python.platform import test
 
 
-# TODO(feihugis): refactor this test to be parameterized.
 class ParallelInterleaveTest(test_base.DatasetTestBase, parameterized.TestCase):
 
   def setUp(self):
@@ -117,49 +115,36 @@
             num_open -= 1
             break
 
-  @combinations.generate(test_base.default_test_combinations())
-  def testPythonImplementation(self):
-    input_lists = [[4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6],
-                   [4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6]]
-
-    # Cycle length 1 acts like `Dataset.flat_map()`.
-    expected_elements = itertools.chain(*input_lists)
-    for expected, produced in zip(expected_elements,
-                                  self._interleave(input_lists, 1, 1)):
-      self.assertEqual(expected, produced)
-
-    # Cycle length > 1.
-    expected_elements = [
-        4, 5, 4, 5, 4, 5, 4, 5, 5, 6, 6, 4, 6, 4, 6, 4, 6, 4, 6, 5, 6, 5, 6, 5,
-        6, 5, 6, 5, 6, 6
-    ]
+  @combinations.generate(
+      combinations.times(
+          combinations.combine(
+              input_lists=[[[4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6],
+                            [4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6]]],
+              expected_elements=[[4, 4, 4, 4, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6,
+                                  4, 4, 4, 4, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6]],
+              cycle_length=1, block_length=1) +
+          combinations.combine(
+              input_lists=[[[4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6],
+                            [4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6]]],
+              expected_elements=[[4, 5, 4, 5, 4, 5, 4, 5, 5, 6, 6, 4, 6, 4, 6,
+                                  4, 6, 4, 6, 5, 6, 5, 6, 5, 6, 5, 6, 5, 6, 6]],
+              cycle_length=2, block_length=1) +
+          combinations.combine(
+              input_lists=[[[4] * 4, [5] * 5, [6] * 6] * 2],
+              expected_elements=[[4, 4, 5, 5, 4, 4, 5, 5, 5, 6, 6, 4, 4, 6, 6,
+                                  4, 4, 6, 6, 5, 5, 6, 6, 5, 5, 6, 6, 5, 6, 6]],
+              cycle_length=2, block_length=2) +
+          combinations.combine(
+              input_lists=[[[4, 4, 4, 4], [], [6, 6, 6, 6, 6, 6], [4, 4, 4, 4],
+                            [], [6, 6, 6, 6, 6, 6]]],
+              expected_elements=[[4, 4, 6, 4, 6, 4, 6, 6, 4, 6, 4, 6, 4, 4, 6,
+                                  6, 6, 6, 6, 6]],
+              cycle_length=2, block_length=1)))
+  def testPythonImplementation(
+      self, input_lists, expected_elements, cycle_length, block_length):
     for index, (expected, produced) in enumerate(
-        zip_longest(expected_elements, self._interleave(input_lists, 2, 1))):
-      self.assertEqual(expected, produced, "Values differ at %s. %s != %s" %
-                       (index, expected, produced))
-
-  @combinations.generate(test_base.default_test_combinations())
-  def testPythonImplementationBlockLength(self):
-    input_lists = [[4] * 4, [5] * 5, [6] * 6] * 2
-    expected_elements = [
-        4, 4, 5, 5, 4, 4, 5, 5, 5, 6, 6, 4, 4, 6, 6, 4, 4, 6, 6, 5, 5, 6, 6, 5,
-        5, 6, 6, 5, 6, 6
-    ]
-    for index, (expected, produced) in enumerate(
-        zip_longest(expected_elements, self._interleave(input_lists, 2, 2))):
-      self.assertEqual(expected, produced, "Values differ at %s. %s != %s" %
-                       (index, expected, produced))
-
-  @combinations.generate(test_base.default_test_combinations())
-  def testPythonImplementationEmptyLists(self):
-    input_lists = [[4, 4, 4, 4], [], [6, 6, 6, 6, 6, 6], [4, 4, 4, 4], [],
-                   [6, 6, 6, 6, 6, 6]]
-
-    expected_elements = [
-        4, 4, 6, 4, 6, 4, 6, 6, 4, 6, 4, 6, 4, 4, 6, 6, 6, 6, 6, 6
-    ]
-    for index, (expected, produced) in enumerate(
-        zip_longest(expected_elements, self._interleave(input_lists, 2, 1))):
+        zip_longest(expected_elements,
+                    self._interleave(input_lists, cycle_length, block_length))):
       self.assertEqual(expected, produced, "Values differ at %s. %s != %s" %
                        (index, expected, produced))
 
@@ -172,7 +157,12 @@
     for i in range(4, 7):
       self.write_coordination_events[i].set()
 
-  def _testSingleThreaded(self, sloppy=False, prefetch_input_elements=0):
+  @combinations.generate(
+      combinations.times(
+          test_base.default_test_combinations(),
+          combinations.combine(
+              sloppy=[False, True], prefetch_input_elements=[0, 1])))
+  def testSingleThreaded(self, sloppy, prefetch_input_elements):
     # cycle_length=1,block_length=1 acts like `Dataset.interleave()` and
     # `Dataset.flat_map()` and is single-threaded. No synchronization required.
     self.skipTest("b/131722904")
@@ -194,22 +184,6 @@
       self.evaluate(next_element())
 
   @combinations.generate(test_base.default_test_combinations())
-  def testSingleThreaded(self):
-    self._testSingleThreaded()
-
-  @combinations.generate(test_base.default_test_combinations())
-  def testSingleThreadedSloppy(self):
-    self._testSingleThreaded(sloppy=True)
-
-  @combinations.generate(test_base.default_test_combinations())
-  def testSingleThreadedPrefetch1Itr(self):
-    self._testSingleThreaded(prefetch_input_elements=1)
-
-  @combinations.generate(test_base.default_test_combinations())
-  def testSingleThreadedPrefetch1ItrSloppy(self):
-    self._testSingleThreaded(prefetch_input_elements=1, sloppy=True)
-
-  @combinations.generate(test_base.default_test_combinations())
   def testSingleThreadedRagged(self):
     # Tests a sequence with wildly different elements per iterator.
     self.skipTest("b/131722904")
@@ -237,7 +211,11 @@
     with self.assertRaises(errors.OutOfRangeError):
       self.evaluate(next_element())
 
-  def _testTwoThreadsNoContention(self, sloppy=False):
+  @combinations.generate(
+      combinations.times(
+          test_base.default_test_combinations(),
+          combinations.combine(sloppy=[False, True])))
+  def testTwoThreadsNoContention(self, sloppy):
     # num_threads > 1.
     # Explicit coordination should result in `Dataset.interleave()` behavior
     self.skipTest("b/131722904")
@@ -268,15 +246,11 @@
     with self.assertRaises(errors.OutOfRangeError):
       self.evaluate(next_element())
 
-  @combinations.generate(test_base.default_test_combinations())
-  def testTwoThreadsNoContention(self):
-    self._testTwoThreadsNoContention()
-
-  @combinations.generate(test_base.default_test_combinations())
-  def testTwoThreadsNoContentionSloppy(self):
-    self._testTwoThreadsNoContention(sloppy=True)
-
-  def _testTwoThreadsNoContentionWithRaces(self, sloppy=False):
+  @combinations.generate(
+      combinations.times(
+          test_base.default_test_combinations(),
+          combinations.combine(sloppy=[False, True])))
+  def testTwoThreadsNoContentionWithRaces(self, sloppy):
     """Tests where all the workers race in producing elements.
 
     Note: this is in contrast with the previous test which carefully sequences
@@ -317,15 +291,11 @@
     with self.assertRaises(errors.OutOfRangeError):
       self.evaluate(next_element())
 
-  @combinations.generate(test_base.default_test_combinations())
-  def testTwoThreadsNoContentionWithRaces(self):
-    self._testTwoThreadsNoContentionWithRaces()
-
-  @combinations.generate(test_base.default_test_combinations())
-  def testTwoThreadsNoContentionWithRacesSloppy(self):
-    self._testTwoThreadsNoContentionWithRaces(sloppy=True)
-
-  def _testTwoThreadsNoContentionBlockLength(self, sloppy=False):
+  @combinations.generate(
+      combinations.times(
+          test_base.default_test_combinations(),
+          combinations.combine(sloppy=[False, True])))
+  def testTwoThreadsNoContentionBlockLength(self, sloppy):
     # num_threads > 1.
     # Explicit coordination should result in `Dataset.interleave()` behavior
     self.skipTest("b/131722904")
@@ -356,15 +326,11 @@
     with self.assertRaises(errors.OutOfRangeError):
       self.evaluate(next_element())
 
-  @combinations.generate(test_base.default_test_combinations())
-  def testTwoThreadsNoContentionBlockLength(self):
-    self._testTwoThreadsNoContentionBlockLength()
-
-  @combinations.generate(test_base.default_test_combinations())
-  def testTwoThreadsNoContentionBlockLengthSloppy(self):
-    self._testTwoThreadsNoContentionBlockLength(sloppy=True)
-
-  def _testTwoThreadsNoContentionWithRacesAndBlocking(self, sloppy=False):
+  @combinations.generate(
+      combinations.times(
+          test_base.default_test_combinations(),
+          combinations.combine(sloppy=[False, True])))
+  def testTwoThreadsNoContentionWithRacesAndBlocking(self, sloppy):
     """Tests where all the workers race in producing elements.
 
     Note: this is in contrast with the previous test which carefully sequences
@@ -406,15 +372,11 @@
     with self.assertRaises(errors.OutOfRangeError):
       self.evaluate(next_element())
 
-  @combinations.generate(test_base.default_test_combinations())
-  def testTwoThreadsNoContentionWithRacesAndBlocking(self):
-    self._testTwoThreadsNoContentionWithRacesAndBlocking()
-
-  @combinations.generate(test_base.default_test_combinations())
-  def testTwoThreadsNoContentionWithRacesAndBlockingSloppy(self):
-    self._testTwoThreadsNoContentionWithRacesAndBlocking(sloppy=True)
-
-  def _testEmptyInput(self, sloppy=False):
+  @combinations.generate(
+      combinations.times(
+          test_base.default_test_combinations(),
+          combinations.combine(sloppy=[False, True])))
+  def testEmptyInput(self, sloppy):
     # Empty input.
     self._clear_coordination_events()
     next_element = self.getNext(
@@ -428,15 +390,11 @@
     with self.assertRaises(errors.OutOfRangeError):
       self.evaluate(next_element())
 
-  @combinations.generate(test_base.default_test_combinations())
-  def testEmptyInput(self):
-    self._testEmptyInput()
-
-  @combinations.generate(test_base.default_test_combinations())
-  def testEmptyInputSloppy(self):
-    self._testEmptyInput(sloppy=True)
-
-  def _testNonEmptyInputIntoEmptyOutputs(self, sloppy=False):
+  @combinations.generate(
+      combinations.times(
+          test_base.default_test_combinations(),
+          combinations.combine(sloppy=[False, True])))
+  def _testNonEmptyInputIntoEmptyOutputs(self, sloppy):
     # Non-empty input leading to empty output.
     self._clear_coordination_events()
     next_element = self.getNext(
@@ -450,15 +408,12 @@
     with self.assertRaises(errors.OutOfRangeError):
       self.evaluate(next_element())
 
-  @combinations.generate(test_base.default_test_combinations())
-  def testNonEmptyInputIntoEmptyOutputs(self):
-    self._testNonEmptyInputIntoEmptyOutputs()
-
-  @combinations.generate(test_base.default_test_combinations())
-  def testNonEmptyInputIntoEmptyOutputsSloppy(self):
-    self._testNonEmptyInputIntoEmptyOutputs(sloppy=True)
-
-  def _testPartiallyEmptyOutputs(self, sloppy=False, prefetch_input_elements=1):
+  @combinations.generate(
+      combinations.times(
+          test_base.default_test_combinations(),
+          combinations.combine(
+              sloppy=[False, True], prefetch_input_elements=[1, 0])))
+  def testPartiallyEmptyOutputs(self, sloppy, prefetch_input_elements):
     race_indices = {2, 8, 14}  # Sequence points when sloppy mode has race conds
     # Mixture of non-empty and empty interleaved datasets.
     self.skipTest("b/131722904")
@@ -491,14 +446,6 @@
                                                  actual_element))
 
   @combinations.generate(test_base.default_test_combinations())
-  def testPartiallyEmptyOutputs(self):
-    self._testPartiallyEmptyOutputs()
-
-  @combinations.generate(test_base.default_test_combinations())
-  def testPartiallyEmptyOutputsSloppy(self):
-    self._testPartiallyEmptyOutputs(sloppy=True, prefetch_input_elements=0)
-
-  @combinations.generate(test_base.default_test_combinations())
   def testDelayedOutputSloppy(self):
     # Explicitly control the sequence of events to ensure we correctly avoid
     # head-of-line blocking.
@@ -558,7 +505,11 @@
     with self.assertRaises(errors.OutOfRangeError):
       self.evaluate(next_element())
 
-  def _testEarlyExit(self, sloppy=False):
+  @combinations.generate(
+      combinations.times(
+          test_base.default_test_combinations(),
+          combinations.combine(sloppy=[False, True])))
+  def testEarlyExit(self, sloppy):
     # Exiting without consuming all input should not block
     self.skipTest("b/131722904")
     self._clear_coordination_events()
@@ -582,15 +533,11 @@
       self.read_coordination_events[i].acquire()
       self.write_coordination_events[i].set()
 
-  @combinations.generate(test_base.default_test_combinations())
-  def testEarlyExit(self):
-    self._testEarlyExit()
-
-  @combinations.generate(test_base.default_test_combinations())
-  def testEarlyExitSloppy(self):
-    self._testEarlyExit(sloppy=True)
-
-  def _testTooManyReaders(self, sloppy=False):
+  @combinations.generate(
+      combinations.times(
+          test_base.default_test_combinations(),
+          combinations.combine(sloppy=[False, True])))
+  def testTooManyReaders(self, sloppy=False):
 
     def interleave_fn(x):
       dataset = dataset_ops.Dataset.from_tensors(x)
@@ -612,14 +559,6 @@
     self.assertItemsEqual(output_values, expected_values)
 
   @combinations.generate(test_base.default_test_combinations())
-  def testTooManyReaders(self):
-    self._testTooManyReaders()
-
-  @combinations.generate(test_base.default_test_combinations())
-  def testTooManyReadersSloppy(self):
-    self._testTooManyReaders(sloppy=True)
-
-  @combinations.generate(test_base.default_test_combinations())
   def testSparse(self):
     def _map_fn(i):
       return sparse_tensor.SparseTensor(